From 139edacfbaf62e799d610d5066478155bf6be02d Mon Sep 17 00:00:00 2001 From: Adam Hines Date: Wed, 13 Dec 2023 16:02:39 +1000 Subject: [PATCH 1/8] Adding in query specific places, not just the same as model training amount --- .gitignore | 1 + main.py | 10 +++++++-- vprtempo/VPRTempo.py | 45 ++++++++++++++++++++++++++++++--------- vprtempo/VPRTempoTrain.py | 2 ++ 4 files changed, 46 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index 9a70cc7..7baed65 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ vprtempo/dataset/winter/ vprtempo/dataset/event.csv vprtempo/output/ vprtempo/src/__pycache__/ +vprtempo/models/VPRTempo3136627250045.pth diff --git a/main.py b/main.py index 753163e..31889d4 100644 --- a/main.py +++ b/main.py @@ -25,9 +25,11 @@ ''' import sys import argparse +import torch import torch.quantization as quantization +from tqdm import tqdm from vprtempo.VPRTempo import VPRTempo, run_inference from vprtempo.src.loggers import model_logger, model_logger_quant from vprtempo.VPRTempoQuant import VPRTempoQuant, run_inference_quant @@ -61,6 +63,7 @@ def initialize_and_run_model(args,dims): for _ in range(args.num_modules): # Initialize the model model = VPRTempoTrain(args, dims, logger) + model.to(torch.device('cpu')) models.append(model) # Generate the model name model_name = generate_model_name(model) @@ -90,9 +93,10 @@ def initialize_and_run_model(args,dims): else: models = [] logger = model_logger() - for _ in range(args.num_modules): + for _ in tqdm(range(args.num_modules), desc="Initializing modules"): # Initialize the model model = VPRTempo(dims, args, logger) + model.to(torch.device('cpu')) models.append(model) # Generate the model name model_name = generate_model_name(model) @@ -111,7 +115,9 @@ def parse_network(use_quantize=False, train_new_model=False): parser.add_argument('--data_dir', type=str, default='./vprtempo/dataset/', help="Directory where dataset files are stored") parser.add_argument('--num_places', type=int, default=500, - help="Number of places to use for training and/or inferencing") + help="Number of places to use for training") + parser.add_argument('--query_places', type=int, default=500, + help="Number of places to use for inferencing") parser.add_argument('--num_modules', type=int, default=1, help="Number of expert modules to use split images into") parser.add_argument('--max_module', type=int, default=500, diff --git a/vprtempo/VPRTempo.py b/vprtempo/VPRTempo.py index 6a0a2f0..0198cef 100644 --- a/vprtempo/VPRTempo.py +++ b/vprtempo/VPRTempo.py @@ -108,7 +108,7 @@ def evaluate(self, models, test_loader, layers=None): :param layers: Layers to pass data through """ # Initialize the tqdm progress bar - pbar = tqdm(total=self.num_places, + pbar = tqdm(total=self.query_places, desc="Running the test network", position=0) self.inferences = [] @@ -121,12 +121,15 @@ def evaluate(self, models, test_loader, layers=None): nn.Hardtanh(0, 0.9), nn.ReLU() )) + self.inferences[-1].to(torch.device(self.device)) # Initiliaze the output spikes variable out = [] + labels = [] # Run inference for the specified number of timesteps - for spikes, labels in test_loader: + for spikes, label in test_loader: # Set device - spikes, labels = spikes.to(self.device), labels.to(self.device) + spikes = spikes.to(self.device) + labels.append(label.detach().cpu().item()) # Forward pass spikes = self.forward(spikes) # Add output spikes to list @@ -135,17 +138,17 @@ def evaluate(self, models, test_loader, layers=None): # Close the tqdm progress bar pbar.close() - + # Rehsape output spikes into a similarity matrix - out = np.reshape(np.array(out),(model.num_places,model.num_places)) + out = np.reshape(np.array(out),(model.query_places,model.num_places)) # Recall@N N = [1,5,10,15,20,25] # N values to calculate R = [] # Recall@N values # Create GT matrix - GT = np.zeros((model.num_places,model.num_places), dtype=int) - for n in range(len(GT)): - GT[n,n] = 1 + GT = np.zeros((model.query_places,model.num_places), dtype=int) + for n, ndx in enumerate(labels): + GT[n,ndx] = 1 # Calculate Recall@N for n in N: R.append(round(recallAtK(out,GThard=GT,K=n),2)) @@ -155,6 +158,28 @@ def evaluate(self, models, test_loader, layers=None): table.add_row(["Recall", R[0], R[1], R[2], R[3], R[4], R[5]]) print(table) + import matplotlib.pyplot as plt + + # Increasing the figure size significantly + plt.figure(figsize=(20, 30)) # This is a large figure size, which can be adjusted as needed + + # Plotting "Spiking output" + plt.subplot(2, 1, 1) # First plot + plt.imshow(out, aspect='auto', interpolation='nearest') # 'nearest' interpolation for clear points + plt.title("Spiking output") + plt.colorbar(shrink=0.5) + plt.grid(False) # Disabling grid lines + + # Plotting "Ground truth" + plt.subplot(2, 1, 2) # Second plot + plt.imshow(GT, aspect='auto', interpolation='nearest') + plt.title("Ground truth") + plt.colorbar(shrink=0.5) + plt.grid(False) # Disabling grid lines + + # Displaying the plots + plt.show() + def forward(self, spikes): """ Compute the forward pass of the model. @@ -198,7 +223,7 @@ def run_inference(models, model_name): """ # Initialize the image transforms and datasets image_transform = ProcessImage(models[0].dims, models[0].patches) - max_samples=models[0].num_places + max_samples=models[0].query_places test_dataset = CustomImageDataset(annotations_file=models[0].dataset_file, base_dir=models[0].data_dir, @@ -209,7 +234,7 @@ def run_inference(models, model_name): # Initialize the data loader test_loader = DataLoader(test_dataset, batch_size=1, - shuffle=False, + shuffle=True, num_workers=8, persistent_workers=True) diff --git a/vprtempo/VPRTempoTrain.py b/vprtempo/VPRTempoTrain.py index de3d449..e0bac5c 100644 --- a/vprtempo/VPRTempoTrain.py +++ b/vprtempo/VPRTempoTrain.py @@ -264,6 +264,7 @@ def train_new_model(models, model_name): # Retrieve the layer object for i, model in enumerate(models): model.train() + model.to(torch.device(model.device)) layer = (getattr(model, layer_name)) img_range=user_input_ranges[i] train_dataset = CustomImageDataset(annotations_file=models[0].dataset_file, @@ -282,6 +283,7 @@ def train_new_model(models, model_name): persistent_workers=True) # Train the layers model.train_model(train_loader, layer, model, i, prev_layers=trained_layers) + model.to(torch.device("cpu")) # After training the current layer, add it to the list of trained layers trained_layers.append(layer_name) # Convert the model to evaluation mode From f1c960346a749c0d91951d7f3528ff94810ced92 Mon Sep 17 00:00:00 2001 From: Adam Hines Date: Fri, 15 Dec 2023 10:24:05 +1000 Subject: [PATCH 2/8] Added in GT tolerance, adding 4 full dataset models, need to fix up metrics and GT generation --- main.py | 111 ++++++++++++++++++++++++++++++-------- vprtempo/VPRTempo.py | 97 +++++++++++++++++---------------- vprtempo/VPRTempoTrain.py | 45 +++++++++++----- vprtempo/src/dataset.py | 2 +- vprtempo/src/loggers.py | 4 ++ vprtempo/src/metrics.py | 4 ++ 6 files changed, 180 insertions(+), 83 deletions(-) diff --git a/main.py b/main.py index 31889d4..dbfcd99 100644 --- a/main.py +++ b/main.py @@ -34,7 +34,18 @@ from vprtempo.src.loggers import model_logger, model_logger_quant from vprtempo.VPRTempoQuant import VPRTempoQuant, run_inference_quant from vprtempo.VPRTempoQuantTrain import VPRTempoQuantTrain, generate_model_name_quant, train_new_model_quant -from vprtempo.VPRTempoTrain import VPRTempoTrain, generate_model_name, check_pretrained_model, train_new_model +from vprtempo.VPRTempoTrain import VPRTempoTrain, check_pretrained_model, train_new_model + +def generate_model_name(model): + """ + Generate the model name based on its parameters. + """ + return ("VPRTempo" + + str(model.input) + + str(model.feature) + + str(model.database_places) + + ''.join(model.database_dirs) + + '.pth') def initialize_and_run_model(args,dims): # If user wants to train a new network @@ -45,7 +56,7 @@ def initialize_and_run_model(args,dims): logger = model_logger_quant() # Get the quantization config qconfig = quantization.get_default_qat_qconfig('fbgemm') - for _ in range(args.num_modules): + for _ in tqdm(range(args.num_modules), desc="Initializing modules"): # Initialize the model model = VPRTempoQuantTrain(args, dims, logger) model.train() @@ -57,20 +68,52 @@ def initialize_and_run_model(args,dims): check_pretrained_model(model_name) # Train the model train_new_model_quant(models, model_name, qconfig) - else: # Normal model + + # Base model + else: models = [] - logger = model_logger() - for _ in range(args.num_modules): - # Initialize the model - model = VPRTempoTrain(args, dims, logger) - model.to(torch.device('cpu')) - models.append(model) + logger = model_logger() # Initialize the logger + places = args.database_places # Copy out number of database places + + # Determine how many modules the network needs to create + num_modules = 1 + while places > args.max_module: + places -= args.max_module + num_modules += 1 + + # If the final module has less than max_module, reduce the dim of the output layer + remainder = args.database_places % args.max_module + + # Check if number of modules and database images works + if args.filter * (((num_modules-1)*args.max_module)+remainder) > args.database_places: + print("Error: Too many modules or too few images for the given filter") + sys.exit() + + # Modify final module output layer neuron count according to remainder + if remainder != 0: # There are remainders, adjust output neuron count in final module + out_dim = int((args.database_places - remainder) / (num_modules - 1)) + final_out_dim = remainder + else: # No remainders, all modules are even + out_dim = int(args.database_places / num_modules) + final_out_dim = out_dim + + # Create the modules + final_out = None + for mod in tqdm(range(num_modules), desc="Initializing modules"): + model = VPRTempoTrain(args, dims, logger, num_modules, out_dim, out_dim_remainder=final_out) # Initialize the model + model.to(torch.device('cpu')) # Move module to CPU for storage (necessary for large models) + models.append(model) # Create module list + if mod == num_modules - 2: + final_out = final_out_dim + # Generate the model name model_name = generate_model_name(model) + print(f"Model name: {model_name}") # Check if the model has been trained before check_pretrained_model(model_name) # Train the model train_new_model(models, model_name) + # Run the inference network else: # Set the quantization configuration @@ -78,7 +121,7 @@ def initialize_and_run_model(args,dims): models = [] logger = model_logger_quant() qconfig = quantization.get_default_qat_qconfig('fbgemm') - for _ in range(args.num_modules): + for _ in tqdm(range(args.num_modules), desc="Initializing modules"): # Initialize the model model = VPRTempoQuant(dims, args, logger) model.eval() @@ -92,14 +135,36 @@ def initialize_and_run_model(args,dims): run_inference_quant(models, model_name, qconfig) else: models = [] - logger = model_logger() - for _ in tqdm(range(args.num_modules), desc="Initializing modules"): - # Initialize the model - model = VPRTempo(dims, args, logger) - model.to(torch.device('cpu')) - models.append(model) + logger = model_logger() # Initialize the logger + places = args.database_places # Copy out number of database places + + # Determine how many modules the network needs to create + num_modules = 1 + while places > args.max_module: + places -= args.max_module + num_modules += 1 + + # If the final module has less than max_module, reduce the dim of the output layer + remainder = args.database_places % args.max_module + if remainder != 0: # There are remainders, adjust output neuron count in final module + out_dim = int((args.database_places - remainder) / (num_modules - 1)) + final_out_dim = remainder + else: # No remainders, all modules are even + out_dim = int(args.database_places / num_modules) + final_out_dim = out_dim + + # Create the modules + final_out = None + for mod in tqdm(range(num_modules), desc="Initializing modules"): + model = VPRTempo(args, dims, logger, num_modules, out_dim, out_dim_remainder=final_out) # Initialize the model + model.eval() + model.to(torch.device('cpu')) # Move module to CPU for storage (necessary for large models) + models.append(model) # Create module list + if mod == num_modules - 2: + final_out = final_out_dim # Generate the model name model_name = generate_model_name(model) + print(f"Model name: {model_name}") # Run the inference model run_inference(models, model_name) @@ -114,21 +179,23 @@ def parse_network(use_quantize=False, train_new_model=False): help="Dataset to use for training and/or inferencing") parser.add_argument('--data_dir', type=str, default='./vprtempo/dataset/', help="Directory where dataset files are stored") - parser.add_argument('--num_places', type=int, default=500, + parser.add_argument('--database_places', type=int, default=500, help="Number of places to use for training") parser.add_argument('--query_places', type=int, default=500, help="Number of places to use for inferencing") - parser.add_argument('--num_modules', type=int, default=1, - help="Number of expert modules to use split images into") parser.add_argument('--max_module', type=int, default=500, help="Maximum number of images per module") - parser.add_argument('--database_dirs', nargs='+', default=['spring', 'fall'], + parser.add_argument('--database_dirs', type=str, default='spring, fall', help="Directories to use for training") - parser.add_argument('--query_dir', nargs='+', default=['summer'], + parser.add_argument('--query_dir', type=str, default='summer', help="Directories to use for testing") + parser.add_argument('--shuffle', action='store_true', + help="Shuffle input images during query") + parser.add_argument('--GT_tolerance', type=int, default=2, + help="Ground truth tolerance for matching") # Define training parameters - parser.add_argument('--filter', type=int, default=8, + parser.add_argument('--filter', type=int, default=1, help="Images to skip for training and/or inferencing") parser.add_argument('--epoch', type=int, default=4, help="Number of epochs to train the model") diff --git a/vprtempo/VPRTempo.py b/vprtempo/VPRTempo.py index 0198cef..a64f9b3 100644 --- a/vprtempo/VPRTempo.py +++ b/vprtempo/VPRTempo.py @@ -33,12 +33,12 @@ from tqdm import tqdm from prettytable import PrettyTable -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Subset from vprtempo.src.metrics import recallAtK from vprtempo.src.dataset import CustomImageDataset, ProcessImage class VPRTempo(nn.Module): - def __init__(self, dims, args=None, logger=None): + def __init__(self, args, dims, logger, num_modules, out_dim, out_dim_remainder=None): super(VPRTempo, self).__init__() # Set the arguments @@ -53,17 +53,22 @@ def __init__(self, dims, args=None, logger=None): self.device = "cpu" self.logger = logger + self.num_modules = num_modules + # Set the dataset file self.dataset_file = os.path.join('./vprtempo/dataset', self.dataset + '.csv') - + self.query_dir = [dir.strip() for dir in self.query_dir.split(',')] # Layer dict to keep track of layer names and their order self.layer_dict = {} self.layer_counter = 0 - + self.database_dirs = [dir.strip() for dir in self.database_dirs.split(',')] # Define layer architecture self.input = int(self.dims[0]*self.dims[1]) self.feature = int(self.input * 2) - self.output = int(args.num_places / args.num_modules) + if not out_dim_remainder is None: + self.output = out_dim_remainder + else: + self.output = out_dim """ Define trainable layers here @@ -125,60 +130,56 @@ def evaluate(self, models, test_loader, layers=None): # Initiliaze the output spikes variable out = [] labels = [] + num_corr = 0 # Run inference for the specified number of timesteps for spikes, label in test_loader: # Set device - spikes = spikes.to(self.device) + spikes, label = spikes.to(self.device), label.to(self.device) labels.append(label.detach().cpu().item()) # Forward pass spikes = self.forward(spikes) # Add output spikes to list - out.append(spikes.detach().cpu().tolist()) + #out.append(spikes.detach().cpu().tolist()) pbar.update(1) - + if torch.argmax(spikes) == label: + num_corr += 1 + print(f"Accuracy: {num_corr/len(test_loader)}") # Close the tqdm progress bar pbar.close() # Rehsape output spikes into a similarity matrix - out = np.reshape(np.array(out),(model.query_places,model.num_places)) + #out = np.reshape(np.array(out),(model.query_places,model.database_places)) # Recall@N - N = [1,5,10,15,20,25] # N values to calculate - R = [] # Recall@N values + # N = [1,5,10,15,20,25] # N values to calculate + # R = [] # Recall@N values # Create GT matrix - GT = np.zeros((model.query_places,model.num_places), dtype=int) - for n, ndx in enumerate(labels): - GT[n,ndx] = 1 + # GT = np.zeros((model.query_places,model.database_places), dtype=int) + # for n, ndx in enumerate(labels): + # if model.filter !=1: + # ndx = ndx//model.filter + # GT[n,ndx] = 1 + # Create GT soft matrix + # if model.GT_tolerance > 0: + # GTsoft = np.zeros((model.query_places, model.database_places), dtype=int) + # for n, ndx in enumerate(labels): + # if model.filter != 1: + # ndx = ndx // model.filter + # GTsoft[n, ndx] = 1 + # # Apply tolerance + # for i in range(max(0, n - model.GT_tolerance), min(model.query_places, n + model.GT_tolerance + 1)): + # GTsoft[i, ndx] = 1 + # else: + # GTsoft = None + # Calculate Recall@N - for n in N: - R.append(round(recallAtK(out,GThard=GT,K=n),2)) + #for n in N: + # R.append(round(recallAtK(out,GThard=GT,GTsoft=GTsoft,K=n),2)) # Print the results - table = PrettyTable() - table.field_names = ["N", "1", "5", "10", "15", "20", "25"] - table.add_row(["Recall", R[0], R[1], R[2], R[3], R[4], R[5]]) - print(table) - - import matplotlib.pyplot as plt - - # Increasing the figure size significantly - plt.figure(figsize=(20, 30)) # This is a large figure size, which can be adjusted as needed - - # Plotting "Spiking output" - plt.subplot(2, 1, 1) # First plot - plt.imshow(out, aspect='auto', interpolation='nearest') # 'nearest' interpolation for clear points - plt.title("Spiking output") - plt.colorbar(shrink=0.5) - plt.grid(False) # Disabling grid lines - - # Plotting "Ground truth" - plt.subplot(2, 1, 2) # Second plot - plt.imshow(GT, aspect='auto', interpolation='nearest') - plt.title("Ground truth") - plt.colorbar(shrink=0.5) - plt.grid(False) # Disabling grid lines - - # Displaying the plots - plt.show() + #table = PrettyTable() + #table.field_names = ["N", "1", "5", "10", "15", "20", "25"] + #table.add_row(["Recall", R[0], R[1], R[2], R[3], R[4], R[5]]) + #self.logger.info(table) def forward(self, spikes): """ @@ -190,10 +191,11 @@ def forward(self, spikes): Returns: - Tensor: Output after processing. """ - + # Initialize the output spikes tensor in_spikes = spikes.detach().clone() outputs = [] # List to collect output tensors + # Run inferencing across modules for inference in self.inferences: out_spikes = inference(in_spikes) outputs.append(out_spikes) # Append the output tensor to the list @@ -223,7 +225,7 @@ def run_inference(models, model_name): """ # Initialize the image transforms and datasets image_transform = ProcessImage(models[0].dims, models[0].patches) - max_samples=models[0].query_places + max_samples=models[0].database_places test_dataset = CustomImageDataset(annotations_file=models[0].dataset_file, base_dir=models[0].data_dir, @@ -231,10 +233,13 @@ def run_inference(models, model_name): transform=image_transform, skip=models[0].filter, max_samples=max_samples) + indices = torch.randperm(models[0].database_places).tolist() + subset_indicies = indices[:models[0].query_places] + subset = Subset(test_dataset, subset_indicies) # Initialize the data loader - test_loader = DataLoader(test_dataset, + test_loader = DataLoader(subset, batch_size=1, - shuffle=True, + shuffle=models[0].shuffle, num_workers=8, persistent_workers=True) diff --git a/vprtempo/VPRTempoTrain.py b/vprtempo/VPRTempoTrain.py index e0bac5c..812b371 100644 --- a/vprtempo/VPRTempoTrain.py +++ b/vprtempo/VPRTempoTrain.py @@ -27,6 +27,7 @@ import os import gc import torch +import sys import numpy as np import torch.nn as nn @@ -39,7 +40,7 @@ from vprtempo.src.dataset import CustomImageDataset, ProcessImage class VPRTempoTrain(nn.Module): - def __init__(self, args, dims, logger): + def __init__(self, args, dims, logger, num_modules, out_dim, out_dim_remainder=None): super(VPRTempoTrain, self).__init__() # Set the arguments @@ -52,6 +53,7 @@ def __init__(self, args, dims, logger): else: self.device = "cpu" self.logger = logger + self.num_modules = num_modules # Set the dataset file self.dataset_file = os.path.join('./vprtempo/dataset', self.dataset + '.csv') @@ -63,11 +65,18 @@ def __init__(self, args, dims, logger): # Define layer architecture self.input = int(dims[0]*dims[1]) self.feature = int(self.input * 2) - self.output = int(args.num_places / args.num_modules) + if not out_dim_remainder is None: + self.output = out_dim_remainder + else: + self.output = out_dim # Set the total timestep count - self.location_repeat = len(args.database_dirs) # Number of times to repeat the locations - self.T = int((self.num_places/self.num_modules) * self.location_repeat * self.epoch) + self.database_dirs = [dir.strip() for dir in self.database_dirs.split(',')] + self.location_repeat = len(self.database_dirs) # Number of times to repeat the locations + if not out_dim_remainder is None: + self.T = int(out_dim_remainder * self.location_repeat * self.epoch) + else: + self.T = int(self.max_module * self.location_repeat * self.epoch) """ Define trainable layers here @@ -138,7 +147,7 @@ def train_model(self, train_loader, layer, model, model_num, prev_layers=None): """ # Initialize the tqdm progress bar - pbar = tqdm(total=int(self.T), + pbar = tqdm(total=self.T, desc=f"Module {model_num+1}", position=0) @@ -219,18 +228,21 @@ def generate_model_name(model): str(model.input) + str(model.feature) + str(model.output) + - str(model.num_modules) + + ''.join(model.database_dirs) + '.pth') def check_pretrained_model(model_name): """ Check if a pre-trained model exists and prompt the user to retrain if desired. """ - if os.path.exists(os.path.join('./models', model_name)): + if os.path.exists(os.path.join('./vprtempo/models', model_name)): prompt = "A network with these parameters exists, re-train network? (y/n):\n" retrain = input(prompt).strip().lower() - return retrain == 'n' - return False + if retrain == 'y': + return True + elif retrain == 'n': + print('Training new model cancelled') + sys.exit() def train_new_model(models, model_name): """ @@ -247,15 +259,12 @@ def train_new_model(models, model_name): # Automatically generate user_input_ranges user_input_ranges = [] start_idx = 0 - + # Generate the image ranges for each module for _ in range(models[0].num_modules): range_temp = [start_idx, start_idx+((models[0].max_module-1)*models[0].filter)] user_input_ranges.append(range_temp) start_idx = range_temp[1] + models[0].filter - if models[0].num_places < models[0].max_module: - max_samples=models[0].num_places - else: - max_samples = models[0].max_module + # Keep track of trained layers to pass data through them trained_layers = [] # Training each layer @@ -266,6 +275,14 @@ def train_new_model(models, model_name): model.train() model.to(torch.device(model.device)) layer = (getattr(model, layer_name)) + # Determine the maximum samples for the DataLoader + if model.database_places < model.max_module: + max_samples = model.database_places + elif model.output < model.max_module: + max_samples = model.output + else: + max_samples = model.max_module + # Initialize new dataset with unique range for each module img_range=user_input_ranges[i] train_dataset = CustomImageDataset(annotations_file=models[0].dataset_file, base_dir=models[0].data_dir, diff --git a/vprtempo/src/dataset.py b/vprtempo/src/dataset.py index 976dbf2..209549d 100644 --- a/vprtempo/src/dataset.py +++ b/vprtempo/src/dataset.py @@ -163,7 +163,7 @@ def __init__(self, annotations_file, base_dir, img_dirs, transform=None, target_ self.target_transform = target_transform self.skip = skip self.img_range = img_range - + # Load image labels from each directory, apply the skip and max_samples, and concatenate self.img_labels = [] for img_dir in img_dirs: diff --git a/vprtempo/src/loggers.py b/vprtempo/src/loggers.py index 3081a02..d870de1 100644 --- a/vprtempo/src/loggers.py +++ b/vprtempo/src/loggers.py @@ -48,6 +48,8 @@ def model_logger(): logger.info('Current device is: CPU') logger.info('') + return logger + def model_logger_quant(): """ Configure the logger @@ -88,3 +90,5 @@ def model_logger_quant(): logger.info('Quantization enabled') logger.info('Current device is: CPU') logger.info('') + + return logger \ No newline at end of file diff --git a/vprtempo/src/metrics.py b/vprtempo/src/metrics.py index 9a85b84..4f2849a 100644 --- a/vprtempo/src/metrics.py +++ b/vprtempo/src/metrics.py @@ -149,9 +149,13 @@ def recallAtK(S_in, GThard, GTsoft=None, K=1): # ensure logical datatype in GT and GTsoft GT = GThard.astype('bool') + if GTsoft is not None: + GTsoft = GTsoft.astype('bool') # copy S and set elements that are only true in GTsoft to min(S) to ignore them during evaluation S = S_in.copy() + if GTsoft is not None: + S[GTsoft & ~GT] = S.min() # discard all query images without an actually matching database image j = GT.sum(0) > 0 # columns with matches From c4e5dff62d06455d074c1a7b338667f2c7bfa666 Mon Sep 17 00:00:00 2001 From: Adam Hines Date: Mon, 18 Dec 2023 14:36:03 +1000 Subject: [PATCH 3/8] Fixed up GT and matching, runs subsets for q image counts less than db --- docs/.gitignore | 5 ++ docs/404.html | 25 ++++++ docs/Gemfile | 35 ++++++++ docs/_config.yml | 55 +++++++++++++ .../2023-12-18-welcome-to-jekyll.markdown | 29 +++++++ docs/about.markdown | 18 +++++ docs/index.markdown | 6 ++ vprtempo/VPRTempo.py | 80 ++++++++++--------- 8 files changed, 217 insertions(+), 36 deletions(-) create mode 100644 docs/.gitignore create mode 100644 docs/404.html create mode 100644 docs/Gemfile create mode 100644 docs/_config.yml create mode 100644 docs/_posts/2023-12-18-welcome-to-jekyll.markdown create mode 100644 docs/about.markdown create mode 100644 docs/index.markdown diff --git a/docs/.gitignore b/docs/.gitignore new file mode 100644 index 0000000..f40fbd8 --- /dev/null +++ b/docs/.gitignore @@ -0,0 +1,5 @@ +_site +.sass-cache +.jekyll-cache +.jekyll-metadata +vendor diff --git a/docs/404.html b/docs/404.html new file mode 100644 index 0000000..086a5c9 --- /dev/null +++ b/docs/404.html @@ -0,0 +1,25 @@ +--- +permalink: /404.html +layout: default +--- + + + +
+

404

+ +

Page not found :(

+

The requested page could not be found.

+
diff --git a/docs/Gemfile b/docs/Gemfile new file mode 100644 index 0000000..60c2833 --- /dev/null +++ b/docs/Gemfile @@ -0,0 +1,35 @@ +source "https://rubygems.org" +# Hello! This is where you manage which Jekyll version is used to run. +# When you want to use a different version, change it below, save the +# file and run `bundle install`. Run Jekyll with `bundle exec`, like so: +# +# bundle exec jekyll serve +# +# This will help ensure the proper Jekyll version is running. +# Happy Jekylling! +# gem "jekyll", "~> 4.3.2" +# This is the default theme for new Jekyll sites. You may change this to anything you like. +gem "minima", "~> 2.5" +# If you want to use GitHub Pages, remove the "gem "jekyll"" above and +# uncomment the line below. To upgrade, run `bundle update github-pages`. +# gem "github-pages", group: :jekyll_plugins +# If you have any plugins, put them here! +group :jekyll_plugins do + gem "jekyll-feed", "~> 0.12" +end + +# Windows and JRuby does not include zoneinfo files, so bundle the tzinfo-data gem +# and associated library. +platforms :mingw, :x64_mingw, :mswin, :jruby do + gem "tzinfo", ">= 1", "< 3" + gem "tzinfo-data" +end + +# Performance-booster for watching directories on Windows +gem "wdm", "~> 0.1.1", :platforms => [:mingw, :x64_mingw, :mswin] + +# Lock `http_parser.rb` gem to `v0.6.x` on JRuby builds since newer versions of the gem +# do not have a Java counterpart. +gem "http_parser.rb", "~> 0.6.0", :platforms => [:jruby] + +gem "github-pages", "~> 228", group: :jekyll_plugins diff --git a/docs/_config.yml b/docs/_config.yml new file mode 100644 index 0000000..ef7ba7c --- /dev/null +++ b/docs/_config.yml @@ -0,0 +1,55 @@ +# Welcome to Jekyll! +# +# This config file is meant for settings that affect your whole blog, values +# which you are expected to set up once and rarely edit after that. If you find +# yourself editing this file very often, consider using Jekyll's data files +# feature for the data you need to update frequently. +# +# For technical reasons, this file is *NOT* reloaded automatically when you use +# 'bundle exec jekyll serve'. If you change this file, please restart the server process. +# +# If you need help with YAML syntax, here are some quick references for you: +# https://learn-the-web.algonquindesign.ca/topics/markdown-yaml-cheat-sheet/#yaml +# https://learnxinyminutes.com/docs/yaml/ +# +# Site settings +# These are used to personalize your new site. If you look in the HTML files, +# you will see them accessed via {{ site.title }}, {{ site.email }}, and so on. +# You can create any custom variable you would like, and they will be accessible +# in the templates via {{ site.myvariable }}. + +title: Your awesome title +email: your-email@example.com +description: >- # this means to ignore newlines until "baseurl:" + Write an awesome description for your new site here. You can edit this + line in _config.yml. It will appear in your document head meta (for + Google search results) and in your feed.xml site description. +baseurl: "" # the subpath of your site, e.g. /blog +url: "" # the base hostname & protocol for your site, e.g. http://example.com +twitter_username: jekyllrb +github_username: jekyll + +# Build settings +theme: minima +plugins: + - jekyll-feed + +# Exclude from processing. +# The following items will not be processed, by default. +# Any item listed under the `exclude:` key here will be automatically added to +# the internal "default list". +# +# Excluded items can be processed by explicitly listing the directories or +# their entries' file path in the `include:` list. +# +# exclude: +# - .sass-cache/ +# - .jekyll-cache/ +# - gemfiles/ +# - Gemfile +# - Gemfile.lock +# - node_modules/ +# - vendor/bundle/ +# - vendor/cache/ +# - vendor/gems/ +# - vendor/ruby/ diff --git a/docs/_posts/2023-12-18-welcome-to-jekyll.markdown b/docs/_posts/2023-12-18-welcome-to-jekyll.markdown new file mode 100644 index 0000000..163a0a6 --- /dev/null +++ b/docs/_posts/2023-12-18-welcome-to-jekyll.markdown @@ -0,0 +1,29 @@ +--- +layout: post +title: "Welcome to Jekyll!" +date: 2023-12-18 10:49:10 +1000 +categories: jekyll update +--- +You’ll find this post in your `_posts` directory. Go ahead and edit it and re-build the site to see your changes. You can rebuild the site in many different ways, but the most common way is to run `jekyll serve`, which launches a web server and auto-regenerates your site when a file is updated. + +Jekyll requires blog post files to be named according to the following format: + +`YEAR-MONTH-DAY-title.MARKUP` + +Where `YEAR` is a four-digit number, `MONTH` and `DAY` are both two-digit numbers, and `MARKUP` is the file extension representing the format used in the file. After that, include the necessary front matter. Take a look at the source for this post to get an idea about how it works. + +Jekyll also offers powerful support for code snippets: + +{% highlight ruby %} +def print_hi(name) + puts "Hi, #{name}" +end +print_hi('Tom') +#=> prints 'Hi, Tom' to STDOUT. +{% endhighlight %} + +Check out the [Jekyll docs][jekyll-docs] for more info on how to get the most out of Jekyll. File all bugs/feature requests at [Jekyll’s GitHub repo][jekyll-gh]. If you have questions, you can ask them on [Jekyll Talk][jekyll-talk]. + +[jekyll-docs]: https://jekyllrb.com/docs/home +[jekyll-gh]: https://github.com/jekyll/jekyll +[jekyll-talk]: https://talk.jekyllrb.com/ diff --git a/docs/about.markdown b/docs/about.markdown new file mode 100644 index 0000000..8b4e0b2 --- /dev/null +++ b/docs/about.markdown @@ -0,0 +1,18 @@ +--- +layout: page +title: About +permalink: /about/ +--- + +This is the base Jekyll theme. You can find out more info about customizing your Jekyll theme, as well as basic Jekyll usage documentation at [jekyllrb.com](https://jekyllrb.com/) + +You can find the source code for Minima at GitHub: +[jekyll][jekyll-organization] / +[minima](https://github.com/jekyll/minima) + +You can find the source code for Jekyll at GitHub: +[jekyll][jekyll-organization] / +[jekyll](https://github.com/jekyll/jekyll) + + +[jekyll-organization]: https://github.com/jekyll diff --git a/docs/index.markdown b/docs/index.markdown new file mode 100644 index 0000000..0671507 --- /dev/null +++ b/docs/index.markdown @@ -0,0 +1,6 @@ +--- +# Feel free to add content and custom Front Matter to this file. +# To modify the layout, see https://jekyllrb.com/docs/themes/#overriding-theme-defaults + +layout: home +--- diff --git a/vprtempo/VPRTempo.py b/vprtempo/VPRTempo.py index a64f9b3..f3bf489 100644 --- a/vprtempo/VPRTempo.py +++ b/vprtempo/VPRTempo.py @@ -26,6 +26,7 @@ import os import torch +import random import numpy as np import torch.nn as nn @@ -130,7 +131,6 @@ def evaluate(self, models, test_loader, layers=None): # Initiliaze the output spikes variable out = [] labels = [] - num_corr = 0 # Run inference for the specified number of timesteps for spikes, label in test_loader: # Set device @@ -139,47 +139,46 @@ def evaluate(self, models, test_loader, layers=None): # Forward pass spikes = self.forward(spikes) # Add output spikes to list - #out.append(spikes.detach().cpu().tolist()) + out.append(spikes.detach().cpu().tolist()) pbar.update(1) - if torch.argmax(spikes) == label: - num_corr += 1 - print(f"Accuracy: {num_corr/len(test_loader)}") + # Close the tqdm progress bar pbar.close() # Rehsape output spikes into a similarity matrix - #out = np.reshape(np.array(out),(model.query_places,model.database_places)) + out = np.reshape(np.array(out),(model.query_places,model.database_places)) # Recall@N - # N = [1,5,10,15,20,25] # N values to calculate - # R = [] # Recall@N values + N = [1,5,10,15,20,25] # N values to calculate + R = [] # Recall@N values # Create GT matrix - # GT = np.zeros((model.query_places,model.database_places), dtype=int) - # for n, ndx in enumerate(labels): - # if model.filter !=1: - # ndx = ndx//model.filter - # GT[n,ndx] = 1 + GT = np.zeros((model.query_places,model.database_places), dtype=int) + for n, ndx in enumerate(labels): + if model.filter !=1: + ndx = ndx//model.filter + GT[n,ndx] = 1 + # Create GT soft matrix - # if model.GT_tolerance > 0: - # GTsoft = np.zeros((model.query_places, model.database_places), dtype=int) - # for n, ndx in enumerate(labels): - # if model.filter != 1: - # ndx = ndx // model.filter - # GTsoft[n, ndx] = 1 - # # Apply tolerance - # for i in range(max(0, n - model.GT_tolerance), min(model.query_places, n + model.GT_tolerance + 1)): - # GTsoft[i, ndx] = 1 - # else: - # GTsoft = None + if model.GT_tolerance > 0: + GTsoft = np.zeros((model.query_places,model.database_places), dtype=int) + for n, ndx in enumerate(labels): + if model.filter !=1: + ndx = ndx//model.filter + GTsoft[n, ndx] = 1 + # Apply tolerance + for i in range(max(0, n - model.GT_tolerance), min(model.query_places, n + model.GT_tolerance + 1)): + GTsoft[i, ndx] = 1 + else: + GTsoft = None # Calculate Recall@N - #for n in N: - # R.append(round(recallAtK(out,GThard=GT,GTsoft=GTsoft,K=n),2)) + for n in N: + R.append(round(recallAtK(out,GThard=GT,GTsoft=GTsoft,K=n),2)) # Print the results - #table = PrettyTable() - #table.field_names = ["N", "1", "5", "10", "15", "20", "25"] - #table.add_row(["Recall", R[0], R[1], R[2], R[3], R[4], R[5]]) - #self.logger.info(table) + table = PrettyTable() + table.field_names = ["N", "1", "5", "10", "15", "20", "25"] + table.add_row(["Recall", R[0], R[1], R[2], R[3], R[4], R[5]]) + self.logger.info(table) def forward(self, spikes): """ @@ -225,19 +224,29 @@ def run_inference(models, model_name): """ # Initialize the image transforms and datasets image_transform = ProcessImage(models[0].dims, models[0].patches) - max_samples=models[0].database_places + # Determine input range + if models[0].query_places == models[0].database_places: + max_samples=models[0].query_places + subset = False + elif models[0].query_places < models[0].database_places: + max_samples=models[0].database_places + subset = True + else: + raise ValueError("The number of query places must be less than or equal to the number of database places.") + # Initialize the test dataset test_dataset = CustomImageDataset(annotations_file=models[0].dataset_file, base_dir=models[0].data_dir, img_dirs=models[0].query_dir, transform=image_transform, skip=models[0].filter, max_samples=max_samples) - indices = torch.randperm(models[0].database_places).tolist() - subset_indicies = indices[:models[0].query_places] - subset = Subset(test_dataset, subset_indicies) + # If the number of query places is less than the number of database places, then subset the database + if subset: + test_dataset = Subset(test_dataset, random.sample(range(len(test_dataset)), models[0].query_places)) + # Initialize the data loader - test_loader = DataLoader(subset, + test_loader = DataLoader(test_dataset, batch_size=1, shuffle=models[0].shuffle, num_workers=8, @@ -245,7 +254,6 @@ def run_inference(models, model_name): # Load the model models[0].load_model(models, os.path.join('./vprtempo/models', model_name)) - # Retrieve layer names for inference layer_names = list(models[0].layer_dict.keys()) From d0ff33cd0f770acd948566e93c77d80fefc2ae3f Mon Sep 17 00:00:00 2001 From: Adam Hines Date: Tue, 19 Dec 2023 15:26:26 +1000 Subject: [PATCH 4/8] Altered model naming schema --- docs/404.html | 25 --------- docs/Gemfile | 35 ------------ docs/_config.yml | 55 ------------------- .../2023-12-18-welcome-to-jekyll.markdown | 29 ---------- docs/about.markdown | 18 ------ docs/index.markdown | 6 -- main.py | 14 ++--- vprtempo/VPRTempo.py | 38 +++++++++---- vprtempo/VPRTempoTrain.py | 10 ---- vprtempo/src/metrics.py | 17 +++--- 10 files changed, 42 insertions(+), 205 deletions(-) delete mode 100644 docs/404.html delete mode 100644 docs/Gemfile delete mode 100644 docs/_config.yml delete mode 100644 docs/_posts/2023-12-18-welcome-to-jekyll.markdown delete mode 100644 docs/about.markdown delete mode 100644 docs/index.markdown diff --git a/docs/404.html b/docs/404.html deleted file mode 100644 index 086a5c9..0000000 --- a/docs/404.html +++ /dev/null @@ -1,25 +0,0 @@ ---- -permalink: /404.html -layout: default ---- - - - -
-

404

- -

Page not found :(

-

The requested page could not be found.

-
diff --git a/docs/Gemfile b/docs/Gemfile deleted file mode 100644 index 60c2833..0000000 --- a/docs/Gemfile +++ /dev/null @@ -1,35 +0,0 @@ -source "https://rubygems.org" -# Hello! This is where you manage which Jekyll version is used to run. -# When you want to use a different version, change it below, save the -# file and run `bundle install`. Run Jekyll with `bundle exec`, like so: -# -# bundle exec jekyll serve -# -# This will help ensure the proper Jekyll version is running. -# Happy Jekylling! -# gem "jekyll", "~> 4.3.2" -# This is the default theme for new Jekyll sites. You may change this to anything you like. -gem "minima", "~> 2.5" -# If you want to use GitHub Pages, remove the "gem "jekyll"" above and -# uncomment the line below. To upgrade, run `bundle update github-pages`. -# gem "github-pages", group: :jekyll_plugins -# If you have any plugins, put them here! -group :jekyll_plugins do - gem "jekyll-feed", "~> 0.12" -end - -# Windows and JRuby does not include zoneinfo files, so bundle the tzinfo-data gem -# and associated library. -platforms :mingw, :x64_mingw, :mswin, :jruby do - gem "tzinfo", ">= 1", "< 3" - gem "tzinfo-data" -end - -# Performance-booster for watching directories on Windows -gem "wdm", "~> 0.1.1", :platforms => [:mingw, :x64_mingw, :mswin] - -# Lock `http_parser.rb` gem to `v0.6.x` on JRuby builds since newer versions of the gem -# do not have a Java counterpart. -gem "http_parser.rb", "~> 0.6.0", :platforms => [:jruby] - -gem "github-pages", "~> 228", group: :jekyll_plugins diff --git a/docs/_config.yml b/docs/_config.yml deleted file mode 100644 index ef7ba7c..0000000 --- a/docs/_config.yml +++ /dev/null @@ -1,55 +0,0 @@ -# Welcome to Jekyll! -# -# This config file is meant for settings that affect your whole blog, values -# which you are expected to set up once and rarely edit after that. If you find -# yourself editing this file very often, consider using Jekyll's data files -# feature for the data you need to update frequently. -# -# For technical reasons, this file is *NOT* reloaded automatically when you use -# 'bundle exec jekyll serve'. If you change this file, please restart the server process. -# -# If you need help with YAML syntax, here are some quick references for you: -# https://learn-the-web.algonquindesign.ca/topics/markdown-yaml-cheat-sheet/#yaml -# https://learnxinyminutes.com/docs/yaml/ -# -# Site settings -# These are used to personalize your new site. If you look in the HTML files, -# you will see them accessed via {{ site.title }}, {{ site.email }}, and so on. -# You can create any custom variable you would like, and they will be accessible -# in the templates via {{ site.myvariable }}. - -title: Your awesome title -email: your-email@example.com -description: >- # this means to ignore newlines until "baseurl:" - Write an awesome description for your new site here. You can edit this - line in _config.yml. It will appear in your document head meta (for - Google search results) and in your feed.xml site description. -baseurl: "" # the subpath of your site, e.g. /blog -url: "" # the base hostname & protocol for your site, e.g. http://example.com -twitter_username: jekyllrb -github_username: jekyll - -# Build settings -theme: minima -plugins: - - jekyll-feed - -# Exclude from processing. -# The following items will not be processed, by default. -# Any item listed under the `exclude:` key here will be automatically added to -# the internal "default list". -# -# Excluded items can be processed by explicitly listing the directories or -# their entries' file path in the `include:` list. -# -# exclude: -# - .sass-cache/ -# - .jekyll-cache/ -# - gemfiles/ -# - Gemfile -# - Gemfile.lock -# - node_modules/ -# - vendor/bundle/ -# - vendor/cache/ -# - vendor/gems/ -# - vendor/ruby/ diff --git a/docs/_posts/2023-12-18-welcome-to-jekyll.markdown b/docs/_posts/2023-12-18-welcome-to-jekyll.markdown deleted file mode 100644 index 163a0a6..0000000 --- a/docs/_posts/2023-12-18-welcome-to-jekyll.markdown +++ /dev/null @@ -1,29 +0,0 @@ ---- -layout: post -title: "Welcome to Jekyll!" -date: 2023-12-18 10:49:10 +1000 -categories: jekyll update ---- -You’ll find this post in your `_posts` directory. Go ahead and edit it and re-build the site to see your changes. You can rebuild the site in many different ways, but the most common way is to run `jekyll serve`, which launches a web server and auto-regenerates your site when a file is updated. - -Jekyll requires blog post files to be named according to the following format: - -`YEAR-MONTH-DAY-title.MARKUP` - -Where `YEAR` is a four-digit number, `MONTH` and `DAY` are both two-digit numbers, and `MARKUP` is the file extension representing the format used in the file. After that, include the necessary front matter. Take a look at the source for this post to get an idea about how it works. - -Jekyll also offers powerful support for code snippets: - -{% highlight ruby %} -def print_hi(name) - puts "Hi, #{name}" -end -print_hi('Tom') -#=> prints 'Hi, Tom' to STDOUT. -{% endhighlight %} - -Check out the [Jekyll docs][jekyll-docs] for more info on how to get the most out of Jekyll. File all bugs/feature requests at [Jekyll’s GitHub repo][jekyll-gh]. If you have questions, you can ask them on [Jekyll Talk][jekyll-talk]. - -[jekyll-docs]: https://jekyllrb.com/docs/home -[jekyll-gh]: https://github.com/jekyll/jekyll -[jekyll-talk]: https://talk.jekyllrb.com/ diff --git a/docs/about.markdown b/docs/about.markdown deleted file mode 100644 index 8b4e0b2..0000000 --- a/docs/about.markdown +++ /dev/null @@ -1,18 +0,0 @@ ---- -layout: page -title: About -permalink: /about/ ---- - -This is the base Jekyll theme. You can find out more info about customizing your Jekyll theme, as well as basic Jekyll usage documentation at [jekyllrb.com](https://jekyllrb.com/) - -You can find the source code for Minima at GitHub: -[jekyll][jekyll-organization] / -[minima](https://github.com/jekyll/minima) - -You can find the source code for Jekyll at GitHub: -[jekyll][jekyll-organization] / -[jekyll](https://github.com/jekyll/jekyll) - - -[jekyll-organization]: https://github.com/jekyll diff --git a/docs/index.markdown b/docs/index.markdown deleted file mode 100644 index 0671507..0000000 --- a/docs/index.markdown +++ /dev/null @@ -1,6 +0,0 @@ ---- -# Feel free to add content and custom Front Matter to this file. -# To modify the layout, see https://jekyllrb.com/docs/themes/#overriding-theme-defaults - -layout: home ---- diff --git a/main.py b/main.py index dbfcd99..ddde21e 100644 --- a/main.py +++ b/main.py @@ -33,19 +33,19 @@ from vprtempo.VPRTempo import VPRTempo, run_inference from vprtempo.src.loggers import model_logger, model_logger_quant from vprtempo.VPRTempoQuant import VPRTempoQuant, run_inference_quant -from vprtempo.VPRTempoQuantTrain import VPRTempoQuantTrain, generate_model_name_quant, train_new_model_quant from vprtempo.VPRTempoTrain import VPRTempoTrain, check_pretrained_model, train_new_model +from vprtempo.VPRTempoQuantTrain import VPRTempoQuantTrain, generate_model_name_quant, train_new_model_quant def generate_model_name(model): """ Generate the model name based on its parameters. """ - return ("VPRTempo" + - str(model.input) + - str(model.feature) + - str(model.database_places) + - ''.join(model.database_dirs) + - '.pth') + return (''.join(model.database_dirs)+"_"+ + "VPRTempo_" + + "IN"+str(model.input)+"_" + + "FN"+str(model.feature)+"_" + + "DB"+str(model.database_places) + + ".pth") def initialize_and_run_model(args,dims): # If user wants to train a new network diff --git a/vprtempo/VPRTempo.py b/vprtempo/VPRTempo.py index f3bf489..0089810 100644 --- a/vprtempo/VPRTempo.py +++ b/vprtempo/VPRTempo.py @@ -30,12 +30,13 @@ import numpy as np import torch.nn as nn +import matplotlib.pyplot as plt import vprtempo.src.blitnet as bn from tqdm import tqdm from prettytable import PrettyTable from torch.utils.data import DataLoader, Subset -from vprtempo.src.metrics import recallAtK +from vprtempo.src.metrics import recallAtK, createPR from vprtempo.src.dataset import CustomImageDataset, ProcessImage class VPRTempo(nn.Module): @@ -131,6 +132,7 @@ def evaluate(self, models, test_loader, layers=None): # Initiliaze the output spikes variable out = [] labels = [] + # Run inference for the specified number of timesteps for spikes, label in test_loader: # Set device @@ -144,26 +146,22 @@ def evaluate(self, models, test_loader, layers=None): # Close the tqdm progress bar pbar.close() - # Rehsape output spikes into a similarity matrix out = np.reshape(np.array(out),(model.query_places,model.database_places)) - # Recall@N - N = [1,5,10,15,20,25] # N values to calculate - R = [] # Recall@N values # Create GT matrix GT = np.zeros((model.query_places,model.database_places), dtype=int) for n, ndx in enumerate(labels): - if model.filter !=1: - ndx = ndx//model.filter + #if model.filter !=1: + # ndx = ndx//model.filter GT[n,ndx] = 1 # Create GT soft matrix if model.GT_tolerance > 0: GTsoft = np.zeros((model.query_places,model.database_places), dtype=int) for n, ndx in enumerate(labels): - if model.filter !=1: - ndx = ndx//model.filter + #if model.filter !=1: + # ndx = ndx//model.filter GTsoft[n, ndx] = 1 # Apply tolerance for i in range(max(0, n - model.GT_tolerance), min(model.query_places, n + model.GT_tolerance + 1)): @@ -171,6 +169,17 @@ def evaluate(self, models, test_loader, layers=None): else: GTsoft = None + # Create PR curve + P, R = createPR(out, GThard=GT, GTsoft=GTsoft, matching='single', n_thresh=100) + # Plot PR curve + plt.plot(R,P) + plt.xlabel('Recall') + plt.ylabel('Precision') + plt.title('Precision-Recall Curve') + plt.show() + # Recall@N + N = [1,5,10,15,20,25] # N values to calculate + R = [] # Recall@N values # Calculate Recall@N for n in N: R.append(round(recallAtK(out,GThard=GT,GTsoft=GTsoft,K=n),2)) @@ -239,11 +248,18 @@ def run_inference(models, model_name): base_dir=models[0].data_dir, img_dirs=models[0].query_dir, transform=image_transform, - skip=models[0].filter, max_samples=max_samples) # If the number of query places is less than the number of database places, then subset the database if subset: - test_dataset = Subset(test_dataset, random.sample(range(len(test_dataset)), models[0].query_places)) + if models[0].shuffle: + test_dataset = Subset(test_dataset, random.sample(range(len(test_dataset)), models[0].query_places)) + else: + # Generate indices with applied skip + indices = [i for i in range(models[0].database_places) if i % models[0].filter == 0] + # Limit to the desired number of queries + indices = indices[:models[0].query_places] + test_dataset = Subset(test_dataset, indices) + # Initialize the data loader test_loader = DataLoader(test_dataset, diff --git a/vprtempo/VPRTempoTrain.py b/vprtempo/VPRTempoTrain.py index 812b371..0a7049e 100644 --- a/vprtempo/VPRTempoTrain.py +++ b/vprtempo/VPRTempoTrain.py @@ -220,16 +220,6 @@ def save_model(self,models, model_out): torch.save(state_dicts, model_out) -def generate_model_name(model): - """ - Generate the model name based on its parameters. - """ - return ("VPRTempo" + - str(model.input) + - str(model.feature) + - str(model.output) + - ''.join(model.database_dirs) + - '.pth') def check_pretrained_model(model_name): """ diff --git a/vprtempo/src/metrics.py b/vprtempo/src/metrics.py index 4f2849a..9900fd4 100644 --- a/vprtempo/src/metrics.py +++ b/vprtempo/src/metrics.py @@ -47,20 +47,19 @@ def createPR(S_in, GThard, GTsoft=None, matching='multi', n_thresh=100): # copy S and set elements that are only true in GTsoft to min(S) to ignore them during evaluation S = S_in.copy() - S[S == 0] = np.nan if GTsoft is not None: S[GTsoft & ~GT] = S.min() # single-best-match or multi-match VPR if matching == 'single': - # GT-values for best match per query (i.e., per column) - GT = GT[np.nanargmax(S, axis=0), np.arange(GT.shape[1])] + # count the number of ground-truth positives (GTP) + GTP = np.count_nonzero(GT.any(0)) - # count the number of ground-truth positives (GTP) - GTP = np.count_nonzero(GT) + # GT-values for best match per query (i.e., per column) + GT = GT[np.argmax(S, axis=0), np.arange(GT.shape[1])] # similarities for best match per query (i.e., per column) - S = np.nanmax(S, axis=0) + S = np.max(S, axis=0) elif matching == 'multi': # count the number of ground-truth positives (GTP) @@ -71,8 +70,8 @@ def createPR(S_in, GThard, GTsoft=None, matching='multi', n_thresh=100): P = [1, ] # select start and end treshold - startV = np.nanmax(S) # start-value for treshold - endV = np.nanmin(S) # end-value for treshold + startV = S.max() # start-value for treshold + endV = S.min() # end-value for treshold # iterate over different thresholds for i in np.linspace(startV, endV, n_thresh): @@ -170,4 +169,4 @@ def recallAtK(S_in, GThard, GTsoft=None, K=1): # recall@K RatK = np.sum(GT.sum(0) > 0) / GT.shape[1] - return RatK + return RatK \ No newline at end of file From ab5f63cba921cd46413a4d471daea8685f66e321 Mon Sep 17 00:00:00 2001 From: Adam Hines Date: Thu, 21 Dec 2023 13:55:45 +1000 Subject: [PATCH 5/8] Fixed up some code redudancies, added in PR curve and sim mat plotting options --- .DS_Store | Bin 10244 -> 0 bytes main.py | 125 ++++++++++++++++++++-------------- tutorials/.DS_Store | Bin 6148 -> 0 bytes tutorials/mats/.DS_Store | Bin 6148 -> 0 bytes vprtempo/VPRTempo.py | 134 ++++++++++++++++++++++++------------- vprtempo/dataset/.DS_Store | Bin 14340 -> 0 bytes vprtempo/models/.DS_Store | Bin 6148 -> 0 bytes vprtempo/src/loggers.py | 4 +- 8 files changed, 165 insertions(+), 98 deletions(-) delete mode 100644 .DS_Store delete mode 100644 tutorials/.DS_Store delete mode 100644 tutorials/mats/.DS_Store delete mode 100644 vprtempo/dataset/.DS_Store delete mode 100644 vprtempo/models/.DS_Store diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index b9183bcce3be679eb9994ae6b0c29813116d3f44..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 10244 zcmeHM&2Jk;6n~QhXH6QqNs~4RkXHDD)F{|#NELBl<2XX*qlvOnNt6_Kzv3=k@0i_n zLa`$G3j6{50o=Is2v-gq;e@y$4*UaLdf-&vd@Qxs2~+|Rux6r}-`kn@X7>H&WAnxU zK&l1uhMA#$35zolYiD^Vjmeq?M1uC>rtTO@8kZ4Nm3B(pmM3^E9 zQzAnQBFwSgkb0)XoDY>+kVZ|cXEp_J zfkFgF$U}2f-0gEI=k6qUIKL^MW1QhWhrj;U=XasoU0q&TG%9O1|73#M1Y$aOfYBK_kZ3B13u;0DWJC<-281;4st`-HX7unPq+X3P=p<%cO7c*F5(hW z=;gZlE!WjJeobO<2ero70o*^J)Vk;sX?blGKQraZ-#n3|HyLqkMa!z(oD^|B*ql(L zD8OTRbXMXSo`ubd<%7 zRHztG3@8Q^1B!uBU|?K0VX*uE&GUc%9|c#H4~hZBz&Q+%iJDcbVq+F?=k?fKyMg&G zW|o*fj_A?HA>e9?jDKXer~ V;8QN1QStv@+&JTn args.max_module: + places -= args.max_module + num_modules += 1 + + # If the final module has less than max_module, reduce the dim of the output layer + remainder = args.database_places % args.max_module + if remainder != 0: # There are remainders, adjust output neuron count in final module + out_dim = int((args.database_places - remainder) / (num_modules - 1)) + final_out_dim = remainder + else: # No remainders, all modules are even + out_dim = int(args.database_places / num_modules) + final_out_dim = out_dim + # If user wants to train a new network if args.train_new_model: # If using quantization aware training @@ -63,7 +110,7 @@ def initialize_and_run_model(args,dims): model.qconfig = qconfig models.append(model) # Generate the model name - model_name = generate_model_name_quant(model) + model_name = generate_model_name(model,args.quantize) # Check if the model has been trained before check_pretrained_model(model_name) # Train the model @@ -73,29 +120,6 @@ def initialize_and_run_model(args,dims): else: models = [] logger = model_logger() # Initialize the logger - places = args.database_places # Copy out number of database places - - # Determine how many modules the network needs to create - num_modules = 1 - while places > args.max_module: - places -= args.max_module - num_modules += 1 - - # If the final module has less than max_module, reduce the dim of the output layer - remainder = args.database_places % args.max_module - - # Check if number of modules and database images works - if args.filter * (((num_modules-1)*args.max_module)+remainder) > args.database_places: - print("Error: Too many modules or too few images for the given filter") - sys.exit() - - # Modify final module output layer neuron count according to remainder - if remainder != 0: # There are remainders, adjust output neuron count in final module - out_dim = int((args.database_places - remainder) / (num_modules - 1)) - final_out_dim = remainder - else: # No remainders, all modules are even - out_dim = int(args.database_places / num_modules) - final_out_dim = out_dim # Create the modules final_out = None @@ -135,28 +159,21 @@ def initialize_and_run_model(args,dims): run_inference_quant(models, model_name, qconfig) else: models = [] - logger = model_logger() # Initialize the logger + logger, output_folder = model_logger() # Initialize the logger places = args.database_places # Copy out number of database places - # Determine how many modules the network needs to create - num_modules = 1 - while places > args.max_module: - places -= args.max_module - num_modules += 1 - - # If the final module has less than max_module, reduce the dim of the output layer - remainder = args.database_places % args.max_module - if remainder != 0: # There are remainders, adjust output neuron count in final module - out_dim = int((args.database_places - remainder) / (num_modules - 1)) - final_out_dim = remainder - else: # No remainders, all modules are even - out_dim = int(args.database_places / num_modules) - final_out_dim = out_dim - # Create the modules final_out = None for mod in tqdm(range(num_modules), desc="Initializing modules"): - model = VPRTempo(args, dims, logger, num_modules, out_dim, out_dim_remainder=final_out) # Initialize the model + model = VPRTempo( + args, + dims, + logger, + num_modules, + output_folder, + out_dim, + out_dim_remainder=final_out + ) model.eval() model.to(torch.device('cpu')) # Move module to CPU for storage (necessary for large models) models.append(model) # Create module list @@ -212,6 +229,12 @@ def parse_network(use_quantize=False, train_new_model=False): parser.add_argument('--quantize', action='store_true', help="Enable/disable quantization for the model") + # Define metrics functionality + parser.add_argument('--PR_curve', action='store_true', + help="Flag to generate a Precision-Recall curve") + parser.add_argument('--sim_mat', action='store_true', + help="Flag to plot the similarity matrix, GT, and GTsoft") + # If the function is called with specific arguments, override sys.argv if use_quantize or train_new_model: sys.argv = [''] diff --git a/tutorials/.DS_Store b/tutorials/.DS_Store deleted file mode 100644 index 6cbab197bfe01d7a1bba7c4bacd81a6ed4589230..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKL5mYH6n@zq?bJo6P+>1Zz-w9C?y|CY8OM6qqmJl7rOxaW9h{w%&eXD73Uk&U zOHy%LiwGjF!+**m_fdTd_S9bnolU*_^$CjZ_$vDr-s@wf6wzj=1R~vrA-|=5YZ`3TR;%Zuk z@#H!8o+y>n-`P5Z=2lP62=()LPKyT!MY1 z#lT{15F-#_T!F?_*c3w;cj%?H3oOzN4v diff --git a/tutorials/mats/.DS_Store b/tutorials/mats/.DS_Store deleted file mode 100644 index b2b0f537229c8fdbfc2772807bd4fdbbb908cc6d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKO;5r=5S;~-5@NzZ6ONmBB?@RT#!Ibu@M?`7)L<(_8*A54kOLv zT)wcfT6Bs|*}3-b)yN+NgJJ3g{YyGJQz{OpaS)zIy-}yKex#B?5GB37CPdvHmRwy# zNmq?LHB7px=K5yBDLJK1r8*wBn!7dG*x#Gf%SOHdGc@(fKm{VS!Tjo1u1z3RvDM0Ik zL??7DW(M`rfrD-V5DS>rhB5st5=UBeEoKIB22B_$qM-`gVhBUWytH+$#mu0ggRsqq zuvZqgLlOGwcz&tNLAVBaWCd7(p9)a#hgPBe|9F1?uZwtK1z3UqsemZ7{kDhCX7|>a xr=-1BqTisCQC(*6i-LjPim{fq;wri}%u5OoU5lAP%%Jg)fR=#=R$!qDd;$SiTIT=& diff --git a/vprtempo/VPRTempo.py b/vprtempo/VPRTempo.py index 0089810..691ffdf 100644 --- a/vprtempo/VPRTempo.py +++ b/vprtempo/VPRTempo.py @@ -25,6 +25,7 @@ ''' import os +import json import torch import random @@ -40,33 +41,41 @@ from vprtempo.src.dataset import CustomImageDataset, ProcessImage class VPRTempo(nn.Module): - def __init__(self, args, dims, logger, num_modules, out_dim, out_dim_remainder=None): + def __init__(self, args, dims, logger, num_modules, output_folder, out_dim, out_dim_remainder=None): super(VPRTempo, self).__init__() - # Set the arguments + # Set the args if args is not None: self.args = args for arg in vars(args): setattr(self, arg, getattr(args, arg)) setattr(self, 'dims', dims) + + # Set the device if torch.cuda.is_available(): self.device = "cuda:0" else: self.device = "cpu" + # Set input args self.logger = logger self.num_modules = num_modules + self.output_folder = output_folder # Set the dataset file self.dataset_file = os.path.join('./vprtempo/dataset', self.dataset + '.csv') self.query_dir = [dir.strip() for dir in self.query_dir.split(',')] + # Layer dict to keep track of layer names and their order self.layer_dict = {} self.layer_counter = 0 self.database_dirs = [dir.strip() for dir in self.database_dirs.split(',')] + # Define layer architecture self.input = int(self.dims[0]*self.dims[1]) self.feature = int(self.input * 2) + + # Output dimension changes for final module if not an even distribution of places if not out_dim_remainder is None: self.output = out_dim_remainder else: @@ -107,12 +116,12 @@ def add_layer(self, name, **kwargs): self.layer_dict[name] = self.layer_counter self.layer_counter += 1 - def evaluate(self, models, test_loader, layers=None): + def evaluate(self, models, test_loader): """ Run the inferencing model and calculate the accuracy. + :param models: Models to run inference on, each model is a VPRTempo module :param test_loader: Testing data loader - :param layers: Layers to pass data through """ # Initialize the tqdm progress bar pbar = tqdm(total=self.query_places, @@ -136,7 +145,7 @@ def evaluate(self, models, test_loader, layers=None): # Run inference for the specified number of timesteps for spikes, label in test_loader: # Set device - spikes, label = spikes.to(self.device), label.to(self.device) + spikes = spikes.to(self.device) labels.append(label.detach().cpu().item()) # Forward pass spikes = self.forward(spikes) @@ -152,31 +161,66 @@ def evaluate(self, models, test_loader, layers=None): # Create GT matrix GT = np.zeros((model.query_places,model.database_places), dtype=int) for n, ndx in enumerate(labels): - #if model.filter !=1: - # ndx = ndx//model.filter + if model.filter !=1: + ndx = ndx//model.filter GT[n,ndx] = 1 # Create GT soft matrix if model.GT_tolerance > 0: GTsoft = np.zeros((model.query_places,model.database_places), dtype=int) for n, ndx in enumerate(labels): - #if model.filter !=1: - # ndx = ndx//model.filter + if model.filter !=1: + ndx = ndx//model.filter GTsoft[n, ndx] = 1 # Apply tolerance for i in range(max(0, n - model.GT_tolerance), min(model.query_places, n + model.GT_tolerance + 1)): GTsoft[i, ndx] = 1 else: GTsoft = None + + # If user specified, generate a PR curve + if model.PR_curve: + # Create PR curve + P, R = createPR(out, GThard=GT, GTsoft=GTsoft, matching='single', n_thresh=100) + # Combine P and R into a list of lists + PR_data = { + "Precision": P, + "Recall": R + } + output_file = "PR_curve_data.json" + # Construct the full path + full_path = f"{model.output_folder}/{output_file}" + # Write the data to a JSON file + with open(full_path, 'w') as file: + json.dump(PR_data, file) + # Plot PR curve + plt.plot(R,P) + plt.xlabel('Recall') + plt.ylabel('Precision') + plt.title('Precision-Recall Curve') + plt.show() + + if model.sim_mat: + # Create a figure and a set of subplots + fig, axs = plt.subplots(1, 3, figsize=(15, 5)) + + # Plot each matrix using matshow + cax1 = axs[0].matshow(out, cmap='viridis') + fig.colorbar(cax1, ax=axs[0], shrink=0.8) + axs[0].set_title('Similarity matrix') + + cax2 = axs[1].matshow(GT, cmap='plasma') + fig.colorbar(cax2, ax=axs[1], shrink=0.8) + axs[1].set_title('GT') + + cax3 = axs[2].matshow(GTsoft, cmap='inferno') + fig.colorbar(cax3, ax=axs[2], shrink=0.8) + axs[2].set_title('GT-soft') + + # Adjust layout + plt.tight_layout() + plt.show() - # Create PR curve - P, R = createPR(out, GThard=GT, GTsoft=GTsoft, matching='single', n_thresh=100) - # Plot PR curve - plt.plot(R,P) - plt.xlabel('Recall') - plt.ylabel('Precision') - plt.title('Precision-Recall Curve') - plt.show() # Recall@N N = [1,5,10,15,20,25] # N values to calculate R = [] # Recall@N values @@ -227,52 +271,52 @@ def run_inference(models, model_name): """ Run inference on a pre-trained model. - :param model: Model to run inference on + :param models: Models to run inference on, each model is a VPRTempo module :param model_name: Name of the model to load - :param qconfig: Quantization configuration """ - # Initialize the image transforms and datasets - image_transform = ProcessImage(models[0].dims, models[0].patches) - - # Determine input range - if models[0].query_places == models[0].database_places: - max_samples=models[0].query_places - subset = False - elif models[0].query_places < models[0].database_places: - max_samples=models[0].database_places - subset = True + # Set first index model as the main model for parameters + model = models[0] + # Initialize the image transforms + image_transform = ProcessImage(model.dims, model.patches) + + # Determines if querying a subset of the database or the entire database + if model.query_places == model.database_places: + subset = False # Entire database + elif model.query_places < model.database_places: + subset = True # Subset of the database else: raise ValueError("The number of query places must be less than or equal to the number of database places.") + # Initialize the test dataset - test_dataset = CustomImageDataset(annotations_file=models[0].dataset_file, - base_dir=models[0].data_dir, - img_dirs=models[0].query_dir, + test_dataset = CustomImageDataset(annotations_file=model.dataset_file, + base_dir=model.data_dir, + img_dirs=model.query_dir, transform=image_transform, - max_samples=max_samples) - # If the number of query places is less than the number of database places, then subset the database + max_samples=model.database_places, + skip=model.filter) + + # If using a subset of the database if subset: - if models[0].shuffle: - test_dataset = Subset(test_dataset, random.sample(range(len(test_dataset)), models[0].query_places)) - else: - # Generate indices with applied skip - indices = [i for i in range(models[0].database_places) if i % models[0].filter == 0] + if model.shuffle: # For a randomized selection of database places + test_dataset = Subset(test_dataset, random.sample(range(len(test_dataset)), model.query_places)) + else: # For a sequential selection of database places + indices = [i for i in range(model.database_places) if i % model.filter == 0] # Limit to the desired number of queries - indices = indices[:models[0].query_places] + indices = indices[:model.query_places] + # Create the subset test_dataset = Subset(test_dataset, indices) # Initialize the data loader test_loader = DataLoader(test_dataset, batch_size=1, - shuffle=models[0].shuffle, + shuffle=model.shuffle, num_workers=8, persistent_workers=True) # Load the model - models[0].load_model(models, os.path.join('./vprtempo/models', model_name)) - # Retrieve layer names for inference - layer_names = list(models[0].layer_dict.keys()) + model.load_model(models, os.path.join('./vprtempo/models', model_name)) # Use evaluate method for inference accuracy with torch.no_grad(): - models[0].evaluate(models, test_loader, layers=layer_names) \ No newline at end of file + model.evaluate(models, test_loader) \ No newline at end of file diff --git a/vprtempo/dataset/.DS_Store b/vprtempo/dataset/.DS_Store deleted file mode 100644 index ea74ea872093a8a770552a84473018d23402d961..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 14340 zcmeGiZEPGzb>1en_mXYAZfxQNZ*%gcYT7ui<9yg@8ufg3n!5RLnzQ3`;^eY-ySBI4 z+r94Y-NkW@OjRn;Qc!+T6^g>IwzRYe2vqO`BuEKJ{3!t`e5(k=$B$NmKZT<3W@hhh z?`#VpxwL6_PBU+2-kUc&Z{Nqh+a-iRqM+>{q=yh9Vx>@)uz7-zD20JS@w6@~t8fOe zZpMmznGlnh#iV1Iwl%$odhvvSQ*`}mR;P$sw;t>aVvsF)6>IkmRb4$D}`|%E^x?P2R8H(p5o~vJS*6z z5A1X<>chAX7dS)=7T`+}@}=OX7{Hg~eAK6l`*49nzMKHt8Tt4-gP)-Q+a2eljGRE6 z!>|n{5K3UN1X^)QA{McULp034RqSh|n`DVgio>Rz&6t%ed;n>KS$Gp$m0g+|lsf80 znDLNXpZ8KButCqtP(k<(n<(6C5x3spjwS8)57w%ytdWu&s$n7iB%{>A7KR%EldwDCdzpMY$(*$3?hk zP3p9%W^{AJO4y!f=O$FgQLSk`F`F{7w#Bc!nX$#;F-NB?WB0wyMH!19DtET>j!a^_ zLr5(CGX1{@A84sk2Ma_-E!)y#@nd{8g284R86Zc<3385HAQ#DVG?~@qLxSCZ#jDHlz)Nb@TU`}))ZEHs9@;H3OUrV>)sO4Q z%P8tY$kDAB)M9gZW!;9&TefcR?YlRAVED*G*O;e~+N6`{`Yg|wGIVD`b+fvqp-UMn z8JagUUNl@IW9p-pn$sgK&=tAX+>Sn}Ls2+EbwQalRmd2R6bx@xj)@}MBdT;tlNw&! z5nH3XXrg=>`JkV>UFlR#QOuyDW2^ZK?o_r(qb>?YlvPgHE%ivF=g?e<6P-?*Kp}y00~`7FuBwY=JJ=4SmoLdm#ZMa10$;5+-2=UD?Cn!g-j5 zId~L41y90L@HBiDJ`c~qm*5h78NLEvhi|}T_!fK*uELMuC-5_P6aD~ygg?Ps@K5*` zyel*a8-#XYqtGdI3EPEkp-<=+?h*D0M}-j~C8)x~f-NxDhnYgzWu7h{kV9b>VZFKe zu~B$sX>;p1_`vXk(&3xN+^+Lb-imQNdiz~>@9N(_P*Fs#xn*SyngG-L8f#&C&-9{B zXRk3!ly2KxHodAba^oCBgVE26)NbV_pE|ZhM1zqV{NdI+ad=2>DS4d@7@-Axn zI#>_wa0m3DekU+YA7vVzgfTb+X*dfR4B0ubz=lUq*B^y>_yjD#W2ot$L0$hGT!d%g z3-CqM_vhgScoANLucEep8(xF&zz^Vua0Pzk>;11$>;DRW^EI}t$wz868U9hP$-&_M zwSx_h2ZDPg_~h(ra9@rl7~C(UqjcLEi&+%hOYAAzTO8cUb$m)S0#DkMS6D&`gcA5@ zN?;8hJA!_qh}9zA(z*PS@88+m`r6cidY78ifxf3_bRUwmp>m!ue}7N%_m`OwzR&r4 zW3v*E z;LMifV}33ez*qCYrSQQVu+$=xn*XeLvK{9C^dGJ;|8FE={!im)SV9Se5?EdW4II@& zzyCY1{P+LMGYB!E1VRbi^b&y8L&>289;5tXqy9wnTe~RMvsfw3_p#D~eYOM}_BkFx z_BkF7=;wI-jL2^g7pGl>3mjq(%m4jDK>Ah62mGDzieyGN`#a%EW`v(#g!%u9=l}l! Du1P2i diff --git a/vprtempo/models/.DS_Store b/vprtempo/models/.DS_Store deleted file mode 100644 index 28b3c56f896fa8c0c5ff67d4e21152c0cdbb6e3d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKJx{|h5PesQRK(H=x>U+9sAx-R$554ljW01n0g0gy2D)eG$ME}@;N97(NJ4-G zB+y-SezxO#v7ggi4M3*L>psv1(4s3gHduUMdSAU_jo7nJ6yvzT9UhQj2y(mq0R?pL zIylD>PH=|z`!~cMuK~vN5q5a>(c|tu;`VTnUCkHSoUwdyq)Y6*n%RoCD|};2u*>)n z87IU~*sxz>tIn>-HZZpO**%r`w!WC1uC->RcrI~uD_qA|;u7~%Nq+<1*YKsp%%+VVjeJrL*n5>;*%wwP{f{2{=|ku(tuvN0~E*CzA}x|)v51Fk|?u~f{lnuM*6?g+ibYdz1 diff --git a/vprtempo/src/loggers.py b/vprtempo/src/loggers.py index d870de1..8f7e5cb 100644 --- a/vprtempo/src/loggers.py +++ b/vprtempo/src/loggers.py @@ -48,7 +48,7 @@ def model_logger(): logger.info('Current device is: CPU') logger.info('') - return logger + return logger, output_folder def model_logger_quant(): """ @@ -91,4 +91,4 @@ def model_logger_quant(): logger.info('Current device is: CPU') logger.info('') - return logger \ No newline at end of file + return logger, output_folder \ No newline at end of file From 7e9abb68d04e7586733f9c441c314e507fb2f5ce Mon Sep 17 00:00:00 2001 From: Adam Hines Date: Tue, 2 Jan 2024 11:54:22 +1000 Subject: [PATCH 6/8] Fixed up quantization with new module strategy, trained new models with new model names --- .gitignore | 1 - README.md | 24 ++-- main.py | 44 +++++--- vprtempo/VPRTempoQuant.py | 194 ++++++++++++++++++++++++--------- vprtempo/VPRTempoQuantTrain.py | 105 ++++++++++-------- 5 files changed, 245 insertions(+), 123 deletions(-) diff --git a/.gitignore b/.gitignore index 7baed65..9a70cc7 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,3 @@ vprtempo/dataset/winter/ vprtempo/dataset/event.csv vprtempo/output/ vprtempo/src/__pycache__/ -vprtempo/models/VPRTempo3136627250045.pth diff --git a/README.md b/README.md index a22dcc7..f66787e 100644 --- a/README.md +++ b/README.md @@ -64,22 +64,28 @@ If you wish to enable CUDA, please follow the instructions on the [PyTorch - Get Dependencies can be installed either through our provided `requirements.txt` files. ```python -pip3 install -r requirements.txt +pip install -r requirements.txt ``` As above, if you wish to install CUDA please visit [PyTorch - Get Started](https://pytorch.org/get-started/locally/). ### Option 3: Conda install >**:heavy_exclamation_mark: Recommended:** -> Use [Mambaforge](https://mamba.readthedocs.io/en/latest/installation.html) instead of conda. +> Use [Mambaforge](https://mamba.readthedocs.io/en/latest/installation/mamba-installation.html) instead of conda. + +Requirements for VPRTempo may be installed using our [conda-forge package](https://anaconda.org/conda-forge/vprtempo). ```console -# Windows/Linux - CUDA enabled -conda create -n vprtempo -c pytorch -c nvidia python torchvision torchaudio pytorch-cuda=11.7 cudatoolkit prettytable tqdm numpy pandas scikit-learn +# Linux/OS X +conda create -n vprtempo -c conda-forge vprtempo + +# Linux CUDA enabled +conda create -n vprtempo -c conda-forge -c pytorch -c nvidia vprtempo pytorch-cuda cudatoolkit -# Windows/Linux - CPU only -conda create -n vprtempo python pytorch torchvision torchaudio cpuonly prettytable tqdm numpy pandas scikit-learn -c pytorch +# Windows +conda create -n vprtempo -c pytorch python pytorch torchvision torchaudio cpuonly prettytable tqdm numpy pandas scikit-learn + +# Windows CUDA enabled +conda create -n vprtempo -c pytorch -c nvidia python torchvision torchaudio pytorch-cuda=11.7 cudatoolkit prettytable tqdm numpy pandas scikit-learn -# MacOS -conda create -n vprtempo -c conda-forge python prettytable tqdm numpy pandas scikit-learn -c pytorch pytorch::pytorch torchvision torchaudio ``` ## Datasets @@ -142,8 +148,6 @@ python main.py --quantize Example of the quantized VPRTempo networking running

