Skip to content

Commit

Permalink
add two examples (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 authored Aug 25, 2024
1 parent ed039c1 commit e0bfa17
Show file tree
Hide file tree
Showing 3 changed files with 284 additions and 1 deletion.
155 changes: 155 additions & 0 deletions examples/COBAHH_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# -*- coding: utf-8 -*-

import time

import brainstate as bst
import brainunit as u

import braintaichi as bti

s = 1e-2
Cm = 200 * s # Membrane Capacitance [pF]
gl = 10. * s # Leak Conductance [nS]
g_Na = 20. * 1000 * s
g_Kd = 6. * 1000 * s # K Conductance [nS]
El = -60. # Resting Potential [mV]
ENa = 50. # reversal potential (Sodium) [mV]
EK = -90. # reversal potential (Potassium) [mV]
VT = -63.
V_th = -20.
taue = 5. # Excitatory synaptic time constant [ms]
taui = 10. # Inhibitory synaptic time constant [ms]
Ee = 0. # Excitatory reversal potential (mV)
Ei = -80. # Inhibitory reversal potential (Potassium) [mV]
we = 6. * s # excitatory synaptic conductance [nS]
wi = 67. * s # inhibitory synaptic conductance [nS]


class HH(bst.nn.Neuron):
def __init__(self, size, method='exp_auto'):
super(HH, self).__init__(size)

def init_state(self, *args, **kwargs):
# variables
self.V = bst.State(El + bst.random.randn(self.num) * 5 - 5.)
self.m = bst.State(u.math.zeros(self.num))
self.n = bst.State(u.math.zeros(self.num))
self.h = bst.State(u.math.zeros(self.num))
self.rate = bst.State(u.math.zeros(self.num))
self.spike = bst.State(u.math.zeros(self.num, dtype=bool))

def dV(self, V, t, m, h, n, Isyn):
Isyn = self.sum_current_inputs(self.V.value, init=Isyn) # sum projection inputs
gna = g_Na * (m * m * m) * h
n2 = n * n
gkd = g_Kd * (n2 * n2)
dVdt = (-gl * (V - El) - gna * (V - ENa) - gkd * (V - EK) + Isyn) / Cm
return dVdt

def dm(self, m, t, V, ):
m_alpha = 1.28 / u.math.exprel((13 - V + VT) / 4)
m_beta = 1.4 / u.math.exprel((V - VT - 40) / 5)
dmdt = (m_alpha * (1 - m) - m_beta * m)
return dmdt

def dh(self, h, t, V):
h_alpha = 0.128 * u.math.exp((17 - V + VT) / 18)
h_beta = 4. / (1 + u.math.exp(-(V - VT - 40) / 5))
dhdt = (h_alpha * (1 - h) - h_beta * h)
return dhdt

def dn(self, n, t, V):
n_alpha = 0.16 / u.math.exprel((15 - V + VT) / 5.)
n_beta = 0.5 * u.math.exp((10 - V + VT) / 40)
dndt = (n_alpha * (1 - n) - n_beta * n)
return dndt

def update(self, inp=0.):
t = bst.environ.get('t')
V = bst.nn.exp_euler_step(self.dV, self.V.value, t, self.m.value, self.h.value, self.n.value, inp)
m = bst.nn.exp_euler_step(self.dm, self.m.value, t, self.V.value)
n = bst.nn.exp_euler_step(self.dn, self.n.value, t, self.V.value)
h = bst.nn.exp_euler_step(self.dh, self.h.value, t, self.V.value)
self.spike.value = u.math.logical_and(self.V.value < V_th, V >= V_th)
self.m.value = m
self.h.value = h
self.n.value = n
self.V.value = V
self.rate.value += self.spike.value
return self.spike.value


class CSRLinear(bst.Module):
def __init__(self, n_pre, n_post, g_max, prob):
super().__init__()
self.g_max = g_max
self.n_pre = n_pre
self.n_post = n_post
self.prob = prob

def update(self, spk):
return bti.jitc_event_mv_prob_homo(
spk, self.g_max, conn_prob=self.prob, shape=(self.n_pre, self.n_post,), seed=123, transpose=True
)


class Exponential(bst.Projection):
def __init__(self, num_pre, post, prob, g_max, tau, E):
super().__init__()

self.proj = bst.nn.HalfProjAlignPostMg(
comm=CSRLinear(num_pre, post.num, g_max, prob),
syn=bst.nn.Expon.delayed(post.num, tau=tau),
out=bst.nn.COBA.delayed(E=E),
post=post
)

def update(self, spk):
self.proj.update(spk)


class COBA_HH_Net(bst.ModuleGroup):
def __init__(self, scale=1.):
super(COBA_HH_Net, self).__init__()
self.num_exc = int(3200 * scale)
self.num_inh = int(800 * scale)
self.num = self.num_exc + self.num_inh

self.N = HH(self.num)
self.E = Exponential(self.num_exc, self.N, prob=80 / self.num, g_max=we, tau=taue, E=Ee)
self.I = Exponential(self.num_inh, self.N, prob=80 / self.num, g_max=wi, tau=taui, E=Ei)

def update(self):
self.E(self.N.spike.value[:self.num_exc])
self.I(self.N.spike.value[self.num_exc:])
self.N()

def step_run(self, i):
with bst.environ.context(i=i, t=i * bst.environ.get_dt()):
self.update()


def run_a_simulation(scale=10, duration=1e3):
net = COBA_HH_Net(scale=scale)
bst.init_states(net)

indices = u.math.arange(int(duration / bst.environ.get_dt()))

t0 = time.time()
# if the network size is big, please turn on "progress_bar"
# otherwise, the XLA may compute wrongly
r = bst.transform.for_loop(net.step_run, indices)
t1 = time.time()

rate = net.N.rate.value.sum() / net.N.num / duration * 1e3

print(f'scale={scale}, size={net.num}, time = {t1 - t0} s, firing rate = {rate} Hz')


def check_firing_rate(x64=True, platform='cpu'):
for scale in [1, 2, 4, 6, 8, 10, 20, 30, 40, 50, 80, 100]:
run_a_simulation(scale=scale, duration=2e3)


if __name__ == '__main__':
check_firing_rate()
123 changes: 123 additions & 0 deletions examples/COBA_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# -*- coding: utf-8 -*-

import time

import brainstate as bst
import jax.numpy as jnp

import braintaichi as bti

taum = 20
taue = 5
taui = 10
Vt = -50
Vr = -60
El = -60
Erev_exc = 0.
Erev_inh = -80.
Ib = 20.
ref = 5.0
we = 0.6
wi = 6.7


class LIF(bst.nn.Neuron):
def __init__(self, size, V_init: callable, **kwargs):
super(LIF, self).__init__(size, **kwargs)

# parameters
self.V_rest = Vr
self.V_reset = El
self.V_th = Vt
self.tau = taum
self.tau_ref = ref

self.V_init = V_init

def init_state(self, *args, **kwargs):
# variables
self.V = bst.init.state(self.V_init, self.num)
self.spike = bst.init.state(bst.init.Constant(False, dtype=bool), self.num)
self.t_last_spike = bst.init.state(bst.init.Constant(-1e7), self.num)

def update(self, inp):
inp = self.sum_current_inputs(self.V.value, init=inp) # sum all projection inputs
refractory = (bst.environ.get('t') - self.t_last_spike.value) <= self.tau_ref
V = self.V.value + (-self.V.value + self.V_rest + inp) / self.tau * bst.environ.get_dt()
V = jnp.where(refractory, self.V.value, V)
spike = self.V_th <= V
self.t_last_spike.value = jnp.where(spike, bst.environ.get('t'), self.t_last_spike.value)
self.V.value = jnp.where(spike, self.V_reset, V)
self.spike.value = spike
return spike


class CSRLinear(bst.Module):
def __init__(self, n_pre, n_post, g_max, prob):
super().__init__()
self.g_max = g_max
self.n_pre = n_pre
self.n_post = n_post
self.prob = prob

def update(self, spk):
return bti.jitc_event_mv_prob_homo(
spk, self.g_max, conn_prob=self.prob, shape=(self.n_pre, self.n_post, ), seed=123, transpose=True
)


class Exponential(bst.Projection):
def __init__(self, num_pre, post, prob, g_max, tau, E):
super().__init__()
self.proj = bst.nn.HalfProjAlignPostMg(
comm=CSRLinear(num_pre, post.num, g_max, prob),
syn=bst.nn.Expon.delayed(post.num, tau=tau),
out=bst.nn.COBA.delayed(E=E),
post=post
)


class COBA(bst.ModuleGroup):
def __init__(self, scale):
super().__init__()
self.num_exc = int(3200 * scale)
self.num_inh = int(800 * scale)
self.N = LIF(self.num_exc + self.num_inh, V_init=bst.init.Normal(-55., 5.))
self.E = Exponential(self.num_exc, self.N, prob=80. / self.N.num, E=Erev_exc, g_max=we, tau=taue)
self.I = Exponential(self.num_inh, self.N, prob=80. / self.N.num, E=Erev_inh, g_max=wi, tau=taui)

def init_state(self, *args, **kwargs):
self.rate = bst.init.state(jnp.zeros, self.N.num)

def update(self, inp=Ib):
self.E(self.N.spike.value[:self.num_exc])
self.I(self.N.spike.value[self.num_exc:])
self.N(inp)
self.rate.value += self.N.spike.value

def step_run(self, i):
with bst.environ.context(i=i, t=i * bst.environ.get_dt()):
self.update()


def run_a_simulation(scale=10, duration=1e3, ):
net = COBA(scale=scale)
bst.init_states(net)
indices = jnp.arange(int(duration / bst.environ.get_dt()))
t0 = time.time()
bst.transform.for_loop(net.step_run, indices)
t1 = time.time()

# running
rate = net.rate.value.sum() / net.N.num / duration * 1e3
print(f'scale={scale}, size={net.N.num}, time = {t1 - t0} s, '
f'firing rate = {rate} Hz')


def check_firing_rate():
for s in [1, 2, 4, 6, 8, 10, 20, 40, 60, 80, 100]:
run_a_simulation(scale=s, duration=5e3)


if __name__ == '__main__':
check_firing_rate()
7 changes: 6 additions & 1 deletion lib/gpu_taichi_kernel_call.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ namespace brain_taichi {
void **buffers,
const char *opaque,
std::size_t opaque_len) {
cudaStreamSynchronize(stream);
taichi_kernel->set_cuda_stream(stream);
OpaqueStruct data = parseOpaque(opaque, opaque_len);

// restruct shape_list, it's a 2d array and the shape of it is (in_num+out_num, the max of dim_count)
// restruct shape_list, it's a 2d array and the shape
// of it is (in_num+out_num, the max of dim_count)
int param_total_num = data.in_num + data.out_num;
uint32_t shape_list_2d[param_total_num][8];
for (int i = 0; i < param_total_num; i++) {
Expand All @@ -19,8 +21,10 @@ namespace brain_taichi {
}
}

// Load the taichi kernel
taichi_kernel->load(data.kernel_aot_path.c_str());

// push the input data
for (int i = 0; i < data.in_num; i++) {
push_input(data.type_list[i],
buffers[i],
Expand All @@ -29,6 +33,7 @@ namespace brain_taichi {
shape_list_2d[i]);
}

// push the output data
for (int i = 0; i < data.out_num; i++) {
push_output(data.type_list[i + data.in_num],
buffers[i + data.in_num],
Expand Down

0 comments on commit e0bfa17

Please sign in to comment.