Skip to content

Commit

Permalink
fix: address subset issue highlighted in #95
Browse files Browse the repository at this point in the history
  • Loading branch information
CompRhys committed Jan 20, 2025
1 parent bfd5f16 commit f5bfcc5
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions aviary/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,11 +316,6 @@ def train_ensemble(
when early stopping. Defaults to None.
verbose (bool, optional): Whether to show progress bars for each epoch.
"""
if isinstance(train_set, Subset):
train_set = train_set.dataset
if isinstance(val_set, Subset):
val_set = val_set.dataset

train_loader = DataLoader(train_set, **data_params)
print(f"Training on {len(train_set):,} samples")

Expand Down Expand Up @@ -354,7 +349,13 @@ def train_ensemble(

for target, normalizer in normalizer_dict.items():
if normalizer is not None:
sample_target = Tensor(train_set.df[target].values)
if isinstance(train_set, Subset):
sample_target = Tensor(
train_set.dataset.df[target].iloc[train_set.indices].values
)
else:
sample_target = Tensor(train_set.df[target].values)

if not restart_params["resume"]:
normalizer.fit(sample_target)
print(f"Dummy MAE: {(sample_target - normalizer.mean).abs().mean():.4f}")
Expand Down Expand Up @@ -455,10 +456,6 @@ def results_multitask(
"------------Evaluate model on Test Set------------\n"
"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n"
)

if isinstance(test_set, Subset):
test_set = test_set.dataset

test_loader = DataLoader(test_set, **data_params)
print(f"Testing on {len(test_set):,} samples")

Expand Down

0 comments on commit f5bfcc5

Please sign in to comment.