Skip to content

Commit

Permalink
isort, black
Browse files Browse the repository at this point in the history
  • Loading branch information
aflaxman committed Jan 5, 2025
1 parent 7b95f0d commit 6d2067f
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 79 deletions.
21 changes: 11 additions & 10 deletions src/vivarium_nih_moud/components/locations.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,31 @@
from typing import Dict, Callable
from typing import Callable, Dict

import numpy as np
import pandas as pd

from vivarium.framework.engine import Builder
from vivarium_public_health.disease.state import BaseDiseaseState

from .conditions import RiskDiseaseModel


class HousingState(BaseDiseaseState):
def __init__(self, state_id: str, allow_self_transition: bool = True):
super().__init__(state_id, allow_self_transition)

def add_rate_transition(self, output_state):
rate = f'quarters.{self.state_id}_to_{output_state.state_id}.transition_rate' # key for this transition rate in the artifact
get_data_functions={
'transition_rate': lambda builder, i, o: builder.data.load(rate)
rate = f"quarters.{self.state_id}_to_{output_state.state_id}.transition_rate" # key for this transition rate in the artifact
get_data_functions = {
"transition_rate": lambda builder, i, o: builder.data.load(rate)
}
return super().add_rate_transition(output_state, get_data_functions)


def quarters_model():
cause = 'quarters'
cause = "quarters"

housed = HousingState('housed')
unhoused = HousingState('unhoused')
incarcerated = HousingState('incarcerated')
housed = HousingState("housed")
unhoused = HousingState("unhoused")
incarcerated = HousingState("incarcerated")

housed.add_rate_transition(unhoused)
housed.add_rate_transition(incarcerated)
Expand All @@ -39,4 +40,4 @@ def quarters_model():
cause,
initial_state=housed,
states=[housed, unhoused, incarcerated],
)
)
13 changes: 6 additions & 7 deletions src/vivarium_nih_moud/data/dismod_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
import numpy as np
import numpyro
import pandas as pd
import utils
from diffrax import Dopri5, ODETerm, SaveAt, diffeqsolve
from numpyro import distributions as dist
from numpyro import infer

import utils
from utils import write_or_replace


def transform_to_data(param, df_in, sex, ages, years):
"""Convert artifact data to a format suitable for DisMod-AT-NumPyro."""
t = df_in.loc[sex]
Expand Down Expand Up @@ -444,10 +444,10 @@ def oud_data(sex):
ages,
[2021],
),
transform_to_data("ti", utils.generate_constant_data(0.0), sex, ages, [2021]),
transform_to_data("ts", utils.generate_constant_data(0.0), sex, ages, [2021]),
# transform_to_data("tf", utils.generate_constant_data(1.0), sex, ages, [2021]),
# transform_to_data("tx", utils.generate_constant_data(0.0), sex, ages, [2021]),
transform_to_data("ti", utils.generate_constant_data(0.0), sex, ages, [2021]),
transform_to_data("ts", utils.generate_constant_data(0.0), sex, ages, [2021]),
# transform_to_data("tf", utils.generate_constant_data(1.0), sex, ages, [2021]),
# transform_to_data("tx", utils.generate_constant_data(0.0), sex, ages, [2021]),
]
)
return df_data
Expand Down Expand Up @@ -480,7 +480,6 @@ def get_rates(model_dict, rate_type, year):
write_or_replace(art, rate_name, df_out)



if __name__ == "__main__":
from vivarium import Artifact

Expand Down
19 changes: 10 additions & 9 deletions src/vivarium_nih_moud/data/quarters_data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import numpy as np, pandas as pd

import numpy as np
import pandas as pd
import utils


