-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathsetting.py
1469 lines (1298 loc) · 64.8 KB
/
setting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
""" Current most general Setting in the Reinforcement Learning side of the tree.
"""
import difflib
import json
import textwrap
import warnings
from dataclasses import dataclass, fields
from functools import partial
from pathlib import Path
from typing import Any, Callable, ClassVar, Dict, List, Optional, Type, Union
import gym
import numpy as np
from gym import spaces
from gym.envs.registration import EnvSpec, registry
from gym.utils import colorize
from gym.wrappers import TimeLimit
from simple_parsing import choice, field, list_field
from simple_parsing.helpers import dict_field
try:
from stable_baselines3.common.atari_wrappers import AtariWrapper as SB3AtariWrapper
except ImportError:
class SB3AtariWrapper:
pass
from gym.wrappers.atari_preprocessing import AtariPreprocessing as GymAtariWrapper
import wandb
from sequoia.common import Config
from sequoia.common.gym_wrappers import (
AddDoneToObservation,
MultiTaskEnvironment,
RenderEnvWrapper,
SmoothTransitions,
TransformObservation,
TransformReward,
)
from sequoia.common.gym_wrappers.action_limit import ActionLimit
from sequoia.common.gym_wrappers.convert_tensors import add_tensor_support
from sequoia.common.gym_wrappers.env_dataset import EnvDataset
from sequoia.common.gym_wrappers.episode_limit import EpisodeLimit
from sequoia.common.gym_wrappers.pixel_observation import ImageObservations
from sequoia.common.gym_wrappers.utils import is_atari_env
from sequoia.common.spaces import Sparse, TypedDictSpace
from sequoia.common.transforms import Transforms
from sequoia.settings.assumptions.continual import ContinualAssumption
from sequoia.settings.base import Method
from sequoia.settings.rl import ActiveEnvironment, RLSetting
from sequoia.settings.rl.wrappers import (
HideTaskLabelsWrapper,
MeasureRLPerformanceWrapper,
TypedObjectsWrapper,
)
from sequoia.utils import get_logger
from sequoia.utils.generic_functions import move
from sequoia.utils.utils import flag, pairwise
from .environment import GymDataLoader
from .make_env import make_batched_env
from .objects import Actions, Observations, Rewards # type: ignore
from .results import ContinualRLResults
from .tasks import ContinuousTask, TaskSchedule, is_supported, make_continuous_task, names_match
from .test_environment import ContinualRLTestEnvironment
logger = get_logger(__name__)
# Type alias for the Environment returned by `train/val/test_dataloader`.
Environment = ActiveEnvironment[
"ContinualRLSetting.Observations",
"ContinualRLSetting.Observations",
"ContinualRLSetting.Rewards",
]
# NOTE: Takes about 0.2 seconds to check for all compatible envs (with loading), and
# only happens once.
supported_envs: Dict[str, EnvSpec] = {
spec.id: spec for env_id, spec in registry.env_specs.items() if is_supported(env_id)
}
available_datasets: Dict[str, str] = {env_id: env_id for env_id in supported_envs}
# available_datasets.update(
# {camel_case(env_id.split("-v")[0]): env_id for env_id in supported_envs}
# )
@dataclass
class ContinualRLSetting(RLSetting, ContinualAssumption):
"""Reinforcement Learning Setting where the environment changes over time.
This is an Active setting which uses gym environments as sources of data.
These environments' attributes could change over time following a task
schedule. An example of this could be that the gravity increases over time
in cartpole, making the task progressively harder as the agent interacts with
the environment.
"""
# (NOTE: commenting out SLSetting.Observations as it is the same class
# as Setting.Observations, and we want a consistent method resolution order.
Observations: ClassVar[Type[Observations]] = Observations
Actions: ClassVar[Type[Actions]] = Actions
Rewards: ClassVar[Type[Rewards]] = Rewards
# The type of results returned by an RL experiment.
Results: ClassVar[Type[Results]] = ContinualRLResults
# The type wrapper used to wrap the test environment, and which produces the
# results.
TestEnvironment: ClassVar[Type[TestEnvironment]] = ContinualRLTestEnvironment
# Dict of all available options for the 'dataset' field below.
available_datasets: ClassVar[Dict[str, Union[str, Any]]] = available_datasets
# The function used to create the tasks for the chosen env.
_task_sampling_function: ClassVar[Callable[..., ContinuousTask]] = make_continuous_task
# Which environment (a.k.a. "dataset") to learn on.
# The dataset could be either a string (env id or a key from the
# available_datasets dict), a gym.Env, or a callable that returns a
# single environment.
dataset: str = choice(available_datasets, default="CartPole-v0")
# The number of "tasks" that will be created for the training, valid and test
# environments.
# NOTE: In the case of settings with smooth task boundaries, this is the number of
# "base" tasks which are created, and the task space consists of interpolations
# between these base tasks.
# When left unset, will use a default value that makes sense
# (something like 5).
nb_tasks: int = field(5, alias=["n_tasks", "num_tasks"])
# Environment/dataset to use for validation. Defaults to the same as `dataset`.
train_dataset: Optional[str] = None
# Environment/dataset to use for validation. Defaults to the same as `dataset`.
val_dataset: Optional[str] = None
# Environment/dataset to use for testing. Defaults to the same as `dataset`.
test_dataset: Optional[str] = None
# Wether the task boundaries are smooth or sudden.
smooth_task_boundaries: bool = True
# Wether the tasks are sampled uniformly. (This is set to True in MultiTaskRLSetting
# and below)
stationary_context: bool = False
# Max number of training steps in total. (Also acts as the "length" of the training
# and validation "Datasets")
train_max_steps: int = 100_000
# Maximum number of episodes in total.
# TODO: Add tests for this 'max episodes' and 'episodes_per_task'.
train_max_episodes: Optional[int] = None
# Total number of steps in the test loop. (Also acts as the "length" of the testing
# environment.)
test_max_steps: int = 10_000
test_max_episodes: Optional[int] = None
# Standard deviation of the multiplicative Gaussian noise that is used to
# create the values of the env attributes for each task.
task_noise_std: float = 0.2
# NOTE: THIS ARG IS DEPRECATED! Only keeping it here so previous config yaml files
# don't cause a crash.
observe_state_directly: Optional[bool] = None
# NOTE: Removing those, in favor of just using the registered Pixel<...>-v? variant.
# force_pixel_observations: bool = False
# """ Wether to use the "pixel" version of `self.dataset`.
# When `False`, does nothing.
# When `True`, will do one of the following, depending on the choice of environment:
# - For classic control envs, it adds a `PixelObservationsWrapper` to the env.
# - For atari envs:
# - If `self.dataset` is a regular atari env (e.g. "ALE/Breakout-v5"), does nothing.
# - if `self.dataset` is the 'RAM' version of an atari env, raises an error.
# - For mujoco envs, this raises a NotImplementedError, as we don't yet know how to
# make a pixel-version the Mujoco Envs.
# - For other envs:
# - If the environment's observation space appears to be image-based, an error
# will be raised.
# - If the environment's observation space doesn't seem to be image-based, does
# nothing.
# """
# force_state_observations: bool = False
# """ Wether to use the "state" version of `self.dataset`.
# When `False`, does nothing.
# When `True`, will do one of the following, depending on the choice of environment:
# - For classic control envs, it does nothing, as they are already state-based.
# - TODO: For atari envs, the 'RAM' version of the chosen env will be used.
# - For mujoco envs, it doesn nothing, as they are already state-based.
# - For other envs, if this is set to True, then
# - If the environment's observation space appears to be image-based, an error
# will be raised.
# - If the environment's observation space doesn't seem to be image-based, does
# nothing.
# """
# NOTE: Removing this from the continual setting.
# By default 1 for this setting, meaning that the context is a linear interpolation
# between the start context (usually the default task for the environment) and a
# randomly sampled task.
# nb_tasks: int = field(5, alias=["n_tasks", "num_tasks"])
# Wether to convert the observations / actions / rewards of the envs (and their
# spaces) such that they return Tensors rather than numpy arrays.
# TODO: Maybe switch this to True by default?
prefer_tensors: bool = False
# Path to a json file from which to read the train task schedule.
train_task_schedule_path: Optional[Path] = None
# Path to a json file from which to read the validation task schedule.
val_task_schedule_path: Optional[Path] = None
# Path to a json file from which to read the test task schedule.
test_task_schedule_path: Optional[Path] = None
# Wether observations from the environments whould include
# the end-of-episode signal. Only really useful if your method will iterate
# over the environments in the dataloader style
# (as does the baseline method).
add_done_to_observations: bool = False
# The maximum number of steps per episode. When None, there is no limit.
max_episode_steps: Optional[int] = None
# Transforms to be applied by default to the observatons of the train/valid/test
# environments.
transforms: List[Transforms] = list_field()
# Transforms to be applied to the training environment, in addition to those already
# in `transforms`.
train_transforms: List[Transforms] = list_field()
# Transforms to be applied to the validation environment, in addition to those
# already in `transforms`.
val_transforms: List[Transforms] = list_field()
# Transforms to be applied to the testing environment, in addition to those already
# in `transforms`.
test_transforms: List[Transforms] = list_field()
# When True, a Monitor-like wrapper will be applied to the training environment
# and monitor the 'online' performance during training. Note that in SL, this will
# also cause the Rewards (y) to be withheld until actions are passed to the `send`
# method of the Environment.
monitor_training_performance: bool = flag(True)
#
# -------- Fields below don't have corresponding command-line arguments. -----------
#
train_task_schedule: Dict[int, Dict[str, float]] = dict_field(cmd=False)
val_task_schedule: Dict[int, Dict[str, float]] = dict_field(cmd=False)
test_task_schedule: Dict[int, Dict[str, float]] = dict_field(cmd=False)
# TODO: Naming is a bit inconsistent, using `valid` here, whereas we use `val`
# elsewhere.
train_wrappers: List[Callable[[gym.Env], gym.Env]] = list_field(cmd=False)
val_wrappers: List[Callable[[gym.Env], gym.Env]] = list_field(cmd=False)
test_wrappers: List[Callable[[gym.Env], gym.Env]] = list_field(cmd=False)
# keyword arguments to be passed to the base environment through gym.make(base_env, **kwargs).
base_env_kwargs: Dict = dict_field(cmd=False)
batch_size: Optional[int] = field(default=None, cmd=False)
num_workers: Optional[int] = field(default=None, cmd=False)
# Maximum number of training steps per task.
# NOTE: In this particular setting there aren't clear 'tasks' to speak of.
train_steps_per_task: Optional[int] = None
# Number of test steps per task.
# NOTE: In this particular setting there aren't clear 'tasks' to speak of.
test_steps_per_task: Optional[int] = None
# # Deprecated: use `train_max_steps` instead.
# max_steps: Optional[int] = deprecated_property(redirects_to="train_max_steps")
# # Deprecated: use `test_max_steps` instead.
# test_steps: Optional[int] = deprecated_property(redirects_to="test_max_steps")
# # Deprecated, use `train_steps_per_task` instead.
# steps_per_task: Optional[int] = deprecated_property(redirects_to="train_steps_per_task")
def __post_init__(self):
defaults = {f.name: f.default for f in fields(self)}
super().__post_init__()
# TODO: Fix nnoying little issues with this trio of fields that are interlinked:
if self.test_steps_per_task is not None:
# We need set the value of self.test_max_steps and self.test_steps_per_task
if self.test_task_schedule and max(self.test_task_schedule) != len(
self.test_task_schedule
):
self.test_max_steps = max(self.test_task_schedule)
elif self.test_max_steps == defaults["test_max_steps"]:
self.test_max_steps = self.nb_tasks * self.test_steps_per_task
else:
self.nb_tasks = self.test_max_steps // self.test_steps_per_task
# if self.max_steps is not None:
# warnings.warn(DeprecationWarning("'max_steps' is deprecated, use 'train_max_steps' instead."))
# self.train_max_steps = self.max_steps
# if self.test_steps is not None:
# warnings.warn(DeprecationWarning("'test_steps' is deprecated, use 'test_max_steps' instead."))
if self.dataset and self.dataset not in self.available_datasets.values():
try:
self.dataset = find_matching_dataset(self.available_datasets, self.dataset)
except NotImplementedError as e:
logger.info(f"Will try to use custom dataset {self.dataset}.")
except Exception as e:
if getattr(self, "train_envs", []):
logger.info(f"Using custom environments / datasets.")
else:
raise gym.error.UnregisteredEnv(
f"({e}) The chosen dataset/environment ({self.dataset}) isn't in the dict of "
f"available datasets/environments, and a task schedule was not passed, "
f"so this Setting ({type(self).__name__}) doesn't know how to create "
f"tasks for that env!\n"
f"Supported envs:\n"
+ ("\n".join(f"- {k}: {v}" for k, v in self.available_datasets.items()))
)
# The ids of the train/valid/test environments.
self.train_dataset: Union[str, Callable[[], gym.Env]] = self.train_dataset or self.dataset
self.val_dataset: Union[str, Callable[[], gym.Env]] = self.val_dataset or self.dataset
self.test_dataset: Union[str, Callable[[], gym.Env]] = self.test_dataset or self.dataset
logger.info(f"Chosen dataset: {textwrap.shorten(str(self.train_dataset), 50)}")
# # The environment 'ID' associated with each 'simple name'.
# self.train_dataset_id: str = self._get_dataset_id(self.train_dataset)
# self.val_dataset_id: str = self._get_dataset_id(self.val_dataset)
# self.train_dataset_id: str = self._get_dataset_id(self.train_dataset)
# Set the number of tasks depending on the increment, and vice-versa.
# (as only one of the two should be used).
assert self.train_max_steps, "assuming this should always be set, for now."
# Load the task schedules from the corresponding files, if present.
if self.train_task_schedule_path:
self.train_task_schedule = _load_task_schedule(self.train_task_schedule_path)
self.nb_tasks = len(self.train_task_schedule) - 1
if self.val_task_schedule_path:
self.val_task_schedule = _load_task_schedule(self.val_task_schedule_path)
if self.test_task_schedule_path:
self.test_task_schedule = _load_task_schedule(self.test_task_schedule_path)
self.train_env: gym.Env
self.valid_env: gym.Env
self.test_env: gym.Env
# Temporary environments which are created and used only for creating the task
# schedules and the 'base' observation spaces, and then closed right after.
self._temp_train_env: Optional[gym.Env] = self._make_env(self.train_dataset)
self._temp_val_env: Optional[gym.Env] = None
self._temp_test_env: Optional[gym.Env] = None
# Create the task schedules, using the 'task sampling' function from `tasks.py`.
# TODO: PLEASE HELP I'm going mad because of the validation logic for these
# fields!!
if not self.train_task_schedule:
self.train_task_schedule = self.create_train_task_schedule()
elif max(self.train_task_schedule) == len(self.train_task_schedule) - 1:
# If the keys correspond to the task ids rather than the steps:
if self.nb_tasks in [defaults["nb_tasks"], None]:
self.nb_tasks = len(self.train_task_schedule) - 1
if self.nb_tasks < 1:
raise RuntimeError(f"Need at least 2 entries in the task schedule!")
logger.info(
f"Assuming that the last entry in the provided task schedule is "
f"the final state, and that there are {self.nb_tasks} tasks. "
)
self.train_steps_per_task = (
self.train_steps_per_task or self.train_max_steps // self.nb_tasks
)
new_keys = np.linspace(
0, self.train_max_steps, self.nb_tasks + 1, endpoint=True, dtype=int
).tolist()
assert len(new_keys) == len(self.train_task_schedule)
self.train_task_schedule = type(self.train_task_schedule)(
{
new_key: self.train_task_schedule[old_key]
for new_key, old_key in zip(new_keys, sorted(self.train_task_schedule.keys()))
}
)
elif self.smooth_task_boundaries:
# We have a task schedule for Continual RL.
if self.train_max_steps == defaults["train_max_steps"]:
self.train_max_steps = max(self.train_task_schedule)
if self.smooth_task_boundaries:
# NOTE: Need to have an entry at the final step
last_task_step = max(self.train_task_schedule.keys())
last_task = self.train_task_schedule[last_task_step]
if self.train_max_steps not in self.train_task_schedule:
# FIXME Duplicating the last task for now?
self.train_task_schedule[self.train_max_steps] = last_task
if 0 not in self.train_task_schedule.keys():
raise RuntimeError(
"`train_task_schedule` needs an entry at key 0, as the initial state"
)
if self.train_max_steps != max(self.train_task_schedule):
if self.train_max_steps in [defaults["train_max_steps"], None]:
# TODO: This might be wrong no?
self.train_max_steps = max(self.train_task_schedule)
logger.info(f"Setting `train_max_steps` to {self.train_max_steps}")
elif self.smooth_task_boundaries:
raise RuntimeError(
f"For now, the train task schedule needs to have a value at key "
f"`train_max_steps` ({self.train_max_steps})."
)
else:
last_task_step = max(self.train_task_schedule)
last_task = self.train_task_schedule[last_task_step]
logger.debug("Using the last task as the final state.")
self.train_task_schedule[self.train_max_steps] = last_task
if not self.val_task_schedule:
# Avoid creating an additional env, just reuse the train_temp_env.
self._temp_val_env = (
self._temp_train_env
if self.val_dataset == self.train_dataset
else self._make_env(self.val_dataset)
)
self.val_task_schedule = self.create_val_task_schedule()
elif max(self.val_task_schedule) == len(self.val_task_schedule) - 1:
# If the keys correspond to the task ids rather than the transition steps
expected_nb_tasks = len(self.val_task_schedule)
old_keys = sorted(self.val_task_schedule.keys())
new_keys = np.linspace(
0, self.train_max_steps, self.nb_tasks + 1, endpoint=True, dtype=int
).tolist()
assert len(new_keys) == len(self.train_task_schedule)
self.val_task_schedule = type(self.val_task_schedule)(
{
new_key: self.val_task_schedule[old_key]
for new_key, old_key in zip(new_keys, old_keys)
}
)
if not self.test_task_schedule:
self._temp_test_env = (
self._temp_train_env
if self.test_dataset == self.train_dataset
else self._make_env(self.val_dataset)
)
self.test_task_schedule = self.create_test_task_schedule()
elif max(self.test_task_schedule) == len(self.test_task_schedule) - 1:
# If the keys correspond to the task ids rather than the transition steps
old_keys = sorted(self.test_task_schedule.keys())
new_keys = np.linspace(
0, self.test_max_steps, self.nb_tasks + 1, endpoint=True, dtype=int
).tolist()
self.test_task_schedule = type(self.test_task_schedule)(
{
new_key: self.test_task_schedule[old_key]
for new_key, old_key in zip(new_keys, old_keys)
}
)
if 0 not in self.test_task_schedule.keys():
raise RuntimeError("`test_task_schedule` needs an entry at key 0, as the initial state")
if self.test_max_steps != max(self.test_task_schedule):
if self.test_max_steps == defaults["test_max_steps"]:
self.test_max_steps = max(self.test_task_schedule)
logger.info(f"Setting `test_max_steps` to {self.test_max_steps}")
elif self.smooth_task_boundaries:
raise RuntimeError(
f"For now, the test task schedule needs to have a value at key "
f"`test_max_steps` ({self.test_max_steps}). "
)
# Close the temporary environments.
# NOTE: Avoid closing the envs for now in case 'live' envs were passed to the Setting.
if self._temp_train_env:
# self._temp_train_env.close()
pass
if self._temp_val_env and self._temp_val_env is not self._temp_train_env:
# self._temp_val_env.close()
pass
if self._temp_test_env and self._temp_test_env is not self._temp_train_env:
# self._temp_test_env.close()
pass
train_task_lengths: List[int] = [
task_b_step - task_a_step
for task_a_step, task_b_step in pairwise(sorted(self.train_task_schedule.keys()))
]
# TODO: This will crash if nb_tasks is 1, right?
# train_max_steps = train_last_boundary + train_task_lengths[-1]
test_task_lengths: List[int] = [
task_b_step - task_a_step
for task_a_step, task_b_step in pairwise(sorted(self.test_task_schedule.keys()))
]
if not (
len(self.train_task_schedule)
== len(self.test_task_schedule)
== len(self.val_task_schedule)
):
raise RuntimeError(
"Training, validation and testing task schedules should have the same "
"number of items for now."
)
train_last_boundary = max(set(self.train_task_schedule.keys()) - {self.train_max_steps})
test_last_boundary = max(set(self.test_task_schedule.keys()) - {self.test_max_steps})
# TODO: Really annoying validation logic for these fields needs to be simplified
# somehow.
# if self.train_steps_per_task is None:
# # if self.nb_tasks
# train_steps_per_task = self.train_max_steps // self.nb_tasks
# if self.train_task_schedule:
# task_lengths = [
# b - a for a, b in pairwise(self.train_task_schedule.keys())
# ]
# if any(
# abs(task_length - train_steps_per_task) > 1
# for task_length in task_lengths
# ):
# raise RuntimeError(
# f"Trying to set a value for `train_steps_per_task`, but "
# f"the keys of the task schedule are either uneven, or not "
# f"equal to {train_steps_per_task}: "
# f"task schedule keys: {self.train_task_schedule.keys()}"
# )
# self.train_steps_per_task = train_steps_per_task
# FIXME: This is quite confusing:
expected_nb_tasks = len(self.train_task_schedule) - 1
# if (
# self.train_max_steps not in [defaults["train_max_steps"], None]
# and self.train_max_steps == max(self.train_task_schedule)
# ) or self.smooth_task_boundaries:
# expected_nb_tasks -= 1
if self.nb_tasks != expected_nb_tasks:
if self.nb_tasks in [None, defaults["nb_tasks"]]:
assert len(self.train_task_schedule) == len(self.test_task_schedule)
self.nb_tasks = len(self.train_task_schedule) - 1
logger.info(f"`nb_tasks` set to {self.nb_tasks} based on the task schedule")
else:
raise RuntimeError(
f"The passed number of tasks ({self.nb_tasks}) is inconsistent "
f"with train_max_steps ({self.train_max_steps}) and the "
f"passed task schedule (with keys "
f"{self.train_task_schedule.keys()}): "
f"Expected nb_tasks to be None or {expected_nb_tasks}."
)
if not train_task_lengths:
assert not test_task_lengths
assert expected_nb_tasks == 1
assert self.train_max_steps > 0
assert self.test_max_steps > 0
train_max_steps = self.train_max_steps
test_max_steps = self.test_max_steps
else:
train_max_steps = sum(train_task_lengths)
test_max_steps = sum(test_task_lengths)
# train_max_steps = round(train_last_boundary + train_task_lengths[-1])
# test_max_steps = round(test_last_boundary + test_task_lengths[-1])
if self.train_max_steps != train_max_steps:
if self.train_max_steps == defaults["train_max_steps"]:
self.train_max_steps = train_max_steps
else:
raise RuntimeError(
f"Value of train_max_steps ({self.train_max_steps}) is "
f"inconsistent with the given train task schedule, which has "
f"the last task boundary at step {train_last_boundary}, with "
f"task lengths of {train_task_lengths}, as it suggests the maximum "
f"total number of steps to be {train_last_boundary} + "
f"{train_task_lengths[-1]} => {train_max_steps}!"
)
if self.test_max_steps != test_max_steps:
if self.test_max_steps == defaults["test_max_steps"]:
self.test_max_steps = test_max_steps
else:
raise RuntimeError(
f"Value of test_max_steps ({self.test_max_steps}) is "
f"inconsistent with the given test task schedule (which has keys "
f"{self.test_task_schedule.keys()}). Expected the last key to be "
f"{test_max_steps}"
)
if self.train_steps_per_task is None:
self.train_steps_per_task = self.train_max_steps // self.nb_tasks
# TODO: Fix these annoying interactions once and for all.
assert self.train_max_steps // self.nb_tasks == self.train_steps_per_task, (
self.train_max_steps,
self.nb_tasks,
self.train_steps_per_task,
self.train_task_schedule.keys(),
)
if self.test_steps_per_task is None:
self.test_steps_per_task = self.test_max_steps // self.nb_tasks
assert self.test_max_steps // self.nb_tasks == self.test_steps_per_task, (
self.test_max_steps,
self.nb_tasks,
self.test_steps_per_task,
self.test_task_schedule.keys(),
)
def create_train_task_schedule(self) -> TaskSchedule:
# change_steps = [0, self.train_max_steps]
# Ex: nb_tasks == 5, train_max_steps = 10_000:
# change_steps = [0, 2_000, 4_000, 6_000, 8_000, 10_000]
if self.train_steps_per_task is not None:
train_max_steps = self.train_steps_per_task * self.nb_tasks
# if self.smooth_task_boundaries:
# train_max_steps = self.train_steps_per_task * self.nb_tasks
# else:
# train_max_steps = self.train_steps_per_task * self.nb_tasks
else:
train_max_steps = self.train_max_steps
assert self.nb_tasks is not None
task_schedule_keys = np.linspace(
0, train_max_steps, self.nb_tasks + 1, endpoint=True, dtype=int
).tolist()
return self.create_task_schedule(
temp_env=self._temp_train_env,
change_steps=task_schedule_keys,
# # TODO: Add properties for the train/valid/test seeds?
seed=self.config.seed if self.config else 123,
)
def create_val_task_schedule(self) -> TaskSchedule:
# Always the same as train task schedule for now.
return self.train_task_schedule.copy()
def create_test_task_schedule(self) -> TaskSchedule[ContinuousTask]:
# Re-scale the steps in the task schedule based on self.test_max_steps
# NOTE: Using the same task schedule as in training and validation for now.
if self.train_task_schedule:
nb_tasks = len(self.train_task_schedule) - 1
else:
nb_tasks = self.nb_tasks
# TODO: Do we want to re-allow the `test_steps_per_task` argument?
if self.test_steps_per_task is not None:
test_max_steps = self.test_steps_per_task * nb_tasks
else:
test_max_steps = self.test_max_steps
test_task_schedule_keys = np.linspace(
0, test_max_steps, nb_tasks + 1, endpoint=True, dtype=int
).tolist()
return {
step: task
for step, task in zip(test_task_schedule_keys, self.train_task_schedule.values())
}
def create_task_schedule(
self,
temp_env: gym.Env,
change_steps: List[int],
seed: int = None,
) -> Dict[int, Dict]:
"""Create the task schedule, which maps from a step to the changes that
will occur in the environment when that step is reached.
Uses the provided `temp_env` to generate the random tasks at the steps
given in `change_steps` (a list of integers).
Returns a dictionary mapping from integers (the steps) to the changes
that will occur in the env at that step.
TODO: For now in ContinualRL we use an interpolation of a dict of attributes
to be set on the unwrapped env, but in IncrementalRL it is possible to pass
callables to be applied on the environment at a given timestep.
"""
task_schedule: Dict[int, Dict] = {}
# TODO: Make it possible to use something other than steps as keys in the task
# schedule, something like a NamedTuple[int, DeltaType], e.g. Episodes(10) or Steps(10)
# something like that!
# IDEA: Even fancier, we could use a TimeDelta to say "do one hour of task 0"!!
for step in change_steps:
# TODO: Pass wether its for training/validation/testing?
task = type(self)._task_sampling_function(
temp_env,
step=step,
change_steps=change_steps,
seed=seed,
)
task_schedule[step] = task
return task_schedule
@property
def observation_space(self) -> TypedDictSpace:
"""The un-batched observation space, based on the choice of dataset and
the transforms at `self.transforms` (which apply to the train/valid/test
environments).
The returned spaces is a TypedDictSpace, with the following properties/items:
- `x`: observation space (e.g. `Image` space)
- `task_labels`: Union[Discrete, Sparse[Discrete]]
The task labels for each sample when task labels are available,
otherwise the task labels space is `Sparse`, and entries will be `None`.
"""
# TODO: Is it right that we set the observation space on the Setting to be the
# observation space of the current train environment?
# In what situation could there be any difference between those?
# - Changing the 'transforms' attributes after training?
# if self.train_env is not None:
# # assert self._observation_space == self.train_env.observation_space
# return self.train_env.observation_space
if isinstance(self._temp_train_env.observation_space, TypedDictSpace):
x_space = self._temp_train_env.observation_space.x
task_label_space = self._temp_train_env.observation_space.task_labels
else:
x_space = self._temp_train_env.observation_space
# apply the transforms to the observation space.
for transform in self.transforms:
x_space = transform(x_space)
task_label_space = self.task_label_space
done_space = spaces.Box(0, 1, shape=(), dtype=bool)
if not self.add_done_to_observations:
done_space = Sparse(done_space, sparsity=1)
observation_space = TypedDictSpace(
x=x_space,
task_labels=task_label_space,
done=done_space,
dtype=self.Observations,
)
if self.prefer_tensors:
observation_space = add_tensor_support(observation_space)
assert isinstance(observation_space, TypedDictSpace)
return observation_space
@property
def task_label_space(self) -> gym.Space:
# TODO: Explore an alternative design for the task sampling, based more around
# gym spaces rather than the generic function approach that's currently used?
# FIXME: This isn't really elegant, there isn't a `nb_tasks` attribute on the
# ContinualRLSetting anymore, so we have to do a bit of a hack.. Would be
# cleaner to maybe put this in the assumption class, under
# `self.task_label_space`?
task_label_space = spaces.Box(0.0, 1.0, shape=())
if not self.task_labels_at_train_time or not self.task_labels_at_test_time:
sparsity = 1
if self.task_labels_at_train_time ^ self.task_labels_at_test_time:
# We have task labels "50%" of the time, ish:
sparsity = 0.5
task_label_space = Sparse(task_label_space, sparsity=sparsity)
return task_label_space
@property
def action_space(self) -> gym.Space:
# TODO: Convert the action/reward spaces so they also use TypedDictSpace (even
# if they just have one item), so that it correctly reflects the objects that
# the envs accept.
y_pred_space = self._temp_train_env.action_space
# action_space = TypedDictSpace(y_pred=y_pred_space, dtype=self.Actions)
return y_pred_space
@property
def reward_space(self) -> gym.Space:
reward_range = self._temp_train_env.reward_range
return getattr(
self._temp_train_env,
"reward_space",
spaces.Box(reward_range[0], reward_range[1], shape=()),
)
def apply(self, method: Method, config: Config = None) -> "ContinualRLSetting.Results":
"""Apply the given method on this setting to producing some results."""
# Use the supplied config, or parse one from the arguments that were
# used to create `self`.
self.config = config or self._setup_config(method)
logger.debug(f"Config: {self.config}")
# TODO: Test to make sure that this doesn't cause any other bugs with respect to
# the display of stuff:
# Call this method, which creates a virtual display if necessary.
self.config.get_display()
# TODO: Should we really overwrite the method's 'config' attribute here?
if not getattr(method, "config", None):
method.config = self.config
# TODO: Remove `Setting.configure(method)` entirely, from everywhere,
# and use the `prepare_data` or `setup` methods instead (since these
# `configure` methods aren't using the `method` anyway.)
method.configure(setting=self)
# BUG This won't work if the task schedule uses callables as the values (as
# they aren't json-serializable.)
if self.stationary_context:
logger.info(
"Train tasks: " + json.dumps(list(self.train_task_schedule.values()), indent="\t")
)
else:
try:
logger.info(
"Train task schedule:" + json.dumps(self.train_task_schedule, indent="\t")
)
# BUG: Sometimes the task schedule isnt json-serializable!
except TypeError:
logger.info("Train task schedule: ")
for key, value in self.train_task_schedule.items():
logger.info(f"{key}: {value}")
if self.config.debug:
logger.debug("Test task schedule:" + json.dumps(self.test_task_schedule, indent="\t"))
# Run the Training loop (which is defined in ContinualAssumption).
results = self.main_loop(method)
logger.info("Results summary:")
logger.info(results.to_log_dict())
logger.info(results.summary())
method.receive_results(self, results=results)
return results
# Run the Test loop (which is defined in IncrementalAssumption).
# results: RlResults = self.test_loop(method)
def setup(self, stage: str = None) -> None:
# Called before the start of each task during training, validation and
# testing.
super().setup(stage=stage)
if stage in {"fit", None}:
self.train_wrappers = self.create_train_wrappers()
if stage in {"validate", None}:
self.valid_wrappers = self.create_valid_wrappers()
elif stage in {"test", None}:
self.test_wrappers = self.create_test_wrappers()
def prepare_data(self, *args, **kwargs) -> None:
# We don't really download anything atm.
if self.config is None:
self.config = Config()
super().prepare_data(*args, **kwargs)
def train_dataloader(
self, batch_size: int = None, num_workers: int = None
) -> ActiveEnvironment:
"""Create a training gym.Env/DataLoader for the current task.
Parameters
----------
batch_size : int, optional
The batch size, which in this case is the number of environments to
run in parallel. When `None`, the env won't be vectorized. Defaults
to None.
num_workers : int, optional
The number of workers (processes) to use in the vectorized env. When
None, the envs are run in sequence, which could be very slow. Only
applies when `batch_size` is not None. Defaults to None.
Returns
-------
GymDataLoader
A (possibly vectorized) environment/dataloader for the current task.
"""
if not self.has_prepared_data:
self.prepare_data()
# NOTE: We actually want to call setup every time, so we re-create the
# wrappers for each task.
self.setup("fit")
batch_size = batch_size or self.batch_size
num_workers = num_workers if num_workers is not None else self.num_workers
train_seed = self.config.seed if self.config else None
env_factory = partial(
self._make_env,
base_env=self.train_dataset,
wrappers=self.train_wrappers,
**self.base_env_kwargs,
)
env_dataloader = self._make_env_dataloader(
env_factory,
batch_size=batch_size,
num_workers=num_workers,
max_steps=self.steps_per_phase,
max_episodes=self.train_max_episodes,
seed=train_seed,
)
if self.monitor_training_performance:
# NOTE: It doesn't always make sense to log stuff with the current task ID!
wandb_prefix = "Train"
if self.known_task_boundaries_at_train_time:
wandb_prefix += f"/Task {self.current_task_id}"
env_dataloader = MeasureRLPerformanceWrapper(env_dataloader, wandb_prefix=wandb_prefix)
if self.config.render and batch_size is None:
env_dataloader = RenderEnvWrapper(env_dataloader)
self.train_env = env_dataloader
# BUG: There is a mismatch between the train env's observation space and the
# shape of its observations.
# self.observation_space = self.train_env.observation_space
return self.train_env
def val_dataloader(self, batch_size: int = None, num_workers: int = None) -> Environment:
"""Create a validation gym.Env/DataLoader for the current task.
Parameters
----------
batch_size : int, optional
The batch size, which in this case is the number of environments to
run in parallel. When `None`, the env won't be vectorized. Defaults
to None.
num_workers : int, optional
The number of workers (processes) to use in the vectorized env. When
None, the envs are run in sequence, which could be very slow. Only
applies when `batch_size` is not None. Defaults to None.
Returns
-------
GymDataLoader
A (possibly vectorized) environment/dataloader for the current task.
"""
if not self.has_prepared_data:
self.prepare_data()
# Need to force this to happen every time, because the wrappers might change
# between tasks.
self._has_setup_validate = False
self.setup("validate")
env_factory = partial(
self._make_env,
base_env=self.val_dataset,
wrappers=self.valid_wrappers,
**self.base_env_kwargs,
)
valid_seed = self.config.seed if self.config else None
env_dataloader = self._make_env_dataloader(
env_factory,
batch_size=batch_size or self.batch_size,
num_workers=num_workers if num_workers is not None else self.num_workers,
max_steps=self.steps_per_phase,
# TODO: Create a new property to limit validation episodes?
max_episodes=self.train_max_episodes,
seed=valid_seed,
)
if self.monitor_training_performance:
# NOTE: We also add it here, just so it logs metrics to wandb.
# NOTE: It doesn't always make sense to log stuff with the current task ID!
wandb_prefix = "Valid"
if self.known_task_boundaries_at_train_time:
wandb_prefix += f"/Task {self.current_task_id}"
env_dataloader = MeasureRLPerformanceWrapper(env_dataloader, wandb_prefix=wandb_prefix)
self.val_env = env_dataloader
return self.val_env
def test_dataloader(self, batch_size: int = None, num_workers: int = None) -> TestEnvironment:
"""Create the test 'dataloader/gym.Env' for all tasks.
NOTE: This test environment isn't just for the current task, it actually
contains the sequence of all tasks. This is different than the train or
validation environments, since if the task labels are available at train
time, then calling train/valid_dataloader` returns the envs for the
current task only, and the `.fit` method is called once per task.
This environment is also different in that it is wrapped with a Monitor,
which we might eventually use to save the results/gifs/logs of the
testing runs.
Parameters
----------
batch_size : int, optional
The batch size, which in this case is the number of environments to
run in parallel. When `None`, the env won't be vectorized. Defaults
to None.
num_workers : int, optional
The number of workers (processes) to use in the vectorized env. When
None, the envs are run in sequence, which could be very slow. Only
applies when `batch_size` is not None. Defaults to None.
Returns
-------
TestEnvironment
A testing environment which keeps track of the performance of the
actor and accumulates logs/statistics that are used to eventually
create the 'Result' object.
"""
if not self.has_prepared_data:
self.prepare_data()
# NOTE: New for PL: The call doesn't go through if self._has_setup_test is True
# Need to force this to happen every time, because the wrappers might change
# between tasks.
self._has_setup_test = False
self.setup("test")
# BUG: gym.wrappers.Monitor doesn't want to play nice when applied to
# Vectorized env, it seems..
# FIXME: Remove this when the Monitor class works correctly with
# batched environments.
batch_size = batch_size or self.batch_size
if batch_size is not None:
logger.warning(
UserWarning(
colorize(
f"WIP: Only support batch size of `None` (i.e., a single env) "
f"for the test environments of RL Settings at the moment, "
f"because the Monitor class from gym doesn't work with "
f"VectorEnvs. (batch size was {batch_size})",