diff --git a/google/cloud/bigquery/client.py b/google/cloud/bigquery/client.py index bb4d80c73..811e9ef03 100644 --- a/google/cloud/bigquery/client.py +++ b/google/cloud/bigquery/client.py @@ -2045,6 +2045,7 @@ def _get_query_results( location: Optional[str] = None, timeout: TimeoutType = DEFAULT_TIMEOUT, page_size: int = 0, + start_index: Optional[int] = None, ) -> _QueryResults: """Get the query results object for a query job. @@ -2063,9 +2064,12 @@ def _get_query_results( before using ``retry``. If set, this connection timeout may be increased to a minimum value. This prevents retries on what would otherwise be a successful response. - page_size (int): + page_size (Optional[int]): Maximum number of rows in a single response. See maxResults in the jobs.getQueryResults REST API. + start_index (Optional[int]): + Zero-based index of the starting row. See startIndex in the + jobs.getQueryResults REST API. Returns: google.cloud.bigquery.query._QueryResults: @@ -2095,6 +2099,9 @@ def _get_query_results( if location is not None: extra_params["location"] = location + if start_index is not None: + extra_params["startIndex"] = start_index + path = "/projects/{}/queries/{}".format(project, job_id) # This call is typically made in a polling loop that checks whether the diff --git a/google/cloud/bigquery/job/query.py b/google/cloud/bigquery/job/query.py index 954a46963..4d95f0e71 100644 --- a/google/cloud/bigquery/job/query.py +++ b/google/cloud/bigquery/job/query.py @@ -1409,6 +1409,7 @@ def _reload_query_results( retry: "retries.Retry" = DEFAULT_RETRY, timeout: Optional[float] = None, page_size: int = 0, + start_index: Optional[int] = None, ): """Refresh the cached query results unless already cached and complete. @@ -1421,6 +1422,9 @@ def _reload_query_results( page_size (int): Maximum number of rows in a single response. See maxResults in the jobs.getQueryResults REST API. + start_index (Optional[int]): + Zero-based index of the starting row. See startIndex in the + jobs.getQueryResults REST API. """ # Optimization: avoid a call to jobs.getQueryResults if it's already # been fetched, e.g. from jobs.query first page of results. @@ -1468,6 +1472,7 @@ def _reload_query_results( location=self.location, timeout=transport_timeout, page_size=page_size, + start_index=start_index, ) def result( # type: ignore # (incompatible with supertype) @@ -1570,6 +1575,9 @@ def result( # type: ignore # (incompatible with supertype) if page_size is not None: reload_query_results_kwargs["page_size"] = page_size + if start_index is not None: + reload_query_results_kwargs["start_index"] = start_index + try: retry_do_query = getattr(self, "_retry_do_query", None) if retry_do_query is not None: diff --git a/google/cloud/bigquery/table.py b/google/cloud/bigquery/table.py index 3ffd5ca56..861f806b4 100644 --- a/google/cloud/bigquery/table.py +++ b/google/cloud/bigquery/table.py @@ -1987,12 +1987,19 @@ def _get_next_page_response(self): return response params = self._get_query_params() + + # If the user has provided page_size and start_index, we need to pass + # start_index for the first page, but for all subsequent pages, we + # should not pass start_index. We make a shallow copy of params and do + # not alter the original, so if the user iterates the results again, + # start_index is preserved. + params_copy = copy.copy(params) if self._page_size is not None: if self.page_number and "startIndex" in params: - del params["startIndex"] + del params_copy["startIndex"] return self.api_request( - method=self._HTTP_METHOD, path=self.path, query_params=params + method=self._HTTP_METHOD, path=self.path, query_params=params_copy ) @property diff --git a/tests/unit/job/test_query.py b/tests/unit/job/test_query.py index 1df65279d..46b802aa3 100644 --- a/tests/unit/job/test_query.py +++ b/tests/unit/job/test_query.py @@ -1682,6 +1682,78 @@ def test_result_with_start_index(self): tabledata_list_request[1]["query_params"]["maxResults"], page_size ) + def test_result_with_start_index_multi_page(self): + # When there are multiple pages of response and the user has set + # start_index, we should supply start_index to the server in the first + # request. However, in the subsequent requests, we will pass only + # page_token but not start_index, because the server only allows one + # of them. + from google.cloud.bigquery.table import RowIterator + + query_resource = { + "jobComplete": True, + "jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID}, + "schema": {"fields": [{"name": "col1", "type": "STRING"}]}, + "totalRows": "7", + } + + # Although the result has 7 rows, the response only returns 6, because + # start_index is 1. + tabledata_resource_1 = { + "totalRows": "7", + "pageToken": "page_token_1", + "rows": [ + {"f": [{"v": "abc"}]}, + {"f": [{"v": "def"}]}, + {"f": [{"v": "ghi"}]}, + ], + } + tabledata_resource_2 = { + "totalRows": "7", + "pageToken": None, + "rows": [ + {"f": [{"v": "jkl"}]}, + {"f": [{"v": "mno"}]}, + {"f": [{"v": "pqe"}]}, + ], + } + + connection = make_connection( + query_resource, tabledata_resource_1, tabledata_resource_2 + ) + client = _make_client(self.PROJECT, connection=connection) + resource = self._make_resource(ended=True) + job = self._get_target_class().from_api_repr(resource, client) + + start_index = 1 + page_size = 3 + + result = job.result(page_size=page_size, start_index=start_index) + + self.assertIsInstance(result, RowIterator) + self.assertEqual(result.total_rows, 7) + + rows = list(result) + + self.assertEqual(len(rows), 6) + self.assertEqual(len(connection.api_request.call_args_list), 3) + + # First call has both startIndex and maxResults. + tabledata_list_request_1 = connection.api_request.call_args_list[1] + self.assertEqual( + tabledata_list_request_1[1]["query_params"]["startIndex"], start_index + ) + self.assertEqual( + tabledata_list_request_1[1]["query_params"]["maxResults"], page_size + ) + + # Second call only has maxResults. + tabledata_list_request_2 = connection.api_request.call_args_list[2] + self.assertFalse("startIndex" in tabledata_list_request_2[1]["query_params"]) + self.assertEqual( + tabledata_list_request_2[1]["query_params"]["maxResults"], page_size + ) + def test_result_error(self): from google.cloud import exceptions