Applying TRAK to a custom task #1: Classification#

In this tutorial, we’ll demonstrate how to apply TRAK to a custom task, using classification as an example.

Applying TRAK to a new task requires defining an appropriate model output function, which is implemented by extending AbstractModelOutput. First, we’ll conceptually go over what a model output function is. Then, we will see how it is implemented inside TRAK for the case of (image) classification.

The TRAK library already ships with an implementation of AbstractModelOutput for several standard tasks. For example, to use the one corresponding to standard classification (for tasks with a single input, e.g., image classification), you simply specify the task as follows:

traker = TRAKer(..., task="image_classification")

Prelim: Model output functions#

Computing TRAK scores requires specifying a model output function that you want to attribute. Intuitively, you can just think of it as a some kind of loss or scoring function evaluated on an example.

More formally, given:

  • an example of interest \(z\) (e.g., an input-label pair) and

  • model parameters \(\theta\),

the model output function \(f(z;\theta)\) computes a real number based on evaluating the model on example \(z\).

For example, one choice of model output function could be the loss \(L(z)\) that the model incurs on example \(z\) (e.g., the cross-entropy loss). We motivate and derive appropriate model output functions for several standard tasks (binary and multiclass classification, CLIP loss, and some NLP tasks) in detail in our paper.

Give a model output function \(f(\cdot;\theta)\) and a target example \(z\) of interest, TRAK computes the attribution score of each training example \(z_i\) indicating its importance to \(f(z;\theta)\).

Implementing model output functions in TRAK#

In order for TRAKer to compute attribution scores, it needs access to the following two functions:

  • The model output function itself, i.e., \(f(z;\theta)\)

  • The gradient of the (training) loss w.r.t. to the model output function, i.e., \(\frac{\partial L(z;\theta)}{\partial f}\). We refer to this function simply as output-to-loss gradient.

We provide a dedicated class, AbstractModelOutput, that computes the above two functions from a model (a torch.Module instance) using the following two functions:

The AbstractModelOutput.get_output() method implements the model output function: given a batch of examples, it returns a vector containing the model outputs for each example in the batch. This is the function that TRAKer computes gradients of.

The AbstractModelOutput.get_out_to_loss_grad() method implements the output-to-loss gradient. Since for all the examples in our paper we could analytically derive this term, we “hardcode” this in the get_out_to_loss_grad method, thus avoiding an additional gradient computation.


If you find yourself in the (likely rare) situation where you can’t analytically derive the output-to-loss gradient, you can implement AbstractModelOutput.get_out_to_loss_grad() by first computing the model output as in AbstractModelOutput.get_output() and using autograd to compute the output-to-loss gradient.

So to apply TRAK to a new task, all you have to do is extend AbstractModelOutput and implement the above two functions, then pass in the new model output object as the task when instantiating TRAKer:

class CustomModelOutput(AbstractModelOutput):
    def get_output(...):
        # Implement

    def get_out_to_loss_grad(...):
        # Implement

traker = TRAKer(model=model,


If you implement a AbstractModelOutput for a common task or objective that you think may be useful to others, please make a pull request and we can include it as a default (so that you can just specify the task as a string).

Example: Classification#

To illustrate how to implement AbstractModelOutput, we’ll look at the example of standard classification, where the model is optimized to minimize the cross-entropy loss:

\[L(z;\theta) = -\log(p(z;\theta))\]

where \(p(z;\theta)\) is the soft-max probability associated for the correct class \(y\) for example \(z=(x,y)\).

For classification, we use the following model output function:

\[f(z;\theta) = \log\left(\frac{p(z;\theta)}{1 - p(z;\theta)}\right)\]


This is the natural analog to the logit function in binary logistic regression. See Section 3 in our paper for an explanation of why this is an appropriate choice.

The corresponding output-to-loss gradient is given by:

\[\frac{\partial L(z;\theta)}{\partial f} = -\frac{\partial}{\partial f} \log(1 + \exp(-f)) = \frac{\exp(-f)}{1 + \exp(-f)} = 1 - p(z;\theta)\]


Note that \(p\) here is the soft-max probability associated with the correct class \(y\) for example \(z=(x,y)\), and not the logit associated with the correct class. In other words, \(p(z;\theta) = \frac{\exp(g_y)}{\sum_j \exp(g_j)}\) where \(g(z;\theta)\) is the vector of logits for example \(z=(x,y)\). Thus, the model output function in terms of logits is given by:

\[\begin{split}\begin{align} f(z;\theta) &= \log\left(\frac{{\exp(g_y)}/{\sum_j \exp(g_j)}}{1-{\exp(g_y)}/{\sum_j \exp(g_j)}}\right) \\ &= \log\left(\frac{{\exp(g_y)}/{\sum_j \exp(g_j)}}{(\sum_j \exp(g_j)-{\exp(g_y)})/{\sum_j \exp(g_j)}}\right) \\ &= \log\left(\frac{\exp(g_y)}{\sum_{j\neq y} \exp(g_j)}\right) \\ &= g_y - \log\left(\sum_{j\neq y} \exp(g_j)\right) \end{align}\end{split}\]

We can implement the above model output function in a numerically stable way using the built-in torch logsumexp function.


For the above choice of model output function, TRAK provides a default implementation as ImageClassificationModelOutput. Below, we reproduce the implementation so that you can see how it’s implemented. The model output function is implemented as follows:

def get_output(model: Module,
               weights: Iterable[Tensor],
               buffers: Iterable[Tensor],
               image: Tensor,
               label: Tensor):
    logits = ch.func.functional_call(model, (weights, buffers), image.unsqueeze(0))
    bindex = ch.arange(logits.shape[0]).to(logits.device, non_blocking=False)
    logits_correct = logits[bindex, label.unsqueeze(0)]

    cloned_logits = logits.clone()
    # remove the logits of the correct labels from the sum
    # in logsumexp by setting to -ch.inf
    cloned_logits[bindex, label.unsqueeze(0)] = ch.tensor(-ch.inf, device=logits.device, dtype=logits.dtype)

    margins = logits_correct - cloned_logits.logsumexp(dim=-1)
    return margins.sum()

Note that the get_output function uses torch.func’s functional_call to make a stateless forward pass.


In TRAK, we use torch.func’s vmap to make the per-sample gradient computations faster. Check out, e.g., this torch.func tutorial to learn more about how to use torch.func.

Similarly, the output-to-loss gradient function is implemented as follows:

def get_out_to_loss_grad(self, model, weights, buffers, batch):
    images, labels = batch
    logits = ch.func.functional_call(model, (weights, buffers), images)
    # here we are directly implementing the gradient instead of relying on autodiff to do
    # that for us
    ps = self.softmax(logits / self.loss_temperature)[ch.arange(logits.size(0)), labels]
    return (1 - ps).clone().detach().unsqueeze(-1)

Note that we are directly implementing the gradient we analytically derived above (instead of using automatic differentiation).

That’s all! Though we showed how ImageClassificationModelOutput is implemented inside, to use it you just need to specify task=image_classification when instantiating TRAKer.

Extending to other tasks#

For more examples, see Applying TRAK to a custom task #2: Text Classification using BERT and Applying TRAK to a custom task #3: CLIP.