Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torchscript model to pnnx model error #4598

Closed
Plutoisme opened this issue Mar 30, 2023 · 7 comments
Closed

Torchscript model to pnnx model error #4598

Plutoisme opened this issue Mar 30, 2023 · 7 comments

Comments

@Plutoisme
Copy link

error log | 日志或报错信息 | ログ

model | 模型 | モデル

  1. original model
class stftLayer(nn.Module):
    '''
    input: shape=[batchsize, T] audio time signal.
    output: mag, phase. which shape=[batchsize, F, T] audio STFT.
    '''
    def __init__(self, frameLength, hopLength):
        super(stftLayer, self).__init__()
        self.eps = torch.finfo(torch.float32).eps
        self.frameLength = frameLength
        self.hopLength = hopLength

    def forward(self, input):
        y = torch.stft(input, n_fft=self.frameLength,
                       hop_length=self.hopLength,
                       win_length=self.frameLength,
                       return_complex=True,
                       center=False)
        realPart = y.real
        imagPart = y.imag
        mag = torch.clamp(realPart ** 2 + imagPart ** 2, self.eps) ** 0.5
        phase = torch.atan2(imagPart + self.eps, realPart + self.eps)
        return mag, phase
    
class NormLayer(nn.Module):
    def __init__(self, channels):
        super(NormLayer, self).__init__()
        self.eps = 1e-7
        self.gamma = nn.Parameter(torch.ones(1, 1, channels), requires_grad=True)
        self.beta = nn.Parameter(torch.zeros(1, 1, channels), requires_grad=True)
        self.register_parameter("gamma", self.gamma)
        self.register_parameter("beta", self.beta)

    def forward(self, inputs):
        meanOfFrames = torch.mean(inputs, dim=-1, keepdim=True)
        varianceOfFrames = torch.mean(torch.square(inputs - meanOfFrames), dim=-1, keepdim=True)
        stdOfFrames = torch.sqrt(varianceOfFrames + self.eps)

        outputs = (inputs - meanOfFrames) / stdOfFrames
        outputs = outputs * self.gamma
        outputs = outputs + self.beta
        return 


class seperationOfLSTMs_Stateful_aec(nn.Module):
    '''
    input: [N, T, F]
    output: []
    '''
    def __init__(self, input_size=257*2, hidden_size=128, dropout=0.2):
        super(seperationOfLSTMs_Stateful_aec,self).__init__()
        self.lstm1 = nn.LSTM(input_size = input_size,
                             hidden_size = hidden_size,
                             num_layers = 1,
                             batch_first = True,
                             dropout = 0, # no meaning while num_layers == 1
                             bidirectional = False
                             )
        self.lstm2 = nn.LSTM(input_size = hidden_size,
                             hidden_size = hidden_size,
                             num_layers = 1,
                             batch_first = True,
                             dropout = 0,
                             bidirectional = False
                             )
        self.drop = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_size, int(input_size/2))
        self.sigmoid = nn.Sigmoid()

    def forward(self, input, in_states):
        h1_in, c1_in = in_states[:1, :, :, 0], in_states[:1, :, :, 1]
        h2_in, c2_in = in_states[1:, :, :, 0], in_states[1:, :, :, 1]

        h1_in = h1_in.contiguous()
        c1_in = c1_in.contiguous()
        h2_in = h2_in.contiguous()
        c2_in = c2_in.contiguous()

        x1, (h1, c1) = self.lstm1(input, (h1_in, c1_in))
        x1 = self.drop(x1)
        x2, (h2, c2) = self.lstm2(x1, (h2_in, c2_in))
        x2 = self.drop(x2)
    
        mask = self.fc(x2)
        mask = self.sigmoid(mask)

        h = torch.cat((h1, h2), dim=0)
        c = torch.cat((c1, c2), dim=0)
        outStates = torch.stack((h,c), dim=-1)
        return mask, outStates

