Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/consistent oud treatment #2

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
cbe8740
WIP: use numpyro implementation of dismod at to derive a consistent r…
aflaxman Oct 9, 2024
17a9c3a
WIP: enhance code to the point where it fits something in a notebook
aflaxman Oct 31, 2024
acadb99
ENH: great results for consistent remission
aflaxman Nov 5, 2024
a00b0a9
ENH: add consistent rate generation to artifact building code
aflaxman Nov 6, 2024
7670a69
add requirements in the right place
aflaxman Nov 7, 2024
ac0b900
isort, black
aflaxman Nov 7, 2024
1d2c26c
ENH: light tweaks to make model run with consistent rates
aflaxman Nov 13, 2024
36f456d
isort, black
aflaxman Nov 22, 2024
5c6e8a2
oops
aflaxman Nov 22, 2024
0aeca14
BUG: make consistent data include 1000 draws
aflaxman Nov 24, 2024
e92484d
Updates to run locally
aflaxman Jan 2, 2025
8ab6169
Refactor to facilitate local development of consistent rate code --- …
aflaxman Jan 2, 2025
a5a40ac
seemingly running code for consistent model with treatment compartment
aflaxman Jan 2, 2025
112bf2c
WIP: debugging moud model
aflaxman Jan 2, 2025
2b6dfd4
moud model working (still need to debug RiskCauseModel)
aflaxman Jan 2, 2025
745d7f0
Observers are not working, but RiskDiseaseModel seems to be
aflaxman Jan 2, 2025
d7c852f
I figured out how to get the observers to work! It due to my misunde…
aflaxman Jan 2, 2025
4dae93f
isort, black
aflaxman Jan 2, 2025
ae15623
debugging attempt
aflaxman Jan 4, 2025
e8f698b
isort, black
aflaxman Jan 4, 2025
f6a353c
debugging attempt --- add more information to the consistent rates, s…
aflaxman Jan 4, 2025
c7c6fa5
try to get mcmc convergence from less precise made-up data
aflaxman Jan 4, 2025
3d89d2b
this version does not seem to have the convergence problems, but i do…
aflaxman Jan 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
*.hdf
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down Expand Up @@ -128,3 +129,4 @@ dmypy.json

