Skip to content

Commit

Permalink
Add two new colums to oracle-output target files
Browse files Browse the repository at this point in the history
Resolves #296

This changeset adds sequence_as_of and tree_as_of columns to the
oracle output target data files.
  • Loading branch information
bsweger committed Feb 28, 2025
1 parent a945767 commit 0cc86b7
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
2 changes: 2 additions & 0 deletions src/.ruff.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[format]
quote-style = "double"
20 changes: 16 additions & 4 deletions src/get_target_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,18 @@
# ///

import json
from pathlib import Path
import logging
import sys
from datetime import date, datetime, timedelta, timezone
from pathlib import Path

import click
import polars as pl
import pyarrow as pa # type: ignore
import pyarrow.dataset as ds # type: ignore
import pyarrow.parquet as pq # type: ignore
from click.testing import CliRunner
from click import Context, Option
from click.testing import CliRunner

from cladetime import Clade, CladeTime, sequence # type: ignore

Expand Down Expand Up @@ -360,7 +360,13 @@ def create_target_data(
pl.col("target_date")
>= datetime.fromisoformat(nowcast_string) - timedelta(days=31)
)
.with_columns(pl.lit(nowcast_string).alias("nowcast_date"))
.with_columns(
pl.lit(nowcast_string).alias("nowcast_date"),
pl.lit(sequence_as_of_string).alias("sequence_as_of"),
pl.lit(assignments.meta["tree_as_of"].strftime("%Y-%m-%d")).alias(
"tree_as_of"
),
)
.rename({"observation": "oracle_value"})
)

Expand Down Expand Up @@ -429,6 +435,8 @@ def write_target_data(
("clade", pa.string()),
("oracle_value", pa.float64()),
("nowcast_date", pa.date32()),
("sequence_as_of", pa.date32()),
("tree_as_of", pa.date32()),
]
)
oracle_arrow = oracle_arrow.cast(oracle_schema)
Expand Down Expand Up @@ -582,7 +590,7 @@ def test_target_data():

oracle = oracle.collect()
expected_oracle_cols = set(
["nowcast_date", "location", "target_date", "clade", "oracle_value"]
["nowcast_date", "location", "target_date", "clade", "oracle_value", "sequence_as_of", "tree_as_of"]
)
assert set(oracle.columns) == expected_oracle_cols
assert oracle.height == ts.height
Expand Down Expand Up @@ -689,6 +697,8 @@ def test_target_data_integration(caplog, tmp_path):
assert oracle_schema_dict.get("target_date") is date
assert oracle_schema_dict.get("clade") is str
assert oracle_schema_dict.get("oracle_value") is float
assert oracle_schema_dict.get("sequence_as_of") is date
assert oracle_schema_dict.get("tree_as_of") is date

# check data types when reading target data with Arrow
ts_arrow = ds.dataset(str(ts_path), format="parquet")
Expand All @@ -708,3 +718,5 @@ def test_target_data_integration(caplog, tmp_path):
assert oracle_schema.field("clade").type == pa.string()
assert oracle_schema.field("oracle_value").type == pa.float64()
assert oracle_schema.field("target_date").type == pa.date32()
assert oracle_schema.field("sequence_as_of").type == pa.date32()
assert oracle_schema.field("tree_as_of").type == pa.date32()

0 comments on commit 0cc86b7

Please sign in to comment.