Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plot truncation when no smoothing #405

Merged
merged 6 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions rlberry/manager/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def plot_writer_data(
xtag : str or None, default=None
Tag of data to plot on x-axis. If None, use 'global_step'. Another often-used x-axis is
the time elapsed `dw_time_elapsed`, in which case smooth needs to be set to True or there must be only one run.
smooth : boolean, default=True
smooth : boolean, default=False
Whether to smooth the curve with a Nadaraya-Watson Kernel smoothing.
Remark that this also allow for an xtag which is not synchronized on all the simulations (e.g. time for instance).
smoothing_bandwidth: float or array of floats or None
Expand Down Expand Up @@ -530,15 +530,25 @@ def plot_synchronized_curves(
ylabel = y
assert len(data) > 0, "dataset is empty"
n_tot_simu = int(data["n_simu"].max())
# check that every simulation have the same xs

# check that every simulation have the same xs or truncate
processed_df = pd.DataFrame()
for name in np.unique(data["name"]):
df_name = data.loc[data["name"] == name]
x_simu_0 = df_name.loc[df_name["n_simu"] == 0, xlabel].values.astype(float)
for n_simu in range(1, int(n_tot_simu)):
x_simu = df_name.loc[df_name["n_simu"] == n_simu, xlabel].values.astype(
float
)
assert np.all(x_simu == x_simu_0)
if len(x_simu) != len(x_simu_0):
logger.warn("x axis is not the same for all the runs, truncating.")
x_simu_0 = np.intersect1d(x_simu_0, x_simu)
df_name = df_name.loc[df_name[xlabel].apply(lambda x: x in x_simu_0)]
assert (
len(df_name) > 0
), "x_axis are incompatible across runs, you should use smoothing"
processed_df = pd.concat([processed_df, df_name], ignore_index=True)
data = processed_df

ax, styles, cmap = _prepare_ax(data, ax, linestyles)

Expand Down
6 changes: 5 additions & 1 deletion rlberry/manager/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from rlberry.manager.plotting import plot_smoothed_curves, plot_synchronized_curves
from rlberry.agents import AgentWithSimplePolicy

np.random.seed(42)


class RandomAgent(AgentWithSimplePolicy):
name = "RandomAgent"
Expand All @@ -23,7 +25,9 @@ def __init__(self, env, **kwargs):

def fit(self, budget=100, **kwargs):
observation, info = self.env.reset()
for ep in range(budget):
for ep in range(
budget + np.random.randint(5)
): # to simulate having different sizes
action = self.policy(observation)
observation, reward, done, _, _ = self.env.step(action)

Expand Down
Loading