Skip to content

Commit 9b1bacf

Browse files
authored
Fix qubit counting counting classical bits (#1816)
Fixes #1811 - adds new helper methods that are less misleading - uses them in the qubit counting code @NoureldinYosri
1 parent 65a1e32 commit 9b1bacf

4 files changed

Lines changed: 70 additions & 8 deletions

File tree

qualtran/_infra/quantum_graph.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,14 +139,47 @@ class Connection:
139139
left: Soquet
140140
right: Soquet
141141

142+
@cached_property
143+
def num_qubits(self) -> int:
144+
"""The number of qubits in the connection.
145+
146+
This excludes classical bits.
147+
"""
148+
lq = self.left.reg.dtype.num_qubits
149+
rq = self.right.reg.dtype.num_qubits
150+
151+
if lq != rq:
152+
raise ValueError(f"Invalid Connection {self}: num_qubits mismatch: {lq} != {rq}")
153+
return lq
154+
155+
@cached_property
156+
def num_cbits(self) -> int:
157+
"""The number of classical bits in the connection."""
158+
lc = self.left.reg.dtype.num_cbits
159+
rc = self.right.reg.dtype.num_cbits
160+
161+
if lc != rc:
162+
raise ValueError(f"Invalid Connection {self}: num_cbits mismatch: {lc} != {rc}")
163+
return lc
164+
165+
@cached_property
166+
def num_bits(self) -> int:
167+
"""The number of bits in the connection (quantum + classical)."""
168+
lb = self.left.reg.dtype.num_bits
169+
rb = self.right.reg.dtype.num_bits
170+
171+
if lb != rb:
172+
raise ValueError(f"Invalid Connection {self}: shape mismatch: {lb} != {rb}")
173+
return lb
174+
142175
@cached_property
143176
def shape(self) -> int:
144-
ls = self.left.reg.bitsize
145-
rs = self.right.reg.bitsize
177+
"""The number of bits in the connection (quantum + classical).
146178
147-
if ls != rs:
148-
raise ValueError(f"Invalid Connection {self}: shape mismatch: {ls} != {rs}")
149-
return ls
179+
This is a misleading name for this property kept for backwards compatibility.
180+
Please prefer `.num_bits`.
181+
"""
182+
return self.num_bits
150183

151184
def __str__(self) -> str:
152185
return f'{self.left} -> {self.right}'

qualtran/drawing/graphviz.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def add_binst(self, graph: pydot.Graph, binst: BloqInstance) -> pydot.Graph:
311311

312312
def cxn_label(self, cxn: Connection) -> str:
313313
"""Overridable method to return labels for connections."""
314-
return str(cxn.shape)
314+
return str(cxn.num_bits)
315315

316316
def cxn_edge(self, left_id: str, right_id: str, cxn: Connection) -> pydot.Edge:
317317
"""Overridable method to style a pydot.Edge for connecionts."""

qualtran/resource_counting/_qubit_counts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,12 @@ def _cbloq_max_width(
6565
# During the application of the binst, we have "observer" connections that have
6666
# width as well as the width from the binst itself. We consider the case where
6767
# the bloq may have a max_width greater than the max of its left/right registers.
68-
during_size = _bloq_max_width(binst.bloq) + sum(s.shape for s in in_play)
68+
during_size = _bloq_max_width(binst.bloq) + sum(s.num_qubits for s in in_play)
6969
max_width = smax(max_width, during_size)
7070

7171
# After the binst, its successor connections are 'in play'.
7272
in_play.update(succ_cxns)
73-
after_size = sum(s.shape for s in in_play)
73+
after_size = sum(s.num_qubits for s in in_play)
7474
max_width = smax(max_width, after_size)
7575

7676
return max_width

qualtran/resource_counting/_qubit_counts_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,35 @@ def test_many_alloc():
9595
assert get_cost_value(bloq, QubitCount()) == 11
9696

9797

98+
def test_qubit_count_ignores_cbits():
99+
import functools
100+
101+
import attr
102+
103+
from qualtran import Bloq, BloqBuilder, CBit, QBit, Register, Signature
104+
105+
@attr.frozen
106+
class MyBloq(Bloq):
107+
n: int
108+
m: int
109+
110+
@functools.cached_property
111+
def signature(self):
112+
return Signature(
113+
[Register('qs', QBit(), shape=(self.n,)), Register('cs', CBit(), shape=(self.m,))]
114+
)
115+
116+
def build_call_graph(self, ssa):
117+
return {}
118+
119+
def build_composite_bloq(self, bb: BloqBuilder, qs, cs):
120+
return {'qs': qs, 'cs': cs}
121+
122+
blq = MyBloq(5, 100)
123+
assert get_cost_value(blq, QubitCount()) == 5
124+
assert get_cost_value(blq.decompose_bloq(), QubitCount()) == 5
125+
126+
98127
@pytest.mark.notebook
99128
def test_notebook():
100129
qlt_testing.execute_notebook("qubit_counts")

0 commit comments

Comments
 (0)