Skip to content

Commit 0dca32f

Browse files
authored
fix: harden 'query.stream' against retriable exceptions (pylint-dev#456)
Closes pylint-dev#223.
1 parent 335e2c4 commit 0dca32f

File tree

2 files changed

+164
-7
lines changed

2 files changed

+164
-7
lines changed

google/cloud/firestore_v1/query.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
"""
2121
from google.cloud import firestore_v1
2222
from google.cloud.firestore_v1.base_document import DocumentSnapshot
23+
from google.api_core import exceptions # type: ignore
2324
from google.api_core import gapic_v1 # type: ignore
2425
from google.api_core import retry as retries # type: ignore
2526

@@ -208,6 +209,29 @@ def _chunkify(
208209
):
209210
return
210211

212+
def _get_stream_iterator(self, transaction, retry, timeout):
213+
"""Helper method for :meth:`stream`."""
214+
request, expected_prefix, kwargs = self._prep_stream(
215+
transaction, retry, timeout,
216+
)
217+
218+
response_iterator = self._client._firestore_api.run_query(
219+
request=request, metadata=self._client._rpc_metadata, **kwargs,
220+
)
221+
222+
return response_iterator, expected_prefix
223+
224+
def _retry_query_after_exception(self, exc, retry, transaction):
225+
"""Helper method for :meth:`stream`."""
226+
if transaction is None: # no snapshot-based retry inside transaction
227+
if retry is gapic_v1.method.DEFAULT:
228+
transport = self._client._firestore_api._transport
229+
gapic_callable = transport.run_query
230+
retry = gapic_callable._retry
231+
return retry._predicate(exc)
232+
233+
return False
234+
211235
def stream(
212236
self,
213237
transaction=None,
@@ -244,15 +268,28 @@ def stream(
244268
:class:`~google.cloud.firestore_v1.document.DocumentSnapshot`:
245269
The next document that fulfills the query.
246270
"""
247-
request, expected_prefix, kwargs = self._prep_stream(
271+
response_iterator, expected_prefix = self._get_stream_iterator(
248272
transaction, retry, timeout,
249273
)
250274

251-
response_iterator = self._client._firestore_api.run_query(
252-
request=request, metadata=self._client._rpc_metadata, **kwargs,
253-
)
275+
last_snapshot = None
276+
277+
while True:
278+
try:
279+
response = next(response_iterator, None)
280+
except exceptions.GoogleAPICallError as exc:
281+
if self._retry_query_after_exception(exc, retry, transaction):
282+
new_query = self.start_after(last_snapshot)
283+
response_iterator, _ = new_query._get_stream_iterator(
284+
transaction, retry, timeout,
285+
)
286+
continue
287+
else:
288+
raise
289+
290+
if response is None: # EOI
291+
break
254292

255-
for response in response_iterator:
256293
if self._all_descendants:
257294
snapshot = _collection_group_query_response_to_snapshot(
258295
response, self._parent
@@ -262,6 +299,7 @@ def stream(
262299
response, self._parent, expected_prefix
263300
)
264301
if snapshot is not None:
302+
last_snapshot = snapshot
265303
yield snapshot
266304

267305
def on_snapshot(self, callback: Callable) -> Watch:

tests/unit/v1/test_query.py

Lines changed: 121 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from google.cloud.firestore_v1.types.document import Document
16-
from google.cloud.firestore_v1.types.firestore import RunQueryResponse
1715
import types
1816
import unittest
1917

2018
import mock
2119
import pytest
2220

21+
from google.api_core import gapic_v1
22+
from google.cloud.firestore_v1.types.document import Document
23+
from google.cloud.firestore_v1.types.firestore import RunQueryResponse
2324
from tests.unit.v1.test_base_query import _make_credentials
2425
from tests.unit.v1.test_base_query import _make_cursor_pb
2526
from tests.unit.v1.test_base_query import _make_query_response
@@ -456,6 +457,124 @@ def test_stream_w_collection_group(self):
456457
metadata=client._rpc_metadata,
457458
)
458459

460+
def _stream_w_retriable_exc_helper(
461+
self,
462+
retry=gapic_v1.method.DEFAULT,
463+
timeout=None,
464+
transaction=None,
465+
expect_retry=True,
466+
):
467+
from google.api_core import exceptions
468+
from google.cloud.firestore_v1 import _helpers
469+
470+
if transaction is not None:
471+
expect_retry = False
472+
473+
# Create a minimal fake GAPIC.
474+
firestore_api = mock.Mock(spec=["run_query", "_transport"])
475+
transport = firestore_api._transport = mock.Mock(spec=["run_query"])
476+
stub = transport.run_query = mock.create_autospec(
477+
gapic_v1.method._GapicCallable
478+
)
479+
stub._retry = mock.Mock(spec=["_predicate"])
480+
stub._predicate = lambda exc: True # pragma: NO COVER
481+
482+
# Attach the fake GAPIC to a real client.
483+
client = _make_client()
484+
client._firestore_api_internal = firestore_api
485+
486+
# Make a **real** collection reference as parent.
487+
parent = client.collection("dee")
488+
489+
# Add a dummy response to the minimal fake GAPIC.
490+
_, expected_prefix = parent._parent_info()
491+
name = "{}/sleep".format(expected_prefix)
492+
data = {"snooze": 10}
493+
response_pb = _make_query_response(name=name, data=data)
494+
retriable_exc = exceptions.ServiceUnavailable("testing")
495+
496+
def _stream_w_exception(*_args, **_kw):
497+
yield response_pb
498+
raise retriable_exc
499+
500+
firestore_api.run_query.side_effect = [_stream_w_exception(), iter([])]
501+
kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)
502+
503+
# Execute the query and check the response.
504+
query = self._make_one(parent)
505+
506+
get_response = query.stream(transaction=transaction, **kwargs)
507+
508+
self.assertIsInstance(get_response, types.GeneratorType)
509+
if expect_retry:
510+
returned = list(get_response)
511+
else:
512+
returned = [next(get_response)]
513+
with self.assertRaises(exceptions.ServiceUnavailable):
514+
next(get_response)
515+
516+
self.assertEqual(len(returned), 1)
517+
snapshot = returned[0]
518+
self.assertEqual(snapshot.reference._path, ("dee", "sleep"))
519+
self.assertEqual(snapshot.to_dict(), data)
520+
521+
# Verify the mock call.
522+
parent_path, _ = parent._parent_info()
523+
calls = firestore_api.run_query.call_args_list
524+
525+
if expect_retry:
526+
self.assertEqual(len(calls), 2)
527+
else:
528+
self.assertEqual(len(calls), 1)
529+
530+
if transaction is not None:
531+
expected_transaction_id = transaction.id
532+
else:
533+
expected_transaction_id = None
534+
535+
self.assertEqual(
536+
calls[0],
537+
mock.call(
538+
request={
539+
"parent": parent_path,
540+
"structured_query": query._to_protobuf(),
541+
"transaction": expected_transaction_id,
542+
},
543+
metadata=client._rpc_metadata,
544+
**kwargs,
545+
),
546+
)
547+
548+
if expect_retry:
549+
new_query = query.start_after(snapshot)
550+
self.assertEqual(
551+
calls[1],
552+
mock.call(
553+
request={
554+
"parent": parent_path,
555+
"structured_query": new_query._to_protobuf(),
556+
"transaction": None,
557+
},
558+
metadata=client._rpc_metadata,
559+
**kwargs,
560+
),
561+
)
562+
563+
def test_stream_w_retriable_exc_w_defaults(self):
564+
self._stream_w_retriable_exc_helper()
565+
566+
def test_stream_w_retriable_exc_w_retry(self):
567+
retry = mock.Mock(spec=["_predicate"])
568+
retry._predicate = lambda exc: False
569+
self._stream_w_retriable_exc_helper(retry=retry, expect_retry=False)
570+
571+
def test_stream_w_retriable_exc_w_transaction(self):
572+
from google.cloud.firestore_v1 import transaction
573+
574+
txn = transaction.Transaction(client=mock.Mock(spec=[]))
575+
txn._id = b"DEADBEEF"
576+
self._stream_w_retriable_exc_helper(transaction=txn)
577+
459578
@mock.patch("google.cloud.firestore_v1.query.Watch", autospec=True)
460579
def test_on_snapshot(self, watch):
461580
query = self._make_one(mock.sentinel.parent)

0 commit comments

Comments
 (0)