Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensorboard projector for contrastive learning for embedding visualization #217

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions viscy/data/triplet.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ def __getitem__(self, index: int) -> TripletSample:
if self.time_interval == "any":
positive_patch = anchor_patch.clone()
positive_norm = anchor_norm
positive_row = anchor_row
else:
positive_row = self._sample_positive(anchor_row)
positive_patch, positive_norm = self._slice_patch(positive_row)
Expand Down Expand Up @@ -261,14 +262,29 @@ def __getitem__(self, index: int) -> TripletSample:
patch=anchor_patch,
norm_meta=anchor_norm,
)
sample = {"anchor": anchor_patch}
sample = {
"anchor": anchor_patch,
"anchor_metadata": anchor_row[
INDEX_COLUMNS
].to_dict(), # Always include metadata
}
if self.fit:
if self.return_negative:
sample.update({"positive": positive_patch, "negative": negative_patch})
sample.update(
{
"positive": positive_patch,
"negative": negative_patch,
"positive_metadata": positive_row[INDEX_COLUMNS].to_dict(),
"negative_metadata": negative_row[INDEX_COLUMNS].to_dict(),
}
)
else:
sample.update({"positive": positive_patch})
else:
sample.update({"index": anchor_row[INDEX_COLUMNS].to_dict()})
sample.update(
{
"positive": positive_patch,
"positive_index": positive_row[INDEX_COLUMNS].to_dict(),
}
)
return sample


Expand Down
4 changes: 4 additions & 0 deletions viscy/data/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class TrackingIndex(TypedDict):

fov_name: OneOrSeq[str]
id: OneOrSeq[int]
t: OneOrSeq[int]


class TripletSample(TypedDict):
Expand All @@ -77,3 +78,6 @@ class TripletSample(TypedDict):
anchor: Tensor
positive: NotRequired[Tensor]
negative: NotRequired[Tensor]
anchor_metadata: NotRequired[TrackingIndex]
positive_metadata: NotRequired[TrackingIndex]
negative_metadata: NotRequired[TrackingIndex]
228 changes: 201 additions & 27 deletions viscy/representation/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from lightning.pytorch import LightningModule
from pytorch_metric_learning.losses import NTXentLoss
from torch import Tensor, nn
from umap import UMAP

from viscy.data.typing import TrackingIndex, TripletSample
from viscy.representation.contrastive import ContrastiveEncoder
Expand All @@ -16,6 +15,7 @@
_logger = logging.getLogger("lightning.pytorch")


# TODO: log the embeddings every other epoch? expose a variable to control this
class ContrastivePrediction(TypedDict):
features: Tensor
projections: Tensor
Expand All @@ -35,7 +35,7 @@ def __init__(
schedule: Literal["WarmupCosine", "Constant"] = "Constant",
log_batches_per_epoch: int = 8,
log_samples_per_batch: int = 1,
log_embeddings: bool = False,
embedding_log_interval: int = 1, # Log embeddings every N epochs
example_input_array_shape: Sequence[int] = (1, 2, 15, 256, 256),
) -> None:
super().__init__()
Expand All @@ -48,7 +48,7 @@ def __init__(
self.example_input_array = torch.rand(*example_input_array_shape)
self.training_step_outputs = []
self.validation_step_outputs = []
self.log_embeddings = log_embeddings
self.embedding_log_interval = embedding_log_interval

def forward(self, x: Tensor) -> Tensor:
"Only return projected embeddings for training and validation."
Expand Down Expand Up @@ -121,19 +121,103 @@ def _log_step_samples(self, batch_idx, samples, stage: Literal["train", "val"]):
)
output_list.extend(detach_sample(samples, self.log_samples_per_batch))

def log_embedding_umap(self, embeddings: Tensor, tag: str):
_logger.debug(f"Computing UMAP for {tag} embeddings.")
umap = UMAP(n_components=2)
embeddings_np = embeddings.detach().cpu().numpy()
umap_embeddings = umap.fit_transform(embeddings_np)
def log_embedding_tensorboard(
self,
embeddings: Tensor,
images: Tensor,
metadata: Sequence[list],
tag: str,
metadata_header: Sequence[str],
global_step: int = 0,
):
"""Log embeddings with their corresponding images and metadata to TensorBoard Embedding Projector

Args:
embeddings: Tensor of embeddings to visualize
images: Corresponding images for the embeddings (B, C, D, H, W) or (B, C, H, W)
where D is the depth dimension
metadata: List of list with the metadata for each embedding
tag: Name tag for the embedding visualization
metadata_header: List of strings with the header for each metadata column
global_step: Current training step
"""
_logger.debug(
f"Logging embeddings to TensorBoard Embedding Projector for {tag}"
)
# Store original embeddings tensor for norm calculations
embeddings_tensor = embeddings.detach()
# Convert to numpy only for visualization
embeddings_numpy = embeddings_tensor.cpu().numpy()
# Take middle slice of 3D images for visualization
images = images.detach().cpu()
if images.ndim == 5: # (B, C, D, H, W)
middle_d = images.shape[2] // 2
images = images[:, :, middle_d] # Now (B, C, H, W)

