Skip to content

Commit

Permalink
#56 complete restructure of stage plots
Browse files Browse the repository at this point in the history
  • Loading branch information
jess-breda committed Aug 8, 2024
1 parent 3ecc914 commit 72e5b82
Show file tree
Hide file tree
Showing 8 changed files with 865 additions and 87 deletions.
596 changes: 588 additions & 8 deletions notebooks/dj_exploratory_notebooks/fixationgrower_initial_plots.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/behav_viz/utils/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def set_legend(ax, legend):
### curriculum colors
ALPHA_V1_color = "#9C1D4F"
ALPHA_V2_color = "#1D9C6A"
ALPHA_PALLETTE = [ALPHA_V1_color, ALPHA_V2_color]

### Result column utilities ###

Expand Down
106 changes: 89 additions & 17 deletions src/behav_viz/visualize/FixationGrower/exp_compare_alpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,40 +8,112 @@
from behav_viz.visualize.df_preperation import compute_days_relative_to_stage


def plot_stage_relative_to_stage(
###################### STAGE PROGRESS OVER DATE ######################


def plot_ma_stage_compare_experiments(
df,
stage,
condition,
ylim=None,
ax=None,
title="",
**kwargs,
ylim=None,
rotate_x_labels=False,
relative_to_stage=None,
):

if ax is None:
fig, ax = pu.make_fig()

# make the df
viz.multianimal_plots.plot_ma_stage_by_condition(
df,
condition="fix_experiment",
ax=ax,
title=title,
palette=pu.ALPHA_PALLETTE,
hue_order=["V1", "V2"],
ylim=ylim,
rotate_x_labels=rotate_x_labels,
relative_to_stage=relative_to_stage,
)

return None

df = df[df["curriculum"].str.contains(condition, case=False)].copy()
plot_df = compute_days_relative_to_stage(df, stage).reset_index()
x_var = f"days_relative_to_stage_{stage}"

# plot
def plot_ma_stage_single_experiment(
df,
experiment,
ax=None,
title="",
ylim=None,
rotate_x_labels=False,
relative_to_stage=None,
):

if ax is None:
fig, ax = pu.make_fig()

viz.multi_animal.plot_ma_stage(
plot_df.query("stage >= @stage"),
plot_df = df[df["fix_experiment"].str.contains(experiment, case=False)].copy()
color = pu.ALPHA_V1_color if "1" in experiment else pu.ALPHA_V2_color

viz.multianimal_plots.plot_ma_stage(
plot_df,
ax=ax,
x_var=x_var,
x_var="date",
ylim=ylim,
**kwargs,
title=title,
rotate_x_labels=rotate_x_labels,
color=color,
relative_to_stage=relative_to_stage,
)

_ = ax.set(
xlabel=f"Days relative to stage {stage}",
title=f"Days in Stage for {condition} animals (N = {len(plot_df.animal_id.unique())})",
_ = ax.set(title=title)

return None


###################### STAGE DURATION ######################
def compare_plot_days_in_stage(
df,
ax=None,
min_stage=None,
max_stage=None,
):
"""TODO: first make a general one then specific"""

if ax is None:
fig, ax = pu.make_fig((6, 4))

days_in_stage_df = viz.df_preperation.make_days_in_stage_df(
df, min_stage, max_stage, hue_var="fix_experiment"
)

sns.boxplot(
data=days_in_stage_df,
x="stage",
y="n_days",
hue="fix_experiment",
hue_order=["V1", "V2"],
palette=[pu.ALPHA_V1_color, pu.ALPHA_V2_color],
ax=ax,
showfliers=False,
dodge=True,
)
sns.swarmplot(
data=days_in_stage_df,
x="stage",
y="n_days",
hue="fix_experiment",
hue_order=["V1", "V2"],
palette=[pu.ALPHA_V1_color, pu.ALPHA_V2_color],
alpha=0.5,
dodge=True,
ax=ax,
)

_ = ax.set(ylabel="N Days", xlabel="Stage")
sns.despine()

# Optionally adjust the legend
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[0:2], labels[0:2], title="fix_experiment", frameon=False)

return None
2 changes: 1 addition & 1 deletion src/behav_viz/visualize/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ..utils import plot_utils
from . import (
multi_animal,
multianimal_plots,
multiplots,
plot_days_info,
plot_trials_info,
Expand Down
32 changes: 27 additions & 5 deletions src/behav_viz/visualize/df_preperation.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def make_long_trial_dur_df(df: pd.DataFrame) -> pd.DataFrame:


def compute_days_relative_to_stage(
df: pd.DataFrame, stage: str, date_col_name: str = "date"
df: pd.DataFrame, stage: int, date_col_name: str = "date"
) -> pd.DataFrame:
"""
Compute the number of days relative to a specific stage in the dataframe.
Expand All @@ -135,7 +135,7 @@ def compute_days_relative_to_stage(
-------
df : pd.DataFrame
DataFrame containing the data
stage : str
stage : int
The specific stage to compute the relative days for
date_col_name : str, optional
The name of the column containing the dates, by default "date"
Expand Down Expand Up @@ -170,17 +170,39 @@ def compute_days_relative_to_stage(
return df


def make_days_in_stage_df(df, min_stage=None, max_stage=None):
""" """
def make_days_in_stage_df(df, min_stage=None, max_stage=None, hue_var=None):
"""
Compute the number of days spent in each stage, for each animal in the df
Parameters:
-----------
df : pd.DataFrame
DataFrame containing the data
min_stage : int, optional
The minimum stage value to include in the computation, by default None
max_stage : int, optional
The maximum stage value to include in the computation, by default None
hue_var : str, optional
The variable to use for grouping, by default None
Returns:
--------
pd.DataFrame
DataFrame with an additional column indicating the number of days relative to the stage
"""
# query the df for stage >= min_stage and stage <= max_stage
# if they are not None
if min_stage is not None:
df = df.query("stage >= @min_stage")
if max_stage is not None:
df = df.query("stage <= @max_stage")
if hue_var is None:
cols = ["animal_id", "stage"]
else:
cols = ["animal_id", "stage", hue_var]

days_in_stage_df = (
df.groupby(["animal_id", "stage"])
df.groupby(cols)
.agg(n_days=pd.NamedAgg(column="date", aggfunc="nunique"))
.reset_index()
)
Expand Down
55 changes: 0 additions & 55 deletions src/behav_viz/visualize/multi_animal.py

This file was deleted.

131 changes: 131 additions & 0 deletions src/behav_viz/visualize/multianimal_plots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from behav_viz.utils import plot_utils as pu
import behav_viz.visualize as viz

###################### STAGE ######################


def plot_ma_stage(
df,
ax=None,
x_var="date",
title="",
ylim=None,
rotate_x_labels=False,
relative_to_stage=None,
**kwargs
):

if ax is None:
fig, ax = pu.make_fig()

# plot each animal as a gray line
for _, sub_df in df.groupby("animal_id"):
viz.plots.plot_stage(
sub_df,
x_var=x_var,
ax=ax,
alpha=0.5,
color="gray",
relative_to_stage=relative_to_stage,
)

# plot the mean of the animals
viz.plots.plot_stage(
df,
x_var=x_var,
ax=ax,
rotate_x_labels=rotate_x_labels,
ylim=ylim,
relative_to_stage=relative_to_stage,
**kwargs
)


def plot_ma_stage_by_condition(
df,
condition,
ax=None,
x_var="date",
title="",
palette="husl",
ylim=None,
rotate_x_labels=False,
relative_to_stage=None,
**kwargs
):

if ax is None:
fig, ax = pu.make_fig()

# hacky way of plot multi animals with respective colors
pal = sns.color_palette(palette, len(df[condition].unique()))
for ii, (cond, sub_df) in enumerate(df.groupby([condition])):

color = pal[ii]

for _, sub_sub_df in sub_df.groupby("animal_id"):
viz.plots.plot_stage(
sub_sub_df,
x_var=x_var,
ax=ax,
alpha=0.5,
color=color,
relative_to_stage=relative_to_stage,
**kwargs
)

# plot the mean of the animals
viz.plots.plot_stage(
df,
hue=condition,
x_var=x_var,
ax=ax,
rotate_x_labels=rotate_x_labels,
ylim=ylim,
palette=pal,
relative_to_stage=relative_to_stage,
**kwargs
)

_ = ax.set(
ylabel="Stage",
title=title,
)

return None


def plot_ma_days_in_stage(
df, ax=None, min_stage=None, max_stage=None, plot_individuals=True, **kwargs
):

if ax is None:
fig, ax = pu.make_fig((6, 4))

days_in_stage_df = viz.df_preperation.make_days_in_stage_df(
df, min_stage, max_stage
)

sns.boxplot(
data=days_in_stage_df,
x="stage",
y="n_days",
color="white",
**kwargs,
ax=ax,
showfliers=False
)
if plot_individuals:
sns.swarmplot(
data=days_in_stage_df, x="stage", y="n_days", label="", color="gray", ax=ax
)

_ = ax.set(ylabel="N Days", xlabel="Stage")
sns.despine()

return None
Loading

0 comments on commit 72e5b82

Please sign in to comment.