Skip to content

Commit

Permalink
feat: Download session headers (#1429)
Browse files Browse the repository at this point in the history
* feat: some changes to provide more precise statistics

* feat: update headers for proper stats
  • Loading branch information
IgnatovFedor authored Apr 20, 2021
1 parent 9dc3587 commit e3283fd
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 19 deletions.
20 changes: 12 additions & 8 deletions deeppavlov/core/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
tqdm.monitor_interval = 0


def _get_download_token() -> str:
def get_download_token() -> str:
"""Return a download token from ~/.deeppavlov/token file.
If token file does not exists, creates the file and writes to it a random URL-safe text string
Expand Down Expand Up @@ -78,14 +78,15 @@ def s3_download(url: str, destination: str) -> None:
file_object.download_file(destination, Callback=pbar.update)


def simple_download(url: str, destination: Union[Path, str]) -> None:
def simple_download(url: str, destination: Union[Path, str], headers: Optional[dict] = None) -> None:
"""Download a file from URL to target location.
Displays a progress bar to the terminal during the download process.
Args:
url: The source URL.
destination: Path to the file destination (including file name).
headers: Headers for file server.
"""
destination = Path(destination)
Expand All @@ -99,7 +100,6 @@ def simple_download(url: str, destination: Union[Path, str]) -> None:
chunk_size = 32 * 1024
temporary = destination.with_suffix(destination.suffix + '.part')

headers = {'dp-token': _get_download_token()}
r = requests.get(url, stream=True, headers=headers)
if r.status_code != 200:
raise RuntimeError(f'Got status code {r.status_code} when trying to download {url}')
Expand Down Expand Up @@ -137,13 +137,15 @@ def simple_download(url: str, destination: Union[Path, str]) -> None:
temporary.rename(destination)


def download(dest_file_path: [List[Union[str, Path]]], source_url: str, force_download: bool = True) -> None:
def download(dest_file_path: [List[Union[str, Path]]], source_url: str, force_download: bool = True,
headers: Optional[dict] = None) -> None:
"""Download a file from URL to one or several target locations.
Args:
dest_file_path: Path or list of paths to the file destination (including file name).
source_url: The source URL.
force_download: Download file if it already exists, or not.
headers: Headers for file server.
"""

Expand Down Expand Up @@ -173,7 +175,7 @@ def download(dest_file_path: [List[Union[str, Path]]], source_url: str, force_do
if not cached_exists:
first_dest_path.parent.mkdir(parents=True, exist_ok=True)

simple_download(source_url, first_dest_path)
simple_download(source_url, first_dest_path, headers)
else:
log.info(f'Found cached {source_url} in {first_dest_path}')

Expand Down Expand Up @@ -223,7 +225,8 @@ def ungzip(file_path: Union[Path, str], extract_path: Optional[Union[Path, str]]

def download_decompress(url: str,
download_path: Union[Path, str],
extract_paths: Optional[Union[List[Union[Path, str]], Path, str]] = None) -> None:
extract_paths: Optional[Union[List[Union[Path, str]], Path, str]] = None,
headers: Optional[dict] = None) -> None:
"""Download and extract .tar.gz or .gz file to one or several target locations.
The archive is deleted if extraction was successful.
Expand All @@ -232,6 +235,7 @@ def download_decompress(url: str,
url: URL for file downloading.
download_path: Path to the directory where downloaded file will be stored until the end of extraction.
extract_paths: Path or list of paths where contents of archive will be extracted.
headers: Headers for file server.
"""
file_name = Path(urlparse(url).path).name
Expand All @@ -253,15 +257,15 @@ def download_decompress(url: str,
extracted_path = cache_dir / (url_hash + '_extracted')
extracted = extracted_path.exists()
if not extracted and not arch_file_path.exists():
simple_download(url, arch_file_path)
simple_download(url, arch_file_path, headers)
else:
if extracted:
log.info(f'Found cached and extracted {url} in {extracted_path}')
else:
log.info(f'Found cached {url} in {arch_file_path}')
else:
arch_file_path = download_path / file_name
simple_download(url, arch_file_path)
simple_download(url, arch_file_path, headers)
extracted_path = extract_paths.pop()

if not extracted:
Expand Down
30 changes: 19 additions & 11 deletions deeppavlov/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import secrets
import shutil
import sys
from argparse import ArgumentParser, Namespace
Expand All @@ -20,14 +21,13 @@
from pathlib import Path
from typing import Union, Optional, Dict, Iterable, Set, Tuple, List
from urllib.parse import urlparse

import requests
from filelock import FileLock

import deeppavlov
from deeppavlov.core.commands.utils import expand_path, parse_config
from deeppavlov.core.data.utils import download, download_decompress, get_all_elems_from_json, file_md5, \
set_query_parameter, path_set_md5
set_query_parameter, path_set_md5, get_download_token

log = getLogger(__name__)

Expand Down Expand Up @@ -77,7 +77,7 @@ def get_configs_downloads(config: Optional[Union[str, Path, dict]] = None) -> Di
return all_downloads


def check_md5(url: str, dest_paths: List[Path]) -> bool:
def check_md5(url: str, dest_paths: List[Path], headers: Optional[dict] = None) -> bool:
url_md5 = path_set_md5(url)

try:
Expand All @@ -89,7 +89,7 @@ def check_md5(url: str, dest_paths: List[Path]) -> bool:
obj = s3.Object(bucket, key)
data = obj.get()['Body'].read().decode('utf8')
else:
r = requests.get(url_md5)
r = requests.get(url_md5, headers=headers)
if r.status_code != 200:
return False
data = r.text
Expand Down Expand Up @@ -126,21 +126,21 @@ def check_md5(url: str, dest_paths: List[Path]) -> bool:
return True


def download_resource(url: str, dest_paths: Iterable[Union[Path, str]]) -> None:
def download_resource(url: str, dest_paths: Iterable[Union[Path, str]], headers: Optional[dict] = None) -> None:
dest_paths = [Path(dest) for dest in dest_paths]
download_path = dest_paths[0].parent
download_path.mkdir(parents=True, exist_ok=True)
file_name = urlparse(url).path.split('/')[-1]
lockfile = download_path / f'.{file_name}.lock'

with FileLock(lockfile).acquire(poll_intervall=10):
if check_md5(url, dest_paths):
if check_md5(url, dest_paths, headers):
log.info(f'Skipped {url} download because of matching hashes')
elif any(ext in url for ext in ('.tar.gz', '.gz', '.zip')):
download_decompress(url, download_path, dest_paths)
download_decompress(url, download_path, dest_paths, headers=headers)
else:
dest_files = [dest_path / file_name for dest_path in dest_paths]
download(dest_files, url)
download(dest_files, url, headers=headers)


def download_resources(args: Namespace) -> None:
Expand All @@ -159,11 +159,19 @@ def download_resources(args: Namespace) -> None:

def deep_download(config: Union[str, Path, dict]) -> None:
downloads = get_configs_downloads(config)

for url, dest_paths in downloads.items():
last_id = len(downloads) - 1
session_id = secrets.token_urlsafe(32)

for file_id, (url, dest_paths) in enumerate(downloads.items()):
headers = {
'dp-token': get_download_token(),
'dp-session': session_id,
'dp-file-id': str(last_id - file_id),
'dp-version': deeppavlov.__version__
}
if not url.startswith('s3://') and not isinstance(config, dict):
url = set_query_parameter(url, 'config', Path(config).stem)
download_resource(url, dest_paths)
download_resource(url, dest_paths, headers)


def main(args: Optional[List[str]] = None) -> None:
Expand Down

0 comments on commit e3283fd

Please sign in to comment.