Skip to content

Commit

Permalink
Update VPRTempoQuant for v1.1.6
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamDHines committed Aug 15, 2024
1 parent c65f293 commit a2a1052
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 31 deletions.
36 changes: 9 additions & 27 deletions vprtempo/VPRTempoQuant.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(self, args, dims, logger, num_modules, output_folder, out_dim, out_
self.dequant = DeQuantStub()

# Set the dataset file
self.dataset_file = os.path.join('./vprtempo/dataset', self.dataset + '.csv')
self.dataset_file = os.path.join('./vprtempo/dataset', f'{self.dataset}-{self.query_dir}' + '.csv')
self.query_dir = [dir.strip() for dir in self.query_dir.split(',')]

# Layer dict to keep track of layer names and their order
Expand Down Expand Up @@ -133,9 +133,7 @@ def evaluate(self, models, test_loader, layers=None):
for model in models:
self.inferences.append(nn.Sequential(
model.feature_layer.w,
nn.Hardtanh(0, maxSpike),
model.output_layer.w,
nn.Hardtanh(0, maxSpike)
))
# Initialize the tqdm progress bar
pbar = tqdm(total=self.query_places,
Expand Down Expand Up @@ -164,6 +162,8 @@ def evaluate(self, models, test_loader, layers=None):
# Create GT matrix
GT = np.zeros((model.query_places,model.database_places), dtype=int)
for n, ndx in enumerate(labels):
if model.skip != 0 and not model.query_places < model.database_places:
ndx = ndx - model.skip
if model.filter !=1:
ndx = ndx//model.filter
GT[n,ndx] = 1
Expand All @@ -172,6 +172,8 @@ def evaluate(self, models, test_loader, layers=None):
if model.GT_tolerance > 0:
GTsoft = np.zeros((model.query_places,model.database_places), dtype=int)
for n, ndx in enumerate(labels):
if model.skip != 0 and not model.query_places < model.database_places:
ndx = ndx - model.skip
if model.filter !=1:
ndx = ndx//model.filter
GTsoft[n, ndx] = 1
Expand Down Expand Up @@ -282,39 +284,19 @@ def run_inference_quant(models, model_name):
model = models[0]
# Initialize the image transforms
image_transform = ProcessImage(model.dims, model.patches)

# Determines if querying a subset of the database or the entire database
if model.query_places == model.database_places:
subset = False # Entire database
elif model.query_places < model.database_places:
subset = True # Subset of the database
else:
raise ValueError("The number of query places must be less than or equal to the number of database places.")

# Initialize the test dataset
test_dataset = CustomImageDataset(annotations_file=model.dataset_file,
base_dir=model.data_dir,
img_dirs=model.query_dir,
transform=image_transform,
max_samples=model.database_places,
skip=model.filter)
max_samples=model.query_places,
filter=model.filter,
skip=model.skip)

# If using a subset of the database
if subset:
if model.shuffle: # For a randomized selection of database places
test_dataset = Subset(test_dataset, random.sample(range(len(test_dataset)), model.query_places))
else: # For a sequential selection of database places
indices = [i for i in range(model.database_places) if i % model.filter == 0]
# Limit to the desired number of queries
indices = indices[:model.query_places]
# Create the subset
test_dataset = Subset(test_dataset, indices)


# Initialize the data loader
test_loader = DataLoader(test_dataset,
batch_size=1,
shuffle=model.shuffle,
batch_size=1,
num_workers=8,
persistent_workers=True)

Expand Down
13 changes: 10 additions & 3 deletions vprtempo/VPRTempoQuantTrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,13 @@ def __init__(self, args, dims, logger, num_modules, out_dim, out_dim_remainder=N
self.dequant = DeQuantStub()

# Set the dataset file
self.dataset_file = os.path.join('./vprtempo/dataset', self.dataset + '.csv')
fields = self.database_dirs.split(',')
if len(fields) > 1:
self.dataset_file = []
for field in fields:
self.dataset_file.append(os.path.join('./vprtempo/dataset', f'{self.dataset}-{field}' + '.csv'))
else:
self.dataset_file = os.path.join('./vprtempo/dataset', f'{self.dataset}-{self.database_dirs}' + '.csv')

# Layer dict to keep track of layer names and their order
self.layer_dict = {}
Expand Down Expand Up @@ -96,7 +102,7 @@ def __init__(self, args, dims, logger, num_modules, out_dim, out_dim_remainder=N
ip_rate=0.15,
stdp_rate=0.005,
spk_force=True,
p=[0.25, 0.75],
p=[1.0, 1.0],
device=self.device
)

Expand Down Expand Up @@ -275,7 +281,8 @@ def train_new_model_quant(models, model_name):
base_dir=model.data_dir,
img_dirs=model.database_dirs,
transform=image_transform,
skip=model.filter,
filter=models[0].filter,
skip=models[0].skip,
test=False,
img_range=img_range,
max_samples=max_samples)
Expand Down
Binary file not shown.
Binary file not shown.
2 changes: 1 addition & 1 deletion vprtempo/src/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def model_logger_quant():
logger.info(' ╚████╔╝ ██║ ██║ ██║ ██║ ███████╗██║ ╚═╝ ██║██║ ╚██████╔╝ ╚██████╔╝╚██████╔╝██║ ██║██║ ╚████║ ██║')
logger.info(' ╚═══╝ ╚═╝ ╚═╝ ╚═╝ ╚═╝ ╚══════╝╚═╝ ╚═╝╚═╝ ╚═════╝ ╚══▀▀═╝ ╚═════╝ ╚═╝ ╚═╝╚═╝ ╚═══╝ ╚═╝')
logger.info('-----------------------------------------------------------------------')
logger.info('Temporally Encoded Spiking Neural Network for Visual Place Recognition v1.1.0')
logger.info('Temporally Encoded Spiking Neural Network for Visual Place Recognition v1.1.6')
logger.info('Queensland University of Technology, Centre for Robotics')
logger.info('')
logger.info('© 2023 Adam D Hines, Peter G Stratton, Michael Milford, Tobias Fischer')
Expand Down

0 comments on commit a2a1052

Please sign in to comment.