diff --git a/examples/.config/model_params_onnxrt.json b/examples/.config/model_params_onnxrt.json index 5db34a114..085c7ef6c 100644 --- a/examples/.config/model_params_onnxrt.json +++ b/examples/.config/model_params_onnxrt.json @@ -55,6 +55,34 @@ "input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf-with-past", "main_script": "main.py", "batch_size": 1 - } + }, + "bert_base_MRPC": { + "model_src_dir": "nlp/bert/quantization/ptq_static", + "dataset_location": "/tf_dataset/pytorch/glue_data/MRPC", + "input_model": "/tf_dataset2/models/onnx/bert_base_MRPC/bert.onnx", + "main_script": "main.py", + "batch_size": 8 + }, + "bert_base_MRPC_dynamic": { + "model_src_dir": "nlp/bert/quantization/ptq_dynamic", + "dataset_location": "/tf_dataset/pytorch/glue_data/MRPC", + "input_model": "/tf_dataset2/models/onnx/bert_base_MRPC/bert.onnx", + "main_script": "main.py", + "batch_size": 8 + }, + "resnet50-v1-12_qdq": { + "model_src_dir": "image_recognition/resnet50/quantization/ptq_static", + "dataset_location": "/tf_dataset2/datasets/imagenet/ImagenetRaw/ILSVRC2012_img_val", + "input_model": "/tf_dataset2/models/onnx/resnet50-v1-12/resnet50-v1-13.onnx", + "main_script": "main.py", + "batch_size": 1 + }, + "resnet50-v1-12": { + "model_src_dir": "image_recognition/resnet50/quantization/ptq_static", + "dataset_location": "/tf_dataset2/datasets/imagenet/ImagenetRaw/ILSVRC2012_img_val", + "input_model": "/tf_dataset2/models/onnx/resnet50-v1-12/resnet50-v1-12.onnx", + "main_script": "main.py", + "batch_size": 1 + }, } } diff --git a/examples/image_recognition/resnet50/quantization/ptq_static/main.py b/examples/image_recognition/resnet50/quantization/ptq_static/main.py index 4218769ca..232a12912 100644 --- a/examples/image_recognition/resnet50/quantization/ptq_static/main.py +++ b/examples/image_recognition/resnet50/quantization/ptq_static/main.py @@ -119,80 +119,6 @@ def result(self): return 0 return self.num_correct / self.num_sample -class Dataloader: - def __init__(self, dataset_location, image_list, batch_size): - self.batch_size = batch_size - self.image_list = [] - self.label_list = [] - with open(image_list, 'r') as f: - for s in f: - image_name, label = re.split(r"\s+", s.strip()) - src = os.path.join(dataset_location, image_name) - if not os.path.exists(src): - continue - - self.image_list.append(src) - self.label_list.append(int(label)) - - def _preprpcess(self, src): - with Image.open(src) as image: - image = np.array(image.convert('RGB')).astype(np.float32) - image = image / 255. - image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_LINEAR) - - h, w = image.shape[0], image.shape[1] - - y0 = (h - 224) // 2 - x0 = (w - 224) // 2 - image = image[y0:y0 + 224, x0:x0 + 224, :] - image = (image - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225] - image = image.transpose((2, 0, 1)) - return image.astype('float32') - - def __iter__(self): - return self._generate_dataloader() - - def _generate_dataloader(self): - sampler = iter(range(0, len(self.image_list), 1)) - - def collate(batch): - """Puts each data field into a pd frame with outer dimension batch size""" - elem = batch[0] - if isinstance(elem, collections.abc.Mapping): - return {key: collate([d[key] for d in batch]) for key in elem} - elif isinstance(elem, collections.abc.Sequence): - batch = zip(*batch) - return [collate(samples) for samples in batch] - elif isinstance(elem, np.ndarray): - try: - return np.stack(batch) - except: - return batch - else: - return batch - - def batch_sampler(): - batch = [] - for idx in sampler: - batch.append(idx) - if len(batch) == self.batch_size: - yield batch - batch = [] - if len(batch) > 0: - yield batch - - def fetcher(ids): - data = [self._preprpcess(self.image_list[idx]) for idx in ids] - label = [self.label_list[idx] for idx in ids] - return collate(data), label - - for batched_indices in batch_sampler(): - try: - data = fetcher(batched_indices) - yield data - except StopIteration: - return - class DataReader(data_reader.CalibrationDataReader): def __init__(self, model_path, dataset_location, image_list, batch_size=1, calibration_sampling_size=-1):