-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e67d6a6
commit 37a25a8
Showing
6 changed files
with
141 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import torch | ||
from torch.optim import Adam | ||
import matplotlib.pyplot as plt | ||
|
||
net = torch.nn.Linear(3, 2) | ||
optimizer = Adam(net.parameters(), lr=0.01) | ||
# Move the network, target value, and training inputs to the GPU | ||
net.cuda() | ||
target = torch.tensor([[1.0, 1.0]], device='cuda') | ||
log = [] | ||
for _ in range(1000): | ||
y_batch = net(torch.randn(100, 3, device='cuda')) | ||
loss = ((y_batch - target) ** 2).sum(1).mean() | ||
log.append(loss.item()) | ||
net.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
|
||
print(f'weight is {net.weight}\n') | ||
print(f'bias is {net.bias}\n') | ||
|
||
plt.ylabel('loss') | ||
plt.xlabel('iteration') | ||
plt.plot(log) | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import torch | ||
from torch.optim import Adam | ||
from torch.nn.functional import cross_entropy | ||
from collections import OrderedDict | ||
from torch.nn import Linear, ReLU, Sequential | ||
|
||
def classify_target(x, y): | ||
return (y > (x * 3).sin()).long() | ||
|
||
mlp = torch.nn.Sequential(OrderedDict([ | ||
('layer1', Sequential(Linear(2, 20), ReLU())), | ||
('layer2', Sequential(Linear(20, 20), ReLU())), | ||
('layer3', Sequential(Linear(20, 2))) | ||
])) | ||
|
||
mlp.cuda() | ||
|
||
optimizer = Adam(mlp.parameters(), lr=0.01) | ||
for iteration in range(1024): | ||
in_batch = torch.randn(10000, 2, device='cuda') | ||
target_batch = classify_target(in_batch[:,0], in_batch[:,1]) | ||
out_batch = mlp(in_batch) | ||
loss = cross_entropy(out_batch, target_batch) | ||
if iteration > 0: | ||
mlp.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
if iteration == 2 ** iteration.bit_length() - 1: | ||
pred_batch = out_batch.max(1)[1] | ||
accuracy = (pred_batch == target_batch).float().sum() / len(in_batch) | ||
print(f'Iteration {iteration} accuracy: {accuracy}') |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import os | ||
import torch | ||
|
||
net = torch.nn.Linear(3, 2) | ||
print(net) | ||
|
||
print(net(torch.tensor([1.0, 0.0, 0.0]))) | ||
|
||
x_batch = torch.tensor([ | ||
[1.0, 0., 0.], | ||
[0., 1.0, 0.], | ||
[0., 0., 1.0], | ||
[0., 0., 0.], | ||
]) | ||
|
||
print(net(x_batch)) | ||
|
||
print("weight is: ", net.weight) | ||
print("bias is: ", net.bias) | ||
|
||
for name, param in net.named_parameters(): | ||
print(f'{name} = {param}\n') | ||
|
||
for k, v in net.state_dict().items(): | ||
print(f'{k}: {v.type()}{tuple(v.shape)}') | ||
|
||
torch.save(net.state_dict(), "linear.pth") | ||
|
||
net.load_state_dict(torch.load("linear.pth")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import torch | ||
|
||
net = torch.nn.Linear(3, 2) | ||
|
||
x_batch = torch.tensor([ | ||
[1.0, 0., 0.], | ||
[0., 1.0, 0.], | ||
[0., 0., 1.0], | ||
[0., 0., 0.], | ||
]) | ||
|
||
y_batch = net(x_batch) | ||
|
||
loss = ((y_batch - torch.tensor([[1.0, 1.0]])) ** 2).sum(1).mean() | ||
print(f"loss is {loss}") | ||
|
||
loss.backward() | ||
print(f'weight is {net.weight} and grad is:\n{net.weight.grad}\n') | ||
print(f'bias is {net.bias} and grad is:\n{net.bias.grad}\n') | ||
|
||
log = [] | ||
for _ in range(10000): | ||
y_batch = net(x_batch) | ||
loss = ((y_batch - torch.tensor([[1.0, 1.0]])) ** 2).sum(1).mean() | ||
log.append(loss.item()) | ||
net.zero_grad() | ||
loss.backward() | ||
with torch.no_grad(): | ||
for p in net.parameters(): | ||
p[...] -= 0.01 * p.grad | ||
print(f'weight is {net.weight}\n') | ||
print(f'bias is {net.bias}\n') | ||
|
||
import matplotlib.pyplot as plt | ||
|
||
plt.ylabel('loss') | ||
plt.xlabel('iteration') | ||
plt.plot(log) | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import torch | ||
from collections import OrderedDict | ||
from torch.nn import Linear, ReLU, Sequential | ||
|
||
mlp = torch.nn.Sequential(OrderedDict([ | ||
('layer1', Sequential(Linear(2, 20), ReLU())), | ||
('layer2', Sequential(Linear(20, 20), ReLU())), | ||
('layer3', Sequential(Linear(20, 2))) | ||
])) | ||
|
||
print(mlp) | ||
|
||
for n, c in mlp.named_modules(): | ||
print(f'{n or "The whole network"} is a {type(c).__name__}') | ||
|
||
for name, param in mlp.named_parameters(): | ||
print(f'{name} has shape {tuple(param.shape)}') |