Skip to content

Commit

Permalink
Merge pull request #256 from bokveizen/fanchen_250306
Browse files Browse the repository at this point in the history
[Bug fixes for #255] Added environment config files for FLP and MCP
  • Loading branch information
fedebotu authored Mar 6, 2025
2 parents cb72927 + ce7e96c commit d1d238f
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 1 deletion.
13 changes: 13 additions & 0 deletions configs/env/flp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
_target_: rl4co.envs.FLPEnv
name: flp

generator_params:
num_loc: 100
min_loc: 0.0
max_loc: 1.0
loc_distribution: uniform
to_choose: 10

# data_dir: ${paths.root_dir}/data/mcp
# val_file: mcp${env.generator_params.num_loc}_val_seed4321.npz
# test_file: mcp${env.generator_params.num_loc}_test_seed1234.npz
17 changes: 17 additions & 0 deletions configs/env/mcp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
_target_: rl4co.envs.MCPEnv
name: mcp

generator_params:
num_items: 200
num_sets: 100
min_weight: 1
max_weight: 10
min_size: 5
max_size: 15
n_sets_to_choose: 10
size_distribution: uniform
weight_distribution: uniform

# data_dir: ${paths.root_dir}/data/mcp
# val_file: mcp${env.generator_params.num_loc}_val_seed4321.npz
# test_file: mcp${env.generator_params.num_loc}_test_seed1234.npz
2 changes: 1 addition & 1 deletion configs/experiment/graph/am.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

defaults:
- override /model: am.yaml
- override /env: mcp.yaml
- override /env: flp.yaml
- override /callbacks: default.yaml
- override /trainer: default.yaml
- override /logger: wandb.yaml
Expand Down
20 changes: 20 additions & 0 deletions rl4co/models/nn/env_embeddings/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def env_context_embedding(env_name: str, config: dict) -> nn.Module:
"mdcpdp": MDCPDPContext,
"mtvrp": MTVRPContext,
"shpp": TSPContext,
"flp": FLPContext,
}

if env_name not in embedding_registry:
Expand Down Expand Up @@ -372,3 +373,22 @@ def _state_embedding(self, embeddings, td):
],
-1,
)

class FLPContext(EnvContext):
"""Context embedding for the Facility Location Problem (FLP).
"""
def __init__(self, embed_dim: int):
super(FLPContext, self).__init__(embed_dim=embed_dim)
self.embed_dim = embed_dim
# self.mlp_context = MLP(embed_dim, [embed_dim, embed_dim])
self.projection = nn.Linear(embed_dim, embed_dim, bias=True)

def forward(self, embeddings, td):
cur_dist = td["distances"].unsqueeze(-2) # (batch_size, 1, n_points)
dist_improve = cur_dist - td["orig_distances"] # (batch_size, n_points, n_points)
dist_improve = torch.clamp(dist_improve, min=0).sum(-1) # (batch_size, n_points)

# softmax
loc_best_soft = torch.softmax(dist_improve, dim=-1) # (batch_size, n_points)
embed_best = (embeddings * loc_best_soft[..., None]).sum(-2)
return embed_best
10 changes: 10 additions & 0 deletions rl4co/models/nn/env_embeddings/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def env_init_embedding(env_name: str, config: dict) -> nn.Module:
"jssp": FJSPInitEmbedding,
"mtvrp": MTVRPInitEmbedding,
"shpp": TSPInitEmbedding,
"flp": FLPInitEmbedding,
}

if env_name not in embedding_registry:
Expand Down Expand Up @@ -562,3 +563,12 @@ def forward(self, td):
)
)
return torch.cat((depot_embedding, node_embeddings), -2)

class FLPInitEmbedding(nn.Module):
def __init__(self, embed_dim: int):
super().__init__()
self.projection = nn.Linear(2, embed_dim, bias=True)

def forward(self, td: TensorDict):
hdim = self.projection(td["locs"])
return hdim

0 comments on commit d1d238f

Please sign in to comment.