Quickstart — get TRAK scores for CIFAR#


Follow along in this Jupyter notebook. If you want to browse pre-computed TRAK scores instead, check out this Colab notebook.

In this tutorial, we’ll show you how to use the TRAK API to compute data attribution scores for ResNet-9 models trained on CIFAR-10. While we use a particular model architecture and dataset, the code in this tutorial can be easily adapted to any classification task.

Overall, this tutorial will show you how to:

  1. Load model checkpoints

  2. Set up the TRAKer class

  3. Compute TRAK features for training data

  4. Compute TRAK scores for target examples

Let’s get started!

Load model checkpoints#

First, you need models to apply TRAK to. You can either use the script below to train three ResNet-9 models on CIFAR-10 and save the checkpoints (e.g., state_dict()s), or use your own checkpoint [1]. (In fact, in this tutorial you can replace ResNet-9 + CIFAR-10 with any architecture + classification task of your choice.)

Training code for CIFAR-10
import os
from pathlib import Path
import wget
from tqdm import tqdm
import numpy as np
import torch
from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss, Conv2d, BatchNorm2d
from torch.optim import SGD, lr_scheduler
import torchvision

# Resnet9
class Mul(torch.nn.Module):
    def __init__(self, weight):
        super(Mul, self).__init__()
        self.weight = weight
    def forward(self, x): return x * self.weight

class Flatten(torch.nn.Module):
    def forward(self, x): return x.view(x.size(0), -1)

class Residual(torch.nn.Module):
    def __init__(self, module):
        super(Residual, self).__init__()
        self.module = module
    def forward(self, x): return x + self.module(x)