def generate_quarters_data(art, location: str, years):
# copy metadata
for key in [
Expand All @@ -12,31 +13,31 @@ def generate_quarters_data(art, location: str, years):
utils.write_or_replace(art, key.replace("opioid_use_disorders", "quarters"), data)

# add stand-in data for rates
key = 'cause.quarters.cause_specific_mortality_rate'
key = "cause.quarters.cause_specific_mortality_rate"
data = utils.generate_constant_data(0.0)
utils.write_or_replace(art, key, data)

key = 'quarters.housed_to_unhoused.transition_rate'
key = "quarters.housed_to_unhoused.transition_rate"
data = utils.generate_constant_data(0.05)
utils.write_or_replace(art, key, data)

key = 'quarters.housed_to_incarcerated.transition_rate'
key = "quarters.housed_to_incarcerated.transition_rate"
data = utils.generate_constant_data(0.05)
utils.write_or_replace(art, key, data)

key = 'quarters.unhoused_to_housed.transition_rate'
key = "quarters.unhoused_to_housed.transition_rate"
data = utils.generate_constant_data(0.50)
utils.write_or_replace(art, key, data)

key = 'quarters.unhoused_to_incarcerated.transition_rate'
key = "quarters.unhoused_to_incarcerated.transition_rate"
data = utils.generate_constant_data(0.50)
utils.write_or_replace(art, key, data)

key = 'quarters.incarcerated_to_housed.transition_rate'
key = "quarters.incarcerated_to_housed.transition_rate"
data = utils.generate_constant_data(0.50)
utils.write_or_replace(art, key, data)

key = 'quarters.incarcerated_to_unhoused.transition_rate'
key = "quarters.incarcerated_to_unhoused.transition_rate"
data = utils.generate_constant_data(0.50)
utils.write_or_replace(art, key, data)

Expand Down
29 changes: 18 additions & 11 deletions src/vivarium_nih_moud/data/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import numpy as np, pandas as pd
import numpy as np
import pandas as pd


def write_or_replace(art, key, data):
if key in art.keys:
Expand All @@ -7,22 +9,27 @@ def write_or_replace(art, key, data):
else:
art.write(key, data)


def generate_constant_data(data_value):
ages = [20]
years = [2021]
sexes = ['Male', 'Female']
sexes = ["Male", "Female"]
data = []
for age in ages:
for year in years:
for sex in sexes:
data.append({
'age_start': age,
'age_end': age+5,
'year_start': year,
'year_end': year+1,
'sex': sex,
})
data.append(
{
"age_start": age,
"age_end": age + 5,
"year_start": year,
"year_end": year + 1,
"sex": sex,
}
)
for i in range(1_000):
data[-1][f'draw_{i}'] = np.clip(data_value+np.random.uniform(0,.1), 0, 1)
data[-1][f"draw_{i}"] = np.clip(
data_value + np.random.uniform(0, 0.1), 0, 1
)
data = pd.DataFrame(data)
return data.set_index(['sex', 'age_start', 'age_end', 'year_start', 'year_end'])
return data.set_index(["sex", "age_start", "age_end", "year_start", "year_end"])
108 changes: 66 additions & 42 deletions tests/visual_test.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,63 @@
# visual_test.py
import numpy as np, matplotlib.pyplot as plt, pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import vivarium as vi


def add_age_group(pop):
"""
Add an age group column to the population DataFrame.
Parameters:
pop (pandas.DataFrame): The population DataFrame.
"""
pop['age_group'] = pd.cut(pop['age'], bins=np.arange(0,101,5), right=False)
pop["age_group"] = pd.cut(pop["age"], bins=np.arange(0, 101, 5), right=False)


def validation_plots(sim, art):
"""
Create validation plots for incidence, prevalence, and cause-specific mortality rate (CSMR) over time.
Plots males and females separately, with rows for each metric (incidence, prevalence, and CSMR).
Adds artifact values to each plot for comparison.
Parameters:
sim (vivarium.framework.interactive_context.InteractiveContext): The simulation context.
art (vivarium.framework.artifact.Artifact): The artifact object.
"""
cause = 'oud_consistent'
key_pt = f'person_time_{cause}'
key_transition = f'transition_count_{cause}'
cause = "oud_consistent"
key_pt = f"person_time_{cause}"
key_transition = f"transition_count_{cause}"

years = 12*7/365
years = 12 * 7 / 365
# import pdb; pdb.set_trace()
sim.run_for(pd.Timedelta(days=years*365))
sim.run_for(pd.Timedelta(days=years * 365))

pop = sim.get_population()
add_age_group(pop)

df_pt = pop.groupby(['age_group', 'sex']).tracked.count() * years
df_susceptible_pt = pop[pop[cause] == f'susceptible_to_{cause}'].groupby(['age_group', 'sex']).tracked.count() * years
df_prevalent_cases = pop[pop[cause] != f'susceptible_to_{cause}'].groupby(['age_group', 'sex']).tracked.count()
df_pt = pop.groupby(["age_group", "sex"]).tracked.count() * years
df_susceptible_pt = (
pop[pop[cause] == f"susceptible_to_{cause}"]
.groupby(["age_group", "sex"])
.tracked.count()
* years
)
df_prevalent_cases = (
pop[pop[cause] != f"susceptible_to_{cause}"]
.groupby(["age_group", "sex"])
.tracked.count()
)
df_with_condition_pt = df_prevalent_cases * years
df_deaths = pop[pop.cause_of_death == cause].groupby(['age_group', 'sex']).tracked.count() * years
df_incident_cases = pop.groupby(['age_group', 'sex']).susceptible_to_oud_consistent_event_count.sum()
df_remission_cases = pop.groupby(['age_group', 'sex']).susceptible_to_oud_consistent_event_count.sum()*0
df_deaths = (
pop[pop.cause_of_death == cause].groupby(["age_group", "sex"]).tracked.count() * years
)
df_incident_cases = pop.groupby(
["age_group", "sex"]
).susceptible_to_oud_consistent_event_count.sum()
df_remission_cases = (
pop.groupby(["age_group", "sex"]).susceptible_to_oud_consistent_event_count.sum() * 0
)

# Compute CSMR, Prevalence, and Incidence
df_csmr = 100_000 * df_deaths / df_pt
Expand All @@ -47,28 +66,28 @@ def validation_plots(sim, art):
df_remission = 100_000 * df_remission_cases / df_with_condition_pt
df_excess_mortality = 100_000 * df_deaths / df_with_condition_pt


# Load artifact data
df_art_csmr = art.load(f'cause.{cause}.cause_specific_mortality_rate')
df_art_prevalence = art.load(f'cause.{cause}.prevalence')
df_art_incidence = art.load(f'cause.{cause}.incidence_rate')
df_art_remission = art.load(f'cause.{cause}.remission_rate')
df_art_excess_mortality = art.load(f'cause.{cause}.excess_mortality_rate')
df_art_csmr = art.load(f"cause.{cause}.cause_specific_mortality_rate")
df_art_prevalence = art.load(f"cause.{cause}.prevalence")
df_art_incidence = art.load(f"cause.{cause}.incidence_rate")
df_art_remission = art.load(f"cause.{cause}.remission_rate")
df_art_excess_mortality = art.load(f"cause.{cause}.excess_mortality_rate")

# Set up the plotting grid: two columns (Males, Females), three rows (CSMR, Prevalence, Incidence)
fig, axes = plt.subplots(5, 2, figsize=(14, 9), sharex=True)
fig.suptitle(f'Validation Plots for {cause.capitalize()}', fontsize=16)
metrics = [('Cause Specific Mortality Rate', df_csmr, df_art_csmr),
('Excess Mortality', df_excess_mortality, df_art_excess_mortality),
('Prevalence', df_prevalence, df_art_prevalence),
('Incidence', df_incidence, df_art_incidence),
('Remission', df_remission, df_art_remission),
]
sexes = ['Male', 'Female']
fig.suptitle(f"Validation Plots for {cause.capitalize()}", fontsize=16)
metrics = [
("Cause Specific Mortality Rate", df_csmr, df_art_csmr),
("Excess Mortality", df_excess_mortality, df_art_excess_mortality),
("Prevalence", df_prevalence, df_art_prevalence),
("Incidence", df_incidence, df_art_incidence),
("Remission", df_remission, df_art_remission),
]
sexes = ["Male", "Female"]
years = [2021]

def age_x(age_group):
return (age_group.left+age_group.right)/2
return (age_group.left + age_group.right) / 2

for row, (metric_name, data, artifact_data) in enumerate(metrics):
for col, sex in enumerate(sexes):
Expand All @@ -78,31 +97,36 @@ def age_x(age_group):
data_to_plot.index = data_to_plot.index.map(age_x).astype(float)
data_to_plot = data_to_plot.sort_index()
t1 = data_to_plot
t1[sex].plot(ax=ax, label=f'{year}', marker='o', linestyle='none')
t1[sex].plot(ax=ax, label=f"{year}", marker="o", linestyle="none")

# Plot artifact data
artifact_to_plot = artifact_data.loc[sex]*100_000
artifact_to_plot.index = artifact_to_plot.eval('.5*(age_start+age_end)').astype(float)
artifact_to_plot = artifact_data.loc[sex] * 100_000
artifact_to_plot.index = artifact_to_plot.eval(".5*(age_start+age_end)").astype(
float
)
artifact_to_plot = artifact_to_plot.sort_index()
artifact_mean = artifact_to_plot.mean(axis=1)
artifact_mean.plot(ax=ax, label=f'Artifact {sex}', color=f'k', linestyle='-', alpha=.75)
artifact_mean.plot(
ax=ax, label=f"Artifact {sex}", color=f"k", linestyle="-", alpha=0.75
)

ax.set_title(f'{metric_name} ({sex})')
ax.set_ylabel(f'{metric_name} (Per 100,000 PY)')
ax.set_title(f"{metric_name} ({sex})")
ax.set_ylabel(f"{metric_name} (Per 100,000 PY)")
ax.grid(True)
# ax.set_yscale('log')
ax.legend(loc='upper left')
# ax.set_yscale('log')
ax.legend(loc="upper left")

plt.xlabel('Age Group')
plt.xlabel("Age Group")
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()

path = 'src/vivarium_nih_moud/model_specifications/model_spec.yaml'

path = "src/vivarium_nih_moud/model_specifications/model_spec.yaml"
sim = vi.InteractiveContext(path)

pop = sim.get_population()

artifact_path = sim.configuration.input_data.artifact_path
art = vi.Artifact(artifact_path)

validation_plots(sim, art)
validation_plots(sim, art)

0 comments on commit 6d2067f

Please sign in to comment.