Skip to content

Commit

Permalink
♻️ Refactor ContentTagger mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
HugoPerrier committed Dec 9, 2024
1 parent 314e927 commit 673c071
Show file tree
Hide file tree
Showing 18 changed files with 1,162 additions and 974 deletions.
2 changes: 1 addition & 1 deletion melusine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from sklearn.base import BaseEstimator, TransformerMixin

from melusine.backend import backend
from melusine.io import IoMixin
from melusine.io_mixin import IoMixin

logger = logging.getLogger(__name__)

Expand Down
3 changes: 3 additions & 0 deletions melusine/conf/pipelines/demo_pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ demo_pipeline:
- class_name: ContentTagger
config_key: content_tagger
module: melusine.processors
- class_name: RefinedTagger
config_key: refined_tagger
module: melusine.processors
- class_name: TextExtractor
config_key: text_extractor
module: melusine.processors
Expand Down
2 changes: 1 addition & 1 deletion melusine/conf/processors/refined_tagger.yaml
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
content_tagger:
refined_tagger:
default_tag: BODY
22 changes: 4 additions & 18 deletions melusine/detectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""

from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List

from melusine.base import MelusineDetector, MelusineItem, MelusineRegex
from melusine.message import Message
Expand Down Expand Up @@ -95,19 +95,12 @@ def pre_detect(self, row: MelusineItem, debug_mode: bool = False) -> MelusineIte
target_tags={self.BODY_PART}, stop_at={self.GREETINGS_PART}
)

# Extract the THANKS part in the last message
thanks_parts: List[Tuple[str, str]] = row[self.messages_column][0].extract_parts(target_tags={self.THANKS_PART})

# Compute THANKS text
if not thanks_parts:
thanks_text: str = ""
else:
thanks_text = "\n".join(x[1] for x in thanks_parts)
# Extract the THANKS text in the last message
thanks_text = row[self.messages_column][0].extract_text(target_tags={self.THANKS_PART})

# Save debug data
if debug_mode:
debug_dict = {
self.THANKS_PARTS_COL: thanks_parts,
self.THANKS_TEXT_COL: thanks_text,
self.HAS_BODY: has_body,
}
Expand Down Expand Up @@ -236,20 +229,13 @@ def pre_detect(self, row: MelusineItem, debug_mode: bool = False) -> MelusineIte
"""
# Last message body
last_message: Message = row[self.messages_column][0]
body_parts = last_message.extract_last_body()

if body_parts:
row[self.CONST_TEXT_COL_NAME] = "\n".join(text for tag, text in body_parts)
else:
row[self.CONST_TEXT_COL_NAME] = ""
row[self.CONST_TEXT_COL_NAME] = last_message.extract_text(target_tags=("BODY",), stop_at=("GREETINGS",))

# Prepare and save debug data
if debug_mode:
debug_dict: Dict[str, Any] = {
self.CONST_DEBUG_TEXT_KEY: row[self.CONST_TEXT_COL_NAME],
}
if self.messages_column:
debug_dict[self.CONST_DEBUG_PARTS_KEY] = body_parts
row[self.debug_dict_col].update(debug_dict)

return row
Expand Down
2 changes: 1 addition & 1 deletion melusine/io/__init__.py → melusine/io_mixin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
The melusine.io module includes classes for input/output data.
"""

from melusine.io._classes import IoMixin
from melusine.io_mixin._classes import IoMixin

__all__ = ["IoMixin"]
4 changes: 0 additions & 4 deletions melusine/io/_classes.py → melusine/io_mixin/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@ class IoMixin:
Defines generic load methods.
"""

def __init__(self, **kwargs: Any):
"""Initialize attribute."""
self.json_exclude_list: list[str] = ["_func", "json_exclude_list"]

@classmethod
def from_config(
cls: type[T],
Expand Down
56 changes: 46 additions & 10 deletions melusine/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import re
from datetime import datetime
from typing import Any, Dict, Iterable, List, Optional, Tuple
from typing import Any, Dict, Iterable, List, Optional

from melusine import config

Expand All @@ -22,6 +22,7 @@ class Message:
DEFAULT_STR_TAG_NAME_LENGTH = 22
MAIN_TAG_TYPE = "refined_tag"
FALLBACK_TAG_TYPE = "base_tag"
MAIN_TEXT_TYPE = "base_text"

def __init__(
self,
Expand Down Expand Up @@ -65,6 +66,9 @@ def __init__(
self.clean_header: str = ""
self.clean_text: str = ""

self.effective_tag_key = "base_tag"
self.effective_text_key = "base_text"

@property
def str_tag_name_length(self) -> int:
"""
Expand All @@ -89,7 +93,7 @@ def extract_parts(
self,
target_tags: Optional[Iterable[str]] = None,
stop_at: Optional[Iterable[str]] = None,
tag_type: str = MAIN_TAG_TYPE,
tag_type: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""
Function to extract target tags from the message.
Expand All @@ -110,13 +114,11 @@ def extract_parts(
if not self.tags:
return []

if tag_type is None:
tag_type = self.effective_tag_key

# List of tags in the message
try:
tag_name_list: List[str] = [x[tag_type] for x in self.tags]
# If tag_type is not available, fall back on base_tag
except KeyError:
tag_type = self.FALLBACK_TAG_TYPE
tag_name_list: List[str] = [x[tag_type] for x in self.tags]
tag_name_list: List[str] = [x[tag_type] for x in self.tags]

if target_tags is None:
target_tags = tag_name_list
Expand All @@ -135,11 +137,42 @@ def extract_parts(

return [x for x in effective_tags if x[tag_type] in target_tags]

def extract_text(
self,
target_tags: Optional[Iterable[str]] = None,
stop_at: Optional[Iterable[str]] = None,
tag_type: Optional[str] = None,
text_type: str = MAIN_TEXT_TYPE,
separator: str = "\n",
) -> str:
"""
Function to extract target tags from the message.
Parameters
----------
target_tags:
Tags to be extracted.
stop_at:
Tags for which extraction should stop.
tag_type:
Type of tags to consider.
text_type:
Type of text to consider
separator:
Separator to join the extracted texts.
Returns
-------
_: List of extracted tags.
"""
parts = self.extract_parts(target_tags=target_tags, stop_at=stop_at, tag_type=tag_type)
return separator.join([x[text_type] for x in parts])

def extract_last_body(
self,
target_tags: Iterable[str] = ("BODY",),
stop_at: Iterable[str] = ("GREETINGS",),
tag_type: str = MAIN_TAG_TYPE,
tag_type: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""
Extract the BODY parts of the last message in the email.
Expand All @@ -160,7 +193,7 @@ def has_tags(
self,
target_tags: Iterable[str] = ("BODY",),
stop_at: Optional[Iterable[str]] = None,
tag_type: str = MAIN_TAG_TYPE,
tag_type: Optional[str] = None,
) -> bool:
"""
Function to check if input tags are present in the message.
Expand All @@ -182,6 +215,9 @@ def has_tags(
if self.tags is None:
return False

if tag_type is None:
tag_type = self.effective_tag_key

if not stop_at:
stop_at = set()

Expand Down
2 changes: 1 addition & 1 deletion melusine/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from melusine.backend import backend
from melusine.backend.base_backend import Any
from melusine.base import MelusineTransformer
from melusine.io import IoMixin
from melusine.io_mixin import IoMixin

T = TypeVar("T")

Expand Down
Loading

0 comments on commit 673c071

Please sign in to comment.