Skip to content

Commit

Permalink
feat: add file lock for remote files download to local path when mult…
Browse files Browse the repository at this point in the history
…iple thread environment. (#1887)
  • Loading branch information
Tridu33 authored Dec 24, 2024
1 parent b4132be commit 2d01edf
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,14 @@ def tokenize(text):
model = AutoModelForSequenceClassification.from_pretrained('bert-base-cased', num_labels=2)

from mindnlp.engine import TrainingArguments

training_args = TrainingArguments(
output_dir="bert_imdb_finetune_cpu",
save_strategy="epoch",
logging_strategy="epoch",
num_train_epochs=2.0,
learning_rate=2e-5
)
training_args = training_args.set_optimizer(name="adamw", beta1=0.8) # 手动指定优化器,OptimizerNames.SGD
training_args = training_args.set_optimizer(name="adamw", beta1=0.8) # Manually specify the optimizer, OptimizerNames.SGD

trainer = Trainer(
model=model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ echo "==========================================="
EXEC_PATH=$(pwd)
if [ ! -d "${EXEC_PATH}/data" ]; then
if [ ! -f "${EXEC_PATH}/emotion_detection.tar.gz" ]; then
wget wget https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz -O emotion_detection.tar.gz
wget https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz -O emotion_detection.tar.gz
fi
tar xvf emotion_detection.tar.gz
fi
Expand Down
30 changes: 25 additions & 5 deletions mindnlp/utils/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import json
import types
import functools
import sys
import tempfile
import time
from typing import Union, Optional, Dict, Any
Expand Down Expand Up @@ -91,7 +92,7 @@ def download_url(url, proxies=None):
Returns:
`str`: The location of the temporary file where the url was downloaded.
"""
return http_get(url, tempfile.gettempdir(), download_file_name='tmp_' + url.split('/')[-1], proxies=proxies)
return threads_exclusive_http_get(url, tempfile.gettempdir(), download_file_name='tmp_' + url.split('/')[-1], proxies=proxies)

def copy_func(f):
"""Returns a copy of a function f."""
Expand Down Expand Up @@ -142,6 +143,25 @@ def get_cache_path():
return cache_dir


def threads_exclusive_http_get(url, storage_folder=None, md5sum=None, download_file_name=None, proxies=None, headers=None):
pointer_path = os.path.join(storage_folder, download_file_name)
lock_file_path = pointer_path + ".lock"
if sys.platform != "win32":
import fcntl # pylint: disable=import-error
else:
import winfcntlock as fcntl # pylint: disable=import-error
with open(lock_file_path, 'w') as lock_file:
fd = lock_file.fileno()
try:
fcntl.flock(fd, fcntl.LOCK_EX)
file_path = http_get(url, path=storage_folder, download_file_name=download_file_name, proxies=proxies, headers=headers)
return file_path
except Exception as exp:
raise exp
finally:
fcntl.flock(fd, fcntl.LOCK_UN)


def http_get(url, path=None, md5sum=None, download_file_name=None, proxies=None, headers=None):
r"""
Download from given url, save to path.
Expand Down Expand Up @@ -628,11 +648,11 @@ def download(
else:
headers = {}
try:
pointer_path = http_get(url, storage_folder, download_file_name=relative_filename, proxies=proxies, headers=headers)
except Exception:
pointer_path = threads_exclusive_http_get(url, storage_folder, download_file_name=relative_filename, proxies=proxies, headers=headers)
except Exception as exp:
# Otherwise, our Internet connection is down.
# etag is None
raise
raise exp

return pointer_path

Expand Down Expand Up @@ -723,7 +743,7 @@ def get_from_cache(
if os.path.exists(file_path) and check_md5(file_path, md5sum):
return file_path
try:
path = http_get(url, cache_dir, md5sum, download_file_name=filename, proxies=proxies)
path = threads_exclusive_http_get(url, cache_dir, md5sum, download_file_name=filename, proxies=proxies)
return path
except (ProxyError, SSLError) as exc:
raise exc
Expand Down
32 changes: 32 additions & 0 deletions mindnlp/utils/winfcntlock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""fcntl replacement for Windows."""
import win32con # pylint: disable=import-error
import pywintypes # pylint: disable=import-error
import win32file # pylint: disable=import-error


LOCK_EX = win32con.LOCKFILE_EXCLUSIVE_LOCK
LOCK_SH = 0 # The default value
LOCK_NB = win32con.LOCKFILE_FAIL_IMMEDIATELY
__overlapped = pywintypes.OVERLAPPED()

def lock(file, flags):
hfile = win32file._get_osfhandle(file.fileno())
win32file.LockFileEx(hfile, flags, 0, 0xffff0000, __overlapped)

def unlock(file):
hfile = win32file._get_osfhandle(file.fileno())
win32file.UnlockFileEx(hfile, 0, 0xffff0000, __overlapped)

0 comments on commit 2d01edf

Please sign in to comment.