-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathmgda.py
More file actions
90 lines (74 loc) · 3.33 KB
/
mgda.py
File metadata and controls
90 lines (74 loc) · 3.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
from torch import Tensor
from ._gramian_utils import compute_gramian
from .bases import _WeightedAggregator, _Weighting
class MGDA(_WeightedAggregator):
r"""
:class:`~torchjd.aggregation.bases.Aggregator` performing the gradient aggregation step of
`Multiple-gradient descent algorithm (MGDA) for multiobjective optimization
<https://www.sciencedirect.com/science/article/pii/S1631073X12000738>`_. The implementation is
based on Algorithm 2 of `Multi-Task Learning as Multi-Objective Optimization
<https://proceedings.neurips.cc/paper_files/paper/2018/file/432aca3a1e345e339f35a30c8f65edce-Paper.pdf>`_.
:param epsilon: The value of :math:`\hat{\gamma}` below which we stop the optimization.
:param max_iters: The maximum number of iterations of the optimization loop.
.. admonition::
Example
Use MGDA to aggregate a matrix.
>>> from torch import tensor
>>> from torchjd.aggregation import MGDA
>>>
>>> A = MGDA()
>>> J = tensor([[-4., 1., 1.], [6., 1., 1.]])
>>>
>>> A(J)
tensor([1.1921e-07, 1.0000e+00, 1.0000e+00])
"""
def __init__(self, epsilon: float = 0.001, max_iters: int = 100):
super().__init__(weighting=_MGDAWeighting(epsilon=epsilon, max_iters=max_iters))
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(epsilon={self.weighting.epsilon}, "
f"max_iters={self.weighting.max_iters})"
)
class _MGDAWeighting(_Weighting):
r"""
:class:`~torchjd.aggregation.bases._Weighting` that extracts weights using Algorithm
2 of `Multi-Task Learning as Multi-Objective Optimization
<https://proceedings.neurips.cc/paper_files/paper/2018/file/432aca3a1e345e339f35a30c8f65edce-Paper.pdf>`_.
:param epsilon: The value of :math:`\hat{\gamma}` below which we stop the optimization.
:param max_iters: The maximum number of iterations of the optimization loop.
"""
def __init__(self, epsilon: float, max_iters: int):
super().__init__()
self.epsilon = epsilon
self.max_iters = max_iters
def _compute_from_gramian(self, gramian: Tensor) -> Tensor:
"""
This is the Frank-Wolfe solver in Algorithm 2 of `Multi-Task Learning as Multi-Objective
Optimization
<https://proceedings.neurips.cc/paper_files/paper/2018/file/432aca3a1e345e339f35a30c8f65edce-Paper.pdf>`_.
"""
device = gramian.device
dtype = gramian.dtype
alpha = torch.ones(gramian.shape[0], device=device, dtype=dtype) / gramian.shape[0]
for i in range(self.max_iters):
t = torch.argmin(gramian @ alpha)
e_t = torch.zeros(gramian.shape[0], device=device, dtype=dtype)
e_t[t] = 1.0
a = alpha @ (gramian @ e_t)
b = alpha @ (gramian @ alpha)
c = e_t @ (gramian @ e_t)
if c <= a:
gamma = 1.0
elif b <= a:
gamma = 0.0
else:
gamma = (b - a) / (b + c - 2 * a)
alpha = (1 - gamma) * alpha + gamma * e_t
if gamma < self.epsilon:
break
return alpha
def forward(self, matrix: Tensor) -> Tensor:
gramian = compute_gramian(matrix)
weights = self._compute_from_gramian(gramian)
return weights