Parallelize TRAK scoring with SLURM#

Often we would like to compute TRAK scores from multiple checkpoints of the same model.


Check our paper to see why using multiple checkpoints helps improve TRAK’s performance.

This means that we need to run TRAKer.featurize() for all training examples for each checkpoint. But fortunately, this is a highly parallelizable process!

Below, we sketch a simple way of parallelizing featurize() and score() across checkpoints. We’ll use SLURM — a popular job scheduling system.


You can find all the code for this example here. We’ll skip some details in the post to highlight the main ideas behind using TRAK with SLURM.

Overall, we’ll write three files:


  • run.sbatch


We will use run.sbatch to run different instances of in parallel, and get the final TRAK scores using


In terms of MapReduce, you can of featurize_and_score as the map function and gather as the reduce function.

1. Featurizing each checkpoint#

Everything needed for scoring prior to finalize_scores() will go in For example, can be as follows:

 1from argparse import ArgumentParser
 2from trak import TRAKer
 4def main(model_id):
 5    model,loader_train, loader_val = ...
 6    # use model_id here to load the respective checkpoint, e.g.:
 7    ckpt = torch.load(f'/path/to/checkpoints/ckpt_{model_id}.pt')
 9    traker = TRAKer(model=model,
10                    task='image_classification',
11                    train_set_size=len(ds_train))
13    traker.load_checkpoint(ckpt, model_id=model_id)
14    for batch in loader_train:
15        traker.featurize(batch=batch, ...)
16    traker.finalize_features(model_ids=[model_id])
18    traker.start_scoring_checkpoint(exp_name=..., checkpoint=ckpt, model_id=model_id, ...)
19    for batch in loader_val:
20        traker.score(batch=batch, ...)
22    # This will be called from instead.
23    # scores = traker.finalize_scores()
25if __name__ == "__main__":
26    parser = ArgumentParser()
27    parser.add_argument('--model_id', required=True, type=int)
28    args = parser.parse_args()
29    main(args.model_id)

2. Run featurize in parallel#

Now we can run the above script script in parallel with a run.sbatch. Here is a minimal example:

#SBATCH --nodes=1
#SBATCH --cpus-per-task=8
#SBATCH --gres=gpu:a100:1
#SBATCH --array=0-9
#SBATCH --job-name=trak


python --model_id $MODEL_ID

The above script will submit 10 jobs in parallel or us: this is specified by the #SBATCH array=0-9 command. Each job will pass in the job ID as a model ID for TRAK. To learn more about the SBATCH, check out SLURMs docs.

Note that on line 16 of the example above, we call finalize_features() with model_ids=[model_id]. This is important — if we don’t specify this, TRAK by default attempts to finalize the features for all model_ids (checkpoints) in the save_dir of the current TRAKer instance.


sbatch run.sbatch

in the terminal will populate the specified save_dir with all intermediate results we need to compute the final TRAK scores.

3. Gather final scores#

The only thing left to do is call TRAKer.finalize_scores(). This method combines the scores across checkpoints (think of it as a gather). This is what will do:

from trak import TRAKer

model = ...

traker = TRAKer(model=model, task='image_classification', ...)
scores = traker.finalize_scores(exp_name=...)

That’s it!


Ease of parallelization was a priority for us when we designed TRAK. The above example uses SLURM to achieve parallelization but is definitely not the only option — for example, you should have no problems integrating TRAK with torch distributed.