Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The CPU kernels do not build correctly. #48

Open
LizhengyuSJTU opened this issue Dec 12, 2024 · 18 comments
Open

The CPU kernels do not build correctly. #48

LizhengyuSJTU opened this issue Dec 12, 2024 · 18 comments
Labels
bug Something isn't working

Comments

@LizhengyuSJTU
Copy link

LizhengyuSJTU commented Dec 12, 2024

The CPU kernels do not build correctly. Please check the installation of braintaichi. on Windows 11 system with Python 3.11.

The code
import numpy as np
import time
import brainpy as bp
import brainpy.math as bm
import matplotlib.pyplot as plt
from typing import Union, Sequence, Callable, Optional
from brainpy import math as bm
from brainpy._src.context import share
from brainpy._src.initialize import parameter
from brainpy._src.dyn import _docs
from brainpy._src.dyn.base import SynDyn
from brainpy._src.integrators.joint_eq import JointEq
from brainpy._src.integrators.ode.generic import odeint
from brainpy._src.mixin import AlignPost, ReturnInfo
from brainpy.types import ArrayType
import types
import os
import cv2
from PIL import Image
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor
bm.set_platform('cpu')


def mkdir(fn):
    '''创建文件夹及中间文件夹'''
    os.makedirs(fn, exist_ok=True)


def fig_to_video(fig_paths, filename, frame_rate=24, delete_figs=False, formats=None):
    '''
    将图片合成视频或gif
    '''
    if formats is None:
        formats = ['mp4', 'gif']

    # 创建文件夹
    mkdir(os.path.dirname(filename))

    valid_fig_paths = fig_paths.copy()

    # Ensure there are valid figs to process
    if not valid_fig_paths:
        print("No valid figs to process.")
        return

    if 'mp4' in formats:
        # MP4 Video output
        video_filename = f"{filename}.mp4"
        frame = cv2.imread(valid_fig_paths[0])
        height, width, layers = frame.shape
        video_size = (width, height)
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(video_filename, fourcc,
                              frame_rate, video_size)

        for fig_path in valid_fig_paths:
            fig = cv2.imread(fig_path)
            if fig.shape[1] != video_size[0] or fig.shape[0] != video_size[1]:
                fig = cv2.resize(
                    fig, video_size, interpolation=cv2.INTER_LANCZOS4)
            out.write(fig)
        out.release()

    if 'gif' in formats:
        # GIF output
        gif_filename = f"{filename}.gif"
        figs = [Image.open(fig_path) for fig_path in valid_fig_paths]
        resized_figs = [fig.resize(
            (figs[0].width, figs[0].height), Image.LANCZOS) for fig in figs]
        resized_figs[0].save(gif_filename, save_all=True,
                               append_images=resized_figs[1:], duration=1000/frame_rate, loop=0)

    if delete_figs:
        # Delete figs after processing all formats
        for fig_path in valid_fig_paths:
            os.remove(fig_path)
        print("All valid figs deleted.")

    cv2.destroyAllWindows()


def split_list(lst, n):
    '''
    将列表尽量均等地分割为n个子列表。

    参数:
    - lst: list
        要分割的列表。
    - n: int
        子列表的数量。

    返回:
    list: 包含n个子列表的列表。
    '''
    # 计算每个子列表的长度
    length = len(lst)
    size = length // n
    remainder = length % n

    # 创建子列表
    divided_list = []
    start = 0
    for i in range(n):
        # 确定子列表的长度
        sublist_size = size + 1 if i < remainder else size
        # 添加子列表到结果列表中
        divided_list.append(lst[start:start + sublist_size])
        # 更新下一个子列表的起始位置
        start += sublist_size
    return divided_list


def flatten_list(input_list, level=None):
    '''
    展开嵌套列表(要求每一层都是列表)

    参数:
    - input_list: 嵌套列表
    - level: 展开的层级数,如果为1,则从外往内展开一层,如果为2,则从外往内展开两层,以此类推;如果为None,则展开所有层级
    '''
    if level is None:
        level = float('inf')  # 默认情况下展开所有层

    def flatten_recursive(lst, curr_level):
        flattened = []
        for item in lst:
            if isinstance(item, list) and curr_level < level:
                flattened.extend(flatten_recursive(item, curr_level + 1))
            else:
                flattened.append(item)
        return flattened

    return flatten_recursive(input_list, 0)


def multi_process(process_num, func, args_list=None, kwargs_list=None, func_name=''):
    '''
    多进程并行处理函数

    参数:
    - process_num: int, 并行处理的进程数(由于multi_process状态下,代码错误的提示难以看出错误位置,在测试时可以先把process_num设置为1,这时会按照正常默认方式运行和报错)
    - func: function, 要并行处理的函数
    - args_list: list, 函数的位置参数列表
    - kwargs_list: list, 函数的关键字参数列表
    - func_name: str, 函数的名称(也可以输入任务的名称等需要显示的信息)

    注意:
    假如args_list和kwargs_list的长度等于1,则会将其扩展到process_num
    假如args_list = [(1), (2)]这样的写法是不对的,至少要让里面成为元组,即args_list = [(1,), (2,)]
    假如已经在multi_process中,继续使用multi_process会自动转为单进程运行(此时args_list和kwargs_list会被flatten)
    '''
    if args_list is None:
        args_list = [()]
    if kwargs_list is None:
        kwargs_list = [{}]
    for i, args in enumerate(args_list):
        if args is None:
            args_list[i] = ()
    for i, kwargs in enumerate(kwargs_list):
        if kwargs is None:
            kwargs_list[i] = {}
    if len(args_list) != process_num:
        if len(args_list) == 1:
            args_list = args_list * process_num
        elif process_num == 1:
            args_list = flatten_list(args_list, level=1)
        else:
            raise ValueError("The length of args_list must be equal to process_num or 1.")
    if len(kwargs_list) != process_num:
        if len(kwargs_list) == 1:
            kwargs_list = kwargs_list * process_num
        elif process_num == 1:
            kwargs_list = flatten_list(kwargs_list, level=1)
        else:
            raise ValueError("The length of kwargs_list must be equal to process_num or 1.")

    if process_num != 1:
        results = []
        # 使用 ProcessPoolExecutor 进行多进程处理
        with ProcessPoolExecutor(max_workers=process_num) as executor:
            # 提交任务
            futures = [executor.submit(func, *args, **kwargs) for args, kwargs in zip(args_list, kwargs_list)]

            # 等待所有future对象按照提交的顺序完成,并收集结果
            for future in futures:
                try:
                    # 这里按照futures的顺序获取结果,保证结果的顺序与提交顺序相同
                    results.append(future.result())
                except Exception as e:
                    results.append(None)
                    print(f"An error occurred: {e}")
        return results
    elif process_num == 1:
        return [func(*args, **kwargs) for args, kwargs in zip(args_list, kwargs_list)]


def part_list_for(func, for_list, for_idx_name, *args, **kwargs):
    results = []
    for i in for_list:
        results.append(func(*args, **{**kwargs, for_idx_name: i}))
    return results


def multi_process_list_for(process_num, func, args=None, kwargs=None, for_list=None, for_idx_name='i', func_name=''):
    '''
    多进程并行处理for循环,for循环形式为for i in for_list

    参数:
    - process_num: int, 并行处理的进程数
    - func: function, 要并行处理的函数
    - args: 函数的位置参数(不推荐,因为idx在func中的位置不确定)
    - kwargs: 函数的关键字参数
    - func_name: str, 函数的名称(也可以输入任务的名称等需要显示的信息)

    注意:
    只有当for循环每个之间独立时才能使用这个函数
    如果需要使用items()方法,请使用multi_process_items_for;如果需要使用enumerate()方法,请使用multi_process_enumerate_for;此处尚未支持zip()方法,但是zip也可以通过普通for循环实现
    '''
    if args is None:
        args = ()
    if kwargs is None:
        kwargs = {}
    for_list = list(for_list)   # 防止for_list是生成器,比如range(10)
    divided_list = split_list(for_list, process_num)
    args_list = [(func, divided, for_idx_name)+args for divided in divided_list]
    kwargs_list = [kwargs] * process_num
    return flatten_list(multi_process(process_num, part_list_for, args_list, kwargs_list, func_name), level=1)


class Func(SynDyn, AlignPost):
  def __init__(
      self,
      size: Union[int, Sequence[int]],
      keep_size: bool = False,
      sharding: Optional[Sequence[str]] = None,
      name: Optional[str] = None,
      mode: Optional[bm.Mode] = None,
      func: Optional[Callable] = None,
  ):
    super().__init__(name=name,
                     mode=mode,
                     size=size,
                     keep_size=keep_size,
                     sharding=sharding)

    # parameters
    self.func = func

    # function
    self._current = None

    self.reset_state(self.mode)

  def reset_state(self, batch_or_mode=None, **kwargs):
    self.g = self.init_variable(bm.zeros, batch_or_mode)

  def update(self, x=None):
    self.g.value = bm.ones_like(self.g.value) * self.func(share['t'])
    return self.g.value

  def add_current(self, x):
    self.g.value += x

  def return_info(self):
    return self.g


class ExponentialCOBA(bp.Projection):
  def __init__(self, pre, post, delay, tau, E, comm):
    super().__init__()
    self.proj = bp.dyn.FullProjAlignPost(
      pre=pre,
      delay=delay,
      comm=comm,
      syn=bp.dyn.Expon(size=post.num, tau=tau),# Exponential synapse
      out=bp.dyn.COBA(E=E),
      post=post
    )


class NormalizedDualExponV2(bp.dyn.DualExponV2):
    '''
    调整A的默认值(https://brainpy.readthedocs.io/en/latest/apis/generated/brainpy.dyn.DualExponV2.html),使得整个kernel积分为1

    注意,如果想要获取g的话,要使用这样的语法:
    定义syn
    self.syn = bf.NormalizedDualExponCUBA(self.pre, self.post, delay=None, comm=bp.dnn.CSRLinear(bp.conn.FixedProb(1., pre=self.pre.num, post=self.post.num), 1.), tau_rise=2., tau_decay=20.)
    拿到syn的两个g和a
    (self.syn.proj.refs['syn'].g_decay - self.syn.proj.refs['syn'].g_rise) * self.syn.proj.refs['syn'].a

    相比之下,NormailzedExponCUBA的g可以直接拿到
    '''
    def __init__(
        self,
        size: Union[int, Sequence[int]],
        keep_size: bool = False,
        sharding: Optional[Sequence[str]] = None,
        method: str = 'exp_auto',
        name: Optional[str] = None,
        mode: Optional[bm.Mode] = None,

        # synapse parameters
        tau_decay: Union[float, ArrayType, Callable] = 10.0,
        tau_rise: Union[float, ArrayType, Callable] = 1.,
        A: Optional[Union[float, ArrayType, Callable]] = None,
    ):
        super().__init__(name=name,
                            mode=mode,
                            size=size,
                            keep_size=keep_size,
                            sharding=sharding)

        def _format_dual_exp_A(self, A):
            A = parameter(A, sizes=self.varshape, allow_none=True, sharding=self.sharding)
            if A is None:
                A = 1 / (self.tau_decay - self.tau_rise)
            return A

        # parameters
        self.tau_rise = self.init_param(tau_rise)
        self.tau_decay = self.init_param(tau_decay)
        self.a = _format_dual_exp_A(self, A)

        # integrator
        self.integral = odeint(lambda g, t, tau: -g / tau, method=method)

        self.reset_state(self.mode)


class NormalizedDualExponCOBA(bp.Projection):
    def __init__(self, pre, post, delay, comm, tau_rise, tau_decay, E, out_label=None):
        super().__init__()

        self.proj = bp.dyn.FullProjAlignPostMg(
        pre=pre,
        delay=delay,
        comm=comm,
        syn=NormalizedDualExponV2.desc(post.num, tau_rise=tau_rise, tau_decay=tau_decay),
        out=bp.dyn.COBA.desc(E),
        post=post,
        out_label=out_label
        )


class FuncCUBA(bp.Projection): # CUBA: current-based synapse
  def __init__(self, pre, post, delay, func, out_label=None):
    super().__init__()
    self.proj = bp.dyn.FullProjAlignPost(
      pre=pre,
      delay=delay,
      comm=bp.dnn.AllToAll(pre.num, post.num, 1.),
      syn=Func(size=post.num, func=func),
      out=bp.dyn.CUBA(),
      post=post,
      out_label=out_label
    )


class FuncCOBA(bp.Projection): # COBA: conductance-based synapse
  def __init__(self, pre, post, delay, func, E, out_label=None):
    super().__init__()
    self.proj = bp.dyn.FullProjAlignPost(
      pre=pre,
      delay=delay,
      comm=bp.dnn.AllToAll(pre.num, post.num, 1.),
      syn=Func(size=post.num, func=func),
      out=bp.dyn.COBA(E=E),
      post=post,
      out_label=out_label
    )


