Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Import Error #169

Open
BLUE811420 opened this issue Feb 7, 2024 · 40 comments
Open

Import Error #169

BLUE811420 opened this issue Feb 7, 2024 · 40 comments

Comments

@BLUE811420
Copy link

ImportError: /home/ysn/anaconda3/envs/Mamba/lib/python3.8/site-packages/causal_conv1d_cuda.cpython-38-x86_64-linux-gnu.so: undefined symbol: _ZN3c107WarningC1ENS_7variantIJNS0_11UserWarningENS0_18DeprecationWarningEEEERKNS_14SourceLocationENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEEb
what is the reason? please for the answer

@sabeaussan
Copy link

It seems related to CUDA and causal_conv_1d, what is the output of nvidia-smi ? And what is you're pytorch version and causal_conv_1d version ?

@BLUE811420
Copy link
Author

BLUE811420 commented Feb 9, 2024 via email

@sabeaussan
Copy link

Nice, I guess you should close the issue then.

@AshleyLuo001
Copy link

yep,i solve it ,the causal_conv_1d version error,thank you

---- Replied Message ---- | From | @.> | | Date | 02/09/2024 15:38 | | To | state-spaces/mamba @.> | | Cc | BLUE811420 @.>, Author @.> | | Subject | Re: [state-spaces/mamba] Import Error (Issue #169) | It seems related to CUDA and causal_conv_1d, what is the output of nvidia-smi ? And what is you're pytorch version and causal_conv_1d version ? — Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>

How did you match the causal_conv_1d version ,please.

@BLUE811420
Copy link
Author

BLUE811420 commented Feb 10, 2024 via email

@AshleyLuo001
Copy link

both 1.1.1and cuda 11.8+torch2.1

---- Replied Message ---- | From | @.> | | Date | 02/10/2024 20:02 | | To | state-spaces/mamba @.> | | Cc | BLUE811420 @.>, Author @.> | | Subject | Re: [state-spaces/mamba] Import Error (Issue #169) | yep,i solve it ,the causal_conv_1d version error,thank you … ---- Replied Message ---- | From | @.> | | Date | 02/09/2024 15:38 | | To | state-spaces/mamba @.> | | Cc | BLUE811420 @.>, Author @.> | | Subject | Re: [state-spaces/mamba] Import Error (Issue #169) | It seems related to CUDA and causal_conv_1d, what is the output of nvidia-smi ? And what is you're pytorch version and causal_conv_1d version ? — Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.> How did you match the causal_conv_1d version ,please. — Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.>

thank you! I get it!

@zhoukexin-62
Copy link

both 1.1.1and cuda 11.8+torch2.1

---- Replied Message ---- | From | @.> | | Date | 02/10/2024 20:02 | | To | state-spaces/mamba @.> | | Cc | BLUE811420 @.>, Author @.> | | Subject | Re: [state-spaces/mamba] Import Error (Issue #169) | yep,i solve it ,the causal_conv_1d version error,thank you … ---- Replied Message ---- | From | @.> | | Date | 02/09/2024 15:38 | | To | state-spaces/mamba @.> | | Cc | BLUE811420 @.>, Author @.> | | Subject | Re: [state-spaces/mamba] Import Error (Issue #169) | It seems related to CUDA and causal_conv_1d, what is the output of nvidia-smi ? And what is you're pytorch version and causal_conv_1d version ? — Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.> How did you match the causal_conv_1d version ,please. — Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.>

image
i check the version,but still got problem in test。
selective_scan_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN3c107WarningC1ENS_7variantIJNS0_11UserWarningENS0_18DeprecationWarningEEEERKNS_14SourceLocationENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEEb

how can i do?

@BLUE811420
Copy link
Author

i think mamba_ssm abi should be false

@zhoukexin-62
Copy link

i think mamba_ssm abi should be false

how can i do to slove it?

@BLUE811420
Copy link
Author

try download mamba_ssm-1.0.1+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl, but im not sure it could be work

@zhoukexin-62
Copy link

i try it before,still can't work. so the mamba_ssm version should be same as the init.py file?and the causal_conv1d version should be same as the mamba_ssm?thx for u reply.

1 similar comment
@zhoukexin-62
Copy link

i try it before,still can't work. so the mamba_ssm version should be same as the init.py file?and the causal_conv1d version should be same as the mamba_ssm?thx for u reply.

@BLUE811420
Copy link
Author

this is my version
image

@BLUE811420
Copy link
Author

i think u are Chinese? we can have a wechat

@BLUE811420
Copy link
Author

make it

@BLUE811420
Copy link
Author

BLUE811420 commented Mar 1, 2024 via email

@CCECfgd
Copy link

CCECfgd commented Mar 2, 2024

i think u are Chinese? we can have a wechat

你好,我也存在环境方面的困扰,我们是否能通过微信交流下,十分感谢!

@zhoukexin-62
Copy link

i think u are Chinese? we can have a wechat

你好,我也存在环境方面的困扰,我们是否能通过微信交流下,十分感谢!

给个ID

@jorenretel
Copy link

jorenretel commented Mar 21, 2024

For anyone else landing here. The conclusion of this thread was not completely clear to me. I had basically the same problem with

selective_scan_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol

In my case, I discovered that I had some incompatible pip cache laying around. A simple:

pip uninstall mamba-ssm
pip install mamba-ssm --no-cache-dir

fixes the issue though.

@aaronshenhao
Copy link

For anyone else landing here. The conclusion of this thread was not completely clear to me. I had basically the same problem with

selective_scan_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol

In my case, I discovered that I had some incompatible pip cache laying around. A simple:

pip uninstall mamba-ssm
pip install mamba-ssm --no-cache-dir

fixes the issue though.

This worked for me! I was encountering the error in the official Colab notebook.

@zx1292982431
Copy link

For anyone else landing here. The conclusion of this thread was not completely clear to me. I had basically the same problem with

selective_scan_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol

In my case, I discovered that I had some incompatible pip cache laying around. A simple:

pip uninstall mamba-ssm
pip install mamba-ssm --no-cache-dir

fixes the issue though.

Thank you! It's worked for me! Have a nice day!😘

@MohitIntel
Copy link

Still see this issue on nvcr.io/nvidia/pytorch:23.11-py3
The above suggested --no-cache-dir did not work for me :(.
Here is my pip list:
`Package Version Editable project location


absl-py 2.0.0
accelerate 0.27.2
aiohttp 3.8.6
aiosignal 1.3.1
annotated-types 0.6.0
apex 0.1
argon2-cffi 23.1.0
argon2-cffi-bindings 21.2.0
asttokens 2.4.1
astunparse 1.6.3
async-timeout 4.0.3
attrs 23.1.0
audioread 3.0.1
beautifulsoup4 4.12.2
bleach 6.1.0
blis 0.7.11
cachetools 5.3.2
catalogue 2.0.10
causal-conv1d 1.2.0.post2
certifi 2023.11.17
cffi 1.16.0
charset-normalizer 3.3.1
click 8.1.7
cloudpathlib 0.16.0
cloudpickle 3.0.0
cmake 3.27.7
coloredlogs 15.0.1
comm 0.2.0
confection 0.1.3
contourpy 1.2.0
cubinlinker 0.3.0+2.g711d153
cuda-python 12.3.0rc4+8.ge6f99b5
cudf 23.10.0
cugraph 23.10.0
cugraph-dgl 23.10.0
cugraph-service-client 23.10.0
cugraph-service-server 23.10.0
cuml 23.10.0
cupy-cuda12x 12.2.0
cycler 0.12.1
cymem 2.0.8
Cython 3.0.5
dask 2023.9.2
dask-cuda 23.10.0
dask-cudf 23.10.0
datasets 2.18.0
debugpy 1.8.0
decorator 5.1.1
defusedxml 0.7.1
diffusers 0.26.3
dill 0.3.8
distributed 2023.9.2
dm-tree 0.1.8
einops 0.7.0
exceptiongroup 1.1.3
execnet 2.0.2
executing 2.0.1
expecttest 0.1.3
fastjsonschema 2.19.0
fastrlock 0.8.2
filelock 3.13.1
flash-attn 2.0.4
fonttools 4.45.0
frozenlist 1.4.0
fsspec 2023.10.0
gast 0.5.4
google-auth 2.23.4
google-auth-oauthlib 0.4.6
graphsurgeon 0.4.6
grpcio 1.59.3
huggingface-hub 0.22.2
humanfriendly 10.0
hypothesis 5.35.1
idna 3.4
importlib-metadata 6.8.0
iniconfig 2.0.0
intel-openmp 2021.4.0
ipykernel 6.26.0
ipython 8.17.2
ipython-genutils 0.2.0
jedi 0.19.1
Jinja2 3.1.2
joblib 1.3.2
json5 0.9.14
jsonschema 4.20.0
jsonschema-specifications 2023.11.1
jupyter_client 8.6.0
jupyter_core 5.5.0
jupyter-tensorboard 0.2.0
jupyterlab 2.3.2
jupyterlab-pygments 0.2.2
jupyterlab-server 1.2.0
jupytext 1.15.2
kiwisolver 1.4.5
langcodes 3.3.0
lazy_loader 0.3
librosa 0.10.1
llvmlite 0.40.1
locket 1.0.0
mamba-ssm 1.2.0
Markdown 3.5.1
markdown-it-py 3.0.0
MarkupSafe 2.1.3
matplotlib 3.8.2
matplotlib-inline 0.1.6
mdit-py-plugins 0.4.0
mdurl 0.1.2
mistune 3.0.2
mkl 2021.1.1
mkl-devel 2021.1.1
mkl-include 2021.1.1
mock 5.1.0
mpmath 1.3.0
msgpack 1.0.7
multidict 6.0.4
multiprocess 0.70.16
murmurhash 1.0.10
nbclient 0.9.0
nbconvert 7.11.0
nbformat 5.9.2
nest-asyncio 1.5.8
networkx 2.6.3
ninja 1.11.1.1
notebook 6.4.10
numba 0.57.1+1.gc2aae5dd0
numpy 1.24.4
nvfuser 0.0.21+gitunknown
nvidia-dali-cuda120 1.31.0
nvidia-pyindex 1.0.9
nvtx 0.2.5
oauthlib 3.2.2
onnx 1.14.1
opencv 4.7.0
optimum 1.18.1
optimum-habana 1.11.0 /root/mdeopujari/optimum-habana-fork
optree 0.10.0
packaging 23.2
pandas 1.5.3
pandocfilters 1.5.0
parso 0.8.3
partd 1.4.1
pexpect 4.8.0
Pillow 9.2.0
pip 23.3.1
platformdirs 4.0.0
pluggy 1.3.0
ply 3.11
polygraphy 0.49.1
pooch 1.8.0
preshed 3.0.9
prettytable 3.9.0
prometheus-client 0.18.0
prompt-toolkit 3.0.41
protobuf 4.24.4
psutil 5.9.4
ptxcompiler 0.8.1+2.g4c26c4c
ptyprocess 0.7.0
pure-eval 0.2.2
pyarrow 12.0.1
pyarrow-hotfix 0.6
pyasn1 0.5.1
pyasn1-modules 0.3.0
pybind11 2.11.1
pybind11-global 2.11.1
pycocotools 2.0+nv0.8.0
pycparser 2.21
pydantic 2.5.1
pydantic_core 2.14.3
Pygments 2.17.1
pylibcugraph 23.10.0
pylibcugraphops 23.10.0
pylibraft 23.10.0
pynvml 11.4.1
pyparsing 3.1.1
pytest 7.4.3
pytest-flakefinder 1.1.0
pytest-rerunfailures 12.0
pytest-shard 0.1.2
pytest-xdist 3.4.0
python-dateutil 2.8.2
python-hostlist 1.23.0
pytorch-quantization 2.1.2
pytz 2023.3.post1
PyYAML 6.0.1
pyzmq 25.1.1
raft-dask 23.10.0
referencing 0.31.0
regex 2023.10.3
requests 2.31.0
requests-oauthlib 1.3.1
rmm 23.10.0
rpds-py 0.13.1
rsa 4.9
safetensors 0.4.2
scikit-learn 1.2.0
scipy 1.11.3
Send2Trash 1.8.2
sentencepiece 0.2.0
setuptools 68.2.2
six 1.16.0
smart-open 6.4.0
sortedcontainers 2.4.0
soundfile 0.12.1
soupsieve 2.5
soxr 0.3.7
spacy 3.7.2
spacy-legacy 3.0.12
spacy-loggers 1.0.5
sphinx-glpi-theme 0.4.1
srsly 2.4.8
stack-data 0.6.3
sympy 1.12
tabulate 0.9.0
tbb 2021.11.0
tblib 3.0.0
tensorboard 2.9.0
tensorboard-data-server 0.6.1
tensorboard-plugin-wit 1.8.1
tensorrt 8.6.1
terminado 0.18.0
thinc 8.2.1
threadpoolctl 3.2.0
thriftpy2 0.4.17
tinycss2 1.2.1
tokenizers 0.15.2
toml 0.10.2
tomli 2.0.1
toolz 0.12.0
torch 2.2.0a0+6a974be
torch-tensorrt 2.2.0a0
torchdata 0.7.0a0
torchtext 0.16.0a0
torchvision 0.17.0a0
tornado 6.3.3
tqdm 4.66.1
traitlets 5.9.0
transformer-engine 1.0.0+66d91d5
transformers 4.39.0
treelite 3.9.1
treelite-runtime 3.9.1
triton 2.1.0+6e4932c
typer 0.9.0
types-dataclasses 0.6.6
typing_extensions 4.8.0
ucx-py 0.34.0
uff 0.6.9
urllib3 1.26.18
wasabi 1.1.2
wcwidth 0.2.10
weasel 0.3.4
webencodings 0.5.1
Werkzeug 3.0.1
wheel 0.41.3
xdoctest 1.0.2
xgboost 1.7.6
xxhash 3.4.1
yarl 1.9.2
zict 3.0.0
zipp 3.17.0`

@Kevin-naticl
Copy link

I have the same problem. Is it a problem of the match of version of causal-conv1d and mamba?
I installed these versions:
D:\Downloads\causal_conv1d-1.2.0.post2+cu118torch1.12cxx11abiTRUE-cp37-cp37m-linux_x86_64.whl
D:\Downloads\mamba_ssm-1.2.0.post1+cu118torch1.12cxx11abiTRUE-cp37-cp37m-linux_x86_64.whl

It still have the problem:
from mamba_ssm import mamba
File "/home/work_nfs10/bykang/Test/mamba_ssm/init.py", line 3, in
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
File "/home/work_nfs10/bykang/Test/mamba_ssm/ops/selective_scan_interface.py", line 16, in
import selective_scan_cuda
ImportError: /home/work_nfs7/bykang/bykang_env/envs/stcm/lib/python3.7/site-packages/selective_scan_cuda.cpython-37m-x86_64-linux-gnu.so: undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE

@Kevin-naticl
Copy link

And my pytorch is torch1.13+cu117

@uxhao-o
Copy link

uxhao-o commented Apr 18, 2024

I have been successfully run. Environment follows:
cuda 11.8
python 3.10.13
pytorch 2.1.1
causal_conv1d 1.1.1
mamba-ssm 1.2.0.post1

pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
pip install causal_conv1d==1.1.1
pip install mamba-ssm==1.2.0.post1

@xifen523
Copy link

thx, it is work

@JeffreyzzZ0
Copy link

my python version is 3.8,abi need false i think

---- Replied Message ---- | From | @.> | | Date | 03/02/2024 00:31 | | To | state-spaces/mamba @.> | | Cc | Xiao @.>, Author @.> | | Subject | Re: [state-spaces/mamba] Import Error (Issue #169) | 我是3.10,你看看你的whl是不是下载abiFalse版本,一般用False安装就能用了。cuda、torch、python版本自己对下。
------------------ 原始邮件 ------------------ 发件人: "state-spaces/mamba" @.>; 发送时间: 2024年3月1日(星期五) 下午3:54 @.>; @.@.>; 主题: Re: [state-spaces/mamba] Import Error (Issue #169) I use your version, but it doesn't work. I wonder if your python version is 3.8?Thanks! — Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you commented.Message ID: @.> — Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.>

Hey bro, Can I have your wechat? I have trouble in installing the environment

@askrbayern
Copy link

askrbayern commented May 7, 2024

Thank you guys! I solved my problem. I believe the problem is (py)torch version, and mamba/conv1d don't support 2.3.0 yet.(check Dao's response in Issue #217)

My problem: undefined symbol for both mamba-ssm and causal_conv1d

My previous environment: CUDA 12.1, torch/pytorch 2.3.0

My solution: reinstall (py)torch with version 2.2

  • pip uninstall torch torchvision torchaudio mamba-ssm causal_conv1d
  • conda uninstall pytorch if you have pytorch version 2.3.0
  • pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu121
  • conda install pytorch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 pytorch-cuda=12.1 -c pytorch -c nvidia
    find your code for installing other versions of pytorch here
  • pip install causal_conv1d
  • pip install mamba-ssm or pip install mamba-ssm=1.2.0.post1, both work

==============
Solutions that have been proposed in this post that doesnt work in my case:

  • pip install mamba-ssm --no-cache-dir (Doesnt make no difference for me)
  • pip install causal_conv1d==1.1.1 (Cannot install, so I just installed version 1.2.0)

Solutions proposed in Issue #217 that doesnt work in my case:

  • pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.2.0.post2/causal_conv1d-1.2.0.post2+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
  • pip install https://github.com/state-spaces/mamba/releases/download/v1.2.0.post1/mamba_ssm-1.2.0.post1+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl

@Yumo216
Copy link

Yumo216 commented May 16, 2024

Thank you guys! I solved my problem. I believe the problem is (py)torch version, and mamba/conv1d don't support 2.3.0 yet.(check Dao's response in Issue #217)

My problem: undefined symbol for both mamba-ssm and causal_conv1d

My previous environment: CUDA 12.1, torch/pytorch 2.3.0

My solution: reinstall (py)torch with version 2.2

  • pip uninstall torch torchvision torchaudio mamba-ssm causal_conv1d
  • conda uninstall pytorch if you have pytorch version 2.3.0
  • pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu121
  • conda install pytorch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 pytorch-cuda=12.1 -c pytorch -c nvidia
    find your code for installing other versions of pytorch here
  • pip install causal_conv1d
  • pip install mamba-ssm or pip install mamba-ssm=1.2.0.post1, both work

============== Solutions that have been proposed in this post that doesnt work in my case:

  • pip install mamba-ssm --no-cache-dir (Doesnt make no difference for me)
  • pip install causal_conv1d==1.1.1 (Cannot install, so I just installed version 1.2.0)

Solutions proposed in Issue #217 that doesnt work in my case:

  • pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.2.0.post2/causal_conv1d-1.2.0.post2+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
  • pip install https://github.com/state-spaces/mamba/releases/download/v1.2.0.post1/mamba_ssm-1.2.0.post1+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl

You are a real hero. You saved my life!

@elttaes
Copy link

elttaes commented May 24, 2024

Thank you guys! I solved my problem. I believe the problem is (py)torch version, and mamba/conv1d don't support 2.3.0 yet.(check Dao's response in Issue #217)

My problem: undefined symbol for both mamba-ssm and causal_conv1d

My previous environment: CUDA 12.1, torch/pytorch 2.3.0

My solution: reinstall (py)torch with version 2.2

  • pip uninstall torch torchvision torchaudio mamba-ssm causal_conv1d
  • conda uninstall pytorch if you have pytorch version 2.3.0
  • pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu121
  • conda install pytorch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 pytorch-cuda=12.1 -c pytorch -c nvidia
    find your code for installing other versions of pytorch here
  • pip install causal_conv1d
  • pip install mamba-ssm or pip install mamba-ssm=1.2.0.post1, both work

============== Solutions that have been proposed in this post that doesnt work in my case:

  • pip install mamba-ssm --no-cache-dir (Doesnt make no difference for me)
  • pip install causal_conv1d==1.1.1 (Cannot install, so I just installed version 1.2.0)

Solutions proposed in Issue #217 that doesnt work in my case:

  • pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.2.0.post2/causal_conv1d-1.2.0.post2+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
  • pip install https://github.com/state-spaces/mamba/releases/download/v1.2.0.post1/mamba_ssm-1.2.0.post1+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl

I have no choice but to directly copy my past conda environment, and after trying many methods, I have no idea where the problem lies. You are a real hero!!!

@MUsixG
Copy link

MUsixG commented May 27, 2024

For anyone else landing here. The conclusion of this thread was not completely clear to me. I had basically the same problem with

selective_scan_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol

In my case, I discovered that I had some incompatible pip cache laying around. A simple:

pip uninstall mamba-ssm
pip install mamba-ssm --no-cache-dir

fixes the issue though.

Tranks! It works for me. Have a nice day!

@Joeland4
Copy link

Joeland4 commented Aug 5, 2024

I have been successfully run. Environment follows: cuda 11.8 python 3.10.13 pytorch 2.1.1 causal_conv1d 1.1.1 mamba-ssm 1.2.0.post1

pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
pip install causal_conv1d==1.1.1
pip install mamba-ssm==1.2.0.post1

very nice!

@inari233
Copy link

inari233 commented Nov 8, 2024

Thank you guys! I solved my problem. I believe the problem is (py)torch version, and mamba/conv1d don't support 2.3.0 yet.(check Dao's response in Issue #217)

My problem: undefined symbol for both mamba-ssm and causal_conv1d

My previous environment: CUDA 12.1, torch/pytorch 2.3.0

My solution: reinstall (py)torch with version 2.2

  • pip uninstall torch torchvision torchaudio mamba-ssm causal_conv1d
  • conda uninstall pytorch if you have pytorch version 2.3.0
  • pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu121
  • conda install pytorch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 pytorch-cuda=12.1 -c pytorch -c nvidia
    find your code for installing other versions of pytorch here
  • pip install causal_conv1d
  • pip install mamba-ssm or pip install mamba-ssm=1.2.0.post1, both work

============== Solutions that have been proposed in this post that doesnt work in my case:

  • pip install mamba-ssm --no-cache-dir (Doesnt make no difference for me)
  • pip install causal_conv1d==1.1.1 (Cannot install, so I just installed version 1.2.0)

Solutions proposed in Issue #217 that doesnt work in my case:

  • pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.2.0.post2/causal_conv1d-1.2.0.post2+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
  • pip install https://github.com/state-spaces/mamba/releases/download/v1.2.0.post1/mamba_ssm-1.2.0.post1+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl

thank you so much! I solved my problem!

@GitHub-TXZ
Copy link

I have been successfully run. Environment follows: cuda 11.8 python 3.10.13 pytorch 2.1.1 causal_conv1d 1.1.1 mamba-ssm 1.2.0.post1

pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
pip install causal_conv1d==1.1.1
pip install mamba-ssm==1.2.0.post1

worked for me, thanks a lot.

@2326261098
Copy link

我已经成功运行了。环境如下: cuda 11.8 python 3.10.13 pytorch 2.1.1 causal_conv1d 1.1.1 mamba-ssm 1.2.0.post1

pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
pip install causal_conv1d==1.1.1
pip install mamba-ssm==1.2.0.post1

TypeError: causal_conv1d_fwd(): incompatible function arguments. The following argument types are supported:

1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: Optional[torch.Tensor], arg3: Optional[torch.Tensor], arg4: bool) -> torch.Tensor`      **Why does this error occur?Here's the error**

`Invoked with: tensor([[[-0.4972, -0.0841, -0.7153, ..., -0.5116, -0.5966, -0.4038],
[-0.5304, -0.1593, -0.5388, ..., -0.7242, -0.6015, -0.4977],
[-0.1011, 0.1564, 0.0449, ..., -0.1410, -0.1681, 0.0308],
...,
[ 0.1738, -0.2017, 0.1143, ..., -0.0232, -0.1404, 0.2061],
[ 0.2232, 0.1862, 0.3750, ..., 0.5341, 0.3552, 0.3399],
[ 0.0310, 0.4717, 0.3391, ..., 0.4183, 0.4458, 0.2148]],

    [[-0.1236,  0.0183, -0.2721,  ..., -0.1963, -0.2919, -0.2980],
     [-0.4473, -0.2262, -0.3592,  ..., -0.3785, -0.2162, -0.3304],
     [-0.5022, -0.1747,  0.0297,  ...,  0.3263,  0.2644,  0.1743],
     ...,
     [-0.2261, -0.2941,  0.1317,  ..., -0.1408, -0.0791, -0.1693],
     [ 0.1351,  0.1586,  0.2388,  ...,  0.2527,  0.1926,  0.2220],
     [ 0.2123,  0.2589,  0.6359,  ...,  0.5199,  0.6834,  0.7729]],

    [[ 0.1273,  0.1420,  0.1048,  ..., -0.1384, -0.0352, -0.0753],
     [-0.2793, -0.3165, -0.0517,  ...,  0.0255, -0.1416, -0.3111],
     [-0.3826, -0.5766, -0.2958,  ..., -0.4406, -0.3173, -0.2496],
     ...,
     [ 0.0736, -0.1909,  0.3622,  ..., -0.0320,  0.1088, -0.0544],
     [-0.3116, -0.5057,  0.0961,  ...,  0.1355, -0.1470, -0.0526],
     [-0.1089, -0.3265, -0.3133,  ..., -0.6575, -0.2360, -0.5200]],

    ...,

    [[-0.7018,  0.0977, -0.1910,  ...,  0.3664,  0.1110,  0.2611],
     [ 0.1290,  0.5651,  0.4527,  ...,  0.7411,  0.6430,  0.8723],
     [ 0.1521,  0.4316,  0.0037,  ...,  0.3059,  0.3185,  0.1616],
     ...,
     [ 0.3293,  0.2723,  0.1616,  ...,  0.2428,  0.1496,  0.0885],
     [ 0.5509,  0.0466, -0.2150,  ..., -0.2304, -0.0796, -0.0673],
     [ 0.3542,  0.3935, -0.1611,  ..., -0.0892,  0.0210, -0.2887]],

    [[-0.3187, -0.2526, -0.3609,  ..., -0.5215, -0.6626, -0.4987],
     [-0.0222,  0.0628, -0.3620,  ..., -0.7080, -0.6005, -0.7152],
     [ 0.1274,  0.2786, -0.0307,  ..., -0.1835,  0.0140,  0.0427],
     ...,
     [-0.2447, -0.2071, -0.0975,  ..., -0.3935, -0.1154,  0.0744],
     [ 0.0261, -0.1499,  0.1627,  ...,  0.4103,  0.2666,  0.2431],
     [ 0.2563,  0.1628,  0.0456,  ...,  0.3119,  0.2226,  0.2040]],

    [[ 0.0922,  0.1823,  0.0314,  ..., -0.3988, -0.3899, -0.5533],
     [-0.1961,  0.4195, -0.0733,  ..., -0.6066, -0.1775, -0.4679],
     [-0.0493, -0.0992, -0.2422,  ..., -0.1149, -0.2295,  0.0174],
     ...,
     [ 0.4047,  0.1344, -0.0736,  ..., -0.0625, -0.2637,  0.0819],
     [ 0.2770, -0.2014, -0.2445,  ...,  0.1019,  0.2819,  0.2453],
     [-0.3003, -0.2975, -0.2502,  ..., -0.3509, -0.4315, -0.0314]]],
   device='cuda:0', requires_grad=True), tensor([[-0.5545, -0.0504],
    [-0.7029,  0.1547],
    [-0.6145, -0.5547],
    ...,
    [ 0.4526, -0.2328],
    [-0.0573,  0.5587],
    [-0.6896, -0.0596]], device='cuda:0', requires_grad=True), Parameter containing:

tensor([ 1.3125e-02, 5.7447e-01, 2.1270e-01, -5.5272e-01, -5.4643e-01,
3.1301e-01, -5.5619e-01, 2.0742e-01, 4.9026e-01, -5.6646e-01,
5.3306e-01, -1.2783e-01, 9.3145e-05, 3.9434e-01, 6.2848e-01,
4.9313e-01, 6.8772e-01, -3.3493e-01, -4.7579e-01, 2.9487e-01,
5.1497e-01, 3.5366e-01, -6.5622e-02, -3.1617e-01, 3.1176e-01,
-4.3122e-01, -2.5413e-01, -5.0960e-01, 2.1220e-02, 3.3978e-01,
3.8265e-01, -6.0304e-01, -2.3946e-01, -2.1740e-01, -4.2517e-01,
4.6322e-01, 5.4477e-01, -6.2205e-01, 1.2914e-01, -3.6551e-01,
4.3933e-01, 3.7618e-01, 4.9020e-02, 6.6763e-01, -2.2146e-01,
-6.5910e-01, -5.0423e-01, 6.8807e-01, 1.3961e-01, -5.2864e-01,
-6.2716e-01, 1.8419e-01, 3.9406e-01, 1.3695e-01, 3.1223e-01,
1.7759e-01, 1.2750e-01, -5.9097e-01, 6.6207e-01, -3.8690e-01,
-5.0959e-01, -4.8000e-01, 6.2164e-01, -5.6189e-01, -5.7952e-01,
-1.0893e-01, 2.8869e-01, 4.2020e-01, -2.3112e-01, 2.4409e-01,
-5.6767e-01, -5.2430e-01, 7.6210e-03, 5.8997e-01, -6.4513e-01,
-3.9515e-01, -1.1175e-01, -6.9283e-01, 2.2273e-01, -2.8023e-03,
1.7083e-01, 6.8048e-01, 4.4314e-01, 1.8380e-01, -6.0850e-01,
-5.2201e-01, 1.3431e-01, 5.1366e-01, 3.8980e-01, 8.0231e-02,
7.7360e-02, -7.0148e-01, -4.9134e-01, -6.9597e-01, 5.5243e-01,
4.6764e-01, 3.7257e-01, 4.0829e-02, -2.4584e-01, 1.4880e-01,
2.6845e-01, 3.8751e-01, -7.9380e-02, -6.0303e-01, 4.3858e-01,
6.9118e-01, -2.4803e-02, -1.4691e-01, -8.6392e-02, 6.9252e-01,
-3.4503e-01, -2.5196e-01, -3.3160e-01, 6.7852e-01, 5.1945e-01,
1.1066e-01, 6.4966e-01, -2.7265e-01, -1.2356e-01, -5.0324e-01,
-2.7921e-01, 1.6126e-01, 3.0039e-01, 4.2285e-01, 3.5727e-01,
3.9305e-01, 2.8597e-01, 1.0190e-01, -5.3146e-01, 2.8922e-01,
6.6468e-01, -6.5051e-02, 3.9596e-01, 1.6124e-01, -2.5961e-01,
1.2858e-01, -6.3433e-01, -6.6816e-02, 5.3377e-02, -2.6140e-01,
-8.9719e-02, 5.8318e-01, 3.8704e-01, 2.6642e-01, -4.5158e-02,
-2.3971e-01, 4.4443e-01, 4.3098e-01, -6.4022e-01, -6.1767e-02,
-3.5504e-01, -6.5127e-01, -4.5525e-01, 3.4103e-01, -3.7587e-01,
5.4082e-02, 3.1197e-04, -4.1126e-01, 3.6386e-01, -5.7085e-01,
3.3455e-01, 5.9189e-02, 6.3824e-01, -6.7276e-01, 1.0381e-01,
-1.7907e-01, 4.8597e-01, -5.4322e-01, 5.1185e-01, -6.4938e-01,
-6.8780e-01, -3.0260e-01, 5.0499e-02, 2.0524e-01, 2.9970e-01,
-3.2142e-01, 5.0607e-01, 5.8199e-01, -2.6932e-02, -1.7641e-01,
3.3141e-01, 5.0907e-01, -2.5410e-01, 1.1003e-01, -2.5117e-01,
4.3627e-01, -6.8592e-01, 6.4168e-01, -6.6428e-01, 4.5772e-01,
-6.0930e-01, -2.3210e-01, -3.8760e-02, 9.6199e-02, 3.7996e-01,
-4.4099e-01, -3.3903e-01, -6.8992e-01, 1.1125e-01, -4.1246e-02,
-1.4029e-01, 1.5533e-01, -5.8081e-01, 3.8148e-01, -6.9527e-01,
-8.3308e-02, 1.3670e-01, 6.2792e-02, -1.4291e-01, -1.7687e-01,
4.6374e-01, 2.9657e-01, 1.7613e-01, -2.5471e-01, -4.8007e-01,
-2.1221e-01, 3.2772e-01, -6.4011e-01, -5.0705e-01, -3.9751e-01,
2.9622e-01, 2.6043e-01, -3.7087e-01, -6.7318e-01, 2.7583e-02,
3.1430e-01, -7.0222e-01, 1.3224e-01, -5.1947e-01, 3.1999e-01,
-6.3319e-01, -2.8209e-01, 4.8890e-01, 5.7621e-01, -3.4365e-01,
1.1053e-01, -4.6959e-01, -1.9331e-01, 1.0737e-01, 1.8986e-01,
-4.9161e-01, 2.0024e-01, -3.2232e-01, 2.7146e-01, 5.9602e-01,
5.6228e-01, -6.1709e-02, -6.5471e-01, -5.4500e-01, 6.9533e-01,
1.7229e-01, 6.9120e-02, 5.7528e-01, 2.3999e-01, 2.4260e-01,
5.8455e-01, 2.4276e-01, 6.3645e-01, 5.7485e-01, -3.3611e-01,
5.5120e-01, -2.1198e-01, 2.3297e-01, 2.1452e-01, -9.4671e-02,
6.5293e-01, 3.9053e-02, 2.2275e-02, 2.9628e-01, 3.1908e-01,
-6.7140e-01, 6.8762e-01, -5.7121e-01, -1.5617e-01, -3.4653e-01,
2.2521e-01, 7.1833e-02, -2.4522e-01, 5.0351e-01, 2.5200e-01,
9.7844e-02, 2.2677e-01, -3.4881e-01, 1.7409e-01, -6.7461e-01,
4.9950e-01, 4.7342e-01, -2.5474e-01, -4.7948e-02, 6.3706e-01,
-6.4689e-01, 6.6643e-01, 5.0080e-01, 5.7286e-01, -6.3920e-01,
4.0922e-01, -6.9562e-01, -6.2423e-01, -5.8335e-01, -2.3679e-01,
-3.4973e-01, 1.3297e-01, 3.5813e-01, -1.6693e-01, 6.0688e-01,
5.7882e-01, 9.6085e-02, -3.4164e-01, -2.8880e-01, -3.6775e-01,
5.5118e-02, 1.4274e-02, -5.1396e-01, 5.3207e-01, 5.9296e-01,
2.5704e-01, 6.2389e-01, 3.0990e-01, -3.9118e-01, 1.4950e-01,
-6.5002e-01, -4.6454e-01, 3.4622e-01, -4.2720e-02, 1.6612e-01,
2.0920e-01, 8.7081e-02, -1.4402e-01, 2.4345e-01, 5.3161e-01,
1.9381e-01, 1.7876e-01, 5.2671e-01, -4.1422e-01, 3.7718e-02,
6.6343e-01, -9.6545e-02, 3.6653e-01, 6.5478e-01, 4.2322e-01,
2.9497e-01, -1.8178e-02, -7.5735e-02, 6.3674e-01, -5.9702e-01,
-1.3265e-01, 4.5049e-01, 5.5740e-01, -5.2704e-01, 6.8029e-01,
5.9742e-01, -2.4403e-01, -4.8697e-01, -6.9242e-01, 3.4685e-01,
3.2806e-01, 5.1032e-01, 1.9634e-01, -6.5303e-01, -3.9335e-02,
5.9816e-01, 6.1563e-02, 2.6235e-01, 5.2690e-01, 5.2608e-01,
-9.7234e-02, -5.9401e-01, 5.2502e-01, 6.0506e-01, -2.6884e-01,
-6.8601e-01, 6.7088e-01, 5.7145e-01, 4.5113e-01, -6.3760e-01,
4.8450e-01, -3.8944e-05, 2.1491e-01, -1.7481e-01, 5.6721e-01,
-6.9511e-01, -2.9684e-01, 7.0481e-01, 7.8162e-02, -3.3698e-01,
-6.6241e-01, -2.9706e-01, 5.2349e-01, 2.7842e-01, 5.8484e-01,
3.9155e-01, 5.7357e-01, 2.0789e-01, 3.8797e-01, 6.5748e-01,
2.6987e-01, -3.7847e-01, -4.8756e-01, -5.5882e-01, 9.0059e-04,
-7.2690e-03, 5.3511e-01, 5.9143e-01, -1.7592e-02, 3.6787e-01,
6.6262e-01, 6.6483e-01, 1.1594e-01, 7.0592e-01, 2.4665e-01,
3.0344e-01, -6.4271e-01, 2.2355e-01, 3.3860e-01, -4.1460e-01,
-4.2086e-02, -6.4810e-01, 2.1489e-01, -3.5640e-01, -5.8422e-01,
6.3596e-01, -2.4636e-01, -6.3408e-01, -7.3345e-02, -2.6216e-01,
-6.6210e-02, -6.0102e-01, -1.2141e-01, 8.8954e-02, -5.5794e-01,
-4.9817e-01, -6.8143e-01, 5.0540e-01, 3.1634e-01, -1.2341e-02,
8.1507e-02, -6.3220e-02, -6.6325e-01, -6.9435e-01, -8.7460e-02,
-7.0179e-01, -6.2861e-02, 6.5155e-01, 2.8843e-01, 4.6375e-01,
1.7497e-01, 6.9253e-01, 2.3177e-02, -6.8518e-01, -6.1371e-01,
2.1623e-01, -2.2477e-01, -6.8041e-01, -5.1647e-01, -4.8018e-01,
-6.3432e-01, -2.7604e-01, -3.6074e-01, -4.8807e-01, 1.1406e-01,
-4.0313e-01, 1.5262e-01, 4.0057e-01, 5.6440e-01, 4.7118e-01,
-6.4346e-01, -5.2752e-02, 1.8268e-01, -6.3996e-02, -4.9924e-01,
4.7452e-01, -2.0045e-01, 2.3334e-01, 2.6453e-01, -3.7866e-01,
2.0975e-01, -9.9624e-02, -5.7317e-01, 5.4630e-01, 1.4398e-01,
-4.1372e-01, 3.1472e-01, -6.9125e-01, 6.6794e-01, 5.6112e-01,
1.5061e-01, 5.2241e-01, 6.6672e-01, 1.6463e-01, 5.6151e-01,
-1.6747e-01, 4.3575e-01, 2.0381e-02, -6.3836e-01, 8.6133e-02,
-1.2431e-01, 3.2535e-01, -4.6177e-01, 5.3694e-01, -2.2770e-01,
1.1685e-01, 5.9772e-01, 3.6253e-01, -1.8397e-01, 3.2186e-01,
5.0348e-01, -6.6262e-01, 6.8172e-01, 3.6229e-01, 8.8797e-02,
-3.1243e-01, -4.8207e-01], device='cuda:0', requires_grad=True), None, None, None, True
Traceback (most recent call last):
File "/root/autodl-tmp/TimeMachine/TimeMachine_supervised/run_longExp.py", line 119, in
exp.train(setting)
File "/root/autodl-tmp/TimeMachine/TimeMachine_supervised/exp/exp_main.py", line 144, in train
outputs = self.model(batch_x)
File "/root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/root/autodl-tmp/TimeMachine/TimeMachine_supervised/models/TimeMachine.py", line 52, in forward
x3=self.mamba3(x)
File "/root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/root/miniconda3/lib/python3.10/site-packages/mamba_ssm/modules/mamba_simple.py", line 146, in forward
out = mamba_inner_fn(
File "/root/miniconda3/lib/python3.10/site-packages/mamba_ssm/ops/selective_scan_interface.py", line 317, in mamba_inner_fn
return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
File "/root/miniconda3/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/root/miniconda3/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 113, in decorate_fwd
return fwd(*args, **kwargs)
File "/root/miniconda3/lib/python3.10/site-packages/mamba_ssm/ops/selective_scan_interface.py", line 187, in forward
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
TypeError: causal_conv1d_fwd(): incompatible function arguments. The following argument types are supported:
1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: Optional[torch.Tensor], arg3: Optional[torch.Tensor], arg4: bool) -> torch.Tensor`

@shanguanma
Copy link

pip install mamba-ssm --no-cache-dir

I use the below command, it works for me, thanks a lot.

pip uninstall mamba-ssm
pip install mamba-ssm --no-cache-dir --no-build-isolation```

@xmustu
Copy link

xmustu commented Dec 20, 2024

我已经成功运行了。环境如下: cuda 11.8 python 3.10.13 pytorch 2.1.1 causal_conv1d 1.1.1 mamba-ssm 1.2.0.post1

pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
pip install causal_conv1d==1.1.1
pip install mamba-ssm==1.2.0.post1

TypeError: causal_conv1d_fwd(): incompatible function arguments. The following argument types are supported:

1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: Optional[torch.Tensor], arg3: Optional[torch.Tensor], arg4: bool) -> torch.Tensor`      **Why does this error occur?Here's the error**

`Invoked with: tensor([[[-0.4972, -0.0841, -0.7153, ..., -0.5116, -0.5966, -0.4038], [-0.5304, -0.1593, -0.5388, ..., -0.7242, -0.6015, -0.4977], [-0.1011, 0.1564, 0.0449, ..., -0.1410, -0.1681, 0.0308], ..., [ 0.1738, -0.2017, 0.1143, ..., -0.0232, -0.1404, 0.2061], [ 0.2232, 0.1862, 0.3750, ..., 0.5341, 0.3552, 0.3399], [ 0.0310, 0.4717, 0.3391, ..., 0.4183, 0.4458, 0.2148]],

    [[-0.1236,  0.0183, -0.2721,  ..., -0.1963, -0.2919, -0.2980],
     [-0.4473, -0.2262, -0.3592,  ..., -0.3785, -0.2162, -0.3304],
     [-0.5022, -0.1747,  0.0297,  ...,  0.3263,  0.2644,  0.1743],
     ...,
     [-0.2261, -0.2941,  0.1317,  ..., -0.1408, -0.0791, -0.1693],
     [ 0.1351,  0.1586,  0.2388,  ...,  0.2527,  0.1926,  0.2220],
     [ 0.2123,  0.2589,  0.6359,  ...,  0.5199,  0.6834,  0.7729]],

    [[ 0.1273,  0.1420,  0.1048,  ..., -0.1384, -0.0352, -0.0753],
     [-0.2793, -0.3165, -0.0517,  ...,  0.0255, -0.1416, -0.3111],
     [-0.3826, -0.5766, -0.2958,  ..., -0.4406, -0.3173, -0.2496],
     ...,
     [ 0.0736, -0.1909,  0.3622,  ..., -0.0320,  0.1088, -0.0544],
     [-0.3116, -0.5057,  0.0961,  ...,  0.1355, -0.1470, -0.0526],
     [-0.1089, -0.3265, -0.3133,  ..., -0.6575, -0.2360, -0.5200]],

    ...,

    [[-0.7018,  0.0977, -0.1910,  ...,  0.3664,  0.1110,  0.2611],
     [ 0.1290,  0.5651,  0.4527,  ...,  0.7411,  0.6430,  0.8723],
     [ 0.1521,  0.4316,  0.0037,  ...,  0.3059,  0.3185,  0.1616],
     ...,
     [ 0.3293,  0.2723,  0.1616,  ...,  0.2428,  0.1496,  0.0885],
     [ 0.5509,  0.0466, -0.2150,  ..., -0.2304, -0.0796, -0.0673],
     [ 0.3542,  0.3935, -0.1611,  ..., -0.0892,  0.0210, -0.2887]],

    [[-0.3187, -0.2526, -0.3609,  ..., -0.5215, -0.6626, -0.4987],
     [-0.0222,  0.0628, -0.3620,  ..., -0.7080, -0.6005, -0.7152],
     [ 0.1274,  0.2786, -0.0307,  ..., -0.1835,  0.0140,  0.0427],
     ...,
     [-0.2447, -0.2071, -0.0975,  ..., -0.3935, -0.1154,  0.0744],
     [ 0.0261, -0.1499,  0.1627,  ...,  0.4103,  0.2666,  0.2431],
     [ 0.2563,  0.1628,  0.0456,  ...,  0.3119,  0.2226,  0.2040]],

    [[ 0.0922,  0.1823,  0.0314,  ..., -0.3988, -0.3899, -0.5533],
     [-0.1961,  0.4195, -0.0733,  ..., -0.6066, -0.1775, -0.4679],
     [-0.0493, -0.0992, -0.2422,  ..., -0.1149, -0.2295,  0.0174],
     ...,
     [ 0.4047,  0.1344, -0.0736,  ..., -0.0625, -0.2637,  0.0819],
     [ 0.2770, -0.2014, -0.2445,  ...,  0.1019,  0.2819,  0.2453],
     [-0.3003, -0.2975, -0.2502,  ..., -0.3509, -0.4315, -0.0314]]],
   device='cuda:0', requires_grad=True), tensor([[-0.5545, -0.0504],
    [-0.7029,  0.1547],
    [-0.6145, -0.5547],
    ...,
    [ 0.4526, -0.2328],
    [-0.0573,  0.5587],
    [-0.6896, -0.0596]], device='cuda:0', requires_grad=True), Parameter containing:

tensor([ 1.3125e-02, 5.7447e-01, 2.1270e-01, -5.5272e-01, -5.4643e-01, 3.1301e-01, -5.5619e-01, 2.0742e-01, 4.9026e-01, -5.6646e-01, 5.3306e-01, -1.2783e-01, 9.3145e-05, 3.9434e-01, 6.2848e-01, 4.9313e-01, 6.8772e-01, -3.3493e-01, -4.7579e-01, 2.9487e-01, 5.1497e-01, 3.5366e-01, -6.5622e-02, -3.1617e-01, 3.1176e-01, -4.3122e-01, -2.5413e-01, -5.0960e-01, 2.1220e-02, 3.3978e-01, 3.8265e-01, -6.0304e-01, -2.3946e-01, -2.1740e-01, -4.2517e-01, 4.6322e-01, 5.4477e-01, -6.2205e-01, 1.2914e-01, -3.6551e-01, 4.3933e-01, 3.7618e-01, 4.9020e-02, 6.6763e-01, -2.2146e-01, -6.5910e-01, -5.0423e-01, 6.8807e-01, 1.3961e-01, -5.2864e-01, -6.2716e-01, 1.8419e-01, 3.9406e-01, 1.3695e-01, 3.1223e-01, 1.7759e-01, 1.2750e-01, -5.9097e-01, 6.6207e-01, -3.8690e-01, -5.0959e-01, -4.8000e-01, 6.2164e-01, -5.6189e-01, -5.7952e-01, -1.0893e-01, 2.8869e-01, 4.2020e-01, -2.3112e-01, 2.4409e-01, -5.6767e-01, -5.2430e-01, 7.6210e-03, 5.8997e-01, -6.4513e-01, -3.9515e-01, -1.1175e-01, -6.9283e-01, 2.2273e-01, -2.8023e-03, 1.7083e-01, 6.8048e-01, 4.4314e-01, 1.8380e-01, -6.0850e-01, -5.2201e-01, 1.3431e-01, 5.1366e-01, 3.8980e-01, 8.0231e-02, 7.7360e-02, -7.0148e-01, -4.9134e-01, -6.9597e-01, 5.5243e-01, 4.6764e-01, 3.7257e-01, 4.0829e-02, -2.4584e-01, 1.4880e-01, 2.6845e-01, 3.8751e-01, -7.9380e-02, -6.0303e-01, 4.3858e-01, 6.9118e-01, -2.4803e-02, -1.4691e-01, -8.6392e-02, 6.9252e-01, -3.4503e-01, -2.5196e-01, -3.3160e-01, 6.7852e-01, 5.1945e-01, 1.1066e-01, 6.4966e-01, -2.7265e-01, -1.2356e-01, -5.0324e-01, -2.7921e-01, 1.6126e-01, 3.0039e-01, 4.2285e-01, 3.5727e-01, 3.9305e-01, 2.8597e-01, 1.0190e-01, -5.3146e-01, 2.8922e-01, 6.6468e-01, -6.5051e-02, 3.9596e-01, 1.6124e-01, -2.5961e-01, 1.2858e-01, -6.3433e-01, -6.6816e-02, 5.3377e-02, -2.6140e-01, -8.9719e-02, 5.8318e-01, 3.8704e-01, 2.6642e-01, -4.5158e-02, -2.3971e-01, 4.4443e-01, 4.3098e-01, -6.4022e-01, -6.1767e-02, -3.5504e-01, -6.5127e-01, -4.5525e-01, 3.4103e-01, -3.7587e-01, 5.4082e-02, 3.1197e-04, -4.1126e-01, 3.6386e-01, -5.7085e-01, 3.3455e-01, 5.9189e-02, 6.3824e-01, -6.7276e-01, 1.0381e-01, -1.7907e-01, 4.8597e-01, -5.4322e-01, 5.1185e-01, -6.4938e-01, -6.8780e-01, -3.0260e-01, 5.0499e-02, 2.0524e-01, 2.9970e-01, -3.2142e-01, 5.0607e-01, 5.8199e-01, -2.6932e-02, -1.7641e-01, 3.3141e-01, 5.0907e-01, -2.5410e-01, 1.1003e-01, -2.5117e-01, 4.3627e-01, -6.8592e-01, 6.4168e-01, -6.6428e-01, 4.5772e-01, -6.0930e-01, -2.3210e-01, -3.8760e-02, 9.6199e-02, 3.7996e-01, -4.4099e-01, -3.3903e-01, -6.8992e-01, 1.1125e-01, -4.1246e-02, -1.4029e-01, 1.5533e-01, -5.8081e-01, 3.8148e-01, -6.9527e-01, -8.3308e-02, 1.3670e-01, 6.2792e-02, -1.4291e-01, -1.7687e-01, 4.6374e-01, 2.9657e-01, 1.7613e-01, -2.5471e-01, -4.8007e-01, -2.1221e-01, 3.2772e-01, -6.4011e-01, -5.0705e-01, -3.9751e-01, 2.9622e-01, 2.6043e-01, -3.7087e-01, -6.7318e-01, 2.7583e-02, 3.1430e-01, -7.0222e-01, 1.3224e-01, -5.1947e-01, 3.1999e-01, -6.3319e-01, -2.8209e-01, 4.8890e-01, 5.7621e-01, -3.4365e-01, 1.1053e-01, -4.6959e-01, -1.9331e-01, 1.0737e-01, 1.8986e-01, -4.9161e-01, 2.0024e-01, -3.2232e-01, 2.7146e-01, 5.9602e-01, 5.6228e-01, -6.1709e-02, -6.5471e-01, -5.4500e-01, 6.9533e-01, 1.7229e-01, 6.9120e-02, 5.7528e-01, 2.3999e-01, 2.4260e-01, 5.8455e-01, 2.4276e-01, 6.3645e-01, 5.7485e-01, -3.3611e-01, 5.5120e-01, -2.1198e-01, 2.3297e-01, 2.1452e-01, -9.4671e-02, 6.5293e-01, 3.9053e-02, 2.2275e-02, 2.9628e-01, 3.1908e-01, -6.7140e-01, 6.8762e-01, -5.7121e-01, -1.5617e-01, -3.4653e-01, 2.2521e-01, 7.1833e-02, -2.4522e-01, 5.0351e-01, 2.5200e-01, 9.7844e-02, 2.2677e-01, -3.4881e-01, 1.7409e-01, -6.7461e-01, 4.9950e-01, 4.7342e-01, -2.5474e-01, -4.7948e-02, 6.3706e-01, -6.4689e-01, 6.6643e-01, 5.0080e-01, 5.7286e-01, -6.3920e-01, 4.0922e-01, -6.9562e-01, -6.2423e-01, -5.8335e-01, -2.3679e-01, -3.4973e-01, 1.3297e-01, 3.5813e-01, -1.6693e-01, 6.0688e-01, 5.7882e-01, 9.6085e-02, -3.4164e-01, -2.8880e-01, -3.6775e-01, 5.5118e-02, 1.4274e-02, -5.1396e-01, 5.3207e-01, 5.9296e-01, 2.5704e-01, 6.2389e-01, 3.0990e-01, -3.9118e-01, 1.4950e-01, -6.5002e-01, -4.6454e-01, 3.4622e-01, -4.2720e-02, 1.6612e-01, 2.0920e-01, 8.7081e-02, -1.4402e-01, 2.4345e-01, 5.3161e-01, 1.9381e-01, 1.7876e-01, 5.2671e-01, -4.1422e-01, 3.7718e-02, 6.6343e-01, -9.6545e-02, 3.6653e-01, 6.5478e-01, 4.2322e-01, 2.9497e-01, -1.8178e-02, -7.5735e-02, 6.3674e-01, -5.9702e-01, -1.3265e-01, 4.5049e-01, 5.5740e-01, -5.2704e-01, 6.8029e-01, 5.9742e-01, -2.4403e-01, -4.8697e-01, -6.9242e-01, 3.4685e-01, 3.2806e-01, 5.1032e-01, 1.9634e-01, -6.5303e-01, -3.9335e-02, 5.9816e-01, 6.1563e-02, 2.6235e-01, 5.2690e-01, 5.2608e-01, -9.7234e-02, -5.9401e-01, 5.2502e-01, 6.0506e-01, -2.6884e-01, -6.8601e-01, 6.7088e-01, 5.7145e-01, 4.5113e-01, -6.3760e-01, 4.8450e-01, -3.8944e-05, 2.1491e-01, -1.7481e-01, 5.6721e-01, -6.9511e-01, -2.9684e-01, 7.0481e-01, 7.8162e-02, -3.3698e-01, -6.6241e-01, -2.9706e-01, 5.2349e-01, 2.7842e-01, 5.8484e-01, 3.9155e-01, 5.7357e-01, 2.0789e-01, 3.8797e-01, 6.5748e-01, 2.6987e-01, -3.7847e-01, -4.8756e-01, -5.5882e-01, 9.0059e-04, -7.2690e-03, 5.3511e-01, 5.9143e-01, -1.7592e-02, 3.6787e-01, 6.6262e-01, 6.6483e-01, 1.1594e-01, 7.0592e-01, 2.4665e-01, 3.0344e-01, -6.4271e-01, 2.2355e-01, 3.3860e-01, -4.1460e-01, -4.2086e-02, -6.4810e-01, 2.1489e-01, -3.5640e-01, -5.8422e-01, 6.3596e-01, -2.4636e-01, -6.3408e-01, -7.3345e-02, -2.6216e-01, -6.6210e-02, -6.0102e-01, -1.2141e-01, 8.8954e-02, -5.5794e-01, -4.9817e-01, -6.8143e-01, 5.0540e-01, 3.1634e-01, -1.2341e-02, 8.1507e-02, -6.3220e-02, -6.6325e-01, -6.9435e-01, -8.7460e-02, -7.0179e-01, -6.2861e-02, 6.5155e-01, 2.8843e-01, 4.6375e-01, 1.7497e-01, 6.9253e-01, 2.3177e-02, -6.8518e-01, -6.1371e-01, 2.1623e-01, -2.2477e-01, -6.8041e-01, -5.1647e-01, -4.8018e-01, -6.3432e-01, -2.7604e-01, -3.6074e-01, -4.8807e-01, 1.1406e-01, -4.0313e-01, 1.5262e-01, 4.0057e-01, 5.6440e-01, 4.7118e-01, -6.4346e-01, -5.2752e-02, 1.8268e-01, -6.3996e-02, -4.9924e-01, 4.7452e-01, -2.0045e-01, 2.3334e-01, 2.6453e-01, -3.7866e-01, 2.0975e-01, -9.9624e-02, -5.7317e-01, 5.4630e-01, 1.4398e-01, -4.1372e-01, 3.1472e-01, -6.9125e-01, 6.6794e-01, 5.6112e-01, 1.5061e-01, 5.2241e-01, 6.6672e-01, 1.6463e-01, 5.6151e-01, -1.6747e-01, 4.3575e-01, 2.0381e-02, -6.3836e-01, 8.6133e-02, -1.2431e-01, 3.2535e-01, -4.6177e-01, 5.3694e-01, -2.2770e-01, 1.1685e-01, 5.9772e-01, 3.6253e-01, -1.8397e-01, 3.2186e-01, 5.0348e-01, -6.6262e-01, 6.8172e-01, 3.6229e-01, 8.8797e-02, -3.1243e-01, -4.8207e-01], device='cuda:0', requires_grad=True), None, None, None, True Traceback (most recent call last): File "/root/autodl-tmp/TimeMachine/TimeMachine_supervised/run_longExp.py", line 119, in exp.train(setting) File "/root/autodl-tmp/TimeMachine/TimeMachine_supervised/exp/exp_main.py", line 144, in train outputs = self.model(batch_x) File "/root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/root/autodl-tmp/TimeMachine/TimeMachine_supervised/models/TimeMachine.py", line 52, in forward x3=self.mamba3(x) File "/root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/root/miniconda3/lib/python3.10/site-packages/mamba_ssm/modules/mamba_simple.py", line 146, in forward out = mamba_inner_fn( File "/root/miniconda3/lib/python3.10/site-packages/mamba_ssm/ops/selective_scan_interface.py", line 317, in mamba_inner_fn return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, File "/root/miniconda3/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply return super().apply(*args, **kwargs) # type: ignore[misc] File "/root/miniconda3/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 113, in decorate_fwd return fwd(*args, **kwargs) File "/root/miniconda3/lib/python3.10/site-packages/mamba_ssm/ops/selective_scan_interface.py", line 187, in forward conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd( TypeError: causal_conv1d_fwd(): incompatible function arguments. The following argument types are supported: 1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: Optional[torch.Tensor], arg3: Optional[torch.Tensor], arg4: bool) -> torch.Tensor`

Hey, dude, have you slove the problem?

@2326261098
Copy link

2326261098 commented Dec 20, 2024 via email

@PenghaoYin
Copy link

I have been successfully run. Environment follows: cuda 11.8 python 3.10.13 pytorch 2.1.1 causal_conv1d 1.1.1 mamba-ssm 1.2.0.post1

pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
pip install causal_conv1d==1.1.1
pip install mamba-ssm==1.2.0.post1

These configurations solved my problem of importing mamba_ssm, which is the same as the raiser of this issue. Thank you so much!

@PenghaoYin
Copy link

我已经成功运行了。环境如下: cuda 11.8 python 3.10.13 pytorch 2.1.1 causal_conv1d 1.1.1 mamba-ssm 1.2.0.post1

pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
pip install causal_conv1d==1.1.1
pip install mamba-ssm==1.2.0.post1

TypeError: causal_conv1d_fwd(): incompatible function arguments. The following argument types are supported:

1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: Optional[torch.Tensor], arg3: Optional[torch.Tensor], arg4: bool) -> torch.Tensor`      **Why does this error occur?Here's the error**

`Invoked with: tensor([[[-0.4972, -0.0841, -0.7153, ..., -0.5116, -0.5966, -0.4038], [-0.5304, -0.1593, -0.5388, ..., -0.7242, -0.6015, -0.4977], [-0.1011, 0.1564, 0.0449, ..., -0.1410, -0.1681, 0.0308], ..., [ 0.1738, -0.2017, 0.1143, ..., -0.0232, -0.1404, 0.2061], [ 0.2232, 0.1862, 0.3750, ..., 0.5341, 0.3552, 0.3399], [ 0.0310, 0.4717, 0.3391, ..., 0.4183, 0.4458, 0.2148]],

    [[-0.1236,  0.0183, -0.2721,  ..., -0.1963, -0.2919, -0.2980],
     [-0.4473, -0.2262, -0.3592,  ..., -0.3785, -0.2162, -0.3304],
     [-0.5022, -0.1747,  0.0297,  ...,  0.3263,  0.2644,  0.1743],
     ...,
     [-0.2261, -0.2941,  0.1317,  ..., -0.1408, -0.0791, -0.1693],
     [ 0.1351,  0.1586,  0.2388,  ...,  0.2527,  0.1926,  0.2220],
     [ 0.2123,  0.2589,  0.6359,  ...,  0.5199,  0.6834,  0.7729]],

    [[ 0.1273,  0.1420,  0.1048,  ..., -0.1384, -0.0352, -0.0753],
     [-0.2793, -0.3165, -0.0517,  ...,  0.0255, -0.1416, -0.3111],
     [-0.3826, -0.5766, -0.2958,  ..., -0.4406, -0.3173, -0.2496],
     ...,
     [ 0.0736, -0.1909,  0.3622,  ..., -0.0320,  0.1088, -0.0544],
     [-0.3116, -0.5057,  0.0961,  ...,  0.1355, -0.1470, -0.0526],
     [-0.1089, -0.3265, -0.3133,  ..., -0.6575, -0.2360, -0.5200]],

    ...,

    [[-0.7018,  0.0977, -0.1910,  ...,  0.3664,  0.1110,  0.2611],
     [ 0.1290,  0.5651,  0.4527,  ...,  0.7411,  0.6430,  0.8723],
     [ 0.1521,  0.4316,  0.0037,  ...,  0.3059,  0.3185,  0.1616],
     ...,
     [ 0.3293,  0.2723,  0.1616,  ...,  0.2428,  0.1496,  0.0885],
     [ 0.5509,  0.0466, -0.2150,  ..., -0.2304, -0.0796, -0.0673],
     [ 0.3542,  0.3935, -0.1611,  ..., -0.0892,  0.0210, -0.2887]],

    [[-0.3187, -0.2526, -0.3609,  ..., -0.5215, -0.6626, -0.4987],
     [-0.0222,  0.0628, -0.3620,  ..., -0.7080, -0.6005, -0.7152],
     [ 0.1274,  0.2786, -0.0307,  ..., -0.1835,  0.0140,  0.0427],
     ...,
     [-0.2447, -0.2071, -0.0975,  ..., -0.3935, -0.1154,  0.0744],
     [ 0.0261, -0.1499,  0.1627,  ...,  0.4103,  0.2666,  0.2431],
     [ 0.2563,  0.1628,  0.0456,  ...,  0.3119,  0.2226,  0.2040]],

    [[ 0.0922,  0.1823,  0.0314,  ..., -0.3988, -0.3899, -0.5533],
     [-0.1961,  0.4195, -0.0733,  ..., -0.6066, -0.1775, -0.4679],
     [-0.0493, -0.0992, -0.2422,  ..., -0.1149, -0.2295,  0.0174],
     ...,
     [ 0.4047,  0.1344, -0.0736,  ..., -0.0625, -0.2637,  0.0819],
     [ 0.2770, -0.2014, -0.2445,  ...,  0.1019,  0.2819,  0.2453],
     [-0.3003, -0.2975, -0.2502,  ..., -0.3509, -0.4315, -0.0314]]],
   device='cuda:0', requires_grad=True), tensor([[-0.5545, -0.0504],
    [-0.7029,  0.1547],
    [-0.6145, -0.5547],
    ...,
    [ 0.4526, -0.2328],
    [-0.0573,  0.5587],
    [-0.6896, -0.0596]], device='cuda:0', requires_grad=True), Parameter containing:

tensor([ 1.3125e-02, 5.7447e-01, 2.1270e-01, -5.5272e-01, -5.4643e-01, 3.1301e-01, -5.5619e-01, 2.0742e-01, 4.9026e-01, -5.6646e-01, 5.3306e-01, -1.2783e-01, 9.3145e-05, 3.9434e-01, 6.2848e-01, 4.9313e-01, 6.8772e-01, -3.3493e-01, -4.7579e-01, 2.9487e-01, 5.1497e-01, 3.5366e-01, -6.5622e-02, -3.1617e-01, 3.1176e-01, -4.3122e-01, -2.5413e-01, -5.0960e-01, 2.1220e-02, 3.3978e-01, 3.8265e-01, -6.0304e-01, -2.3946e-01, -2.1740e-01, -4.2517e-01, 4.6322e-01, 5.4477e-01, -6.2205e-01, 1.2914e-01, -3.6551e-01, 4.3933e-01, 3.7618e-01, 4.9020e-02, 6.6763e-01, -2.2146e-01, -6.5910e-01, -5.0423e-01, 6.8807e-01, 1.3961e-01, -5.2864e-01, -6.2716e-01, 1.8419e-01, 3.9406e-01, 1.3695e-01, 3.1223e-01, 1.7759e-01, 1.2750e-01, -5.9097e-01, 6.6207e-01, -3.8690e-01, -5.0959e-01, -4.8000e-01, 6.2164e-01, -5.6189e-01, -5.7952e-01, -1.0893e-01, 2.8869e-01, 4.2020e-01, -2.3112e-01, 2.4409e-01, -5.6767e-01, -5.2430e-01, 7.6210e-03, 5.8997e-01, -6.4513e-01, -3.9515e-01, -1.1175e-01, -6.9283e-01, 2.2273e-01, -2.8023e-03, 1.7083e-01, 6.8048e-01, 4.4314e-01, 1.8380e-01, -6.0850e-01, -5.2201e-01, 1.3431e-01, 5.1366e-01, 3.8980e-01, 8.0231e-02, 7.7360e-02, -7.0148e-01, -4.9134e-01, -6.9597e-01, 5.5243e-01, 4.6764e-01, 3.7257e-01, 4.0829e-02, -2.4584e-01, 1.4880e-01, 2.6845e-01, 3.8751e-01, -7.9380e-02, -6.0303e-01, 4.3858e-01, 6.9118e-01, -2.4803e-02, -1.4691e-01, -8.6392e-02, 6.9252e-01, -3.4503e-01, -2.5196e-01, -3.3160e-01, 6.7852e-01, 5.1945e-01, 1.1066e-01, 6.4966e-01, -2.7265e-01, -1.2356e-01, -5.0324e-01, -2.7921e-01, 1.6126e-01, 3.0039e-01, 4.2285e-01, 3.5727e-01, 3.9305e-01, 2.8597e-01, 1.0190e-01, -5.3146e-01, 2.8922e-01, 6.6468e-01, -6.5051e-02, 3.9596e-01, 1.6124e-01, -2.5961e-01, 1.2858e-01, -6.3433e-01, -6.6816e-02, 5.3377e-02, -2.6140e-01, -8.9719e-02, 5.8318e-01, 3.8704e-01, 2.6642e-01, -4.5158e-02, -2.3971e-01, 4.4443e-01, 4.3098e-01, -6.4022e-01, -6.1767e-02, -3.5504e-01, -6.5127e-01, -4.5525e-01, 3.4103e-01, -3.7587e-01, 5.4082e-02, 3.1197e-04, -4.1126e-01, 3.6386e-01, -5.7085e-01, 3.3455e-01, 5.9189e-02, 6.3824e-01, -6.7276e-01, 1.0381e-01, -1.7907e-01, 4.8597e-01, -5.4322e-01, 5.1185e-01, -6.4938e-01, -6.8780e-01, -3.0260e-01, 5.0499e-02, 2.0524e-01, 2.9970e-01, -3.2142e-01, 5.0607e-01, 5.8199e-01, -2.6932e-02, -1.7641e-01, 3.3141e-01, 5.0907e-01, -2.5410e-01, 1.1003e-01, -2.5117e-01, 4.3627e-01, -6.8592e-01, 6.4168e-01, -6.6428e-01, 4.5772e-01, -6.0930e-01, -2.3210e-01, -3.8760e-02, 9.6199e-02, 3.7996e-01, -4.4099e-01, -3.3903e-01, -6.8992e-01, 1.1125e-01, -4.1246e-02, -1.4029e-01, 1.5533e-01, -5.8081e-01, 3.8148e-01, -6.9527e-01, -8.3308e-02, 1.3670e-01, 6.2792e-02, -1.4291e-01, -1.7687e-01, 4.6374e-01, 2.9657e-01, 1.7613e-01, -2.5471e-01, -4.8007e-01, -2.1221e-01, 3.2772e-01, -6.4011e-01, -5.0705e-01, -3.9751e-01, 2.9622e-01, 2.6043e-01, -3.7087e-01, -6.7318e-01, 2.7583e-02, 3.1430e-01, -7.0222e-01, 1.3224e-01, -5.1947e-01, 3.1999e-01, -6.3319e-01, -2.8209e-01, 4.8890e-01, 5.7621e-01, -3.4365e-01, 1.1053e-01, -4.6959e-01, -1.9331e-01, 1.0737e-01, 1.8986e-01, -4.9161e-01, 2.0024e-01, -3.2232e-01, 2.7146e-01, 5.9602e-01, 5.6228e-01, -6.1709e-02, -6.5471e-01, -5.4500e-01, 6.9533e-01, 1.7229e-01, 6.9120e-02, 5.7528e-01, 2.3999e-01, 2.4260e-01, 5.8455e-01, 2.4276e-01, 6.3645e-01, 5.7485e-01, -3.3611e-01, 5.5120e-01, -2.1198e-01, 2.3297e-01, 2.1452e-01, -9.4671e-02, 6.5293e-01, 3.9053e-02, 2.2275e-02, 2.9628e-01, 3.1908e-01, -6.7140e-01, 6.8762e-01, -5.7121e-01, -1.5617e-01, -3.4653e-01, 2.2521e-01, 7.1833e-02, -2.4522e-01, 5.0351e-01, 2.5200e-01, 9.7844e-02, 2.2677e-01, -3.4881e-01, 1.7409e-01, -6.7461e-01, 4.9950e-01, 4.7342e-01, -2.5474e-01, -4.7948e-02, 6.3706e-01, -6.4689e-01, 6.6643e-01, 5.0080e-01, 5.7286e-01, -6.3920e-01, 4.0922e-01, -6.9562e-01, -6.2423e-01, -5.8335e-01, -2.3679e-01, -3.4973e-01, 1.3297e-01, 3.5813e-01, -1.6693e-01, 6.0688e-01, 5.7882e-01, 9.6085e-02, -3.4164e-01, -2.8880e-01, -3.6775e-01, 5.5118e-02, 1.4274e-02, -5.1396e-01, 5.3207e-01, 5.9296e-01, 2.5704e-01, 6.2389e-01, 3.0990e-01, -3.9118e-01, 1.4950e-01, -6.5002e-01, -4.6454e-01, 3.4622e-01, -4.2720e-02, 1.6612e-01, 2.0920e-01, 8.7081e-02, -1.4402e-01, 2.4345e-01, 5.3161e-01, 1.9381e-01, 1.7876e-01, 5.2671e-01, -4.1422e-01, 3.7718e-02, 6.6343e-01, -9.6545e-02, 3.6653e-01, 6.5478e-01, 4.2322e-01, 2.9497e-01, -1.8178e-02, -7.5735e-02, 6.3674e-01, -5.9702e-01, -1.3265e-01, 4.5049e-01, 5.5740e-01, -5.2704e-01, 6.8029e-01, 5.9742e-01, -2.4403e-01, -4.8697e-01, -6.9242e-01, 3.4685e-01, 3.2806e-01, 5.1032e-01, 1.9634e-01, -6.5303e-01, -3.9335e-02, 5.9816e-01, 6.1563e-02, 2.6235e-01, 5.2690e-01, 5.2608e-01, -9.7234e-02, -5.9401e-01, 5.2502e-01, 6.0506e-01, -2.6884e-01, -6.8601e-01, 6.7088e-01, 5.7145e-01, 4.5113e-01, -6.3760e-01, 4.8450e-01, -3.8944e-05, 2.1491e-01, -1.7481e-01, 5.6721e-01, -6.9511e-01, -2.9684e-01, 7.0481e-01, 7.8162e-02, -3.3698e-01, -6.6241e-01, -2.9706e-01, 5.2349e-01, 2.7842e-01, 5.8484e-01, 3.9155e-01, 5.7357e-01, 2.0789e-01, 3.8797e-01, 6.5748e-01, 2.6987e-01, -3.7847e-01, -4.8756e-01, -5.5882e-01, 9.0059e-04, -7.2690e-03, 5.3511e-01, 5.9143e-01, -1.7592e-02, 3.6787e-01, 6.6262e-01, 6.6483e-01, 1.1594e-01, 7.0592e-01, 2.4665e-01, 3.0344e-01, -6.4271e-01, 2.2355e-01, 3.3860e-01, -4.1460e-01, -4.2086e-02, -6.4810e-01, 2.1489e-01, -3.5640e-01, -5.8422e-01, 6.3596e-01, -2.4636e-01, -6.3408e-01, -7.3345e-02, -2.6216e-01, -6.6210e-02, -6.0102e-01, -1.2141e-01, 8.8954e-02, -5.5794e-01, -4.9817e-01, -6.8143e-01, 5.0540e-01, 3.1634e-01, -1.2341e-02, 8.1507e-02, -6.3220e-02, -6.6325e-01, -6.9435e-01, -8.7460e-02, -7.0179e-01, -6.2861e-02, 6.5155e-01, 2.8843e-01, 4.6375e-01, 1.7497e-01, 6.9253e-01, 2.3177e-02, -6.8518e-01, -6.1371e-01, 2.1623e-01, -2.2477e-01, -6.8041e-01, -5.1647e-01, -4.8018e-01, -6.3432e-01, -2.7604e-01, -3.6074e-01, -4.8807e-01, 1.1406e-01, -4.0313e-01, 1.5262e-01, 4.0057e-01, 5.6440e-01, 4.7118e-01, -6.4346e-01, -5.2752e-02, 1.8268e-01, -6.3996e-02, -4.9924e-01, 4.7452e-01, -2.0045e-01, 2.3334e-01, 2.6453e-01, -3.7866e-01, 2.0975e-01, -9.9624e-02, -5.7317e-01, 5.4630e-01, 1.4398e-01, -4.1372e-01, 3.1472e-01, -6.9125e-01, 6.6794e-01, 5.6112e-01, 1.5061e-01, 5.2241e-01, 6.6672e-01, 1.6463e-01, 5.6151e-01, -1.6747e-01, 4.3575e-01, 2.0381e-02, -6.3836e-01, 8.6133e-02, -1.2431e-01, 3.2535e-01, -4.6177e-01, 5.3694e-01, -2.2770e-01, 1.1685e-01, 5.9772e-01, 3.6253e-01, -1.8397e-01, 3.2186e-01, 5.0348e-01, -6.6262e-01, 6.8172e-01, 3.6229e-01, 8.8797e-02, -3.1243e-01, -4.8207e-01], device='cuda:0', requires_grad=True), None, None, None, True Traceback (most recent call last): File "/root/autodl-tmp/TimeMachine/TimeMachine_supervised/run_longExp.py", line 119, in exp.train(setting) File "/root/autodl-tmp/TimeMachine/TimeMachine_supervised/exp/exp_main.py", line 144, in train outputs = self.model(batch_x) File "/root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/root/autodl-tmp/TimeMachine/TimeMachine_supervised/models/TimeMachine.py", line 52, in forward x3=self.mamba3(x) File "/root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/root/miniconda3/lib/python3.10/site-packages/mamba_ssm/modules/mamba_simple.py", line 146, in forward out = mamba_inner_fn( File "/root/miniconda3/lib/python3.10/site-packages/mamba_ssm/ops/selective_scan_interface.py", line 317, in mamba_inner_fn return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, File "/root/miniconda3/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply return super().apply(*args, **kwargs) # type: ignore[misc] File "/root/miniconda3/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 113, in decorate_fwd return fwd(*args, **kwargs) File "/root/miniconda3/lib/python3.10/site-packages/mamba_ssm/ops/selective_scan_interface.py", line 187, in forward conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd( TypeError: causal_conv1d_fwd(): incompatible function arguments. The following argument types are supported: 1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: Optional[torch.Tensor], arg3: Optional[torch.Tensor], arg4: bool) -> torch.Tensor`

The same problem

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests