Skip to content

Commit

Permalink
Fixed issue in quant model with training layers across modules
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamDHines committed Jan 2, 2024
1 parent 5832b77 commit e958cda
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions vprtempo/VPRTempoQuantTrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def train_model(self, train_loader, layer, model, model_num, prev_layers=None):
# idx scale factor for different modules
idx_scale = (self.max_module*self.filter)*model_num
# Run training for the specified number of epochs
for epoch in range(self.epoch):
for _ in range(self.epoch):
# Run training for the specified number of timesteps
for spikes, labels in train_loader:
spikes, labels = spikes.to(self.device), labels.to(self.device)
Expand Down Expand Up @@ -245,15 +245,12 @@ def train_new_model_quant(models, model_name):
# Automatically generate user_input_ranges
user_input_ranges = []
start_idx = 0

# Generate the image ranges for each module
for _ in range(model.num_modules):
range_temp = [start_idx, start_idx+((model.max_module-1)*model.filter)]
user_input_ranges.append(range_temp)
start_idx = range_temp[1] + models[0].filter
if model.query_places < model.max_module:
max_samples=model.query_places
else:
max_samples = model.max_module
start_idx = range_temp[1] + model.filter

# Keep track of trained layers to pass data through them
trained_layers = []

Expand Down Expand Up @@ -290,7 +287,7 @@ def train_new_model_quant(models, model_name):
persistent_workers=True)
# Train the layers
model.train_model(train_loader, layer, model, i, prev_layers=trained_layers)
trained_layers.append(layer_name)
trained_layers.append(layer_name)

# Convert the model to evaluation mode
for model in models:
Expand Down

0 comments on commit e958cda

Please sign in to comment.