"""
Computing features for the TRAK algorithm involves computing (and projecting)
per-sample gradients. This module contains classes that compute these
per-sample gradients. The :code:`AbstractFeatureComputer` class defines the
interface for such gradient computers. Then, we provide two implementations:
- :class:`FunctionalFeatureComputer`: A fast implementation that uses
:code:`torch.func` to vectorize the computation of per-sample gradients, and
thus fully levereage parallelism.
- :class:`IterativeFeatureComputer`: A more naive implementation that only uses
native pytorch operations (i.e. no :code:`torch.func`), and computes per-sample
gradients in a for-loop. This is often much slower than the functional
version, but it is useful if you cannot use :code:`torch.func`, e.g., if you
have an old version of pytorch that does not support it, or if your application
is not supported by :code:`torch.func`.
"""
from abc import ABC, abstractmethod
from typing import Iterable, Optional
from torch import Tensor
from .utils import get_num_params, parameters_to_vector
from .modelout_functions import AbstractModelOutput
import logging
import torch
ch = torch
[docs]
class AbstractGradientComputer(ABC):
"""Implementations of the GradientComputer class should allow for
per-sample gradients. This is behavior is enabled with three methods:
- the :meth:`.load_model_params` method, well, loads model parameters. It can
be as simple as a :code:`self.model.load_state_dict(..)`
- the :meth:`.compute_per_sample_grad` method computes per-sample gradients
of the chosen model output function with respect to the model's parameters.
- the :meth:`.compute_loss_grad` method computes the gradients of the loss
function with respect to the model output (which should be a scalar) for
every sample.
"""
[docs]
@abstractmethod
def __init__(
self,
model: torch.nn.Module,
task: AbstractModelOutput,
grad_dim: Optional[int] = None,
dtype: Optional[torch.dtype] = torch.float16,
device: Optional[torch.device] = "cuda",
) -> None:
"""Initializes attributes, nothing too interesting happening.
Args:
model (torch.nn.Module):
model
task (AbstractModelOutput):
task (model output function)
grad_dim (int, optional):
Size of the gradients (number of model parameters). Defaults to
None.
dtype (torch.dtype, optional):
Torch dtype of the gradients. Defaults to torch.float16.
device (torch.device, optional):
Torch device where gradients will be stored. Defaults to 'cuda'.
"""
self.model = model
self.modelout_fn = task
self.grad_dim = grad_dim
self.dtype = dtype
self.device = device
[docs]
@abstractmethod
def load_model_params(self, model) -> None:
...
[docs]
@abstractmethod
def compute_per_sample_grad(self, batch: Iterable[Tensor]) -> Tensor:
...
[docs]
@abstractmethod
def compute_loss_grad(self, batch: Iterable[Tensor], batch_size: int) -> Tensor:
...
[docs]
class FunctionalGradientComputer(AbstractGradientComputer):
[docs]
def __init__(
self,
model: torch.nn.Module,
task: AbstractModelOutput,
grad_dim: int,
dtype: torch.dtype,
device: torch.device,
grad_wrt: Optional[Iterable[str]] = None,
) -> None:
"""Initializes attributes, and loads model parameters.
Args:
grad_wrt (list[str], optional):
A list of parameter names for which to keep gradients. If None,
gradients are taken with respect to all model parameters.
Defaults to None.
"""
super().__init__(model, task, grad_dim, dtype, device)
self.model = model
self.num_params = get_num_params(self.model)
self.load_model_params(model)
self.grad_wrt = grad_wrt
self.logger = logging.getLogger("GradientComputer")
[docs]
def load_model_params(self, model) -> None:
"""Given a a torch.nn.Module model, inits/updates the (functional)
weights and buffers. See https://pytorch.org/docs/stable/func.html
for more details on :code:`torch.func`'s functional models.
Args:
model (torch.nn.Module):
model to load
"""
self.func_weights = dict(model.named_parameters())
self.func_buffers = dict(model.named_buffers())
[docs]
def compute_per_sample_grad(self, batch: Iterable[Tensor]) -> Tensor:
"""Uses functorch's :code:`vmap` (see
https://pytorch.org/functorch/stable/generated/functorch.vmap.html#functorch.vmap
for more details) to vectorize the computations of per-sample gradients.
Doesn't use :code:`batch_size`; only added to follow the abstract method
signature.
Args:
batch (Iterable[Tensor]):
batch of data
Returns:
dict[Tensor]:
A dictionary where each key is a parameter name and the value is
the gradient tensor for that parameter.
"""
# taking the gradient wrt weights (second argument of get_output, hence argnums=1)
grads_loss = torch.func.grad(
self.modelout_fn.get_output, has_aux=False, argnums=1
)
# map over batch dimensions (hence 0 for each batch dimension, and None for model params)
grads = torch.func.vmap(
grads_loss,
in_dims=(None, None, None, *([0] * len(batch))),
randomness="different",
)(self.model, self.func_weights, self.func_buffers, *batch)
if self.grad_wrt is not None:
for param_name in list(grads.keys()):
if param_name not in self.grad_wrt:
del grads[param_name]
return grads
[docs]
def compute_loss_grad(self, batch: Iterable[Tensor]) -> Tensor:
"""Computes the gradient of the loss with respect to the model output
.. math::
\\partial \\ell / \\partial \\text{(model output)}
Note: For all applications we considered, we analytically derived the
out-to-loss gradient, thus avoiding the need to do any backward passes
(let alone per-sample grads). If for your application this is not feasible,
you'll need to subclass this and modify this method to have a structure
similar to the one of :meth:`FunctionalGradientComputer:.get_output`,
i.e. something like:
.. code-block:: python
grad_out_to_loss = grad(self.model_out_to_loss_grad, ...)
grads = vmap(grad_out_to_loss, ...)
...
Args:
batch (Iterable[Tensor]):
batch of data
Returns:
Tensor:
The gradient of the loss with respect to the model output.
"""
return self.modelout_fn.get_out_to_loss_grad(
self.model, self.func_weights, self.func_buffers, batch
)
[docs]
class IterativeGradientComputer(AbstractGradientComputer):
[docs]
def __init__(
self,
model,
task: AbstractModelOutput,
grad_dim: int,
dtype: torch.dtype,
device: torch.device,
grad_wrt: Optional[Iterable[str]] = None,
) -> None:
super().__init__(model, task, grad_dim, dtype, device)
self.load_model_params(model)
self.grad_wrt = grad_wrt
self.logger = logging.getLogger("GradientComputer")
if self.grad_wrt is not None:
self.logger.warning(
"IterativeGradientComputer: ignoring grad_wrt argument."
)
[docs]
def load_model_params(self, model) -> Tensor:
self.model = model
self.model_params = list(self.model.parameters())
[docs]
def compute_per_sample_grad(self, batch: Iterable[Tensor]) -> Tensor:
"""Computes per-sample gradients of the model output function This
method does not leverage vectorization (and is hence much slower than
its equivalent in :class:`.FunctionalGradientComputer`). We recommend
that you use this only if :code:`torch.func` is not available to you,
e.g. if you have a (very) old version of pytorch.
Args:
batch (Iterable[Tensor]):
batch of data
Returns:
Tensor:
gradients of the model output function of each sample in the
batch with respect to the model's parameters.
"""
batch_size = batch[0].shape[0]
grads = ch.zeros(batch_size, self.grad_dim).to(batch[0].device)
margin = self.modelout_fn.get_output(self.model, None, None, *batch)
for ind in range(batch_size):
grads[ind] = parameters_to_vector(
ch.autograd.grad(margin[ind], self.model_params, retain_graph=True)
)
return grads
[docs]
def compute_loss_grad(self, batch: Iterable[Tensor]) -> Tensor:
"""Computes the gradient of the loss with respect to the model output
.. math::
\\partial \\ell / \\partial \\text{(model output)}
Note: For all applications we considered, we analytically derived the
out-to-loss gradient, thus avoiding the need to do any backward passes
(let alone per-sample grads). If for your application this is not feasible,
you'll need to subclass this and modify this method to have a structure
similar to the one of :meth:`.IterativeGradientComputer.get_output`,
i.e. something like:
.. code-block:: python
out_to_loss = self.model_out_to_loss(...)
for ind in range(batch_size):
grads[ind] = torch.autograd.grad(out_to_loss[ind], ...)
...
Args:
batch (Iterable[Tensor]):
batch of data
Returns:
Tensor:
The gradient of the loss with respect to the model output.
"""
return self.modelout_fn.get_out_to_loss_grad(self.model, None, None, batch)