-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #173 from purdue-arc/autonomy_migration
Merge autonomy_migration into ros-2-complete
- Loading branch information
Showing
26 changed files
with
1,764 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
latest_list |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
list_2023-10-22_17-00-09 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
reward: | ||
# reward to be given each frame | ||
# constant: 0.0 | ||
# # reward as a function of squared distance from the ball | ||
# ball_dist_sq: -0.5 | ||
# # reward as a function of squared distance from the ball to the goal | ||
# goal_dist_sq: -1.0 | ||
# reward given when car changes velocity direction | ||
direction_change: 0 | ||
# reward given when the episode ends with the car scoring the proper goal | ||
win: [100, 200, 300] | ||
# reward given when the episode ends with the car scoring on the wrong goal | ||
loss: [100, 100, 100] | ||
# reward given each frame when the car is in reverse | ||
# reverse: -25 | ||
# # reward given each frame when the car is near the walls | ||
# walls: | ||
# # actual reward value | ||
# value: -50 | ||
# # threshold distance (meters) | ||
# threshold: 0.25 | ||
|
||
# duration when the episode will be terminated. Unit is seconds (sim time) | ||
max_episode_time: 15 | ||
|
||
log: | ||
base_dir: "~/catkin_ws/data/rocket_league/" | ||
# frequency to save progress plot. Unit is episodes | ||
plot_freq: 50 | ||
# variables to display in progress plot | ||
basic: | ||
- "duration" | ||
- "goals" | ||
advanced: | ||
- "net_reward" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from launch import LaunchDescription | ||
from launch.actions import DeclareLaunchArgument, LogInfo | ||
from launch.conditions import IfCondition | ||
from launch.substitutions import LaunchConfiguration | ||
from launch_ros.actions import Node | ||
|
||
def generate_launch_description(): | ||
plot_log = LaunchConfiguration('plot_log') # True if plotting is enabled. Else False | ||
weights = LaunchConfiguration('weights') # Filepath to weights | ||
|
||
plot_log_launch_arg = DeclareLaunchArgument( | ||
'plot_log', | ||
default_value='false', | ||
description='Set to true to enable logging for plotting performance.' | ||
) | ||
|
||
weights_launch_arg = DeclareLaunchArgument( | ||
'weights', | ||
default_value='~/catkin_ws/data/rocket_league/model', | ||
description='filepath to weights.' | ||
) | ||
|
||
plot_log_info = LogInfo(condition=IfCondition(plot_log), msg="Enabling performance plotting...") | ||
|
||
agent_node = Node( | ||
package='rktl_autonomy', | ||
executable='rocket_league_agent', | ||
name='rocket_league_agent', | ||
output='screen', | ||
parameters=[{'weights': weights}], | ||
) | ||
|
||
plotter_node = Node( | ||
condition=IfCondition(plot_log), | ||
package='rktl_autonomy', | ||
executable='plotter', | ||
name='plotter', | ||
output='screen', | ||
remappings=[('~log', 'rocket_league_agent/log')], | ||
) | ||
|
||
return LaunchDescription([plot_log_launch_arg, | ||
weights_launch_arg, | ||
plot_log_info, | ||
agent_node, | ||
plotter_node | ||
]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
from launch import LaunchDescription | ||
from launch.actions import DeclareLaunchArgument, LogInfo | ||
from launch.conditions import IfCondition | ||
from launch.substitutions import LaunchConfiguration | ||
from launch_ros.actions import Node | ||
from launch.actions import IncludeLaunchDescription | ||
from launch.launch_description_sources import PythonLaunchDescriptionSource | ||
|
||
def generate_launch_description(): | ||
return LaunchDescription([ | ||
DeclareLaunchArgument( | ||
'plot_log', | ||
default_value='true', | ||
description='Set to true to enable logging for plotting performance.' | ||
), | ||
DeclareLaunchArgument( | ||
'agent_name', | ||
default_value='rocket_league_agent', | ||
description='Name of the agent.' | ||
), | ||
DeclareLaunchArgument( | ||
'log_file', | ||
default_value='rocket_league_agent/log', | ||
description='Filepath for logger output'), | ||
DeclareLaunchArgument( | ||
'render', | ||
default_value='false', | ||
description='Set to true to enable rendering.' | ||
), | ||
DeclareLaunchArgument( | ||
'sim_mode', | ||
default_value='ideal', | ||
description='Simulation mode.' | ||
), | ||
|
||
DeclareLaunchArgument( | ||
'rate', | ||
default_value='10.0', | ||
description='Rate parameter.' | ||
), | ||
DeclareLaunchArgument( | ||
'agent_type', | ||
default_value='none', | ||
description='Agent type.' | ||
), | ||
|
||
LogInfo( | ||
condition=IfCondition(LaunchConfiguration('plot_log')), | ||
msg="Enabling performance plotting..." | ||
), | ||
|
||
Node( | ||
package='rktl_autonomy', | ||
executable='rocket_league_agent', | ||
name=LaunchConfiguration('agent_name'), | ||
output='screen', | ||
parameters=[{'rate': LaunchConfiguration('rate')}], | ||
namespace=LaunchConfiguration('agent_name') | ||
), | ||
|
||
IncludeLaunchDescription( | ||
PythonLaunchDescriptionSource(['$(find rktl_launch)/launch/rocket_league_sim_launch.py']), # TODO: Replace with the path to the launch file | ||
launch_arguments={ | ||
'render': LaunchConfiguration('render'), | ||
'sim_mode': LaunchConfiguration('sim_mode'), | ||
'agent_type': LaunchConfiguration('agent_type'), | ||
}.items(), | ||
), | ||
|
||
Node( | ||
condition=IfCondition(LaunchConfiguration('plot_log')), | ||
package='rktl_autonomy', | ||
executable='plotter', | ||
name='plotter', | ||
output='screen', | ||
remappings=[('~log', LaunchConfiguration('log_file'))], | ||
namespace=LaunchConfiguration('agent_name') | ||
) | ||
]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
#!/usr/bin/env python3 | ||
"""Convenience node to plot training progress. | ||
License: | ||
BSD 3-Clause License | ||
Copyright (c) 2023, Autonomous Robotics Club of Purdue (Purdue ARC) | ||
All rights reserved. | ||
""" | ||
|
||
# import rospy | ||
import rclpy | ||
from rclpy.node import Node | ||
from rclpy.parameter import Parameter | ||
from diagnostic_msgs.msg import DiagnosticStatus | ||
from mpl_toolkits.axes_grid1 import host_subplot | ||
import mpl_toolkits.axisartist as AA | ||
import matplotlib | ||
matplotlib.use('PS') | ||
import matplotlib.pyplot as plt | ||
from numpy import append | ||
from os.path import expanduser, normpath | ||
class Plotter(rclpy.Node): | ||
"""Plot progress during training.""" | ||
def __init__(self): | ||
#rospy.init_node('model_progress_plotter') | ||
super().__init__('model_progress_plotter') | ||
|
||
# Constants | ||
self.LOG_DIR = normpath(expanduser(self.get_parameter('~log/base_dir').get_parameter_value().string_value)) | ||
#self.PLOT_FREQ = rospy.get_param('~log/plot_freq', 25) | ||
self.PLOT_FREQ = self.get_parameter_or('~log/plot_freq', Parameter(25)).get_parameter_value().integer_value | ||
#self.BASIC_VARS = rospy.get_param('~log/basic', ["duration"]) | ||
self.BASIC_VARS = self.get_parameter_or('~log/basic', Parameter(["duration"])).get_parameter_value().string_value | ||
#self.ADVANCED_VARS = rospy.get_param('~log/advanced', ["net_reward"]) | ||
self.ADVANCED_VARS = self.get_parameter_or('~log/advanced', Parameter(["net_reward"])).get_parameter_value().string_value | ||
|
||
self.KEYS = ["episode"] | ||
self.KEYS.append(self.BASIC_VARS) | ||
self.KEYS.append(self.ADVANCED_VARS) | ||
|
||
# Subscribers | ||
#rospy.Subscriber('~log', DiagnosticStatus, self.progress_cb) | ||
self.create_subscription(DiagnosticStatus, '~log', self.progress_cb, qos_profile=10) | ||
|
||
self.history = None | ||
self.LOG_NAME = None | ||
self.next_plot_episode = self.PLOT_FREQ | ||
self.init_plot() | ||
rclpy.spin(self) | ||
|
||
def init_plot(self): | ||
"""Initialize the plot and all axes.""" | ||
labels = self.BASIC_VARS + self.ADVANCED_VARS | ||
|
||
# create host | ||
plt.figure(figsize=(11,8.5)) | ||
self.host = host_subplot(111, axes_class=AA.Axes) | ||
|
||
plt.title("Training Log") | ||
self.host.set_xlabel("Episode") | ||
self.host.set_ylabel(labels[0]) | ||
|
||
self.axes = {labels[0]:self.host} | ||
|
||
# create exta axes | ||
extras = len(labels) - 1 | ||
plt.subplots_adjust(right = 0.96 - 0.07*extras) | ||
offset = 0 | ||
for label in labels[1:]: | ||
axis = self.host.twinx() | ||
axis.axis["right"] = axis.get_grid_helper().new_fixed_axis(loc="right", axes=axis, offset=(offset, 0)) | ||
axis.axis["right"].toggle(all=True) | ||
axis.set_ylabel(label) | ||
|
||
self.axes[label] = axis | ||
offset += 60 | ||
|
||
# create lines | ||
self.lines = {} | ||
|
||
for var in self.BASIC_VARS: | ||
line, = self.axes[var].plot(-1, 0, label=var) | ||
self.lines[var] = line | ||
|
||
for var in self.ADVANCED_VARS: | ||
line_max, = self.axes[var].plot(-1, 0, ':', label=var+"/max") | ||
line_avg, = self.axes[var].plot(-1, 0, '-', label=var+"/avg", color=line_max.get_color()) | ||
line_min, = self.axes[var].plot(-1, 0, ':', label=var+"/min", color=line_max.get_color()) | ||
|
||
self.lines[var+"/max"] = line_max | ||
self.lines[var] = line_avg | ||
self.lines[var+"/min"] = line_min | ||
|
||
# create legend | ||
self.host.legend() | ||
|
||
# color y axes | ||
for var, axis in self.axes.items(): | ||
axis.axis["left"].label.set_color(self.lines[var].get_color()) | ||
axis.axis["right"].label.set_color(self.lines[var].get_color()) | ||
|
||
def progress_cb(self, progress_msg): | ||
"""Track training progress and save when configured to.""" | ||
if self.LOG_NAME is None and progress_msg.hardware_id: | ||
self.LOG_NAME = '/' + progress_msg.hardware_id.replace(':', '/plot_') + '.png' | ||
|
||
data = {} | ||
|
||
for item in progress_msg.values: | ||
if item.key in self.KEYS: | ||
data[item.key] = float(item.value) | ||
|
||
if data["episode"] is not None: | ||
if self.history is None: | ||
self.history = [data] | ||
else: | ||
self.history.append(data) | ||
|
||
if data["episode"] >= self.next_plot_episode: | ||
self.plot() | ||
self.next_plot_episode += self.PLOT_FREQ | ||
self.history = None | ||
else: | ||
#rospy.logerr("Bad progress message.") | ||
self.get_logger().warn("Bad progress message.") | ||
|
||
def plot(self): | ||
"""Add new data to plot, show, and save""" | ||
# calculate the avgs, mins, maxs of all variables | ||
sums = {key:0.0 for key in self.BASIC_VARS+self.ADVANCED_VARS} | ||
mins = {key:float("inf") for key in self.ADVANCED_VARS} | ||
maxs = {key:float("-inf") for key in self.ADVANCED_VARS} | ||
|
||
for episode in self.history: | ||
for var, value in episode.items(): | ||
if var in sums: | ||
sums[var] += value | ||
if var in mins: | ||
mins[var] = min(mins[var], value) | ||
if var in maxs: | ||
maxs[var] = max(maxs[var], value) | ||
|
||
# calculate the avg of all variables | ||
avgs = {key:sums[key]/len(self.history) for key in self.BASIC_VARS+self.ADVANCED_VARS} | ||
|
||
# update lines | ||
episode = self.history[-1]["episode"] | ||
|
||
for var in self.BASIC_VARS: | ||
line = self.lines[var] | ||
line.set_xdata(append(line.get_xdata(), episode)) | ||
line.set_ydata(append(line.get_ydata(), avgs[var])) | ||
|
||
for var in self.ADVANCED_VARS: | ||
# avg | ||
line = self.lines[var] | ||
line.set_xdata(append(line.get_xdata(), episode)) | ||
line.set_ydata(append(line.get_ydata(), avgs[var])) | ||
|
||
# max | ||
line = self.lines[var+"/max"] | ||
line.set_xdata(append(line.get_xdata(), episode)) | ||
line.set_ydata(append(line.get_ydata(), maxs[var])) | ||
|
||
# min | ||
line = self.lines[var+"/min"] | ||
line.set_xdata(append(line.get_xdata(), episode)) | ||
line.set_ydata(append(line.get_ydata(), mins[var])) | ||
|
||
# update plot | ||
for var, axis in self.axes.items(): | ||
axis.relim() | ||
axis.autoscale() | ||
self.host.set_xlim(0, episode) | ||
plt.draw() | ||
|
||
# update file | ||
#rospy.loginfo(f"Saving training progress to {self.LOG_DIR}{self.LOG_NAME}") | ||
self.get_logger().info("Saving training progress to {self.LOG_DIR}{self.LOG_NAME}") | ||
plt.savefig(self.LOG_DIR + self.LOG_NAME) | ||
|
||
if __name__ == "__main__": | ||
Plotter() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
#!/usr/bin/env python3 | ||
"""Real-time evaluation of the agent trained for the Rocket League project. | ||
License: | ||
BSD 3-Clause License | ||
Copyright (c) 2023, Autonomous Robotics Club of Purdue (Purdue ARC) | ||
All rights reserved. | ||
""" | ||
|
||
from rktl_autonomy import RocketLeagueInterface | ||
from stable_baselines3 import PPO | ||
from os.path import expanduser | ||
# import rospy | ||
import rclpy | ||
from rclpy.exceptions import ROSInterruptException | ||
|
||
|
||
|
||
# create interface (and init ROS) | ||
env = RocketLeagueInterface(eval=True) | ||
|
||
# load the model | ||
# weights = expanduser(rospy.get_param('~weights')) | ||
weights = expanduser(env.node.get_parameter('~weights')) | ||
model = PPO.load(weights) | ||
|
||
# evaluate in real-time | ||
obs = env.reset() | ||
while True: | ||
action, __ = model.predict(obs) | ||
try: | ||
obs, __, __, __ = env.step(action) | ||
# except rospy.ROSInterruptException: | ||
except ROSInterruptException: | ||
exit() |
Oops, something went wrong.