# Handle different channel configurations
if images.shape[1] > 1:
# Create a list to store normalized channels
normalized_channels = []
for ch in range(images.shape[1]):
# Convert single channel to grayscale
ch_images = images[:, ch : ch + 1]
# Normalize each channel independently
ch_images = (ch_images - ch_images.min()) / (
ch_images.max() - ch_images.min()
)
normalized_channels.append(ch_images)

# Combine channels - using first channel for red, second for green, rest averaged for blue
combined_images = torch.zeros(
images.shape[0], 3, images.shape[2], images.shape[3]
)
combined_images[:, 0] = normalized_channels[0].squeeze(1) # Red channel
combined_images[:, 1] = (
normalized_channels[1].squeeze(1)
if len(normalized_channels) > 1
else normalized_channels[0].squeeze(1)
) # Green channel
if len(normalized_channels) > 2:
combined_images[:, 2] = (
torch.stack(normalized_channels[2:]).mean(dim=0).squeeze(1)
) # Blue channel - average of remaining channels
else:
combined_images[:, 2] = normalized_channels[0].squeeze(1)
else:
# For single channel, repeat to create grayscale
combined_images = images.repeat(1, 3, 1, 1)
combined_images = (combined_images - combined_images.min()) / (
combined_images.max() - combined_images.min()
)

# Log UMAP embeddings to TensorBoard
# Log a single embedding visualization with the combined image
self.logger.experiment.add_embedding(
umap_embeddings,
global_step=self.current_epoch,
tag=f"{tag}_umap",
embeddings_numpy,
metadata=metadata,
label_img=combined_images,
global_step=global_step,
tag=tag,
metadata_header=metadata_header,
)

# Log statistics using the original tensor
self.log(
f"{tag}_mean_norm",
torch.norm(embeddings_tensor, dim=1).mean(),
on_epoch=True,
)
self.log(
f"{tag}_std_norm",
torch.norm(embeddings_tensor, dim=1).std(),
on_epoch=True,
)

def _format_metadata(self, index: TrackingIndex | None) -> str:
"""Format tracking index into a metadata string."""
if index is None:
return "unknown"
return f"track_{index.get('track_id', 'unknown')}:fov_{index.get('fov', 'unknown')}"

def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor:
anchor_img = batch["anchor"]
pos_img = batch["positive"]
Expand Down Expand Up @@ -168,37 +252,122 @@ def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor:
def on_train_epoch_end(self) -> None:
super().on_train_epoch_end()
self._log_samples("train_samples", self.training_step_outputs)
# Log UMAP embeddings for validation
if self.log_embeddings:
embeddings = torch.cat(
[output["embeddings"] for output in self.validation_step_outputs]
)
self.log_embedding_umap(embeddings, tag="train")
self.training_step_outputs = []

def _prepare_embedding_metadata(
self,
anchor_metadata: dict,
positive_metadata: dict | None = None,
negative_metadata: dict | None = None,
include_positive: bool = False,
include_negative: bool = False,
) -> tuple[list[list[str]], list[str]]:
"""Prepare metadata for embedding visualization.

Args:
anchor_metadata: Metadata for anchor samples
positive_metadata: Metadata for positive samples (optional)
negative_metadata: Metadata for negative samples (optional)
include_positive: Whether to include positive samples in metadata
include_negative: Whether to include negative samples in metadata

Returns:
tuple containing:
- metadata: List of lists containing metadata values
- metadata_header: List of metadata field names
"""
metadata_header = ["fov_name", "t", "id", "type"]

def process_field(x, field):
if field in ["t", "id"] and isinstance(x, torch.Tensor):
return str(x.detach().cpu().item())
return str(x)