def construct_rn9(num_classes=10):
    def conv_bn(channels_in, channels_out, kernel_size=3, stride=1, padding=1, groups=1):
        return torch.nn.Sequential(
                torch.nn.Conv2d(channels_in, channels_out, kernel_size=kernel_size,
                            stride=stride, padding=padding, groups=groups, bias=False),
    model = torch.nn.Sequential(
        conv_bn(3, 64, kernel_size=3, stride=1, padding=1),
        conv_bn(64, 128, kernel_size=5, stride=2, padding=2),
        Residual(torch.nn.Sequential(conv_bn(128, 128), conv_bn(128, 128))),
        conv_bn(128, 256, kernel_size=3, stride=1, padding=1),
        Residual(torch.nn.Sequential(conv_bn(256, 256), conv_bn(256, 256))),
        conv_bn(256, 128, kernel_size=3, stride=1, padding=0),
        torch.nn.AdaptiveMaxPool2d((1, 1)),
        torch.nn.Linear(128, num_classes, bias=False),
    return model

def get_dataloader(batch_size=256, num_workers=8, split='train', shuffle=False, augment=True):
    if augment:
        transforms = torchvision.transforms.Compose(
                         torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465),
                                                          (0.2023, 0.1994, 0.201))])
        transforms = torchvision.transforms.Compose([
                         torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465),
                                                          (0.2023, 0.1994, 0.201))])

    is_train = (split == 'train')
    dataset = torchvision.datasets.CIFAR10(root='/tmp/cifar/',

    loader = torch.utils.data.DataLoader(dataset=dataset,

    return loader

def train(model, loader, lr=0.4, epochs=24, momentum=0.9,
          weight_decay=5e-4, lr_peak_epoch=5, label_smoothing=0.0, model_id=0):

    opt = SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
    iters_per_epoch = len(loader)
    # Cyclic LR with single triangle
    lr_schedule = np.interp(np.arange((epochs+1) * iters_per_epoch),
                            [0, lr_peak_epoch * iters_per_epoch, epochs * iters_per_epoch],
                            [0, 1, 0])
    scheduler = lr_scheduler.LambdaLR(opt, lr_schedule.__getitem__)
    scaler = GradScaler()
    loss_fn = CrossEntropyLoss(label_smoothing=label_smoothing)

    for ep in range(epochs):
        for it, (ims, labs) in enumerate(loader):
            ims = ims.cuda()
            labs = labs.cuda()
            with autocast():
                out = model(ims)
                loss = loss_fn(out, labs)

        if ep in [12, 15, 18, 21, 23]:
            torch.save(model.state_dict(), f'./checkpoints/sd_{model_id}_epoch_{ep}.pt')

    return model

os.makedirs('./checkpoints', exist_ok=True)
loader_for_training = get_dataloader(batch_size=512, split='train', shuffle=True)

# you can modify the for loop below to train more models
for i in tqdm(range(1), desc='Training models..'):
    model = construct_rn9().to(memory_format=torch.channels_last).cuda()
    model = train(model, loader_for_training, model_id=i)

For the remaining steps, we’ll assume you have N model checkpoints in ./checkpoints:

import torch
from pathlib import Path

ckpt_files = list(Path('./checkpoints').rglob('*.pt'))
ckpts = [torch.load(ckpt, map_location='cpu') for ckpt in ckpt_files]

Set up the TRAKer class#

The TRAKer class is the entry point to the TRAK API. Construct it by calling __init__() with three arguments:

  • a model (a torch.nn.Module instance) — this is the model architecture/class that you want to compute attributions for. Note that this model you pass in does not need to be initialized (we’ll do that separately below).

  • a task (a string or a AbstractModelOutput instance) — this specifies the type of learning task you want to attribue with TRAK, e.g. image classification, language modeling, CLIP-style contrastive learning, etc. Internally, the task tells TRAKer how to evaluate a given batch of data.

  • a train_set_size (an integer) — the size of the training set you want to keep trak of

Let’s set up our model and dataset:

# Replace with your choice of model constructor
model = construct_rn9().to(memory_format=torch.channels_last).cuda().eval()

# Replace with your choice of data loader (should be deterministic ordering)
loader_train = get_dataloader(batch_size=128, split='train')

Now we are ready to start TRAKing our model on the dataset of choice. Let’s initialize the TRAKer object.

from trak import TRAKer

traker = TRAKer(model=model,

By default, all metadata and arrays created by TRAKer are stored in ./trak_results. You can override this by specifying a custom save_dir to TRAKer.

In addition, you can specify the dimension of the features used by TRAK with the proj_dim argument, e.g.,

traker = TRAKer(..., proj_dim=4096)  # default dimension is 2048

(For the curious, this corresponds to the dimension of the output of random projections in our algorithm. We recommend proj_dim between 1,000 and 40,000.)

For more customizations, check out the API reference.

Compute TRAK features for training data#

Now that we have constructed a TRAKer object, let’s use it to process the training data. We process the training examples by calling featurize():

 1for model_id, ckpt in enumerate(tqdm(ckpts)):
 2    # TRAKer loads the provided checkpoint and also associates
 3    # the provided (unique) model_id with the checkpoint.
 4    traker.load_checkpoint(ckpt, model_id=model_id)
 6    for batch in loader_train:
 7        batch = [x.cuda() for x in batch]
 8        # TRAKer computes features corresponding to the batch of examples,
 9        # using the checkpoint loaded above.
10        traker.featurize(batch=batch, num_samples=batch[0].shape[0])
12# Tells TRAKer that we've given it all the information, at which point
13# TRAKer does some post-processing to get ready for the next step
14# (scoring target examples).


Here we assume that the data loader we are using is not shuffled, so we only need to specify how many samples are in batch. Alternatively, we can use a shuffled data loader, and pass in inds instead of num_samples to featurize(). In that case, inds should be an array of the same length as the batch, specifying the indices of the examples in the batch within the training dataset.

Above, we sequentially iterate over multiple model checkpoints


While you can still compute TRAK with a single checkpoint, using multiple checkpoints significantly improves TRAK’s performance. See our

But you can also—and we recommend you to—parallelize this step across multiple jobs. All you have to do is initialize a different TRAKer object with the same save_dir within each job and specify the appropriate model_id when calling load_checkpoint(). For more details, check out how to Parallelize TRAK scoring with SLURM.

Compute TRAK scores for target examples#

Finally, we are ready to compute attribution scores. To do this, you need to choose a set of target examples that you want to attribute. For the purpose of this tutorial, let’s make the targets be the entire validation set:

loader_targets = get_dataloader(batch_size=batch_size, split='val', augment=False)

As before, we iterate over checkpoints and batches of data:

 1for model_id, ckpt in enumerate(tqdm(ckpts)):
 2    traker.start_scoring_checkpoint(exp_name='quickstart',
 3                                    checkpoint=ckpt,
 4                                    model_id=model_id,
 5                                    num_targets=len(loader_targets.dataset))
 6    for batch in loader_targets:
 7        batch = [x.cuda() for x in batch]
 8        traker.score(batch=batch, num_samples=batch[0].shape[0])
10scores = traker.finalize_scores(exp_name='quickstart')

Here, start_scoring_checkpoint() has a similar function to load_checkpoint() used when featuring the training set; it prepares the TRAKer by loading the checkpoint and initializing internal data structures. The score() method is analogous to featurize(); it processes the target batch and computes the corresponding features.


Be careful that you provide the same model_id for each checkpoint as in the featurizing step—TRAK will not check that you did that. If you use the wrong model_ids, TRAK will silently fail.

P.S.: If you know of a clean, robust way to hash model parameters to detect a changed checkpoint, open an issue on github and we can add an assert to check for model_id consistency.

The final line above returns TRAK scores as a numpy.array from the finalize_scores() method. Additionally, finalize_scores() saves the scores to disk in memory-mapped file (.mmap format).

We can visualize some of the top scoring TRAK images from the scores array we just computed:

Top scoring TRAK images

That’s it! Once you have your model(s) and your data, just a few API-calls to TRAK` let’s you compute data attribution scores.