Skip to content

Commit

Permalink
add mypy type checker (#746)
Browse files Browse the repository at this point in the history
* add mypy
* add mypy vscode extenstion
* add mypy to ci
  • Loading branch information
ekneg54 authored Feb 14, 2025
1 parent a4164f5 commit c896c2a
Show file tree
Hide file tree
Showing 18 changed files with 92 additions and 71 deletions.
7 changes: 5 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ jobs:
uses: tj-actions/changed-files@v41
with:
files: |
**/*.py
logprep/**/*.py
- name: Install dependencies
run: |
pip install --upgrade pip wheel
pip install .[dev]
- name: check black formating
- name: check black formatting
run: |
black --check --diff --config ./pyproject.toml .
- name: lint helm charts
Expand All @@ -68,6 +68,9 @@ jobs:
if: steps.changed-files.outputs.all_changed_files
run: |
pylint ${{ steps.changed-files.outputs.all_changed_files }}
- name: mypy type checking
if: steps.changed-files.outputs.all_changed_files
run: mypy --follow-imports=skip ${{ steps.changed-files.outputs.all_changed_files }}
- name: Run tests and collect coverage
run: pytest tests/unit --cov=logprep --cov-report=xml
- name: Upload coverage reports to Codecov with GitHub Action
Expand Down
1 change: 1 addition & 0 deletions .vscode/extensions.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"ms-python.pylint",
"ms-python.isort",
"ms-toolsai.jupyter",
"ms-python.mypy-type-checker",
"njpwerner.autodocstring",
"ryanluker.vscode-coverage-gutters",
"streetsidesoftware.code-spell-checker"
Expand Down
10 changes: 5 additions & 5 deletions logprep/abc/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import zlib
from abc import abstractmethod
from copy import deepcopy
from functools import partial, cached_property
from functools import cached_property, partial
from hmac import HMAC
from typing import Optional, Tuple
from zoneinfo import ZoneInfo
Expand Down Expand Up @@ -91,9 +91,9 @@ class TimeDeltaConfig:
"""TimeDelta Configurations
Works only if the preprocessor log_arrival_time_target_field is set."""

target_field: field(validator=[validators.instance_of(str), lambda _, __, x: bool(x)])
target_field: str = field(validator=(validators.instance_of(str), lambda _, __, x: bool(x)))
"""Defines the fieldname to which the time difference should be written to."""
reference_field: field(validator=[validators.instance_of(str), lambda _, __, x: bool(x)])
reference_field: str = field(validator=(validators.instance_of(str), lambda _, __, x: bool(x)))
"""Defines a field with a timestamp that should be used for the time difference.
The calculation will be the arrival time minus the time of this reference field."""

Expand Down Expand Up @@ -233,7 +233,7 @@ def _add_env_enrichment(self):
"""Check and return if the env enrichment should be added to the event."""
return bool(self._config.preprocessing.get("enrich_by_env_variables"))

def _get_raw_event(self, timeout: float) -> bytearray: # pylint: disable=unused-argument
def _get_raw_event(self, timeout: float) -> bytes | None: # pylint: disable=unused-argument
"""Implements the details how to get the raw event
Parameters
Expand Down Expand Up @@ -283,7 +283,7 @@ def get_next(self, timeout: float) -> dict | None:
"""
event, raw_event = self._get_event(timeout)
if event is None:
return
return None
self.metrics.number_of_processed_events += 1
if not isinstance(event, dict):
raise CriticalInputError(self, "not a dict", event)
Expand Down
16 changes: 10 additions & 6 deletions logprep/abc/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import logging
import os
from abc import abstractmethod
from typing import TYPE_CHECKING, ClassVar, Dict, List, Type
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Type

from attr import define, field, validators
from attrs import define, field, validators

from logprep.abc.component import Component
from logprep.framework.rule_tree.rule_tree import RuleTree
Expand Down Expand Up @@ -69,10 +69,10 @@ class ProcessorResult:
factory=list,
)
""" The warnings that occurred during processing """
event: dict | None = field(validator=validators.instance_of((dict, type(None))), default=None)
""" A reference to the event that was processed """
processor_name: str = field(validator=validators.instance_of(str))
""" The name of the processor """
event: dict = field(validator=validators.optional(validators.instance_of(dict)), default=None)
""" A reference to the event that was processed """


