Skip to content

Commit 7bace85

Browse files
committed
Override gate_op controlled_by and three_qubit gates controlled methods
1 parent 6cfdda7 commit 7bace85

File tree

7 files changed

+151
-25
lines changed

7 files changed

+151
-25
lines changed

cirq-core/cirq/interop/quirk/cells/input_rotation_cells_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@ def test_input_rotation_cells():
6464
assert_url_to_circuit_returns(
6565
'{"cols":[["•","Z^(A/2^n)","inputA2"]]}',
6666
diagram="""
67-
0: ───@───────────
67+
0: ───@^(A/2^2)───
6868
69-
1: ───Z^(A/2^2)───
69+
1: ───@───────────
7070
7171
2: ───A0──────────
7272
@@ -89,9 +89,9 @@ def test_input_rotation_cells():
8989
assert_url_to_circuit_returns(
9090
'{"cols":[["•","X^(-A/2^n)","inputA2"]]}',
9191
diagram="""
92-
0: ───@────────────
92+
0: ───@^(-A/2^2)───
9393
94-
1: ───X^(-A/2^2)───
94+
1: ───X────────────
9595
9696
2: ───A0───────────
9797

cirq-core/cirq/interop/quirk/url_to_circuit_test.py

Lines changed: 15 additions & 15 deletions
Large diffs are not rendered by default.

cirq-core/cirq/ops/common_gates_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,26 @@ def test_no_specialized_control_for_global_shift_non_zero(gate, specialized_type
193193
assert not isinstance(gate.controlled(), specialized_type)
194194

195195

196+
@pytest.mark.parametrize(
197+
'input_gate, q, expected_gate',
198+
[
199+
(cirq.Z, cirq.LineQubit.range(2), cirq.CZ),
200+
(cirq.Z, cirq.LineQubit.range(3), cirq.CCZ),
201+
(cirq.X, cirq.LineQubit.range(2), cirq.CX),
202+
(cirq.X, cirq.LineQubit.range(3), cirq.CCX),
203+
(cirq.Z ** 0.5, cirq.LineQubit.range(2), cirq.CZ ** 0.5),
204+
(cirq.Z ** 0.5, cirq.LineQubit.range(3), cirq.CCZ ** 0.5),
205+
(cirq.X ** 0.5, cirq.LineQubit.range(2), cirq.CX ** 0.5),
206+
(cirq.X ** 0.5, cirq.LineQubit.range(3), cirq.CCX ** 0.5),
207+
],
208+
)
209+
def test_specialized_control_gate_and_op_same(input_gate, q, expected_gate):
210+
num_controls = len(q) - 1
211+
op1 = input_gate.controlled(num_controls).on(*q[1:], q[0]) # Way-1: gate.controlled().on()
212+
op2 = input_gate.on(q[0]).controlled_by(*q[1:]) # Way-2: gate.on().controlled_by()
213+
assert op1 == op2 == expected_gate(*q[1:], q[0])
214+
215+
196216
@pytest.mark.parametrize(
197217
'gate, matrix',
198218
[

cirq-core/cirq/ops/gate_operation.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
AbstractSet,
2020
Any,
2121
cast,
22+
Collection,
2223
Dict,
2324
FrozenSet,
2425
Iterable,
@@ -333,5 +334,17 @@ def _equal_up_to_global_phase_(
333334
return False
334335
return protocols.equal_up_to_global_phase(self.gate, other.gate, atol=atol)
335336

337+
def controlled_by(
338+
self,
339+
*control_qubits: 'cirq.Qid',
340+
control_values: Optional[Sequence[Union[int, Collection[int]]]] = None,
341+
) -> 'cirq.Operation':
342+
qubits = tuple(control_qubits)
343+
return self._gate.controlled(
344+
num_controls=len(qubits),
345+
control_values=control_values,
346+
control_qid_shape=tuple(q.dimension for q in qubits),
347+
).on(*(qubits + self._qubits))
348+
336349

337350
TV = TypeVar('TV', bound=raw_types.Gate)

cirq-core/cirq/ops/raw_types_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,13 @@ def test_gate():
127127

128128

129129
def test_op():
130-
a, b, c = cirq.LineQubit.range(3)
130+
a, b, c, d = cirq.LineQubit.range(4)
131131
g = ValiGate()
132-
op = g(a)
133-
assert op.controlled_by() is op
134-
controlled_op = op.controlled_by(b, c)
132+
op = g(a, b)
133+
assert op.controlled_by() == op
134+
controlled_op = op.controlled_by(c, d)
135135
assert controlled_op.sub_operation == op
136-
assert controlled_op.controls == (b, c)
136+
assert controlled_op.controls == (c, d)
137137

138138

139139
def test_op_validate():

cirq-core/cirq/ops/three_qubit_gates.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,17 @@
1414

1515
"""Common quantum gates that target three qubits."""
1616

17-
from typing import AbstractSet, Any, List, Optional, Tuple, TYPE_CHECKING
17+
from typing import (
18+
AbstractSet,
19+
Any,
20+
Collection,
21+
List,
22+
Optional,
23+
Sequence,
24+
Tuple,
25+
TYPE_CHECKING,
26+
Union,
27+
)
1828

1929
import numpy as np
2030
import sympy
@@ -29,6 +39,7 @@
2939
gate_features,
3040
pauli_gates,
3141
swap_gates,
42+
raw_types,
3243
)
3344

3445
if TYPE_CHECKING:
@@ -167,6 +178,32 @@ def __str__(self) -> str:
167178
return 'CCZ'
168179
return f'CCZ**{self._exponent}'
169180

181+
def controlled(
182+
self,
183+
num_controls: int = None,
184+
control_values: Optional[Sequence[Union[int, Collection[int]]]] = None,
185+
control_qid_shape: Optional[Tuple[int, ...]] = None,
186+
) -> raw_types.Gate:
187+
"""
188+
Returns a controlled `ZPowGate` with two additional controls.
189+
190+
The `controlled` method of the `Gate` class, of which this class is a
191+
child, returns a `ControlledGate` with `sub_gate = self`. This method
192+
overrides this behavior to return a `ControlledGate` with
193+
`sub_gate = ZPowGate`.
194+
"""
195+
if num_controls == 0:
196+
return self
197+
return controlled_gate.ControlledGate(
198+
controlled_gate.ControlledGate(
199+
common_gates.ZPowGate(exponent=self._exponent, global_shift=self._global_shift),
200+
num_controls=2,
201+
),
202+
num_controls=num_controls,
203+
control_values=control_values,
204+
control_qid_shape=control_qid_shape,
205+
)
206+
170207

171208
@value.value_equality()
172209
class ThreeQubitDiagonalGate(gate_features.ThreeQubitGate):
@@ -424,6 +461,32 @@ def __str__(self) -> str:
424461
return 'TOFFOLI'
425462
return f'TOFFOLI**{self._exponent}'
426463

464+
def controlled(
465+
self,
466+
num_controls: int = None,
467+
control_values: Optional[Sequence[Union[int, Collection[int]]]] = None,
468+
control_qid_shape: Optional[Tuple[int, ...]] = None,
469+
) -> raw_types.Gate:
470+
"""
471+
Returns a controlled `XPowGate` with two additional controls.
472+
473+
The `controlled` method of the `Gate` class, of which this class is a
474+
child, returns a `ControlledGate` with `sub_gate = self`. This method
475+
overrides this behavior to return a `ControlledGate` with
476+
`sub_gate = XPowGate`.
477+
"""
478+
if num_controls == 0:
479+
return self
480+
return controlled_gate.ControlledGate(
481+
controlled_gate.ControlledGate(
482+
common_gates.XPowGate(exponent=self._exponent, global_shift=self._global_shift),
483+
num_controls=2,
484+
),
485+
num_controls=num_controls,
486+
control_values=control_values,
487+
control_qid_shape=control_qid_shape,
488+
)
489+
427490

428491
@value.value_equality()
429492
class CSwapGate(gate_features.ThreeQubitGate, gate_features.InterchangeableQubitsGate):
@@ -569,6 +632,29 @@ def __str__(self) -> str:
569632
def __repr__(self) -> str:
570633
return 'cirq.FREDKIN'
571634

635+
def controlled(
636+
self,
637+
num_controls: int = None,
638+
control_values: Optional[Sequence[Union[int, Collection[int]]]] = None,
639+
control_qid_shape: Optional[Tuple[int, ...]] = None,
640+
) -> raw_types.Gate:
641+
"""
642+
Returns a controlled `SWAP` with one additional control.
643+
644+
The `controlled` method of the `Gate` class, of which this class is a
645+
child, returns a `ControlledGate` with `sub_gate = self`. This method
646+
overrides this behavior to return a `ControlledGate` with
647+
`sub_gate = SWAP`.
648+
"""
649+
if num_controls == 0:
650+
return self
651+
return controlled_gate.ControlledGate(
652+
controlled_gate.ControlledGate(swap_gates.SWAP, num_controls=1),
653+
num_controls=num_controls,
654+
control_values=control_values,
655+
control_qid_shape=control_qid_shape,
656+
)
657+
572658

573659
CCZ = CCZPowGate()
574660
document(

cirq-core/cirq/ops/three_qubit_gates_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,3 +319,10 @@ def test_resolve(resolve_fn):
319319
diagonal_gate = resolve_fn(diagonal_gate, {'b': 19})
320320
assert diagonal_gate == cirq.ThreeQubitDiagonalGate(diagonal_angles)
321321
assert not cirq.is_parameterized(diagonal_gate)
322+
323+
324+
def test_controlled_ops_consistency():
325+
a, b, c, d = cirq.LineQubit.range(4)
326+
assert cirq.CCX(a, b, c).controlled_by(d) == cirq.CCX(d, b, c).controlled_by(a)
327+
assert cirq.CCZ(a, b, c).controlled_by(d) == cirq.CCZ(d, b, c).controlled_by(a)
328+
assert cirq.CSWAP(a, b, c).controlled_by(d) == cirq.CSWAP(d, b, c).controlled_by(a)

0 commit comments

Comments
 (0)