From 9a26e63249853ad05bbe033a07504c7f84e10404 Mon Sep 17 00:00:00 2001 From: rejuvyesh Date: Thu, 29 Oct 2020 16:03:25 -0700 Subject: [PATCH 01/17] Import implementation for public release Co-authored-by: Shushman --- Project.toml | 14 + src/MAMCTS.jl | 59 +++- src/fcmcts/fcmcts.jl | 307 ++++++++++++++++++ src/fvmcts/action_coordination/maxplus.jl | 262 +++++++++++++++ src/fvmcts/action_coordination/varel.jl | 280 +++++++++++++++++ src/fvmcts/fv_mcts_vanilla.jl | 367 ++++++++++++++++++++++ 6 files changed, 1288 insertions(+), 1 deletion(-) create mode 100644 src/fcmcts/fcmcts.jl create mode 100644 src/fvmcts/action_coordination/maxplus.jl create mode 100644 src/fvmcts/action_coordination/varel.jl create mode 100644 src/fvmcts/fv_mcts_vanilla.jl diff --git a/Project.toml b/Project.toml index 88f09b1..06d547d 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,20 @@ uuid = "c016a6d7-1193-47d7-896a-d9f14d6b4b26" authors = ["Stanford Intelligent Systems Laboratory"] version = "0.1.0" +[deps] +BeliefUpdaters = "8bb6e9a1-7d73-552c-a44a-e5dc5634aac4" +LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MAPOMDPs = "f50418f3-c642-4efe-9903-417dc09ce874" +MCTS = "e12ccd36-dcad-5f33-8774-9175229e7b33" +POMDPLinter = "f3bd98c0-eb40-45e2-9eb1-f2763262d755" +POMDPModelTools = "08074719-1b2a-587c-a292-00f91cc44415" +POMDPPolicies = "182e52fb-cfd0-5e46-8c26-fd0667c990f4" +POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d" +Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + [compat] julia = "1.5" diff --git a/src/MAMCTS.jl b/src/MAMCTS.jl index 5399134..449ba8a 100644 --- a/src/MAMCTS.jl +++ b/src/MAMCTS.jl @@ -1,5 +1,62 @@ module MAMCTS -# Write your package code here. +using Random +using LinearAlgebra + +using Parameters +using POMDPs +using MAPOMDPs +using POMDPPolicies +using POMDPLinter +using MCTS +using LightGraphs +using BeliefUpdaters + +using MCTS: convert_estimator +import POMDPModelTools + +### +# Factored Value MCTS +# + +abstract type CoordinationStatistics end + +""" +Random Policy factored for each agent. Avoids exploding action space. +""" +struct FactoredRandomPolicy{RNG<:AbstractRNG,P<:JointMDP, U<:Updater} <: Policy + rng::RNG + problem::P + updater::U +end + +FactoredRandomPolicy(problem::JointMDP; rng=Random.GLOBAL_RNG, updater=NothingUpdater()) = FactoredRandomPolicy(rng, problem, updater) + +function POMDPs.action(policy::FactoredRandomPolicy, s) + return [rand(policy.rng, get_agent_actions(policy.problem, i, si)) for (i, si) in enumerate(s)] +end + +POMDPs.solve(solver::RandomSolver, problem::JointMDP) = FactoredRandomPolicy(solver.rng, problem, NothingUpdater()) + +include(joinpath("fvmcts", "fv_mcts_vanilla.jl")) +include(joinpath("fvmcts", "action_coordination", "varel.jl")) +include(joinpath("fvmcts", "action_coordination", "maxplus.jl")) + +export + FVMCTSSolver, + MaxPlus, + VarEl + +### + +### +# Naive Fully Connected Centralized MCTS +# + +include(joinpath("fcmcts", "fcmcts.jl")) +export + FCMCTSSolver + +### end diff --git a/src/fcmcts/fcmcts.jl b/src/fcmcts/fcmcts.jl new file mode 100644 index 0000000..749a161 --- /dev/null +++ b/src/fcmcts/fcmcts.jl @@ -0,0 +1,307 @@ + + +@with_kw mutable struct FCMCTSSolver <: AbstractMCTSSolver + n_iterations::Int64 = 100 + max_time::Float64 = Inf + depth::Int64 = 10 + exploration_constant::Float64 = 1.0 + rng::AbstractRNG = Random.GLOBAL_RNG + estimate_value::Any = RolloutEstimator(RandomSolver(rng)) + init_Q::Any = 0.0 + init_N::Any = 0 + reuse_tree::Bool = false +end + +mutable struct FCMCTSTree{S,A} + # To track if state node in tree already + # NOTE: We don't strictly need this at all if no tree reuse... + state_map::Dict{AbstractVector{S},Int64} + + # these vectors have one entry for each state node + # Only doing factored satistics (for actions), not state components + child_ids::Vector{Vector{Int}} + total_n::Vector{Int} + s_labels::Vector{AbstractVector{S}} + + # TODO(jkg): is this the best way to track stats? + # these vectors have one entry for each action node + n::Vector{Int64} + q::Vector{Float64} + a_labels::Vector{AbstractVector{A}} + + lock::ReentrantLock +end + +function FCMCTSTree{S,A}(init_state::AbstractVector{S}, lock::ReentrantLock, sz::Int=1000) where {S,A} + sz = min(sz, 100_000) + return FCMCTSTree{S,A}(Dict{typeof(init_state),Int64}(), + sizehint!(Vector{Int}[], sz), + sizehint!(Int[], sz), + sizehint!(typeof(init_state)[], sz), + Int64[], + Float64[], + sizehint!(Vector{A}[], sz), + lock) +end + +Base.isempty(t::FCMCTSTree) = isempty(t.state_map) +state_nodes(t::FCMCTSTree) = (FCStateNode(t, id) for id in 1:length(t.total_n)) + +struct FCStateNode{S,A} + tree::FCMCTSTree{S,A} + id::Int64 +end + +# accessors for state nodes +@inline state(n::FCStateNode) = lock(n.tree.lock) do + n.tree.s_labels[n.id] +end +@inline total_n(n::FCStateNode) = n.tree.total_n[n.id] +@inline children(n::FCStateNode) = (FCActionNode(n.tree, id) for id in n.tree.child_ids[n.id]) + +# Adding action node info +struct FCActionNode{S,A} + tree::FCMCTSTree{S,A} + id::Int64 +end + +# accessors for action nodes +@inline POMDPs.action(n::FCActionNode) = n.tree.a_labels[n.id] + + +mutable struct FCMCTSPlanner{S,A,SE,RNG<:AbstractRNG} <: AbstractMCTSPlanner{JointMDP{S,A}} + solver::FCMCTSSolver + mdp::JointMDP{S,A} + tree::FCMCTSTree{S,A} + solved_estimate::SE + rng::RNG +end + +function FCMCTSPlanner(solver::FCMCTSSolver, mdp::JointMDP{S,A}) where {S,A} + init_state = initialstate(mdp, solver.rng) + tree = FCMCTSTree{S,A}(init_state, ReentrantLock(), solver.n_iterations) + se = convert_estimator(solver.estimate_value, solver, mdp) + return FCMCTSPlanner(solver, mdp, tree, se, solver.rng) +end + + +function clear_tree!(planner::FCMCTSPlanner) + lock(planner.tree.lock) do + # Clear out state hash dict entirely + empty!(planner.tree.state_map) + # Empty state vectors with state hints + sz = min(planner.solver.n_iterations, 100_000) + + empty!(planner.tree.s_labels) + sizehint!(planner.tree.s_labels, sz) + + empty!(planner.tree.child_ids) + sizehint!(planner.tree.child_ids, sz) + empty!(planner.tree.total_n) + sizehint!(planner.tree.total_n, sz) + + empty!(planner.tree.n) + empty!(planner.tree.q) + empty!(planner.tree.a_labels) + end +end + +function POMDPs.solve(solver::FCMCTSSolver, mdp::JointMDP) + return FCMCTSPlanner(solver, mdp) +end + +function POMDPs.action(planner::FCMCTSPlanner, s) + clear_tree!(planner) + plan!(planner, s) + s_lut = lock(planner.tree.lock) do + planner.tree.state_map[s] + end + best_anode = lock(planner.tree.lock) do + compute_best_action_node(planner.mdp, planner.tree, FCStateNode(planner.tree, s_lut)) # c = 0.0 by default + end + + best_a = lock(planner.tree.lock) do + action(best_anode) + end + return best_a +end + +function POMDPModelTools.action_info(planner::FCMCTSPlanner, s) + a = POMDPs.action(planner, s) + return a, nothing +end + + +function plan!(planner::FCMCTSPlanner, s) + planner.tree = build_tree(planner, s) +end + +function build_tree(planner::FCMCTSPlanner, s::AbstractVector{S}) where S + n_iterations = planner.solver.n_iterations + depth = planner.solver.depth + + root = insert_node!(planner.tree, planner, s) + # build the tree + @sync for n = 1:n_iterations + @spawn simulate(planner, root, depth) + end + return planner.tree +end + +function simulate(planner::FCMCTSPlanner, node::FCStateNode, depth::Int64) + + mdp = planner.mdp + rng = planner.rng + s = state(node) + tree = node.tree + + # once depth is zero return + if isterminal(planner.mdp, s) + return 0.0 + elseif depth == 0 + return sum(estimate_value(planner.solved_estimate, planner.mdp, s, depth)) + end + + # Choose best UCB action (NOT an action node) + ucb_action_node = lock(planner.tree.lock) do + compute_best_action_node(mdp, planner.tree, node, planner.solver.exploration_constant) + end + ucb_action = lock(planner.tree.lock) do + action(ucb_action_node) + end + + # @show ucb_action + # MC Transition + sp, r = gen(DDNOut(:sp, :r), mdp, s, ucb_action, rng) + + # NOTE(jkg): just summing up the rewards to get a single value + # TODO: should we divide by n_agents? + r = sum(r) + + spid = lock(tree.lock) do + get(tree.state_map, sp, 0) # may be non-zero even with no tree reuse + end + if spid == 0 + spn = insert_node!(tree, planner, sp) + spid = spn.id + # TODO estimate_value + # NOTE(jkg): again just summing up the values to get a single value + q = r + discount(mdp) * sum(estimate_value(planner.solved_estimate, planner.mdp, sp, depth - 1)) + else + q = r + discount(mdp) * simulate(planner, FCStateNode(tree, spid) , depth - 1) + end + + ## Not bothering with tree vis right now + # Augment N(s) + lock(tree.lock) do + tree.total_n[node.id] += 1 + tree.n[ucb_action_node.id] += 1 + tree.q[ucb_action_node.id] += (q - tree.q[ucb_action_node.id]) / tree.n[ucb_action_node.id] + end + + return q +end + +# NOTE: This is a bit different from https://github.com/JuliaPOMDP/MCTS.jl/blob/master/src/vanilla.jl#L328 +function insert_node!(tree::FCMCTSTree, planner::FCMCTSPlanner, s) + + lock(tree.lock) do + push!(tree.s_labels, s) + tree.state_map[s] = length(tree.s_labels) + push!(tree.child_ids, []) + end + + # NOTE: Doing state-dep actions here the JointMDP way + state_dep_jtactions = vec(map(collect, Iterators.product((get_agent_actions(planner.mdp, i, si) for (i, si) in enumerate(s))...))) + total_n = 0 + + for a in state_dep_jtactions + n = init_N(planner.solver.init_N, planner.mdp, s, a) + total_n += n + lock(tree.lock) do + push!(tree.n, n) + push!(tree.q, init_Q(planner.solver.init_Q, planner.mdp, s, a)) + push!(tree.a_labels, a) + push!(last(tree.child_ids), length(tree.n)) + end + end + lock(tree.lock) do + push!(tree.total_n, total_n) + end + ln = lock(tree.lock) do + length(tree.total_n) + end + return FCStateNode(tree, ln) +end + + + +# NOTE: The logic here is a bit simpler than https://github.com/JuliaPOMDP/MCTS.jl/blob/master/src/vanilla.jl#L390 +# Double check that this is still the behavior we want +function compute_best_action_node(mdp::JointMDP, tree::FCMCTSTree, node::FCStateNode, c::Float64=0.0) + best_val = -Inf # The Q value + best = first(children(node)) + + sn = total_n(node) + + child_nodes = children(node) + + for sanode in child_nodes + + val = tree.q[sanode.id] + c*sqrt(log(sn + 1)/ (tree.n[sanode.id] + 1)) + + + if val > best_val + best_val = val + best = sanode + end + end + return best +end + +POMDPLinter.@POMDP_require simulate(planner::FCMCTSPlanner, s, depth::Int64) begin + mdp = planner.mdp + P = typeof(mdp) + @assert P <: JointMDP # req does different thing? + SV = statetype(P) + @assert typeof(SV) <: AbstractVector # TODO: Is this correct? + AV = actiontype(P) + @assert typeof(A) <: AbstractVector + @req discount(::P) + @req isterminal(::P, ::SV) + @subreq insert_node!(planner.tree, planner, s) + @subreq estimate_value(planner.solved_estimate, mdp, s, depth) + @req gen(::DDNOut{(:sp, :r)}, ::P, ::SV, ::A, ::typeof(planner.rng)) + + # MMDP reqs + @req get_agent_actions(::P, ::Int64) + @req get_agent_actions(::P, ::Int64, ::eltype(SV)) + @req n_agents(::P) + + # TODO: Should we also have this requirement for SV? + @req isequal(::S, ::S) + @req hash(::S) +end + +POMDPLinter.@POMDP_require insert_node!(tree::FCMCTSTree, planner::FCMCTSPlanner, s) begin + + P = typeof(planner.mdp) + AV = actiontype(P) + A = eltype(AV) + SV = typeof(s) + S = eltype(SV) + + # TODO: Review IQ and IN + IQ = typeof(planner.solver.init_Q) + if !(IQ <: Number) && !(IQ <: Function) + @req init_Q(::IQ, ::P, ::S, ::Vector{Int64}, ::AbstractVector{A}) + end + + IN = typeof(planner.solver.init_N) + if !(IN <: Number) && !(IN <: Function) + @req init_N(::IQ, ::P, ::S, ::Vector{Int64}, ::AbstractVector{A}) + end + + @req isequal(::S, ::S) + @req hash(::S) +end diff --git a/src/fvmcts/action_coordination/maxplus.jl b/src/fvmcts/action_coordination/maxplus.jl new file mode 100644 index 0000000..d7945b1 --- /dev/null +++ b/src/fvmcts/action_coordination/maxplus.jl @@ -0,0 +1,262 @@ +## Now do for message passing +# NOTE: Matrix implicitly assumes all agents have same number of actions +mutable struct PerStateMPStats + agent_action_n::Matrix{Int64} # N X A + agent_action_q::Matrix{Float64} + edge_action_n::Matrix{Int64} # |E| X A^2 + edge_action_q::Matrix{Float64} +end + +# NOTE: Putting params here is a little ugly but coordinate_action can't have them since VarEl doesn't use those args +mutable struct MaxPlusStatistics{S} <: CoordinationStatistics + adjmatgraph::SimpleGraph + message_iters::Int64 + message_norm::Bool + use_agent_utils::Bool + node_exploration::Bool + edge_exploration::Bool # NOTE: One of this or node exploration must be true + all_states_stats::Dict{AbstractVector{S},PerStateMPStats} +end + +function clear_statistics!(mp_stats::MaxPlusStatistics) + empty!(mp_stats.all_states_stats) +end + +function update_statistics!(mdp::JointMDP{S,A}, tree::JointMCTSTree{S,A,MaxPlusStatistics{S}}, + s::AbstractVector{S}, ucb_action::AbstractVector{A}, q::AbstractVector{Float64}) where {S,A} + + state_stats = tree.coordination_stats.all_states_stats[s] + n_agents = length(s) + + # Update per agent action stats + for i = 1:n_agents + ac_idx = get_agent_actionindex(mdp, i, ucb_action[i]) + lock(tree.lock) do + state_stats.agent_action_n[i, ac_idx] += 1 + state_stats.agent_action_q[i, ac_idx] += + (q[i] - state_stats.agent_action_q[i, ac_idx]) / state_stats.agent_action_n[i, ac_idx] + end + end + + # Now update edge action stats + for (idx, e) in enumerate(edges(tree.coordination_stats.adjmatgraph)) + # NOTE: Need to be careful about action ordering + # Being more general to have unequal agent actions + edge_comp = (e.src,e.dst) + edge_tup = Tuple(1:length(tree.all_agent_actions[c]) for c in edge_comp) + edge_ac_idx = LinearIndices(edge_tup)[get_agent_actionindex(mdp, e.src, ucb_action[e.src]), + get_agent_actionindex(mdp, e.dst, ucb_action[e.dst])] + q_edge_value = q[e.src] + q[e.dst] + + lock(tree.lock) do + state_stats.edge_action_n[idx, edge_ac_idx] += 1 + state_stats.edge_action_q[idx, edge_ac_idx] += + (q_edge_value - state_stats.edge_action_q[idx, edge_ac_idx]) / state_stats.edge_action_n[idx, edge_ac_idx] + end + end + + lock(tree.lock) do + tree.coordination_stats.all_states_stats[s] = state_stats + end + +end + +function init_statistics!(tree::JointMCTSTree{S,A,MaxPlusStatistics{S}}, planner::JointMCTSPlanner, + s::AbstractVector{S}) where {S,A} + + n_agents = length(s) + + # NOTE: Assuming all agents have the same actions here + n_all_actions = length(tree.all_agent_actions[1]) + + agent_action_n = zeros(Int64, n_agents, n_all_actions) + agent_action_q = zeros(Float64, n_agents, n_all_actions) + + # Loop over agents and then actions + # TODO: Need to define init_N and init_Q for single agent + for i = 1:n_agents + for (j, ac) in enumerate(tree.all_agent_actions[i]) + agent_action_n[i, j] = init_N(planner.solver.init_N, planner.mdp, s, i, ac) + agent_action_q[i, j] = init_Q(planner.solver.init_Q, planner.mdp, s, i, ac) + end + end + + n_edges = ne(tree.coordination_stats.adjmatgraph) + edge_action_n = zeros(Int64, n_edges, n_all_actions^2) + edge_action_q = zeros(Float64, n_edges, n_all_actions^2) + + # Loop over edges and then action_i \times action_j + for (idx, e) in enumerate(edges(tree.coordination_stats.adjmatgraph)) + edge_comp = (e.src, e.dst) + n_edge_actions = prod([length(tree.all_agent_actions[c]) for c in edge_comp]) + edge_tup = Tuple(1:length(tree.all_agent_actions[c]) for c in edge_comp) + + for edge_ac_idx = 1:n_edge_actions + ct_idx = CartesianIndices(edge_tup)[edge_ac_idx] + edge_action = [tree.all_agent_actions[c] for c in edge_comp] + + edge_action_n[idx, edge_ac_idx] = init_N(planner.solver.init_N, planner.mdp, s, edge_comp, edge_action) + edge_action_q[idx, edge_ac_idx] = init_Q(planner.solver.init_Q, planner.mdp, s, edge_comp, edge_action) + end + end + + state_stats = PerStateMPStats(agent_action_n, agent_action_q, edge_action_n, edge_action_q) + lock(tree.lock) do + tree.coordination_stats.all_states_stats[s] = state_stats + end +end + +# NOTE: Following deepCG +function coordinate_action(mdp::JointMDP{S,A}, tree::JointMCTSTree{S,A,MaxPlusStatistics{S}}, s::AbstractVector{S}, + exploration_constant::Float64=0.0, node_id::Int64=0) where {S,A} + + state_stats = lock(tree.lock) do + tree.coordination_stats.all_states_stats[s] + end + adjgraphmat = lock(tree.lock) do + tree.coordination_stats.adjmatgraph + end + k = tree.coordination_stats.message_iters + message_norm = tree.coordination_stats.message_norm + + n_agents = length(s) + state_agent_actions = [get_agent_actions(mdp, i, si) for (i, si) in enumerate(s)] + n_all_actions = length(tree.all_agent_actions[1]) + n_edges = ne(tree.coordination_stats.adjmatgraph) + + # Init forward and backward messages and q0 + fwd_messages = zeros(Float64, n_edges, n_all_actions) + bwd_messages = zeros(Float64, n_edges, n_all_actions) + + if tree.coordination_stats.use_agent_utils + q_values = state_stats.agent_action_q / n_agents + else + q_values = zeros(size(state_stats.agent_action_q)) + end + + state_total_n = lock(tree.lock) do + (node_id > 0) ? tree.total_n[node_id] : 1 + end + + + # Iterate over passes + for t = 1:k + fnormdiff, bnormdiff = perform_message_passing!(fwd_messages, bwd_messages, mdp, tree.all_agent_actions, + adjgraphmat, state_agent_actions, n_edges, q_values, message_norm, + 0, state_stats, state_total_n) + + if !tree.coordination_stats.use_agent_utils + q_values = zeros(size(state_stats.agent_action_q)) + end + + # Update Q value with messages + for i = 1:n_agents + + # need indices of all edges that agent is involved in + nbrs = neighbors(tree.coordination_stats.adjmatgraph, i) + + edgelist = collect(edges(adjgraphmat)) + + if tree.coordination_stats.use_agent_utils + @views q_values[i, :] = state_stats.agent_action_q[i, :]/n_agents + end + for n in nbrs + if Edge(i,n) in edgelist # use backward message + q_values[i,:] += bwd_messages[findfirst(isequal(Edge(i,n)), edgelist), :] + elseif Edge(n,i) in edgelist + q_values[i,:] += fwd_messages[findfirst(isequal(Edge(n,i)), edgelist), :] + else + @warn "Neither edge found!" + end + end + end + + # If converged, break + if isapprox(fnormdiff, 0.0) && isapprox(bnormdiff, 0.0) + break + end + end # for t = 1:k + + # If edge exploration flag enabled, do a final exploration bonus + # TODO: Code reuse with normal message passing; consider modularizing + if tree.coordination_stats.edge_exploration + perform_message_passing!(fwd_messages, bwd_messages, mdp, tree.all_agent_actions, + adjgraphmat, state_agent_actions, n_edges, q_values, message_norm, + exploration_constant, state_stats, state_total_n) + end # if edge_exploration + + + # Maximize q values for agents + best_action = Vector{A}(undef, n_agents) + for i = 1:n_agents + + # NOTE: Again can't just iterate over agent actions as it may be a subset + exp_q_values = zeros(length(state_agent_actions[i])) + if tree.coordination_stats.node_exploration + for (idx, ai) in enumerate(state_agent_actions[i]) + ai_idx = get_agent_actionindex(mdp, i, ai) + exp_q_values[idx] = q_values[i, ai_idx] + exploration_constant*sqrt((log(state_total_n + 1.0))/(state_stats.agent_action_n[i, ai_idx] + 1.0)) + end + else + for (idx, ai) in enumerate(state_agent_actions[i]) + ai_idx = get_agent_actionindex(mdp, i, ai) + exp_q_values[idx] = q_values[i, ai_idx] + end + end + + # NOTE: Can now look up index in exp_q_values and then again look at state_agent_actions + _, idx = findmax(exp_q_values) + best_action[i] = state_agent_actions[i][idx] + end + + return best_action +end + +function perform_message_passing!(fwd_messages::AbstractArray{F,2}, bwd_messages::AbstractArray{F,2}, + mdp, all_agent_actions, + adjgraphmat, state_agent_actions, n_edges::Int, q_values, message_norm, + exploration_constant, state_stats, state_total_n) where {F} + # Iterate over edges + fwd_messages_old = deepcopy(fwd_messages) + bwd_messages_old = deepcopy(bwd_messages) + for (e_idx, e) in enumerate(edges(adjgraphmat)) + + i = e.src + j = e.dst + edge_tup_indices = LinearIndices(Tuple(1:length(all_agent_actions[c]) for c in (i,j))) + + # forward: maximize sender + # NOTE: Can't do enumerate as action set might be smaller + # Need to look up global index of agent action and use that + # Need to break up vectorized loop + @inbounds for aj in state_agent_actions[j] + aj_idx = get_agent_actionindex(mdp, j, aj) + fwd_message_vals = zeros(length(state_agent_actions[i])) + # TODO: Should we use inbounds here again? + @inbounds for (idx, ai) in enumerate(state_agent_actions[i]) + ai_idx = get_agent_actionindex(mdp, i, ai) + fwd_message_vals[idx] = q_values[i, ai_idx] - bwd_messages_old[e_idx, ai_idx] + state_stats.edge_action_q[e_idx, edge_tup_indices[ai_idx, aj_idx]]/n_edges + exploration_constant * sqrt( (log(state_total_n + 1.0)) / (state_stats.edge_action_n[e_idx, edge_tup_indices[ai_idx, aj_idx]] + 1) ) + end + fwd_messages[e_idx, aj_idx] = maximum(fwd_message_vals) + end + + @inbounds for ai in state_agent_actions[i] + ai_idx = get_agent_actionindex(mdp, i, ai) + bwd_message_vals = zeros(length(state_agent_actions[j])) + @inbounds for (idx, aj) in enumerate(state_agent_actions[j]) + aj_idx = get_agent_actionindex(mdp, j, aj) + bwd_message_vals[idx] = q_values[j, aj_idx] - fwd_messages_old[e_idx, aj_idx] + state_stats.edge_action_q[e_idx, edge_tup_indices[ai_idx, aj_idx]]/n_edges + exploration_constant * sqrt( (log(state_total_n + 1.0))/ (state_stats.edge_action_n[e_idx, edge_tup_indices[ai_idx, aj_idx]] + 1) ) + end + bwd_messages[e_idx, ai_idx] = maximum(bwd_message_vals) + end + + if message_norm + @views fwd_messages[e_idx, :] .-= sum(fwd_messages[e_idx, :])/length(fwd_messages[e_idx, :]) + @views bwd_messages[e_idx, :] .-= sum(bwd_messages[e_idx, :])/length(bwd_messages[e_idx, :]) + end + + end # (idx,edges) in enumerate(edges) + + # Return norm of message difference + return norm(fwd_messages - fwd_messages_old), norm(bwd_messages - bwd_messages_old) +end diff --git a/src/fvmcts/action_coordination/varel.jl b/src/fvmcts/action_coordination/varel.jl new file mode 100644 index 0000000..3e5d887 --- /dev/null +++ b/src/fvmcts/action_coordination/varel.jl @@ -0,0 +1,280 @@ +mutable struct VarElStatistics{S} <: CoordinationStatistics + coord_graph_components::Vector{Vector{Int64}} + min_degree_ordering::Vector{Int64} + n_component_stats::Dict{AbstractVector{S},Vector{Vector{Int64}}} + q_component_stats::Dict{AbstractVector{S},Vector{Vector{Float64}}} +end + +function clear_statistics!(ve_stats::VarElStatistics) + empty!(ve_stats.n_component_stats) + empty!(ve_stats.q_component_stats) +end + + +function coordinate_action(mdp::JointMDP{S,A}, tree::JointMCTSTree{S,A,VarElStatistics{S}}, s::AbstractVector{S}, + exploration_constant::Float64=0.0, node_id::Int64=0) where {S,A} + + n_agents = length(s) + best_action_idxs = MVector{n_agents}([-1 for i in 1:n_agents]) + + # + # !Note: Acquire lock so as to avoid race + # + state_q_stats = lock(tree.lock) do + tree.coordination_stats.q_component_stats[s] + end + state_n_stats = lock(tree.lock) do + tree.coordination_stats.n_component_stats[s] + end + state_total_n = lock(tree.lock) do + (node_id > 0) ? tree.total_n[node_id] : 0 + end + + # Maintain set of potential functions + # NOTE: Hashing a vector here + potential_fns = Dict{Vector{Int64},Vector{Float64}}() + for (comp, q_stats) in zip(tree.coordination_stats.coord_graph_components, state_q_stats) + potential_fns[comp] = q_stats + end + + # Need this for reverse process + # Maps agent to other elements in best response functions and corresponding set of actions + # E.g. Agent 2 -> (3,4) in its best response and corresponding vector of agent 2 best actions + best_response_fns = Dict{Int64,Tuple{Vector{Int64},Vector{Int64}}}() + + state_dep_actions = [get_agent_actions(mdp, i, si) for (i, si) in enumerate(s)] + + # Iterate over variable ordering + # Need to maintain intermediate tables + for ag_idx in tree.coordination_stats.min_degree_ordering + + # Lookup factors with agent in them and simultaneously construct + # members of new potential function, and delete old factors + agent_factors = Vector{Vector{Int64}}(undef, 0) + new_potential_members = Vector{Int64}(undef, 0) + for k in collect(keys(potential_fns)) + if ag_idx in k + + # Agent to-be-eliminated is in factor + push!(agent_factors, k) + + # Construct key for new potential as union of all others + # except ag_idx + for ag in k + if ag != ag_idx && ~(ag in new_potential_members) + push!(new_potential_members, ag) + end + end + end + end + + if isempty(new_potential_members) == true + # No out neighbors..either at beginning or end of ordering + @assert agent_factors == [[ag_idx]] "agent_factors $(agent_factors) is not just [ag_idx] $([ag_idx])!" + best_action_idxs[ag_idx] = _best_actionindex_empty(potential_fns, + state_dep_actions, + tree.all_agent_actions, + ag_idx) + + else + + # Generate new potential function + # AND the best response vector for eliminated agent + n_comp_actions = prod([length(tree.all_agent_actions[c]) for c in new_potential_members]) + + # NOTE: Tuples should ALWAYS use tree.all_agent_actions for indexing + comp_tup = Tuple(1:length(tree.all_agent_actions[c]) for c in new_potential_members) + + # Initialize q-stats for new potential and best response action vector + # will be inserted into corresponding dictionaries at the end + new_potential_stats = Vector{Float64}(undef, n_comp_actions) + best_response_vect = Vector{Int64}(undef, n_comp_actions) + + # Iterate over new potential joint actions and compute new payoff and best response + for comp_ac_idx = 1:n_comp_actions + + # Get joint action for other members in potential + ct_idx = CartesianIndices(comp_tup)[comp_ac_idx] + + # For maximizing over agent actions + # As before, we now need to init with -Inf + ag_ac_values = zeros(length(tree.all_agent_actions[ag_idx])) + + # TODO: Agent actions should already be in order + # Only do anything if action legal + for (ag_ac_idx, ag_ac) in enumerate(tree.all_agent_actions[ag_idx]) + + if ag_ac in state_dep_actions[ag_idx] + + # Need to look up corresponding stats from agent_factors + for factor in agent_factors + + # NOTE: Need to reconcile the ORDER of ag_idx in factor + factor_action_idxs = MVector{length(factor),Int64}(undef) + + for (idx, f) in enumerate(factor) + + # if f is ag_idx, set corresponding factor action to ag_ac + if f == ag_idx + factor_action_idxs[idx] = ag_ac_idx + else + # Lookup index for corresp. agent action in ct_idx + new_pot_idx = findfirst(isequal(f), new_potential_members) + factor_action_idxs[idx] = ct_idx[new_pot_idx] + end # f == ag_idx + end + + # NOW we can look up the stats of the factor + factor_tup = Tuple(1:length(tree.all_agent_actions[c]) for c in factor) + factor_action_linidx = LinearIndices(factor_tup)[factor_action_idxs...] + + ag_ac_values[ag_ac_idx] += potential_fns[factor][factor_action_linidx] + + # Additionally add exploration stats if factor in original set + factor_comp_idx = findfirst(isequal(factor), tree.coordination_stats.coord_graph_components) + if state_total_n > 0 && ~(isnothing(factor_comp_idx)) # NOTE: Julia1.1 + ag_ac_values[ag_ac_idx] += exploration_constant * sqrt((log(state_total_n+1.0))/(state_n_stats[factor_comp_idx][factor_action_linidx]+1.0)) + end + end # factor in agent_factors + else + ag_ac_values[ag_ac_idx] = -Inf + end # ag_ac in state_dep_actions + end # ag_ac_idx = 1:length(tree.all_agent_actions[ag_idx]) + + + # Now we lookup ag_ac_values for the best value to be put in new_potential_stats + # and the best index to be put in best_response_vect + # NOTE: The -Inf mask should ensure only legal idxs chosen + # If all ag_ac_values equal, should we sample randomly? + best_val, best_idx = findmax(ag_ac_values) + + new_potential_stats[comp_ac_idx] = best_val + best_response_vect[comp_ac_idx] = best_idx + end # comp_ac_idx in n_comp_actions + + # Finally, we enter new stats vector and best response vector back to dicts + potential_fns[new_potential_members] = new_potential_stats + best_response_fns[ag_idx] = (new_potential_members, best_response_vect) + end # isempty(new_potential_members) + + # Delete keys in agent_factors from potential fns since variable has been eliminated + for factor in agent_factors + delete!(potential_fns, factor) + end + end # ag_idx in min_deg_ordering + + # NOTE: At this point, best_action_idxs has at least one entry...for the last action obtained + @assert !all(isequal(-1), best_action_idxs) "best_action_idxs is still undefined!" + + # Message passing in reverse order to recover best action + for ag_idx in Base.Iterators.reverse(tree.coordination_stats.min_degree_ordering) + + # Only do something if best action already not obtained + if best_action_idxs[ag_idx] == -1 + + # Should just be able to lookup best response function + (agents, best_response_vect) = best_response_fns[ag_idx] + + # Members of agents should already have their best action defined + agent_ac_tup = Tuple(1:length(tree.all_agent_actions[c]) for c in agents) + best_agents_action_idxs = [best_action_idxs[ag] for ag in agents] + best_response_idx = LinearIndices(agent_ac_tup)[best_agents_action_idxs...] + + # Assign best action for ag_idx + best_action_idxs[ag_idx] = best_response_vect[best_response_idx] + end # isdefined + end + + # Finally, return best action by iterating over best action indices + # NOTE: best_action should use state-dep actions to reverse index + best_action = [tree.all_agent_actions[ag][idx] for (ag, idx) in enumerate(best_action_idxs)] + + return best_action +end + + +function update_statistics!(mdp::JointMDP{S,A}, tree::JointMCTSTree{S,A,VarElStatistics{S}}, + s::AbstractVector{S}, ucb_action::AbstractVector{A}, q::AbstractVector{Float64}) where {S,A} + + n_agents = length(s) + + for (idx, comp) in enumerate(tree.coordination_stats.coord_graph_components) + + # Create cartesian index tuple + comp_tup = Tuple(1:length(tree.all_agent_actions[c]) for c in comp) + + # RECOVER local action corresp. to ucb action + # TODO: Review this carefully. Need @req for action index for agent. + local_action = [ucb_action[c] for c in comp] + local_action_idxs = [get_agent_actionindex(mdp, c, a) for (a, c) in zip(local_action, comp)] + + comp_ac_idx = LinearIndices(comp_tup)[local_action_idxs...] + + # NOTE: NOW we can update stats. Could generalize incremental update more here + lock(tree.lock) do + tree.coordination_stats.n_component_stats[s][idx][comp_ac_idx] += 1 + q_comp_value = sum(q[c] for c in comp) + tree.coordination_stats.q_component_stats[s][idx][comp_ac_idx] += + (q_comp_value - tree.coordination_stats.q_component_stats[s][idx][comp_ac_idx]) / tree.coordination_stats.n_component_stats[s][idx][comp_ac_idx] + end + end +end + + +function init_statistics!(tree::JointMCTSTree{S,A,VarElStatistics{S}}, planner::JointMCTSPlanner, + s::AbstractVector{S}) where {S,A} + + n_comps = length(tree.coordination_stats.coord_graph_components) + n_component_stats = Vector{Vector{Int64}}(undef, n_comps) + q_component_stats = Vector{Vector{Float64}}(undef, n_comps) + + n_agents = length(s) + + # TODO: Could actually make actions state-dependent if need be + for (idx, comp) in enumerate(tree.coordination_stats.coord_graph_components) + + n_comp_actions = prod([length(tree.all_agent_actions[c]) for c in comp]) + + n_component_stats[idx] = Vector{Int64}(undef, n_comp_actions) + q_component_stats[idx] = Vector{Float64}(undef, n_comp_actions) + + comp_tup = Tuple(1:length(tree.all_agent_actions[c]) for c in comp) + + for comp_ac_idx = 1:n_comp_actions + + # Generate action subcomponent and call init_Q and init_N for it + ct_idx = CartesianIndices(comp_tup)[comp_ac_idx] # Tuple corresp to + local_action = [tree.all_agent_actions[c][ai] for (c, ai) in zip(comp, Tuple(ct_idx))] + + # NOTE: init_N and init_Q are functions of component AND local action + # TODO(jkg): init_N and init_Q need to be defined + n_component_stats[idx][comp_ac_idx] = init_N(planner.solver.init_N, planner.mdp, s, comp, local_action) + q_component_stats[idx][comp_ac_idx] = init_Q(planner.solver.init_Q, planner.mdp, s, comp, local_action) + end + end + + # Update tree member + lock(tree.lock) do + tree.coordination_stats.n_component_stats[s] = n_component_stats + tree.coordination_stats.q_component_stats[s] = q_component_stats + end +end + +@inline function _best_actionindex_empty(potential_fns, state_dep_actions, all_agent_actions, ag_idx) + # NOTE: This is inefficient but necessary for state-dep actions? + if length(state_dep_actions[ag_idx]) == length(all_agent_actions[ag_idx]) + _, best_ac_idx = findmax(potential_fns[[ag_idx]]) + else + # Now we need to choose the best index from among legal actions + # Create an array with illegal actions having -Inf and then fill legal vals + # TODO: More efficient way to do this? + masked_action_vals = fill(-Inf, length(all_agent_actions[ag_idx])) + for (iac, ac) in enumerate(all_agent_actions[ag_idx]) + if ac in state_dep_actions[ag_idx] + masked_action_vals[iac] = potential_fns[[ag_idx]][iac] + end + end + _, best_ac_idx = findmax(masked_action_vals) + end + return best_ac_idx +end diff --git a/src/fvmcts/fv_mcts_vanilla.jl b/src/fvmcts/fv_mcts_vanilla.jl new file mode 100644 index 0000000..a36e1c3 --- /dev/null +++ b/src/fvmcts/fv_mcts_vanilla.jl @@ -0,0 +1,367 @@ +## We can use the following directly without modification +## 1. domain_knowledge.jl for Rollout, init_Q and init_N functions +## 2. FVMCTSSolver for representing the overall MCTS (underlying things will change) +using StaticArrays +using Parameters +using Base.Threads: @spawn + +abstract type AbstractCoordinationStrategy end + +struct VarEl <: AbstractCoordinationStrategy +end + +@with_kw struct MaxPlus <:AbstractCoordinationStrategy + message_iters::Int64 = 10 + message_norm::Bool = true + use_agent_utils::Bool = false + node_exploration::Bool = true + edge_exploration::Bool = true +end + +@with_kw mutable struct FVMCTSSolver <: AbstractMCTSSolver + n_iterations::Int64 = 100 + max_time::Float64 = Inf + depth::Int64 = 10 + exploration_constant::Float64 = 1.0 + rng::AbstractRNG = Random.GLOBAL_RNG + estimate_value::Any = RolloutEstimator(RandomSolver(rng)) + init_Q::Any = 0.0 + init_N::Any = 0 + reuse_tree::Bool = false + coordination_strategy::Any = VarEl() +end + + +# JointMCTS tree has to be different, to efficiently encode Q-stats +mutable struct JointMCTSTree{S,A,CS<:CoordinationStatistics} + + # To track if state node in tree already + # NOTE: We don't strictly need this at all if no tree reuse... + state_map::Dict{AbstractVector{S},Int64} + + # these vectors have one entry for each state node + # Only doing factored satistics (for actions), not state components + # Looks like we don't need child_ids + total_n::Vector{Int} + s_labels::Vector{AbstractVector{S}} + + # Track stats for all action components over the n_iterations + all_agent_actions::Vector{AbstractVector{A}} + + coordination_stats::CS + lock::ReentrantLock + # Don't need a_labels because need to do var-el for best action anyway +end + +# Just a glorified wrapper now +function JointMCTSTree(all_agent_actions::Vector{AbstractVector{A}}, + coordination_stats::CS, + init_state::AbstractVector{S}, + lock::ReentrantLock, + sz::Int64=10000) where {S, A, CS <: CoordinationStatistics} + + return JointMCTSTree{S,A,CS}(Dict{typeof(init_state),Int64}(), + sizehint!(Int[], sz), + sizehint!(typeof(init_state)[], sz), + all_agent_actions, + coordination_stats, + lock + ) +end # function + + + +Base.isempty(t::JointMCTSTree) = isempty(t.state_map) +state_nodes(t::JointMCTSTree) = (JointStateNode(t, id) for id in 1:length(t.total_n)) + +struct JointStateNode{S} + tree::JointMCTSTree{S} + id::Int64 +end + +#get_state_node(tree::JointMCTSTree, id) = JointStateNode(tree, id) + +# accessors for state nodes +@inline state(n::JointStateNode) = n.tree.s_labels[n.id] +@inline total_n(n::JointStateNode) = n.tree.total_n[n.id] + +## No need for `children` or ActionNode just yet + +mutable struct JointMCTSPlanner{S, A, SE, CS <: CoordinationStatistics, RNG <: AbstractRNG} <: AbstractMCTSPlanner{JointMDP{S,A}} + solver::FVMCTSSolver + mdp::JointMDP{S,A} + tree::JointMCTSTree{S,A,CS} + solved_estimate::SE + rng::RNG +end + +function varel_joint_mcts_planner(solver::FVMCTSSolver, + mdp::JointMDP{S,A}, + init_state::AbstractVector{S}, + ) where {S,A} + + # Get coord graph comps from maximal cliques of graph + adjmat = coord_graph_adj_mat(mdp) + @assert size(adjmat)[1] == n_agents(mdp) "Adjacency Mat does not match number of agents!" + + adjmatgraph = SimpleGraph(adjmat) + coord_graph_components = maximal_cliques(adjmatgraph) + min_degree_ordering = sortperm(degree(adjmatgraph)) + + # Initialize full agent actions + # TODO(jkg): this is incorrect? Or we need to override actiontype to refer to agent actions? + all_agent_actions = Vector{actiontype(mdp)}(undef, n_agents(mdp)) + for i = 1:n_agents(mdp) + all_agent_actions[i] = get_agent_actions(mdp, i) + end + + ve_stats = VarElStatistics{eltype(init_state)}(coord_graph_components, min_degree_ordering, + Dict{typeof(init_state),Vector{Vector{Int64}}}(), + Dict{typeof(init_state),Vector{Vector{Int64}}}(), + ) + + # Create tree FROM CURRENT STATE + tree = JointMCTSTree(all_agent_actions, ve_stats, + init_state, ReentrantLock(), solver.n_iterations) + se = convert_estimator(solver.estimate_value, solver, mdp) + + return JointMCTSPlanner(solver, mdp, tree, se, solver.rng) +end # end JointMCTSPlanner + + +function maxplus_joint_mcts_planner(solver::FVMCTSSolver, + mdp::JointMDP{S,A}, + init_state::AbstractVector{S}, + message_iters::Int64, + message_norm::Bool, + use_agent_utils::Bool, + node_exploration::Bool, + edge_exploration::Bool, + ) where {S,A} + + @assert (node_exploration || edge_exploration) "At least one of nodes or edges should explore!" + + adjmat = coord_graph_adj_mat(mdp) + @assert size(adjmat)[1] == n_agents(mdp) "Adjacency Mat does not match number of agents!" + + adjmatgraph = SimpleGraph(adjmat) + # Initialize full agent actions + # TODO(jkg): this is incorrect? Or we need to override actiontype to refer to agent actions? + all_agent_actions = Vector{actiontype(mdp)}(undef, n_agents(mdp)) + for i = 1:n_agents(mdp) + all_agent_actions[i] = get_agent_actions(mdp, i) + end + + mp_stats = MaxPlusStatistics{eltype(init_state)}(adjmatgraph, + message_iters, + message_norm, + use_agent_utils, + node_exploration, + edge_exploration, + Dict{typeof(init_state),PerStateMPStats}()) + + # Create tree FROM CURRENT STATE + tree = JointMCTSTree(all_agent_actions, mp_stats, + init_state, ReentrantLock(), solver.n_iterations) + se = convert_estimator(solver.estimate_value, solver, mdp) + + return JointMCTSPlanner(solver, mdp, tree, se, solver.rng) +end + + +# Reset tree. +function clear_tree!(planner::JointMCTSPlanner) + + # Clear out state hash dict entirely + empty!(planner.tree.state_map) + + # Empty state vectors with state hints + sz = min(planner.solver.n_iterations, 100_000) + + empty!(planner.tree.s_labels) + sizehint!(planner.tree.s_labels, planner.solver.n_iterations) + + # Don't touch all_agent_actions and coord graph component + # Just clear comp stats dict + clear_statistics!(planner.tree.coordination_stats) +end + +# function get_state_node(tree::JointMCTSTree, s, planner::JointMCTSPlanner) +# if haskey(tree.state_map, s) +# return JointStateNode(tree, tree.state_map[s]) # Is this correct? Not equiv to vanilla +# else +# return insert_node!(tree, planner, s) +# end +# end + +MCTS.init_Q(n::Number, mdp::JointMDP, s, c, a) = convert(Float64, n) +MCTS.init_N(n::Number, mdp::JointMDP, s, c, a) = convert(Int, n) + + +# no computation is done in solve - the solver is just given the mdp model that it will work with +function POMDPs.solve(solver::FVMCTSSolver, mdp::JointMDP) + if typeof(solver.coordination_strategy) == VarEl + return varel_joint_mcts_planner(solver, mdp, initialstate(mdp, solver.rng)) + elseif typeof(solver.coordination_strategy) == MaxPlus + return maxplus_joint_mcts_planner(solver, mdp, initialstate(mdp, solver.rng), solver.coordination_strategy.message_iters, + solver.coordination_strategy.message_norm, solver.coordination_strategy.use_agent_utils, + solver.coordination_strategy.node_exploration, solver.coordination_strategy.edge_exploration) + else + throw(error("Not Implemented")) + end +end + + +# IMP: Overriding action for JointMCTSPlanner here +# NOTE: Hardcoding no tree reuse for now +function POMDPs.action(planner::JointMCTSPlanner, s) + clear_tree!(planner) # Always call this at the top + plan!(planner, s) + action = coordinate_action(planner.mdp, planner.tree, s) + return action +end + +function POMDPModelTools.action_info(planner::JointMCTSPlanner, s) + clear_tree!(planner) # Always call this at the top + plan!(planner, s) + action = coordinate_action(planner.mdp, planner.tree, s) + return action, nothing +end + +## Not implementing value functions right now.... +## ..Is it just the MAX of the best action, rather than argmax??? + +# Could reuse plan! from vanilla.jl. But I don't like +# calling an element of an abstract type like AbstractMCTSPlanner +function plan!(planner::JointMCTSPlanner, s) + planner.tree = build_tree(planner, s) +end + +# Build_Tree can be called on the assumption that no reuse AND tree is reinitialized +function build_tree(planner::JointMCTSPlanner, s::AbstractVector{S}) where S + + n_iterations = planner.solver.n_iterations + depth = planner.solver.depth + + root = insert_node!(planner.tree, planner, s) + # build the tree + @sync for n = 1:n_iterations + @spawn simulate(planner, root, depth) + end + return planner.tree +end + +function simulate(planner::JointMCTSPlanner, node::JointStateNode, depth::Int64) + + mdp = planner.mdp + rng = planner.rng + s = state(node) + tree = node.tree + + + # once depth is zero return + if isterminal(planner.mdp, s) + return 0.0 + elseif depth == 0 + return estimate_value(planner.solved_estimate, planner.mdp, s, depth) + end + + # Choose best UCB action (NOT an action node) + ucb_action = coordinate_action(mdp, planner.tree, s, planner.solver.exploration_constant, node.id) + + # @show ucb_action + # MC Transition + sp, r = gen(DDNOut(:sp, :r), mdp, s, ucb_action, rng) + + spid = lock(tree.lock) do + get(tree.state_map, sp, 0) # may be non-zero even with no tree reuse + end + if spid == 0 + spn = insert_node!(tree, planner, sp) + spid = spn.id + # TODO define estimate_value + q = r .+ discount(mdp) * estimate_value(planner.solved_estimate, planner.mdp, sp, depth - 1) + else + q = r .+ discount(mdp) * simulate(planner, JointStateNode(tree, spid) , depth - 1) + end + + ## Not bothering with tree vis right now + # Augment N(s) + lock(tree.lock) do + tree.total_n[node.id] += 1 + end + + # Update component statistics! (non-trivial) + # This is related but distinct from initialization + update_statistics!(mdp, tree, s, ucb_action, q) + + return q +end + +POMDPLinter.@POMDP_require simulate(planner::JointMCTSPlanner, s, depth::Int64) begin + mdp = planner.mdp + P = typeof(mdp) + @assert P <: JointMDP # req does different thing? + SV = statetype(P) + @assert typeof(SV) <: AbstractVector # TODO: Is this correct? + AV = actiontype(P) + @assert typeof(A) <: AbstractVector + @req discount(::P) + @req isterminal(::P, ::SV) + @subreq insert_node!(planner.tree, planner, s) + @subreq estimate_value(planner.solved_estimate, mdp, s, depth) + @req gen(::DDNOut{(:sp, :r)}, ::P, ::SV, ::A, ::typeof(planner.rng)) + + # MMDP reqs + @req get_agent_actions(::P, ::Int64) + @req get_agent_actions(::P, ::Int64, ::eltype(SV)) + @req n_agents(::P) + @req coord_graph_adj_mat(::P) + + # TODO: Should we also have this requirement for SV? + @req isequal(::S, ::S) + @req hash(::S) +end + + + +function insert_node!(tree::JointMCTSTree{S,A,CS}, planner::JointMCTSPlanner, + s::AbstractVector{S}) where {S,A,CS <: CoordinationStatistics} + + lock(tree.lock) do + push!(tree.s_labels, s) + tree.state_map[s] = length(tree.s_labels) + push!(tree.total_n, 1) + + # TODO: Could actually make actions state-dependent if need be + init_statistics!(tree, planner, s) + end + # length(tree.s_labels) is just an alias for the number of state nodes + ls = lock(tree.lock) do + length(tree.s_labels) + end + return JointStateNode(tree, ls) +end + +POMDPLinter.@POMDP_require insert_node!(tree::JointMCTSTree, planner::JointMCTSPlanner, s) begin + + P = typeof(planner.mdp) + AV = actiontype(P) + A = eltype(AV) + SV = typeof(s) + S = eltype(SV) + + # TODO: Review IQ and IN + IQ = typeof(planner.solver.init_Q) + if !(IQ <: Number) && !(IQ <: Function) + @req init_Q(::IQ, ::P, ::S, ::Vector{Int64}, ::AbstractVector{A}) + end + + IN = typeof(planner.solver.init_N) + if !(IN <: Number) && !(IN <: Function) + @req init_N(::IQ, ::P, ::S, ::Vector{Int64}, ::AbstractVector{A}) + end + + @req isequal(::S, ::S) + @req hash(::S) +end From 6a3fc7cbbb6f930b502810f69b85e512c9f9dd14 Mon Sep 17 00:00:00 2001 From: shushman Date: Thu, 29 Oct 2020 18:40:26 -0700 Subject: [PATCH 02/17] Changed Julia compat to 1.4 for now --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 06d547d..8f2a60d 100644 --- a/Project.toml +++ b/Project.toml @@ -18,7 +18,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [compat] -julia = "1.5" +julia = "1.4" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" From eb1820e889fa6b78951fa194bbe7b41776f2c430 Mon Sep 17 00:00:00 2001 From: rejuvyesh Date: Mon, 2 Nov 2020 17:38:41 -0800 Subject: [PATCH 03/17] rename Joint -> FV --- src/fvmcts/fv_mcts_vanilla.jl | 64 +++++++++++++++++------------------ 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/src/fvmcts/fv_mcts_vanilla.jl b/src/fvmcts/fv_mcts_vanilla.jl index a36e1c3..f53a8aa 100644 --- a/src/fvmcts/fv_mcts_vanilla.jl +++ b/src/fvmcts/fv_mcts_vanilla.jl @@ -32,8 +32,8 @@ end end -# JointMCTS tree has to be different, to efficiently encode Q-stats -mutable struct JointMCTSTree{S,A,CS<:CoordinationStatistics} +# FVMCTS tree has to be different, to efficiently encode Q-stats +mutable struct FVMCTSTree{S,A,CS<:CoordinationStatistics} # To track if state node in tree already # NOTE: We don't strictly need this at all if no tree reuse... @@ -54,13 +54,13 @@ mutable struct JointMCTSTree{S,A,CS<:CoordinationStatistics} end # Just a glorified wrapper now -function JointMCTSTree(all_agent_actions::Vector{AbstractVector{A}}, +function FVMCTSTree(all_agent_actions::Vector{AbstractVector{A}}, coordination_stats::CS, init_state::AbstractVector{S}, lock::ReentrantLock, sz::Int64=10000) where {S, A, CS <: CoordinationStatistics} - return JointMCTSTree{S,A,CS}(Dict{typeof(init_state),Int64}(), + return FVMCTSTree{S,A,CS}(Dict{typeof(init_state),Int64}(), sizehint!(Int[], sz), sizehint!(typeof(init_state)[], sz), all_agent_actions, @@ -71,26 +71,26 @@ end # function -Base.isempty(t::JointMCTSTree) = isempty(t.state_map) -state_nodes(t::JointMCTSTree) = (JointStateNode(t, id) for id in 1:length(t.total_n)) +Base.isempty(t::FVMCTSTree) = isempty(t.state_map) +state_nodes(t::FVMCTSTree) = (FVStateNode(t, id) for id in 1:length(t.total_n)) -struct JointStateNode{S} - tree::JointMCTSTree{S} +struct FVStateNode{S} + tree::FVMCTSTree{S} id::Int64 end -#get_state_node(tree::JointMCTSTree, id) = JointStateNode(tree, id) +#get_state_node(tree::FVMCTSTree, id) = FVStateNode(tree, id) # accessors for state nodes -@inline state(n::JointStateNode) = n.tree.s_labels[n.id] -@inline total_n(n::JointStateNode) = n.tree.total_n[n.id] +@inline state(n::FVStateNode) = n.tree.s_labels[n.id] +@inline total_n(n::FVStateNode) = n.tree.total_n[n.id] ## No need for `children` or ActionNode just yet -mutable struct JointMCTSPlanner{S, A, SE, CS <: CoordinationStatistics, RNG <: AbstractRNG} <: AbstractMCTSPlanner{JointMDP{S,A}} +mutable struct FVMCTSPlanner{S, A, SE, CS <: CoordinationStatistics, RNG <: AbstractRNG} <: AbstractMCTSPlanner{JointMDP{S,A}} solver::FVMCTSSolver mdp::JointMDP{S,A} - tree::JointMCTSTree{S,A,CS} + tree::FVMCTSTree{S,A,CS} solved_estimate::SE rng::RNG end @@ -121,12 +121,12 @@ function varel_joint_mcts_planner(solver::FVMCTSSolver, ) # Create tree FROM CURRENT STATE - tree = JointMCTSTree(all_agent_actions, ve_stats, + tree = FVMCTSTree(all_agent_actions, ve_stats, init_state, ReentrantLock(), solver.n_iterations) se = convert_estimator(solver.estimate_value, solver, mdp) - return JointMCTSPlanner(solver, mdp, tree, se, solver.rng) -end # end JointMCTSPlanner + return FVMCTSPlanner(solver, mdp, tree, se, solver.rng) +end # end FVMCTSPlanner function maxplus_joint_mcts_planner(solver::FVMCTSSolver, @@ -161,16 +161,16 @@ function maxplus_joint_mcts_planner(solver::FVMCTSSolver, Dict{typeof(init_state),PerStateMPStats}()) # Create tree FROM CURRENT STATE - tree = JointMCTSTree(all_agent_actions, mp_stats, + tree = FVMCTSTree(all_agent_actions, mp_stats, init_state, ReentrantLock(), solver.n_iterations) se = convert_estimator(solver.estimate_value, solver, mdp) - return JointMCTSPlanner(solver, mdp, tree, se, solver.rng) + return FVMCTSPlanner(solver, mdp, tree, se, solver.rng) end # Reset tree. -function clear_tree!(planner::JointMCTSPlanner) +function clear_tree!(planner::FVMCTSPlanner) # Clear out state hash dict entirely empty!(planner.tree.state_map) @@ -186,9 +186,9 @@ function clear_tree!(planner::JointMCTSPlanner) clear_statistics!(planner.tree.coordination_stats) end -# function get_state_node(tree::JointMCTSTree, s, planner::JointMCTSPlanner) +# function get_state_node(tree::FVMCTSTree, s, planner::FVMCTSPlanner) # if haskey(tree.state_map, s) -# return JointStateNode(tree, tree.state_map[s]) # Is this correct? Not equiv to vanilla +# return FVStateNode(tree, tree.state_map[s]) # Is this correct? Not equiv to vanilla # else # return insert_node!(tree, planner, s) # end @@ -212,16 +212,16 @@ function POMDPs.solve(solver::FVMCTSSolver, mdp::JointMDP) end -# IMP: Overriding action for JointMCTSPlanner here +# IMP: Overriding action for FVMCTSPlanner here # NOTE: Hardcoding no tree reuse for now -function POMDPs.action(planner::JointMCTSPlanner, s) +function POMDPs.action(planner::FVMCTSPlanner, s) clear_tree!(planner) # Always call this at the top plan!(planner, s) action = coordinate_action(planner.mdp, planner.tree, s) return action end -function POMDPModelTools.action_info(planner::JointMCTSPlanner, s) +function POMDPModelTools.action_info(planner::FVMCTSPlanner, s) clear_tree!(planner) # Always call this at the top plan!(planner, s) action = coordinate_action(planner.mdp, planner.tree, s) @@ -233,12 +233,12 @@ end # Could reuse plan! from vanilla.jl. But I don't like # calling an element of an abstract type like AbstractMCTSPlanner -function plan!(planner::JointMCTSPlanner, s) +function plan!(planner::FVMCTSPlanner, s) planner.tree = build_tree(planner, s) end # Build_Tree can be called on the assumption that no reuse AND tree is reinitialized -function build_tree(planner::JointMCTSPlanner, s::AbstractVector{S}) where S +function build_tree(planner::FVMCTSPlanner, s::AbstractVector{S}) where S n_iterations = planner.solver.n_iterations depth = planner.solver.depth @@ -251,7 +251,7 @@ function build_tree(planner::JointMCTSPlanner, s::AbstractVector{S}) where S return planner.tree end -function simulate(planner::JointMCTSPlanner, node::JointStateNode, depth::Int64) +function simulate(planner::FVMCTSPlanner, node::FVStateNode, depth::Int64) mdp = planner.mdp rng = planner.rng @@ -282,7 +282,7 @@ function simulate(planner::JointMCTSPlanner, node::JointStateNode, depth::Int64) # TODO define estimate_value q = r .+ discount(mdp) * estimate_value(planner.solved_estimate, planner.mdp, sp, depth - 1) else - q = r .+ discount(mdp) * simulate(planner, JointStateNode(tree, spid) , depth - 1) + q = r .+ discount(mdp) * simulate(planner, FVStateNode(tree, spid) , depth - 1) end ## Not bothering with tree vis right now @@ -298,7 +298,7 @@ function simulate(planner::JointMCTSPlanner, node::JointStateNode, depth::Int64) return q end -POMDPLinter.@POMDP_require simulate(planner::JointMCTSPlanner, s, depth::Int64) begin +POMDPLinter.@POMDP_require simulate(planner::FVMCTSPlanner, s, depth::Int64) begin mdp = planner.mdp P = typeof(mdp) @assert P <: JointMDP # req does different thing? @@ -325,7 +325,7 @@ end -function insert_node!(tree::JointMCTSTree{S,A,CS}, planner::JointMCTSPlanner, +function insert_node!(tree::FVMCTSTree{S,A,CS}, planner::FVMCTSPlanner, s::AbstractVector{S}) where {S,A,CS <: CoordinationStatistics} lock(tree.lock) do @@ -340,10 +340,10 @@ function insert_node!(tree::JointMCTSTree{S,A,CS}, planner::JointMCTSPlanner, ls = lock(tree.lock) do length(tree.s_labels) end - return JointStateNode(tree, ls) + return FVStateNode(tree, ls) end -POMDPLinter.@POMDP_require insert_node!(tree::JointMCTSTree, planner::JointMCTSPlanner, s) begin +POMDPLinter.@POMDP_require insert_node!(tree::FVMCTSTree, planner::FVMCTSPlanner, s) begin P = typeof(planner.mdp) AV = actiontype(P) From e2b23f91808f34d370ae3a7535f0fa27773ee754 Mon Sep 17 00:00:00 2001 From: rejuvyesh Date: Mon, 2 Nov 2020 17:41:46 -0800 Subject: [PATCH 04/17] factor out the FV policy --- src/MAMCTS.jl | 18 +----------------- src/fvmcts/factoredpolicy.jl | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 17 deletions(-) create mode 100644 src/fvmcts/factoredpolicy.jl diff --git a/src/MAMCTS.jl b/src/MAMCTS.jl index 449ba8a..0706be8 100644 --- a/src/MAMCTS.jl +++ b/src/MAMCTS.jl @@ -21,23 +21,7 @@ import POMDPModelTools abstract type CoordinationStatistics end -""" -Random Policy factored for each agent. Avoids exploding action space. -""" -struct FactoredRandomPolicy{RNG<:AbstractRNG,P<:JointMDP, U<:Updater} <: Policy - rng::RNG - problem::P - updater::U -end - -FactoredRandomPolicy(problem::JointMDP; rng=Random.GLOBAL_RNG, updater=NothingUpdater()) = FactoredRandomPolicy(rng, problem, updater) - -function POMDPs.action(policy::FactoredRandomPolicy, s) - return [rand(policy.rng, get_agent_actions(policy.problem, i, si)) for (i, si) in enumerate(s)] -end - -POMDPs.solve(solver::RandomSolver, problem::JointMDP) = FactoredRandomPolicy(solver.rng, problem, NothingUpdater()) - +include(joinpath("fvmcts", "factoredpolicy.jl")) include(joinpath("fvmcts", "fv_mcts_vanilla.jl")) include(joinpath("fvmcts", "action_coordination", "varel.jl")) include(joinpath("fvmcts", "action_coordination", "maxplus.jl")) diff --git a/src/fvmcts/factoredpolicy.jl b/src/fvmcts/factoredpolicy.jl new file mode 100644 index 0000000..d1af7b9 --- /dev/null +++ b/src/fvmcts/factoredpolicy.jl @@ -0,0 +1,17 @@ + +""" +Random Policy factored for each agent. Avoids exploding action space. +""" +struct FactoredRandomPolicy{RNG<:AbstractRNG,P<:JointMDP, U<:Updater} <: Policy + rng::RNG + problem::P + updater::U +end + +FactoredRandomPolicy(problem::JointMDP; rng=Random.GLOBAL_RNG, updater=NothingUpdater()) = FactoredRandomPolicy(rng, problem, updater) + +function POMDPs.action(policy::FactoredRandomPolicy, s) + return [rand(policy.rng, get_agent_actions(policy.problem, i, si)) for (i, si) in enumerate(s)] +end + +POMDPs.solve(solver::RandomSolver, problem::JointMDP) = FactoredRandomPolicy(solver.rng, problem, NothingUpdater()) \ No newline at end of file From fd81daa9d3e195cca7a51c09c27390212a6fbe0f Mon Sep 17 00:00:00 2001 From: shushman Date: Tue, 3 Nov 2020 22:33:51 -0800 Subject: [PATCH 05/17] Docstring for FVMCTS Solver; cleaned up some bad comments --- src/fvmcts/fv_mcts_vanilla.jl | 142 +++++++++++++++++++++++----------- 1 file changed, 95 insertions(+), 47 deletions(-) diff --git a/src/fvmcts/fv_mcts_vanilla.jl b/src/fvmcts/fv_mcts_vanilla.jl index f53a8aa..ec6a781 100644 --- a/src/fvmcts/fv_mcts_vanilla.jl +++ b/src/fvmcts/fv_mcts_vanilla.jl @@ -1,6 +1,3 @@ -## We can use the following directly without modification -## 1. domain_knowledge.jl for Rollout, init_Q and init_N functions -## 2. FVMCTSSolver for representing the overall MCTS (underlying things will change) using StaticArrays using Parameters using Base.Threads: @spawn @@ -18,6 +15,60 @@ end edge_exploration::Bool = true end +""" +Factored Value Monte Carlo Tree Search solver datastructure + +Fields: + n_iterations::Int64 + Number of iterations during each action() call. + default: 100 + + max_time::Float64 + Maximum CPU time to spend computing an action. + default::Inf + + depth::Int64 + Number of iterations during each action() call. + default: 100 + + exploration_constant::Float64: + Specifies how much the solver should explore. In the UCB equation, Q + c*sqrt(log(t/N)), c is the exploration constant. + The exploration terms for FV-MCTS-Var-El and FV-MCTS-Max-Plus are different but the role of c is the same. + default: 1.0 + + rng::AbstractRNG: + Random number generator + + estimate_value::Any (rollout policy) + Function, object, or number used to estimate the value at the leaf nodes. + If this is a function `f`, `f(mdp, s, depth)` will be called to estimate the value. + If this is an object `o`, `estimate_value(o, mdp, s, depth)` will be called. + If this is a number, the value will be set to that number + default: RolloutEstimator(RandomSolver(rng)) + + init_Q::Any + Function, object, or number used to set the initial Q(s,a) value at a new node. + If this is a function `f`, `f(mdp, s, a)` will be called to set the value. + If this is an object `o`, `init_Q(o, mdp, s, a)` will be called. + If this is a number, Q will be set to that number + default: 0.0 + + init_N::Any + Function, object, or number used to set the initial N(s,a) value at a new node. + If this is a function `f`, `f(mdp, s, a)` will be called to set the value. + If this is an object `o`, `init_N(o, mdp, s, a)` will be called. + If this is a number, N will be set to that number + default: 0 + + reuse_tree::Bool + If this is true, the tree information is re-used for calculating the next plan. + Of course, clear_tree! can always be called to override this. + default: false + + coordination_strategy::AbstractCoordinationStrategy + The specific strategy with which to compute the best joint action from the current MCTS statistics. + default: VarEl() +""" @with_kw mutable struct FVMCTSSolver <: AbstractMCTSSolver n_iterations::Int64 = 100 max_time::Float64 = Inf @@ -28,37 +79,31 @@ end init_Q::Any = 0.0 init_N::Any = 0 reuse_tree::Bool = false - coordination_strategy::Any = VarEl() + coordination_strategy::AbstractCoordinationStrategy = VarEl() end -# FVMCTS tree has to be different, to efficiently encode Q-stats mutable struct FVMCTSTree{S,A,CS<:CoordinationStatistics} - # To track if state node in tree already - # NOTE: We don't strictly need this at all if no tree reuse... + # To map the multi-agent state vector to the ID of the node in the tree state_map::Dict{AbstractVector{S},Int64} - # these vectors have one entry for each state node - # Only doing factored satistics (for actions), not state components - # Looks like we don't need child_ids - total_n::Vector{Int} - s_labels::Vector{AbstractVector{S}} + # The next two vectors have one for each node ID in the tree + total_n::Vector{Int} # The number of times the node has been tried + s_labels::Vector{AbstractVector{S}} # The state corresponding to the node ID - # Track stats for all action components over the n_iterations + # List of all individual actions of each agent for coordination purposes. all_agent_actions::Vector{AbstractVector{A}} coordination_stats::CS lock::ReentrantLock - # Don't need a_labels because need to do var-el for best action anyway end -# Just a glorified wrapper now function FVMCTSTree(all_agent_actions::Vector{AbstractVector{A}}, - coordination_stats::CS, - init_state::AbstractVector{S}, - lock::ReentrantLock, - sz::Int64=10000) where {S, A, CS <: CoordinationStatistics} + coordination_stats::CS, + init_state::AbstractVector{S}, + lock::ReentrantLock, + sz::Int64=10000) where {S, A, CS <: CoordinationStatistics} return FVMCTSTree{S,A,CS}(Dict{typeof(init_state),Int64}(), sizehint!(Int[], sz), @@ -79,9 +124,8 @@ struct FVStateNode{S} id::Int64 end -#get_state_node(tree::FVMCTSTree, id) = FVStateNode(tree, id) -# accessors for state nodes +# Accessors for state nodes @inline state(n::FVStateNode) = n.tree.s_labels[n.id] @inline total_n(n::FVStateNode) = n.tree.total_n[n.id] @@ -95,14 +139,18 @@ mutable struct FVMCTSPlanner{S, A, SE, CS <: CoordinationStatistics, RNG <: Abst rng::RNG end +""" +Called internally in solve() to create the FVMCTSPlanner where Var-El is the specific action coordination strategy. +Creates VarElStatistics internally with the CG components and the minimum degree ordering heuristic. +""" function varel_joint_mcts_planner(solver::FVMCTSSolver, mdp::JointMDP{S,A}, init_state::AbstractVector{S}, ) where {S,A} - # Get coord graph comps from maximal cliques of graph + # Get coordination graph components from maximal cliques adjmat = coord_graph_adj_mat(mdp) - @assert size(adjmat)[1] == n_agents(mdp) "Adjacency Mat does not match number of agents!" + @assert size(adjmat)[1] == n_agents(mdp) "Adjacency Matrix does not match number of agents!" adjmatgraph = SimpleGraph(adjmat) coord_graph_components = maximal_cliques(adjmatgraph) @@ -120,15 +168,18 @@ function varel_joint_mcts_planner(solver::FVMCTSSolver, Dict{typeof(init_state),Vector{Vector{Int64}}}(), ) - # Create tree FROM CURRENT STATE + # Create tree from the current state tree = FVMCTSTree(all_agent_actions, ve_stats, init_state, ReentrantLock(), solver.n_iterations) se = convert_estimator(solver.estimate_value, solver, mdp) return FVMCTSPlanner(solver, mdp, tree, se, solver.rng) -end # end FVMCTSPlanner - +end # end function +""" +Called internally in solve() to create the FVMCTSPlanner where Max-Plus is the specific action coordination strategy. +Creates MaxPlusStatistics and assumes the various MP flags are sent down from the CoordinationStrategy object given to the solver. +""" function maxplus_joint_mcts_planner(solver::FVMCTSSolver, mdp::JointMDP{S,A}, init_state::AbstractVector{S}, @@ -145,6 +196,7 @@ function maxplus_joint_mcts_planner(solver::FVMCTSSolver, @assert size(adjmat)[1] == n_agents(mdp) "Adjacency Mat does not match number of agents!" adjmatgraph = SimpleGraph(adjmat) + # Initialize full agent actions # TODO(jkg): this is incorrect? Or we need to override actiontype to refer to agent actions? all_agent_actions = Vector{actiontype(mdp)}(undef, n_agents(mdp)) @@ -153,16 +205,16 @@ function maxplus_joint_mcts_planner(solver::FVMCTSSolver, end mp_stats = MaxPlusStatistics{eltype(init_state)}(adjmatgraph, - message_iters, - message_norm, - use_agent_utils, - node_exploration, - edge_exploration, - Dict{typeof(init_state),PerStateMPStats}()) - - # Create tree FROM CURRENT STATE + message_iters, + message_norm, + use_agent_utils, + node_exploration, + edge_exploration, + Dict{typeof(init_state),PerStateMPStats}()) + + # Create tree from the current state tree = FVMCTSTree(all_agent_actions, mp_stats, - init_state, ReentrantLock(), solver.n_iterations) + init_state, ReentrantLock(), solver.n_iterations) se = convert_estimator(solver.estimate_value, solver, mdp) return FVMCTSPlanner(solver, mdp, tree, se, solver.rng) @@ -186,26 +238,22 @@ function clear_tree!(planner::FVMCTSPlanner) clear_statistics!(planner.tree.coordination_stats) end -# function get_state_node(tree::FVMCTSTree, s, planner::FVMCTSPlanner) -# if haskey(tree.state_map, s) -# return FVStateNode(tree, tree.state_map[s]) # Is this correct? Not equiv to vanilla -# else -# return insert_node!(tree, planner, s) -# end -# end - MCTS.init_Q(n::Number, mdp::JointMDP, s, c, a) = convert(Float64, n) MCTS.init_N(n::Number, mdp::JointMDP, s, c, a) = convert(Int, n) -# no computation is done in solve - the solver is just given the mdp model that it will work with +# No computation is done in solve; the solver is just given the mdp model that it will work with +# and in case of MaxPlus, the various flags for the MaxPlus behavior function POMDPs.solve(solver::FVMCTSSolver, mdp::JointMDP) if typeof(solver.coordination_strategy) == VarEl return varel_joint_mcts_planner(solver, mdp, initialstate(mdp, solver.rng)) elseif typeof(solver.coordination_strategy) == MaxPlus - return maxplus_joint_mcts_planner(solver, mdp, initialstate(mdp, solver.rng), solver.coordination_strategy.message_iters, - solver.coordination_strategy.message_norm, solver.coordination_strategy.use_agent_utils, - solver.coordination_strategy.node_exploration, solver.coordination_strategy.edge_exploration) + return maxplus_joint_mcts_planner(solver, mdp, initialstate(mdp, solver.rng), + solver.coordination_strategy.message_iters, + solver.coordination_strategy.message_norm, + solver.coordination_strategy.use_agent_utils, + solver.coordination_strategy.node_exploration, + solver.coordination_strategy.edge_exploration) else throw(error("Not Implemented")) end From 65c80eb6cb88cdcc2de1d3fc91c7c202be8ffdfb Mon Sep 17 00:00:00 2001 From: shushman Date: Fri, 6 Nov 2020 17:13:21 -0800 Subject: [PATCH 06/17] Fleshed out vanilla and varel-mp --- src/fvmcts/action_coordination/maxplus.jl | 37 ++++++++++++++++++++--- src/fvmcts/action_coordination/varel.jl | 32 ++++++++++++++++---- src/fvmcts/fv_mcts_vanilla.jl | 29 ++++++++---------- 3 files changed, 72 insertions(+), 26 deletions(-) diff --git a/src/fvmcts/action_coordination/maxplus.jl b/src/fvmcts/action_coordination/maxplus.jl index d7945b1..a0a72f1 100644 --- a/src/fvmcts/action_coordination/maxplus.jl +++ b/src/fvmcts/action_coordination/maxplus.jl @@ -1,4 +1,3 @@ -## Now do for message passing # NOTE: Matrix implicitly assumes all agents have same number of actions mutable struct PerStateMPStats agent_action_n::Matrix{Int64} # N X A @@ -7,7 +6,32 @@ mutable struct PerStateMPStats edge_action_q::Matrix{Float64} end -# NOTE: Putting params here is a little ugly but coordinate_action can't have them since VarEl doesn't use those args +""" +Tracks the specific informations and statistics we need to use Max-Plus to coordinate_action +the joint action in Factored-Value MCTS. Putting parameters here is a little ugly but coordinate_action can't have them since VarEl doesn't use those args. + +Fields: + adjmatgraph::SimpleGraph + The coordination graph as a LightGraphs SimpleGraph. + + message_iters::Int64 + Number of rounds of message passing. + + message_norm::Bool + Whether to normalize the messages or not after message passing. + + use_agent_utils::Bool + Whether to include the per-agent utilities while computing the best agent action (see our paper for details) + + node_exploration::Bool + Whether to use the per-node UCB style bonus while computing the best agent action (see our paper for details) + + edge_exploration::Bool + Whether to use the per-edge UCB style bonus after the message passing rounds (see our paper for details). One of this or node_exploration MUST be true for exploration. + + all_states_stats::Dict{AbstractVector{S},PerStateMPStats} + Maps each joint state in the tree to the per-state statistics. +""" mutable struct MaxPlusStatistics{S} <: CoordinationStatistics adjmatgraph::SimpleGraph message_iters::Int64 @@ -22,6 +46,9 @@ function clear_statistics!(mp_stats::MaxPlusStatistics) empty!(mp_stats.all_states_stats) end +""" +Take the q-value from the MCTS step and distribute the updates across the per-node and per-edge q-stats as per the formula in our paper. +""" function update_statistics!(mdp::JointMDP{S,A}, tree::JointMCTSTree{S,A,MaxPlusStatistics{S}}, s::AbstractVector{S}, ucb_action::AbstractVector{A}, q::AbstractVector{Float64}) where {S,A} @@ -38,7 +65,7 @@ function update_statistics!(mdp::JointMDP{S,A}, tree::JointMCTSTree{S,A,MaxPlusS end end - # Now update edge action stats + # Now update per-edge action stats for (idx, e) in enumerate(edges(tree.coordination_stats.adjmatgraph)) # NOTE: Need to be careful about action ordering # Being more general to have unequal agent actions @@ -106,7 +133,9 @@ function init_statistics!(tree::JointMCTSTree{S,A,MaxPlusStatistics{S}}, planner end end -# NOTE: Following deepCG +""" +Runs Max-Plus at the current state using the per-state MaxPlusStatistics to compute the best joint action with either or both of node-wise and edge-wise exploration bonus. Rounds of message passing are followed by per-node maximization. +""" function coordinate_action(mdp::JointMDP{S,A}, tree::JointMCTSTree{S,A,MaxPlusStatistics{S}}, s::AbstractVector{S}, exploration_constant::Float64=0.0, node_id::Int64=0) where {S,A} diff --git a/src/fvmcts/action_coordination/varel.jl b/src/fvmcts/action_coordination/varel.jl index 3e5d887..d9646a2 100644 --- a/src/fvmcts/action_coordination/varel.jl +++ b/src/fvmcts/action_coordination/varel.jl @@ -1,3 +1,20 @@ +""" +Tracks the specific informations and statistics we need to use Var-El to coordinate_action +the joint action in Factored-Value MCTS. + +Fields: + coord_graph_components::Vector{Vector{Int64}} + The list of coordination graph components, i.e., cliques, where each element is a list of agent IDs that are in a mutual clique. + + min_degree_ordering::Vector{Int64} + Ordering of agent IDs in increasing CG degree. This ordering is the heuristic most typically used for the elimination order in Var-El. + + n_component_stats::Dict{AbstractVector{S},Vector{Vector{Int64}}} + Maps each joint state in the tree (for which we need to compute the UCB action) to the frequency of each component's various local actions. + + q_component_stats::Dict{AbstractVector{S},Vector{Vector{Float64}}} + Maps each joint state in the tree to the accumulated q-value of each component's various local actions. +""" mutable struct VarElStatistics{S} <: CoordinationStatistics coord_graph_components::Vector{Vector{Int64}} min_degree_ordering::Vector{Int64} @@ -11,15 +28,17 @@ function clear_statistics!(ve_stats::VarElStatistics) end +""" +Runs variable elimination at the current state using the VarEl Statistics to compute the best joint action with the component-wise exploration bonus. +FYI: Rather complicated. +""" function coordinate_action(mdp::JointMDP{S,A}, tree::JointMCTSTree{S,A,VarElStatistics{S}}, s::AbstractVector{S}, exploration_constant::Float64=0.0, node_id::Int64=0) where {S,A} n_agents = length(s) best_action_idxs = MVector{n_agents}([-1 for i in 1:n_agents]) - # # !Note: Acquire lock so as to avoid race - # state_q_stats = lock(tree.lock) do tree.coordination_stats.q_component_stats[s] end @@ -58,8 +77,7 @@ function coordinate_action(mdp::JointMDP{S,A}, tree::JointMCTSTree{S,A,VarElStat # Agent to-be-eliminated is in factor push!(agent_factors, k) - # Construct key for new potential as union of all others - # except ag_idx + # Construct key for new potential as union of all others except ag_idx for ag in k if ag != ag_idx && ~(ag in new_potential_members) push!(new_potential_members, ag) @@ -78,8 +96,7 @@ function coordinate_action(mdp::JointMDP{S,A}, tree::JointMCTSTree{S,A,VarElStat else - # Generate new potential function - # AND the best response vector for eliminated agent + # Generate new potential function and the best response vector for eliminated agent n_comp_actions = prod([length(tree.all_agent_actions[c]) for c in new_potential_members]) # NOTE: Tuples should ALWAYS use tree.all_agent_actions for indexing @@ -193,6 +210,9 @@ function coordinate_action(mdp::JointMDP{S,A}, tree::JointMCTSTree{S,A,VarElStat end +""" +Take the q-value from the MCTS step and distribute the updates across the component q-stats as per the formula in the Amato-Oliehoek paper. +""" function update_statistics!(mdp::JointMDP{S,A}, tree::JointMCTSTree{S,A,VarElStatistics{S}}, s::AbstractVector{S}, ucb_action::AbstractVector{A}, q::AbstractVector{Float64}) where {S,A} diff --git a/src/fvmcts/fv_mcts_vanilla.jl b/src/fvmcts/fv_mcts_vanilla.jl index ec6a781..f1b42c2 100644 --- a/src/fvmcts/fv_mcts_vanilla.jl +++ b/src/fvmcts/fv_mcts_vanilla.jl @@ -224,7 +224,7 @@ end # Reset tree. function clear_tree!(planner::FVMCTSPlanner) - # Clear out state hash dict entirely + # Clear out state map dict entirely empty!(planner.tree.state_map) # Empty state vectors with state hints @@ -276,23 +276,20 @@ function POMDPModelTools.action_info(planner::FVMCTSPlanner, s) return action, nothing end -## Not implementing value functions right now.... -## ..Is it just the MAX of the best action, rather than argmax??? -# Could reuse plan! from vanilla.jl. But I don't like -# calling an element of an abstract type like AbstractMCTSPlanner function plan!(planner::FVMCTSPlanner, s) planner.tree = build_tree(planner, s) end -# Build_Tree can be called on the assumption that no reuse AND tree is reinitialized +# build_tree can be called on the assumption that no reuse AND tree is reinitialized function build_tree(planner::FVMCTSPlanner, s::AbstractVector{S}) where S n_iterations = planner.solver.n_iterations depth = planner.solver.depth root = insert_node!(planner.tree, planner, s) - # build the tree + + # Simulate can be multi-threaded @sync for n = 1:n_iterations @spawn simulate(planner, root, depth) end @@ -314,11 +311,10 @@ function simulate(planner::FVMCTSPlanner, node::FVStateNode, depth::Int64) return estimate_value(planner.solved_estimate, planner.mdp, s, depth) end - # Choose best UCB action (NOT an action node) + # Choose best UCB action (NOT an action node as in vanilla MCTS) ucb_action = coordinate_action(mdp, planner.tree, s, planner.solver.exploration_constant, node.id) - # @show ucb_action - # MC Transition + # Monte Carlo Transition sp, r = gen(DDNOut(:sp, :r), mdp, s, ucb_action, rng) spid = lock(tree.lock) do @@ -327,13 +323,13 @@ function simulate(planner::FVMCTSPlanner, node::FVStateNode, depth::Int64) if spid == 0 spn = insert_node!(tree, planner, sp) spid = spn.id - # TODO define estimate_value + q = r .+ discount(mdp) * estimate_value(planner.solved_estimate, planner.mdp, sp, depth - 1) else q = r .+ discount(mdp) * simulate(planner, FVStateNode(tree, spid) , depth - 1) end - ## Not bothering with tree vis right now + # NOTE: Not bothering with tree visualization right now # Augment N(s) lock(tree.lock) do tree.total_n[node.id] += 1 @@ -349,9 +345,9 @@ end POMDPLinter.@POMDP_require simulate(planner::FVMCTSPlanner, s, depth::Int64) begin mdp = planner.mdp P = typeof(mdp) - @assert P <: JointMDP # req does different thing? + @assert P <: JointMDP SV = statetype(P) - @assert typeof(SV) <: AbstractVector # TODO: Is this correct? + @assert typeof(SV) <: AbstractVector AV = actiontype(P) @assert typeof(A) <: AbstractVector @req discount(::P) @@ -360,7 +356,7 @@ POMDPLinter.@POMDP_require simulate(planner::FVMCTSPlanner, s, depth::Int64) beg @subreq estimate_value(planner.solved_estimate, mdp, s, depth) @req gen(::DDNOut{(:sp, :r)}, ::P, ::SV, ::A, ::typeof(planner.rng)) - # MMDP reqs + ## Requirements from MMDP Model @req get_agent_actions(::P, ::Int64) @req get_agent_actions(::P, ::Int64, ::eltype(SV)) @req n_agents(::P) @@ -381,9 +377,10 @@ function insert_node!(tree::FVMCTSTree{S,A,CS}, planner::FVMCTSPlanner, tree.state_map[s] = length(tree.s_labels) push!(tree.total_n, 1) - # TODO: Could actually make actions state-dependent if need be + # NOTE: Could actually make actions state-dependent if need be init_statistics!(tree, planner, s) end + # length(tree.s_labels) is just an alias for the number of state nodes ls = lock(tree.lock) do length(tree.s_labels) From 8c1fc0bf26238ce58f8beabcf914d36f335ae9a1 Mon Sep 17 00:00:00 2001 From: rejuvyesh Date: Thu, 31 Dec 2020 02:23:49 +0000 Subject: [PATCH 07/17] Fixes for new api --- Project.toml | 4 +- src/MAMCTS.jl | 5 +- src/fcmcts/fcmcts.jl | 33 +++++----- src/fvmcts/action_coordination/maxplus.jl | 32 ++++----- src/fvmcts/action_coordination/varel.jl | 48 +++++++++++--- src/fvmcts/factoredpolicy.jl | 2 +- src/fvmcts/fv_mcts_vanilla.jl | 79 ++++++++++++----------- 7 files changed, 117 insertions(+), 86 deletions(-) diff --git a/Project.toml b/Project.toml index 8f2a60d..99b9304 100644 --- a/Project.toml +++ b/Project.toml @@ -7,8 +7,8 @@ version = "0.1.0" BeliefUpdaters = "8bb6e9a1-7d73-552c-a44a-e5dc5634aac4" LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -MAPOMDPs = "f50418f3-c642-4efe-9903-417dc09ce874" MCTS = "e12ccd36-dcad-5f33-8774-9175229e7b33" +MultiAgentPOMDPs = "9ac5fd70-7902-42e7-9745-ff446b44e779" POMDPLinter = "f3bd98c0-eb40-45e2-9eb1-f2763262d755" POMDPModelTools = "08074719-1b2a-587c-a292-00f91cc44415" POMDPPolicies = "182e52fb-cfd0-5e46-8c26-fd0667c990f4" @@ -18,7 +18,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [compat] -julia = "1.4" +julia = "1.5" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/src/MAMCTS.jl b/src/MAMCTS.jl index 0706be8..fb23d85 100644 --- a/src/MAMCTS.jl +++ b/src/MAMCTS.jl @@ -3,11 +3,10 @@ module MAMCTS using Random using LinearAlgebra -using Parameters using POMDPs -using MAPOMDPs +using MultiAgentPOMDPs using POMDPPolicies -using POMDPLinter +using POMDPLinter: @req, @subreq, @POMDP_require using MCTS using LightGraphs using BeliefUpdaters diff --git a/src/fcmcts/fcmcts.jl b/src/fcmcts/fcmcts.jl index 749a161..61f463a 100644 --- a/src/fcmcts/fcmcts.jl +++ b/src/fcmcts/fcmcts.jl @@ -1,6 +1,6 @@ -@with_kw mutable struct FCMCTSSolver <: AbstractMCTSSolver +Base.@kwdef mutable struct FCMCTSSolver <: AbstractMCTSSolver n_iterations::Int64 = 100 max_time::Float64 = Inf depth::Int64 = 10 @@ -15,26 +15,26 @@ end mutable struct FCMCTSTree{S,A} # To track if state node in tree already # NOTE: We don't strictly need this at all if no tree reuse... - state_map::Dict{AbstractVector{S},Int64} + state_map::Dict{S,Int64} # these vectors have one entry for each state node # Only doing factored satistics (for actions), not state components child_ids::Vector{Vector{Int}} total_n::Vector{Int} - s_labels::Vector{AbstractVector{S}} + s_labels::Vector{S} # TODO(jkg): is this the best way to track stats? # these vectors have one entry for each action node n::Vector{Int64} q::Vector{Float64} - a_labels::Vector{AbstractVector{A}} + a_labels::Vector{A} lock::ReentrantLock end -function FCMCTSTree{S,A}(init_state::AbstractVector{S}, lock::ReentrantLock, sz::Int=1000) where {S,A} +function FCMCTSTree{S,A}(init_state::S, lock::ReentrantLock, sz::Int=1000) where {S,A} sz = min(sz, 100_000) - return FCMCTSTree{S,A}(Dict{typeof(init_state),Int64}(), + return FCMCTSTree{S,A}(Dict{S,Int64}(), sizehint!(Vector{Int}[], sz), sizehint!(Int[], sz), sizehint!(typeof(init_state)[], sz), @@ -172,7 +172,7 @@ function simulate(planner::FCMCTSPlanner, node::FCStateNode, depth::Int64) # @show ucb_action # MC Transition - sp, r = gen(DDNOut(:sp, :r), mdp, s, ucb_action, rng) + sp, r = @gen(:sp, :r)(mdp, s, ucb_action, rng) # NOTE(jkg): just summing up the rewards to get a single value # TODO: should we divide by n_agents? @@ -212,7 +212,7 @@ function insert_node!(tree::FCMCTSTree, planner::FCMCTSPlanner, s) end # NOTE: Doing state-dep actions here the JointMDP way - state_dep_jtactions = vec(map(collect, Iterators.product((get_agent_actions(planner.mdp, i, si) for (i, si) in enumerate(s))...))) + state_dep_jtactions = vec(map(collect, Iterators.product((agent_actions(planner.mdp, i, si) for (i, si) in enumerate(s))...))) total_n = 0 for a in state_dep_jtactions @@ -259,23 +259,23 @@ function compute_best_action_node(mdp::JointMDP, tree::FCMCTSTree, node::FCState return best end -POMDPLinter.@POMDP_require simulate(planner::FCMCTSPlanner, s, depth::Int64) begin +@POMDP_require simulate(planner::FCMCTSPlanner, s, depth::Int64) begin mdp = planner.mdp P = typeof(mdp) @assert P <: JointMDP # req does different thing? - SV = statetype(P) - @assert typeof(SV) <: AbstractVector # TODO: Is this correct? + #SV = statetype(P) + #@assert typeof(SV) <: AbstractVector # TODO: Is this correct? AV = actiontype(P) - @assert typeof(A) <: AbstractVector + @assert typeof(AV) <: AbstractVector @req discount(::P) @req isterminal(::P, ::SV) @subreq insert_node!(planner.tree, planner, s) @subreq estimate_value(planner.solved_estimate, mdp, s, depth) - @req gen(::DDNOut{(:sp, :r)}, ::P, ::SV, ::A, ::typeof(planner.rng)) + @req gen(::DDNOut{(:sp, :r)}, ::P, ::SV, ::AV, ::typeof(planner.rng)) # MMDP reqs - @req get_agent_actions(::P, ::Int64) - @req get_agent_actions(::P, ::Int64, ::eltype(SV)) + @req agent_actions(::P, ::Int64) + @req agent_actions(::P, ::Int64, ::eltype(SV)) # TODO should this be eltype? @req n_agents(::P) # TODO: Should we also have this requirement for SV? @@ -283,7 +283,7 @@ POMDPLinter.@POMDP_require simulate(planner::FCMCTSPlanner, s, depth::Int64) beg @req hash(::S) end -POMDPLinter.@POMDP_require insert_node!(tree::FCMCTSTree, planner::FCMCTSPlanner, s) begin +@POMDP_require insert_node!(tree::FCMCTSTree, planner::FCMCTSPlanner, s) begin P = typeof(planner.mdp) AV = actiontype(P) @@ -292,6 +292,7 @@ POMDPLinter.@POMDP_require insert_node!(tree::FCMCTSTree, planner::FCMCTSPlanner S = eltype(SV) # TODO: Review IQ and IN + # Should this be ::S or ::SV? We can have global state that's not a vector. IQ = typeof(planner.solver.init_Q) if !(IQ <: Number) && !(IQ <: Function) @req init_Q(::IQ, ::P, ::S, ::Vector{Int64}, ::AbstractVector{A}) diff --git a/src/fvmcts/action_coordination/maxplus.jl b/src/fvmcts/action_coordination/maxplus.jl index a0a72f1..4bf7244 100644 --- a/src/fvmcts/action_coordination/maxplus.jl +++ b/src/fvmcts/action_coordination/maxplus.jl @@ -39,7 +39,7 @@ mutable struct MaxPlusStatistics{S} <: CoordinationStatistics use_agent_utils::Bool node_exploration::Bool edge_exploration::Bool # NOTE: One of this or node exploration must be true - all_states_stats::Dict{AbstractVector{S},PerStateMPStats} + all_states_stats::Dict{S,PerStateMPStats} end function clear_statistics!(mp_stats::MaxPlusStatistics) @@ -49,15 +49,15 @@ end """ Take the q-value from the MCTS step and distribute the updates across the per-node and per-edge q-stats as per the formula in our paper. """ -function update_statistics!(mdp::JointMDP{S,A}, tree::JointMCTSTree{S,A,MaxPlusStatistics{S}}, - s::AbstractVector{S}, ucb_action::AbstractVector{A}, q::AbstractVector{Float64}) where {S,A} +function update_statistics!(mdp::JointMDP{S,A}, tree::FVMCTSTree{S,A,MaxPlusStatistics{S}}, + s::S, ucb_action::A, q::AbstractVector{Float64}) where {S,A} state_stats = tree.coordination_stats.all_states_stats[s] n_agents = length(s) # Update per agent action stats for i = 1:n_agents - ac_idx = get_agent_actionindex(mdp, i, ucb_action[i]) + ac_idx = agent_actionindex(mdp, i, ucb_action[i]) lock(tree.lock) do state_stats.agent_action_n[i, ac_idx] += 1 state_stats.agent_action_q[i, ac_idx] += @@ -71,8 +71,8 @@ function update_statistics!(mdp::JointMDP{S,A}, tree::JointMCTSTree{S,A,MaxPlusS # Being more general to have unequal agent actions edge_comp = (e.src,e.dst) edge_tup = Tuple(1:length(tree.all_agent_actions[c]) for c in edge_comp) - edge_ac_idx = LinearIndices(edge_tup)[get_agent_actionindex(mdp, e.src, ucb_action[e.src]), - get_agent_actionindex(mdp, e.dst, ucb_action[e.dst])] + edge_ac_idx = LinearIndices(edge_tup)[agent_actionindex(mdp, e.src, ucb_action[e.src]), + agent_actionindex(mdp, e.dst, ucb_action[e.dst])] q_edge_value = q[e.src] + q[e.dst] lock(tree.lock) do @@ -88,8 +88,8 @@ function update_statistics!(mdp::JointMDP{S,A}, tree::JointMCTSTree{S,A,MaxPlusS end -function init_statistics!(tree::JointMCTSTree{S,A,MaxPlusStatistics{S}}, planner::JointMCTSPlanner, - s::AbstractVector{S}) where {S,A} +function init_statistics!(tree::FVMCTSTree{S,A,MaxPlusStatistics{S}}, planner::FVMCTSPlanner, + s::S) where {S,A} n_agents = length(s) @@ -136,7 +136,7 @@ end """ Runs Max-Plus at the current state using the per-state MaxPlusStatistics to compute the best joint action with either or both of node-wise and edge-wise exploration bonus. Rounds of message passing are followed by per-node maximization. """ -function coordinate_action(mdp::JointMDP{S,A}, tree::JointMCTSTree{S,A,MaxPlusStatistics{S}}, s::AbstractVector{S}, +function coordinate_action(mdp::JointMDP{S,A}, tree::FVMCTSTree{S,A,MaxPlusStatistics{S}}, s::S, exploration_constant::Float64=0.0, node_id::Int64=0) where {S,A} state_stats = lock(tree.lock) do @@ -149,7 +149,7 @@ function coordinate_action(mdp::JointMDP{S,A}, tree::JointMCTSTree{S,A,MaxPlusSt message_norm = tree.coordination_stats.message_norm n_agents = length(s) - state_agent_actions = [get_agent_actions(mdp, i, si) for (i, si) in enumerate(s)] + state_agent_actions = [agent_actions(mdp, i, si) for (i, si) in enumerate(s)] n_all_actions = length(tree.all_agent_actions[1]) n_edges = ne(tree.coordination_stats.adjmatgraph) @@ -223,12 +223,12 @@ function coordinate_action(mdp::JointMDP{S,A}, tree::JointMCTSTree{S,A,MaxPlusSt exp_q_values = zeros(length(state_agent_actions[i])) if tree.coordination_stats.node_exploration for (idx, ai) in enumerate(state_agent_actions[i]) - ai_idx = get_agent_actionindex(mdp, i, ai) + ai_idx = agent_actionindex(mdp, i, ai) exp_q_values[idx] = q_values[i, ai_idx] + exploration_constant*sqrt((log(state_total_n + 1.0))/(state_stats.agent_action_n[i, ai_idx] + 1.0)) end else for (idx, ai) in enumerate(state_agent_actions[i]) - ai_idx = get_agent_actionindex(mdp, i, ai) + ai_idx = agent_actionindex(mdp, i, ai) exp_q_values[idx] = q_values[i, ai_idx] end end @@ -259,21 +259,21 @@ function perform_message_passing!(fwd_messages::AbstractArray{F,2}, bwd_messages # Need to look up global index of agent action and use that # Need to break up vectorized loop @inbounds for aj in state_agent_actions[j] - aj_idx = get_agent_actionindex(mdp, j, aj) + aj_idx = agent_actionindex(mdp, j, aj) fwd_message_vals = zeros(length(state_agent_actions[i])) # TODO: Should we use inbounds here again? @inbounds for (idx, ai) in enumerate(state_agent_actions[i]) - ai_idx = get_agent_actionindex(mdp, i, ai) + ai_idx = agent_actionindex(mdp, i, ai) fwd_message_vals[idx] = q_values[i, ai_idx] - bwd_messages_old[e_idx, ai_idx] + state_stats.edge_action_q[e_idx, edge_tup_indices[ai_idx, aj_idx]]/n_edges + exploration_constant * sqrt( (log(state_total_n + 1.0)) / (state_stats.edge_action_n[e_idx, edge_tup_indices[ai_idx, aj_idx]] + 1) ) end fwd_messages[e_idx, aj_idx] = maximum(fwd_message_vals) end @inbounds for ai in state_agent_actions[i] - ai_idx = get_agent_actionindex(mdp, i, ai) + ai_idx = agent_actionindex(mdp, i, ai) bwd_message_vals = zeros(length(state_agent_actions[j])) @inbounds for (idx, aj) in enumerate(state_agent_actions[j]) - aj_idx = get_agent_actionindex(mdp, j, aj) + aj_idx = agent_actionindex(mdp, j, aj) bwd_message_vals[idx] = q_values[j, aj_idx] - fwd_messages_old[e_idx, aj_idx] + state_stats.edge_action_q[e_idx, edge_tup_indices[ai_idx, aj_idx]]/n_edges + exploration_constant * sqrt( (log(state_total_n + 1.0))/ (state_stats.edge_action_n[e_idx, edge_tup_indices[ai_idx, aj_idx]] + 1) ) end bwd_messages[e_idx, ai_idx] = maximum(bwd_message_vals) diff --git a/src/fvmcts/action_coordination/varel.jl b/src/fvmcts/action_coordination/varel.jl index d9646a2..0946c32 100644 --- a/src/fvmcts/action_coordination/varel.jl +++ b/src/fvmcts/action_coordination/varel.jl @@ -18,8 +18,8 @@ Fields: mutable struct VarElStatistics{S} <: CoordinationStatistics coord_graph_components::Vector{Vector{Int64}} min_degree_ordering::Vector{Int64} - n_component_stats::Dict{AbstractVector{S},Vector{Vector{Int64}}} - q_component_stats::Dict{AbstractVector{S},Vector{Vector{Float64}}} + n_component_stats::Dict{S,Vector{Vector{Int64}}} + q_component_stats::Dict{S,Vector{Vector{Float64}}} end function clear_statistics!(ve_stats::VarElStatistics) @@ -32,7 +32,7 @@ end Runs variable elimination at the current state using the VarEl Statistics to compute the best joint action with the component-wise exploration bonus. FYI: Rather complicated. """ -function coordinate_action(mdp::JointMDP{S,A}, tree::JointMCTSTree{S,A,VarElStatistics{S}}, s::AbstractVector{S}, +function coordinate_action(mdp::JointMDP{S,A}, tree::FVMCTSTree{S,A,VarElStatistics{S}}, s::S, exploration_constant::Float64=0.0, node_id::Int64=0) where {S,A} n_agents = length(s) @@ -61,7 +61,7 @@ function coordinate_action(mdp::JointMDP{S,A}, tree::JointMCTSTree{S,A,VarElStat # E.g. Agent 2 -> (3,4) in its best response and corresponding vector of agent 2 best actions best_response_fns = Dict{Int64,Tuple{Vector{Int64},Vector{Int64}}}() - state_dep_actions = [get_agent_actions(mdp, i, si) for (i, si) in enumerate(s)] + state_dep_actions = [agent_actions(mdp, i, si) for (i, si) in enumerate(s)] # Iterate over variable ordering # Need to maintain intermediate tables @@ -213,8 +213,8 @@ end """ Take the q-value from the MCTS step and distribute the updates across the component q-stats as per the formula in the Amato-Oliehoek paper. """ -function update_statistics!(mdp::JointMDP{S,A}, tree::JointMCTSTree{S,A,VarElStatistics{S}}, - s::AbstractVector{S}, ucb_action::AbstractVector{A}, q::AbstractVector{Float64}) where {S,A} +function update_statistics!(mdp::JointMDP{S,A}, tree::FVMCTSTree{S,A,VarElStatistics{S}}, + s::S, ucb_action::A, q::AbstractVector{Float64}) where {S,A} n_agents = length(s) @@ -226,23 +226,51 @@ function update_statistics!(mdp::JointMDP{S,A}, tree::JointMCTSTree{S,A,VarElSta # RECOVER local action corresp. to ucb action # TODO: Review this carefully. Need @req for action index for agent. local_action = [ucb_action[c] for c in comp] - local_action_idxs = [get_agent_actionindex(mdp, c, a) for (a, c) in zip(local_action, comp)] + local_action_idxs = [agent_actionindex(mdp, c, a) for (a, c) in zip(local_action, comp)] comp_ac_idx = LinearIndices(comp_tup)[local_action_idxs...] # NOTE: NOW we can update stats. Could generalize incremental update more here lock(tree.lock) do tree.coordination_stats.n_component_stats[s][idx][comp_ac_idx] += 1 - q_comp_value = sum(q[c] for c in comp) + q_comp_value = sum(q[c] for c in comp) # TODO!!! tree.coordination_stats.q_component_stats[s][idx][comp_ac_idx] += (q_comp_value - tree.coordination_stats.q_component_stats[s][idx][comp_ac_idx]) / tree.coordination_stats.n_component_stats[s][idx][comp_ac_idx] end end end +# TODO: is this the correct thing to do? +# My guess is no, but not sure. +function update_statistics!(mdp::JointMDP{S,A}, tree::FVMCTSTree{S,A,VarElStatistics{S}}, + s::S, ucb_action::A, q::Float64) where {S, A} + n_agents = length(s) + + for (idx, comp) in enumerate(tree.coordination_stats.coord_graph_components) + + # Create cartesian index tuple + comp_tup = Tuple(1:length(tree.all_agent_actions[c]) for c in comp) + + # RECOVER local action corresp. to ucb action + # TODO: Review this carefully. Need @req for action index for agent. + local_action = [ucb_action[c] for c in comp] + local_action_idxs = [agent_actionindex(mdp, c, a) for (a, c) in zip(local_action, comp)] + + comp_ac_idx = LinearIndices(comp_tup)[local_action_idxs...] + + # NOTE: NOW we can update stats. Could generalize incremental update more here + lock(tree.lock) do + tree.coordination_stats.n_component_stats[s][idx][comp_ac_idx] += 1 + q_comp_value = q # TODO!!! + tree.coordination_stats.q_component_stats[s][idx][comp_ac_idx] += + (q_comp_value - tree.coordination_stats.q_component_stats[s][idx][comp_ac_idx]) / tree.coordination_stats.n_component_stats[s][idx][comp_ac_idx] + end + end +end + -function init_statistics!(tree::JointMCTSTree{S,A,VarElStatistics{S}}, planner::JointMCTSPlanner, - s::AbstractVector{S}) where {S,A} +function init_statistics!(tree::FVMCTSTree{S,A,VarElStatistics{S}}, planner::FVMCTSPlanner, + s::S) where {S,A} n_comps = length(tree.coordination_stats.coord_graph_components) n_component_stats = Vector{Vector{Int64}}(undef, n_comps) diff --git a/src/fvmcts/factoredpolicy.jl b/src/fvmcts/factoredpolicy.jl index d1af7b9..d99c775 100644 --- a/src/fvmcts/factoredpolicy.jl +++ b/src/fvmcts/factoredpolicy.jl @@ -11,7 +11,7 @@ end FactoredRandomPolicy(problem::JointMDP; rng=Random.GLOBAL_RNG, updater=NothingUpdater()) = FactoredRandomPolicy(rng, problem, updater) function POMDPs.action(policy::FactoredRandomPolicy, s) - return [rand(policy.rng, get_agent_actions(policy.problem, i, si)) for (i, si) in enumerate(s)] + return [rand(policy.rng, agent_actions(policy.problem, i, si)) for (i, si) in enumerate(s)] end POMDPs.solve(solver::RandomSolver, problem::JointMDP) = FactoredRandomPolicy(solver.rng, problem, NothingUpdater()) \ No newline at end of file diff --git a/src/fvmcts/fv_mcts_vanilla.jl b/src/fvmcts/fv_mcts_vanilla.jl index f1b42c2..c71c6c9 100644 --- a/src/fvmcts/fv_mcts_vanilla.jl +++ b/src/fvmcts/fv_mcts_vanilla.jl @@ -7,7 +7,7 @@ abstract type AbstractCoordinationStrategy end struct VarEl <: AbstractCoordinationStrategy end -@with_kw struct MaxPlus <:AbstractCoordinationStrategy +Base.@kwdef struct MaxPlus <:AbstractCoordinationStrategy message_iters::Int64 = 10 message_norm::Bool = true use_agent_utils::Bool = false @@ -69,7 +69,7 @@ Fields: The specific strategy with which to compute the best joint action from the current MCTS statistics. default: VarEl() """ -@with_kw mutable struct FVMCTSSolver <: AbstractMCTSSolver +Base.@kwdef mutable struct FVMCTSSolver <: AbstractMCTSSolver n_iterations::Int64 = 100 max_time::Float64 = Inf depth::Int64 = 10 @@ -86,28 +86,28 @@ end mutable struct FVMCTSTree{S,A,CS<:CoordinationStatistics} # To map the multi-agent state vector to the ID of the node in the tree - state_map::Dict{AbstractVector{S},Int64} + state_map::Dict{S,Int64} # The next two vectors have one for each node ID in the tree total_n::Vector{Int} # The number of times the node has been tried - s_labels::Vector{AbstractVector{S}} # The state corresponding to the node ID + s_labels::Vector{S} # The state corresponding to the node ID # List of all individual actions of each agent for coordination purposes. - all_agent_actions::Vector{AbstractVector{A}} + all_agent_actions::Vector{A} coordination_stats::CS lock::ReentrantLock end -function FVMCTSTree(all_agent_actions::Vector{AbstractVector{A}}, +function FVMCTSTree(all_agent_actions::Vector{A}, coordination_stats::CS, - init_state::AbstractVector{S}, + init_state::S, lock::ReentrantLock, sz::Int64=10000) where {S, A, CS <: CoordinationStatistics} - return FVMCTSTree{S,A,CS}(Dict{typeof(init_state),Int64}(), + return FVMCTSTree{S,A,CS}(Dict{S,Int64}(), sizehint!(Int[], sz), - sizehint!(typeof(init_state)[], sz), + sizehint!(S[], sz), all_agent_actions, coordination_stats, lock @@ -145,25 +145,26 @@ Creates VarElStatistics internally with the CG components and the minimum degree """ function varel_joint_mcts_planner(solver::FVMCTSSolver, mdp::JointMDP{S,A}, - init_state::AbstractVector{S}, + init_state::S, ) where {S,A} # Get coordination graph components from maximal cliques - adjmat = coord_graph_adj_mat(mdp) - @assert size(adjmat)[1] == n_agents(mdp) "Adjacency Matrix does not match number of agents!" + #adjmat = coord_graph_adj_mat(mdp) + #@assert size(adjmat)[1] == n_agents(mdp) "Adjacency Matrix does not match number of agents!" - adjmatgraph = SimpleGraph(adjmat) + #adjmatgraph = SimpleGraph(adjmat) + adjmatgraph = coordination_graph(mdp) + coord_graph_components = maximal_cliques(adjmatgraph) min_degree_ordering = sortperm(degree(adjmatgraph)) # Initialize full agent actions - # TODO(jkg): this is incorrect? Or we need to override actiontype to refer to agent actions? - all_agent_actions = Vector{actiontype(mdp)}(undef, n_agents(mdp)) + all_agent_actions = Vector{(actiontype(mdp))}(undef, n_agents(mdp)) for i = 1:n_agents(mdp) - all_agent_actions[i] = get_agent_actions(mdp, i) + all_agent_actions[i] = agent_actions(mdp, i) end - ve_stats = VarElStatistics{eltype(init_state)}(coord_graph_components, min_degree_ordering, + ve_stats = VarElStatistics{S}(coord_graph_components, min_degree_ordering, Dict{typeof(init_state),Vector{Vector{Int64}}}(), Dict{typeof(init_state),Vector{Vector{Int64}}}(), ) @@ -182,7 +183,7 @@ Creates MaxPlusStatistics and assumes the various MP flags are sent down from th """ function maxplus_joint_mcts_planner(solver::FVMCTSSolver, mdp::JointMDP{S,A}, - init_state::AbstractVector{S}, + init_state::S, message_iters::Int64, message_norm::Bool, use_agent_utils::Bool, @@ -192,25 +193,27 @@ function maxplus_joint_mcts_planner(solver::FVMCTSSolver, @assert (node_exploration || edge_exploration) "At least one of nodes or edges should explore!" - adjmat = coord_graph_adj_mat(mdp) - @assert size(adjmat)[1] == n_agents(mdp) "Adjacency Mat does not match number of agents!" +#= adjmat = coord_graph_adj_mat(mdp) + @assert size(adjmat)[1] == n_agents(mdp) "Adjacency Mat does not match number of agents!" =# - adjmatgraph = SimpleGraph(adjmat) + #adjmatgraph = SimpleGraph(adjmat) + adjmatgraph = coordination_graph(mdp) + @assert size(adjacency_matrix(adjmatgraph))[1] == n_agents(mdp) # Initialize full agent actions # TODO(jkg): this is incorrect? Or we need to override actiontype to refer to agent actions? - all_agent_actions = Vector{actiontype(mdp)}(undef, n_agents(mdp)) + all_agent_actions = Vector{(actiontype(mdp))}(undef, n_agents(mdp)) for i = 1:n_agents(mdp) - all_agent_actions[i] = get_agent_actions(mdp, i) + all_agent_actions[i] = agent_actions(mdp, i) end - mp_stats = MaxPlusStatistics{eltype(init_state)}(adjmatgraph, + mp_stats = MaxPlusStatistics{S}(adjmatgraph, message_iters, message_norm, use_agent_utils, node_exploration, edge_exploration, - Dict{typeof(init_state),PerStateMPStats}()) + Dict{S,PerStateMPStats}()) # Create tree from the current state tree = FVMCTSTree(all_agent_actions, mp_stats, @@ -282,7 +285,7 @@ function plan!(planner::FVMCTSPlanner, s) end # build_tree can be called on the assumption that no reuse AND tree is reinitialized -function build_tree(planner::FVMCTSPlanner, s::AbstractVector{S}) where S +function build_tree(planner::FVMCTSPlanner, s::S) where S n_iterations = planner.solver.n_iterations depth = planner.solver.depth @@ -315,7 +318,7 @@ function simulate(planner::FVMCTSPlanner, node::FVStateNode, depth::Int64) ucb_action = coordinate_action(mdp, planner.tree, s, planner.solver.exploration_constant, node.id) # Monte Carlo Transition - sp, r = gen(DDNOut(:sp, :r), mdp, s, ucb_action, rng) + sp, r = @gen(:sp, :r)(mdp, s, ucb_action, rng) spid = lock(tree.lock) do get(tree.state_map, sp, 0) # may be non-zero even with no tree reuse @@ -342,12 +345,12 @@ function simulate(planner::FVMCTSPlanner, node::FVStateNode, depth::Int64) return q end -POMDPLinter.@POMDP_require simulate(planner::FVMCTSPlanner, s, depth::Int64) begin +@POMDP_require simulate(planner::FVMCTSPlanner, s, depth::Int64) begin mdp = planner.mdp P = typeof(mdp) @assert P <: JointMDP - SV = statetype(P) - @assert typeof(SV) <: AbstractVector + #SV = statetype(P) + #@assert typeof(SV) <: AbstractVector AV = actiontype(P) @assert typeof(A) <: AbstractVector @req discount(::P) @@ -357,10 +360,10 @@ POMDPLinter.@POMDP_require simulate(planner::FVMCTSPlanner, s, depth::Int64) beg @req gen(::DDNOut{(:sp, :r)}, ::P, ::SV, ::A, ::typeof(planner.rng)) ## Requirements from MMDP Model - @req get_agent_actions(::P, ::Int64) - @req get_agent_actions(::P, ::Int64, ::eltype(SV)) + @req agent_actions(::P, ::Int64) + @req agent_actions(::P, ::Int64, ::eltype(SV)) @req n_agents(::P) - @req coord_graph_adj_mat(::P) + @req coordination_graph(::P) # TODO: Should we also have this requirement for SV? @req isequal(::S, ::S) @@ -370,7 +373,7 @@ end function insert_node!(tree::FVMCTSTree{S,A,CS}, planner::FVMCTSPlanner, - s::AbstractVector{S}) where {S,A,CS <: CoordinationStatistics} + s::S) where {S,A,CS <: CoordinationStatistics} lock(tree.lock) do push!(tree.s_labels, s) @@ -388,23 +391,23 @@ function insert_node!(tree::FVMCTSTree{S,A,CS}, planner::FVMCTSPlanner, return FVStateNode(tree, ls) end -POMDPLinter.@POMDP_require insert_node!(tree::FVMCTSTree, planner::FVMCTSPlanner, s) begin +@POMDP_require insert_node!(tree::FVMCTSTree, planner::FVMCTSPlanner, s) begin P = typeof(planner.mdp) AV = actiontype(P) A = eltype(AV) SV = typeof(s) - S = eltype(SV) + #S = eltype(SV) # TODO: Review IQ and IN IQ = typeof(planner.solver.init_Q) if !(IQ <: Number) && !(IQ <: Function) - @req init_Q(::IQ, ::P, ::S, ::Vector{Int64}, ::AbstractVector{A}) + @req init_Q(::IQ, ::P, ::SV, ::Vector{Int64}, ::AbstractVector{A}) end IN = typeof(planner.solver.init_N) if !(IN <: Number) && !(IN <: Function) - @req init_N(::IQ, ::P, ::S, ::Vector{Int64}, ::AbstractVector{A}) + @req init_N(::IQ, ::P, ::SV, ::Vector{Int64}, ::AbstractVector{A}) end @req isequal(::S, ::S) From 7ee8e1ce1c9aba8711b4930d5a1e351819044113 Mon Sep 17 00:00:00 2001 From: rejuvyesh Date: Mon, 4 Jan 2021 03:13:19 +0000 Subject: [PATCH 08/17] allow scalar global rewards --- src/fvmcts/action_coordination/maxplus.jl | 12 +++++++++--- src/fvmcts/action_coordination/varel.jl | 4 ++-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/fvmcts/action_coordination/maxplus.jl b/src/fvmcts/action_coordination/maxplus.jl index 4bf7244..7990e7a 100644 --- a/src/fvmcts/action_coordination/maxplus.jl +++ b/src/fvmcts/action_coordination/maxplus.jl @@ -46,6 +46,12 @@ function clear_statistics!(mp_stats::MaxPlusStatistics) empty!(mp_stats.all_states_stats) end +function update_statistics!(mdp::JointMDP{S,A}, tree::FVMCTSTree{S,A,MaxPlusStatistics{S}}, + s::S, ucb_action::A, q::AbstractFloat) where {S,A} + + update_statistics!(mdp, tree, s, ucb_action, ones(typeof(q), n_agents(mdp)) * q) +end + """ Take the q-value from the MCTS step and distribute the updates across the per-node and per-edge q-stats as per the formula in our paper. """ @@ -53,10 +59,10 @@ function update_statistics!(mdp::JointMDP{S,A}, tree::FVMCTSTree{S,A,MaxPlusStat s::S, ucb_action::A, q::AbstractVector{Float64}) where {S,A} state_stats = tree.coordination_stats.all_states_stats[s] - n_agents = length(s) + nagents = n_agents(mdp) # Update per agent action stats - for i = 1:n_agents + for i = 1:nagents ac_idx = agent_actionindex(mdp, i, ucb_action[i]) lock(tree.lock) do state_stats.agent_action_n[i, ac_idx] += 1 @@ -207,7 +213,6 @@ function coordinate_action(mdp::JointMDP{S,A}, tree::FVMCTSTree{S,A,MaxPlusStati end # for t = 1:k # If edge exploration flag enabled, do a final exploration bonus - # TODO: Code reuse with normal message passing; consider modularizing if tree.coordination_stats.edge_exploration perform_message_passing!(fwd_messages, bwd_messages, mdp, tree.all_agent_actions, adjgraphmat, state_agent_actions, n_edges, q_values, message_norm, @@ -279,6 +284,7 @@ function perform_message_passing!(fwd_messages::AbstractArray{F,2}, bwd_messages bwd_messages[e_idx, ai_idx] = maximum(bwd_message_vals) end + # Normalize messages for better convergence if message_norm @views fwd_messages[e_idx, :] .-= sum(fwd_messages[e_idx, :])/length(fwd_messages[e_idx, :]) @views bwd_messages[e_idx, :] .-= sum(bwd_messages[e_idx, :])/length(bwd_messages[e_idx, :]) diff --git a/src/fvmcts/action_coordination/varel.jl b/src/fvmcts/action_coordination/varel.jl index 0946c32..d2730ae 100644 --- a/src/fvmcts/action_coordination/varel.jl +++ b/src/fvmcts/action_coordination/varel.jl @@ -233,7 +233,7 @@ function update_statistics!(mdp::JointMDP{S,A}, tree::FVMCTSTree{S,A,VarElStatis # NOTE: NOW we can update stats. Could generalize incremental update more here lock(tree.lock) do tree.coordination_stats.n_component_stats[s][idx][comp_ac_idx] += 1 - q_comp_value = sum(q[c] for c in comp) # TODO!!! + q_comp_value = sum(q[c] for c in comp) tree.coordination_stats.q_component_stats[s][idx][comp_ac_idx] += (q_comp_value - tree.coordination_stats.q_component_stats[s][idx][comp_ac_idx]) / tree.coordination_stats.n_component_stats[s][idx][comp_ac_idx] end @@ -261,7 +261,7 @@ function update_statistics!(mdp::JointMDP{S,A}, tree::FVMCTSTree{S,A,VarElStatis # NOTE: NOW we can update stats. Could generalize incremental update more here lock(tree.lock) do tree.coordination_stats.n_component_stats[s][idx][comp_ac_idx] += 1 - q_comp_value = q # TODO!!! + q_comp_value = q * length(comp) # Maintains equivalence with `sum(q[c] for c in comp)` tree.coordination_stats.q_component_stats[s][idx][comp_ac_idx] += (q_comp_value - tree.coordination_stats.q_component_stats[s][idx][comp_ac_idx]) / tree.coordination_stats.n_component_stats[s][idx][comp_ac_idx] end From 83897fe15cfaddbe56ac96f25082ed74c3be487d Mon Sep 17 00:00:00 2001 From: rejuvyesh Date: Mon, 4 Jan 2021 03:36:33 +0000 Subject: [PATCH 09/17] Get everything working --- Project.toml | 1 + src/MAMCTS.jl | 40 +++++++++++++++++++++++ src/fvmcts/action_coordination/maxplus.jl | 2 +- 3 files changed, 42 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 99b9304..9d69639 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ MultiAgentPOMDPs = "9ac5fd70-7902-42e7-9745-ff446b44e779" POMDPLinter = "f3bd98c0-eb40-45e2-9eb1-f2763262d755" POMDPModelTools = "08074719-1b2a-587c-a292-00f91cc44415" POMDPPolicies = "182e52fb-cfd0-5e46-8c26-fd0667c990f4" +POMDPSimulators = "e0d0a172-29c6-5d4e-96d0-f262df5d01fd" POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/src/MAMCTS.jl b/src/MAMCTS.jl index fb23d85..c9f6086 100644 --- a/src/MAMCTS.jl +++ b/src/MAMCTS.jl @@ -14,6 +14,46 @@ using BeliefUpdaters using MCTS: convert_estimator import POMDPModelTools +using POMDPSimulators: RolloutSimulator +import POMDPs + +function POMDPs.simulate(sim::RolloutSimulator, mdp::JointMDP, policy::Policy, initialstate::S) where {S} + + if sim.eps == nothing + eps = 0.0 + else + eps = sim.eps + end + + if sim.max_steps == nothing + max_steps = typemax(Int) + else + max_steps = sim.max_steps + end + + s = initialstate + + disc = 1.0 + r_total = zeros(n_agents(mdp)) + step = 1 + + while disc > eps && !isterminal(mdp, s) && step <= max_steps + a = action(policy, s) + + sp, r = @gen(:sp, :r)(mdp, s, a, sim.rng) + + r_total .+= disc.*r + + s = sp + + disc *= discount(mdp) + step += 1 + end + + return r_total +end + + ### # Factored Value MCTS # diff --git a/src/fvmcts/action_coordination/maxplus.jl b/src/fvmcts/action_coordination/maxplus.jl index 7990e7a..40fe54e 100644 --- a/src/fvmcts/action_coordination/maxplus.jl +++ b/src/fvmcts/action_coordination/maxplus.jl @@ -221,7 +221,7 @@ function coordinate_action(mdp::JointMDP{S,A}, tree::FVMCTSTree{S,A,MaxPlusStati # Maximize q values for agents - best_action = Vector{A}(undef, n_agents) + best_action = Vector{eltype(A)}(undef, n_agents) for i = 1:n_agents # NOTE: Again can't just iterate over agent actions as it may be a subset From 4a91e275880d53bf52b948d87f4a8db0b817ddf0 Mon Sep 17 00:00:00 2001 From: rejuvyesh Date: Mon, 4 Jan 2021 03:36:45 +0000 Subject: [PATCH 10/17] add tests Won't work until other stuff registered first --- test/Project.toml | 3 +++ test/runtests.jl | 64 ++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 66 insertions(+), 1 deletion(-) create mode 100644 test/Project.toml diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 0000000..72f58eb --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,3 @@ +[deps] +MultiAgentSysAdmin = "a4538d8c-5052-4f30-aec9-286910cf67a1" +MultiUAVDelivery = "13c59af0-a5df-4589-8a68-75dc1bf2d35a" diff --git a/test/runtests.jl b/test/runtests.jl index 362e9bc..0e47049 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,68 @@ using MAMCTS using Test +using POMDPs +using MultiAgentSysAdmin +using MultiUAVDelivery + @testset "MAMCTS.jl" begin - # Write your tests here. + + @testset "varel" begin + @testset "sysadmin" begin + + @testset "local" begin + mdp = BiSysAdmin() + solver = FVMCTSSolver() + planner = solve(solver, mdp) + s = rand(initialstate(mdp)) + a = action(planner, s) + @test a isa actiontype(mdp) + end + + + @testset "global" begin + mdp = BiSysAdmin(;global_rewards=true) + solver = FVMCTSSolver() + planner = solve(solver, mdp) + s = rand(initialstate(mdp)) + a = action(planner, s) + @test a isa actiontype(mdp) + end + end + end + + @testset "maxplus" begin + @testset "sysadmin" begin + + @testset "local" begin + mdp = BiSysAdmin() + solver = FVMCTSSolver(;coordination_strategy=MaxPlus()) + planner = solve(solver, mdp) + s = rand(initialstate(mdp)) + a = action(planner, s) + @test a isa actiontype(mdp) + end + + + @testset "global" begin + mdp = BiSysAdmin(;global_rewards=true) + solver = FVMCTSSolver(;coordination_strategy=MaxPlus()) + planner = solve(solver, mdp) + s = rand(initialstate(mdp)) + a = action(planner, s) + @test a isa actiontype(mdp) + end + end + + @testset "uav" begin + mdp = FirstOrderMultiUAVDelivery() + solver = FVMCTSSolver(;coordination_strategy=MaxPlus()) + planner = solve(solver, mdp) + s = rand(initialstate(mdp)) + a = action(planner, s) + @test a isa actiontype(mdp) + end + + end + end From 6a5962628f01e49d6918220716f717875fb37db8 Mon Sep 17 00:00:00 2001 From: rejuvyesh Date: Mon, 4 Jan 2021 03:50:51 +0000 Subject: [PATCH 11/17] ugly monkeypatch for RolloutSimulator --- src/MAMCTS.jl | 37 +++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/src/MAMCTS.jl b/src/MAMCTS.jl index c9f6086..55ad847 100644 --- a/src/MAMCTS.jl +++ b/src/MAMCTS.jl @@ -17,6 +17,7 @@ import POMDPModelTools using POMDPSimulators: RolloutSimulator import POMDPs +# Patch simulate to support vector of rewards function POMDPs.simulate(sim::RolloutSimulator, mdp::JointMDP, policy::Policy, initialstate::S) where {S} if sim.eps == nothing @@ -33,10 +34,21 @@ function POMDPs.simulate(sim::RolloutSimulator, mdp::JointMDP, policy::Policy, i s = initialstate - disc = 1.0 - r_total = zeros(n_agents(mdp)) - step = 1 + + # TODO: doesn't this add unnecessary action search? + r = @gen(:r)(mdp, s, action(policy, s), sim.rng) + if r isa AbstractVector + r_total = zeros(n_agents(mdp)) + else + r_total = 0.0 + end + sim_helper!(r_total, sim, mdp, policy, s, max_steps, eps) + return r_total +end +function sim_helper!(r_total::AbstractVector{F}, sim, mdp, policy, s, max_steps, eps) where {F} + step = 1 + disc = 1.0 while disc > eps && !isterminal(mdp, s) && step <= max_steps a = action(policy, s) @@ -50,7 +62,24 @@ function POMDPs.simulate(sim::RolloutSimulator, mdp::JointMDP, policy::Policy, i step += 1 end - return r_total +end + +function sim_helper!(r_total::AbstractFloat, sim, mdp, policy, s, max_steps, eps) + step = 1 + disc = 1.0 + while disc > eps && !isterminal(mdp, s) && step <= max_steps + a = action(policy, s) + + sp, r = @gen(:sp, :r)(mdp, s, a, sim.rng) + + r_total += disc.*r + + s = sp + + disc *= discount(mdp) + step += 1 + end + end From 0ab8bd745b11ef96a1890418fe9ce2c98c8e887b Mon Sep 17 00:00:00 2001 From: rejuvyesh Date: Mon, 4 Jan 2021 17:52:40 +0000 Subject: [PATCH 12/17] small fixes as requested for `@req` --- src/fcmcts/fcmcts.jl | 4 ++-- src/fvmcts/fv_mcts_vanilla.jl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/fcmcts/fcmcts.jl b/src/fcmcts/fcmcts.jl index 61f463a..8fdc5b5 100644 --- a/src/fcmcts/fcmcts.jl +++ b/src/fcmcts/fcmcts.jl @@ -263,7 +263,7 @@ end mdp = planner.mdp P = typeof(mdp) @assert P <: JointMDP # req does different thing? - #SV = statetype(P) + SV = statetype(P) #@assert typeof(SV) <: AbstractVector # TODO: Is this correct? AV = actiontype(P) @assert typeof(AV) <: AbstractVector @@ -271,7 +271,7 @@ end @req isterminal(::P, ::SV) @subreq insert_node!(planner.tree, planner, s) @subreq estimate_value(planner.solved_estimate, mdp, s, depth) - @req gen(::DDNOut{(:sp, :r)}, ::P, ::SV, ::AV, ::typeof(planner.rng)) + @req gen(::P, ::SV, ::AV, ::typeof(planner.rng)) # XXX this is not exactly right - it could be satisfied with transition # MMDP reqs @req agent_actions(::P, ::Int64) diff --git a/src/fvmcts/fv_mcts_vanilla.jl b/src/fvmcts/fv_mcts_vanilla.jl index c71c6c9..107c1bd 100644 --- a/src/fvmcts/fv_mcts_vanilla.jl +++ b/src/fvmcts/fv_mcts_vanilla.jl @@ -349,7 +349,7 @@ end mdp = planner.mdp P = typeof(mdp) @assert P <: JointMDP - #SV = statetype(P) + SV = statetype(P) #@assert typeof(SV) <: AbstractVector AV = actiontype(P) @assert typeof(A) <: AbstractVector @@ -357,7 +357,7 @@ end @req isterminal(::P, ::SV) @subreq insert_node!(planner.tree, planner, s) @subreq estimate_value(planner.solved_estimate, mdp, s, depth) - @req gen(::DDNOut{(:sp, :r)}, ::P, ::SV, ::A, ::typeof(planner.rng)) + @req gen(::P, ::SV, ::A, ::typeof(planner.rng)) # XXX this is not exactly right - it could be satisfied with transition ## Requirements from MMDP Model @req agent_actions(::P, ::Int64) From b0249d447fc04552e3d967897c5acc911894cc50 Mon Sep 17 00:00:00 2001 From: rejuvyesh Date: Mon, 4 Jan 2021 21:30:51 +0000 Subject: [PATCH 13/17] require state space to be iterable --- src/fcmcts/fcmcts.jl | 1 + src/fvmcts/fv_mcts_vanilla.jl | 1 + 2 files changed, 2 insertions(+) diff --git a/src/fcmcts/fcmcts.jl b/src/fcmcts/fcmcts.jl index 8fdc5b5..be666ba 100644 --- a/src/fcmcts/fcmcts.jl +++ b/src/fcmcts/fcmcts.jl @@ -264,6 +264,7 @@ end P = typeof(mdp) @assert P <: JointMDP # req does different thing? SV = statetype(P) + @req iterate(::SV) #@assert typeof(SV) <: AbstractVector # TODO: Is this correct? AV = actiontype(P) @assert typeof(AV) <: AbstractVector diff --git a/src/fvmcts/fv_mcts_vanilla.jl b/src/fvmcts/fv_mcts_vanilla.jl index 107c1bd..160be3f 100644 --- a/src/fvmcts/fv_mcts_vanilla.jl +++ b/src/fvmcts/fv_mcts_vanilla.jl @@ -350,6 +350,7 @@ end P = typeof(mdp) @assert P <: JointMDP SV = statetype(P) + @req iterate(::SV) #@assert typeof(SV) <: AbstractVector AV = actiontype(P) @assert typeof(A) <: AbstractVector From c0753db7f862efb3ac60ab009c44c644d4488ecd Mon Sep 17 00:00:00 2001 From: rejuvyesh Date: Tue, 5 Jan 2021 04:21:08 +0000 Subject: [PATCH 14/17] Rename project, only have FactoredValueMCTS --- .github/workflows/ci.yml | 4 +- Project.toml | 2 +- README.md | 10 +- docs/Manifest.toml | 2 +- docs/Project.toml | 2 +- docs/make.jl | 12 +- docs/src/index.md | 6 +- src/{MAMCTS.jl => FactoredValueMCTS.jl} | 13 +- src/fcmcts/fcmcts.jl | 309 ------------------------ test/runtests.jl | 4 +- 10 files changed, 22 insertions(+), 342 deletions(-) rename src/{MAMCTS.jl => FactoredValueMCTS.jl} (93%) delete mode 100644 src/fcmcts/fcmcts.jl diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a95395e..d3a74a6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -60,8 +60,8 @@ jobs: - run: | julia --project=docs -e ' using Documenter: doctest - using MAMCTS - doctest(MAMCTS)' + using FactoredValueMCTS + doctest(FactoredValueMCTS)' - run: julia --project=docs docs/make.jl env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/Project.toml b/Project.toml index 9d69639..6a0e875 100644 --- a/Project.toml +++ b/Project.toml @@ -1,4 +1,4 @@ -name = "MAMCTS" +name = "FactoredValueMCTS" uuid = "c016a6d7-1193-47d7-896a-d9f14d6b4b26" authors = ["Stanford Intelligent Systems Laboratory"] version = "0.1.0" diff --git a/README.md b/README.md index 629c498..f80f1c2 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ -# MAMCTS +# FactoredValueMCTS -[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://rejuvyesh.github.io/MAMCTS.jl/stable) -[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://rejuvyesh.github.io/MAMCTS.jl/dev) -[![Build Status](https://github.com/rejuvyesh/MAMCTS.jl/workflows/CI/badge.svg)](https://github.com/rejuvyesh/MAMCTS.jl/actions) -[![Coverage](https://codecov.io/gh/rejuvyesh/MAMCTS.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/rejuvyesh/MAMCTS.jl) +[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://JuliaPOMDP.github.io/FactoredValueMCTS.jl/stable) +[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://JuliaPOMDP.github.io/FactoredValueMCTS.jl/dev) +[![Build Status](https://github.com/JuliaPOMDP/FactoredValueMCTS.jl/workflows/CI/badge.svg)](https://github.com/JuliaPOMDP/FactoredValueMCTS.jl/actions) +[![Coverage](https://codecov.io/gh/JuliaPOMDP/FactoredValueMCTS.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/JuliaPOMDP/FactoredValueMCTS.jl) diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 4053bdb..2b3e7d3 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -49,7 +49,7 @@ uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" [[Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" -[[MAMCTS]] +[[FactoredValueMCTS]] path = ".." uuid = "c016a6d7-1193-47d7-896a-d9f14d6b4b26" version = "0.1.0" diff --git a/docs/Project.toml b/docs/Project.toml index c646ea9..cfa3359 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,3 +1,3 @@ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -MAMCTS = "c016a6d7-1193-47d7-896a-d9f14d6b4b26" +FactoredValueMCTS = "c016a6d7-1193-47d7-896a-d9f14d6b4b26" diff --git a/docs/make.jl b/docs/make.jl index 4350303..4245289 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,14 +1,14 @@ -using MAMCTS +using FactoredValueMCTS using Documenter makedocs(; - modules=[MAMCTS], + modules=[FactoredValueMCTS], authors="Stanford Intelligent Systems Laboratory", - repo="https://github.com/rejuvyesh/MAMCTS.jl/blob/{commit}{path}#L{line}", - sitename="MAMCTS.jl", + repo="https://github.com/JuliaPOMDP/FactoredValueMCTS.jl/blob/{commit}{path}#L{line}", + sitename="FactoredValueMCTS.jl", format=Documenter.HTML(; prettyurls=get(ENV, "CI", "false") == "true", - canonical="https://rejuvyesh.github.io/MAMCTS.jl", + canonical="https://JuliaPOMDP.github.io/FactoredValueMCTS.jl", assets=String[], ), pages=[ @@ -17,5 +17,5 @@ makedocs(; ) deploydocs(; - repo="github.com/rejuvyesh/MAMCTS.jl", + repo="github.com/JuliaPOMDP/FactoredValueMCTS.jl", ) diff --git a/docs/src/index.md b/docs/src/index.md index f63d0e5..11deaaa 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,12 +1,12 @@ ```@meta -CurrentModule = MAMCTS +CurrentModule = FactoredValueMCTS ``` -# MAMCTS +# FactoredValueMCTS ```@index ``` ```@autodocs -Modules = [MAMCTS] +Modules = [FactoredValueMCTS] ``` diff --git a/src/MAMCTS.jl b/src/FactoredValueMCTS.jl similarity index 93% rename from src/MAMCTS.jl rename to src/FactoredValueMCTS.jl index 55ad847..f37aaf5 100644 --- a/src/MAMCTS.jl +++ b/src/FactoredValueMCTS.jl @@ -1,4 +1,4 @@ -module MAMCTS +module FactoredValueMCTS using Random using LinearAlgebra @@ -99,16 +99,5 @@ export MaxPlus, VarEl -### - -### -# Naive Fully Connected Centralized MCTS -# - -include(joinpath("fcmcts", "fcmcts.jl")) -export - FCMCTSSolver - -### end diff --git a/src/fcmcts/fcmcts.jl b/src/fcmcts/fcmcts.jl deleted file mode 100644 index be666ba..0000000 --- a/src/fcmcts/fcmcts.jl +++ /dev/null @@ -1,309 +0,0 @@ - - -Base.@kwdef mutable struct FCMCTSSolver <: AbstractMCTSSolver - n_iterations::Int64 = 100 - max_time::Float64 = Inf - depth::Int64 = 10 - exploration_constant::Float64 = 1.0 - rng::AbstractRNG = Random.GLOBAL_RNG - estimate_value::Any = RolloutEstimator(RandomSolver(rng)) - init_Q::Any = 0.0 - init_N::Any = 0 - reuse_tree::Bool = false -end - -mutable struct FCMCTSTree{S,A} - # To track if state node in tree already - # NOTE: We don't strictly need this at all if no tree reuse... - state_map::Dict{S,Int64} - - # these vectors have one entry for each state node - # Only doing factored satistics (for actions), not state components - child_ids::Vector{Vector{Int}} - total_n::Vector{Int} - s_labels::Vector{S} - - # TODO(jkg): is this the best way to track stats? - # these vectors have one entry for each action node - n::Vector{Int64} - q::Vector{Float64} - a_labels::Vector{A} - - lock::ReentrantLock -end - -function FCMCTSTree{S,A}(init_state::S, lock::ReentrantLock, sz::Int=1000) where {S,A} - sz = min(sz, 100_000) - return FCMCTSTree{S,A}(Dict{S,Int64}(), - sizehint!(Vector{Int}[], sz), - sizehint!(Int[], sz), - sizehint!(typeof(init_state)[], sz), - Int64[], - Float64[], - sizehint!(Vector{A}[], sz), - lock) -end - -Base.isempty(t::FCMCTSTree) = isempty(t.state_map) -state_nodes(t::FCMCTSTree) = (FCStateNode(t, id) for id in 1:length(t.total_n)) - -struct FCStateNode{S,A} - tree::FCMCTSTree{S,A} - id::Int64 -end - -# accessors for state nodes -@inline state(n::FCStateNode) = lock(n.tree.lock) do - n.tree.s_labels[n.id] -end -@inline total_n(n::FCStateNode) = n.tree.total_n[n.id] -@inline children(n::FCStateNode) = (FCActionNode(n.tree, id) for id in n.tree.child_ids[n.id]) - -# Adding action node info -struct FCActionNode{S,A} - tree::FCMCTSTree{S,A} - id::Int64 -end - -# accessors for action nodes -@inline POMDPs.action(n::FCActionNode) = n.tree.a_labels[n.id] - - -mutable struct FCMCTSPlanner{S,A,SE,RNG<:AbstractRNG} <: AbstractMCTSPlanner{JointMDP{S,A}} - solver::FCMCTSSolver - mdp::JointMDP{S,A} - tree::FCMCTSTree{S,A} - solved_estimate::SE - rng::RNG -end - -function FCMCTSPlanner(solver::FCMCTSSolver, mdp::JointMDP{S,A}) where {S,A} - init_state = initialstate(mdp, solver.rng) - tree = FCMCTSTree{S,A}(init_state, ReentrantLock(), solver.n_iterations) - se = convert_estimator(solver.estimate_value, solver, mdp) - return FCMCTSPlanner(solver, mdp, tree, se, solver.rng) -end - - -function clear_tree!(planner::FCMCTSPlanner) - lock(planner.tree.lock) do - # Clear out state hash dict entirely - empty!(planner.tree.state_map) - # Empty state vectors with state hints - sz = min(planner.solver.n_iterations, 100_000) - - empty!(planner.tree.s_labels) - sizehint!(planner.tree.s_labels, sz) - - empty!(planner.tree.child_ids) - sizehint!(planner.tree.child_ids, sz) - empty!(planner.tree.total_n) - sizehint!(planner.tree.total_n, sz) - - empty!(planner.tree.n) - empty!(planner.tree.q) - empty!(planner.tree.a_labels) - end -end - -function POMDPs.solve(solver::FCMCTSSolver, mdp::JointMDP) - return FCMCTSPlanner(solver, mdp) -end - -function POMDPs.action(planner::FCMCTSPlanner, s) - clear_tree!(planner) - plan!(planner, s) - s_lut = lock(planner.tree.lock) do - planner.tree.state_map[s] - end - best_anode = lock(planner.tree.lock) do - compute_best_action_node(planner.mdp, planner.tree, FCStateNode(planner.tree, s_lut)) # c = 0.0 by default - end - - best_a = lock(planner.tree.lock) do - action(best_anode) - end - return best_a -end - -function POMDPModelTools.action_info(planner::FCMCTSPlanner, s) - a = POMDPs.action(planner, s) - return a, nothing -end - - -function plan!(planner::FCMCTSPlanner, s) - planner.tree = build_tree(planner, s) -end - -function build_tree(planner::FCMCTSPlanner, s::AbstractVector{S}) where S - n_iterations = planner.solver.n_iterations - depth = planner.solver.depth - - root = insert_node!(planner.tree, planner, s) - # build the tree - @sync for n = 1:n_iterations - @spawn simulate(planner, root, depth) - end - return planner.tree -end - -function simulate(planner::FCMCTSPlanner, node::FCStateNode, depth::Int64) - - mdp = planner.mdp - rng = planner.rng - s = state(node) - tree = node.tree - - # once depth is zero return - if isterminal(planner.mdp, s) - return 0.0 - elseif depth == 0 - return sum(estimate_value(planner.solved_estimate, planner.mdp, s, depth)) - end - - # Choose best UCB action (NOT an action node) - ucb_action_node = lock(planner.tree.lock) do - compute_best_action_node(mdp, planner.tree, node, planner.solver.exploration_constant) - end - ucb_action = lock(planner.tree.lock) do - action(ucb_action_node) - end - - # @show ucb_action - # MC Transition - sp, r = @gen(:sp, :r)(mdp, s, ucb_action, rng) - - # NOTE(jkg): just summing up the rewards to get a single value - # TODO: should we divide by n_agents? - r = sum(r) - - spid = lock(tree.lock) do - get(tree.state_map, sp, 0) # may be non-zero even with no tree reuse - end - if spid == 0 - spn = insert_node!(tree, planner, sp) - spid = spn.id - # TODO estimate_value - # NOTE(jkg): again just summing up the values to get a single value - q = r + discount(mdp) * sum(estimate_value(planner.solved_estimate, planner.mdp, sp, depth - 1)) - else - q = r + discount(mdp) * simulate(planner, FCStateNode(tree, spid) , depth - 1) - end - - ## Not bothering with tree vis right now - # Augment N(s) - lock(tree.lock) do - tree.total_n[node.id] += 1 - tree.n[ucb_action_node.id] += 1 - tree.q[ucb_action_node.id] += (q - tree.q[ucb_action_node.id]) / tree.n[ucb_action_node.id] - end - - return q -end - -# NOTE: This is a bit different from https://github.com/JuliaPOMDP/MCTS.jl/blob/master/src/vanilla.jl#L328 -function insert_node!(tree::FCMCTSTree, planner::FCMCTSPlanner, s) - - lock(tree.lock) do - push!(tree.s_labels, s) - tree.state_map[s] = length(tree.s_labels) - push!(tree.child_ids, []) - end - - # NOTE: Doing state-dep actions here the JointMDP way - state_dep_jtactions = vec(map(collect, Iterators.product((agent_actions(planner.mdp, i, si) for (i, si) in enumerate(s))...))) - total_n = 0 - - for a in state_dep_jtactions - n = init_N(planner.solver.init_N, planner.mdp, s, a) - total_n += n - lock(tree.lock) do - push!(tree.n, n) - push!(tree.q, init_Q(planner.solver.init_Q, planner.mdp, s, a)) - push!(tree.a_labels, a) - push!(last(tree.child_ids), length(tree.n)) - end - end - lock(tree.lock) do - push!(tree.total_n, total_n) - end - ln = lock(tree.lock) do - length(tree.total_n) - end - return FCStateNode(tree, ln) -end - - - -# NOTE: The logic here is a bit simpler than https://github.com/JuliaPOMDP/MCTS.jl/blob/master/src/vanilla.jl#L390 -# Double check that this is still the behavior we want -function compute_best_action_node(mdp::JointMDP, tree::FCMCTSTree, node::FCStateNode, c::Float64=0.0) - best_val = -Inf # The Q value - best = first(children(node)) - - sn = total_n(node) - - child_nodes = children(node) - - for sanode in child_nodes - - val = tree.q[sanode.id] + c*sqrt(log(sn + 1)/ (tree.n[sanode.id] + 1)) - - - if val > best_val - best_val = val - best = sanode - end - end - return best -end - -@POMDP_require simulate(planner::FCMCTSPlanner, s, depth::Int64) begin - mdp = planner.mdp - P = typeof(mdp) - @assert P <: JointMDP # req does different thing? - SV = statetype(P) - @req iterate(::SV) - #@assert typeof(SV) <: AbstractVector # TODO: Is this correct? - AV = actiontype(P) - @assert typeof(AV) <: AbstractVector - @req discount(::P) - @req isterminal(::P, ::SV) - @subreq insert_node!(planner.tree, planner, s) - @subreq estimate_value(planner.solved_estimate, mdp, s, depth) - @req gen(::P, ::SV, ::AV, ::typeof(planner.rng)) # XXX this is not exactly right - it could be satisfied with transition - - # MMDP reqs - @req agent_actions(::P, ::Int64) - @req agent_actions(::P, ::Int64, ::eltype(SV)) # TODO should this be eltype? - @req n_agents(::P) - - # TODO: Should we also have this requirement for SV? - @req isequal(::S, ::S) - @req hash(::S) -end - -@POMDP_require insert_node!(tree::FCMCTSTree, planner::FCMCTSPlanner, s) begin - - P = typeof(planner.mdp) - AV = actiontype(P) - A = eltype(AV) - SV = typeof(s) - S = eltype(SV) - - # TODO: Review IQ and IN - # Should this be ::S or ::SV? We can have global state that's not a vector. - IQ = typeof(planner.solver.init_Q) - if !(IQ <: Number) && !(IQ <: Function) - @req init_Q(::IQ, ::P, ::S, ::Vector{Int64}, ::AbstractVector{A}) - end - - IN = typeof(planner.solver.init_N) - if !(IN <: Number) && !(IN <: Function) - @req init_N(::IQ, ::P, ::S, ::Vector{Int64}, ::AbstractVector{A}) - end - - @req isequal(::S, ::S) - @req hash(::S) -end diff --git a/test/runtests.jl b/test/runtests.jl index 0e47049..80d9e48 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,11 +1,11 @@ -using MAMCTS +using FactoredValueMCTS using Test using POMDPs using MultiAgentSysAdmin using MultiUAVDelivery -@testset "MAMCTS.jl" begin +@testset "FactoredValueMCTS.jl" begin @testset "varel" begin @testset "sysadmin" begin From 8b7098dbdbfe7bceda5ea26814a9c4ec741660fc Mon Sep 17 00:00:00 2001 From: rejuvyesh Date: Tue, 5 Jan 2021 04:25:58 +0000 Subject: [PATCH 15/17] fix a few typos --- src/fvmcts/fv_mcts_vanilla.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/fvmcts/fv_mcts_vanilla.jl b/src/fvmcts/fv_mcts_vanilla.jl index 160be3f..ad7810b 100644 --- a/src/fvmcts/fv_mcts_vanilla.jl +++ b/src/fvmcts/fv_mcts_vanilla.jl @@ -353,12 +353,12 @@ end @req iterate(::SV) #@assert typeof(SV) <: AbstractVector AV = actiontype(P) - @assert typeof(A) <: AbstractVector + @assert typeof(AV) <: AbstractVector @req discount(::P) @req isterminal(::P, ::SV) @subreq insert_node!(planner.tree, planner, s) @subreq estimate_value(planner.solved_estimate, mdp, s, depth) - @req gen(::P, ::SV, ::A, ::typeof(planner.rng)) # XXX this is not exactly right - it could be satisfied with transition + @req gen(::P, ::SV, ::AV, ::typeof(planner.rng)) # XXX this is not exactly right - it could be satisfied with transition ## Requirements from MMDP Model @req agent_actions(::P, ::Int64) From 5a6cdf3b41a8eb088ec5fee2e65e7907b48f98a8 Mon Sep 17 00:00:00 2001 From: rejuvyesh Date: Tue, 5 Jan 2021 04:58:24 +0000 Subject: [PATCH 16/17] make linter happy --- src/FactoredValueMCTS.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/FactoredValueMCTS.jl b/src/FactoredValueMCTS.jl index f37aaf5..97a9ad5 100644 --- a/src/FactoredValueMCTS.jl +++ b/src/FactoredValueMCTS.jl @@ -20,13 +20,13 @@ import POMDPs # Patch simulate to support vector of rewards function POMDPs.simulate(sim::RolloutSimulator, mdp::JointMDP, policy::Policy, initialstate::S) where {S} - if sim.eps == nothing + if sim.eps === nothing eps = 0.0 else eps = sim.eps end - if sim.max_steps == nothing + if sim.max_steps === nothing max_steps = typemax(Int) else max_steps = sim.max_steps From c973fb4879c367e35f85250c38a1329a0b14a3d0 Mon Sep 17 00:00:00 2001 From: rejuvyesh Date: Tue, 5 Jan 2021 17:53:00 +0000 Subject: [PATCH 17/17] Add citation for our paper --- CITATION.bib | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 CITATION.bib diff --git a/CITATION.bib b/CITATION.bib new file mode 100644 index 0000000..3e9f805 --- /dev/null +++ b/CITATION.bib @@ -0,0 +1,7 @@ +@inproceedings{choudhury2021scalable, + title={Scalable Anytime Planning for Multi-Agent {MDP}s}, + author={Choudhury, Shushman and Gupta, Jayesh K and Morales, Peter and Kochenderfer, Mykel}, + booktitle={International Conference on Autonomous Agents and Multiagent Systems (AAMAS)}, + year={2021}, + organization={IFAAMAS} +} \ No newline at end of file