Skip to content

Commit

Permalink
Merge pull request #35 from brightbandtech/taylor/ope-64-metric-testing
Browse files Browse the repository at this point in the history
Taylor/ope-64-metric-testing
  • Loading branch information
aaTman authored Jan 25, 2025
2 parents accb13e + 4814ccd commit 4368d17
Show file tree
Hide file tree
Showing 6 changed files with 322 additions and 38 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ wheels/
.ipynb_checkpoints/
.coverage*
.mypy_cache
.vscode
.vscode/
40 changes: 10 additions & 30 deletions src/extremeweatherbench/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,8 @@ def align_datasets(
init_time_datetime: datetime.datetime,
):
"""Align the forecast and observation datasets."""
try:
forecast = forecast.sel(init_time=init_time_datetime)
# handles duplicate initialization times. please try to avoid this situation
except ValueError:
init_time_duplicate_length = len(
forecast.where(
forecast.init_time == init_time_datetime, drop=True
).init_time
)
if init_time_duplicate_length > 1:
logger.warning(
"init time %s has more than %d forecast associated with it, taking first only",
init_time_datetime,
init_time_duplicate_length,
)
forecast = forecast.sel(init_time=init_time_datetime).isel(init_time=0)

forecast = forecast.sel(init_time=init_time_datetime)
time = np.array(
[
init_time_datetime + pd.Timedelta(hours=int(t))
Expand All @@ -62,7 +48,7 @@ def align_datasets(

@dataclasses.dataclass
class RegionalRMSE(Metric):
"""Root mean squared error of a regional forecast evalauted against observations."""
"""Root mean squared error of a regional forecast evaluated against observations."""

def compute(self, forecast: xr.DataArray, observation: xr.DataArray):
rmse_values = []
Expand Down Expand Up @@ -108,10 +94,7 @@ def compute(self, forecast: xr.DataArray, observation: xr.DataArray):
coords={"lead_time": lead_time.values},
)
max_mae_values.append(max_mae_dataarray)
max_mae_full_da = xr.concat(max_mae_values, dim="lead_time")
# Reverse the lead time so that the minimum lead time is first
max_mae_full_da = max_mae_full_da.isel(lead_time=slice(None, None, -1))
max_mae_full_da = utils.expand_lead_times_to_6_hourly(max_mae_full_da)
max_mae_full_da = utils.process_dataarray_for_output(max_mae_values)
return max_mae_full_da


Expand All @@ -120,14 +103,14 @@ class MaxOfMinTempMAE(Metric):
"""Mean absolute error of forecasted highest minimum temperature values."""

def compute(self, forecast: xr.DataArray, observation: xr.DataArray):
max_mae_values = []
max_min_mae_values = []
observation_spatial_mean = observation.mean(["latitude", "longitude"])
observation_spatial_mean = observation_spatial_mean.where(
observation_spatial_mean.time.dt.hour % 6 == 0, drop=True
)
forecast_spatial_mean = forecast.mean(["latitude", "longitude"])
for init_time in forecast_spatial_mean.init_time:
if forecast.name == "air_temperature":
if forecast_spatial_mean.name == "air_temperature":
# Keep only times at 00, 06, 12, and 18Z
# Group by dayofyear and check if each day has all 4 synoptic times
valid_days = (
Expand Down Expand Up @@ -162,21 +145,18 @@ def compute(self, forecast: xr.DataArray, observation: xr.DataArray):
lead_time = init_forecast_spatial_mean.where(
init_forecast_spatial_mean.time == max_min_timestamp, drop=True
).lead_time
max_mae_dataarray = xr.DataArray(
max_min_mae_dataarray = xr.DataArray(
data=abs(
init_forecast_spatial_mean.max().values - max_min_value
),
dims=["lead_time"],
coords={"lead_time": lead_time.values},
)
max_mae_values.append(max_mae_dataarray)
max_min_mae_values.append(max_min_mae_dataarray)
else:
raise KeyError("Only air_temperature forecasts are supported.")
max_mae_full_da = xr.concat(max_mae_values, dim="lead_time")
# Reverse the lead time so that the minimum lead time is first
max_mae_full_da = max_mae_full_da.isel(lead_time=slice(None, None, -1))
max_mae_full_da = utils.expand_lead_times_to_6_hourly(max_mae_full_da)
return max_mae_full_da
max_min_mae_full_da = utils.process_dataarray_for_output(max_min_mae_values)
return max_min_mae_full_da


@dataclasses.dataclass
Expand Down
36 changes: 34 additions & 2 deletions src/extremeweatherbench/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
other specialized package.
"""

from typing import Union
from typing import Union, List, Optional
from collections import namedtuple
import fsspec
import geopandas as gpd
import numpy as np
import pandas as pd
import regionmask
import ujson
import rioxarray # noqa: F401
import xarray as xr
from kerchunk.hdf import SingleHdf5ToZarr
from shapely.geometry import box
Expand Down Expand Up @@ -269,5 +270,36 @@ def expand_lead_times_to_6_hourly(
final_times.append(hour)
dataarray = xr.DataArray(
data=final_data, dims=["lead_time"], coords={"lead_time": final_times}
)
).astype(float)
return dataarray


def process_dataarray_for_output(da_list: List[Optional[xr.DataArray]]):
"""Extract and format data from a list of DataArrays.
Args:
dataarray: A list of xarray DataArrays.
data: An xarray DataArray (likely unused in current implementation).
dims: Dimensions of the DataArray (likely unused in current implementation).
coords: Coordinates of the DataArray (likely unused in current implementation).
dim: Dimension name (likely unused in current implementation).
lead_time: Lead time coordinate name (likely unused in current implementation).
Returns:
An xarray DataArray with lead_time coordinate, expanded to 6-hourly intervals.
Returns a DataArray with a single NaN value if the input dataarray is empty.
"""

if len(da_list) == 0:
# create dummy nan dataarray in case no applicable forecast valid times exist
output_da = xr.DataArray(
data=[np.nan],
dims=["lead_time"],
coords={"lead_time": [0]},
)
else:
output_da = xr.concat(da_list, dim="lead_time")
# Reverse the lead time so that the minimum lead time is first
output_da = output_da.isel(lead_time=slice(None, None, -1))
output_da = expand_lead_times_to_6_hourly(output_da)
return output_da
52 changes: 48 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@
def mock_forecast_dataset():
init_time = pd.date_range("2021-06-20", periods=5)
lead_time = range(0, 241, 6)
data = np.random.rand(len(init_time), 180, 360, len(lead_time))
data = np.random.RandomState(21897820).standard_normal(
size=(len(init_time), 180, 360, len(lead_time)),
)
latitudes = np.linspace(-90, 90, 180)
longitudes = np.linspace(0, 359, 360)
dataset = xr.Dataset(
{
"air_temperature": (
["init_time", "latitude", "longitude", "lead_time"],
data,
20 + 5 * data,
),
"eastward_wind": (
["init_time", "latitude", "longitude", "lead_time"],
Expand Down Expand Up @@ -55,16 +57,58 @@ def mock_config():
@pytest.fixture
def mock_gridded_obs_dataset():
time = pd.date_range("2021-06-20", freq="3h", periods=200)
data = np.random.rand(len(time), 180, 360)
data = np.random.RandomState(21897820).standard_normal(size=(len(time), 180, 360))
latitudes = np.linspace(-90, 90, 180)
longitudes = np.linspace(0, 359, 360)
dataset = xr.Dataset(
{
"2m_temperature": (["time", "latitude", "longitude"], 20 + 5 * data),
"tp": (["time", "latitude", "longitude"], data),
"10m_u_component_of_wind": (["time", "latitude", "longitude"], data),
"10m_v_component_of_wind": (["time", "latitude", "longitude"], data),
},
coords={"time": time, "latitude": latitudes, "longitude": longitudes},
)
return dataset


@pytest.fixture
def mock_gridded_obs_dataset_max_in_forecast():
time = pd.date_range("2021-06-20", freq="3h", periods=200)
data = np.random.RandomState(21897820).standard_normal(size=(len(time), 180, 360))
data[10, :, :] = 5
latitudes = np.linspace(-90, 90, 180)
longitudes = np.linspace(0, 359, 360)
dataset = xr.Dataset(
{
"2m_temperature": (["time", "latitude", "longitude"], data),
"2m_temperature": (["time", "latitude", "longitude"], 20 + 5 * data),
"tp": (["time", "latitude", "longitude"], data),
"10m_u_component_of_wind": (["time", "latitude", "longitude"], data),
"10m_v_component_of_wind": (["time", "latitude", "longitude"], data),
},
coords={"time": time, "latitude": latitudes, "longitude": longitudes},
)
return dataset


@pytest.fixture
def mock_forecast_dataarray(mock_forecast_dataset):
return dataset_to_dataarray(mock_forecast_dataset)


@pytest.fixture
def mock_gridded_obs_dataarray(mock_gridded_obs_dataset):
return dataset_to_dataarray(mock_gridded_obs_dataset)


@pytest.fixture
def mock_gridded_obs_dataarray_max_in_forecast(
mock_gridded_obs_dataset_max_in_forecast,
):
return dataset_to_dataarray(mock_gridded_obs_dataset_max_in_forecast)


def dataset_to_dataarray(dataset):
"""Convert an xarray Dataset to a DataArray."""
mock_data_var = [data_var for data_var in dataset.data_vars][0]
return dataset[mock_data_var]
1 change: 0 additions & 1 deletion tests/test_case.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pytest
import extremeweatherbench.case as case
import rioxarray # noqa: F401
from extremeweatherbench.utils import Location
import datetime

Expand Down
Loading

0 comments on commit 4368d17

Please sign in to comment.