def ij_conn(pre, post, pre_size, post_size):
    '''
    利用brainpy的bp.conn.IJConn生成conn
    '''
    conn = bp.conn.IJConn(i=pre, j=post)
    conn = conn(pre_size=pre_size, post_size=post_size)
    return conn


def ij_comm(pre, post, pre_size, post_size, weight):
    '''
    利用brainpy的bp.conn.IJConn和bp.dnn.EventCSRLinear生成comm
    '''
    conn = ij_conn(pre, post, pre_size, post_size)
    return bp.dnn.EventCSRLinear(conn, weight)


class EINet(bp.DynamicalSystem):
    def __init__(self, grid_num, grid_distance, E_conn_grid_num, I_conn_grid_num, wE, wI, sigmaE, E_inp_func, I_inp_func):
        super().__init__()

        self.location, self.conn_weight = generate_conn_and_weight(grid_num, grid_distance, E_conn_grid_num, I_conn_grid_num, wE, wI, sigmaE)
        ne = self.location['E_num']
        ni = self.location['I_num']

        # 神经元
        self.E = bp.dyn.ExpIFRef(ne, V_rest=-70., V_reset=-70., V_th=-40., V_T=-60.6250, delta_T=6.5625, tau=20., R=1., tau_ref=5., V_initializer=bp.init.Normal(-55., 10.))
        self.I = bp.dyn.ExpIFRef(ni, V_rest=-70., V_reset=-70., V_th=-40., V_T=-60.6250, delta_T=6.5625, tau=20., R=1., tau_ref=5., V_initializer=bp.init.Normal(-55., 10.))
        self.E_inp = bp.dyn.InputGroup(size=1) # placeholder
        self.I_inp = bp.dyn.InputGroup(size=1) # placeholder

        # 连接
        neuron_comm = get_comm(self.location, self.conn_weight)
        self.E2E = NormalizedDualExponCOBA(pre=self.E, post=self.E, delay=0., comm=neuron_comm['E2E_comm'], tau_rise=0.3, tau_decay=2., E=0., out_label='E')
        self.E2I = NormalizedDualExponCOBA(pre=self.E, post=self.I, delay=0., comm=neuron_comm['E2I_comm'], tau_rise=0.3, tau_decay=2., E=0., out_label='E')
        self.I2E = NormalizedDualExponCOBA(pre=self.I, post=self.E, delay=0., comm=neuron_comm['I2E_comm'], tau_rise=0.3, tau_decay=3., E=-80., out_label='I')
        self.I2I = NormalizedDualExponCOBA(pre=self.I, post=self.I, delay=0., comm=neuron_comm['I2I_comm'], tau_rise=0.3, tau_decay=3., E=-80., out_label='I')

        # 额外的输入
        addtion_func = lambda t: 0.001 * 20
        self.additionE2E = FuncCOBA(pre=self.E_inp, post=self.E, delay=0., func=addtion_func, E=0., out_label='E')
        self.additionE2I = FuncCOBA(pre=self.I_inp, post=self.I, delay=0., func=addtion_func, E=0., out_label='E')
        self.additionI2E = FuncCOBA(pre=self.E_inp, post=self.E, delay=0., func=addtion_func, E=-80., out_label='I')
        self.additionI2I = FuncCOBA(pre=self.I_inp, post=self.I, delay=0., func=addtion_func, E=-80., out_label='I')

        # 输入
        self.E_inp2E = FuncCUBA(pre=self.E_inp, post=self.E, delay=0., func=E_inp_func, out_label='input')
        self.I_inp2I = FuncCUBA(pre=self.I_inp, post=self.I, delay=0., func=I_inp_func, out_label='input')

        # Poisson输入
        self.Poisson_inp_for_E = bp.dyn.PoissonGroup(size=ne, freqs=10.)
        self.Poisson_inp_for_E2E = NormalizedDualExponCOBA(pre=self.Poisson_inp_for_E, post=self.E, delay=0., comm=bp.dnn.OneToOne(ne, ne, 0.2), tau_rise=0.3, tau_decay=2., E=0., out_label='input')
        self.Poisson_inp_for_I = bp.dyn.PoissonGroup(size=ni, freqs=10.)
        self.Poisson_inp_for_I2I = NormalizedDualExponCOBA(pre=self.Poisson_inp_for_I, post=self.I, delay=0., comm=bp.dnn.OneToOne(ni, ni, 0.2), tau_rise=0.3, tau_decay=2., E=0., out_label='input')

    def update(self):
        self.E2E()
        self.E2I()
        self.I2E()
        self.I2I()
        self.additionE2E()
        self.additionE2I()
        self.additionI2E()
        self.additionI2I()
        self.E_inp2E()
        self.I_inp2I()
        self.Poisson_inp_for_E2E()
        self.Poisson_inp_for_I2I()
        self.E()
        self.I()


def generate_location(grid_num, grid_distance):
  '''生成grid_num*grid_num个位置,grid_distance为两个相邻位置之间的距离,在mesh上,E_neuron的间隔为1,I_neuron的间隔为2'''
  grid_loc_x = np.arange(0, grid_num) * grid_distance
  grid_loc_y = np.arange(0, grid_num) * grid_distance

  grid_idx_E = np.arange(0, grid_num)
  grid_idx_I = grid_idx_E[::2]

  E_x_mesh, E_y_mesh = np.meshgrid(grid_loc_x[grid_idx_E], grid_loc_y[grid_idx_E])
  E_i_mesh, E_j_mesh = np.meshgrid(grid_idx_E, grid_idx_E)
  I_x_mesh, I_y_mesh = np.meshgrid(grid_loc_x[grid_idx_I], grid_loc_y[grid_idx_I])
  I_i_mesh, I_j_mesh = np.meshgrid(grid_idx_I, grid_idx_I)
  return E_x_mesh, E_y_mesh, E_i_mesh, E_j_mesh, I_x_mesh, I_y_mesh, I_i_mesh, I_j_mesh


