Skip to content

Commit 11ece2c

Browse files
committed
Merge branch 'master' of https://github.com/quantumlib/cirq into controlled_op_gate_fix
2 parents 7bace85 + cea9eec commit 11ece2c

File tree

18 files changed

+288
-367
lines changed

18 files changed

+288
-367
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -191,11 +191,6 @@ jobs:
191191
- name: Install requirements
192192
run: |
193193
pip install -r dev_tools/requirements/deps/protos.txt
194-
sudo apt install curl gnupg
195-
curl -fsSL https://bazel.build/bazel-release.pub.gpg | gpg --dearmor > bazel.gpg
196-
sudo mv bazel.gpg /etc/apt/trusted.gpg.d/
197-
echo "deb [arch=amd64] https://storage.googleapis.com/bazel-apt stable jdk1.8" | sudo tee /etc/apt/sources.list.d/bazel.list
198-
sudo apt update && sudo apt install bazel
199194
- name: Build protos
200195
run: check/build-changed-protos
201196
coverage:
@@ -335,4 +330,4 @@ jobs:
335330
- name: Install npm dependencies
336331
run: check/npm ci
337332
- name: Test
338-
run: check/ts-coverage
333+
run: check/ts-coverage

WORKSPACE

Lines changed: 0 additions & 62 deletions
This file was deleted.

check/build-changed-protos

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -71,30 +71,9 @@ if [[ ! -z "$changed" ]]; then
7171
done
7272
fi
7373

74-
# Check if any of the build files have changed.
75-
build_changed=$(git diff --name-only ${rev} -- \
76-
| grep "google/api/.*BUILD$" \
77-
| sed -e "s/BUILD//")
78-
79-
if [[ -z "$build_changed" ]]; then
80-
echo -e "No BUILD files changed."
81-
else
82-
# Build all the targets for the changed BUILD rule.
83-
echo -e "BUILD file changed in directories: '${build_changed[@]}'"
84-
for build_prefix in $build_changed
85-
do
86-
(set -x; bazel build "${build_prefix}...:*")
87-
done
88-
fi
89-
90-
# We potentially will attempt to build a second time, but this is not
91-
# a problem as bazel returns immediately.
74+
# We potentially will attempt to build a second time.
9275
for base in $changed
9376
do
94-
(set -x; bazel build "${base}_proto"; \
95-
bazel build "${base}_py_proto"; \
96-
bazel build "${base}_cc_proto")
97-
9877
python -m grpc_tools.protoc -I=cirq-google --python_out=cirq-google --mypy_out=cirq-google ${base}.proto
9978
done
10079

@@ -109,4 +88,4 @@ if [[ ! -z "$untracked" ]]; then
10988
echo -e "\033[31m ${generated}\033[0m"
11089
done
11190
exit 1
112-
fi
91+
fi

check/build-protos

Lines changed: 0 additions & 30 deletions
This file was deleted.

cirq-core/cirq/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@
259259
QubitOrderOrList,
260260
QubitPermutationGate,
261261
reset,
262+
reset_each,
262263
ResetChannel,
263264
riswap,
264265
Rx,

