Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/torchjd/aggregation/imtl_g.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,13 @@ class _IMTLGWeighting(_Weighting):
"""

def forward(self, matrix: Tensor) -> Tensor:
d = torch.linalg.norm(matrix, dim=1)
v = torch.linalg.pinv(matrix @ matrix.T) @ d
gramian = matrix @ matrix.T
Comment thread
PierreQuinton marked this conversation as resolved.
Outdated
return self._compute_from_gramian(gramian)

@staticmethod
def _compute_from_gramian(gramian: Tensor) -> Tensor:
d = torch.sqrt(torch.diagonal(gramian))
v = torch.linalg.pinv(gramian) @ d
v_sum = v.sum()

if v_sum.abs() < 1e-12:
Expand Down