diff --git a/models/bert/model.py b/models/bert/model.py index 5456fbd..f2519a0 100644 --- a/models/bert/model.py +++ b/models/bert/model.py @@ -33,13 +33,12 @@ logger = logging.getLogger(__name__) PRETRAINED_MODEL_ARCHIVE_MAP = { - 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", - 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", - 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", - 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", - 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", - 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", - 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", + 'bert-base-uncased': os.path.join(os.pardir, 'hedwig-data', 'models', 'bert_pretrained', 'bert-base-uncased.tar.gz'), + 'bert-large-uncased': os.path.join(os.pardir, 'hedwig-data', 'models', 'bert_pretrained', 'bert-large-uncased.tar.gz'), + 'bert-base-cased': os.path.join(os.pardir, 'hedwig-data', 'models', 'bert_pretrained', 'bert-base-cased.tar.gz'), + 'bert-large-cased': os.path.join(os.pardir, 'hedwig-data', 'models', 'bert_pretrained', 'bert-large-cased.tar.gz'), + 'bert-base-multilingual-uncased': os.path.join(os.pardir, 'hedwig-data', 'models', 'bert_pretrained', 'bert-base-multilingual-uncased.tar.gz'), + 'bert-base-multilingual-cased': os.path.join(os.pardir, 'hedwig-data', 'models', 'bert_pretrained', 'bert-base-multilingual-cased.tar.gz') } CONFIG_NAME = 'bert_config.json' WEIGHTS_NAME = 'pytorch_model.bin' diff --git a/utils/io.py b/utils/io.py index 190c625..d531a50 100644 --- a/utils/io.py +++ b/utils/io.py @@ -15,21 +15,11 @@ from __future__ import absolute_import, division, print_function, unicode_literals -import json import logging import os -import shutil import sys -import tempfile -from functools import wraps -from hashlib import sha256 from io import open -import boto3 -import requests -from botocore.exceptions import ClientError -from tqdm import tqdm - from urllib.parse import urlparse try: @@ -43,50 +33,6 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name -def url_to_filename(url, etag=None): - """ - Convert `url` into a hashed filename in a repeatable way. - If `etag` is specified, append its hash to the url's, delimited - by a period. - """ - url_bytes = url.encode('utf-8') - url_hash = sha256(url_bytes) - filename = url_hash.hexdigest() - - if etag: - etag_bytes = etag.encode('utf-8') - etag_hash = sha256(etag_bytes) - filename += '.' + etag_hash.hexdigest() - - return filename - - -def filename_to_url(filename, cache_dir=None): - """ - Return the url and etag (which may be ``None``) stored for `filename`. - Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. - """ - if cache_dir is None: - cache_dir = PYTORCH_PRETRAINED_BERT_CACHE - if sys.version_info[0] == 3 and isinstance(cache_dir, Path): - cache_dir = str(cache_dir) - - cache_path = os.path.join(cache_dir, filename) - if not os.path.exists(cache_path): - raise EnvironmentError("file {} not found".format(cache_path)) - - meta_path = cache_path + '.json' - if not os.path.exists(meta_path): - raise EnvironmentError("file {} not found".format(meta_path)) - - with open(meta_path, encoding="utf-8") as meta_file: - metadata = json.load(meta_file) - url = metadata['url'] - etag = metadata['etag'] - - return url, etag - - def cached_path(url_or_filename, cache_dir=None): """ Given something that might be a URL (or might be a local path), @@ -103,10 +49,7 @@ def cached_path(url_or_filename, cache_dir=None): parsed = urlparse(url_or_filename) - if parsed.scheme in ('http', 'https', 's3'): - # URL, so get it from the cache (downloading if necessary) - return get_from_cache(url_or_filename, cache_dir) - elif os.path.exists(url_or_filename): + if os.path.exists(url_or_filename): # File, and it exists. return url_or_filename elif parsed.scheme == '': @@ -117,127 +60,6 @@ def cached_path(url_or_filename, cache_dir=None): raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) -def split_s3_path(url): - """Split a full s3 path into the bucket name and path.""" - parsed = urlparse(url) - if not parsed.netloc or not parsed.path: - raise ValueError("bad s3 path {}".format(url)) - bucket_name = parsed.netloc - s3_path = parsed.path - # Remove '/' at beginning of path. - if s3_path.startswith("/"): - s3_path = s3_path[1:] - return bucket_name, s3_path - - -def s3_request(func): - """ - Wrapper function for s3 requests in order to create more helpful error - messages. - """ - - @wraps(func) - def wrapper(url, *args, **kwargs): - try: - return func(url, *args, **kwargs) - except ClientError as exc: - if int(exc.response["Error"]["Code"]) == 404: - raise EnvironmentError("file {} not found".format(url)) - else: - raise - - return wrapper - - -@s3_request -def s3_etag(url): - """Check ETag on S3 object.""" - s3_resource = boto3.resource("s3") - bucket_name, s3_path = split_s3_path(url) - s3_object = s3_resource.Object(bucket_name, s3_path) - return s3_object.e_tag - - -@s3_request -def s3_get(url, temp_file): - """Pull a file directly from S3.""" - s3_resource = boto3.resource("s3") - bucket_name, s3_path = split_s3_path(url) - s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) - - -def http_get(url, temp_file): - req = requests.get(url, stream=True) - content_length = req.headers.get('Content-Length') - total = int(content_length) if content_length is not None else None - progress = tqdm(unit="B", total=total) - for chunk in req.iter_content(chunk_size=1024): - if chunk: # filter out keep-alive new chunks - progress.update(len(chunk)) - temp_file.write(chunk) - progress.close() - - -def get_from_cache(url, cache_dir=None): - """ - Given a URL, look for the corresponding dataset in the local cache. - If it's not there, download it. Then return the path to the cached file. - """ - if cache_dir is None: - cache_dir = PYTORCH_PRETRAINED_BERT_CACHE - if sys.version_info[0] == 3 and isinstance(cache_dir, Path): - cache_dir = str(cache_dir) - - if not os.path.exists(cache_dir): - os.makedirs(cache_dir) - - # Get eTag to add to filename, if it exists. - if url.startswith("s3://"): - etag = s3_etag(url) - else: - response = requests.head(url, allow_redirects=True) - if response.status_code != 200: - raise IOError("HEAD request failed for url {} with status code {}" - .format(url, response.status_code)) - etag = response.headers.get("ETag") - - filename = url_to_filename(url, etag) - - # get cache path to put the file - cache_path = os.path.join(cache_dir, filename) - - if not os.path.exists(cache_path): - # Download to temporary file, then copy to cache dir once finished. - # Otherwise you get corrupt cache entries if the download gets interrupted. - with tempfile.NamedTemporaryFile() as temp_file: - logger.info("%s not found in cache, downloading to %s", url, temp_file.name) - - # GET file object - if url.startswith("s3://"): - s3_get(url, temp_file) - else: - http_get(url, temp_file) - - # we are copying the file before closing it, so flush to avoid truncation - temp_file.flush() - # shutil.copyfileobj() starts at the current position, so go to the start - temp_file.seek(0) - - logger.info("copying %s to cache at %s", temp_file.name, cache_path) - with open(cache_path, 'wb') as cache_file: - shutil.copyfileobj(temp_file, cache_file) - - logger.info("creating metadata file for %s", cache_path) - meta = {'url': url, 'etag': etag} - meta_path = cache_path + '.json' - with open(meta_path, 'w', encoding="utf-8") as meta_file: - json.dump(meta, meta_file) - - logger.info("removing temp file %s", temp_file.name) - - return cache_path - - def read_set_from_file(filename): """ Extract a de-duped collection (set) of text from a file. diff --git a/utils/tokenization.py b/utils/tokenization.py index 8761998..d589a2d 100644 --- a/utils/tokenization.py +++ b/utils/tokenization.py @@ -25,13 +25,12 @@ logger = logging.getLogger(__name__) PRETRAINED_VOCAB_ARCHIVE_MAP = { - 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", - 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", - 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", - 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", - 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", - 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", - 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", + 'bert-base-uncased': os.path.join(os.pardir, 'hedwig-data', 'models', 'bert_pretrained', 'bert-base-uncased-vocab.txt'), + 'bert-large-uncased': os.path.join(os.pardir, 'hedwig-data', 'models', 'bert_pretrained', 'bert-large-uncased-vocab.txt'), + 'bert-base-cased': os.path.join(os.pardir, 'hedwig-data', 'models', 'bert_pretrained', 'bert-base-cased-vocab.txt'), + 'bert-large-cased': os.path.join(os.pardir, 'hedwig-data', 'models', 'bert_pretrained', 'bert-large-cased-vocab.txt'), + 'bert-base-multilingual-uncased': os.path.join(os.pardir, 'hedwig-data', 'models', 'bert_pretrained', 'bert-base-multilingual-uncased-vocab.txt'), + 'bert-base-multilingual-cased': os.path.join(os.pardir, 'hedwig-data', 'models', 'bert_pretrained', 'bert-base-multilingual-cased-vocab.txt') } PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 'bert-base-uncased': 512,