Skip to content

Optim-wip: Add CLIP loss objectives #945

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: optim-wip
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 223 additions & 1 deletion captum/optim/_core/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@

import torch
import torch.nn as nn
from captum.optim._utils.image.common import _dot_cossim, get_neuron_pos
from captum.optim._utils.image.common import (
_create_new_vector,
_dot_cossim,
get_neuron_pos,
)
from captum.optim._utils.typing import ModuleOutputMapping


Expand Down Expand Up @@ -837,6 +841,221 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
return activations


@loss_wrapper
class L2Mean(BaseLoss):
"""
Simple L2Loss penalty where the mean is used instead of the square root of the
sum.

Used for CLIP models in https://distill.pub/2021/multimodal-neurons/ as per the
supplementary code:
https://github.com/openai/CLIP-featurevis/blob/master/example_facets.py
"""

def __init__(
self,
target: torch.nn.Module,
channel_index: Optional[int] = None,
constant: float = 0.5,
batch_index: Optional[Union[int, List[int]]] = None,
) -> None:
"""
Args:

target (nn.Module): A target layer, transform, or image parameterization
instance.
channel_index (int, optional): Optionally only target a specific channel.
If set to ``None``, all channels with be used.
Default: ``None``
constant (float, optional): Constant value to deduct from the activations.
Default: ``0.5``
batch_index (int or list of int, optional): The index or index range of
activations to optimize if optimizing a batch of activations. If set
to ``None``, defaults to all activations in the batch. Index ranges
should be in the format of: [start, end].
Default: ``None``
"""
BaseLoss.__init__(self, target, batch_index)
self.constant = constant
self.channel_index = channel_index

def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
activations = targets_to_values[self.target][
self.batch_index[0] : self.batch_index[1]
]
if self.channel_index is not None:
activations = activations[:, self.channel_index : self.channel_index + 1]
return ((activations - self.constant) ** 2).mean()


@loss_wrapper
class VectorLoss(BaseLoss):
"""
This objective is useful for optimizing towards channel directions. This can
helpful for visualizing models like OpenAI's CLIP.

This loss objective is similar to the Direction objective, except it computes the
matrix product of the activations and vector, rather than the cosine similarity.
In addition to optimizing towards channel directions, this objective can also
perform a similar role to the ChannelActivation objective by using one-hot 1D
vectors.

See here for more details:
https://distill.pub/2021/multimodal-neurons/
https://github.com/openai/CLIP-featurevis/blob/master/example_facets.py
"""

def __init__(
self,
target: torch.nn.Module,
vec: torch.Tensor,
activation_fn: Optional[Callable] = torch.nn.functional.relu,
move_channel_dim_to_final_dim: bool = True,
batch_index: Optional[Union[int, List[int]]] = None,
) -> None:
"""
Args:

target (nn.Module): A target layer instance.
vec (torch.Tensor): A 1D channel vector with the same size as the
channel / feature dimension of the target layer instance.
activation_fn (callable, optional): An optional activation function to
apply to the activations before computing the matrix product. If set
to ``None``, then no activation function will be used.
Default: ``torch.nn.functional.relu``
move_channel_dim_to_final_dim (bool, optional): Whether or not to move the
channel dimension to the last dimension before computing the matrix
product. Set to ``False`` if the using the channels last format.
Default: ``True``
batch_index (int or list of int, optional): The index or index range of
activations to optimize if optimizing a batch of activations. If set
to ``None``, defaults to all activations in the batch. Index ranges
should be in the format of: [start, end].
Default: ``None``
"""
BaseLoss.__init__(self, target, batch_index)
assert vec.dim() == 1
self.vec = vec
self.activation_fn = activation_fn
self.move_channel_dim_to_final_dim = move_channel_dim_to_final_dim

def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
activations = targets_to_values[self.target]
activations = activations[self.batch_index[0] : self.batch_index[1]]
return _create_new_vector(
activations,
vec=self.vec,
activation_fn=self.activation_fn,
move_channel_dim_to_final_dim=self.move_channel_dim_to_final_dim,
).mean()


@loss_wrapper
class FacetLoss(BaseLoss):
"""
The Facet loss objective used for Faceted Feature Visualization as described in:
https://distill.pub/2021/multimodal-neurons/#faceted-feature-visualization
https://github.com/openai/CLIP-featurevis/blob/master/example_facets.py

The FacetLoss objective allows us to steer feature visualization towards a
particular theme / concept. This is done by using the weights from linear probes
trained on the lower layers of a model to discriminate between a certain theme or
concept and generic natural images.
"""

