Skip to content

Commit

Permalink
Reuse ExternalToolVersion
Browse files Browse the repository at this point in the history
  • Loading branch information
grihabor committed Jan 11, 2025
1 parent 4b8a446 commit 6efda67
Showing 1 changed file with 25 additions and 19 deletions.
44 changes: 25 additions & 19 deletions build-support/bin/external_tool_versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,23 @@
pants run build-support/bin:external-tool-versions -- --tool pants.backend.k8s.kubectl_subsystem:Kubectl > list.txt
"""

import argparse
import hashlib
import importlib
import logging
import re
import xml.etree.ElementTree as ET
from collections.abc import Callable, Iterator
from dataclasses import dataclass
from multiprocessing.pool import ThreadPool
from string import Formatter
from urllib.parse import urlparse

import requests

logger = logging.getLogger(__name__)

from pants.core.util_rules.external_tool import ExternalToolVersion

@dataclass(frozen=True)
class VersionHash:
version: str
platform: str
size: int
sha256: str
logger = logging.getLogger(__name__)


def format_string_to_regex(format_string: str) -> re.Pattern:
Expand Down Expand Up @@ -80,19 +74,25 @@ def get_k8s_versions(url_template: str, pool: ThreadPool) -> Iterator[str]:
}


def fetch_version(url_template: str, version: str, platform: str) -> VersionHash | None:
url = url_template.format(version=version, platform=platform)
def fetch_version(
*,
url_template: str,
version: str,
platform: str,
platform_mapping: dict[str, str],
) -> ExternalToolVersion | None:
url = url_template.format(version=version, platform=platform_mapping[platform])
response = requests.get(url, allow_redirects=True)
if response.status_code != 200:
logger.error("failed to fetch version: %s\n%s", version, response.text)
return None

size = len(response.content)
sha256 = hashlib.sha256(response.content)
return VersionHash(
return ExternalToolVersion(
version=version,
platform=platform,
size=size,
filesize=size,
sha256=sha256.hexdigest(),
)

Expand Down Expand Up @@ -135,20 +135,26 @@ def main():

platforms = args.platforms.split(",")
platform_mapping = cls.default_url_platform_mapping
mapped_platforms = {platform_mapping.get(p) for p in platforms}

domain = urlparse(cls.default_url_template).netloc
get_versions = DOMAIN_TO_VERSIONS_MAPPING[domain]
pool = ThreadPool(processes=args.workers)
results = []
for version in get_versions(cls.default_url_template, pool):
for platform in mapped_platforms:
for platform in platforms:
logger.debug("fetching version: %s %s", version, platform)
results.append(
pool.apply_async(fetch_version, args=(cls.default_url_template, version, platform))
pool.apply_async(
fetch_version,
kwds=dict(
version=version,
platform=platform,
url_template=cls.default_url_template,
platform_mapping=platform_mapping,
),
)
)

backward_platform_mapping = {v: k for k, v in platform_mapping.items()}
for result in results:
v = result.get(60)
if v is None:
Expand All @@ -157,9 +163,9 @@ def main():
"|".join(
[
v.version,
backward_platform_mapping[v.platform],
v.platform,
v.sha256,
str(v.size),
str(v.filesize),
]
)
)
Expand Down

0 comments on commit 6efda67

Please sign in to comment.