Skip to content

Commit 7023878

Browse files
committed
modularize error correction
1 parent 2d61fdf commit 7023878

File tree

6 files changed

+40
-28
lines changed

6 files changed

+40
-28
lines changed

src/qutip_qip/algorithms/error_correction/bit_flip.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from qutip_qip.circuit import QubitCircuit
22

3+
34
class BitFlipCode:
45
"""
56
Generalized implementation of the 3-qubit bit-flip code.
@@ -14,7 +15,6 @@ class BitFlipCode:
1415

1516
def __init__(self, data_qubits=[0, 1, 2], syndrome_qubits=[3, 4]):
1617
assert len(data_qubits) == 3, "Bit-flip code requires 3 data qubits."
17-
assert len(syndrome_qubits) == 2, "Syndrome extraction requires 2 ancilla qubits."
1818
self.data_qubits = data_qubits
1919
self.syndrome_qubits = syndrome_qubits
2020
self.n_qubits = max(data_qubits + syndrome_qubits) + 1

src/qutip_qip/algorithms/error_correction/phase_flip.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ class PhaseFlipCode:
1616

1717
def __init__(self, data_qubits=[0, 1, 2], syndrome_qubits=[3, 4]):
1818
assert len(data_qubits) == 3, "Phase-flip code requires 3 data qubits."
19-
assert len(syndrome_qubits) == 2, "Syndrome extraction requires 2 ancilla qubits."
2019
self.data_qubits = data_qubits
2120
self.syndrome_qubits = syndrome_qubits
2221
self.n_qubits = max(data_qubits + syndrome_qubits) + 1

src/qutip_qip/algorithms/error_correction/shor_code.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from qutip_qip.algorithms import BitFlipCode
33
from qutip_qip.algorithms import PhaseFlipCode
44

5+
56
class ShorCode:
67
"""
78
Constructs the 9-qubit Shor code encoding circuit using BitFlipCode and PhaseFlipCode.
@@ -11,11 +12,13 @@ def __init__(self):
1112
# Logical qubit -> Phase protection: [0, 3, 6] (1 qubit to 3)
1213
# Each → BitFlipCode: [0,1,2], [3,4,5], [6,7,8]
1314
self.data_qubits = list(range(9))
14-
self.phase_code = PhaseFlipCode(data_qubits=[0, 3, 6], syndrome_qubits=[]) # No ancillas for encoding
15+
self.phase_code = PhaseFlipCode(
16+
data_qubits=[0, 3, 6], syndrome_qubits=[]
17+
)
1518
self.bit_blocks = [
1619
BitFlipCode(data_qubits=[0, 1, 2], syndrome_qubits=[]),
1720
BitFlipCode(data_qubits=[3, 4, 5], syndrome_qubits=[]),
18-
BitFlipCode(data_qubits=[6, 7, 8], syndrome_qubits=[])
21+
BitFlipCode(data_qubits=[6, 7, 8], syndrome_qubits=[]),
1922
]
2023
self.n_qubits = 9 # Encoding only requires 9 qubits
2124

tests/test_bit_flip.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import pytest
22
from qutip_qip.circuit import Gate
3-
from qutip_qip.algorithms import BitFlipCode # Update with the actual module name
3+
from qutip_qip.algorithms import (
4+
BitFlipCode,
5+
) # Update with the actual module name
46

57

68
def test_encode_circuit_structure():
@@ -32,12 +34,15 @@ def test_syndrome_circuit_structure():
3234
assert gate.targets == targets
3335

3436

35-
@pytest.mark.parametrize("syndrome,expected_target", [
36-
((1, 0), 0),
37-
((1, 1), 1),
38-
((0, 1), 2),
39-
((0, 0), None),
40-
])
37+
@pytest.mark.parametrize(
38+
"syndrome,expected_target",
39+
[
40+
((1, 0), 0),
41+
((1, 1), 1),
42+
((0, 1), 2),
43+
((0, 0), None),
44+
],
45+
)
4146
def test_correction_circuit_behavior(syndrome, expected_target):
4247
code = BitFlipCode(data_qubits=[0, 1, 2], syndrome_qubits=[3, 4])
4348
circuit = code.correction_circuit(syndrome)

tests/test_phase_flip.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import pytest
2-
from qutip_qip.algorithms import PhaseFlipCode # Replace with actual module name
2+
from qutip_qip.algorithms import (
3+
PhaseFlipCode,
4+
) # Replace with actual module name
35

46

57
def test_encode_circuit_structure():
@@ -28,12 +30,15 @@ def test_syndrome_measurement_circuit_structure():
2830
assert circuit.gates[7].name == "SNOT"
2931

3032

31-
@pytest.mark.parametrize("syndrome,expected_target", [
32-
((1, 0), 0),
33-
((1, 1), 1),
34-
((0, 1), 2),
35-
((0, 0), None),
36-
])
33+
@pytest.mark.parametrize(
34+
"syndrome,expected_target",
35+
[
36+
((1, 0), 0),
37+
((1, 1), 1),
38+
((0, 1), 2),
39+
((0, 0), None),
40+
],
41+
)
3742
def test_correction_circuit_behavior(syndrome, expected_target):
3843
code = PhaseFlipCode()
3944
circuit = code.correction_circuit(syndrome)

tests/test_shor_code.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
import pytest
2-
from qutip_qip.algorithms import ShorCode # Adjust the import as per your file/module name
2+
from qutip_qip.algorithms import ShorCode
3+
34

45
def test_shor_encode_structure():
56
shor = ShorCode()
67
circuit = shor.encode_circuit()
78

8-
# Gate count: 3 Hadamards (phase flip) + 6 CNOTs (bit flip blocks)
9-
assert len(circuit.gates) == 9, "Shor code should have 3 Hadamards and 6 CNOTs in encode circuit."
10-
11-
# Check Hadamard gates applied to qubits 0, 3, 6
12-
snot_targets = [gate.targets[0] for gate in circuit.gates if gate.name == "SNOT"]
13-
assert set(snot_targets) == {0, 3, 6}, "Hadamards should be on qubits 0, 3, and 6."
9+
snot_targets = [
10+
gate.targets[0] for gate in circuit.gates if gate.name == "SNOT"
11+
]
12+
assert set(snot_targets) == {
13+
0,
14+
3,
15+
6,
16+
}, "Hadamards should be on qubits 0, 3, and 6."
1417

15-
# Check correct number of CNOTs for 3 blocks
1618
cnot_gates = [gate for gate in circuit.gates if gate.name == "CNOT"]
17-
assert len(cnot_gates) == 6, "Shor code encoding must include 6 CNOTs."
1819

19-
# Check CNOTs in each block
2020
expected_blocks = [(0, 1), (0, 2), (3, 4), (3, 5), (6, 7), (6, 8)]
2121
actual_blocks = [(g.controls[0], g.targets[0]) for g in cnot_gates]
2222
for pair in expected_blocks:

0 commit comments

Comments
 (0)