Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Download session headers #1429

Merged
merged 2 commits into from
Apr 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions deeppavlov/core/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
tqdm.monitor_interval = 0


def _get_download_token() -> str:
def get_download_token() -> str:
"""Return a download token from ~/.deeppavlov/token file.
If token file does not exists, creates the file and writes to it a random URL-safe text string
Expand Down Expand Up @@ -78,14 +78,15 @@ def s3_download(url: str, destination: str) -> None:
file_object.download_file(destination, Callback=pbar.update)


def simple_download(url: str, destination: Union[Path, str]) -> None:
def simple_download(url: str, destination: Union[Path, str], headers: Optional[dict] = None) -> None:
oserikov marked this conversation as resolved.
Show resolved Hide resolved
"""Download a file from URL to target location.
Displays a progress bar to the terminal during the download process.
Args:
url: The source URL.
destination: Path to the file destination (including file name).
headers: Headers for file server.
"""
destination = Path(destination)
Expand All @@ -99,7 +100,6 @@ def simple_download(url: str, destination: Union[Path, str]) -> None:
chunk_size = 32 * 1024
temporary = destination.with_suffix(destination.suffix + '.part')

headers = {'dp-token': _get_download_token()}
r = requests.get(url, stream=True, headers=headers)
if r.status_code != 200:
raise RuntimeError(f'Got status code {r.status_code} when trying to download {url}')
Expand Down Expand Up @@ -137,13 +137,15 @@ def simple_download(url: str, destination: Union[Path, str]) -> None:
temporary.rename(destination)


def download(dest_file_path: [List[Union[str, Path]]], source_url: str, force_download: bool = True) -> None:
def download(dest_file_path: [List[Union[str, Path]]], source_url: str, force_download: bool = True,
headers: Optional[dict] = None) -> None:
"""Download a file from URL to one or several target locations.
Args:
dest_file_path: Path or list of paths to the file destination (including file name).
source_url: The source URL.
force_download: Download file if it already exists, or not.
headers: Headers for file server.
"""

Expand Down Expand Up @@ -173,7 +175,7 @@ def download(dest_file_path: [List[Union[str, Path]]], source_url: str, force_do
if not cached_exists:
first_dest_path.parent.mkdir(parents=True, exist_ok=True)

simple_download(source_url, first_dest_path)
simple_download(source_url, first_dest_path, headers)
else:
log.info(f'Found cached {source_url} in {first_dest_path}')

Expand Down Expand Up @@ -223,7 +225,8 @@ def ungzip(file_path: Union[Path, str], extract_path: Optional[Union[Path, str]]

def download_decompress(url: str,
oserikov marked this conversation as resolved.
Show resolved Hide resolved
download_path: Union[Path, str],
extract_paths: Optional[Union[List[Union[Path, str]], Path, str]] = None) -> None:
extract_paths: Optional[Union[List[Union[Path, str]], Path, str]] = None,
headers: Optional[dict] = None) -> None:
"""Download and extract .tar.gz or .gz file to one or several target locations.
The archive is deleted if extraction was successful.
Expand All @@ -232,6 +235,7 @@ def download_decompress(url: str,
url: URL for file downloading.
download_path: Path to the directory where downloaded file will be stored until the end of extraction.
extract_paths: Path or list of paths where contents of archive will be extracted.
headers: Headers for file server.
"""
file_name = Path(urlparse(url).path).name
Expand All @@ -253,15 +257,15 @@ def download_decompress(url: str,
extracted_path = cache_dir / (url_hash + '_extracted')
extracted = extracted_path.exists()
if not extracted and not arch_file_path.exists():
simple_download(url, arch_file_path)
simple_download(url, arch_file_path, headers)
else:
if extracted:
log.info(f'Found cached and extracted {url} in {extracted_path}')
else:
log.info(f'Found cached {url} in {arch_file_path}')
else:
arch_file_path = download_path / file_name
simple_download(url, arch_file_path)
simple_download(url, arch_file_path, headers)
extracted_path = extract_paths.pop()

if not extracted:
Expand Down
30 changes: 19 additions & 11 deletions deeppavlov/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import secrets
import shutil
import sys
from argparse import ArgumentParser, Namespace
Expand All @@ -20,14 +21,13 @@
from pathlib import Path
from typing import Union, Optional, Dict, Iterable, Set, Tuple, List
from urllib.parse import urlparse

import requests
from filelock import FileLock

import deeppavlov
from deeppavlov.core.commands.utils import expand_path, parse_config
from deeppavlov.core.data.utils import download, download_decompress, get_all_elems_from_json, file_md5, \
set_query_parameter, path_set_md5
set_query_parameter, path_set_md5, get_download_token

log = getLogger(__name__)

Expand Down Expand Up @@ -77,7 +77,7 @@ def get_configs_downloads(config: Optional[Union[str, Path, dict]] = None) -> Di
return all_downloads


def check_md5(url: str, dest_paths: List[Path]) -> bool:
def check_md5(url: str, dest_paths: List[Path], headers: Optional[dict] = None) -> bool:
url_md5 = path_set_md5(url)

try:
Expand All @@ -89,7 +89,7 @@ def check_md5(url: str, dest_paths: List[Path]) -> bool:
obj = s3.Object(bucket, key)
data = obj.get()['Body'].read().decode('utf8')
else:
r = requests.get(url_md5)
r = requests.get(url_md5, headers=headers)
if r.status_code != 200:
return False
data = r.text
Expand Down Expand Up @@ -126,21 +126,21 @@ def check_md5(url: str, dest_paths: List[Path]) -> bool:
return True


def download_resource(url: str, dest_paths: Iterable[Union[Path, str]]) -> None:
def download_resource(url: str, dest_paths: Iterable[Union[Path, str]], headers: Optional[dict] = None) -> None:
dest_paths = [Path(dest) for dest in dest_paths]
download_path = dest_paths[0].parent
download_path.mkdir(parents=True, exist_ok=True)
file_name = urlparse(url).path.split('/')[-1]
lockfile = download_path / f'.{file_name}.lock'

with FileLock(lockfile).acquire(poll_intervall=10):
if check_md5(url, dest_paths):
if check_md5(url, dest_paths, headers):
log.info(f'Skipped {url} download because of matching hashes')
elif any(ext in url for ext in ('.tar.gz', '.gz', '.zip')):
download_decompress(url, download_path, dest_paths)
download_decompress(url, download_path, dest_paths, headers=headers)
else:
dest_files = [dest_path / file_name for dest_path in dest_paths]
download(dest_files, url)
download(dest_files, url, headers=headers)


def download_resources(args: Namespace) -> None:
Expand All @@ -159,11 +159,19 @@ def download_resources(args: Namespace) -> None:

def deep_download(config: Union[str, Path, dict]) -> None:
downloads = get_configs_downloads(config)

for url, dest_paths in downloads.items():
last_id = len(downloads) - 1
session_id = secrets.token_urlsafe(32)

for file_id, (url, dest_paths) in enumerate(downloads.items()):
headers = {
'dp-token': get_download_token(),
'dp-session': session_id,
'dp-file-id': str(last_id - file_id),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

last_id - file_id
вау а что это за математика?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

В рамках одной сессии скачивания моделей мы имеей N файлов. last_id-file_id нужно чтобы нумеровать эти файлы от N-1 до 0. Получив dp-file-id == 0 мы понимаем, что это был последний файл в рамках данной сессии. С тем же успехом можно было добавить какой-нибудь хэдер вроде dp-last-session-file для последнего файла, но сделал обратный отсчёт.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ок, но я бы вынес в отдельную переменную, а то с ума сойдём потом

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Т.е. добавить dp_file_id = last_id - file_id, а в дикте 'dp-file-id': str(dp_file_id)? Не вижу смысла. Или ты про другую отдельную переменную?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

смысл был только в том, чтобы было понятное название у этой разности, но и так нормас

'dp-version': deeppavlov.__version__
}
if not url.startswith('s3://') and not isinstance(config, dict):
url = set_query_parameter(url, 'config', Path(config).stem)
download_resource(url, dest_paths)
download_resource(url, dest_paths, headers)


def main(args: Optional[List[str]] = None) -> None:
Expand Down