Skip to content

Commit e73b0f2

Browse files
committed
Make dependence on gramian explicit in IMTL-G
1 parent 67b0c89 commit e73b0f2

1 file changed

Lines changed: 7 additions & 2 deletions

File tree

src/torchjd/aggregation/imtl_g.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,13 @@ class _IMTLGWeighting(_Weighting):
3838
"""
3939

4040
def forward(self, matrix: Tensor) -> Tensor:
41-
d = torch.linalg.norm(matrix, dim=1)
42-
v = torch.linalg.pinv(matrix @ matrix.T) @ d
41+
gramian = matrix @ matrix.T
42+
return self._compute_from_gramian(gramian)
43+
44+
@staticmethod
45+
def _compute_from_gramian(gramian: Tensor) -> Tensor:
46+
d = torch.sqrt(torch.diagonal(gramian))
47+
v = torch.linalg.pinv(gramian) @ d
4348
v_sum = v.sum()
4449

4550
if v_sum.abs() < 1e-12:

0 commit comments

Comments
 (0)