Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove model caching mechanism for bert and hbert #42

Merged
merged 32 commits into from
Nov 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
2b7923e
Integrate BERT into Hedwig (#29)
achyudh Apr 14, 2019
cb14201
Resolve conflicts in the dev fork
achyudh Apr 14, 2019
8346514
Merge branch 'karkaroff-master'
achyudh Apr 14, 2019
fff8e0a
Resolve merge conflicts in README.md
achyudh Apr 14, 2019
0979f77
Add TREC relevance datasets
achyudh Apr 19, 2019
e5f2ee0
Add relevance transfer trainer and evaluator
achyudh Apr 19, 2019
57f0680
Add re-ranking module
achyudh Apr 19, 2019
7d26d71
Add ImbalancedDatasetSampler
achyudh Apr 19, 2019
eab4fc2
Add relevance transfer package
achyudh Apr 19, 2019
a08b2d1
Fix import in classification trainer
achyudh Apr 19, 2019
cb3ca31
Merge remote-tracking branch 'castorini/master'
achyudh Apr 19, 2019
0890eae
Remove unwanted args from models/bert
achyudh Apr 29, 2019
a8de77c
Merge remote-tracking branch 'castorini/master'
achyudh Apr 29, 2019
1116c64
Fix bug where model wasn't in training mode every epoch
achyudh May 2, 2019
8c36691
Merge remote-tracking branch 'castorini/master'
achyudh May 2, 2019
0f34aa0
Add Robust45 preprocessor for BERT
achyudh May 5, 2019
7bed0f1
Add support for BERT for relevance transfer
achyudh May 5, 2019
6c8c728
Add hierarchical BERT model
achyudh Jul 3, 2019
615fa27
Remove tensorboardX logging
achyudh Jul 7, 2019
b40cccb
Add hierarchical BERT for relevance transfer
achyudh Jul 7, 2019
70ec667
Merge remote-tracking branch 'castorini/master'
achyudh Jul 7, 2019
1b031a8
Add learning rate multiplier
achyudh Sep 1, 2019
a987e2c
Merge branch 'master' of github.com:castorini/hedwig
achyudh Sep 1, 2019
e81cfff
Add lr multiplier for relevance transfer
achyudh Sep 2, 2019
4758607
Add MLP model
achyudh Sep 7, 2019
289cde0
Add fastText model
achyudh Sep 8, 2019
12a09da
Add Reuters bag-of-words dataset class
achyudh Sep 8, 2019
bcf1dca
Add input dropout for MLP
achyudh Sep 8, 2019
7aeded5
Merge branch 'master' of github.com:castorini/hedwig
achyudh Sep 8, 2019
448b087
Remove duplicate README files
achyudh Sep 8, 2019
71a2df3
Remove model caching mechanism for bert and hbert
achyudh Nov 1, 2019
7899780
Merge branch 'master' of github.com:castorini/hedwig
achyudh Nov 1, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions models/bert/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,12 @@
logger = logging.getLogger(__name__)

PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
'bert-base-uncased': os.path.join(os.pardir, 'hedwig-data', 'models', 'bert_pretrained', 'bert-base-uncased.tar.gz'),
'bert-large-uncased': os.path.join(os.pardir, 'hedwig-data', 'models', 'bert_pretrained', 'bert-large-uncased.tar.gz'),
'bert-base-cased': os.path.join(os.pardir, 'hedwig-data', 'models', 'bert_pretrained', 'bert-base-cased.tar.gz'),
'bert-large-cased': os.path.join(os.pardir, 'hedwig-data', 'models', 'bert_pretrained', 'bert-large-cased.tar.gz'),
'bert-base-multilingual-uncased': os.path.join(os.pardir, 'hedwig-data', 'models', 'bert_pretrained', 'bert-base-multilingual-uncased.tar.gz'),
'bert-base-multilingual-cased': os.path.join(os.pardir, 'hedwig-data', 'models', 'bert_pretrained', 'bert-base-multilingual-cased.tar.gz')
}
CONFIG_NAME = 'bert_config.json'
WEIGHTS_NAME = 'pytorch_model.bin'
Expand Down
180 changes: 1 addition & 179 deletions utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,11 @@

from __future__ import absolute_import, division, print_function, unicode_literals

import json
import logging
import os
import shutil
import sys
import tempfile
from functools import wraps
from hashlib import sha256
from io import open

import boto3
import requests
from botocore.exceptions import ClientError
from tqdm import tqdm

from urllib.parse import urlparse

try:
Expand All @@ -43,50 +33,6 @@
logger = logging.getLogger(__name__) # pylint: disable=invalid-name


def url_to_filename(url, etag=None):
"""
Convert `url` into a hashed filename in a repeatable way.
If `etag` is specified, append its hash to the url's, delimited
by a period.
"""
url_bytes = url.encode('utf-8')
url_hash = sha256(url_bytes)
filename = url_hash.hexdigest()

if etag:
etag_bytes = etag.encode('utf-8')
etag_hash = sha256(etag_bytes)
filename += '.' + etag_hash.hexdigest()

return filename


