From dfbfabc82fcc4ed80011fdb5d963a6f7706a4771 Mon Sep 17 00:00:00 2001 From: Alex Bostock Date: Tue, 20 Feb 2024 19:22:17 +0000 Subject: [PATCH] feat!: validate MD5 checksums on upload and export (#165) * feat!: validate MD5 checksums on upload and export * Remove print * Refactor file handling for style * Fix style / lint issues * Format to style guide * Refactor for linter / cognitive complexity rule --- scribemi/ScribeMi.py | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/scribemi/ScribeMi.py b/scribemi/ScribeMi.py index c9491cd..5630dce 100644 --- a/scribemi/ScribeMi.py +++ b/scribemi/ScribeMi.py @@ -6,6 +6,8 @@ from aws_requests_auth.aws_auth import AWSRequestsAuth from datetime import datetime from typing_extensions import TypedDict, Optional, List +from hashlib import md5 +from base64 import b64encode class Env(TypedDict): @@ -416,6 +418,8 @@ def fetch_model(self, task: MITask): ) res = requests.get(modelUrl) if res.status_code == 200: + verify_etag_checksum(res) + return json.loads(res.text) elif res.status_code == 401 or res.status_code == 403: raise UnauthenticatedException( @@ -462,14 +466,20 @@ def submit_task( if isinstance(file_or_filename, str) and params.get("filename") == None: params["filename"] = file_or_filename - post_res = self.call_endpoint("POST", "/tasks", params) - put_url = post_res["url"] - if isinstance(file_or_filename, str): with open(file_or_filename, "rb") as file: - upload_file(file, put_url) + file_content = file.read() else: - return upload_file(file_or_filename, put_url) + file_content = file_or_filename.read() + + hash = md5(file_content, usedforsecurity=False) + md5checksum = b64encode(hash.digest()).decode() + params["md5checksum"] = md5checksum + + post_res = self.call_endpoint("POST", "/tasks", params) + put_url = post_res["url"] + + upload_file(file_content, md5checksum, put_url) return post_res["jobid"] @@ -486,7 +496,14 @@ def delete_task(self, task: MITask): return self.call_endpoint("DELETE", "/tasks/{}".format(task["jobid"])) -def upload_file(file, url): - res = requests.put(url, data=file) +def upload_file(file, md5checksum, url): + res = requests.put(url, data=file, headers={"Content-MD5": md5checksum}) if res.status_code != 200: raise Exception("Error uploading file: {}".format(res.status_code)) + + +def verify_etag_checksum(res: requests.Response): + md5checksum_expected = res.headers["ETag"].replace('"', "") + md5checksum = md5(res.text.encode(), usedforsecurity=False).hexdigest() + if md5checksum != md5checksum_expected: + raise Exception("Integrity Error: invalid checksum. Please retry.")