Skip to content

Commit

Permalink
Make asset download multi-threaded
Browse files Browse the repository at this point in the history
  • Loading branch information
thesadru committed Dec 30, 2023
1 parent 1335940 commit 9fe7dec
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 32 deletions.
2 changes: 1 addition & 1 deletion arkprts/assets/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
parser.add_argument("--allow", type=str, default="gamedata/excel/*", help="Files allowed to be downloaded.")
parser.add_argument("--force", action="store_true", default=False, help="Force new files to be downloaded")
parser.add_argument("--log-level", type=str, default="INFO", help="Logging level")
parser.add_argument("--server", type=str, default="en", help="Server to use, global only")
parser.add_argument("--server", type=str, default="en", help="Server to use, can be 'all'")
parser.add_argument("--normalize", action="store_true", help="Reformat files into a normalized expanded format")


Expand Down
87 changes: 58 additions & 29 deletions arkprts/assets/bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
"""
from __future__ import annotations

import asyncio
import concurrent.futures
import fnmatch
import functools
import io
import json
import logging
Expand All @@ -31,6 +34,8 @@
UnityPyAsset = typing.Any
UnityPyObject = typing.Any

UPDATED_FBS = {"cn": False, "yostar": False}


def asset_path_to_server_filename(path: str) -> str:
"""Take a path to a zipped unity asset and return its filename on the server."""
Expand All @@ -47,12 +52,18 @@ def unzip_only_file(stream: io.BytesIO | bytes) -> bytes:
return archive.read(archive.namelist()[0])


def resolve_unity_asset_cache(filename: str, server: netn.ArknightsServer) -> pathlib.Path:
"""Resolve a path to a cached arknights ab file."""
path = pathlib.Path(tempfile.gettempdir()) / "ArknightsUnity" / filename
path.parent.mkdir(parents=True, exist_ok=True)
return path.with_suffix(".ab")


def load_unity_file(stream: io.BytesIO | bytes) -> bytes:
"""Load a zipped arknights unity .ab file."""
"""Load an unzipped arknights unity .ab file."""
import UnityPy

ab_data = unzip_only_file(stream)
env: typing.Any = UnityPy.load(io.BytesIO(ab_data)) # pyright: ignore
env: typing.Any = UnityPy.load(io.BytesIO(stream)) # pyright: ignore

bundle_file, *_ = env.files.values()
assert not _
Expand Down Expand Up @@ -129,6 +140,10 @@ def resolve_fbs_schema_directory(server: typing.Literal["cn", "yostar"]) -> path
async def update_fbs_schema(*, force: bool = False) -> None:
"""Download or otherwise update FBS files."""
for server, branch in [("cn", "main"), ("yostar", "YoStar")]:
if UPDATED_FBS[server] and not force:
continue

UPDATED_FBS[server] = True
directory = resolve_fbs_schema_directory(server).parent
await git.update_repository("MooncellWiki/OpenArknightsFBS", directory, branch=branch, force=force)

Expand Down Expand Up @@ -354,43 +369,49 @@ def _get_current_hot_update_list(self, server: netn.ArknightsServer) -> typing.A
with path.open("r") as file:
return json.load(file)

async def _download_unity_asset(
async def _download_unity_file(
self,
path: str,
*,
save: bool = True,
server: netn.ArknightsServer | None = None,
) -> UnityPyAsset:
"""Download an asset as a UnityPy asset."""
LOGGER.debug("Downloading and extracting asset %s for server %s", path, server)
data = await self._download_asset(path, server=server)
return load_unity_file(data)
) -> bytes:
"""Download an asset and return it unzipped."""
LOGGER.debug("Downloading and unzipping asset %s for server %s", path, server)
zipped_data = await self._download_asset(path, server=server)
data = unzip_only_file(zipped_data)
if save:
p = resolve_unity_asset_cache(path, server=server or self.default_server)
p.write_bytes(data)

async def _download_and_save(
return data

def _parse_and_save(
self,
path: str,
data: bytes,
*,
target_container: str | None = None,
server: netn.ArknightsServer | None = None,
normalize: bool = False,
) -> typing.AsyncIterable[tuple[str, bytes]]:
) -> typing.Iterable[tuple[str, bytes]]:
"""Download and extract an asset."""
server = server or self.default_server

asset = await self._download_unity_asset(path, server=server)
asset = load_unity_file(data)

fetched_any = False
for fetched_any, (path, data) in enumerate(
for fetched_any, (unpacked_rel_path, unpacked_data) in enumerate(
unpack_assets(asset, target_container, server=server, normalize=normalize),
1,
):
savepath = self.directory / server / path
savepath = self.directory / server / unpacked_rel_path
savepath.parent.mkdir(exist_ok=True, parents=True)
savepath.write_bytes(data)
savepath.write_bytes(unpacked_data)

yield (path, data)
yield (unpacked_rel_path, unpacked_data)

if not fetched_any:
warnings.warn(f"Unpacking {path} (container: {target_container}) yielded no results")
warnings.warn(f"Unpacking yielded no results (container: {target_container}) ")

async def update_assets(
self,
Expand All @@ -407,7 +428,7 @@ async def update_assets(
server = server or self.default_server or "all"
if server == "all":
for server in netn.NETWORK_ROUTES:
await self.update_assets(allow, server=server)
await self.update_assets(allow, server=server, force=force, normalize=normalize)

return

Expand All @@ -422,15 +443,23 @@ async def update_assets(
if any("gamedata" in name for name in requested_names):
await update_fbs_schema()

# sequential doesn't matter since most of the time is spent unpacking
# Fix this once images come into play (threadpoolexecutor and such)
# first download all .ab files in a temporary directory then start extracting them.
for name in requested_names:
try:
async for path, _ in self._download_and_save(name, server=server, normalize=normalize):
LOGGER.debug("Downloaded asset %s from %s for server %s", path, name, server)
except Exception as e:
LOGGER.exception("Failed to download asset %s for server %s", name, server, exc_info=e)
datas = await asyncio.gather(*(self._download_unity_file(name, server=server) for name in requested_names))
loop = asyncio.get_event_loop()
# this should be a ProcessPoolExecutor but pickling is a problem in classes
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [
loop.run_in_executor(
executor,
functools.partial(self._parse_and_save, d, server=server, normalize=normalize),
)
for d in datas
]
for name, f in zip(requested_names, asyncio.as_completed(futures)):
try:
for path, _ in await f:
LOGGER.debug("Extracted asset %s from %s for server %s", path, name, server)
except Exception as e:
LOGGER.exception("Failed to extract asset %s for server %s", name, server, exc_info=e)

hot_update_list_path = self.directory / server / "hot_update_list.json"
hot_update_list_path.parent.mkdir(parents=True, exist_ok=True)
Expand All @@ -452,7 +481,7 @@ async def aget_file(self, path: str, *, server: netn.ArknightsServer | None = No
raise ValueError("No viable asset path found, please load all assets and use get_file.")

for potential_asset_path in asset_paths:
asset = await self._download_unity_asset(potential_asset_path, server=server)
asset = load_unity_file(await self._download_unity_file(potential_asset_path, server=server))
for output_path, data in unpack_assets(asset, path, server=server):
if save:
savepath = self.directory / server / output_path
Expand Down
2 changes: 1 addition & 1 deletion arkprts/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self, client: CoreClient | None = None, **kwargs: typing.Any) -> No
if client:
_set_recursively(self, "client", client)

@pydantic.model_validator(mode="before") # pyright: ignore[reportUnknownMemberType]
@pydantic.model_validator(mode="before") # pyright: ignore
def _fix_amiya(cls, value: typing.Any, info: pydantic.ValidationInfo) -> typing.Any:
"""Flatten Amiya to only keep her selected form if applicable."""
if value and value.get("tmpl"):
Expand Down
2 changes: 1 addition & 1 deletion dev-requirements/pytest.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pytest
pytest-asyncio
pytest-asyncio==0.21.1
pytest-dotenv
pytest-cov
coverage[toml]

0 comments on commit 9fe7dec

Please sign in to comment.