def filename_to_url(filename, cache_dir=None):
"""
Return the url and etag (which may be ``None``) stored for `filename`.
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir)

cache_path = os.path.join(cache_dir, filename)
if not os.path.exists(cache_path):
raise EnvironmentError("file {} not found".format(cache_path))

meta_path = cache_path + '.json'
if not os.path.exists(meta_path):
raise EnvironmentError("file {} not found".format(meta_path))

with open(meta_path, encoding="utf-8") as meta_file:
metadata = json.load(meta_file)
url = metadata['url']
etag = metadata['etag']

return url, etag


def cached_path(url_or_filename, cache_dir=None):
"""
Given something that might be a URL (or might be a local path),
Expand All @@ -103,10 +49,7 @@ def cached_path(url_or_filename, cache_dir=None):

parsed = urlparse(url_or_filename)

if parsed.scheme in ('http', 'https', 's3'):
# URL, so get it from the cache (downloading if necessary)
return get_from_cache(url_or_filename, cache_dir)
elif os.path.exists(url_or_filename):
if os.path.exists(url_or_filename):
# File, and it exists.
return url_or_filename
elif parsed.scheme == '':
Expand All @@ -117,127 +60,6 @@ def cached_path(url_or_filename, cache_dir=None):
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))


def split_s3_path(url):
"""Split a full s3 path into the bucket name and path."""
parsed = urlparse(url)
if not parsed.netloc or not parsed.path:
raise ValueError("bad s3 path {}".format(url))
bucket_name = parsed.netloc
s3_path = parsed.path
# Remove '/' at beginning of path.
if s3_path.startswith("/"):
s3_path = s3_path[1:]
return bucket_name, s3_path


def s3_request(func):
"""
Wrapper function for s3 requests in order to create more helpful error
messages.
"""

@wraps(func)
def wrapper(url, *args, **kwargs):
try:
return func(url, *args, **kwargs)
except ClientError as exc:
if int(exc.response["Error"]["Code"]) == 404:
raise EnvironmentError("file {} not found".format(url))
else:
raise

return wrapper


@s3_request
def s3_etag(url):
"""Check ETag on S3 object."""
s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url)
s3_object = s3_resource.Object(bucket_name, s3_path)
return s3_object.e_tag


@s3_request
def s3_get(url, temp_file):
"""Pull a file directly from S3."""
s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url)
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)


def http_get(url, temp_file):
req = requests.get(url, stream=True)
content_length = req.headers.get('Content-Length')
total = int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total)
for chunk in req.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()


def get_from_cache(url, cache_dir=None):
"""
Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir)

if not os.path.exists(cache_dir):
os.makedirs(cache_dir)

# Get eTag to add to filename, if it exists.
if url.startswith("s3://"):
etag = s3_etag(url)
else:
response = requests.head(url, allow_redirects=True)
if response.status_code != 200:
raise IOError("HEAD request failed for url {} with status code {}"
.format(url, response.status_code))
etag = response.headers.get("ETag")

filename = url_to_filename(url, etag)

# get cache path to put the file
cache_path = os.path.join(cache_dir, filename)

if not os.path.exists(cache_path):
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with tempfile.NamedTemporaryFile() as temp_file:
logger.info("%s not found in cache, downloading to %s", url, temp_file.name)

# GET file object
if url.startswith("s3://"):
s3_get(url, temp_file)
else:
http_get(url, temp_file)

# we are copying the file before closing it, so flush to avoid truncation
temp_file.flush()
# shutil.copyfileobj() starts at the current position, so go to the start
temp_file.seek(0)

logger.info("copying %s to cache at %s", temp_file.name, cache_path)
with open(cache_path, 'wb') as cache_file:
shutil.copyfileobj(temp_file, cache_file)

logger.info("creating metadata file for %s", cache_path)
meta = {'url': url, 'etag': etag}
meta_path = cache_path + '.json'
with open(meta_path, 'w', encoding="utf-8") as meta_file:
json.dump(meta, meta_file)

logger.info("removing temp file %s", temp_file.name)

return cache_path


def read_set_from_file(filename):
"""
Extract a de-duped collection (set) of text from a file.
Expand Down
13 changes: 6 additions & 7 deletions utils/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,12 @@
logger = logging.getLogger(__name__)

PRETRAINED_VOCAB_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
'bert-base-uncased': os.path.join(os.pardir, 'hedwig-data', 'models', 'bert_pretrained', 'bert-base-uncased-vocab.txt'),
'bert-large-uncased': os.path.join(os.pardir, 'hedwig-data', 'models', 'bert_pretrained', 'bert-large-uncased-vocab.txt'),
'bert-base-cased': os.path.join(os.pardir, 'hedwig-data', 'models', 'bert_pretrained', 'bert-base-cased-vocab.txt'),
'bert-large-cased': os.path.join(os.pardir, 'hedwig-data', 'models', 'bert_pretrained', 'bert-large-cased-vocab.txt'),
'bert-base-multilingual-uncased': os.path.join(os.pardir, 'hedwig-data', 'models', 'bert_pretrained', 'bert-base-multilingual-uncased-vocab.txt'),
'bert-base-multilingual-cased': os.path.join(os.pardir, 'hedwig-data', 'models', 'bert_pretrained', 'bert-base-multilingual-cased-vocab.txt')
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'bert-base-uncased': 512,
Expand Down