-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update _get_weights()
method for SpatialAccessor
and TemporalAccessor
#252
Conversation
elif isinstance(weights, xr.DataArray): | ||
dv_weights = weights | ||
|
||
self._validate_weights(dv, axis, dv_weights) | ||
dataset[dv.name] = self._averager(dv, axis, dv_weights) | ||
return dataset | ||
|
||
def get_weights( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume this is the only line that changed in this file (aside from calls to _get_weights
)? Github is highlighting this whole function...probably because it got moved?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this method was moved and renamed. Nothing else was changed in this file.
tests/test_spatial.py
Outdated
# FIXME: ValueError when domain bounds contains lower bound larger than | ||
# upper bound |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should throw an error? I'm kind of confused why it wasn't before based on this check. Should this test have region bounds such that lon_bounds=np.array([350, 20])
or something like that?
I think region bounds should be able to accept a larger right hand bound, but I don't think this should be true for domain bounds.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ended up deleting this test because it is faulty and we already have a test for raising a ValueError if the domain bounds has a larger lower bound vs. upper bound.
), | ||
"lon": xr.DataArray( | ||
name="lon_wts", | ||
data=np.array([1, 2, 3, 4]), | ||
coords={"lon": self.ds.lon}, | ||
dims=["lon"], | ||
), | ||
} | ||
|
||
def test_weights_for_single_axis_are_identical(self): | ||
axis_weights = self.axis_weights | ||
del axis_weights["lon"] | ||
|
||
result = self.ds.spatial._combine_weights(axis_weights=self.axis_weights) | ||
expected = self.axis_weights["lat"] | ||
|
||
assert result.identical(expected) | ||
|
||
def test_weights_for_multiple_axis_is_the_product_of_matrix_multiplication(self): | ||
result = self.ds.spatial._combine_weights(axis_weights=self.axis_weights) | ||
expected = xr.DataArray( | ||
data=np.array([[1, 2, 3, 4], [2, 4, 6, 8], [3, 6, 9, 12], [4, 8, 12, 16]]), | ||
coords={"lat": self.ds.lat, "lon": self.ds.lon}, | ||
dims=["lat", "lon"], | ||
) | ||
|
||
assert result.identical(expected) | ||
|
||
|
||
class TestAverager: | ||
@pytest.fixture(autouse=True) | ||
def setup(self): | ||
self.ds = generate_dataset(cf_compliant=True, has_bounds=True) | ||
|
||
@requires_dask | ||
def test_chunked_weighted_avg_over_lat_and_lon_axes(self): | ||
ds = self.ds.copy().chunk(2) | ||
|
||
weights = xr.DataArray( | ||
data=np.array([[1, 2, 3, 4], [2, 4, 6, 8], [3, 6, 9, 12], [4, 8, 12, 16]]), | ||
coords={"lat": ds.lat, "lon": ds.lon}, | ||
dims=["lat", "lon"], | ||
) | ||
|
||
result = ds.spatial._averager(ds.ts, axis=["X", "Y"], weights=weights) | ||
expected = xr.DataArray( | ||
name="ts", data=np.ones(15), coords={"time": ds.time}, dims=["time"] | ||
) | ||
|
||
assert result.identical(expected) | ||
|
||
def test_weighted_avg_over_lat_axis(self): | ||
weights = xr.DataArray( | ||
name="lat_wts", | ||
data=np.array([1, 2, 3, 4]), | ||
coords={"lat": self.ds.lat}, | ||
dims=["lat"], | ||
) | ||
|
||
result = self.ds.spatial._averager(self.ds.ts, axis=["Y"], weights=weights) | ||
expected = xr.DataArray( | ||
name="ts", | ||
data=np.ones((15, 4)), | ||
coords={"time": self.ds.time, "lon": self.ds.lon}, | ||
dims=["time", "lon"], | ||
) | ||
|
||
assert result.identical(expected) | ||
|
||
def test_weighted_avg_over_lon_axis(self): | ||
weights = xr.DataArray( | ||
name="lon_wts", | ||
data=np.array([1, 2, 3, 4]), | ||
coords={"lon": self.ds.lon}, | ||
dims=["lon"], | ||
) | ||
|
||
result = self.ds.spatial._averager(self.ds.ts, axis=["X"], weights=weights) | ||
expected = xr.DataArray( | ||
name="ts", | ||
data=np.ones((15, 4)), | ||
coords={"time": self.ds.time, "lat": self.ds.lat}, | ||
dims=["time", "lat"], | ||
) | ||
|
||
assert result.identical(expected) | ||
|
||
def test_weighted_avg_over_lat_and_lon_axis(self): | ||
weights = xr.DataArray( | ||
data=np.array([[1, 2, 3, 4], [2, 4, 6, 8], [3, 6, 9, 12], [4, 8, 12, 16]]), | ||
coords={"lat": self.ds.lat, "lon": self.ds.lon}, | ||
dims=["lat", "lon"], | ||
) | ||
|
||
result = self.ds.spatial._averager(self.ds.ts, axis=["X", "Y"], weights=weights) | ||
expected = xr.DataArray( | ||
name="ts", data=np.ones(15), coords={"time": self.ds.time}, dims=["time"] | ||
) | ||
|
||
assert result.identical(expected) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I kind of remember you commenting on this...is this all removed because they are private methods (and we don't need to test them)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same for TestGetWeights
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct, we should test the public method which will cover the private methods (implementation details).
No need to review any of these changes here so I am just porting some tests over from the private methods to the public methods.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have no objection to making get_weights()
public, though I didn't fully understand the the changes to the unit tests. Should I view the changes in tests/
as a reorganization or something that I should review more carefully?
I think the error for test_weights_for_region_in_lon_domain_with_both_spanning_p_meridian
is correct and we should represent the problematic line as ds.lon_bnds.data[:] = np.array([[-1, 1], [1, 90], [90, 180], [180, 359]])
.
I addressed this comment in one of the PR review comments.
I ended up removing tests where the domain bounds has a larger left hand value than right hand value because they are supposed to throw an error, rather than pass. There's already a test to check that an error is thrown for this case. I kept |
Codecov Report
@@ Coverage Diff @@
## main #252 +/- ##
=========================================
Coverage 100.00% 100.00%
=========================================
Files 9 8 -1
Lines 742 736 -6
=========================================
- Hits 742 736 -6
Continue to review full report at Codecov.
|
3c6c5a3
to
2d82d14
Compare
Hi @pochedls, @lee1043, and @chengzhuzhang, instead of making
Example for Lines 332 to 339 in 44620c0
Lines 742 to 745 in 44620c0
Example for Lines 465 to 472 in 44620c0
Lines 624 to 629 in 44620c0
Lines 1369 to 1402 in 44620c0
If this looks good, we can implement |
Hi @tomvothecoder Thank you for the discussion. I don't recall that I have a use case that requires to get/keep the weights to pass down to another operation. With the exception that, they are useful to be examed during xcdat development/validation. So I don't have strong opinion on the implementation. I will defer to Steve and Jiwoo. |
I like I think it would be good to keep |
There is only one Keeping weights with Climatology
Keeping weights with Departures
Sounds good, I added
Hmmm, we can address this if users open up an issue about it, unless you think it is worth addressing now. |
Got it – thanks Tom. |
- Add kwarg `keep_weights` to both methods - Make `_get_weights()` a public method for `SpatialAccessor` - Delete redundant tests for private methods - Delete faulty tests with domain bounds:
225df77
to
634dca4
Compare
SpatialAccessor._get_weights()
and TemporalAccessor._get_weights()
public_get_weights()
method for SpatialAccessor
and TemporalAccessor
Description
Summary of Changes
_get_weights()
method forSpatialAccessor
andTemporalAccessor
#251None
value forlat_bounds
andlon_bounds
keyword args toSpatialAccessor.get_weights()
keep_weights
to both methods_get_weights()
a public method forSpatialAccessor
Checklist
If applicable: