Parallelize TRAK scoring with SLURM#
Often we would like to compute TRAK scores from multiple checkpoints of
the same model.
Note
Check our paper to see why using multiple checkpoints helps improve TRAK’s performance.
This means that we need to run TRAKer.featurize() for all
training examples for each checkpoint. But fortunately, this is a highly parallelizable
process!
Below, we sketch a simple way of parallelizing featurize() and
score() across checkpoints. We’ll use SLURM — a popular job scheduling
system.
Note
You can find all the code for this example here. We’ll
skip some details in the post to highlight the main ideas behind using
TRAK with SLURM.
Overall, we’ll write three files:
featurize_and_score.pyrun.sbatchgather.py
We will use run.sbatch to run different instances of featurize_and_score.py
in parallel, and get the final TRAK scores using gather.py.
Note
In terms of MapReduce, you can of featurize_and_score as the map function and gather as the reduce function.
1. Featurizing each checkpoint#
Everything needed for scoring prior to finalize_scores() will go in
featurize_and_score.py.
For example, featurize_and_score.py can be as follows:
1from argparse import ArgumentParser
2from trak import TRAKer
3
4def main(model_id):
5 model,loader_train, loader_val = ...
6 # use model_id here to load the respective checkpoint, e.g.:
7 ckpt = torch.load(f'/path/to/checkpoints/ckpt_{model_id}.pt')
8
9 traker = TRAKer(model=model,
10 task='image_classification',
11 train_set_size=len(ds_train))
12
13 traker.load_checkpoint(ckpt, model_id=model_id)
14 for batch in loader_train:
15 traker.featurize(batch=batch, ...)
16 traker.finalize_features(model_ids=[model_id])
17
18 traker.start_scoring_checkpoint(exp_name=..., checkpoint=ckpt, model_id=model_id, ...)
19 for batch in loader_val:
20 traker.score(batch=batch, ...)
21
22 # This will be called from gather.py instead.
23 # scores = traker.finalize_scores()
24
25if __name__ == "__main__":
26 parser = ArgumentParser()
27 parser.add_argument('--model_id', required=True, type=int)
28 args = parser.parse_args()
29 main(args.model_id)
2. Run featurize in parallel#
Now we can run the above script script in parallel with a run.sbatch.
Here is a minimal example:
#!/bin/bash
#SBATCH --nodes=1
#SBATCH --cpus-per-task=8
#SBATCH --gres=gpu:a100:1
#SBATCH --array=0-9
#SBATCH --job-name=trak
MODEL_ID=$SLURM_ARRAY_TASK_ID
python featurize_and_score.py --model_id $MODEL_ID
The above script will submit 10 jobs in parallel or us: this is specified by the
#SBATCH array=0-9 command. Each job will pass in the job ID as a model
ID for TRAK. To learn more about the SBATCH, check out
SLURMs docs.
Note that on line 16 of the example featurize_and_score.py above, we
call finalize_features() with model_ids=[model_id]. This is
important — if we don’t specify this, TRAK by default attempts to
finalize the features for all model_ids (checkpoints) in the
save_dir of the current TRAKer instance.
Running
sbatch run.sbatch
in the terminal will populate the specified save_dir with all
intermediate results we need to compute the final TRAK scores.
3. Gather final scores#
The only thing left to do is call TRAKer.finalize_scores(). This method
combines the scores across checkpoints (think of it as a gather).
This is what gather.py will do:
from trak import TRAKer
model = ...
traker = TRAKer(model=model, task='image_classification', ...)
scores = traker.finalize_scores(exp_name=...)
That’s it!
Note
Ease of parallelization was a priority for us when we designed TRAK.
The above example uses SLURM to achieve parallelization but is
definitely not the only option — for example, you should have no problems
integrating TRAK with torch distributed.