diff --git a/torchrec/metrics/ndcg.py b/torchrec/metrics/ndcg.py index 816a8e678..61be6525d 100644 --- a/torchrec/metrics/ndcg.py +++ b/torchrec/metrics/ndcg.py @@ -253,7 +253,7 @@ def _get_ndcg_states( dim=-1, index=expanded_session_ids, src=adjusted_weights, # [num_tasks, batch_size] - reduce="max", + reduce="amax", ) )