Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pnnx 转ncnn时出现崩溃 #4211

Closed
zuowanbushiwo opened this issue Sep 20, 2022 · 6 comments
Closed

pnnx 转ncnn时出现崩溃 #4211

zuowanbushiwo opened this issue Sep 20, 2022 · 6 comments

Comments

@zuowanbushiwo
Copy link

error log | 日志或报错信息 | ログ

pnnxparam = DTLN.pnnx.param
pnnxbin = DTLN.pnnx.bin
pnnxpy = DTLN_pnnx.py
ncnnparam = DTLN.ncnn.param
ncnnbin = DTLN.ncnn.bin
ncnnpy = DTLN_ncnn.py
optlevel = 2
device = cpu
inputshape = [1,512]f32,[2,1,128,2]f32,[2,1,128,2]f32
inputshape2 =
customop =
moduleop =
############# pass_level0
inline module = Pytorch_InstantLayerNormalization
inline module = SeperationBlock_Stateful
inline module = Simple_STFT_Layer
inline module = Pytorch_InstantLayerNormalization
inline module = SeperationBlock_Stateful
inline module = Simple_STFT_Layer
41  92  y.1  r.1  i.1  105  106  107  108  mag.1  110  230  111  phase.1  19  20  orig_input.1  phase.3  125  126  127  hx.2  129  130  131  hx0.2  133  134  135  hx1.2  137  138  139  hx2.2  142  143  144  145  147  148  149  150  151  152  h.2  c.2  158  35  36  estimated_mag.1  42  43  s1_stft.1  y1.1  input.1  58  inputs.1  160  beta.1  gamma.1  mean.1  169  170  variance.1  173  std.1  175  outputs.1  outputs0.1  178  190  191  192  hx.1  194  195  196  hx0.1  198  199  200  hx1.1  202  203  204  hx2.1  207  208  209  210  212  213  214  215  216  217  h.1  c.1  223  68  69  estimated.1  input0.1  78
----------------

foldable_constant beta.1
foldable_constant gamma.1
############# pass_level1
no attribute value
unknown Parameter value kind prim::Constant of TensorType, t.dim = 0
no attribute value
no attribute value
no attribute value
no attribute value
no attribute value
no attribute value
############# pass_level2
############# pass_level3
assign unique operator name pnnx_unique_0 to sep1.drop
assign unique operator name pnnx_unique_1 to encoder_norm1
assign unique operator name pnnx_unique_2 to sep2.drop
eliminate_noop_math aten::add pnnx_167
eliminate_noop_math aten::mul pnnx_165
############# pass_level4
############# pass_level5
############# pass_ncnn
Segmentation fault (core dumped)

context | 编译/运行环境 | バックグラウンド

# -*- coding: utf-8 -*-
import os
import torch
import numpy as np
import torch.nn as nn


class Simple_STFT_Layer(nn.Module):
    def __init__(self, frame_len, frame_hop):
        super(Simple_STFT_Layer, self).__init__()
        self.eps = torch.finfo(torch.float32).eps
        self.frame_len = frame_len
        self.frame_hop = frame_hop

    def forward(self, x):
        if len(x.shape) != 2:
            print("x must be in [B, T]")
        y = torch.stft(x, n_fft=self.frame_len, hop_length=self.frame_hop,
                       win_length=self.frame_len, return_complex=True, center=False)
        r = y.real
        i = y.imag
        mag = torch.clamp(r ** 2 + i ** 2, self.eps) ** 0.5
        phase = torch.atan2(i + self.eps, r + self.eps)
        return mag, phase


class Pytorch_InstantLayerNormalization(nn.Module):

    def __init__(self, channels):
        """
            Constructor
        """
        super(Pytorch_InstantLayerNormalization, self).__init__()
        self.epsilon = 1e-7
        self.gamma = nn.Parameter(torch.ones(1, 1, channels), requires_grad=True)
        self.beta = nn.Parameter(torch.zeros(1, 1, channels), requires_grad=True)
        self.register_parameter("gamma", self.gamma)
        self.register_parameter("beta", self.beta)

    def forward(self, inputs):
        # calculate mean of each frame
        mean = torch.mean(inputs, dim=-1, keepdim=True)

        # calculate variance of each frame
        variance = torch.mean(torch.square(inputs - mean), dim=-1, keepdim=True)
        # calculate standard deviation
        std = torch.sqrt(variance + self.epsilon)
        outputs = (inputs - mean) / std
        # scale with gamma
        outputs = outputs * self.gamma
        # add the bias beta
        outputs = outputs + self.beta
        # return output
        return outputs



