diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a95395e..d3a74a6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -60,8 +60,8 @@ jobs: - run: | julia --project=docs -e ' using Documenter: doctest - using MAMCTS - doctest(MAMCTS)' + using FactoredValueMCTS + doctest(FactoredValueMCTS)' - run: julia --project=docs docs/make.jl env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/CITATION.bib b/CITATION.bib new file mode 100644 index 0000000..3e9f805 --- /dev/null +++ b/CITATION.bib @@ -0,0 +1,7 @@ +@inproceedings{choudhury2021scalable, + title={Scalable Anytime Planning for Multi-Agent {MDP}s}, + author={Choudhury, Shushman and Gupta, Jayesh K and Morales, Peter and Kochenderfer, Mykel}, + booktitle={International Conference on Autonomous Agents and Multiagent Systems (AAMAS)}, + year={2021}, + organization={IFAAMAS} +} \ No newline at end of file diff --git a/Project.toml b/Project.toml index 88f09b1..6a0e875 100644 --- a/Project.toml +++ b/Project.toml @@ -1,8 +1,23 @@ -name = "MAMCTS" +name = "FactoredValueMCTS" uuid = "c016a6d7-1193-47d7-896a-d9f14d6b4b26" authors = ["Stanford Intelligent Systems Laboratory"] version = "0.1.0" +[deps] +BeliefUpdaters = "8bb6e9a1-7d73-552c-a44a-e5dc5634aac4" +LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MCTS = "e12ccd36-dcad-5f33-8774-9175229e7b33" +MultiAgentPOMDPs = "9ac5fd70-7902-42e7-9745-ff446b44e779" +POMDPLinter = "f3bd98c0-eb40-45e2-9eb1-f2763262d755" +POMDPModelTools = "08074719-1b2a-587c-a292-00f91cc44415" +POMDPPolicies = "182e52fb-cfd0-5e46-8c26-fd0667c990f4" +POMDPSimulators = "e0d0a172-29c6-5d4e-96d0-f262df5d01fd" +POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d" +Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + [compat] julia = "1.5" diff --git a/README.md b/README.md index 629c498..f80f1c2 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ -# MAMCTS +# FactoredValueMCTS -[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://rejuvyesh.github.io/MAMCTS.jl/stable) -[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://rejuvyesh.github.io/MAMCTS.jl/dev) -[![Build Status](https://github.com/rejuvyesh/MAMCTS.jl/workflows/CI/badge.svg)](https://github.com/rejuvyesh/MAMCTS.jl/actions) -[![Coverage](https://codecov.io/gh/rejuvyesh/MAMCTS.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/rejuvyesh/MAMCTS.jl) +[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://JuliaPOMDP.github.io/FactoredValueMCTS.jl/stable) +[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://JuliaPOMDP.github.io/FactoredValueMCTS.jl/dev) +[![Build Status](https://github.com/JuliaPOMDP/FactoredValueMCTS.jl/workflows/CI/badge.svg)](https://github.com/JuliaPOMDP/FactoredValueMCTS.jl/actions) +[![Coverage](https://codecov.io/gh/JuliaPOMDP/FactoredValueMCTS.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/JuliaPOMDP/FactoredValueMCTS.jl) diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 4053bdb..2b3e7d3 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -49,7 +49,7 @@ uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" [[Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" -[[MAMCTS]] +[[FactoredValueMCTS]] path = ".." uuid = "c016a6d7-1193-47d7-896a-d9f14d6b4b26" version = "0.1.0" diff --git a/docs/Project.toml b/docs/Project.toml index c646ea9..cfa3359 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,3 +1,3 @@ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -MAMCTS = "c016a6d7-1193-47d7-896a-d9f14d6b4b26" +FactoredValueMCTS = "c016a6d7-1193-47d7-896a-d9f14d6b4b26" diff --git a/docs/make.jl b/docs/make.jl index 4350303..4245289 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,14 +1,14 @@ -using MAMCTS +using FactoredValueMCTS using Documenter makedocs(; - modules=[MAMCTS], + modules=[FactoredValueMCTS], authors="Stanford Intelligent Systems Laboratory", - repo="https://github.com/rejuvyesh/MAMCTS.jl/blob/{commit}{path}#L{line}", - sitename="MAMCTS.jl", + repo="https://github.com/JuliaPOMDP/FactoredValueMCTS.jl/blob/{commit}{path}#L{line}", + sitename="FactoredValueMCTS.jl", format=Documenter.HTML(; prettyurls=get(ENV, "CI", "false") == "true", - canonical="https://rejuvyesh.github.io/MAMCTS.jl", + canonical="https://JuliaPOMDP.github.io/FactoredValueMCTS.jl", assets=String[], ), pages=[ @@ -17,5 +17,5 @@ makedocs(; ) deploydocs(; - repo="github.com/rejuvyesh/MAMCTS.jl", + repo="github.com/JuliaPOMDP/FactoredValueMCTS.jl", ) diff --git a/docs/src/index.md b/docs/src/index.md index f63d0e5..11deaaa 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,12 +1,12 @@ ```@meta -CurrentModule = MAMCTS +CurrentModule = FactoredValueMCTS ``` -# MAMCTS +# FactoredValueMCTS ```@index ``` ```@autodocs -Modules = [MAMCTS] +Modules = [FactoredValueMCTS] ``` diff --git a/src/FactoredValueMCTS.jl b/src/FactoredValueMCTS.jl new file mode 100644 index 0000000..97a9ad5 --- /dev/null +++ b/src/FactoredValueMCTS.jl @@ -0,0 +1,103 @@ +module FactoredValueMCTS + +using Random +using LinearAlgebra + +using POMDPs +using MultiAgentPOMDPs +using POMDPPolicies +using POMDPLinter: @req, @subreq, @POMDP_require +using MCTS +using LightGraphs +using BeliefUpdaters + +using MCTS: convert_estimator +import POMDPModelTools + +using POMDPSimulators: RolloutSimulator +import POMDPs + +# Patch simulate to support vector of rewards +function POMDPs.simulate(sim::RolloutSimulator, mdp::JointMDP, policy::Policy, initialstate::S) where {S} + + if sim.eps === nothing + eps = 0.0 + else + eps = sim.eps + end + + if sim.max_steps === nothing + max_steps = typemax(Int) + else + max_steps = sim.max_steps + end + + s = initialstate + + + # TODO: doesn't this add unnecessary action search? + r = @gen(:r)(mdp, s, action(policy, s), sim.rng) + if r isa AbstractVector + r_total = zeros(n_agents(mdp)) + else + r_total = 0.0 + end + sim_helper!(r_total, sim, mdp, policy, s, max_steps, eps) + return r_total +end + +function sim_helper!(r_total::AbstractVector{F}, sim, mdp, policy, s, max_steps, eps) where {F} + step = 1 + disc = 1.0 + while disc > eps && !isterminal(mdp, s) && step <= max_steps + a = action(policy, s) + + sp, r = @gen(:sp, :r)(mdp, s, a, sim.rng) + + r_total .+= disc.*r + + s = sp + + disc *= discount(mdp) + step += 1 + end + +end + +function sim_helper!(r_total::AbstractFloat, sim, mdp, policy, s, max_steps, eps) + step = 1 + disc = 1.0 + while disc > eps && !isterminal(mdp, s) && step <= max_steps + a = action(policy, s) + + sp, r = @gen(:sp, :r)(mdp, s, a, sim.rng) + + r_total += disc.*r + + s = sp + + disc *= discount(mdp) + step += 1 + end + +end + + +### +# Factored Value MCTS +# + +abstract type CoordinationStatistics end + +include(joinpath("fvmcts", "factoredpolicy.jl")) +include(joinpath("fvmcts", "fv_mcts_vanilla.jl")) +include(joinpath("fvmcts", "action_coordination", "varel.jl")) +include(joinpath("fvmcts", "action_coordination", "maxplus.jl")) + +export + FVMCTSSolver, + MaxPlus, + VarEl + + +end diff --git a/src/MAMCTS.jl b/src/MAMCTS.jl deleted file mode 100644 index 5399134..0000000 --- a/src/MAMCTS.jl +++ /dev/null @@ -1,5 +0,0 @@ -module MAMCTS - -# Write your package code here. - -end diff --git a/src/fvmcts/action_coordination/maxplus.jl b/src/fvmcts/action_coordination/maxplus.jl new file mode 100644 index 0000000..40fe54e --- /dev/null +++ b/src/fvmcts/action_coordination/maxplus.jl @@ -0,0 +1,297 @@ +# NOTE: Matrix implicitly assumes all agents have same number of actions +mutable struct PerStateMPStats + agent_action_n::Matrix{Int64} # N X A + agent_action_q::Matrix{Float64} + edge_action_n::Matrix{Int64} # |E| X A^2 + edge_action_q::Matrix{Float64} +end + +""" +Tracks the specific informations and statistics we need to use Max-Plus to coordinate_action +the joint action in Factored-Value MCTS. Putting parameters here is a little ugly but coordinate_action can't have them since VarEl doesn't use those args. + +Fields: + adjmatgraph::SimpleGraph + The coordination graph as a LightGraphs SimpleGraph. + + message_iters::Int64 + Number of rounds of message passing. + + message_norm::Bool + Whether to normalize the messages or not after message passing. + + use_agent_utils::Bool + Whether to include the per-agent utilities while computing the best agent action (see our paper for details) + + node_exploration::Bool + Whether to use the per-node UCB style bonus while computing the best agent action (see our paper for details) + + edge_exploration::Bool + Whether to use the per-edge UCB style bonus after the message passing rounds (see our paper for details). One of this or node_exploration MUST be true for exploration. + + all_states_stats::Dict{AbstractVector{S},PerStateMPStats} + Maps each joint state in the tree to the per-state statistics. +""" +mutable struct MaxPlusStatistics{S} <: CoordinationStatistics + adjmatgraph::SimpleGraph + message_iters::Int64 + message_norm::Bool + use_agent_utils::Bool + node_exploration::Bool + edge_exploration::Bool # NOTE: One of this or node exploration must be true + all_states_stats::Dict{S,PerStateMPStats} +end + +function clear_statistics!(mp_stats::MaxPlusStatistics) + empty!(mp_stats.all_states_stats) +end + +function update_statistics!(mdp::JointMDP{S,A}, tree::FVMCTSTree{S,A,MaxPlusStatistics{S}}, + s::S, ucb_action::A, q::AbstractFloat) where {S,A} + + update_statistics!(mdp, tree, s, ucb_action, ones(typeof(q), n_agents(mdp)) * q) +end + +""" +Take the q-value from the MCTS step and distribute the updates across the per-node and per-edge q-stats as per the formula in our paper. +""" +function update_statistics!(mdp::JointMDP{S,A}, tree::FVMCTSTree{S,A,MaxPlusStatistics{S}}, + s::S, ucb_action::A, q::AbstractVector{Float64}) where {S,A} + + state_stats = tree.coordination_stats.all_states_stats[s] + nagents = n_agents(mdp) + + # Update per agent action stats + for i = 1:nagents + ac_idx = agent_actionindex(mdp, i, ucb_action[i]) + lock(tree.lock) do + state_stats.agent_action_n[i, ac_idx] += 1 + state_stats.agent_action_q[i, ac_idx] += + (q[i] - state_stats.agent_action_q[i, ac_idx]) / state_stats.agent_action_n[i, ac_idx] + end + end + + # Now update per-edge action stats + for (idx, e) in enumerate(edges(tree.coordination_stats.adjmatgraph)) + # NOTE: Need to be careful about action ordering + # Being more general to have unequal agent actions + edge_comp = (e.src,e.dst) + edge_tup = Tuple(1:length(tree.all_agent_actions[c]) for c in edge_comp) + edge_ac_idx = LinearIndices(edge_tup)[agent_actionindex(mdp, e.src, ucb_action[e.src]), + agent_actionindex(mdp, e.dst, ucb_action[e.dst])] + q_edge_value = q[e.src] + q[e.dst] + + lock(tree.lock) do + state_stats.edge_action_n[idx, edge_ac_idx] += 1 + state_stats.edge_action_q[idx, edge_ac_idx] += + (q_edge_value - state_stats.edge_action_q[idx, edge_ac_idx]) / state_stats.edge_action_n[idx, edge_ac_idx] + end + end + + lock(tree.lock) do + tree.coordination_stats.all_states_stats[s] = state_stats + end + +end + +function init_statistics!(tree::FVMCTSTree{S,A,MaxPlusStatistics{S}}, planner::FVMCTSPlanner, + s::S) where {S,A} + + n_agents = length(s) + + # NOTE: Assuming all agents have the same actions here + n_all_actions = length(tree.all_agent_actions[1]) + + agent_action_n = zeros(Int64, n_agents, n_all_actions) + agent_action_q = zeros(Float64, n_agents, n_all_actions) + + # Loop over agents and then actions + # TODO: Need to define init_N and init_Q for single agent + for i = 1:n_agents + for (j, ac) in enumerate(tree.all_agent_actions[i]) + agent_action_n[i, j] = init_N(planner.solver.init_N, planner.mdp, s, i, ac) + agent_action_q[i, j] = init_Q(planner.solver.init_Q, planner.mdp, s, i, ac) + end + end + + n_edges = ne(tree.coordination_stats.adjmatgraph) + edge_action_n = zeros(Int64, n_edges, n_all_actions^2) + edge_action_q = zeros(Float64, n_edges, n_all_actions^2) + + # Loop over edges and then action_i \times action_j + for (idx, e) in enumerate(edges(tree.coordination_stats.adjmatgraph)) + edge_comp = (e.src, e.dst) + n_edge_actions = prod([length(tree.all_agent_actions[c]) for c in edge_comp]) + edge_tup = Tuple(1:length(tree.all_agent_actions[c]) for c in edge_comp) + + for edge_ac_idx = 1:n_edge_actions + ct_idx = CartesianIndices(edge_tup)[edge_ac_idx] + edge_action = [tree.all_agent_actions[c] for c in edge_comp] + + edge_action_n[idx, edge_ac_idx] = init_N(planner.solver.init_N, planner.mdp, s, edge_comp, edge_action) + edge_action_q[idx, edge_ac_idx] = init_Q(planner.solver.init_Q, planner.mdp, s, edge_comp, edge_action) + end + end + + state_stats = PerStateMPStats(agent_action_n, agent_action_q, edge_action_n, edge_action_q) + lock(tree.lock) do + tree.coordination_stats.all_states_stats[s] = state_stats + end +end + +""" +Runs Max-Plus at the current state using the per-state MaxPlusStatistics to compute the best joint action with either or both of node-wise and edge-wise exploration bonus. Rounds of message passing are followed by per-node maximization. +""" +function coordinate_action(mdp::JointMDP{S,A}, tree::FVMCTSTree{S,A,MaxPlusStatistics{S}}, s::S, + exploration_constant::Float64=0.0, node_id::Int64=0) where {S,A} + + state_stats = lock(tree.lock) do + tree.coordination_stats.all_states_stats[s] + end + adjgraphmat = lock(tree.lock) do + tree.coordination_stats.adjmatgraph + end + k = tree.coordination_stats.message_iters + message_norm = tree.coordination_stats.message_norm + + n_agents = length(s) + state_agent_actions = [agent_actions(mdp, i, si) for (i, si) in enumerate(s)] + n_all_actions = length(tree.all_agent_actions[1]) + n_edges = ne(tree.coordination_stats.adjmatgraph) + + # Init forward and backward messages and q0 + fwd_messages = zeros(Float64, n_edges, n_all_actions) + bwd_messages = zeros(Float64, n_edges, n_all_actions) + + if tree.coordination_stats.use_agent_utils + q_values = state_stats.agent_action_q / n_agents + else + q_values = zeros(size(state_stats.agent_action_q)) + end + + state_total_n = lock(tree.lock) do + (node_id > 0) ? tree.total_n[node_id] : 1 + end + + + # Iterate over passes + for t = 1:k + fnormdiff, bnormdiff = perform_message_passing!(fwd_messages, bwd_messages, mdp, tree.all_agent_actions, + adjgraphmat, state_agent_actions, n_edges, q_values, message_norm, + 0, state_stats, state_total_n) + + if !tree.coordination_stats.use_agent_utils + q_values = zeros(size(state_stats.agent_action_q)) + end + + # Update Q value with messages + for i = 1:n_agents + + # need indices of all edges that agent is involved in + nbrs = neighbors(tree.coordination_stats.adjmatgraph, i) + + edgelist = collect(edges(adjgraphmat)) + + if tree.coordination_stats.use_agent_utils + @views q_values[i, :] = state_stats.agent_action_q[i, :]/n_agents + end + for n in nbrs + if Edge(i,n) in edgelist # use backward message + q_values[i,:] += bwd_messages[findfirst(isequal(Edge(i,n)), edgelist), :] + elseif Edge(n,i) in edgelist + q_values[i,:] += fwd_messages[findfirst(isequal(Edge(n,i)), edgelist), :] + else + @warn "Neither edge found!" + end + end + end + + # If converged, break + if isapprox(fnormdiff, 0.0) && isapprox(bnormdiff, 0.0) + break + end + end # for t = 1:k + + # If edge exploration flag enabled, do a final exploration bonus + if tree.coordination_stats.edge_exploration + perform_message_passing!(fwd_messages, bwd_messages, mdp, tree.all_agent_actions, + adjgraphmat, state_agent_actions, n_edges, q_values, message_norm, + exploration_constant, state_stats, state_total_n) + end # if edge_exploration + + + # Maximize q values for agents + best_action = Vector{eltype(A)}(undef, n_agents) + for i = 1:n_agents + + # NOTE: Again can't just iterate over agent actions as it may be a subset + exp_q_values = zeros(length(state_agent_actions[i])) + if tree.coordination_stats.node_exploration + for (idx, ai) in enumerate(state_agent_actions[i]) + ai_idx = agent_actionindex(mdp, i, ai) + exp_q_values[idx] = q_values[i, ai_idx] + exploration_constant*sqrt((log(state_total_n + 1.0))/(state_stats.agent_action_n[i, ai_idx] + 1.0)) + end + else + for (idx, ai) in enumerate(state_agent_actions[i]) + ai_idx = agent_actionindex(mdp, i, ai) + exp_q_values[idx] = q_values[i, ai_idx] + end + end + + # NOTE: Can now look up index in exp_q_values and then again look at state_agent_actions + _, idx = findmax(exp_q_values) + best_action[i] = state_agent_actions[i][idx] + end + + return best_action +end + +function perform_message_passing!(fwd_messages::AbstractArray{F,2}, bwd_messages::AbstractArray{F,2}, + mdp, all_agent_actions, + adjgraphmat, state_agent_actions, n_edges::Int, q_values, message_norm, + exploration_constant, state_stats, state_total_n) where {F} + # Iterate over edges + fwd_messages_old = deepcopy(fwd_messages) + bwd_messages_old = deepcopy(bwd_messages) + for (e_idx, e) in enumerate(edges(adjgraphmat)) + + i = e.src + j = e.dst + edge_tup_indices = LinearIndices(Tuple(1:length(all_agent_actions[c]) for c in (i,j))) + + # forward: maximize sender + # NOTE: Can't do enumerate as action set might be smaller + # Need to look up global index of agent action and use that + # Need to break up vectorized loop + @inbounds for aj in state_agent_actions[j] + aj_idx = agent_actionindex(mdp, j, aj) + fwd_message_vals = zeros(length(state_agent_actions[i])) + # TODO: Should we use inbounds here again? + @inbounds for (idx, ai) in enumerate(state_agent_actions[i]) + ai_idx = agent_actionindex(mdp, i, ai) + fwd_message_vals[idx] = q_values[i, ai_idx] - bwd_messages_old[e_idx, ai_idx] + state_stats.edge_action_q[e_idx, edge_tup_indices[ai_idx, aj_idx]]/n_edges + exploration_constant * sqrt( (log(state_total_n + 1.0)) / (state_stats.edge_action_n[e_idx, edge_tup_indices[ai_idx, aj_idx]] + 1) ) + end + fwd_messages[e_idx, aj_idx] = maximum(fwd_message_vals) + end + + @inbounds for ai in state_agent_actions[i] + ai_idx = agent_actionindex(mdp, i, ai) + bwd_message_vals = zeros(length(state_agent_actions[j])) + @inbounds for (idx, aj) in enumerate(state_agent_actions[j]) + aj_idx = agent_actionindex(mdp, j, aj) + bwd_message_vals[idx] = q_values[j, aj_idx] - fwd_messages_old[e_idx, aj_idx] + state_stats.edge_action_q[e_idx, edge_tup_indices[ai_idx, aj_idx]]/n_edges + exploration_constant * sqrt( (log(state_total_n + 1.0))/ (state_stats.edge_action_n[e_idx, edge_tup_indices[ai_idx, aj_idx]] + 1) ) + end + bwd_messages[e_idx, ai_idx] = maximum(bwd_message_vals) + end + + # Normalize messages for better convergence + if message_norm + @views fwd_messages[e_idx, :] .-= sum(fwd_messages[e_idx, :])/length(fwd_messages[e_idx, :]) + @views bwd_messages[e_idx, :] .-= sum(bwd_messages[e_idx, :])/length(bwd_messages[e_idx, :]) + end + + end # (idx,edges) in enumerate(edges) + + # Return norm of message difference + return norm(fwd_messages - fwd_messages_old), norm(bwd_messages - bwd_messages_old) +end diff --git a/src/fvmcts/action_coordination/varel.jl b/src/fvmcts/action_coordination/varel.jl new file mode 100644 index 0000000..d2730ae --- /dev/null +++ b/src/fvmcts/action_coordination/varel.jl @@ -0,0 +1,328 @@ +""" +Tracks the specific informations and statistics we need to use Var-El to coordinate_action +the joint action in Factored-Value MCTS. + +Fields: + coord_graph_components::Vector{Vector{Int64}} + The list of coordination graph components, i.e., cliques, where each element is a list of agent IDs that are in a mutual clique. + + min_degree_ordering::Vector{Int64} + Ordering of agent IDs in increasing CG degree. This ordering is the heuristic most typically used for the elimination order in Var-El. + + n_component_stats::Dict{AbstractVector{S},Vector{Vector{Int64}}} + Maps each joint state in the tree (for which we need to compute the UCB action) to the frequency of each component's various local actions. + + q_component_stats::Dict{AbstractVector{S},Vector{Vector{Float64}}} + Maps each joint state in the tree to the accumulated q-value of each component's various local actions. +""" +mutable struct VarElStatistics{S} <: CoordinationStatistics + coord_graph_components::Vector{Vector{Int64}} + min_degree_ordering::Vector{Int64} + n_component_stats::Dict{S,Vector{Vector{Int64}}} + q_component_stats::Dict{S,Vector{Vector{Float64}}} +end + +function clear_statistics!(ve_stats::VarElStatistics) + empty!(ve_stats.n_component_stats) + empty!(ve_stats.q_component_stats) +end + + +""" +Runs variable elimination at the current state using the VarEl Statistics to compute the best joint action with the component-wise exploration bonus. +FYI: Rather complicated. +""" +function coordinate_action(mdp::JointMDP{S,A}, tree::FVMCTSTree{S,A,VarElStatistics{S}}, s::S, + exploration_constant::Float64=0.0, node_id::Int64=0) where {S,A} + + n_agents = length(s) + best_action_idxs = MVector{n_agents}([-1 for i in 1:n_agents]) + + # !Note: Acquire lock so as to avoid race + state_q_stats = lock(tree.lock) do + tree.coordination_stats.q_component_stats[s] + end + state_n_stats = lock(tree.lock) do + tree.coordination_stats.n_component_stats[s] + end + state_total_n = lock(tree.lock) do + (node_id > 0) ? tree.total_n[node_id] : 0 + end + + # Maintain set of potential functions + # NOTE: Hashing a vector here + potential_fns = Dict{Vector{Int64},Vector{Float64}}() + for (comp, q_stats) in zip(tree.coordination_stats.coord_graph_components, state_q_stats) + potential_fns[comp] = q_stats + end + + # Need this for reverse process + # Maps agent to other elements in best response functions and corresponding set of actions + # E.g. Agent 2 -> (3,4) in its best response and corresponding vector of agent 2 best actions + best_response_fns = Dict{Int64,Tuple{Vector{Int64},Vector{Int64}}}() + + state_dep_actions = [agent_actions(mdp, i, si) for (i, si) in enumerate(s)] + + # Iterate over variable ordering + # Need to maintain intermediate tables + for ag_idx in tree.coordination_stats.min_degree_ordering + + # Lookup factors with agent in them and simultaneously construct + # members of new potential function, and delete old factors + agent_factors = Vector{Vector{Int64}}(undef, 0) + new_potential_members = Vector{Int64}(undef, 0) + for k in collect(keys(potential_fns)) + if ag_idx in k + + # Agent to-be-eliminated is in factor + push!(agent_factors, k) + + # Construct key for new potential as union of all others except ag_idx + for ag in k + if ag != ag_idx && ~(ag in new_potential_members) + push!(new_potential_members, ag) + end + end + end + end + + if isempty(new_potential_members) == true + # No out neighbors..either at beginning or end of ordering + @assert agent_factors == [[ag_idx]] "agent_factors $(agent_factors) is not just [ag_idx] $([ag_idx])!" + best_action_idxs[ag_idx] = _best_actionindex_empty(potential_fns, + state_dep_actions, + tree.all_agent_actions, + ag_idx) + + else + + # Generate new potential function and the best response vector for eliminated agent + n_comp_actions = prod([length(tree.all_agent_actions[c]) for c in new_potential_members]) + + # NOTE: Tuples should ALWAYS use tree.all_agent_actions for indexing + comp_tup = Tuple(1:length(tree.all_agent_actions[c]) for c in new_potential_members) + + # Initialize q-stats for new potential and best response action vector + # will be inserted into corresponding dictionaries at the end + new_potential_stats = Vector{Float64}(undef, n_comp_actions) + best_response_vect = Vector{Int64}(undef, n_comp_actions) + + # Iterate over new potential joint actions and compute new payoff and best response + for comp_ac_idx = 1:n_comp_actions + + # Get joint action for other members in potential + ct_idx = CartesianIndices(comp_tup)[comp_ac_idx] + + # For maximizing over agent actions + # As before, we now need to init with -Inf + ag_ac_values = zeros(length(tree.all_agent_actions[ag_idx])) + + # TODO: Agent actions should already be in order + # Only do anything if action legal + for (ag_ac_idx, ag_ac) in enumerate(tree.all_agent_actions[ag_idx]) + + if ag_ac in state_dep_actions[ag_idx] + + # Need to look up corresponding stats from agent_factors + for factor in agent_factors + + # NOTE: Need to reconcile the ORDER of ag_idx in factor + factor_action_idxs = MVector{length(factor),Int64}(undef) + + for (idx, f) in enumerate(factor) + + # if f is ag_idx, set corresponding factor action to ag_ac + if f == ag_idx + factor_action_idxs[idx] = ag_ac_idx + else + # Lookup index for corresp. agent action in ct_idx + new_pot_idx = findfirst(isequal(f), new_potential_members) + factor_action_idxs[idx] = ct_idx[new_pot_idx] + end # f == ag_idx + end + + # NOW we can look up the stats of the factor + factor_tup = Tuple(1:length(tree.all_agent_actions[c]) for c in factor) + factor_action_linidx = LinearIndices(factor_tup)[factor_action_idxs...] + + ag_ac_values[ag_ac_idx] += potential_fns[factor][factor_action_linidx] + + # Additionally add exploration stats if factor in original set + factor_comp_idx = findfirst(isequal(factor), tree.coordination_stats.coord_graph_components) + if state_total_n > 0 && ~(isnothing(factor_comp_idx)) # NOTE: Julia1.1 + ag_ac_values[ag_ac_idx] += exploration_constant * sqrt((log(state_total_n+1.0))/(state_n_stats[factor_comp_idx][factor_action_linidx]+1.0)) + end + end # factor in agent_factors + else + ag_ac_values[ag_ac_idx] = -Inf + end # ag_ac in state_dep_actions + end # ag_ac_idx = 1:length(tree.all_agent_actions[ag_idx]) + + + # Now we lookup ag_ac_values for the best value to be put in new_potential_stats + # and the best index to be put in best_response_vect + # NOTE: The -Inf mask should ensure only legal idxs chosen + # If all ag_ac_values equal, should we sample randomly? + best_val, best_idx = findmax(ag_ac_values) + + new_potential_stats[comp_ac_idx] = best_val + best_response_vect[comp_ac_idx] = best_idx + end # comp_ac_idx in n_comp_actions + + # Finally, we enter new stats vector and best response vector back to dicts + potential_fns[new_potential_members] = new_potential_stats + best_response_fns[ag_idx] = (new_potential_members, best_response_vect) + end # isempty(new_potential_members) + + # Delete keys in agent_factors from potential fns since variable has been eliminated + for factor in agent_factors + delete!(potential_fns, factor) + end + end # ag_idx in min_deg_ordering + + # NOTE: At this point, best_action_idxs has at least one entry...for the last action obtained + @assert !all(isequal(-1), best_action_idxs) "best_action_idxs is still undefined!" + + # Message passing in reverse order to recover best action + for ag_idx in Base.Iterators.reverse(tree.coordination_stats.min_degree_ordering) + + # Only do something if best action already not obtained + if best_action_idxs[ag_idx] == -1 + + # Should just be able to lookup best response function + (agents, best_response_vect) = best_response_fns[ag_idx] + + # Members of agents should already have their best action defined + agent_ac_tup = Tuple(1:length(tree.all_agent_actions[c]) for c in agents) + best_agents_action_idxs = [best_action_idxs[ag] for ag in agents] + best_response_idx = LinearIndices(agent_ac_tup)[best_agents_action_idxs...] + + # Assign best action for ag_idx + best_action_idxs[ag_idx] = best_response_vect[best_response_idx] + end # isdefined + end + + # Finally, return best action by iterating over best action indices + # NOTE: best_action should use state-dep actions to reverse index + best_action = [tree.all_agent_actions[ag][idx] for (ag, idx) in enumerate(best_action_idxs)] + + return best_action +end + + +""" +Take the q-value from the MCTS step and distribute the updates across the component q-stats as per the formula in the Amato-Oliehoek paper. +""" +function update_statistics!(mdp::JointMDP{S,A}, tree::FVMCTSTree{S,A,VarElStatistics{S}}, + s::S, ucb_action::A, q::AbstractVector{Float64}) where {S,A} + + n_agents = length(s) + + for (idx, comp) in enumerate(tree.coordination_stats.coord_graph_components) + + # Create cartesian index tuple + comp_tup = Tuple(1:length(tree.all_agent_actions[c]) for c in comp) + + # RECOVER local action corresp. to ucb action + # TODO: Review this carefully. Need @req for action index for agent. + local_action = [ucb_action[c] for c in comp] + local_action_idxs = [agent_actionindex(mdp, c, a) for (a, c) in zip(local_action, comp)] + + comp_ac_idx = LinearIndices(comp_tup)[local_action_idxs...] + + # NOTE: NOW we can update stats. Could generalize incremental update more here + lock(tree.lock) do + tree.coordination_stats.n_component_stats[s][idx][comp_ac_idx] += 1 + q_comp_value = sum(q[c] for c in comp) + tree.coordination_stats.q_component_stats[s][idx][comp_ac_idx] += + (q_comp_value - tree.coordination_stats.q_component_stats[s][idx][comp_ac_idx]) / tree.coordination_stats.n_component_stats[s][idx][comp_ac_idx] + end + end +end + +# TODO: is this the correct thing to do? +# My guess is no, but not sure. +function update_statistics!(mdp::JointMDP{S,A}, tree::FVMCTSTree{S,A,VarElStatistics{S}}, + s::S, ucb_action::A, q::Float64) where {S, A} + n_agents = length(s) + + for (idx, comp) in enumerate(tree.coordination_stats.coord_graph_components) + + # Create cartesian index tuple + comp_tup = Tuple(1:length(tree.all_agent_actions[c]) for c in comp) + + # RECOVER local action corresp. to ucb action + # TODO: Review this carefully. Need @req for action index for agent. + local_action = [ucb_action[c] for c in comp] + local_action_idxs = [agent_actionindex(mdp, c, a) for (a, c) in zip(local_action, comp)] + + comp_ac_idx = LinearIndices(comp_tup)[local_action_idxs...] + + # NOTE: NOW we can update stats. Could generalize incremental update more here + lock(tree.lock) do + tree.coordination_stats.n_component_stats[s][idx][comp_ac_idx] += 1 + q_comp_value = q * length(comp) # Maintains equivalence with `sum(q[c] for c in comp)` + tree.coordination_stats.q_component_stats[s][idx][comp_ac_idx] += + (q_comp_value - tree.coordination_stats.q_component_stats[s][idx][comp_ac_idx]) / tree.coordination_stats.n_component_stats[s][idx][comp_ac_idx] + end + end +end + + +function init_statistics!(tree::FVMCTSTree{S,A,VarElStatistics{S}}, planner::FVMCTSPlanner, + s::S) where {S,A} + + n_comps = length(tree.coordination_stats.coord_graph_components) + n_component_stats = Vector{Vector{Int64}}(undef, n_comps) + q_component_stats = Vector{Vector{Float64}}(undef, n_comps) + + n_agents = length(s) + + # TODO: Could actually make actions state-dependent if need be + for (idx, comp) in enumerate(tree.coordination_stats.coord_graph_components) + + n_comp_actions = prod([length(tree.all_agent_actions[c]) for c in comp]) + + n_component_stats[idx] = Vector{Int64}(undef, n_comp_actions) + q_component_stats[idx] = Vector{Float64}(undef, n_comp_actions) + + comp_tup = Tuple(1:length(tree.all_agent_actions[c]) for c in comp) + + for comp_ac_idx = 1:n_comp_actions + + # Generate action subcomponent and call init_Q and init_N for it + ct_idx = CartesianIndices(comp_tup)[comp_ac_idx] # Tuple corresp to + local_action = [tree.all_agent_actions[c][ai] for (c, ai) in zip(comp, Tuple(ct_idx))] + + # NOTE: init_N and init_Q are functions of component AND local action + # TODO(jkg): init_N and init_Q need to be defined + n_component_stats[idx][comp_ac_idx] = init_N(planner.solver.init_N, planner.mdp, s, comp, local_action) + q_component_stats[idx][comp_ac_idx] = init_Q(planner.solver.init_Q, planner.mdp, s, comp, local_action) + end + end + + # Update tree member + lock(tree.lock) do + tree.coordination_stats.n_component_stats[s] = n_component_stats + tree.coordination_stats.q_component_stats[s] = q_component_stats + end +end + +@inline function _best_actionindex_empty(potential_fns, state_dep_actions, all_agent_actions, ag_idx) + # NOTE: This is inefficient but necessary for state-dep actions? + if length(state_dep_actions[ag_idx]) == length(all_agent_actions[ag_idx]) + _, best_ac_idx = findmax(potential_fns[[ag_idx]]) + else + # Now we need to choose the best index from among legal actions + # Create an array with illegal actions having -Inf and then fill legal vals + # TODO: More efficient way to do this? + masked_action_vals = fill(-Inf, length(all_agent_actions[ag_idx])) + for (iac, ac) in enumerate(all_agent_actions[ag_idx]) + if ac in state_dep_actions[ag_idx] + masked_action_vals[iac] = potential_fns[[ag_idx]][iac] + end + end + _, best_ac_idx = findmax(masked_action_vals) + end + return best_ac_idx +end diff --git a/src/fvmcts/factoredpolicy.jl b/src/fvmcts/factoredpolicy.jl new file mode 100644 index 0000000..d99c775 --- /dev/null +++ b/src/fvmcts/factoredpolicy.jl @@ -0,0 +1,17 @@ + +""" +Random Policy factored for each agent. Avoids exploding action space. +""" +struct FactoredRandomPolicy{RNG<:AbstractRNG,P<:JointMDP, U<:Updater} <: Policy + rng::RNG + problem::P + updater::U +end + +FactoredRandomPolicy(problem::JointMDP; rng=Random.GLOBAL_RNG, updater=NothingUpdater()) = FactoredRandomPolicy(rng, problem, updater) + +function POMDPs.action(policy::FactoredRandomPolicy, s) + return [rand(policy.rng, agent_actions(policy.problem, i, si)) for (i, si) in enumerate(s)] +end + +POMDPs.solve(solver::RandomSolver, problem::JointMDP) = FactoredRandomPolicy(solver.rng, problem, NothingUpdater()) \ No newline at end of file diff --git a/src/fvmcts/fv_mcts_vanilla.jl b/src/fvmcts/fv_mcts_vanilla.jl new file mode 100644 index 0000000..ad7810b --- /dev/null +++ b/src/fvmcts/fv_mcts_vanilla.jl @@ -0,0 +1,416 @@ +using StaticArrays +using Parameters +using Base.Threads: @spawn + +abstract type AbstractCoordinationStrategy end + +struct VarEl <: AbstractCoordinationStrategy +end + +Base.@kwdef struct MaxPlus <:AbstractCoordinationStrategy + message_iters::Int64 = 10 + message_norm::Bool = true + use_agent_utils::Bool = false + node_exploration::Bool = true + edge_exploration::Bool = true +end + +""" +Factored Value Monte Carlo Tree Search solver datastructure + +Fields: + n_iterations::Int64 + Number of iterations during each action() call. + default: 100 + + max_time::Float64 + Maximum CPU time to spend computing an action. + default::Inf + + depth::Int64 + Number of iterations during each action() call. + default: 100 + + exploration_constant::Float64: + Specifies how much the solver should explore. In the UCB equation, Q + c*sqrt(log(t/N)), c is the exploration constant. + The exploration terms for FV-MCTS-Var-El and FV-MCTS-Max-Plus are different but the role of c is the same. + default: 1.0 + + rng::AbstractRNG: + Random number generator + + estimate_value::Any (rollout policy) + Function, object, or number used to estimate the value at the leaf nodes. + If this is a function `f`, `f(mdp, s, depth)` will be called to estimate the value. + If this is an object `o`, `estimate_value(o, mdp, s, depth)` will be called. + If this is a number, the value will be set to that number + default: RolloutEstimator(RandomSolver(rng)) + + init_Q::Any + Function, object, or number used to set the initial Q(s,a) value at a new node. + If this is a function `f`, `f(mdp, s, a)` will be called to set the value. + If this is an object `o`, `init_Q(o, mdp, s, a)` will be called. + If this is a number, Q will be set to that number + default: 0.0 + + init_N::Any + Function, object, or number used to set the initial N(s,a) value at a new node. + If this is a function `f`, `f(mdp, s, a)` will be called to set the value. + If this is an object `o`, `init_N(o, mdp, s, a)` will be called. + If this is a number, N will be set to that number + default: 0 + + reuse_tree::Bool + If this is true, the tree information is re-used for calculating the next plan. + Of course, clear_tree! can always be called to override this. + default: false + + coordination_strategy::AbstractCoordinationStrategy + The specific strategy with which to compute the best joint action from the current MCTS statistics. + default: VarEl() +""" +Base.@kwdef mutable struct FVMCTSSolver <: AbstractMCTSSolver + n_iterations::Int64 = 100 + max_time::Float64 = Inf + depth::Int64 = 10 + exploration_constant::Float64 = 1.0 + rng::AbstractRNG = Random.GLOBAL_RNG + estimate_value::Any = RolloutEstimator(RandomSolver(rng)) + init_Q::Any = 0.0 + init_N::Any = 0 + reuse_tree::Bool = false + coordination_strategy::AbstractCoordinationStrategy = VarEl() +end + + +mutable struct FVMCTSTree{S,A,CS<:CoordinationStatistics} + + # To map the multi-agent state vector to the ID of the node in the tree + state_map::Dict{S,Int64} + + # The next two vectors have one for each node ID in the tree + total_n::Vector{Int} # The number of times the node has been tried + s_labels::Vector{S} # The state corresponding to the node ID + + # List of all individual actions of each agent for coordination purposes. + all_agent_actions::Vector{A} + + coordination_stats::CS + lock::ReentrantLock +end + +function FVMCTSTree(all_agent_actions::Vector{A}, + coordination_stats::CS, + init_state::S, + lock::ReentrantLock, + sz::Int64=10000) where {S, A, CS <: CoordinationStatistics} + + return FVMCTSTree{S,A,CS}(Dict{S,Int64}(), + sizehint!(Int[], sz), + sizehint!(S[], sz), + all_agent_actions, + coordination_stats, + lock + ) +end # function + + + +Base.isempty(t::FVMCTSTree) = isempty(t.state_map) +state_nodes(t::FVMCTSTree) = (FVStateNode(t, id) for id in 1:length(t.total_n)) + +struct FVStateNode{S} + tree::FVMCTSTree{S} + id::Int64 +end + + +# Accessors for state nodes +@inline state(n::FVStateNode) = n.tree.s_labels[n.id] +@inline total_n(n::FVStateNode) = n.tree.total_n[n.id] + +## No need for `children` or ActionNode just yet + +mutable struct FVMCTSPlanner{S, A, SE, CS <: CoordinationStatistics, RNG <: AbstractRNG} <: AbstractMCTSPlanner{JointMDP{S,A}} + solver::FVMCTSSolver + mdp::JointMDP{S,A} + tree::FVMCTSTree{S,A,CS} + solved_estimate::SE + rng::RNG +end + +""" +Called internally in solve() to create the FVMCTSPlanner where Var-El is the specific action coordination strategy. +Creates VarElStatistics internally with the CG components and the minimum degree ordering heuristic. +""" +function varel_joint_mcts_planner(solver::FVMCTSSolver, + mdp::JointMDP{S,A}, + init_state::S, + ) where {S,A} + + # Get coordination graph components from maximal cliques + #adjmat = coord_graph_adj_mat(mdp) + #@assert size(adjmat)[1] == n_agents(mdp) "Adjacency Matrix does not match number of agents!" + + #adjmatgraph = SimpleGraph(adjmat) + adjmatgraph = coordination_graph(mdp) + + coord_graph_components = maximal_cliques(adjmatgraph) + min_degree_ordering = sortperm(degree(adjmatgraph)) + + # Initialize full agent actions + all_agent_actions = Vector{(actiontype(mdp))}(undef, n_agents(mdp)) + for i = 1:n_agents(mdp) + all_agent_actions[i] = agent_actions(mdp, i) + end + + ve_stats = VarElStatistics{S}(coord_graph_components, min_degree_ordering, + Dict{typeof(init_state),Vector{Vector{Int64}}}(), + Dict{typeof(init_state),Vector{Vector{Int64}}}(), + ) + + # Create tree from the current state + tree = FVMCTSTree(all_agent_actions, ve_stats, + init_state, ReentrantLock(), solver.n_iterations) + se = convert_estimator(solver.estimate_value, solver, mdp) + + return FVMCTSPlanner(solver, mdp, tree, se, solver.rng) +end # end function + +""" +Called internally in solve() to create the FVMCTSPlanner where Max-Plus is the specific action coordination strategy. +Creates MaxPlusStatistics and assumes the various MP flags are sent down from the CoordinationStrategy object given to the solver. +""" +function maxplus_joint_mcts_planner(solver::FVMCTSSolver, + mdp::JointMDP{S,A}, + init_state::S, + message_iters::Int64, + message_norm::Bool, + use_agent_utils::Bool, + node_exploration::Bool, + edge_exploration::Bool, + ) where {S,A} + + @assert (node_exploration || edge_exploration) "At least one of nodes or edges should explore!" + +#= adjmat = coord_graph_adj_mat(mdp) + @assert size(adjmat)[1] == n_agents(mdp) "Adjacency Mat does not match number of agents!" =# + + #adjmatgraph = SimpleGraph(adjmat) + adjmatgraph = coordination_graph(mdp) + @assert size(adjacency_matrix(adjmatgraph))[1] == n_agents(mdp) + + # Initialize full agent actions + # TODO(jkg): this is incorrect? Or we need to override actiontype to refer to agent actions? + all_agent_actions = Vector{(actiontype(mdp))}(undef, n_agents(mdp)) + for i = 1:n_agents(mdp) + all_agent_actions[i] = agent_actions(mdp, i) + end + + mp_stats = MaxPlusStatistics{S}(adjmatgraph, + message_iters, + message_norm, + use_agent_utils, + node_exploration, + edge_exploration, + Dict{S,PerStateMPStats}()) + + # Create tree from the current state + tree = FVMCTSTree(all_agent_actions, mp_stats, + init_state, ReentrantLock(), solver.n_iterations) + se = convert_estimator(solver.estimate_value, solver, mdp) + + return FVMCTSPlanner(solver, mdp, tree, se, solver.rng) +end + + +# Reset tree. +function clear_tree!(planner::FVMCTSPlanner) + + # Clear out state map dict entirely + empty!(planner.tree.state_map) + + # Empty state vectors with state hints + sz = min(planner.solver.n_iterations, 100_000) + + empty!(planner.tree.s_labels) + sizehint!(planner.tree.s_labels, planner.solver.n_iterations) + + # Don't touch all_agent_actions and coord graph component + # Just clear comp stats dict + clear_statistics!(planner.tree.coordination_stats) +end + +MCTS.init_Q(n::Number, mdp::JointMDP, s, c, a) = convert(Float64, n) +MCTS.init_N(n::Number, mdp::JointMDP, s, c, a) = convert(Int, n) + + +# No computation is done in solve; the solver is just given the mdp model that it will work with +# and in case of MaxPlus, the various flags for the MaxPlus behavior +function POMDPs.solve(solver::FVMCTSSolver, mdp::JointMDP) + if typeof(solver.coordination_strategy) == VarEl + return varel_joint_mcts_planner(solver, mdp, initialstate(mdp, solver.rng)) + elseif typeof(solver.coordination_strategy) == MaxPlus + return maxplus_joint_mcts_planner(solver, mdp, initialstate(mdp, solver.rng), + solver.coordination_strategy.message_iters, + solver.coordination_strategy.message_norm, + solver.coordination_strategy.use_agent_utils, + solver.coordination_strategy.node_exploration, + solver.coordination_strategy.edge_exploration) + else + throw(error("Not Implemented")) + end +end + + +# IMP: Overriding action for FVMCTSPlanner here +# NOTE: Hardcoding no tree reuse for now +function POMDPs.action(planner::FVMCTSPlanner, s) + clear_tree!(planner) # Always call this at the top + plan!(planner, s) + action = coordinate_action(planner.mdp, planner.tree, s) + return action +end + +function POMDPModelTools.action_info(planner::FVMCTSPlanner, s) + clear_tree!(planner) # Always call this at the top + plan!(planner, s) + action = coordinate_action(planner.mdp, planner.tree, s) + return action, nothing +end + + +function plan!(planner::FVMCTSPlanner, s) + planner.tree = build_tree(planner, s) +end + +# build_tree can be called on the assumption that no reuse AND tree is reinitialized +function build_tree(planner::FVMCTSPlanner, s::S) where S + + n_iterations = planner.solver.n_iterations + depth = planner.solver.depth + + root = insert_node!(planner.tree, planner, s) + + # Simulate can be multi-threaded + @sync for n = 1:n_iterations + @spawn simulate(planner, root, depth) + end + return planner.tree +end + +function simulate(planner::FVMCTSPlanner, node::FVStateNode, depth::Int64) + + mdp = planner.mdp + rng = planner.rng + s = state(node) + tree = node.tree + + + # once depth is zero return + if isterminal(planner.mdp, s) + return 0.0 + elseif depth == 0 + return estimate_value(planner.solved_estimate, planner.mdp, s, depth) + end + + # Choose best UCB action (NOT an action node as in vanilla MCTS) + ucb_action = coordinate_action(mdp, planner.tree, s, planner.solver.exploration_constant, node.id) + + # Monte Carlo Transition + sp, r = @gen(:sp, :r)(mdp, s, ucb_action, rng) + + spid = lock(tree.lock) do + get(tree.state_map, sp, 0) # may be non-zero even with no tree reuse + end + if spid == 0 + spn = insert_node!(tree, planner, sp) + spid = spn.id + + q = r .+ discount(mdp) * estimate_value(planner.solved_estimate, planner.mdp, sp, depth - 1) + else + q = r .+ discount(mdp) * simulate(planner, FVStateNode(tree, spid) , depth - 1) + end + + # NOTE: Not bothering with tree visualization right now + # Augment N(s) + lock(tree.lock) do + tree.total_n[node.id] += 1 + end + + # Update component statistics! (non-trivial) + # This is related but distinct from initialization + update_statistics!(mdp, tree, s, ucb_action, q) + + return q +end + +@POMDP_require simulate(planner::FVMCTSPlanner, s, depth::Int64) begin + mdp = planner.mdp + P = typeof(mdp) + @assert P <: JointMDP + SV = statetype(P) + @req iterate(::SV) + #@assert typeof(SV) <: AbstractVector + AV = actiontype(P) + @assert typeof(AV) <: AbstractVector + @req discount(::P) + @req isterminal(::P, ::SV) + @subreq insert_node!(planner.tree, planner, s) + @subreq estimate_value(planner.solved_estimate, mdp, s, depth) + @req gen(::P, ::SV, ::AV, ::typeof(planner.rng)) # XXX this is not exactly right - it could be satisfied with transition + + ## Requirements from MMDP Model + @req agent_actions(::P, ::Int64) + @req agent_actions(::P, ::Int64, ::eltype(SV)) + @req n_agents(::P) + @req coordination_graph(::P) + + # TODO: Should we also have this requirement for SV? + @req isequal(::S, ::S) + @req hash(::S) +end + + + +function insert_node!(tree::FVMCTSTree{S,A,CS}, planner::FVMCTSPlanner, + s::S) where {S,A,CS <: CoordinationStatistics} + + lock(tree.lock) do + push!(tree.s_labels, s) + tree.state_map[s] = length(tree.s_labels) + push!(tree.total_n, 1) + + # NOTE: Could actually make actions state-dependent if need be + init_statistics!(tree, planner, s) + end + + # length(tree.s_labels) is just an alias for the number of state nodes + ls = lock(tree.lock) do + length(tree.s_labels) + end + return FVStateNode(tree, ls) +end + +@POMDP_require insert_node!(tree::FVMCTSTree, planner::FVMCTSPlanner, s) begin + + P = typeof(planner.mdp) + AV = actiontype(P) + A = eltype(AV) + SV = typeof(s) + #S = eltype(SV) + + # TODO: Review IQ and IN + IQ = typeof(planner.solver.init_Q) + if !(IQ <: Number) && !(IQ <: Function) + @req init_Q(::IQ, ::P, ::SV, ::Vector{Int64}, ::AbstractVector{A}) + end + + IN = typeof(planner.solver.init_N) + if !(IN <: Number) && !(IN <: Function) + @req init_N(::IQ, ::P, ::SV, ::Vector{Int64}, ::AbstractVector{A}) + end + + @req isequal(::S, ::S) + @req hash(::S) +end diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 0000000..72f58eb --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,3 @@ +[deps] +MultiAgentSysAdmin = "a4538d8c-5052-4f30-aec9-286910cf67a1" +MultiUAVDelivery = "13c59af0-a5df-4589-8a68-75dc1bf2d35a" diff --git a/test/runtests.jl b/test/runtests.jl index 362e9bc..80d9e48 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,68 @@ -using MAMCTS +using FactoredValueMCTS using Test -@testset "MAMCTS.jl" begin - # Write your tests here. +using POMDPs +using MultiAgentSysAdmin +using MultiUAVDelivery + +@testset "FactoredValueMCTS.jl" begin + + @testset "varel" begin + @testset "sysadmin" begin + + @testset "local" begin + mdp = BiSysAdmin() + solver = FVMCTSSolver() + planner = solve(solver, mdp) + s = rand(initialstate(mdp)) + a = action(planner, s) + @test a isa actiontype(mdp) + end + + + @testset "global" begin + mdp = BiSysAdmin(;global_rewards=true) + solver = FVMCTSSolver() + planner = solve(solver, mdp) + s = rand(initialstate(mdp)) + a = action(planner, s) + @test a isa actiontype(mdp) + end + end + end + + @testset "maxplus" begin + @testset "sysadmin" begin + + @testset "local" begin + mdp = BiSysAdmin() + solver = FVMCTSSolver(;coordination_strategy=MaxPlus()) + planner = solve(solver, mdp) + s = rand(initialstate(mdp)) + a = action(planner, s) + @test a isa actiontype(mdp) + end + + + @testset "global" begin + mdp = BiSysAdmin(;global_rewards=true) + solver = FVMCTSSolver(;coordination_strategy=MaxPlus()) + planner = solve(solver, mdp) + s = rand(initialstate(mdp)) + a = action(planner, s) + @test a isa actiontype(mdp) + end + end + + @testset "uav" begin + mdp = FirstOrderMultiUAVDelivery() + solver = FVMCTSSolver(;coordination_strategy=MaxPlus()) + planner = solve(solver, mdp) + s = rand(initialstate(mdp)) + a = action(planner, s) + @test a isa actiontype(mdp) + end + + end + end