Skip to content

Commit 4806c88

Browse files
committed
Make project_weights public
1 parent 0017a56 commit 4806c88

4 files changed

Lines changed: 9 additions & 9 deletions

File tree

src/torchjd/aggregation/_dual_cone_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torch import Tensor
77

88

9-
def _project_weights(U: Tensor, G: Tensor, solver: Literal["quadprog"]) -> Tensor:
9+
def project_weights(U: Tensor, G: Tensor, solver: Literal["quadprog"]) -> Tensor:
1010
"""
1111
Computes the tensor of weights corresponding to the projection of the vectors in `U` onto the
1212
rows of a matrix whose Gramian is provided.

src/torchjd/aggregation/dualproj.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from torch import Tensor
44

5-
from ._dual_cone_utils import _project_weights
5+
from ._dual_cone_utils import project_weights
66
from ._gramian_utils import _compute_regularized_normalized_gramian
77
from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting
88
from .bases import _WeightedAggregator, _Weighting
@@ -101,5 +101,5 @@ def __init__(
101101
def forward(self, matrix: Tensor) -> Tensor:
102102
u = self.weighting(matrix)
103103
G = _compute_regularized_normalized_gramian(matrix, self.norm_eps, self.reg_eps)
104-
w = _project_weights(u, G, self.solver)
104+
w = project_weights(u, G, self.solver)
105105
return w

src/torchjd/aggregation/upgrad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
from torch import Tensor
55

6-
from ._dual_cone_utils import _project_weights
6+
from ._dual_cone_utils import project_weights
77
from ._gramian_utils import _compute_regularized_normalized_gramian
88
from ._pref_vector_utils import _pref_vector_to_str_suffix, _pref_vector_to_weighting
99
from .bases import _WeightedAggregator, _Weighting
@@ -97,5 +97,5 @@ def __init__(
9797
def forward(self, matrix: Tensor) -> Tensor:
9898
U = torch.diag(self.weighting(matrix))
9999
G = _compute_regularized_normalized_gramian(matrix, self.norm_eps, self.reg_eps)
100-
W = _project_weights(U, G, self.solver)
100+
W = project_weights(U, G, self.solver)
101101
return torch.sum(W, dim=0)

tests/unit/aggregation/test_dual_cone_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from pytest import mark, raises
44
from torch.testing import assert_close
55

6-
from torchjd.aggregation._dual_cone_utils import _project_weight_vector, _project_weights
6+
from torchjd.aggregation._dual_cone_utils import _project_weight_vector, project_weights
77

88

99
@mark.parametrize("shape", [(5, 7), (9, 37), (2, 14), (32, 114), (50, 100)])
@@ -33,7 +33,7 @@ def test_solution_weights(shape: tuple[int, int]):
3333
G = J @ J.T
3434
u = torch.rand(shape[0])
3535

36-
w = _project_weights(u, G, "quadprog")
36+
w = project_weights(u, G, "quadprog")
3737
dual_gap = w - u
3838

3939
# Dual feasibility
@@ -64,8 +64,8 @@ def test_tensorization_shape(shape: tuple[int, ...]):
6464

6565
G = matrix @ matrix.T
6666

67-
W_tensor = _project_weights(U_tensor, G, "quadprog")
68-
W_matrix = _project_weights(U_matrix, G, "quadprog")
67+
W_tensor = project_weights(U_tensor, G, "quadprog")
68+
W_matrix = project_weights(U_matrix, G, "quadprog")
6969

7070
assert_close(W_matrix.reshape(shape), W_tensor)
7171

0 commit comments

Comments
 (0)