This project aims to create a multi-channel auto-encoder for MNIST dataset using PyTorch. We pass images (CH1) and labels (CH2) as inputs and gradually mask the labels so the model can learn not only to predict the correct labels but also generate images that can be predicted with the correct labels. Each channel has its own encoder, and the model has a global encoder encoding each encoded channel, same for the decoding way.
Table of Contents
- Clone the repo
git clone https://github.com/DvdNss/mnist_encoder
- Install requirements
pip install -r requirements.txt
- Install PyTorch GPU if needed
pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
example/
: contains inference outputsmodel/
: contains model .pt fileresources/
: contains repo home imagesource/
:mc_autoencoder.py
: model structure (structure, forward pass...)model.py
: model methods (train, eval, save...)train.py
: training scriptinference.py
: eval and inference scriptapp.py
: project GUI
utils/
:device.py
: fast script for device availability (cpu or gpu -- just rundevice.py
)
- Run the
train.py
script. Feel free to edit parameters likechannel sizes
,epochs
orlearning rate
.
import torch.nn
from torchvision.transforms import ToTensor
from source.model import Model
# Load data & dataloader
train_data, test_data, train_dataloader, test_dataloader = Model.load_mnist(transform=ToTensor(), batch_size=1)
# Load model
model = Model(device='cuda', img_chan_size=100, global_chan_size=50)
print(model.model)
# Loss & Optim
loss = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.model.parameters(), lr=1e-3)
epoch = 15
mask_prob = 25
# Training model
for _ in range(0, epoch):
model.train(dataloader=train_dataloader, loss=loss, optimizer=optimizer, mask_prob=mask_prob, log_iter=60000)
model.eval(dataloader=test_dataloader, loss=loss)
model.eval(dataloader=test_dataloader, loss=loss, mask=True)
mask_prob += 25 if mask_prob < 100 else 0
print(f'Mask probability is now {mask_prob}%. ')
# Save model
model.save('model/model.pt')
- Run the
inference.py
script (examples will be stored in example/ as .png)
from torchvision.transforms import ToTensor
from source.model import Model
# Load data & dataloader
train_data, test_data, train_dataloader, test_dataloader = Model.load_mnist(transform=ToTensor(), batch_size=1)
# Load model
model = Model(load_model='model/model.pt', img_chan_size=100, global_chan_size=50)
print(f'Trainable parameters: {sum(p.numel() for p in model.model.parameters())}. ')
# Quick inference
model.infer(eval_data=test_data, random=True)
- Examples of input/output/label
- Use the model with GUI
cd source/
streamlit run app.py
David NAISSE - @LinkedIn - [email protected]