# Applying `TRAK`

to a custom task #1: Classification#

In this tutorial, we’ll demonstrate how to apply `TRAK`

to a
custom task, using classification as an example.

Applying `TRAK`

to a new task requires defining an appropriate **model output function**,
which is implemented by extending `AbstractModelOutput`

.
First, we’ll conceptually go over what a model output function is. Then, we will see how it is implemented inside `TRAK`

for the case of (image) classification.

The `TRAK`

library already ships with an implementation of `AbstractModelOutput`

for several standard tasks. For example, to use the one corresponding to standard classification (for tasks with a single input, e.g., image classification),
you simply specify the task as follows:

```
traker = TRAKer(..., task="image_classification")
```

## Prelim: Model output functions#

Computing `TRAK`

scores requires specifying a **model output function** that you want to attribute. Intuitively, you can just think of it as a some kind of loss or scoring function evaluated on an example.

More formally, given:

an example of interest \(z\) (e.g., an input-label pair) and

model parameters \(\theta\),

the model output function \(f(z;\theta)\) computes a real number based on evaluating the model on example \(z\).

For example, one choice of model output function could be the *loss* \(L(z)\)
that the model incurs on example \(z\) (e.g., the cross-entropy loss).
We motivate and derive appropriate model output
functions for several standard tasks (binary and multiclass classification, CLIP loss,
and some NLP tasks) in detail in our paper.

Give a model output function \(f(\cdot;\theta)\) and a target example \(z\) of interest, `TRAK`

computes the *attribution score* of each training example \(z_i\) indicating its importance to \(f(z;\theta)\).

## Implementing model output functions in `TRAK`

#

In order for `TRAKer`

to compute attribution scores, it needs access to the following two functions:

The model output function itself, i.e., \(f(z;\theta)\)

The gradient of the (training) loss w.r.t. to the model output function, i.e., \(\frac{\partial L(z;\theta)}{\partial f}\). We refer to this function simply as

*output-to-loss gradient.*

We provide a dedicated class, `AbstractModelOutput`

, that computes the above two functions from a model (a `torch.Module`

instance) using the following two functions:

The `AbstractModelOutput.get_output()`

method implements the model output
function: given a batch of examples, it returns a
vector containing the model outputs for each example in the batch.
This is the
function that `TRAKer`

computes gradients of.

The `AbstractModelOutput.get_out_to_loss_grad()`

method implements the output-to-loss gradient. Since for all the examples in our paper we
could analytically derive this term, we “hardcode”
this in the `get_out_to_loss_grad`

method, thus avoiding an additional
gradient computation.

Note

If you find yourself in the (likely rare) situation where you can’t
analytically derive the output-to-loss gradient, you can implement `AbstractModelOutput.get_out_to_loss_grad()`

by
first computing the model output as in `AbstractModelOutput.get_output()`

and using `autograd`

to compute the output-to-loss gradient.

So to apply `TRAK`

to a new task, all you have to do is extend `AbstractModelOutput`

and implement the above two functions, then pass in the new model output object as
the `task`

when instantiating `TRAKer`

:

```
class CustomModelOutput(AbstractModelOutput):
def get_output(...):
# Implement
def get_out_to_loss_grad(...):
# Implement
traker = TRAKer(model=model,
task=CustomModelOutput,
...)
```

Note

If you implement a `AbstractModelOutput`

for a common task or objective that you think may be useful to others, please make a pull request
and we can include it as a default (so that you can just specify the `task`

as a string).

## Example: Classification#

To illustrate how to implement `AbstractModelOutput`

, we’ll look at the example of standard classification, where the model is optimized to minimize
the cross-entropy loss:

where \(p(z;\theta)\) is the soft-max probability associated for the correct class \(y\) for example \(z=(x,y)\).

For classification, we use the following model output function:

Note

This is the natural analog to the logit function in binary logistic regression. See Section 3 in our paper for an explanation of why this is an appropriate choice.

The corresponding output-to-loss gradient is given by:

Note

Note that \(p\) here is the soft-max probability associated with the
correct class \(y\) for example \(z=(x,y)\), and *not* the logit
associated with the correct class. In other words, \(p(z;\theta) =
\frac{\exp(g_y)}{\sum_j \exp(g_j)}\) where \(g(z;\theta)\) is the vector
of logits for example \(z=(x,y)\). Thus, the model output function in
terms of logits is given by:

We can implement the above model output function in a numerically stable way
using the built-in `torch`

`logsumexp`

function.

### Implementation#

For the above choice of model output function, `TRAK`

provides a default implementation
as `ImageClassificationModelOutput`

.
Below, we reproduce the implementation so that you can see how it’s implemented.
The model output function is implemented as follows:

```
def get_output(model: Module,
weights: Iterable[Tensor],
buffers: Iterable[Tensor],
image: Tensor,
label: Tensor):
logits = ch.func.functional_call(model, (weights, buffers), image.unsqueeze(0))
bindex = ch.arange(logits.shape[0]).to(logits.device, non_blocking=False)
logits_correct = logits[bindex, label.unsqueeze(0)]
cloned_logits = logits.clone()
# remove the logits of the correct labels from the sum
# in logsumexp by setting to -ch.inf
cloned_logits[bindex, label.unsqueeze(0)] = ch.tensor(-ch.inf, device=logits.device, dtype=logits.dtype)
margins = logits_correct - cloned_logits.logsumexp(dim=-1)
return margins.sum()
```

Note that the `get_output`

function uses `torch.func`

’s
`functional_call`

to make a stateless forward pass.

Note

In `TRAK`

, we use `torch.func`

’s `vmap`

to make the per-sample gradient
computations faster. Check out, e.g., this torch.func tutorial to
learn more about how to use `torch.func`

.

Similarly, the output-to-loss gradient function is implemented as follows:

```
def get_out_to_loss_grad(self, model, weights, buffers, batch):
images, labels = batch
logits = ch.func.functional_call(model, (weights, buffers), images)
# here we are directly implementing the gradient instead of relying on autodiff to do
# that for us
ps = self.softmax(logits / self.loss_temperature)[ch.arange(logits.size(0)), labels]
return (1 - ps).clone().detach().unsqueeze(-1)
```

Note that we are directly implementing the gradient we analytically derived above (instead of using automatic differentiation).

That’s all!
Though we showed how `ImageClassificationModelOutput`

is implemented inside, to use it you just need to specify
`task=image_classification`

when instantiating `TRAKer`

.

## Extending to other tasks#

For more examples, see Applying TRAK to a custom task #2: Text Classification using BERT and Applying TRAK to a custom task #3: CLIP.