cirq-core/cirq/devices/noise_model.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,7 @@ def is_virtual_moment(self, moment: 'cirq.Moment') -> bool:
8888
"""
8989
if not moment.operations:
9090
return False
91-
return all(
92-
[
93-
isinstance(op, ops.TaggedOperation) and ops.VirtualTag() in op.tags
94-
for op in moment.operations
95-
]
96-
)
91+
return all(ops.VirtualTag() in op.tags for op in moment)
9792

9893
def _noisy_moments_impl_moment(
9994
self, moments: 'Iterable[cirq.Moment]', system_qubits: Sequence['cirq.Qid']

cirq-core/cirq/linalg/transformations.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def fractional_swap(target):
287287
return out
288288

289289

290-
def partial_trace(tensor: np.ndarray, keep_indices: List[int]) -> np.ndarray:
290+
def partial_trace(tensor: np.ndarray, keep_indices: Sequence[int]) -> np.ndarray:
291291
"""Takes the partial trace of a given tensor.
292292
293293
The input tensor must have shape `(d_0, ..., d_{k-1}, d_0, ..., d_{k-1})`.
@@ -620,12 +620,13 @@ def factor_density_matrix(
620620
`remainder` means the sub-matrix on the remaining axes, in the same
621621
order as the original density matrix.
622622
"""
623-
axes1 = list(axes) + [i + int(t.ndim / 2) for i in axes]
624-
extracted, remainder = factor_state_vector(t, axes1, validate=False)
623+
extracted = partial_trace(t, axes)
624+
remaining_axes = [i for i in range(t.ndim // 2) if i not in axes]
625+
remainder = partial_trace(t, remaining_axes)
625626
if validate:
626627
t1 = density_matrix_kronecker_product(extracted, remainder)
627-
axes2 = list(axes) + [i for i in range(int(t.ndim / 2)) if i not in axes]
628-
t2 = transpose_density_matrix_to_axis_order(t1, axes2)
628+
product_axes = list(axes) + remaining_axes
629+
t2 = transpose_density_matrix_to_axis_order(t1, product_axes)
629630
if not np.allclose(t2, t, atol=atol):
630631
raise ValueError('The tensor cannot be factored by the requested axes')
631632
return extracted, remainder

cirq-core/cirq/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
PhaseDampingChannel,
4646
PhaseFlipChannel,
4747
reset,
48+
reset_each,
4849
ResetChannel,
4950
)
5051

cirq-core/cirq/ops/common_channels.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""Quantum channels that are commonly used in the literature."""
1616

1717
import itertools
18-
from typing import Any, Dict, Iterable, Optional, Sequence, Tuple, Union, TYPE_CHECKING
18+
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union, TYPE_CHECKING
1919

2020
import numpy as np
2121

@@ -789,6 +789,11 @@ def reset(qubit: 'cirq.Qid') -> raw_types.Operation:
789789
return ResetChannel(qubit.dimension).on(qubit)
790790

791791

792+
def reset_each(*qubits: 'cirq.Qid') -> List[raw_types.Operation]:
793+
"""Returns a list of `ResetChannel` instances on the given qubits."""
794+
return [ResetChannel(q.dimension).on(q) for q in qubits]
795+
796+
792797
@value.value_equality
793798
class PhaseDampingChannel(gate_features.SingleQubitGate):
794799
"""Dampen qubit phase.

cirq-core/cirq/ops/common_channels_test.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,11 +241,22 @@ def test_deprecated_on_each_for_depolarizing_channel_one_qubit():
241241

242242

243243
def test_deprecated_on_each_for_depolarizing_channel_two_qubits():
244-
q0, q1 = cirq.LineQubit.range(2)
244+
q0, q1, q2, q3, q4, q5 = cirq.LineQubit.range(6)
245245
op = cirq.DepolarizingChannel(p=0.1, n_qubits=2)
246246

247-
with pytest.raises(ValueError, match="one qubit"):
247+
op.on_each([(q0, q1)])
248+
op.on_each([(q0, q1), (q2, q3)])
249+
op.on_each(zip([q0, q2, q4], [q1, q3, q5]))
250+
op.on_each((q0, q1))
251+
op.on_each([q0, q1])
252+
with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'):
248253
op.on_each(q0, q1)
254+
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
255+
op.on_each([('bogus object 0', 'bogus object 1')])
256+
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
257+
op.on_each(['01'])
258+
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
259+
op.on_each([(False, None)])
249260

250261

251262
def test_depolarizing_channel_apply_two_qubits():
@@ -506,6 +517,16 @@ def test_reset_act_on():
506517
)
507518

508519

520+
def test_reset_each():
521+
qubits = cirq.LineQubit.range(8)
522+
for n in range(len(qubits) + 1):
523+
ops = cirq.reset_each(*qubits[:n])
524+
assert len(ops) == n
525+
for i, op in enumerate(ops):
526+
assert isinstance(op.gate, cirq.ResetChannel)
527+
assert op.qubits == (qubits[i],)
528+
529+
509530
def test_phase_damping_channel():
510531
d = cirq.phase_damp(0.3)
511532
np.testing.assert_almost_equal(

cirq-core/cirq/ops/gate_features.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
"""
1919

2020
import abc
21-
from typing import Union, Iterable, Any, List
21+
from typing import Union, Iterable, Any, List, Sequence
2222

2323
from cirq.ops import raw_types
2424

@@ -36,20 +36,39 @@ class SupportsOnEachGate(raw_types.Gate, metaclass=abc.ABCMeta):
3636

3737
def on_each(self, *targets: Union[raw_types.Qid, Iterable[Any]]) -> List[raw_types.Operation]:
3838
"""Returns a list of operations applying the gate to all targets.
39-
4039
Args:
41-
*targets: The qubits to apply this gate to.
42-
40+
*targets: The qubits to apply this gate to. For single-qubit gates
41+
this can be provided as varargs or a combination of nested
42+
iterables. For multi-qubit gates this must be provided as an
43+
`Iterable[Sequence[Qid]]`, where each sequence has `num_qubits`
44+
qubits.
4345
Returns:
4446
Operations applying this gate to the target qubits.
45-
4647
Raises:
47-
ValueError if targets are not instances of Qid or List[Qid].
48-
ValueError if the gate operates on two or more Qids.
48+
ValueError if targets are not instances of Qid or Iterable[Qid].
49+
ValueError if the gate qubit number is incompatible.
4950
"""
51+
operations: List[raw_types.Operation] = []
5052
if self._num_qubits_() > 1:
51-
raise ValueError('This gate only supports on_each when it is a one qubit gate.')
52-
operations = [] # type: List[raw_types.Operation]
53+
iterator: Iterable = targets
54+
if len(targets) == 1:
55+
if not isinstance(targets[0], Iterable):
56+
raise TypeError(f'{targets[0]} object is not iterable.')
57+
t0 = list(targets[0])
58+
iterator = [t0] if t0 and isinstance(t0[0], raw_types.Qid) else t0
59+
for target in iterator:
60+
if not isinstance(target, Sequence):
61+
raise ValueError(
62+
f'Inputs to multi-qubit gates must be Sequence[Qid].'
63+
f' Type: {type(target)}'
64+
)
65+
if not all(isinstance(x, raw_types.Qid) for x in target):
66+
raise ValueError(f'All values in sequence should be Qids, but got {target}')
67+
if len(target) != self._num_qubits_():
68+
raise ValueError(f'Expected {self._num_qubits_()} qubits, got {target}')
69+
operations.append(self.on(*target))
70+
return operations
71+
5372
for target in targets:
5473
if isinstance(target, raw_types.Qid):
5574
operations.append(self.on(target))

cirq-core/cirq/ops/identity_test.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,28 @@ def test_identity_on_each_only_single_qubit():
6666
cirq.IdentityGate(1, (3,)).on(q0_3),
6767
cirq.IdentityGate(1, (3,)).on(q1_3),
6868
]
69-
with pytest.raises(ValueError, match='one qubit'):
70-
cirq.IdentityGate(num_qubits=2).on_each(q0, q1)
69+
70+
71+
def test_identity_on_each_two_qubits():
72+
q0, q1, q2, q3 = cirq.LineQubit.range(4)
73+
q0_3, q1_3 = q0.with_dimension(3), q1.with_dimension(3)
74+
assert cirq.IdentityGate(2).on_each([(q0, q1)]) == [cirq.IdentityGate(2)(q0, q1)]
75+
assert cirq.IdentityGate(2).on_each([(q0, q1), (q2, q3)]) == [
76+
cirq.IdentityGate(2)(q0, q1),
77+
cirq.IdentityGate(2)(q2, q3),
78+
]
79+
assert cirq.IdentityGate(2, (3, 3)).on_each([(q0_3, q1_3)]) == [
80+
cirq.IdentityGate(2, (3, 3))(q0_3, q1_3),
81+
]
82+
assert cirq.IdentityGate(2).on_each((q0, q1)) == [cirq.IdentityGate(2)(q0, q1)]
83+
with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'):
84+
cirq.IdentityGate(2).on_each(q0, q1)
85+
with pytest.raises(ValueError, match='All values in sequence should be Qids'):
86+
cirq.IdentityGate(2).on_each([[(q0, q1)]])
87+
with pytest.raises(ValueError, match='Expected 2 qubits'):
88+
cirq.IdentityGate(2).on_each([(q0,)])
89+
with pytest.raises(ValueError, match='Expected 2 qubits'):
90+
cirq.IdentityGate(2).on_each([(q0, q1, q2)])
7191

