Skip to content

Commit 849849d

Browse files
Provide sweep-iterator API. (#4094)
Fixes #4064. This PR provides `*_iter` methods for all simulator "sweep" methods. External simulators which implement the old list-based APIs will still function properly (through the magic of `@value.alternative`), but these simulators are advised to switch to the iterator-based API for better performance. Note that the list-based API is _not_ deprecated, and we have no plans to remove it at this time.
1 parent b4c445f commit 849849d

File tree

9 files changed

+285
-61
lines changed

9 files changed

+285
-61
lines changed

cirq-core/cirq/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,7 @@
429429
chosen_angle_to_half_turns,
430430
Duration,
431431
DURATION_LIKE,
432+
GenericMetaImplementAnyOneOf,
432433
LinearDict,
433434
MEASUREMENT_KEY_SEPARATOR,
434435
MeasurementKey,

cirq-core/cirq/protocols/json_test_data/spec.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,10 @@
119119
'ThreeQubitGate',
120120
'TwoQubitGate',
121121
'ABCMetaImplementAnyOneOf',
122+
'GenericMetaImplementAnyOneOf',
123+
'SimulatesAmplitudes',
124+
'SimulatesExpectationValues',
125+
'SimulatesFinalState',
122126
# protocols:
123127
'SupportsActOn',
124128
'SupportsApplyChannel',

cirq-core/cirq/sim/simulator.py

Lines changed: 139 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,14 @@ def run_sweep(
7373
params: study.Sweepable,
7474
repetitions: int = 1,
7575
) -> List[study.Result]:
76+
return list(self.run_sweep_iter(program, params, repetitions))
77+
78+
def run_sweep_iter(
79+
self,
80+
program: 'cirq.Circuit',
81+
params: study.Sweepable,
82+
repetitions: int = 1,
83+
) -> Iterator[study.Result]:
7684
"""Runs the supplied Circuit, mimicking quantum hardware.
7785
7886
In contrast to run, this allows for sweeping over different parameter
@@ -92,7 +100,6 @@ def run_sweep(
92100

93101
_verify_unique_measurement_keys(program)
94102

95-
trial_results = [] # type: List[study.Result]
96103
for param_resolver in study.to_resolvers(params):
97104
measurements = {}
98105
if repetitions == 0:
@@ -102,12 +109,9 @@ def run_sweep(
102109
measurements = self._run(
103110
circuit=program, param_resolver=param_resolver, repetitions=repetitions
104111
)
105-
trial_results.append(
106-
study.Result.from_single_parameter_set(
107-
params=param_resolver, measurements=measurements
108-
)
112+
yield study.Result.from_single_parameter_set(
113+
params=param_resolver, measurements=measurements
109114
)
110-
return trial_results
111115

112116
@abc.abstractmethod
113117
def _run(
@@ -131,13 +135,13 @@ def _run(
131135
raise NotImplementedError()
132136

133137

134-
class SimulatesAmplitudes(metaclass=abc.ABCMeta):
138+
class SimulatesAmplitudes(metaclass=value.ABCMetaImplementAnyOneOf):
135139
"""Simulator that computes final amplitudes of given bitstrings.
136140
137141
Given a circuit and a list of bitstrings, computes the amplitudes
138142
of the given bitstrings in the state obtained by applying the circuit
139143
to the all zeros state. Implementors of this interface should implement
140-
the compute_amplitudes_sweep method.
144+
the compute_amplitudes_sweep_iter method.
141145
"""
142146

143147
def compute_amplitudes(
@@ -169,14 +173,42 @@ def compute_amplitudes(
169173
program, bitstrings, study.ParamResolver(param_resolver), qubit_order
170174
)[0]
171175

172-
@abc.abstractmethod
173176
def compute_amplitudes_sweep(
174177
self,
175178
program: 'cirq.Circuit',
176179
bitstrings: Sequence[int],
177180
params: study.Sweepable,
178181
qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT,
179182
) -> Sequence[Sequence[complex]]:
183+
"""Wraps computed amplitudes in a list.
184+
185+
Prefer overriding `compute_amplitudes_sweep_iter`.
186+
"""
187+
return list(self.compute_amplitudes_sweep_iter(program, bitstrings, params, qubit_order))
188+
189+
def _compute_amplitudes_sweep_to_iter(
190+
self,
191+
program: 'cirq.Circuit',
192+
bitstrings: Sequence[int],
193+
params: study.Sweepable,
194+
qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT,
195+
) -> Iterator[Sequence[complex]]:
196+
if type(self).compute_amplitudes_sweep == SimulatesAmplitudes.compute_amplitudes_sweep:
197+
raise RecursionError(
198+
"Must define either compute_amplitudes_sweep or compute_amplitudes_sweep_iter."
199+
)
200+
yield from self.compute_amplitudes_sweep(program, bitstrings, params, qubit_order)
201+
202+
@value.alternative(
203+
requires='compute_amplitudes_sweep', implementation=_compute_amplitudes_sweep_to_iter
204+
)
205+
def compute_amplitudes_sweep_iter(
206+
self,
207+
program: 'cirq.Circuit',
208+
bitstrings: Sequence[int],
209+
params: study.Sweepable,
210+
qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT,
211+
) -> Iterator[Sequence[complex]]:
180212
"""Computes the desired amplitudes.
181213
182214
The initial state is assumed to be the all zeros state.
@@ -193,20 +225,20 @@ def compute_amplitudes_sweep(
193225
ordering of the computational basis states.
194226
195227
Returns:
196-
List of lists of amplitudes. The outer dimension indexes the
197-
circuit parameters and the inner dimension indexes the bitstrings.
228+
An Iterator over lists of amplitudes. The outer dimension indexes
229+
the circuit parameters and the inner dimension indexes bitstrings.
198230
"""
199231
raise NotImplementedError()
200232

201233

202-
class SimulatesExpectationValues(metaclass=abc.ABCMeta):
234+
class SimulatesExpectationValues(metaclass=value.ABCMetaImplementAnyOneOf):
203235
"""Simulator that computes exact expectation values of observables.
204236
205237
Given a circuit and an observable map, computes exact (to float precision)
206238
expectation values for each observable at the end of the circuit.
207239
208240
Implementors of this interface should implement the
209-
simulate_expectation_values_sweep method.
241+
simulate_expectation_values_sweep_iter method.
210242
"""
211243

212244
def simulate_expectation_values(
@@ -257,7 +289,6 @@ def simulate_expectation_values(
257289
permit_terminal_measurements,
258290
)[0]
259291

260-
@abc.abstractmethod
261292
def simulate_expectation_values_sweep(
262293
self,
263294
program: 'cirq.Circuit',
@@ -267,6 +298,60 @@ def simulate_expectation_values_sweep(
267298
initial_state: Any = None,
268299
permit_terminal_measurements: bool = False,
269300
) -> List[List[float]]:
301+
"""Wraps computed expectation values in a list.
302+
303+
Prefer overriding `simulate_expectation_values_sweep_iter`.
304+
"""
305+
return list(
306+
self.simulate_expectation_values_sweep_iter(
307+
program,
308+
observables,
309+
params,
310+
qubit_order,
311+
initial_state,
312+
permit_terminal_measurements,
313+
)
314+
)
315+
316+
def _simulate_expectation_values_sweep_to_iter(
317+
self,
318+
program: 'cirq.Circuit',
319+
observables: Union['cirq.PauliSumLike', List['cirq.PauliSumLike']],
320+
params: 'study.Sweepable',
321+
qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT,
322+
initial_state: Any = None,
323+
permit_terminal_measurements: bool = False,
324+
) -> Iterator[List[float]]:
325+
if (
326+
type(self).simulate_expectation_values_sweep
327+
== SimulatesExpectationValues.simulate_expectation_values_sweep
328+
):
329+
raise RecursionError(
330+
"Must define either simulate_expectation_values_sweep or "
331+
"simulate_expectation_values_sweep_iter."
332+
)
333+
yield from self.simulate_expectation_values_sweep(
334+
program,
335+
observables,
336+
params,
337+
qubit_order,
338+
initial_state,
339+
permit_terminal_measurements,
340+
)
341+
342+
@value.alternative(
343+
requires='simulate_expectation_values_sweep',
344+
implementation=_simulate_expectation_values_sweep_to_iter,
345+
)
346+
def simulate_expectation_values_sweep_iter(
347+
self,
348+
program: 'cirq.Circuit',
349+
observables: Union['cirq.PauliSumLike', List['cirq.PauliSumLike']],
350+
params: 'study.Sweepable',
351+
qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT,
352+
initial_state: Any = None,
353+
permit_terminal_measurements: bool = False,
354+
) -> Iterator[List[float]]:
270355
"""Simulates the supplied circuit and calculates exact expectation
271356
values for the given observables on its final state, sweeping over the
272357
given params.
@@ -291,21 +376,23 @@ def simulate_expectation_values_sweep(
291376
ruining expectation value calculations.
292377
293378
Returns:
294-
A list of expectation-value lists. The outer index determines the
295-
sweep, and the inner index determines the observable. For instance,
296-
results[1][3] would select the fourth observable measured in the
297-
second sweep.
379+
An Iterator over expectation-value lists. The outer index determines
380+
the sweep, and the inner index determines the observable. For
381+
instance, results[1][3] would select the fourth observable measured
382+
in the second sweep.
298383
299384
Raises:
300385
ValueError if 'program' has terminal measurement(s) and
301386
'permit_terminal_measurements' is False.
302387
"""
303388

304389

305-
class SimulatesFinalState(Generic[TSimulationTrialResult], metaclass=abc.ABCMeta):
390+
class SimulatesFinalState(
391+
Generic[TSimulationTrialResult], metaclass=value.GenericMetaImplementAnyOneOf
392+
):
306393
"""Simulator that allows access to the simulator's final state.
307394
308-
Implementors of this interface should implement the simulate_sweep
395+
Implementors of this interface should implement the simulate_sweep_iter
309396
method. This simulator only returns the state of the quantum system
310397
for the final step of a simulation. This simulator state may be a state
311398
vector, the density matrix, or another representation, depending on the
@@ -342,14 +429,38 @@ def simulate(
342429
program, study.ParamResolver(param_resolver), qubit_order, initial_state
343430
)[0]
344431

345-
@abc.abstractmethod
346432
def simulate_sweep(
347433
self,
348434
program: 'cirq.Circuit',
349435
params: study.Sweepable,
350436
qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT,
351437
initial_state: Any = None,
352438
) -> List[TSimulationTrialResult]:
439+
"""Wraps computed states in a list.
440+
441+
Prefer overriding `simulate_sweep_iter`.
442+
"""
443+
return list(self.simulate_sweep_iter(program, params, qubit_order, initial_state))
444+
445+
def _simulate_sweep_to_iter(
446+
self,
447+
program: 'cirq.Circuit',
448+
params: study.Sweepable,
449+
qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT,
450+
initial_state: Any = None,
451+
) -> Iterator[TSimulationTrialResult]:
452+
if type(self).simulate_sweep == SimulatesFinalState.simulate_sweep:
453+
raise RecursionError("Must define either simulate_sweep or simulate_sweep_iter.")
454+
yield from self.simulate_sweep(program, params, qubit_order, initial_state)
455+
456+
@value.alternative(requires='simulate_sweep', implementation=_simulate_sweep_to_iter)
457+
def simulate_sweep_iter(
458+
self,
459+
program: 'cirq.Circuit',
460+
params: study.Sweepable,
461+
qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT,
462+
initial_state: Any = None,
463+
) -> Iterator[TSimulationTrialResult]:
353464
"""Simulates the supplied Circuit.
354465
355466
This method returns a result which allows access to the entire final
@@ -367,7 +478,7 @@ def simulate_sweep(
367478
documentation of the implementing class for details.
368479
369480
Returns:
370-
List of SimulationTrialResults for this run, one for each
481+
Iterator over SimulationTrialResults for this run, one for each
371482
possible parameter resolver.
372483
"""
373484
raise NotImplementedError()
@@ -391,13 +502,13 @@ class SimulatesIntermediateState(
391502
a state vector.
392503
"""
393504

394-
def simulate_sweep(
505+
def simulate_sweep_iter(
395506
self,
396507
program: 'cirq.Circuit',
397508
params: study.Sweepable,
398509
qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT,
399510
initial_state: Any = None,
400-
) -> List[TSimulationTrialResult]:
511+
) -> Iterator[TSimulationTrialResult]:
401512
"""Simulates the supplied Circuit.
402513
403514
This method returns a result which allows access to the entire
@@ -419,7 +530,6 @@ def simulate_sweep(
419530
List of SimulationTrialResults for this run, one for each
420531
possible parameter resolver.
421532
"""
422-
trial_results = []
423533
qubit_order = ops.QubitOrder.as_qubit_order(qubit_order)
424534
for param_resolver in study.to_resolvers(params):
425535
all_step_results = self.simulate_moment_steps(
@@ -429,14 +539,11 @@ def simulate_sweep(
429539
for step_result in all_step_results:
430540
for k, v in step_result.measurements.items():
431541
measurements[k] = np.array(v, dtype=np.uint8)
432-
trial_results.append(
433-
self._create_simulator_trial_result(
434-
params=param_resolver,
435-
measurements=measurements,
436-
final_simulator_state=step_result._simulator_state(),
437-
)
542+
yield self._create_simulator_trial_result(
543+
params=param_resolver,
544+
measurements=measurements,
545+
final_simulator_state=step_result._simulator_state(),
438546
)
439-
return trial_results
440547

441548
def simulate_moment_steps(
442549
self,

0 commit comments

Comments
 (0)