Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add more flexibility gc loss #504

Merged
merged 3 commits into from
Mar 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 113 additions & 4 deletions src/gnn_tracking/metrics/losses/metric_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pytorch_lightning.core.mixins.hparams_mixin import HyperparametersMixin
from torch import Tensor as T
from torch.linalg import norm
from torch.nn.functional import relu
from torch_cluster import radius_graph

from gnn_tracking.metrics.losses import MultiLossFct, MultiLossFctReturn
Expand Down Expand Up @@ -39,10 +40,16 @@ def _hinge_loss_components(
norm_rep = rep_edges.shape[1] + eps
elif normalization == "n_hits_oi":
norm_rep = n_hits_oi + eps
elif normalization == "n_att_edges":
norm_rep = att_edges.shape[1] + eps
else:
msg = f"Normalization {normalization} not recognized."
raise ValueError(msg)
v_rep = torch.sum(r_emb_hinge - torch.pow(dists_rep, p_rep)) / norm_rep
# Note: Relu necessary for p < 1
v_rep = (
torch.sum(torch.nn.functional.relu(r_emb_hinge - torch.pow(dists_rep, p_rep)))
/ norm_rep
)

return v_att, v_rep

Expand All @@ -60,6 +67,7 @@ def __init__(
p_attr: float = 1.0,
p_rep: float = 1.0,
rep_normalization: str = "n_hits_oi",
rep_oi_only: bool = True,
):
"""Loss for graph construction using metric learning.

Expand All @@ -73,8 +81,11 @@ def __init__(
p_attr: Power for the attraction term (default 1: linear loss)
p_rep: Power for the repulsion term (default 1: linear loss)
normalization: Normalization for the repulsive term. Can be either
"n_rep_edges" (normalizes by the number of repulsive edges) or
"n_hits_oi" (normalizes by the number of hits of interest).
"n_rep_edges" (normalizes by the number of repulsive edges < r_emb) or
"n_hits_oi" (normalizes by the number of hits of interest) or
"n_att_edges" (normalizes by the number of attractive edges of interest)
rep_oi_only: Only consider repulsion between hits if at least one
of the hits is of interest
"""
super().__init__()
self.save_hyperparameters()
Expand All @@ -92,7 +103,10 @@ def _get_edges(
)
# Every edge has to start at a particle of interest, so no special
# case with noise
rep_edges = near_edges[:, mask[near_edges[0]]]
if self.hparams.rep_oi_only:
rep_edges = near_edges[:, mask[near_edges[0]]]
else:
rep_edges = near_edges
rep_edges = rep_edges[:, particle_id[rep_edges[0]] != particle_id[rep_edges[1]]]
att_edges = true_edge_index[:, mask[true_edge_index[0]]]
return att_edges, rep_edges
Expand Down Expand Up @@ -162,3 +176,98 @@ def forward(
weight_dct=weights,
extra_metrics=extra,
)


@torch.jit.script
def _old_hinge_loss_components(
*,
x: T,
edge_index: T,
particle_id: T,
pt: T,
r_emb_hinge: float,
pt_thld: float,
p_attr: float,
p_rep: float,
) -> tuple[T, T]:
true_edge = (particle_id[edge_index[0]] == particle_id[edge_index[1]]) & (
particle_id[edge_index[0]] > 0
)
true_high_pt_edge = true_edge & (pt[edge_index[0]] > pt_thld)
dists = norm(x[edge_index[0]] - x[edge_index[1]], dim=-1)
normalization = true_high_pt_edge.sum() + 1e-8
return torch.sum(
torch.pow(dists[true_high_pt_edge], p_attr)
) / normalization, torch.sum(
relu(r_emb_hinge - torch.pow(dists[~true_edge], p_rep)) / normalization
)


class OldGraphConstructionHingeEmbeddingLoss(MultiLossFct, HyperparametersMixin):
# noinspection PyUnusedLocal
def __init__(
self,
*,
r_emb=1,
max_num_neighbors: int = 256,
attr_pt_thld: float = 0.9,
p_attr: float = 1,
p_rep: float = 1,
lw_repulsive: float = 1.0,
):
"""Loss for graph construction using metric learning.

Args:
r_emb: Radius for edge construction
max_num_neighbors: Maximum number of neighbors in radius graph building.
See https://github.com/rusty1s/pytorch_cluster#radius-graph
p_attr: Power for the attraction term (default 1: linear loss)
p_rep: Power for the repulsion term (default 1: linear loss)
"""
super().__init__()
self.save_hyperparameters()

def _build_graph(self, x: T, batch: T, true_edge_index: T, pt: T) -> T:
true_edge_mask = pt[true_edge_index[0]] > self.hparams.attr_pt_thld
near_edges = radius_graph(
x,
r=self.hparams.r_emb,
batch=batch,
loop=False,
max_num_neighbors=self.hparams.max_num_neighbors,
)
return torch.unique(
torch.cat([true_edge_index[:, true_edge_mask], near_edges], dim=-1), dim=-1
)

# noinspection PyUnusedLocal
def forward(
self, *, x: T, particle_id: T, batch: T, true_edge_index: T, pt: T, **kwargs
) -> dict[str, T]:
edge_index = self._build_graph(
x=x, batch=batch, true_edge_index=true_edge_index, pt=pt
)
attr, rep = _old_hinge_loss_components(
x=x,
edge_index=edge_index,
particle_id=particle_id,
r_emb_hinge=self.hparams.r_emb,
pt=pt,
pt_thld=self.hparams.attr_pt_thld,
p_attr=self.hparams.p_attr,
p_rep=self.hparams.p_rep,
)
losses = {
"attractive": attr,
"repulsive": rep,
}
weights: dict[str, float] = {
"attractive": 1.0,
"repulsive": self.hparams.lw_repulsive,
}
extra = {}
return MultiLossFctReturn(
loss_dct=losses,
weight_dct=weights,
extra_metrics=extra,
)
Loading