Skip to content

Commit 30b65a0

Browse files
authored
feat(aggregation): Add getters and setters to UPGrad parameters (#658)
* Add reg_eps getter and setter to UPGrad and UPGradWeighting * Add norm_eps getter and setter to UPGrad and UPGradWeighting * Add pref_vector getter and setter to UPGrad and UPGradWeighting * Add tests for setters * Add changelog entry
1 parent 4597af8 commit 30b65a0

3 files changed

Lines changed: 115 additions & 10 deletions

File tree

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ changelog does not include internal changes that do not affect the user.
88

99
## [Unreleased]
1010

11+
### Added
12+
13+
- Added `pref_vector`, `norm_eps`, and `reg_eps` getters and setters to `UPGrad` and
14+
`UPGradWeighting`. The setters for `norm_eps` and `reg_eps` validate that the assigned value is
15+
non-negative.
16+
1117
## [0.10.0] - 2026-04-16
1218

1319
### Added

src/torchjd/aggregation/_upgrad.py

Lines changed: 61 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ def __init__(
3434
solver: SUPPORTED_SOLVER = "quadprog",
3535
) -> None:
3636
super().__init__()
37-
self._pref_vector = pref_vector
38-
self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting())
37+
self.pref_vector = pref_vector
3938
self.norm_eps = norm_eps
4039
self.reg_eps = reg_eps
4140
self.solver: SUPPORTED_SOLVER = solver
@@ -46,6 +45,39 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
4645
W = project_weights(U, G, self.solver)
4746
return torch.sum(W, dim=0)
4847

48+
@property
49+
def pref_vector(self) -> Tensor | None:
50+
return self._pref_vector
51+
52+
@pref_vector.setter
53+
def pref_vector(self, value: Tensor | None) -> None:
54+
self.weighting = pref_vector_to_weighting(value, default=MeanWeighting())
55+
self._pref_vector = value
56+
57+
@property
58+
def norm_eps(self) -> float:
59+
return self._norm_eps
60+
61+
@norm_eps.setter
62+
def norm_eps(self, value: float) -> None:
63+
64+
if value < 0:
65+
raise ValueError(f"norm_eps must be non-negative, but got {value}.")
66+
67+
self._norm_eps = value
68+
69+
@property
70+
def reg_eps(self) -> float:
71+
return self._reg_eps
72+
73+
@reg_eps.setter
74+
def reg_eps(self, value: float) -> None:
75+
76+
if value < 0:
77+
raise ValueError(f"reg_eps must be non-negative, but got {value}.")
78+
79+
self._reg_eps = value
80+
4981

5082
class UPGrad(GramianWeightedAggregator):
5183
r"""
@@ -73,9 +105,6 @@ def __init__(
73105
reg_eps: float = 0.0001,
74106
solver: SUPPORTED_SOLVER = "quadprog",
75107
) -> None:
76-
self._pref_vector = pref_vector
77-
self._norm_eps = norm_eps
78-
self._reg_eps = reg_eps
79108
self._solver: SUPPORTED_SOLVER = solver
80109

81110
super().__init__(
@@ -85,11 +114,35 @@ def __init__(
85114
# This prevents considering the computed weights as constant w.r.t. the matrix.
86115
self.register_full_backward_pre_hook(raise_non_differentiable_error)
87116

117+
@property
118+
def pref_vector(self) -> Tensor | None:
119+
return self.gramian_weighting.pref_vector
120+
121+
@pref_vector.setter
122+
def pref_vector(self, value: Tensor | None) -> None:
123+
self.gramian_weighting.pref_vector = value
124+
125+
@property
126+
def norm_eps(self) -> float:
127+
return self.gramian_weighting.norm_eps
128+
129+
@norm_eps.setter
130+
def norm_eps(self, value: float) -> None:
131+
self.gramian_weighting.norm_eps = value
132+
133+
@property
134+
def reg_eps(self) -> float:
135+
return self.gramian_weighting.reg_eps
136+
137+
@reg_eps.setter
138+
def reg_eps(self, value: float) -> None:
139+
self.gramian_weighting.reg_eps = value
140+
88141
def __repr__(self) -> str:
89142
return (
90-
f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, norm_eps="
91-
f"{self._norm_eps}, reg_eps={self._reg_eps}, solver={repr(self._solver)})"
143+
f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, norm_eps="
144+
f"{self.norm_eps}, reg_eps={self.reg_eps}, solver={repr(self._solver)})"
92145
)
93146

94147
def __str__(self) -> str:
95-
return f"UPGrad{pref_vector_to_str_suffix(self._pref_vector)}"
148+
return f"UPGrad{pref_vector_to_str_suffix(self.pref_vector)}"

tests/unit/aggregation/test_upgrad.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import torch
2-
from pytest import mark
2+
from pytest import mark, raises
33
from torch import Tensor
44
from utils.tensors import ones_
55

6-
from torchjd.aggregation import UPGrad
6+
from torchjd.aggregation import ConstantWeighting, UPGrad
7+
from torchjd.aggregation._upgrad import UPGradWeighting
78

89
from ._asserts import (
910
assert_expected_structure,
@@ -67,3 +68,48 @@ def test_representations() -> None:
6768
"solver='quadprog')"
6869
)
6970
assert str(A) == "UPGrad([1., 2., 3.])"
71+
72+
73+
def test_pref_vector_setter_updates_value() -> None:
74+
A = UPGrad()
75+
new_pref = torch.tensor([1.0, 2.0, 3.0])
76+
A.pref_vector = new_pref
77+
assert A.pref_vector is new_pref
78+
assert isinstance(A.gramian_weighting.weighting, ConstantWeighting)
79+
assert A.gramian_weighting.weighting.weights is new_pref
80+
81+
82+
def test_norm_eps_setter_updates_value() -> None:
83+
A = UPGrad()
84+
A.norm_eps = 0.25
85+
assert A.norm_eps == 0.25
86+
87+
88+
def test_reg_eps_setter_updates_value() -> None:
89+
A = UPGrad()
90+
A.reg_eps = 0.25
91+
assert A.reg_eps == 0.25
92+
93+
94+
def test_norm_eps_setter_rejects_negative() -> None:
95+
A = UPGrad()
96+
with raises(ValueError, match="norm_eps"):
97+
A.norm_eps = -1e-9
98+
99+
100+
def test_reg_eps_setter_rejects_negative() -> None:
101+
A = UPGrad()
102+
with raises(ValueError, match="reg_eps"):
103+
A.reg_eps = -1e-9
104+
105+
106+
def test_weighting_norm_eps_setter_rejects_negative() -> None:
107+
W = UPGradWeighting()
108+
with raises(ValueError, match="norm_eps"):
109+
W.norm_eps = -1e-9
110+
111+
112+
def test_weighting_reg_eps_setter_rejects_negative() -> None:
113+
W = UPGradWeighting()
114+
with raises(ValueError, match="reg_eps"):
115+
W.reg_eps = -1e-9

0 commit comments

Comments
 (0)