API reference#
trak.traker module#
This module contains the main TRAKer
class, which is the front-facing
class for TRAK. See the README and docs for example usage.
In short, methods of the TRAKer
class are used to compute TRAK scores
for a set of model checkpoints, a set of target samples, and a set of train samples.
This is done in two stages:
- Featurizing. TRAKer.featurize()
and TRAKer.finalize_features()
are used to compute the TRAK features for a set of model checkpoints and a set of train samples.
Scoring.
TRAKer.start_scoring_checkpoint()
,TRAKer.score()
, andTRAKer.finalize_scores()
are used to compute the TRAK scores for a set of target samples, given the TRAK features computed in the previous step.
- class trak.traker.TRAKer(model: ~torch.nn.modules.module.Module, task: ~trak.modelout_functions.AbstractModelOutput | str, train_set_size: int, save_dir: str = './trak_results', load_from_save_dir: bool = True, device: str | ~torch.device = 'cuda', gradient_computer: ~trak.gradient_computers.AbstractGradientComputer = <class 'trak.gradient_computers.FunctionalGradientComputer'>, projector: ~trak.projectors.AbstractProjector | None = None, saver: ~trak.savers.AbstractSaver | None = None, score_computer: ~trak.score_computers.AbstractScoreComputer | None = None, proj_dim: int = 2048, logging_level=20, use_half_precision: bool = True, proj_max_batch_size: int = 32, projector_seed: int = 0, grad_wrt: ~typing.Iterable[str] | None = None, lambda_reg: float = 0.0)[source]#
Bases:
object
The main front-facing class for TRAK. See the README and docs for example usage.
- __init__(model: ~torch.nn.modules.module.Module, task: ~trak.modelout_functions.AbstractModelOutput | str, train_set_size: int, save_dir: str = './trak_results', load_from_save_dir: bool = True, device: str | ~torch.device = 'cuda', gradient_computer: ~trak.gradient_computers.AbstractGradientComputer = <class 'trak.gradient_computers.FunctionalGradientComputer'>, projector: ~trak.projectors.AbstractProjector | None = None, saver: ~trak.savers.AbstractSaver | None = None, score_computer: ~trak.score_computers.AbstractScoreComputer | None = None, proj_dim: int = 2048, logging_level=20, use_half_precision: bool = True, proj_max_batch_size: int = 32, projector_seed: int = 0, grad_wrt: ~typing.Iterable[str] | None = None, lambda_reg: float = 0.0) None [source]#
- Args:
- model (torch.nn.Module):
model to use for TRAK
- task (Union[AbstractModelOutput, str]):
Type of model that TRAK will be ran on. Accepts either one of the following strings: 1)
image_classification
2)text_classification
3)clip
or an instance of some implementation of the abstract classAbstractModelOutput
.- train_set_size (int):
Size of the train set that TRAK is featurizing
- save_dir (str, optional):
Directory to save final TRAK scores, intermediate results, and metadata. Defaults to :code:’./trak_results’.
- load_from_save_dir (bool, optional):
If True, the :class`.TRAKer` instance will attempt to load existing metadata from save_dir. May lead to I/O issues if multiple TRAKer instances ran in parallel have this flag set to True. See the SLURM tutorial for more details.
- device (Union[str, torch.device], optional):
torch device on which to do computations. Defaults to ‘cuda’.
- gradient_computer (AbstractGradientComputer, optional):
Class to use to get per-example gradients. See
AbstractGradientComputer
for more details. Defaults toFunctionalGradientComputer
.- projector (Optional[AbstractProjector], optional):
Either set
proj_dim
and aCudaProjector
Rademacher projector will be used or give a custom subclass ofAbstractProjector
class and leaveproj_dim
as None. Defaults to None.- saver (Optional[AbstractSaver], optional):
Class to use for saving intermediate results and final TRAK scores to RAM/disk. If None, the
MmapSaver
will be used. Defaults to None.- score_computer (Optional[AbstractScoreComputer], optional):
Class to use for computing the final TRAK scores. If None, the
BasicScoreComputer
will be used. Defaults to None.- proj_dim (int, optional):
Dimension of the projected TRAK features. See Section 4.3 of our paper for more details. Defaults to 2048.
- logging_level (int, optional):
Logging level for TRAK loggers. Defaults to logging.INFO.
- use_half_precision (bool, optional):
If True, TRAK will use half precision (float16) for all computations and arrays will be stored in float16. Otherwise, it will use float32. Defaults to True.
- proj_max_batch_size (int):
Batch size used by fast_jl if the CudaProjector is used. Must be a multiple of 8. The maximum batch size is 32 for A100 GPUs, 16 for V100 GPUs, 40 for H100 GPUs. Defaults to 32.
- projector_seed (int):
Random seed used by the projector. Defaults to 0.
- grad_wrt (Optional[Iterable[str]], optional):
If not None, the gradients will be computed only with respect to the parameters specified in this list. The list should contain the names of the parameters to compute gradients with respect to, as they appear in the model’s state dictionary. If None, gradients are taken with respect to all model parameters. Defaults to None.
- lambda_reg (float):
The \(\ell_2\) (ridge) regularization penalty added to the \(XTX\) term in score computers when computing the matrix inverse \((XTX)^{-1}\). Defaults to 0.
- featurize(batch: Iterable[Tensor], inds: Iterable[int] | None = None, num_samples: int | None = None) None [source]#
Creates TRAK features for the given batch by computing the gradient of the model output function and projecting it. In the notation of the paper, for an input pair \(z=(x,y)\), model parameters \(\theta\), and JL projection matrix \(P\), this method computes \(P^\top \nabla_\theta f(z_i, \theta)\). Additionally, this method computes the gradient of the out-to-loss function (in the notation of the paper, the \(Q\) term in Section 3.4).
Either
inds
ornum_samples
must be specified. Usingnum_samples
will write sequentially into the internal store of theTRAKer()
.- Args:
- batch (Iterable[Tensor]):
input batch
- inds (Optional[Iterable[int]], optional):
Indices of the batch samples in the train set. Defaults to None.
- num_samples (Optional[int], optional):
Number of samples in the batch. Defaults to None.
- finalize_features(model_ids: Iterable[int] | None = None, del_grads: bool = False) None [source]#
For a set of checkpoints \(C\) (specified by model IDs), and gradients \(\{ \Phi_c \}_{c\in C}\), this method computes \(\Phi_c (\Phi_c^\top\Phi_c)^{-1}\) for all \(c\in C\) and stores the results in the internal store of the
TRAKer()
class.- Args:
- model_ids (Iterable[int], optional): A list of model IDs for which
features should be finalized. If None, features are finalized for all model IDs in the
save_dir
of theTRAKer
class. Defaults to None.
- finalize_scores(exp_name: str, model_ids: Iterable[int] | None = None, allow_skip: bool = False) Tensor [source]#
This method computes the final TRAK scores for the given targets, train samples, and model checkpoints (specified by model IDs).
- Args:
- exp_name (str):
Experiment name. Each experiment should have a unique name, and it corresponds to a set of targets being scored. The experiment name is used as the name for saving the target features, as well as scores produced by this method in the
save_dir
of theTRAKer
class.- model_ids (Iterable[int], optional):
A list of model IDs for which scores should be finalized. If None, scores are computed for all model IDs in the
save_dir
of theTRAKer
class. Defaults to None.- allow_skip (bool, optional):
If True, raises only a warning, instead of an error, when target gradients are not computed for a given model ID. Defaults to False.
- Returns:
Tensor: TRAK scores
- init_projector(projector: AbstractProjector | None, proj_dim: int, proj_max_batch_size: int) None [source]#
Initialize the projector for a traker class
- Args:
- projector (Optional[AbstractProjector]):
JL projector to use. If None, a CudaProjector will be used (if possible).
- proj_dim (int):
Dimension of the projected gradients and TRAK features.
- proj_max_batch_size (int):
Batch size used by fast_jl if the CudaProjector is used. Must be a multiple of 8. The maximum batch size is 32 for A100 GPUs, 16 for V100 GPUs, 40 for H100 GPUs.
- load_checkpoint(checkpoint: Iterable[Tensor], model_id: int, _allow_featurizing_already_registered=False) None [source]#
Loads state dictionary for the given checkpoint; initializes arrays to store TRAK features for that checkpoint, tied to the model ID.
- Args:
- checkpoint (Iterable[Tensor]):
state_dict to load
- model_id (int):
a unique ID for a checkpoint
- _allow_featurizing_already_registered (bool, optional):
Only use if you want to override the default behaviour that
featurize
is forbidden on already registered model IDs. Defaults to None.
- score(batch: Iterable[Tensor], inds: Iterable[int] | None = None, num_samples: int | None = None) None [source]#
This method computes the (intermediate per-checkpoint) TRAK scores for a batch of targets and stores them in the internal store of the
TRAKer
class.Either
inds
ornum_samples
must be specified. Usingnum_samples
will write sequentially into the internal store of theTRAKer
.- Args:
- batch (Iterable[Tensor]):
input batch
- inds (Optional[Iterable[int]], optional):
Indices of the batch samples in the train set. Defaults to None.
- num_samples (Optional[int], optional):
Number of samples in the batch. Defaults to None.
- start_scoring_checkpoint(exp_name: str, checkpoint: Iterable[Tensor], model_id: int, num_targets: int) None [source]#
This method prepares the internal store of the
TRAKer
class to start computing scores for a set of targets.- Args:
- exp_name (str):
Experiment name. Each experiment should have a unique name, and it corresponds to a set of targets being scored. The experiment name is used as the name for saving the target features, as well as scores produced by this method in the
save_dir
of theTRAKer
class.- checkpoint (Iterable[Tensor]):
model checkpoint (state dict)
- model_id (int):
a unique ID for a checkpoint
- num_targets (int):
number of targets to score
trak.gradient_computers module#
Computing features for the TRAK algorithm involves computing (and projecting)
per-sample gradients. This module contains classes that compute these
per-sample gradients. The AbstractFeatureComputer
class defines the
interface for such gradient computers. Then, we provide two implementations:
- FunctionalFeatureComputer
: A fast implementation that uses
torch.func
to vectorize the computation of per-sample gradients, and thus fully levereage parallelism.
IterativeFeatureComputer
: A more naive implementation that only uses native pytorch operations (i.e. notorch.func
), and computes per-sample gradients in a for-loop. This is often much slower than the functional version, but it is useful if you cannot usetorch.func
, e.g., if you have an old version of pytorch that does not support it, or if your application is not supported bytorch.func
.
- class trak.gradient_computers.AbstractGradientComputer(model: Module, task: AbstractModelOutput, grad_dim: int | None = None, dtype: dtype | None = torch.float16, device: device | None = 'cuda')[source]#
Bases:
ABC
Implementations of the GradientComputer class should allow for per-sample gradients. This is behavior is enabled with three methods:
the
load_model_params()
method, well, loads model parameters. It can be as simple as aself.model.load_state_dict(..)
the
compute_per_sample_grad()
method computes per-sample gradients of the chosen model output function with respect to the model’s parameters.the
compute_loss_grad()
method computes the gradients of the loss function with respect to the model output (which should be a scalar) for every sample.
- abstract __init__(model: Module, task: AbstractModelOutput, grad_dim: int | None = None, dtype: dtype | None = torch.float16, device: device | None = 'cuda') None [source]#
Initializes attributes, nothing too interesting happening.
- Args:
- model (torch.nn.Module):
model
- task (AbstractModelOutput):
task (model output function)
- grad_dim (int, optional):
Size of the gradients (number of model parameters). Defaults to None.
- dtype (torch.dtype, optional):
Torch dtype of the gradients. Defaults to torch.float16.
- device (torch.device, optional):
Torch device where gradients will be stored. Defaults to ‘cuda’.
- class trak.gradient_computers.FunctionalGradientComputer(model: Module, task: AbstractModelOutput, grad_dim: int, dtype: dtype, device: device, grad_wrt: Iterable[str] | None = None)[source]#
Bases:
AbstractGradientComputer
- __init__(model: Module, task: AbstractModelOutput, grad_dim: int, dtype: dtype, device: device, grad_wrt: Iterable[str] | None = None) None [source]#
Initializes attributes, and loads model parameters.
- Args:
- grad_wrt (list[str], optional):
A list of parameter names for which to keep gradients. If None, gradients are taken with respect to all model parameters. Defaults to None.
- compute_loss_grad(batch: Iterable[Tensor]) Tensor [source]#
Computes the gradient of the loss with respect to the model output
\[\partial \ell / \partial \text{(model output)}\]Note: For all applications we considered, we analytically derived the out-to-loss gradient, thus avoiding the need to do any backward passes (let alone per-sample grads). If for your application this is not feasible, you’ll need to subclass this and modify this method to have a structure similar to the one of
FunctionalGradientComputer:.get_output()
, i.e. something like:grad_out_to_loss = grad(self.model_out_to_loss_grad, ...) grads = vmap(grad_out_to_loss, ...) ...
- Args:
- batch (Iterable[Tensor]):
batch of data
- Returns:
- Tensor:
The gradient of the loss with respect to the model output.
- compute_per_sample_grad(batch: Iterable[Tensor]) Tensor [source]#
Uses functorch’s
vmap
(see https://pytorch.org/functorch/stable/generated/functorch.vmap.html#functorch.vmap for more details) to vectorize the computations of per-sample gradients.Doesn’t use
batch_size
; only added to follow the abstract method signature.- Args:
- batch (Iterable[Tensor]):
batch of data
- Returns:
- dict[Tensor]:
A dictionary where each key is a parameter name and the value is the gradient tensor for that parameter.
- load_model_params(model) None [source]#
Given a a torch.nn.Module model, inits/updates the (functional) weights and buffers. See https://pytorch.org/docs/stable/func.html for more details on
torch.func
’s functional models.- Args:
- model (torch.nn.Module):
model to load
- class trak.gradient_computers.IterativeGradientComputer(model, task: AbstractModelOutput, grad_dim: int, dtype: dtype, device: device, grad_wrt: Iterable[str] | None = None)[source]#
Bases:
AbstractGradientComputer
- __init__(model, task: AbstractModelOutput, grad_dim: int, dtype: dtype, device: device, grad_wrt: Iterable[str] | None = None) None [source]#
Initializes attributes, nothing too interesting happening.
- Args:
- model (torch.nn.Module):
model
- task (AbstractModelOutput):
task (model output function)
- grad_dim (int, optional):
Size of the gradients (number of model parameters). Defaults to None.
- dtype (torch.dtype, optional):
Torch dtype of the gradients. Defaults to torch.float16.
- device (torch.device, optional):
Torch device where gradients will be stored. Defaults to ‘cuda’.
- compute_loss_grad(batch: Iterable[Tensor]) Tensor [source]#
Computes the gradient of the loss with respect to the model output .. math:
\partial \ell / \partial \text{(model output)}
Note: For all applications we considered, we analytically derived the out-to-loss gradient, thus avoiding the need to do any backward passes (let alone per-sample grads). If for your application this is not feasible, you’ll need to subclass this and modify this method to have a structure similar to the one of
IterativeGradientComputer.get_output()
, i.e. something like: .. code-block:: pythonout_to_loss = self.model_out_to_loss(…) for ind in range(batch_size):
grads[ind] = torch.autograd.grad(out_to_loss[ind], …)
…
- Args:
- batch (Iterable[Tensor]):
batch of data
- Returns:
- Tensor:
The gradient of the loss with respect to the model output.
- compute_per_sample_grad(batch: Iterable[Tensor]) Tensor [source]#
Computes per-sample gradients of the model output function This method does not leverage vectorization (and is hence much slower than its equivalent in
FunctionalGradientComputer
). We recommend that you use this only iftorch.func
is not available to you, e.g. if you have a (very) old version of pytorch. Args:- batch (Iterable[Tensor]):
batch of data
- Returns:
- Tensor:
gradients of the model output function of each sample in the batch with respect to the model’s parameters.
trak.modelout_functions module#
Here we provide an abstract “model output” class AbstractModelOutput, together with a number of subclasses for particular applications (vision, language, etc):
These classes implement methods that transform input batches to the desired model output (e.g. logits, loss, etc). See Sections 2 & 3 of our paper for more details on what model output functions are in the context of TRAK and how to use & design them.
See, e.g. this tutorial for an example on how
to subclass AbstractModelOutput
for a task of your choice.
- class trak.modelout_functions.AbstractModelOutput[source]#
Bases:
ABC
See, e.g. this tutorial for an example on how to subclass
AbstractModelOutput
for a task of your choice.Subclasses must implement:
a
get_output
method that takes in a batch of inputs and model weights to produce outputs that TRAK will be trained to predict. In the notation of the paper,get_output
should return \(f(z,\theta)\)a
get_out_to_loss_grad
method that takes in a batch of inputs and model weights to produce the gradient of the function that transforms the model outputs above into the loss with respect to the batch. In the notation of the paper,get_out_to_loss_grad
returns (entries along the diagonal of) \(Q\).
- abstract get_out_to_loss_grad(model, batch: Iterable[Tensor]) Tensor [source]#
See Sections 2 & 3 of our paper for more details on what the out-to-loss functions (in the notation of the paper, \(Q\)) are in the context of TRAK and how to use & design them.
- Args:
model (torch.nn.Module): model batch (Iterable[Tensor]): input batch
- Returns:
Tensor: gradient of the out-to-loss function
- abstract get_output(model, batch: Iterable[Tensor]) Tensor [source]#
See Sections 2 & 3 of our paper for more details on what model output functions are in the context of TRAK and how to use & design them.
- Args:
- model (torch.nn.Module):
model
- batch (Iterable[Tensor]):
input batch
- Returns:
- Tensor:
model output function
- class trak.modelout_functions.CLIPModelOutput(temperature: float | None = None, simulated_batch_size: int = 300)[source]#
Bases:
AbstractModelOutput
Margin for multimodal contrastive learning (CLIP). See Section 5.1 of our paper for more details. Compatible with the open_clip implementation of CLIP.
- Raises:
AssertionError: this model output function requires using additional CLIP embeddings, which are computed using the
get_embeddings()
method. This method should be invoked before featurizing.
- __init__(temperature: float | None = None, simulated_batch_size: int = 300) None [source]#
- Args:
- temperature (float, optional):
Temperature to use inside the softmax for the out-to-loss function. If None, CLIP’s
logit_scale
is used. Defaults to None- simulated_batch_size (int, optional):
Size of the “simulated” batch size for the model output function. See Section 5.1 of the TRAK paper for more details. Defaults to 300.
- static get_embeddings(model, loader, batch_size: int, embedding_dim: int, size: int = 50000, preprocess_fn_img=None, preprocess_fn_txt=None) None [source]#
Computes (image and text) embeddings and saves them in the class attributes
image_embeddings
andtext_embeddings
.- Args:
- model (torch.nn.Module):
model
- loader ():
data loader
- batch_size (int):
input batch size
- size (int, optional):
Maximum number of embeddings to compute. Defaults to 50_000.
- embedding_dim (int, optional):
Dimension of CLIP embedding. Defaults to 1024.
- preprocess_fn_img (func, optional):
Transforms to apply to images from the loader before forward pass. Defaults to None.
- preprocess_fn_txt (func, optional):
Transforms to apply to images from the loader before forward pass. Defaults to None.
- get_out_to_loss_grad(model, weights, buffers, batch: Iterable[Tensor]) Tensor [source]#
Computes the (reweighting term Q in the paper)
- Args:
- model (torch.nn.Module):
torch model
- weights (Iterable[Tensor]):
functorch model weights
- buffers (Iterable[Tensor]):
functorch model buffers
- batch (Iterable[Tensor]):
input batch
- Returns:
- Tensor:
out-to-loss (reweighting term) for the input batch
- static get_output(model: Module, weights: Iterable[Tensor], buffers: Iterable[Tensor], image: Tensor, label: Tensor) Tensor [source]#
For a given input \(z=(x, y)\) and model parameters \(\theta\), let \(\phi(x, \theta)\) be the CLIP image embedding and \(\psi(y, \theta)\) be the CLIP text embedding. Last, let \(B\) be a (simulated) batch. This method implements the model output function
\[-\log(\frac{\phi(x)\cdot \psi(y)}{\sum_{(x', y')\in B} \phi(x)\cdot \psi(y')}) -\log(\frac{\phi(x)\cdot \psi(y)}{\sum_{(x', y')\in B} \phi(x')\cdot \psi(y)})\]It uses functional models from torch.func (previously functorch) to make the per-sample gradient computations (much) faster. For more details on what functional models are, and how to use them, please refer to https://pytorch.org/docs/stable/func.html and https://pytorch.org/functorch/stable/notebooks/per_sample_grads.html.
- Args:
- model (torch.nn.Module):
torch model
- weights (Iterable[Tensor]):
functorch model weights
- buffers (Iterable[Tensor]):
functorch model buffers
- image (Tensor):
input image, should not have batch dimension
- label (Tensor):
input label, should not have batch dimension
- Returns:
- Tensor:
model output for the given image-label pair \(z\) and weights & buffers \(\theta\).
- image_embeddings = None#
- num_computed_embeddings = 0#
- sim_batch_size = 0#
- text_embeddings = None#
- class trak.modelout_functions.ImageClassificationModelOutput(temperature: float = 1.0)[source]#
Bases:
AbstractModelOutput
Margin for (multiclass) image classification. See Section 3.3 of our paper for more details.
- __init__(temperature: float = 1.0) None [source]#
- Args:
temperature (float, optional): Temperature to use inside the softmax for the out-to-loss function. Defaults to 1.
- get_out_to_loss_grad(model, weights, buffers, batch: Iterable[Tensor]) Tensor [source]#
Computes the (reweighting term Q in the paper)
- Args:
- model (torch.nn.Module):
torch model
- weights (Iterable[Tensor]):
functorch model weights
- buffers (Iterable[Tensor]):
functorch model buffers
- batch (Iterable[Tensor]):
input batch
- Returns:
- Tensor:
out-to-loss (reweighting term) for the input batch
- static get_output(model: Module, weights: Iterable[Tensor], buffers: Iterable[Tensor], image: Tensor, label: Tensor) Tensor [source]#
For a given input \(z=(x, y)\) and model parameters \(\theta\), let \(p(z, \theta)\) be the softmax probability of the correct class. This method implements the model output function
\[\log(\frac{p(z, \theta)}{1 - p(z, \theta)}).\]It uses functional models from torch.func (previously functorch) to make the per-sample gradient computations (much) faster. For more details on what functional models are, and how to use them, please refer to https://pytorch.org/docs/stable/func.html and https://pytorch.org/functorch/stable/notebooks/per_sample_grads.html.
- Args:
- model (torch.nn.Module):
torch model
- weights (Iterable[Tensor]):
functorch model weights
- buffers (Iterable[Tensor]):
functorch model buffers
- image (Tensor):
input image, should not have batch dimension
- label (Tensor):
input label, should not have batch dimension
- Returns:
- Tensor:
model output for the given image-label pair \(z\) and weights & buffers \(\theta\).
- class trak.modelout_functions.IterativeImageClassificationModelOutput(temperature: float = 1.0)[source]#
Bases:
AbstractModelOutput
Margin for (multiclass) image classification. See Section 3.3 of our paper for more details.
- __init__(temperature: float = 1.0) None [source]#
- Args:
temperature (float, optional): Temperature to use inside the softmax for the out-to-loss function. Defaults to 1.
- get_out_to_loss_grad(model, weights, buffers, batch: Iterable[Tensor]) Tensor [source]#
Computes the (reweighting term Q in the paper)
- Args:
- model (torch.nn.Module):
torch model
- weights (Iterable[Tensor]):
functorch model weights
- buffers (Iterable[Tensor]):
functorch model buffers
- batch (Iterable[Tensor]):
input batch
- Returns:
- Tensor:
out-to-loss (reweighting term) for the input batch
- static get_output(model: Module, weights: Iterable[Tensor], buffers: Iterable[Tensor], images: Tensor, labels: Tensor) Tensor [source]#
For a given input \(z=(x, y)\) and model parameters \(\theta\), let \(p(z, \theta)\) be the softmax probability of the correct class. This method implements the model output function
\[\log(\frac{p(z, \theta)}{1 - p(z, \theta)}).\]It uses functional models from torch.func (previously functorch) to make the per-sample gradient computations (much) faster. For more details on what functional models are, and how to use them, please refer to https://pytorch.org/docs/stable/func.html and https://pytorch.org/functorch/stable/notebooks/per_sample_grads.html.
- Args:
- model (torch.nn.Module):
torch model
- weights (Iterable[Tensor]):
functorch model weights (added se we don’t break abstraction)
- buffers (Iterable[Tensor]):
functorch model buffers (added se we don’t break abstraction)
- images (Tensor):
input images
- labels (Tensor):
input labels
- Returns:
- Tensor:
model output for the given image-label pair \(z\)
- class trak.modelout_functions.TextClassificationModelOutput(temperature=1.0)[source]#
Bases:
AbstractModelOutput
Margin for text classification models. This assumes that the model takes in input_ids, token_type_ids, and attention_mask.
\[\text{logit}[\text{correct}] - \log\left(\sum_{i \neq \text{correct}} \exp(\text{logit}[i])\right)\]- get_out_to_loss_grad(model, weights, buffers, batch: Iterable[Tensor]) Tensor [source]#
See Sections 2 & 3 of our paper for more details on what the out-to-loss functions (in the notation of the paper, \(Q\)) are in the context of TRAK and how to use & design them.
- Args:
model (torch.nn.Module): model batch (Iterable[Tensor]): input batch
- Returns:
Tensor: gradient of the out-to-loss function
- static get_output(model, weights: Iterable[Tensor], buffers: Iterable[Tensor], input_id: Tensor, token_type_id: Tensor, attention_mask: Tensor, label: Tensor) Tensor [source]#
See Sections 2 & 3 of our paper for more details on what model output functions are in the context of TRAK and how to use & design them.
- Args:
- model (torch.nn.Module):
model
- batch (Iterable[Tensor]):
input batch
- Returns:
- Tensor:
model output function
trak.projectors module#
Projectors are used to project gradients to a lower-dimensional space. This 1) allows us to compute TRAK scores in a much more efficient manner, and 2) turns out to be act as a useful regularizer (see Appendix E.1 in our paper).
Here, we provide four implementations of the projector:
- NoOpProjector
(no-op)
- BasicSingleBlockProjector
(bare-bones, inefficient implementation)
- BasicProjector
(block-wise implementation)
- CudaProjector
(a fast implementation with a custom CUDA kernel)
- class trak.projectors.AbstractProjector(grad_dim: int, proj_dim: int, seed: int, proj_type: str | ProjectionType, device: str | device)[source]#
Bases:
ABC
Implementations of the Projector class must implement the
AbstractProjector.project()
method, which takes in model gradients and returns- abstract __init__(grad_dim: int, proj_dim: int, seed: int, proj_type: str | ProjectionType, device: str | device) None [source]#
Initializes hyperparameters for the projection.
- Args:
- grad_dim (int):
number of parameters in the model (dimension of the gradient vectors)
- proj_dim (int):
dimension after the projection
- seed (int):
random seed for the generation of the sketching (projection) matrix
- proj_type (Union[str, ProjectionType]):
the random projection (JL transform) guearantees that distances will be approximately preserved for a variety of choices of the random matrix (see e.g. https://arxiv.org/abs/1411.2404). Here, we provide an implementation for matrices with iid Gaussian entries and iid Rademacher entries.
- device (Union[str, torch.device]):
CUDA device to use
- abstract project(grads: Tensor, model_id: int) Tensor [source]#
Performs the random projection. Model ID is included so that we generate different projection matrices for every model ID.
- Args:
grads (Tensor): a batch of gradients to be projected model_id (int): a unique ID for a checkpoint
- Returns:
Tensor: the projected gradients
- class trak.projectors.BasicProjector(grad_dim: int, proj_dim: int, seed: int, proj_type: ProjectionType, device: device, block_size: int = 100, dtype: dtype = torch.float32, model_id=0, *args, **kwargs)[source]#
Bases:
AbstractProjector
A simple block-wise implementation of the projection. The projection matrix is generated on-device in blocks. The accumulated result across blocks is returned.
Note: This class will be significantly slower and have a larger memory footprint than the CudaProjector. It is recommended that you use this method only if the CudaProjector is not available to you – e.g. if you don’t have a CUDA-enabled device with compute capability >=7.0 (see https://developer.nvidia.com/cuda-gpus).
- __init__(grad_dim: int, proj_dim: int, seed: int, proj_type: ProjectionType, device: device, block_size: int = 100, dtype: dtype = torch.float32, model_id=0, *args, **kwargs) None [source]#
Initializes hyperparameters for the projection.
- Args:
- grad_dim (int):
number of parameters in the model (dimension of the gradient vectors)
- proj_dim (int):
dimension after the projection
- seed (int):
random seed for the generation of the sketching (projection) matrix
- proj_type (Union[str, ProjectionType]):
the random projection (JL transform) guearantees that distances will be approximately preserved for a variety of choices of the random matrix (see e.g. https://arxiv.org/abs/1411.2404). Here, we provide an implementation for matrices with iid Gaussian entries and iid Rademacher entries.
- device (Union[str, torch.device]):
CUDA device to use
- project(grads: Tensor, model_id: int) Tensor [source]#
Performs the random projection. Model ID is included so that we generate different projection matrices for every model ID.
- Args:
grads (Tensor): a batch of gradients to be projected model_id (int): a unique ID for a checkpoint
- Returns:
Tensor: the projected gradients
- class trak.projectors.BasicSingleBlockProjector(grad_dim: int, proj_dim: int, seed: int, proj_type: ProjectionType, device, dtype=torch.float32, model_id=0, *args, **kwargs)[source]#
Bases:
AbstractProjector
A bare-bones, inefficient implementation of the projection, which simply calls torch’s matmul for the projection step.
Note: for most model sizes (e.g. even for ResNet18), and small projection dimensions (e.g. anything > 100) this method will OOM on an A100.
Unless you have a good reason to use this class (I cannot think of one, I added this only for testing purposes), use instead the CudaProjector or BasicProjector.
- __init__(grad_dim: int, proj_dim: int, seed: int, proj_type: ProjectionType, device, dtype=torch.float32, model_id=0, *args, **kwargs) None [source]#
Initializes hyperparameters for the projection.
- Args:
- grad_dim (int):
number of parameters in the model (dimension of the gradient vectors)
- proj_dim (int):
dimension after the projection
- seed (int):
random seed for the generation of the sketching (projection) matrix
- proj_type (Union[str, ProjectionType]):
the random projection (JL transform) guearantees that distances will be approximately preserved for a variety of choices of the random matrix (see e.g. https://arxiv.org/abs/1411.2404). Here, we provide an implementation for matrices with iid Gaussian entries and iid Rademacher entries.
- device (Union[str, torch.device]):
CUDA device to use
- project(grads: Tensor, model_id: int) Tensor [source]#
Performs the random projection. Model ID is included so that we generate different projection matrices for every model ID.
- Args:
grads (Tensor): a batch of gradients to be projected model_id (int): a unique ID for a checkpoint
- Returns:
Tensor: the projected gradients
- class trak.projectors.ChunkedCudaProjector(projector_per_chunk: list, max_chunk_size: int, params_per_chunk: list, feat_bs: int, device: device, dtype: dtype)[source]#
Bases:
object
- class trak.projectors.CudaProjector(grad_dim: int, proj_dim: int, seed: int, proj_type: ProjectionType, device, max_batch_size: int, *args, **kwargs)[source]#
Bases:
AbstractProjector
A performant implementation of the projection for CUDA with compute capability >= 7.0.
- __init__(grad_dim: int, proj_dim: int, seed: int, proj_type: ProjectionType, device, max_batch_size: int, *args, **kwargs) None [source]#
- Args:
- grad_dim (int):
Number of parameters
- proj_dim (int):
Dimension we project to during the projection step
- seed (int):
Random seed
- proj_type (ProjectionType):
Type of randomness to use for projection matrix (rademacher or normal)
- device:
CUDA device
- max_batch_size (int):
Explicitly constraints the batch size the CudaProjector is going to use for projection. Set this if you get a ‘The batch size of the CudaProjector is too large for your GPU’ error. Must be either 8, 16, or 32.
- Raises:
- ValueError:
When attempting to use this on a non-CUDA device
- ModuleNotFoundError:
When fast_jl is not installed
- project(grads: dict | Tensor, model_id: int) Tensor [source]#
Performs the random projection. Model ID is included so that we generate different projection matrices for every model ID.
- Args:
grads (Tensor): a batch of gradients to be projected model_id (int): a unique ID for a checkpoint
- Returns:
Tensor: the projected gradients
- class trak.projectors.NoOpProjector(grad_dim: int = 0, proj_dim: int = 0, seed: int = 0, proj_type: str | ProjectionType = 'na', device: str | device = 'cuda', *args, **kwargs)[source]#
Bases:
AbstractProjector
A projector that returns the gradients as they are, i.e., implements
projector.project(grad) = grad
.- __init__(grad_dim: int = 0, proj_dim: int = 0, seed: int = 0, proj_type: str | ProjectionType = 'na', device: str | device = 'cuda', *args, **kwargs) None [source]#
Initializes hyperparameters for the projection.
- Args:
- grad_dim (int):
number of parameters in the model (dimension of the gradient vectors)
- proj_dim (int):
dimension after the projection
- seed (int):
random seed for the generation of the sketching (projection) matrix
- proj_type (Union[str, ProjectionType]):
the random projection (JL transform) guearantees that distances will be approximately preserved for a variety of choices of the random matrix (see e.g. https://arxiv.org/abs/1411.2404). Here, we provide an implementation for matrices with iid Gaussian entries and iid Rademacher entries.
- device (Union[str, torch.device]):
CUDA device to use
trak.savers module#
This module contains classes that save TRAK results, intermediate values, and
metadata to disk. The AbstractSaver
class defines the interface for
savers. Then, we provide one implementation:
- MmapSaver
: A saver that uses memory-mapped numpy arrays. This makes
loading and saving small chunks of data (e.g.) during featurizing feasible without loading the entire file into memory.
- class trak.savers.AbstractSaver(save_dir: Path | str, metadata: Iterable, load_from_save_dir: bool, logging_level: int, use_half_precision: bool)[source]#
Bases:
ABC
Implementations of Saver class must implement getters and setters for TRAK features and scores, as well as intermediate values like gradients and “out-to-loss-gradient”.
The Saver class also handles the recording of metadata associated with each TRAK run. For example, hyperparameters like “JL dimension” – the dimension used for the dimensionality reduction step of TRAK (Johnson-Lindenstrauss projection).
- abstract __init__(save_dir: Path | str, metadata: Iterable, load_from_save_dir: bool, logging_level: int, use_half_precision: bool) None [source]#
Creates the save directory if it doesn’t already exist. If the save directory already exists, it validates that the current TRAKer class has the same hyperparameters (metadata) as the one specified in the save directory. Next, this method loads any existing computed results / intermediate values in the save directory. Last, it initalizes the self.current_store attributes which will be later populated with data for the “current” model ID of the TRAKer instance.
- Args:
- save_dir (Union[Path, str]): directory to save TRAK results,
intermediate values, and metadata
- metadata (Iterable): a dictionary containing metadata related to the
TRAKer class
- load_from_save_dir (bool): If True, the Saver instance will attempt
to load existing metadata from save_dir. May lead to I/O issues if multiple Saver instances ran in parallel have this flag set to True. See the SLURM tutorial in our docs for more details.
- logging_level (int):
logging level for the logger associated with this Saver instance
- use_half_precision (bool):
If True, the Saver instance will save all results and intermediate values in half precision (float16).
- abstract del_grads(model_id: int, target: bool) None [source]#
Delete the intermediate values (gradients) for a given model id
- Args:
- model_id (int):
a unique ID for a checkpoint
- target (bool):
if True, delete the gradients of the target samples, otherwise delete the train set gradients.
- abstract init_experiment(model_id: int) None [source]#
Initializes store for a given experiment & model ID (checkpoint).
- Args:
- model_id (int):
a unique ID for a checkpoint
- abstract init_store(model_id: int) None [source]#
Initializes store for a given model ID (checkpoint).
- Args:
- model_id (int):
a unique ID for a checkpoint
- abstract load_current_store(model_id: int) None [source]#
Populates the self.current_store attributes with data for the given model ID (checkpoint).
- Args:
- model_id (int):
a unique ID for a checkpoint
- abstract register_model_id(model_id: int) None [source]#
Create metadata for a new model ID (checkpoint).
- Args:
- model_id (int):
a unique ID for a checkpoint
- class trak.savers.MmapSaver(save_dir, metadata, train_set_size, proj_dim, load_from_save_dir, logging_level, use_half_precision)[source]#
Bases:
AbstractSaver
A saver that uses memory-mapped numpy arrays. This makes small reads and writes (e.g.) during featurizing feasible without loading the entire file into memory.
- __init__(save_dir, metadata, train_set_size, proj_dim, load_from_save_dir, logging_level, use_half_precision) None [source]#
Creates the save directory if it doesn’t already exist. If the save directory already exists, it validates that the current TRAKer class has the same hyperparameters (metadata) as the one specified in the save directory. Next, this method loads any existing computed results / intermediate values in the save directory. Last, it initalizes the self.current_store attributes which will be later populated with data for the “current” model ID of the TRAKer instance.
- Args:
- save_dir (Union[Path, str]): directory to save TRAK results,
intermediate values, and metadata
- metadata (Iterable): a dictionary containing metadata related to the
TRAKer class
- load_from_save_dir (bool): If True, the Saver instance will attempt
to load existing metadata from save_dir. May lead to I/O issues if multiple Saver instances ran in parallel have this flag set to True. See the SLURM tutorial in our docs for more details.
- logging_level (int):
logging level for the logger associated with this Saver instance
- use_half_precision (bool):
If True, the Saver instance will save all results and intermediate values in half precision (float16).
- del_grads(model_id)[source]#
Delete the intermediate values (gradients) for a given model id
- Args:
- model_id (int):
a unique ID for a checkpoint
- target (bool):
if True, delete the gradients of the target samples, otherwise delete the train set gradients.
- init_experiment(exp_name, num_targets, model_id) None [source]#
Initializes store for a given experiment & model ID (checkpoint).
- Args:
- model_id (int):
a unique ID for a checkpoint
- init_store(model_id) None [source]#
Initializes store for a given model ID (checkpoint).
- Args:
- model_id (int):
a unique ID for a checkpoint
- load_current_store(model_id: int, exp_name: str | None = None, exp_num_targets: int | None = -1, mode: str | None = 'r+') None [source]#
This method uses numpy memmaps for serializing the TRAK results and intermediate values.
- Args:
- model_id (int):
a unique ID for a checkpoint
- exp_name (str, optional):
Experiment name for which to load the features. If None, loads the train (source) features for a model ID. Defaults to None.
- exp_num_targets (int, optional):
Number of targets for the experiment. Specify only when exp_name is not None. Defaults to -1.
- mode (str, optional):
Defaults to ‘r+’.
- register_model_id(model_id: int, _allow_featurizing_already_registered: bool) None [source]#
This method 1) checks if the model ID already exists in the save dir 2) if yes, it raises an error since model IDs must be unique 3) if not, it creates a metadata file for it and initalizes store mmaps
- Args:
- model_id (int):
a unique ID for a checkpoint
- Raises:
- ModelIDException:
raised if the model ID to be registered already exists
trak.score_computers module#
Computing scores for the TRAK algorithm from pre-computed projected gradients
involves a number of matrix multiplications. This module contains classes that
perform these operations. The AbstractScoreComputer
class defines the
interface for score computers. Then, we provide two implementations:
- BasicSingleBlockScoreComputer
: A bare-bones implementation, mostly for
testing purposes.
BasicScoreComputer
: A more sophisticated implementation that doesblock-wise matrix multiplications to avoid OOM errors.
- class trak.score_computers.AbstractScoreComputer(dtype, device)[source]#
Bases:
ABC
The
ScoreComputer
class Implementations of the ScoreComputer class must implement three methods: -get_xtx
-get_x_xtx_inv
-get_scores
- abstract get_scores(features: Tensor, target_grads: Tensor, accumulator: Tensor) None [source]#
Computes the scores for a given set of features and target gradients. In particular, this function takes in a matrix of features \(\Phi=X(X^ op X)^{-1}\), computed by the
get_x_xtx_inv
method, and a matrix of target (projected) gradients \(X_{target}\). Then, it computes the scores as \(\Phi X_{target}^ op\). The resulting matrix has shape(n, m)
, where \(n\) is the number of training examples and \(m\) is the number of target examples.The
accumulator
argument is used to store the result of the computation. This is useful when computing scores for multiple model checkpoints, as it allows us to re-use the same memory for the score matrix.- Args:
features (Tensor): features \(\Phi\) of shape
(n, p)
. target_grads (Tensor):target projected gradients \(X_{target}\) of shape
(m, p)
.accumulator (Tensor): accumulator of shape
(n, m)
.
- abstract get_x_xtx_inv(grads: Tensor, xtx: Tensor) Tensor [source]#
Computes \(X(X^ op X)^{-1}\), where \(X\) is the matrix of projected gradients. Here, the shape of \(X\) is
(n, p)
, where \(n\) is the number of training examples and \(p\) is the dimension of the projection. This function takes as input the pre-computed \(X^ op X\) matrix, which is computed by theget_xtx
method.- Args:
grads (Tensor): projected gradients \(X\) of shape
(n, p)
. xtx (Tensor): \(X^ op X\) of shape(p, p)
.- Returns:
Tensor: \(X(X^ op X)^{-1}\) of shape
(n, p)
.
- abstract get_xtx(grads: Tensor) Tensor [source]#
Computes \(X^ op X\), where \(X\) is the matrix of projected gradients. Here, the shape of \(X\) is
(n, p)
, where \(n\) is the number of training examples and \(p\) is the dimension of the projection.- Args:
grads (Tensor): projected gradients of shape
(n, p)
.- Returns:
Tensor: \(X^ op X\) of shape
(p, p)
.
- class trak.score_computers.BasicScoreComputer(dtype: dtype, device: device, CUDA_MAX_DIM_SIZE: int = 20000, logging_level=20, lambda_reg: float = 0.0)[source]#
Bases:
AbstractScoreComputer
An implementation of
ScoreComputer
that computes matmuls in a block-wise manner.- __init__(dtype: dtype, device: device, CUDA_MAX_DIM_SIZE: int = 20000, logging_level=20, lambda_reg: float = 0.0) None [source]#
- Args:
dtype (torch.dtype): device (Union[str, torch.device]):
torch device to do matmuls on
- CUDA_MAX_DIM_SIZE (int, optional):
Size of block for block-wise matmuls. Defaults to 100_000.
- logging_level (logging level, optional):
Logging level for the logger. Defaults to logging.info.
- lambda_reg (int):
regularization term for l2 reg on xtx
- get_scores(features: Tensor, target_grads: Tensor, accumulator: Tensor) Tensor [source]#
Computes the scores for a given set of features and target gradients. In particular, this function takes in a matrix of features \(\Phi=X(X^ op X)^{-1}\), computed by the
get_x_xtx_inv
method, and a matrix of target (projected) gradients \(X_{target}\). Then, it computes the scores as \(\Phi X_{target}^ op\). The resulting matrix has shape(n, m)
, where \(n\) is the number of training examples and \(m\) is the number of target examples.The
accumulator
argument is used to store the result of the computation. This is useful when computing scores for multiple model checkpoints, as it allows us to re-use the same memory for the score matrix.- Args:
features (Tensor): features \(\Phi\) of shape
(n, p)
. target_grads (Tensor):target projected gradients \(X_{target}\) of shape
(m, p)
.accumulator (Tensor): accumulator of shape
(n, m)
.
- get_x_xtx_inv(grads: Tensor, xtx: Tensor) Tensor [source]#
Computes \(X(X^ op X)^{-1}\), where \(X\) is the matrix of projected gradients. Here, the shape of \(X\) is
(n, p)
, where \(n\) is the number of training examples and \(p\) is the dimension of the projection. This function takes as input the pre-computed \(X^ op X\) matrix, which is computed by theget_xtx
method.- Args:
grads (Tensor): projected gradients \(X\) of shape
(n, p)
. xtx (Tensor): \(X^ op X\) of shape(p, p)
.- Returns:
Tensor: \(X(X^ op X)^{-1}\) of shape
(n, p)
.
- get_xtx(grads: Tensor) Tensor [source]#
Computes \(X^ op X\), where \(X\) is the matrix of projected gradients. Here, the shape of \(X\) is
(n, p)
, where \(n\) is the number of training examples and \(p\) is the dimension of the projection.- Args:
grads (Tensor): projected gradients of shape
(n, p)
.- Returns:
Tensor: \(X^ op X\) of shape
(p, p)
.
- class trak.score_computers.BasicSingleBlockScoreComputer(dtype, device)[source]#
Bases:
AbstractScoreComputer
A bare-bones implementation of
ScoreComputer
that will likely OOM for almost all applications. Here for testing purposes only. Unless you have a good reason not to, you should useBasicScoreComputer()
instead.- get_scores(features: Tensor, target_grads: Tensor, accumulator: Tensor) None [source]#
Computes the scores for a given set of features and target gradients. In particular, this function takes in a matrix of features \(\Phi=X(X^ op X)^{-1}\), computed by the
get_x_xtx_inv
method, and a matrix of target (projected) gradients \(X_{target}\). Then, it computes the scores as \(\Phi X_{target}^ op\). The resulting matrix has shape(n, m)
, where \(n\) is the number of training examples and \(m\) is the number of target examples.The
accumulator
argument is used to store the result of the computation. This is useful when computing scores for multiple model checkpoints, as it allows us to re-use the same memory for the score matrix.- Args:
features (Tensor): features \(\Phi\) of shape
(n, p)
. target_grads (Tensor):target projected gradients \(X_{target}\) of shape
(m, p)
.accumulator (Tensor): accumulator of shape
(n, m)
.
- get_x_xtx_inv(grads: Tensor, xtx: Tensor) Tensor [source]#
Computes \(X(X^ op X)^{-1}\), where \(X\) is the matrix of projected gradients. Here, the shape of \(X\) is
(n, p)
, where \(n\) is the number of training examples and \(p\) is the dimension of the projection. This function takes as input the pre-computed \(X^ op X\) matrix, which is computed by theget_xtx
method.- Args:
grads (Tensor): projected gradients \(X\) of shape
(n, p)
. xtx (Tensor): \(X^ op X\) of shape(p, p)
.- Returns:
Tensor: \(X(X^ op X)^{-1}\) of shape
(n, p)
.
- get_xtx(grads: Tensor) Tensor [source]#
Computes \(X^ op X\), where \(X\) is the matrix of projected gradients. Here, the shape of \(X\) is
(n, p)
, where \(n\) is the number of training examples and \(p\) is the dimension of the projection.- Args:
grads (Tensor): projected gradients of shape
(n, p)
.- Returns:
Tensor: \(X^ op X\) of shape
(p, p)
.
trak.utils module#
- trak.utils.get_matrix_mult(features: Tensor, target_grads: Tensor, target_dtype: dtype | None = None, batch_size: int = 8096, use_blockwise: bool = False) Tensor [source]#
Computes features @ target_grads.T. If the output matrix is too large to fit in memory, it will be computed in blocks.
- Args:
- features (Tensor):
The first matrix to multiply.
- target_grads (Tensor):
The second matrix to multiply.
- target_dtype (torch.dtype, optional):
The dtype of the output matrix. If None, defaults to the dtype of features. Defaults to None.
- batch_size (int, optional):
The batch size to use for blockwise matrix multiplication. Defaults to 8096.
- use_blockwise (bool, optional):
Whether or not to use blockwise matrix multiplication. Defaults to False.
- trak.utils.get_matrix_mult_blockwise(features: Tensor, target_grads: Tensor, target_dtype: type, bs: int)[source]#
- trak.utils.get_matrix_mult_standard(features: Tensor, target_grads: Tensor, target_dtype: type)[source]#
- trak.utils.get_parameter_chunk_sizes(model: Module, batch_size: int)[source]#
The
CudaProjector
supports projecting when the product of the number of parameters and the batch size is less than the the max value of int32. This function computes the number of parameters that can be projected at once for a given model and batch size.The method returns a tuple containing the maximum number of parameters that can be projected at once and a list of the actual number of parameters in each chunk (a sequence of paramter groups). Used in
ChunkedCudaProjector
.
- trak.utils.parameters_to_vector(parameters) Tensor [source]#
Same as https://pytorch.org/docs/stable/generated/torch.nn.utils.parameters_to_vector.html but with
reshape
instead ofview
to avoid a pesky error.
- trak.utils.vectorize(g, arr=None, device='cuda') Tensor [source]#
records result into arr
gradients are given as a dict
(name_w0: grad_w0, ... name_wp: grad_wp)
wherep
is the number of weight matrices. eachgrad_wi
has shape[batch_size, ...]
this function flattensg
to have shape[batch_size, num_params]
.