diff --git a/toolz/__init__.py b/toolz/__init__.py index 8e3b0b9a..c7bfa05c 100644 --- a/toolz/__init__.py +++ b/toolz/__init__.py @@ -17,7 +17,7 @@ # Aliases comp = compose -from . import curried, sandbox +from . import curried, exceptions, sandbox functoolz._sigs.create_signature_registry() diff --git a/toolz/exceptions.py b/toolz/exceptions.py new file mode 100644 index 00000000..aba77133 --- /dev/null +++ b/toolz/exceptions.py @@ -0,0 +1,6 @@ + +__all__ = ('IterationError',) + + +class IterationError(RuntimeError): + pass diff --git a/toolz/itertoolz.py b/toolz/itertoolz.py index b8165162..bc15a633 100644 --- a/toolz/itertoolz.py +++ b/toolz/itertoolz.py @@ -7,6 +7,7 @@ from random import Random from collections.abc import Sequence from toolz.utils import no_default +from toolz.exceptions import IterationError __all__ = ('remove', 'accumulate', 'groupby', 'merge_sorted', 'interleave', @@ -373,7 +374,9 @@ def first(seq): >>> first('ABC') 'A' """ - return next(iter(seq)) + for rv in seq: + return rv + raise IterationError("Received empty sequence") def second(seq): @@ -383,8 +386,14 @@ def second(seq): 'B' """ seq = iter(seq) - next(seq) - return next(seq) + for item in seq: + break + else: + raise IterationError("Received empty sequence") + for item in seq: + return item + else: + raise IterationError("Length of sequence is < 2") def nth(n, seq): @@ -396,7 +405,9 @@ def nth(n, seq): if isinstance(seq, (tuple, list, Sequence)): return seq[n] else: - return next(itertools.islice(seq, n, None)) + for rv in itertools.islice(seq, n, None): + return rv + raise IterationError("Length of seq is < %d" % n) def last(seq): @@ -531,8 +542,9 @@ def interpose(el, seq): [1, 'a', 2, 'a', 3] """ inposed = concat(zip(itertools.repeat(el), seq)) - next(inposed) - return inposed + for _ in inposed: + return inposed + raise IterationError("Received empty sequence") def frequencies(seq): @@ -722,13 +734,16 @@ def partition_all(n, seq): """ args = [iter(seq)] * n it = zip_longest(*args, fillvalue=no_pad) + try: prev = next(it) except StopIteration: return + for item in it: yield prev prev = item + if prev[-1] is no_pad: try: # If seq defines __len__, then @@ -997,8 +1012,11 @@ def peek(seq): [0, 1, 2, 3, 4] """ iterator = iter(seq) - item = next(iterator) - return item, itertools.chain((item,), iterator) + for peeked in iterator: + break + else: + raise IterationError("Received empty sequence") + return peeked, itertools.chain((peeked,), iterator) def peekn(n, seq): @@ -1016,7 +1034,7 @@ def peekn(n, seq): """ iterator = iter(seq) peeked = tuple(take(n, iterator)) - return peeked, itertools.chain(iter(peeked), iterator) + return peeked, itertools.chain(peeked, iterator) def random_sample(prob, seq, random_state=None): diff --git a/toolz/tests/test_itertoolz.py b/toolz/tests/test_itertoolz.py index 61618725..eb5e7a4b 100644 --- a/toolz/tests/test_itertoolz.py +++ b/toolz/tests/test_itertoolz.py @@ -14,8 +14,10 @@ sliding_window, count, partition, partition_all, take_nth, pluck, join, diff, topk, peek, peekn, random_sample) -from operator import add, mul +from toolz.exceptions import IterationError + +from operator import add, mul # is comparison will fail between this and no_default no_default2 = loads(dumps('__no__default__')) @@ -127,7 +129,7 @@ def test_nth(): assert nth(2, iter('ABCDE')) == 'C' assert nth(1, (3, 2, 1)) == 2 assert nth(0, {'foo': 'bar'}) == 'foo' - assert raises(StopIteration, lambda: nth(10, {10: 'foo'})) + assert raises(IterationError, lambda: nth(10, {10: 'foo'})) assert nth(-2, 'ABCDE') == 'D' assert raises(ValueError, lambda: nth(-2, iter('ABCDE'))) @@ -136,12 +138,15 @@ def test_first(): assert first('ABCDE') == 'A' assert first((3, 2, 1)) == 3 assert isinstance(first({0: 'zero', 1: 'one'}), int) + assert raises(IterationError, lambda: first([])) def test_second(): assert second('ABCDE') == 'B' assert second((3, 2, 1)) == 2 assert isinstance(second({0: 'zero', 1: 'one'}), int) + assert raises(IterationError, lambda: second([1])) + assert raises(IterationError, lambda: second([])) def test_last(): @@ -228,6 +233,7 @@ def test_interpose(): assert "tXaXrXzXaXn" == "".join(interpose("X", "tarzan")) assert list(interpose(0, itertools.repeat(1, 4))) == [1, 0, 1, 0, 1, 0, 1] assert list(interpose('.', ['a', 'b', 'c'])) == ['a', '.', 'b', '.', 'c'] + assert raises(IterationError, lambda: interpose('a', [])) def test_frequencies(): @@ -510,8 +516,7 @@ def test_peek(): element, blist = peek(alist) assert element == alist[0] assert list(blist) == alist - - assert raises(StopIteration, lambda: peek([])) + assert raises(IterationError, lambda: peek([])) def test_peekn():