Skip to content

Commit 81df6b3

Browse files
Finish support for LineQubit everywhere. (#641)
* format. * Jae feedback.
1 parent 2773a6d commit 81df6b3

14 files changed

+70
-16
lines changed

tensorflow_quantum/core/ops/circuit_execution_ops_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,8 @@ def test_supported_gates_consistent(self, op_and_sim):
175175
"""Ensure that supported gates are consistent across backends."""
176176
op = op_and_sim[0]
177177
sim = op_and_sim[1]
178-
qubits = cirq.GridQubit.rect(1, 5)
178+
# mix qubit types.
179+
qubits = cirq.GridQubit.rect(1, 4) + [cirq.LineQubit(10)]
179180
circuit_batch = []
180181

181182
gate_ref = util.get_supported_gates()
@@ -345,7 +346,7 @@ def test_analytical_expectation(self, op_and_sim, n_qubits, symbol_names,
345346
op = op_and_sim[0]
346347
sim = op_and_sim[1]
347348

348-
qubits = cirq.GridQubit.rect(1, n_qubits)
349+
qubits = cirq.LineQubit.range(n_qubits - 1) + [cirq.GridQubit(0, 0)]
349350
circuit_batch, resolver_batch = \
350351
util.random_symbol_circuit_resolver_batch(
351352
qubits, symbol_names, BATCH_SIZE)

tensorflow_quantum/core/ops/math_ops/fidelity_op_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def test_correctness_with_symbols(self, n_qubits, batch_size,
113113
def test_correctness_without_symbols(self, n_qubits, batch_size,
114114
inner_dim_size):
115115
"""Tests that inner_product works without symbols."""
116-
qubits = cirq.GridQubit.rect(1, n_qubits)
116+
qubits = cirq.LineQubit.range(n_qubits)
117117
circuit_batch, _ = \
118118
util.random_circuit_resolver_batch(
119119
qubits, batch_size)

tensorflow_quantum/core/ops/math_ops/inner_product_grad_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def test_correctness_with_symbols(self, n_qubits, batch_size,
235235
"""Tests that inner_product works with symbols."""
236236
symbol_names = ['alpha', 'beta', 'gamma']
237237
n_params = len(symbol_names)
238-
qubits = cirq.GridQubit.rect(1, n_qubits)
238+
qubits = cirq.LineQubit.range(n_qubits)
239239
circuit_batch, resolver_batch = \
240240
util.random_symbol_circuit_resolver_batch(
241241
qubits, symbol_names, batch_size)

tensorflow_quantum/core/ops/math_ops/inner_product_op_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def test_correctness_with_symbols(self, n_qubits, batch_size,
232232
inner_dim_size):
233233
"""Tests that inner_product works with symbols."""
234234
symbol_names = ['alpha', 'beta', 'gamma']
235-
qubits = cirq.GridQubit.rect(1, n_qubits)
235+
qubits = cirq.LineQubit.range(n_qubits)
236236
circuit_batch, resolver_batch = \
237237
util.random_symbol_circuit_resolver_batch(
238238
qubits, symbol_names, batch_size)

tensorflow_quantum/core/ops/noise/noisy_expectation_op_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def test_single_channel(self, channel):
273273
symbol_names = []
274274
batch_size = 5
275275
n_qubits = 6
276-
qubits = cirq.GridQubit.rect(1, n_qubits)
276+
qubits = cirq.LineQubit.range(n_qubits)
277277

278278
circuit_batch, resolver_batch = \
279279
util.random_circuit_resolver_batch(

tensorflow_quantum/core/ops/noise/noisy_sampled_expectation_op_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def test_single_channel(self, channel):
273273
symbol_names = []
274274
batch_size = 5
275275
n_qubits = 6
276-
qubits = cirq.GridQubit.rect(1, n_qubits)
276+
qubits = cirq.LineQubit.range(n_qubits)
277277

278278
circuit_batch, resolver_batch = \
279279
util.random_circuit_resolver_batch(

tensorflow_quantum/core/ops/noise/noisy_samples_op_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def test_simulate_samples_inputs(self):
164164
def test_simulate_consistency(self, batch_size, n_qubits, noisy):
165165
"""Test consistency with batch_util.py simulation."""
166166
symbol_names = ['alpha', 'beta']
167-
qubits = cirq.GridQubit.rect(1, n_qubits)
167+
qubits = cirq.LineQubit.range(n_qubits)
168168

169169
circuit_batch, resolver_batch = \
170170
util.random_symbol_circuit_resolver_batch(

tensorflow_quantum/core/ops/tfq_adj_grad_op_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ def test_calculate_adj_grad_simple_case_single(self):
359359
n_qubits = 2
360360
batch_size = 1
361361
symbol_names = ['alpha', 'beta', 'gamma']
362-
qubits = cirq.GridQubit.rect(1, n_qubits)
362+
qubits = cirq.LineQubit.range(n_qubits)
363363
circuit_batch, resolver_batch = \
364364
[cirq.Circuit(cirq.X(qubits[0]) ** sympy.Symbol('alpha'),
365365
cirq.Y(qubits[1]) ** sympy.Symbol('alpha'),

tensorflow_quantum/core/ops/tfq_ps_util_ops_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def test_moment_preservation(self):
259259
"""Test Moment-structure preservation."""
260260
t = sympy.Symbol('t')
261261
r = sympy.Symbol('r')
262-
qubits = cirq.GridQubit.rect(1, 6)
262+
qubits = cirq.LineQubit.range(6)
263263
circuit_batch = [
264264
cirq.Circuit(
265265
cirq.Moment([cirq.H(q) for q in qubits]),
@@ -471,7 +471,7 @@ def test_weight_coefficient(self):
471471

472472
def test_simple_pad(self):
473473
"""Test simple padding."""
474-
bit = cirq.GridQubit(0, 0)
474+
bit = cirq.LineQubit(1)
475475
circuit = cirq.Circuit(
476476
cirq.X(bit)**sympy.Symbol('alpha'),
477477
cirq.Y(bit)**sympy.Symbol('alpha'),
@@ -720,7 +720,7 @@ def test_error(self):
720720

721721
def test_many_values(self):
722722
"""Ensure that padding with few symbols and many values works."""
723-
bit = cirq.GridQubit(0, 0)
723+
bit = cirq.LineQubit(1)
724724
circuits = [
725725
cirq.Circuit(
726726
cirq.X(bit)**(sympy.Symbol('alpha') * 2.0),

tensorflow_quantum/core/ops/tfq_unitary_op_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def test_calculate_unitary_consistency_symbol_free(self, n_qubits,
170170
unitary_op):
171171
"""Test calculate_unitary works without symbols."""
172172
unitary_op = tfq_unitary_op.get_unitary_op()
173-
qubits = cirq.GridQubit.rect(1, n_qubits)
173+
qubits = cirq.LineQubit.range(n_qubits)
174174
circuit_batch, _ = util.random_circuit_resolver_batch(qubits, 25)
175175

176176
tfq_results = unitary_op(util.convert_to_tensor(circuit_batch), [],

tensorflow_quantum/core/ops/tfq_utility_ops_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def test_append_circuit(self, max_n_bits, symbols, n_circuits):
9090
base_circuits = []
9191
circuits_to_append = []
9292
qubits = cirq.GridQubit.rect(1, max_n_bits)
93-
other_qubits = cirq.GridQubit.rect(2, max_n_bits)
93+
other_qubits = cirq.GridQubit.rect(2, max_n_bits) + [cirq.LineQubit(10)]
9494

9595
base_circuits, _ = util.random_symbol_circuit_resolver_batch(
9696
qubits,

tensorflow_quantum/core/src/program_resolution.cc

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,16 @@ using tfq::proto::PauliTerm;
4040
using tfq::proto::Program;
4141
using tfq::proto::Qubit;
4242

43+
inline absl::string_view IntMaxStr() {
44+
static constexpr char kMaxVal[] = "2147483647";
45+
return kMaxVal;
46+
}
47+
4348
Status RegisterQubits(
4449
const std::string& qb_string,
4550
absl::flat_hash_set<std::pair<std::pair<int, int>, std::string>>* id_set) {
4651
// Inserts qubits found in qb_string into id_set.
52+
// Supported GridQubit wire formats and line qubit wire formats.
4753

4854
if (qb_string.empty()) {
4955
return Status::OK(); // no control-default value specified in serializer.py
@@ -52,7 +58,11 @@ Status RegisterQubits(
5258
const std::vector<absl::string_view> qb_list = absl::StrSplit(qb_string, ',');
5359
for (auto qb : qb_list) {
5460
int r, c;
55-
const std::vector<absl::string_view> splits = absl::StrSplit(qb, '_');
61+
std::vector<absl::string_view> splits = absl::StrSplit(qb, '_');
62+
if (splits.size() == 1) { // Pad the front of linequbit with INTMAX.
63+
splits.insert(splits.begin(), IntMaxStr());
64+
}
65+
5666
if (splits.size() != 2) {
5767
return Status(tensorflow::error::INVALID_ARGUMENT,
5868
absl::StrCat("Unable to parse qubit: ", qb));

tensorflow_quantum/core/src/program_resolution_test.cc

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,29 @@ const std::string valid_program = R"(
5454
}
5555
)";
5656

57+
const std::string valid_line_program = R"(
58+
circuit {
59+
moments {
60+
operations {
61+
args {
62+
key: "control_qubits"
63+
value {
64+
arg_value {
65+
string_value: "0_1"
66+
}
67+
}
68+
}
69+
qubits {
70+
id: "1"
71+
}
72+
qubits {
73+
id: "2"
74+
}
75+
}
76+
}
77+
}
78+
)";
79+
5780
const std::string valid_psum = R"(
5881
terms {
5982
coefficient_real: 1.0
@@ -123,6 +146,26 @@ TEST(ProgramResolutionTest, ResolveQubitIdsValid) {
123146
"0");
124147
}
125148

149+
TEST(ProgramResolutionTest, ResolveQubitIdsValidLine) {
150+
Program program;
151+
unsigned int qubit_count;
152+
ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(valid_line_program,
153+
&program));
154+
155+
EXPECT_EQ(ResolveQubitIds(&program, &qubit_count), Status::OK());
156+
EXPECT_EQ(qubit_count, 3);
157+
EXPECT_EQ(program.circuit().moments(0).operations(0).qubits(0).id(), "1");
158+
EXPECT_EQ(program.circuit().moments(0).operations(0).qubits(1).id(), "2");
159+
EXPECT_EQ(program.circuit()
160+
.moments(0)
161+
.operations(0)
162+
.args()
163+
.at("control_qubits")
164+
.arg_value()
165+
.string_value(),
166+
"0");
167+
}
168+
126169
TEST(ProgramResolutionTest, ResolveQubitIdsInvalidControlQubit) {
127170
Program program;
128171
unsigned int qubit_count;

tensorflow_quantum/python/util_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def _exponential(theta, op):
3939
return np.eye(op_mat.shape[0]) * np.cos(theta) - 1j * op_mat * np.sin(theta)
4040

4141

42-
BITS = list(cirq.GridQubit.rect(1, 10))
42+
BITS = list(cirq.GridQubit.rect(1, 10) + cirq.LineQubit.range(2))
4343

4444

4545
def _items_to_tensorize():

0 commit comments

Comments
 (0)