Skip to content

Commit 86479ae

Browse files
authored
Integrate StreamManager with run_sweep() (#6233)
* Integrate StreamManager with run() methods * Added tests and improved docstring * Deferred run_batch and run_calibration stream migration * Fix test failures
1 parent 5c36dc0 commit 86479ae

File tree

7 files changed

+510
-160
lines changed

7 files changed

+510
-160
lines changed

cirq-google/cirq_google/engine/engine.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ async def run_sweep_async(
278278
job_description: Optional[str] = None,
279279
job_labels: Optional[Dict[str, str]] = None,
280280
) -> engine_job.EngineJob:
281-
"""Runs the supplied Circuit via Quantum Engine.Creates
281+
"""Runs the supplied Circuit via Quantum Engine.
282282
283283
In contrast to run, this runs across multiple parameter sweeps, and
284284
does not block until a result is returned.
@@ -312,20 +312,35 @@ async def run_sweep_async(
312312
Raises:
313313
ValueError: If no gate set is provided.
314314
"""
315-
engine_program = await self.create_program_async(
316-
program, program_id, description=program_description, labels=program_labels
317-
)
318-
return await engine_program.run_sweep_async(
319-
job_id=job_id,
320-
params=params,
321-
repetitions=repetitions,
315+
if not program_id:
316+
program_id = _make_random_id('prog-')
317+
if not job_id:
318+
job_id = _make_random_id('job-')
319+
run_context = self.context._serialize_run_context(params, repetitions)
320+
321+
stream_job_response_future = self.context.client.run_job_over_stream(
322+
project_id=self.project_id,
323+
program_id=str(program_id),
324+
program_description=program_description,
325+
program_labels=program_labels,
326+
code=self.context._serialize_program(program),
327+
job_id=str(job_id),
322328
processor_ids=processor_ids,
323-
description=job_description,
324-
labels=job_labels,
329+
run_context=run_context,
330+
job_description=job_description,
331+
job_labels=job_labels,
332+
)
333+
return engine_job.EngineJob(
334+
self.project_id,
335+
str(program_id),
336+
str(job_id),
337+
self.context,
338+
stream_job_response_future=stream_job_response_future,
325339
)
326340

327341
run_sweep = duet.sync(run_sweep_async)
328342

343+
# TODO(#5996) Migrate to stream client
329344
async def run_batch_async(
330345
self,
331346
programs: Sequence[cirq.AbstractCircuit],
@@ -406,6 +421,7 @@ async def run_batch_async(
406421

407422
run_batch = duet.sync(run_batch_async)
408423

424+
# TODO(#5996) Migrate to stream client
409425
async def run_calibration_async(
410426
self,
411427
layers: List['cirq_google.CalibrationLayer'],

cirq-google/cirq_google/engine/engine_client.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from cirq._compat import cached_property
3939
from cirq_google.cloud import quantum
4040
from cirq_google.engine.asyncio_executor import AsyncioExecutor
41+
from cirq_google.engine import stream_manager
4142

4243
_M = TypeVar('_M', bound=proto.Message)
4344
_R = TypeVar('_R')
@@ -105,6 +106,10 @@ async def make_client():
105106

106107
return self._executor.submit(make_client).result()
107108

109+
@cached_property
110+
def _stream_manager(self) -> stream_manager.StreamManager:
111+
return stream_manager.StreamManager(self.grpc_client)
112+
108113
async def _send_request_async(self, func: Callable[[_M], Awaitable[_R]], request: _M) -> _R:
109114
"""Sends a request by invoking an asyncio callable."""
110115
return await self._run_retry_async(func, request)
@@ -697,6 +702,79 @@ async def get_job_results_async(
697702

698703
get_job_results = duet.sync(get_job_results_async)
699704

705+
def run_job_over_stream(
706+
self,
707+
project_id: str,
708+
program_id: str,
709+
code: any_pb2.Any,
710+
job_id: str,
711+
processor_ids: Sequence[str],
712+
run_context: any_pb2.Any,
713+
program_description: Optional[str] = None,
714+
program_labels: Optional[Dict[str, str]] = None,
715+
priority: Optional[int] = None,
716+
job_description: Optional[str] = None,
717+
job_labels: Optional[Dict[str, str]] = None,
718+
) -> duet.AwaitableFuture[Union[quantum.QuantumResult, quantum.QuantumJob]]:
719+
"""Runs a job with the given program and job information over a stream.
720+
721+
Sends the request over the Quantum Engine QuantumRunStream bidirectional stream, and returns
722+
a future for the stream response. The future will be completed with a `QuantumResult` if
723+
the job is successful; otherwise, it will be completed with a QuantumJob.
724+
725+
Args:
726+
project_id: A project_id of the parent Google Cloud Project.
727+
program_id: Unique ID of the program within the parent project.
728+
code: Properly serialized program code.
729+
job_id: Unique ID of the job within the parent program.
730+
run_context: Properly serialized run context.
731+
processor_ids: List of processor id for running the program.
732+
program_description: An optional description to set on the program.
733+
program_labels: Optional set of labels to set on the program.
734+
priority: Optional priority to run at, 0-1000.
735+
job_description: Optional description to set on the job.
736+
job_labels: Optional set of labels to set on the job.
737+
738+
Returns:
739+
A future for the job result, or the job if the job has failed.
740+
741+
Raises:
742+
ValueError: If the priority is not between 0 and 1000.
743+
"""
744+
# Check program to run and program parameters.
745+
if priority and not 0 <= priority < 1000:
746+
raise ValueError('priority must be between 0 and 1000')
747+
748+
project_name = _project_name(project_id)
749+
750+
program_name = _program_name_from_ids(project_id, program_id)
751+
program = quantum.QuantumProgram(name=program_name, code=code)
752+
if program_description:
753+
program.description = program_description
754+
if program_labels:
755+
program.labels.update(program_labels)
756+
757+
job = quantum.QuantumJob(
758+
name=_job_name_from_ids(project_id, program_id, job_id),
759+
scheduling_config=quantum.SchedulingConfig(
760+
processor_selector=quantum.SchedulingConfig.ProcessorSelector(
761+
processor_names=[
762+
_processor_name_from_ids(project_id, processor_id)
763+
for processor_id in processor_ids
764+
]
765+
)
766+
),
767+
run_context=run_context,
768+
)
769+
if priority:
770+
job.scheduling_config.priority = priority
771+
if job_description:
772+
job.description = job_description
773+
if job_labels:
774+
job.labels.update(job_labels)
775+
776+
return self._stream_manager.submit(project_name, program, job)
777+
700778
async def list_processors_async(self, project_id: str) -> List[quantum.QuantumProcessor]:
701779
"""Returns a list of Processors that the user has visibility to in the
702780
current Engine project. The names of these processors are used to

0 commit comments

Comments
 (0)