From 139edacfbaf62e799d610d5066478155bf6be02d Mon Sep 17 00:00:00 2001
From: Adam Hines
Date: Wed, 13 Dec 2023 16:02:39 +1000
Subject: [PATCH 1/8] Adding in query specific places, not just the same as
model training amount
---
.gitignore | 1 +
main.py | 10 +++++++--
vprtempo/VPRTempo.py | 45 ++++++++++++++++++++++++++++++---------
vprtempo/VPRTempoTrain.py | 2 ++
4 files changed, 46 insertions(+), 12 deletions(-)
diff --git a/.gitignore b/.gitignore
index 9a70cc7..7baed65 100644
--- a/.gitignore
+++ b/.gitignore
@@ -8,3 +8,4 @@ vprtempo/dataset/winter/
vprtempo/dataset/event.csv
vprtempo/output/
vprtempo/src/__pycache__/
+vprtempo/models/VPRTempo3136627250045.pth
diff --git a/main.py b/main.py
index 753163e..31889d4 100644
--- a/main.py
+++ b/main.py
@@ -25,9 +25,11 @@
'''
import sys
import argparse
+import torch
import torch.quantization as quantization
+from tqdm import tqdm
from vprtempo.VPRTempo import VPRTempo, run_inference
from vprtempo.src.loggers import model_logger, model_logger_quant
from vprtempo.VPRTempoQuant import VPRTempoQuant, run_inference_quant
@@ -61,6 +63,7 @@ def initialize_and_run_model(args,dims):
for _ in range(args.num_modules):
# Initialize the model
model = VPRTempoTrain(args, dims, logger)
+ model.to(torch.device('cpu'))
models.append(model)
# Generate the model name
model_name = generate_model_name(model)
@@ -90,9 +93,10 @@ def initialize_and_run_model(args,dims):
else:
models = []
logger = model_logger()
- for _ in range(args.num_modules):
+ for _ in tqdm(range(args.num_modules), desc="Initializing modules"):
# Initialize the model
model = VPRTempo(dims, args, logger)
+ model.to(torch.device('cpu'))
models.append(model)
# Generate the model name
model_name = generate_model_name(model)
@@ -111,7 +115,9 @@ def parse_network(use_quantize=False, train_new_model=False):
parser.add_argument('--data_dir', type=str, default='./vprtempo/dataset/',
help="Directory where dataset files are stored")
parser.add_argument('--num_places', type=int, default=500,
- help="Number of places to use for training and/or inferencing")
+ help="Number of places to use for training")
+ parser.add_argument('--query_places', type=int, default=500,
+ help="Number of places to use for inferencing")
parser.add_argument('--num_modules', type=int, default=1,
help="Number of expert modules to use split images into")
parser.add_argument('--max_module', type=int, default=500,
diff --git a/vprtempo/VPRTempo.py b/vprtempo/VPRTempo.py
index 6a0a2f0..0198cef 100644
--- a/vprtempo/VPRTempo.py
+++ b/vprtempo/VPRTempo.py
@@ -108,7 +108,7 @@ def evaluate(self, models, test_loader, layers=None):
:param layers: Layers to pass data through
"""
# Initialize the tqdm progress bar
- pbar = tqdm(total=self.num_places,
+ pbar = tqdm(total=self.query_places,
desc="Running the test network",
position=0)
self.inferences = []
@@ -121,12 +121,15 @@ def evaluate(self, models, test_loader, layers=None):
nn.Hardtanh(0, 0.9),
nn.ReLU()
))
+ self.inferences[-1].to(torch.device(self.device))
# Initiliaze the output spikes variable
out = []
+ labels = []
# Run inference for the specified number of timesteps
- for spikes, labels in test_loader:
+ for spikes, label in test_loader:
# Set device
- spikes, labels = spikes.to(self.device), labels.to(self.device)
+ spikes = spikes.to(self.device)
+ labels.append(label.detach().cpu().item())
# Forward pass
spikes = self.forward(spikes)
# Add output spikes to list
@@ -135,17 +138,17 @@ def evaluate(self, models, test_loader, layers=None):
# Close the tqdm progress bar
pbar.close()
-
+
# Rehsape output spikes into a similarity matrix
- out = np.reshape(np.array(out),(model.num_places,model.num_places))
+ out = np.reshape(np.array(out),(model.query_places,model.num_places))
# Recall@N
N = [1,5,10,15,20,25] # N values to calculate
R = [] # Recall@N values
# Create GT matrix
- GT = np.zeros((model.num_places,model.num_places), dtype=int)
- for n in range(len(GT)):
- GT[n,n] = 1
+ GT = np.zeros((model.query_places,model.num_places), dtype=int)
+ for n, ndx in enumerate(labels):
+ GT[n,ndx] = 1
# Calculate Recall@N
for n in N:
R.append(round(recallAtK(out,GThard=GT,K=n),2))
@@ -155,6 +158,28 @@ def evaluate(self, models, test_loader, layers=None):
table.add_row(["Recall", R[0], R[1], R[2], R[3], R[4], R[5]])
print(table)
+ import matplotlib.pyplot as plt
+
+ # Increasing the figure size significantly
+ plt.figure(figsize=(20, 30)) # This is a large figure size, which can be adjusted as needed
+
+ # Plotting "Spiking output"
+ plt.subplot(2, 1, 1) # First plot
+ plt.imshow(out, aspect='auto', interpolation='nearest') # 'nearest' interpolation for clear points
+ plt.title("Spiking output")
+ plt.colorbar(shrink=0.5)
+ plt.grid(False) # Disabling grid lines
+
+ # Plotting "Ground truth"
+ plt.subplot(2, 1, 2) # Second plot
+ plt.imshow(GT, aspect='auto', interpolation='nearest')
+ plt.title("Ground truth")
+ plt.colorbar(shrink=0.5)
+ plt.grid(False) # Disabling grid lines
+
+ # Displaying the plots
+ plt.show()
+
def forward(self, spikes):
"""
Compute the forward pass of the model.
@@ -198,7 +223,7 @@ def run_inference(models, model_name):
"""
# Initialize the image transforms and datasets
image_transform = ProcessImage(models[0].dims, models[0].patches)
- max_samples=models[0].num_places
+ max_samples=models[0].query_places
test_dataset = CustomImageDataset(annotations_file=models[0].dataset_file,
base_dir=models[0].data_dir,
@@ -209,7 +234,7 @@ def run_inference(models, model_name):
# Initialize the data loader
test_loader = DataLoader(test_dataset,
batch_size=1,
- shuffle=False,
+ shuffle=True,
num_workers=8,
persistent_workers=True)
diff --git a/vprtempo/VPRTempoTrain.py b/vprtempo/VPRTempoTrain.py
index de3d449..e0bac5c 100644
--- a/vprtempo/VPRTempoTrain.py
+++ b/vprtempo/VPRTempoTrain.py
@@ -264,6 +264,7 @@ def train_new_model(models, model_name):
# Retrieve the layer object
for i, model in enumerate(models):
model.train()
+ model.to(torch.device(model.device))
layer = (getattr(model, layer_name))
img_range=user_input_ranges[i]
train_dataset = CustomImageDataset(annotations_file=models[0].dataset_file,
@@ -282,6 +283,7 @@ def train_new_model(models, model_name):
persistent_workers=True)
# Train the layers
model.train_model(train_loader, layer, model, i, prev_layers=trained_layers)
+ model.to(torch.device("cpu"))
# After training the current layer, add it to the list of trained layers
trained_layers.append(layer_name)
# Convert the model to evaluation mode
From f1c960346a749c0d91951d7f3528ff94810ced92 Mon Sep 17 00:00:00 2001
From: Adam Hines
Date: Fri, 15 Dec 2023 10:24:05 +1000
Subject: [PATCH 2/8] Added in GT tolerance, adding 4 full dataset models, need
to fix up metrics and GT generation
---
main.py | 111 ++++++++++++++++++++++++++++++--------
vprtempo/VPRTempo.py | 97 +++++++++++++++++----------------
vprtempo/VPRTempoTrain.py | 45 +++++++++++-----
vprtempo/src/dataset.py | 2 +-
vprtempo/src/loggers.py | 4 ++
vprtempo/src/metrics.py | 4 ++
6 files changed, 180 insertions(+), 83 deletions(-)
diff --git a/main.py b/main.py
index 31889d4..dbfcd99 100644
--- a/main.py
+++ b/main.py
@@ -34,7 +34,18 @@
from vprtempo.src.loggers import model_logger, model_logger_quant
from vprtempo.VPRTempoQuant import VPRTempoQuant, run_inference_quant
from vprtempo.VPRTempoQuantTrain import VPRTempoQuantTrain, generate_model_name_quant, train_new_model_quant
-from vprtempo.VPRTempoTrain import VPRTempoTrain, generate_model_name, check_pretrained_model, train_new_model
+from vprtempo.VPRTempoTrain import VPRTempoTrain, check_pretrained_model, train_new_model
+
+def generate_model_name(model):
+ """
+ Generate the model name based on its parameters.
+ """
+ return ("VPRTempo" +
+ str(model.input) +
+ str(model.feature) +
+ str(model.database_places) +
+ ''.join(model.database_dirs) +
+ '.pth')
def initialize_and_run_model(args,dims):
# If user wants to train a new network
@@ -45,7 +56,7 @@ def initialize_and_run_model(args,dims):
logger = model_logger_quant()
# Get the quantization config
qconfig = quantization.get_default_qat_qconfig('fbgemm')
- for _ in range(args.num_modules):
+ for _ in tqdm(range(args.num_modules), desc="Initializing modules"):
# Initialize the model
model = VPRTempoQuantTrain(args, dims, logger)
model.train()
@@ -57,20 +68,52 @@ def initialize_and_run_model(args,dims):
check_pretrained_model(model_name)
# Train the model
train_new_model_quant(models, model_name, qconfig)
- else: # Normal model
+
+ # Base model
+ else:
models = []
- logger = model_logger()
- for _ in range(args.num_modules):
- # Initialize the model
- model = VPRTempoTrain(args, dims, logger)
- model.to(torch.device('cpu'))
- models.append(model)
+ logger = model_logger() # Initialize the logger
+ places = args.database_places # Copy out number of database places
+
+ # Determine how many modules the network needs to create
+ num_modules = 1
+ while places > args.max_module:
+ places -= args.max_module
+ num_modules += 1
+
+ # If the final module has less than max_module, reduce the dim of the output layer
+ remainder = args.database_places % args.max_module
+
+ # Check if number of modules and database images works
+ if args.filter * (((num_modules-1)*args.max_module)+remainder) > args.database_places:
+ print("Error: Too many modules or too few images for the given filter")
+ sys.exit()
+
+ # Modify final module output layer neuron count according to remainder
+ if remainder != 0: # There are remainders, adjust output neuron count in final module
+ out_dim = int((args.database_places - remainder) / (num_modules - 1))
+ final_out_dim = remainder
+ else: # No remainders, all modules are even
+ out_dim = int(args.database_places / num_modules)
+ final_out_dim = out_dim
+
+ # Create the modules
+ final_out = None
+ for mod in tqdm(range(num_modules), desc="Initializing modules"):
+ model = VPRTempoTrain(args, dims, logger, num_modules, out_dim, out_dim_remainder=final_out) # Initialize the model
+ model.to(torch.device('cpu')) # Move module to CPU for storage (necessary for large models)
+ models.append(model) # Create module list
+ if mod == num_modules - 2:
+ final_out = final_out_dim
+
# Generate the model name
model_name = generate_model_name(model)
+ print(f"Model name: {model_name}")
# Check if the model has been trained before
check_pretrained_model(model_name)
# Train the model
train_new_model(models, model_name)
+
# Run the inference network
else:
# Set the quantization configuration
@@ -78,7 +121,7 @@ def initialize_and_run_model(args,dims):
models = []
logger = model_logger_quant()
qconfig = quantization.get_default_qat_qconfig('fbgemm')
- for _ in range(args.num_modules):
+ for _ in tqdm(range(args.num_modules), desc="Initializing modules"):
# Initialize the model
model = VPRTempoQuant(dims, args, logger)
model.eval()
@@ -92,14 +135,36 @@ def initialize_and_run_model(args,dims):
run_inference_quant(models, model_name, qconfig)
else:
models = []
- logger = model_logger()
- for _ in tqdm(range(args.num_modules), desc="Initializing modules"):
- # Initialize the model
- model = VPRTempo(dims, args, logger)
- model.to(torch.device('cpu'))
- models.append(model)
+ logger = model_logger() # Initialize the logger
+ places = args.database_places # Copy out number of database places
+
+ # Determine how many modules the network needs to create
+ num_modules = 1
+ while places > args.max_module:
+ places -= args.max_module
+ num_modules += 1
+
+ # If the final module has less than max_module, reduce the dim of the output layer
+ remainder = args.database_places % args.max_module
+ if remainder != 0: # There are remainders, adjust output neuron count in final module
+ out_dim = int((args.database_places - remainder) / (num_modules - 1))
+ final_out_dim = remainder
+ else: # No remainders, all modules are even
+ out_dim = int(args.database_places / num_modules)
+ final_out_dim = out_dim
+
+ # Create the modules
+ final_out = None
+ for mod in tqdm(range(num_modules), desc="Initializing modules"):
+ model = VPRTempo(args, dims, logger, num_modules, out_dim, out_dim_remainder=final_out) # Initialize the model
+ model.eval()
+ model.to(torch.device('cpu')) # Move module to CPU for storage (necessary for large models)
+ models.append(model) # Create module list
+ if mod == num_modules - 2:
+ final_out = final_out_dim
# Generate the model name
model_name = generate_model_name(model)
+ print(f"Model name: {model_name}")
# Run the inference model
run_inference(models, model_name)
@@ -114,21 +179,23 @@ def parse_network(use_quantize=False, train_new_model=False):
help="Dataset to use for training and/or inferencing")
parser.add_argument('--data_dir', type=str, default='./vprtempo/dataset/',
help="Directory where dataset files are stored")
- parser.add_argument('--num_places', type=int, default=500,
+ parser.add_argument('--database_places', type=int, default=500,
help="Number of places to use for training")
parser.add_argument('--query_places', type=int, default=500,
help="Number of places to use for inferencing")
- parser.add_argument('--num_modules', type=int, default=1,
- help="Number of expert modules to use split images into")
parser.add_argument('--max_module', type=int, default=500,
help="Maximum number of images per module")
- parser.add_argument('--database_dirs', nargs='+', default=['spring', 'fall'],
+ parser.add_argument('--database_dirs', type=str, default='spring, fall',
help="Directories to use for training")
- parser.add_argument('--query_dir', nargs='+', default=['summer'],
+ parser.add_argument('--query_dir', type=str, default='summer',
help="Directories to use for testing")
+ parser.add_argument('--shuffle', action='store_true',
+ help="Shuffle input images during query")
+ parser.add_argument('--GT_tolerance', type=int, default=2,
+ help="Ground truth tolerance for matching")
# Define training parameters
- parser.add_argument('--filter', type=int, default=8,
+ parser.add_argument('--filter', type=int, default=1,
help="Images to skip for training and/or inferencing")
parser.add_argument('--epoch', type=int, default=4,
help="Number of epochs to train the model")
diff --git a/vprtempo/VPRTempo.py b/vprtempo/VPRTempo.py
index 0198cef..a64f9b3 100644
--- a/vprtempo/VPRTempo.py
+++ b/vprtempo/VPRTempo.py
@@ -33,12 +33,12 @@
from tqdm import tqdm
from prettytable import PrettyTable
-from torch.utils.data import DataLoader
+from torch.utils.data import DataLoader, Subset
from vprtempo.src.metrics import recallAtK
from vprtempo.src.dataset import CustomImageDataset, ProcessImage
class VPRTempo(nn.Module):
- def __init__(self, dims, args=None, logger=None):
+ def __init__(self, args, dims, logger, num_modules, out_dim, out_dim_remainder=None):
super(VPRTempo, self).__init__()
# Set the arguments
@@ -53,17 +53,22 @@ def __init__(self, dims, args=None, logger=None):
self.device = "cpu"
self.logger = logger
+ self.num_modules = num_modules
+
# Set the dataset file
self.dataset_file = os.path.join('./vprtempo/dataset', self.dataset + '.csv')
-
+ self.query_dir = [dir.strip() for dir in self.query_dir.split(',')]
# Layer dict to keep track of layer names and their order
self.layer_dict = {}
self.layer_counter = 0
-
+ self.database_dirs = [dir.strip() for dir in self.database_dirs.split(',')]
# Define layer architecture
self.input = int(self.dims[0]*self.dims[1])
self.feature = int(self.input * 2)
- self.output = int(args.num_places / args.num_modules)
+ if not out_dim_remainder is None:
+ self.output = out_dim_remainder
+ else:
+ self.output = out_dim
"""
Define trainable layers here
@@ -125,60 +130,56 @@ def evaluate(self, models, test_loader, layers=None):
# Initiliaze the output spikes variable
out = []
labels = []
+ num_corr = 0
# Run inference for the specified number of timesteps
for spikes, label in test_loader:
# Set device
- spikes = spikes.to(self.device)
+ spikes, label = spikes.to(self.device), label.to(self.device)
labels.append(label.detach().cpu().item())
# Forward pass
spikes = self.forward(spikes)
# Add output spikes to list
- out.append(spikes.detach().cpu().tolist())
+ #out.append(spikes.detach().cpu().tolist())
pbar.update(1)
-
+ if torch.argmax(spikes) == label:
+ num_corr += 1
+ print(f"Accuracy: {num_corr/len(test_loader)}")
# Close the tqdm progress bar
pbar.close()
# Rehsape output spikes into a similarity matrix
- out = np.reshape(np.array(out),(model.query_places,model.num_places))
+ #out = np.reshape(np.array(out),(model.query_places,model.database_places))
# Recall@N
- N = [1,5,10,15,20,25] # N values to calculate
- R = [] # Recall@N values
+ # N = [1,5,10,15,20,25] # N values to calculate
+ # R = [] # Recall@N values
# Create GT matrix
- GT = np.zeros((model.query_places,model.num_places), dtype=int)
- for n, ndx in enumerate(labels):
- GT[n,ndx] = 1
+ # GT = np.zeros((model.query_places,model.database_places), dtype=int)
+ # for n, ndx in enumerate(labels):
+ # if model.filter !=1:
+ # ndx = ndx//model.filter
+ # GT[n,ndx] = 1
+ # Create GT soft matrix
+ # if model.GT_tolerance > 0:
+ # GTsoft = np.zeros((model.query_places, model.database_places), dtype=int)
+ # for n, ndx in enumerate(labels):
+ # if model.filter != 1:
+ # ndx = ndx // model.filter
+ # GTsoft[n, ndx] = 1
+ # # Apply tolerance
+ # for i in range(max(0, n - model.GT_tolerance), min(model.query_places, n + model.GT_tolerance + 1)):
+ # GTsoft[i, ndx] = 1
+ # else:
+ # GTsoft = None
+
# Calculate Recall@N
- for n in N:
- R.append(round(recallAtK(out,GThard=GT,K=n),2))
+ #for n in N:
+ # R.append(round(recallAtK(out,GThard=GT,GTsoft=GTsoft,K=n),2))
# Print the results
- table = PrettyTable()
- table.field_names = ["N", "1", "5", "10", "15", "20", "25"]
- table.add_row(["Recall", R[0], R[1], R[2], R[3], R[4], R[5]])
- print(table)
-
- import matplotlib.pyplot as plt
-
- # Increasing the figure size significantly
- plt.figure(figsize=(20, 30)) # This is a large figure size, which can be adjusted as needed
-
- # Plotting "Spiking output"
- plt.subplot(2, 1, 1) # First plot
- plt.imshow(out, aspect='auto', interpolation='nearest') # 'nearest' interpolation for clear points
- plt.title("Spiking output")
- plt.colorbar(shrink=0.5)
- plt.grid(False) # Disabling grid lines
-
- # Plotting "Ground truth"
- plt.subplot(2, 1, 2) # Second plot
- plt.imshow(GT, aspect='auto', interpolation='nearest')
- plt.title("Ground truth")
- plt.colorbar(shrink=0.5)
- plt.grid(False) # Disabling grid lines
-
- # Displaying the plots
- plt.show()
+ #table = PrettyTable()
+ #table.field_names = ["N", "1", "5", "10", "15", "20", "25"]
+ #table.add_row(["Recall", R[0], R[1], R[2], R[3], R[4], R[5]])
+ #self.logger.info(table)
def forward(self, spikes):
"""
@@ -190,10 +191,11 @@ def forward(self, spikes):
Returns:
- Tensor: Output after processing.
"""
-
+ # Initialize the output spikes tensor
in_spikes = spikes.detach().clone()
outputs = [] # List to collect output tensors
+ # Run inferencing across modules
for inference in self.inferences:
out_spikes = inference(in_spikes)
outputs.append(out_spikes) # Append the output tensor to the list
@@ -223,7 +225,7 @@ def run_inference(models, model_name):
"""
# Initialize the image transforms and datasets
image_transform = ProcessImage(models[0].dims, models[0].patches)
- max_samples=models[0].query_places
+ max_samples=models[0].database_places
test_dataset = CustomImageDataset(annotations_file=models[0].dataset_file,
base_dir=models[0].data_dir,
@@ -231,10 +233,13 @@ def run_inference(models, model_name):
transform=image_transform,
skip=models[0].filter,
max_samples=max_samples)
+ indices = torch.randperm(models[0].database_places).tolist()
+ subset_indicies = indices[:models[0].query_places]
+ subset = Subset(test_dataset, subset_indicies)
# Initialize the data loader
- test_loader = DataLoader(test_dataset,
+ test_loader = DataLoader(subset,
batch_size=1,
- shuffle=True,
+ shuffle=models[0].shuffle,
num_workers=8,
persistent_workers=True)
diff --git a/vprtempo/VPRTempoTrain.py b/vprtempo/VPRTempoTrain.py
index e0bac5c..812b371 100644
--- a/vprtempo/VPRTempoTrain.py
+++ b/vprtempo/VPRTempoTrain.py
@@ -27,6 +27,7 @@
import os
import gc
import torch
+import sys
import numpy as np
import torch.nn as nn
@@ -39,7 +40,7 @@
from vprtempo.src.dataset import CustomImageDataset, ProcessImage
class VPRTempoTrain(nn.Module):
- def __init__(self, args, dims, logger):
+ def __init__(self, args, dims, logger, num_modules, out_dim, out_dim_remainder=None):
super(VPRTempoTrain, self).__init__()
# Set the arguments
@@ -52,6 +53,7 @@ def __init__(self, args, dims, logger):
else:
self.device = "cpu"
self.logger = logger
+ self.num_modules = num_modules
# Set the dataset file
self.dataset_file = os.path.join('./vprtempo/dataset', self.dataset + '.csv')
@@ -63,11 +65,18 @@ def __init__(self, args, dims, logger):
# Define layer architecture
self.input = int(dims[0]*dims[1])
self.feature = int(self.input * 2)
- self.output = int(args.num_places / args.num_modules)
+ if not out_dim_remainder is None:
+ self.output = out_dim_remainder
+ else:
+ self.output = out_dim
# Set the total timestep count
- self.location_repeat = len(args.database_dirs) # Number of times to repeat the locations
- self.T = int((self.num_places/self.num_modules) * self.location_repeat * self.epoch)
+ self.database_dirs = [dir.strip() for dir in self.database_dirs.split(',')]
+ self.location_repeat = len(self.database_dirs) # Number of times to repeat the locations
+ if not out_dim_remainder is None:
+ self.T = int(out_dim_remainder * self.location_repeat * self.epoch)
+ else:
+ self.T = int(self.max_module * self.location_repeat * self.epoch)
"""
Define trainable layers here
@@ -138,7 +147,7 @@ def train_model(self, train_loader, layer, model, model_num, prev_layers=None):
"""
# Initialize the tqdm progress bar
- pbar = tqdm(total=int(self.T),
+ pbar = tqdm(total=self.T,
desc=f"Module {model_num+1}",
position=0)
@@ -219,18 +228,21 @@ def generate_model_name(model):
str(model.input) +
str(model.feature) +
str(model.output) +
- str(model.num_modules) +
+ ''.join(model.database_dirs) +
'.pth')
def check_pretrained_model(model_name):
"""
Check if a pre-trained model exists and prompt the user to retrain if desired.
"""
- if os.path.exists(os.path.join('./models', model_name)):
+ if os.path.exists(os.path.join('./vprtempo/models', model_name)):
prompt = "A network with these parameters exists, re-train network? (y/n):\n"
retrain = input(prompt).strip().lower()
- return retrain == 'n'
- return False
+ if retrain == 'y':
+ return True
+ elif retrain == 'n':
+ print('Training new model cancelled')
+ sys.exit()
def train_new_model(models, model_name):
"""
@@ -247,15 +259,12 @@ def train_new_model(models, model_name):
# Automatically generate user_input_ranges
user_input_ranges = []
start_idx = 0
-
+ # Generate the image ranges for each module
for _ in range(models[0].num_modules):
range_temp = [start_idx, start_idx+((models[0].max_module-1)*models[0].filter)]
user_input_ranges.append(range_temp)
start_idx = range_temp[1] + models[0].filter
- if models[0].num_places < models[0].max_module:
- max_samples=models[0].num_places
- else:
- max_samples = models[0].max_module
+
# Keep track of trained layers to pass data through them
trained_layers = []
# Training each layer
@@ -266,6 +275,14 @@ def train_new_model(models, model_name):
model.train()
model.to(torch.device(model.device))
layer = (getattr(model, layer_name))
+ # Determine the maximum samples for the DataLoader
+ if model.database_places < model.max_module:
+ max_samples = model.database_places
+ elif model.output < model.max_module:
+ max_samples = model.output
+ else:
+ max_samples = model.max_module
+ # Initialize new dataset with unique range for each module
img_range=user_input_ranges[i]
train_dataset = CustomImageDataset(annotations_file=models[0].dataset_file,
base_dir=models[0].data_dir,
diff --git a/vprtempo/src/dataset.py b/vprtempo/src/dataset.py
index 976dbf2..209549d 100644
--- a/vprtempo/src/dataset.py
+++ b/vprtempo/src/dataset.py
@@ -163,7 +163,7 @@ def __init__(self, annotations_file, base_dir, img_dirs, transform=None, target_
self.target_transform = target_transform
self.skip = skip
self.img_range = img_range
-
+
# Load image labels from each directory, apply the skip and max_samples, and concatenate
self.img_labels = []
for img_dir in img_dirs:
diff --git a/vprtempo/src/loggers.py b/vprtempo/src/loggers.py
index 3081a02..d870de1 100644
--- a/vprtempo/src/loggers.py
+++ b/vprtempo/src/loggers.py
@@ -48,6 +48,8 @@ def model_logger():
logger.info('Current device is: CPU')
logger.info('')
+ return logger
+
def model_logger_quant():
"""
Configure the logger
@@ -88,3 +90,5 @@ def model_logger_quant():
logger.info('Quantization enabled')
logger.info('Current device is: CPU')
logger.info('')
+
+ return logger
\ No newline at end of file
diff --git a/vprtempo/src/metrics.py b/vprtempo/src/metrics.py
index 9a85b84..4f2849a 100644
--- a/vprtempo/src/metrics.py
+++ b/vprtempo/src/metrics.py
@@ -149,9 +149,13 @@ def recallAtK(S_in, GThard, GTsoft=None, K=1):
# ensure logical datatype in GT and GTsoft
GT = GThard.astype('bool')
+ if GTsoft is not None:
+ GTsoft = GTsoft.astype('bool')
# copy S and set elements that are only true in GTsoft to min(S) to ignore them during evaluation
S = S_in.copy()
+ if GTsoft is not None:
+ S[GTsoft & ~GT] = S.min()
# discard all query images without an actually matching database image
j = GT.sum(0) > 0 # columns with matches
From c4e5dff62d06455d074c1a7b338667f2c7bfa666 Mon Sep 17 00:00:00 2001
From: Adam Hines
Date: Mon, 18 Dec 2023 14:36:03 +1000
Subject: [PATCH 3/8] Fixed up GT and matching, runs subsets for q image counts
less than db
---
docs/.gitignore | 5 ++
docs/404.html | 25 ++++++
docs/Gemfile | 35 ++++++++
docs/_config.yml | 55 +++++++++++++
.../2023-12-18-welcome-to-jekyll.markdown | 29 +++++++
docs/about.markdown | 18 +++++
docs/index.markdown | 6 ++
vprtempo/VPRTempo.py | 80 ++++++++++---------
8 files changed, 217 insertions(+), 36 deletions(-)
create mode 100644 docs/.gitignore
create mode 100644 docs/404.html
create mode 100644 docs/Gemfile
create mode 100644 docs/_config.yml
create mode 100644 docs/_posts/2023-12-18-welcome-to-jekyll.markdown
create mode 100644 docs/about.markdown
create mode 100644 docs/index.markdown
diff --git a/docs/.gitignore b/docs/.gitignore
new file mode 100644
index 0000000..f40fbd8
--- /dev/null
+++ b/docs/.gitignore
@@ -0,0 +1,5 @@
+_site
+.sass-cache
+.jekyll-cache
+.jekyll-metadata
+vendor
diff --git a/docs/404.html b/docs/404.html
new file mode 100644
index 0000000..086a5c9
--- /dev/null
+++ b/docs/404.html
@@ -0,0 +1,25 @@
+---
+permalink: /404.html
+layout: default
+---
+
+
+
+
+
404
+
+
Page not found :(
+
The requested page could not be found.
+
diff --git a/docs/Gemfile b/docs/Gemfile
new file mode 100644
index 0000000..60c2833
--- /dev/null
+++ b/docs/Gemfile
@@ -0,0 +1,35 @@
+source "https://rubygems.org"
+# Hello! This is where you manage which Jekyll version is used to run.
+# When you want to use a different version, change it below, save the
+# file and run `bundle install`. Run Jekyll with `bundle exec`, like so:
+#
+# bundle exec jekyll serve
+#
+# This will help ensure the proper Jekyll version is running.
+# Happy Jekylling!
+# gem "jekyll", "~> 4.3.2"
+# This is the default theme for new Jekyll sites. You may change this to anything you like.
+gem "minima", "~> 2.5"
+# If you want to use GitHub Pages, remove the "gem "jekyll"" above and
+# uncomment the line below. To upgrade, run `bundle update github-pages`.
+# gem "github-pages", group: :jekyll_plugins
+# If you have any plugins, put them here!
+group :jekyll_plugins do
+ gem "jekyll-feed", "~> 0.12"
+end
+
+# Windows and JRuby does not include zoneinfo files, so bundle the tzinfo-data gem
+# and associated library.
+platforms :mingw, :x64_mingw, :mswin, :jruby do
+ gem "tzinfo", ">= 1", "< 3"
+ gem "tzinfo-data"
+end
+
+# Performance-booster for watching directories on Windows
+gem "wdm", "~> 0.1.1", :platforms => [:mingw, :x64_mingw, :mswin]
+
+# Lock `http_parser.rb` gem to `v0.6.x` on JRuby builds since newer versions of the gem
+# do not have a Java counterpart.
+gem "http_parser.rb", "~> 0.6.0", :platforms => [:jruby]
+
+gem "github-pages", "~> 228", group: :jekyll_plugins
diff --git a/docs/_config.yml b/docs/_config.yml
new file mode 100644
index 0000000..ef7ba7c
--- /dev/null
+++ b/docs/_config.yml
@@ -0,0 +1,55 @@
+# Welcome to Jekyll!
+#
+# This config file is meant for settings that affect your whole blog, values
+# which you are expected to set up once and rarely edit after that. If you find
+# yourself editing this file very often, consider using Jekyll's data files
+# feature for the data you need to update frequently.
+#
+# For technical reasons, this file is *NOT* reloaded automatically when you use
+# 'bundle exec jekyll serve'. If you change this file, please restart the server process.
+#
+# If you need help with YAML syntax, here are some quick references for you:
+# https://learn-the-web.algonquindesign.ca/topics/markdown-yaml-cheat-sheet/#yaml
+# https://learnxinyminutes.com/docs/yaml/
+#
+# Site settings
+# These are used to personalize your new site. If you look in the HTML files,
+# you will see them accessed via {{ site.title }}, {{ site.email }}, and so on.
+# You can create any custom variable you would like, and they will be accessible
+# in the templates via {{ site.myvariable }}.
+
+title: Your awesome title
+email: your-email@example.com
+description: >- # this means to ignore newlines until "baseurl:"
+ Write an awesome description for your new site here. You can edit this
+ line in _config.yml. It will appear in your document head meta (for
+ Google search results) and in your feed.xml site description.
+baseurl: "" # the subpath of your site, e.g. /blog
+url: "" # the base hostname & protocol for your site, e.g. http://example.com
+twitter_username: jekyllrb
+github_username: jekyll
+
+# Build settings
+theme: minima
+plugins:
+ - jekyll-feed
+
+# Exclude from processing.
+# The following items will not be processed, by default.
+# Any item listed under the `exclude:` key here will be automatically added to
+# the internal "default list".
+#
+# Excluded items can be processed by explicitly listing the directories or
+# their entries' file path in the `include:` list.
+#
+# exclude:
+# - .sass-cache/
+# - .jekyll-cache/
+# - gemfiles/
+# - Gemfile
+# - Gemfile.lock
+# - node_modules/
+# - vendor/bundle/
+# - vendor/cache/
+# - vendor/gems/
+# - vendor/ruby/
diff --git a/docs/_posts/2023-12-18-welcome-to-jekyll.markdown b/docs/_posts/2023-12-18-welcome-to-jekyll.markdown
new file mode 100644
index 0000000..163a0a6
--- /dev/null
+++ b/docs/_posts/2023-12-18-welcome-to-jekyll.markdown
@@ -0,0 +1,29 @@
+---
+layout: post
+title: "Welcome to Jekyll!"
+date: 2023-12-18 10:49:10 +1000
+categories: jekyll update
+---
+You’ll find this post in your `_posts` directory. Go ahead and edit it and re-build the site to see your changes. You can rebuild the site in many different ways, but the most common way is to run `jekyll serve`, which launches a web server and auto-regenerates your site when a file is updated.
+
+Jekyll requires blog post files to be named according to the following format:
+
+`YEAR-MONTH-DAY-title.MARKUP`
+
+Where `YEAR` is a four-digit number, `MONTH` and `DAY` are both two-digit numbers, and `MARKUP` is the file extension representing the format used in the file. After that, include the necessary front matter. Take a look at the source for this post to get an idea about how it works.
+
+Jekyll also offers powerful support for code snippets:
+
+{% highlight ruby %}
+def print_hi(name)
+ puts "Hi, #{name}"
+end
+print_hi('Tom')
+#=> prints 'Hi, Tom' to STDOUT.
+{% endhighlight %}
+
+Check out the [Jekyll docs][jekyll-docs] for more info on how to get the most out of Jekyll. File all bugs/feature requests at [Jekyll’s GitHub repo][jekyll-gh]. If you have questions, you can ask them on [Jekyll Talk][jekyll-talk].
+
+[jekyll-docs]: https://jekyllrb.com/docs/home
+[jekyll-gh]: https://github.com/jekyll/jekyll
+[jekyll-talk]: https://talk.jekyllrb.com/
diff --git a/docs/about.markdown b/docs/about.markdown
new file mode 100644
index 0000000..8b4e0b2
--- /dev/null
+++ b/docs/about.markdown
@@ -0,0 +1,18 @@
+---
+layout: page
+title: About
+permalink: /about/
+---
+
+This is the base Jekyll theme. You can find out more info about customizing your Jekyll theme, as well as basic Jekyll usage documentation at [jekyllrb.com](https://jekyllrb.com/)
+
+You can find the source code for Minima at GitHub:
+[jekyll][jekyll-organization] /
+[minima](https://github.com/jekyll/minima)
+
+You can find the source code for Jekyll at GitHub:
+[jekyll][jekyll-organization] /
+[jekyll](https://github.com/jekyll/jekyll)
+
+
+[jekyll-organization]: https://github.com/jekyll
diff --git a/docs/index.markdown b/docs/index.markdown
new file mode 100644
index 0000000..0671507
--- /dev/null
+++ b/docs/index.markdown
@@ -0,0 +1,6 @@
+---
+# Feel free to add content and custom Front Matter to this file.
+# To modify the layout, see https://jekyllrb.com/docs/themes/#overriding-theme-defaults
+
+layout: home
+---
diff --git a/vprtempo/VPRTempo.py b/vprtempo/VPRTempo.py
index a64f9b3..f3bf489 100644
--- a/vprtempo/VPRTempo.py
+++ b/vprtempo/VPRTempo.py
@@ -26,6 +26,7 @@
import os
import torch
+import random
import numpy as np
import torch.nn as nn
@@ -130,7 +131,6 @@ def evaluate(self, models, test_loader, layers=None):
# Initiliaze the output spikes variable
out = []
labels = []
- num_corr = 0
# Run inference for the specified number of timesteps
for spikes, label in test_loader:
# Set device
@@ -139,47 +139,46 @@ def evaluate(self, models, test_loader, layers=None):
# Forward pass
spikes = self.forward(spikes)
# Add output spikes to list
- #out.append(spikes.detach().cpu().tolist())
+ out.append(spikes.detach().cpu().tolist())
pbar.update(1)
- if torch.argmax(spikes) == label:
- num_corr += 1
- print(f"Accuracy: {num_corr/len(test_loader)}")
+
# Close the tqdm progress bar
pbar.close()
# Rehsape output spikes into a similarity matrix
- #out = np.reshape(np.array(out),(model.query_places,model.database_places))
+ out = np.reshape(np.array(out),(model.query_places,model.database_places))
# Recall@N
- # N = [1,5,10,15,20,25] # N values to calculate
- # R = [] # Recall@N values
+ N = [1,5,10,15,20,25] # N values to calculate
+ R = [] # Recall@N values
# Create GT matrix
- # GT = np.zeros((model.query_places,model.database_places), dtype=int)
- # for n, ndx in enumerate(labels):
- # if model.filter !=1:
- # ndx = ndx//model.filter
- # GT[n,ndx] = 1
+ GT = np.zeros((model.query_places,model.database_places), dtype=int)
+ for n, ndx in enumerate(labels):
+ if model.filter !=1:
+ ndx = ndx//model.filter
+ GT[n,ndx] = 1
+
# Create GT soft matrix
- # if model.GT_tolerance > 0:
- # GTsoft = np.zeros((model.query_places, model.database_places), dtype=int)
- # for n, ndx in enumerate(labels):
- # if model.filter != 1:
- # ndx = ndx // model.filter
- # GTsoft[n, ndx] = 1
- # # Apply tolerance
- # for i in range(max(0, n - model.GT_tolerance), min(model.query_places, n + model.GT_tolerance + 1)):
- # GTsoft[i, ndx] = 1
- # else:
- # GTsoft = None
+ if model.GT_tolerance > 0:
+ GTsoft = np.zeros((model.query_places,model.database_places), dtype=int)
+ for n, ndx in enumerate(labels):
+ if model.filter !=1:
+ ndx = ndx//model.filter
+ GTsoft[n, ndx] = 1
+ # Apply tolerance
+ for i in range(max(0, n - model.GT_tolerance), min(model.query_places, n + model.GT_tolerance + 1)):
+ GTsoft[i, ndx] = 1
+ else:
+ GTsoft = None
# Calculate Recall@N
- #for n in N:
- # R.append(round(recallAtK(out,GThard=GT,GTsoft=GTsoft,K=n),2))
+ for n in N:
+ R.append(round(recallAtK(out,GThard=GT,GTsoft=GTsoft,K=n),2))
# Print the results
- #table = PrettyTable()
- #table.field_names = ["N", "1", "5", "10", "15", "20", "25"]
- #table.add_row(["Recall", R[0], R[1], R[2], R[3], R[4], R[5]])
- #self.logger.info(table)
+ table = PrettyTable()
+ table.field_names = ["N", "1", "5", "10", "15", "20", "25"]
+ table.add_row(["Recall", R[0], R[1], R[2], R[3], R[4], R[5]])
+ self.logger.info(table)
def forward(self, spikes):
"""
@@ -225,19 +224,29 @@ def run_inference(models, model_name):
"""
# Initialize the image transforms and datasets
image_transform = ProcessImage(models[0].dims, models[0].patches)
- max_samples=models[0].database_places
+ # Determine input range
+ if models[0].query_places == models[0].database_places:
+ max_samples=models[0].query_places
+ subset = False
+ elif models[0].query_places < models[0].database_places:
+ max_samples=models[0].database_places
+ subset = True
+ else:
+ raise ValueError("The number of query places must be less than or equal to the number of database places.")
+ # Initialize the test dataset
test_dataset = CustomImageDataset(annotations_file=models[0].dataset_file,
base_dir=models[0].data_dir,
img_dirs=models[0].query_dir,
transform=image_transform,
skip=models[0].filter,
max_samples=max_samples)
- indices = torch.randperm(models[0].database_places).tolist()
- subset_indicies = indices[:models[0].query_places]
- subset = Subset(test_dataset, subset_indicies)
+ # If the number of query places is less than the number of database places, then subset the database
+ if subset:
+ test_dataset = Subset(test_dataset, random.sample(range(len(test_dataset)), models[0].query_places))
+
# Initialize the data loader
- test_loader = DataLoader(subset,
+ test_loader = DataLoader(test_dataset,
batch_size=1,
shuffle=models[0].shuffle,
num_workers=8,
@@ -245,7 +254,6 @@ def run_inference(models, model_name):
# Load the model
models[0].load_model(models, os.path.join('./vprtempo/models', model_name))
-
# Retrieve layer names for inference
layer_names = list(models[0].layer_dict.keys())
From d0ff33cd0f770acd948566e93c77d80fefc2ae3f Mon Sep 17 00:00:00 2001
From: Adam Hines
Date: Tue, 19 Dec 2023 15:26:26 +1000
Subject: [PATCH 4/8] Altered model naming schema
---
docs/404.html | 25 ---------
docs/Gemfile | 35 ------------
docs/_config.yml | 55 -------------------
.../2023-12-18-welcome-to-jekyll.markdown | 29 ----------
docs/about.markdown | 18 ------
docs/index.markdown | 6 --
main.py | 14 ++---
vprtempo/VPRTempo.py | 38 +++++++++----
vprtempo/VPRTempoTrain.py | 10 ----
vprtempo/src/metrics.py | 17 +++---
10 files changed, 42 insertions(+), 205 deletions(-)
delete mode 100644 docs/404.html
delete mode 100644 docs/Gemfile
delete mode 100644 docs/_config.yml
delete mode 100644 docs/_posts/2023-12-18-welcome-to-jekyll.markdown
delete mode 100644 docs/about.markdown
delete mode 100644 docs/index.markdown
diff --git a/docs/404.html b/docs/404.html
deleted file mode 100644
index 086a5c9..0000000
--- a/docs/404.html
+++ /dev/null
@@ -1,25 +0,0 @@
----
-permalink: /404.html
-layout: default
----
-
-
-
-
-
404
-
-
Page not found :(
-
The requested page could not be found.
-
diff --git a/docs/Gemfile b/docs/Gemfile
deleted file mode 100644
index 60c2833..0000000
--- a/docs/Gemfile
+++ /dev/null
@@ -1,35 +0,0 @@
-source "https://rubygems.org"
-# Hello! This is where you manage which Jekyll version is used to run.
-# When you want to use a different version, change it below, save the
-# file and run `bundle install`. Run Jekyll with `bundle exec`, like so:
-#
-# bundle exec jekyll serve
-#
-# This will help ensure the proper Jekyll version is running.
-# Happy Jekylling!
-# gem "jekyll", "~> 4.3.2"
-# This is the default theme for new Jekyll sites. You may change this to anything you like.
-gem "minima", "~> 2.5"
-# If you want to use GitHub Pages, remove the "gem "jekyll"" above and
-# uncomment the line below. To upgrade, run `bundle update github-pages`.
-# gem "github-pages", group: :jekyll_plugins
-# If you have any plugins, put them here!
-group :jekyll_plugins do
- gem "jekyll-feed", "~> 0.12"
-end
-
-# Windows and JRuby does not include zoneinfo files, so bundle the tzinfo-data gem
-# and associated library.
-platforms :mingw, :x64_mingw, :mswin, :jruby do
- gem "tzinfo", ">= 1", "< 3"
- gem "tzinfo-data"
-end
-
-# Performance-booster for watching directories on Windows
-gem "wdm", "~> 0.1.1", :platforms => [:mingw, :x64_mingw, :mswin]
-
-# Lock `http_parser.rb` gem to `v0.6.x` on JRuby builds since newer versions of the gem
-# do not have a Java counterpart.
-gem "http_parser.rb", "~> 0.6.0", :platforms => [:jruby]
-
-gem "github-pages", "~> 228", group: :jekyll_plugins
diff --git a/docs/_config.yml b/docs/_config.yml
deleted file mode 100644
index ef7ba7c..0000000
--- a/docs/_config.yml
+++ /dev/null
@@ -1,55 +0,0 @@
-# Welcome to Jekyll!
-#
-# This config file is meant for settings that affect your whole blog, values
-# which you are expected to set up once and rarely edit after that. If you find
-# yourself editing this file very often, consider using Jekyll's data files
-# feature for the data you need to update frequently.
-#
-# For technical reasons, this file is *NOT* reloaded automatically when you use
-# 'bundle exec jekyll serve'. If you change this file, please restart the server process.
-#
-# If you need help with YAML syntax, here are some quick references for you:
-# https://learn-the-web.algonquindesign.ca/topics/markdown-yaml-cheat-sheet/#yaml
-# https://learnxinyminutes.com/docs/yaml/
-#
-# Site settings
-# These are used to personalize your new site. If you look in the HTML files,
-# you will see them accessed via {{ site.title }}, {{ site.email }}, and so on.
-# You can create any custom variable you would like, and they will be accessible
-# in the templates via {{ site.myvariable }}.
-
-title: Your awesome title
-email: your-email@example.com
-description: >- # this means to ignore newlines until "baseurl:"
- Write an awesome description for your new site here. You can edit this
- line in _config.yml. It will appear in your document head meta (for
- Google search results) and in your feed.xml site description.
-baseurl: "" # the subpath of your site, e.g. /blog
-url: "" # the base hostname & protocol for your site, e.g. http://example.com
-twitter_username: jekyllrb
-github_username: jekyll
-
-# Build settings
-theme: minima
-plugins:
- - jekyll-feed
-
-# Exclude from processing.
-# The following items will not be processed, by default.
-# Any item listed under the `exclude:` key here will be automatically added to
-# the internal "default list".
-#
-# Excluded items can be processed by explicitly listing the directories or
-# their entries' file path in the `include:` list.
-#
-# exclude:
-# - .sass-cache/
-# - .jekyll-cache/
-# - gemfiles/
-# - Gemfile
-# - Gemfile.lock
-# - node_modules/
-# - vendor/bundle/
-# - vendor/cache/
-# - vendor/gems/
-# - vendor/ruby/
diff --git a/docs/_posts/2023-12-18-welcome-to-jekyll.markdown b/docs/_posts/2023-12-18-welcome-to-jekyll.markdown
deleted file mode 100644
index 163a0a6..0000000
--- a/docs/_posts/2023-12-18-welcome-to-jekyll.markdown
+++ /dev/null
@@ -1,29 +0,0 @@
----
-layout: post
-title: "Welcome to Jekyll!"
-date: 2023-12-18 10:49:10 +1000
-categories: jekyll update
----
-You’ll find this post in your `_posts` directory. Go ahead and edit it and re-build the site to see your changes. You can rebuild the site in many different ways, but the most common way is to run `jekyll serve`, which launches a web server and auto-regenerates your site when a file is updated.
-
-Jekyll requires blog post files to be named according to the following format:
-
-`YEAR-MONTH-DAY-title.MARKUP`
-
-Where `YEAR` is a four-digit number, `MONTH` and `DAY` are both two-digit numbers, and `MARKUP` is the file extension representing the format used in the file. After that, include the necessary front matter. Take a look at the source for this post to get an idea about how it works.
-
-Jekyll also offers powerful support for code snippets:
-
-{% highlight ruby %}
-def print_hi(name)
- puts "Hi, #{name}"
-end
-print_hi('Tom')
-#=> prints 'Hi, Tom' to STDOUT.
-{% endhighlight %}
-
-Check out the [Jekyll docs][jekyll-docs] for more info on how to get the most out of Jekyll. File all bugs/feature requests at [Jekyll’s GitHub repo][jekyll-gh]. If you have questions, you can ask them on [Jekyll Talk][jekyll-talk].
-
-[jekyll-docs]: https://jekyllrb.com/docs/home
-[jekyll-gh]: https://github.com/jekyll/jekyll
-[jekyll-talk]: https://talk.jekyllrb.com/
diff --git a/docs/about.markdown b/docs/about.markdown
deleted file mode 100644
index 8b4e0b2..0000000
--- a/docs/about.markdown
+++ /dev/null
@@ -1,18 +0,0 @@
----
-layout: page
-title: About
-permalink: /about/
----
-
-This is the base Jekyll theme. You can find out more info about customizing your Jekyll theme, as well as basic Jekyll usage documentation at [jekyllrb.com](https://jekyllrb.com/)
-
-You can find the source code for Minima at GitHub:
-[jekyll][jekyll-organization] /
-[minima](https://github.com/jekyll/minima)
-
-You can find the source code for Jekyll at GitHub:
-[jekyll][jekyll-organization] /
-[jekyll](https://github.com/jekyll/jekyll)
-
-
-[jekyll-organization]: https://github.com/jekyll
diff --git a/docs/index.markdown b/docs/index.markdown
deleted file mode 100644
index 0671507..0000000
--- a/docs/index.markdown
+++ /dev/null
@@ -1,6 +0,0 @@
----
-# Feel free to add content and custom Front Matter to this file.
-# To modify the layout, see https://jekyllrb.com/docs/themes/#overriding-theme-defaults
-
-layout: home
----
diff --git a/main.py b/main.py
index dbfcd99..ddde21e 100644
--- a/main.py
+++ b/main.py
@@ -33,19 +33,19 @@
from vprtempo.VPRTempo import VPRTempo, run_inference
from vprtempo.src.loggers import model_logger, model_logger_quant
from vprtempo.VPRTempoQuant import VPRTempoQuant, run_inference_quant
-from vprtempo.VPRTempoQuantTrain import VPRTempoQuantTrain, generate_model_name_quant, train_new_model_quant
from vprtempo.VPRTempoTrain import VPRTempoTrain, check_pretrained_model, train_new_model
+from vprtempo.VPRTempoQuantTrain import VPRTempoQuantTrain, generate_model_name_quant, train_new_model_quant
def generate_model_name(model):
"""
Generate the model name based on its parameters.
"""
- return ("VPRTempo" +
- str(model.input) +
- str(model.feature) +
- str(model.database_places) +
- ''.join(model.database_dirs) +
- '.pth')
+ return (''.join(model.database_dirs)+"_"+
+ "VPRTempo_" +
+ "IN"+str(model.input)+"_" +
+ "FN"+str(model.feature)+"_" +
+ "DB"+str(model.database_places) +
+ ".pth")
def initialize_and_run_model(args,dims):
# If user wants to train a new network
diff --git a/vprtempo/VPRTempo.py b/vprtempo/VPRTempo.py
index f3bf489..0089810 100644
--- a/vprtempo/VPRTempo.py
+++ b/vprtempo/VPRTempo.py
@@ -30,12 +30,13 @@
import numpy as np
import torch.nn as nn
+import matplotlib.pyplot as plt
import vprtempo.src.blitnet as bn
from tqdm import tqdm
from prettytable import PrettyTable
from torch.utils.data import DataLoader, Subset
-from vprtempo.src.metrics import recallAtK
+from vprtempo.src.metrics import recallAtK, createPR
from vprtempo.src.dataset import CustomImageDataset, ProcessImage
class VPRTempo(nn.Module):
@@ -131,6 +132,7 @@ def evaluate(self, models, test_loader, layers=None):
# Initiliaze the output spikes variable
out = []
labels = []
+
# Run inference for the specified number of timesteps
for spikes, label in test_loader:
# Set device
@@ -144,26 +146,22 @@ def evaluate(self, models, test_loader, layers=None):
# Close the tqdm progress bar
pbar.close()
-
# Rehsape output spikes into a similarity matrix
out = np.reshape(np.array(out),(model.query_places,model.database_places))
- # Recall@N
- N = [1,5,10,15,20,25] # N values to calculate
- R = [] # Recall@N values
# Create GT matrix
GT = np.zeros((model.query_places,model.database_places), dtype=int)
for n, ndx in enumerate(labels):
- if model.filter !=1:
- ndx = ndx//model.filter
+ #if model.filter !=1:
+ # ndx = ndx//model.filter
GT[n,ndx] = 1
# Create GT soft matrix
if model.GT_tolerance > 0:
GTsoft = np.zeros((model.query_places,model.database_places), dtype=int)
for n, ndx in enumerate(labels):
- if model.filter !=1:
- ndx = ndx//model.filter
+ #if model.filter !=1:
+ # ndx = ndx//model.filter
GTsoft[n, ndx] = 1
# Apply tolerance
for i in range(max(0, n - model.GT_tolerance), min(model.query_places, n + model.GT_tolerance + 1)):
@@ -171,6 +169,17 @@ def evaluate(self, models, test_loader, layers=None):
else:
GTsoft = None
+ # Create PR curve
+ P, R = createPR(out, GThard=GT, GTsoft=GTsoft, matching='single', n_thresh=100)
+ # Plot PR curve
+ plt.plot(R,P)
+ plt.xlabel('Recall')
+ plt.ylabel('Precision')
+ plt.title('Precision-Recall Curve')
+ plt.show()
+ # Recall@N
+ N = [1,5,10,15,20,25] # N values to calculate
+ R = [] # Recall@N values
# Calculate Recall@N
for n in N:
R.append(round(recallAtK(out,GThard=GT,GTsoft=GTsoft,K=n),2))
@@ -239,11 +248,18 @@ def run_inference(models, model_name):
base_dir=models[0].data_dir,
img_dirs=models[0].query_dir,
transform=image_transform,
- skip=models[0].filter,
max_samples=max_samples)
# If the number of query places is less than the number of database places, then subset the database
if subset:
- test_dataset = Subset(test_dataset, random.sample(range(len(test_dataset)), models[0].query_places))
+ if models[0].shuffle:
+ test_dataset = Subset(test_dataset, random.sample(range(len(test_dataset)), models[0].query_places))
+ else:
+ # Generate indices with applied skip
+ indices = [i for i in range(models[0].database_places) if i % models[0].filter == 0]
+ # Limit to the desired number of queries
+ indices = indices[:models[0].query_places]
+ test_dataset = Subset(test_dataset, indices)
+
# Initialize the data loader
test_loader = DataLoader(test_dataset,
diff --git a/vprtempo/VPRTempoTrain.py b/vprtempo/VPRTempoTrain.py
index 812b371..0a7049e 100644
--- a/vprtempo/VPRTempoTrain.py
+++ b/vprtempo/VPRTempoTrain.py
@@ -220,16 +220,6 @@ def save_model(self,models, model_out):
torch.save(state_dicts, model_out)
-def generate_model_name(model):
- """
- Generate the model name based on its parameters.
- """
- return ("VPRTempo" +
- str(model.input) +
- str(model.feature) +
- str(model.output) +
- ''.join(model.database_dirs) +
- '.pth')
def check_pretrained_model(model_name):
"""
diff --git a/vprtempo/src/metrics.py b/vprtempo/src/metrics.py
index 4f2849a..9900fd4 100644
--- a/vprtempo/src/metrics.py
+++ b/vprtempo/src/metrics.py
@@ -47,20 +47,19 @@ def createPR(S_in, GThard, GTsoft=None, matching='multi', n_thresh=100):
# copy S and set elements that are only true in GTsoft to min(S) to ignore them during evaluation
S = S_in.copy()
- S[S == 0] = np.nan
if GTsoft is not None:
S[GTsoft & ~GT] = S.min()
# single-best-match or multi-match VPR
if matching == 'single':
- # GT-values for best match per query (i.e., per column)
- GT = GT[np.nanargmax(S, axis=0), np.arange(GT.shape[1])]
+ # count the number of ground-truth positives (GTP)
+ GTP = np.count_nonzero(GT.any(0))
- # count the number of ground-truth positives (GTP)
- GTP = np.count_nonzero(GT)
+ # GT-values for best match per query (i.e., per column)
+ GT = GT[np.argmax(S, axis=0), np.arange(GT.shape[1])]
# similarities for best match per query (i.e., per column)
- S = np.nanmax(S, axis=0)
+ S = np.max(S, axis=0)
elif matching == 'multi':
# count the number of ground-truth positives (GTP)
@@ -71,8 +70,8 @@ def createPR(S_in, GThard, GTsoft=None, matching='multi', n_thresh=100):
P = [1, ]
# select start and end treshold
- startV = np.nanmax(S) # start-value for treshold
- endV = np.nanmin(S) # end-value for treshold
+ startV = S.max() # start-value for treshold
+ endV = S.min() # end-value for treshold
# iterate over different thresholds
for i in np.linspace(startV, endV, n_thresh):
@@ -170,4 +169,4 @@ def recallAtK(S_in, GThard, GTsoft=None, K=1):
# recall@K
RatK = np.sum(GT.sum(0) > 0) / GT.shape[1]
- return RatK
+ return RatK
\ No newline at end of file
From ab5f63cba921cd46413a4d471daea8685f66e321 Mon Sep 17 00:00:00 2001
From: Adam Hines
Date: Thu, 21 Dec 2023 13:55:45 +1000
Subject: [PATCH 5/8] Fixed up some code redudancies, added in PR curve and sim
mat plotting options
---
.DS_Store | Bin 10244 -> 0 bytes
main.py | 125 ++++++++++++++++++++--------------
tutorials/.DS_Store | Bin 6148 -> 0 bytes
tutorials/mats/.DS_Store | Bin 6148 -> 0 bytes
vprtempo/VPRTempo.py | 134 ++++++++++++++++++++++++-------------
vprtempo/dataset/.DS_Store | Bin 14340 -> 0 bytes
vprtempo/models/.DS_Store | Bin 6148 -> 0 bytes
vprtempo/src/loggers.py | 4 +-
8 files changed, 165 insertions(+), 98 deletions(-)
delete mode 100644 .DS_Store
delete mode 100644 tutorials/.DS_Store
delete mode 100644 tutorials/mats/.DS_Store
delete mode 100644 vprtempo/dataset/.DS_Store
delete mode 100644 vprtempo/models/.DS_Store
diff --git a/.DS_Store b/.DS_Store
deleted file mode 100644
index b9183bcce3be679eb9994ae6b0c29813116d3f44..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001
literal 10244
zcmeHM&2Jk;6n~QhXH6QqNs~4RkXHDD)F{|#NELBl<2XX*qlvOnNt6_Kzv3=k@0i_n
zLa`$G3j6{50o=Is2v-gq;e@y$4*UaLdf-&vd@Qxs2~+|Rux6r}-`kn@X7>H&WAnxU
zK&l1uhMA#$35zolYiD^Vjmeq?M1uC>rtTO@8kZ4Nm3B(pmM3^E9
zQzAnQBFwSgkb0)XoDY>+kVZ|cXEp_J
zfkFgF$U}2f-0gEI=k6qUIKL^MW1QhWhrj;U=XasoU0q&TG%9O1|73#M1Y$aOfYBK_kZ3B13u;0DWJC<-281;4st`-HX7unPq+X3P=p<%cO7c*F5(hW
z=;gZlE!WjJeobO<2ero70o*^J)Vk;sX?blGKQraZ-#n3|HyLqkMa!z(oD^|B*ql(L
zD8OTRbXMXSo`ubd<%7
zRHztG3@8Q^1B!uBU|?K0VX*uE&GUc%9|c#H4~hZBz&Q+%iJDcbVq+F?=k?fKyMg&G
zW|o*fj_A?HA>e9?jDKXer~
V;8QN1QStv@+&JTn args.max_module:
+ places -= args.max_module
+ num_modules += 1
+
+ # If the final module has less than max_module, reduce the dim of the output layer
+ remainder = args.database_places % args.max_module
+ if remainder != 0: # There are remainders, adjust output neuron count in final module
+ out_dim = int((args.database_places - remainder) / (num_modules - 1))
+ final_out_dim = remainder
+ else: # No remainders, all modules are even
+ out_dim = int(args.database_places / num_modules)
+ final_out_dim = out_dim
+
# If user wants to train a new network
if args.train_new_model:
# If using quantization aware training
@@ -63,7 +110,7 @@ def initialize_and_run_model(args,dims):
model.qconfig = qconfig
models.append(model)
# Generate the model name
- model_name = generate_model_name_quant(model)
+ model_name = generate_model_name(model,args.quantize)
# Check if the model has been trained before
check_pretrained_model(model_name)
# Train the model
@@ -73,29 +120,6 @@ def initialize_and_run_model(args,dims):
else:
models = []
logger = model_logger() # Initialize the logger
- places = args.database_places # Copy out number of database places
-
- # Determine how many modules the network needs to create
- num_modules = 1
- while places > args.max_module:
- places -= args.max_module
- num_modules += 1
-
- # If the final module has less than max_module, reduce the dim of the output layer
- remainder = args.database_places % args.max_module
-
- # Check if number of modules and database images works
- if args.filter * (((num_modules-1)*args.max_module)+remainder) > args.database_places:
- print("Error: Too many modules or too few images for the given filter")
- sys.exit()
-
- # Modify final module output layer neuron count according to remainder
- if remainder != 0: # There are remainders, adjust output neuron count in final module
- out_dim = int((args.database_places - remainder) / (num_modules - 1))
- final_out_dim = remainder
- else: # No remainders, all modules are even
- out_dim = int(args.database_places / num_modules)
- final_out_dim = out_dim
# Create the modules
final_out = None
@@ -135,28 +159,21 @@ def initialize_and_run_model(args,dims):
run_inference_quant(models, model_name, qconfig)
else:
models = []
- logger = model_logger() # Initialize the logger
+ logger, output_folder = model_logger() # Initialize the logger
places = args.database_places # Copy out number of database places
- # Determine how many modules the network needs to create
- num_modules = 1
- while places > args.max_module:
- places -= args.max_module
- num_modules += 1
-
- # If the final module has less than max_module, reduce the dim of the output layer
- remainder = args.database_places % args.max_module
- if remainder != 0: # There are remainders, adjust output neuron count in final module
- out_dim = int((args.database_places - remainder) / (num_modules - 1))
- final_out_dim = remainder
- else: # No remainders, all modules are even
- out_dim = int(args.database_places / num_modules)
- final_out_dim = out_dim
-
# Create the modules
final_out = None
for mod in tqdm(range(num_modules), desc="Initializing modules"):
- model = VPRTempo(args, dims, logger, num_modules, out_dim, out_dim_remainder=final_out) # Initialize the model
+ model = VPRTempo(
+ args,
+ dims,
+ logger,
+ num_modules,
+ output_folder,
+ out_dim,
+ out_dim_remainder=final_out
+ )
model.eval()
model.to(torch.device('cpu')) # Move module to CPU for storage (necessary for large models)
models.append(model) # Create module list
@@ -212,6 +229,12 @@ def parse_network(use_quantize=False, train_new_model=False):
parser.add_argument('--quantize', action='store_true',
help="Enable/disable quantization for the model")
+ # Define metrics functionality
+ parser.add_argument('--PR_curve', action='store_true',
+ help="Flag to generate a Precision-Recall curve")
+ parser.add_argument('--sim_mat', action='store_true',
+ help="Flag to plot the similarity matrix, GT, and GTsoft")
+
# If the function is called with specific arguments, override sys.argv
if use_quantize or train_new_model:
sys.argv = ['']
diff --git a/tutorials/.DS_Store b/tutorials/.DS_Store
deleted file mode 100644
index 6cbab197bfe01d7a1bba7c4bacd81a6ed4589230..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001
literal 6148
zcmeHKL5mYH6n@zq?bJo6P+>1Zz-w9C?y|CY8OM6qqmJl7rOxaW9h{w%&eXD73Uk&U
zOHy%LiwGjF!+**m_fdTd_S9bnolU*_^$CjZ_$vDr-s@wf6wzj=1R~vrA-|=5YZ`3TR;%Zuk
z@#H!8o+y>n-`P5Z=2lP62=()LPKyT!MY1
z#lT{15F-#_T!F?_*c3w;cj%?H3oOzN4v
diff --git a/tutorials/mats/.DS_Store b/tutorials/mats/.DS_Store
deleted file mode 100644
index b2b0f537229c8fdbfc2772807bd4fdbbb908cc6d..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001
literal 6148
zcmeHKO;5r=5S;~-5@NzZ6ONmBB?@RT#!Ibu@M?`7)L<(_8*A54kOLv
zT)wcfT6Bs|*}3-b)yN+NgJJ3g{YyGJQz{OpaS)zIy-}yKex#B?5GB37CPdvHmRwy#
zNmq?LHB7px=K5yBDLJK1r8*wBn!7dG*x#Gf%SOHdGc@(fKm{VS!Tjo1u1z3RvDM0Ik
zL??7DW(M`rfrD-V5DS>rhB5st5=UBeEoKIB22B_$qM-`gVhBUWytH+$#mu0ggRsqq
zuvZqgLlOGwcz&tNLAVBaWCd7(p9)a#hgPBe|9F1?uZwtK1z3UqsemZ7{kDhCX7|>a
xr=-1BqTisCQC(*6i-LjPim{fq;wri}%u5OoU5lAP%%Jg)fR=#=R$!qDd;$SiTIT=&
diff --git a/vprtempo/VPRTempo.py b/vprtempo/VPRTempo.py
index 0089810..691ffdf 100644
--- a/vprtempo/VPRTempo.py
+++ b/vprtempo/VPRTempo.py
@@ -25,6 +25,7 @@
'''
import os
+import json
import torch
import random
@@ -40,33 +41,41 @@
from vprtempo.src.dataset import CustomImageDataset, ProcessImage
class VPRTempo(nn.Module):
- def __init__(self, args, dims, logger, num_modules, out_dim, out_dim_remainder=None):
+ def __init__(self, args, dims, logger, num_modules, output_folder, out_dim, out_dim_remainder=None):
super(VPRTempo, self).__init__()
- # Set the arguments
+ # Set the args
if args is not None:
self.args = args
for arg in vars(args):
setattr(self, arg, getattr(args, arg))
setattr(self, 'dims', dims)
+
+ # Set the device
if torch.cuda.is_available():
self.device = "cuda:0"
else:
self.device = "cpu"
+ # Set input args
self.logger = logger
self.num_modules = num_modules
+ self.output_folder = output_folder
# Set the dataset file
self.dataset_file = os.path.join('./vprtempo/dataset', self.dataset + '.csv')
self.query_dir = [dir.strip() for dir in self.query_dir.split(',')]
+
# Layer dict to keep track of layer names and their order
self.layer_dict = {}
self.layer_counter = 0
self.database_dirs = [dir.strip() for dir in self.database_dirs.split(',')]
+
# Define layer architecture
self.input = int(self.dims[0]*self.dims[1])
self.feature = int(self.input * 2)
+
+ # Output dimension changes for final module if not an even distribution of places
if not out_dim_remainder is None:
self.output = out_dim_remainder
else:
@@ -107,12 +116,12 @@ def add_layer(self, name, **kwargs):
self.layer_dict[name] = self.layer_counter
self.layer_counter += 1
- def evaluate(self, models, test_loader, layers=None):
+ def evaluate(self, models, test_loader):
"""
Run the inferencing model and calculate the accuracy.
+ :param models: Models to run inference on, each model is a VPRTempo module
:param test_loader: Testing data loader
- :param layers: Layers to pass data through
"""
# Initialize the tqdm progress bar
pbar = tqdm(total=self.query_places,
@@ -136,7 +145,7 @@ def evaluate(self, models, test_loader, layers=None):
# Run inference for the specified number of timesteps
for spikes, label in test_loader:
# Set device
- spikes, label = spikes.to(self.device), label.to(self.device)
+ spikes = spikes.to(self.device)
labels.append(label.detach().cpu().item())
# Forward pass
spikes = self.forward(spikes)
@@ -152,31 +161,66 @@ def evaluate(self, models, test_loader, layers=None):
# Create GT matrix
GT = np.zeros((model.query_places,model.database_places), dtype=int)
for n, ndx in enumerate(labels):
- #if model.filter !=1:
- # ndx = ndx//model.filter
+ if model.filter !=1:
+ ndx = ndx//model.filter
GT[n,ndx] = 1
# Create GT soft matrix
if model.GT_tolerance > 0:
GTsoft = np.zeros((model.query_places,model.database_places), dtype=int)
for n, ndx in enumerate(labels):
- #if model.filter !=1:
- # ndx = ndx//model.filter
+ if model.filter !=1:
+ ndx = ndx//model.filter
GTsoft[n, ndx] = 1
# Apply tolerance
for i in range(max(0, n - model.GT_tolerance), min(model.query_places, n + model.GT_tolerance + 1)):
GTsoft[i, ndx] = 1
else:
GTsoft = None
+
+ # If user specified, generate a PR curve
+ if model.PR_curve:
+ # Create PR curve
+ P, R = createPR(out, GThard=GT, GTsoft=GTsoft, matching='single', n_thresh=100)
+ # Combine P and R into a list of lists
+ PR_data = {
+ "Precision": P,
+ "Recall": R
+ }
+ output_file = "PR_curve_data.json"
+ # Construct the full path
+ full_path = f"{model.output_folder}/{output_file}"
+ # Write the data to a JSON file
+ with open(full_path, 'w') as file:
+ json.dump(PR_data, file)
+ # Plot PR curve
+ plt.plot(R,P)
+ plt.xlabel('Recall')
+ plt.ylabel('Precision')
+ plt.title('Precision-Recall Curve')
+ plt.show()
+
+ if model.sim_mat:
+ # Create a figure and a set of subplots
+ fig, axs = plt.subplots(1, 3, figsize=(15, 5))
+
+ # Plot each matrix using matshow
+ cax1 = axs[0].matshow(out, cmap='viridis')
+ fig.colorbar(cax1, ax=axs[0], shrink=0.8)
+ axs[0].set_title('Similarity matrix')
+
+ cax2 = axs[1].matshow(GT, cmap='plasma')
+ fig.colorbar(cax2, ax=axs[1], shrink=0.8)
+ axs[1].set_title('GT')
+
+ cax3 = axs[2].matshow(GTsoft, cmap='inferno')
+ fig.colorbar(cax3, ax=axs[2], shrink=0.8)
+ axs[2].set_title('GT-soft')
+
+ # Adjust layout
+ plt.tight_layout()
+ plt.show()
- # Create PR curve
- P, R = createPR(out, GThard=GT, GTsoft=GTsoft, matching='single', n_thresh=100)
- # Plot PR curve
- plt.plot(R,P)
- plt.xlabel('Recall')
- plt.ylabel('Precision')
- plt.title('Precision-Recall Curve')
- plt.show()
# Recall@N
N = [1,5,10,15,20,25] # N values to calculate
R = [] # Recall@N values
@@ -227,52 +271,52 @@ def run_inference(models, model_name):
"""
Run inference on a pre-trained model.
- :param model: Model to run inference on
+ :param models: Models to run inference on, each model is a VPRTempo module
:param model_name: Name of the model to load
- :param qconfig: Quantization configuration
"""
- # Initialize the image transforms and datasets
- image_transform = ProcessImage(models[0].dims, models[0].patches)
-
- # Determine input range
- if models[0].query_places == models[0].database_places:
- max_samples=models[0].query_places
- subset = False
- elif models[0].query_places < models[0].database_places:
- max_samples=models[0].database_places
- subset = True
+ # Set first index model as the main model for parameters
+ model = models[0]
+ # Initialize the image transforms
+ image_transform = ProcessImage(model.dims, model.patches)
+
+ # Determines if querying a subset of the database or the entire database
+ if model.query_places == model.database_places:
+ subset = False # Entire database
+ elif model.query_places < model.database_places:
+ subset = True # Subset of the database
else:
raise ValueError("The number of query places must be less than or equal to the number of database places.")
+
# Initialize the test dataset
- test_dataset = CustomImageDataset(annotations_file=models[0].dataset_file,
- base_dir=models[0].data_dir,
- img_dirs=models[0].query_dir,
+ test_dataset = CustomImageDataset(annotations_file=model.dataset_file,
+ base_dir=model.data_dir,
+ img_dirs=model.query_dir,
transform=image_transform,
- max_samples=max_samples)
- # If the number of query places is less than the number of database places, then subset the database
+ max_samples=model.database_places,
+ skip=model.filter)
+
+ # If using a subset of the database
if subset:
- if models[0].shuffle:
- test_dataset = Subset(test_dataset, random.sample(range(len(test_dataset)), models[0].query_places))
- else:
- # Generate indices with applied skip
- indices = [i for i in range(models[0].database_places) if i % models[0].filter == 0]
+ if model.shuffle: # For a randomized selection of database places
+ test_dataset = Subset(test_dataset, random.sample(range(len(test_dataset)), model.query_places))
+ else: # For a sequential selection of database places
+ indices = [i for i in range(model.database_places) if i % model.filter == 0]
# Limit to the desired number of queries
- indices = indices[:models[0].query_places]
+ indices = indices[:model.query_places]
+ # Create the subset
test_dataset = Subset(test_dataset, indices)
# Initialize the data loader
test_loader = DataLoader(test_dataset,
batch_size=1,
- shuffle=models[0].shuffle,
+ shuffle=model.shuffle,
num_workers=8,
persistent_workers=True)
# Load the model
- models[0].load_model(models, os.path.join('./vprtempo/models', model_name))
- # Retrieve layer names for inference
- layer_names = list(models[0].layer_dict.keys())
+ model.load_model(models, os.path.join('./vprtempo/models', model_name))
# Use evaluate method for inference accuracy
with torch.no_grad():
- models[0].evaluate(models, test_loader, layers=layer_names)
\ No newline at end of file
+ model.evaluate(models, test_loader)
\ No newline at end of file
diff --git a/vprtempo/dataset/.DS_Store b/vprtempo/dataset/.DS_Store
deleted file mode 100644
index ea74ea872093a8a770552a84473018d23402d961..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001
literal 14340
zcmeGiZEPGzb>1en_mXYAZfxQNZ*%gcYT7ui<9yg@8ufg3n!5RLnzQ3`;^eY-ySBI4
z+r94Y-NkW@OjRn;Qc!+T6^g>IwzRYe2vqO`BuEKJ{3!t`e5(k=$B$NmKZT<3W@hhh
z?`#VpxwL6_PBU+2-kUc&Z{Nqh+a-iRqM+>{q=yh9Vx>@)uz7-zD20JS@w6@~t8fOe
zZpMmznGlnh#iV1Iwl%$odhvvSQ*`}mR;P$sw;t>aVvsF)6>IkmRb4$D}`|%E^x?P2R8H(p5o~vJS*6z
z5A1X<>chAX7dS)=7T`+}@}=OX7{Hg~eAK6l`*49nzMKHt8Tt4-gP)-Q+a2eljGRE6
z!>|n{5K3UN1X^)QA{McULp034RqSh|n`DVgio>Rz&6t%ed;n>KS$Gp$m0g+|lsf80
znDLNXpZ8KButCqtP(k<(n<(6C5x3spjwS8)57w%ytdWu&s$n7iB%{>A7KR%EldwDCdzpMY$(*$3?hk
zP3p9%W^{AJO4y!f=O$FgQLSk`F`F{7w#Bc!nX$#;F-NB?WB0wyMH!19DtET>j!a^_
zLr5(CGX1{@A84sk2Ma_-E!)y#@nd{8g284R86Zc<3385HAQ#DVG?~@qLxSCZ#jDHlz)Nb@TU`}))ZEHs9@;H3OUrV>)sO4Q
z%P8tY$kDAB)M9gZW!;9&TefcR?YlRAVED*G*O;e~+N6`{`Yg|wGIVD`b+fvqp-UMn
z8JagUUNl@IW9p-pn$sgK&=tAX+>Sn}Ls2+EbwQalRmd2R6bx@xj)@}MBdT;tlNw&!
z5nH3XXrg=>`JkV>UFlR#QOuyDW2^ZK?o_r(qb>?YlvPgHE%ivF=g?e<6P-?*Kp}y00~`7FuBwY=JJ=4SmoLdm#ZMa10$;5+-2=UD?Cn!g-j5
zId~L41y90L@HBiDJ`c~qm*5h78NLEvhi|}T_!fK*uELMuC-5_P6aD~ygg?Ps@K5*`
zyel*a8-#XYqtGdI3EPEkp-<=+?h*D0M}-j~C8)x~f-NxDhnYgzWu7h{kV9b>VZFKe
zu~B$sX>;p1_`vXk(&3xN+^+Lb-imQNdiz~>@9N(_P*Fs#xn*SyngG-L8f#&C&-9{B
zXRk3!ly2KxHodAba^oCBgVE26)NbV_pE|ZhM1zqV{NdI+ad=2>DS4d@7@-Axn
zI#>_wa0m3DekU+YA7vVzgfTb+X*dfR4B0ubz=lUq*B^y>_yjD#W2ot$L0$hGT!d%g
z3-CqM_vhgScoANLucEep8(xF&zz^Vua0Pzk>;11$>;DRW^EI}t$wz868U9hP$-&_M
zwSx_h2ZDPg_~h(ra9@rl7~C(UqjcLEi&+%hOYAAzTO8cUb$m)S0#DkMS6D&`gcA5@
zN?;8hJA!_qh}9zA(z*PS@88+m`r6cidY78ifxf3_bRUwmp>m!ue}7N%_m`OwzR&r4
zW3v*E
z;LMifV}33ez*qCYrSQQVu+$=xn*XeLvK{9C^dGJ;|8FE={!im)SV9Se5?EdW4II@&
zzyCY1{P+LMGYB!E1VRbi^b&y8L&>289;5tXqy9wnTe~RMvsfw3_p#D~eYOM}_BkFx
z_BkF7=;wI-jL2^g7pGl>3mjq(%m4jDK>Ah62mGDzieyGN`#a%EW`v(#g!%u9=l}l!
Du1P2i
diff --git a/vprtempo/models/.DS_Store b/vprtempo/models/.DS_Store
deleted file mode 100644
index 28b3c56f896fa8c0c5ff67d4e21152c0cdbb6e3d..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001
literal 6148
zcmeHKJx{|h5PesQRK(H=x>U+9sAx-R$554ljW01n0g0gy2D)eG$ME}@;N97(NJ4-G
zB+y-SezxO#v7ggi4M3*L>psv1(4s3gHduUMdSAU_jo7nJ6yvzT9UhQj2y(mq0R?pL
zIylD>PH=|z`!~cMuK~vN5q5a>(c|tu;`VTnUCkHSoUwdyq)Y6*n%RoCD|};2u*>)n
z87IU~*sxz>tIn>-HZZpO**%r`w!WC1uC->RcrI~uD_qA|;u7~%Nq+<1*YKsp%%+VVjeJrL*n5>;*%wwP{f{2{=|ku(tuvN0~E*CzA}x|)v51Fk|?u~f{lnuM*6?g+ibYdz1
diff --git a/vprtempo/src/loggers.py b/vprtempo/src/loggers.py
index d870de1..8f7e5cb 100644
--- a/vprtempo/src/loggers.py
+++ b/vprtempo/src/loggers.py
@@ -48,7 +48,7 @@ def model_logger():
logger.info('Current device is: CPU')
logger.info('')
- return logger
+ return logger, output_folder
def model_logger_quant():
"""
@@ -91,4 +91,4 @@ def model_logger_quant():
logger.info('Current device is: CPU')
logger.info('')
- return logger
\ No newline at end of file
+ return logger, output_folder
\ No newline at end of file
From 7e9abb68d04e7586733f9c441c314e507fb2f5ce Mon Sep 17 00:00:00 2001
From: Adam Hines
Date: Tue, 2 Jan 2024 11:54:22 +1000
Subject: [PATCH 6/8] Fixed up quantization with new module strategy, trained
new models with new model names
---
.gitignore | 1 -
README.md | 24 ++--
main.py | 44 +++++---
vprtempo/VPRTempoQuant.py | 194 ++++++++++++++++++++++++---------
vprtempo/VPRTempoQuantTrain.py | 105 ++++++++++--------
5 files changed, 245 insertions(+), 123 deletions(-)
diff --git a/.gitignore b/.gitignore
index 7baed65..9a70cc7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -8,4 +8,3 @@ vprtempo/dataset/winter/
vprtempo/dataset/event.csv
vprtempo/output/
vprtempo/src/__pycache__/
-vprtempo/models/VPRTempo3136627250045.pth
diff --git a/README.md b/README.md
index a22dcc7..f66787e 100644
--- a/README.md
+++ b/README.md
@@ -64,22 +64,28 @@ If you wish to enable CUDA, please follow the instructions on the [PyTorch - Get
Dependencies can be installed either through our provided `requirements.txt` files.
```python
-pip3 install -r requirements.txt
+pip install -r requirements.txt
```
As above, if you wish to install CUDA please visit [PyTorch - Get Started](https://pytorch.org/get-started/locally/).
### Option 3: Conda install
>**:heavy_exclamation_mark: Recommended:**
-> Use [Mambaforge](https://mamba.readthedocs.io/en/latest/installation.html) instead of conda.
+> Use [Mambaforge](https://mamba.readthedocs.io/en/latest/installation/mamba-installation.html) instead of conda.
+
+Requirements for VPRTempo may be installed using our [conda-forge package](https://anaconda.org/conda-forge/vprtempo).
```console
-# Windows/Linux - CUDA enabled
-conda create -n vprtempo -c pytorch -c nvidia python torchvision torchaudio pytorch-cuda=11.7 cudatoolkit prettytable tqdm numpy pandas scikit-learn
+# Linux/OS X
+conda create -n vprtempo -c conda-forge vprtempo
+
+# Linux CUDA enabled
+conda create -n vprtempo -c conda-forge -c pytorch -c nvidia vprtempo pytorch-cuda cudatoolkit
-# Windows/Linux - CPU only
-conda create -n vprtempo python pytorch torchvision torchaudio cpuonly prettytable tqdm numpy pandas scikit-learn -c pytorch
+# Windows
+conda create -n vprtempo -c pytorch python pytorch torchvision torchaudio cpuonly prettytable tqdm numpy pandas scikit-learn
+
+# Windows CUDA enabled
+conda create -n vprtempo -c pytorch -c nvidia python torchvision torchaudio pytorch-cuda=11.7 cudatoolkit prettytable tqdm numpy pandas scikit-learn
-# MacOS
-conda create -n vprtempo -c conda-forge python prettytable tqdm numpy pandas scikit-learn -c pytorch pytorch::pytorch torchvision torchaudio
```
## Datasets
@@ -142,8 +148,6 @@ python main.py --quantize
-#### IDE
-You can also run VPRTempo through your IDE by running `main.py`. Change the `bool` flag for `use_quantize` to `True` if you wish to run VPRTempoQuant.
### Train new network
If you do not wish to use the pretrained models or you would like to train your own, we can parse the `--train_new_model` flag to `main.py`. Note, if a pretrained model already exists you will be prompted if you would like to retrain it.
diff --git a/main.py b/main.py
index 0e739e6..5300ae5 100644
--- a/main.py
+++ b/main.py
@@ -100,21 +100,26 @@ def initialize_and_run_model(args,dims):
# If using quantization aware training
if args.quantize:
models = []
- logger = model_logger_quant()
- # Get the quantization config
+ logger = model_logger_quant() # Initialize the logger
qconfig = quantization.get_default_qat_qconfig('fbgemm')
- for _ in tqdm(range(args.num_modules), desc="Initializing modules"):
- # Initialize the model
- model = VPRTempoQuantTrain(args, dims, logger)
+ # Create the modules
+ final_out = None
+ for mod in tqdm(range(num_modules), desc="Initializing modules"):
+ model = VPRTempoQuantTrain(args, dims, logger, num_modules, out_dim, out_dim_remainder=final_out) # Initialize the model
model.train()
model.qconfig = qconfig
- models.append(model)
+ quantization.prepare_qat(model, inplace=True)
+ models.append(model) # Create module list
+ if mod == num_modules - 2:
+ final_out = final_out_dim
# Generate the model name
model_name = generate_model_name(model,args.quantize)
# Check if the model has been trained before
check_pretrained_model(model_name)
+ # Get the quantization config
+ qconfig = quantization.get_default_qat_qconfig('fbgemm')
# Train the model
- train_new_model_quant(models, model_name, qconfig)
+ train_new_model_quant(models, model_name)
# Base model
else:
@@ -143,20 +148,29 @@ def initialize_and_run_model(args,dims):
# Set the quantization configuration
if args.quantize:
models = []
- logger = model_logger_quant()
+ logger, output_folder = model_logger_quant()
qconfig = quantization.get_default_qat_qconfig('fbgemm')
- for _ in tqdm(range(args.num_modules), desc="Initializing modules"):
+ final_out = None
+ for _ in tqdm(range(num_modules), desc="Initializing modules"):
# Initialize the model
- model = VPRTempoQuant(dims, args, logger)
+ model = VPRTempoQuant(
+ args,
+ dims,
+ logger,
+ num_modules,
+ output_folder,
+ out_dim,
+ out_dim_remainder=final_out
+ )
model.eval()
model.qconfig = qconfig
- model = quantization.prepare(model, inplace=False)
- model = quantization.convert(model, inplace=False)
+ quantization.prepare(model, inplace=True)
+ quantization.convert(model, inplace=True)
models.append(model)
# Generate the model name
- model_name = generate_model_name_quant(model)
+ model_name = generate_model_name(model, args.quantize)
# Run the quantized inference model
- run_inference_quant(models, model_name, qconfig)
+ run_inference_quant(models, model_name)
else:
models = []
logger, output_folder = model_logger() # Initialize the logger
@@ -208,7 +222,7 @@ def parse_network(use_quantize=False, train_new_model=False):
help="Directories to use for testing")
parser.add_argument('--shuffle', action='store_true',
help="Shuffle input images during query")
- parser.add_argument('--GT_tolerance', type=int, default=2,
+ parser.add_argument('--GT_tolerance', type=int, default=1,
help="Ground truth tolerance for matching")
# Define training parameters
diff --git a/vprtempo/VPRTempoQuant.py b/vprtempo/VPRTempoQuant.py
index b667923..a752574 100644
--- a/vprtempo/VPRTempoQuant.py
+++ b/vprtempo/VPRTempoQuant.py
@@ -25,49 +25,64 @@
'''
import os
+import json
import torch
+import random
import numpy as np
import torch.nn as nn
+import matplotlib.pyplot as plt
import vprtempo.src.blitnet as bn
from tqdm import tqdm
from prettytable import PrettyTable
-from torch.utils.data import DataLoader
-from vprtempo.src.metrics import recallAtK
+from torch.utils.data import DataLoader, Subset
+from vprtempo.src.metrics import recallAtK, createPR
from torch.ao.quantization import QuantStub, DeQuantStub
from vprtempo.src.dataset import CustomImageDataset, ProcessImage
#from main import parse_network
class VPRTempoQuant(nn.Module):
- def __init__(self, dims, args=None, logger=None):
+ def __init__(self, args, dims, logger, num_modules, output_folder, out_dim, out_dim_remainder=None):
super(VPRTempoQuant, self).__init__()
- # Set the arguments
- self.args = args
- for arg in vars(args):
- setattr(self, arg, getattr(args, arg))
+ # Set the args
+ if args is not None:
+ self.args = args
+ for arg in vars(args):
+ setattr(self, arg, getattr(args, arg))
setattr(self, 'dims', dims)
- # Set the dataset file
- self.dataset_file = os.path.join('./vprtempo/dataset', self.dataset + '.csv')
- # Set the model logger and return the device
- self.logger = logger
+ # Set the device
self.device = "cpu"
- # Add quantization stubs for Quantization Aware Training (QAT)
+ # Set input args
+ self.logger = logger
+ self.num_modules = num_modules
+ self.output_folder = output_folder
+
self.quant = QuantStub()
- self.dequant = DeQuantStub()
+ self.dequant = DeQuantStub()
+
+ # Set the dataset file
+ self.dataset_file = os.path.join('./vprtempo/dataset', self.dataset + '.csv')
+ self.query_dir = [dir.strip() for dir in self.query_dir.split(',')]
# Layer dict to keep track of layer names and their order
self.layer_dict = {}
self.layer_counter = 0
+ self.database_dirs = [dir.strip() for dir in self.database_dirs.split(',')]
# Define layer architecture
- self.input = int(dims[0]*dims[1])
+ self.input = int(self.dims[0]*self.dims[1])
self.feature = int(self.input * 2)
- self.output = int(args.num_places / args.num_modules)
+
+ # Output dimension changes for final module if not an even distribution of places
+ if not out_dim_remainder is None:
+ self.output = out_dim_remainder
+ else:
+ self.output = out_dim
"""
Define trainable layers here
@@ -125,15 +140,17 @@ def evaluate(self, models, test_loader, layers=None):
nn.ReLU()
))
# Initialize the tqdm progress bar
- pbar = tqdm(total=self.num_places,
+ pbar = tqdm(total=self.query_places,
desc="Running the test network",
position=0)
# Initiliaze the output spikes variable
out = []
+ labels = []
# Run inference for the specified number of timesteps
- for spikes, labels in test_loader:
+ for spikes, label in test_loader:
# Set device
- spikes, labels = spikes.to(self.device), labels.to(self.device)
+ spikes = spikes.to(self.device)
+ labels.append(label.detach().item())
# Pass through previous layers if they exist
spikes = self.forward(spikes)
# Add output spikes to list
@@ -144,21 +161,82 @@ def evaluate(self, models, test_loader, layers=None):
pbar.close()
# Rehsape output spikes into a similarity matrix
- out = np.reshape(np.array(out),(self.num_places,self.num_places))
- # Calculate and print the Recall@N
- N = [1,5,10,15,20,25]
- R = []
+ out = np.reshape(np.array(out),(model.query_places,model.database_places))
+
# Create GT matrix
- GT = np.zeros((self.num_places,self.num_places), dtype=int)
- for n in range(len(GT)):
- GT[n,n] = 1
+ GT = np.zeros((model.query_places,model.database_places), dtype=int)
+ for n, ndx in enumerate(labels):
+ if model.filter !=1:
+ ndx = ndx//model.filter
+ GT[n,ndx] = 1
+
+ # Create GT soft matrix
+ if model.GT_tolerance > 0:
+ GTsoft = np.zeros((model.query_places,model.database_places), dtype=int)
+ for n, ndx in enumerate(labels):
+ if model.filter !=1:
+ ndx = ndx//model.filter
+ GTsoft[n, ndx] = 1
+ # Apply tolerance
+ for i in range(max(0, n - model.GT_tolerance), min(model.query_places, n + model.GT_tolerance + 1)):
+ GTsoft[i, ndx] = 1
+ else:
+ GTsoft = None
+
+ # If user specified, generate a PR curve
+ if model.PR_curve:
+ # Create PR curve
+ P, R = createPR(out, GThard=GT, GTsoft=GTsoft, matching='single', n_thresh=100)
+ # Combine P and R into a list of lists
+ PR_data = {
+ "Precision": P,
+ "Recall": R
+ }
+ output_file = "PR_curve_data.json"
+ # Construct the full path
+ full_path = f"{model.output_folder}/{output_file}"
+ # Write the data to a JSON file
+ with open(full_path, 'w') as file:
+ json.dump(PR_data, file)
+ # Plot PR curve
+ plt.plot(R,P)
+ plt.xlabel('Recall')
+ plt.ylabel('Precision')
+ plt.title('Precision-Recall Curve')
+ plt.show()
+
+ if model.sim_mat:
+ # Create a figure and a set of subplots
+ fig, axs = plt.subplots(1, 3, figsize=(15, 5))
+
+ # Plot each matrix using matshow
+ cax1 = axs[0].matshow(out, cmap='viridis')
+ fig.colorbar(cax1, ax=axs[0], shrink=0.8)
+ axs[0].set_title('Similarity matrix')
+
+ cax2 = axs[1].matshow(GT, cmap='plasma')
+ fig.colorbar(cax2, ax=axs[1], shrink=0.8)
+ axs[1].set_title('GT')
+
+ cax3 = axs[2].matshow(GTsoft, cmap='inferno')
+ fig.colorbar(cax3, ax=axs[2], shrink=0.8)
+ axs[2].set_title('GT-soft')
+
+ # Adjust layout
+ plt.tight_layout()
+ plt.show()
+
+ # Recall@N
+ N = [1,5,10,15,20,25] # N values to calculate
+ R = [] # Recall@N values
+ # Calculate Recall@N
for n in N:
- R.append(recallAtK(out,GThard=GT,K=n))
+ R.append(round(recallAtK(out,GThard=GT,GTsoft=GTsoft,K=n),2))
# Print the results
table = PrettyTable()
table.field_names = ["N", "1", "5", "10", "15", "20", "25"]
table.add_row(["Recall", R[0], R[1], R[2], R[3], R[4], R[5]])
- print(table)
+ self.logger.info(table)
def forward(self, spikes):
"""
@@ -194,45 +272,57 @@ def load_model(self, models, model_path):
model.load_state_dict(combined_state_dict[f'model_{i}'])
model.eval() # Set the model to inference mode
-
-def check_pretrained_model(model_name):
- """
- Check if a pre-trained model exists and tell user if it does not.
- """
- if not os.path.exists(os.path.join('./models', model_name)):
- model.logger.info("A pre-trained network does not exist: please train one using VPRTempoQuant_Trainer")
- pretrain = 'n'
- else:
- pretrain = 'y'
- return pretrain
-def run_inference_quant(models, model_name, qconfig):
+def run_inference_quant(models, model_name):
"""
Run inference on a pre-trained model.
- :param model: Model to run inference on
+ :param models: Models to run inference on, each model is a VPRTempo module
:param model_name: Name of the model to load
- :param qconfig: Quantization configuration
"""
- max_samples = models[0].num_places
- # Initialize the image transforms and datasets
- image_transform = ProcessImage(models[0].dims, models[0].patches)
- test_dataset = CustomImageDataset(annotations_file=models[0].dataset_file,
- base_dir=models[0].data_dir,
- img_dirs=models[0].query_dir,
+ # Set first index model as the main model for parameters
+ model = models[0]
+ # Initialize the image transforms
+ image_transform = ProcessImage(model.dims, model.patches)
+
+ # Determines if querying a subset of the database or the entire database
+ if model.query_places == model.database_places:
+ subset = False # Entire database
+ elif model.query_places < model.database_places:
+ subset = True # Subset of the database
+ else:
+ raise ValueError("The number of query places must be less than or equal to the number of database places.")
+
+ # Initialize the test dataset
+ test_dataset = CustomImageDataset(annotations_file=model.dataset_file,
+ base_dir=model.data_dir,
+ img_dirs=model.query_dir,
transform=image_transform,
- skip=models[0].filter,
- max_samples=max_samples)
+ max_samples=model.database_places,
+ skip=model.filter)
+
+ # If using a subset of the database
+ if subset:
+ if model.shuffle: # For a randomized selection of database places
+ test_dataset = Subset(test_dataset, random.sample(range(len(test_dataset)), model.query_places))
+ else: # For a sequential selection of database places
+ indices = [i for i in range(model.database_places) if i % model.filter == 0]
+ # Limit to the desired number of queries
+ indices = indices[:model.query_places]
+ # Create the subset
+ test_dataset = Subset(test_dataset, indices)
+
+
# Initialize the data loader
test_loader = DataLoader(test_dataset,
batch_size=1,
- shuffle=False,
+ shuffle=model.shuffle,
num_workers=8,
persistent_workers=True)
# Load the model
- models[0].load_model(models, os.path.join('./vprtempo/models', model_name))
+ model.load_model(models, os.path.join('./vprtempo/models', model_name))
# Use evaluate method for inference accuracy
with torch.no_grad():
- models[0].evaluate(models, test_loader)
\ No newline at end of file
+ model.evaluate(models, test_loader)
\ No newline at end of file
diff --git a/vprtempo/VPRTempoQuantTrain.py b/vprtempo/VPRTempoQuantTrain.py
index 451420a..ac320de 100644
--- a/vprtempo/VPRTempoQuantTrain.py
+++ b/vprtempo/VPRTempoQuantTrain.py
@@ -39,7 +39,7 @@
from vprtempo.src.dataset import CustomImageDataset, ProcessImage
class VPRTempoQuantTrain(nn.Module):
- def __init__(self, args, dims, logger):
+ def __init__(self, args, dims, logger, num_modules, out_dim, out_dim_remainder=None):
super(VPRTempoQuantTrain, self).__init__()
# Set the arguments
@@ -47,20 +47,16 @@ def __init__(self, args, dims, logger):
for arg in vars(args):
setattr(self, arg, getattr(args, arg))
setattr(self, 'dims', dims)
-
- # Only CPU available for quantization
- self.device = "cpu"
- # Set the logger
+ self.device = "cpu"
self.logger = logger
+ self.num_modules = num_modules
+ self.quant = QuantStub()
+ self.dequant = DeQuantStub()
# Set the dataset file
self.dataset_file = os.path.join('./vprtempo/dataset', self.dataset + '.csv')
- # Add quantization stubs for Quantization Aware Training (QAT)
- self.quant = QuantStub()
- self.dequant = DeQuantStub()
-
# Layer dict to keep track of layer names and their order
self.layer_dict = {}
self.layer_counter = 0
@@ -68,11 +64,18 @@ def __init__(self, args, dims, logger):
# Define layer architecture
self.input = int(dims[0]*dims[1])
self.feature = int(self.input * 2)
- self.output = int(args.num_places / args.num_modules)
+ if not out_dim_remainder is None:
+ self.output = out_dim_remainder
+ else:
+ self.output = out_dim
# Set the total timestep count
- self.location_repeat = len(args.database_dirs) # Number of times to repeat the locations
- self.T = int((self.num_places / self.num_modules) * self.location_repeat * self.epoch)
+ self.database_dirs = [dir.strip() for dir in self.database_dirs.split(',')]
+ self.location_repeat = len(self.database_dirs) # Number of times to repeat the locations
+ if not out_dim_remainder is None:
+ self.T = int(out_dim_remainder * self.location_repeat * self.epoch)
+ else:
+ self.T = int(self.max_module * self.location_repeat * self.epoch)
"""
Define trainable layers here
@@ -226,60 +229,72 @@ def check_pretrained_model(model_name):
return retrain == 'n'
return False
-def train_new_model_quant(models, model_name, qconfig):
+def train_new_model_quant(models, model_name):
"""
Train a new model.
:param model: Model to train
:param model_name: Name of the model to save after training
- :param qconfig: Quantization configuration
"""
+ # Set first index model as the main model for parameters
+ model = models[0]
# Initialize the image transforms and datasets
image_transform = transforms.Compose([
- ProcessImage(models[0].dims, models[0].patches)
+ ProcessImage(model.dims, model.patches)
])
# Automatically generate user_input_ranges
user_input_ranges = []
start_idx = 0
- for _ in range(models[0].num_modules):
- range_temp = [start_idx, start_idx+((models[0].max_module-1)*models[0].filter)]
+ for _ in range(model.num_modules):
+ range_temp = [start_idx, start_idx+((model.max_module-1)*model.filter)]
user_input_ranges.append(range_temp)
start_idx = range_temp[1] + models[0].filter
- if models[0].num_places < models[0].max_module:
- max_samples=models[0].num_places
+ if model.query_places < model.max_module:
+ max_samples=model.query_places
else:
- max_samples = models[0].max_module
+ max_samples = model.max_module
# Keep track of trained layers to pass data through them
+ trained_layers = []
+
# Training each layer
- trained_models = []
- for i, model in enumerate(models):
- trained_layers = []
- img_range=user_input_ranges[i]
- train_dataset = CustomImageDataset(annotations_file=models[0].dataset_file,
- base_dir=models[0].data_dir,
- img_dirs=models[0].database_dirs,
- transform=image_transform,
- skip=models[0].filter,
- test=False,
- img_range=img_range,
- max_samples=max_samples)
- # Initialize the data loader
- train_loader = DataLoader(train_dataset,
- batch_size=1,
- shuffle=True,
- num_workers=8,
- persistent_workers=True)
- model = quantization.prepare_qat(model, inplace=False)
- for layer_name, _ in sorted(models[0].layer_dict.items(), key=lambda item: item[1]):
- print(f"Training layer: {layer_name}")
+ for layer_name, _ in sorted(model.layer_dict.items(), key=lambda item: item[1]):
+ print(f"Training layer: {layer_name}")
+ # Retrieve the layer object
+ for i, model in enumerate(models):
+ model.train()
+ model.to(torch.device(model.device))
layer = (getattr(model, layer_name))
-
+ # Determine the maximum samples for the DataLoader
+ if model.database_places < model.max_module:
+ max_samples = model.database_places
+ elif model.output < model.max_module:
+ max_samples = model.output
+ else:
+ max_samples = model.max_module
+ # Initialize new dataset with unique range for each module
+ img_range=user_input_ranges[i]
+ train_dataset = CustomImageDataset(annotations_file=model.dataset_file,
+ base_dir=model.data_dir,
+ img_dirs=model.database_dirs,
+ transform=image_transform,
+ skip=model.filter,
+ test=False,
+ img_range=img_range,
+ max_samples=max_samples)
+ # Initialize the data loader
+ train_loader = DataLoader(train_dataset,
+ batch_size=1,
+ shuffle=True,
+ num_workers=8,
+ persistent_workers=True)
# Train the layers
model.train_model(train_loader, layer, model, i, prev_layers=trained_layers)
trained_layers.append(layer_name)
- trained_models.append(quantization.convert(model, inplace=False))
- # After training the current layer, add it to the list of trained layer
+
+ # Convert the model to evaluation mode
+ for model in models:
+ quantization.convert(model, inplace=True)
# Save the model
- model.save_model(trained_models,os.path.join('./vprtempo/models', model_name))
\ No newline at end of file
+ model.save_model(models,os.path.join('./vprtempo/models', model_name))
\ No newline at end of file
From 5832b774e0212445f288292679a45ed59e0894ad Mon Sep 17 00:00:00 2001
From: Adam Hines
Date: Tue, 2 Jan 2024 11:59:57 +1000
Subject: [PATCH 7/8] Touch-ups - readme and version number for v1.1.5
---
README.md | 2 +-
setup.py | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/README.md b/README.md
index f66787e..3b13f4e 100644
--- a/README.md
+++ b/README.md
@@ -25,7 +25,7 @@ To use VPRTempo, please follow the instructions below for installation and usage
- Quantization Aware Training (QAT) enabled to train weights in int8 space
- Addition of tutorials in Jupyter Notebooks to learn how to use VPRTempo as well as explain the computational logic
- Simplification of weight operations, reducing to a single weight tensor - allowing positive and negative connections to change sign during training
- - Easier dependency installation with PyPi/pip
+ - Easier dependency installation with PyPi/pip and conda
- And more!
## License & Citation
diff --git a/setup.py b/setup.py
index 112ab1f..c63fda9 100644
--- a/setup.py
+++ b/setup.py
@@ -21,7 +21,7 @@
# define the setup
setup(
name="VPRTempo",
- version="1.1.4",
+ version="1.1.5",
description='VPRTempo: A Fast Temporally Encoded Spiking Neural Network for Visual Place Recognition',
long_description=long_description,
long_description_content_type='text/markdown',
From e958cda4cd5dcc762ba0456e1d78ba109938a4af Mon Sep 17 00:00:00 2001
From: Adam Hines
Date: Tue, 2 Jan 2024 12:21:35 +1000
Subject: [PATCH 8/8] Fixed issue in quant model with training layers across
modules
---
vprtempo/VPRTempoQuantTrain.py | 13 +++++--------
1 file changed, 5 insertions(+), 8 deletions(-)
diff --git a/vprtempo/VPRTempoQuantTrain.py b/vprtempo/VPRTempoQuantTrain.py
index ac320de..d5b3c68 100644
--- a/vprtempo/VPRTempoQuantTrain.py
+++ b/vprtempo/VPRTempoQuantTrain.py
@@ -151,7 +151,7 @@ def train_model(self, train_loader, layer, model, model_num, prev_layers=None):
# idx scale factor for different modules
idx_scale = (self.max_module*self.filter)*model_num
# Run training for the specified number of epochs
- for epoch in range(self.epoch):
+ for _ in range(self.epoch):
# Run training for the specified number of timesteps
for spikes, labels in train_loader:
spikes, labels = spikes.to(self.device), labels.to(self.device)
@@ -245,15 +245,12 @@ def train_new_model_quant(models, model_name):
# Automatically generate user_input_ranges
user_input_ranges = []
start_idx = 0
-
+ # Generate the image ranges for each module
for _ in range(model.num_modules):
range_temp = [start_idx, start_idx+((model.max_module-1)*model.filter)]
user_input_ranges.append(range_temp)
- start_idx = range_temp[1] + models[0].filter
- if model.query_places < model.max_module:
- max_samples=model.query_places
- else:
- max_samples = model.max_module
+ start_idx = range_temp[1] + model.filter
+
# Keep track of trained layers to pass data through them
trained_layers = []
@@ -290,7 +287,7 @@ def train_new_model_quant(models, model_name):
persistent_workers=True)
# Train the layers
model.train_model(train_loader, layer, model, i, prev_layers=trained_layers)
- trained_layers.append(layer_name)
+ trained_layers.append(layer_name)
# Convert the model to evaluation mode
for model in models: