Skip to content

Commit

Permalink
Merge pull request #173 from purdue-arc/autonomy_migration
Browse files Browse the repository at this point in the history
Merge autonomy_migration into ros-2-complete
  • Loading branch information
jcrm1 authored Nov 12, 2023
2 parents 31ea08e + 63fd225 commit 53b9cda
Show file tree
Hide file tree
Showing 26 changed files with 1,764 additions and 0 deletions.
Empty file added log/COLCON_IGNORE
Empty file.
1 change: 1 addition & 0 deletions log/latest
1 change: 1 addition & 0 deletions log/latest_list
35 changes: 35 additions & 0 deletions src/rktl_autonomy/config/rocket_league.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
reward:
# reward to be given each frame
# constant: 0.0
# # reward as a function of squared distance from the ball
# ball_dist_sq: -0.5
# # reward as a function of squared distance from the ball to the goal
# goal_dist_sq: -1.0
# reward given when car changes velocity direction
direction_change: 0
# reward given when the episode ends with the car scoring the proper goal
win: [100, 200, 300]
# reward given when the episode ends with the car scoring on the wrong goal
loss: [100, 100, 100]
# reward given each frame when the car is in reverse
# reverse: -25
# # reward given each frame when the car is near the walls
# walls:
# # actual reward value
# value: -50
# # threshold distance (meters)
# threshold: 0.25

# duration when the episode will be terminated. Unit is seconds (sim time)
max_episode_time: 15

log:
base_dir: "~/catkin_ws/data/rocket_league/"
# frequency to save progress plot. Unit is episodes
plot_freq: 50
# variables to display in progress plot
basic:
- "duration"
- "goals"
advanced:
- "net_reward"
47 changes: 47 additions & 0 deletions src/rktl_autonomy/launch/rocket_league_agent_launch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument, LogInfo
from launch.conditions import IfCondition
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node

def generate_launch_description():
plot_log = LaunchConfiguration('plot_log') # True if plotting is enabled. Else False
weights = LaunchConfiguration('weights') # Filepath to weights

plot_log_launch_arg = DeclareLaunchArgument(
'plot_log',
default_value='false',
description='Set to true to enable logging for plotting performance.'
)

weights_launch_arg = DeclareLaunchArgument(
'weights',
default_value='~/catkin_ws/data/rocket_league/model',
description='filepath to weights.'
)

plot_log_info = LogInfo(condition=IfCondition(plot_log), msg="Enabling performance plotting...")

agent_node = Node(
package='rktl_autonomy',
executable='rocket_league_agent',
name='rocket_league_agent',
output='screen',
parameters=[{'weights': weights}],
)

plotter_node = Node(
condition=IfCondition(plot_log),
package='rktl_autonomy',
executable='plotter',
name='plotter',
output='screen',
remappings=[('~log', 'rocket_league_agent/log')],
)

return LaunchDescription([plot_log_launch_arg,
weights_launch_arg,
plot_log_info,
agent_node,
plotter_node
])
79 changes: 79 additions & 0 deletions src/rktl_autonomy/launch/rocket_league_train_launch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument, LogInfo
from launch.conditions import IfCondition
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node
from launch.actions import IncludeLaunchDescription
from launch.launch_description_sources import PythonLaunchDescriptionSource

def generate_launch_description():
return LaunchDescription([
DeclareLaunchArgument(
'plot_log',
default_value='true',
description='Set to true to enable logging for plotting performance.'
),
DeclareLaunchArgument(
'agent_name',
default_value='rocket_league_agent',
description='Name of the agent.'
),
DeclareLaunchArgument(
'log_file',
default_value='rocket_league_agent/log',
description='Filepath for logger output'),
DeclareLaunchArgument(
'render',
default_value='false',
description='Set to true to enable rendering.'
),
DeclareLaunchArgument(
'sim_mode',
default_value='ideal',
description='Simulation mode.'
),

DeclareLaunchArgument(
'rate',
default_value='10.0',
description='Rate parameter.'
),
DeclareLaunchArgument(
'agent_type',
default_value='none',
description='Agent type.'
),

LogInfo(
condition=IfCondition(LaunchConfiguration('plot_log')),
msg="Enabling performance plotting..."
),

Node(
package='rktl_autonomy',
executable='rocket_league_agent',
name=LaunchConfiguration('agent_name'),
output='screen',
parameters=[{'rate': LaunchConfiguration('rate')}],
namespace=LaunchConfiguration('agent_name')
),

IncludeLaunchDescription(
PythonLaunchDescriptionSource(['$(find rktl_launch)/launch/rocket_league_sim_launch.py']), # TODO: Replace with the path to the launch file
launch_arguments={
'render': LaunchConfiguration('render'),
'sim_mode': LaunchConfiguration('sim_mode'),
'agent_type': LaunchConfiguration('agent_type'),
}.items(),
),

Node(
condition=IfCondition(LaunchConfiguration('plot_log')),
package='rktl_autonomy',
executable='plotter',
name='plotter',
output='screen',
remappings=[('~log', LaunchConfiguration('log_file'))],
namespace=LaunchConfiguration('agent_name')
)
])
182 changes: 182 additions & 0 deletions src/rktl_autonomy/nodes/plotter
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
#!/usr/bin/env python3
"""Convenience node to plot training progress.
License:
BSD 3-Clause License
Copyright (c) 2023, Autonomous Robotics Club of Purdue (Purdue ARC)
All rights reserved.
"""

# import rospy
import rclpy
from rclpy.node import Node
from rclpy.parameter import Parameter
from diagnostic_msgs.msg import DiagnosticStatus
from mpl_toolkits.axes_grid1 import host_subplot
import mpl_toolkits.axisartist as AA
import matplotlib
matplotlib.use('PS')
import matplotlib.pyplot as plt
from numpy import append
from os.path import expanduser, normpath
class Plotter(rclpy.Node):
"""Plot progress during training."""
def __init__(self):
#rospy.init_node('model_progress_plotter')
super().__init__('model_progress_plotter')

# Constants
self.LOG_DIR = normpath(expanduser(self.get_parameter('~log/base_dir').get_parameter_value().string_value))
#self.PLOT_FREQ = rospy.get_param('~log/plot_freq', 25)
self.PLOT_FREQ = self.get_parameter_or('~log/plot_freq', Parameter(25)).get_parameter_value().integer_value
#self.BASIC_VARS = rospy.get_param('~log/basic', ["duration"])
self.BASIC_VARS = self.get_parameter_or('~log/basic', Parameter(["duration"])).get_parameter_value().string_value
#self.ADVANCED_VARS = rospy.get_param('~log/advanced', ["net_reward"])
self.ADVANCED_VARS = self.get_parameter_or('~log/advanced', Parameter(["net_reward"])).get_parameter_value().string_value

self.KEYS = ["episode"]
self.KEYS.append(self.BASIC_VARS)
self.KEYS.append(self.ADVANCED_VARS)

# Subscribers
#rospy.Subscriber('~log', DiagnosticStatus, self.progress_cb)
self.create_subscription(DiagnosticStatus, '~log', self.progress_cb, qos_profile=10)

self.history = None
self.LOG_NAME = None
self.next_plot_episode = self.PLOT_FREQ
self.init_plot()
rclpy.spin(self)

def init_plot(self):
"""Initialize the plot and all axes."""
labels = self.BASIC_VARS + self.ADVANCED_VARS

# create host
plt.figure(figsize=(11,8.5))
self.host = host_subplot(111, axes_class=AA.Axes)

plt.title("Training Log")
self.host.set_xlabel("Episode")
self.host.set_ylabel(labels[0])

self.axes = {labels[0]:self.host}

# create exta axes
extras = len(labels) - 1
plt.subplots_adjust(right = 0.96 - 0.07*extras)
offset = 0
for label in labels[1:]:
axis = self.host.twinx()
axis.axis["right"] = axis.get_grid_helper().new_fixed_axis(loc="right", axes=axis, offset=(offset, 0))
axis.axis["right"].toggle(all=True)
axis.set_ylabel(label)

self.axes[label] = axis
offset += 60

# create lines
self.lines = {}

for var in self.BASIC_VARS:
line, = self.axes[var].plot(-1, 0, label=var)
self.lines[var] = line

for var in self.ADVANCED_VARS:
line_max, = self.axes[var].plot(-1, 0, ':', label=var+"/max")
line_avg, = self.axes[var].plot(-1, 0, '-', label=var+"/avg", color=line_max.get_color())
line_min, = self.axes[var].plot(-1, 0, ':', label=var+"/min", color=line_max.get_color())

self.lines[var+"/max"] = line_max
self.lines[var] = line_avg
self.lines[var+"/min"] = line_min

# create legend
self.host.legend()

# color y axes
for var, axis in self.axes.items():
axis.axis["left"].label.set_color(self.lines[var].get_color())
axis.axis["right"].label.set_color(self.lines[var].get_color())

def progress_cb(self, progress_msg):
"""Track training progress and save when configured to."""
if self.LOG_NAME is None and progress_msg.hardware_id:
self.LOG_NAME = '/' + progress_msg.hardware_id.replace(':', '/plot_') + '.png'

data = {}

for item in progress_msg.values:
if item.key in self.KEYS:
data[item.key] = float(item.value)

if data["episode"] is not None:
if self.history is None:
self.history = [data]
else:
self.history.append(data)

if data["episode"] >= self.next_plot_episode:
self.plot()
self.next_plot_episode += self.PLOT_FREQ
self.history = None
else:
#rospy.logerr("Bad progress message.")
self.get_logger().warn("Bad progress message.")

def plot(self):
"""Add new data to plot, show, and save"""
# calculate the avgs, mins, maxs of all variables
sums = {key:0.0 for key in self.BASIC_VARS+self.ADVANCED_VARS}
mins = {key:float("inf") for key in self.ADVANCED_VARS}
maxs = {key:float("-inf") for key in self.ADVANCED_VARS}

for episode in self.history:
for var, value in episode.items():
if var in sums:
sums[var] += value
if var in mins:
mins[var] = min(mins[var], value)
if var in maxs:
maxs[var] = max(maxs[var], value)

# calculate the avg of all variables
avgs = {key:sums[key]/len(self.history) for key in self.BASIC_VARS+self.ADVANCED_VARS}

# update lines
episode = self.history[-1]["episode"]

for var in self.BASIC_VARS:
line = self.lines[var]
line.set_xdata(append(line.get_xdata(), episode))
line.set_ydata(append(line.get_ydata(), avgs[var]))

for var in self.ADVANCED_VARS:
# avg
line = self.lines[var]
line.set_xdata(append(line.get_xdata(), episode))
line.set_ydata(append(line.get_ydata(), avgs[var]))

# max
line = self.lines[var+"/max"]
line.set_xdata(append(line.get_xdata(), episode))
line.set_ydata(append(line.get_ydata(), maxs[var]))

# min
line = self.lines[var+"/min"]
line.set_xdata(append(line.get_xdata(), episode))
line.set_ydata(append(line.get_ydata(), mins[var]))

# update plot
for var, axis in self.axes.items():
axis.relim()
axis.autoscale()
self.host.set_xlim(0, episode)
plt.draw()

# update file
#rospy.loginfo(f"Saving training progress to {self.LOG_DIR}{self.LOG_NAME}")
self.get_logger().info("Saving training progress to {self.LOG_DIR}{self.LOG_NAME}")
plt.savefig(self.LOG_DIR + self.LOG_NAME)

if __name__ == "__main__":
Plotter()
34 changes: 34 additions & 0 deletions src/rktl_autonomy/nodes/rocket_league_agent
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/usr/bin/env python3
"""Real-time evaluation of the agent trained for the Rocket League project.
License:
BSD 3-Clause License
Copyright (c) 2023, Autonomous Robotics Club of Purdue (Purdue ARC)
All rights reserved.
"""

from rktl_autonomy import RocketLeagueInterface
from stable_baselines3 import PPO
from os.path import expanduser
# import rospy
import rclpy
from rclpy.exceptions import ROSInterruptException



# create interface (and init ROS)
env = RocketLeagueInterface(eval=True)

# load the model
# weights = expanduser(rospy.get_param('~weights'))
weights = expanduser(env.node.get_parameter('~weights'))
model = PPO.load(weights)

# evaluate in real-time
obs = env.reset()
while True:
action, __ = model.predict(obs)
try:
obs, __, __, __ = env.step(action)
# except rospy.ROSInterruptException:
except ROSInterruptException:
exit()
Loading

0 comments on commit 53b9cda

Please sign in to comment.