From 4dae93f1941460b42a2877a14cf1bbca105be2b0 Mon Sep 17 00:00:00 2001 From: Abraham Flaxman Date: Thu, 2 Jan 2025 12:31:18 -0800 Subject: [PATCH] isort, black --- .../components/conditions.py | 108 ++++++++++-------- src/vivarium_nih_moud/data/builder.py | 2 - src/vivarium_nih_moud/data/dismod_at.py | 32 ++++-- src/vivarium_nih_moud/tools/make_artifacts.py | 3 +- 4 files changed, 84 insertions(+), 61 deletions(-) diff --git a/src/vivarium_nih_moud/components/conditions.py b/src/vivarium_nih_moud/components/conditions.py index db204d4..33bfcc4 100644 --- a/src/vivarium_nih_moud/components/conditions.py +++ b/src/vivarium_nih_moud/components/conditions.py @@ -1,10 +1,10 @@ -import numpy as np, pandas as pd - +import numpy as np +import pandas as pd from vivarium_public_health.disease import ( DiseaseModel, DiseaseState, - SusceptibleState, RateTransition, + SusceptibleState, ) from vivarium_public_health.risks.data_transformations import ( get_exposure_post_processor, @@ -21,8 +21,9 @@ def setup(self, builder): super(DiseaseModel, self).setup(builder) self.state_to_risk_category_map = builder.configuration[ - self.state_column].state_to_risk_category_map.to_dict() # FIXME: is to_dict right? why is it needed? - + self.state_column + ].state_to_risk_category_map.to_dict() # FIXME: is to_dict right? why is it needed? + # FIXME: I'm not sure why the next three lines are needed; I copy-pasted them from the LTBI class BetterDiseaseModel self.configuration_age_start = builder.configuration.population.initialization_age_min self.configuration_age_end = builder.configuration.population.initialization_age_max @@ -35,18 +36,19 @@ def setup(self, builder): requires_columns=[self.state_column], preferred_post_processor=None, ) - + def get_current_exposure(self, index: pd.Index) -> pd.Series: pop = self.population_view.get(index) exposure = pop[self.state_column].map(self.state_to_risk_category_map) - + return exposure + def moud_model(): - """ Create an MOUD disease model to match this D2 diagram, + """Create an MOUD disease model to match this D2 diagram, where with_condition can be used as a risk exposure https://play.d2lang.com/?script=dM5BCgIxDIXhfU-RC3iBLrxKqW3EBzPJ0KS4EO8uIjgdndll8X_h6QJFTd04VZi2ys0iPQKRdSu8OC4TB6I7_JaKSoVDJRCpJG-cfWbxsM3pdP7pI0EKKkvh1LL_P3yT4UOkxjPMoHLcjwsifc8EgSP7YMdyb9xqrxlTb3wENxNXZb0UNvuoZ3gFAAD__w%3D%3D& - + opioid_use_disorders: { susceptible with_condition @@ -60,88 +62,98 @@ def moud_model(): } """ - cause = 'oud_consistent' + cause = "oud_consistent" # Custom data function for the not_on_treatment state prevalence def get_off_treatment_prevalence(builder, _): base_cause = cause # Use the outer cause variable - + # Load overall prevalence and treatment ratio - prevalence = builder.data.load(f'cause.{base_cause}.prevalence') - treatment_ratio = builder.data.load(f'cause.{base_cause}.treatment_ratio') - + prevalence = builder.data.load(f"cause.{base_cause}.prevalence") + treatment_ratio = builder.data.load(f"cause.{base_cause}.treatment_ratio") + # Calculate on_treatment prevalence - index_cols = ['sex', 'age_start', 'age_end', 'year_start', 'year_end'] - off_treatment_prevalence = prevalence.set_index(index_cols) * (1-treatment_ratio.set_index(index_cols)) + index_cols = ["sex", "age_start", "age_end", "year_start", "year_end"] + off_treatment_prevalence = prevalence.set_index(index_cols) * ( + 1 - treatment_ratio.set_index(index_cols) + ) return off_treatment_prevalence.reset_index() # Custom data function for the on_treatment state prevalence - def get_on_treatment_prevalence(builder, state_id): # Changed parameter name from 'cause' to 'state_id' + def get_on_treatment_prevalence( + builder, state_id + ): # Changed parameter name from 'cause' to 'state_id' # Extract the base cause name from the state_id base_cause = cause # Use the outer cause variable - + # Load overall prevalence and treatment ratio - prevalence = builder.data.load(f'cause.{base_cause}.prevalence') - treatment_ratio = builder.data.load(f'cause.{base_cause}.treatment_ratio') - + prevalence = builder.data.load(f"cause.{base_cause}.prevalence") + treatment_ratio = builder.data.load(f"cause.{base_cause}.treatment_ratio") + # Calculate on_treatment prevalence - index_cols = ['sex', 'age_start', 'age_end', 'year_start', 'year_end'] - on_treatment_prevalence = prevalence.set_index(index_cols) * treatment_ratio.set_index(index_cols) + index_cols = ["sex", "age_start", "age_end", "year_start", "year_end"] + on_treatment_prevalence = prevalence.set_index( + index_cols + ) * treatment_ratio.set_index(index_cols) return on_treatment_prevalence.reset_index() def get_zero(builder, state): return 0.0 susceptible = SusceptibleState(cause, allow_self_transition=True) - + with_condition = DiseaseState( cause, allow_self_transition=True, - get_data_functions={'prevalence': get_off_treatment_prevalence, - } + get_data_functions={ + "prevalence": get_off_treatment_prevalence, + }, ) with_condition.has_excess_mortality = True - - # Create on_treatment state with custom prevalence data function + # Create on_treatment state with custom prevalence data function on_treatment = DiseaseState( - f"on_treatment_for_{cause}", + f"on_treatment_for_{cause}", allow_self_transition=True, - get_data_functions={'prevalence': get_on_treatment_prevalence, - 'disability_weight': get_zero, - 'excess_mortality_rate': get_zero, - } + get_data_functions={ + "prevalence": get_on_treatment_prevalence, + "disability_weight": get_zero, + "excess_mortality_rate": get_zero, + }, ) on_treatment.has_excess_mortality = False # Add transitions susceptible.add_rate_transition(with_condition) with_condition.add_rate_transition(susceptible) - + with_condition.add_rate_transition( on_treatment, get_data_functions={ - 'transition_rate': lambda builder, state_1, state_2: builder.data.load( - f'cause.oud_consistent.treatment_initiation_rate') - } + "transition_rate": lambda builder, state_1, state_2: builder.data.load( + f"cause.oud_consistent.treatment_initiation_rate" + ) + }, ) - + on_treatment.add_rate_transition( with_condition, get_data_functions={ - 'transition_rate': lambda builder, state_1, state_2: builder.data.load( - f'cause.oud_consistent.treatment_failure_rate') - } + "transition_rate": lambda builder, state_1, state_2: builder.data.load( + f"cause.oud_consistent.treatment_failure_rate" + ) + }, ) - + on_treatment.add_rate_transition( susceptible, get_data_functions={ - 'transition_rate': lambda builder, state_1, state_2: builder.data.load( - f'cause.oud_consistent.treatment_success_rate') - } + "transition_rate": lambda builder, state_1, state_2: builder.data.load( + f"cause.oud_consistent.treatment_success_rate" + ) + }, ) - return RiskDiseaseModel(cause, - initial_state=susceptible, - states=[susceptible, with_condition, on_treatment]) \ No newline at end of file + return RiskDiseaseModel( + cause, initial_state=susceptible, states=[susceptible, with_condition, on_treatment] + ) diff --git a/src/vivarium_nih_moud/data/builder.py b/src/vivarium_nih_moud/data/builder.py index 879e6e1..2f580ce 100644 --- a/src/vivarium_nih_moud/data/builder.py +++ b/src/vivarium_nih_moud/data/builder.py @@ -128,5 +128,3 @@ def write_data_by_draw(artifact: Artifact, key: str, data: pd.DataFrame): data = data.reset_index(drop=True) for c in data.columns: store.put(f"{key.path}/{c}", data[c]) - - diff --git a/src/vivarium_nih_moud/data/dismod_at.py b/src/vivarium_nih_moud/data/dismod_at.py index 179100b..6b7946d 100644 --- a/src/vivarium_nih_moud/data/dismod_at.py +++ b/src/vivarium_nih_moud/data/dismod_at.py @@ -199,12 +199,20 @@ def single_location_model( ) p = at_param_w_data( - f"p_{group}", ages, years, knot_val_dict["p"], df_data[df_data.measure == "p"], + f"p_{group}", + ages, + years, + knot_val_dict["p"], + df_data[df_data.measure == "p"], method="constant", ) tx = at_param_w_data( - f"tx_{group}", ages, years, knot_val_dict["tx"], df_data[df_data.measure == "tx"], + f"tx_{group}", + ages, + years, + knot_val_dict["tx"], + df_data[df_data.measure == "tx"], method="constant", ) @@ -217,10 +225,11 @@ def ode_model(group, p, tx, i, r, ti, ts, tf, f, m, sigma, ages, years): def dismod_f(t, y, args): S, C, T = y i, r, ti, ts, tf, f, m = args - return (0 - m*S - i*S + r*C + ts*T , - 0 - m*C - f*C + i*S - r*C - ti*C + tf*T, - 0 - m*T + ti*C - ts*T - tf*T, - ) + return ( + 0 - m * S - i * S + r * C + ts * T, + 0 - m * C - f * C + i * S - r * C - ti * C + tf * T, + 0 - m * T + ti * C - ts * T - tf * T, + ) def ode_consistency_factor(at): a, t = at @@ -242,7 +251,7 @@ def ode_consistency_factor(at): ) S, C, T = solution.ys - difference = jnp.log((C+T) / (S + C + T)) - jnp.log(p(a + dt, t + dt)) + difference = jnp.log((C + T) / (S + C + T)) - jnp.log(p(a + dt, t + dt)) difference += jnp.log(T / (T + C)) - jnp.log(tx(a + dt, t + dt)) return difference @@ -353,6 +362,7 @@ def get_rate(self, param, year): ["sex", "age_start", "age_end", "year_start", "year_end"] ) + def generate_consistent_moud_rates(art, location: str, years): """Generates consistent rates for MOUD data. @@ -445,9 +455,11 @@ def write_or_replace(art, key, data): else: art.write(key, data) + if __name__ == "__main__": from vivarium import Artifact - location = 'Washington' + + location = "Washington" years = 2021 - art = Artifact('washington.hdf') - generate_consistent_moud_rates(art, location, years) \ No newline at end of file + art = Artifact("washington.hdf") + generate_consistent_moud_rates(art, location, years) diff --git a/src/vivarium_nih_moud/tools/make_artifacts.py b/src/vivarium_nih_moud/tools/make_artifacts.py index 298ad0d..7ca957c 100644 --- a/src/vivarium_nih_moud/tools/make_artifacts.py +++ b/src/vivarium_nih_moud/tools/make_artifacts.py @@ -16,9 +16,10 @@ from loguru import logger from ..constants import data_keys, metadata +from ..data import dismod_at from ..tools.app_logging import add_logging_sink, decode_status from ..utilities import sanitize_location -from ..data import dismod_at + def running_from_cluster() -> bool: import vivarium_cluster_tools as vct