Skip to content

Commit

Permalink
Merge pull request #12 from QVPR/module_memory
Browse files Browse the repository at this point in the history
Module memory
  • Loading branch information
AdamDHines authored Jan 2, 2024
2 parents 307fa6c + e958cda commit 4ae753c
Show file tree
Hide file tree
Showing 16 changed files with 552 additions and 217 deletions.
Binary file removed .DS_Store
Binary file not shown.
26 changes: 15 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ To use VPRTempo, please follow the instructions below for installation and usage
- Quantization Aware Training (QAT) enabled to train weights in int8 space
- Addition of tutorials in Jupyter Notebooks to learn how to use VPRTempo as well as explain the computational logic
- Simplification of weight operations, reducing to a single weight tensor - allowing positive and negative connections to change sign during training
- Easier dependency installation with PyPi/pip
- Easier dependency installation with PyPi/pip and conda
- And more!

## License & Citation
Expand Down Expand Up @@ -64,22 +64,28 @@ If you wish to enable CUDA, please follow the instructions on the [PyTorch - Get
Dependencies can be installed either through our provided `requirements.txt` files.

```python
pip3 install -r requirements.txt
pip install -r requirements.txt
```
As above, if you wish to install CUDA please visit [PyTorch - Get Started](https://pytorch.org/get-started/locally/).
### Option 3: Conda install
>**:heavy_exclamation_mark: Recommended:**
> Use [Mambaforge](https://mamba.readthedocs.io/en/latest/installation.html) instead of conda.
> Use [Mambaforge](https://mamba.readthedocs.io/en/latest/installation/mamba-installation.html) instead of conda.
Requirements for VPRTempo may be installed using our [conda-forge package](https://anaconda.org/conda-forge/vprtempo).

```console
# Windows/Linux - CUDA enabled
conda create -n vprtempo -c pytorch -c nvidia python torchvision torchaudio pytorch-cuda=11.7 cudatoolkit prettytable tqdm numpy pandas scikit-learn
# Linux/OS X
conda create -n vprtempo -c conda-forge vprtempo

# Linux CUDA enabled
conda create -n vprtempo -c conda-forge -c pytorch -c nvidia vprtempo pytorch-cuda cudatoolkit

# Windows/Linux - CPU only
conda create -n vprtempo python pytorch torchvision torchaudio cpuonly prettytable tqdm numpy pandas scikit-learn -c pytorch
# Windows
conda create -n vprtempo -c pytorch python pytorch torchvision torchaudio cpuonly prettytable tqdm numpy pandas scikit-learn

# Windows CUDA enabled
conda create -n vprtempo -c pytorch -c nvidia python torchvision torchaudio pytorch-cuda=11.7 cudatoolkit prettytable tqdm numpy pandas scikit-learn

# MacOS
conda create -n vprtempo -c conda-forge python prettytable tqdm numpy pandas scikit-learn -c pytorch pytorch::pytorch torchvision torchaudio
```

## Datasets
Expand Down Expand Up @@ -142,8 +148,6 @@ python main.py --quantize
<img src="./assets/mainquant_example.gif" alt="Example of the quantized VPRTempo networking running"/>
</p>

#### IDE
You can also run VPRTempo through your IDE by running `main.py`. Change the `bool` flag for `use_quantize` to `True` if you wish to run VPRTempoQuant.

### Train new network
If you do not wish to use the pretrained models or you would like to train your own, we can parse the `--train_new_model` flag to `main.py`. Note, if a pretrained model already exists you will be prompted if you would like to retrain it.
Expand Down
5 changes: 5 additions & 0 deletions docs/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_site
.sass-cache
.jekyll-cache
.jekyll-metadata
vendor
180 changes: 145 additions & 35 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,79 +23,179 @@
'''
Imports
'''
import os
import sys
import torch
import argparse

import torch.quantization as quantization

from tqdm import tqdm
from vprtempo.VPRTempo import VPRTempo, run_inference
from vprtempo.VPRTempoTrain import VPRTempoTrain, train_new_model
from vprtempo.src.loggers import model_logger, model_logger_quant
from vprtempo.VPRTempoQuant import VPRTempoQuant, run_inference_quant
from vprtempo.VPRTempoQuantTrain import VPRTempoQuantTrain, generate_model_name_quant, train_new_model_quant
from vprtempo.VPRTempoTrain import VPRTempoTrain, generate_model_name, check_pretrained_model, train_new_model
from vprtempo.VPRTempoQuantTrain import VPRTempoQuantTrain, train_new_model_quant

def generate_model_name(model,quant=False):
"""
Generate the model name based on its parameters.
"""
if quant:
model_name = (''.join(model.database_dirs)+"_"+
"VPRTempoQuant_" +
"IN"+str(model.input)+"_" +
"FN"+str(model.feature)+"_" +
"DB"+str(model.database_places) +
".pth")
else:
model_name = (''.join(model.database_dirs)+"_"+
"VPRTempo_" +
"IN"+str(model.input)+"_" +
"FN"+str(model.feature)+"_" +
"DB"+str(model.database_places) +
".pth")
return model_name

def check_pretrained_model(model_name):
"""
Check if a pre-trained model exists and prompt the user to retrain if desired.
"""
if os.path.exists(os.path.join('./vprtempo/models', model_name)):
prompt = "A network with these parameters exists, re-train network? (y/n):\n"
retrain = input(prompt).strip().lower()
if retrain == 'y':
return True
elif retrain == 'n':
print('Training new model cancelled')
sys.exit()

def initialize_and_run_model(args,dims):
"""
Run the VPRTempo/VPRTempoQuant training or inference models.
:param args: Arguments set for the network
:param dims: Dimensions of the network
"""
# Determine number of modules to generate based on user input
places = args.database_places # Copy out number of database places

# Caclulate number of modules
num_modules = 1
while places > args.max_module:
places -= args.max_module
num_modules += 1

# If the final module has less than max_module, reduce the dim of the output layer
remainder = args.database_places % args.max_module
if remainder != 0: # There are remainders, adjust output neuron count in final module
out_dim = int((args.database_places - remainder) / (num_modules - 1))
final_out_dim = remainder
else: # No remainders, all modules are even
out_dim = int(args.database_places / num_modules)
final_out_dim = out_dim

# If user wants to train a new network
if args.train_new_model:
# If using quantization aware training
if args.quantize:
models = []
logger = model_logger_quant()
# Get the quantization config
logger = model_logger_quant() # Initialize the logger
qconfig = quantization.get_default_qat_qconfig('fbgemm')
for _ in range(args.num_modules):
# Initialize the model
model = VPRTempoQuantTrain(args, dims, logger)
# Create the modules
final_out = None
for mod in tqdm(range(num_modules), desc="Initializing modules"):
model = VPRTempoQuantTrain(args, dims, logger, num_modules, out_dim, out_dim_remainder=final_out) # Initialize the model
model.train()
model.qconfig = qconfig
models.append(model)
quantization.prepare_qat(model, inplace=True)
models.append(model) # Create module list
if mod == num_modules - 2:
final_out = final_out_dim
# Generate the model name
model_name = generate_model_name_quant(model)
model_name = generate_model_name(model,args.quantize)
# Check if the model has been trained before
check_pretrained_model(model_name)
# Get the quantization config
qconfig = quantization.get_default_qat_qconfig('fbgemm')
# Train the model
train_new_model_quant(models, model_name, qconfig)
else: # Normal model
train_new_model_quant(models, model_name)

# Base model
else:
models = []
logger = model_logger()
for _ in range(args.num_modules):
# Initialize the model
model = VPRTempoTrain(args, dims, logger)
models.append(model)
logger = model_logger() # Initialize the logger

# Create the modules
final_out = None
for mod in tqdm(range(num_modules), desc="Initializing modules"):
model = VPRTempoTrain(args, dims, logger, num_modules, out_dim, out_dim_remainder=final_out) # Initialize the model
model.to(torch.device('cpu')) # Move module to CPU for storage (necessary for large models)
models.append(model) # Create module list
if mod == num_modules - 2:
final_out = final_out_dim

# Generate the model name
model_name = generate_model_name(model)
print(f"Model name: {model_name}")
# Check if the model has been trained before
check_pretrained_model(model_name)
# Train the model
train_new_model(models, model_name)

# Run the inference network
else:
# Set the quantization configuration
if args.quantize:
models = []
logger = model_logger_quant()
logger, output_folder = model_logger_quant()
qconfig = quantization.get_default_qat_qconfig('fbgemm')
for _ in range(args.num_modules):
final_out = None
for _ in tqdm(range(num_modules), desc="Initializing modules"):
# Initialize the model
model = VPRTempoQuant(dims, args, logger)
model = VPRTempoQuant(
args,
dims,
logger,
num_modules,
output_folder,
out_dim,
out_dim_remainder=final_out
)
model.eval()
model.qconfig = qconfig
model = quantization.prepare(model, inplace=False)
model = quantization.convert(model, inplace=False)
quantization.prepare(model, inplace=True)
quantization.convert(model, inplace=True)
models.append(model)
# Generate the model name
model_name = generate_model_name_quant(model)
model_name = generate_model_name(model, args.quantize)
# Run the quantized inference model
run_inference_quant(models, model_name, qconfig)
run_inference_quant(models, model_name)
else:
models = []
logger = model_logger()
for _ in range(args.num_modules):
# Initialize the model
model = VPRTempo(dims, args, logger)
models.append(model)
logger, output_folder = model_logger() # Initialize the logger
places = args.database_places # Copy out number of database places

# Create the modules
final_out = None
for mod in tqdm(range(num_modules), desc="Initializing modules"):
model = VPRTempo(
args,
dims,
logger,
num_modules,
output_folder,
out_dim,
out_dim_remainder=final_out
)
model.eval()
model.to(torch.device('cpu')) # Move module to CPU for storage (necessary for large models)
models.append(model) # Create module list
if mod == num_modules - 2:
final_out = final_out_dim
# Generate the model name
model_name = generate_model_name(model)
print(f"Model name: {model_name}")
# Run the inference model
run_inference(models, model_name)

Expand All @@ -110,19 +210,23 @@ def parse_network(use_quantize=False, train_new_model=False):
help="Dataset to use for training and/or inferencing")
parser.add_argument('--data_dir', type=str, default='./vprtempo/dataset/',
help="Directory where dataset files are stored")
parser.add_argument('--num_places', type=int, default=500,
help="Number of places to use for training and/or inferencing")
parser.add_argument('--num_modules', type=int, default=1,
help="Number of expert modules to use split images into")
parser.add_argument('--database_places', type=int, default=500,
help="Number of places to use for training")
parser.add_argument('--query_places', type=int, default=500,
help="Number of places to use for inferencing")
parser.add_argument('--max_module', type=int, default=500,
help="Maximum number of images per module")
parser.add_argument('--database_dirs', nargs='+', default=['spring', 'fall'],
parser.add_argument('--database_dirs', type=str, default='spring, fall',
help="Directories to use for training")
parser.add_argument('--query_dir', nargs='+', default=['summer'],
parser.add_argument('--query_dir', type=str, default='summer',
help="Directories to use for testing")
parser.add_argument('--shuffle', action='store_true',
help="Shuffle input images during query")
parser.add_argument('--GT_tolerance', type=int, default=1,
help="Ground truth tolerance for matching")

# Define training parameters
parser.add_argument('--filter', type=int, default=8,
parser.add_argument('--filter', type=int, default=1,
help="Images to skip for training and/or inferencing")
parser.add_argument('--epoch', type=int, default=4,
help="Number of epochs to train the model")
Expand All @@ -139,6 +243,12 @@ def parse_network(use_quantize=False, train_new_model=False):
parser.add_argument('--quantize', action='store_true',
help="Enable/disable quantization for the model")

# Define metrics functionality
parser.add_argument('--PR_curve', action='store_true',
help="Flag to generate a Precision-Recall curve")
parser.add_argument('--sim_mat', action='store_true',
help="Flag to plot the similarity matrix, GT, and GTsoft")

# If the function is called with specific arguments, override sys.argv
if use_quantize or train_new_model:
sys.argv = ['']
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# define the setup
setup(
name="VPRTempo",
version="1.1.4",
version="1.1.5",
description='VPRTempo: A Fast Temporally Encoded Spiking Neural Network for Visual Place Recognition',
long_description=long_description,
long_description_content_type='text/markdown',
Expand Down
Binary file removed tutorials/.DS_Store
Binary file not shown.
Binary file removed tutorials/mats/.DS_Store
Binary file not shown.
Loading

0 comments on commit 4ae753c

Please sign in to comment.