class SeperationBlock_Stateful(nn.Module):
    def __init__(self, input_size=257, hidden_size=128, dropout=0.25):
        super(SeperationBlock_Stateful, self).__init__()
        self.rnn1 = nn.LSTM(input_size=input_size,
                            hidden_size=hidden_size,
                            num_layers=1,
                            batch_first=True,
                            dropout=0.0,
                            bidirectional=False)
        self.rnn2 = nn.LSTM(input_size=hidden_size,
                            hidden_size=hidden_size,
                            num_layers=1,
                            batch_first=True,
                            dropout=0.0,
                            bidirectional=False)
        self.drop = nn.Dropout(dropout)

        self.dense = nn.Linear(hidden_size, input_size)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, in_states):
        """

        :param x:  [N, T, input_size]
        :param in_states: [2, N, hidden_size, 2]
        :return:
        """
        h1_in, c1_in = in_states[:1, :, :, 0], in_states[:1, :, :, 1]
        h2_in, c2_in = in_states[1:, :, :, 0], in_states[1:, :, :, 1]

        x1, (h1, c1) = self.rnn1(x, (h1_in, c1_in))
        x1 = self.drop(x1)
        x2, (h2, c2) = self.rnn2(x1, (h2_in, c2_in))
        x2 = self.drop(x2)

        mask = self.dense(x2)
        mask = self.sigmoid(mask)

        h = torch.cat((h1, h2), dim=0)
        c = torch.cat((c1, c2), dim=0)
        out_states = torch.stack((h, c), dim=-1)
        return mask, out_states


