Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug fixes for #255] Added environment config files for FLP and MCP #256

Merged
merged 3 commits into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions configs/env/flp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
_target_: rl4co.envs.FLPEnv
name: flp

generator_params:
num_loc: 100
min_loc: 0.0
max_loc: 1.0
loc_distribution: uniform
to_choose: 10

# data_dir: ${paths.root_dir}/data/mcp
# val_file: mcp${env.generator_params.num_loc}_val_seed4321.npz
# test_file: mcp${env.generator_params.num_loc}_test_seed1234.npz
17 changes: 17 additions & 0 deletions configs/env/mcp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
_target_: rl4co.envs.MCPEnv
name: mcp

generator_params:
num_items: 200
num_sets: 100
min_weight: 1
max_weight: 10
min_size: 5
max_size: 15
n_sets_to_choose: 10
size_distribution: uniform
weight_distribution: uniform

# data_dir: ${paths.root_dir}/data/mcp
# val_file: mcp${env.generator_params.num_loc}_val_seed4321.npz
# test_file: mcp${env.generator_params.num_loc}_test_seed1234.npz
2 changes: 1 addition & 1 deletion configs/experiment/graph/am.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

defaults:
- override /model: am.yaml
- override /env: mcp.yaml
- override /env: flp.yaml
- override /callbacks: default.yaml
- override /trainer: default.yaml
- override /logger: wandb.yaml
Expand Down
20 changes: 20 additions & 0 deletions rl4co/models/nn/env_embeddings/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def env_context_embedding(env_name: str, config: dict) -> nn.Module:
"mdcpdp": MDCPDPContext,
"mtvrp": MTVRPContext,
"shpp": TSPContext,
"flp": FLPContext,
}

if env_name not in embedding_registry:
Expand Down Expand Up @@ -372,3 +373,22 @@ def _state_embedding(self, embeddings, td):
],
-1,
)

class FLPContext(EnvContext):
"""Context embedding for the Facility Location Problem (FLP).
"""
def __init__(self, embed_dim: int):
super(FLPContext, self).__init__(embed_dim=embed_dim)
self.embed_dim = embed_dim
# self.mlp_context = MLP(embed_dim, [embed_dim, embed_dim])
self.projection = nn.Linear(embed_dim, embed_dim, bias=True)

def forward(self, embeddings, td):
cur_dist = td["distances"].unsqueeze(-2) # (batch_size, 1, n_points)
dist_improve = cur_dist - td["orig_distances"] # (batch_size, n_points, n_points)
dist_improve = torch.clamp(dist_improve, min=0).sum(-1) # (batch_size, n_points)

# softmax
loc_best_soft = torch.softmax(dist_improve, dim=-1) # (batch_size, n_points)
embed_best = (embeddings * loc_best_soft[..., None]).sum(-2)
return embed_best
10 changes: 10 additions & 0 deletions rl4co/models/nn/env_embeddings/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def env_init_embedding(env_name: str, config: dict) -> nn.Module:
"jssp": FJSPInitEmbedding,
"mtvrp": MTVRPInitEmbedding,
"shpp": TSPInitEmbedding,
"flp": FLPInitEmbedding,
}

if env_name not in embedding_registry:
Expand Down Expand Up @@ -562,3 +563,12 @@ def forward(self, td):
)
)
return torch.cat((depot_embedding, node_embeddings), -2)

class FLPInitEmbedding(nn.Module):
def __init__(self, embed_dim: int):
super().__init__()
self.projection = nn.Linear(2, embed_dim, bias=True)

def forward(self, td: TensorDict):
hdim = self.projection(td["locs"])
return hdim