# Create lists for each metadata field
metadata = [
[str(x) for x in anchor_metadata["fov_name"]],
[process_field(x, "t") for x in anchor_metadata["t"]],
[process_field(x, "id") for x in anchor_metadata["id"]],
["anchor"] * len(anchor_metadata["fov_name"]), # type field for anchors
]

# If including positive samples, extend metadata
if include_positive and positive_metadata is not None:
for i, field in enumerate(metadata_header[:-1]): # Exclude 'type' field
metadata[i].extend(
[process_field(x, field) for x in positive_metadata[field]]
)
# Add 'positive' type for positive samples
metadata[-1].extend(["positive"] * len(positive_metadata["fov_name"]))

# If including negative samples, extend metadata
if include_negative and negative_metadata is not None:
for i, field in enumerate(metadata_header[:-1]): # Exclude 'type' field
metadata[i].extend(
[process_field(x, field) for x in negative_metadata[field]]
)
# Add 'negative' type for negative samples
metadata[-1].extend(["negative"] * len(negative_metadata["fov_name"]))

return metadata, metadata_header

def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor:
"""Validation step of the model."""
anchor = batch["anchor"]
pos_img = batch["positive"]
anchor_projection = self(anchor)
positive_projection = self(pos_img)
negative_projection = None

if isinstance(self.loss_function, NTXentLoss):
indices = torch.arange(
0, anchor_projection.size(0), device=anchor_projection.device
)
batch_size = anchor.size(0)
edyoshikun marked this conversation as resolved.
Show resolved Hide resolved
indices = torch.arange(0, batch_size, device=anchor_projection.device)
labels = torch.cat((indices, indices))
# Note: we assume the two augmented views are the anchor and positive samples
embeddings = torch.cat((anchor_projection, positive_projection))
loss = self.loss_function(embeddings, labels)
self._log_step_samples(batch_idx, (anchor, pos_img), "val")

# Store embeddings for visualization
if self.current_epoch % self.embedding_log_interval == 0 and batch_idx == 0:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ziw-liu I was only logging the first batch and only from the validation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would that always be the first time point(s) then?

# Must include positive samples since we're concatenating embeddings
metadata, metadata_header = self._prepare_embedding_metadata(
batch["anchor_metadata"],
batch["positive_metadata"],
include_positive=True, # Required since we concatenate embeddings
)
self.val_embedding_outputs = {
"embeddings": embeddings.detach(),
"images": torch.cat((anchor, pos_img)).detach(),
"metadata": list(zip(*metadata)),
"metadata_header": metadata_header,
}
else:
neg_img = batch["negative"]
negative_projection = self(neg_img)
loss = self.loss_function(
anchor_projection, positive_projection, negative_projection
)
self._log_step_samples(batch_idx, (anchor, pos_img, neg_img), "val")

# Store embeddings for visualization
if self.current_epoch % self.embedding_log_interval == 0 and batch_idx == 0:
metadata, metadata_header = self._prepare_embedding_metadata(
batch["anchor_metadata"],
batch["positive_metadata"],
batch["negative_metadata"],
include_positive=True, # Required since we concatenate embeddings
include_negative=True,
)
self.val_embedding_outputs = {
"embeddings": torch.cat(
(anchor_projection, positive_projection, negative_projection)
).detach(),
"images": torch.cat((anchor, pos_img, neg_img)).detach(),
"metadata": list(zip(*metadata)),
"metadata_header": metadata_header,
}

self._log_metrics(
loss=loss,
anchor=anchor_projection,
Expand All @@ -211,13 +380,18 @@ def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor:
def on_validation_epoch_end(self) -> None:
super().on_validation_epoch_end()
self._log_samples("val_samples", self.validation_step_outputs)
# Log UMAP embeddings for training
if self.log_embeddings:
embeddings = torch.cat(
[output["embeddings"] for output in self.training_step_outputs]
)
self.log_embedding_umap(embeddings, tag="val")

# Log embeddings for validation on interval epochs
if hasattr(self, "val_embedding_outputs"):
self.log_embedding_tensorboard(
self.val_embedding_outputs["embeddings"],
self.val_embedding_outputs["images"],
self.val_embedding_outputs["metadata"],
tag="embeddings",
metadata_header=self.val_embedding_outputs["metadata_header"],
global_step=self.current_epoch,
)
delattr(self, "val_embedding_outputs")
self.validation_step_outputs = []

def configure_optimizers(self):
Expand Down
Loading