Skip to content

Commit

Permalink
align terminology of BPD with experiment scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
rradules committed May 29, 2024
1 parent 4b04824 commit 3dd6dfb
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
8 changes: 4 additions & 4 deletions momaland/learning/iql/tabular_bpd.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def normalize_objective_rewards(self, reward, reward_scheme):
np.array: the normalized reward
"""
# Set the normalization constants
if reward_scheme == "local":
if reward_scheme == "individual":
cap_min, cap_max, mix_min, mix_max = self.l_cap_min, self.l_cap_max, self.l_mix_min, self.l_mix_max
elif reward_scheme == "global":
elif reward_scheme == "team":
cap_min, cap_max, mix_min, mix_max = self.g_cap_min, self.g_cap_max, self.g_mix_min, self.g_mix_max
else:
raise ValueError(f"Unknown reward scheme: {reward_scheme}")
Expand Down Expand Up @@ -108,15 +108,15 @@ def step(self, actions):
section_agent_types[self._state[i]][self._types[i]] += 1
g_capacity = _global_capacity_reward(self.resource_capacities, section_consumptions)
g_mixture = _global_mixture_reward(section_agent_types)
g_capacity_norm, g_mixture_norm = self.normalize_objective_rewards(np.array([g_capacity, g_mixture]), "global")
g_capacity_norm, g_mixture_norm = self.normalize_objective_rewards(np.array([g_capacity, g_mixture]), "team")
infos = {
agent: {"g_cap": g_capacity, "g_mix": g_mixture, "g_cap_norm": g_capacity_norm, "g_mix_norm": g_mixture_norm}
for agent in self.possible_agents
}

# Normalize the rewards
for agent in self.possible_agents:
rewards[agent] = self.normalize_objective_rewards(rewards[agent], self.reward_scheme)
rewards[agent] = self.normalize_objective_rewards(rewards[agent], self.reward_mode)

return observations, rewards, terminations, truncations, infos

Expand Down
12 changes: 6 additions & 6 deletions momaland/learning/iql/train_iql_bpd.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def compute_normalization_constants(num_agents, sections, capacity, type_distrib
# Maximum local capacity is achieved when there are 'capacity' agents in the section
max_cap_local = _local_capacity_reward(capacity, capacity)
cap_min = 0.0
cap_max = max_cap_local if reward_scheme == "local" else max_cap_global
cap_max = max_cap_local if reward_scheme == "individual" else max_cap_global

# Mixture
# Maximum global mixture: one agent of each type in each section, except one where all other agents are
Expand All @@ -52,7 +52,7 @@ def compute_normalization_constants(num_agents, sections, capacity, type_distrib
# Maximum local mixture is achieved when there is one agent of each type in the section
max_mix_local = _local_mixture_reward([1, 1])
mix_min = 0.0
mix_max = max_mix_local if reward_scheme == "local" else max_mix_global
mix_max = max_mix_local if reward_scheme == "individual" else max_mix_global

return cap_min, cap_max, mix_min, mix_max

Expand Down Expand Up @@ -94,7 +94,7 @@ def parse_args():
parser.add_argument('--position-distribution', type=float, nargs=5, default=[0., 0.5, 0., 0.5, 0.], )
parser.add_argument('--sections', type=int, default=5, )
parser.add_argument('--capacity', type=int, default=3, )
parser.add_argument('--reward-scheme', type=str, default="local", help="the reward scheme to use")
parser.add_argument('--reward-scheme', type=str, default="individual", help="the reward scheme to use")

args = parser.parse_args()
args.time = time.time()
Expand All @@ -112,13 +112,13 @@ def parse_args():
"position_distribution": args.position_distribution,
"sections": args.sections,
"capacity": args.capacity,
"reward_scheme": args.reward_scheme,
"reward_mode": args.reward_scheme,
# Normalization constants
"local_constants": compute_normalization_constants(
args.num_agents, args.sections, args.capacity, args.type_distribution, "local"
args.num_agents, args.sections, args.capacity, args.type_distribution, "individual"
),
"global_constants": compute_normalization_constants(
args.num_agents, args.sections, args.capacity, args.type_distribution, "global"
args.num_agents, args.sections, args.capacity, args.type_distribution, "team"
),
}

Expand Down

0 comments on commit 3dd6dfb

Please sign in to comment.