"""
Here we provide an abstract "model output" class AbstractModelOutput, together
with a number of subclasses for particular applications (vision, language, etc):
- :class:`.ImageClassificationModelOutput`
- :class:`.CLIPModelOutput`
- :class:`.TextClassificationModelOutput`
- :class:`.IterativeImageClassificationModelOutput`
These classes implement methods that transform input batches to the desired
model output (e.g. logits, loss, etc). See Sections 2 & 3 of `our paper
<https://arxiv.org/abs/2303.14186>`_ for more details on what model output
functions are in the context of TRAK and how to use & design them.
See, e.g. `this tutorial
<https://trak.readthedocs.io/en/latest/modeloutput.html>`_ for an example on how
to subclass :code:`AbstractModelOutput` for a task of your choice.
"""
from abc import ABC, abstractmethod
from typing import Iterable
from torch import Tensor
from torch.nn import Module
import torch as ch
[docs]
class AbstractModelOutput(ABC):
"""See, e.g. `this tutorial <https://trak.readthedocs.io/en/latest/clip.html>`_
for an example on how to subclass :code:`AbstractModelOutput` for a task of
your choice.
Subclasses must implement:
- a :code:`get_output` method that takes in a batch of inputs and model
weights to produce outputs that TRAK will be trained to predict. In the
notation of the paper, :code:`get_output` should return :math:`f(z,\\theta)`
- a :code:`get_out_to_loss_grad` method that takes in a batch of inputs and
model weights to produce the gradient of the function that transforms the
model outputs above into the loss with respect to the batch. In the
notation of the paper, :code:`get_out_to_loss_grad` returns (entries along
the diagonal of) :math:`Q`.
"""
[docs]
@abstractmethod
def __init__(self) -> None:
pass
[docs]
@abstractmethod
def get_output(self, model, batch: Iterable[Tensor]) -> Tensor:
"""See Sections 2 & 3 of `our paper
<https://arxiv.org/abs/2303.14186>`_ for more details on what model
output functions are in the context of TRAK and how to use & design
them.
Args:
model (torch.nn.Module):
model
batch (Iterable[Tensor]):
input batch
Returns:
Tensor:
model output function
"""
...
[docs]
@abstractmethod
def get_out_to_loss_grad(self, model, batch: Iterable[Tensor]) -> Tensor:
"""See Sections 2 & 3 of `our paper
<https://arxiv.org/abs/2303.14186>`_ for more details on what the
out-to-loss functions (in the notation of the paper, :math:`Q`) are in
the context of TRAK and how to use & design them.
Args:
model (torch.nn.Module): model
batch (Iterable[Tensor]): input batch
Returns:
Tensor: gradient of the out-to-loss function
"""
...
[docs]
class ImageClassificationModelOutput(AbstractModelOutput):
"""Margin for (multiclass) image classification. See Section 3.3 of `our
paper <https://arxiv.org/abs/2303.14186>`_ for more details.
"""
[docs]
def __init__(self, temperature: float = 1.0) -> None:
"""
Args:
temperature (float, optional): Temperature to use inside the
softmax for the out-to-loss function. Defaults to 1.
"""
super().__init__()
self.softmax = ch.nn.Softmax(-1)
self.loss_temperature = temperature
[docs]
@staticmethod
def get_output(
model: Module,
weights: Iterable[Tensor],
buffers: Iterable[Tensor],
image: Tensor,
label: Tensor,
) -> Tensor:
"""For a given input :math:`z=(x, y)` and model parameters :math:`\\theta`,
let :math:`p(z, \\theta)` be the softmax probability of the correct class.
This method implements the model output function
.. math::
\\log(\\frac{p(z, \\theta)}{1 - p(z, \\theta)}).
It uses functional models from torch.func (previously functorch) to make
the per-sample gradient computations (much) faster. For more details on
what functional models are, and how to use them, please refer to
https://pytorch.org/docs/stable/func.html and
https://pytorch.org/functorch/stable/notebooks/per_sample_grads.html.
Args:
model (torch.nn.Module):
torch model
weights (Iterable[Tensor]):
functorch model weights
buffers (Iterable[Tensor]):
functorch model buffers
image (Tensor):
input image, should not have batch dimension
label (Tensor):
input label, should not have batch dimension
Returns:
Tensor:
model output for the given image-label pair :math:`z` and
weights & buffers :math:`\\theta`.
"""
logits = ch.func.functional_call(model, (weights, buffers), image.unsqueeze(0))
bindex = ch.arange(logits.shape[0]).to(logits.device, non_blocking=False)
logits_correct = logits[bindex, label.unsqueeze(0)]
cloned_logits = logits.clone()
# remove the logits of the correct labels from the sum
# in logsumexp by setting to -ch.inf
cloned_logits[bindex, label.unsqueeze(0)] = ch.tensor(
-ch.inf, device=logits.device, dtype=logits.dtype
)
margins = logits_correct - cloned_logits.logsumexp(dim=-1)
return margins.sum()
[docs]
def get_out_to_loss_grad(
self, model, weights, buffers, batch: Iterable[Tensor]
) -> Tensor:
"""Computes the (reweighting term Q in the paper)
Args:
model (torch.nn.Module):
torch model
weights (Iterable[Tensor]):
functorch model weights
buffers (Iterable[Tensor]):
functorch model buffers
batch (Iterable[Tensor]):
input batch
Returns:
Tensor:
out-to-loss (reweighting term) for the input batch
"""
images, labels = batch
logits = ch.func.functional_call(model, (weights, buffers), images)
# here we are directly implementing the gradient instead of relying on autodiff to do
# that for us
ps = self.softmax(logits / self.loss_temperature)[
ch.arange(logits.size(0)), labels
]
return (1 - ps).clone().detach().unsqueeze(-1)
[docs]
class CLIPModelOutput(AbstractModelOutput):
"""Margin for multimodal contrastive learning (CLIP). See Section 5.1 of
`our paper <https://arxiv.org/abs/2303.14186>`_ for more details.
Compatible with the open_clip implementation of CLIP.
Raises:
AssertionError: this model output function requires using additional
CLIP embeddings, which are computed using the :func:`get_embeddings`
method. This method should be invoked before featurizing.
"""
num_computed_embeddings = 0
sim_batch_size = 0
image_embeddings = None
text_embeddings = None
[docs]
def __init__(
self, temperature: float = None, simulated_batch_size: int = 300
) -> None:
"""
Args:
temperature (float, optional):
Temperature to use inside the softmax for the out-to-loss
function. If None, CLIP's :code:`logit_scale` is used. Defaults
to None
simulated_batch_size (int, optional):
Size of the "simulated" batch size for the model output
function. See Section 5.1 of the TRAK paper for more details.
Defaults to 300.
"""
super().__init__()
self.softmax = ch.nn.Softmax(-1)
self.temperature = temperature
ch.backends.cuda.enable_mem_efficient_sdp(False)
self.sim_batch_size = simulated_batch_size
CLIPModelOutput.sim_batch_size = simulated_batch_size
[docs]
@staticmethod
def get_embeddings(
model,
loader,
batch_size: int,
embedding_dim: int,
size: int = 50_000,
preprocess_fn_img=None,
preprocess_fn_txt=None,
) -> None:
"""Computes (image and text) embeddings and saves them in the class
attributes :code:`image_embeddings` and :code:`text_embeddings`.
Args:
model (torch.nn.Module):
model
loader ():
data loader
batch_size (int):
input batch size
size (int, optional):
Maximum number of embeddings to compute. Defaults to 50_000.
embedding_dim (int, optional):
Dimension of CLIP embedding. Defaults to 1024.
preprocess_fn_img (func, optional):
Transforms to apply to images from the loader before forward
pass. Defaults to None.
preprocess_fn_txt (func, optional):
Transforms to apply to images from the loader before forward
pass. Defaults to None.
"""
img_embs, txt_embs = (
ch.zeros(size, embedding_dim).cuda(),
ch.zeros(size, embedding_dim).cuda(),
)
cutoff = batch_size
with ch.no_grad():
for ind, (images, text) in enumerate(loader):
if preprocess_fn_img is not None:
images = preprocess_fn_img(images)
if preprocess_fn_txt is not None:
text = preprocess_fn_txt(text)
st, ed = ind * batch_size, min((ind + 1) * batch_size, size)
if ed == size:
cutoff = size - ind * batch_size
image_embeddings, text_embeddings, _ = model(images, text)
img_embs[st:ed] = image_embeddings[:cutoff].clone().detach()
txt_embs[st:ed] = text_embeddings[:cutoff].clone().detach()
if (ind + 1) * batch_size >= size:
break
CLIPModelOutput.image_embeddings = img_embs
CLIPModelOutput.text_embeddings = txt_embs
CLIPModelOutput.num_computed_embeddings = size
[docs]
@staticmethod
def get_output(
model: Module,
weights: Iterable[Tensor],
buffers: Iterable[Tensor],
image: Tensor,
label: Tensor,
) -> Tensor:
"""For a given input :math:`z=(x, y)` and model parameters
:math:`\\theta`, let :math:`\\phi(x, \\theta)` be the CLIP image
embedding and :math:`\\psi(y, \\theta)` be the CLIP text embedding.
Last, let :math:`B` be a (simulated) batch. This method implements the
model output function
.. math::
-\\log(\\frac{\\phi(x)\\cdot \\psi(y)}{\\sum_{(x', y')\\in B}
\\phi(x)\\cdot \\psi(y')})
-\\log(\\frac{\\phi(x)\\cdot \\psi(y)}{\\sum_{(x', y')\\in B}
\\phi(x')\\cdot \\psi(y)})
It uses functional models from torch.func (previously functorch) to make
the per-sample gradient computations (much) faster. For more details on
what functional models are, and how to use them, please refer to
https://pytorch.org/docs/stable/func.html and
https://pytorch.org/functorch/stable/notebooks/per_sample_grads.html.
Args:
model (torch.nn.Module):
torch model
weights (Iterable[Tensor]):
functorch model weights
buffers (Iterable[Tensor]):
functorch model buffers
image (Tensor):
input image, should not have batch dimension
label (Tensor):
input label, should not have batch dimension
Returns:
Tensor:
model output for the given image-label pair :math:`z` and
weights & buffers :math:`\\theta`.
"""
all_im_embs = CLIPModelOutput.image_embeddings
all_txt_embs = CLIPModelOutput.text_embeddings
N = CLIPModelOutput.num_computed_embeddings
sim_bs = CLIPModelOutput.sim_batch_size
if all_im_embs is None:
raise AssertionError(
"Run traker.task.get_embeddings first before featurizing!"
)
# tailored for open_clip
# https://github.com/mlfoundations/open_clip/blob/fb72f4db1b17133befd6c67c9cf32a533b85a321/src/open_clip/model.py#L242-L245
clip_inputs = {"image": image.unsqueeze(0), "text": label.unsqueeze(0)}
image_embeddings, text_embeddings, _ = ch.func.functional_call(
model, (weights, buffers), args=(), kwargs=clip_inputs
)
ii = ch.multinomial(
input=ch.arange(N).float(), num_samples=sim_bs, replacement=False
)
result = -ch.logsumexp(
-image_embeddings @ (text_embeddings - all_txt_embs[ii]).T, dim=1
) + -ch.logsumexp(
-text_embeddings @ (image_embeddings - all_im_embs[ii]).T, dim=1
)
return result.sum() # shape of result should be [1]
[docs]
def get_out_to_loss_grad(
self, model, weights, buffers, batch: Iterable[Tensor]
) -> Tensor:
"""Computes the (reweighting term Q in the paper)
Args:
model (torch.nn.Module):
torch model
weights (Iterable[Tensor]):
functorch model weights
buffers (Iterable[Tensor]):
functorch model buffers
batch (Iterable[Tensor]):
input batch
Returns:
Tensor:
out-to-loss (reweighting term) for the input batch
"""
image, label = batch
clip_inputs = {"image": image, "text": label}
image_embeddings, text_embeddings, temp = ch.func.functional_call(
model, (weights, buffers), args=(), kwargs=clip_inputs
)
if self.temperature is None:
self.temperature = temp
res = self.temperature * image_embeddings @ text_embeddings.T
ps = (self.softmax(res) + self.softmax(res.T)).diag() / 2.0
return (1 - ps).clone().detach()
[docs]
class TextClassificationModelOutput(AbstractModelOutput):
"""Margin for text classification models. This assumes that the model takes
in input_ids, token_type_ids, and attention_mask.
.. math::
\\text{logit}[\\text{correct}] - \\log\\left(\\sum_{i \\neq
\\text{correct}} \\exp(\\text{logit}[i])\\right)
"""
[docs]
def __init__(self, temperature=1.0) -> None:
super().__init__()
self.softmax = ch.nn.Softmax(-1)
self.loss_temperature = temperature
[docs]
@staticmethod
def get_output(
model,
weights: Iterable[Tensor],
buffers: Iterable[Tensor],
input_id: Tensor,
token_type_id: Tensor,
attention_mask: Tensor,
label: Tensor,
) -> Tensor:
kw_inputs = {
"input_ids": input_id.unsqueeze(0),
"token_type_ids": token_type_id.unsqueeze(0),
"attention_mask": attention_mask.unsqueeze(0),
}
logits = ch.func.functional_call(
model, (weights, buffers), args=(), kwargs=kw_inputs
)
bindex = ch.arange(logits.shape[0]).to(logits.device, non_blocking=False)
logits_correct = logits[bindex, label.unsqueeze(0)]
cloned_logits = logits.clone()
cloned_logits[bindex, label.unsqueeze(0)] = ch.tensor(
-ch.inf, device=logits.device, dtype=logits.dtype
)
margins = logits_correct - cloned_logits.logsumexp(dim=-1)
return margins.sum()
[docs]
def get_out_to_loss_grad(
self, model, weights, buffers, batch: Iterable[Tensor]
) -> Tensor:
input_ids, token_type_ids, attention_mask, labels = batch
kw_inputs = {
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"attention_mask": attention_mask,
}
logits = ch.func.functional_call(
model, (weights, buffers), args=(), kwargs=kw_inputs
)
ps = self.softmax(logits / self.loss_temperature)[
ch.arange(logits.size(0)), labels
]
return (1 - ps).clone().detach().unsqueeze(-1)
[docs]
class IterativeImageClassificationModelOutput(AbstractModelOutput):
"""Margin for (multiclass) image classification. See Section 3.3 of `our
paper <https://arxiv.org/abs/2303.14186>`_ for more details.
"""
[docs]
def __init__(self, temperature: float = 1.0) -> None:
"""
Args:
temperature (float, optional): Temperature to use inside the
softmax for the out-to-loss function. Defaults to 1.
"""
super().__init__()
self.softmax = ch.nn.Softmax(-1)
self.loss_temperature = temperature
[docs]
@staticmethod
def get_output(
model: Module,
weights: Iterable[Tensor],
buffers: Iterable[Tensor],
images: Tensor,
labels: Tensor,
) -> Tensor:
"""For a given input :math:`z=(x, y)` and model parameters :math:`\\theta`,
let :math:`p(z, \\theta)` be the softmax probability of the correct class.
This method implements the model output function
.. math::
\\log(\\frac{p(z, \\theta)}{1 - p(z, \\theta)}).
It uses functional models from torch.func (previously functorch) to make
the per-sample gradient computations (much) faster. For more details on
what functional models are, and how to use them, please refer to
https://pytorch.org/docs/stable/func.html and
https://pytorch.org/functorch/stable/notebooks/per_sample_grads.html.
Args:
model (torch.nn.Module):
torch model
weights (Iterable[Tensor]):
functorch model weights (added se we don't break abstraction)
buffers (Iterable[Tensor]):
functorch model buffers (added se we don't break abstraction)
images (Tensor):
input images
labels (Tensor):
input labels
Returns:
Tensor:
model output for the given image-label pair :math:`z`
"""
logits = model(images)
bindex = ch.arange(logits.shape[0]).to(logits.device, non_blocking=False)
logits_correct = logits[bindex, labels]
cloned_logits = logits.clone()
# remove the logits of the correct labels from the sum
# in logsumexp by setting to -ch.inf
cloned_logits[bindex, labels] = ch.tensor(
-ch.inf, device=logits.device, dtype=logits.dtype
)
margins = logits_correct - cloned_logits.logsumexp(dim=-1)
return margins
[docs]
def get_out_to_loss_grad(
self, model, weights, buffers, batch: Iterable[Tensor]
) -> Tensor:
"""Computes the (reweighting term Q in the paper)
Args:
model (torch.nn.Module):
torch model
weights (Iterable[Tensor]):
functorch model weights
buffers (Iterable[Tensor]):
functorch model buffers
batch (Iterable[Tensor]):
input batch
Returns:
Tensor:
out-to-loss (reweighting term) for the input batch
"""
images, labels = batch
logits = model(images)
# here we are directly implementing the gradient instead of relying on autodiff to do
# that for us
ps = self.softmax(logits / self.loss_temperature)[
ch.arange(logits.size(0)), labels
]
return (1 - ps).clone().detach().unsqueeze(-1)
TASK_TO_MODELOUT = {
"image_classification": ImageClassificationModelOutput,
"clip": CLIPModelOutput,
"text_classification": TextClassificationModelOutput,
"iterative_image_classification": IterativeImageClassificationModelOutput,
}