Applying TRAK
to a custom task #1: Classification¶
In this tutorial, we’ll demonstrate how to apply TRAK
to a
custom task, using classification as an example.
Applying TRAK
to a new task requires defining an appropriate model output function,
which is implemented by extending AbstractModelOutput
.
First, we’ll conceptually go over what a model output function is. Then, we will see how it is implemented inside TRAK
for the case of (image) classification.
The TRAK
library already ships with an implementation of AbstractModelOutput
for several standard tasks. For example, to use the one corresponding to standard classification (for tasks with a single input, e.g., image classification),
you simply specify the task as follows:
traker = TRAKer(..., task="image_classification")
Prelim: Model output functions¶
Computing TRAK
scores requires specifying a model output function that you want to attribute. Intuitively, you can just think of it as a some kind of loss or scoring function evaluated on an example.
More formally, given:
an example of interest \(z\) (e.g., an input-label pair) and
model parameters \(\theta\),
the model output function \(f(z;\theta)\) computes a real number based on evaluating the model on example \(z\).
For example, one choice of model output function could be the loss \(L(z)\) that the model incurs on example \(z\) (e.g., the cross-entropy loss). We motivate and derive appropriate model output functions for several standard tasks (binary and multiclass classification, CLIP loss, and some NLP tasks) in detail in our paper.
Give a model output function \(f(\cdot;\theta)\) and a target example \(z\) of interest, TRAK
computes the attribution score of each training example \(z_i\) indicating its importance to \(f(z;\theta)\).
Implementing model output functions in TRAK
¶
In order for TRAKer
to compute attribution scores, it needs access to the following two functions:
The model output function itself, i.e., \(f(z;\theta)\)
The gradient of the (training) loss w.r.t. to the model output function, i.e., \(\frac{\partial L(z;\theta)}{\partial f}\). We refer to this function simply as output-to-loss gradient.
We provide a dedicated class, AbstractModelOutput
, that computes the above two functions from a model (a torch.Module
instance) using the following two functions:
The AbstractModelOutput.get_output()
method implements the model output
function: given a batch of examples, it returns a
vector containing the model outputs for each example in the batch.
This is the
function that TRAKer
computes gradients of.
The AbstractModelOutput.get_out_to_loss_grad()
method implements the output-to-loss gradient. Since for all the examples in our paper we
could analytically derive this term, we “hardcode”
this in the get_out_to_loss_grad
method, thus avoiding an additional
gradient computation.
Note
If you find yourself in the (likely rare) situation where you can’t
analytically derive the output-to-loss gradient, you can implement AbstractModelOutput.get_out_to_loss_grad()
by
first computing the model output as in AbstractModelOutput.get_output()
and using autograd
to compute the output-to-loss gradient.
So to apply TRAK
to a new task, all you have to do is extend AbstractModelOutput
and implement the above two functions, then pass in the new model output object as
the task
when instantiating TRAKer
:
class CustomModelOutput(AbstractModelOutput):
def get_output(...):
# Implement
def get_out_to_loss_grad(...):
# Implement
traker = TRAKer(model=model,
task=CustomModelOutput,
...)
Note
If you implement a AbstractModelOutput
for a common task or objective that you think may be useful to others, please make a pull request
and we can include it as a default (so that you can just specify the task
as a string).
Example: Classification¶
To illustrate how to implement AbstractModelOutput
, we’ll look at the example of standard classification, where the model is optimized to minimize
the cross-entropy loss:
where \(p(z;\theta)\) is the soft-max probability associated for the correct class \(y\) for example \(z=(x,y)\).
For classification, we use the following model output function:
Note
This is the natural analog to the logit function in binary logistic regression. See Section 3 in our paper for an explanation of why this is an appropriate choice.
The corresponding output-to-loss gradient is given by:
Note
Note that \(p\) here is the soft-max probability associated with the correct class \(y\) for example \(z=(x,y)\), and not the logit associated with the correct class. In other words, \(p(z;\theta) = \frac{\exp(g_y)}{\sum_j \exp(g_j)}\) where \(g(z;\theta)\) is the vector of logits for example \(z=(x,y)\). Thus, the model output function in terms of logits is given by:
We can implement the above model output function in a numerically stable way
using the built-in torch
logsumexp
function.
Implementation¶
For the above choice of model output function, TRAK
provides a default implementation
as ImageClassificationModelOutput
.
Below, we reproduce the implementation so that you can see how it’s implemented.
The model output function is implemented as follows:
def get_output(model: Module,
weights: Iterable[Tensor],
buffers: Iterable[Tensor],
image: Tensor,
label: Tensor):
logits = ch.func.functional_call(model, (weights, buffers), image.unsqueeze(0))
bindex = ch.arange(logits.shape[0]).to(logits.device, non_blocking=False)
logits_correct = logits[bindex, label.unsqueeze(0)]
cloned_logits = logits.clone()
# remove the logits of the correct labels from the sum
# in logsumexp by setting to -ch.inf
cloned_logits[bindex, label.unsqueeze(0)] = ch.tensor(-ch.inf, device=logits.device, dtype=logits.dtype)
margins = logits_correct - cloned_logits.logsumexp(dim=-1)
return margins.sum()
Note that the get_output
function uses torch.func
’s
functional_call
to make a stateless forward pass.
Note
In TRAK
, we use torch.func
’s vmap
to make the per-sample gradient
computations faster. Check out, e.g., this torch.func tutorial to
learn more about how to use torch.func
.
Similarly, the output-to-loss gradient function is implemented as follows:
def get_out_to_loss_grad(self, model, weights, buffers, batch):
images, labels = batch
logits = ch.func.functional_call(model, (weights, buffers), images)
# here we are directly implementing the gradient instead of relying on autodiff to do
# that for us
ps = self.softmax(logits / self.loss_temperature)[ch.arange(logits.size(0)), labels]
return (1 - ps).clone().detach().unsqueeze(-1)
Note that we are directly implementing the gradient we analytically derived above (instead of using automatic differentiation).
That’s all!
Though we showed how ImageClassificationModelOutput
is implemented inside, to use it you just need to specify
task=image_classification
when instantiating TRAKer
.
Extending to other tasks¶
For more examples, see Applying TRAK to a custom task #2: Text Classification using BERT and Applying TRAK to a custom task #3: CLIP.