def __init__(
self,
vec: torch.Tensor,
ultimate_target: torch.nn.Module,
layer_target: Union[torch.nn.Module, List[torch.nn.Module]],
facet_weights: torch.Tensor,
strength: Optional[Union[float, List[float]]] = None,
batch_index: Optional[Union[int, List[int]]] = None,
) -> None:
"""
Args:

vec (torch.Tensor): A 1D channel vector with the same size as the
channel / feature dimension of ultimate_target.
ultimate_target (nn.Module): The main target layer that we are
visualizing targets from. This is normally the penultimate layer of
the model.
layer_target (nn.Module): A layer that we have facet_weights for. This
target layer should be below the ``ultimate_target`` layer in the
model.
facet_weights (torch.Tensor): Weighting that steers the objective
towards a particular theme or concept. These weight values should
come from linear probes trained on ``layer_target``.
strength (float, list of float, optional): A single float or list of floats
to use for batch dimension weighting. If using a single value, then it
will be applied to all batch dimensions equally. Otherwise a list of
floats with a shape of: [start, end] should be used for
:func:`torch.linspace` to calculate the step values in between. Default
is set to ``None`` for no weighting.
Default: ``None``
batch_index (int or list of int, optional): The index or index range of
activations to optimize if optimizing a batch of activations. If set
to ``None``, defaults to all activations in the batch. Index ranges
should be in the format of: [start, end].
Default: ``None``
"""
BaseLoss.__init__(self, [ultimate_target, layer_target], batch_index)
self.ultimate_target = ultimate_target
self.layer_target = layer_target
assert vec.dim() == 1
self.vec = vec
if isinstance(strength, (tuple, list)):
assert len(strength) == 2
self.strength = strength
assert facet_weights.dim() == 4 or facet_weights.dim() == 2
self.facet_weights = facet_weights

def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
activations_ultimate = targets_to_values[self.ultimate_target]
activations_ultimate = activations_ultimate[
self.batch_index[0] : self.batch_index[1]
]
new_vec = _create_new_vector(activations_ultimate, self.vec)
target_activations = targets_to_values[self.layer_target]

layer_grad = torch.autograd.grad(
outputs=new_vec,
inputs=target_activations,
grad_outputs=torch.ones_like(new_vec),
retain_graph=True,
)[0].detach()[self.batch_index[0] : self.batch_index[1]]
layer = target_activations[self.batch_index[0] : self.batch_index[1]]

flat_attr = layer * torch.nn.functional.relu(layer_grad)
if self.facet_weights.dim() == 2 and flat_attr.dim() == 4:
flat_attr = torch.sum(flat_attr, dim=(2, 3))

if self.strength:
if isinstance(self.strength, (tuple, list)):
strength_t = torch.linspace(
self.strength[0],
self.strength[1],
steps=flat_attr.shape[0],
device=flat_attr.device,
).reshape(flat_attr.shape[0], *[1] * (flat_attr.dim() - 1))
else:
strength_t = self.strength
flat_attr = strength_t * flat_attr

if (
self.facet_weights.dim() == 4
and layer.dim() == 4
and self.facet_weights.shape[2:] != layer.shape[2:]
):
facet_weights = torch.nn.functional.interpolate(
self.facet_weights, size=layer.shape[2:]
)
else:
facet_weights = self.facet_weights

return torch.sum(flat_attr * facet_weights)


def sum_loss_list(
loss_list: List,
to_scalar_fn: Callable[[torch.Tensor], torch.Tensor] = torch.mean,
Expand Down Expand Up @@ -908,6 +1127,9 @@ def default_loss_summarize(loss_value: torch.Tensor) -> torch.Tensor:
"AngledNeuronDirection",
"TensorDirection",
"ActivationWeights",
"L2Mean",
"VectorLoss",
"FacetLoss",
"sum_loss_list",
"default_loss_summarize",
]
55 changes: 54 additions & 1 deletion captum/optim/_utils/image/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -363,3 +363,56 @@ def hex2base10(x: str) -> float:
* ((1 - (-x - 0.5) * 2) * color_list[1] + (-x - 0.5) * 2 * color_list[0])
).permute(2, 0, 1)
return color_tensor


def _create_new_vector(
x: torch.Tensor,
vec: torch.Tensor,
activation_fn: Optional[
Callable[[torch.Tensor], torch.Tensor]
] = torch.nn.functional.relu,
move_channel_dim_to_final_dim: bool = True,
) -> torch.Tensor:
"""
Create a vector using a given set of activations and another vector.
This function is intended for use in CLIP related loss objectives.

https://distill.pub/2021/multimodal-neurons/
https://github.com/openai/CLIP-featurevis/blob/master/example_facets.py
The einsum equation: "ijkl,j->ikl", used by the paper's associated code is the
same thing as: "[..., C] @ vec", where vec has a shape of 'C'.

Args:

x (torch.Tensor): A set of 2d or 4d activations.
vec (torch.Tensor): A 1D direction vector to use, with a compatible shape for
computing the matrix product of the activations. See torch.matmul for
See torch.matmul for more details on compatible shapes:
https://pytorch.org/docs/stable/generated/torch.matmul.html
By default, ``vec`` is expected to share the same size as the channel or
feature dimension of the activations.
activation_fn (Callable, optional): An optional activation function to
apply to the activations before computing the matrix product. If set
to None, then no activation function will be used.
Default: ``torch.nn.functional.relu``
move_channel_dim_to_final_dim (bool, optional): Whether or not to move the
channel dimension to the last dimension before computing the matrix
product.
Default: ``True``

Returns
x (torch.Tensor): A vector created from the input activations and the
stored vector.
"""
assert x.device == vec.device
assert x.dim() > 1 and vec.dim() == 1
if activation_fn:
x = activation_fn(x)
if x.dim() > 2:
if move_channel_dim_to_final_dim:
permute_vals = [0] + list(range(x.dim()))[2:] + [1]
x = x.permute(*permute_vals)
mean_vals = list(range(1, x.dim() - 1))
return torch.mean(x @ vec, mean_vals)
else:
return (x @ vec)[:, None]
Loading