1010from ._weighting_bases import Weighting
1111
1212
13+ class DualProjWeighting (Weighting [PSDMatrix ]):
14+ r"""
15+ :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
16+ :class:`~torchjd.aggregation.DualProj`.
17+
18+ :param pref_vector: The preference vector to use. If not provided, defaults to
19+ :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`.
20+ :param norm_eps: A small value to avoid division by zero when normalizing.
21+ :param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to
22+ numerical errors when computing the gramian, it might not exactly be positive definite.
23+ This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian
24+ ensures that it is positive definite.
25+ :param solver: The solver used to optimize the underlying optimization problem.
26+ """
27+
28+ def __init__ (
29+ self ,
30+ pref_vector : Tensor | None = None ,
31+ norm_eps : float = 0.0001 ,
32+ reg_eps : float = 0.0001 ,
33+ solver : SUPPORTED_SOLVER = "quadprog" ,
34+ ) -> None :
35+ super ().__init__ ()
36+ self ._pref_vector = pref_vector
37+ self .weighting = pref_vector_to_weighting (pref_vector , default = MeanWeighting ())
38+ self .norm_eps = norm_eps
39+ self .reg_eps = reg_eps
40+ self .solver : SUPPORTED_SOLVER = solver
41+
42+ def forward (self , gramian : PSDMatrix , / ) -> Tensor :
43+ u = self .weighting (gramian )
44+ G = regularize (normalize (gramian , self .norm_eps ), self .reg_eps )
45+ w = project_weights (u , G , self .solver )
46+ return w
47+
48+
1349class DualProj (GramianWeightedAggregator ):
1450 r"""
1551 :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that averages the rows of the input
@@ -27,6 +63,8 @@ class DualProj(GramianWeightedAggregator):
2763 :param solver: The solver used to optimize the underlying optimization problem.
2864 """
2965
66+ gramian_weighting : DualProjWeighting
67+
3068 def __init__ (
3169 self ,
3270 pref_vector : Tensor | None = None ,
@@ -54,39 +92,3 @@ def __repr__(self) -> str:
5492
5593 def __str__ (self ) -> str :
5694 return f"DualProj{ pref_vector_to_str_suffix (self ._pref_vector )} "
57-
58-
59- class DualProjWeighting (Weighting [PSDMatrix ]):
60- r"""
61- :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
62- :class:`~torchjd.aggregation.DualProj`.
63-
64- :param pref_vector: The preference vector to use. If not provided, defaults to
65- :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`.
66- :param norm_eps: A small value to avoid division by zero when normalizing.
67- :param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to
68- numerical errors when computing the gramian, it might not exactly be positive definite.
69- This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian
70- ensures that it is positive definite.
71- :param solver: The solver used to optimize the underlying optimization problem.
72- """
73-
74- def __init__ (
75- self ,
76- pref_vector : Tensor | None = None ,
77- norm_eps : float = 0.0001 ,
78- reg_eps : float = 0.0001 ,
79- solver : SUPPORTED_SOLVER = "quadprog" ,
80- ) -> None :
81- super ().__init__ ()
82- self ._pref_vector = pref_vector
83- self .weighting = pref_vector_to_weighting (pref_vector , default = MeanWeighting ())
84- self .norm_eps = norm_eps
85- self .reg_eps = reg_eps
86- self .solver : SUPPORTED_SOLVER = solver
87-
88- def forward (self , gramian : PSDMatrix , / ) -> Tensor :
89- u = self .weighting (gramian )
90- G = regularize (normalize (gramian , self .norm_eps ), self .reg_eps )
91- w = project_weights (u , G , self .solver )
92- return w
0 commit comments