TRAK: Attributing Model Behavior at Scale¶
Overview¶
This is a PyTorch-based API for our method TRAK
: an
effective, efficient data attribution method for gradient-based learning
algorithms. We designed TRAK
’s API around the following guiding
principles:
- Ease of use
You can apply
TRAK
in just a few lines of code (see the quickstart guide).- Speed
Our API comes with fast, custom CUDA kernels. Getting state-of-the-art attribution for BERT-base on QNLI takes ~2 hours on a 8xA100 node.
- Simplicity
Our API is lightweight - the entire codebase is less than 1000 lines of code. It is also quite modular, making it painless adapt any component to your needs.
- Flexibility
Applying
TRAK
to a custom task/modality is easy (check, e.g., how to adapt TRAK to CLIP).
See the code
Install¶
The PyTorch-only version of our package can be installed using
pip install traker
To install the version of our package which contains a fast, custom CUDA kernel, use
pip install traker[fast]
See the Installation FAQs for more details.
Warning
TRAK
is under active development. We are still in a 0.x.x
version and lots of things may change.