Applying TRAK to a custom task #3: CLIP#

In this tutorial, we’ll show another example of applying TRAK to a new custom task, CLIP. If you haven’t, you should first check out Applying TRAK to a custom task #1: Classification to familiarize yourself with the notion of a model output function and how we implement it inside TRAK.

CLIP overview#

We’ll assume that you’re familiar with how CLIP works (having only a rough idea will be sufficient). For a given image-caption pair \((x, y)\), CLIP outputs an image embedding \(\phi(x)\) and a caption embedding \(\psi(y)\).

The CLIP training loss tries to align the image embeddings with their corresponding caption embeddings. In particular, given a batch of \(n\) examples \(\{(x_1,y_1),...,(x_n,y_n)\}\), it computes all \(n \times n\) pairwise cosine similarities between the image and text embeddings \(S_{ij}:=\phi(x)\cdot\psi(y)\), and then aims to maximize the \(S_{ii}\) terms while minimizing the \(S_{ij}\) terms for \(i\neq j\):

\[L_\text{CLIP}(x_i, y_i) = -\log\left(\frac{\exp(-S_{ii})}{\sum_{j\leq n} \exp(-S_{ij})}\right) -\log\left(\frac{\exp(-S_{ii})}{\sum_{j\leq n} \exp(-S_{ji})}\right)\]

Implementing the model output function#

As in our earlier examples, to apply TRAK to this setting, we just need to define an appropriate model output function.

In our paper, we choose the following model output function:

\[f_\text{CLIP}(x_i, y_i) = -\log\sum_{j\leq n}(\exp(-S_{ii}) - \exp(-S_{ij})) -\log\sum_{j\leq n}(\exp(-S_{ii}) - \exp(-S_{ji}))\]


Intuitively, this choice is motivated by viewing the CLIP loss as a sum of two classification problems (one matching images to their correct captions, and vice versa). Check Section 5.1.1 of our papers for details.

Note that unlike in the classification, this model output evaluated at an example now depends on other examples in the batch. To get the CLIP embeddings for all the image-caption pairs in the batch, we implement an additional utility method get_embeddings(). Here, let’s just assume we have access to the arrays all_img_embeddings and all_txt_embeddings.

Now we are ready to implement CLIPModelOutput.get_output():

def get_output(model,
               weights: Iterable[Tensor],
               buffers: Iterable[Tensor],
               image: Tensor,
               label: Tensor):
    # tailored for open_clip
    clip_inputs = {"image": image.unsqueeze(0), "text": label.unsqueeze(0)}
    image_embeddings, text_embeddings, _ = ch.func.functional_call(model,
                                                                   (weights, buffers),

    ii = ch.multinomial(input=ch.arange(N).float(),

    result = -ch.logsumexp(-image_embeddings @ (text_embeddings - all_txt_embs[ii]).T, dim=1) +\
             -ch.logsumexp(-text_embeddings @ (image_embeddings - all_im_embs[ii]).T, dim=1)
    return result.sum()  # shape of result should be [1]

Finally, to compute the output-to-loss gradient term, we observe in our paper that we can reduce to the classification case and compute the corresponding probabilities:

def get_out_to_loss_grad(self, model, weights, buffers, batch):
    image, label = batch
    clip_inputs = {'image': image, 'text': label}
    image_embeddings, text_embeddings, temp = ch.func.functional_call(model,
                                                                      (weights, buffers),
    if self.temperature is None:
        self.temperature = temp
    res = self.temperature * image_embeddings @ text_embeddings.T
    ps = (self.softmax(res) + self.softmax(res.T)).diag() / 2.
    return (1 - ps).clone().detach()

Note, again, that we are directly implementing the gradient, instead of using automatic differentiation.

Putting it together#

Using the above CLIPModelOutput implementation, we can compute TRAK scores for open_clip models as follows:

model, _, preprocess = open_clip.create_model_and_transforms(...)
tokenizer = ...
loader_train, loader_val = ...

traker = TRAKer(model=model,
                task=CLIPModelOutput, # you can also just pass in "clip"

traker.task.get_embeddings(model, ds_train, batch_size=1, size=600, embedding_dim=1024,
                           preprocess_fn_img=lambda x: preprocess(x).to(device).unsqueeze(0),
                           preprocess_fn_txt=lambda x: tokenizer(x[0]).to(device))

traker.load_checkpoint(model.state_dict(), model_id=0)
for (img, captions) in tqdm(loader_train, desc='Featurizing...'):
    x = preprocess(img).to('cuda').unsqueeze(0)
    y = tokenizer(captions).to('cuda')
    traker.featurize(batch=(x, y), num_samples=x.shape[0])


for (img, captions) in tqdm(loader_val, desc='Scoring...'):
    x = preprocess(img).to('cuda').unsqueeze(0)
    y = tokenizer(captions).to('cuda')
    traker.score(batch=(x, y), num_samples=x.shape[0])

scores = traker.finalize_scores(exp_name='clip_example')

That’s all, now you’re ready to adapt TRAK to your custom tasks!