Skip to content

Raise IterationError on StopIteration #473

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
6 changes: 6 additions & 0 deletions toolz/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

__all__ = ('IterationError',)


class IterationError(RuntimeError):
pass
34 changes: 25 additions & 9 deletions toolz/itertoolz.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from random import Random
from collections.abc import Sequence
from toolz.utils import no_default
from toolz.exceptions import IterationError


__all__ = ('remove', 'accumulate', 'groupby', 'merge_sorted', 'interleave',
Expand Down Expand Up @@ -373,7 +374,9 @@ def first(seq):
>>> first('ABC')
'A'
"""
return next(iter(seq))
for rv in seq:
return rv
raise IterationError("Received empty sequence")


def second(seq):
Expand All @@ -382,9 +385,10 @@ def second(seq):
>>> second('ABC')
'B'
"""
seq = iter(seq)
next(seq)
return next(seq)
try:
return first(itertools.islice(seq, 1, None))
except IterationError as exc:
raise IterationError("Lenth of seq is < 2") from exc


def nth(n, seq):
Expand All @@ -396,7 +400,10 @@ def nth(n, seq):
if isinstance(seq, (tuple, list, Sequence)):
return seq[n]
else:
return next(itertools.islice(seq, n, None))
try:
return first(itertools.islice(seq, n, None))
except IterationError as exc:
raise IterationError("Length of seq is < %d" % n) from exc


def last(seq):
Expand Down Expand Up @@ -531,8 +538,11 @@ def interpose(el, seq):
[1, 'a', 2, 'a', 3]
"""
inposed = concat(zip(itertools.repeat(el), seq))
next(inposed)
return inposed
try:
next(inposed)
return inposed
except StopIteration:
raise IterationError("Received empty sequence")


def frequencies(seq):
Expand Down Expand Up @@ -722,13 +732,16 @@ def partition_all(n, seq):
"""
args = [iter(seq)] * n
it = zip_longest(*args, fillvalue=no_pad)

try:
prev = next(it)
except StopIteration:
return

for item in it:
yield prev
prev = item

if prev[-1] is no_pad:
try:
# If seq defines __len__, then
Expand Down Expand Up @@ -997,8 +1010,11 @@ def peek(seq):
[0, 1, 2, 3, 4]
"""
iterator = iter(seq)
item = next(iterator)
return item, itertools.chain((item,), iterator)
try:
item = next(iterator)
return item, itertools.chain((item,), iterator)
except StopIteration:
raise IterationError("Received empty sequence")


def peekn(n, seq):
Expand Down
12 changes: 8 additions & 4 deletions toolz/tests/test_itertoolz.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
sliding_window, count, partition,
partition_all, take_nth, pluck, join,
diff, topk, peek, peekn, random_sample)
from operator import add, mul

from toolz.exceptions import IterationError

from operator import add, mul

# is comparison will fail between this and no_default
no_default2 = loads(dumps('__no__default__'))
Expand Down Expand Up @@ -127,7 +129,7 @@ def test_nth():
assert nth(2, iter('ABCDE')) == 'C'
assert nth(1, (3, 2, 1)) == 2
assert nth(0, {'foo': 'bar'}) == 'foo'
assert raises(StopIteration, lambda: nth(10, {10: 'foo'}))
assert raises(IterationError, lambda: nth(10, {10: 'foo'}))
assert nth(-2, 'ABCDE') == 'D'
assert raises(ValueError, lambda: nth(-2, iter('ABCDE')))

Expand All @@ -136,12 +138,14 @@ def test_first():
assert first('ABCDE') == 'A'
assert first((3, 2, 1)) == 3
assert isinstance(first({0: 'zero', 1: 'one'}), int)
assert raises(IterationError, lambda: first([]))


def test_second():
assert second('ABCDE') == 'B'
assert second((3, 2, 1)) == 2
assert isinstance(second({0: 'zero', 1: 'one'}), int)
assert raises(IterationError, lambda: second([1]))


def test_last():
Expand Down Expand Up @@ -228,6 +232,7 @@ def test_interpose():
assert "tXaXrXzXaXn" == "".join(interpose("X", "tarzan"))
assert list(interpose(0, itertools.repeat(1, 4))) == [1, 0, 1, 0, 1, 0, 1]
assert list(interpose('.', ['a', 'b', 'c'])) == ['a', '.', 'b', '.', 'c']
assert raises(IterationError, lambda: interpose('a', []))


def test_frequencies():
Expand Down Expand Up @@ -510,8 +515,7 @@ def test_peek():
element, blist = peek(alist)
assert element == alist[0]
assert list(blist) == alist

assert raises(StopIteration, lambda: peek([]))
assert raises(IterationError, lambda: peek([]))


def test_peekn():
Expand Down