Skip to content

Commit b8776a5

Browse files
anurudhpmpharrigan
andauthored
Add ctrl system for LinearCombination (#1595)
* add failing tests * add ctrl flag and system for `LinearCombination` --------- Co-authored-by: Matthew Harrigan <mpharrigan@google.com>
1 parent 4510995 commit b8776a5

5 files changed

Lines changed: 66 additions & 18 deletions

File tree

qualtran/bloqs/block_encoding/chebyshev_polynomial.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
from collections import Counter
1515
from functools import cached_property
1616
from typing import Dict, Tuple, TYPE_CHECKING, Union
1717

1818
import attrs
1919
import numpy as np
2020

2121
from qualtran import (
22+
Bloq,
2223
bloq_example,
2324
BloqBuilder,
2425
BloqDocSpec,
@@ -36,11 +37,7 @@
3637
from qualtran.symbolics import is_symbolic, SymbolicFloat, SymbolicInt
3738

3839
if TYPE_CHECKING:
39-
from qualtran.resource_counting import (
40-
BloqCountDictT,
41-
MutableBloqCountDictT,
42-
SympySymbolAllocator,
43-
)
40+
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator
4441

4542

4643
@attrs.frozen
@@ -144,12 +141,11 @@ def build_composite_bloq(self, bb: BloqBuilder, **soqs: SoquetT) -> Dict[str, So
144141

145142
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
146143
n = self.order
147-
s: 'MutableBloqCountDictT' = {
148-
self.block_encoding: n // 2 + n % 2,
149-
self.block_encoding.adjoint(): n // 2,
150-
}
144+
s = Counter[Bloq]()
145+
s[self.block_encoding] += n // 2 + n % 2
146+
s[self.block_encoding.adjoint()] += n // 2
151147
if is_symbolic(self.ancilla_bitsize) or self.ancilla_bitsize > 0:
152-
s[self.reflection_bloq] = n - n % 2
148+
s[self.reflection_bloq] += n - n % 2
153149
return s
154150

155151

qualtran/bloqs/block_encoding/chebyshev_polynomial_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from qualtran.bloqs.reflections.prepare_identity import PrepareIdentity
3535
from qualtran.bloqs.state_preparation.black_box_prepare import BlackBoxPrepare
3636
from qualtran.linalg.matrix import random_hermitian_matrix
37+
from qualtran.resource_counting import get_cost_value, QECGatesCost
3738
from qualtran.symbolics import is_symbolic, SymbolicFloat, SymbolicInt
3839
from qualtran.testing import assert_equivalent_bloq_example_counts, execute_notebook
3940

@@ -156,6 +157,11 @@ def test_scaled_chebyshev_even_tensors():
156157
np.testing.assert_allclose(from_gate, from_tensors, atol=0.06)
157158

158159

160+
def test_scaled_chebyshev_even_cost():
161+
bloq = _scaled_chebyshev_poly_even()
162+
_ = get_cost_value(bloq, QECGatesCost())
163+
164+
159165
@pytest.mark.slow
160166
def test_scaled_chebyshev_odd_tensors():
161167
from_gate = t5(Hadamard().tensor_contract() * 3.14)

qualtran/bloqs/block_encoding/linear_combination.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@
5656
" - `lambd`: Corresponding coefficients.\n",
5757
" - `lambd_bits`: Number of bits needed to represent coefficients precisely.\n",
5858
" - `prepare`: If specified, oracle preparing $\\sum_i \\sqrt{|\\lambda_i|} |i\\rangle$ (state should be normalized and can have junk).\n",
59-
" - `select`: If specified, oracle taking $|i\\rangle|\\psi\\rangle \\mapsto \\text{sgn}(\\lambda_i) |i\\rangle U_i|\\psi\\rangle$. \n",
59+
" - `select`: If specified, oracle taking $|i\\rangle|\\psi\\rangle \\mapsto \\text{sgn}(\\lambda_i) |i\\rangle U_i|\\psi\\rangle$.\n",
60+
" - `is_controlled`: if True, implements a controlled version. Defaults to False. \n",
6061
"\n",
6162
"#### Registers\n",
6263
" - `system`: The system register.\n",

qualtran/bloqs/block_encoding/linear_combination.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,20 @@
2020
from typing_extensions import Self
2121

2222
from qualtran import (
23+
AddControlledT,
24+
Bloq,
2325
bloq_example,
2426
BloqBuilder,
2527
BloqDocSpec,
2628
BQUInt,
29+
CtrlSpec,
30+
DecomposeTypeError,
2731
QAny,
2832
Register,
2933
Signature,
3034
Soquet,
3135
SoquetT,
3236
)
33-
from qualtran._infra.bloq import DecomposeTypeError
3437
from qualtran.bloqs.block_encoding import BlockEncoding
3538
from qualtran.bloqs.block_encoding.lcu_block_encoding import BlackBoxPrepare, BlackBoxSelect
3639
from qualtran.bloqs.block_encoding.phase import Phase
@@ -68,6 +71,7 @@ class LinearCombination(BlockEncoding):
6871
(state should be normalized and can have junk).
6972
select: If specified, oracle taking
7073
$|i\rangle|\psi\rangle \mapsto \text{sgn}(\lambda_i) |i\rangle U_i|\psi\rangle$.
74+
is_controlled: if True, implements a controlled version. Defaults to False.
7175
7276
Registers:
7377
system: The system register.
@@ -88,6 +92,8 @@ class LinearCombination(BlockEncoding):
8892
_prepare: Optional[BlackBoxPrepare] = None
8993
_select: Optional[BlackBoxSelect] = None
9094

95+
is_controlled: bool = False
96+
9197
def __attrs_post_init__(self):
9298
if len(self._block_encodings) != len(self._lambd):
9399
raise ValueError("Must provide the same number of block encodings and coefficients.")
@@ -138,6 +144,7 @@ def rescaled_lambd(self):
138144
@cached_property
139145
def signature(self) -> Signature:
140146
return Signature.build_from_dtypes(
147+
ctrl=QAny(1 if self.is_controlled else 0),
141148
system=QAny(self.system_bitsize),
142149
ancilla=QAny(self.ancilla_bitsize),
143150
resource=QAny(self.resource_bitsize),
@@ -298,11 +305,26 @@ def build_composite_bloq(
298305
be_part = Partition(self.select.system_bitsize, tuple(be_regs))
299306

300307
prepare_soqs = bb.add_d(self.prepare, **prepare_in_soqs)
301-
select_out_soqs = bb.add_d(
302-
self.select,
303-
selection=prepare_soqs.pop("selection"),
304-
system=cast(Soquet, bb.add(evolve(be_part, partition=False), **be_system_soqs)),
305-
)
308+
309+
if not self.is_controlled:
310+
select_out_soqs = bb.add_d(
311+
self.select,
312+
selection=prepare_soqs.pop("selection"),
313+
system=cast(Soquet, bb.add(evolve(be_part, partition=False), **be_system_soqs)),
314+
)
315+
else:
316+
_, add_ctrl_select = self.select.get_ctrl_system(CtrlSpec())
317+
(ctrl,), select_out_soqs_t = add_ctrl_select(
318+
bb,
319+
[soqs.pop('ctrl')],
320+
dict(
321+
selection=prepare_soqs.pop("selection"),
322+
system=cast(Soquet, bb.add(evolve(be_part, partition=False), **be_system_soqs)),
323+
),
324+
)
325+
select_out_soqs = dict(zip(["selection", "system"], select_out_soqs_t))
326+
select_out_soqs["ctrl"] = ctrl
327+
306328
prep_adj_soqs = bb.add_d(
307329
self.prepare.adjoint(), selection=select_out_soqs.pop("selection"), **prepare_soqs
308330
)
@@ -311,6 +333,9 @@ def build_composite_bloq(
311333
be_soqs = bb.add_d(be_part, x=select_out_soqs.pop("system"))
312334
out: Dict[str, SoquetT] = {"system": be_soqs.pop("system")}
313335

336+
if self.is_controlled:
337+
out["ctrl"] = select_out_soqs.pop("ctrl")
338+
314339
# merge ancilla registers of block encoding and Prepare oracle
315340
anc_soqs = {"selection": prep_adj_soqs.pop("selection")}
316341
if self.be_ancilla_bitsize > 0:
@@ -331,6 +356,20 @@ def build_composite_bloq(
331356
def __str__(self) -> str:
332357
return f"B[{'+'.join(str(be)[2:-1] for be in self.signed_block_encodings)}]"
333358

359+
def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> tuple['Bloq', 'AddControlledT']:
360+
from qualtran.bloqs.mcmt.specialized_ctrl import get_ctrl_system_1bit_cv_from_bloqs
361+
362+
return get_ctrl_system_1bit_cv_from_bloqs(
363+
self,
364+
ctrl_spec,
365+
current_ctrl_bit=1 if self.is_controlled else None,
366+
bloq_with_ctrl=evolve(self, is_controlled=True),
367+
ctrl_reg_name='ctrl',
368+
)
369+
370+
def adjoint(self) -> 'LinearCombination':
371+
return self
372+
334373

335374
@bloq_example
336375
def _linear_combination_block_encoding() -> LinearCombination:

qualtran/bloqs/block_encoding/linear_combination_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from qualtran.bloqs.block_encoding.unitary import Unitary
4040
from qualtran.bloqs.for_testing.matrix_gate import MatrixGate
4141
from qualtran.bloqs.reflections.prepare_identity import PrepareIdentity
42+
from qualtran.resource_counting import get_cost_value, QECGatesCost
4243
from qualtran.testing import execute_notebook
4344

4445

@@ -103,6 +104,11 @@ def test_linear_combination_tensors():
103104
np.testing.assert_allclose(from_gate, from_tensors)
104105

105106

107+
def test_linear_combination_cost():
108+
bloq = _linear_combination_block_encoding()
109+
_ = get_cost_value(bloq.controlled(), QECGatesCost())
110+
111+
106112
def run_gate_test(gates, lambd, lambd_bits=1, atol=1e-07):
107113
bloq = LinearCombination(tuple(Unitary(g) for g in gates), lambd, lambd_bits)
108114
from_gate = sum(l * g.tensor_contract() for l, g in zip(lambd, gates))

0 commit comments

Comments
 (0)