Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

All sanity envs cythonized #98

Open
wants to merge 9 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions config/ocean/bandit_cy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[base]
package = ocean
env_name = bandit_cy
policy_name = Policy
rnn_name = None

[train]
total_timesteps = 300_000
learning_rate = 0.017
num_envs = 128
num_workers = 8
env_batch_size = 64
batch_size = 1024
minibatch_size = 1024
bptt_horizon = 4
device = cpu
17 changes: 17 additions & 0 deletions config/ocean/continuous_cy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[base]
package = ocean
env_name = continuous_cy
policy_name = Policy
rnn_name = None

[train]
total_timesteps = 300_000
learning_rate = 0.017
num_envs = 32
num_workers = 8
env_batch_size = 4
batch_size = 4096
minibatch_size = 4096
bptt_horizon = 4
update_epochs = 4
device = cpu
16 changes: 16 additions & 0 deletions config/ocean/memory_cy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[base]
package = ocean
env_name = memory_cy
policy_name = Policy
rnn_name = Recurrent

[train]
total_timesteps = 600_000
learning_rate = 0.02
num_envs = 128
num_workers = 8
env_batch_size = 64
batch_size = 4096
minibatch_size = 1024
bptt_horizon = 8
device = cpu
17 changes: 17 additions & 0 deletions config/ocean/multiagent_cy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[base]
package = ocean
env_name = multiagent_cy
policy_name = Policy
rnn_name = Recurrent

[train]
total_timesteps = 150_000
learning_rate = 0.017
num_envs = 32
num_workers = 8
env_batch_size = 4
batch_size = 4096
minibatch_size = 4096
bptt_horizon = 4
update_epochs = 4
device = cpu
17 changes: 17 additions & 0 deletions config/ocean/password_cy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[base]
package = ocean
env_name = password_cy
policy_name = Policy
rnn_name = Recurrent

[train]
total_timesteps = 150_000
learning_rate = 0.017
num_envs = 32
num_workers = 8
env_batch_size = 4
batch_size = 4096
minibatch_size = 4096
bptt_horizon = 4
update_epochs = 4
device = cpu
16 changes: 16 additions & 0 deletions config/ocean/spaces_cy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[base]
package = ocean
env_name = spaces_cy
policy_name = Policy
rnn_name = None

[train]
total_timesteps = 300_000
learning_rate = 0.017
num_envs = 2048
num_workers = 8
env_batch_size = 256
batch_size = 4096
minibatch_size = 4096
bptt_horizon = 4
device = cpu
17 changes: 17 additions & 0 deletions config/ocean/squared_cy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[base]
package = ocean
env_name = squared_cy
policy_name = Policy
rnn_name = None

[train]
total_timesteps = 300_000
learning_rate = 0.017
num_envs = 32
num_workers = 8
env_batch_size = 4
batch_size = 4096
minibatch_size = 4096
bptt_horizon = 4
update_epochs = 4
device = cpu
17 changes: 17 additions & 0 deletions config/ocean/stochastic_cy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[base]
package = ocean
env_name = stochastic_cy
policy_name = Policy
rnn_name = None

[train]
total_timesteps = 300_000
learning_rate = 0.017
num_envs = 32
num_workers = 8
env_batch_size = 4
batch_size = 4096
minibatch_size = 4096
bptt_horizon = 4
update_epochs = 4
device = cpu
Empty file.
84 changes: 84 additions & 0 deletions pufferlib/environments/ocean/bandit_cy/cy_bandit_cy.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import numpy as np
cimport numpy as cnp
from libc.stdlib cimport malloc, free
from libc.stdint cimport int32_t

cdef struct CBanditEnv:
int num_actions
int solution_idx
float reward_scale
float reward_noise
int hard_fixed_seed
float* rewards
int32_t* actions

cdef CBanditEnv* init_c_bandit_env(int num_actions,
float reward_scale,
float reward_noise,
int hard_fixed_seed,
float* rewards,
int32_t* actions
):
# Allocate memory for the environment
cdef CBanditEnv* env = <CBanditEnv*> malloc(sizeof(CBanditEnv))
env.num_actions = num_actions
env.reward_scale = reward_scale
env.reward_noise = reward_noise
env.hard_fixed_seed = hard_fixed_seed
env.rewards = rewards
env.actions = actions

# Set up the solution
np.random.seed(hard_fixed_seed)
env.solution_idx = np.random.randint(0, num_actions)

return env

cdef void reset(CBanditEnv* env):
np.random.seed(env.hard_fixed_seed)
env.solution_idx = np.random.randint(0, env.num_actions)

cdef void step(CBanditEnv* env):
cdef int action = env.actions[0]
env.rewards[0] = 0.0

if action == env.solution_idx:
env.rewards[0] = 1.0

if env.reward_noise != 0.0:
env.rewards[0] += np.random.randn() * env.reward_scale

env.rewards[0] *= env.reward_scale

cdef void free_c_bandit_env(CBanditEnv* env):
free(env)

# Cython wrapper class
cdef class CBanditCy:
cdef:
CBanditEnv* env

def __init__(self,
int num_actions,
float reward_scale,
float reward_noise,
int hard_fixed_seed,
cnp.ndarray[cnp.float32_t, ndim=2] rewards,
cnp.ndarray[cnp.int32_t, ndim=2] actions
):

self.env = init_c_bandit_env(
num_actions, reward_scale, reward_noise, hard_fixed_seed,
<float*>rewards.data, <int32_t*>actions.data)

def reset(self):
reset(self.env)

def step(self):
step(self.env)

def get_solution_idx(self):
return self.env.solution_idx

def __dealloc__(self):
free_c_bandit_env(self.env)
64 changes: 64 additions & 0 deletions pufferlib/environments/ocean/bandit_cy/py_bandit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import numpy as np
import gymnasium
from .cy_bandit_cy import CBanditCy

class BanditCyEnv(gymnasium.Env):
def __init__(self, num_actions=4, reward_scale=1, reward_noise=0, hard_fixed_seed=42):
super().__init__()

self.num_actions = num_actions
self.reward_scale = reward_scale
self.reward_noise = reward_noise
self.hard_fixed_seed = hard_fixed_seed

self.rewards = np.zeros((1, 1), dtype=np.float32)
self.actions = np.zeros((1, 1), dtype=np.int32)

self.c_env = CBanditCy(num_actions,
reward_scale,
reward_noise,
hard_fixed_seed,
self.rewards,
self.actions)

self.observation_space = gymnasium.spaces.Box(
low=0, high=1, shape=(1,), dtype=np.float32)
self.action_space = gymnasium.spaces.Discrete(num_actions)

def reset(self, seed=None):
self.c_env.reset()
return np.ones(1, dtype=np.float32), {}

def step(self, action):
self.actions[0, 0] = action
self.c_env.step()

solution_idx = self.c_env.get_solution_idx()

return np.ones(1, dtype=np.float32), self.rewards[0, 0], True, False, {'score': action == solution_idx}

def render(self):
pass

def test_performance(num_actions=4, timeout=10, atn_cache=1024):
import time
env = BanditCyEnv(num_actions=num_actions)

env.reset()

tick = 0
actions = np.random.randint(0, num_actions, (atn_cache, 1))

start = time.time()

while time.time() - start < timeout:
atn = actions[tick % atn_cache]
env.step(atn[0])
tick += 1

elapsed_time = time.time() - start
sps = tick / elapsed_time
print(f"SPS: {sps:.2f}")

if __name__ == '__main__':
test_performance()
Empty file.
Loading