Applying TRAK
to a custom task #3: CLIP¶
In this tutorial, we’ll show another example of applying TRAK
to a new
custom task, CLIP. If you haven’t,
you should first check out Applying TRAK to a custom task #1: Classification to familiarize yourself with the notion of
a model output function and how we implement it inside TRAK
.
CLIP overview¶
We’ll assume that you’re familiar with how CLIP works (having only a rough idea will be sufficient). For a given image-caption pair \((x, y)\), CLIP outputs an image embedding \(\phi(x)\) and a caption embedding \(\psi(y)\).
The CLIP training loss tries to align the image embeddings with their corresponding caption embeddings. In particular, given a batch of \(n\) examples \(\{(x_1,y_1),...,(x_n,y_n)\}\), it computes all \(n \times n\) pairwise cosine similarities between the image and text embeddings \(S_{ij}:=\phi(x)\cdot\psi(y)\), and then aims to maximize the \(S_{ii}\) terms while minimizing the \(S_{ij}\) terms for \(i\neq j\):
Implementing the model output function¶
As in our earlier examples, to apply TRAK
to this setting, we just need to define
an appropriate model output function.
In our paper, we choose the following model output function:
Note
Intuitively, this choice is motivated by viewing the CLIP loss as a sum of two classification problems (one matching images to their correct captions, and vice versa). Check Section 5.1.1 of our papers for details.
Note that unlike in the classification, this model output evaluated at an example now depends on other examples in the batch.
To get the CLIP
embeddings for all the image-caption pairs in the batch, we implement an additional utility method
get_embeddings()
. Here, let’s just assume we have
access to the arrays all_img_embeddings
and all_txt_embeddings
.
Now we are ready to implement CLIPModelOutput.get_output()
:
def get_output(model,
weights: Iterable[Tensor],
buffers: Iterable[Tensor],
image: Tensor,
label: Tensor):
# tailored for open_clip
# https://github.com/mlfoundations/open_clip/blob/fb72f4db1b17133befd6c67c9cf32a533b85a321/src/open_clip/model.py#L242-L245
clip_inputs = {"image": image.unsqueeze(0), "text": label.unsqueeze(0)}
image_embeddings, text_embeddings, _ = ch.func.functional_call(model,
(weights, buffers),
args=(),
kwargs=clip_inputs)
ii = ch.multinomial(input=ch.arange(N).float(),
num_samples=sim_bs,
replacement=False)
result = -ch.logsumexp(-image_embeddings @ (text_embeddings - all_txt_embs[ii]).T, dim=1) +\
-ch.logsumexp(-text_embeddings @ (image_embeddings - all_im_embs[ii]).T, dim=1)
return result.sum() # shape of result should be [1]
Finally, to compute the output-to-loss gradient term, we observe in our paper that we can reduce to the classification case and compute the corresponding probabilities:
def get_out_to_loss_grad(self, model, weights, buffers, batch):
image, label = batch
clip_inputs = {'image': image, 'text': label}
image_embeddings, text_embeddings, temp = ch.func.functional_call(model,
(weights, buffers),
args=(),
kwargs=clip_inputs)
if self.temperature is None:
self.temperature = temp
res = self.temperature * image_embeddings @ text_embeddings.T
ps = (self.softmax(res) + self.softmax(res.T)).diag() / 2.
return (1 - ps).clone().detach()
Note, again, that we are directly implementing the gradient, instead of using automatic differentiation.
Putting it together¶
Using the above CLIPModelOutput
implementation, we can compute
TRAK
scores for open_clip models as follows:
model, _, preprocess = open_clip.create_model_and_transforms(...)
tokenizer = ...
loader_train, loader_val = ...
traker = TRAKer(model=model,
task=CLIPModelOutput, # you can also just pass in "clip"
train_set_size=TRAIN_SET_SIZE,
save_dir=args.out,
device=device,
proj_dim=1024)
traker.task.get_embeddings(model, ds_train, batch_size=1, size=600, embedding_dim=1024,
preprocess_fn_img=lambda x: preprocess(x).to(device).unsqueeze(0),
preprocess_fn_txt=lambda x: tokenizer(x[0]).to(device))
traker.load_checkpoint(model.state_dict(), model_id=0)
for (img, captions) in tqdm(loader_train, desc='Featurizing...'):
x = preprocess(img).to('cuda').unsqueeze(0)
y = tokenizer(captions).to('cuda')
traker.featurize(batch=(x, y), num_samples=x.shape[0])
traker.finalize_features()
traker.start_scoring_checkpoint(exp_name='clip_example',
checkpoint=model.state_dict(),
model_id=0,
num_targets=VAL_SET_SIZE)
for (img, captions) in tqdm(loader_val, desc='Scoring...'):
x = preprocess(img).to('cuda').unsqueeze(0)
y = tokenizer(captions).to('cuda')
traker.score(batch=(x, y), num_samples=x.shape[0])
scores = traker.finalize_scores(exp_name='clip_example')
That’s all, now you’re ready to adapt TRAK
to your custom tasks!