Skip to content

Commit

Permalink
requested changes
Browse files Browse the repository at this point in the history
  • Loading branch information
aaTman committed Jan 30, 2025
1 parent 106780a commit 4b48b98
Show file tree
Hide file tree
Showing 8 changed files with 283 additions and 339 deletions.
226 changes: 99 additions & 127 deletions src/extremeweatherbench/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,38 +24,6 @@ def name(self) -> str:
"""Return the class name without parentheses."""
return self.__class__.__name__

def _align_observations_temporal_resolution(
self, forecast: xr.DataArray, observation: xr.DataArray
) -> xr.DataArray:
"""Align the temporal resolution of the forecast and observation dataarrays,
in case the observations are at a higher temporal resolution than forecast data.
Metrics which need a singular timestep from gridded obs will fail if the forecasts
are not aligned with the observation timestamps (e.g. a 03z minimum temp in observations
when the forecast only has 00z and 06z timesteps).
Args:
forecast: The forecast dataarray to align.
observation: The observation dataarray to align.
Returns:
The aligned observation dataarray.
"""
obs_time_delta = pd.to_timedelta(np.diff(observation.time).mean())
forecast_time_delta = pd.to_timedelta(
np.diff(forecast.lead_time).mean(), unit="h"
)

if forecast_time_delta != obs_time_delta:
if forecast_time_delta > obs_time_delta:
# Resample observations to match forecast resolution
observation = observation.resample(time=forecast_time_delta).first()
else:
logger.warning(
"Observation time resolution (%s) is coarser than forecast time resolution (%s)",
obs_time_delta,
forecast_time_delta,
)

return observation


@dataclasses.dataclass
class RegionalRMSE(Metric):
Expand All @@ -78,117 +46,121 @@ def compute(self, forecast: xr.DataArray, observation: xr.DataArray):

@dataclasses.dataclass
class MaximumMAE(Metric):
"""Mean absolute error of forecasted maximum values."""
"""Mean absolute error of forecasted maximum values.
Attributes:
time_deviation_tolerance: amount of time in hours to allow for forecast deviation from the observed maximum
temperature timestamp.
"""

time_deviation_tolerance: int = 48