7292

7393
@pytest.mark.parametrize('num_qubits', [1, 2, 4])

cirq-core/cirq/ops/raw_types.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ def untagged(self) -> 'cirq.Operation':
425425
"""Returns the underlying operation without any tags."""
426426
return self
427427

428-
def with_tags(self, *new_tags: Hashable) -> 'cirq.TaggedOperation':
428+
def with_tags(self, *new_tags: Hashable) -> 'cirq.Operation':
429429
"""Creates a new TaggedOperation, with this op and the specified tags.
430430
431431
This method can be used to attach meta-data to specific operations
@@ -443,6 +443,8 @@ def with_tags(self, *new_tags: Hashable) -> 'cirq.TaggedOperation':
443443
Args:
444444
new_tags: The tags to wrap this operation in.
445445
"""
446+
if not new_tags:
447+
return self
446448
return TaggedOperation(self, *new_tags)
447449

448450
def transform_qubits(
@@ -608,6 +610,8 @@ def with_tags(self, *new_tags: Hashable) -> 'cirq.TaggedOperation':
608610
that has the tags of this operation combined with the new_tags
609611
specified as the parameter.
610612
"""
613+
if not new_tags:
614+
return self
611615
return TaggedOperation(self.sub_operation, *self._tags, *new_tags)
612616

613617
def __str__(self) -> str:

0 commit comments

Comments
 (0)