# Version file
src/*/_version.py
*~
6 changes: 5 additions & 1 deletion artifact_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,8 @@ vivarium_cluster_tools>=2.0.0
black==22.3.0
isort
jupyterlab
matplotlib
matplotlib
jax
numpyro
diffrax
interpax
Empty file removed isort
Empty file.
8 changes: 7 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,13 @@

setup_requires = ["setuptools_scm"]

data_requirements = ["vivarium_inputs[data]>=5.0.7"]
data_requirements = [
"vivarium_inputs[data]>=5.0.7",
"jax",
"numpyro",
"diffrax",
"interpax",
]
cluster_requirements = ["vivarium_cluster_tools>=2.0.3"]
test_requirements = ["pytest"]
lint_requirements = ["black", "isort"]
Expand Down
159 changes: 159 additions & 0 deletions src/vivarium_nih_moud/components/conditions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import numpy as np
import pandas as pd
from vivarium_public_health.disease import (
DiseaseModel,
DiseaseState,
RateTransition,
SusceptibleState,
)
from vivarium_public_health.risks.data_transformations import (
get_exposure_post_processor,
)
from vivarium_public_health.utilities import EntityString


class RiskDiseaseModel(DiseaseModel):
@property
def name(self):
return f"disease_model.{self.cause}"

def setup(self, builder):
super(DiseaseModel, self).setup(builder)

self.state_to_risk_category_map = builder.configuration[
self.state_column
].state_to_risk_category_map.to_dict() # FIXME: is to_dict right? why is it needed?

# FIXME: I'm not sure why the next three lines are needed; I copy-pasted them from the LTBI class BetterDiseaseModel
self.configuration_age_start = builder.configuration.population.initialization_age_min
self.configuration_age_end = builder.configuration.population.initialization_age_max
self.randomness = builder.randomness.get_stream(f"{self.state_column}_initial_states")

# create a pipeline for a risk exposure based on disease model state
self.exposure = builder.value.register_value_producer(
f"{self.state_column}.exposure",
source=self.get_current_exposure,
requires_columns=[self.state_column],
preferred_post_processor=None,
)

def get_current_exposure(self, index: pd.Index) -> pd.Series:
pop = self.population_view.get(index)
exposure = pop[self.state_column].map(self.state_to_risk_category_map)

return exposure


def moud_model():
"""Create an MOUD disease model to match this D2 diagram,
where with_condition can be used as a risk exposure
https://play.d2lang.com/?script=dM5BCgIxDIXhfU-RC3iBLrxKqW3EBzPJ0KS4EO8uIjgdndll8X_h6QJFTd04VZi2ys0iPQKRdSu8OC4TB6I7_JaKSoVDJRCpJG-cfWbxsM3pdP7pI0EKKkvh1LL_P3yT4UOkxjPMoHLcjwsifc8EgSP7YMdyb9xqrxlTb3wENxNXZb0UNvuoZ3gFAAD__w%3D%3D&

opioid_use_disorders: {
susceptible
with_condition
on_treatment

susceptible -> with_condition: incidence_rate
with_condition -> susceptible: remission_rate
with_condition -> on_treatment: treatment_initiation_rate
on_treatment -> with_condition: treatment_failure_rate
on_treatment -> susceptible: treatment_success_rate
}
"""

cause = "oud_consistent"

# Custom data function for the not_on_treatment state prevalence
def get_off_treatment_prevalence(builder, _):
base_cause = cause # Use the outer cause variable

# Load overall prevalence and treatment ratio
prevalence = builder.data.load(f"cause.{base_cause}.prevalence")
treatment_ratio = builder.data.load(f"cause.{base_cause}.treatment_ratio")

# Calculate on_treatment prevalence
index_cols = ["sex", "age_start", "age_end", "year_start", "year_end"]
off_treatment_prevalence = prevalence.set_index(index_cols) * (
1 - treatment_ratio.set_index(index_cols)
)
return off_treatment_prevalence.reset_index()

# Custom data function for the on_treatment state prevalence
def get_on_treatment_prevalence(
builder, state_id
): # Changed parameter name from 'cause' to 'state_id'
# Extract the base cause name from the state_id
base_cause = cause # Use the outer cause variable

# Load overall prevalence and treatment ratio
prevalence = builder.data.load(f"cause.{base_cause}.prevalence")
treatment_ratio = builder.data.load(f"cause.{base_cause}.treatment_ratio")

# Calculate on_treatment prevalence
index_cols = ["sex", "age_start", "age_end", "year_start", "year_end"]
on_treatment_prevalence = prevalence.set_index(
index_cols
) * treatment_ratio.set_index(index_cols)
return on_treatment_prevalence.reset_index()

def get_zero(builder, state):
return 0.0

susceptible = SusceptibleState(cause, allow_self_transition=True)

with_condition = DiseaseState(
cause,
allow_self_transition=True,
get_data_functions={
"prevalence": get_off_treatment_prevalence,
},
)
with_condition.has_excess_mortality = True

# Create on_treatment state with custom prevalence data function
on_treatment = DiseaseState(
f"on_treatment_for_{cause}",
allow_self_transition=True,
get_data_functions={
"prevalence": get_on_treatment_prevalence,
"disability_weight": get_zero,
"excess_mortality_rate": get_zero,
},
)
on_treatment.has_excess_mortality = False

# Add transitions
susceptible.add_rate_transition(with_condition)
with_condition.add_rate_transition(susceptible)

with_condition.add_rate_transition(
on_treatment,
get_data_functions={
"transition_rate": lambda builder, state_1, state_2: builder.data.load(
f"cause.oud_consistent.treatment_initiation_rate"
)
},
)

on_treatment.add_rate_transition(
with_condition,
get_data_functions={
"transition_rate": lambda builder, state_1, state_2: builder.data.load(
f"cause.oud_consistent.treatment_failure_rate"
)
},
)

on_treatment.add_rate_transition(
susceptible,
get_data_functions={
"transition_rate": lambda builder, state_1, state_2: builder.data.load(
f"cause.oud_consistent.treatment_success_rate"
)
},
)

return RiskDiseaseModel(
cause, initial_state=susceptible, states=[susceptible, with_condition, on_treatment]
)
5 changes: 3 additions & 2 deletions src/vivarium_nih_moud/data/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
from pathlib import Path
from typing import Optional

import numpy as np
import pandas as pd
from loguru import logger
from vivarium.framework.artifact import Artifact, EntityKey

from vivarium_nih_moud.constants import data_keys
from vivarium_nih_moud.data import loader
from ..constants import data_keys
from ..data import loader


def open_artifact(output_path: Path, location: str) -> Artifact:
Expand Down
Loading
Loading