def compute(self, forecast: xr.DataArray, observation: xr.DataArray):
if forecast.name == "air_temperature":
max_mae_values = []
observation_spatial_mean = observation.mean(["latitude", "longitude"])
forecast_spatial_mean = forecast.mean(["latitude", "longitude"])
observation_spatial_mean = self._align_observations_temporal_resolution(
forecast_spatial_mean, observation_spatial_mean
max_mae_values = []
observation_spatial_mean = observation.mean(["latitude", "longitude"])
forecast_spatial_mean = forecast.mean(["latitude", "longitude"])
observation_spatial_mean = utils.align_observations_temporal_resolution(
forecast_spatial_mean, observation_spatial_mean
)
for init_time in forecast_spatial_mean.init_time:
max_datetime = observation_spatial_mean.idxmax("time").values
max_value = observation_spatial_mean.sel(time=max_datetime).values
init_forecast_spatial_mean, _ = utils.temporal_align_dataarrays(
forecast_spatial_mean,
observation_spatial_mean,
pd.Timestamp(init_time.values).to_pydatetime(),
)
for init_time in forecast_spatial_mean.init_time:
max_datetime = observation_spatial_mean.idxmax("time").values
max_value = observation_spatial_mean.sel(time=max_datetime).values
init_forecast_spatial_mean, _ = utils.temporal_align_dataarrays(
forecast_spatial_mean,
observation_spatial_mean,
pd.Timestamp(init_time.values).to_pydatetime(),

if max_datetime in init_forecast_spatial_mean.time.values:
# Subset to +-48 hours centered on the maximum temperature timestamp
filtered_forecast = utils.center_forecast_on_time(
init_forecast_spatial_mean,
time=pd.Timestamp(max_datetime),
hours=self.time_deviation_tolerance,
)
lead_time = filtered_forecast.where(
filtered_forecast.time == max_datetime, drop=True
).lead_time
max_mae_dataarray = xr.DataArray(
data=[abs(filtered_forecast.max().values - max_value)],
dims=["lead_time"],
coords={"lead_time": lead_time.values},
)
max_mae_values.append(max_mae_dataarray)

if max_datetime in init_forecast_spatial_mean.time.values:
# Subset to +-48 hours centered on the maximum temperature timestamp
filtered_forecast = utils.center_forecast_on_time(
init_forecast_spatial_mean,
time=pd.Timestamp(max_datetime),
hours=48,
)
lead_time = filtered_forecast.where(
filtered_forecast.time == max_datetime, drop=True
).lead_time
max_mae_dataarray = xr.DataArray(
data=[abs(filtered_forecast.max().values - max_value)],
dims=["lead_time"],
coords={"lead_time": lead_time.values},
)
max_mae_values.append(max_mae_dataarray)
else:
raise NotImplementedError(
"Only air_temperature is currently supported for MaximumMAE."
)
max_mae_full_da = utils.process_dataarray_for_output(max_mae_values)
return max_mae_full_da


@dataclasses.dataclass
class MaxOfMinTempMAE(Metric):
"""Mean absolute error of forecasted highest minimum temperature values."""
"""Mean absolute error of forecasted highest minimum temperature values.
Attributes:
time_deviation_tolerance: amount of time in hours to allow for forecast deviation from the observed maximum
temperature timestamp.
"""

time_deviation_tolerance: int = 48

def compute(self, forecast: xr.DataArray, observation: xr.DataArray):
if forecast.name == "air_temperature":
max_min_mae_values = []
observation_spatial_mean = observation.mean(["latitude", "longitude"])
forecast_spatial_mean = forecast.mean(["latitude", "longitude"])
# Verify observation_spatial_mean's time resolution matches forecast_spatial_mean
observation_spatial_mean = self._align_observations_temporal_resolution(
forecast_spatial_mean, observation_spatial_mean
)
observation_spatial_mean = self._truncate_incomplete_days(
observation_spatial_mean
max_min_mae_values = []
observation_spatial_mean = observation.mean(["latitude", "longitude"])
forecast_spatial_mean = forecast.mean(["latitude", "longitude"])
# Verify observation_spatial_mean's time resolution matches forecast_spatial_mean
observation_spatial_mean = utils.align_observations_temporal_resolution(
forecast_spatial_mean, observation_spatial_mean
)
observation_spatial_mean = self._truncate_incomplete_days(
observation_spatial_mean
)
max_min_timestamp = self._return_max_min_timestamp(observation_spatial_mean)
max_min_value = observation_spatial_mean.sel(time=max_min_timestamp).values

for init_time in forecast_spatial_mean.init_time:
init_forecast_spatial_mean, _ = utils.temporal_align_dataarrays(
forecast_spatial_mean,
observation_spatial_mean,
pd.Timestamp(init_time.values).to_pydatetime(),
)
max_min_timestamp = self._return_max_min_timestamp(observation_spatial_mean)
max_min_value = observation_spatial_mean.sel(time=max_min_timestamp).values

for init_time in forecast_spatial_mean.init_time:
init_forecast_spatial_mean, _ = utils.temporal_align_dataarrays(
forecast_spatial_mean,
observation_spatial_mean,
pd.Timestamp(init_time.values).to_pydatetime(),
if max_min_timestamp in init_forecast_spatial_mean.time.values:
filtered_forecast = self._truncate_incomplete_days(
init_forecast_spatial_mean
)
if max_min_timestamp in init_forecast_spatial_mean.time.values:
filtered_forecast = self._truncate_incomplete_days(
init_forecast_spatial_mean
)
filtered_forecast = utils.center_forecast_on_time(
filtered_forecast,
time=pd.Timestamp(max_min_timestamp),
hours=48,
)
# Ensure that the forecast has a full day of data for each day
# after centering on the max of min timestamp
filtered_forecast = self._truncate_incomplete_days(
filtered_forecast
)
lead_time = filtered_forecast.where(
filtered_forecast.time == max_min_timestamp, drop=True
).lead_time
filtered_forecast_max_min = filtered_forecast.where(
filtered_forecast
== filtered_forecast.groupby("time.dayofyear").min().max(),
drop=True,
filtered_forecast = utils.center_forecast_on_time(
filtered_forecast,
time=pd.Timestamp(max_min_timestamp),
hours=self.time_deviation_tolerance,
)
# Ensure that the forecast has a full day of data for each day
# after centering on the max of min timestamp
filtered_forecast = self._truncate_incomplete_days(filtered_forecast)
lead_time = filtered_forecast.where(
filtered_forecast.time == max_min_timestamp, drop=True
).lead_time
filtered_forecast_max_min = filtered_forecast.where(
filtered_forecast
== filtered_forecast.groupby("time.dayofyear").min().max(),
drop=True,
)
# TODO: add temporal displacement error, which is
# filtered_forecast_max_min.time.values[0] - max_min_timestamp
if max_min_timestamp in filtered_forecast.time.values:
max_min_mae_dataarray = xr.DataArray(
data=abs(filtered_forecast_max_min - max_min_value),
dims=["lead_time"],
coords={"lead_time": lead_time.values},
attrs={
"description": (
"Mean absolute error of forecasted highest minimum temperature values,"
"where lead_time is the time from initialization until the highest minimum"
"observed temperature."
)
},
)
# TODO: add temporal displacement error, which is
# filtered_forecast_max_min.time.values[0] - max_min_timestamp
if max_min_timestamp in filtered_forecast.time.values:
max_min_mae_dataarray = xr.DataArray(
data=abs(filtered_forecast_max_min - max_min_value),
dims=["lead_time"],
coords={"lead_time": lead_time.values},
attrs={
"description": (
"Mean absolute error of forecasted highest minimum temperature values,"
"where lead_time is the time from initialization until the highest minimum"
"observed temperature."
)
},
)
max_min_mae_values.append(max_min_mae_dataarray)
else:
raise NotImplementedError(
"Only air_temperature is currently supported for MaxOfMinTempMAE."
)

max_min_mae_values.append(max_min_mae_dataarray)
max_min_mae_full_da = utils.process_dataarray_for_output(max_min_mae_values)
return max_min_mae_full_da

Expand Down
41 changes: 35 additions & 6 deletions src/extremeweatherbench/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
other specialized package.
"""

from typing import Union, List, Optional
from typing import Union, List
from collections import namedtuple
import fsspec
import geopandas as gpd
Expand Down Expand Up @@ -275,15 +275,14 @@ def expand_lead_times_to_6_hourly(
return dataarray


def process_dataarray_for_output(da_list: List[Optional[xr.DataArray]]) -> xr.DataArray:
def process_dataarray_for_output(da_list: List[xr.DataArray]) -> xr.DataArray:
"""Extract and format data from a list of DataArrays.
Args:
dataarray: A list of xarray DataArrays.
da_list: A list of xarray DataArrays.
Returns:
An xarray DataArray with lead_time coordinate, expanded to 6-hourly intervals.
Returns a DataArray with a single NaN value if the input dataarray is empty.
A DataArray with a sorted lead_time coordinate, expanded to 6-hourly intervals.
"""

if len(da_list) == 0:
Expand All @@ -303,13 +302,16 @@ def process_dataarray_for_output(da_list: List[Optional[xr.DataArray]]) -> xr.Da

def center_forecast_on_time(da: xr.DataArray, time: pd.Timestamp, hours: int):
"""Center a forecast DataArray on a given time with a given range in hours.
Args:
da: The forecast DataArray to center.
time: The time to center the forecast on.
hours: The number of hours to include in the centered forecast.
"""
time_range = pd.date_range(
end=pd.to_datetime(time) + pd.Timedelta(hours=48), periods=97, freq="h"
end=pd.to_datetime(time) + pd.Timedelta(hours=hours),
periods=hours * 2 + 1,
freq="h",
)
return da.sel(time=slice(time_range[0], time_range[-1]))

Expand All @@ -320,10 +322,12 @@ def temporal_align_dataarrays(
init_time_datetime: datetime.datetime,
) -> tuple[xr.DataArray, xr.DataArray]:
"""Align the individual initialization time forecast and observation dataarrays.
Args:
forecast: The forecast dataarray to align.
observation: The observation dataarray to align.
init_time_datetime: The initialization time to subset the forecast dataarray by.
Returns:
A tuple containing the time-aligned forecast and observation dataarrays.
"""
Expand All @@ -336,3 +340,28 @@ def temporal_align_dataarrays(
forecast = forecast.swap_dims({"lead_time": "time"})
forecast, observation = xr.align(forecast, observation, join="inner")
return (forecast, observation)


def align_observations_temporal_resolution(
forecast: xr.DataArray, observation: xr.DataArray
) -> xr.DataArray:
"""Align observation dataarray on the forecast dataarray's temporal resolution.,
Metrics which need a singular timestep from gridded obs will fail if the forecasts
are not aligned with the observation timestamps (e.g. a 03z minimum temp in observations
when the forecast only has 00z and 06z timesteps).
Args:
forecast: The forecast data which will be aligned against.
observation: The observation data to align.
Returns:
The aligned observation dataarray.
"""
obs_time_delta = pd.to_timedelta(np.diff(observation.time).mean())
forecast_time_delta = pd.to_timedelta(np.diff(forecast.lead_time).mean(), unit="h")

if forecast_time_delta > obs_time_delta:
observation = observation.resample(time=forecast_time_delta).first()

return observation
Loading

0 comments on commit 4b48b98

Please sign in to comment.