Skip to content

Commit 5934b5a

Browse files
authored
Defensively copy density matrices in simulator (#3109)
1 parent bbd8a2d commit 5934b5a

File tree

3 files changed

+54
-6
lines changed

3 files changed

+54
-6
lines changed

cirq/qis/states.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
# limitations under the License.
1414
"""Utility methods for creating vectors and matrices."""
1515

16-
from typing import (Any, Iterable, List, Optional, Sequence, Union,
17-
TYPE_CHECKING, Tuple, Type, cast)
16+
from typing import (Any, cast, Iterable, Optional, Sequence, TYPE_CHECKING,
17+
Tuple, Type, Union)
1818

1919
import itertools
2020

@@ -486,7 +486,8 @@ def to_valid_density_matrix(
486486
487487
Returns:
488488
A numpy matrix corresponding to the density matrix on the given number
489-
of qubits.
489+
of qubits. Note that this matrix may share memory with the input
490+
`density_matrix_rep`.
490491
491492
Raises:
492493
ValueError if the density_matrix_rep is not valid.

cirq/sim/density_matrix_simulator.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,8 @@ def _base_iterator(self,
257257
len(qid_shape),
258258
qid_shape=qid_shape,
259259
dtype=self._dtype)
260+
if np.may_share_memory(initial_matrix, initial_state):
261+
initial_matrix = initial_matrix.copy()
260262
measured = collections.defaultdict(
261263
bool) # type: Dict[Tuple[cirq.Qid, ...], bool]
262264
if len(circuit) == 0:
@@ -407,7 +409,7 @@ def set_density_matrix(self, density_matrix_repr: Union[int, np.ndarray]):
407409
density_matrix = np.reshape(density_matrix, sim_state_matrix.shape)
408410
np.copyto(dst=sim_state_matrix, src=density_matrix)
409411

410-
def density_matrix(self):
412+
def density_matrix(self, copy=True):
411413
"""Returns the density matrix at this step in the simulation.
412414
413415
The density matrix that is stored in this result is returned in the
@@ -435,9 +437,16 @@ def density_matrix(self):
435437
| 6 | 1 | 1 | 0 |
436438
| 7 | 1 | 1 | 1 |
437439
440+
Args:
441+
copy: If True, then the returned state is a copy of the density
442+
matrix. If False, then the density matrix is not copied,
443+
potentially saving memory. If one only needs to read derived
444+
parameters from the density matrix and store then using False
445+
can speed up simulation by eliminating a memory copy.
438446
"""
439447
size = np.prod(self._qid_shape, dtype=int)
440-
return np.reshape(self._density_matrix, (size, size))
448+
matrix = self._density_matrix.copy() if copy else self._density_matrix
449+
return np.reshape(matrix, (size, size))
441450

442451
def sample(self,
443452
qubits: List[ops.Qid],
@@ -528,7 +537,7 @@ def __init__(self, params: study.ParamResolver,
528537
final_simulator_state=final_simulator_state)
529538
size = np.prod(protocols.qid_shape(self), dtype=int)
530539
self.final_density_matrix = np.reshape(
531-
final_simulator_state.density_matrix, (size, size))
540+
final_simulator_state.density_matrix.copy(), (size, size))
532541

533542
def _value_equality_values_(self) -> Any:
534543
measurements = {

cirq/sim/density_matrix_simulator_test.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,3 +1112,41 @@ def test_simulate_noise_with_terminal_measurements():
11121112
result2 = simulator.run(circuit2, repetitions=10)
11131113

11141114
assert result1 == result2
1115+
1116+
1117+
def test_density_matrix_copy():
1118+
sim = cirq.DensityMatrixSimulator()
1119+
1120+
q = cirq.LineQubit(0)
1121+
circuit = cirq.Circuit(cirq.H(q), cirq.H(q))
1122+
1123+
matrices = []
1124+
for step in sim.simulate_moment_steps(circuit):
1125+
matrices.append(step.density_matrix(copy=True))
1126+
assert all(np.isclose(np.trace(x), 1.0) for x in matrices)
1127+
for x, y in itertools.combinations(matrices, 2):
1128+
assert not np.shares_memory(x, y)
1129+
1130+
# If the density matrix is not copied, then applying second Hadamard
1131+
# causes old state to be modified.
1132+
matrices = []
1133+
traces = []
1134+
for step in sim.simulate_moment_steps(circuit):
1135+
matrices.append(step.density_matrix(copy=False))
1136+
traces.append(np.trace(step.density_matrix(copy=False)))
1137+
assert any(not np.isclose(np.trace(x), 1.0) for x in matrices)
1138+
assert all(np.isclose(x, 1.0) for x in traces)
1139+
assert all(not np.shares_memory(x, y)
1140+
for x, y in itertools.combinations(matrices, 2))
1141+
1142+
1143+
def test_final_density_matrix_is_not_last_object():
1144+
sim = cirq.DensityMatrixSimulator()
1145+
1146+
q = cirq.LineQubit(0)
1147+
initial_state = np.array([[1, 0], [0, 0]], dtype=np.complex64)
1148+
circuit = cirq.Circuit(cirq.WaitGate(0)(q))
1149+
result = sim.simulate(circuit, initial_state=initial_state)
1150+
assert result.final_density_matrix is not initial_state
1151+
assert not np.shares_memory(result.final_density_matrix, initial_state)
1152+
np.testing.assert_equal(result.final_density_matrix, initial_state)

0 commit comments

Comments
 (0)