Parallelize TRAK scoring with SLURM#

Often we would like to compute TRAK scores from multiple checkpoints of the same model.

Note

Check our paper to see why using multiple checkpoints helps improve TRAK’s performance.

This means that we need to run TRAKer.featurize() for all training examples for each checkpoint. But fortunately, this is a highly parallelizable process!

Below, we sketch a simple way of parallelizing featurize() and score() across checkpoints. We’ll use SLURM — a popular job scheduling system.

Note

You can find all the code for this example here. We’ll skip some details in the post to highlight the main ideas behind using TRAK with SLURM.

Overall, we’ll write three files:

  • featurize_and_score.py

  • run.sbatch

  • gather.py

We will use run.sbatch to run different instances of featurize_and_score.py in parallel, and get the final TRAK scores using gather.py.

Note

In terms of MapReduce, you can of featurize_and_score as the map function and gather as the reduce function.

1. Featurizing each checkpoint#

Everything needed for scoring prior to finalize_scores() will go in featurize_and_score.py. For example, featurize_and_score.py can be as follows:

 1from argparse import ArgumentParser
 2from trak import TRAKer
 3
 4def main(model_id):
 5    model,loader_train, loader_val = ...
 6    # use model_id here to load the respective checkpoint, e.g.:
 7    ckpt = torch.load(f'/path/to/checkpoints/ckpt_{model_id}.pt')
 8
 9    traker = TRAKer(model=model,
10                    task='image_classification',
11                    train_set_size=len(ds_train))
12
13    traker.load_checkpoint(ckpt, model_id=model_id)
14    for batch in loader_train:
15        traker.featurize(batch=batch, ...)
16    traker.finalize_features(model_ids=[model_id])
17
18    traker.start_scoring_checkpoint(exp_name=..., checkpoint=ckpt, model_id=model_id, ...)
19    for batch in loader_val:
20        traker.score(batch=batch, ...)
21
22    # This will be called from gather.py instead.
23    # scores = traker.finalize_scores()
24
25if __name__ == "__main__":
26    parser = ArgumentParser()
27    parser.add_argument('--model_id', required=True, type=int)
28    args = parser.parse_args()
29    main(args.model_id)

2. Run featurize in parallel#

Now we can run the above script script in parallel with a run.sbatch. Here is a minimal example:

#!/bin/bash
#SBATCH --nodes=1
#SBATCH --cpus-per-task=8
#SBATCH --gres=gpu:a100:1
#SBATCH --array=0-9
#SBATCH --job-name=trak

MODEL_ID=$SLURM_ARRAY_TASK_ID

python featurize_and_score.py --model_id $MODEL_ID

The above script will submit 10 jobs in parallel or us: this is specified by the #SBATCH array=0-9 command. Each job will pass in the job ID as a model ID for TRAK. To learn more about the SBATCH, check out SLURMs docs.

Note that on line 16 of the example featurize_and_score.py above, we call finalize_features() with model_ids=[model_id]. This is important — if we don’t specify this, TRAK by default attempts to finalize the features for all model_ids (checkpoints) in the save_dir of the current TRAKer instance.

Running

sbatch run.sbatch

in the terminal will populate the specified save_dir with all intermediate results we need to compute the final TRAK scores.

3. Gather final scores#

The only thing left to do is call TRAKer.finalize_scores(). This method combines the scores across checkpoints (think of it as a gather). This is what gather.py will do:

from trak import TRAKer

model = ...

traker = TRAKer(model=model, task='image_classification', ...)
scores = traker.finalize_scores(exp_name=...)

That’s it!

Note

Ease of parallelization was a priority for us when we designed TRAK. The above example uses SLURM to achieve parallelization but is definitely not the only option — for example, you should have no problems integrating TRAK with torch distributed.