Source code for trak.projectors

"""
Projectors are used to project gradients to a lower-dimensional space. This 1) allows
us to compute TRAK scores in a *much* more efficient manner, and 2) turns out to be
act as a useful regularizer (see Appendix E.1 in our paper).

Here, we provide four implementations of the projector:
- :class:`NoOpProjector` (no-op)
- :class:`BasicSingleBlockProjector` (bare-bones, inefficient implementation)
- :class:`BasicProjector` (block-wise implementation)
- :class:`CudaProjector` (a fast implementation with a custom CUDA kernel)
"""
from abc import ABC, abstractmethod
from typing import Union
from enum import Enum
import math
from torch import Tensor
import torch

from .utils import vectorize


ch = torch


[docs] class ProjectionType(str, Enum): normal: str = "normal" rademacher: str = "rademacher"
[docs] class AbstractProjector(ABC): """Implementations of the Projector class must implement the :meth:`AbstractProjector.project` method, which takes in model gradients and returns """
[docs] @abstractmethod def __init__( self, grad_dim: int, proj_dim: int, seed: int, proj_type: Union[str, ProjectionType], device: Union[str, torch.device], ) -> None: """Initializes hyperparameters for the projection. Args: grad_dim (int): number of parameters in the model (dimension of the gradient vectors) proj_dim (int): dimension after the projection seed (int): random seed for the generation of the sketching (projection) matrix proj_type (Union[str, ProjectionType]): the random projection (JL transform) guearantees that distances will be approximately preserved for a variety of choices of the random matrix (see e.g. https://arxiv.org/abs/1411.2404). Here, we provide an implementation for matrices with iid Gaussian entries and iid Rademacher entries. device (Union[str, torch.device]): CUDA device to use """ self.grad_dim = grad_dim self.proj_dim = proj_dim self.seed = seed self.proj_type = proj_type self.device = device
[docs] @abstractmethod def project(self, grads: Tensor, model_id: int) -> Tensor: """Performs the random projection. Model ID is included so that we generate different projection matrices for every model ID. Args: grads (Tensor): a batch of gradients to be projected model_id (int): a unique ID for a checkpoint Returns: Tensor: the projected gradients """
[docs] def free_memory(self): """Frees up memory used by the projector."""
[docs] class NoOpProjector(AbstractProjector): """ A projector that returns the gradients as they are, i.e., implements :code:`projector.project(grad) = grad`. """
[docs] def __init__( self, grad_dim: int = 0, proj_dim: int = 0, seed: int = 0, proj_type: Union[str, ProjectionType] = "na", device: Union[str, torch.device] = "cuda", *args, **kwargs, ) -> None: super().__init__(grad_dim, proj_dim, seed, proj_type, device)
[docs] def project(self, grads: Tensor, model_id: int) -> Tensor: """A no-op method. Args: grads (Tensor): a batch of gradients to be projected model_id (int): a unique ID for a checkpoint Returns: Tensor: the (non-)projected gradients """ if isinstance(grads, dict): grads = vectorize(grads, device=self.device) return grads
[docs] def free_memory(self): """A no-op method.""" pass
[docs] class BasicSingleBlockProjector(AbstractProjector): """ A bare-bones, inefficient implementation of the projection, which simply calls torch's matmul for the projection step. Note: for most model sizes (e.g. even for ResNet18), and small projection dimensions (e.g. anything > 100) this method will OOM on an A100. Unless you have a good reason to use this class (I cannot think of one, I added this only for testing purposes), use instead the CudaProjector or BasicProjector. """
[docs] def __init__( self, grad_dim: int, proj_dim: int, seed: int, proj_type: ProjectionType, device, dtype=ch.float32, model_id=0, *args, **kwargs, ) -> None: super().__init__(grad_dim, proj_dim, seed, proj_type, device) self.model_id = model_id self.proj_type = proj_type self.generator = ch.Generator(device=self.device) self.generator = self.generator.manual_seed( self.seed + int(1e4) * self.model_id ) self.dtype = dtype self.proj_matrix = ch.empty( self.grad_dim, self.proj_dim, dtype=self.dtype, device=self.device ) self.proj_matrix_available = True self.generate_sketch_matrix() # updates self.proj_matrix
[docs] def free_memory(self): del self.proj_matrix self.proj_matrix_available = False
[docs] def generate_sketch_matrix(self): if not self.proj_matrix_available: self.proj_matrix = ch.empty( self.grad_dim, self.proj_dim, dtype=self.dtype, device=self.device ) self.proj_matrix_available = True if self.proj_type == ProjectionType.normal or self.proj_type == "normal": self.proj_matrix.normal_(generator=self.generator) elif ( self.proj_type == ProjectionType.rademacher or self.proj_type == "rademacher" ): self.proj_matrix.bernoulli_(p=0.5, generator=self.generator) # going from Bernoulli {0, 1} to Rademacher {-1, 1} self.proj_matrix *= 2.0 self.proj_matrix -= 1.0 else: raise KeyError(f"Projection type {self.proj_type} not recognized.")
[docs] def project(self, grads: Tensor, model_id: int) -> Tensor: if isinstance(grads, dict): grads = vectorize(grads, device=self.device) grads = grads.to(dtype=self.dtype) if model_id != self.model_id: self.model_id = model_id self.generator = self.generator.manual_seed( self.seed + int(1e4) * self.model_id ) self.generate_sketch_matrix() # updates self.proj_matrix return grads @ self.proj_matrix
[docs] class BasicProjector(AbstractProjector): """ A simple block-wise implementation of the projection. The projection matrix is generated on-device in blocks. The accumulated result across blocks is returned. Note: This class will be significantly slower and have a larger memory footprint than the CudaProjector. It is recommended that you use this method only if the CudaProjector is not available to you -- e.g. if you don't have a CUDA-enabled device with compute capability >=7.0 (see https://developer.nvidia.com/cuda-gpus). """
[docs] def __init__( self, grad_dim: int, proj_dim: int, seed: int, proj_type: ProjectionType, device: torch.device, block_size: int = 100, dtype: torch.dtype = ch.float32, model_id=0, *args, **kwargs, ) -> None: super().__init__(grad_dim, proj_dim, seed, proj_type, device) self.block_size = min(self.proj_dim, block_size) self.num_blocks = math.ceil(self.proj_dim / self.block_size) self.dtype = dtype self.proj_type = proj_type self.model_id = model_id self.proj_matrix = ch.empty( self.grad_dim, self.block_size, dtype=self.dtype, device=self.device ) self.proj_matrix_available = True self.generator = ch.Generator(device=self.device) self.get_generator_states() self.generate_sketch_matrix(self.generator_states[0])
[docs] def free_memory(self): del self.proj_matrix self.proj_matrix_available = False
[docs] def get_generator_states(self): self.generator_states = [] self.seeds = [] self.jl_size = self.grad_dim * self.block_size for i in range(self.num_blocks): s = self.seed + int(1e3) * i + int(1e5) * self.model_id self.seeds.append(s) self.generator = self.generator.manual_seed(s) self.generator_states.append(self.generator.get_state())
[docs] def generate_sketch_matrix(self, generator_state): if not self.proj_matrix_available: self.proj_matrix = ch.empty( self.grad_dim, self.block_size, dtype=self.dtype, device=self.device ) self.proj_matrix_available = True self.generator.set_state(generator_state) if self.proj_type == ProjectionType.normal or self.proj_type == "normal": self.proj_matrix.normal_(generator=self.generator) elif ( self.proj_type == ProjectionType.rademacher or self.proj_type == "rademacher" ): self.proj_matrix.bernoulli_(p=0.5, generator=self.generator) self.proj_matrix *= 2.0 self.proj_matrix -= 1.0 else: raise KeyError(f"Projection type {self.proj_type} not recognized.")
[docs] def project(self, grads: Tensor, model_id: int) -> Tensor: if isinstance(grads, dict): grads = vectorize(grads, device=self.device) grads = grads.to(dtype=self.dtype) sketch = ch.zeros( size=(grads.size(0), self.proj_dim), dtype=self.dtype, device=self.device ) if model_id != self.model_id: self.model_id = model_id self.get_generator_states() # regenerate random seeds for new model_id if self.num_blocks == 1: self.generate_sketch_matrix(self.generator_states[0]) if self.num_blocks == 1: ch.matmul(grads.data, self.proj_matrix, out=sketch) else: for ind in range(self.num_blocks): self.generate_sketch_matrix(self.generator_states[ind]) st = ind * self.block_size ed = min((ind + 1) * self.block_size, self.proj_dim) sketch[:, st:ed] = ( grads.type(self.dtype) @ self.proj_matrix[:, : (ed - st)] ) return sketch.type(grads.dtype)
[docs] class CudaProjector(AbstractProjector): """ A performant implementation of the projection for CUDA with compute capability >= 7.0. """
[docs] def __init__( self, grad_dim: int, proj_dim: int, seed: int, proj_type: ProjectionType, device, max_batch_size: int, *args, **kwargs, ) -> None: """ Args: grad_dim (int): Number of parameters proj_dim (int): Dimension we project *to* during the projection step seed (int): Random seed proj_type (ProjectionType): Type of randomness to use for projection matrix (rademacher or normal) device: CUDA device max_batch_size (int): Explicitly constraints the batch size the CudaProjector is going to use for projection. Set this if you get a 'The batch size of the CudaProjector is too large for your GPU' error. Must be either 8, 16, or 32. Raises: ValueError: When attempting to use this on a non-CUDA device ModuleNotFoundError: When fast_jl is not installed """ super().__init__(grad_dim, proj_dim, seed, proj_type, device) self.max_batch_size = max_batch_size if isinstance(device, str): device = ch.device(device) if device.type != "cuda": err = "CudaProjector only works on a CUDA device; Either switch to a CUDA device, or use the BasicProjector" raise ValueError(err) self.num_sms = ch.cuda.get_device_properties(device.index).multi_processor_count try: import fast_jl # test run to catch at init time if projection goes through fast_jl.project_rademacher_8( ch.zeros(8, 1_000, device="cuda"), 512, 0, self.num_sms ) except ImportError: err = "You should make sure to install the CUDA projector for traker (called fast_jl).\ See the installation FAQs for more details." raise ModuleNotFoundError(err)
[docs] def project( self, grads: Union[dict, Tensor], model_id: int, ) -> Tensor: if isinstance(grads, dict): grads = vectorize(grads, device=self.device) batch_size = grads.shape[0] effective_batch_size = 32 if batch_size <= 8: effective_batch_size = 8 elif batch_size <= 16: effective_batch_size = 16 effective_batch_size = min(self.max_batch_size, effective_batch_size) function_name = f"project_{self.proj_type.value}_{effective_batch_size}" import fast_jl fn = getattr(fast_jl, function_name) try: result = fn( grads, self.proj_dim, self.seed + int(1e4) * model_id, self.num_sms ) except RuntimeError as e: if "CUDA error: too many resources requested for launch" in str(e): # provide a more helpful error message raise RuntimeError( ( "The batch size of the CudaProjector is too large for your GPU. " "Reduce it by using the proj_max_batch_size argument of the TRAKer.\nOriginal error:" ) ) else: raise e return result
[docs] def free_memory(self): """A no-op method.""" pass
[docs] class ChunkedCudaProjector:
[docs] def __init__( self, projector_per_chunk: list, max_chunk_size: int, params_per_chunk: list, feat_bs: int, device: torch.device, dtype: torch.dtype, ): self.projector_per_chunk = projector_per_chunk self.proj_dim = self.projector_per_chunk[0].proj_dim self.proj_type = self.projector_per_chunk[0].proj_type self.params_per_chunk = params_per_chunk self.max_chunk_size = max_chunk_size self.feat_bs = feat_bs self.device = device self.dtype = dtype self.input_allocated = False
[docs] def allocate_input(self): if self.input_allocated: return self.ch_input = ch.zeros( size=(self.feat_bs, self.max_chunk_size), device=self.device, dtype=self.dtype, ) self.input_allocated = True
[docs] def free_memory(self): if not self.input_allocated: return del self.ch_input self.input_allocated = False
[docs] def project(self, grads, model_id): self.allocate_input() ch_output = ch.zeros( size=(self.feat_bs, self.proj_dim), device=self.device, dtype=self.dtype ) pointer = 0 # iterate over params, keep a counter of params so far, and when prev # chunk reaches max_chunk_size, project and accumulate projector_index = 0 for i, p in enumerate(grads.values()): if len(p.shape) < 2: p_flat = p.data.unsqueeze(-1) else: p_flat = p.data.flatten(start_dim=1) param_size = p_flat.size(1) if pointer + param_size > self.max_chunk_size: # fill remaining entries with 0 assert pointer == self.params_per_chunk[projector_index] # project and accumulate ch_output.add_( self.projector_per_chunk[projector_index].project( self.ch_input[:, :pointer].contiguous(), model_id=model_id, ) ) # reset counter pointer = 0 projector_index += 1 # continue accumulation actual_bs = min(self.ch_input.size(0), p_flat.size(0)) self.ch_input[:actual_bs, pointer : pointer + param_size].copy_(p_flat) pointer += param_size # at the end, we need to project remaining items # fill remaining entries with 0 assert pointer == self.params_per_chunk[projector_index] # project and accumulate ch_output[:actual_bs].add_( self.projector_per_chunk[projector_index].project( self.ch_input[:actual_bs, :pointer].contiguous(), model_id=model_id, ) ) return ch_output[:actual_bs]