def part_generate_conn_and_weight(idx, location, pre_group, pre_step, post_group, post_step, E_conn_grid_num, I_conn_grid_num, wE, wI, sigmaE):
  # 确定使用的conn_grid_num
  if pre_group == 'E':
    conn_grid_num = E_conn_grid_num
  elif pre_group == 'I':
    conn_grid_num = I_conn_grid_num
  conn_weight_pre_idx = []
  conn_weight_post_idx = []
  conn_weight_weight = []
  # pre_i, pre_j 是相对于grid的索引, post_i, post_j 是相对于grid的索引, 这意味着E,I的间隔不同; 为了使用brainpy,需要转换为连续的索引
  pre_i, pre_j = location[f'{pre_group}_i_mesh'].flatten()[idx], location[f'{pre_group}_j_mesh'].flatten()[idx]
  # 只在一个正方形的范围内判断,以减少计算量(这边需要保证好仍然是step的整数倍)
  square_range = conn_grid_num + post_step # 略微扩大一点范围
  post_i_start = int((pre_i - square_range) // post_step * post_step)
  post_i_end = int((pre_i + square_range) // post_step * post_step)
  post_j_start = int((pre_j - square_range) // post_step * post_step)
  post_j_end = int((pre_j + square_range) // post_step * post_step)
  post_i_range = np.arange(post_i_start, post_i_end, step=post_step)
  post_j_range = np.arange(post_j_start, post_j_end, step=post_step)
  # 应用周期性边界条件
  post_i_range = post_i_range % grid_num
  post_j_range = post_j_range % grid_num
  # 可能会有重复,去除
  post_i_range = np.unique(post_i_range)
  post_j_range = np.unique(post_j_range)
  mesh_post_i, mesh_post_j = np.meshgrid(post_i_range, post_j_range)
  for post_i, post_j in zip(mesh_post_i.flatten(), mesh_post_j.flatten()):
    # 应用周期性边界条件
    i_distance = np.min([np.abs((pre_i - post_i)), np.abs((pre_i - post_i + grid_num)), np.abs((pre_i - post_i - grid_num))])
    j_distance = np.min([np.abs((pre_j - post_j)), np.abs((pre_j - post_j + grid_num)), np.abs((pre_j - post_j - grid_num))])
    # 计算l2距离
    l2_distance = np.sqrt(i_distance ** 2 + j_distance ** 2)
    if l2_distance <= conn_grid_num:
      conn_weight_pre_idx.append(int(round((pre_i * location[f'{pre_group}_grid_num'] + pre_j)/pre_step))) # 除以间隔来得到正确的索引
      conn_weight_post_idx.append(int(round((post_i * location[f'{post_group}_grid_num'] + post_j)/post_step))) # 除以间隔来得到正确的索引
      if pre_group == 'E':
        conn_weight_weight.append(wE * np.exp(- l2_distance**2 / sigmaE) * np.abs(np.random.normal(1, 0.4)))
      elif pre_group == 'I':
        conn_weight_weight.append(wI * np.abs(np.random.normal(1, 0.4)))
  return conn_weight_pre_idx, conn_weight_post_idx, conn_weight_weight


def generate_conn_and_weight(grid_num, grid_distance, E_conn_grid_num, I_conn_grid_num, wE, wI, sigmaE):
  '''生成连接索引和连接权重'''
  location = {}
  # 生成位置
  location['E_x_mesh'], location['E_y_mesh'], location['E_i_mesh'], location['E_j_mesh'], location['I_x_mesh'], location['I_y_mesh'], location['I_i_mesh'], location['I_j_mesh'] = generate_location(grid_num, grid_distance)
  location['E_num'] = location['E_x_mesh'].size
  location['I_num'] = location['I_x_mesh'].size
  location['E_step'] = 1
  location['I_step'] = 2
  location['E_grid_num'] = grid_num // location['E_step']
  location['I_grid_num'] = grid_num // location['I_step']

  # 生成连接索引和连接权重
  conn_weight = {}
  for pre_group in ['E', 'I']:
    for post_group in ['E', 'I']:
      print(f'{pre_group}2{post_group}')
      conn_weight[f'{pre_group}2{post_group}_pre_idx'] = []
      conn_weight[f'{pre_group}2{post_group}_post_idx'] = []
      conn_weight[f'{pre_group}2{post_group}_weight'] = []
      pre_step = location[f'{pre_group}_step']
      post_step = location[f'{post_group}_step']
      # multi_process加速
      r = multi_process_list_for(process_num=process_num, func=part_generate_conn_and_weight, kwargs={'location': location, 'pre_group': pre_group, 'pre_step': pre_step, 'post_group': post_group, 'post_step': post_step, 'E_conn_grid_num': E_conn_grid_num, 'I_conn_grid_num': I_conn_grid_num, 'wE': wE, 'wI': wI, 'sigmaE': sigmaE}, for_list=np.arange(location[f'{pre_group}_i_mesh'].size), for_idx_name='idx')
      # 整理结果
      for sub_r in r:
        conn_weight[f'{pre_group}2{post_group}_pre_idx'].extend(sub_r[0])
        conn_weight[f'{pre_group}2{post_group}_post_idx'].extend(sub_r[1])
        conn_weight[f'{pre_group}2{post_group}_weight'].extend(sub_r[2])
      # 转换为np.array
      conn_weight[f'{pre_group}2{post_group}_pre_idx'] = np.array(conn_weight[f'{pre_group}2{post_group}_pre_idx'])
      conn_weight[f'{pre_group}2{post_group}_post_idx'] = np.array(conn_weight[f'{pre_group}2{post_group}_post_idx'])
      conn_weight[f'{pre_group}2{post_group}_weight'] = np.array(conn_weight[f'{pre_group}2{post_group}_weight'])
      # 判断是否有重复(重复指的是一个二元组出现多次)
      pre_post_idx = np.stack([conn_weight[f'{pre_group}2{post_group}_pre_idx'], conn_weight[f'{pre_group}2{post_group}_post_idx']], axis=1)
      unique_pre_post_idx, unique_idx = np.unique(pre_post_idx, axis=0, return_index=True)
      if len(unique_idx) != len(pre_post_idx):
        print('有重复的连接')
  return location, conn_weight


def get_comm(location, conn_weight):
  '''生成comm'''
  neuron_comm = {}
  for pre_group in ['E', 'I']:
    for post_group in ['E', 'I']:
      print(f'{pre_group}2{post_group}_comm')
      neuron_comm[f'{pre_group}2{post_group}_comm'] = ij_comm(pre=conn_weight[f'{pre_group}2{post_group}_pre_idx'], post=conn_weight[f'{pre_group}2{post_group}_post_idx'], pre_size=location[f'{pre_group}_num'], post_size=location[f'{post_group}_num'], weight=conn_weight[f'{pre_group}2{post_group}_weight'])
  return neuron_comm


def set_xylim(ax, grid_num, grid_distance):
  xlim = [0, (grid_num-1) * grid_distance]
  ylim = [0, (grid_num-1) * grid_distance]
  ax.set_xlim(xlim)
  ax.set_ylim(ylim)
  ax.set_aspect('equal')


def visualize_V_one_step(i, basedir, vmin, vmax, E_V, I_V, ts, location, grid_num, grid_distance, s):
  fig_path = os.path.join(basedir, 'V', f'{i}.png')

  fig, axes = plt.subplots(1, 2, figsize=(12, 6), sharex=True, sharey=True)

  ax = axes[0]
  sc = ax.scatter(location['E_x_mesh'].flatten(), location['E_y_mesh'].flatten(), c=E_V[i], cmap=plt.cm.jet, vmin=vmin, vmax=vmax, s=s, clip_on=False)
  cbar = plt.colorbar(sc, ax=ax)
  ax.set_title('E')
  set_xylim(ax, grid_num, grid_distance)

  ax = axes[1]
  sc = ax.scatter(location['I_x_mesh'].flatten(), location['I_y_mesh'].flatten(), c=I_V[i], cmap=plt.cm.jet, vmin=vmin, vmax=vmax, s=s, clip_on=False)
  cbar = plt.colorbar(sc, ax=ax)
  ax.set_title('I')
  set_xylim(ax, grid_num, grid_distance)

  fig.suptitle(f't={ts[i]:.3f}')
  fig.savefig(fig_path)
  plt.close(fig)
  return fig_path


class SNN_analyzer:
  def __init__(self, grid_num, grid_distance, E_conn_grid_num, I_conn_grid_num, wE, wI, sigmaE, E_inp_func, I_inp_func, time_period):
    self.grid_num = grid_num
    self.grid_distance = grid_distance
    self.E_conn_grid_num = E_conn_grid_num
    self.I_conn_grid_num = I_conn_grid_num
    self.wE = wE
    self.wI = wI
    self.sigmaE = sigmaE
    self.E_inp_func = E_inp_func
    self.I_inp_func = I_inp_func
    self.net = EINet(grid_num=grid_num, grid_distance=grid_distance, E_conn_grid_num=E_conn_grid_num, I_conn_grid_num=I_conn_grid_num, wE=wE, wI=wI, sigmaE=sigmaE, E_inp_func=E_inp_func, I_inp_func=I_inp_func)
    self.location = self.net.location
    self.conn_weight = self.net.conn_weight
    monitors = {'E.spike': self.net.E.spike, 'I.spike': self.net.I.spike, 'E.V': self.net.E.V, 'I.V': self.net.I.V}
    monitors['E.E_current'] = lambda: self.net.E.sum_current_inputs(self.net.E.V, label='E')
    monitors['E.I_current'] = lambda: self.net.E.sum_current_inputs(self.net.E.V, label='I')
    monitors['I.E_current'] = lambda: self.net.I.sum_current_inputs(self.net.I.V, label='E')
    monitors['I.I_current'] = lambda: self.net.I.sum_current_inputs(self.net.I.V, label='I')
    monitors['E.input_current'] = lambda: self.net.E.sum_current_inputs(self.net.E.V, label='input')
    monitors['I.input_current'] = lambda: self.net.I.sum_current_inputs(self.net.I.V, label='input')
    monitors['E.E2E_g'] = lambda: (self.net.E2E.proj.refs['syn'].g_decay - self.net.E2E.proj.refs['syn'].g_rise) * self.net.E2E.proj.refs['syn'].a
    monitors['E.E2I_g'] = lambda: (self.net.E2I.proj.refs['syn'].g_decay - self.net.E2I.proj.refs['syn'].g_rise) * self.net.E2I.proj.refs['syn'].a
    monitors['I.I2E_g'] = lambda: (self.net.I2E.proj.refs['syn'].g_decay - self.net.I2E.proj.refs['syn'].g_rise) * self.net.I2E.proj.refs['syn'].a
    monitors['I.I2I_g'] = lambda: (self.net.I2I.proj.refs['syn'].g_decay - self.net.I2I.proj.refs['syn'].g_rise) * self.net.I2I.proj.refs['syn'].a
    self.runner = bp.DSRunner(self.net, monitors=monitors)
    self.runner.run(duration=time_period)
    self.indices = np.arange(int(time_period / bm.get_dt()))
    self.ts = self.indices * bm.get_dt()
    self.E_spike = self.runner.mon['E.spike']
    self.I_spike = self.runner.mon['I.spike']
    self.E_V = self.runner.mon['E.V']
    self.I_V = self.runner.mon['I.V']
    basedir = '../../results'
    current_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
    self.basedir = os.path.join(basedir, current_time)
    mkdir(self.basedir)
    self.s = np.pi * (100 / self.grid_num)**2

  def visualize_V(self, start=0, stop=None, step=10, frame_rate=5, delete_figs=True):
    if stop is None:
      stop = np.min([len(self.ts), 1000])
    s = self.s
    mkdir(os.path.join(self.basedir, 'V'))
    fig_paths = multi_process_list_for(process_num=process_num, func=visualize_V_one_step, kwargs={'basedir': self.basedir, 'vmin': self.net.E.V_rest, 'vmax': self.net.E.V_th, 'E_V': self.E_V, 'I_V': self.I_V, 'ts': self.ts, 'location': self.location, 'grid_num': self.grid_num, 'grid_distance': self.grid_distance, 's': s}, for_list=np.arange(start, stop, step), for_idx_name='i')
    fig_to_video(fig_paths, os.path.join(self.basedir, 'V', 'V_video'), frame_rate=frame_rate, delete_figs=delete_figs)

  def visualize_V_one_neuron(self, neuron_group, row_idx, col_idx):
    if neuron_group == 'E':
      neuron_idx = row_idx * self.grid_num + col_idx
    elif neuron_group == 'I':
      neuron_idx = row_idx * self.grid_num//2 + col_idx
    mkdir(os.path.join(self.basedir, 'V_one_neuron'))
    fig, ax = plt.subplots(1, 1, figsize=(6, 6))
    s = self.s
    if neuron_group == 'E':
      V = self.E_V
    elif neuron_group == 'I':
      V = self.I_V
    ax.scatter(self.indices, V[:, neuron_idx], s=s, clip_on=False)
    ax.set_title(f'{neuron_group} row {row_idx} col {col_idx}')
    ax.set_xlabel('t')
    ax.set_ylabel('V')
    fig.savefig(os.path.join(self.basedir, 'V_one_neuron', f'{neuron_group}_row_{row_idx}_col_{col_idx}.png'))
    plt.close(fig)

  def visualize_spike(self):
    mkdir(os.path.join(self.basedir, 'spike'))
    fig, ax = plt.subplots(2, 1, figsize=(6, 6))
    bp.visualize.raster_plot(self.ts, self.E_spike, ax=ax[0])
    bp.visualize.raster_plot(self.ts, self.I_spike, ax=ax[1])
    ax[0].set_title('E')
    ax[1].set_title('I')
    fig.suptitle('Spike')
    fig.savefig(os.path.join(self.basedir, 'spike', 'spike.png'))
    plt.close(fig)

  def visualize_current(self, E_neuron_idx=None, I_neuron_idx=None):
    if isinstance(E_neuron_idx, int):
      E_neuron_idx = (E_neuron_idx, )
    if isinstance(I_neuron_idx, int):
      I_neuron_idx = (I_neuron_idx, )
    if E_neuron_idx is None:
      E_neuron_idx = slice(None)
    if I_neuron_idx is None:
      I_neuron_idx = slice(None)
    mkdir(os.path.join(self.basedir, 'current'))
    fig, axes = plt.subplots(3, 2, figsize=(12, 6))
    ax = axes[0, 0]
    ax.plot(self.ts, np.mean(self.runner.mon['E.E_current'][:, E_neuron_idx], axis=1), label='E')
    ax = axes[1, 0]
    ax.plot(self.ts, np.mean(self.runner.mon['E.I_current'][:, E_neuron_idx], axis=1), label='I')
    ax = axes[2, 0]
    ax.plot(self.ts, np.mean(self.runner.mon['E.input_current'][:, E_neuron_idx], axis=1), label='input')
    for ax in axes[:, 0]:
      ax.set_title('E')
      ax.legend()
    ax = axes[0, 1]
    ax.plot(self.ts, np.mean(self.runner.mon['I.E_current'][:, I_neuron_idx], axis=1), label='E')
    ax = axes[1, 1]
    ax.plot(self.ts, np.mean(self.runner.mon['I.I_current'][:, I_neuron_idx], axis=1), label='I')
    ax = axes[2, 1]
    ax.plot(self.ts, np.mean(self.runner.mon['I.input_current'][:, I_neuron_idx], axis=1), label='input')
    for ax in axes[:, 1]:
      ax.set_title('I')
      ax.legend()
    fig.suptitle('Current')
    fig.savefig(os.path.join(self.basedir, 'current', f'current_{E_neuron_idx}_{I_neuron_idx}.png'))
    plt.close(fig)

if __name__ == '__main__':
  bm.set_dt(0.1)
  process_num = 1
  grid_num = 100
  grid_distance = 6.1 * 10**(-3) # mm
  # 小包的case
  E_conn_grid_num = 10
  I_conn_grid_num = 40
  wE = 20. * 0.2235 * 0.001 / np.sqrt(0.4) * 80
  wI = 20. * 0.0578 * 0.001 / np.sqrt(0.4) * 40
  # # 连成一片的case
  E_conn_grid_num = 36
  I_conn_grid_num = 12
  wE = 20. * 0.2235 * 0.001 / np.sqrt(0.4) * 50
  wI = 20. * 0.0578 * 0.001 / np.sqrt(0.4) * 120
  sigmaE = 18.
  time_period = 1000.  # ms
  def E_inp_func(t):
    r = np.zeros((grid_num, grid_num))
    return r.flatten()
  I_inp_func = lambda x: 0.
  snn_analyzer = SNN_analyzer(grid_num, grid_distance, E_conn_grid_num, I_conn_grid_num, wE, wI, sigmaE, E_inp_func, I_inp_func, time_period)
  snn_analyzer.visualize_spike()
  snn_analyzer.visualize_V_one_neuron('E', 2, 17)
  snn_analyzer.visualize_V_one_neuron('E', 2, 18)
  snn_analyzer.visualize_V_one_neuron('E', 2, 2)
  snn_analyzer.visualize_V_one_neuron('E', 18, 18)
  snn_analyzer.visualize_V_one_neuron('E', 3, 18)
  snn_analyzer.visualize_current()
  snn_analyzer.visualize_current(E_neuron_idx=0, I_neuron_idx=0)
  snn_analyzer.visualize_current(E_neuron_idx=grid_num//2 * grid_num + grid_num//2, I_neuron_idx=grid_num//2//2 * grid_num//2 + grid_num//2//2)
  snn_analyzer.visualize_V(start=500, stop=4000, step=25, frame_rate=15, delete_figs=True)

Traceback (most recent call last):
Traceback (most recent call last):
  File "C:\Program Files\Python311\Lib\site-packages\jax\_src\interpreters\mlir.py", line 2150, in _lower_jaxpr_to_fun_cached
    func_op = ctx.cached_primitive_lowerings[key]
              ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^
KeyError: (None, let _where = { lambda ; a:bool[10000] b:f32[] c:f32[10000]. let
    d:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
    e:f32[10000] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(10000,)
      sharding=None
    ] d
    f:f32[10000] = select_n a c e
  in (f,) } in
let _where1 = { lambda ; g:bool[2500] h:f32[] i:f32[2500]. let
    j:f32[] = convert_element_type[new_dtype=float32 weak_type=False] h
    k:f32[2500] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(2500,)
      sharding=None
    ] j
    l:f32[2500] = select_n g i k
  in (l,) } in
{ lambda ; m:f32[40530000] n:i32[40530000] o:i32[10001] p:f32[10132500] q:i32[10132500]
    r:i32[10001] s:f32[1102500] t:i32[1102500] u:i32[2501] v:f32[282500] w:i32[282500]
    x:i32[2501] y:f32[10000] z:f32[2500] ba:f32[2500] bb:f32[2500] bc:bool[10000]
    bd:f32[2500] be:f32[10000] bf:bool[2500] bg:f32[10000] bh:f32[2500] bi:f32[10000]
    bj:f32[10000] bk:f32[10000] bl:bool[10000] bm:f32[1] bn:f32[2500] bo:f32[10000]
    bp:f32[10000] bq:f32[1] br:f32[2500] bs:f32[10000] bt:f32[10000] bu:f32[2500]
    bv:f32[2500] bw:bool[2500] bx:f32[10000] by:f32[10000] bz:f32[2500] ca:f32[2500]
    cb:i32[]. let
    cc:f32[] = convert_element_type[new_dtype=float32 weak_type=True] cb
    cd:f32[] = mul cc 0.1
    ce:f32[] = add 0.0 cd
    cf:f32[10000] = braintaichi_custom_op_43[
      float_as_event=True
      outs=(ShapedArray(float32[10000]),)
      shape=(10000, 10000)
      transpose=True
    ] m n o bc
    cg:f32[10000] = add bo cf
    ch:f32[10000] = add bp cf
    ci:f32[2500] = braintaichi_custom_op_43[
      float_as_event=True
      outs=(ShapedArray(float32[2500]),)
      shape=(10000, 2500)
      transpose=True
    ] p q r bc
    cj:f32[2500] = add bh ci
    ck:f32[2500] = add br ci
    cl:f32[10000] = braintaichi_custom_op_43[
      float_as_event=True
      outs=(ShapedArray(float32[10000]),)
      shape=(2500, 10000)
      transpose=True
    ] s t u bf
    cm:f32[10000] = add bs cl
    cn:f32[10000] = add bt cl
    co:f32[2500] = braintaichi_custom_op_43[
      float_as_event=True
      outs=(ShapedArray(float32[2500]),)
      shape=(2500, 2500)
      transpose=True
    ] v w x bf
    cp:f32[2500] = add bu co
    cq:f32[2500] = add bv co
    cr:f32[10000] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(10000,)
      sharding=None
    ] 1.0
    cs:f32[10000] = mul cr 0.019999999552965164
    ct:f32[2500] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(2500,)
      sharding=None
    ] 1.0
    cu:f32[2500] = mul ct 0.019999999552965164
    cv:f32[10000] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(10000,)
      sharding=None
    ] 1.0
    cw:f32[10000] = mul cv 0.019999999552965164
    cx:f32[2500] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(2500,)
      sharding=None
    ] 1.0
    cy:f32[2500] = mul cx 0.019999999552965164
    cz:f32[10000] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(10000,)
      sharding=None
    ] 1.0
    da:f32[10000] = device_put[
      copy_semantics=[<CopySemantics.ALIAS: 1>]
      devices=[None]
      srcs=[None]
    ] y
    db:f32[10000] = mul cz da
    dc:f32[2500] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(2500,)
      sharding=None
    ] 1.0
    dd:f32[2500] = mul dc 0.0
    de:i32[10000] = convert_element_type[new_dtype=int32 weak_type=True] bl
    df:i32[10000] = mul de 10000
    dg:f32[10000] = convert_element_type[new_dtype=float32 weak_type=False] df
    dh:f32[10000] = add bx dg
    di:f32[10000] = convert_element_type[new_dtype=float32 weak_type=False] df
    dj:f32[10000] = add by di
    dk:i32[2500] = convert_element_type[new_dtype=int32 weak_type=True] bw
    dl:i32[2500] = mul dk 2500
    dm:f32[2500] = convert_element_type[new_dtype=float32 weak_type=False] dl
    dn:f32[2500] = add bz dm
    do:f32[2500] = convert_element_type[new_dtype=float32 weak_type=False] dl
    dp:f32[2500] = add ca do
    dq:f32[10000] = neg cg
    dr:f32[10000] = div dq 0.30000001192092896
    ds:f32[10000] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(10000,)
      sharding=None
    ] 1.0
    dt:f32[10000] = div ds 0.30000001192092896
    du:f32[10000] = neg dt
    dv:f32[10000] = mul 0.10000000149011612 du
    dw:f32[10000] = abs dv
    dx:bool[10000] = le dw 9.999999747378752e-06
    dy:f32[10000] = div dv 2.0
    dz:f32[10000] = add 1.0 dy
    ea:f32[10000] = mul dv dv
    eb:f32[10000] = div ea 6.0
    ec:f32[10000] = add dz eb
    ed:f32[10000] = exp dv
    ee:f32[10000] = sub ed 1.0
    ef:f32[10000] = div ee dv
    eg:f32[10000] = select_n dx ef ec
    eh:f32[10000] = mul 0.10000000149011612 eg
    ei:f32[10000] = mul eh dr
    ej:f32[10000] = add cg ei
    ek:f32[10000] = neg ch
    el:f32[10000] = div ek 2.0
    em:f32[10000] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(10000,)
      sharding=None
    ] 1.0
    en:f32[10000] = div em 2.0
    eo:f32[10000] = neg en
    ep:f32[10000] = mul 0.10000000149011612 eo
    eq:f32[10000] = abs ep
    er:bool[10000] = le eq 9.999999747378752e-06
    es:f32[10000] = div ep 2.0
    et:f32[10000] = add 1.0 es
    eu:f32[10000] = mul ep ep
    ev:f32[10000] = div eu 6.0
    ew:f32[10000] = add et ev
    ex:f32[10000] = exp ep
    ey:f32[10000] = sub ex 1.0
    ez:f32[10000] = div ey ep
    fa:f32[10000] = select_n er ez ew
    fb:f32[10000] = mul 0.10000000149011612 fa
    fc:f32[10000] = mul fb el
    fd:f32[10000] = add ch fc
    fe:f32[10000] = sub fd ej
    ff:f32[10000] = mul 0.5882353186607361 fe
    fg:f32[10000] = neg cm
    fh:f32[10000] = div fg 0.30000001192092896
    fi:f32[10000] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(10000,)
      sharding=None
    ] 1.0
    fj:f32[10000] = div fi 0.30000001192092896
    fk:f32[10000] = neg fj
    fl:f32[10000] = mul 0.10000000149011612 fk
    fm:f32[10000] = abs fl
    fn:bool[10000] = le fm 9.999999747378752e-06
    fo:f32[10000] = div fl 2.0
    fp:f32[10000] = add 1.0 fo
    fq:f32[10000] = mul fl fl
    fr:f32[10000] = div fq 6.0
    fs:f32[10000] = add fp fr
    ft:f32[10000] = exp fl
    fu:f32[10000] = sub ft 1.0
    fv:f32[10000] = div fu fl
    fw:f32[10000] = select_n fn fv fs
    fx:f32[10000] = mul 0.10000000149011612 fw
    fy:f32[10000] = mul fx fh
    fz:f32[10000] = add cm fy
    ga:f32[10000] = neg cn
    gb:f32[10000] = div ga 3.0
    gc:f32[10000] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(10000,)
      sharding=None
    ] 1.0
    gd:f32[10000] = div gc 3.0
    ge:f32[10000] = neg gd
    gf:f32[10000] = mul 0.10000000149011612 ge
    gg:f32[10000] = abs gf
    gh:bool[10000] = le gg 9.999999747378752e-06
    gi:f32[10000] = div gf 2.0
    gj:f32[10000] = add 1.0 gi
    gk:f32[10000] = mul gf gf
    gl:f32[10000] = div gk 6.0
    gm:f32[10000] = add gj gl
    gn:f32[10000] = exp gf
    go:f32[10000] = sub gn 1.0
    gp:f32[10000] = div go gf
    gq:f32[10000] = select_n gh gp gm
    gr:f32[10000] = mul 0.10000000149011612 gq
    gs:f32[10000] = mul gr gb
    gt:f32[10000] = add cn gs
    gu:f32[10000] = sub gt fz
    gv:f32[10000] = mul 0.37037035822868347 gu
    gw:f32[10000] = neg dh
    gx:f32[10000] = div gw 0.30000001192092896
    gy:f32[10000] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(10000,)
      sharding=None
    ] 1.0
    gz:f32[10000] = div gy 0.30000001192092896
    ha:f32[10000] = neg gz
    hb:f32[10000] = mul 0.10000000149011612 ha
    hc:f32[10000] = abs hb
    hd:bool[10000] = le hc 9.999999747378752e-06
    he:f32[10000] = div hb 2.0
    hf:f32[10000] = add 1.0 he
    hg:f32[10000] = mul hb hb
    hh:f32[10000] = div hg 6.0
    hi:f32[10000] = add hf hh
    hj:f32[10000] = exp hb
    hk:f32[10000] = sub hj 1.0
    hl:f32[10000] = div hk hb
    hm:f32[10000] = select_n hd hl hi
    hn:f32[10000] = mul 0.10000000149011612 hm
    ho:f32[10000] = mul hn gx
    hp:f32[10000] = add dh ho
    hq:f32[10000] = neg dj
    hr:f32[10000] = div hq 2.0
    hs:f32[10000] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(10000,)
      sharding=None
    ] 1.0
    ht:f32[10000] = div hs 2.0
    hu:f32[10000] = neg ht
    hv:f32[10000] = mul 0.10000000149011612 hu
    hw:f32[10000] = abs hv
    hx:bool[10000] = le hw 9.999999747378752e-06
    hy:f32[10000] = div hv 2.0
    hz:f32[10000] = add 1.0 hy
    ia:f32[10000] = mul hv hv
    ib:f32[10000] = div ia 6.0
    ic:f32[10000] = add hz ib
    id:f32[10000] = exp hv
    ie:f32[10000] = sub id 1.0
    if:f32[10000] = div ie hv
    ig:f32[10000] = select_n hx if ic
    ih:f32[10000] = mul 0.10000000149011612 ig
    ii:f32[10000] = mul ih hr
    ij:f32[10000] = add dj ii
    ik:f32[10000] = sub ij hp
    il:f32[10000] = mul 0.5882353186607361 ik
    im:f32[10000] = sub 0.0 be
    in:f32[10000] = mul ff im
    io:f32[10000] = add in 0.0
    ip:f32[10000] = sub -80.0 be
    iq:f32[10000] = mul gv ip
    ir:f32[10000] = add io iq
    is:f32[10000] = sub 0.0 be
    it:f32[10000] = mul cs is
    iu:f32[10000] = add ir it
    iv:f32[10000] = sub -80.0 be
    iw:f32[10000] = mul cw iv
    ix:f32[10000] = add iu iw
    iy:f32[10000] = add ix db
    iz:f32[10000] = sub 0.0 be
    ja:f32[10000] = mul il iz
    jb:f32[10000] = add iy ja
    jc:f32[10000] = sub be -60.625
    jd:f32[10000] = div jc 6.5625
    je:f32[10000] = exp jd
    jf:f32[10000] = mul 6.5625 je
    jg:f32[10000] = sub be -70.0
    jh:f32[10000] = neg jg
    ji:f32[10000] = add jh jf
    jj:f32[10000] = mul 1.0 jb
    jk:f32[10000] = add ji jj
    jl:f32[10000] = div jk 20.0
    jm:f32[10000] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(10000,)
      sharding=None
    ] 1.0
    jn:f32[10000] = div jm 20.0
    jo:f32[10000] = mul 6.5625 jn
    jp:f32[10000] = mul jo je
    jq:f32[10000] = div jp 6.5625
    jr:f32[10000] = neg jn
    js:f32[10000] = add_any jq jr
    jt:f32[10000] = mul 0.10000000149011612 js
    ju:f32[10000] = abs jt
    jv:bool[10000] = le ju 9.999999747378752e-06
    jw:f32[10000] = div jt 2.0
    jx:f32[10000] = add 1.0 jw
    jy:f32[10000] = mul jt jt
    jz:f32[10000] = div jy 6.0
    ka:f32[10000] = add jx jz
    kb:f32[10000] = exp jt
    kc:f32[10000] = sub kb 1.0
    kd:f32[10000] = div kc jt
    ke:f32[10000] = select_n jv kd ka
    kf:f32[10000] = mul 0.10000000149011612 ke
    kg:f32[10000] = mul kf jl
    kh:f32[10000] = add be kg
    ki:f32[10000] = add kh 0.0
    kj:f32[] = convert_element_type[new_dtype=float32 weak_type=False] ce
    kk:f32[10000] = sub kj bg
    kl:bool[10000] = le kk 5.0
    km:f32[10000] = pjit[
      name=_where
      jaxpr={ lambda ; kn:bool[10000] ko:f32[10000] kp:f32[10000]. let
          kq:f32[10000] = select_n kn kp ko
        in (kq,) }
    ] kl be ki
    kr:bool[10000] = ge km -40.0
    ks:f32[10000] = pjit[name=_where jaxpr=_where] kr -70.0 km
    kt:f32[10000] = pjit[name=_where jaxpr=_where] kr ce bg
    ku:f32[2500] = neg cj
    kv:f32[2500] = div ku 0.30000001192092896
    kw:f32[2500] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(2500,)
      sharding=None
    ] 1.0
    kx:f32[2500] = div kw 0.30000001192092896
    ky:f32[2500] = neg kx
    kz:f32[2500] = mul 0.10000000149011612 ky
    la:f32[2500] = abs kz
    lb:bool[2500] = le la 9.999999747378752e-06
    lc:f32[2500] = div kz 2.0
    ld:f32[2500] = add 1.0 lc
    le:f32[2500] = mul kz kz
    lf:f32[2500] = div le 6.0
    lg:f32[2500] = add ld lf
    lh:f32[2500] = exp kz
    li:f32[2500] = sub lh 1.0
    lj:f32[2500] = div li kz
    lk:f32[2500] = select_n lb lj lg
    ll:f32[2500] = mul 0.10000000149011612 lk
    lm:f32[2500] = mul ll kv
    ln:f32[2500] = add cj lm
    lo:f32[2500] = neg ck
    lp:f32[2500] = div lo 2.0
    lq:f32[2500] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(2500,)
      sharding=None
    ] 1.0
    lr:f32[2500] = div lq 2.0
    ls:f32[2500] = neg lr
    lt:f32[2500] = mul 0.10000000149011612 ls
    lu:f32[2500] = abs lt
    lv:bool[2500] = le lu 9.999999747378752e-06
    lw:f32[2500] = div lt 2.0
    lx:f32[2500] = add 1.0 lw
    ly:f32[2500] = mul lt lt
    lz:f32[2500] = div ly 6.0
    ma:f32[2500] = add lx lz
    mb:f32[2500] = exp lt
    mc:f32[2500] = sub mb 1.0
    md:f32[2500] = div mc lt
    me:f32[2500] = select_n lv md ma
    mf:f32[2500] = mul 0.10000000149011612 me
    mg:f32[2500] = mul mf lp
    mh:f32[2500] = add ck mg
    mi:f32[2500] = sub mh ln
    mj:f32[2500] = mul 0.5882353186607361 mi
    mk:f32[2500] = neg cp
    ml:f32[2500] = div mk 0.30000001192092896
    mm:f32[2500] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(2500,)
      sharding=None
    ] 1.0
    mn:f32[2500] = div mm 0.30000001192092896
    mo:f32[2500] = neg mn
    mp:f32[2500] = mul 0.10000000149011612 mo
    mq:f32[2500] = abs mp
    mr:bool[2500] = le mq 9.999999747378752e-06
    ms:f32[2500] = div mp 2.0
    mt:f32[2500] = add 1.0 ms
    mu:f32[2500] = mul mp mp
    mv:f32[2500] = div mu 6.0
    mw:f32[2500] = add mt mv
    mx:f32[2500] = exp mp
    my:f32[2500] = sub mx 1.0
    mz:f32[2500] = div my mp
    na:f32[2500] = select_n mr mz mw
    nb:f32[2500] = mul 0.10000000149011612 na
    nc:f32[2500] = mul nb ml
    nd:f32[2500] = add cp nc
    ne:f32[2500] = neg cq
    nf:f32[2500] = div ne 3.0
    ng:f32[2500] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(2500,)
      sharding=None
    ] 1.0
    nh:f32[2500] = div ng 3.0
    ni:f32[2500] = neg nh
    nj:f32[2500] = mul 0.10000000149011612 ni
    nk:f32[2500] = abs nj
    nl:bool[2500] = le nk 9.999999747378752e-06
    nm:f32[2500] = div nj 2.0
    nn:f32[2500] = add 1.0 nm
    no:f32[2500] = mul nj nj
    np:f32[2500] = div no 6.0
    nq:f32[2500] = add nn np
    nr:f32[2500] = exp nj
    ns:f32[2500] = sub nr 1.0
    nt:f32[2500] = div ns nj
    nu:f32[2500] = select_n nl nt nq
    nv:f32[2500] = mul 0.10000000149011612 nu
    nw:f32[2500] = mul nv nf
    nx:f32[2500] = add cq nw
    ny:f32[2500] = sub nx nd
    nz:f32[2500] = mul 0.37037035822868347 ny
    oa:f32[2500] = neg dn
    ob:f32[2500] = div oa 0.30000001192092896
    oc:f32[2500] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(2500,)
      sharding=None
    ] 1.0
    od:f32[2500] = div oc 0.30000001192092896
    oe:f32[2500] = neg od
    of:f32[2500] = mul 0.10000000149011612 oe
    og:f32[2500] = abs of
    oh:bool[2500] = le og 9.999999747378752e-06
    oi:f32[2500] = div of 2.0
    oj:f32[2500] = add 1.0 oi
    ok:f32[2500] = mul of of
    ol:f32[2500] = div ok 6.0
    om:f32[2500] = add oj ol
    on:f32[2500] = exp of
    oo:f32[2500] = sub on 1.0
    op:f32[2500] = div oo of
    oq:f32[2500] = select_n oh op om
    or:f32[2500] = mul 0.10000000149011612 oq
    os:f32[2500] = mul or ob
    ot:f32[2500] = add dn os
    ou:f32[2500] = neg dp
    ov:f32[2500] = div ou 2.0
    ow:f32[2500] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(2500,)
      sharding=None
    ] 1.0
    ox:f32[2500] = div ow 2.0
    oy:f32[2500] = neg ox
    oz:f32[2500] = mul 0.10000000149011612 oy
    pa:f32[2500] = abs oz
    pb:bool[2500] = le pa 9.999999747378752e-06
    pc:f32[2500] = div oz 2.0
    pd:f32[2500] = add 1.0 pc
    pe:f32[2500] = mul oz oz
    pf:f32[2500] = div pe 6.0
    pg:f32[2500] = add pd pf
    ph:f32[2500] = exp oz
    pi:f32[2500] = sub ph 1.0
    pj:f32[2500] = div pi oz
    pk:f32[2500] = select_n pb pj pg
    pl:f32[2500] = mul 0.10000000149011612 pk
    pm:f32[2500] = mul pl ov
    pn:f32[2500] = add dp pm
    po:f32[2500] = sub pn ot
    pp:f32[2500] = mul 0.5882353186607361 po
    pq:f32[2500] = sub 0.0 bb
    pr:f32[2500] = mul mj pq
    ps:f32[2500] = add pr 0.0
    pt:f32[2500] = sub -80.0 bb
    pu:f32[2500] = mul nz pt
    pv:f32[2500] = add ps pu
    pw:f32[2500] = sub 0.0 bb
    px:f32[2500] = mul cu pw
    py:f32[2500] = add pv px
    pz:f32[2500] = sub -80.0 bb
    qa:f32[2500] = mul cy pz
    qb:f32[2500] = add py qa
    qc:f32[2500] = add qb dd
    qd:f32[2500] = sub 0.0 bb
    qe:f32[2500] = mul pp qd
    qf:f32[2500] = add qc qe
    qg:f32[2500] = sub bb -60.625
    qh:f32[2500] = div qg 6.5625
    qi:f32[2500] = exp qh
    qj:f32[2500] = mul 6.5625 qi
    qk:f32[2500] = sub bb -70.0
    ql:f32[2500] = neg qk
    qm:f32[2500] = add ql qj
    qn:f32[2500] = mul 1.0 qf
    qo:f32[2500] = add qm qn
    qp:f32[2500] = div qo 20.0
    qq:f32[2500] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(2500,)
      sharding=None
    ] 1.0
    qr:f32[2500] = div qq 20.0
    qs:f32[2500] = mul 6.5625 qr
    qt:f32[2500] = mul qs qi
    qu:f32[2500] = div qt 6.5625
    qv:f32[2500] = neg qr
    qw:f32[2500] = add_any qu qv
    qx:f32[2500] = mul 0.10000000149011612 qw
    qy:f32[2500] = abs qx
    qz:bool[2500] = le qy 9.999999747378752e-06
    ra:f32[2500] = div qx 2.0
    rb:f32[2500] = add 1.0 ra
    rc:f32[2500] = mul qx qx
    rd:f32[2500] = div rc 6.0
    re:f32[2500] = add rb rd
    rf:f32[2500] = exp qx
    rg:f32[2500] = sub rf 1.0
    rh:f32[2500] = div rg qx
    ri:f32[2500] = select_n qz rh re
    rj:f32[2500] = mul 0.10000000149011612 ri
    rk:f32[2500] = mul rj qp
    rl:f32[2500] = add bb rk
    rm:f32[2500] = add rl 0.0
    rn:f32[] = convert_element_type[new_dtype=float32 weak_type=False] ce
    ro:f32[2500] = sub rn bd
    rp:bool[2500] = le ro 5.0
    rq:f32[2500] = pjit[
      name=_where
      jaxpr={ lambda ; rr:bool[2500] rs:f32[2500] rt:f32[2500]. let
          ru:f32[2500] = select_n rr rt rs
        in (ru,) }
    ] rp bb rm
    rv:bool[2500] = ge rq -40.0
    rw:f32[2500] = pjit[name=_where jaxpr=_where1] rv -70.0 rq
    rx:f32[2500] = pjit[name=_where jaxpr=_where1] rv ce bd
    ry:f32[10000] = sub 0.0 ks
    rz:f32[10000] = mul ff ry
    sa:f32[10000] = add rz 0.0
    sb:f32[10000] = sub 0.0 ks
    sc:f32[10000] = mul cs sb
    sd:f32[10000] = add sa sc
    se:f32[10000] = sub -80.0 ks
    sf:f32[10000] = mul gv se
    sg:f32[10000] = add sf 0.0
    sh:f32[10000] = sub -80.0 ks
    si:f32[10000] = mul cw sh
    sj:f32[10000] = add sg si
    sk:f32[2500] = sub 0.0 rw
    sl:f32[2500] = mul mj sk
    sm:f32[2500] = add sl 0.0
    sn:f32[2500] = sub 0.0 rw
    so:f32[2500] = mul cu sn
    sp:f32[2500] = add sm so
    sq:f32[2500] = sub -80.0 rw
    sr:f32[2500] = mul nz sq
    ss:f32[2500] = add sr 0.0
    st:f32[2500] = sub -80.0 rw
    su:f32[2500] = mul cy st
    sv:f32[2500] = add ss su
    sw:f32[10000] = add 0.0 db
    sx:f32[10000] = sub 0.0 ks
    sy:f32[10000] = mul il sx
    sz:f32[10000] = add sw sy
    ta:f32[2500] = add 0.0 dd
    tb:f32[2500] = sub 0.0 rw
    tc:f32[2500] = mul pp tb
    td:f32[2500] = add ta tc
    te:f32[10000] = sub fd ej
    tf:f32[10000] = mul te 0.5882353186607361
    tg:f32[2500] = sub mh ln
    th:f32[2500] = mul tg 0.5882353186607361
    ti:f32[10000] = sub gt fz
    tj:f32[10000] = mul ti 0.37037035822868347
    tk:f32[2500] = sub nx nd
    tl:f32[2500] = mul tk 0.37037035822868347
    debug_callback[
      callback=<function debug_callback.<locals>._flat_callback at 0x000001A9A3D22FC0>
      effect=Debug
    ]
  in (dd, cu, rw, kr, rx, ks, rv, kt, ln, cs, db, cw, bl, bm, cy, ej, fd, bq, mh,
    fz, gt, nd, nx, bw, hp, ij, ot, pn, tf, th, sd, sj, ks, sz, kr, sp, tj, tl, sv,
    rw, td, rv) }, ())

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "C:\Users\Vanilla\Desktop\03-Spatial_Embedded_SNN\code\simulate\run.py", line 722, in <module>
    snn_analyzer = SNN_analyzer(grid_num, grid_distance, E_conn_grid_num, I_conn_grid_num, wE, wI, sigmaE, E_inp_func, I_inp_func, time_period)
  File "C:\Users\Vanilla\Desktop\03-Spatial_Embedded_SNN\code\simulate\run.py", line 615, in __init__
    self.runner.predict(duration=time_period)
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\runners.py", line 484, in predict
    outputs, hists = self._predict(indices, *inputs, shared_args=shared_args)
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\runners.py", line 538, in _predict
    outs_and_mons = self._fun_predict(indices, *xs, shared_args=shared_args)
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\runners.py", line 666, in _fun_predict
    return bm.for_loop(self._step_func_predict,
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\math\object_transform\controls.py", line 891, in for_loop
    dyn_vals, out_vals = transform(operands)
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\math\object_transform\controls.py", line 736, in call
    return jax.lax.scan(f=fun2scan,
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\math\object_transform\controls.py", line 727, in fun2scan
    results = body_fun(*x, **unroll_kwargs)
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\runners.py", line 627, in _step_func_predict
    out = self.target(*x)
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\dynsys.py", line 421, in __call__
    ret = self.update(*args, **kwargs)
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\dynsys.py", line 370, in _compatible_update
    return update_fun(*args, **kwargs)
  File "C:\Users\Vanilla\Desktop\03-Spatial_Embedded_SNN\code\simulate\run.py", line 433, in update
    self.E2E()
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\dynsys.py", line 421, in __call__
    ret = self.update(*args, **kwargs)
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\dynsys.py", line 370, in _compatible_update
    return update_fun(*args, **kwargs)
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\dynsys.py", line 605, in update
    node.update(*args, **kwargs)
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\dynsys.py", line 370, in _compatible_update
    return update_fun(*args, **kwargs)
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\dyn\projections\align_post.py", line 273, in update
    current = self.comm(x)
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\dynsys.py", line 421, in __call__
    ret = self.update(*args, **kwargs)
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\dynsys.py", line 370, in _compatible_update
    return update_fun(*args, **kwargs)
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\dnn\linear.py", line 714, in update
    return bm.event.csrmv(self.weight, self.indices, self.indptr, x,
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\math\event\csr_matvec.py", line 68, in csrmv
    return bti.event_csrmv(data, indices, indptr, events, shape=shape, transpose=transpose)
  File "C:\Program Files\Python311\Lib\site-packages\braintaichi\_eventop\main.py", line 134, in event_csrmv
    return event_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose)[0]
  File "C:\Program Files\Python311\Lib\site-packages\braintaichi\_eventop\csrmv.py", line 91, in event_csrmv_taichi
    return prim(
  File "C:\Program Files\Python311\Lib\site-packages\braintaichi\_primitive\_xla_custom_op.py", line 116, in __call__
    return self.primitive.bind(*ins, outs=outs, **kwargs)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: RuntimeError: The CPU kernels do not build correctly. Please check the installation of braintaichi.

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "C:\Users\Vanilla\Desktop\03-Spatial_Embedded_SNN\code\simulate\run.py", line 722, in <module>
    snn_analyzer = SNN_analyzer(grid_num, grid_distance, E_conn_grid_num, I_conn_grid_num, wE, wI, sigmaE, E_inp_func, I_inp_func, time_period)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Vanilla\Desktop\03-Spatial_Embedded_SNN\code\simulate\run.py", line 615, in __init__
    self.runner.predict(duration=time_period)
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\runners.py", line 484, in predict
    outputs, hists = self._predict(indices, *inputs, shared_args=shared_args)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\runners.py", line 538, in _predict
    outs_and_mons = self._fun_predict(indices, *xs, shared_args=shared_args)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\runners.py", line 666, in _fun_predict
    return bm.for_loop(self._step_func_predict,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\math\object_transform\controls.py", line 891, in for_loop
    dyn_vals, out_vals = transform(operands)
                         ^^^^^^^^^^^^^^^^^^^
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\math\object_transform\controls.py", line 736, in call
    return jax.lax.scan(f=fun2scan,
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Program Files\Python311\Lib\site-packages\braintaichi\_primitive\_mlir_translation_rule.py", line 441, in _taichi_mlir_cpu_translation_rule
    raise RuntimeError(
RuntimeError: The CPU kernels do not build correctly. Please check the installation of braintaichi.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Packages
Package                   Version
------------------------- ------------------
albumentations            1.3.1
altgraph                  0.17.4
asttokens                 2.4.1
brainpy                   2.6.0.post20241205
brainstate                0.1.0.post20241210
braintaichi               0.0.3
brainunit                 0.0.3.post20241211
certifi                   2023.7.22
charset-normalizer        3.3.0
colorama                  0.4.6
colorlog                  6.8.2
comm                      0.1.4
contourpy                 1.1.1
cycler                    0.12.1
debugpy                   1.8.0
decorator                 5.1.1
dictdiffer                0.9.0
dill                      0.3.9
einops                    0.7.0
entmax                    1.1
et-xmlfile                1.1.0
executing                 2.0.1
filelock                  3.12.4
fonttools                 4.43.1
fsspec                    2023.9.2
huggingface-hub           0.17.3
idna                      3.4
imageio                   2.31.5
ipykernel                 6.26.0
ipython                   8.17.1
ipywidgets                8.1.5
jax                       0.4.37
jaxlib                    0.4.36
jedi                      0.19.1
Jinja2                    3.1.2
joblib                    1.3.2
jupyter_client            8.5.0
jupyter_core              5.5.0
jupyterlab_widgets        3.0.13
kiwisolver                1.4.5
lazy_loader               0.3
llvmlite                  0.43.0
lxml                      4.9.3
markdown-it-py            3.0.0
MarkupSafe                2.1.3
matplotlib                3.8.0
matplotlib-inline         0.1.6
mdurl                     0.1.2
ml_dtypes                 0.5.0
mpmath                    1.3.0
munch                     4.0.0
nest-asyncio              1.5.8
networkx                  3.1
numba                     0.60.0
numpy                     1.24.4
opencv-python-headless    4.10.0.84
openpyxl                  3.1.2
opt_einsum                3.4.0
packaging                 24.0
pandas                    2.1.1
parse                     1.20.2
parso                     0.8.3
pefile                    2023.2.7
Pillow                    10.0.1
pip                       24.3.1
pix2tex                   0.1.2
platformdirs              3.11.0
prompt-toolkit            3.0.39
psutil                    5.9.8
pure-eval                 0.2.2
pycocotools               2.0.8
pycryptodome              3.20.0
Pygments                  2.18.0
pyinstaller               6.6.0
pyinstaller-hooks-contrib 2024.6
Pymem                     1.13.1
pynput                    1.7.6
pyparsing                 3.1.1
PyQt6                     6.7.0
PyQt6-Qt6                 6.7.0
PyQt6-sip                 13.6.0
PyQt6-WebEngine           6.5.0
PyQt6-WebEngine-Qt6       6.5.3
pyreadline3               3.4.1
PySide6                   6.5.3
PySide6-Addons            6.5.3
PySide6-Essentials        6.5.3
pystache                  0.6.5
python-dateutil           2.8.2
pytz                      2023.3.post1
pywin32                   306
pywin32-ctypes            0.2.2
PyYAML                    6.0.2
pyzmq                     25.1.1
qudida                    0.0.4
regex                     2023.10.3
requests                  2.31.0
resolvelib                1.0.1
rich                      13.9.4
ruamel.yaml               0.18.6
ruamel.yaml.clib          0.2.8
safetensors               0.4.0
scikit-image              0.22.0
scikit-learn              1.3.1
scipy                     1.14.1
screeninfo                0.8.1
setuptools                65.5.0
shiboken6                 6.5.3
six                       1.16.0
stack-data                0.6.3
sympy                     1.12
taichi                    1.7.2
threadpoolctl             3.2.0
tifffile                  2023.9.26
timm                      0.5.4
tokenizers                0.14.1
torch                     2.1.0
torchaudio                2.1.0
torchvision               0.16.0
tornado                   6.3.3
tqdm                      4.67.1
traitlets                 5.13.0
transformers              4.34.0
typing_extensions         4.12.2
tzdata                    2023.3
urllib3                   2.0.6
watchdog                  4.0.0
wcwidth                   0.2.9
widgetsnbextension        4.0.13
x-transformers            0.15.0

@LizhengyuSJTU LizhengyuSJTU added the bug Something isn't working label Dec 12, 2024
@chaoming0625
Copy link
Collaborator

seems your jax version is too low?

@Routhleck
Copy link
Member

This error was cause by:

if cpu_ops is None:
        raise RuntimeError(
            'The CPU kernels do not build correctly. '
            'Please check the installation of braintaichi.'
        )

That's means Related binaries of braintaichi were not loaded successfully, how did you install braintaichi?

@chaoming0625
Copy link
Collaborator

See new version released #50

@LizhengyuSJTU
Copy link
Author

Via pip, pip install brain[cpu].

@LizhengyuSJTU
Copy link
Author

LizhengyuSJTU commented Dec 13, 2024

Sad

KeyError: (None, let _where = { lambda ; a:bool[10000] b:f32[] c:f32[10000].
KeyError: (None, let _where = { lambda ; a:bool[10000] b:f32[] c:f32[10000]. let
    d:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
    e:f32[10000] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(10000,)
      sharding=None
    ] d
    f:f32[10000] = select_n a c e
  in (f,) } in
let _where1 = { lambda ; g:bool[2500] h:f32[] i:f32[2500]. let
    j:f32[] = convert_element_type[new_dtype=float32 weak_type=False] h
    k:f32[2500] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(2500,)
      sharding=None
    ] j
    l:f32[2500] = select_n g i k
  in (l,) } in
{ lambda ; m:f32[40530000] n:i32[40530000] o:i32[10001] p:f32[10132500] q:i32[10132500]
    r:i32[10001] s:f32[1102500] t:i32[1102500] u:i32[2501] v:f32[282500] w:i32[282500]
    x:i32[2501] y:f32[10000] z:f32[10000] ba:bool[10000] bb:f32[1] bc:bool[2500]
    bd:f32[10000] be:f32[10000] bf:f32[10000] bg:f32[2500] bh:f32[2500] bi:f32[10000]
    bj:f32[2500] bk:f32[2500] bl:f32[1] bm:f32[2500] bn:f32[10000] bo:f32[2500] bp:f32[10000]
    bq:f32[10000] br:f32[2500] bs:f32[2500] bt:f32[10000] bu:f32[10000] bv:f32[2500]
    bw:f32[10000] bx:f32[2500] by:bool[10000] bz:f32[2500] ca:bool[2500] cb:i32[]. let
    cc:f32[] = convert_element_type[new_dtype=float32 weak_type=True] cb
    cd:f32[] = mul cc 0.1
    ce:f32[] = add 0.0 cd
    cf:f32[10000] = braintaichi_custom_op_43[
      float_as_event=True
      outs=(ShapedArray(float32[10000]),)
      shape=(10000, 10000)
      transpose=True
    ] m n o ba
    cg:f32[10000] = add be cf
    ch:f32[10000] = add bd cf
    ci:f32[2500] = braintaichi_custom_op_43[
      float_as_event=True
      outs=(ShapedArray(float32[2500]),)
      shape=(10000, 2500)
      transpose=True
    ] p q r ba
    cj:f32[2500] = add bg ci
    ck:f32[2500] = add bh ci
    cl:f32[10000] = braintaichi_custom_op_43[
      float_as_event=True
      outs=(ShapedArray(float32[10000]),)
      shape=(2500, 10000)
      transpose=True
    ] s t u ca
    cm:f32[10000] = add bt cl
    cn:f32[10000] = add bi cl
    co:f32[2500] = braintaichi_custom_op_43[
      float_as_event=True
      outs=(ShapedArray(float32[2500]),)
      shape=(2500, 2500)
      transpose=True
    ] v w x ca
    cp:f32[2500] = add bj co
    cq:f32[2500] = add bk co
    cr:f32[10000] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(10000,)
      sharding=None
    ] 1.0
    cs:f32[10000] = mul cr 0.019999999552965164
    ct:f32[2500] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(2500,)
      sharding=None
    ] 1.0
    cu:f32[2500] = mul ct 0.019999999552965164
    cv:f32[10000] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(10000,)
      sharding=None
    ] 1.0
    cw:f32[10000] = mul cv 0.019999999552965164
    cx:f32[2500] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(2500,)
      sharding=None
    ] 1.0
    cy:f32[2500] = mul cx 0.019999999552965164
    cz:f32[10000] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(10000,)
      sharding=None
    ] 1.0
    da:f32[10000] = device_put[
      copy_semantics=[<CopySemantics.ALIAS: 1>]
      devices=[None]
      srcs=[None]
    ] y
    db:f32[10000] = mul cz da
    dc:f32[2500] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(2500,)
      sharding=None
    ] 1.0
    dd:f32[2500] = mul dc 0.0
    de:i32[10000] = convert_element_type[new_dtype=int32 weak_type=True] by
    df:i32[10000] = mul de 10000
    dg:f32[10000] = convert_element_type[new_dtype=float32 weak_type=False] df
    dh:f32[10000] = add bp dg
    di:f32[10000] = convert_element_type[new_dtype=float32 weak_type=False] df
    dj:f32[10000] = add bq di
    dk:i32[2500] = convert_element_type[new_dtype=int32 weak_type=True] bc
    dl:i32[2500] = mul dk 2500
    dm:f32[2500] = convert_element_type[new_dtype=float32 weak_type=False] dl
    dn:f32[2500] = add br dm
    do:f32[2500] = convert_element_type[new_dtype=float32 weak_type=False] dl
    dp:f32[2500] = add bs do
    dq:f32[10000] = neg cg
    dr:f32[10000] = div dq 0.30000001192092896
    ds:f32[10000] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(10000,)
      sharding=None
    ] 1.0
    dt:f32[10000] = div ds 0.30000001192092896
    du:f32[10000] = neg dt
    dv:f32[10000] = mul 0.10000000149011612 du
    dw:f32[10000] = abs dv
    dx:bool[10000] = le dw 9.999999747378752e-06
    dy:f32[10000] = div dv 2.0
    dz:f32[10000] = add 1.0 dy
    ea:f32[10000] = mul dv dv
    eb:f32[10000] = div ea 6.0
    ec:f32[10000] = add dz eb
    ed:f32[10000] = exp dv
    ee:f32[10000] = sub ed 1.0
    ef:f32[10000] = div ee dv
    eg:f32[10000] = select_n dx ef ec
    eh:f32[10000] = mul 0.10000000149011612 eg
    ei:f32[10000] = mul eh dr
    ej:f32[10000] = add cg ei
    ek:f32[10000] = neg ch
    el:f32[10000] = div ek 2.0
    em:f32[10000] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(10000,)
      sharding=None
    ] 1.0
    en:f32[10000] = div em 2.0
    eo:f32[10000] = neg en
    ep:f32[10000] = mul 0.10000000149011612 eo
    eq:f32[10000] = abs ep
    er:bool[10000] = le eq 9.999999747378752e-06
    es:f32[10000] = div ep 2.0
    et:f32[10000] = add 1.0 es
    eu:f32[10000] = mul ep ep
    ev:f32[10000] = div eu 6.0
    ew:f32[10000] = add et ev
    ex:f32[10000] = exp ep
    ey:f32[10000] = sub ex 1.0
    ez:f32[10000] = div ey ep
    fa:f32[10000] = select_n er ez ew
    fb:f32[10000] = mul 0.10000000149011612 fa
    fc:f32[10000] = mul fb el
    fd:f32[10000] = add ch fc
    fe:f32[10000] = sub fd ej
    ff:f32[10000] = mul 0.5882353186607361 fe
    fg:f32[10000] = neg cm
    fh:f32[10000] = div fg 0.30000001192092896
    fi:f32[10000] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(10000,)
      sharding=None
    ] 1.0
    fj:f32[10000] = div fi 0.30000001192092896
    fk:f32[10000] = neg fj
    fl:f32[10000] = mul 0.10000000149011612 fk
    fm:f32[10000] = abs fl
    fn:bool[10000] = le fm 9.999999747378752e-06
    fo:f32[10000] = div fl 2.0
    fp:f32[10000] = add 1.0 fo
    fq:f32[10000] = mul fl fl
    fr:f32[10000] = div fq 6.0
    fs:f32[10000] = add fp fr
    ft:f32[10000] = exp fl
    fu:f32[10000] = sub ft 1.0
    fv:f32[10000] = div fu fl
    fw:f32[10000] = select_n fn fv fs
    fx:f32[10000] = mul 0.10000000149011612 fw
    fy:f32[10000] = mul fx fh
    fz:f32[10000] = add cm fy
    ga:f32[10000] = neg cn
    gb:f32[10000] = div ga 3.0
    gc:f32[10000] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(10000,)
      sharding=None
    ] 1.0
    gd:f32[10000] = div gc 3.0
    ge:f32[10000] = neg gd
    gf:f32[10000] = mul 0.10000000149011612 ge
    gg:f32[10000] = abs gf
    gh:bool[10000] = le gg 9.999999747378752e-06
    gi:f32[10000] = div gf 2.0
    gj:f32[10000] = add 1.0 gi
    gk:f32[10000] = mul gf gf
    gl:f32[10000] = div gk 6.0
    gm:f32[10000] = add gj gl
    gn:f32[10000] = exp gf
    go:f32[10000] = sub gn 1.0
    gp:f32[10000] = div go gf
    gq:f32[10000] = select_n gh gp gm
    gr:f32[10000] = mul 0.10000000149011612 gq
    gs:f32[10000] = mul gr gb
    gt:f32[10000] = add cn gs
    gu:f32[10000] = sub gt fz
    gv:f32[10000] = mul 0.37037035822868347 gu
    gw:f32[10000] = neg dh
    gx:f32[10000] = div gw 0.30000001192092896
    gy:f32[10000] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(10000,)
      sharding=None
    ] 1.0
    gz:f32[10000] = div gy 0.30000001192092896
    ha:f32[10000] = neg gz
    hb:f32[10000] = mul 0.10000000149011612 ha
    hc:f32[10000] = abs hb
    hd:bool[10000] = le hc 9.999999747378752e-06
    he:f32[10000] = div hb 2.0
    hf:f32[10000] = add 1.0 he
    hg:f32[10000] = mul hb hb
    hh:f32[10000] = div hg 6.0
    hi:f32[10000] = add hf hh
    hj:f32[10000] = exp hb
    hk:f32[10000] = sub hj 1.0
    hl:f32[10000] = div hk hb
    hm:f32[10000] = select_n hd hl hi
    hn:f32[10000] = mul 0.10000000149011612 hm
    ho:f32[10000] = mul hn gx
    hp:f32[10000] = add dh ho
    hq:f32[10000] = neg dj
    hr:f32[10000] = div hq 2.0
    hs:f32[10000] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(10000,)
      sharding=None
    ] 1.0
    ht:f32[10000] = div hs 2.0
    hu:f32[10000] = neg ht
    hv:f32[10000] = mul 0.10000000149011612 hu
    hw:f32[10000] = abs hv
    hx:bool[10000] = le hw 9.999999747378752e-06
    hy:f32[10000] = div hv 2.0
    hz:f32[10000] = add 1.0 hy
    ia:f32[10000] = mul hv hv
    ib:f32[10000] = div ia 6.0
    ic:f32[10000] = add hz ib
    id:f32[10000] = exp hv
    ie:f32[10000] = sub id 1.0
    if:f32[10000] = div ie hv
    ig:f32[10000] = select_n hx if ic
    ih:f32[10000] = mul 0.10000000149011612 ig
    ii:f32[10000] = mul ih hr
    ij:f32[10000] = add dj ii
    ik:f32[10000] = sub ij hp
    il:f32[10000] = mul 0.5882353186607361 ik
    im:f32[10000] = sub 0.0 z
    in:f32[10000] = mul ff im
    io:f32[10000] = add in 0.0
    ip:f32[10000] = sub -80.0 z
    iq:f32[10000] = mul gv ip
    ir:f32[10000] = add io iq
    is:f32[10000] = sub 0.0 z
    it:f32[10000] = mul cs is
    iu:f32[10000] = add ir it
    iv:f32[10000] = sub -80.0 z
    iw:f32[10000] = mul cw iv
    ix:f32[10000] = add iu iw
    iy:f32[10000] = add ix db
    iz:f32[10000] = sub 0.0 z
    ja:f32[10000] = mul il iz
    jb:f32[10000] = add iy ja
    jc:f32[10000] = sub z -60.625
    jd:f32[10000] = div jc 6.5625
    je:f32[10000] = exp jd
    jf:f32[10000] = mul 6.5625 je
    jg:f32[10000] = sub z -70.0
    jh:f32[10000] = neg jg
    ji:f32[10000] = add jh jf
    jj:f32[10000] = mul 1.0 jb
    jk:f32[10000] = add ji jj
    jl:f32[10000] = div jk 20.0
    jm:f32[10000] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(10000,)
      sharding=None
    ] 1.0
    jn:f32[10000] = div jm 20.0
    jo:f32[10000] = mul 6.5625 jn
    jp:f32[10000] = mul jo je
    jq:f32[10000] = div jp 6.5625
    jr:f32[10000] = neg jn
    js:f32[10000] = add_any jq jr
    jt:f32[10000] = mul 0.10000000149011612 js
    ju:f32[10000] = abs jt
    jv:bool[10000] = le ju 9.999999747378752e-06
    jw:f32[10000] = div jt 2.0
    jx:f32[10000] = add 1.0 jw
    jy:f32[10000] = mul jt jt
    jz:f32[10000] = div jy 6.0
    ka:f32[10000] = add jx jz
    kb:f32[10000] = exp jt
    kc:f32[10000] = sub kb 1.0
    kd:f32[10000] = div kc jt
    ke:f32[10000] = select_n jv kd ka
    kf:f32[10000] = mul 0.10000000149011612 ke
    kg:f32[10000] = mul kf jl
    kh:f32[10000] = add z kg
    ki:f32[10000] = add kh 0.0
    kj:f32[] = convert_element_type[new_dtype=float32 weak_type=False] ce
    kk:f32[10000] = sub kj bw
    kl:bool[10000] = le kk 5.0
    km:f32[10000] = pjit[
      name=_where
      jaxpr={ lambda ; kn:bool[10000] ko:f32[10000] kp:f32[10000]. let
          kq:f32[10000] = select_n kn kp ko
        in (kq,) }
    ] kl z ki
    kr:bool[10000] = ge km -40.0
    ks:f32[10000] = pjit[name=_where jaxpr=_where] kr -70.0 km
    kt:f32[10000] = pjit[name=_where jaxpr=_where] kr ce bw
    ku:f32[2500] = neg cj
    kv:f32[2500] = div ku 0.30000001192092896
    kw:f32[2500] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(2500,)
      sharding=None
    ] 1.0
    kx:f32[2500] = div kw 0.30000001192092896
    ky:f32[2500] = neg kx
    kz:f32[2500] = mul 0.10000000149011612 ky
    la:f32[2500] = abs kz
    lb:bool[2500] = le la 9.999999747378752e-06
    lc:f32[2500] = div kz 2.0
    ld:f32[2500] = add 1.0 lc
    le:f32[2500] = mul kz kz
    lf:f32[2500] = div le 6.0
    lg:f32[2500] = add ld lf
    lh:f32[2500] = exp kz
    li:f32[2500] = sub lh 1.0
    lj:f32[2500] = div li kz
    lk:f32[2500] = select_n lb lj lg
    ll:f32[2500] = mul 0.10000000149011612 lk
    lm:f32[2500] = mul ll kv
    ln:f32[2500] = add cj lm
    lo:f32[2500] = neg ck
    lp:f32[2500] = div lo 2.0
    lq:f32[2500] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(2500,)
      sharding=None
    ] 1.0
    lr:f32[2500] = div lq 2.0
    ls:f32[2500] = neg lr
    lt:f32[2500] = mul 0.10000000149011612 ls
    lu:f32[2500] = abs lt
    lv:bool[2500] = le lu 9.999999747378752e-06
    lw:f32[2500] = div lt 2.0
    lx:f32[2500] = add 1.0 lw
    ly:f32[2500] = mul lt lt
    lz:f32[2500] = div ly 6.0
    ma:f32[2500] = add lx lz
    mb:f32[2500] = exp lt
    mc:f32[2500] = sub mb 1.0
    md:f32[2500] = div mc lt
    me:f32[2500] = select_n lv md ma
    mf:f32[2500] = mul 0.10000000149011612 me
    mg:f32[2500] = mul mf lp
    mh:f32[2500] = add ck mg
    mi:f32[2500] = sub mh ln
    mj:f32[2500] = mul 0.5882353186607361 mi
    mk:f32[2500] = neg cp
    ml:f32[2500] = div mk 0.30000001192092896
    mm:f32[2500] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(2500,)
      sharding=None
    ] 1.0
    mn:f32[2500] = div mm 0.30000001192092896
    mo:f32[2500] = neg mn
    mp:f32[2500] = mul 0.10000000149011612 mo
    mq:f32[2500] = abs mp
    mr:bool[2500] = le mq 9.999999747378752e-06
    ms:f32[2500] = div mp 2.0
    mt:f32[2500] = add 1.0 ms
    mu:f32[2500] = mul mp mp
    mv:f32[2500] = div mu 6.0
    mw:f32[2500] = add mt mv
    mx:f32[2500] = exp mp
    my:f32[2500] = sub mx 1.0
    mz:f32[2500] = div my mp
    na:f32[2500] = select_n mr mz mw
    nb:f32[2500] = mul 0.10000000149011612 na
    nc:f32[2500] = mul nb ml
    nd:f32[2500] = add cp nc
    ne:f32[2500] = neg cq
    nf:f32[2500] = div ne 3.0
    ng:f32[2500] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(2500,)
      sharding=None
    ] 1.0
    nh:f32[2500] = div ng 3.0
    ni:f32[2500] = neg nh
    nj:f32[2500] = mul 0.10000000149011612 ni
    nk:f32[2500] = abs nj
    nl:bool[2500] = le nk 9.999999747378752e-06
    nm:f32[2500] = div nj 2.0
    nn:f32[2500] = add 1.0 nm
    no:f32[2500] = mul nj nj
    np:f32[2500] = div no 6.0
    nq:f32[2500] = add nn np
    nr:f32[2500] = exp nj
    ns:f32[2500] = sub nr 1.0
    nt:f32[2500] = div ns nj
    nu:f32[2500] = select_n nl nt nq
    nv:f32[2500] = mul 0.10000000149011612 nu
    nw:f32[2500] = mul nv nf
    nx:f32[2500] = add cq nw
    ny:f32[2500] = sub nx nd
    nz:f32[2500] = mul 0.37037035822868347 ny
    oa:f32[2500] = neg dn
    ob:f32[2500] = div oa 0.30000001192092896
    oc:f32[2500] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(2500,)
      sharding=None
    ] 1.0
    od:f32[2500] = div oc 0.30000001192092896
    oe:f32[2500] = neg od
    of:f32[2500] = mul 0.10000000149011612 oe
    og:f32[2500] = abs of
    oh:bool[2500] = le og 9.999999747378752e-06
    oi:f32[2500] = div of 2.0
    oj:f32[2500] = add 1.0 oi
    ok:f32[2500] = mul of of
    ol:f32[2500] = div ok 6.0
    om:f32[2500] = add oj ol
    on:f32[2500] = exp of
    oo:f32[2500] = sub on 1.0
    op:f32[2500] = div oo of
    oq:f32[2500] = select_n oh op om
    or:f32[2500] = mul 0.10000000149011612 oq
    os:f32[2500] = mul or ob
    ot:f32[2500] = add dn os
    ou:f32[2500] = neg dp
    ov:f32[2500] = div ou 2.0
    ow:f32[2500] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(2500,)
      sharding=None
    ] 1.0
    ox:f32[2500] = div ow 2.0
    oy:f32[2500] = neg ox
    oz:f32[2500] = mul 0.10000000149011612 oy
    pa:f32[2500] = abs oz
    pb:bool[2500] = le pa 9.999999747378752e-06
    pc:f32[2500] = div oz 2.0
    pd:f32[2500] = add 1.0 pc
    pe:f32[2500] = mul oz oz
    pf:f32[2500] = div pe 6.0
    pg:f32[2500] = add pd pf
    ph:f32[2500] = exp oz
    pi:f32[2500] = sub ph 1.0
    pj:f32[2500] = div pi oz
    pk:f32[2500] = select_n pb pj pg
    pl:f32[2500] = mul 0.10000000149011612 pk
    pm:f32[2500] = mul pl ov
    pn:f32[2500] = add dp pm
    po:f32[2500] = sub pn ot
    pp:f32[2500] = mul 0.5882353186607361 po
    pq:f32[2500] = sub 0.0 bz
    pr:f32[2500] = mul mj pq
    ps:f32[2500] = add pr 0.0
    pt:f32[2500] = sub -80.0 bz
    pu:f32[2500] = mul nz pt
    pv:f32[2500] = add ps pu
    pw:f32[2500] = sub 0.0 bz
    px:f32[2500] = mul cu pw
    py:f32[2500] = add pv px
    pz:f32[2500] = sub -80.0 bz
    qa:f32[2500] = mul cy pz
    qb:f32[2500] = add py qa
    qc:f32[2500] = add qb dd
    qd:f32[2500] = sub 0.0 bz
    qe:f32[2500] = mul pp qd
    qf:f32[2500] = add qc qe
    qg:f32[2500] = sub bz -60.625
    qh:f32[2500] = div qg 6.5625
    qi:f32[2500] = exp qh
    qj:f32[2500] = mul 6.5625 qi
    qk:f32[2500] = sub bz -70.0
    ql:f32[2500] = neg qk
    qm:f32[2500] = add ql qj
    qn:f32[2500] = mul 1.0 qf
    qo:f32[2500] = add qm qn
    qp:f32[2500] = div qo 20.0
    qq:f32[2500] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(2500,)
      sharding=None
    ] 1.0
    qr:f32[2500] = div qq 20.0
    qs:f32[2500] = mul 6.5625 qr
    qt:f32[2500] = mul qs qi
    qu:f32[2500] = div qt 6.5625
    qv:f32[2500] = neg qr
    qw:f32[2500] = add_any qu qv
    qx:f32[2500] = mul 0.10000000149011612 qw
    qy:f32[2500] = abs qx
    qz:bool[2500] = le qy 9.999999747378752e-06
    ra:f32[2500] = div qx 2.0
    rb:f32[2500] = add 1.0 ra
    rc:f32[2500] = mul qx qx
    rd:f32[2500] = div rc 6.0
    re:f32[2500] = add rb rd
    rf:f32[2500] = exp qx
    rg:f32[2500] = sub rf 1.0
    rh:f32[2500] = div rg qx
    ri:f32[2500] = select_n qz rh re
    rj:f32[2500] = mul 0.10000000149011612 ri
    rk:f32[2500] = mul rj qp
    rl:f32[2500] = add bz rk
    rm:f32[2500] = add rl 0.0
    rn:f32[] = convert_element_type[new_dtype=float32 weak_type=False] ce
    ro:f32[2500] = sub rn bx
    rp:bool[2500] = le ro 5.0
    rq:f32[2500] = pjit[
      name=_where
      jaxpr={ lambda ; rr:bool[2500] rs:f32[2500] rt:f32[2500]. let
          ru:f32[2500] = select_n rr rt rs
        in (ru,) }
    ] rp bz rm
    rv:bool[2500] = ge rq -40.0
    rw:f32[2500] = pjit[name=_where jaxpr=_where1] rv -70.0 rq
    rx:f32[2500] = pjit[name=_where jaxpr=_where1] rv ce bx
    ry:f32[10000] = sub 0.0 ks
    rz:f32[10000] = mul ff ry
    sa:f32[10000] = add rz 0.0
    sb:f32[10000] = sub 0.0 ks
    sc:f32[10000] = mul cs sb
    sd:f32[10000] = add sa sc
    se:f32[10000] = sub -80.0 ks
    sf:f32[10000] = mul gv se
    sg:f32[10000] = add sf 0.0
    sh:f32[10000] = sub -80.0 ks
    si:f32[10000] = mul cw sh
    sj:f32[10000] = add sg si
    sk:f32[2500] = sub 0.0 rw
    sl:f32[2500] = mul mj sk
    sm:f32[2500] = add sl 0.0
    sn:f32[2500] = sub 0.0 rw
    so:f32[2500] = mul cu sn
    sp:f32[2500] = add sm so
    sq:f32[2500] = sub -80.0 rw
    sr:f32[2500] = mul nz sq
    ss:f32[2500] = add sr 0.0
    st:f32[2500] = sub -80.0 rw
    su:f32[2500] = mul cy st
    sv:f32[2500] = add ss su
    sw:f32[10000] = add 0.0 db
    sx:f32[10000] = sub 0.0 ks
    sy:f32[10000] = mul il sx
    sz:f32[10000] = add sw sy
    ta:f32[2500] = add 0.0 dd
    tb:f32[2500] = sub 0.0 rw
    tc:f32[2500] = mul pp tb
    td:f32[2500] = add ta tc
    te:f32[10000] = sub fd ej
    tf:f32[10000] = mul te 0.5882353186607361
    tg:f32[2500] = sub mh ln
    th:f32[2500] = mul tg 0.5882353186607361
    ti:f32[10000] = sub gt fz
    tj:f32[10000] = mul ti 0.37037035822868347
    tk:f32[2500] = sub nx nd
    tl:f32[2500] = mul tk 0.37037035822868347
    debug_callback[
      callback=<function debug_callback.<locals>._flat_callback at 0x000001B7A602B7E0>
      effect=Debug
    ]
  in (ks, kr, bb, bc, fd, ej, cs, ln, mh, gt, nd, nx, bl, cy, db, dd, hp, ij, ot,
    pn, fz, cw, cu, kt, rx, by, rw, rv, tf, th, sd, sj, ks, sz, kr, sp, tj, tl, sv,
    rw, td, rv) }, ())

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "C:\Users\Vanilla\Desktop\03-Spatial_Embedded_SNN\code\simulate\run.py", line 722, in <module>
    snn_analyzer = SNN_analyzer(grid_num, grid_distance, E_conn_grid_num, I_conn_grid_num, wE, wI, sigmaE, E_inp_func, I_inp_func, time_period)
  File "C:\Users\Vanilla\Desktop\03-Spatial_Embedded_SNN\code\simulate\run.py", line 615, in __init__
    self.runner.predict(duration=time_period)
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\runners.py", line 484, in predict
    outputs, hists = self._predict(indices, *inputs, shared_args=shared_args)
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\runners.py", line 538, in _predict
    outs_and_mons = self._fun_predict(indices, *xs, shared_args=shared_args)
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\runners.py", line 666, in _fun_predict
    return bm.for_loop(self._step_func_predict,
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\math\object_transform\controls.py", line 891, in for_loop
    dyn_vals, out_vals = transform(operands)
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\math\object_transform\controls.py", line 736, in call
    return jax.lax.scan(f=fun2scan,
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\math\object_transform\controls.py", line 727, in fun2scan
    results = body_fun(*x, **unroll_kwargs)
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\runners.py", line 627, in _step_func_predict
    out = self.target(*x)
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\dynsys.py", line 421, in __call__
    ret = self.update(*args, **kwargs)
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\dynsys.py", line 370, in _compatible_update
    return update_fun(*args, **kwargs)
  File "C:\Users\Vanilla\Desktop\03-Spatial_Embedded_SNN\code\simulate\run.py", line 433, in update
    self.E2E()
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\dynsys.py", line 421, in __call__
    ret = self.update(*args, **kwargs)
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\dynsys.py", line 370, in _compatible_update
    return update_fun(*args, **kwargs)
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\dynsys.py", line 605, in update
    node.update(*args, **kwargs)
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\dynsys.py", line 370, in _compatible_update
    return update_fun(*args, **kwargs)
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\dyn\projections\align_post.py", line 273, in update
    current = self.comm(x)
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\dynsys.py", line 421, in __call__
    ret = self.update(*args, **kwargs)
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\dynsys.py", line 370, in _compatible_update
    return update_fun(*args, **kwargs)
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\dnn\linear.py", line 714, in update
    return bm.event.csrmv(self.weight, self.indices, self.indptr, x,
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\math\event\csr_matvec.py", line 68, in csrmv
    return bti.event_csrmv(data, indices, indptr, events, shape=shape, transpose=transpose)
  File "C:\Program Files\Python311\Lib\site-packages\braintaichi\_eventop\main.py", line 134, in event_csrmv
    return event_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose)[0]
  File "C:\Program Files\Python311\Lib\site-packages\braintaichi\_eventop\csrmv.py", line 91, in event_csrmv_taichi
    return prim(
  File "C:\Program Files\Python311\Lib\site-packages\braintaichi\_primitive\_xla_custom_op.py", line 116, in __call__
    return self.primitive.bind(*ins, outs=outs, **kwargs)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: RuntimeError: The CPU kernels do not build correctly. Please check the installation of braintaichi.

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "C:\Users\Vanilla\Desktop\03-Spatial_Embedded_SNN\code\simulate\run.py", line 722, in <module>
    snn_analyzer = SNN_analyzer(grid_num, grid_distance, E_conn_grid_num, I_conn_grid_num, wE, wI, sigmaE, E_inp_func, I_inp_func, time_period)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Vanilla\Desktop\03-Spatial_Embedded_SNN\code\simulate\run.py", line 615, in __init__
    self.runner.predict(duration=time_period)
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\runners.py", line 484, in predict
    outputs, hists = self._predict(indices, *inputs, shared_args=shared_args)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\runners.py", line 538, in _predict
    outs_and_mons = self._fun_predict(indices, *xs, shared_args=shared_args)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\runners.py", line 666, in _fun_predict
    return bm.for_loop(self._step_func_predict,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\math\object_transform\controls.py", line 891, in for_loop
    dyn_vals, out_vals = transform(operands)
                         ^^^^^^^^^^^^^^^^^^^
  File "C:\Program Files\Python311\Lib\site-packages\brainpy\_src\math\object_transform\controls.py", line 736, in call
    return jax.lax.scan(f=fun2scan,
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Program Files\Python311\Lib\site-packages\braintaichi\_primitive\_mlir_translation_rule.py", line 441, in _taichi_mlir_cpu_translation_rule
    raise RuntimeError(
RuntimeError: The CPU kernels do not build correctly. Please check the installation of braintaichi.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

@Routhleck
Copy link
Member

Try the code below:

import braintaichi
print('cpu_ops' in dir(braintaichi))

Check whether the ouput is True

@LizhengyuSJTU
Copy link
Author

False
I don't know where cpu_ops is imported, but I'm certain that the importation falls.

@Routhleck
Copy link
Member

Try reinstall braintaichi

pip uninstall braintaichi -y
pip install -U braintaichi

@LizhengyuSJTU
Copy link
Author

Another bad news: False

@Routhleck
Copy link
Member

Could you provide the details of your device and operating system

@LizhengyuSJTU
Copy link
Author

LizhengyuSJTU commented Dec 16, 2024

Processor 12th Gen Intel(R) Core(TM) i7-12700H 2.30 GHz
RAM 16.0 GB (15.7 GB avaliable)
System Type 64 bOS, based on x64 processor
Version Windows 11 家庭中文版
Version No. 23H2
OS Version 22631.4602
体验 Windows 功能体验包 1000.22700.1055.0

@Routhleck
Copy link
Member

Have you try create another python environment and install braintaichi?

@LizhengyuSJTU
Copy link
Author

LizhengyuSJTU commented Dec 17, 2024

Well... I will try Python 3.10

@LizhengyuSJTU
Copy link
Author

Failed

WASTED

@chaoming0625
Copy link
Collaborator

Maybe you can try brainstate. It has different dependency which may solve your problem.

@LizhengyuSJTU
Copy link
Author

似乎不行,brainpy._src.runners.DSRunner.predict要用这个

@Routhleck
Copy link
Member

You might consider downgrading brainpy to a version that doesn't require the braintaichi dependency. You can do this by running:

pip install brainpy==2.6.0.post20241025

Alternatively, if you'd like, you can leave your email address, and we can arrange an online meeting to discuss and troubleshoot the issue together.

@LizhengyuSJTU
Copy link
Author

Thanks.
The program now gives out no error, but quits after a DSRunner.predict() call.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants