-
Notifications
You must be signed in to change notification settings - Fork 77
/
Copy pathsave_net_flat.py
118 lines (97 loc) · 3.94 KB
/
save_net_flat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
from torch.nn import functional as F
import numpy as np
def save_model_weights(model, filename):
weights = []
for name, param in model.named_parameters():
weights.append(param.data.cpu().numpy().flatten())
print(name, param.shape, param.data.cpu().numpy().ravel()[0])
weights = np.concatenate(weights)
print('Num weights:', len(weights))
weights.tofile(filename)
# Save the model architecture (you may want to adjust this based on your specific model)
#with open(filename + "_architecture.txt", "w") as f:
# for name, param in model.named_parameters():
# f.write(f"{name}: {param.shape}\n")
def test_model(model):
model = model.cpu().policy
batch_size = 16
obs_window = 11
obs_window_channels = 4
obs_flat = 26
x = torch.arange(
0, batch_size*(obs_window*obs_window*obs_window_channels + obs_flat)
).reshape(batch_size, -1) % 16
cnn_features = x[:, :-obs_flat].view(
batch_size, obs_window, obs_window, obs_window_channels).long()
map_features = F.one_hot(cnn_features[:, :, :, 0], 16).permute(0, 3, 1, 2).float()
extra_map_features = (cnn_features[:, :, :, -3:].float() / 255.0).permute(0, 3, 1, 2)
cnn_features = torch.cat([map_features, extra_map_features], dim=1)
cnn = model.policy.cnn
cnn_features = torch.from_numpy(
np.arange(batch_size*11*11*19).reshape(
batch_size, 19, obs_window, obs_window)
).float()
conv1_out = cnn[0](cnn_features)
#(cnn[0].weight[0] * cnn_features[0, :, :5, :5]).sum() + cnn[0].bias[0]
breakpoint()
hidden = model.encoder(x)
output = model.decoder(hidden)
atn = output.argmax(dim=1)
print('Encode weight sum:', model.encoder.weight.sum())
print('encode decode weight and bias sum:', model.encoder.weight.sum() + model.encoder.bias.sum() + model.decoder.weight.sum() + model.decoder.bias.sum())
print('X sum:', x.sum())
print('Hidden sum:', hidden.sum())
print('Hidden 1-10:', hidden[0, :10])
print('Output sum:', output.sum())
print('Atn sum:', atn.sum())
breakpoint()
exit(0)
def test_lstm():
batch_size = 16
input_size = 128
hidden_size = 128
input = torch.arange(batch_size*input_size).reshape(1, batch_size, -1).float()/ 100000
state = (
torch.arange(batch_size*hidden_size).reshape(1, batch_size, -1).float()/ 100000,
torch.arange(batch_size*hidden_size).reshape(1, batch_size, -1).float() / 100000
)
weights_input = torch.arange(4*hidden_size*input_size).reshape(4*hidden_size, -1).float()/ 100000
weights_state = torch.arange(4*hidden_size*hidden_size).reshape(4*hidden_size, -1).float()/ 100000
bias_input = torch.arange(4*hidden_size).reshape(4*hidden_size).float() / 100000
bias_state = torch.arange(4*hidden_size).reshape(4*hidden_size).float() / 100000
lstm = torch.nn.LSTM(input_size=128, hidden_size=128, num_layers=1)
lstm.weight_ih_l0.data = weights_input
lstm.weight_hh_l0.data = weights_state
lstm.bias_ih_l0.data = bias_input
lstm.bias_hh_l0.data = bias_state
output, new_state = lstm(input, state)
input = input.squeeze(0)
h, c = state
buffer = (
torch.matmul(input, weights_input.T) + bias_input
+ torch.matmul(h, weights_state.T) + bias_state
)[0]
i, f, g, o = torch.split(buffer, hidden_size, dim=1)
i = torch.sigmoid(i)
f = torch.sigmoid(f)
g = torch.tanh(g)
o = torch.sigmoid(o)
c = f*c + i*g
h = o*torch.tanh(c)
breakpoint()
print('Output:', output)
def test_model_forward(model):
data = torch.arange(10*(11*11*4 + 26)) % 16
data[(11*11*4 + 26):] = 0
data = data.reshape(10, -1).float()
output = model(data)
breakpoint()
pass
if __name__ == '__main__':
#test_lstm()
model = torch.load('snake.pt', map_location='cpu')
#test_model_forward(model)
#test_model(model)
save_model_weights(model, 'snake_weights.bin')
print('saved')