Source code for trak.score_computers

"""
Computing scores for the TRAK algorithm from pre-computed projected gradients
involves a number of matrix multiplications. This module contains classes that
perform these operations. The :code:`AbstractScoreComputer` class defines the
interface for score computers. Then, we provide two implementations:
- :class:`BasicSingleBlockScoreComputer`: A bare-bones implementation, mostly for
    testing purposes.
- :class:`BasicScoreComputer`: A more sophisticated implementation that does
    block-wise matrix multiplications to avoid OOM errors.

"""
from abc import ABC, abstractmethod
import logging
from torch import Tensor
import torch

from .utils import get_matrix_mult

ch = torch


[docs] class AbstractScoreComputer(ABC): """ The :code:`ScoreComputer` class Implementations of the ScoreComputer class must implement three methods: - :code:`get_xtx` - :code:`get_x_xtx_inv` - :code:`get_scores` """
[docs] @abstractmethod def __init__(self, dtype, device) -> None: self.dtype = dtype self.device = device
[docs] @abstractmethod def get_xtx(self, grads: Tensor) -> Tensor: """Computes :math:`X^\top X`, where :math:`X` is the matrix of projected gradients. Here, the shape of :math:`X` is :code:`(n, p)`, where :math:`n` is the number of training examples and :math:`p` is the dimension of the projection. Args: grads (Tensor): projected gradients of shape :code:`(n, p)`. Returns: Tensor: :math:`X^\top X` of shape :code:`(p, p)`. """
[docs] @abstractmethod def get_x_xtx_inv(self, grads: Tensor, xtx: Tensor) -> Tensor: """Computes :math:`X(X^\top X)^{-1}`, where :math:`X` is the matrix of projected gradients. Here, the shape of :math:`X` is :code:`(n, p)`, where :math:`n` is the number of training examples and :math:`p` is the dimension of the projection. This function takes as input the pre-computed :math:`X^\top X` matrix, which is computed by the :code:`get_xtx` method. Args: grads (Tensor): projected gradients :math:`X` of shape :code:`(n, p)`. xtx (Tensor): :math:`X^\top X` of shape :code:`(p, p)`. Returns: Tensor: :math:`X(X^\top X)^{-1}` of shape :code:`(n, p)`. """
[docs] @abstractmethod def get_scores( self, features: Tensor, target_grads: Tensor, accumulator: Tensor ) -> None: """Computes the scores for a given set of features and target gradients. In particular, this function takes in a matrix of features :math:`\Phi=X(X^\top X)^{-1}`, computed by the :code:`get_x_xtx_inv` method, and a matrix of target (projected) gradients :math:`X_{target}`. Then, it computes the scores as :math:`\Phi X_{target}^\top`. The resulting matrix has shape :code:`(n, m)`, where :math:`n` is the number of training examples and :math:`m` is the number of target examples. The :code:`accumulator` argument is used to store the result of the computation. This is useful when computing scores for multiple model checkpoints, as it allows us to re-use the same memory for the score matrix. Args: features (Tensor): features :math:`\Phi` of shape :code:`(n, p)`. target_grads (Tensor): target projected gradients :math:`X_{target}` of shape :code:`(m, p)`. accumulator (Tensor): accumulator of shape :code:`(n, m)`. """
[docs] class BasicSingleBlockScoreComputer(AbstractScoreComputer): """A bare-bones implementation of :code:`ScoreComputer` that will likely OOM for almost all applications. Here for testing purposes only. Unless you have a good reason not to, you should use :func:`BasicScoreComputer` instead. """
[docs] def get_xtx(self, grads: Tensor) -> Tensor: return grads.T @ grads
[docs] def get_x_xtx_inv(self, grads: Tensor, xtx: Tensor) -> Tensor: # torch.linalg.inv does not support float16 return grads @ ch.linalg.inv(xtx.float()).to(self.dtype)
[docs] def get_scores( self, features: Tensor, target_grads: Tensor, accumulator: Tensor ) -> None: accumulator += (features @ target_grads.T).detach().cpu()
[docs] class BasicScoreComputer(AbstractScoreComputer): """An implementation of :code:`ScoreComputer` that computes matmuls in a block-wise manner. """
[docs] def __init__( self, dtype: torch.dtype, device: torch.device, CUDA_MAX_DIM_SIZE: int = 20_000, logging_level=logging.INFO, lambda_reg: float = 0.0, ) -> None: """ Args: dtype (torch.dtype): device (Union[str, torch.device]): torch device to do matmuls on CUDA_MAX_DIM_SIZE (int, optional): Size of block for block-wise matmuls. Defaults to 100_000. logging_level (logging level, optional): Logging level for the logger. Defaults to logging.info. lambda_reg (int): regularization term for l2 reg on xtx """ super().__init__(dtype, device) self.CUDA_MAX_DIM_SIZE = CUDA_MAX_DIM_SIZE self.logger = logging.getLogger("ScoreComputer") self.logger.setLevel(logging_level) self.lambda_reg = lambda_reg
[docs] def get_xtx(self, grads: Tensor) -> Tensor: self.proj_dim = grads.shape[1] result = ch.zeros( self.proj_dim, self.proj_dim, dtype=self.dtype, device=self.device ) blocks = ch.split(grads, split_size_or_sections=self.CUDA_MAX_DIM_SIZE, dim=0) for block in blocks: result += block.T.to(self.device) @ block.to(self.device) return result
[docs] def get_x_xtx_inv(self, grads: Tensor, xtx: Tensor) -> Tensor: blocks = ch.split(grads, split_size_or_sections=self.CUDA_MAX_DIM_SIZE, dim=0) xtx_reg = xtx + self.lambda_reg * torch.eye( xtx.size(dim=0), device=xtx.device, dtype=xtx.dtype ) xtx_inv = ch.linalg.inv(xtx_reg.to(ch.float32)) # center X^TX inverse a bit to avoid numerical issues when going to float16 xtx_inv /= xtx_inv.abs().mean() xtx_inv = xtx_inv.to(self.dtype) result = ch.empty( grads.shape[0], xtx_inv.shape[1], dtype=self.dtype, device=self.device ) for i, block in enumerate(blocks): start = i * self.CUDA_MAX_DIM_SIZE end = min(grads.shape[0], (i + 1) * self.CUDA_MAX_DIM_SIZE) result[start:end] = block.to(self.device) @ xtx_inv return result
[docs] def get_scores( self, features: Tensor, target_grads: Tensor, accumulator: Tensor ) -> Tensor: train_dim = features.shape[0] target_dim = target_grads.shape[0] self.logger.debug(f"{train_dim=}, {target_dim=}") accumulator += ( get_matrix_mult(features=features, target_grads=target_grads).detach().cpu() )