Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions src/torchjd/aggregation/aligned_mtl.py
Comment thread
ValerianRey marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,14 @@ def __init__(self, weighting: _Weighting):
def forward(self, matrix: Tensor) -> Tensor:
w = self.weighting(matrix)

G = matrix.T
B = self._compute_balance_transformation(G)
M = matrix @ matrix.T
Comment thread
PierreQuinton marked this conversation as resolved.
Outdated
B = self._compute_balance_transformation(M)
alpha = B @ w

return alpha

@staticmethod
def _compute_balance_transformation(G: Tensor) -> Tensor:
M = G.T @ G

def _compute_balance_transformation(M: Tensor) -> Tensor:
lambda_, V = torch.linalg.eigh(M, UPLO="U") # More modern equivalent to torch.symeig
tol = torch.max(lambda_) * len(M) * torch.finfo().eps
rank = sum(lambda_ > tol)
Expand Down