Skip to content

Commit

Permalink
Update tests for data type changes
Browse files Browse the repository at this point in the history
  • Loading branch information
bsweger committed Jan 14, 2025
1 parent e0e1af1 commit e820c94
Showing 1 changed file with 23 additions and 14 deletions.
37 changes: 23 additions & 14 deletions src/get_target_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,17 +629,19 @@ def test_target_data_integration(caplog, tmp_path):
assert len(modeled_clades) == len(ts_clades)
assert set(ts_clades) == (set(modeled_clades))

assert ts.get_column("tree_as_of").unique().to_list() == ["2024-09-09"]
assert ts.get_column("tree_as_of").unique().to_list() == [
datetime(2024, 9, 9).date()
]

# check time series column data types
ts_schema_dict = ts.schema.to_python()
assert ts_schema_dict.get("location") is str
assert ts_schema_dict.get("target_date") is date
assert ts_schema_dict.get("clade") is str
assert ts_schema_dict.get("observation") is int
assert ts_schema_dict.get("nowcast_date") is str
assert ts_schema_dict.get("sequence_as_of") is str
assert ts_schema_dict.get("tree_as_of") is str
assert ts_schema_dict.get("observation") is float
assert ts_schema_dict.get("nowcast_date") is date
assert ts_schema_dict.get("sequence_as_of") is date
assert ts_schema_dict.get("tree_as_of") is date

# time series rows should = total target dates * total locations * total clades
len(target_dates) * len(state_list) * len(modeled_clades) == ts.height
Expand All @@ -665,22 +667,29 @@ def test_target_data_integration(caplog, tmp_path):
assert len(modeled_clades) == len(oracle_clades)
assert set(oracle_clades) == (set(modeled_clades))

# check oracle column data types
# check oracle column data types on Polars dataframe
oracle_schema_dict = oracle.schema.to_python()
assert oracle_schema_dict.get("nowcast_date") is str
assert oracle_schema_dict.get("nowcast_date") is date
assert oracle_schema_dict.get("location") is str
assert oracle_schema_dict.get("target_date") is date
assert oracle_schema_dict.get("clade") is str
assert oracle_schema_dict.get("oracle_value") is int
assert oracle_schema_dict.get("oracle_value") is float

# string columns also used as hive partition keys should have a datatype of
# string (instead of large_string) when read by Arrow. otherwise, Hubverse
# tools will throw a schema mismatch error when reading the target data
# check data types when reading target data with Arrow
ts_arrow = ds.dataset(str(ts_path), format="parquet")
ts_schema = ts_arrow.schema
assert ts_schema.field("nowcast_date").type == pa.string()
assert ts_schema.field("sequence_as_of").type == pa.string()
assert ts_schema.field("nowcast_date").type == pa.date32()
assert ts_schema.field("location").type == pa.string()
assert ts_schema.field("clade").type == pa.string()
assert ts_schema.field("observation").type == pa.float64()
assert ts_schema.field("target_date").type == pa.date32()
assert ts_schema.field("sequence_as_of").type == pa.date32()
assert ts_schema.field("tree_as_of").type == pa.date32()

oracle_arrow = ds.dataset(str(oracle_path), format="parquet")
oracle_schema = oracle_arrow.schema
assert oracle_schema.field("nowcast_date").type == pa.string()
assert oracle_schema.field("nowcast_date").type == pa.date32()
assert oracle_schema.field("location").type == pa.string()
assert oracle_schema.field("clade").type == pa.string()
assert oracle_schema.field("oracle_value").type == pa.float64()
assert oracle_schema.field("target_date").type == pa.date32()

0 comments on commit e820c94

Please sign in to comment.