Skip to content

Commit

Permalink
import: assert importer returns valid types
Browse files Browse the repository at this point in the history
  • Loading branch information
yagebu committed Dec 31, 2024
1 parent f262d5b commit 2261c27
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 59 deletions.
138 changes: 91 additions & 47 deletions src/fava/core/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import datetime
import os
import sys
import traceback
Expand Down Expand Up @@ -33,11 +34,13 @@
from fava.util.date import local_today

if TYPE_CHECKING: # pragma: no cover
import datetime
from collections.abc import Iterable
from collections.abc import Mapping
from collections.abc import Sequence
from typing import Any
from typing import Callable
from typing import ParamSpec
from typing import TypeVar

from fava.beans.abc import Directive
from fava.beans.ingest import FileMemo
Expand All @@ -46,6 +49,9 @@
HookOutput = list[tuple[str, list[Directive]]]
Hooks = Sequence[Callable[[HookOutput, Sequence[Directive]], HookOutput]]

P = ParamSpec("P")
T = TypeVar("T")


class IngestError(BeancountError):
"""An error with one of the importers."""
Expand All @@ -60,6 +66,16 @@ def __init__(self) -> None:
)


class ImporterInvalidTypeError(FavaAPIError):
"""One of the importer methods returned an unexpected type."""

def __init__(self, attr: str, expected: type[Any], actual: Any) -> None:
super().__init__(
f"Got unexpected type from importer as {attr}:"
f" expected {expected!s}, got {type(actual)!s}:"
)


class ImporterExtractError(ImporterMethodCallError):
"""Error calling extract for importer."""

Expand Down Expand Up @@ -155,61 +171,87 @@ class FileImporters:
importers: list[FileImportInfo]


def get_name(importer: BeanImporterProtocol | Importer) -> str:
"""Get the name of an importer."""
try:
if isinstance(importer, Importer):
return importer.name
return importer.name()
except Exception as err:
raise ImporterMethodCallError from err
def _catch_any(func: Callable[P, T]) -> Callable[P, T]:
"""Helper to catch any exception that might be raised by the importer."""

def wrapper(*args: P.args, **kwds: P.kwargs) -> T:
try:
return func(*args, **kwds)
except Exception as err:
if isinstance(err, ImporterInvalidTypeError):
raise
raise ImporterMethodCallError from err

def importer_identify(
importer: BeanImporterProtocol | Importer, path: Path
) -> bool:
"""Get the name of an importer."""
try:
if isinstance(importer, Importer):
return importer.identify(str(path))
return importer.identify(get_cached_file(path))
except Exception as err:
raise ImporterMethodCallError from err
return wrapper


def file_import_info(
path: Path,
importer: BeanImporterProtocol | Importer,
) -> FileImportInfo:
"""Generate info about a file with an importer."""
filename = str(path)
try:
def _assert_type(attr: str, value: T, type_: type[T]) -> T:
"""Helper to validate types return by importer methods."""
if not isinstance(value, type_):
raise ImporterInvalidTypeError(attr, type_, value)
return value


class WrappedImporter:
"""A wrapper to safely call importer methods."""

importer: BeanImporterProtocol | Importer

def __init__(self, importer: BeanImporterProtocol | Importer) -> None:
self.importer = importer

@property
@_catch_any
def name(self) -> str:
"""Get the name of the importer."""
importer = self.importer
name = (
importer.name
if isinstance(importer, Importer)
else importer.name()
)
return _assert_type("name", name, str)

@_catch_any
def identify(self: WrappedImporter, path: Path) -> bool:
"""Whether the importer is matching the file."""
importer = self.importer
matches = (
importer.identify(str(path))
if isinstance(importer, Importer)
else importer.identify(get_cached_file(path))
)
return _assert_type("identify", matches, bool)

@_catch_any
def file_import_info(self, path: Path) -> FileImportInfo:
"""Generate info about a file with an importer."""
importer = self.importer
if isinstance(importer, Importer):
account = importer.account(filename)
date = importer.date(filename)
name = importer.filename(filename)
str_path = str(path)
account = importer.account(str_path)
date = importer.date(str_path)
filename = importer.filename(str_path)
else:
file = get_cached_file(path)
account = importer.file_account(file)
date = importer.file_date(file)
name = importer.file_name(file)
except Exception as err:
raise ImporterMethodCallError from err
filename = importer.file_name(file)

return FileImportInfo(
get_name(importer),
account or "",
date or local_today(),
name or Path(filename).name,
)
return FileImportInfo(
self.name,
_assert_type("account", account or "", str),
_assert_type("date", date or local_today(), datetime.date),
_assert_type("filename", filename or path.name, str),
)


# Copied here from beangulp to minimise the imports.
_FILE_TOO_LARGE_THRESHOLD = 8 * 1024 * 1024


