From d556a674e37a4b75d933fc0a921e3535f7d10832 Mon Sep 17 00:00:00 2001 From: Kilian Lieret Date: Fri, 8 Mar 2024 21:15:19 -0500 Subject: [PATCH 1/3] Add legacy options to GC loss --- .../metrics/losses/metric_learning.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/gnn_tracking/metrics/losses/metric_learning.py b/src/gnn_tracking/metrics/losses/metric_learning.py index 6ad9fece..5a46e872 100644 --- a/src/gnn_tracking/metrics/losses/metric_learning.py +++ b/src/gnn_tracking/metrics/losses/metric_learning.py @@ -39,6 +39,8 @@ def _hinge_loss_components( norm_rep = rep_edges.shape[1] + eps elif normalization == "n_hits_oi": norm_rep = n_hits_oi + eps + elif normalization == "n_att_edges": + norm_rep = att_edges.shape[1] + eps else: msg = f"Normalization {normalization} not recognized." raise ValueError(msg) @@ -60,6 +62,7 @@ def __init__( p_attr: float = 1.0, p_rep: float = 1.0, rep_normalization: str = "n_hits_oi", + rep_oi_only: bool = True, ): """Loss for graph construction using metric learning. @@ -73,8 +76,11 @@ def __init__( p_attr: Power for the attraction term (default 1: linear loss) p_rep: Power for the repulsion term (default 1: linear loss) normalization: Normalization for the repulsive term. Can be either - "n_rep_edges" (normalizes by the number of repulsive edges) or - "n_hits_oi" (normalizes by the number of hits of interest). + "n_rep_edges" (normalizes by the number of repulsive edges < r_emb) or + "n_hits_oi" (normalizes by the number of hits of interest) or + "n_att_edges" (normalizes by the number of attractive edges of interest) + rep_oi_only: Only consider repulsion between hits if at least one + of the hits is of interest """ super().__init__() self.save_hyperparameters() @@ -92,7 +98,10 @@ def _get_edges( ) # Every edge has to start at a particle of interest, so no special # case with noise - rep_edges = near_edges[:, mask[near_edges[0]]] + if self.hparams.rep_oi_only: + rep_edges = near_edges[:, mask[near_edges[0]]] + else: + rep_edges = near_edges rep_edges = rep_edges[:, particle_id[rep_edges[0]] != particle_id[rep_edges[1]]] att_edges = true_edge_index[:, mask[true_edge_index[0]]] return att_edges, rep_edges From 4ea78e2873aa8c0a7d75648d5c97adaa0ff9c990 Mon Sep 17 00:00:00 2001 From: Kilian Lieret Date: Fri, 8 Mar 2024 21:23:50 -0500 Subject: [PATCH 2/3] Fix: Need relu for p_rep < 1 --- src/gnn_tracking/metrics/losses/metric_learning.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/gnn_tracking/metrics/losses/metric_learning.py b/src/gnn_tracking/metrics/losses/metric_learning.py index 5a46e872..611aaf5f 100644 --- a/src/gnn_tracking/metrics/losses/metric_learning.py +++ b/src/gnn_tracking/metrics/losses/metric_learning.py @@ -44,7 +44,11 @@ def _hinge_loss_components( else: msg = f"Normalization {normalization} not recognized." raise ValueError(msg) - v_rep = torch.sum(r_emb_hinge - torch.pow(dists_rep, p_rep)) / norm_rep + # Note: Relu necessary for p < 1 + v_rep = ( + torch.sum(torch.nn.functional.relu(r_emb_hinge - torch.pow(dists_rep, p_rep))) + / norm_rep + ) return v_att, v_rep From 7a2a2f364f2e538a13b318f1dd5e83f999071ca4 Mon Sep 17 00:00:00 2001 From: Kilian Lieret Date: Sat, 9 Mar 2024 10:08:18 -0500 Subject: [PATCH 3/3] Bring back old hinge loss implementation --- .../metrics/losses/metric_learning.py | 96 +++++++++++++++++++ 1 file changed, 96 insertions(+) diff --git a/src/gnn_tracking/metrics/losses/metric_learning.py b/src/gnn_tracking/metrics/losses/metric_learning.py index 611aaf5f..1a570475 100644 --- a/src/gnn_tracking/metrics/losses/metric_learning.py +++ b/src/gnn_tracking/metrics/losses/metric_learning.py @@ -2,6 +2,7 @@ from pytorch_lightning.core.mixins.hparams_mixin import HyperparametersMixin from torch import Tensor as T from torch.linalg import norm +from torch.nn.functional import relu from torch_cluster import radius_graph from gnn_tracking.metrics.losses import MultiLossFct, MultiLossFctReturn @@ -175,3 +176,98 @@ def forward( weight_dct=weights, extra_metrics=extra, ) + + +@torch.jit.script +def _old_hinge_loss_components( + *, + x: T, + edge_index: T, + particle_id: T, + pt: T, + r_emb_hinge: float, + pt_thld: float, + p_attr: float, + p_rep: float, +) -> tuple[T, T]: + true_edge = (particle_id[edge_index[0]] == particle_id[edge_index[1]]) & ( + particle_id[edge_index[0]] > 0 + ) + true_high_pt_edge = true_edge & (pt[edge_index[0]] > pt_thld) + dists = norm(x[edge_index[0]] - x[edge_index[1]], dim=-1) + normalization = true_high_pt_edge.sum() + 1e-8 + return torch.sum( + torch.pow(dists[true_high_pt_edge], p_attr) + ) / normalization, torch.sum( + relu(r_emb_hinge - torch.pow(dists[~true_edge], p_rep)) / normalization + ) + + +class OldGraphConstructionHingeEmbeddingLoss(MultiLossFct, HyperparametersMixin): + # noinspection PyUnusedLocal + def __init__( + self, + *, + r_emb=1, + max_num_neighbors: int = 256, + attr_pt_thld: float = 0.9, + p_attr: float = 1, + p_rep: float = 1, + lw_repulsive: float = 1.0, + ): + """Loss for graph construction using metric learning. + + Args: + r_emb: Radius for edge construction + max_num_neighbors: Maximum number of neighbors in radius graph building. + See https://github.com/rusty1s/pytorch_cluster#radius-graph + p_attr: Power for the attraction term (default 1: linear loss) + p_rep: Power for the repulsion term (default 1: linear loss) + """ + super().__init__() + self.save_hyperparameters() + + def _build_graph(self, x: T, batch: T, true_edge_index: T, pt: T) -> T: + true_edge_mask = pt[true_edge_index[0]] > self.hparams.attr_pt_thld + near_edges = radius_graph( + x, + r=self.hparams.r_emb, + batch=batch, + loop=False, + max_num_neighbors=self.hparams.max_num_neighbors, + ) + return torch.unique( + torch.cat([true_edge_index[:, true_edge_mask], near_edges], dim=-1), dim=-1 + ) + + # noinspection PyUnusedLocal + def forward( + self, *, x: T, particle_id: T, batch: T, true_edge_index: T, pt: T, **kwargs + ) -> dict[str, T]: + edge_index = self._build_graph( + x=x, batch=batch, true_edge_index=true_edge_index, pt=pt + ) + attr, rep = _old_hinge_loss_components( + x=x, + edge_index=edge_index, + particle_id=particle_id, + r_emb_hinge=self.hparams.r_emb, + pt=pt, + pt_thld=self.hparams.attr_pt_thld, + p_attr=self.hparams.p_attr, + p_rep=self.hparams.p_rep, + ) + losses = { + "attractive": attr, + "repulsive": rep, + } + weights: dict[str, float] = { + "attractive": 1.0, + "repulsive": self.hparams.lw_repulsive, + } + extra = {} + return MultiLossFctReturn( + loss_dct=losses, + weight_dct=weights, + extra_metrics=extra, + )