class Processor(Component):
Expand Down Expand Up @@ -161,7 +161,7 @@ def process(self, event: dict) -> ProcessorResult:
extra data and a list of target outputs.
"""
self.result = ProcessorResult(processor_name=self.name, event=event)
self.result = ProcessorResult(processor_name=self.name, event=event) # type: ignore
logger.debug("%s processing event %s", self.describe(), event)
if self._bypass_rule_tree:
self._process_all_rules(event)
Expand Down Expand Up @@ -214,9 +214,13 @@ def _apply_rules_wrapper(self, event: dict, rule: "Rule"):
except ProcessingWarning as error:
self._handle_warning_error(event, rule, error)
except ProcessingCriticalError as error:
if self.result is None:
raise error
self.result.errors.append(error) # is needed to prevent wrapping it in itself
event.clear()
except Exception as error: # pylint: disable=broad-except
if self.result is None:
raise error
self.result.errors.append(ProcessingCriticalError(str(error), rule))
event.clear()
if not hasattr(rule, "delete_source_fields"):
Expand Down Expand Up @@ -291,7 +295,7 @@ def _has_missing_values(self, event, rule, source_field_dict):
return True
return False

def _write_target_field(self, event: dict, rule: "Rule", result: any) -> None:
def _write_target_field(self, event: dict, rule: "Rule", result: Any) -> None:
add_fields_to(
event,
fields={rule.target_field: result},
Expand Down
20 changes: 11 additions & 9 deletions logprep/connector/file/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class FileWatcherUtil:
"""

def __init__(self, file_name: str = ""):
self.dict = {}
self.dict: dict = {}
if file_name:
self.add_file(file_name)

Expand Down Expand Up @@ -200,7 +200,7 @@ class FileInput(Input):
"""FileInput Connector"""

_messages: queue.Queue = queue.Queue()
_fileinfo_util: object = FileWatcherUtil()
_fileinfo_util: FileWatcherUtil = FileWatcherUtil()
rthread: threading.Event = None

@define(kw_only=True)
Expand All @@ -212,7 +212,7 @@ class Config(Input.Config):
format. Needs to be parsed with dissector or another processor"""

start: str = field(
validator=[validators.instance_of(str), validators.in_(("begin", "end"))],
validator=(validators.instance_of(str), validators.in_(("begin", "end"))),
default="begin",
)
"""Defines the behaviour of the file monitor with the following options:
Expand All @@ -232,7 +232,9 @@ def __init__(self, name: str, configuration: "FileInput.Config"):
super().__init__(name, configuration)
self.stop_flag = threading.Event()

def _calc_file_fingerprint(self, file_pointer: TextIO, fingerprint_length: int = None) -> tuple:
def _calc_file_fingerprint(
self, file_pointer: TextIO, fingerprint_length: int | None = None
) -> tuple:
"""This function creates a crc32 fingerprint of the first 256 bytes of a given file
If the existing log file is less than 256 bytes, it will take what is there
and return also the size"""
Expand Down Expand Up @@ -293,7 +295,7 @@ def _file_input_handler(self, file_name: str):
"""Put log_line as a dict to threadsafe message queue from given input file.
Depending on configuration it will continuously monitor a given file for new
appending log lines. Depending on configuration it will start to process the
given file from the beginning or the end. Will create and continously check
given file from the beginning or the end. Will create and continuously check
the file fingerprints to detect file changes that typically occur on log rotation."""
with open(file_name, encoding="utf-8") as file:
if not self._fileinfo_util.get_fingerprint(file_name):
Expand All @@ -306,21 +308,21 @@ def _file_input_handler(self, file_name: str):
def _line_to_dict(self, input_line: str) -> dict:
"""Takes an input string and turns it into a dict without any parsing or formatting.
Only thing it does additionally is stripping the new lines away."""
input_line: str = input_line.rstrip("\n")
input_line = input_line.rstrip("\n")
if len(input_line) > 0:
return {"message": input_line}
return ""
return {}

def _get_event(self, timeout: float) -> tuple:
"""Returns the first message from the threadsafe queue"""
try:
message: dict = self._messages.get(timeout=timeout)
raw_message: str = str(message).encode("utf8")
raw_message: bytes = str(message).encode("utf8")
return message, raw_message
except queue.Empty:
return None, None