class DTLN_modelForAec_Stateful(nn.Module):
    def __init__(self, frameLength = 512, 
                hopLength = 256,
                hiddenSizeOfRNN = 128,
                typeOfRNN = 'LSTM',
                numOfConv_Features = 256):
        super(DTLN_modelForAec_Stateful, self).__init__()
        self.frameLength = frameLength
        self.hopLength = hopLength
        self.stftLayer = stftLayer(frameLength=frameLength, hopLength=hopLength)
        self.NormLayerStage1_ref = NormLayer(frameLength//2 + 1)
        self.NormLayerStage1_mic = NormLayer(frameLength//2 + 1)
        self.seperationStage1 = seperationOfLSTMs_Stateful_aec(input_size = 2 * (frameLength//2 + 1),
                                                           hidden_size = hiddenSizeOfRNN,
                                                           dropout = 0.2)
        numOfConvOutChannels = numOfConv_Features
        self.seperationStage2 = seperationOfLSTMs_Stateful_aec(input_size = 2 * numOfConvOutChannels,
                                                            hidden_size = hiddenSizeOfRNN,
                                                            dropout = 0.2)
        
        self.Conv1dOne = nn.Conv1d(in_channels = frameLength,
                                    out_channels = numOfConvOutChannels,
                                    kernel_size = 1,
                                    stride = 1,
                                    bias = False)
        self.Conv1dOne_ref = nn.Conv1d(in_channels = frameLength,
                                        out_channels = numOfConvOutChannels,
                                        kernel_size = 1,
                                        stride = 1,
                                        bias = False)
        
        self.Conv1d_forDecode = nn.Conv1d(in_channels = numOfConvOutChannels,
                                        out_channels = frameLength,
                                        kernel_size = 1,
                                        stride = 1,
                                        bias = False)
        
        self.NormLayerStage2_ref = NormLayer(channels = numOfConvOutChannels)
        self.NormLayerStage2_mic = NormLayer(channels = numOfConvOutChannels)

    def segment(self,audio):
        # split audio into frames according to the frameLength and hopLength
        batchsize, numSample = audio.shape
        numOfFrames = (numSample - (self.frameLength - self.hopLength)) // self.hopLength
        audioSegment = torch.zeros((batchsize, self.frameLength, numOfFrames)).to(audio.device)
        for i in range(batchsize):
            for j in range(numOfFrames):
                audioSegment[i,:,j] = audio[i,j*self.hopLength:j*self.hopLength+self.frameLength]
                
        # audioSegment.shape should be [B, frameLength, numOfFrames], Test the shape
        a,b,c = audioSegment.shape
        assert (a==batchsize) and (b==self.frameLength) and (c==numOfFrames), 'shape wrong!'
        return audioSegment
    
    def forward(self, micAudio, refAudio, inState1, inState2):
        refSegment = self.segment(refAudio) # refSegment.shape []
        mag_micAudioStft, phase_micAudioStft = self.stftLayer(micAudio)
        mag_refAudioStft, _ = self.stftLayer(refAudio) # the information of phase in refaudio is not needed.

        mag_micAudioStft = mag_micAudioStft.permute(0,2,1) # [B, F, T] -> [B, T, F]
        mag_refAudioStft = mag_refAudioStft.permute(0,2,1)
        phase_micAudioStft = phase_micAudioStft.permute(0,2,1)

        # Normalize in the F domain.
        mag_micAudioStft_Normed = self.NormLayerStage1_mic(mag_micAudioStft)
        mag_refAudioStft_Normed = self.NormLayerStage1_ref(mag_refAudioStft)


        # Concat the mag_mic and mag_ref in the F domain
        mag_Concated = torch.concatenate((mag_refAudioStft_Normed,mag_micAudioStft_Normed), dim = -1)

        # NetWork predicts the mask.
        mask1, outState1 = self.seperationStage1(mag_Concated, inState1) # mask.shape [B, T, F]
        magEstimated = mask1 * mag_micAudioStft
        out_StftOfStage1 = magEstimated * torch.exp((1j * phase_micAudioStft))
        out_ifft = torch.fft.irfft(out_StftOfStage1, dim=-1) # [N, T, F] -> [N, T, frameLength]
        out_ifft = out_ifft.permute(0, 2, 1) # [N, T, frameLength] -> [N, frameLength, T]

        # encode in frameLength domain
        feaEncoded = self.Conv1dOne(out_ifft) # [N, frameLength, T] -> [N, numOfFeas, T]
        feaEncoded = feaEncoded.permute(0, 2, 1) # [N, numOfFeas, T] -> [N, T, numOfFeas] , needs to multiply output later.
        feaEncoded_Normed = self.NormLayerStage2_mic(feaEncoded)

        # encode the segment of reference signal
        refEncoded = self.Conv1dOne_ref(refSegment) # [N, frameLength, T] -> [N, numOfFeas, T]
        refEncoded = refEncoded.permute(0, 2, 1) # [N, numOfFeas, T] -> [N, T, numOfFeas]
        refEncoded_Normed = self.NormLayerStage2_ref(refEncoded)

        # Concat the refFeatures and magFeatures in the F domain
        Fea_Concated = torch.concatenate((refEncoded_Normed, feaEncoded_Normed), dim = -1)# [N, T, numOfFeas] -> [N, T, 2*numOfFeas]
        mask2, outState2 = self.seperationStage2(Fea_Concated, inState2) # [N, T, 2*numOfFeas] -> [N, T, numOfFeas]
        out = mask2 * feaEncoded

        # decode Feas to time value.
        out = out.permute(0, 2, 1) # [N, T, numOfFeas] -> [N, numOfFeas, T]
        outDecoded = self.Conv1d_forDecode(out) # [N, numOfFeas, T] -> [N, frameLength, T]

        # if use real-time inference, we don't need overlap-add process.
        # Overlap-Add Process
        audioDataEnhanced = F.fold(outDecoded,
                                (micAudio.shape[1], 1),
                                kernel_size = (self.frameLength, 1),
                                padding = (0, 0),
                                stride = (self.hopLength, 1))
        audioDataEnhanced = audioDataEnhanced.reshape(micAudio.shape[0], micAudio.shape[1])

作者大大您好,以上是我的模型,在我初始化后将其导出:

model_path = "/home/lizhinan/project/lightse/DTLNPytorch/models/DTLN-aec Trainfinetune_singletalk 0327_stateful/checkpoints/model_0200.pth"
input_1 = torch.randn(1,2048)
input_2 = torch.randn(1,2048)
in_state1 = torch.randn(2,1,512,2)
in_state2 = torch.randn(2,1,512,2)

model = DTLN_modelForAec_Stateful(frameLength=2048,
                                hopLength=512,
                                hiddenSizeOfRNN=512)
model.load_state_dict(torch.load(model_path, map_location='cpu'))
model.eval()

jit_trace_model = torch.jit.trace(model, (input_1, input_2, in_state1, in_state2))
jit_trace_model.save("DTLN_aec.pth")

导出想要使用torchscript->pnnx->ncnn最后实现移动端推理,但是在第一步就卡住了,当我执行以下命令时:

pnnx-20220418-ubuntu/pnnx DTLN_aec.pth inputshape1=[1,2048]f32 inputshape2=[1,2048]f32 inputshape3=[2,1,512,2]f32 inputshape4=[2,1,512,2]f32 device=cpu

报错:

pnnxparam = DTLN_aec.pnnx.param
pnnxbin = DTLN_aec.pnnx.bin
pnnxpy = DTLN_aec_pnnx.py
ncnnparam = DTLN_aec.ncnn.param
ncnnbin = DTLN_aec.ncnn.bin
ncnnpy = DTLN_aec_ncnn.py
optlevel = 2
device = cpu
inputshape = 
inputshape2 = [1,2048]f32
customop = 
moduleop = 
############# pass_level0
inline module = modules.dtlnModel.NormLayer
inline module = modules.dtlnModel.stftLayer
inline module = modules.dtlnModel_aec_Stateful.seperationOfLSTMs_Stateful_aec
############# pass_level1
unknown Parameter value kind prim::Constant of TensorType, t.dim = 0
unknown Parameter value kind prim::Constant
no attribute value
no attribute value
unknown Parameter value kind prim::Constant
no attribute value
no attribute value
no attribute value
no attribute value
no attribute value
no attribute value
no attribute value
no attribute value
no attribute value
no attribute value
no attribute value
no attribute value
no attribute value
no attribute value
no attribute value
no attribute value
no attribute value
no attribute value
############# pass_level2
############# pass_level3
assign unique operator name pnnx_unique_0 to NormLayerStage1_mic
assign unique operator name pnnx_unique_1 to NormLayerStage1_ref
assign unique operator name pnnx_unique_2 to seperationStage1.drop
assign unique operator name pnnx_unique_3 to NormLayerStage2_mic
assign unique operator name pnnx_unique_4 to NormLayerStage2_ref
assign unique operator name pnnx_unique_5 to seperationStage2.drop
############# pass_level4
############# pass_level5
make_slice_expression input 2
make_slice_expression input 2
make_slice_expression input 3
make_slice_expression input 3
############# pass_ncnn
Segmentation fault (core dumped)

此时会出现debug.bin, debug.param, debug2.bin, debug2.param,DTLN_aec.pnnx.py, DTLN_aec.pnnx.bin, DTLN_aec_pnnx.param等七个文件,下面我给出有关DTLN_aec.pnnx.py的内容:

import os
import numpy as np
import tempfile, zipfile
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        self.seperationStage1_lstm1 = nn.LSTM(batch_first=True, bias=True, bidirectional=False, hidden_size=512, input_size=2050, num_layers=1)
        self.seperationStage1_lstm2 = nn.LSTM(batch_first=True, bias=True, bidirectional=False, hidden_size=512, input_size=512, num_layers=1)
        self.seperationStage1_fc = nn.Linear(bias=True, in_features=512, out_features=1025)
        self.seperationStage1_sigmoid = nn.Sigmoid()
        self.Conv1dOne = nn.Conv1d(bias=False, dilation=(1,), groups=1, in_channels=2048, kernel_size=(1,), out_channels=256, padding=(0,), padding_mode='zeros', stride=(1,))
        self.Conv1dOne_ref = nn.Conv1d(bias=False, dilation=(1,), groups=1, in_channels=2048, kernel_size=(1,), out_channels=256, padding=(0,), padding_mode='zeros', stride=(1,))
        self.seperationStage2_lstm1 = nn.LSTM(batch_first=True, bias=True, bidirectional=False, hidden_size=512, input_size=512, num_layers=1)
        self.seperationStage2_lstm2 = nn.LSTM(batch_first=True, bias=True, bidirectional=False, hidden_size=512, input_size=512, num_layers=1)
        self.seperationStage2_fc = nn.Linear(bias=True, in_features=512, out_features=256)
        self.seperationStage2_sigmoid = nn.Sigmoid()
        self.Conv1d_forDecode = nn.Conv1d(bias=False, dilation=(1,), groups=1, in_channels=256, kernel_size=(1,), out_channels=2048, padding=(0,), padding_mode='zeros', stride=(1,))

        archive = zipfile.ZipFile('DTLN_aec.pnnx.bin', 'r')
        self.seperationStage1_lstm1.bias_hh_l0 = self.load_pnnx_bin_as_parameter(archive, 'seperationStage1.lstm1.bias_hh_l0', (2048), 'float32')
        self.seperationStage1_lstm1.bias_ih_l0 = self.load_pnnx_bin_as_parameter(archive, 'seperationStage1.lstm1.bias_ih_l0', (2048), 'float32')
        self.seperationStage1_lstm1.weight_hh_l0 = self.load_pnnx_bin_as_parameter(archive, 'seperationStage1.lstm1.weight_hh_l0', (2048,512), 'float32')
        self.seperationStage1_lstm1.weight_ih_l0 = self.load_pnnx_bin_as_parameter(archive, 'seperationStage1.lstm1.weight_ih_l0', (2048,2050), 'float32')
        self.seperationStage1_lstm2.bias_hh_l0 = self.load_pnnx_bin_as_parameter(archive, 'seperationStage1.lstm2.bias_hh_l0', (2048), 'float32')
        self.seperationStage1_lstm2.bias_ih_l0 = self.load_pnnx_bin_as_parameter(archive, 'seperationStage1.lstm2.bias_ih_l0', (2048), 'float32')
        self.seperationStage1_lstm2.weight_hh_l0 = self.load_pnnx_bin_as_parameter(archive, 'seperationStage1.lstm2.weight_hh_l0', (2048,512), 'float32')
        self.seperationStage1_lstm2.weight_ih_l0 = self.load_pnnx_bin_as_parameter(archive, 'seperationStage1.lstm2.weight_ih_l0', (2048,512), 'float32')
        self.seperationStage1_fc.bias = self.load_pnnx_bin_as_parameter(archive, 'seperationStage1.fc.bias', (1025), 'float32')
        self.seperationStage1_fc.weight = self.load_pnnx_bin_as_parameter(archive, 'seperationStage1.fc.weight', (1025,512), 'float32')
        self.Conv1dOne.weight = self.load_pnnx_bin_as_parameter(archive, 'Conv1dOne.weight', (256,2048,1), 'float32')
        self.Conv1dOne_ref.weight = self.load_pnnx_bin_as_parameter(archive, 'Conv1dOne_ref.weight', (256,2048,1), 'float32')
        self.seperationStage2_lstm1.bias_hh_l0 = self.load_pnnx_bin_as_parameter(archive, 'seperationStage2.lstm1.bias_hh_l0', (2048), 'float32')
        self.seperationStage2_lstm1.bias_ih_l0 = self.load_pnnx_bin_as_parameter(archive, 'seperationStage2.lstm1.bias_ih_l0', (2048), 'float32')
        self.seperationStage2_lstm1.weight_hh_l0 = self.load_pnnx_bin_as_parameter(archive, 'seperationStage2.lstm1.weight_hh_l0', (2048,512), 'float32')
        self.seperationStage2_lstm1.weight_ih_l0 = self.load_pnnx_bin_as_parameter(archive, 'seperationStage2.lstm1.weight_ih_l0', (2048,512), 'float32')
        self.seperationStage2_lstm2.bias_hh_l0 = self.load_pnnx_bin_as_parameter(archive, 'seperationStage2.lstm2.bias_hh_l0', (2048), 'float32')
        self.seperationStage2_lstm2.bias_ih_l0 = self.load_pnnx_bin_as_parameter(archive, 'seperationStage2.lstm2.bias_ih_l0', (2048), 'float32')
        self.seperationStage2_lstm2.weight_hh_l0 = self.load_pnnx_bin_as_parameter(archive, 'seperationStage2.lstm2.weight_hh_l0', (2048,512), 'float32')
        self.seperationStage2_lstm2.weight_ih_l0 = self.load_pnnx_bin_as_parameter(archive, 'seperationStage2.lstm2.weight_ih_l0', (2048,512), 'float32')
        self.seperationStage2_fc.bias = self.load_pnnx_bin_as_parameter(archive, 'seperationStage2.fc.bias', (256), 'float32')
        self.seperationStage2_fc.weight = self.load_pnnx_bin_as_parameter(archive, 'seperationStage2.fc.weight', (256,512), 'float32')
        self.Conv1d_forDecode.weight = self.load_pnnx_bin_as_parameter(archive, 'Conv1d_forDecode.weight', (2048,256,1), 'float32')
        self.NormLayerStage1_mic_beta = self.load_pnnx_bin_as_parameter(archive, 'NormLayerStage1_mic.beta', (1,1,1025), 'float32')
        self.pnnx_unique_0_gamma = self.load_pnnx_bin_as_parameter(archive, 'pnnx_unique_0.gamma', (1,1,1025), 'float32')
        self.NormLayerStage1_ref_beta = self.load_pnnx_bin_as_parameter(archive, 'NormLayerStage1_ref.beta', (1,1,1025), 'float32')
        self.pnnx_unique_1_gamma = self.load_pnnx_bin_as_parameter(archive, 'pnnx_unique_1.gamma', (1,1,1025), 'float32')
        self.NormLayerStage2_mic_beta = self.load_pnnx_bin_as_parameter(archive, 'NormLayerStage2_mic.beta', (1,1,256), 'float32')
        self.pnnx_unique_3_gamma = self.load_pnnx_bin_as_parameter(archive, 'pnnx_unique_3.gamma', (1,1,256), 'float32')
        self.NormLayerStage2_ref_beta = self.load_pnnx_bin_as_parameter(archive, 'NormLayerStage2_ref.beta', (1,1,256), 'float32')
        self.pnnx_unique_4_gamma = self.load_pnnx_bin_as_parameter(archive, 'pnnx_unique_4.gamma', (1,1,256), 'float32')
        archive.close()

    def load_pnnx_bin_as_parameter(self, archive, key, shape, dtype, requires_grad=True):
        return nn.Parameter(self.load_pnnx_bin_as_tensor(archive, key, shape, dtype), requires_grad)

    def load_pnnx_bin_as_tensor(self, archive, key, shape, dtype):
        _, tmppath = tempfile.mkstemp()
        tmpf = open(tmppath, 'wb')
        with archive.open(key) as keyfile:
            tmpf.write(keyfile.read())
        tmpf.close()
        m = np.memmap(tmppath, dtype=dtype, mode='r', shape=shape).copy()
        os.remove(tmppath)
        return torch.from_numpy(m)

    def forward(self, v_0, v_1, v_2, v_3):
        v_4 = [int(v_1.size(0)), 2048, int(((v_1.size(1) - 1536) // 512))]
        v_5 = torch.zeros(size=v_4)
        v_6 = True
        v_7 = False
        v_8 = None
        v_9 = 2048
        v_10 = 512
        v_11 = aten::stft(v_0, v_9, v_10, v_9, v_8, v_7, v_8, v_6)
        v_12 = aten::resolve_conj(v_11)
        v_13 = aten::real(v_12)
        v_14 = aten::imag(v_12)
        v_15 = (v_13.pow(2) + v_14.pow(2))
        v_16 = torch.clamp(input=v_15, max=None, min=0.000000)
        v_17 = v_16.pow(5.000000e-01)
        v_18 = (v_14 + 1.192093e-07)
        v_19 = (v_13 + 1.192093e-07)
        v_20 = aten::atan2(v_18, v_19)
        v_21 = aten::stft(v_1, v_9, v_10, v_9, v_8, v_7, v_8, v_6)
        v_22 = aten::resolve_conj(v_21)
        v_23 = aten::real(v_22)
        v_24 = aten::imag(v_22)
        v_25 = (v_23.pow(2) + v_24.pow(2))
        v_26 = torch.clamp(input=v_25, max=None, min=0.000000)
        v_27 = v_26.pow(5.000000e-01)
        v_28 = self.NormLayerStage1_mic_beta
        v_29 = self.pnnx_unique_0_gamma
        v_30 = torch.permute(input=v_17, dims=(0,2,1))
        v_31 = torch.mean(input=v_30, dim=(-1,), keepdim=True)
        v_32 = (v_30 - v_31)
        v_33 = aten::square(v_32)
        v_34 = torch.mean(input=v_33, dim=(-1,), keepdim=True)
        v_35 = ((((v_30 - v_31) / torch.sqrt((v_34 + 1.000000e-07))) * v_29) + v_28)
        v_36 = self.NormLayerStage1_ref_beta
        v_37 = self.pnnx_unique_1_gamma
        v_38 = torch.permute(input=v_27, dims=(0,2,1))
        v_39 = torch.mean(input=v_38, dim=(-1,), keepdim=True)
        v_40 = (v_38 - v_39)
        v_41 = aten::square(v_40)
        v_42 = torch.mean(input=v_41, dim=(-1,), keepdim=True)
        v_43 = ((((v_38 - v_39) / torch.sqrt((v_42 + 1.000000e-07))) * v_37) + v_36)
        v_44 = v_2[:1]
        v_45 = v_44.select(dim=3, index=1)
        v_46 = v_45.contiguous(memory_format=torch.contiguous_format)
        v_47 = v_44.select(dim=3, index=0)
        v_48 = v_47.contiguous(memory_format=torch.contiguous_format)
        v_49 = torch.cat((v_43, v_35), dim=-1)
        v_50, (v_51, v_52) = self.seperationStage1_lstm1(v_49, (v_48, v_46))
        v_53 = v_2[1:]
        v_54 = v_53.select(dim=3, index=1)
        v_55 = v_54.contiguous(memory_format=torch.contiguous_format)
        v_56 = v_53.select(dim=3, index=0)
        v_57 = v_56.contiguous(memory_format=torch.contiguous_format)
        v_58, (v_59, v_60) = self.seperationStage1_lstm2(v_50, (v_57, v_55))
        v_61 = self.seperationStage1_fc(v_58)
        v_62 = self.seperationStage1_sigmoid(v_61)
        v_63 = torch.cat((v_52, v_60), dim=0)
        v_64 = torch.cat((v_51, v_59), dim=0)
        v_65 = torch.stack((v_64, v_63), dim=-1)
        v_66 = torch.permute(input=v_20, dims=(0,2,1))
        v_67 = (v_66 * None)
        v_68 = aten::exp(v_67)
        v_69 = ((v_62 * v_30) * v_68)
        v_70 = -1
        v_71 = aten::fft_irfft(v_69, v_8, v_70, v_8)
        v_72 = torch.permute(input=v_71, dims=(0,2,1))
        v_73 = self.Conv1dOne(v_72)
        v_74 = self.NormLayerStage2_mic_beta
        v_75 = self.pnnx_unique_3_gamma
        v_76 = torch.permute(input=v_73, dims=(0,2,1))
        v_77 = torch.mean(input=v_76, dim=(-1,), keepdim=True)
        v_78 = (v_76 - v_77)
        v_79 = aten::square(v_78)
        v_80 = torch.mean(input=v_79, dim=(-1,), keepdim=True)
        v_81 = ((((v_76 - v_77) / torch.sqrt((v_80 + 1.000000e-07))) * v_75) + v_74)
        v_82 = self.Conv1dOne_ref(v_5)
        v_83 = self.NormLayerStage2_ref_beta
        v_84 = self.pnnx_unique_4_gamma
        v_85 = torch.permute(input=v_82, dims=(0,2,1))
        v_86 = torch.mean(input=v_85, dim=(-1,), keepdim=True)
        v_87 = (v_85 - v_86)
        v_88 = aten::square(v_87)
        v_89 = torch.mean(input=v_88, dim=(-1,), keepdim=True)
        v_90 = ((((v_85 - v_86) / torch.sqrt((v_89 + 1.000000e-07))) * v_84) + v_83)
        v_91 = v_3[:1]
        v_92 = v_91.select(dim=3, index=1)
        v_93 = v_92.contiguous(memory_format=torch.contiguous_format)
        v_94 = v_91.select(dim=3, index=0)
        v_95 = v_94.contiguous(memory_format=torch.contiguous_format)
        v_96 = torch.cat((v_90, v_81), dim=-1)
        v_97, (v_98, v_99) = self.seperationStage2_lstm1(v_96, (v_95, v_93))
        v_100 = v_3[1:]
        v_101 = v_100.select(dim=3, index=1)
        v_102 = v_101.contiguous(memory_format=torch.contiguous_format)
        v_103 = v_100.select(dim=3, index=0)
        v_104 = v_103.contiguous(memory_format=torch.contiguous_format)
        v_105, (v_106, v_107) = self.seperationStage2_lstm2(v_97, (v_104, v_102))
        v_108 = self.seperationStage2_fc(v_105)
        v_109 = self.seperationStage2_sigmoid(v_108)
        v_110 = torch.cat((v_99, v_107), dim=0)
        v_111 = torch.cat((v_98, v_106), dim=0)
        v_112 = torch.stack((v_111, v_110), dim=-1)
        v_113 = (v_109 * v_76)
        v_114 = torch.permute(input=v_113, dims=(0,2,1))
        v_115 = self.Conv1d_forDecode(v_114)
        v_116 = [int(v_0.size(1)), 1]
        v_117 = [2048, 1]
        v_118 = [1, 1]
        v_119 = [0, 0]
        v_120 = [512, 1]
        v_121 = aten::col2im(v_115, v_116, v_117, v_118, v_119, v_120)
        v_122 = [int(v_0.size(0)), int(v_0.size(1))]
        v_123 = v_121.reshape(*v_122)
        v_124 = (v_123, v_65, v_112, )
        return v_124

def export_torchscript():
    net = Model()
    net.eval()

    torch.manual_seed(0)
    v_0 = torch.rand(dtype=null)
    v_1 = torch.rand(dtype=null)
    v_2 = torch.rand(dtype=null)
    v_3 = torch.rand(dtype=null)

    mod = torch.jit.trace(net, (v_0, v_1, v_2, v_3))
    mod.save("DTLN_aec_pnnx.py.pt")

def test_inference():
    net = Model()
    net.eval()

    torch.manual_seed(0)
    v_0 = torch.rand(dtype=null)
    v_1 = torch.rand(dtype=null)
    v_2 = torch.rand(dtype=null)
    v_3 = torch.rand(dtype=null)

    return net(v_0, v_1, v_2, v_3)

我注意到可能是我在定义过一个stft层, 是进行短时傅里叶变换的层,可能目前咱的框架不支持? 于是我尝试转成onnx中间表示格式,发现onnx对torch.complex这种数据格式不支持。

我是nihui大大的知乎忠实粉丝,大佬说过:pnnx命名的初衷就是p在o字母后面,表示其会比onnx更好,想请大佬帮忙看下什么问题,是否可以对大大的框架完善有帮助?谢谢大大!

@Baiyuetribe
Copy link
Contributor

pnnx-20220418-ubuntu/pnnx DTLN_aec.pth inputshape1=[1,2048]f32 inputshape2=[1,2048]f32 inputshape3=[2,1,512,2]f32 inputshape4=[2,1,512,2]f32 device=cpu这是错误的,inputshapeinputshape2存在时代表输入动态尺寸,没有其余输入。
修正后,pnnx转换命令为:
pnnx DTLN_aec.pth inputshape=[1,2048],[1,2048],[2,1,512,2],[2,1,512,2]

@Plutoisme
Copy link
Author

我按照您的修改方式进行修改:依旧报一样的错误,请问还有什么解决方式吗?谢谢!

lizhinan@ml-d3090:~/project/lightse/DTLNPytorch$ pnnx-20220418-ubuntu/pnnx DTLN_aec.pth inputshape=[1,2048],[1,2048],[2,1,512,2],[2,1,512,2]
pnnxparam = DTLN_aec.pnnx.param
pnnxbin = DTLN_aec.pnnx.bin
pnnxpy = DTLN_aec_pnnx.py
ncnnparam = DTLN_aec.ncnn.param
ncnnbin = DTLN_aec.ncnn.bin
ncnnpy = DTLN_aec_ncnn.py
optlevel = 2
device = cpu
inputshape = [1,2048]f32,[1,2048]f32,[2,1,512,2]f32,[2,1,512,2]f32
inputshape2 = 
customop = 
moduleop = 
############# pass_level0
inline module = modules.dtlnModel.NormLayer
inline module = modules.dtlnModel.stftLayer
inline module = modules.dtlnModel_aec_Stateful.seperationOfLSTMs_Stateful_aec
inline module = modules.dtlnModel.NormLayer
inline module = modules.dtlnModel.stftLayer
inline module = modules.dtlnModel_aec_Stateful.seperationOfLSTMs_Stateful_aec
123  40  37  batchsize.1  numSample.1  39  numOfFrames.1  53  audioSegment.1  70  72  74  77  79  82  219  inp.2  y.2  realPart.2  imagPart.2  233  234  235  236  mag_micAudioStft.1  238  463  239  phase_micAudioStft.1  86  87  inp.1  y.1  realPart.1  imagPart.1  255  256  257  258  259  inputs.1  inputs0.1  phase_micAudioStft.3  260  beta.2  gamma.2  meanOfFrames.2  269  270  varianceOfFrames.2  273  stdOfFrames.2  275  outputs.2  outputs0.2  278  279  beta.4  gamma.4  meanOfFrames.4  288  289  varianceOfFrames.4  292  stdOfFrames.4  294  outputs.4  outputs0.4  297  input.1  309  310  311  h1_in.2  313  314  315  c1_in.2  317  318  319  h2_in.2  321  322  323  c2_in.2  hx.2  hx0.2  hx1.2  hx2.2  330  331  332  333  335  336  337  338  339  340  h.2  c.2  346  117  118  magEstimated.1  124  127  out_StftOfStage1.1  out_ifft.1  input0.1  139  inputs1.1  348  beta.6  gamma.6  meanOfFrames.6  357  358  varianceOfFrames.6  361  stdOfFrames.6  363  outputs.6  outputs0.6  366  147  inputs2.1  367  beta.1  gamma.1  meanOfFrames.1  376  377  varianceOfFrames.1  380  stdOfFrames.1  382  outputs.1  outputs0.1  385  input1.1  397  398  399  h1_in.1  401  402  403  c1_in.1  405  406  407  h2_in.1  409  410  411  c2_in.1  hx.1  hx0.1  hx1.1  hx2.1  418  419  420  421  423  424  425  426  427  428  h.1  c.1  434  163  164  out.1  input2.1  173  176  audioDataEnhanced.1  190  196  203  
----------------

[W BinaryOps.cpp:601] Warning: floor_divide is deprecated, and will be removed in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values.
To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (function operator())
foldable_constant gamma.2
foldable_constant beta.1
foldable_constant gamma.1
foldable_constant beta.6
foldable_constant gamma.6
foldable_constant beta.2
foldable_constant beta.4
foldable_constant gamma.4
############# pass_level1
unknown Parameter value kind prim::Constant of TensorType, t.dim = 0
unknown Parameter value kind prim::Constant
no attribute value
no attribute value
unknown Parameter value kind prim::Constant
no attribute value
no attribute value
no attribute value
no attribute value
no attribute value
no attribute value
no attribute value
no attribute value
no attribute value
no attribute value
no attribute value
no attribute value
no attribute value
no attribute value
no attribute value
no attribute value
no attribute value
no attribute value
############# pass_level2
############# pass_level3
assign unique operator name pnnx_unique_0 to NormLayerStage1_mic
assign unique operator name pnnx_unique_1 to NormLayerStage1_ref
assign unique operator name pnnx_unique_2 to seperationStage1.drop
assign unique operator name pnnx_unique_3 to NormLayerStage2_mic
assign unique operator name pnnx_unique_4 to NormLayerStage2_ref
assign unique operator name pnnx_unique_5 to seperationStage2.drop
############# pass_level4
############# pass_level5
make_slice_expression input 2
make_slice_expression input 2
make_slice_expression input 3
make_slice_expression input 3
############# pass_ncnn
Segmentation fault (core dumped)

@nihui
Copy link
Member

nihui commented Apr 14, 2023

模型转换问题比较多,有挺多算子没有支持,ncnn目前也没有支持复数类型

out_StftOfStage1 = magEstimated * torch.exp((1j * phase_micAudioStft))

这句导致了crash,没有识别出 1j

@nihui
Copy link
Member

nihui commented Apr 14, 2023

#4627

@Plutoisme
Copy link
Author

#4627

哈哈哈,nihui大大爱了爱了!我以后也要成为你这么牛逼的可爱的人!

@nihui
Copy link
Member

nihui commented Aug 5, 2024

针对onnx模型转换的各种问题,推荐使用最新的pnnx工具转换到ncnn
In view of various problems in onnx model conversion, it is recommended to use the latest pnnx tool to convert your model to ncnn

pip install pnnx
pnnx model.onnx inputshape=[1,3,224,224]

详细参考文档
Detailed reference documentation
https://github.com/pnnx/pnnx
https://github.com/Tencent/ncnn/wiki/use-ncnn-with-pytorch-or-onnx#how-to-use-pnnx

@nihui
Copy link
Member

nihui commented Nov 19, 2024

#5779

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants