API reference#

trak.traker module#

This module contains the main TRAKer class, which is the front-facing class for TRAK. See the README and docs for example usage.

In short, methods of the TRAKer class are used to compute TRAK scores for a set of model checkpoints, a set of target samples, and a set of train samples. This is done in two stages: - Featurizing. TRAKer.featurize() and TRAKer.finalize_features()

are used to compute the TRAK features for a set of model checkpoints and a set of train samples.

  • Scoring. TRAKer.start_scoring_checkpoint(), TRAKer.score(), and TRAKer.finalize_scores() are used to compute the TRAK scores for a set of target samples, given the TRAK features computed in the previous step.

class trak.traker.TRAKer(model: ~torch.nn.modules.module.Module, task: ~trak.modelout_functions.AbstractModelOutput | str, train_set_size: int, save_dir: str = './trak_results', load_from_save_dir: bool = True, device: str | ~torch.device = 'cuda', gradient_computer: ~trak.gradient_computers.AbstractGradientComputer = <class 'trak.gradient_computers.FunctionalGradientComputer'>, projector: ~trak.projectors.AbstractProjector | None = None, saver: ~trak.savers.AbstractSaver | None = None, score_computer: ~trak.score_computers.AbstractScoreComputer | None = None, proj_dim: int = 2048, logging_level=20, use_half_precision: bool = True, proj_max_batch_size: int = 32, projector_seed: int = 0, grad_wrt: ~typing.Iterable[str] | None = None, lambda_reg: float = 0.0)[source]#

Bases: object

The main front-facing class for TRAK. See the README and docs for example usage.

__init__(model: ~torch.nn.modules.module.Module, task: ~trak.modelout_functions.AbstractModelOutput | str, train_set_size: int, save_dir: str = './trak_results', load_from_save_dir: bool = True, device: str | ~torch.device = 'cuda', gradient_computer: ~trak.gradient_computers.AbstractGradientComputer = <class 'trak.gradient_computers.FunctionalGradientComputer'>, projector: ~trak.projectors.AbstractProjector | None = None, saver: ~trak.savers.AbstractSaver | None = None, score_computer: ~trak.score_computers.AbstractScoreComputer | None = None, proj_dim: int = 2048, logging_level=20, use_half_precision: bool = True, proj_max_batch_size: int = 32, projector_seed: int = 0, grad_wrt: ~typing.Iterable[str] | None = None, lambda_reg: float = 0.0) None[source]#
Args:
model (torch.nn.Module):

model to use for TRAK

task (Union[AbstractModelOutput, str]):

Type of model that TRAK will be ran on. Accepts either one of the following strings: 1) image_classification 2) text_classification 3) clip or an instance of some implementation of the abstract class AbstractModelOutput.

train_set_size (int):

Size of the train set that TRAK is featurizing

save_dir (str, optional):

Directory to save final TRAK scores, intermediate results, and metadata. Defaults to :code:’./trak_results’.

load_from_save_dir (bool, optional):

If True, the :class`.TRAKer` instance will attempt to load existing metadata from save_dir. May lead to I/O issues if multiple TRAKer instances ran in parallel have this flag set to True. See the SLURM tutorial for more details.

device (Union[str, torch.device], optional):

torch device on which to do computations. Defaults to ‘cuda’.

gradient_computer (AbstractGradientComputer, optional):

Class to use to get per-example gradients. See AbstractGradientComputer for more details. Defaults to FunctionalGradientComputer.

projector (Optional[AbstractProjector], optional):

Either set proj_dim and a CudaProjector Rademacher projector will be used or give a custom subclass of AbstractProjector class and leave proj_dim as None. Defaults to None.

saver (Optional[AbstractSaver], optional):

Class to use for saving intermediate results and final TRAK scores to RAM/disk. If None, the MmapSaver will be used. Defaults to None.

score_computer (Optional[AbstractScoreComputer], optional):

Class to use for computing the final TRAK scores. If None, the BasicScoreComputer will be used. Defaults to None.

proj_dim (int, optional):

Dimension of the projected TRAK features. See Section 4.3 of our paper for more details. Defaults to 2048.

logging_level (int, optional):

Logging level for TRAK loggers. Defaults to logging.INFO.

use_half_precision (bool, optional):

If True, TRAK will use half precision (float16) for all computations and arrays will be stored in float16. Otherwise, it will use float32. Defaults to True.

proj_max_batch_size (int):

Batch size used by fast_jl if the CudaProjector is used. Must be a multiple of 8. The maximum batch size is 32 for A100 GPUs, 16 for V100 GPUs, 40 for H100 GPUs. Defaults to 32.

projector_seed (int):

Random seed used by the projector. Defaults to 0.

grad_wrt (Optional[Iterable[str]], optional):

If not None, the gradients will be computed only with respect to the parameters specified in this list. The list should contain the names of the parameters to compute gradients with respect to, as they appear in the model’s state dictionary. If None, gradients are taken with respect to all model parameters. Defaults to None.

lambda_reg (float):

The \(\ell_2\) (ridge) regularization penalty added to the \(XTX\) term in score computers when computing the matrix inverse \((XTX)^{-1}\). Defaults to 0.

featurize(batch: Iterable[Tensor], inds: Iterable[int] | None = None, num_samples: int | None = None) None[source]#

Creates TRAK features for the given batch by computing the gradient of the model output function and projecting it. In the notation of the paper, for an input pair \(z=(x,y)\), model parameters \(\theta\), and JL projection matrix \(P\), this method computes \(P^\top \nabla_\theta f(z_i, \theta)\). Additionally, this method computes the gradient of the out-to-loss function (in the notation of the paper, the \(Q\) term in Section 3.4).

Either inds or num_samples must be specified. Using num_samples will write sequentially into the internal store of the TRAKer().

Args:
batch (Iterable[Tensor]):

input batch

inds (Optional[Iterable[int]], optional):

Indices of the batch samples in the train set. Defaults to None.

num_samples (Optional[int], optional):

Number of samples in the batch. Defaults to None.

finalize_features(model_ids: Iterable[int] | None = None, del_grads: bool = False) None[source]#

For a set of checkpoints \(C\) (specified by model IDs), and gradients \(\{ \Phi_c \}_{c\in C}\), this method computes \(\Phi_c (\Phi_c^\top\Phi_c)^{-1}\) for all \(c\in C\) and stores the results in the internal store of the TRAKer() class.

Args:
model_ids (Iterable[int], optional): A list of model IDs for which

features should be finalized. If None, features are finalized for all model IDs in the save_dir of the TRAKer class. Defaults to None.

finalize_scores(exp_name: str, model_ids: Iterable[int] | None = None, allow_skip: bool = False) Tensor[source]#

This method computes the final TRAK scores for the given targets, train samples, and model checkpoints (specified by model IDs).

Args:
exp_name (str):

Experiment name. Each experiment should have a unique name, and it corresponds to a set of targets being scored. The experiment name is used as the name for saving the target features, as well as scores produced by this method in the save_dir of the TRAKer class.

model_ids (Iterable[int], optional):

A list of model IDs for which scores should be finalized. If None, scores are computed for all model IDs in the save_dir of the TRAKer class. Defaults to None.

allow_skip (bool, optional):

If True, raises only a warning, instead of an error, when target gradients are not computed for a given model ID. Defaults to False.

Returns:

Tensor: TRAK scores

init_projector(projector: AbstractProjector | None, proj_dim: int, proj_max_batch_size: int) None[source]#

Initialize the projector for a traker class

Args:
projector (Optional[AbstractProjector]):

JL projector to use. If None, a CudaProjector will be used (if possible).

proj_dim (int):

Dimension of the projected gradients and TRAK features.

proj_max_batch_size (int):

Batch size used by fast_jl if the CudaProjector is used. Must be a multiple of 8. The maximum batch size is 32 for A100 GPUs, 16 for V100 GPUs, 40 for H100 GPUs.

load_checkpoint(checkpoint: Iterable[Tensor], model_id: int, _allow_featurizing_already_registered=False) None[source]#

Loads state dictionary for the given checkpoint; initializes arrays to store TRAK features for that checkpoint, tied to the model ID.

Args:
checkpoint (Iterable[Tensor]):

state_dict to load

model_id (int):

a unique ID for a checkpoint

_allow_featurizing_already_registered (bool, optional):

Only use if you want to override the default behaviour that featurize is forbidden on already registered model IDs. Defaults to None.

score(batch: Iterable[Tensor], inds: Iterable[int] | None = None, num_samples: int | None = None) None[source]#

This method computes the (intermediate per-checkpoint) TRAK scores for a batch of targets and stores them in the internal store of the TRAKer class.

Either inds or num_samples must be specified. Using num_samples will write sequentially into the internal store of the TRAKer.

Args:
batch (Iterable[Tensor]):

input batch

inds (Optional[Iterable[int]], optional):

Indices of the batch samples in the train set. Defaults to None.

num_samples (Optional[int], optional):

Number of samples in the batch. Defaults to None.

start_scoring_checkpoint(exp_name: str, checkpoint: Iterable[Tensor], model_id: int, num_targets: int) None[source]#

This method prepares the internal store of the TRAKer class to start computing scores for a set of targets.

Args:
exp_name (str):

Experiment name. Each experiment should have a unique name, and it corresponds to a set of targets being scored. The experiment name is used as the name for saving the target features, as well as scores produced by this method in the save_dir of the TRAKer class.

checkpoint (Iterable[Tensor]):

model checkpoint (state dict)

model_id (int):

a unique ID for a checkpoint

num_targets (int):

number of targets to score

trak.gradient_computers module#

Computing features for the TRAK algorithm involves computing (and projecting) per-sample gradients. This module contains classes that compute these per-sample gradients. The AbstractFeatureComputer class defines the interface for such gradient computers. Then, we provide two implementations: - FunctionalFeatureComputer: A fast implementation that uses

torch.func to vectorize the computation of per-sample gradients, and thus fully levereage parallelism.

  • IterativeFeatureComputer: A more naive implementation that only uses native pytorch operations (i.e. no torch.func), and computes per-sample gradients in a for-loop. This is often much slower than the functional version, but it is useful if you cannot use torch.func, e.g., if you have an old version of pytorch that does not support it, or if your application is not supported by torch.func.

class trak.gradient_computers.AbstractGradientComputer(model: Module, task: AbstractModelOutput, grad_dim: int | None = None, dtype: dtype | None = torch.float16, device: device | None = 'cuda')[source]#

Bases: ABC

Implementations of the GradientComputer class should allow for per-sample gradients. This is behavior is enabled with three methods:

  • the load_model_params() method, well, loads model parameters. It can be as simple as a self.model.load_state_dict(..)

  • the compute_per_sample_grad() method computes per-sample gradients of the chosen model output function with respect to the model’s parameters.

  • the compute_loss_grad() method computes the gradients of the loss function with respect to the model output (which should be a scalar) for every sample.

abstract __init__(model: Module, task: AbstractModelOutput, grad_dim: int | None = None, dtype: dtype | None = torch.float16, device: device | None = 'cuda') None[source]#

Initializes attributes, nothing too interesting happening.

Args:
model (torch.nn.Module):

model

task (AbstractModelOutput):

task (model output function)

grad_dim (int, optional):

Size of the gradients (number of model parameters). Defaults to None.

dtype (torch.dtype, optional):

Torch dtype of the gradients. Defaults to torch.float16.

device (torch.device, optional):

Torch device where gradients will be stored. Defaults to ‘cuda’.

abstract compute_loss_grad(batch: Iterable[Tensor], batch_size: int) Tensor[source]#
abstract compute_per_sample_grad(batch: Iterable[Tensor]) Tensor[source]#
abstract load_model_params(model) None[source]#
class trak.gradient_computers.FunctionalGradientComputer(model: Module, task: AbstractModelOutput, grad_dim: int, dtype: dtype, device: device, grad_wrt: Iterable[str] | None = None)[source]#

Bases: AbstractGradientComputer

__init__(model: Module, task: AbstractModelOutput, grad_dim: int, dtype: dtype, device: device, grad_wrt: Iterable[str] | None = None) None[source]#

Initializes attributes, and loads model parameters.

Args:
grad_wrt (list[str], optional):

A list of parameter names for which to keep gradients. If None, gradients are taken with respect to all model parameters. Defaults to None.

compute_loss_grad(batch: Iterable[Tensor]) Tensor[source]#

Computes the gradient of the loss with respect to the model output

\[\partial \ell / \partial \text{(model output)}\]

Note: For all applications we considered, we analytically derived the out-to-loss gradient, thus avoiding the need to do any backward passes (let alone per-sample grads). If for your application this is not feasible, you’ll need to subclass this and modify this method to have a structure similar to the one of FunctionalGradientComputer:.get_output(), i.e. something like:

grad_out_to_loss = grad(self.model_out_to_loss_grad, ...)
grads = vmap(grad_out_to_loss, ...)
...
Args:
batch (Iterable[Tensor]):

batch of data

Returns:
Tensor:

The gradient of the loss with respect to the model output.

compute_per_sample_grad(batch: Iterable[Tensor]) Tensor[source]#

Uses functorch’s vmap (see https://pytorch.org/functorch/stable/generated/functorch.vmap.html#functorch.vmap for more details) to vectorize the computations of per-sample gradients.

Doesn’t use batch_size; only added to follow the abstract method signature.

Args:
batch (Iterable[Tensor]):

batch of data

Returns:
dict[Tensor]:

A dictionary where each key is a parameter name and the value is the gradient tensor for that parameter.

load_model_params(model) None[source]#

Given a a torch.nn.Module model, inits/updates the (functional) weights and buffers. See https://pytorch.org/docs/stable/func.html for more details on torch.func’s functional models.

Args:
model (torch.nn.Module):

model to load

class trak.gradient_computers.IterativeGradientComputer(model, task: AbstractModelOutput, grad_dim: int, dtype: dtype, device: device, grad_wrt: Iterable[str] | None = None)[source]#

Bases: AbstractGradientComputer

__init__(model, task: AbstractModelOutput, grad_dim: int, dtype: dtype, device: device, grad_wrt: Iterable[str] | None = None) None[source]#

Initializes attributes, nothing too interesting happening.

Args:
model (torch.nn.Module):

model

task (AbstractModelOutput):

task (model output function)

grad_dim (int, optional):

Size of the gradients (number of model parameters). Defaults to None.

dtype (torch.dtype, optional):

Torch dtype of the gradients. Defaults to torch.float16.

device (torch.device, optional):

Torch device where gradients will be stored. Defaults to ‘cuda’.

compute_loss_grad(batch: Iterable[Tensor]) Tensor[source]#

Computes the gradient of the loss with respect to the model output .. math:

\partial \ell / \partial \text{(model output)}

Note: For all applications we considered, we analytically derived the out-to-loss gradient, thus avoiding the need to do any backward passes (let alone per-sample grads). If for your application this is not feasible, you’ll need to subclass this and modify this method to have a structure similar to the one of IterativeGradientComputer.get_output(), i.e. something like: .. code-block:: python

out_to_loss = self.model_out_to_loss(…) for ind in range(batch_size):

grads[ind] = torch.autograd.grad(out_to_loss[ind], …)

Args:
batch (Iterable[Tensor]):

batch of data

Returns:
Tensor:

The gradient of the loss with respect to the model output.

compute_per_sample_grad(batch: Iterable[Tensor]) Tensor[source]#

Computes per-sample gradients of the model output function This method does not leverage vectorization (and is hence much slower than its equivalent in FunctionalGradientComputer). We recommend that you use this only if torch.func is not available to you, e.g. if you have a (very) old version of pytorch. Args:

batch (Iterable[Tensor]):

batch of data

Returns:
Tensor:

gradients of the model output function of each sample in the batch with respect to the model’s parameters.

load_model_params(model) Tensor[source]#

trak.modelout_functions module#

Here we provide an abstract “model output” class AbstractModelOutput, together with a number of subclasses for particular applications (vision, language, etc):

These classes implement methods that transform input batches to the desired model output (e.g. logits, loss, etc). See Sections 2 & 3 of our paper for more details on what model output functions are in the context of TRAK and how to use & design them.

See, e.g. this tutorial for an example on how to subclass AbstractModelOutput for a task of your choice.

class trak.modelout_functions.AbstractModelOutput[source]#

Bases: ABC

See, e.g. this tutorial for an example on how to subclass AbstractModelOutput for a task of your choice.

Subclasses must implement:

  • a get_output method that takes in a batch of inputs and model weights to produce outputs that TRAK will be trained to predict. In the notation of the paper, get_output should return \(f(z,\theta)\)

  • a get_out_to_loss_grad method that takes in a batch of inputs and model weights to produce the gradient of the function that transforms the model outputs above into the loss with respect to the batch. In the notation of the paper, get_out_to_loss_grad returns (entries along the diagonal of) \(Q\).

abstract __init__() None[source]#
abstract get_out_to_loss_grad(model, batch: Iterable[Tensor]) Tensor[source]#

See Sections 2 & 3 of our paper for more details on what the out-to-loss functions (in the notation of the paper, \(Q\)) are in the context of TRAK and how to use & design them.

Args:

model (torch.nn.Module): model batch (Iterable[Tensor]): input batch

Returns:

Tensor: gradient of the out-to-loss function

abstract get_output(model, batch: Iterable[Tensor]) Tensor[source]#

See Sections 2 & 3 of our paper for more details on what model output functions are in the context of TRAK and how to use & design them.

Args:
model (torch.nn.Module):

model

batch (Iterable[Tensor]):

input batch

Returns:
Tensor:

model output function

class trak.modelout_functions.CLIPModelOutput(temperature: float | None = None, simulated_batch_size: int = 300)[source]#

Bases: AbstractModelOutput

Margin for multimodal contrastive learning (CLIP). See Section 5.1 of our paper for more details. Compatible with the open_clip implementation of CLIP.

Raises:

AssertionError: this model output function requires using additional CLIP embeddings, which are computed using the get_embeddings() method. This method should be invoked before featurizing.

__init__(temperature: float | None = None, simulated_batch_size: int = 300) None[source]#
Args:
temperature (float, optional):

Temperature to use inside the softmax for the out-to-loss function. If None, CLIP’s logit_scale is used. Defaults to None

simulated_batch_size (int, optional):

Size of the “simulated” batch size for the model output function. See Section 5.1 of the TRAK paper for more details. Defaults to 300.

static get_embeddings(model, loader, batch_size: int, embedding_dim: int, size: int = 50000, preprocess_fn_img=None, preprocess_fn_txt=None) None[source]#

Computes (image and text) embeddings and saves them in the class attributes image_embeddings and text_embeddings.

Args:
model (torch.nn.Module):

model

loader ():

data loader

batch_size (int):

input batch size

size (int, optional):

Maximum number of embeddings to compute. Defaults to 50_000.

embedding_dim (int, optional):

Dimension of CLIP embedding. Defaults to 1024.

preprocess_fn_img (func, optional):

Transforms to apply to images from the loader before forward pass. Defaults to None.

preprocess_fn_txt (func, optional):

Transforms to apply to images from the loader before forward pass. Defaults to None.

get_out_to_loss_grad(model, weights, buffers, batch: Iterable[Tensor]) Tensor[source]#

Computes the (reweighting term Q in the paper)

Args:
model (torch.nn.Module):

torch model

weights (Iterable[Tensor]):

functorch model weights

buffers (Iterable[Tensor]):

functorch model buffers

batch (Iterable[Tensor]):

input batch

Returns:
Tensor:

out-to-loss (reweighting term) for the input batch

static get_output(model: Module, weights: Iterable[Tensor], buffers: Iterable[Tensor], image: Tensor, label: Tensor) Tensor[source]#

For a given input \(z=(x, y)\) and model parameters \(\theta\), let \(\phi(x, \theta)\) be the CLIP image embedding and \(\psi(y, \theta)\) be the CLIP text embedding. Last, let \(B\) be a (simulated) batch. This method implements the model output function

\[-\log(\frac{\phi(x)\cdot \psi(y)}{\sum_{(x', y')\in B} \phi(x)\cdot \psi(y')}) -\log(\frac{\phi(x)\cdot \psi(y)}{\sum_{(x', y')\in B} \phi(x')\cdot \psi(y)})\]

It uses functional models from torch.func (previously functorch) to make the per-sample gradient computations (much) faster. For more details on what functional models are, and how to use them, please refer to https://pytorch.org/docs/stable/func.html and https://pytorch.org/functorch/stable/notebooks/per_sample_grads.html.

Args:
model (torch.nn.Module):

torch model

weights (Iterable[Tensor]):

functorch model weights

buffers (Iterable[Tensor]):

functorch model buffers

image (Tensor):

input image, should not have batch dimension

label (Tensor):

input label, should not have batch dimension

Returns:
Tensor:

model output for the given image-label pair \(z\) and weights & buffers \(\theta\).

image_embeddings = None#
num_computed_embeddings = 0#
sim_batch_size = 0#
text_embeddings = None#
class trak.modelout_functions.ImageClassificationModelOutput(temperature: float = 1.0)[source]#

Bases: AbstractModelOutput

Margin for (multiclass) image classification. See Section 3.3 of our paper for more details.

__init__(temperature: float = 1.0) None[source]#
Args:

temperature (float, optional): Temperature to use inside the softmax for the out-to-loss function. Defaults to 1.

get_out_to_loss_grad(model, weights, buffers, batch: Iterable[Tensor]) Tensor[source]#

Computes the (reweighting term Q in the paper)

Args:
model (torch.nn.Module):

torch model

weights (Iterable[Tensor]):

functorch model weights

buffers (Iterable[Tensor]):

functorch model buffers

batch (Iterable[Tensor]):

input batch

Returns:
Tensor:

out-to-loss (reweighting term) for the input batch

static get_output(model: Module, weights: Iterable[Tensor], buffers: Iterable[Tensor], image: Tensor, label: Tensor) Tensor[source]#

For a given input \(z=(x, y)\) and model parameters \(\theta\), let \(p(z, \theta)\) be the softmax probability of the correct class. This method implements the model output function

\[\log(\frac{p(z, \theta)}{1 - p(z, \theta)}).\]

It uses functional models from torch.func (previously functorch) to make the per-sample gradient computations (much) faster. For more details on what functional models are, and how to use them, please refer to https://pytorch.org/docs/stable/func.html and https://pytorch.org/functorch/stable/notebooks/per_sample_grads.html.

Args:
model (torch.nn.Module):

torch model

weights (Iterable[Tensor]):

functorch model weights

buffers (Iterable[Tensor]):

functorch model buffers

image (Tensor):

input image, should not have batch dimension

label (Tensor):

input label, should not have batch dimension

Returns:
Tensor:

model output for the given image-label pair \(z\) and weights & buffers \(\theta\).

class trak.modelout_functions.IterativeImageClassificationModelOutput(temperature: float = 1.0)[source]#

Bases: AbstractModelOutput

Margin for (multiclass) image classification. See Section 3.3 of our paper for more details.

__init__(temperature: float = 1.0) None[source]#
Args:

temperature (float, optional): Temperature to use inside the softmax for the out-to-loss function. Defaults to 1.

get_out_to_loss_grad(model, weights, buffers, batch: Iterable[Tensor]) Tensor[source]#

Computes the (reweighting term Q in the paper)

Args:
model (torch.nn.Module):

torch model

weights (Iterable[Tensor]):

functorch model weights

buffers (Iterable[Tensor]):

functorch model buffers

batch (Iterable[Tensor]):

input batch

Returns:
Tensor:

out-to-loss (reweighting term) for the input batch

static get_output(model: Module, weights: Iterable[Tensor], buffers: Iterable[Tensor], images: Tensor, labels: Tensor) Tensor[source]#

For a given input \(z=(x, y)\) and model parameters \(\theta\), let \(p(z, \theta)\) be the softmax probability of the correct class. This method implements the model output function

\[\log(\frac{p(z, \theta)}{1 - p(z, \theta)}).\]

It uses functional models from torch.func (previously functorch) to make the per-sample gradient computations (much) faster. For more details on what functional models are, and how to use them, please refer to https://pytorch.org/docs/stable/func.html and https://pytorch.org/functorch/stable/notebooks/per_sample_grads.html.

Args:
model (torch.nn.Module):

torch model

weights (Iterable[Tensor]):

functorch model weights (added se we don’t break abstraction)

buffers (Iterable[Tensor]):

functorch model buffers (added se we don’t break abstraction)

images (Tensor):

input images

labels (Tensor):

input labels

Returns:
Tensor:

model output for the given image-label pair \(z\)

class trak.modelout_functions.TextClassificationModelOutput(temperature=1.0)[source]#

Bases: AbstractModelOutput

Margin for text classification models. This assumes that the model takes in input_ids, token_type_ids, and attention_mask.

\[\text{logit}[\text{correct}] - \log\left(\sum_{i \neq \text{correct}} \exp(\text{logit}[i])\right)\]
__init__(temperature=1.0) None[source]#
get_out_to_loss_grad(model, weights, buffers, batch: Iterable[Tensor]) Tensor[source]#

See Sections 2 & 3 of our paper for more details on what the out-to-loss functions (in the notation of the paper, \(Q\)) are in the context of TRAK and how to use & design them.

Args:

model (torch.nn.Module): model batch (Iterable[Tensor]): input batch

Returns:

Tensor: gradient of the out-to-loss function

static get_output(model, weights: Iterable[Tensor], buffers: Iterable[Tensor], input_id: Tensor, token_type_id: Tensor, attention_mask: Tensor, label: Tensor) Tensor[source]#

See Sections 2 & 3 of our paper for more details on what model output functions are in the context of TRAK and how to use & design them.

Args:
model (torch.nn.Module):

model

batch (Iterable[Tensor]):

input batch

Returns:
Tensor:

model output function

trak.projectors module#

Projectors are used to project gradients to a lower-dimensional space. This 1) allows us to compute TRAK scores in a much more efficient manner, and 2) turns out to be act as a useful regularizer (see Appendix E.1 in our paper).

