Skip to content

Commit

Permalink
isort, black
Browse files Browse the repository at this point in the history
  • Loading branch information
aflaxman committed Jan 2, 2025
1 parent d7c852f commit 4dae93f
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 61 deletions.
108 changes: 60 additions & 48 deletions src/vivarium_nih_moud/components/conditions.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import numpy as np, pandas as pd

import numpy as np
import pandas as pd
from vivarium_public_health.disease import (
DiseaseModel,
DiseaseState,
SusceptibleState,
RateTransition,
SusceptibleState,
)
from vivarium_public_health.risks.data_transformations import (
get_exposure_post_processor,
Expand All @@ -21,8 +21,9 @@ def setup(self, builder):
super(DiseaseModel, self).setup(builder)

self.state_to_risk_category_map = builder.configuration[
self.state_column].state_to_risk_category_map.to_dict() # FIXME: is to_dict right? why is it needed?

self.state_column
].state_to_risk_category_map.to_dict() # FIXME: is to_dict right? why is it needed?

# FIXME: I'm not sure why the next three lines are needed; I copy-pasted them from the LTBI class BetterDiseaseModel
self.configuration_age_start = builder.configuration.population.initialization_age_min
self.configuration_age_end = builder.configuration.population.initialization_age_max
Expand All @@ -35,18 +36,19 @@ def setup(self, builder):
requires_columns=[self.state_column],
preferred_post_processor=None,
)

def get_current_exposure(self, index: pd.Index) -> pd.Series:
pop = self.population_view.get(index)
exposure = pop[self.state_column].map(self.state_to_risk_category_map)

return exposure


