Skip to content

Commit 4ca30bd

Browse files
Fix numpy scalar extraction (#372)
* extract numpy scalars properly * change to one-liner * Update src/squlearn/optimizers/approximated_gradients.py Co-authored-by: Florian Wieland <114916947+ProfessorNova@users.noreply.github.com> * Update src/squlearn/optimizers/approximated_gradients.py Co-authored-by: Florian Wieland <114916947+ProfessorNova@users.noreply.github.com> * Update src/squlearn/optimizers/approximated_gradients.py Co-authored-by: Florian Wieland <114916947+ProfessorNova@users.noreply.github.com> --------- Co-authored-by: Florian Wieland <114916947+ProfessorNova@users.noreply.github.com>
1 parent c6899ad commit 4ca30bd

2 files changed

Lines changed: 13 additions & 14 deletions

File tree

src/squlearn/kernel/lowlevel_kernel/projected_quantum_kernel.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -961,10 +961,8 @@ def __call__(
961961
param = parameters[: qnn.num_parameters]
962962
param_op = parameters[qnn.num_parameters :]
963963

964-
if len(param.shape) == 1 and len(param) == 1:
965-
param = float(param)
966-
if len(param_op.shape) == 1 and len(param_op) == 1:
967-
param_op = float(param_op)
964+
param = param.item() if param.size == 1 else param
965+
param_op = param_op.item() if param_op.size == 1 else param_op
968966

969967
x_result = qnn.evaluate(x, param, param_op, "f")["f"]
970968
if y is not None:

src/squlearn/optimizers/approximated_gradients.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,33 +66,34 @@ def gradient(self, x: np.ndarray) -> np.ndarray:
6666
g = np.zeros(len(x))
6767
for i in range(len(x)):
6868
dx = np.eye(1, len(x), k=i)[0] * self.eps
69-
g[i] = ((self.fun(x + dx) - f0)) / self.eps # Extract scalar
69+
g[i] = ((self.fun(x + dx) - f0) / self.eps).item()
7070

7171
elif self.formula == "backwards":
7272
f0 = self.fun(x)
7373
g = np.zeros(len(x))
7474
for i in range(len(x)):
7575
dx = np.eye(1, len(x), k=i)[0] * self.eps
76-
g[i] = ((f0 - self.fun(x - dx))) / self.eps # Extract scalar
76+
g[i] = ((f0 - self.fun(x - dx)) / self.eps).item()
7777

7878
elif self.formula == "central":
7979
g = np.zeros(len(x))
8080
for i in range(len(x)):
8181
dx = np.eye(1, len(x), k=i)[0] * self.eps
82-
g[i] = ((self.fun(x + dx) - self.fun(x - dx))) / (2.0 * self.eps) # Extract scalar
82+
g[i] = ((self.fun(x + dx) - self.fun(x - dx)) / (2.0 * self.eps)).item()
8383

8484
elif self.formula == "five-point":
8585
g = np.zeros(len(x))
8686
for i in range(len(x)):
8787
dx = np.eye(1, len(x), k=i)[0] * self.eps
8888
g[i] = (
89-
-1.0 * self.fun(x + 2.0 * dx)
90-
+ 8.0 * self.fun(x + 1.0 * dx)
91-
- 8.0 * self.fun(x - 1.0 * dx)
92-
+ 1.0 * self.fun(x - 2.0 * dx)
93-
) / (
94-
12.0 * self.eps
95-
) # Extract scalar
89+
(
90+
-1.0 * self.fun(x + 2.0 * dx)
91+
+ 8.0 * self.fun(x + 1.0 * dx)
92+
- 8.0 * self.fun(x - 1.0 * dx)
93+
+ 1.0 * self.fun(x - 2.0 * dx)
94+
)
95+
/ (12.0 * self.eps)
96+
).item()
9697
else:
9798
raise ValueError("Wrong value of type: " + self.formula)
9899

0 commit comments

Comments
 (0)