Here, we provide four implementations of the projector: - NoOpProjector (no-op) - BasicSingleBlockProjector (bare-bones, inefficient implementation) - BasicProjector (block-wise implementation) - CudaProjector (a fast implementation with a custom CUDA kernel)

class trak.projectors.AbstractProjector(grad_dim: int, proj_dim: int, seed: int, proj_type: str | ProjectionType, device: str | device)[source]#

Bases: ABC

Implementations of the Projector class must implement the AbstractProjector.project() method, which takes in model gradients and returns

abstract __init__(grad_dim: int, proj_dim: int, seed: int, proj_type: str | ProjectionType, device: str | device) None[source]#

Initializes hyperparameters for the projection.

Args:
grad_dim (int):

number of parameters in the model (dimension of the gradient vectors)

proj_dim (int):

dimension after the projection

seed (int):

random seed for the generation of the sketching (projection) matrix

proj_type (Union[str, ProjectionType]):

the random projection (JL transform) guearantees that distances will be approximately preserved for a variety of choices of the random matrix (see e.g. https://arxiv.org/abs/1411.2404). Here, we provide an implementation for matrices with iid Gaussian entries and iid Rademacher entries.

device (Union[str, torch.device]):

CUDA device to use

free_memory()[source]#

Frees up memory used by the projector.

abstract project(grads: Tensor, model_id: int) Tensor[source]#

Performs the random projection. Model ID is included so that we generate different projection matrices for every model ID.

Args:

grads (Tensor): a batch of gradients to be projected model_id (int): a unique ID for a checkpoint

Returns:

Tensor: the projected gradients

class trak.projectors.BasicProjector(grad_dim: int, proj_dim: int, seed: int, proj_type: ProjectionType, device: device, block_size: int = 100, dtype: dtype = torch.float32, model_id=0, *args, **kwargs)[source]#

Bases: AbstractProjector

A simple block-wise implementation of the projection. The projection matrix is generated on-device in blocks. The accumulated result across blocks is returned.

Note: This class will be significantly slower and have a larger memory footprint than the CudaProjector. It is recommended that you use this method only if the CudaProjector is not available to you – e.g. if you don’t have a CUDA-enabled device with compute capability >=7.0 (see https://developer.nvidia.com/cuda-gpus).

__init__(grad_dim: int, proj_dim: int, seed: int, proj_type: ProjectionType, device: device, block_size: int = 100, dtype: dtype = torch.float32, model_id=0, *args, **kwargs) None[source]#

Initializes hyperparameters for the projection.

Args:
grad_dim (int):

number of parameters in the model (dimension of the gradient vectors)

proj_dim (int):

dimension after the projection

seed (int):

random seed for the generation of the sketching (projection) matrix

proj_type (Union[str, ProjectionType]):

the random projection (JL transform) guearantees that distances will be approximately preserved for a variety of choices of the random matrix (see e.g. https://arxiv.org/abs/1411.2404). Here, we provide an implementation for matrices with iid Gaussian entries and iid Rademacher entries.

device (Union[str, torch.device]):

CUDA device to use

free_memory()[source]#

Frees up memory used by the projector.

generate_sketch_matrix(generator_state)[source]#
get_generator_states()[source]#
project(grads: Tensor, model_id: int) Tensor[source]#

Performs the random projection. Model ID is included so that we generate different projection matrices for every model ID.

Args:

grads (Tensor): a batch of gradients to be projected model_id (int): a unique ID for a checkpoint

Returns:

Tensor: the projected gradients

class trak.projectors.BasicSingleBlockProjector(grad_dim: int, proj_dim: int, seed: int, proj_type: ProjectionType, device, dtype=torch.float32, model_id=0, *args, **kwargs)[source]#

Bases: AbstractProjector

A bare-bones, inefficient implementation of the projection, which simply calls torch’s matmul for the projection step.

Note: for most model sizes (e.g. even for ResNet18), and small projection dimensions (e.g. anything > 100) this method will OOM on an A100.

Unless you have a good reason to use this class (I cannot think of one, I added this only for testing purposes), use instead the CudaProjector or BasicProjector.

__init__(grad_dim: int, proj_dim: int, seed: int, proj_type: ProjectionType, device, dtype=torch.float32, model_id=0, *args, **kwargs) None[source]#

Initializes hyperparameters for the projection.

Args:
grad_dim (int):

number of parameters in the model (dimension of the gradient vectors)

proj_dim (int):

dimension after the projection

seed (int):

random seed for the generation of the sketching (projection) matrix

proj_type (Union[str, ProjectionType]):

the random projection (JL transform) guearantees that distances will be approximately preserved for a variety of choices of the random matrix (see e.g. https://arxiv.org/abs/1411.2404). Here, we provide an implementation for matrices with iid Gaussian entries and iid Rademacher entries.

device (Union[str, torch.device]):

CUDA device to use

free_memory()[source]#

Frees up memory used by the projector.

generate_sketch_matrix()[source]#
project(grads: Tensor, model_id: int) Tensor[source]#

Performs the random projection. Model ID is included so that we generate different projection matrices for every model ID.

Args:

grads (Tensor): a batch of gradients to be projected model_id (int): a unique ID for a checkpoint

Returns:

Tensor: the projected gradients

class trak.projectors.ChunkedCudaProjector(projector_per_chunk: list, max_chunk_size: int, params_per_chunk: list, feat_bs: int, device: device, dtype: dtype)[source]#

Bases: object

__init__(projector_per_chunk: list, max_chunk_size: int, params_per_chunk: list, feat_bs: int, device: device, dtype: dtype)[source]#
allocate_input()[source]#
free_memory()[source]#
project(grads, model_id)[source]#
class trak.projectors.CudaProjector(grad_dim: int, proj_dim: int, seed: int, proj_type: ProjectionType, device, max_batch_size: int, *args, **kwargs)[source]#

Bases: AbstractProjector

A performant implementation of the projection for CUDA with compute capability >= 7.0.

__init__(grad_dim: int, proj_dim: int, seed: int, proj_type: ProjectionType, device, max_batch_size: int, *args, **kwargs) None[source]#
Args:
grad_dim (int):

Number of parameters

proj_dim (int):

Dimension we project to during the projection step

seed (int):

Random seed

proj_type (ProjectionType):

Type of randomness to use for projection matrix (rademacher or normal)

device:

CUDA device

max_batch_size (int):

Explicitly constraints the batch size the CudaProjector is going to use for projection. Set this if you get a ‘The batch size of the CudaProjector is too large for your GPU’ error. Must be either 8, 16, or 32.

Raises:
ValueError:

When attempting to use this on a non-CUDA device

ModuleNotFoundError:

When fast_jl is not installed

free_memory()[source]#

A no-op method.

project(grads: dict | Tensor, model_id: int) Tensor[source]#

Performs the random projection. Model ID is included so that we generate different projection matrices for every model ID.

Args:

grads (Tensor): a batch of gradients to be projected model_id (int): a unique ID for a checkpoint

Returns:

Tensor: the projected gradients

class trak.projectors.NoOpProjector(grad_dim: int = 0, proj_dim: int = 0, seed: int = 0, proj_type: str | ProjectionType = 'na', device: str | device = 'cuda', *args, **kwargs)[source]#

Bases: AbstractProjector

A projector that returns the gradients as they are, i.e., implements projector.project(grad) = grad.

__init__(grad_dim: int = 0, proj_dim: int = 0, seed: int = 0, proj_type: str | ProjectionType = 'na', device: str | device = 'cuda', *args, **kwargs) None[source]#

Initializes hyperparameters for the projection.

Args:
grad_dim (int):

number of parameters in the model (dimension of the gradient vectors)

proj_dim (int):

dimension after the projection

seed (int):

random seed for the generation of the sketching (projection) matrix

proj_type (Union[str, ProjectionType]):

the random projection (JL transform) guearantees that distances will be approximately preserved for a variety of choices of the random matrix (see e.g. https://arxiv.org/abs/1411.2404). Here, we provide an implementation for matrices with iid Gaussian entries and iid Rademacher entries.

device (Union[str, torch.device]):

CUDA device to use

free_memory()[source]#

A no-op method.

project(grads: Tensor, model_id: int) Tensor[source]#

A no-op method.

Args:

grads (Tensor): a batch of gradients to be projected model_id (int): a unique ID for a checkpoint

Returns:

Tensor: the (non-)projected gradients

class trak.projectors.ProjectionType(value)[source]#

Bases: str, Enum

An enumeration.

normal: str = 'normal'#
rademacher: str = 'rademacher'#

trak.savers module#

This module contains classes that save TRAK results, intermediate values, and metadata to disk. The AbstractSaver class defines the interface for savers. Then, we provide one implementation: - MmapSaver: A saver that uses memory-mapped numpy arrays. This makes

loading and saving small chunks of data (e.g.) during featurizing feasible without loading the entire file into memory.

class trak.savers.AbstractSaver(save_dir: Path | str, metadata: Iterable, load_from_save_dir: bool, logging_level: int, use_half_precision: bool)[source]#

Bases: ABC

Implementations of Saver class must implement getters and setters for TRAK features and scores, as well as intermediate values like gradients and “out-to-loss-gradient”.

The Saver class also handles the recording of metadata associated with each TRAK run. For example, hyperparameters like “JL dimension” – the dimension used for the dimensionality reduction step of TRAK (Johnson-Lindenstrauss projection).

abstract __init__(save_dir: Path | str, metadata: Iterable, load_from_save_dir: bool, logging_level: int, use_half_precision: bool) None[source]#

Creates the save directory if it doesn’t already exist. If the save directory already exists, it validates that the current TRAKer class has the same hyperparameters (metadata) as the one specified in the save directory. Next, this method loads any existing computed results / intermediate values in the save directory. Last, it initalizes the self.current_store attributes which will be later populated with data for the “current” model ID of the TRAKer instance.

Args:
save_dir (Union[Path, str]): directory to save TRAK results,

intermediate values, and metadata

metadata (Iterable): a dictionary containing metadata related to the

TRAKer class

load_from_save_dir (bool): If True, the Saver instance will attempt

to load existing metadata from save_dir. May lead to I/O issues if multiple Saver instances ran in parallel have this flag set to True. See the SLURM tutorial in our docs for more details.

logging_level (int):

logging level for the logger associated with this Saver instance

use_half_precision (bool):

If True, the Saver instance will save all results and intermediate values in half precision (float16).

abstract del_grads(model_id: int, target: bool) None[source]#

Delete the intermediate values (gradients) for a given model id

Args:
model_id (int):

a unique ID for a checkpoint

target (bool):

if True, delete the gradients of the target samples, otherwise delete the train set gradients.

abstract init_experiment(model_id: int) None[source]#

Initializes store for a given experiment & model ID (checkpoint).

Args:
model_id (int):

a unique ID for a checkpoint

abstract init_store(model_id: int) None[source]#

Initializes store for a given model ID (checkpoint).

Args:
model_id (int):

a unique ID for a checkpoint

abstract load_current_store(model_id: int) None[source]#

Populates the self.current_store attributes with data for the given model ID (checkpoint).

Args:
model_id (int):

a unique ID for a checkpoint

abstract register_model_id(model_id: int) None[source]#

Create metadata for a new model ID (checkpoint).

Args:
model_id (int):

a unique ID for a checkpoint

abstract save_scores(exp_name: str) None[source]#

Saves scores for a given experiment name

Args:
exp_name (str):

experiment name

abstract serialize_current_model_id_metadata() None[source]#

Write to disk / commit any updates to the metadata associated to the current model ID

class trak.savers.MmapSaver(save_dir, metadata, train_set_size, proj_dim, load_from_save_dir, logging_level, use_half_precision)[source]#

Bases: AbstractSaver

A saver that uses memory-mapped numpy arrays. This makes small reads and writes (e.g.) during featurizing feasible without loading the entire file into memory.

__init__(save_dir, metadata, train_set_size, proj_dim, load_from_save_dir, logging_level, use_half_precision) None[source]#

Creates the save directory if it doesn’t already exist. If the save directory already exists, it validates that the current TRAKer class has the same hyperparameters (metadata) as the one specified in the save directory. Next, this method loads any existing computed results / intermediate values in the save directory. Last, it initalizes the self.current_store attributes which will be later populated with data for the “current” model ID of the TRAKer instance.

Args:
save_dir (Union[Path, str]): directory to save TRAK results,

intermediate values, and metadata

metadata (Iterable): a dictionary containing metadata related to the

TRAKer class

load_from_save_dir (bool): If True, the Saver instance will attempt

to load existing metadata from save_dir. May lead to I/O issues if multiple Saver instances ran in parallel have this flag set to True. See the SLURM tutorial in our docs for more details.

logging_level (int):

logging level for the logger associated with this Saver instance

use_half_precision (bool):

If True, the Saver instance will save all results and intermediate values in half precision (float16).

del_grads(model_id)[source]#

Delete the intermediate values (gradients) for a given model id

Args:
model_id (int):

a unique ID for a checkpoint

target (bool):

if True, delete the gradients of the target samples, otherwise delete the train set gradients.

init_experiment(exp_name, num_targets, model_id) None[source]#

Initializes store for a given experiment & model ID (checkpoint).

Args:
model_id (int):

a unique ID for a checkpoint

init_store(model_id) None[source]#

Initializes store for a given model ID (checkpoint).

Args:
model_id (int):

a unique ID for a checkpoint

load_current_store(model_id: int, exp_name: str | None = None, exp_num_targets: int | None = -1, mode: str | None = 'r+') None[source]#

This method uses numpy memmaps for serializing the TRAK results and intermediate values.

Args:
model_id (int):

a unique ID for a checkpoint

exp_name (str, optional):

Experiment name for which to load the features. If None, loads the train (source) features for a model ID. Defaults to None.

exp_num_targets (int, optional):

Number of targets for the experiment. Specify only when exp_name is not None. Defaults to -1.

mode (str, optional):

Defaults to ‘r+’.

register_model_id(model_id: int, _allow_featurizing_already_registered: bool) None[source]#

This method 1) checks if the model ID already exists in the save dir 2) if yes, it raises an error since model IDs must be unique 3) if not, it creates a metadata file for it and initalizes store mmaps

Args:
model_id (int):

a unique ID for a checkpoint

Raises:
ModelIDException:

raised if the model ID to be registered already exists

save_scores(exp_name)[source]#

Saves scores for a given experiment name

Args:
exp_name (str):

experiment name

serialize_current_model_id_metadata(already_exists=True) None[source]#

Write to disk / commit any updates to the metadata associated to the current model ID

exception trak.savers.ModelIDException[source]#

Bases: Exception

A minimal custom exception for errors related to model IDs

trak.score_computers module#

Computing scores for the TRAK algorithm from pre-computed projected gradients involves a number of matrix multiplications. This module contains classes that perform these operations. The AbstractScoreComputer class defines the interface for score computers. Then, we provide two implementations: - BasicSingleBlockScoreComputer: A bare-bones implementation, mostly for

testing purposes.

  • BasicScoreComputer: A more sophisticated implementation that does

    block-wise matrix multiplications to avoid OOM errors.

class trak.score_computers.AbstractScoreComputer(dtype, device)[source]#

Bases: ABC

The ScoreComputer class Implementations of the ScoreComputer class must implement three methods: - get_xtx - get_x_xtx_inv - get_scores

abstract __init__(dtype, device) None[source]#
abstract get_scores(features: Tensor, target_grads: Tensor, accumulator: Tensor) None[source]#

Computes the scores for a given set of features and target gradients. In particular, this function takes in a matrix of features \(\Phi=X(X^ op X)^{-1}\), computed by the get_x_xtx_inv method, and a matrix of target (projected) gradients \(X_{target}\). Then, it computes the scores as \(\Phi X_{target}^ op\). The resulting matrix has shape (n, m), where \(n\) is the number of training examples and \(m\) is the number of target examples.

The accumulator argument is used to store the result of the computation. This is useful when computing scores for multiple model checkpoints, as it allows us to re-use the same memory for the score matrix.

Args:

features (Tensor): features \(\Phi\) of shape (n, p). target_grads (Tensor):

target projected gradients \(X_{target}\) of shape (m, p).

accumulator (Tensor): accumulator of shape (n, m).

abstract get_x_xtx_inv(grads: Tensor, xtx: Tensor) Tensor[source]#

Computes \(X(X^ op X)^{-1}\), where \(X\) is the matrix of projected gradients. Here, the shape of \(X\) is (n, p), where \(n\) is the number of training examples and \(p\) is the dimension of the projection. This function takes as input the pre-computed \(X^ op X\) matrix, which is computed by the get_xtx method.

Args:

grads (Tensor): projected gradients \(X\) of shape (n, p). xtx (Tensor): \(X^ op X\) of shape (p, p).

Returns:

Tensor: \(X(X^ op X)^{-1}\) of shape (n, p).

abstract get_xtx(grads: Tensor) Tensor[source]#

Computes \(X^ op X\), where \(X\) is the matrix of projected gradients. Here, the shape of \(X\) is (n, p), where \(n\) is the number of training examples and \(p\) is the dimension of the projection.

Args:

grads (Tensor): projected gradients of shape (n, p).

Returns:

Tensor: \(X^ op X\) of shape (p, p).

class trak.score_computers.BasicScoreComputer(dtype: dtype, device: device, CUDA_MAX_DIM_SIZE: int = 20000, logging_level=20, lambda_reg: float = 0.0)[source]#

Bases: AbstractScoreComputer

An implementation of ScoreComputer that computes matmuls in a block-wise manner.

__init__(dtype: dtype, device: device, CUDA_MAX_DIM_SIZE: int = 20000, logging_level=20, lambda_reg: float = 0.0) None[source]#
Args:

dtype (torch.dtype): device (Union[str, torch.device]):

torch device to do matmuls on

CUDA_MAX_DIM_SIZE (int, optional):

Size of block for block-wise matmuls. Defaults to 100_000.

logging_level (logging level, optional):

Logging level for the logger. Defaults to logging.info.

lambda_reg (int):

regularization term for l2 reg on xtx

get_scores(features: Tensor, target_grads: Tensor, accumulator: Tensor) Tensor[source]#

Computes the scores for a given set of features and target gradients. In particular, this function takes in a matrix of features \(\Phi=X(X^ op X)^{-1}\), computed by the get_x_xtx_inv method, and a matrix of target (projected) gradients \(X_{target}\). Then, it computes the scores as \(\Phi X_{target}^ op\). The resulting matrix has shape (n, m), where \(n\) is the number of training examples and \(m\) is the number of target examples.

The accumulator argument is used to store the result of the computation. This is useful when computing scores for multiple model checkpoints, as it allows us to re-use the same memory for the score matrix.

Args:

features (Tensor): features \(\Phi\) of shape (n, p). target_grads (Tensor):

target projected gradients \(X_{target}\) of shape (m, p).

accumulator (Tensor): accumulator of shape (n, m).

get_x_xtx_inv(grads: Tensor, xtx: Tensor) Tensor[source]#

Computes \(X(X^ op X)^{-1}\), where \(X\) is the matrix of projected gradients. Here, the shape of \(X\) is (n, p), where \(n\) is the number of training examples and \(p\) is the dimension of the projection. This function takes as input the pre-computed \(X^ op X\) matrix, which is computed by the get_xtx method.

Args:

grads (Tensor): projected gradients \(X\) of shape (n, p). xtx (Tensor): \(X^ op X\) of shape (p, p).

Returns:

Tensor: \(X(X^ op X)^{-1}\) of shape (n, p).

get_xtx(grads: Tensor) Tensor[source]#

Computes \(X^ op X\), where \(X\) is the matrix of projected gradients. Here, the shape of \(X\) is (n, p), where \(n\) is the number of training examples and \(p\) is the dimension of the projection.

Args:

grads (Tensor): projected gradients of shape (n, p).

Returns:

Tensor: \(X^ op X\) of shape (p, p).

class trak.score_computers.BasicSingleBlockScoreComputer(dtype, device)[source]#

Bases: AbstractScoreComputer

A bare-bones implementation of ScoreComputer that will likely OOM for almost all applications. Here for testing purposes only. Unless you have a good reason not to, you should use BasicScoreComputer() instead.

get_scores(features: Tensor, target_grads: Tensor, accumulator: Tensor) None[source]#

Computes the scores for a given set of features and target gradients. In particular, this function takes in a matrix of features \(\Phi=X(X^ op X)^{-1}\), computed by the get_x_xtx_inv method, and a matrix of target (projected) gradients \(X_{target}\). Then, it computes the scores as \(\Phi X_{target}^ op\). The resulting matrix has shape (n, m), where \(n\) is the number of training examples and \(m\) is the number of target examples.

The accumulator argument is used to store the result of the computation. This is useful when computing scores for multiple model checkpoints, as it allows us to re-use the same memory for the score matrix.

Args:

features (Tensor): features \(\Phi\) of shape (n, p). target_grads (Tensor):

target projected gradients \(X_{target}\) of shape (m, p).

accumulator (Tensor): accumulator of shape (n, m).

get_x_xtx_inv(grads: Tensor, xtx: Tensor) Tensor[source]#

Computes \(X(X^ op X)^{-1}\), where \(X\) is the matrix of projected gradients. Here, the shape of \(X\) is (n, p), where \(n\) is the number of training examples and \(p\) is the dimension of the projection. This function takes as input the pre-computed \(X^ op X\) matrix, which is computed by the get_xtx method.

Args:

grads (Tensor): projected gradients \(X\) of shape (n, p). xtx (Tensor): \(X^ op X\) of shape (p, p).

Returns:

Tensor: \(X(X^ op X)^{-1}\) of shape (n, p).

get_xtx(grads: Tensor) Tensor[source]#

Computes \(X^ op X\), where \(X\) is the matrix of projected gradients. Here, the shape of \(X\) is (n, p), where \(n\) is the number of training examples and \(p\) is the dimension of the projection.

Args:

grads (Tensor): projected gradients of shape (n, p).

Returns:

Tensor: \(X^ op X\) of shape (p, p).

trak.utils module#

trak.utils.get_free_memory(device)[source]#
trak.utils.get_matrix_mult(features: Tensor, target_grads: Tensor, target_dtype: dtype | None = None, batch_size: int = 8096, use_blockwise: bool = False) Tensor[source]#

Computes features @ target_grads.T. If the output matrix is too large to fit in memory, it will be computed in blocks.

Args:
features (Tensor):

The first matrix to multiply.

target_grads (Tensor):

The second matrix to multiply.

target_dtype (torch.dtype, optional):

The dtype of the output matrix. If None, defaults to the dtype of features. Defaults to None.

batch_size (int, optional):

The batch size to use for blockwise matrix multiplication. Defaults to 8096.

use_blockwise (bool, optional):

Whether or not to use blockwise matrix multiplication. Defaults to False.

trak.utils.get_matrix_mult_blockwise(features: Tensor, target_grads: Tensor, target_dtype: type, bs: int)[source]#
trak.utils.get_matrix_mult_standard(features: Tensor, target_grads: Tensor, target_dtype: type)[source]#
trak.utils.get_num_params(model: Module) int[source]#
trak.utils.get_output_memory(features: Tensor, target_grads: Tensor, target_dtype: type)[source]#
trak.utils.get_parameter_chunk_sizes(model: Module, batch_size: int)[source]#

The CudaProjector supports projecting when the product of the number of parameters and the batch size is less than the the max value of int32. This function computes the number of parameters that can be projected at once for a given model and batch size.

The method returns a tuple containing the maximum number of parameters that can be projected at once and a list of the actual number of parameters in each chunk (a sequence of paramter groups). Used in ChunkedCudaProjector.

trak.utils.is_not_buffer(ind, params_dict) bool[source]#
trak.utils.parameters_to_vector(parameters) Tensor[source]#

Same as https://pytorch.org/docs/stable/generated/torch.nn.utils.parameters_to_vector.html but with reshape instead of view to avoid a pesky error.

trak.utils.test_install(use_fast_jl: bool = True)[source]#
trak.utils.vectorize(g, arr=None, device='cuda') Tensor[source]#

records result into arr

gradients are given as a dict (name_w0: grad_w0, ... name_wp: grad_wp) where p is the number of weight matrices. each grad_wi has shape [batch_size, ...] this function flattens g to have shape [batch_size, num_params].

Module contents#