Skip to content

DvdNss/mnist_encoder

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

45 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

About The Project -

This project aims to create a multi-channel auto-encoder for MNIST dataset using PyTorch. We pass images (CH1) and labels (CH2) as inputs and gradually mask the labels so the model can learn not only to predict the correct labels but also generate images that can be predicted with the correct labels. Each channel has its own encoder, and the model has a global encoder encoding each encoded channel, same for the decoding way.

Test it here!

Table of Contents
  1. About The Project
  2. Getting Started
  3. Usage
  4. Contact

Getting Started

Installation

  1. Clone the repo
git clone https://github.com/DvdNss/mnist_encoder
  1. Install requirements
pip install -r requirements.txt
  1. Install PyTorch GPU if needed
pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html

Usage

Structure

  • example/: contains inference outputs
  • model/: contains model .pt file
  • resources/: contains repo home image
  • source/:
    • mc_autoencoder.py: model structure (structure, forward pass...)
    • model.py: model methods (train, eval, save...)
    • train.py: training script
    • inference.py: eval and inference script
    • app.py: project GUI
  • utils/:
    • device.py: fast script for device availability (cpu or gpu -- just run device.py)

Example

  1. Run the train.py script. Feel free to edit parameters like channel sizes, epochs or learning rate.
import torch.nn
from torchvision.transforms import ToTensor

from source.model import Model

# Load data & dataloader
train_data, test_data, train_dataloader, test_dataloader = Model.load_mnist(transform=ToTensor(), batch_size=1)

# Load model
model = Model(device='cuda', img_chan_size=100, global_chan_size=50)
print(model.model)

# Loss & Optim
loss = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.model.parameters(), lr=1e-3)

epoch = 15
mask_prob = 25

# Training model
for _ in range(0, epoch):
    model.train(dataloader=train_dataloader, loss=loss, optimizer=optimizer, mask_prob=mask_prob, log_iter=60000)
    model.eval(dataloader=test_dataloader, loss=loss)
    model.eval(dataloader=test_dataloader, loss=loss, mask=True)
    mask_prob += 25 if mask_prob < 100 else 0
    print(f'Mask probability is now {mask_prob}%. ')

# Save model
model.save('model/model.pt')
  1. Run the inference.py script (examples will be stored in example/ as .png)
from torchvision.transforms import ToTensor

from source.model import Model

# Load data & dataloader
train_data, test_data, train_dataloader, test_dataloader = Model.load_mnist(transform=ToTensor(), batch_size=1)

# Load model
model = Model(load_model='model/model.pt', img_chan_size=100, global_chan_size=50)
print(f'Trainable parameters: {sum(p.numel() for p in model.model.parameters())}. ')

# Quick inference
model.infer(eval_data=test_data, random=True)
  1. Examples of input/output/label

input :
output :
label : 8

  1. Use the model with GUI
cd source/
streamlit run app.py

Contact

David NAISSE - @LinkedIn - [email protected]

Releases

No releases published

Packages

No packages published

Languages