Quickstart — get TRAK
scores for CIFAR
#
Note
Follow along in this Jupyter notebook. If you want to browse pre-computed TRAK scores instead, check out this Colab notebook.
In this tutorial, we’ll show you how to use the TRAK
API to compute data
attribution scores for ResNet-9 models trained on
CIFAR-10
. While we use a particular model architecture and dataset, the code in this tutorial can be easily adapted to any classification task.
Overall, this tutorial will show you how to:
Let’s get started!
Load model checkpoints#
First, you need models to apply TRAK
to. You can either use the script
below to train three ResNet-9 models on CIFAR-10
and save the checkpoints
(e.g., state_dict()
s), or use your own checkpoint [1]. (In fact, in this
tutorial you can replace ResNet-9 + CIFAR-10 with any architecture +
classification task of your choice.)
Training code for CIFAR-10
import os
from pathlib import Path
import wget
from tqdm import tqdm
import numpy as np
import torch
from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss, Conv2d, BatchNorm2d
from torch.optim import SGD, lr_scheduler
import torchvision
# Resnet9
class Mul(torch.nn.Module):
def __init__(self, weight):
super(Mul, self).__init__()
self.weight = weight
def forward(self, x): return x * self.weight
class Flatten(torch.nn.Module):
def forward(self, x): return x.view(x.size(0), -1)
class Residual(torch.nn.Module):
def __init__(self, module):
super(Residual, self).__init__()
self.module = module
def forward(self, x): return x + self.module(x)
def construct_rn9(num_classes=10):
def conv_bn(channels_in, channels_out, kernel_size=3, stride=1, padding=1, groups=1):
return torch.nn.Sequential(
torch.nn.Conv2d(channels_in, channels_out, kernel_size=kernel_size,
stride=stride, padding=padding, groups=groups, bias=False),
torch.nn.BatchNorm2d(channels_out),
torch.nn.ReLU(inplace=True)
)
model = torch.nn.Sequential(
conv_bn(3, 64, kernel_size=3, stride=1, padding=1),
conv_bn(64, 128, kernel_size=5, stride=2, padding=2),
Residual(torch.nn.Sequential(conv_bn(128, 128), conv_bn(128, 128))),
conv_bn(128, 256, kernel_size=3, stride=1, padding=1),
torch.nn.MaxPool2d(2),
Residual(torch.nn.Sequential(conv_bn(256, 256), conv_bn(256, 256))),
conv_bn(256, 128, kernel_size=3, stride=1, padding=0),
torch.nn.AdaptiveMaxPool2d((1, 1)),
Flatten(),
torch.nn.Linear(128, num_classes, bias=False),
Mul(0.2)
)
return model
def get_dataloader(batch_size=256, num_workers=8, split='train', shuffle=False, augment=True):
if augment:
transforms = torchvision.transforms.Compose(
[torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.RandomAffine(0),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.201))])
else:
transforms = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.201))])
is_train = (split == 'train')
dataset = torchvision.datasets.CIFAR10(root='/tmp/cifar/',
download=True,
train=is_train,
transform=transforms)
loader = torch.utils.data.DataLoader(dataset=dataset,
shuffle=shuffle,
batch_size=batch_size,
num_workers=num_workers)
return loader
def train(model, loader, lr=0.4, epochs=24, momentum=0.9,
weight_decay=5e-4, lr_peak_epoch=5, label_smoothing=0.0, model_id=0):
opt = SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
iters_per_epoch = len(loader)
# Cyclic LR with single triangle
lr_schedule = np.interp(np.arange((epochs+1) * iters_per_epoch),
[0, lr_peak_epoch * iters_per_epoch, epochs * iters_per_epoch],
[0, 1, 0])
scheduler = lr_scheduler.LambdaLR(opt, lr_schedule.__getitem__)
scaler = GradScaler()
loss_fn = CrossEntropyLoss(label_smoothing=label_smoothing)
for ep in range(epochs):
for it, (ims, labs) in enumerate(loader):
ims = ims.cuda()
labs = labs.cuda()
opt.zero_grad(set_to_none=True)
with autocast():
out = model(ims)
loss = loss_fn(out, labs)
scaler.scale(loss).backward()
scaler.step(opt)
scaler.update()
scheduler.step()
if ep in [12, 15, 18, 21, 23]:
torch.save(model.state_dict(), f'./checkpoints/sd_{model_id}_epoch_{ep}.pt')
return model
os.makedirs('./checkpoints', exist_ok=True)
loader_for_training = get_dataloader(batch_size=512, split='train', shuffle=True)
# you can modify the for loop below to train more models
for i in tqdm(range(1), desc='Training models..'):
model = construct_rn9().to(memory_format=torch.channels_last).cuda()
model = train(model, loader_for_training, model_id=i)
For the remaining steps, we’ll assume you have N
model
checkpoints in ./checkpoints
:
import torch
from pathlib import Path
ckpt_files = list(Path('./checkpoints').rglob('*.pt'))
ckpts = [torch.load(ckpt, map_location='cpu') for ckpt in ckpt_files]
Set up the TRAKer
class#
The TRAKer
class is the entry point to the TRAK
API. Construct it by calling __init__()
with three arguments:
a
model
(atorch.nn.Module
instance) — this is the model architecture/class that you want to compute attributions for. Note that this model you pass in does not need to be initialized (we’ll do that separately below).a
task
(a string or aAbstractModelOutput
instance) — this specifies the type of learning task you want to attribue withTRAK
, e.g. image classification, language modeling, CLIP-style contrastive learning, etc. Internally, the task tellsTRAKer
how to evaluate a given batch of data.a
train_set_size
(an integer) — the size of the training set you want to keep trak of
Let’s set up our model and dataset:
# Replace with your choice of model constructor
model = construct_rn9().to(memory_format=torch.channels_last).cuda().eval()
# Replace with your choice of data loader (should be deterministic ordering)
loader_train = get_dataloader(batch_size=128, split='train')
Now we are ready to start TRAKing our model on the dataset of choice. Let’s initialize the TRAKer object.
from trak import TRAKer
traker = TRAKer(model=model,
task='image_classification',
train_set_size=len(loader_train.dataset))
By default, all metadata and arrays created by TRAKer
are stored in
./trak_results
. You can override this by specifying a custom
save_dir
to TRAKer
.
In addition, you can specify the dimension of the features used by TRAK
with the proj_dim
argument, e.g.,
traker = TRAKer(..., proj_dim=4096) # default dimension is 2048
(For the curious, this corresponds to the dimension of the output of random
projections in our algorithm. We recommend proj_dim
between 1,000 and
40,000.)
For more customizations, check out the API reference.
Compute TRAK
features for training data#
Now that we have constructed a TRAKer
object, let’s use it to process
the training data. We process the training examples by calling
featurize()
:
1for model_id, ckpt in enumerate(tqdm(ckpts)):
2 # TRAKer loads the provided checkpoint and also associates
3 # the provided (unique) model_id with the checkpoint.
4 traker.load_checkpoint(ckpt, model_id=model_id)
5
6 for batch in loader_train:
7 batch = [x.cuda() for x in batch]
8 # TRAKer computes features corresponding to the batch of examples,
9 # using the checkpoint loaded above.
10 traker.featurize(batch=batch, num_samples=batch[0].shape[0])
11
12# Tells TRAKer that we've given it all the information, at which point
13# TRAKer does some post-processing to get ready for the next step
14# (scoring target examples).
15traker.finalize_features()
Note
Here we assume that the data loader we are using is not shuffled,
so we only need to specify how many samples are in batch.
Alternatively, we can use
a shuffled data loader, and pass in inds
instead of num_samples
to featurize()
. In that case, inds
should be an array of the same
length as the batch, specifying the indices of the examples in the batch within
the training dataset.
Above, we sequentially iterate over multiple model checkpoints
Note
While you can still compute TRAK
with a single checkpoint, using multiple checkpoints significantly improves TRAK’s performance. See our
But you can also—and we recommend you to—parallelize this step across multiple jobs.
All you have to do is initialize a different TRAKer
object with the same
save_dir
within each job and specify the appropriate model_id
when calling
load_checkpoint()
.
For more details, check out how to Parallelize TRAK scoring with SLURM.
Compute TRAK
scores for target examples#
Finally, we are ready to compute attribution scores. To do this, you need to choose a set of target examples that you want to attribute. For the purpose of this tutorial, let’s make the targets be the entire validation set:
loader_targets = get_dataloader(batch_size=batch_size, split='val', augment=False)
As before, we iterate over checkpoints and batches of data:
1for model_id, ckpt in enumerate(tqdm(ckpts)):
2 traker.start_scoring_checkpoint(exp_name='quickstart',
3 checkpoint=ckpt,
4 model_id=model_id,
5 num_targets=len(loader_targets.dataset))
6 for batch in loader_targets:
7 batch = [x.cuda() for x in batch]
8 traker.score(batch=batch, num_samples=batch[0].shape[0])
9
10scores = traker.finalize_scores(exp_name='quickstart')
Here, start_scoring_checkpoint()
has a similar function to
load_checkpoint()
used when featuring the training set; it prepares the
TRAKer
by loading the checkpoint and initializing internal data structures.
The score()
method is analogous to
featurize()
; it processes the target batch and computes
the corresponding features.
Note
Be careful that you provide the same model_id
for each checkpoint as
in the featurizing step—TRAK
will not check that you did that.
If you use the wrong model_id
s, TRAK
will silently fail.
P.S.: If you know of a clean, robust way to hash model parameters to detect a changed checkpoint,
open an issue on github and we can add an assert
to check for model_id
consistency.
The final line above returns TRAK
scores as a numpy.array
from
the finalize_scores()
method. Additionally, finalize_scores()
saves the scores to disk in memory-mapped file (.mmap
format).
We can visualize some of the top scoring TRAK
images from the
scores
array we just computed:

That’s it!
Once you have your model(s) and your data, just a few API-calls to TRAK`
let’s you compute data attribution scores.