From e34484242feaf96774add82e3d9ea618685d7ce9 Mon Sep 17 00:00:00 2001 From: Xingyou Song Date: Thu, 21 Nov 2024 07:30:36 -0800 Subject: [PATCH] Rename `regression_metrics` -> `metrics` PiperOrigin-RevId: 698773892 --- .../{regression_metrics.py => metrics.py} | 0 optformer/embed_then_regress/train.py | 6 +++--- 2 files changed, 3 insertions(+), 3 deletions(-) rename optformer/embed_then_regress/{regression_metrics.py => metrics.py} (100%) diff --git a/optformer/embed_then_regress/regression_metrics.py b/optformer/embed_then_regress/metrics.py similarity index 100% rename from optformer/embed_then_regress/regression_metrics.py rename to optformer/embed_then_regress/metrics.py diff --git a/optformer/embed_then_regress/train.py b/optformer/embed_then_regress/train.py index 71c24db..303e027 100644 --- a/optformer/embed_then_regress/train.py +++ b/optformer/embed_then_regress/train.py @@ -27,7 +27,7 @@ from optformer.embed_then_regress import checkpointing as ckpt_lib from optformer.embed_then_regress import configs from optformer.embed_then_regress import icl_transformer -from optformer.embed_then_regress import regression_metrics +from optformer.embed_then_regress import metrics as metrics_lib import tensorflow as tf @@ -109,10 +109,10 @@ def loss_fn( target_mask = 1 - batch['mask'] # [B, L] target_nlogprob = nlogprob * target_mask # [B, L] - avg_nlogprob = regression_metrics.masked_mean(target_nlogprob, target_mask) + avg_nlogprob = metrics_lib.masked_mean(target_nlogprob, target_mask) loss = jnp.mean(avg_nlogprob) # [B] -> Scalar - metrics = regression_metrics.default_metrics(mean, batch['y'], target_mask) + metrics = metrics_lib.default_metrics(mean, batch['y'], target_mask) metrics['loss'] = loss return loss, metrics