-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathremote_trainer.py
executable file
·109 lines (85 loc) · 4.22 KB
/
remote_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import pickle
import uuid
from copy import deepcopy
from typing import List
import ray
from ray.rllib.agents import with_common_config
from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG as config_ppo
from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.common import STEPS_TRAINED_COUNTER, LEARNER_INFO, \
WORKER_UPDATE_TIMER, LEARN_ON_BATCH_TIMER, LOAD_BATCH_TIMER, \
_get_global_vars, _check_sample_batch_type, _get_shared_metrics
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches, StandardizeFields, SelectExperiences
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, MultiAgentBatch
from ray.rllib.utils.typing import PolicyID, SampleBatchType
from ap_rllib.helpers import filter_pickleable, dict_get_any_value, save_gym_space, unlink_ignore_error
from frankenstein.remote_communicator import RemoteHTTPPickleCommunicator
from gym_compete_rllib.load_gym_compete_policy import nets_to_weights, load_weights_from_vars
# the function that does the training
from frankenstein.remote_trainer_with_communicator import train_external
# copied from TrainTFMultiGPU and modified
class ExternalTrainOp:
"""Train using the function above externally."""
def __init__(self,
workers: WorkerSet,
config: dict,
policies: List[PolicyID] = frozenset([])):
self.workers = workers
self.policies = policies or workers.local_worker().policies_to_train
self.config = config
def __call__(self,
samples: SampleBatchType) -> (SampleBatchType, List[dict]):
_check_sample_batch_type(samples)
# Handle everything as if multiagent
if isinstance(samples, SampleBatch):
samples = MultiAgentBatch({
DEFAULT_POLICY_ID: samples
}, samples.count)
# data: samples
metrics = _get_shared_metrics()
load_timer = metrics.timers[LOAD_BATCH_TIMER]
learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
# calling train_external to train with stable baselines
p = {k: self.workers.local_worker().get_policy(k) for k in self.policies}
info = train_external(policies=p, samples=samples, config=self.config)
load_timer.push_units_processed(samples.count)
learn_timer.push_units_processed(samples.count)
fetches = info
metrics.counters[STEPS_TRAINED_COUNTER] += samples.count
metrics.info[LEARNER_INFO] = fetches
if self.workers.remote_workers():
with metrics.timers[WORKER_UPDATE_TIMER]:
weights = ray.put(self.workers.local_worker().get_weights(
self.policies))
for e in self.workers.remote_workers():
e.set_weights.remote(weights, _get_global_vars())
# Also update global vars of the local worker.
self.workers.local_worker().set_global_vars(_get_global_vars())
return samples, fetches
def execution_plan(workers, config):
"""Execution plan which calls ExternalTrainOp."""
rollouts = ParallelRollouts(workers, mode="bulk_sync")
# Collect large batches of relevant experiences & standardize.
rollouts = rollouts.for_each(
SelectExperiences(workers.trainable_policies()))
rollouts = rollouts.combine(
ConcatBatches(min_batch_size=config["train_batch_size"]))
rollouts = rollouts.for_each(StandardizeFields(["advantages"]))
train_op = rollouts.for_each(
ExternalTrainOp(workers=workers,
config=config))
return StandardMetricsReporting(train_op, workers, config)
# creating ExternalTrainer
DEFAULT_CONFIG = deepcopy(config_ppo)
# default values, can be changed using rllib configuration
DEFAULT_CONFIG.update({'http_remote_port': "http://127.0.0.1:50001", 'run_uid': 'aba', 'tmp_dir': '/tmp/'})
DEFAULT_CONFIG = with_common_config(DEFAULT_CONFIG)
ExternalTrainer = build_trainer(
name="External",
default_config=DEFAULT_CONFIG,
default_policy=PPOTFPolicy,
execution_plan=execution_plan)