def setup(self):
def setup(self) -> None:
"""Creates and starts the Thread that continuously monitors the given logfile.
Right now this input connector is only started in the first process.
It needs the class attribute pipeline_index before running setup in Pipeline
Expand Down
14 changes: 7 additions & 7 deletions logprep/connector/opensearch/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,23 +109,23 @@ class Config(Output.Config):
"""(Optional) Timeout after :code:`message_backlog` is flushed if
:code:`message_backlog_size` is not reached."""
thread_count: int = field(
default=4, validator=[validators.instance_of(int), validators.gt(1)]
default=4, validator=(validators.instance_of(int), validators.gt(1))
)
"""Number of threads to use for bulk requests."""
queue_size: int = field(
default=4, validator=[validators.instance_of(int), validators.gt(1)]
default=4, validator=(validators.instance_of(int), validators.gt(1))
)
"""Number of queue size to use for bulk requests."""
chunk_size: int = field(
default=500, validator=[validators.instance_of(int), validators.gt(1)]
default=500, validator=(validators.instance_of(int), validators.gt(1))
)
"""Chunk size to use for bulk requests."""
max_chunk_bytes: int = field(
default=100 * 1024 * 1024, validator=[validators.instance_of(int), validators.gt(1)]
default=100 * 1024 * 1024, validator=(validators.instance_of(int), validators.gt(1))
)
"""Max chunk size to use for bulk requests. The default is 100MB."""
max_retries: int = field(
default=3, validator=[validators.instance_of(int), validators.gt(0)]
default=3, validator=(validators.instance_of(int), validators.gt(0))
)
"""Max retries for all requests. Default is 3."""
desired_cluster_status: list = field(
Expand All @@ -134,7 +134,7 @@ class Config(Output.Config):
"""Desired cluster status for health check as list of strings. Default is ["green"]"""
default_op_type: str = field(
default="index",
validator=[validators.instance_of(str), validators.in_(["create", "index"])],
validator=(validators.instance_of(str), validators.in_(["create", "index"])),
)
"""Default op_type for indexing documents. Default is 'index',
Consider using 'create' for data streams or to prevent overwriting existing documents."""
Expand Down Expand Up @@ -174,7 +174,7 @@ def schema(self) -> str:
return "https" if self._config.ca_cert else "http"

@property
def http_auth(self) -> tuple:
def http_auth(self) -> tuple | None:
"""Returns the credentials
Returns
Expand Down
16 changes: 9 additions & 7 deletions logprep/generator/http/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,26 @@
from functools import cached_property
from operator import itemgetter
from pathlib import Path
from typing import Generator, List
from typing import Dict, Generator, List

import msgspec
import yaml
from attr import define, field, validators
from attrs import define, field, validators
from ruamel.yaml import YAML

from logprep.generator.http.manipulator import Manipulator

yaml = YAML(typ="safe")


@define(kw_only=True)
class TimestampReplacementConfig:
"""Configuration Class fot TimestampReplacement"""

