Skip to content

Commit adf8817

Browse files
authored
PYTHON-4584 Add length option to Cursor.to_list for motor compat (#1791)
1 parent f2f75fc commit adf8817

File tree

9 files changed

+164
-28
lines changed

9 files changed

+164
-28
lines changed

gridfs/asynchronous/grid_file.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1892,8 +1892,16 @@ async def next(self) -> AsyncGridOut:
18921892
next_file = await super().next()
18931893
return AsyncGridOut(self._root_collection, file_document=next_file, session=self.session)
18941894

1895-
async def to_list(self) -> list[AsyncGridOut]:
1896-
return [x async for x in self] # noqa: C416,RUF100
1895+
async def to_list(self, length: Optional[int] = None) -> list[AsyncGridOut]:
1896+
"""Convert the cursor to a list."""
1897+
if length is None:
1898+
return [x async for x in self] # noqa: C416,RUF100
1899+
if length < 1:
1900+
raise ValueError("to_list() length must be greater than 0")
1901+
ret = []
1902+
for _ in range(length):
1903+
ret.append(await self.next())
1904+
return ret
18971905

18981906
__anext__ = next
18991907

gridfs/synchronous/grid_file.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1878,8 +1878,16 @@ def next(self) -> GridOut:
18781878
next_file = super().next()
18791879
return GridOut(self._root_collection, file_document=next_file, session=self.session)
18801880

1881-
def to_list(self) -> list[GridOut]:
1882-
return [x for x in self] # noqa: C416,RUF100
1881+
def to_list(self, length: Optional[int] = None) -> list[GridOut]:
1882+
"""Convert the cursor to a list."""
1883+
if length is None:
1884+
return [x for x in self] # noqa: C416,RUF100
1885+
if length < 1:
1886+
raise ValueError("to_list() length must be greater than 0")
1887+
ret = []
1888+
for _ in range(length):
1889+
ret.append(self.next())
1890+
return ret
18831891

18841892
__next__ = next
18851893

pymongo/asynchronous/command_cursor.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -346,13 +346,17 @@ async def _try_next(self, get_more_allowed: bool) -> Optional[_DocumentType]:
346346
else:
347347
return None
348348

349-
async def _next_batch(self, result: list) -> bool:
350-
"""Get all available documents from the cursor."""
349+
async def _next_batch(self, result: list, total: Optional[int] = None) -> bool:
350+
"""Get all or some available documents from the cursor."""
351351
if not len(self._data) and not self._killed:
352352
await self._refresh()
353353
if len(self._data):
354-
result.extend(self._data)
355-
self._data.clear()
354+
if total is None:
355+
result.extend(self._data)
356+
self._data.clear()
357+
else:
358+
for _ in range(min(len(self._data), total)):
359+
result.append(self._data.popleft())
356360
return True
357361
else:
358362
return False
@@ -381,21 +385,32 @@ async def __aenter__(self) -> AsyncCommandCursor[_DocumentType]:
381385
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
382386
await self.close()
383387

384-
async def to_list(self) -> list[_DocumentType]:
388+
async def to_list(self, length: Optional[int] = None) -> list[_DocumentType]:
385389
"""Converts the contents of this cursor to a list more efficiently than ``[doc async for doc in cursor]``.
386390
387391
To use::
388392
389393
>>> await cursor.to_list()
390394
395+
Or, so read at most n items from the cursor::
396+
397+
>>> await cursor.to_list(n)
398+
391399
If the cursor is empty or has no more results, an empty list will be returned.
392400
393401
.. versionadded:: 4.9
394402
"""
395403
res: list[_DocumentType] = []
404+
remaining = length
405+
if isinstance(length, int) and length < 1:
406+
raise ValueError("to_list() length must be greater than 0")
396407
while self.alive:
397-
if not await self._next_batch(res):
408+
if not await self._next_batch(res, remaining):
398409
break
410+
if length is not None:
411+
remaining = length - len(res)
412+
if remaining == 0:
413+
break
399414
return res
400415

401416

pymongo/asynchronous/cursor.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1260,16 +1260,20 @@ async def next(self) -> _DocumentType:
12601260
else:
12611261
raise StopAsyncIteration
12621262

1263-
async def _next_batch(self, result: list) -> bool:
1264-
"""Get all available documents from the cursor."""
1263+
async def _next_batch(self, result: list, total: Optional[int] = None) -> bool:
1264+
"""Get all or some documents from the cursor."""
12651265
if not self._exhaust_checked:
12661266
self._exhaust_checked = True
12671267
await self._supports_exhaust()
12681268
if self._empty:
12691269
return False
12701270
if len(self._data) or await self._refresh():
1271-
result.extend(self._data)
1272-
self._data.clear()
1271+
if total is None:
1272+
result.extend(self._data)
1273+
self._data.clear()
1274+
else:
1275+
for _ in range(min(len(self._data), total)):
1276+
result.append(self._data.popleft())
12731277
return True
12741278
else:
12751279
return False
@@ -1286,21 +1290,32 @@ async def __aenter__(self) -> AsyncCursor[_DocumentType]:
12861290
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
12871291
await self.close()
12881292

1289-
async def to_list(self) -> list[_DocumentType]:
1293+
async def to_list(self, length: Optional[int] = None) -> list[_DocumentType]:
12901294
"""Converts the contents of this cursor to a list more efficiently than ``[doc async for doc in cursor]``.
12911295
12921296
To use::
12931297
12941298
>>> await cursor.to_list()
12951299
1300+
Or, so read at most n items from the cursor::
1301+
1302+
>>> await cursor.to_list(n)
1303+
12961304
If the cursor is empty or has no more results, an empty list will be returned.
12971305
12981306
.. versionadded:: 4.9
12991307
"""
13001308
res: list[_DocumentType] = []
1309+
remaining = length
1310+
if isinstance(length, int) and length < 1:
1311+
raise ValueError("to_list() length must be greater than 0")
13011312
while self.alive:
1302-
if not await self._next_batch(res):
1313+
if not await self._next_batch(res, remaining):
13031314
break
1315+
if length is not None:
1316+
remaining = length - len(res)
1317+
if remaining == 0:
1318+
break
13041319
return res
13051320

13061321

pymongo/synchronous/command_cursor.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -346,13 +346,17 @@ def _try_next(self, get_more_allowed: bool) -> Optional[_DocumentType]:
346346
else:
347347
return None
348348

349-
def _next_batch(self, result: list) -> bool:
350-
"""Get all available documents from the cursor."""
349+
def _next_batch(self, result: list, total: Optional[int] = None) -> bool:
350+
"""Get all or some available documents from the cursor."""
351351
if not len(self._data) and not self._killed:
352352
self._refresh()
353353
if len(self._data):
354-
result.extend(self._data)
355-
self._data.clear()
354+
if total is None:
355+
result.extend(self._data)
356+
self._data.clear()
357+
else:
358+
for _ in range(min(len(self._data), total)):
359+
result.append(self._data.popleft())
356360
return True
357361
else:
358362
return False
@@ -381,21 +385,32 @@ def __enter__(self) -> CommandCursor[_DocumentType]:
381385
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
382386
self.close()
383387

384-
def to_list(self) -> list[_DocumentType]:
388+
def to_list(self, length: Optional[int] = None) -> list[_DocumentType]:
385389
"""Converts the contents of this cursor to a list more efficiently than ``[doc for doc in cursor]``.
386390
387391
To use::
388392
389393
>>> cursor.to_list()
390394
395+
Or, so read at most n items from the cursor::
396+
397+
>>> cursor.to_list(n)
398+
391399
If the cursor is empty or has no more results, an empty list will be returned.
392400
393401
.. versionadded:: 4.9
394402
"""
395403
res: list[_DocumentType] = []
404+
remaining = length
405+
if isinstance(length, int) and length < 1:
406+
raise ValueError("to_list() length must be greater than 0")
396407
while self.alive:
397-
if not self._next_batch(res):
408+
if not self._next_batch(res, remaining):
398409
break
410+
if length is not None:
411+
remaining = length - len(res)
412+
if remaining == 0:
413+
break
399414
return res
400415

401416

pymongo/synchronous/cursor.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1258,16 +1258,20 @@ def next(self) -> _DocumentType:
12581258
else:
12591259
raise StopIteration
12601260

1261-
def _next_batch(self, result: list) -> bool:
1262-
"""Get all available documents from the cursor."""
1261+
def _next_batch(self, result: list, total: Optional[int] = None) -> bool:
1262+
"""Get all or some documents from the cursor."""
12631263
if not self._exhaust_checked:
12641264
self._exhaust_checked = True
12651265
self._supports_exhaust()
12661266
if self._empty:
12671267
return False
12681268
if len(self._data) or self._refresh():
1269-
result.extend(self._data)
1270-
self._data.clear()
1269+
if total is None:
1270+
result.extend(self._data)
1271+
self._data.clear()
1272+
else:
1273+
for _ in range(min(len(self._data), total)):
1274+
result.append(self._data.popleft())
12711275
return True
12721276
else:
12731277
return False
@@ -1284,21 +1288,32 @@ def __enter__(self) -> Cursor[_DocumentType]:
12841288
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
12851289
self.close()
12861290

1287-
def to_list(self) -> list[_DocumentType]:
1291+
def to_list(self, length: Optional[int] = None) -> list[_DocumentType]:
12881292
"""Converts the contents of this cursor to a list more efficiently than ``[doc for doc in cursor]``.
12891293
12901294
To use::
12911295
12921296
>>> cursor.to_list()
12931297
1298+
Or, so read at most n items from the cursor::
1299+
1300+
>>> cursor.to_list(n)
1301+
12941302
If the cursor is empty or has no more results, an empty list will be returned.
12951303
12961304
.. versionadded:: 4.9
12971305
"""
12981306
res: list[_DocumentType] = []
1307+
remaining = length
1308+
if isinstance(length, int) and length < 1:
1309+
raise ValueError("to_list() length must be greater than 0")
12991310
while self.alive:
1300-
if not self._next_batch(res):
1311+
if not self._next_batch(res, remaining):
13011312
break
1313+
if length is not None:
1314+
remaining = length - len(res)
1315+
if remaining == 0:
1316+
break
13021317
return res
13031318

13041319

test/asynchronous/test_cursor.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1401,6 +1401,20 @@ async def test_to_list_empty(self):
14011401
docs = await c.to_list()
14021402
self.assertEqual([], docs)
14031403

1404+
async def test_to_list_length(self):
1405+
coll = self.db.test
1406+
await coll.insert_many([{} for _ in range(5)])
1407+
self.addCleanup(coll.drop)
1408+
c = coll.find()
1409+
docs = await c.to_list(3)
1410+
self.assertEqual(len(docs), 3)
1411+
1412+
c = coll.find(batch_size=2)
1413+
docs = await c.to_list(3)
1414+
self.assertEqual(len(docs), 3)
1415+
docs = await c.to_list(3)
1416+
self.assertEqual(len(docs), 2)
1417+
14041418
@async_client_context.require_change_streams
14051419
async def test_command_cursor_to_list(self):
14061420
# Set maxAwaitTimeMS=1 to speed up the test.
@@ -1417,6 +1431,19 @@ async def test_command_cursor_to_list_empty(self):
14171431
docs = await c.to_list()
14181432
self.assertEqual([], docs)
14191433

1434+
@async_client_context.require_change_streams
1435+
async def test_command_cursor_to_list_length(self):
1436+
db = self.db
1437+
await db.drop_collection("test")
1438+
await db.test.insert_many([{"foo": 1}, {"foo": 2}])
1439+
1440+
pipeline = {"$project": {"_id": False, "foo": True}}
1441+
result = await db.test.aggregate([pipeline])
1442+
self.assertEqual(len(await result.to_list()), 2)
1443+
1444+
result = await db.test.aggregate([pipeline])
1445+
self.assertEqual(len(await result.to_list(1)), 1)
1446+
14201447

14211448
class TestRawBatchCursor(AsyncIntegrationTest):
14221449
async def test_find_raw(self):

test/test_cursor.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1392,6 +1392,20 @@ def test_to_list_empty(self):
13921392
docs = c.to_list()
13931393
self.assertEqual([], docs)
13941394

1395+
def test_to_list_length(self):
1396+
coll = self.db.test
1397+
coll.insert_many([{} for _ in range(5)])
1398+
self.addCleanup(coll.drop)
1399+
c = coll.find()
1400+
docs = c.to_list(3)
1401+
self.assertEqual(len(docs), 3)
1402+
1403+
c = coll.find(batch_size=2)
1404+
docs = c.to_list(3)
1405+
self.assertEqual(len(docs), 3)
1406+
docs = c.to_list(3)
1407+
self.assertEqual(len(docs), 2)
1408+
13951409
@client_context.require_change_streams
13961410
def test_command_cursor_to_list(self):
13971411
# Set maxAwaitTimeMS=1 to speed up the test.
@@ -1408,6 +1422,19 @@ def test_command_cursor_to_list_empty(self):
14081422
docs = c.to_list()
14091423
self.assertEqual([], docs)
14101424

1425+
@client_context.require_change_streams
1426+
def test_command_cursor_to_list_length(self):
1427+
db = self.db
1428+
db.drop_collection("test")
1429+
db.test.insert_many([{"foo": 1}, {"foo": 2}])
1430+
1431+
pipeline = {"$project": {"_id": False, "foo": True}}
1432+
result = db.test.aggregate([pipeline])
1433+
self.assertEqual(len(result.to_list()), 2)
1434+
1435+
result = db.test.aggregate([pipeline])
1436+
self.assertEqual(len(result.to_list(1)), 1)
1437+
14111438

14121439
class TestRawBatchCursor(IntegrationTest):
14131440
def test_find_raw(self):

test/test_gridfs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,12 @@ def test_gridfs_find(self):
440440
gout = next(cursor)
441441
self.assertEqual(b"test2+", gout.read())
442442
self.assertRaises(StopIteration, cursor.__next__)
443+
cursor.rewind()
444+
items = cursor.to_list()
445+
self.assertEqual(len(items), 2)
446+
cursor.rewind()
447+
items = cursor.to_list(1)
448+
self.assertEqual(len(items), 1)
443449
cursor.close()
444450
self.assertRaises(TypeError, self.fs.find, {}, {"_id": True})
445451

0 commit comments

Comments
 (0)