Skip to content

Commit

Permalink
fea: add tests for halfrank
Browse files Browse the repository at this point in the history
  • Loading branch information
CompRhys committed Dec 6, 2024
1 parent a1a17d8 commit a74e313
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 1 deletion.
2 changes: 1 addition & 1 deletion ax/modelbridge/transforms/sklearn_y.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def _fit(self, X, y=None, force_transform=False):
)

# Calculate rank quantiles
ranks = stats.rankdata(col, method="dense")
ranks = stats.rankdata(col, method="dense", nan_policy="omit")
dedup_median_index = np.searchsorted(unique_labels, median)
denominator = 2 * dedup_median_index + (
unique_labels[dedup_median_index] == median
Expand Down
94 changes: 94 additions & 0 deletions ax/modelbridge/transforms/tests/test_sklearn_y_transform.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
from ax.modelbridge.transforms.sklearn_y import (
HalfRankTransformer,
InfeasibleTransformer,
LogWarpingTransformer,
)
Expand Down Expand Up @@ -213,6 +214,99 @@ def test_p_feasible_calculation(self) -> None:
self.assertTrue(np.allclose(p_feasible_2, expected_p_feasible_2))


class TestHalfRankTransformer(TestCase):
def test_basic_transform(self) -> None:
"""Test basic transformation with simple data."""
X = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]])
transformer = HalfRankTransformer()
X_transformed = transformer.fit_transform(X)

# Test shape preservation
self.assertEqual(X_transformed.shape, X.shape)

# Test that values above median are unchanged
medians = np.median(X, axis=0)
for i in range(X.shape[1]):
above_median_mask = X[:, i] >= medians[i]
self.assertTrue(
np.allclose(
X_transformed[above_median_mask, i], X[above_median_mask, i]
)
)

# Test inverse transform recovers original
X_recovered = transformer.inverse_transform(X_transformed)
self.assertTrue(np.allclose(X_recovered, X))

def test_nan_handling(self) -> None:
"""Test handling of NaN values."""
X = np.array([[1.0, np.nan], [3.0, 4.0], [np.nan, 2.0], [5.0, 6.0]])
transformer = HalfRankTransformer()
X_transformed = transformer.fit_transform(X)

# Test NaN values remain NaN
self.assertTrue(np.isnan(X_transformed[0, 1]))
self.assertTrue(np.isnan(X_transformed[2, 0]))

# Test non-NaN values are transformed
self.assertFalse(np.isnan(X_transformed[0, 0]))
self.assertFalse(np.isnan(X_transformed[1, 0]))
self.assertFalse(np.isnan(X_transformed[1, 1]))
self.assertFalse(np.isnan(X_transformed[3, 1]))

# Test inverse transform preserves NaN and recovers original values
X_recovered = transformer.inverse_transform(X_transformed)
self.assertTrue(np.allclose(X_recovered[~np.isnan(X)], X[~np.isnan(X)]))
self.assertTrue(np.isnan(X_recovered[0, 1]))
self.assertTrue(np.isnan(X_recovered[2, 0]))

def test_extrapolation(self) -> None:
"""Test extrapolation for values below minimum."""
X = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
transformer = HalfRankTransformer()
transformer.fit(X)

# Test with values below minimum
X_test = np.array([[0.0, 1.0], [2.0, 3.0]])
X_transformed = transformer.transform(X_test)

# Values below minimum should be transformed
self.assertNotEqual(X_transformed[0, 0], X_test[0, 0])

# Test inverse transform recovers original values
X_recovered = transformer.inverse_transform(X_transformed)
self.assertTrue(np.allclose(X_recovered, X_test))

def test_copy_behavior(self) -> None:
"""Test that copy parameter works as expected."""
X = np.array([[1.0, 2.0], [3.0, 4.0]])
X_orig = X.copy()

# Test with copy=True (default)
transformer = HalfRankTransformer(copy=True)
transformer.fit(X)
self.assertTrue(np.array_equal(X, X_orig)) # Original should be unchanged

# Test with copy=False
transformer = HalfRankTransformer(copy=False)
X_transform = transformer.fit_transform(X)
self.assertFalse(np.array_equal(X, X_orig)) # Original should be modified
self.assertTrue(np.array_equal(X, X_transform)) # Should be the same object

def test_input_validation(self) -> None:
"""Test input validation."""
transformer = HalfRankTransformer()

# Test 1D array raises error
with self.assertRaises(ValueError):
transformer.fit(np.array([1.0, 2.0]))

# Test wrong shape in transform after fit
transformer.fit(np.array([[1.0, 2.0], [3.0, 4.0]]))
with self.assertRaises(ValueError):
transformer.transform(np.array([[1.0], [2.0]]))


if __name__ == "__main__":
import unittest

Expand Down

0 comments on commit a74e313

Please sign in to comment.