key: str = field(validator=[validators.instance_of(str)])
key: str = field(validator=(validators.instance_of(str)))
format: str = field(validator=validators.instance_of(str))
time_shift: str = field(
default="+0000",
validator=[validators.instance_of(str), validators.matches_re(r"[+-]\d{4}")],
validator=(validators.instance_of(str), validators.matches_re(r"[+-]\d{4}")),
)
time_delta: timedelta = field(
default=None, validator=validators.optional(validators.instance_of(timedelta))
Expand Down Expand Up @@ -101,7 +103,7 @@ def __init__(self, config: dict):
self.events_sent = 0
self.batch_size = config.get("batch_size")
self.log = logging.getLogger("Input")
self.log_class_manipulator_mapping = {}
self.log_class_manipulator_mapping: Dict = {}
self.number_events_of_dataset = 0
self.event_file_counter = 0

Expand Down Expand Up @@ -148,7 +150,7 @@ def _load_event_class_config(self, event_class_dir_path: str) -> EventClassConfi
"""Load the event class specific configuration"""
config_path = os.path.join(event_class_dir_path, "config.yaml")
with open(config_path, "r", encoding="utf8") as file:
event_class_config = yaml.safe_load(file)
event_class_config = yaml.load(file)
self.log.debug("Following class config was loaded: %s", event_class_config)
event_class_config = EventClassConfig(**event_class_config)
if "," in event_class_config.target_path:
Expand Down
4 changes: 1 addition & 3 deletions logprep/processor/calculator/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,7 @@ class CalculatorRule(FieldManagerRule):
class Config(FieldManagerRule.Config):
"""Config for Calculator"""

calc: str = field(
validator=[validators.instance_of(str), validators.min_len(3)],
)
calc: str = field(validator=(validators.instance_of(str), validators.min_len(3)))
"""The calculation expression. Fields from the event can be used by
surrounding them with :code:`${` and :code:`}`."""
source_fields: list = field(factory=list, init=False, repr=False, eq=False)
Expand Down
2 changes: 1 addition & 1 deletion logprep/processor/generic_resolver/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class Config(FieldManagerRule.Config):
]
)
"""Mapping in form of :code:`{SOURCE_FIELD: DESTINATION_FIELD}`"""
resolve_list: dict = field(validator=[validators.instance_of(dict)], factory=dict)
resolve_list: dict = field(validator=(validators.instance_of(dict)), factory=dict)
"""lookup mapping in form of
:code:`{REGEX_PATTERN_0: ADDED_VALUE_0, ..., REGEX_PATTERN_N: ADDED_VALUE_N}`"""
resolve_from_file: dict = field(
Expand Down
4 changes: 2 additions & 2 deletions logprep/processor/labeler/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from typing import Optional

from attr import define, field, validators
from attrs import define, field, validators

from logprep.abc.processor import Processor
from logprep.processor.labeler.labeling_schema import LabelingSchema
Expand All @@ -41,7 +41,7 @@ class Labeler(Processor):
class Config(Processor.Config):
"""Labeler Configurations"""

schema: str = field(validator=[validators.instance_of(str)])
schema: str = field(validator=(validators.instance_of(str)))
"""Path to a labeling schema file. For string format see :ref:`getters`."""
include_parent_labels: Optional[bool] = field(
default=False, validator=validators.optional(validator=validators.instance_of(bool))
Expand Down
8 changes: 5 additions & 3 deletions logprep/processor/pre_detector/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,19 +176,21 @@ class Config(Rule.Config): # pylint: disable=too-many-instance-attributes
timestamp_field: str = field(validator=validators.instance_of(str), default="@timestamp")
"""the field which has the given timestamp to be normalized defaults to :code:`@timestamp`"""
source_timezone: ZoneInfo = field(
validator=[validators.instance_of(ZoneInfo)], converter=ZoneInfo, default="UTC"
validator=(validators.instance_of(ZoneInfo)), converter=ZoneInfo, default="UTC"
)
""" timezone of source_fields defaults to :code:`UTC`"""
target_timezone: ZoneInfo = field(
validator=[validators.instance_of(ZoneInfo)], converter=ZoneInfo, default="UTC"
validator=(validators.instance_of(ZoneInfo)), converter=ZoneInfo, default="UTC"
)
""" timezone for target_field defaults to :code:`UTC`"""
failure_tags: list = field(
validator=validators.instance_of(list), default=["pre_detector_failure"]
)
""" tags to be added if processing of the rule fails"""

def __eq__(self, other: "PreDetectorRule") -> bool:
def __eq__(self, other: object) -> bool:
if not isinstance(other, PreDetectorRule):
return NotImplemented
return all(
[
super().__eq__(other),
Expand Down
8 changes: 4 additions & 4 deletions logprep/processor/pseudonymizer/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class Config(FieldManager.Config):
* /var/git/logprep-rules/pseudonymizer_rules/regex_mapping.json
"""
max_cached_pseudonyms: int = field(
validator=[validators.instance_of(int), validators.gt(0)]
validator=(validators.instance_of(int), validators.gt(0))
)
"""
The maximum number of cached pseudonyms. One cache entry requires ~250 Byte, thus 10
Expand All @@ -127,19 +127,19 @@ class Config(FieldManager.Config):
entry is deleted. Has to be greater than 0.
"""
max_cached_pseudonymized_urls: int = field(
validator=[validators.instance_of(int), validators.gt(0)], default=10000
validator=(validators.instance_of(int), validators.gt(0)), default=10000
)
"""The maximum number of cached pseudonymized urls. Default is 10000.
Behaves similarly to the max_cached_pseudonyms. Has to be greater than 0."""
mode: str = field(
validator=[validators.instance_of(str), validators.in_(("GCM", "CTR"))], default="GCM"
validator=(validators.instance_of(str), validators.in_(("GCM", "CTR"))), default="GCM"
)
"""Optional mode of operation for the encryption. Can be either 'GCM' or 'CTR'.
Default is 'GCM'.
"""

@define(kw_only=True)
class Metrics(Processor.Metrics):
class Metrics(Processor.Metrics): # type: ignore
"""Tracks statistics about the Pseudonymizer"""

pseudonymized_urls: CounterMetric = field(
Expand Down
Loading

0 comments on commit c896c2a

Please sign in to comment.