Skip to content

Commit

Permalink
add reshape
Browse files Browse the repository at this point in the history
  • Loading branch information
msnh2012 committed Aug 31, 2020
1 parent 07bfba4 commit 1bdf54b
Showing 1 changed file with 105 additions and 2 deletions.
107 changes: 105 additions & 2 deletions tools/pytorch2Msnhnet/PytorchToMsnhnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,109 @@ def _view(inData, *args):
ccc.append(x)
dataSize = inData.shape[1]*inData.shape[2]*inData.shape[3]

if inData.shape[0] != 1:
raise NotImplementedError("params error")

if len(list(args)) == 1:
if args[0] != -1:
raise NotImplementedError("params error")
msnhnet.buildView(str(x._cdata),1,1,dataSize)

if len(list(args)) == 2:
if args[0] == -1 and args[1] != -1:
if dataSize % args[1] != 0:
raise NotImplementedError("params error")
dim1 = dataSize/args[1]
dim2 = args[1]
msnhnet.buildView(str(x._cdata),1,dim1,dim2)
elif args[0] != -1 and args[1] == -1:
if dataSize % args[1] != 0:
raise NotImplementedError("params error")
dim1 = args[0]
dim2 = dataSize/args[0]
msnhnet.buildView(str(x._cdata),1,dim1,dim2)
elif args[0] != -1 and args[1] != -1:
if dataSize % (args[1]*args[0]) != 0:
raise NotImplementedError("params error")
dim1 = arg[0]
dim2 = arg[1]
msnhnet.buildView(str(x._cdata),1,dim1,dim2)
else:
raise NotImplementedError("params error")
if len(list(args)) == 3:
if args[0] == -1 and args[1] != -1 and args[2] != -1:
if dataSize % (args[1]*args[2]) != 0:
raise NotImplementedError("params error")
dim0 = dataSize /(args[1]*args[2])
dim1 = args[1]
dim2 = args[2]
msnhnet.buildView(str(x._cdata),dim0,dim1,dim2)
elif args[0] != -1 and args[1] == -1 and args[2] != -1:
if dataSize % (args[0]*args[2]) != 0:
raise NotImplementedError("params error")
dim0 = args[0]
dim1 = dataSize/(args[0]*args[2])
dim2 = args[2]
msnhnet.buildView(str(x._cdata),dim0,dim1,dim2)
elif args[0] != -1 and args[1] != -1 and args[2] == -1:
if dataSize % (args[0]*args[1]) != 0:
raise NotImplementedError("params error")
dim0 = args[0]
dim1 = args[1]
dim2 = dataSize/(args[0]*args[1])
msnhnet.buildView(str(x._cdata),dim0,dim1,dim2)
elif args[0] != -1 and args[1] != -1 and args[2] != -1:
if dataSize / (args[0]*args[1]*args[2]) != 1:
raise NotImplementedError("params error")
dim0 = args[0]
dim1 = args[1]
dim2 = args[2]
msnhnet.buildView(str(x._cdata),dim0,dim1,dim2)
if len(list(args)) == 4:
if args[0] == -1:
if dataSize/(args[1]*args[2]*args[3])==1 :
dim0 = args[1]
dim1 = args[2]
dim2 = args[3]
msnhnet.buildView(str(x._cdata),dim0,dim1,dim2)
else:
raise NotImplementedError("params error")
elif args[0] == 1:
if args[1] == -1 and args[2] != -1 and args[3] != -1:
if dataSize % (args[1]*args[2]) != 0:
raise NotImplementedError("params error")
dim0 = dataSize /(args[2]*args[3])
dim1 = args[2]
dim2 = args[3]
msnhnet.buildView(str(x._cdata),dim0,dim1,dim2)
elif args[1] != -1 and args[2] == -1 and args[3] != -1:
if dataSize % (args[1]*args[3]) != 0:
raise NotImplementedError("params error")
dim0 = args[1]
dim1 = dataSize/(args[1]*args[3])
dim2 = args[3]
msnhnet.buildView(str(x._cdata),dim0,dim1,dim2)
elif args[1] != -1 and args[2] != -1 and args[3] == -1:
if dataSize % (args[1]*args[2]) != 0:
raise NotImplementedError("params error")
dim0 = args[1]
dim1 = args[2]
dim2 = dataSize/(args[1]*args[2])
msnhnet.buildView(str(x._cdata),dim0,dim1,dim2)
elif args[1] != -1 and args[2] != -1 and args[3] != -1:
if dataSize / (args[1]*args[2]*args[3]) != 1:
raise NotImplementedError("params error")
dim0 = args[1]
dim1 = args[2]
dim2 = args[3]
msnhnet.buildView(str(x._cdata),dim0,dim1,dim2)
return x

def _reshape(inData, *args):
x=raw_reshape(inData, *args)
ccc.append(x)
dataSize = inData.shape[1]*inData.shape[2]*inData.shape[3]

if inData.shape[0] != 1:
raise NotImplementedError("params error")

Expand Down Expand Up @@ -869,6 +972,8 @@ def _expand_as(inData, *args):
for t in [torch.Tensor]:
raw_view = t.view
t.view = _view
raw_reshape = t.reshape
t.reshape = _reshape
raw_mean = t.mean
t.mean = _mean
raw__add__ = t.__add__
Expand Down Expand Up @@ -898,8 +1003,6 @@ def _expand_as(inData, *args):
raw__expand_as__ = t.expand_as
t.expand_as = _expand_as



def trans(net, inputVar, msnhnet_path, msnhbin_path):
Hook.hookInited = True
msnhnet.buildConfig(str(id(inputVar)), inputVar.size())
Expand Down

0 comments on commit 1bdf54b

Please sign in to comment.