-#### IDE -You can also run VPRTempo through your IDE by running `main.py`. Change the `bool` flag for `use_quantize` to `True` if you wish to run VPRTempoQuant. ### Train new network If you do not wish to use the pretrained models or you would like to train your own, we can parse the `--train_new_model` flag to `main.py`. Note, if a pretrained model already exists you will be prompted if you would like to retrain it. diff --git a/main.py b/main.py index 0e739e6..5300ae5 100644 --- a/main.py +++ b/main.py @@ -100,21 +100,26 @@ def initialize_and_run_model(args,dims): # If using quantization aware training if args.quantize: models = [] - logger = model_logger_quant() - # Get the quantization config + logger = model_logger_quant() # Initialize the logger qconfig = quantization.get_default_qat_qconfig('fbgemm') - for _ in tqdm(range(args.num_modules), desc="Initializing modules"): - # Initialize the model - model = VPRTempoQuantTrain(args, dims, logger) + # Create the modules + final_out = None + for mod in tqdm(range(num_modules), desc="Initializing modules"): + model = VPRTempoQuantTrain(args, dims, logger, num_modules, out_dim, out_dim_remainder=final_out) # Initialize the model model.train() model.qconfig = qconfig - models.append(model) + quantization.prepare_qat(model, inplace=True) + models.append(model) # Create module list + if mod == num_modules - 2: + final_out = final_out_dim # Generate the model name model_name = generate_model_name(model,args.quantize) # Check if the model has been trained before check_pretrained_model(model_name) + # Get the quantization config + qconfig = quantization.get_default_qat_qconfig('fbgemm') # Train the model - train_new_model_quant(models, model_name, qconfig) + train_new_model_quant(models, model_name) # Base model else: @@ -143,20 +148,29 @@ def initialize_and_run_model(args,dims): # Set the quantization configuration if args.quantize: models = [] - logger = model_logger_quant() + logger, output_folder = model_logger_quant() qconfig = quantization.get_default_qat_qconfig('fbgemm') - for _ in tqdm(range(args.num_modules), desc="Initializing modules"): + final_out = None + for _ in tqdm(range(num_modules), desc="Initializing modules"): # Initialize the model - model = VPRTempoQuant(dims, args, logger) + model = VPRTempoQuant( + args, + dims, + logger, + num_modules, + output_folder, + out_dim, + out_dim_remainder=final_out + ) model.eval() model.qconfig = qconfig - model = quantization.prepare(model, inplace=False) - model = quantization.convert(model, inplace=False) + quantization.prepare(model, inplace=True) + quantization.convert(model, inplace=True) models.append(model) # Generate the model name - model_name = generate_model_name_quant(model) + model_name = generate_model_name(model, args.quantize) # Run the quantized inference model - run_inference_quant(models, model_name, qconfig) + run_inference_quant(models, model_name) else: models = [] logger, output_folder = model_logger() # Initialize the logger @@ -208,7 +222,7 @@ def parse_network(use_quantize=False, train_new_model=False): help="Directories to use for testing") parser.add_argument('--shuffle', action='store_true', help="Shuffle input images during query") - parser.add_argument('--GT_tolerance', type=int, default=2, + parser.add_argument('--GT_tolerance', type=int, default=1, help="Ground truth tolerance for matching") # Define training parameters diff --git a/vprtempo/VPRTempoQuant.py b/vprtempo/VPRTempoQuant.py index b667923..a752574 100644 --- a/vprtempo/VPRTempoQuant.py +++ b/vprtempo/VPRTempoQuant.py @@ -25,49 +25,64 @@ ''' import os +import json import torch +import random import numpy as np import torch.nn as nn +import matplotlib.pyplot as plt import vprtempo.src.blitnet as bn from tqdm import tqdm from prettytable import PrettyTable -from torch.utils.data import DataLoader -from vprtempo.src.metrics import recallAtK +from torch.utils.data import DataLoader, Subset +from vprtempo.src.metrics import recallAtK, createPR from torch.ao.quantization import QuantStub, DeQuantStub from vprtempo.src.dataset import CustomImageDataset, ProcessImage #from main import parse_network class VPRTempoQuant(nn.Module): - def __init__(self, dims, args=None, logger=None): + def __init__(self, args, dims, logger, num_modules, output_folder, out_dim, out_dim_remainder=None): super(VPRTempoQuant, self).__init__() - # Set the arguments - self.args = args - for arg in vars(args): - setattr(self, arg, getattr(args, arg)) + # Set the args + if args is not None: + self.args = args + for arg in vars(args): + setattr(self, arg, getattr(args, arg)) setattr(self, 'dims', dims) - # Set the dataset file - self.dataset_file = os.path.join('./vprtempo/dataset', self.dataset + '.csv') - # Set the model logger and return the device - self.logger = logger + # Set the device self.device = "cpu" - # Add quantization stubs for Quantization Aware Training (QAT) + # Set input args + self.logger = logger + self.num_modules = num_modules + self.output_folder = output_folder + self.quant = QuantStub() - self.dequant = DeQuantStub() + self.dequant = DeQuantStub() + + # Set the dataset file + self.dataset_file = os.path.join('./vprtempo/dataset', self.dataset + '.csv') + self.query_dir = [dir.strip() for dir in self.query_dir.split(',')] # Layer dict to keep track of layer names and their order self.layer_dict = {} self.layer_counter = 0 + self.database_dirs = [dir.strip() for dir in self.database_dirs.split(',')] # Define layer architecture - self.input = int(dims[0]*dims[1]) + self.input = int(self.dims[0]*self.dims[1]) self.feature = int(self.input * 2) - self.output = int(args.num_places / args.num_modules) + + # Output dimension changes for final module if not an even distribution of places + if not out_dim_remainder is None: + self.output = out_dim_remainder + else: + self.output = out_dim """ Define trainable layers here @@ -125,15 +140,17 @@ def evaluate(self, models, test_loader, layers=None): nn.ReLU() )) # Initialize the tqdm progress bar - pbar = tqdm(total=self.num_places, + pbar = tqdm(total=self.query_places, desc="Running the test network", position=0) # Initiliaze the output spikes variable out = [] + labels = [] # Run inference for the specified number of timesteps - for spikes, labels in test_loader: + for spikes, label in test_loader: # Set device - spikes, labels = spikes.to(self.device), labels.to(self.device) + spikes = spikes.to(self.device) + labels.append(label.detach().item()) # Pass through previous layers if they exist spikes = self.forward(spikes) # Add output spikes to list @@ -144,21 +161,82 @@ def evaluate(self, models, test_loader, layers=None): pbar.close() # Rehsape output spikes into a similarity matrix - out = np.reshape(np.array(out),(self.num_places,self.num_places)) - # Calculate and print the Recall@N - N = [1,5,10,15,20,25] - R = [] + out = np.reshape(np.array(out),(model.query_places,model.database_places)) + # Create GT matrix - GT = np.zeros((self.num_places,self.num_places), dtype=int) - for n in range(len(GT)): - GT[n,n] = 1 + GT = np.zeros((model.query_places,model.database_places), dtype=int) + for n, ndx in enumerate(labels): + if model.filter !=1: + ndx = ndx//model.filter + GT[n,ndx] = 1 + + # Create GT soft matrix + if model.GT_tolerance > 0: + GTsoft = np.zeros((model.query_places,model.database_places), dtype=int) + for n, ndx in enumerate(labels): + if model.filter !=1: + ndx = ndx//model.filter + GTsoft[n, ndx] = 1 + # Apply tolerance + for i in range(max(0, n - model.GT_tolerance), min(model.query_places, n + model.GT_tolerance + 1)): + GTsoft[i, ndx] = 1 + else: + GTsoft = None + + # If user specified, generate a PR curve + if model.PR_curve: + # Create PR curve + P, R = createPR(out, GThard=GT, GTsoft=GTsoft, matching='single', n_thresh=100) + # Combine P and R into a list of lists + PR_data = { + "Precision": P, + "Recall": R + } + output_file = "PR_curve_data.json" + # Construct the full path + full_path = f"{model.output_folder}/{output_file}" + # Write the data to a JSON file + with open(full_path, 'w') as file: + json.dump(PR_data, file) + # Plot PR curve + plt.plot(R,P) + plt.xlabel('Recall') + plt.ylabel('Precision') + plt.title('Precision-Recall Curve') + plt.show() + + if model.sim_mat: + # Create a figure and a set of subplots + fig, axs = plt.subplots(1, 3, figsize=(15, 5)) + + # Plot each matrix using matshow + cax1 = axs[0].matshow(out, cmap='viridis') + fig.colorbar(cax1, ax=axs[0], shrink=0.8) + axs[0].set_title('Similarity matrix') + + cax2 = axs[1].matshow(GT, cmap='plasma') + fig.colorbar(cax2, ax=axs[1], shrink=0.8) + axs[1].set_title('GT') + + cax3 = axs[2].matshow(GTsoft, cmap='inferno') + fig.colorbar(cax3, ax=axs[2], shrink=0.8) + axs[2].set_title('GT-soft') + + # Adjust layout + plt.tight_layout() + plt.show() + + # Recall@N + N = [1,5,10,15,20,25] # N values to calculate + R = [] # Recall@N values + # Calculate Recall@N for n in N: - R.append(recallAtK(out,GThard=GT,K=n)) + R.append(round(recallAtK(out,GThard=GT,GTsoft=GTsoft,K=n),2)) # Print the results table = PrettyTable() table.field_names = ["N", "1", "5", "10", "15", "20", "25"] table.add_row(["Recall", R[0], R[1], R[2], R[3], R[4], R[5]]) - print(table) + self.logger.info(table) def forward(self, spikes): """ @@ -194,45 +272,57 @@ def load_model(self, models, model_path): model.load_state_dict(combined_state_dict[f'model_{i}']) model.eval() # Set the model to inference mode - -def check_pretrained_model(model_name): - """ - Check if a pre-trained model exists and tell user if it does not. - """ - if not os.path.exists(os.path.join('./models', model_name)): - model.logger.info("A pre-trained network does not exist: please train one using VPRTempoQuant_Trainer") - pretrain = 'n' - else: - pretrain = 'y' - return pretrain -def run_inference_quant(models, model_name, qconfig): +def run_inference_quant(models, model_name): """ Run inference on a pre-trained model. - :param model: Model to run inference on + :param models: Models to run inference on, each model is a VPRTempo module :param model_name: Name of the model to load - :param qconfig: Quantization configuration """ - max_samples = models[0].num_places - # Initialize the image transforms and datasets - image_transform = ProcessImage(models[0].dims, models[0].patches) - test_dataset = CustomImageDataset(annotations_file=models[0].dataset_file, - base_dir=models[0].data_dir, - img_dirs=models[0].query_dir, + # Set first index model as the main model for parameters + model = models[0] + # Initialize the image transforms + image_transform = ProcessImage(model.dims, model.patches) + + # Determines if querying a subset of the database or the entire database + if model.query_places == model.database_places: + subset = False # Entire database + elif model.query_places < model.database_places: + subset = True # Subset of the database + else: + raise ValueError("The number of query places must be less than or equal to the number of database places.") + + # Initialize the test dataset + test_dataset = CustomImageDataset(annotations_file=model.dataset_file, + base_dir=model.data_dir, + img_dirs=model.query_dir, transform=image_transform, - skip=models[0].filter, - max_samples=max_samples) + max_samples=model.database_places, + skip=model.filter) + + # If using a subset of the database + if subset: + if model.shuffle: # For a randomized selection of database places + test_dataset = Subset(test_dataset, random.sample(range(len(test_dataset)), model.query_places)) + else: # For a sequential selection of database places + indices = [i for i in range(model.database_places) if i % model.filter == 0] + # Limit to the desired number of queries + indices = indices[:model.query_places] + # Create the subset + test_dataset = Subset(test_dataset, indices) + + # Initialize the data loader test_loader = DataLoader(test_dataset, batch_size=1, - shuffle=False, + shuffle=model.shuffle, num_workers=8, persistent_workers=True) # Load the model - models[0].load_model(models, os.path.join('./vprtempo/models', model_name)) + model.load_model(models, os.path.join('./vprtempo/models', model_name)) # Use evaluate method for inference accuracy with torch.no_grad(): - models[0].evaluate(models, test_loader) \ No newline at end of file + model.evaluate(models, test_loader) \ No newline at end of file diff --git a/vprtempo/VPRTempoQuantTrain.py b/vprtempo/VPRTempoQuantTrain.py index 451420a..ac320de 100644 --- a/vprtempo/VPRTempoQuantTrain.py +++ b/vprtempo/VPRTempoQuantTrain.py @@ -39,7 +39,7 @@ from vprtempo.src.dataset import CustomImageDataset, ProcessImage class VPRTempoQuantTrain(nn.Module): - def __init__(self, args, dims, logger): + def __init__(self, args, dims, logger, num_modules, out_dim, out_dim_remainder=None): super(VPRTempoQuantTrain, self).__init__() # Set the arguments @@ -47,20 +47,16 @@ def __init__(self, args, dims, logger): for arg in vars(args): setattr(self, arg, getattr(args, arg)) setattr(self, 'dims', dims) - - # Only CPU available for quantization - self.device = "cpu" - # Set the logger + self.device = "cpu" self.logger = logger + self.num_modules = num_modules + self.quant = QuantStub() + self.dequant = DeQuantStub() # Set the dataset file self.dataset_file = os.path.join('./vprtempo/dataset', self.dataset + '.csv') - # Add quantization stubs for Quantization Aware Training (QAT) - self.quant = QuantStub() - self.dequant = DeQuantStub() - # Layer dict to keep track of layer names and their order self.layer_dict = {} self.layer_counter = 0 @@ -68,11 +64,18 @@ def __init__(self, args, dims, logger): # Define layer architecture self.input = int(dims[0]*dims[1]) self.feature = int(self.input * 2) - self.output = int(args.num_places / args.num_modules) + if not out_dim_remainder is None: + self.output = out_dim_remainder + else: + self.output = out_dim # Set the total timestep count - self.location_repeat = len(args.database_dirs) # Number of times to repeat the locations - self.T = int((self.num_places / self.num_modules) * self.location_repeat * self.epoch) + self.database_dirs = [dir.strip() for dir in self.database_dirs.split(',')] + self.location_repeat = len(self.database_dirs) # Number of times to repeat the locations + if not out_dim_remainder is None: + self.T = int(out_dim_remainder * self.location_repeat * self.epoch) + else: + self.T = int(self.max_module * self.location_repeat * self.epoch) """ Define trainable layers here @@ -226,60 +229,72 @@ def check_pretrained_model(model_name): return retrain == 'n' return False -def train_new_model_quant(models, model_name, qconfig): +def train_new_model_quant(models, model_name): """ Train a new model. :param model: Model to train :param model_name: Name of the model to save after training - :param qconfig: Quantization configuration """ + # Set first index model as the main model for parameters + model = models[0] # Initialize the image transforms and datasets image_transform = transforms.Compose([ - ProcessImage(models[0].dims, models[0].patches) + ProcessImage(model.dims, model.patches) ]) # Automatically generate user_input_ranges user_input_ranges = [] start_idx = 0 - for _ in range(models[0].num_modules): - range_temp = [start_idx, start_idx+((models[0].max_module-1)*models[0].filter)] + for _ in range(model.num_modules): + range_temp = [start_idx, start_idx+((model.max_module-1)*model.filter)] user_input_ranges.append(range_temp) start_idx = range_temp[1] + models[0].filter - if models[0].num_places < models[0].max_module: - max_samples=models[0].num_places + if model.query_places < model.max_module: + max_samples=model.query_places else: - max_samples = models[0].max_module + max_samples = model.max_module # Keep track of trained layers to pass data through them + trained_layers = [] + # Training each layer - trained_models = [] - for i, model in enumerate(models): - trained_layers = [] - img_range=user_input_ranges[i] - train_dataset = CustomImageDataset(annotations_file=models[0].dataset_file, - base_dir=models[0].data_dir, - img_dirs=models[0].database_dirs, - transform=image_transform, - skip=models[0].filter, - test=False, - img_range=img_range, - max_samples=max_samples) - # Initialize the data loader - train_loader = DataLoader(train_dataset, - batch_size=1, - shuffle=True, - num_workers=8, - persistent_workers=True) - model = quantization.prepare_qat(model, inplace=False) - for layer_name, _ in sorted(models[0].layer_dict.items(), key=lambda item: item[1]): - print(f"Training layer: {layer_name}") + for layer_name, _ in sorted(model.layer_dict.items(), key=lambda item: item[1]): + print(f"Training layer: {layer_name}") + # Retrieve the layer object + for i, model in enumerate(models): + model.train() + model.to(torch.device(model.device)) layer = (getattr(model, layer_name)) - + # Determine the maximum samples for the DataLoader + if model.database_places < model.max_module: + max_samples = model.database_places + elif model.output < model.max_module: + max_samples = model.output + else: + max_samples = model.max_module + # Initialize new dataset with unique range for each module + img_range=user_input_ranges[i] + train_dataset = CustomImageDataset(annotations_file=model.dataset_file, + base_dir=model.data_dir, + img_dirs=model.database_dirs, + transform=image_transform, + skip=model.filter, + test=False, + img_range=img_range, + max_samples=max_samples) + # Initialize the data loader + train_loader = DataLoader(train_dataset, + batch_size=1, + shuffle=True, + num_workers=8, + persistent_workers=True) # Train the layers model.train_model(train_loader, layer, model, i, prev_layers=trained_layers) trained_layers.append(layer_name) - trained_models.append(quantization.convert(model, inplace=False)) - # After training the current layer, add it to the list of trained layer + + # Convert the model to evaluation mode + for model in models: + quantization.convert(model, inplace=True) # Save the model - model.save_model(trained_models,os.path.join('./vprtempo/models', model_name)) \ No newline at end of file + model.save_model(models,os.path.join('./vprtempo/models', model_name)) \ No newline at end of file From 5832b774e0212445f288292679a45ed59e0894ad Mon Sep 17 00:00:00 2001 From: Adam Hines Date: Tue, 2 Jan 2024 11:59:57 +1000 Subject: [PATCH 7/8] Touch-ups - readme and version number for v1.1.5 --- README.md | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index f66787e..3b13f4e 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ To use VPRTempo, please follow the instructions below for installation and usage - Quantization Aware Training (QAT) enabled to train weights in int8 space - Addition of tutorials in Jupyter Notebooks to learn how to use VPRTempo as well as explain the computational logic - Simplification of weight operations, reducing to a single weight tensor - allowing positive and negative connections to change sign during training - - Easier dependency installation with PyPi/pip + - Easier dependency installation with PyPi/pip and conda - And more! ## License & Citation diff --git a/setup.py b/setup.py index 112ab1f..c63fda9 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ # define the setup setup( name="VPRTempo", - version="1.1.4", + version="1.1.5", description='VPRTempo: A Fast Temporally Encoded Spiking Neural Network for Visual Place Recognition', long_description=long_description, long_description_content_type='text/markdown', From e958cda4cd5dcc762ba0456e1d78ba109938a4af Mon Sep 17 00:00:00 2001 From: Adam Hines Date: Tue, 2 Jan 2024 12:21:35 +1000 Subject: [PATCH 8/8] Fixed issue in quant model with training layers across modules --- vprtempo/VPRTempoQuantTrain.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/vprtempo/VPRTempoQuantTrain.py b/vprtempo/VPRTempoQuantTrain.py index ac320de..d5b3c68 100644 --- a/vprtempo/VPRTempoQuantTrain.py +++ b/vprtempo/VPRTempoQuantTrain.py @@ -151,7 +151,7 @@ def train_model(self, train_loader, layer, model, model_num, prev_layers=None): # idx scale factor for different modules idx_scale = (self.max_module*self.filter)*model_num # Run training for the specified number of epochs - for epoch in range(self.epoch): + for _ in range(self.epoch): # Run training for the specified number of timesteps for spikes, labels in train_loader: spikes, labels = spikes.to(self.device), labels.to(self.device) @@ -245,15 +245,12 @@ def train_new_model_quant(models, model_name): # Automatically generate user_input_ranges user_input_ranges = [] start_idx = 0 - + # Generate the image ranges for each module for _ in range(model.num_modules): range_temp = [start_idx, start_idx+((model.max_module-1)*model.filter)] user_input_ranges.append(range_temp) - start_idx = range_temp[1] + models[0].filter - if model.query_places < model.max_module: - max_samples=model.query_places - else: - max_samples = model.max_module + start_idx = range_temp[1] + model.filter + # Keep track of trained layers to pass data through them trained_layers = [] @@ -290,7 +287,7 @@ def train_new_model_quant(models, model_name): persistent_workers=True) # Train the layers model.train_model(train_loader, layer, model, i, prev_layers=trained_layers) - trained_layers.append(layer_name) + trained_layers.append(layer_name) # Convert the model to evaluation mode for model in models: