Skip to content

Commit 09ba5d3

Browse files
authored
Implement serialization logic for ProjectorSums. (#623)
* Serialize Projector Sums * nits * Use Boolean to encode the basis state * Address review comments
1 parent 81df6b3 commit 09ba5d3

File tree

8 files changed

+230
-0
lines changed

8 files changed

+230
-0
lines changed

tensorflow_quantum/core/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ py_library(
1212
deps = [
1313
"//tensorflow_quantum/core/ops",
1414
"//tensorflow_quantum/core/proto:pauli_sum_py_proto",
15+
"//tensorflow_quantum/core/proto:projector_sum_py_proto",
1516
"//tensorflow_quantum/core/serialize",
1617
],
1718
)

tensorflow_quantum/core/ops/BUILD

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ cc_binary(
8282
":parse_context",
8383
# cirq cc proto
8484
# pauli sum cc proto
85+
# projector sum cc proto
8586
":tfq_simulate_utils",
8687
"//tensorflow_quantum/core/src:adj_util",
8788
"//tensorflow_quantum/core/src:circuit_parser_qsim",
@@ -204,6 +205,7 @@ cc_binary(
204205
# cirq cc proto
205206
"//tensorflow_quantum/core/proto:pauli_sum_cc_proto",
206207
"//tensorflow_quantum/core/proto:program_cc_proto",
208+
"//tensorflow_quantum/core/proto:projector_sum_cc_proto",
207209
"//tensorflow_quantum/core/src:circuit_parser_qsim",
208210
"//tensorflow_quantum/core/src:program_resolution",
209211
"//tensorflow_quantum/core/src:util_qsim",
@@ -286,6 +288,7 @@ cc_library(
286288
":tfq_simulate_utils",
287289
"//tensorflow_quantum/core/proto:pauli_sum_cc_proto",
288290
"//tensorflow_quantum/core/proto:program_cc_proto",
291+
"//tensorflow_quantum/core/proto:projector_sum_cc_proto",
289292
"//tensorflow_quantum/core/src:program_resolution",
290293
"@com_google_absl//absl/container:flat_hash_map",
291294
"@com_google_absl//absl/container:inlined_vector",
@@ -344,6 +347,7 @@ cc_binary(
344347
# cirq cc proto
345348
"//tensorflow_quantum/core/proto:pauli_sum_cc_proto",
346349
"//tensorflow_quantum/core/proto:program_cc_proto",
350+
"//tensorflow_quantum/core/proto:projector_sum_cc_proto",
347351
"//tensorflow_quantum/core/src:circuit_parser_qsim",
348352
"//tensorflow_quantum/core/src:util_qsim",
349353
"@com_google_absl//absl/container:flat_hash_map",
@@ -374,6 +378,7 @@ py_library(
374378
deps = [
375379
":load_module",
376380
# pauli sum cc proto
381+
# projector sum cc proto
377382
# tensorflow framework for wrappers
378383
],
379384
)
@@ -464,6 +469,7 @@ py_library(
464469
":batch_util",
465470
"//tensorflow_quantum/core/proto:program_py_proto",
466471
"//tensorflow_quantum/core/proto:pauli_sum_py_proto",
472+
"//tensorflow_quantum/core/proto:projector_sum_py_proto",
467473
"//tensorflow_quantum/core/serialize:serializer",
468474
],
469475
)

tensorflow_quantum/core/proto/BUILD

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,20 @@ cc_proto_library(
3838
":pauli_sum_proto",
3939
],
4040
)
41+
42+
py_proto_library(
43+
name = "projector_sum_py_proto",
44+
srcs = ["projector_sum.proto"],
45+
)
46+
47+
proto_library(
48+
name = "projector_sum_proto",
49+
srcs = ["projector_sum.proto"],
50+
)
51+
52+
cc_proto_library(
53+
name = "projector_sum_cc_proto",
54+
deps = [
55+
":projector_sum_proto",
56+
],
57+
)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
syntax = "proto3";
2+
3+
package tfq.proto;
4+
5+
// Store the sum of simpler terms.
6+
message ProjectorSum {
7+
repeated ProjectorTerm terms = 1;
8+
}
9+
10+
// Store a term which is a coefficient and the qubits of the projection.
11+
message ProjectorTerm {
12+
float coefficient_real = 1;
13+
float coefficient_imag = 2;
14+
repeated ProjectorDictEntry projector_dict = 3;
15+
}
16+
17+
// Store a single projection.
18+
message ProjectorDictEntry {
19+
string qubit_id = 1;
20+
// False means |0> and true means |1>.
21+
bool basis_state = 2;
22+
}
23+

tensorflow_quantum/core/serialize/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ py_library(
8282
":serializable_gate_set",
8383
"//tensorflow_quantum/core/proto:pauli_sum_py_proto",
8484
"//tensorflow_quantum/core/proto:program_py_proto",
85+
"//tensorflow_quantum/core/proto:projector_sum_py_proto",
8586
],
8687
)
8788

@@ -93,5 +94,6 @@ py_test(
9394
":serializer",
9495
# cirq proto
9596
"//tensorflow_quantum/core/proto:pauli_sum_py_proto",
97+
"//tensorflow_quantum/core/proto:projector_sum_py_proto",
9698
],
9799
)

tensorflow_quantum/core/serialize/serializer.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
serializable_gate_set
2525
from tensorflow_quantum.core.proto import pauli_sum_pb2
2626
from tensorflow_quantum.core.proto import program_pb2
27+
from tensorflow_quantum.core.proto import projector_sum_pb2
2728

2829
# Needed to allow autograph to crawl AST without erroring.
2930
_CONSTANT_TRUE = lambda x: True
@@ -971,3 +972,68 @@ def _process_pauli_type(char):
971972
if char == 'Y':
972973
return cirq.Y
973974
raise ValueError("Invalid pauli type.")
975+
976+
977+
def serialize_projectorsum(projectorsum):
978+
"""Constructs a projector_sum proto from `cirq.ProjectorSum`.
979+
980+
Args:
981+
projectorsum: A `cirq.ProjectorSum` or `cirq.ProjectorString` object.
982+
983+
Returns:
984+
A projector_sum proto object.
985+
"""
986+
if isinstance(projectorsum, cirq.ProjectorString):
987+
projectorsum = cirq.ProjectorSum.from_pauli_strings(projectorsum)
988+
989+
if not isinstance(projectorsum, cirq.ProjectorSum):
990+
raise TypeError("serialize requires a cirq.ProjectorSum object."
991+
" Given: " + str(type(projectorsum)))
992+
993+
if any(not isinstance(qubit, (cirq.GridQubit, cirq.LineQubit))
994+
for qubit in projectorsum.qubits):
995+
raise ValueError("Attempted to serialize a paulisum that doesn't use "
996+
"only cirq.GridQubit or cirq.LineQubit.")
997+
998+
projectorsum_proto = projector_sum_pb2.ProjectorSum()
999+
for term in projectorsum:
1000+
projectorterm_proto = projector_sum_pb2.ProjectorTerm()
1001+
1002+
projectorterm_proto.coefficient_real = term.coefficient.real
1003+
projectorterm_proto.coefficient_imag = term.coefficient.imag
1004+
for qubit, basis_state in sorted(
1005+
term.projector_dict.items()): # sort to keep qubits ordered
1006+
projectorterm_proto.projector_dict.add(
1007+
qubit_id=op_serializer.qubit_to_proto(qubit),
1008+
basis_state=basis_state)
1009+
1010+
projectorsum_proto.terms.extend([projectorterm_proto])
1011+
1012+
return projectorsum_proto
1013+
1014+
1015+
def deserialize_projectorsum(proto):
1016+
"""Constructs a `cirq.ProjectorSum` from projector_sum proto.
1017+
1018+
Args:
1019+
proto: A projector_sum proto object.
1020+
1021+
Returns:
1022+
A `cirq.ProjectorSum` object.
1023+
"""
1024+
if not isinstance(proto, projector_sum_pb2.ProjectorSum):
1025+
raise TypeError("deserialize requires a projector_sum_pb2 object."
1026+
" Given: " + str(type(proto)))
1027+
1028+
res = cirq.ProjectorSum()
1029+
for term_proto in proto.terms:
1030+
coef = float(_round(term_proto.coefficient_real)) + \
1031+
1.0j * float(_round(term_proto.coefficient_imag))
1032+
projector_dict = {}
1033+
for projector_dict_entry in term_proto.projector_dict:
1034+
qubit = op_deserializer.qubit_from_proto(
1035+
projector_dict_entry.qubit_id)
1036+
projector_dict[qubit] = 1 if projector_dict_entry.basis_state else 0
1037+
res += cirq.ProjectorString(projector_dict, coef)
1038+
1039+
return res

tensorflow_quantum/core/serialize/serializer_test.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from absl.testing import parameterized
2323
from tensorflow_quantum.core.proto import pauli_sum_pb2
2424
from tensorflow_quantum.core.proto import program_pb2
25+
from tensorflow_quantum.core.proto import projector_sum_pb2
2526
from tensorflow_quantum.core.serialize import serializer
2627

2728

@@ -438,6 +439,37 @@ def _get_valid_pauli_proto_pairs(qubit_type='grid'):
438439
return pairs
439440

440441

442+
def _get_valid_projector_proto_pairs(qubit_type='grid'):
443+
"""Generate valid projectorsum proto pairs."""
444+
q0 = cirq.GridQubit(0, 0)
445+
q1 = cirq.GridQubit(1, 0)
446+
q0_str = '0_0'
447+
q1_str = '1_0'
448+
449+
if qubit_type == 'line':
450+
q0 = cirq.LineQubit(0)
451+
q1 = cirq.LineQubit(1)
452+
q0_str = '0'
453+
q1_str = '1'
454+
455+
pairs = [
456+
(cirq.ProjectorSum.from_projector_strings(
457+
cirq.ProjectorString(projector_dict={q0: 0})),
458+
_build_projector_proto([1.0], [[0]], [[q0_str]])),
459+
(cirq.ProjectorSum.from_projector_strings(
460+
cirq.ProjectorString(projector_dict={q0: 0}, coefficient=0.125j)),
461+
_build_projector_proto([0.125j], [[0]], [[q0_str]])),
462+
(cirq.ProjectorSum.from_projector_strings([
463+
cirq.ProjectorString(projector_dict={
464+
q0: 0,
465+
q1: 1
466+
}),
467+
]), _build_projector_proto([1.0], [[0, 1]], [[q0_str, q1_str]])),
468+
]
469+
470+
return pairs
471+
472+
441473
def _get_noise_proto_pairs(qubit_type='grid'):
442474
q0 = cirq.GridQubit(0, 0)
443475
q0_str = '0_0'
@@ -506,6 +538,26 @@ def _build_pauli_proto(coefs, ops, qubit_ids):
506538
return a
507539

508540

541+
def _build_projector_proto(coefs, basis_states, qubit_ids):
542+
"""Construct projector_sum proto explicitly."""
543+
terms = []
544+
for i in range(len(coefs)):
545+
term = projector_sum_pb2.ProjectorTerm()
546+
term.coefficient_real = coefs[i].real
547+
term.coefficient_imag = coefs[i].imag
548+
for j in range(len(qubit_ids[i])):
549+
if basis_states[i][j] == 0:
550+
term.projector_dict.add(qubit_id=qubit_ids[i][j])
551+
else:
552+
term.projector_dict.add(qubit_id=qubit_ids[i][j],
553+
basis_state=True)
554+
terms.append(term)
555+
556+
a = projector_sum_pb2.ProjectorSum()
557+
a.terms.extend(terms)
558+
return a
559+
560+
509561
class SerializerTest(tf.test.TestCase, parameterized.TestCase):
510562
"""Tests basic serializer functionality"""
511563

@@ -715,6 +767,65 @@ def test_serialize_deserialize_paulisum_consistency(self, sum_proto_pair):
715767
serializer.serialize_paulisum(sum_proto_pair[0])),
716768
sum_proto_pair[0])
717769

