Skip to content

Commit 0700441

Browse files
use new random_pauli_list (#379)
1 parent 2b148f4 commit 0700441

5 files changed

Lines changed: 21 additions & 9 deletions

File tree

src/squlearn/qrc/base_qrc.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ def validate_data(self, *args, **kwargs):
1717
return self._validate_data(*args, **kwargs)
1818

1919

20-
from qiskit.quantum_info.random import random_pauli_list
21-
2220
from ..observables.observable_base import ObservableBase
2321
from ..observables import CustomObservable, SinglePauli
2422
from ..encoding_circuit.encoding_circuit_base import EncodingCircuitBase
@@ -27,6 +25,8 @@ def validate_data(self, *args, **kwargs):
2725
from ..qnn.lowlevel_qnn import LowLevelQNN
2826
from ..util.serialization import SerializableModelMixin
2927

28+
from .util.random_pauli_list import random_pauli_list
29+
3030

3131
class BaseQRC(BaseEstimator, SerializableModelMixin, ABC):
3232
"""Base class for Quantum Reservoir Computing (QRC) models.
@@ -65,7 +65,7 @@ def __init__(
6565
ml_model: str = "linear",
6666
ml_model_options: Union[dict, None] = None,
6767
operators: Union[ObservableBase, list[ObservableBase], str] = "random_paulis",
68-
num_operators: int = 100,
68+
num_operators: int = None,
6969
operator_seed: int = 0,
7070
param_ini: Union[np.ndarray, None] = None,
7171
param_op_ini: Union[np.ndarray, None] = None,
@@ -89,6 +89,19 @@ def __init__(
8989
self._is_lowlevel_qnn_initialized = False
9090
self._qnn = None
9191

92+
if self.operators == "random_paulis":
93+
if (
94+
self.num_operators is not None
95+
and self.num_operators > 4**self.encoding_circuit.num_qubits
96+
):
97+
raise ValueError(
98+
f"Number of operators ({self.num_operators}) exceeds the maximum possible "
99+
f"({4**self.encoding_circuit.num_qubits}) for the given number of qubits "
100+
f"({self.encoding_circuit.num_qubits})."
101+
)
102+
elif self.num_operators is None:
103+
self.num_operators = min(100, 4**self.encoding_circuit.num_qubits)
104+
92105
self._ml_model = None
93106
self._initialize_observables()
94107
self._initialize_ml_model()
@@ -152,7 +165,6 @@ def _initialize_observables(self) -> None:
152165
self.encoding_circuit.num_qubits,
153166
self.num_operators,
154167
seed=self.operator_seed,
155-
phase=False,
156168
)
157169
self._operators = [
158170
CustomObservable(self.encoding_circuit.num_qubits, str(p)) for p in paulis

src/squlearn/qrc/qrc_classifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(
9494
ml_model: str = "linear",
9595
ml_model_options: Union[dict, None] = None,
9696
operators: Union[ObservableBase, list[ObservableBase], str] = "random_paulis",
97-
num_operators: int = 100,
97+
num_operators: int = None,
9898
operator_seed: int = 0,
9999
param_ini: Union[np.ndarray, None] = None,
100100
param_op_ini: Union[np.ndarray, None] = None,

src/squlearn/qrc/qrc_regressor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def __init__(
9292
ml_model: str = "linear",
9393
ml_model_options: Union[dict, None] = None,
9494
operators: Union[ObservableBase, list[ObservableBase], str] = "random_paulis",
95-
num_operators: int = 100,
95+
num_operators: int = None,
9696
operator_seed: int = 0,
9797
param_ini: Union[np.ndarray, None] = None,
9898
param_op_ini: Union[np.ndarray, None] = None,

tests/qrc/test_qrc_classifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_fit_predict(self, data, ml_model):
4949
values = qrc_classifier.predict(X)
5050

5151
referece_values = {
52-
"linear": np.array([0, 1, 0, 1, 1, 1]),
52+
"linear": np.array([0, 1, 0, 1, 0, 1]),
5353
"mlp": np.array([0, 1, 0, 1, 0, 1]),
5454
"kernel": np.array([0, 1, 0, 0, 0, 1]),
5555
}

tests/qrc/test_qrc_regressor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,10 @@ def test_fit_predict(self, data, ml_model):
5353
[2.88509523, -0.80308901, 3.76200899, -1.35995202, 8.84630756, -1.36004739]
5454
),
5555
"mlp": np.array(
56-
[3.06585455, -0.96776248, 4.0395428, -1.52636054, 8.80042369, -1.52645433]
56+
[3.04943393, -0.87369101, 3.98814234, -1.42608296, 8.61361362, -1.42617601]
5757
),
5858
"kernel": np.array(
59-
[2.87653068, -0.65406699, 3.74209315, -1.13894099, 8.15484098, -1.13902237]
59+
[2.54250631, 0.06183984, 3.13005284, -0.29406904, 6.04568976, -0.29412921]
6060
),
6161
}
6262

0 commit comments

Comments
 (0)