-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathcagrad.py
More file actions
105 lines (79 loc) · 3.71 KB
/
cagrad.py
File metadata and controls
105 lines (79 loc) · 3.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import cvxpy as cp
import numpy as np
import torch
from torch import Tensor
from ._gramian_utils import compute_gramian, normalize
from .bases import _WeightedAggregator, _Weighting
class CAGrad(_WeightedAggregator):
"""
:class:`~torchjd.aggregation.bases.Aggregator` as defined in Algorithm 1 of
`Conflict-Averse Gradient Descent for Multi-task Learning
<https://arxiv.org/pdf/2110.14048.pdf>`_.
:param c: The scale of the radius of the ball constraint.
:param norm_eps: A small value to avoid division by zero when normalizing.
.. admonition::
Example
Use CAGrad to aggregate a matrix.
>>> from torch import tensor
>>> from torchjd.aggregation import CAGrad
>>>
>>> A = CAGrad(c=0.5)
>>> J = tensor([[-4., 1., 1.], [6., 1., 1.]])
>>>
>>> A(J)
tensor([0.1835, 1.2041, 1.2041])
"""
def __init__(self, c: float, norm_eps: float = 0.0001):
super().__init__(weighting=_CAGradWeighting(c=c, norm_eps=norm_eps))
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(c={self.weighting.c}, norm_eps={self.weighting.norm_eps})"
)
def __str__(self) -> str:
c_str = str(self.weighting.c).rstrip("0")
return f"CAGrad{c_str}"
class _CAGradWeighting(_Weighting):
"""
:class:`~torchjd.aggregation.bases._Weighting` that extracts weights using the CAGrad
algorithm, as defined in algorithm 1 of `Conflict-Averse Gradient Descent for Multi-task
Learning <https://arxiv.org/pdf/2110.14048.pdf>`_.
:param c: The scale of the radius of the ball constraint.
:param norm_eps: A small value to avoid division by zero when normalizing.
.. note::
This implementation differs from the `official implementations
<https://github.com/Cranial-XIX/CAGrad/>`_ in the way the underlying optimization problem is
solved. This uses the `CLARABEL <https://oxfordcontrol.github.io/ClarabelDocs/stable/>`_
solver of `cvxpy <https://www.cvxpy.org/index.html>`_ rather than the `scipy.minimize
<https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html>`_
function.
"""
def __init__(self, c: float, norm_eps: float):
super().__init__()
if c < 0.0:
raise ValueError(f"Parameter `c` should be a non-negative float. Found `c = {c}`.")
self.c = c
self.norm_eps = norm_eps
def forward(self, matrix: Tensor) -> Tensor:
gramian = normalize(compute_gramian(matrix), self.norm_eps)
return self._compute_from_gramian(gramian)
def _compute_from_gramian(self, gramian: Tensor) -> Tensor:
U, S, _ = torch.svd(gramian)
reduced_matrix = U @ S.sqrt().diag()
reduced_array = reduced_matrix.cpu().detach().numpy().astype(np.float64)
dimension = gramian.shape[0]
reduced_g_0 = reduced_array.T @ np.ones(dimension) / dimension
sqrt_phi = self.c * np.linalg.norm(reduced_g_0, 2)
w = cp.Variable(shape=dimension)
cost = (reduced_array @ reduced_g_0).T @ w + sqrt_phi * cp.norm(reduced_array.T @ w, 2)
problem = cp.Problem(objective=cp.Minimize(cost), constraints=[w >= 0, cp.sum(w) == 1])
problem.solve(cp.CLARABEL)
w_opt = w.value
g_w_norm = np.linalg.norm(reduced_array.T @ w_opt)
if g_w_norm >= self.norm_eps:
weight_array = np.ones(dimension) / dimension
weight_array += (sqrt_phi / g_w_norm) * w_opt
else:
# We are approximately on the pareto front
weight_array = np.zeros(dimension)
weights = torch.from_numpy(weight_array).to(device=gramian.device, dtype=gramian.dtype)
return weights