diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index b9183bc..0000000 Binary files a/.DS_Store and /dev/null differ diff --git a/README.md b/README.md index a22dcc7..3b13f4e 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ To use VPRTempo, please follow the instructions below for installation and usage - Quantization Aware Training (QAT) enabled to train weights in int8 space - Addition of tutorials in Jupyter Notebooks to learn how to use VPRTempo as well as explain the computational logic - Simplification of weight operations, reducing to a single weight tensor - allowing positive and negative connections to change sign during training - - Easier dependency installation with PyPi/pip + - Easier dependency installation with PyPi/pip and conda - And more! ## License & Citation @@ -64,22 +64,28 @@ If you wish to enable CUDA, please follow the instructions on the [PyTorch - Get Dependencies can be installed either through our provided `requirements.txt` files. ```python -pip3 install -r requirements.txt +pip install -r requirements.txt ``` As above, if you wish to install CUDA please visit [PyTorch - Get Started](https://pytorch.org/get-started/locally/). ### Option 3: Conda install >**:heavy_exclamation_mark: Recommended:** -> Use [Mambaforge](https://mamba.readthedocs.io/en/latest/installation.html) instead of conda. +> Use [Mambaforge](https://mamba.readthedocs.io/en/latest/installation/mamba-installation.html) instead of conda. + +Requirements for VPRTempo may be installed using our [conda-forge package](https://anaconda.org/conda-forge/vprtempo). ```console -# Windows/Linux - CUDA enabled -conda create -n vprtempo -c pytorch -c nvidia python torchvision torchaudio pytorch-cuda=11.7 cudatoolkit prettytable tqdm numpy pandas scikit-learn +# Linux/OS X +conda create -n vprtempo -c conda-forge vprtempo + +# Linux CUDA enabled +conda create -n vprtempo -c conda-forge -c pytorch -c nvidia vprtempo pytorch-cuda cudatoolkit -# Windows/Linux - CPU only -conda create -n vprtempo python pytorch torchvision torchaudio cpuonly prettytable tqdm numpy pandas scikit-learn -c pytorch +# Windows +conda create -n vprtempo -c pytorch python pytorch torchvision torchaudio cpuonly prettytable tqdm numpy pandas scikit-learn + +# Windows CUDA enabled +conda create -n vprtempo -c pytorch -c nvidia python torchvision torchaudio pytorch-cuda=11.7 cudatoolkit prettytable tqdm numpy pandas scikit-learn -# MacOS -conda create -n vprtempo -c conda-forge python prettytable tqdm numpy pandas scikit-learn -c pytorch pytorch::pytorch torchvision torchaudio ``` ## Datasets @@ -142,8 +148,6 @@ python main.py --quantize Example of the quantized VPRTempo networking running

-#### IDE -You can also run VPRTempo through your IDE by running `main.py`. Change the `bool` flag for `use_quantize` to `True` if you wish to run VPRTempoQuant. ### Train new network If you do not wish to use the pretrained models or you would like to train your own, we can parse the `--train_new_model` flag to `main.py`. Note, if a pretrained model already exists you will be prompted if you would like to retrain it. diff --git a/docs/.gitignore b/docs/.gitignore new file mode 100644 index 0000000..f40fbd8 --- /dev/null +++ b/docs/.gitignore @@ -0,0 +1,5 @@ +_site +.sass-cache +.jekyll-cache +.jekyll-metadata +vendor diff --git a/main.py b/main.py index 753163e..5300ae5 100644 --- a/main.py +++ b/main.py @@ -23,79 +23,179 @@ ''' Imports ''' +import os import sys +import torch import argparse import torch.quantization as quantization +from tqdm import tqdm from vprtempo.VPRTempo import VPRTempo, run_inference +from vprtempo.VPRTempoTrain import VPRTempoTrain, train_new_model from vprtempo.src.loggers import model_logger, model_logger_quant from vprtempo.VPRTempoQuant import VPRTempoQuant, run_inference_quant -from vprtempo.VPRTempoQuantTrain import VPRTempoQuantTrain, generate_model_name_quant, train_new_model_quant -from vprtempo.VPRTempoTrain import VPRTempoTrain, generate_model_name, check_pretrained_model, train_new_model +from vprtempo.VPRTempoQuantTrain import VPRTempoQuantTrain, train_new_model_quant + +def generate_model_name(model,quant=False): + """ + Generate the model name based on its parameters. + """ + if quant: + model_name = (''.join(model.database_dirs)+"_"+ + "VPRTempoQuant_" + + "IN"+str(model.input)+"_" + + "FN"+str(model.feature)+"_" + + "DB"+str(model.database_places) + + ".pth") + else: + model_name = (''.join(model.database_dirs)+"_"+ + "VPRTempo_" + + "IN"+str(model.input)+"_" + + "FN"+str(model.feature)+"_" + + "DB"+str(model.database_places) + + ".pth") + return model_name + +def check_pretrained_model(model_name): + """ + Check if a pre-trained model exists and prompt the user to retrain if desired. + """ + if os.path.exists(os.path.join('./vprtempo/models', model_name)): + prompt = "A network with these parameters exists, re-train network? (y/n):\n" + retrain = input(prompt).strip().lower() + if retrain == 'y': + return True + elif retrain == 'n': + print('Training new model cancelled') + sys.exit() def initialize_and_run_model(args,dims): + """ + Run the VPRTempo/VPRTempoQuant training or inference models. + + :param args: Arguments set for the network + :param dims: Dimensions of the network + """ + # Determine number of modules to generate based on user input + places = args.database_places # Copy out number of database places + + # Caclulate number of modules + num_modules = 1 + while places > args.max_module: + places -= args.max_module + num_modules += 1 + + # If the final module has less than max_module, reduce the dim of the output layer + remainder = args.database_places % args.max_module + if remainder != 0: # There are remainders, adjust output neuron count in final module + out_dim = int((args.database_places - remainder) / (num_modules - 1)) + final_out_dim = remainder + else: # No remainders, all modules are even + out_dim = int(args.database_places / num_modules) + final_out_dim = out_dim + # If user wants to train a new network if args.train_new_model: # If using quantization aware training if args.quantize: models = [] - logger = model_logger_quant() - # Get the quantization config + logger = model_logger_quant() # Initialize the logger qconfig = quantization.get_default_qat_qconfig('fbgemm') - for _ in range(args.num_modules): - # Initialize the model - model = VPRTempoQuantTrain(args, dims, logger) + # Create the modules + final_out = None + for mod in tqdm(range(num_modules), desc="Initializing modules"): + model = VPRTempoQuantTrain(args, dims, logger, num_modules, out_dim, out_dim_remainder=final_out) # Initialize the model model.train() model.qconfig = qconfig - models.append(model) + quantization.prepare_qat(model, inplace=True) + models.append(model) # Create module list + if mod == num_modules - 2: + final_out = final_out_dim # Generate the model name - model_name = generate_model_name_quant(model) + model_name = generate_model_name(model,args.quantize) # Check if the model has been trained before check_pretrained_model(model_name) + # Get the quantization config + qconfig = quantization.get_default_qat_qconfig('fbgemm') # Train the model - train_new_model_quant(models, model_name, qconfig) - else: # Normal model + train_new_model_quant(models, model_name) + + # Base model + else: models = [] - logger = model_logger() - for _ in range(args.num_modules): - # Initialize the model - model = VPRTempoTrain(args, dims, logger) - models.append(model) + logger = model_logger() # Initialize the logger + + # Create the modules + final_out = None + for mod in tqdm(range(num_modules), desc="Initializing modules"): + model = VPRTempoTrain(args, dims, logger, num_modules, out_dim, out_dim_remainder=final_out) # Initialize the model + model.to(torch.device('cpu')) # Move module to CPU for storage (necessary for large models) + models.append(model) # Create module list + if mod == num_modules - 2: + final_out = final_out_dim + # Generate the model name model_name = generate_model_name(model) + print(f"Model name: {model_name}") # Check if the model has been trained before check_pretrained_model(model_name) # Train the model train_new_model(models, model_name) + # Run the inference network else: # Set the quantization configuration if args.quantize: models = [] - logger = model_logger_quant() + logger, output_folder = model_logger_quant() qconfig = quantization.get_default_qat_qconfig('fbgemm') - for _ in range(args.num_modules): + final_out = None + for _ in tqdm(range(num_modules), desc="Initializing modules"): # Initialize the model - model = VPRTempoQuant(dims, args, logger) + model = VPRTempoQuant( + args, + dims, + logger, + num_modules, + output_folder, + out_dim, + out_dim_remainder=final_out + ) model.eval() model.qconfig = qconfig - model = quantization.prepare(model, inplace=False) - model = quantization.convert(model, inplace=False) + quantization.prepare(model, inplace=True) + quantization.convert(model, inplace=True) models.append(model) # Generate the model name - model_name = generate_model_name_quant(model) + model_name = generate_model_name(model, args.quantize) # Run the quantized inference model - run_inference_quant(models, model_name, qconfig) + run_inference_quant(models, model_name) else: models = [] - logger = model_logger() - for _ in range(args.num_modules): - # Initialize the model - model = VPRTempo(dims, args, logger) - models.append(model) + logger, output_folder = model_logger() # Initialize the logger + places = args.database_places # Copy out number of database places + + # Create the modules + final_out = None + for mod in tqdm(range(num_modules), desc="Initializing modules"): + model = VPRTempo( + args, + dims, + logger, + num_modules, + output_folder, + out_dim, + out_dim_remainder=final_out + ) + model.eval() + model.to(torch.device('cpu')) # Move module to CPU for storage (necessary for large models) + models.append(model) # Create module list + if mod == num_modules - 2: + final_out = final_out_dim # Generate the model name model_name = generate_model_name(model) + print(f"Model name: {model_name}") # Run the inference model run_inference(models, model_name) @@ -110,19 +210,23 @@ def parse_network(use_quantize=False, train_new_model=False): help="Dataset to use for training and/or inferencing") parser.add_argument('--data_dir', type=str, default='./vprtempo/dataset/', help="Directory where dataset files are stored") - parser.add_argument('--num_places', type=int, default=500, - help="Number of places to use for training and/or inferencing") - parser.add_argument('--num_modules', type=int, default=1, - help="Number of expert modules to use split images into") + parser.add_argument('--database_places', type=int, default=500, + help="Number of places to use for training") + parser.add_argument('--query_places', type=int, default=500, + help="Number of places to use for inferencing") parser.add_argument('--max_module', type=int, default=500, help="Maximum number of images per module") - parser.add_argument('--database_dirs', nargs='+', default=['spring', 'fall'], + parser.add_argument('--database_dirs', type=str, default='spring, fall', help="Directories to use for training") - parser.add_argument('--query_dir', nargs='+', default=['summer'], + parser.add_argument('--query_dir', type=str, default='summer', help="Directories to use for testing") + parser.add_argument('--shuffle', action='store_true', + help="Shuffle input images during query") + parser.add_argument('--GT_tolerance', type=int, default=1, + help="Ground truth tolerance for matching") # Define training parameters - parser.add_argument('--filter', type=int, default=8, + parser.add_argument('--filter', type=int, default=1, help="Images to skip for training and/or inferencing") parser.add_argument('--epoch', type=int, default=4, help="Number of epochs to train the model") @@ -139,6 +243,12 @@ def parse_network(use_quantize=False, train_new_model=False): parser.add_argument('--quantize', action='store_true', help="Enable/disable quantization for the model") + # Define metrics functionality + parser.add_argument('--PR_curve', action='store_true', + help="Flag to generate a Precision-Recall curve") + parser.add_argument('--sim_mat', action='store_true', + help="Flag to plot the similarity matrix, GT, and GTsoft") + # If the function is called with specific arguments, override sys.argv if use_quantize or train_new_model: sys.argv = [''] diff --git a/setup.py b/setup.py index 112ab1f..c63fda9 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ # define the setup setup( name="VPRTempo", - version="1.1.4", + version="1.1.5", description='VPRTempo: A Fast Temporally Encoded Spiking Neural Network for Visual Place Recognition', long_description=long_description, long_description_content_type='text/markdown', diff --git a/tutorials/.DS_Store b/tutorials/.DS_Store deleted file mode 100644 index 6cbab19..0000000 Binary files a/tutorials/.DS_Store and /dev/null differ diff --git a/tutorials/mats/.DS_Store b/tutorials/mats/.DS_Store deleted file mode 100644 index b2b0f53..0000000 Binary files a/tutorials/mats/.DS_Store and /dev/null differ diff --git a/vprtempo/VPRTempo.py b/vprtempo/VPRTempo.py index 6a0a2f0..691ffdf 100644 --- a/vprtempo/VPRTempo.py +++ b/vprtempo/VPRTempo.py @@ -25,45 +25,61 @@ ''' import os +import json import torch +import random import numpy as np import torch.nn as nn +import matplotlib.pyplot as plt import vprtempo.src.blitnet as bn from tqdm import tqdm from prettytable import PrettyTable -from torch.utils.data import DataLoader -from vprtempo.src.metrics import recallAtK +from torch.utils.data import DataLoader, Subset +from vprtempo.src.metrics import recallAtK, createPR from vprtempo.src.dataset import CustomImageDataset, ProcessImage class VPRTempo(nn.Module): - def __init__(self, dims, args=None, logger=None): + def __init__(self, args, dims, logger, num_modules, output_folder, out_dim, out_dim_remainder=None): super(VPRTempo, self).__init__() - # Set the arguments + # Set the args if args is not None: self.args = args for arg in vars(args): setattr(self, arg, getattr(args, arg)) setattr(self, 'dims', dims) + + # Set the device if torch.cuda.is_available(): self.device = "cuda:0" else: self.device = "cpu" + # Set input args self.logger = logger + self.num_modules = num_modules + self.output_folder = output_folder + # Set the dataset file self.dataset_file = os.path.join('./vprtempo/dataset', self.dataset + '.csv') + self.query_dir = [dir.strip() for dir in self.query_dir.split(',')] # Layer dict to keep track of layer names and their order self.layer_dict = {} self.layer_counter = 0 + self.database_dirs = [dir.strip() for dir in self.database_dirs.split(',')] # Define layer architecture self.input = int(self.dims[0]*self.dims[1]) self.feature = int(self.input * 2) - self.output = int(args.num_places / args.num_modules) + + # Output dimension changes for final module if not an even distribution of places + if not out_dim_remainder is None: + self.output = out_dim_remainder + else: + self.output = out_dim """ Define trainable layers here @@ -100,15 +116,15 @@ def add_layer(self, name, **kwargs): self.layer_dict[name] = self.layer_counter self.layer_counter += 1 - def evaluate(self, models, test_loader, layers=None): + def evaluate(self, models, test_loader): """ Run the inferencing model and calculate the accuracy. + :param models: Models to run inference on, each model is a VPRTempo module :param test_loader: Testing data loader - :param layers: Layers to pass data through """ # Initialize the tqdm progress bar - pbar = tqdm(total=self.num_places, + pbar = tqdm(total=self.query_places, desc="Running the test network", position=0) self.inferences = [] @@ -121,12 +137,16 @@ def evaluate(self, models, test_loader, layers=None): nn.Hardtanh(0, 0.9), nn.ReLU() )) + self.inferences[-1].to(torch.device(self.device)) # Initiliaze the output spikes variable out = [] + labels = [] + # Run inference for the specified number of timesteps - for spikes, labels in test_loader: + for spikes, label in test_loader: # Set device - spikes, labels = spikes.to(self.device), labels.to(self.device) + spikes = spikes.to(self.device) + labels.append(label.detach().cpu().item()) # Forward pass spikes = self.forward(spikes) # Add output spikes to list @@ -135,25 +155,83 @@ def evaluate(self, models, test_loader, layers=None): # Close the tqdm progress bar pbar.close() - # Rehsape output spikes into a similarity matrix - out = np.reshape(np.array(out),(model.num_places,model.num_places)) + out = np.reshape(np.array(out),(model.query_places,model.database_places)) + + # Create GT matrix + GT = np.zeros((model.query_places,model.database_places), dtype=int) + for n, ndx in enumerate(labels): + if model.filter !=1: + ndx = ndx//model.filter + GT[n,ndx] = 1 + + # Create GT soft matrix + if model.GT_tolerance > 0: + GTsoft = np.zeros((model.query_places,model.database_places), dtype=int) + for n, ndx in enumerate(labels): + if model.filter !=1: + ndx = ndx//model.filter + GTsoft[n, ndx] = 1 + # Apply tolerance + for i in range(max(0, n - model.GT_tolerance), min(model.query_places, n + model.GT_tolerance + 1)): + GTsoft[i, ndx] = 1 + else: + GTsoft = None + + # If user specified, generate a PR curve + if model.PR_curve: + # Create PR curve + P, R = createPR(out, GThard=GT, GTsoft=GTsoft, matching='single', n_thresh=100) + # Combine P and R into a list of lists + PR_data = { + "Precision": P, + "Recall": R + } + output_file = "PR_curve_data.json" + # Construct the full path + full_path = f"{model.output_folder}/{output_file}" + # Write the data to a JSON file + with open(full_path, 'w') as file: + json.dump(PR_data, file) + # Plot PR curve + plt.plot(R,P) + plt.xlabel('Recall') + plt.ylabel('Precision') + plt.title('Precision-Recall Curve') + plt.show() + + if model.sim_mat: + # Create a figure and a set of subplots + fig, axs = plt.subplots(1, 3, figsize=(15, 5)) + + # Plot each matrix using matshow + cax1 = axs[0].matshow(out, cmap='viridis') + fig.colorbar(cax1, ax=axs[0], shrink=0.8) + axs[0].set_title('Similarity matrix') + + cax2 = axs[1].matshow(GT, cmap='plasma') + fig.colorbar(cax2, ax=axs[1], shrink=0.8) + axs[1].set_title('GT') + + cax3 = axs[2].matshow(GTsoft, cmap='inferno') + fig.colorbar(cax3, ax=axs[2], shrink=0.8) + axs[2].set_title('GT-soft') + + # Adjust layout + plt.tight_layout() + plt.show() # Recall@N N = [1,5,10,15,20,25] # N values to calculate R = [] # Recall@N values - # Create GT matrix - GT = np.zeros((model.num_places,model.num_places), dtype=int) - for n in range(len(GT)): - GT[n,n] = 1 # Calculate Recall@N for n in N: - R.append(round(recallAtK(out,GThard=GT,K=n),2)) + R.append(round(recallAtK(out,GThard=GT,GTsoft=GTsoft,K=n),2)) # Print the results table = PrettyTable() table.field_names = ["N", "1", "5", "10", "15", "20", "25"] table.add_row(["Recall", R[0], R[1], R[2], R[3], R[4], R[5]]) - print(table) + self.logger.info(table) def forward(self, spikes): """ @@ -165,10 +243,11 @@ def forward(self, spikes): Returns: - Tensor: Output after processing. """ - + # Initialize the output spikes tensor in_spikes = spikes.detach().clone() outputs = [] # List to collect output tensors + # Run inferencing across modules for inference in self.inferences: out_spikes = inference(in_spikes) outputs.append(out_spikes) # Append the output tensor to the list @@ -192,33 +271,52 @@ def run_inference(models, model_name): """ Run inference on a pre-trained model. - :param model: Model to run inference on + :param models: Models to run inference on, each model is a VPRTempo module :param model_name: Name of the model to load - :param qconfig: Quantization configuration """ - # Initialize the image transforms and datasets - image_transform = ProcessImage(models[0].dims, models[0].patches) - max_samples=models[0].num_places + # Set first index model as the main model for parameters + model = models[0] + # Initialize the image transforms + image_transform = ProcessImage(model.dims, model.patches) - test_dataset = CustomImageDataset(annotations_file=models[0].dataset_file, - base_dir=models[0].data_dir, - img_dirs=models[0].query_dir, + # Determines if querying a subset of the database or the entire database + if model.query_places == model.database_places: + subset = False # Entire database + elif model.query_places < model.database_places: + subset = True # Subset of the database + else: + raise ValueError("The number of query places must be less than or equal to the number of database places.") + + # Initialize the test dataset + test_dataset = CustomImageDataset(annotations_file=model.dataset_file, + base_dir=model.data_dir, + img_dirs=model.query_dir, transform=image_transform, - skip=models[0].filter, - max_samples=max_samples) + max_samples=model.database_places, + skip=model.filter) + + # If using a subset of the database + if subset: + if model.shuffle: # For a randomized selection of database places + test_dataset = Subset(test_dataset, random.sample(range(len(test_dataset)), model.query_places)) + else: # For a sequential selection of database places + indices = [i for i in range(model.database_places) if i % model.filter == 0] + # Limit to the desired number of queries + indices = indices[:model.query_places] + # Create the subset + test_dataset = Subset(test_dataset, indices) + + # Initialize the data loader test_loader = DataLoader(test_dataset, batch_size=1, - shuffle=False, + shuffle=model.shuffle, num_workers=8, persistent_workers=True) # Load the model - models[0].load_model(models, os.path.join('./vprtempo/models', model_name)) - - # Retrieve layer names for inference - layer_names = list(models[0].layer_dict.keys()) + model.load_model(models, os.path.join('./vprtempo/models', model_name)) # Use evaluate method for inference accuracy with torch.no_grad(): - models[0].evaluate(models, test_loader, layers=layer_names) \ No newline at end of file + model.evaluate(models, test_loader) \ No newline at end of file diff --git a/vprtempo/VPRTempoQuant.py b/vprtempo/VPRTempoQuant.py index b667923..a752574 100644 --- a/vprtempo/VPRTempoQuant.py +++ b/vprtempo/VPRTempoQuant.py @@ -25,49 +25,64 @@ ''' import os +import json import torch +import random import numpy as np import torch.nn as nn +import matplotlib.pyplot as plt import vprtempo.src.blitnet as bn from tqdm import tqdm from prettytable import PrettyTable -from torch.utils.data import DataLoader -from vprtempo.src.metrics import recallAtK +from torch.utils.data import DataLoader, Subset +from vprtempo.src.metrics import recallAtK, createPR from torch.ao.quantization import QuantStub, DeQuantStub from vprtempo.src.dataset import CustomImageDataset, ProcessImage #from main import parse_network class VPRTempoQuant(nn.Module): - def __init__(self, dims, args=None, logger=None): + def __init__(self, args, dims, logger, num_modules, output_folder, out_dim, out_dim_remainder=None): super(VPRTempoQuant, self).__init__() - # Set the arguments - self.args = args - for arg in vars(args): - setattr(self, arg, getattr(args, arg)) + # Set the args + if args is not None: + self.args = args + for arg in vars(args): + setattr(self, arg, getattr(args, arg)) setattr(self, 'dims', dims) - # Set the dataset file - self.dataset_file = os.path.join('./vprtempo/dataset', self.dataset + '.csv') - # Set the model logger and return the device - self.logger = logger + # Set the device self.device = "cpu" - # Add quantization stubs for Quantization Aware Training (QAT) + # Set input args + self.logger = logger + self.num_modules = num_modules + self.output_folder = output_folder + self.quant = QuantStub() - self.dequant = DeQuantStub() + self.dequant = DeQuantStub() + + # Set the dataset file + self.dataset_file = os.path.join('./vprtempo/dataset', self.dataset + '.csv') + self.query_dir = [dir.strip() for dir in self.query_dir.split(',')] # Layer dict to keep track of layer names and their order self.layer_dict = {} self.layer_counter = 0 + self.database_dirs = [dir.strip() for dir in self.database_dirs.split(',')] # Define layer architecture - self.input = int(dims[0]*dims[1]) + self.input = int(self.dims[0]*self.dims[1]) self.feature = int(self.input * 2) - self.output = int(args.num_places / args.num_modules) + + # Output dimension changes for final module if not an even distribution of places + if not out_dim_remainder is None: + self.output = out_dim_remainder + else: + self.output = out_dim """ Define trainable layers here @@ -125,15 +140,17 @@ def evaluate(self, models, test_loader, layers=None): nn.ReLU() )) # Initialize the tqdm progress bar - pbar = tqdm(total=self.num_places, + pbar = tqdm(total=self.query_places, desc="Running the test network", position=0) # Initiliaze the output spikes variable out = [] + labels = [] # Run inference for the specified number of timesteps - for spikes, labels in test_loader: + for spikes, label in test_loader: # Set device - spikes, labels = spikes.to(self.device), labels.to(self.device) + spikes = spikes.to(self.device) + labels.append(label.detach().item()) # Pass through previous layers if they exist spikes = self.forward(spikes) # Add output spikes to list @@ -144,21 +161,82 @@ def evaluate(self, models, test_loader, layers=None): pbar.close() # Rehsape output spikes into a similarity matrix - out = np.reshape(np.array(out),(self.num_places,self.num_places)) - # Calculate and print the Recall@N - N = [1,5,10,15,20,25] - R = [] + out = np.reshape(np.array(out),(model.query_places,model.database_places)) + # Create GT matrix - GT = np.zeros((self.num_places,self.num_places), dtype=int) - for n in range(len(GT)): - GT[n,n] = 1 + GT = np.zeros((model.query_places,model.database_places), dtype=int) + for n, ndx in enumerate(labels): + if model.filter !=1: + ndx = ndx//model.filter + GT[n,ndx] = 1 + + # Create GT soft matrix + if model.GT_tolerance > 0: + GTsoft = np.zeros((model.query_places,model.database_places), dtype=int) + for n, ndx in enumerate(labels): + if model.filter !=1: + ndx = ndx//model.filter + GTsoft[n, ndx] = 1 + # Apply tolerance + for i in range(max(0, n - model.GT_tolerance), min(model.query_places, n + model.GT_tolerance + 1)): + GTsoft[i, ndx] = 1 + else: + GTsoft = None + + # If user specified, generate a PR curve + if model.PR_curve: + # Create PR curve + P, R = createPR(out, GThard=GT, GTsoft=GTsoft, matching='single', n_thresh=100) + # Combine P and R into a list of lists + PR_data = { + "Precision": P, + "Recall": R + } + output_file = "PR_curve_data.json" + # Construct the full path + full_path = f"{model.output_folder}/{output_file}" + # Write the data to a JSON file + with open(full_path, 'w') as file: + json.dump(PR_data, file) + # Plot PR curve + plt.plot(R,P) + plt.xlabel('Recall') + plt.ylabel('Precision') + plt.title('Precision-Recall Curve') + plt.show() + + if model.sim_mat: + # Create a figure and a set of subplots + fig, axs = plt.subplots(1, 3, figsize=(15, 5)) + + # Plot each matrix using matshow + cax1 = axs[0].matshow(out, cmap='viridis') + fig.colorbar(cax1, ax=axs[0], shrink=0.8) + axs[0].set_title('Similarity matrix') + + cax2 = axs[1].matshow(GT, cmap='plasma') + fig.colorbar(cax2, ax=axs[1], shrink=0.8) + axs[1].set_title('GT') + + cax3 = axs[2].matshow(GTsoft, cmap='inferno') + fig.colorbar(cax3, ax=axs[2], shrink=0.8) + axs[2].set_title('GT-soft') + + # Adjust layout + plt.tight_layout() + plt.show() + + # Recall@N + N = [1,5,10,15,20,25] # N values to calculate + R = [] # Recall@N values + # Calculate Recall@N for n in N: - R.append(recallAtK(out,GThard=GT,K=n)) + R.append(round(recallAtK(out,GThard=GT,GTsoft=GTsoft,K=n),2)) # Print the results table = PrettyTable() table.field_names = ["N", "1", "5", "10", "15", "20", "25"] table.add_row(["Recall", R[0], R[1], R[2], R[3], R[4], R[5]]) - print(table) + self.logger.info(table) def forward(self, spikes): """ @@ -194,45 +272,57 @@ def load_model(self, models, model_path): model.load_state_dict(combined_state_dict[f'model_{i}']) model.eval() # Set the model to inference mode - -def check_pretrained_model(model_name): - """ - Check if a pre-trained model exists and tell user if it does not. - """ - if not os.path.exists(os.path.join('./models', model_name)): - model.logger.info("A pre-trained network does not exist: please train one using VPRTempoQuant_Trainer") - pretrain = 'n' - else: - pretrain = 'y' - return pretrain -def run_inference_quant(models, model_name, qconfig): +def run_inference_quant(models, model_name): """ Run inference on a pre-trained model. - :param model: Model to run inference on + :param models: Models to run inference on, each model is a VPRTempo module :param model_name: Name of the model to load - :param qconfig: Quantization configuration """ - max_samples = models[0].num_places - # Initialize the image transforms and datasets - image_transform = ProcessImage(models[0].dims, models[0].patches) - test_dataset = CustomImageDataset(annotations_file=models[0].dataset_file, - base_dir=models[0].data_dir, - img_dirs=models[0].query_dir, + # Set first index model as the main model for parameters + model = models[0] + # Initialize the image transforms + image_transform = ProcessImage(model.dims, model.patches) + + # Determines if querying a subset of the database or the entire database + if model.query_places == model.database_places: + subset = False # Entire database + elif model.query_places < model.database_places: + subset = True # Subset of the database + else: + raise ValueError("The number of query places must be less than or equal to the number of database places.") + + # Initialize the test dataset + test_dataset = CustomImageDataset(annotations_file=model.dataset_file, + base_dir=model.data_dir, + img_dirs=model.query_dir, transform=image_transform, - skip=models[0].filter, - max_samples=max_samples) + max_samples=model.database_places, + skip=model.filter) + + # If using a subset of the database + if subset: + if model.shuffle: # For a randomized selection of database places + test_dataset = Subset(test_dataset, random.sample(range(len(test_dataset)), model.query_places)) + else: # For a sequential selection of database places + indices = [i for i in range(model.database_places) if i % model.filter == 0] + # Limit to the desired number of queries + indices = indices[:model.query_places] + # Create the subset + test_dataset = Subset(test_dataset, indices) + + # Initialize the data loader test_loader = DataLoader(test_dataset, batch_size=1, - shuffle=False, + shuffle=model.shuffle, num_workers=8, persistent_workers=True) # Load the model - models[0].load_model(models, os.path.join('./vprtempo/models', model_name)) + model.load_model(models, os.path.join('./vprtempo/models', model_name)) # Use evaluate method for inference accuracy with torch.no_grad(): - models[0].evaluate(models, test_loader) \ No newline at end of file + model.evaluate(models, test_loader) \ No newline at end of file diff --git a/vprtempo/VPRTempoQuantTrain.py b/vprtempo/VPRTempoQuantTrain.py index 451420a..d5b3c68 100644 --- a/vprtempo/VPRTempoQuantTrain.py +++ b/vprtempo/VPRTempoQuantTrain.py @@ -39,7 +39,7 @@ from vprtempo.src.dataset import CustomImageDataset, ProcessImage class VPRTempoQuantTrain(nn.Module): - def __init__(self, args, dims, logger): + def __init__(self, args, dims, logger, num_modules, out_dim, out_dim_remainder=None): super(VPRTempoQuantTrain, self).__init__() # Set the arguments @@ -47,20 +47,16 @@ def __init__(self, args, dims, logger): for arg in vars(args): setattr(self, arg, getattr(args, arg)) setattr(self, 'dims', dims) - - # Only CPU available for quantization - self.device = "cpu" - # Set the logger + self.device = "cpu" self.logger = logger + self.num_modules = num_modules + self.quant = QuantStub() + self.dequant = DeQuantStub() # Set the dataset file self.dataset_file = os.path.join('./vprtempo/dataset', self.dataset + '.csv') - # Add quantization stubs for Quantization Aware Training (QAT) - self.quant = QuantStub() - self.dequant = DeQuantStub() - # Layer dict to keep track of layer names and their order self.layer_dict = {} self.layer_counter = 0 @@ -68,11 +64,18 @@ def __init__(self, args, dims, logger): # Define layer architecture self.input = int(dims[0]*dims[1]) self.feature = int(self.input * 2) - self.output = int(args.num_places / args.num_modules) + if not out_dim_remainder is None: + self.output = out_dim_remainder + else: + self.output = out_dim # Set the total timestep count - self.location_repeat = len(args.database_dirs) # Number of times to repeat the locations - self.T = int((self.num_places / self.num_modules) * self.location_repeat * self.epoch) + self.database_dirs = [dir.strip() for dir in self.database_dirs.split(',')] + self.location_repeat = len(self.database_dirs) # Number of times to repeat the locations + if not out_dim_remainder is None: + self.T = int(out_dim_remainder * self.location_repeat * self.epoch) + else: + self.T = int(self.max_module * self.location_repeat * self.epoch) """ Define trainable layers here @@ -148,7 +151,7 @@ def train_model(self, train_loader, layer, model, model_num, prev_layers=None): # idx scale factor for different modules idx_scale = (self.max_module*self.filter)*model_num # Run training for the specified number of epochs - for epoch in range(self.epoch): + for _ in range(self.epoch): # Run training for the specified number of timesteps for spikes, labels in train_loader: spikes, labels = spikes.to(self.device), labels.to(self.device) @@ -226,60 +229,69 @@ def check_pretrained_model(model_name): return retrain == 'n' return False -def train_new_model_quant(models, model_name, qconfig): +def train_new_model_quant(models, model_name): """ Train a new model. :param model: Model to train :param model_name: Name of the model to save after training - :param qconfig: Quantization configuration """ + # Set first index model as the main model for parameters + model = models[0] # Initialize the image transforms and datasets image_transform = transforms.Compose([ - ProcessImage(models[0].dims, models[0].patches) + ProcessImage(model.dims, model.patches) ]) # Automatically generate user_input_ranges user_input_ranges = [] start_idx = 0 - - for _ in range(models[0].num_modules): - range_temp = [start_idx, start_idx+((models[0].max_module-1)*models[0].filter)] + # Generate the image ranges for each module + for _ in range(model.num_modules): + range_temp = [start_idx, start_idx+((model.max_module-1)*model.filter)] user_input_ranges.append(range_temp) - start_idx = range_temp[1] + models[0].filter - if models[0].num_places < models[0].max_module: - max_samples=models[0].num_places - else: - max_samples = models[0].max_module + start_idx = range_temp[1] + model.filter + # Keep track of trained layers to pass data through them + trained_layers = [] + # Training each layer - trained_models = [] - for i, model in enumerate(models): - trained_layers = [] - img_range=user_input_ranges[i] - train_dataset = CustomImageDataset(annotations_file=models[0].dataset_file, - base_dir=models[0].data_dir, - img_dirs=models[0].database_dirs, - transform=image_transform, - skip=models[0].filter, - test=False, - img_range=img_range, - max_samples=max_samples) - # Initialize the data loader - train_loader = DataLoader(train_dataset, - batch_size=1, - shuffle=True, - num_workers=8, - persistent_workers=True) - model = quantization.prepare_qat(model, inplace=False) - for layer_name, _ in sorted(models[0].layer_dict.items(), key=lambda item: item[1]): - print(f"Training layer: {layer_name}") + for layer_name, _ in sorted(model.layer_dict.items(), key=lambda item: item[1]): + print(f"Training layer: {layer_name}") + # Retrieve the layer object + for i, model in enumerate(models): + model.train() + model.to(torch.device(model.device)) layer = (getattr(model, layer_name)) - + # Determine the maximum samples for the DataLoader + if model.database_places < model.max_module: + max_samples = model.database_places + elif model.output < model.max_module: + max_samples = model.output + else: + max_samples = model.max_module + # Initialize new dataset with unique range for each module + img_range=user_input_ranges[i] + train_dataset = CustomImageDataset(annotations_file=model.dataset_file, + base_dir=model.data_dir, + img_dirs=model.database_dirs, + transform=image_transform, + skip=model.filter, + test=False, + img_range=img_range, + max_samples=max_samples) + # Initialize the data loader + train_loader = DataLoader(train_dataset, + batch_size=1, + shuffle=True, + num_workers=8, + persistent_workers=True) # Train the layers model.train_model(train_loader, layer, model, i, prev_layers=trained_layers) - trained_layers.append(layer_name) - trained_models.append(quantization.convert(model, inplace=False)) - # After training the current layer, add it to the list of trained layer + trained_layers.append(layer_name) + + # Convert the model to evaluation mode + for model in models: + quantization.convert(model, inplace=True) # Save the model - model.save_model(trained_models,os.path.join('./vprtempo/models', model_name)) \ No newline at end of file + model.save_model(models,os.path.join('./vprtempo/models', model_name)) \ No newline at end of file diff --git a/vprtempo/VPRTempoTrain.py b/vprtempo/VPRTempoTrain.py index de3d449..0a7049e 100644 --- a/vprtempo/VPRTempoTrain.py +++ b/vprtempo/VPRTempoTrain.py @@ -27,6 +27,7 @@ import os import gc import torch +import sys import numpy as np import torch.nn as nn @@ -39,7 +40,7 @@ from vprtempo.src.dataset import CustomImageDataset, ProcessImage class VPRTempoTrain(nn.Module): - def __init__(self, args, dims, logger): + def __init__(self, args, dims, logger, num_modules, out_dim, out_dim_remainder=None): super(VPRTempoTrain, self).__init__() # Set the arguments @@ -52,6 +53,7 @@ def __init__(self, args, dims, logger): else: self.device = "cpu" self.logger = logger + self.num_modules = num_modules # Set the dataset file self.dataset_file = os.path.join('./vprtempo/dataset', self.dataset + '.csv') @@ -63,11 +65,18 @@ def __init__(self, args, dims, logger): # Define layer architecture self.input = int(dims[0]*dims[1]) self.feature = int(self.input * 2) - self.output = int(args.num_places / args.num_modules) + if not out_dim_remainder is None: + self.output = out_dim_remainder + else: + self.output = out_dim # Set the total timestep count - self.location_repeat = len(args.database_dirs) # Number of times to repeat the locations - self.T = int((self.num_places/self.num_modules) * self.location_repeat * self.epoch) + self.database_dirs = [dir.strip() for dir in self.database_dirs.split(',')] + self.location_repeat = len(self.database_dirs) # Number of times to repeat the locations + if not out_dim_remainder is None: + self.T = int(out_dim_remainder * self.location_repeat * self.epoch) + else: + self.T = int(self.max_module * self.location_repeat * self.epoch) """ Define trainable layers here @@ -138,7 +147,7 @@ def train_model(self, train_loader, layer, model, model_num, prev_layers=None): """ # Initialize the tqdm progress bar - pbar = tqdm(total=int(self.T), + pbar = tqdm(total=self.T, desc=f"Module {model_num+1}", position=0) @@ -211,26 +220,19 @@ def save_model(self,models, model_out): torch.save(state_dicts, model_out) -def generate_model_name(model): - """ - Generate the model name based on its parameters. - """ - return ("VPRTempo" + - str(model.input) + - str(model.feature) + - str(model.output) + - str(model.num_modules) + - '.pth') def check_pretrained_model(model_name): """ Check if a pre-trained model exists and prompt the user to retrain if desired. """ - if os.path.exists(os.path.join('./models', model_name)): + if os.path.exists(os.path.join('./vprtempo/models', model_name)): prompt = "A network with these parameters exists, re-train network? (y/n):\n" retrain = input(prompt).strip().lower() - return retrain == 'n' - return False + if retrain == 'y': + return True + elif retrain == 'n': + print('Training new model cancelled') + sys.exit() def train_new_model(models, model_name): """ @@ -247,15 +249,12 @@ def train_new_model(models, model_name): # Automatically generate user_input_ranges user_input_ranges = [] start_idx = 0 - + # Generate the image ranges for each module for _ in range(models[0].num_modules): range_temp = [start_idx, start_idx+((models[0].max_module-1)*models[0].filter)] user_input_ranges.append(range_temp) start_idx = range_temp[1] + models[0].filter - if models[0].num_places < models[0].max_module: - max_samples=models[0].num_places - else: - max_samples = models[0].max_module + # Keep track of trained layers to pass data through them trained_layers = [] # Training each layer @@ -264,7 +263,16 @@ def train_new_model(models, model_name): # Retrieve the layer object for i, model in enumerate(models): model.train() + model.to(torch.device(model.device)) layer = (getattr(model, layer_name)) + # Determine the maximum samples for the DataLoader + if model.database_places < model.max_module: + max_samples = model.database_places + elif model.output < model.max_module: + max_samples = model.output + else: + max_samples = model.max_module + # Initialize new dataset with unique range for each module img_range=user_input_ranges[i] train_dataset = CustomImageDataset(annotations_file=models[0].dataset_file, base_dir=models[0].data_dir, @@ -282,6 +290,7 @@ def train_new_model(models, model_name): persistent_workers=True) # Train the layers model.train_model(train_loader, layer, model, i, prev_layers=trained_layers) + model.to(torch.device("cpu")) # After training the current layer, add it to the list of trained layers trained_layers.append(layer_name) # Convert the model to evaluation mode diff --git a/vprtempo/dataset/.DS_Store b/vprtempo/dataset/.DS_Store deleted file mode 100644 index ea74ea8..0000000 Binary files a/vprtempo/dataset/.DS_Store and /dev/null differ diff --git a/vprtempo/models/.DS_Store b/vprtempo/models/.DS_Store deleted file mode 100644 index 28b3c56..0000000 Binary files a/vprtempo/models/.DS_Store and /dev/null differ diff --git a/vprtempo/src/dataset.py b/vprtempo/src/dataset.py index 976dbf2..209549d 100644 --- a/vprtempo/src/dataset.py +++ b/vprtempo/src/dataset.py @@ -163,7 +163,7 @@ def __init__(self, annotations_file, base_dir, img_dirs, transform=None, target_ self.target_transform = target_transform self.skip = skip self.img_range = img_range - + # Load image labels from each directory, apply the skip and max_samples, and concatenate self.img_labels = [] for img_dir in img_dirs: diff --git a/vprtempo/src/loggers.py b/vprtempo/src/loggers.py index 3081a02..8f7e5cb 100644 --- a/vprtempo/src/loggers.py +++ b/vprtempo/src/loggers.py @@ -48,6 +48,8 @@ def model_logger(): logger.info('Current device is: CPU') logger.info('') + return logger, output_folder + def model_logger_quant(): """ Configure the logger @@ -88,3 +90,5 @@ def model_logger_quant(): logger.info('Quantization enabled') logger.info('Current device is: CPU') logger.info('') + + return logger, output_folder \ No newline at end of file diff --git a/vprtempo/src/metrics.py b/vprtempo/src/metrics.py index 9a85b84..9900fd4 100644 --- a/vprtempo/src/metrics.py +++ b/vprtempo/src/metrics.py @@ -47,20 +47,19 @@ def createPR(S_in, GThard, GTsoft=None, matching='multi', n_thresh=100): # copy S and set elements that are only true in GTsoft to min(S) to ignore them during evaluation S = S_in.copy() - S[S == 0] = np.nan if GTsoft is not None: S[GTsoft & ~GT] = S.min() # single-best-match or multi-match VPR if matching == 'single': - # GT-values for best match per query (i.e., per column) - GT = GT[np.nanargmax(S, axis=0), np.arange(GT.shape[1])] + # count the number of ground-truth positives (GTP) + GTP = np.count_nonzero(GT.any(0)) - # count the number of ground-truth positives (GTP) - GTP = np.count_nonzero(GT) + # GT-values for best match per query (i.e., per column) + GT = GT[np.argmax(S, axis=0), np.arange(GT.shape[1])] # similarities for best match per query (i.e., per column) - S = np.nanmax(S, axis=0) + S = np.max(S, axis=0) elif matching == 'multi': # count the number of ground-truth positives (GTP) @@ -71,8 +70,8 @@ def createPR(S_in, GThard, GTsoft=None, matching='multi', n_thresh=100): P = [1, ] # select start and end treshold - startV = np.nanmax(S) # start-value for treshold - endV = np.nanmin(S) # end-value for treshold + startV = S.max() # start-value for treshold + endV = S.min() # end-value for treshold # iterate over different thresholds for i in np.linspace(startV, endV, n_thresh): @@ -149,9 +148,13 @@ def recallAtK(S_in, GThard, GTsoft=None, K=1): # ensure logical datatype in GT and GTsoft GT = GThard.astype('bool') + if GTsoft is not None: + GTsoft = GTsoft.astype('bool') # copy S and set elements that are only true in GTsoft to min(S) to ignore them during evaluation S = S_in.copy() + if GTsoft is not None: + S[GTsoft & ~GT] = S.min() # discard all query images without an actually matching database image j = GT.sum(0) > 0 # columns with matches @@ -166,4 +169,4 @@ def recallAtK(S_in, GThard, GTsoft=None, K=1): # recall@K RatK = np.sum(GT.sum(0) > 0) / GT.shape[1] - return RatK + return RatK \ No newline at end of file