def find_imports(
config: Sequence[BeanImporterProtocol | Importer], directory: Path
config: Sequence[WrappedImporter], directory: Path
) -> Iterable[FileImporters]:
"""Pair files and matching importers.
Expand All @@ -223,31 +265,32 @@ def find_imports(
continue

importers = [
file_import_info(path, importer)
importer.file_import_info(path)
for importer in config
if importer_identify(importer, path)
if importer.identify(path)
]
yield FileImporters(
name=str(path), basename=path.name, importers=importers
)


def extract_from_file(
importer: BeanImporterProtocol | Importer,
wrapped_importer: WrappedImporter,
path: Path,
existing_entries: Sequence[Directive],
) -> list[Directive]:
"""Import entries from a document.
Args:
importer: The importer instance to handle the document.
wrapped_importer: The importer instance to handle the document.
path: Filesystem path to the document.
existing_entries: Existing entries.
Returns:
The list of imported entries.
"""
filename = str(path)
importer = wrapped_importer.importer
if isinstance(importer, Importer):
entries = importer.extract(filename, existing=existing_entries)
else:
Expand All @@ -269,7 +312,7 @@ def extract_from_file(

def load_import_config(
module_path: Path,
) -> tuple[Mapping[str, BeanImporterProtocol | Importer], Hooks]:
) -> tuple[Mapping[str, WrappedImporter], Hooks]:
"""Load the given import config and extract importers and hooks.
Args:
Expand Down Expand Up @@ -311,7 +354,8 @@ def load_import_config(
"not satisfy importer protocol"
)
raise ImportConfigLoadError(msg)
importers[get_name(importer)] = importer
wrapped_importer = WrappedImporter(importer)
importers[wrapped_importer.name] = wrapped_importer
return importers, hooks


Expand All @@ -320,7 +364,7 @@ class IngestModule(FavaModule):

def __init__(self, ledger: FavaLedger) -> None:
super().__init__(ledger)
self.importers: Mapping[str, BeanImporterProtocol | Importer] = {}
self.importers: Mapping[str, WrappedImporter] = {}
self.hooks: Hooks = []
self.mtime: int | None = None
self.errors: list[IngestError] = []
Expand Down Expand Up @@ -359,7 +403,7 @@ def load_file(self) -> None: # noqa: D102
try:
self.importers, self.hooks = load_import_config(module_path)
self.mtime = new_mtime
except ImportConfigLoadError as error:
except FavaAPIError as error:
msg = f"Error in import config '{module_path}': {error!s}"
self._error(msg)

Expand Down
37 changes: 25 additions & 12 deletions tests/test_core_ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@
from fava.beans.abc import Note
from fava.beans.abc import Transaction
from fava.beans.ingest import BeanImporterProtocol
from fava.core.ingest import file_import_info
from fava.core.ingest import FileImportInfo
from fava.core.ingest import filepath_in_primary_imports_folder
from fava.core.ingest import get_name
from fava.core.ingest import ImportConfigLoadError
from fava.core.ingest import importer_identify
from fava.core.ingest import ImporterExtractError
from fava.core.ingest import ImporterInvalidTypeError
from fava.core.ingest import load_import_config
from fava.core.ingest import WrappedImporter
from fava.helpers import FavaAPIError
from fava.serialisation import serialise
from fava.util.date import local_today
Expand All @@ -40,7 +39,7 @@ def test_ingest_file_import_info(
assert importer

csv_path = test_data_dir / "import.csv"
info = file_import_info(csv_path, importer)
info = importer.file_import_info(csv_path)
assert info.account == "Assets:Checking"


Expand All @@ -49,7 +48,7 @@ def __init__(self, acc: str = "Assets:Checking") -> None:
self.acc = acc

def name(self) -> str:
return self.acc
return f"MinimalImporter({self.acc})"

def identify(self, file: FileMemo) -> bool:
return self.acc in file.name
Expand All @@ -61,11 +60,11 @@ def file_account(self, _file: FileMemo) -> str:
def test_ingest_file_import_info_minimal_importer(test_data_dir: Path) -> None:
csv_path = test_data_dir / "import.csv"

info = file_import_info(csv_path, MinimalImporter("rawfile"))
assert isinstance(info.account, str)
importer = WrappedImporter(MinimalImporter())
info = importer.file_import_info(csv_path)
assert info == FileImportInfo(
"rawfile",
"rawfile",
"MinimalImporter(Assets:Checking)",
"Assets:Checking",
local_today(),
"import.csv",
)
Expand All @@ -82,8 +81,9 @@ def test_ingest_file_import_info_account_method_errors(
) -> None:
csv_path = test_data_dir / "import.csv"

importer = WrappedImporter(AccountNameErrors())
with pytest.raises(FavaAPIError) as err:
file_import_info(csv_path, AccountNameErrors())
importer.file_import_info(csv_path)
assert "Some error reason..." in err.value.message


Expand All @@ -96,8 +96,9 @@ def identify(self, _file: FileMemo) -> bool:
def test_ingest_identify_errors(test_data_dir: Path) -> None:
csv_path = test_data_dir / "import.csv"

importer = WrappedImporter(IdentifyErrors())
with pytest.raises(FavaAPIError) as err:
importer_identify(IdentifyErrors(), csv_path)
importer.identify(csv_path)
assert "IDENTIFY_ERRORS" in err.value.message


Expand All @@ -108,11 +109,23 @@ def name(self) -> str:


def test_ingest_get_name_errors() -> None:
importer = WrappedImporter(ImporterNameErrors())
with pytest.raises(FavaAPIError) as err:
get_name(ImporterNameErrors())
assert importer.name
assert "GET_NAME_WILL_ERROR" in err.value.message


class ImporterNameInvalidType(MinimalImporter):
def name(self) -> str:
return False # type: ignore[return-value]


def test_ingest_get_name_invalid_type() -> None:
importer = WrappedImporter(ImporterNameInvalidType())
with pytest.raises(ImporterInvalidTypeError):
assert importer.name


@pytest.mark.skipif(
sys.platform == "win32", reason="different error on windows"
)
Expand Down

0 comments on commit 2261c27

Please sign in to comment.