TRAK: Attributing Model Behavior at Scale#

See also

Check out our paper and blog post!


This is a PyTorch-based API for our method TRAK: an effective, efficient data attribution method for gradient-based learning algorithms. We designed TRAK’s API around the following guiding principles:

Ease of use

You can apply TRAK in just a few lines of code (see the quickstart guide).


Our API comes with fast, custom CUDA kernels. Getting state-of-the-art attribution for BERT-base on QNLI takes ~2 hours on a 8xA100 node.


Our API is lightweight - the entire codebase is less than 1000 lines of code. It is also quite modular, making it painless adapt any component to your needs.


Applying TRAK to a custom task/modality is easy (check, e.g., how to adapt TRAK to CLIP).


The PyTorch-only version of our package can be installed using

pip install traker

To install the version of our package which contains a fast, custom CUDA kernel, use

pip install traker[fast]

See the Installation FAQs for more details.


TRAK is under active development. We are still in a 0.x.x version and lots of things may change.


Indices and tables#