Skip to content

Commit c193c48

Browse files
authored
Fix bug in controlled() method of CCX and CCZ gates (#6237)
* Fix bug in controlled() method of CCX and CCZ gates * Fix typing
1 parent 5e342c2 commit c193c48

File tree

5 files changed

+48
-11
lines changed

5 files changed

+48
-11
lines changed

cirq-core/cirq/ops/three_qubit_gates.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -206,11 +206,13 @@ def controlled(
206206
"""
207207
if num_controls == 0:
208208
return self
209+
sub_gate: 'cirq.Gate' = self
210+
if self._global_shift == 0:
211+
sub_gate = controlled_gate.ControlledGate(
212+
common_gates.ZPowGate(exponent=self._exponent), num_controls=2
213+
)
209214
return controlled_gate.ControlledGate(
210-
controlled_gate.ControlledGate(
211-
common_gates.ZPowGate(exponent=self._exponent, global_shift=self._global_shift),
212-
num_controls=2,
213-
),
215+
sub_gate,
214216
num_controls=num_controls,
215217
control_values=control_values,
216218
control_qid_shape=control_qid_shape,
@@ -518,11 +520,13 @@ def controlled(
518520
"""
519521
if num_controls == 0:
520522
return self
523+
sub_gate: 'cirq.Gate' = self
524+
if self._global_shift == 0:
525+
sub_gate = controlled_gate.ControlledGate(
526+
common_gates.XPowGate(exponent=self._exponent), num_controls=2
527+
)
521528
return controlled_gate.ControlledGate(
522-
controlled_gate.ControlledGate(
523-
common_gates.XPowGate(exponent=self._exponent, global_shift=self._global_shift),
524-
num_controls=2,
525-
),
529+
sub_gate,
526530
num_controls=num_controls,
527531
control_values=control_values,
528532
control_qid_shape=control_qid_shape,

cirq-core/cirq/testing/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@
3030

3131
from cirq.testing.consistent_channels import assert_consistent_channel, assert_consistent_mixture
3232

33-
from cirq.testing.consistent_controlled_gate_op import assert_controlled_and_controlled_by_identical
33+
from cirq.testing.consistent_controlled_gate_op import (
34+
assert_controlled_and_controlled_by_identical,
35+
assert_controlled_unitary_consistent,
36+
)
3437

3538
from cirq.testing.consistent_decomposition import (
3639
assert_decompose_ends_at_default_gateset,

cirq-core/cirq/testing/consistent_controlled_gate_op.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
from typing import Sequence, Optional, Union, Collection
1616

17-
from cirq import devices, ops
17+
from cirq import devices, ops, protocols
18+
import numpy as np
1819

1920

2021
def assert_controlled_and_controlled_by_identical(
@@ -34,6 +35,19 @@ def assert_controlled_and_controlled_by_identical(
3435
_assert_gate_consistent(gate, num_control, control_value)
3536

3637

38+
def assert_controlled_unitary_consistent(gate: ops.Gate):
39+
"""Checks that unitary of ControlledGate(gate) is consistent with gate.controlled()."""
40+
41+
u_orig = protocols.unitary(ops.ControlledGate(gate))
42+
u_controlled = protocols.unitary(gate.controlled())
43+
np.testing.assert_allclose(
44+
u_orig,
45+
u_controlled,
46+
atol=1e-6,
47+
err_msg=f"Unitary for gate.controlled() is inconsistent for {gate=}",
48+
)
49+
50+
3751
def _assert_gate_consistent(
3852
gate: ops.Gate,
3953
num_controls: int,

cirq-core/cirq/testing/consistent_controlled_gate_op_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,14 @@ def test_assert_controlled_and_controlled_by_identical():
7676
cirq.testing.assert_controlled_and_controlled_by_identical(
7777
GoodGate(), num_controls=[1, 2], control_values=[(1,), (1, 1, 1)]
7878
)
79+
80+
81+
def test_assert_controlled_unitary_consistent():
82+
cirq.testing.assert_controlled_and_controlled_by_identical(
83+
GoodGate(exponent=0.5, global_shift=1 / 3)
84+
)
85+
86+
with pytest.raises(AssertionError):
87+
cirq.testing.assert_controlled_and_controlled_by_identical(
88+
BadGate(exponent=0.5, global_shift=1 / 3)
89+
)

cirq-core/cirq/testing/consistent_protocols.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@
3636
from cirq.testing.consistent_resolve_parameters import assert_consistent_resolve_parameters
3737
from cirq.testing.consistent_specified_has_unitary import assert_specifies_has_unitary_if_unitary
3838
from cirq.testing.equivalent_repr_eval import assert_equivalent_repr
39-
from cirq.testing.consistent_controlled_gate_op import assert_controlled_and_controlled_by_identical
39+
from cirq.testing.consistent_controlled_gate_op import (
40+
assert_controlled_and_controlled_by_identical,
41+
assert_controlled_unitary_consistent,
42+
)
4043
from cirq.testing.consistent_unitary import assert_unitary_is_consistent
4144

4245

@@ -167,6 +170,8 @@ def _assert_meets_standards_helper(
167170
assert_eigen_shifts_is_consistent_with_eigen_components(val)
168171
if isinstance(val, ops.Gate) and protocols.has_mixture(val):
169172
assert_controlled_and_controlled_by_identical(val)
173+
if protocols.has_unitary(val):
174+
assert_controlled_unitary_consistent(val)
170175

171176

172177
def assert_commutes_magic_method_consistent_with_unitaries(

0 commit comments

Comments
 (0)