Skip to content

Commit bd2e63c

Browse files
authored
Boolean Hamiltonian gate yields fewer gates (#4386)
* Boolean Hamiltonian gate yields fewer gates * Address some of the comments * Address some of the comments * Add unit test * Expand unit test * Fix unit test * Add test that is failing but should pass * More comprehensive tests * Fix code and make unit tests pass * Address comments
1 parent 3d50921 commit bd2e63c

File tree

2 files changed

+318
-5
lines changed

2 files changed

+318
-5
lines changed

cirq-core/cirq/ops/boolean_hamiltonian.py

Lines changed: 192 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
by Stuart Hadfield, https://arxiv.org/pdf/1804.09130.pdf
2020
[2] https://www.youtube.com/watch?v=AOKM9BkweVU is a useful intro
2121
[3] https://github.com/rsln-s/IEEE_QW_2020/blob/master/Slides.pdf
22+
[4] Efficient Quantum Circuits for Diagonal Unitaries Without Ancillas by Jonathan Welch, Daniel
23+
Greenbaum, Sarah Mostame, and Alán Aspuru-Guzik, https://arxiv.org/abs/1306.3991
2224
"""
25+
import itertools
26+
import functools
2327

2428
from typing import Any, Dict, Generator, List, Sequence, Tuple
2529

@@ -112,6 +116,187 @@ def _decompose_(self):
112116
)
113117

114118

119+
def _gray_code_comparator(k1: Tuple[int, ...], k2: Tuple[int, ...], flip: bool = False) -> int:
120+
"""Compares two Gray-encoded binary numbers.
121+
122+
Args:
123+
k1: A tuple of ints, representing the bits that are one. For example, 6 would be (1, 2).
124+
k2: The second number, represented similarly as k1.
125+
flip: Whether to flip the comparison.
126+
127+
Returns:
128+
-1 if k1 < k2 (or +1 if flip is true)
129+
0 if k1 == k2
130+
+1 if k1 > k2 (or -1 if flip is true)
131+
"""
132+
max_1 = k1[-1] if k1 else -1
133+
max_2 = k2[-1] if k2 else -1
134+
if max_1 != max_2:
135+
return -1 if (max_1 < max_2) ^ flip else 1
136+
if max_1 == -1:
137+
return 0
138+
return _gray_code_comparator(k1[0:-1], k2[0:-1], not flip)
139+
140+
141+
def _simplify_commuting_cnots(
142+
cnots: List[Tuple[int, int]], flip_control_and_target: bool
143+
) -> Tuple[bool, List[Tuple[int, int]]]:
144+
"""Attempts to commute CNOTs and remove cancelling pairs.
145+
146+
Commutation relations are based on 9 (flip_control_and_target=False) or 10
147+
(flip_control_target=True) of [4]:
148+
When flip_control_target=True:
149+
150+
CNOT(j, i) @ CNOT(j, k) = CNOT(j, k) @ CNOT(j, i)
151+
───X─────── ───────X───
152+
│ │
153+
───@───@─── = ───@───@───
154+
│ │
155+
───────X─── ───X───────
156+
157+
When flip_control_target=False:
158+
159+
CNOT(i, j) @ CNOT(k, j) = CNOT(k, j) @ CNOT(i, j)
160+
───@─────── ───────@───
161+
│ │
162+
───X───X─── = ───X───X───
163+
│ │
164+
───────@─── ───@───────
165+
166+
Args:
167+
cnots: A list of CNOTS, encoded as integer tuples (control, target). The code does not make
168+
any assumption as to the order of the CNOTs, but it is likely to work better if its
169+
inputs are from Gray-sorted Hamiltonians. Regardless of the order of the CNOTs, the
170+
code is conservative and should be robust to mis-ordered inputs with the only side
171+
effect being a lack of simplification.
172+
flip_control_and_target: Whether to flip control and target.
173+
174+
Returns:
175+
A tuple containing a Boolean that tells whether a simplification has been performed and the
176+
CNOT list, potentially simplified, encoded as integer tuples (control, target).
177+
"""
178+
179+
target, control = (0, 1) if flip_control_and_target else (1, 0)
180+
181+
i = 0
182+
qubit_to_index: Dict[int, int] = {cnots[i][control]: i} if cnots else {}
183+
for j in range(1, len(cnots)):
184+
if cnots[i][target] != cnots[j][target]:
185+
# The targets (resp. control) don't match, so we reset the search.
186+
i = j
187+
qubit_to_index = {cnots[j][control]: j}
188+
continue
189+
190+
if cnots[j][control] in qubit_to_index:
191+
k = qubit_to_index[cnots[j][control]]
192+
# The controls (resp. targets) are the same, so we can simplify away.
193+
cnots = [cnots[n] for n in range(len(cnots)) if n != j and n != k]
194+
# TODO(#4532): Speed up code by not returning early.
195+
return True, cnots
196+
197+
qubit_to_index[cnots[j][control]] = j
198+
199+
return False, cnots
200+
201+
202+
def _simplify_cnots_triplets(
203+
cnots: List[Tuple[int, int]], flip_control_and_target: bool
204+
) -> Tuple[bool, List[Tuple[int, int]]]:
205+
"""Simplifies CNOT pairs according to equation 11 of [4].
206+
207+
CNOT(i, j) @ CNOT(j, k) == CNOT(j, k) @ CNOT(i, k) @ CNOT(i, j)
208+
───@─────── ───────@───@───
209+
│ │ │
210+
───X───@─── = ───@───┼───X───
211+
│ │ │
212+
───────X─── ───X───X───────
213+
214+
Args:
215+
cnots: A list of CNOTS, encoded as integer tuples (control, target).
216+
flip_control_and_target: Whether to flip control and target.
217+
218+
Returns:
219+
A tuple containing a Boolean that tells whether a simplification has been performed and the
220+
CNOT list, potentially simplified, encoded as integer tuples (control, target).
221+
"""
222+
target, control = (0, 1) if flip_control_and_target else (1, 0)
223+
224+
# We investigate potential pivots sequentially.
225+
for j in range(1, len(cnots) - 1):
226+
# First, we look back for as long as the controls (resp. targets) are the same.
227+
# They all commute, so all are potential candidates for being simplified.
228+
# prev_match_index is qubit to index in `cnots` array.
229+
prev_match_index: Dict[int, int] = {}
230+
for i in range(j - 1, -1, -1):
231+
# These CNOTs have the same target (resp. control) and though they are not candidates
232+
# for simplification, since they commute, we can keep looking for candidates.
233+
if cnots[i][target] == cnots[j][target]:
234+
continue
235+
if cnots[i][control] != cnots[j][control]:
236+
break
237+
# We take a note of the control (resp. target).
238+
prev_match_index[cnots[i][target]] = i
239+
240+
# Next, we look forward for as long as the targets (resp. controls) are the
241+
# same. They all commute, so all are potential candidates for being simplified.
242+
# post_match_index is qubit to index in `cnots` array.
243+
post_match_index: Dict[int, int] = {}
244+
for k in range(j + 1, len(cnots)):
245+
# These CNOTs have the same control (resp. target) and though they are not candidates
246+
# for simplification, since they commute, we can keep looking for candidates.
247+
if cnots[j][control] == cnots[k][control]:
248+
continue
249+
if cnots[j][target] != cnots[k][target]:
250+
break
251+
# We take a note of the target (resp. control).
252+
post_match_index[cnots[k][control]] = k
253+
254+
# Among all the candidates, find if they have a match.
255+
keys = prev_match_index.keys() & post_match_index.keys()
256+
for key in keys:
257+
# We perform the swap which removes the pivot.
258+
new_idx: List[int] = (
259+
# Anything strictly before the pivot that is not the CNOT to swap.
260+
[idx for idx in range(0, j) if idx != prev_match_index[key]]
261+
# The two swapped CNOTs.
262+
+ [post_match_index[key], prev_match_index[key]]
263+
# Anything after the pivot that is not the CNOT to swap.
264+
+ [idx for idx in range(j + 1, len(cnots)) if idx != post_match_index[key]]
265+
)
266+
# Since we removed the pivot, the length should be one fewer.
267+
cnots = [cnots[idx] for idx in new_idx]
268+
# TODO(#4532): Speed up code by not returning early.
269+
return True, cnots
270+
271+
return False, cnots
272+
273+
274+
def _simplify_cnots(cnots: List[Tuple[int, int]]) -> List[Tuple[int, int]]:
275+
"""Takes a series of CNOTs and tries to applies rule to cancel out gates.
276+
277+
Algorithm based on "Efficient quantum circuits for diagonal unitaries without ancillas" by
278+
Jonathan Welch, Daniel Greenbaum, Sarah Mostame, Alán Aspuru-Guzik
279+
https://arxiv.org/abs/1306.3991
280+
281+
Args:
282+
cnots: A list of CNOTs represented as tuples of integer (control, target).
283+
284+
Returns:
285+
The simplified list of CNOTs, encoded as integer tuples (control, target).
286+
"""
287+
288+
found_simplification = True
289+
while found_simplification:
290+
for simplify_fn, flip_control_and_target in itertools.product(
291+
[_simplify_commuting_cnots, _simplify_cnots_triplets], [False, True]
292+
):
293+
found_simplification, cnots = simplify_fn(cnots, flip_control_and_target)
294+
if found_simplification:
295+
break
296+
297+
return cnots
298+
299+
115300
def _get_gates_from_hamiltonians(
116301
hamiltonian_polynomial_list: List['cirq.PauliSum'],
117302
qubit_map: Dict[str, 'cirq.Qid'],
@@ -145,16 +330,18 @@ def _apply_cnots(prevh: Tuple[int, ...], currh: Tuple[int, ...]):
145330
cnots.extend((prevh[i], prevh[-1]) for i in range(len(prevh) - 1))
146331
cnots.extend((currh[i], currh[-1]) for i in range(len(currh) - 1))
147332

148-
# TODO(tonybruguier): At this point, some CNOT gates can be cancelled out according to:
149-
# "Efficient quantum circuits for diagonal unitaries without ancillas" by Jonathan Welch,
150-
# Daniel Greenbaum, Sarah Mostame, Alán Aspuru-Guzik
151-
# https://arxiv.org/abs/1306.3991
333+
cnots = _simplify_cnots(cnots)
152334

153335
for gate in (cirq.CNOT(qubits[c], qubits[t]) for c, t in cnots):
154336
yield gate
155337

338+
sorted_hamiltonian_keys = sorted(
339+
hamiltonians.keys(), key=functools.cmp_to_key(_gray_code_comparator)
340+
)
341+
156342
previous_h: Tuple[int, ...] = ()
157-
for h, w in hamiltonians.items():
343+
for h in sorted_hamiltonian_keys:
344+
w = hamiltonians[h]
158345
yield _apply_cnots(previous_h, h)
159346

160347
if len(h) >= 1:

cirq-core/cirq/ops/boolean_hamiltonian_test.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,17 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import functools
1415
import itertools
1516
import math
17+
import random
1618

1719
import numpy as np
1820
import pytest
1921
import sympy.parsing.sympy_parser as sympy_parser
2022

2123
import cirq
24+
import cirq.ops.boolean_hamiltonian as bh
2225

2326

2427
@pytest.mark.parametrize(
@@ -98,3 +101,126 @@ def test_with_custom_names():
98101

99102
with pytest.raises(ValueError, match='Length of replacement qubits must be the same'):
100103
original_op.with_qubits(q2)
104+
105+
106+
@pytest.mark.parametrize(
107+
'n_bits,expected_hs',
108+
[
109+
(1, [(), (0,)]),
110+
(2, [(), (0,), (0, 1), (1,)]),
111+
(3, [(), (0,), (0, 1), (1,), (1, 2), (0, 1, 2), (0, 2), (2,)]),
112+
],
113+
)
114+
def test_gray_code_sorting(n_bits, expected_hs):
115+
hs_template = []
116+
for x in range(2 ** n_bits):
117+
h = []
118+
for i in range(n_bits):
119+
if x % 2 == 1:
120+
h.append(i)
121+
x -= 1
122+
x //= 2
123+
hs_template.append(tuple(sorted(h)))
124+
125+
for seed in range(10):
126+
random.seed(seed)
127+
128+
hs = hs_template.copy()
129+
random.shuffle(hs)
130+
131+
sorted_hs = sorted(list(hs), key=functools.cmp_to_key(bh._gray_code_comparator))
132+
133+
np.testing.assert_array_equal(sorted_hs, expected_hs)
134+
135+
136+
@pytest.mark.parametrize(
137+
'seq_a,seq_b,expected',
138+
[
139+
((), (), 0),
140+
((), (0,), -1),
141+
((0,), (), 1),
142+
((0,), (0,), 0),
143+
],
144+
)
145+
def test_gray_code_comparison(seq_a, seq_b, expected):
146+
assert bh._gray_code_comparator(seq_a, seq_b) == expected
147+
148+
149+
@pytest.mark.parametrize(
150+
'input_cnots,input_flip_control_and_target,expected_simplified,expected_output_cnots',
151+
[
152+
# Empty inputs don't get simplified.
153+
([], False, False, []),
154+
([], True, False, []),
155+
# Single CNOTs don't get simplified.
156+
([(0, 1)], False, False, [(0, 1)]),
157+
([(0, 1)], True, False, [(0, 1)]),
158+
# Simplify away two CNOTs that are identical:
159+
([(0, 1), (0, 1)], False, True, []),
160+
([(0, 1), (0, 1)], True, True, []),
161+
# Also simplify away if there's another CNOT in between.
162+
([(0, 1), (2, 1), (0, 1)], False, True, [(2, 1)]),
163+
([(0, 1), (0, 2), (0, 1)], True, True, [(0, 2)]),
164+
# However, the in-between has to share the same target/control.
165+
([(0, 1), (0, 2), (0, 1)], False, False, [(0, 1), (0, 2), (0, 1)]),
166+
([(0, 1), (2, 1), (0, 1)], True, False, [(0, 1), (2, 1), (0, 1)]),
167+
# Can simplify, but violates CNOT ordering assumption
168+
([(0, 1), (2, 3), (0, 1)], False, False, [(0, 1), (2, 3), (0, 1)]),
169+
],
170+
)
171+
def test_simplify_commuting_cnots(
172+
input_cnots, input_flip_control_and_target, expected_simplified, expected_output_cnots
173+
):
174+
actual_simplified, actual_output_cnots = bh._simplify_commuting_cnots(
175+
input_cnots, input_flip_control_and_target
176+
)
177+
assert actual_simplified == expected_simplified
178+
assert actual_output_cnots == expected_output_cnots
179+
180+
181+
@pytest.mark.parametrize(
182+
'input_cnots,input_flip_control_and_target,expected_simplified,expected_output_cnots',
183+
[
184+
# Empty inputs don't get simplified.
185+
([], False, False, []),
186+
([], True, False, []),
187+
# Single CNOTs don't get simplified.
188+
([(0, 1)], False, False, [(0, 1)]),
189+
([(0, 1)], True, False, [(0, 1)]),
190+
# Simplify according to equation 11 of [4].
191+
([(2, 1), (2, 0), (1, 0)], False, True, [(1, 0), (2, 1)]),
192+
([(1, 2), (0, 2), (0, 1)], True, True, [(0, 1), (1, 2)]),
193+
# Same as above, but with a intervening CNOTs that prevent simplifications.
194+
([(2, 1), (2, 0), (100, 101), (1, 0)], False, False, [(2, 1), (2, 0), (100, 101), (1, 0)]),
195+
([(2, 1), (100, 101), (2, 0), (1, 0)], False, False, [(2, 1), (100, 101), (2, 0), (1, 0)]),
196+
# swap (2, 1) and (1, 0) around (2, 0)
197+
([(2, 1), (2, 3), (2, 0), (3, 0), (1, 0)], False, True, [(2, 3), (1, 0), (2, 1), (3, 0)]),
198+
([(2, 1), (2, 0), (2, 3), (3, 0), (1, 0)], False, True, [(1, 0), (2, 1), (2, 3), (3, 0)]),
199+
([(2, 3), (2, 1), (2, 0), (3, 0), (1, 0)], False, True, [(2, 3), (1, 0), (2, 1), (3, 0)]),
200+
([(2, 1), (2, 3), (3, 0), (2, 0), (1, 0)], False, True, [(2, 3), (3, 0), (1, 0), (2, 1)]),
201+
([(2, 1), (2, 3), (2, 0), (1, 0), (3, 0)], False, True, [(2, 3), (1, 0), (2, 1), (3, 0)]),
202+
],
203+
)
204+
def test_simplify_cnots_triplets(
205+
input_cnots, input_flip_control_and_target, expected_simplified, expected_output_cnots
206+
):
207+
actual_simplified, actual_output_cnots = bh._simplify_cnots_triplets(
208+
input_cnots, input_flip_control_and_target
209+
)
210+
assert actual_simplified == expected_simplified
211+
assert actual_output_cnots == expected_output_cnots
212+
213+
# Check that the unitaries are the same.
214+
qubit_ids = set(sum(input_cnots, ()))
215+
qubits = {qubit_id: cirq.NamedQubit(f"{qubit_id}") for qubit_id in qubit_ids}
216+
217+
target, control = (0, 1) if input_flip_control_and_target else (1, 0)
218+
219+
circuit_input = cirq.Circuit()
220+
for input_cnot in input_cnots:
221+
circuit_input.append(cirq.CNOT(qubits[input_cnot[target]], qubits[input_cnot[control]]))
222+
circuit_actual = cirq.Circuit()
223+
for actual_cnot in actual_output_cnots:
224+
circuit_actual.append(cirq.CNOT(qubits[actual_cnot[target]], qubits[actual_cnot[control]]))
225+
226+
np.testing.assert_allclose(cirq.unitary(circuit_input), cirq.unitary(circuit_actual), atol=1e-6)

0 commit comments

Comments
 (0)