def moud_model():
""" Create an MOUD disease model to match this D2 diagram,
"""Create an MOUD disease model to match this D2 diagram,
where with_condition can be used as a risk exposure
https://play.d2lang.com/?script=dM5BCgIxDIXhfU-RC3iBLrxKqW3EBzPJ0KS4EO8uIjgdndll8X_h6QJFTd04VZi2ys0iPQKRdSu8OC4TB6I7_JaKSoVDJRCpJG-cfWbxsM3pdP7pI0EKKkvh1LL_P3yT4UOkxjPMoHLcjwsifc8EgSP7YMdyb9xqrxlTb3wENxNXZb0UNvuoZ3gFAAD__w%3D%3D&
opioid_use_disorders: {
susceptible
with_condition
Expand All @@ -60,88 +62,98 @@ def moud_model():
}
"""

cause = 'oud_consistent'
cause = "oud_consistent"

# Custom data function for the not_on_treatment state prevalence
def get_off_treatment_prevalence(builder, _):
base_cause = cause # Use the outer cause variable

# Load overall prevalence and treatment ratio
prevalence = builder.data.load(f'cause.{base_cause}.prevalence')
treatment_ratio = builder.data.load(f'cause.{base_cause}.treatment_ratio')
prevalence = builder.data.load(f"cause.{base_cause}.prevalence")
treatment_ratio = builder.data.load(f"cause.{base_cause}.treatment_ratio")

# Calculate on_treatment prevalence
index_cols = ['sex', 'age_start', 'age_end', 'year_start', 'year_end']
off_treatment_prevalence = prevalence.set_index(index_cols) * (1-treatment_ratio.set_index(index_cols))
index_cols = ["sex", "age_start", "age_end", "year_start", "year_end"]
off_treatment_prevalence = prevalence.set_index(index_cols) * (
1 - treatment_ratio.set_index(index_cols)
)
return off_treatment_prevalence.reset_index()

# Custom data function for the on_treatment state prevalence
def get_on_treatment_prevalence(builder, state_id): # Changed parameter name from 'cause' to 'state_id'
def get_on_treatment_prevalence(
builder, state_id
): # Changed parameter name from 'cause' to 'state_id'
# Extract the base cause name from the state_id
base_cause = cause # Use the outer cause variable

# Load overall prevalence and treatment ratio
prevalence = builder.data.load(f'cause.{base_cause}.prevalence')
treatment_ratio = builder.data.load(f'cause.{base_cause}.treatment_ratio')
prevalence = builder.data.load(f"cause.{base_cause}.prevalence")
treatment_ratio = builder.data.load(f"cause.{base_cause}.treatment_ratio")

# Calculate on_treatment prevalence
index_cols = ['sex', 'age_start', 'age_end', 'year_start', 'year_end']
on_treatment_prevalence = prevalence.set_index(index_cols) * treatment_ratio.set_index(index_cols)
index_cols = ["sex", "age_start", "age_end", "year_start", "year_end"]
on_treatment_prevalence = prevalence.set_index(
index_cols
) * treatment_ratio.set_index(index_cols)
return on_treatment_prevalence.reset_index()

def get_zero(builder, state):
return 0.0

susceptible = SusceptibleState(cause, allow_self_transition=True)

with_condition = DiseaseState(
cause,
allow_self_transition=True,
get_data_functions={'prevalence': get_off_treatment_prevalence,
}
get_data_functions={
"prevalence": get_off_treatment_prevalence,
},
)
with_condition.has_excess_mortality = True


# Create on_treatment state with custom prevalence data function
# Create on_treatment state with custom prevalence data function
on_treatment = DiseaseState(
f"on_treatment_for_{cause}",
f"on_treatment_for_{cause}",
allow_self_transition=True,
get_data_functions={'prevalence': get_on_treatment_prevalence,
'disability_weight': get_zero,
'excess_mortality_rate': get_zero,
}
get_data_functions={
"prevalence": get_on_treatment_prevalence,
"disability_weight": get_zero,
"excess_mortality_rate": get_zero,
},
)
on_treatment.has_excess_mortality = False

# Add transitions
susceptible.add_rate_transition(with_condition)
with_condition.add_rate_transition(susceptible)

with_condition.add_rate_transition(
on_treatment,
get_data_functions={
'transition_rate': lambda builder, state_1, state_2: builder.data.load(
f'cause.oud_consistent.treatment_initiation_rate')
}
"transition_rate": lambda builder, state_1, state_2: builder.data.load(
f"cause.oud_consistent.treatment_initiation_rate"
)
},
)

on_treatment.add_rate_transition(
with_condition,
get_data_functions={
'transition_rate': lambda builder, state_1, state_2: builder.data.load(
f'cause.oud_consistent.treatment_failure_rate')
}
"transition_rate": lambda builder, state_1, state_2: builder.data.load(
f"cause.oud_consistent.treatment_failure_rate"
)
},
)

on_treatment.add_rate_transition(
susceptible,
get_data_functions={
'transition_rate': lambda builder, state_1, state_2: builder.data.load(
f'cause.oud_consistent.treatment_success_rate')
}
"transition_rate": lambda builder, state_1, state_2: builder.data.load(
f"cause.oud_consistent.treatment_success_rate"
)
},
)

return RiskDiseaseModel(cause,
initial_state=susceptible,
states=[susceptible, with_condition, on_treatment])
return RiskDiseaseModel(
cause, initial_state=susceptible, states=[susceptible, with_condition, on_treatment]
)
2 changes: 0 additions & 2 deletions src/vivarium_nih_moud/data/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,5 +128,3 @@ def write_data_by_draw(artifact: Artifact, key: str, data: pd.DataFrame):
data = data.reset_index(drop=True)
for c in data.columns:
store.put(f"{key.path}/{c}", data[c])


32 changes: 22 additions & 10 deletions src/vivarium_nih_moud/data/dismod_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,20 @@ def single_location_model(
)

p = at_param_w_data(
f"p_{group}", ages, years, knot_val_dict["p"], df_data[df_data.measure == "p"],
f"p_{group}",
ages,
years,
knot_val_dict["p"],
df_data[df_data.measure == "p"],
method="constant",
)

tx = at_param_w_data(
f"tx_{group}", ages, years, knot_val_dict["tx"], df_data[df_data.measure == "tx"],
f"tx_{group}",
ages,
years,
knot_val_dict["tx"],
df_data[df_data.measure == "tx"],
method="constant",
)

Expand All @@ -217,10 +225,11 @@ def ode_model(group, p, tx, i, r, ti, ts, tf, f, m, sigma, ages, years):
def dismod_f(t, y, args):
S, C, T = y
i, r, ti, ts, tf, f, m = args
return (0 - m*S - i*S + r*C + ts*T ,
0 - m*C - f*C + i*S - r*C - ti*C + tf*T,
0 - m*T + ti*C - ts*T - tf*T,
)
return (
0 - m * S - i * S + r * C + ts * T,
0 - m * C - f * C + i * S - r * C - ti * C + tf * T,
0 - m * T + ti * C - ts * T - tf * T,
)

def ode_consistency_factor(at):
a, t = at
Expand All @@ -242,7 +251,7 @@ def ode_consistency_factor(at):
)

S, C, T = solution.ys
difference = jnp.log((C+T) / (S + C + T)) - jnp.log(p(a + dt, t + dt))
difference = jnp.log((C + T) / (S + C + T)) - jnp.log(p(a + dt, t + dt))
difference += jnp.log(T / (T + C)) - jnp.log(tx(a + dt, t + dt))
return difference

Expand Down Expand Up @@ -353,6 +362,7 @@ def get_rate(self, param, year):
["sex", "age_start", "age_end", "year_start", "year_end"]
)


def generate_consistent_moud_rates(art, location: str, years):
"""Generates consistent rates for MOUD data.
Expand Down Expand Up @@ -445,9 +455,11 @@ def write_or_replace(art, key, data):
else:
art.write(key, data)


if __name__ == "__main__":
from vivarium import Artifact
location = 'Washington'

location = "Washington"
years = 2021
art = Artifact('washington.hdf')
generate_consistent_moud_rates(art, location, years)
art = Artifact("washington.hdf")
generate_consistent_moud_rates(art, location, years)
3 changes: 2 additions & 1 deletion src/vivarium_nih_moud/tools/make_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
from loguru import logger

from ..constants import data_keys, metadata
from ..data import dismod_at
from ..tools.app_logging import add_logging_sink, decode_status
from ..utilities import sanitize_location
from ..data import dismod_at


def running_from_cluster() -> bool:
import vivarium_cluster_tools as vct
Expand Down

0 comments on commit 4dae93f

Please sign in to comment.