770+
@parameterized.parameters([{'inp': v} for v in ['wrong', 1.0, None, []]])
771+
def test_serialize_projectorsum_wrong_type(self, inp):
772+
"""Attempt to serialize invalid object types."""
773+
with self.assertRaises(TypeError):
774+
serializer.serialize_projectorsum(inp)
775+
776+
@parameterized.parameters([{'inp': v} for v in ['wrong', 1.0, None, []]])
777+
def test_deserialize_projectorsum_wrong_type(self, inp):
778+
"""Attempt to deserialize invalid object types."""
779+
with self.assertRaises(TypeError):
780+
serializer.deserialize_projectorsum(inp)
781+
782+
def test_serialize_projectorsum_invalid(self):
783+
"""Ensure we don't support anything but GridQubits."""
784+
q0 = cirq.NamedQubit('wont work')
785+
a = cirq.ProjectorSum.from_projector_strings(
786+
cirq.ProjectorString(projector_dict={q0: 0}))
787+
with self.assertRaises(ValueError):
788+
serializer.serialize_projectorsum(a)
789+
790+
@parameterized.parameters(
791+
[{
792+
'sum_proto_pair': v
793+
} for v in _get_valid_projector_proto_pairs(qubit_type='grid') +
794+
_get_valid_projector_proto_pairs(qubit_type='line')])
795+
def test_serialize_projectorsum_simple(self, sum_proto_pair):
796+
"""Ensure serialization is correct."""
797+
self.assertProtoEquals(
798+
sum_proto_pair[1],
799+
serializer.serialize_projectorsum(sum_proto_pair[0]))
800+
801+
@parameterized.parameters(
802+
[{
803+
'sum_proto_pair': v
804+
} for v in _get_valid_projector_proto_pairs(qubit_type='grid') +
805+
_get_valid_projector_proto_pairs(qubit_type='line')])
806+
def test_deserialize_projectorsum_simple(self, sum_proto_pair):
807+
"""Ensure deserialization is correct."""
808+
self.assertEqual(serializer.deserialize_projectorsum(sum_proto_pair[1]),
809+
sum_proto_pair[0])
810+
811+
@parameterized.parameters(
812+
[{
813+
'sum_proto_pair': v
814+
} for v in _get_valid_projector_proto_pairs(qubit_type='grid') +
815+
_get_valid_projector_proto_pairs(qubit_type='line')])
816+
def test_serialize_deserialize_projectorsum_consistency(
817+
self, sum_proto_pair):
818+
"""Serialize and deserialize and ensure nothing changed."""
819+
self.assertEqual(
820+
serializer.serialize_projectorsum(
821+
serializer.deserialize_projectorsum(sum_proto_pair[1])),
822+
sum_proto_pair[1])
823+
824+
self.assertEqual(
825+
serializer.deserialize_projectorsum(
826+
serializer.serialize_projectorsum(sum_proto_pair[0])),
827+
sum_proto_pair[0])
828+
718829
@parameterized.parameters([
719830
{
720831
'gate': cirq.rx(3.0 * sympy.Symbol('alpha'))

tensorflow_quantum/core/src/BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ cc_library(
5050
deps = [
5151
"//tensorflow_quantum/core/proto:pauli_sum_cc_proto",
5252
"//tensorflow_quantum/core/proto:program_cc_proto",
53+
"//tensorflow_quantum/core/proto:projector_sum_cc_proto",
5354
"@com_google_absl//absl/container:flat_hash_map",
5455
"@local_config_tf//:libtensorflow_framework",
5556
"@local_config_tf//:tf_header_lib",
@@ -73,6 +74,7 @@ cc_test(
7374
":circuit_parser_qsim",
7475
"//tensorflow_quantum/core/proto:pauli_sum_cc_proto",
7576
"//tensorflow_quantum/core/proto:program_cc_proto",
77+
"//tensorflow_quantum/core/proto:projector_sum_cc_proto",
7678
"@com_google_absl//absl/container:flat_hash_map",
7779
"@com_google_absl//absl/strings",
7880
"@com_google_googletest//:gtest_main",
@@ -91,6 +93,7 @@ cc_library(
9193
deps = [
9294
":circuit_parser_qsim",
9395
"//tensorflow_quantum/core/proto:pauli_sum_cc_proto",
96+
"//tensorflow_quantum/core/proto:projector_sum_cc_proto",
9497
"@com_google_absl//absl/container:inlined_vector", # unclear why needed.
9598
"@local_config_tf//:libtensorflow_framework",
9699
"@local_config_tf//:tf_header_lib",
@@ -120,6 +123,7 @@ cc_library(
120123
deps = [
121124
"//tensorflow_quantum/core/proto:pauli_sum_cc_proto",
122125
"//tensorflow_quantum/core/proto:program_cc_proto",
126+
"//tensorflow_quantum/core/proto:projector_sum_cc_proto",
123127
"@com_google_absl//absl/container:flat_hash_map",
124128
"@com_google_absl//absl/container:flat_hash_set",
125129
"@com_google_absl//absl/strings",

0 commit comments

Comments
 (0)