class Pytorch_DTLN_stateful(nn.Module):
    def __init__(self, frame_len=512, frame_hop=128, window='rect'):
        super(Pytorch_DTLN_stateful, self).__init__()
        self.frame_len = frame_len
        self.frame_hop = frame_hop
        self.stft = Simple_STFT_Layer(frame_len, frame_hop)

        self.sep1 = SeperationBlock_Stateful(input_size=(frame_len // 2 + 1), hidden_size=128, dropout=0.25)

        self.encoder_size = 256
        self.encoder_conv1 = nn.Conv1d(in_channels=frame_len, out_channels=self.encoder_size,
                                       kernel_size=1, stride=1, bias=False)

        self.encoder_norm1 = Pytorch_InstantLayerNormalization(channels=self.encoder_size)

        self.sep2 = SeperationBlock_Stateful(input_size=self.encoder_size, hidden_size=128, dropout=0.25)

        self.decoder_conv1 = nn.Conv1d(in_channels=self.encoder_size, out_channels=frame_len,
                                       kernel_size=1, stride=1, bias=False)

    def forward(self, x, in_state1, in_state2):
        """

        :param x:  [N, T]
        :return:
        """
        batch, n_frames = x.shape

        mag, phase = self.stft(x)
        mag = mag.permute(0, 2, 1)
        phase = phase.permute(0, 2, 1)

        # N, T, hidden_size
        mask, out_state1 = self.sep1(mag, in_state1)
        estimated_mag = mask * mag

        s1_stft = estimated_mag * torch.exp((1j * phase))
        y1 = torch.fft.irfft2(s1_stft, dim=-1)
        y1 = y1.permute(0, 2, 1)

        encoded_f = self.encoder_conv1(y1)
        encoded_f = encoded_f.permute(0, 2, 1)
        encoded_f_norm = self.encoder_norm1(encoded_f)
        mask_2, out_state2 = self.sep2(encoded_f_norm, in_state2)
        estimated = mask_2 * encoded_f
        estimated = estimated.permute(0, 2, 1)
        decoded_frame = self.decoder_conv1(estimated)
        return decoded_frame, out_state1, out_state2




def test_stateful():
    x = torch.randn(1, 512)
    in_state1 = torch.randn(2, 1, 128, 2)
    in_state2 = torch.randn(2, 1, 128, 2)
    net = Pytorch_DTLN_stateful()
    y, out_state1, out_state2 = net(x, in_state1, in_state2)
    print(y.shape)
    print(out_state1.shape)
    print(out_state2.shape)


def test():
    net = Pytorch_DTLN_stateful()
    net.eval()

    torch.manual_seed(0)
    x = torch.randn(1, 512)
    in_state1 = torch.randn(2, 1, 128, 2)
    in_state2 = torch.randn(2, 1, 128, 2)

    y, out_state1, out_state2 = net(x, in_state1, in_state2)

    # export torchscript
    mod = torch.jit.trace(net, (x, in_state1, in_state2))
    
    mod.save("DTLN.pt")

    # torchscript to pnnx
    import os
    os.system("/home/yangjie/pnnx/pnnx DTLN.pt inputshape=[1,512],[2,1,128,2],[2,1,128,2]")

    return True

if __name__ == "__main__":
    if test():
        print("successed!")
    else:
        print("failed!")

how to reproduce | 复现步骤 | 再現方法

more | 其他 | その他

@zuowanbushiwo
Copy link
Author

gdb 调试 信息如下, 使用 cmake -DCMAKE_BUILD_TYPE=Debug 编译的,但是好像堆栈还是不完整

Type "apropos word" to search for commands related to "word"...
Reading symbols from /home/yangjie/pnnx/pnnx...done.
(gdb) set args DTLN.pt inputshape=[1,512],[2,1,128,2],[2,1,128,2]
(gdb) run
Starting program: /home/yangjie/pnnx/pnnx DTLN.pt inputshape=[1,512],[2,1,128,2],[2,1,128,2]
BFD: warning: /home/yangjie/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/lib/../../../../libgomp.so.1: unsupported GNU_PROPERTY_TYPE (5) type: 0xc0010001
BFD: warning: /home/yangjie/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/lib/../../../../libgomp.so.1: unsupported GNU_PROPERTY_TYPE (5) type: 0xc0010002
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/lib/x86_64-linux-gnu/libthread_db.so.1".
pnnxparam = DTLN.pnnx.param
pnnxbin = DTLN.pnnx.bin
pnnxpy = DTLN_pnnx.py
ncnnparam = DTLN.ncnn.param
ncnnbin = DTLN.ncnn.bin
ncnnpy = DTLN_ncnn.py
optlevel = 2
device = cpu
inputshape = [1,512]f32,[2,1,128,2]f32,[2,1,128,2]f32
inputshape2 =
customop =
moduleop =
############# pass_level0
inline module = Pytorch_InstantLayerNormalization
inline module = SeperationBlock_Stateful
inline module = Simple_STFT_Layer
inline module = Pytorch_InstantLayerNormalization
inline module = SeperationBlock_Stateful
inline module = Simple_STFT_Layer
41  92  y.1  r.1  i.1  105  106  107  108  mag.1  110  230  111  phase.1  19  20  orig_input.1  phase.3  125  126  127  hx.2  129  130  131  hx0.2  133  134  135  hx1.2  137  138  139  hx2.2  142  143  144  145  147  148  149  150  151  152  h.2  c.2  158  35  36  estimated_mag.1  42  43  s1_stft.1  y1.1  input.1  58  inputs.1  160  beta.1  gamma.1  mean.1  169  170  variance.1  173  std.1  175  outputs.1  outputs0.1  178  190  191  192  hx.1  194  195  196  hx0.1  198  199  200  hx1.1  202  203  204  hx2.1  207  208  209  210  212  213  214  215  216  217  h.1  c.1  223  68  69  estimated.1  input0.1  78
----------------

[New Thread 0x7fff0bb4d700 (LWP 5122)]
[New Thread 0x7fff0b34c700 (LWP 5123)]
[New Thread 0x7fff0ab4b700 (LWP 5124)]
[New Thread 0x7fff0a34a700 (LWP 5125)]
[New Thread 0x7fff09b49700 (LWP 5126)]
[New Thread 0x7fff09348700 (LWP 5127)]
[New Thread 0x7fff08b47700 (LWP 5128)]
[New Thread 0x7fff08346700 (LWP 5129)]
[New Thread 0x7fff07b45700 (LWP 5130)]
[New Thread 0x7fff07344700 (LWP 5131)]
[New Thread 0x7fff06b43700 (LWP 5132)]
[New Thread 0x7fff06342700 (LWP 5133)]
[New Thread 0x7fff05b41700 (LWP 5134)]
[New Thread 0x7fff05340700 (LWP 5135)]
[New Thread 0x7fff04b3f700 (LWP 5136)]
[New Thread 0x7fff0433e700 (LWP 5137)]
[New Thread 0x7fff03b3d700 (LWP 5138)]
[New Thread 0x7fff0333c700 (LWP 5139)]
[New Thread 0x7fff02b3b700 (LWP 5140)]
[New Thread 0x7fff0233a700 (LWP 5141)]
[New Thread 0x7fff01b39700 (LWP 5142)]
[New Thread 0x7fff01338700 (LWP 5143)]
[New Thread 0x7fff00b37700 (LWP 5144)]
[New Thread 0x7fff00336700 (LWP 5145)]
[New Thread 0x7ffeffb35700 (LWP 5146)]
[New Thread 0x7ffeff334700 (LWP 5147)]
[New Thread 0x7ffefeb33700 (LWP 5148)]
[New Thread 0x7ffefe332700 (LWP 5149)]
[New Thread 0x7ffefdb31700 (LWP 5150)]
[New Thread 0x7ffefd330700 (LWP 5151)]
[New Thread 0x7ffefcb2f700 (LWP 5152)]
[New Thread 0x7ffefc32e700 (LWP 5153)]
[New Thread 0x7ffefbb2d700 (LWP 5154)]
[New Thread 0x7ffefb32c700 (LWP 5155)]
[New Thread 0x7ffefab2b700 (LWP 5156)]
[New Thread 0x7ffefa32a700 (LWP 5157)]
[New Thread 0x7ffef9b29700 (LWP 5158)]
[New Thread 0x7ffef9328700 (LWP 5159)]
[New Thread 0x7ffef8b27700 (LWP 5160)]
[Thread 0x7fff01b39700 (LWP 5142) exited]
[Thread 0x7ffef9b29700 (LWP 5158) exited]
[Thread 0x7ffefa32a700 (LWP 5157) exited]
[Thread 0x7ffefdb31700 (LWP 5150) exited]
[Thread 0x7ffefe332700 (LWP 5149) exited]
[Thread 0x7fff01338700 (LWP 5143) exited]
[Thread 0x7fff00336700 (LWP 5145) exited]
[Thread 0x7ffef9328700 (LWP 5159) exited]
[Thread 0x7ffefab2b700 (LWP 5156) exited]
[Thread 0x7ffefc32e700 (LWP 5153) exited]
[Thread 0x7ffef8b27700 (LWP 5160) exited]
[Thread 0x7ffeff334700 (LWP 5147) exited]
[Thread 0x7ffeffb35700 (LWP 5146) exited]
[Thread 0x7ffefbb2d700 (LWP 5154) exited]
[Thread 0x7ffefcb2f700 (LWP 5152) exited]
[Thread 0x7fff00b37700 (LWP 5144) exited]
[Thread 0x7ffefd330700 (LWP 5151) exited]
[Thread 0x7ffefeb33700 (LWP 5148) exited]
[Thread 0x7ffefb32c700 (LWP 5155) exited]
[Thread 0x7fff07344700 (LWP 5131) exited]
[Thread 0x7fff02b3b700 (LWP 5140) exited]
[Thread 0x7fff06b43700 (LWP 5132) exited]
[Thread 0x7fff05b41700 (LWP 5134) exited]
[Thread 0x7fff06342700 (LWP 5133) exited]
[Thread 0x7fff0433e700 (LWP 5137) exited]
[Thread 0x7fff05340700 (LWP 5135) exited]
[Thread 0x7fff0233a700 (LWP 5141) exited]
[Thread 0x7fff04b3f700 (LWP 5136) exited]
[Thread 0x7fff03b3d700 (LWP 5138) exited]
[Thread 0x7fff0333c700 (LWP 5139) exited]
[New Thread 0x7fff0333c700 (LWP 5161)]
[New Thread 0x7fff03b3d700 (LWP 5162)]
[New Thread 0x7fff04b3f700 (LWP 5163)]
[New Thread 0x7fff0233a700 (LWP 5164)]
[New Thread 0x7fff07344700 (LWP 5165)]
[New Thread 0x7fff06b43700 (LWP 5166)]
[New Thread 0x7fff06342700 (LWP 5167)]
[New Thread 0x7fff05b41700 (LWP 5168)]
[New Thread 0x7fff05340700 (LWP 5169)]
[New Thread 0x7fff0433e700 (LWP 5170)]
[New Thread 0x7fff02b3b700 (LWP 5171)]
[New Thread 0x7fff01b39700 (LWP 5172)]
[New Thread 0x7fff01338700 (LWP 5173)]
[New Thread 0x7fff00b37700 (LWP 5174)]
[New Thread 0x7fff00336700 (LWP 5175)]
[New Thread 0x7ffeffb35700 (LWP 5176)]
[New Thread 0x7ffeff334700 (LWP 5177)]
[New Thread 0x7ffefeb33700 (LWP 5178)]
[New Thread 0x7ffefe332700 (LWP 5179)]
[New Thread 0x7ffefdb31700 (LWP 5180)]
[New Thread 0x7ffefd330700 (LWP 5181)]
[New Thread 0x7ffefcb2f700 (LWP 5182)]
[New Thread 0x7ffefc32e700 (LWP 5183)]
[New Thread 0x7ffefbb2d700 (LWP 5184)]
[New Thread 0x7ffefb32c700 (LWP 5185)]
[New Thread 0x7ffefab2b700 (LWP 5186)]
[New Thread 0x7ffefa32a700 (LWP 5187)]
[New Thread 0x7ffef9b29700 (LWP 5188)]
[New Thread 0x7ffef9328700 (LWP 5189)]
[New Thread 0x7ffef8b27700 (LWP 5190)]
[Thread 0x7ffefcb2f700 (LWP 5182) exited]
[Thread 0x7ffef9b29700 (LWP 5188) exited]
[Thread 0x7ffefc32e700 (LWP 5183) exited]
[Thread 0x7ffefd330700 (LWP 5181) exited]
[Thread 0x7ffeff334700 (LWP 5177) exited]
[Thread 0x7ffef8b27700 (LWP 5190) exited]
[Thread 0x7fff01338700 (LWP 5173) exited]
[Thread 0x7ffefe332700 (LWP 5179) exited]
[Thread 0x7ffeffb35700 (LWP 5176) exited]
[Thread 0x7fff01b39700 (LWP 5172) exited]
[Thread 0x7ffefeb33700 (LWP 5178) exited]
[Thread 0x7fff00336700 (LWP 5175) exited]
[Thread 0x7ffef9328700 (LWP 5189) exited]
[Thread 0x7ffefa32a700 (LWP 5187) exited]
[Thread 0x7ffefb32c700 (LWP 5185) exited]
[Thread 0x7ffefbb2d700 (LWP 5184) exited]
[Thread 0x7ffefdb31700 (LWP 5180) exited]
[Thread 0x7fff07344700 (LWP 5165) exited]
[Thread 0x7ffefab2b700 (LWP 5186) exited]
[Thread 0x7fff02b3b700 (LWP 5171) exited]
[Thread 0x7fff05b41700 (LWP 5168) exited]
[Thread 0x7fff05340700 (LWP 5169) exited]
[Thread 0x7fff0233a700 (LWP 5164) exited]
[Thread 0x7fff06b43700 (LWP 5166) exited]
[Thread 0x7fff00b37700 (LWP 5174) exited]
[Thread 0x7fff0333c700 (LWP 5161) exited]
[Thread 0x7fff06342700 (LWP 5167) exited]
[Thread 0x7fff0433e700 (LWP 5170) exited]
[Thread 0x7fff03b3d700 (LWP 5162) exited]
[Thread 0x7fff04b3f700 (LWP 5163) exited]
[New Thread 0x7fff04b3f700 (LWP 5191)]
[New Thread 0x7fff03b3d700 (LWP 5192)]
[New Thread 0x7fff0433e700 (LWP 5193)]
[New Thread 0x7fff06342700 (LWP 5194)]
[New Thread 0x7fff07344700 (LWP 5195)]
[New Thread 0x7fff06b43700 (LWP 5196)]
[New Thread 0x7fff05b41700 (LWP 5197)]
[New Thread 0x7fff05340700 (LWP 5198)]
[New Thread 0x7fff0333c700 (LWP 5199)]
[New Thread 0x7fff02b3b700 (LWP 5200)]
[New Thread 0x7fff0233a700 (LWP 5201)]
[New Thread 0x7fff01b39700 (LWP 5202)]
[New Thread 0x7fff01338700 (LWP 5203)]
[New Thread 0x7fff00b37700 (LWP 5204)]
[New Thread 0x7fff00336700 (LWP 5205)]
[New Thread 0x7ffeffb35700 (LWP 5206)]
[New Thread 0x7ffeff334700 (LWP 5207)]
[New Thread 0x7ffefeb33700 (LWP 5208)]
[New Thread 0x7ffefe332700 (LWP 5209)]
[New Thread 0x7ffefdb31700 (LWP 5210)]
[New Thread 0x7ffefd330700 (LWP 5211)]
[New Thread 0x7ffefcb2f700 (LWP 5212)]
[New Thread 0x7ffefc32e700 (LWP 5213)]
[New Thread 0x7ffefbb2d700 (LWP 5214)]
[New Thread 0x7ffefb32c700 (LWP 5215)]
[New Thread 0x7ffefab2b700 (LWP 5216)]
[New Thread 0x7ffefa32a700 (LWP 5217)]
[New Thread 0x7ffef9b29700 (LWP 5218)]
[New Thread 0x7ffef9328700 (LWP 5219)]
[New Thread 0x7ffef8b27700 (LWP 5220)]
foldable_constant beta.1
foldable_constant gamma.1
############# pass_level1
no attribute value
unknown Parameter value kind prim::Constant of TensorType, t.dim = 0
no attribute value
no attribute value
no attribute value
no attribute value
no attribute value
no attribute value
############# pass_level2
############# pass_level3
assign unique operator name pnnx_unique_0 to sep1.drop
assign unique operator name pnnx_unique_1 to encoder_norm1
assign unique operator name pnnx_unique_2 to sep2.drop
eliminate_noop_math aten::add pnnx_167
eliminate_noop_math aten::mul pnnx_165
############# pass_level4
############# pass_level5
############# pass_ncnn

Thread 1 "pnnx" received signal SIGSEGV, Segmentation fault.
0x00005555555b6450 in std::vector<pnnx::Operator*, std::allocator<pnnx::Operator*> >::push_back (this=0x10, __x=@0x7fffffffdd28: 0x555558acb000)
    at /usr/include/c++/7/bits/stl_vector.h:941
941             if (this->_M_impl._M_finish != this->_M_impl._M_end_of_storage)
(gdb)

@zuowanbushiwo
Copy link
Author

@nihui 大佬,麻烦帮忙看一下

@zchrissirhcz
Copy link
Contributor

gdb 里面跑到 segfault 时,可以用 bt 命令查看 call stack, 使用 up / down 等命令切换 frame, 在不同的 frame 下可以用 p 命令等打印变量,做进一步检查。

@Plutoisme
Copy link

@nihui 大佬,麻烦帮忙看一下

gdb 调试 信息如下, 使用 cmake -DCMAKE_BUILD_TYPE=Debug 编译的,但是好像堆栈还是不完整

Type "apropos word" to search for commands related to "word"...
Reading symbols from /home/yangjie/pnnx/pnnx...done.
(gdb) set args DTLN.pt inputshape=[1,512],[2,1,128,2],[2,1,128,2]
(gdb) run
Starting program: /home/yangjie/pnnx/pnnx DTLN.pt inputshape=[1,512],[2,1,128,2],[2,1,128,2]
BFD: warning: /home/yangjie/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/lib/../../../../libgomp.so.1: unsupported GNU_PROPERTY_TYPE (5) type: 0xc0010001
BFD: warning: /home/yangjie/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/lib/../../../../libgomp.so.1: unsupported GNU_PROPERTY_TYPE (5) type: 0xc0010002
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/lib/x86_64-linux-gnu/libthread_db.so.1".
pnnxparam = DTLN.pnnx.param
pnnxbin = DTLN.pnnx.bin
pnnxpy = DTLN_pnnx.py
ncnnparam = DTLN.ncnn.param
ncnnbin = DTLN.ncnn.bin
ncnnpy = DTLN_ncnn.py
optlevel = 2
device = cpu
inputshape = [1,512]f32,[2,1,128,2]f32,[2,1,128,2]f32
inputshape2 =
customop =
moduleop =
############# pass_level0
inline module = Pytorch_InstantLayerNormalization
inline module = SeperationBlock_Stateful
inline module = Simple_STFT_Layer
inline module = Pytorch_InstantLayerNormalization
inline module = SeperationBlock_Stateful
inline module = Simple_STFT_Layer
41  92  y.1  r.1  i.1  105  106  107  108  mag.1  110  230  111  phase.1  19  20  orig_input.1  phase.3  125  126  127  hx.2  129  130  131  hx0.2  133  134  135  hx1.2  137  138  139  hx2.2  142  143  144  145  147  148  149  150  151  152  h.2  c.2  158  35  36  estimated_mag.1  42  43  s1_stft.1  y1.1  input.1  58  inputs.1  160  beta.1  gamma.1  mean.1  169  170  variance.1  173  std.1  175  outputs.1  outputs0.1  178  190  191  192  hx.1  194  195  196  hx0.1  198  199  200  hx1.1  202  203  204  hx2.1  207  208  209  210  212  213  214  215  216  217  h.1  c.1  223  68  69  estimated.1  input0.1  78
----------------

[New Thread 0x7fff0bb4d700 (LWP 5122)]
[New Thread 0x7fff0b34c700 (LWP 5123)]
[New Thread 0x7fff0ab4b700 (LWP 5124)]
[New Thread 0x7fff0a34a700 (LWP 5125)]
[New Thread 0x7fff09b49700 (LWP 5126)]
[New Thread 0x7fff09348700 (LWP 5127)]
[New Thread 0x7fff08b47700 (LWP 5128)]
[New Thread 0x7fff08346700 (LWP 5129)]
[New Thread 0x7fff07b45700 (LWP 5130)]
[New Thread 0x7fff07344700 (LWP 5131)]
[New Thread 0x7fff06b43700 (LWP 5132)]
[New Thread 0x7fff06342700 (LWP 5133)]
[New Thread 0x7fff05b41700 (LWP 5134)]
[New Thread 0x7fff05340700 (LWP 5135)]
[New Thread 0x7fff04b3f700 (LWP 5136)]
[New Thread 0x7fff0433e700 (LWP 5137)]
[New Thread 0x7fff03b3d700 (LWP 5138)]
[New Thread 0x7fff0333c700 (LWP 5139)]
[New Thread 0x7fff02b3b700 (LWP 5140)]
[New Thread 0x7fff0233a700 (LWP 5141)]
[New Thread 0x7fff01b39700 (LWP 5142)]
[New Thread 0x7fff01338700 (LWP 5143)]
[New Thread 0x7fff00b37700 (LWP 5144)]
[New Thread 0x7fff00336700 (LWP 5145)]
[New Thread 0x7ffeffb35700 (LWP 5146)]
[New Thread 0x7ffeff334700 (LWP 5147)]
[New Thread 0x7ffefeb33700 (LWP 5148)]
[New Thread 0x7ffefe332700 (LWP 5149)]
[New Thread 0x7ffefdb31700 (LWP 5150)]
[New Thread 0x7ffefd330700 (LWP 5151)]
[New Thread 0x7ffefcb2f700 (LWP 5152)]
[New Thread 0x7ffefc32e700 (LWP 5153)]
[New Thread 0x7ffefbb2d700 (LWP 5154)]
[New Thread 0x7ffefb32c700 (LWP 5155)]
[New Thread 0x7ffefab2b700 (LWP 5156)]
[New Thread 0x7ffefa32a700 (LWP 5157)]
[New Thread 0x7ffef9b29700 (LWP 5158)]
[New Thread 0x7ffef9328700 (LWP 5159)]
[New Thread 0x7ffef8b27700 (LWP 5160)]
[Thread 0x7fff01b39700 (LWP 5142) exited]
[Thread 0x7ffef9b29700 (LWP 5158) exited]
[Thread 0x7ffefa32a700 (LWP 5157) exited]
[Thread 0x7ffefdb31700 (LWP 5150) exited]
[Thread 0x7ffefe332700 (LWP 5149) exited]
[Thread 0x7fff01338700 (LWP 5143) exited]
[Thread 0x7fff00336700 (LWP 5145) exited]
[Thread 0x7ffef9328700 (LWP 5159) exited]
[Thread 0x7ffefab2b700 (LWP 5156) exited]
[Thread 0x7ffefc32e700 (LWP 5153) exited]
[Thread 0x7ffef8b27700 (LWP 5160) exited]
[Thread 0x7ffeff334700 (LWP 5147) exited]
[Thread 0x7ffeffb35700 (LWP 5146) exited]
[Thread 0x7ffefbb2d700 (LWP 5154) exited]
[Thread 0x7ffefcb2f700 (LWP 5152) exited]
[Thread 0x7fff00b37700 (LWP 5144) exited]
[Thread 0x7ffefd330700 (LWP 5151) exited]
[Thread 0x7ffefeb33700 (LWP 5148) exited]
[Thread 0x7ffefb32c700 (LWP 5155) exited]
[Thread 0x7fff07344700 (LWP 5131) exited]
[Thread 0x7fff02b3b700 (LWP 5140) exited]
[Thread 0x7fff06b43700 (LWP 5132) exited]
[Thread 0x7fff05b41700 (LWP 5134) exited]
[Thread 0x7fff06342700 (LWP 5133) exited]
[Thread 0x7fff0433e700 (LWP 5137) exited]
[Thread 0x7fff05340700 (LWP 5135) exited]
[Thread 0x7fff0233a700 (LWP 5141) exited]
[Thread 0x7fff04b3f700 (LWP 5136) exited]
[Thread 0x7fff03b3d700 (LWP 5138) exited]
[Thread 0x7fff0333c700 (LWP 5139) exited]
[New Thread 0x7fff0333c700 (LWP 5161)]
[New Thread 0x7fff03b3d700 (LWP 5162)]
[New Thread 0x7fff04b3f700 (LWP 5163)]
[New Thread 0x7fff0233a700 (LWP 5164)]
[New Thread 0x7fff07344700 (LWP 5165)]
[New Thread 0x7fff06b43700 (LWP 5166)]
[New Thread 0x7fff06342700 (LWP 5167)]
[New Thread 0x7fff05b41700 (LWP 5168)]
[New Thread 0x7fff05340700 (LWP 5169)]
[New Thread 0x7fff0433e700 (LWP 5170)]
[New Thread 0x7fff02b3b700 (LWP 5171)]
[New Thread 0x7fff01b39700 (LWP 5172)]
[New Thread 0x7fff01338700 (LWP 5173)]
[New Thread 0x7fff00b37700 (LWP 5174)]
[New Thread 0x7fff00336700 (LWP 5175)]
[New Thread 0x7ffeffb35700 (LWP 5176)]
[New Thread 0x7ffeff334700 (LWP 5177)]
[New Thread 0x7ffefeb33700 (LWP 5178)]
[New Thread 0x7ffefe332700 (LWP 5179)]
[New Thread 0x7ffefdb31700 (LWP 5180)]
[New Thread 0x7ffefd330700 (LWP 5181)]
[New Thread 0x7ffefcb2f700 (LWP 5182)]
[New Thread 0x7ffefc32e700 (LWP 5183)]
[New Thread 0x7ffefbb2d700 (LWP 5184)]
[New Thread 0x7ffefb32c700 (LWP 5185)]
[New Thread 0x7ffefab2b700 (LWP 5186)]
[New Thread 0x7ffefa32a700 (LWP 5187)]
[New Thread 0x7ffef9b29700 (LWP 5188)]
[New Thread 0x7ffef9328700 (LWP 5189)]
[New Thread 0x7ffef8b27700 (LWP 5190)]
[Thread 0x7ffefcb2f700 (LWP 5182) exited]
[Thread 0x7ffef9b29700 (LWP 5188) exited]
[Thread 0x7ffefc32e700 (LWP 5183) exited]
[Thread 0x7ffefd330700 (LWP 5181) exited]
[Thread 0x7ffeff334700 (LWP 5177) exited]
[Thread 0x7ffef8b27700 (LWP 5190) exited]
[Thread 0x7fff01338700 (LWP 5173) exited]
[Thread 0x7ffefe332700 (LWP 5179) exited]
[Thread 0x7ffeffb35700 (LWP 5176) exited]
[Thread 0x7fff01b39700 (LWP 5172) exited]
[Thread 0x7ffefeb33700 (LWP 5178) exited]
[Thread 0x7fff00336700 (LWP 5175) exited]
[Thread 0x7ffef9328700 (LWP 5189) exited]
[Thread 0x7ffefa32a700 (LWP 5187) exited]
[Thread 0x7ffefb32c700 (LWP 5185) exited]
[Thread 0x7ffefbb2d700 (LWP 5184) exited]
[Thread 0x7ffefdb31700 (LWP 5180) exited]
[Thread 0x7fff07344700 (LWP 5165) exited]
[Thread 0x7ffefab2b700 (LWP 5186) exited]
[Thread 0x7fff02b3b700 (LWP 5171) exited]
[Thread 0x7fff05b41700 (LWP 5168) exited]
[Thread 0x7fff05340700 (LWP 5169) exited]
[Thread 0x7fff0233a700 (LWP 5164) exited]
[Thread 0x7fff06b43700 (LWP 5166) exited]
[Thread 0x7fff00b37700 (LWP 5174) exited]
[Thread 0x7fff0333c700 (LWP 5161) exited]
[Thread 0x7fff06342700 (LWP 5167) exited]
[Thread 0x7fff0433e700 (LWP 5170) exited]
[Thread 0x7fff03b3d700 (LWP 5162) exited]
[Thread 0x7fff04b3f700 (LWP 5163) exited]
[New Thread 0x7fff04b3f700 (LWP 5191)]
[New Thread 0x7fff03b3d700 (LWP 5192)]
[New Thread 0x7fff0433e700 (LWP 5193)]
[New Thread 0x7fff06342700 (LWP 5194)]
[New Thread 0x7fff07344700 (LWP 5195)]
[New Thread 0x7fff06b43700 (LWP 5196)]
[New Thread 0x7fff05b41700 (LWP 5197)]
[New Thread 0x7fff05340700 (LWP 5198)]
[New Thread 0x7fff0333c700 (LWP 5199)]
[New Thread 0x7fff02b3b700 (LWP 5200)]
[New Thread 0x7fff0233a700 (LWP 5201)]
[New Thread 0x7fff01b39700 (LWP 5202)]
[New Thread 0x7fff01338700 (LWP 5203)]
[New Thread 0x7fff00b37700 (LWP 5204)]
[New Thread 0x7fff00336700 (LWP 5205)]
[New Thread 0x7ffeffb35700 (LWP 5206)]
[New Thread 0x7ffeff334700 (LWP 5207)]
[New Thread 0x7ffefeb33700 (LWP 5208)]
[New Thread 0x7ffefe332700 (LWP 5209)]
[New Thread 0x7ffefdb31700 (LWP 5210)]
[New Thread 0x7ffefd330700 (LWP 5211)]
[New Thread 0x7ffefcb2f700 (LWP 5212)]
[New Thread 0x7ffefc32e700 (LWP 5213)]
[New Thread 0x7ffefbb2d700 (LWP 5214)]
[New Thread 0x7ffefb32c700 (LWP 5215)]
[New Thread 0x7ffefab2b700 (LWP 5216)]
[New Thread 0x7ffefa32a700 (LWP 5217)]
[New Thread 0x7ffef9b29700 (LWP 5218)]
[New Thread 0x7ffef9328700 (LWP 5219)]
[New Thread 0x7ffef8b27700 (LWP 5220)]
foldable_constant beta.1
foldable_constant gamma.1
############# pass_level1
no attribute value
unknown Parameter value kind prim::Constant of TensorType, t.dim = 0
no attribute value
no attribute value
no attribute value
no attribute value
no attribute value
no attribute value
############# pass_level2
############# pass_level3
assign unique operator name pnnx_unique_0 to sep1.drop
assign unique operator name pnnx_unique_1 to encoder_norm1
assign unique operator name pnnx_unique_2 to sep2.drop
eliminate_noop_math aten::add pnnx_167
eliminate_noop_math aten::mul pnnx_165
############# pass_level4
############# pass_level5
############# pass_ncnn

Thread 1 "pnnx" received signal SIGSEGV, Segmentation fault.
0x00005555555b6450 in std::vector<pnnx::Operator*, std::allocator<pnnx::Operator*> >::push_back (this=0x10, __x=@0x7fffffffdd28: 0x555558acb000)
    at /usr/include/c++/7/bits/stl_vector.h:941
941             if (this->_M_impl._M_finish != this->_M_impl._M_end_of_storage)
(gdb)

哥们解决了吗,我目前也是在转这个模型遇到了问题

@nihui
Copy link
Member

nihui commented Nov 19, 2024

#5779

@zuowanbushiwo
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants