Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mypy var annotated #10956

Merged
merged 13 commits into from
Mar 4, 2022
1 change: 1 addition & 0 deletions changelog/9094.misc.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Enable `mypy` `var-annotated` check and fix any resulting errors.
2 changes: 1 addition & 1 deletion rasa/core/actions/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,7 +1179,7 @@ async def run(
) -> List[Event]:
"""Runs action. Please see parent class for the full docstring."""
slot_events: List[Event] = []
executed_custom_actions = set()
executed_custom_actions: Set[Text] = set()

user_slots = [
slot for slot in domain.slots if slot.name not in DEFAULT_SLOT_NAMES
Expand Down
6 changes: 3 additions & 3 deletions rasa/core/actions/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ def _create_unique_entity_mappings(self, domain: Domain) -> Set[Text]:
Returns:
A set of json dumps of unique mappings of type `from_entity`.
"""
unique_entity_slot_mappings = set()
duplicate_entity_slot_mappings = set()
unique_entity_slot_mappings: Set[Text] = set()
duplicate_entity_slot_mappings: Set[Text] = set()
domain_slots = domain.as_dict().get(KEY_SLOTS, {})
for slot in domain.required_slots_for_form(self.name()):
for slot_mapping in domain_slots.get(slot, {}).get(SLOT_MAPPINGS, []):
Expand Down Expand Up @@ -360,7 +360,7 @@ def _get_slot_extractions(
events_since_last_user_uttered = FormAction._get_events_since_last_user_uttered(
tracker
)
slot_values = {}
slot_values: Dict[Text, Any] = {}

required_slots = self._add_dynamic_slots_requested_by_dynamic_forms(
tracker, domain
Expand Down
4 changes: 3 additions & 1 deletion rasa/core/channels/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,10 +330,12 @@ class CollectingOutputChannel(OutputChannel):
(doesn't send them anywhere, just collects them)."""

def __init__(self) -> None:
self.messages = []
"""Initialise list to collect messages."""
self.messages: List[Dict[Text, Any]] = []

@classmethod
def name(cls) -> Text:
"""Name of the channel."""
return "collector"

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion rasa/core/channels/hangouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def name(cls) -> Text:

def __init__(self) -> None:
"""Starts messages as empty dictionary."""
self.messages = {}
self.messages: Dict[Text, Any] = {}

@staticmethod
def _text_card(message: Dict[Text, Any]) -> Dict:
Expand Down
2 changes: 1 addition & 1 deletion rasa/core/channels/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def stream_response(
metadata: Optional[Dict[Text, Any]],
) -> Callable[[Any], Awaitable[None]]:
async def stream(resp: Any) -> None:
q = Queue()
q: Queue = Queue()
task = asyncio.ensure_future(
self.on_message_wrapper(
on_new_message, text, q, sender_id, input_channel, metadata
Expand Down
2 changes: 1 addition & 1 deletion rasa/core/evaluation/marker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def _collect_yaml_files_from_path(path: Union[Text, Path]) -> List[Text]:

@staticmethod
def _collect_configs_from_yaml_files(yaml_files: List[Text]) -> Dict[Text, Dict]:
marker_names = set()
marker_names: Set[Text] = set()
loaded_configs: Dict[Text, Dict] = {}
for yaml_file in yaml_files:
loaded_config = rasa.shared.utils.io.read_yaml_file(yaml_file)
Expand Down
2 changes: 1 addition & 1 deletion rasa/core/evaluation/marker_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _add_num_user_turns_str_to(stat_name: Text) -> Text:
def __init__(self) -> None:
"""Creates a new marker statistics object."""
# to ensure consistency of processed rows
self._marker_names = []
self._marker_names: List[Text] = []

# (1) For collecting the per-session analysis:
# NOTE: we could stream / compute them later instead of collecting them...
Expand Down
6 changes: 3 additions & 3 deletions rasa/core/featurizers/single_state_featurizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ class SingleStateFeaturizer:

def __init__(self) -> None:
"""Initialize the single state featurizer."""
self._default_feature_states = {}
self.action_texts = []
self.entity_tag_specs = []
self._default_feature_states: Dict[Text, Any] = {}
self.action_texts: List[Text] = []
self.entity_tag_specs: List[EntityTagSpec] = []

def _create_entity_tag_specs(
self, bilou_tagging: bool = False
Expand Down
2 changes: 1 addition & 1 deletion rasa/core/featurizers/tracker_featurizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,7 +1062,7 @@ def _extract_examples(
tracker_states[:label_index], self.max_history
)
label = [event.intent_name or event.text]
entities = [{}]
entities: List[Dict[Text, Any]] = [{}]

yield sliced_states, label, entities

Expand Down
8 changes: 6 additions & 2 deletions rasa/core/lock_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os

from async_generator import asynccontextmanager
from typing import Text, Union, Optional, AsyncGenerator
from typing import AsyncGenerator, Dict, Optional, Text, Union

from rasa.shared.exceptions import RasaException, ConnectionException
import rasa.shared.utils.common
Expand Down Expand Up @@ -274,19 +274,23 @@ class InMemoryLockStore(LockStore):
"""In-memory store for ticket locks."""

def __init__(self) -> None:
self.conversation_locks = {}
"""Initialise dictionary of locks."""
self.conversation_locks: Dict[Text, TicketLock] = {}
super().__init__()

def get_lock(self, conversation_id: Text) -> Optional[TicketLock]:
"""Get lock for conversation if it exists."""
return self.conversation_locks.get(conversation_id)

def delete_lock(self, conversation_id: Text) -> None:
"""Delete lock for conversation."""
deleted_lock = self.conversation_locks.pop(conversation_id, None)
self._log_deletion(
conversation_id, deletion_successful=deleted_lock is not None
)

def save_lock(self, lock: TicketLock) -> None:
"""Save lock in store."""
self.conversation_locks[lock.conversation_id] = lock


Expand Down
8 changes: 4 additions & 4 deletions rasa/core/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,9 @@ def _migrate_domain_files(
backup_location: where to backup all domain files
out_path: location where to store the migrated files
"""
slots = {}
forms = {}
entities = []
slots: Dict[Text, Any] = {}
forms: Dict[Text, Any] = {}
entities: List[Any] = []

domain_files = [
file for file in domain_path.iterdir() if Domain.is_domain_file(file)
Expand Down Expand Up @@ -264,7 +264,7 @@ def _migrate_domain_files(

slots.update(original_content.get(KEY_SLOTS, {}))
forms.update(original_content.get(KEY_FORMS, {}))
entities.extend(original_content.get(KEY_ENTITIES, {}))
entities.extend(original_content.get(KEY_ENTITIES, []))

if not slots or not forms:
raise RasaException(
Expand Down
2 changes: 1 addition & 1 deletion rasa/core/policies/memoization.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _create_lookup_from_states(
Returns:
lookup dictionary
"""
lookup = {}
lookup: Dict[Text, Text] = {}

if not trackers_as_states:
return lookup
Expand Down
30 changes: 20 additions & 10 deletions rasa/core/policies/rule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,9 @@ def __init__(
self._enable_fallback_prediction = config["enable_fallback_prediction"]
self._check_for_contradictions = config["check_for_contradictions"]

self._rules_sources = defaultdict(list)
self._rules_sources: defaultdict[Text, List[Tuple[Text, Text]]] = defaultdict(
list
)

@classmethod
def raise_if_incompatible_with_domain(
Expand Down Expand Up @@ -190,7 +192,7 @@ def _is_rule_snippet_state(state: State) -> bool:
return prev_action_name == RULE_SNIPPET_ACTION_NAME

def _create_feature_key(self, states: List[State]) -> Optional[Text]:
new_states = []
new_states: List[State] = []
for state in reversed(states):
if self._is_rule_snippet_state(state):
# remove all states before RULE_SNIPPET_ACTION_NAME
Expand Down Expand Up @@ -493,7 +495,7 @@ def _collect_sources(
tracker: TrackerWithCachedStates,
predicted_action_name: Optional[Text],
gold_action_name: Text,
prediction_source: Optional[Text],
prediction_source: Text,
) -> None:
# we need to remember which action should be predicted by the rule
# in order to correctly output the names of the contradicting rules
Expand Down Expand Up @@ -566,7 +568,14 @@ def _check_prediction(
gold_action_name: Text,
prediction_source: Optional[Text],
) -> List[Text]:
if not predicted_action_name or predicted_action_name == gold_action_name:
# FIXME: `predicted_action_name` and `prediction_source` are
# either None together or defined together. This could be improved
# by better typing in this class, but requires some refactoring
if (
not predicted_action_name
or not prediction_source
or predicted_action_name == gold_action_name
):
return []

if self._should_delete(prediction_source, tracker, predicted_action_name):
Expand Down Expand Up @@ -638,12 +647,13 @@ def _run_prediction_on_trackers(
running_tracker, domain, gold_action_name
)
if collect_sources:
self._collect_sources(
running_tracker,
predicted_action_name,
gold_action_name,
prediction_source,
)
if prediction_source:
self._collect_sources(
running_tracker,
predicted_action_name,
gold_action_name,
prediction_source,
)
else:
# to be able to remove only rules turns from the dialogue history
# for ML policies,
Expand Down
2 changes: 1 addition & 1 deletion rasa/core/policies/unexpected_intent_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ def _collect_label_id_grouped_scores(
if LABEL_PAD_ID in unique_label_ids:
unique_label_ids.remove(LABEL_PAD_ID)

label_id_scores = {
label_id_scores: Dict[int, Dict[Text, List[float]]] = {
label_id: {POSITIVE_SCORES_KEY: [], NEGATIVE_SCORES_KEY: []}
for label_id in unique_label_ids
}
Expand Down
2 changes: 1 addition & 1 deletion rasa/core/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ async def handle_reminder(
)
else:
intent = reminder_event.intent
entities = reminder_event.entities or {}
entities: Union[List[Dict], Dict] = reminder_event.entities or {}
await self.trigger_external_user_uttered(
intent, entities, tracker, output_channel
)
Expand Down
2 changes: 1 addition & 1 deletion rasa/core/tracker_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def __init__(
event_broker: Optional[EventBroker] = None,
**kwargs: Dict[Text, Any],
) -> None:
self.store = {}
self.store: Dict[Text, Text] = {}
super().__init__(domain, event_broker, **kwargs)

def save(self, tracker: DialogueStateTracker) -> None:
Expand Down
8 changes: 3 additions & 5 deletions rasa/core/training/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@
OTHER_ACTION = uuid.uuid4().hex
NEW_ACTION = uuid.uuid4().hex

NEW_RESPONSES = {}
NEW_RESPONSES: Dict[Text, List[Dict[Text, Any]]] = {}

MAX_NUMBER_OF_TRAINING_STORIES_FOR_VISUALIZATION = 200

Expand Down Expand Up @@ -318,9 +318,8 @@ async def _ask_questions(
is_abort: Callable[[Dict[Text, Any]], bool] = lambda x: False,
) -> Any:
"""Ask the user a question, if Ctrl-C is pressed provide user with menu."""

should_retry = True
answers = {}
answers: Any = {}

while should_retry:
answers = questions.ask()
Expand All @@ -335,7 +334,6 @@ def _selection_choices_from_intent_prediction(
predictions: List[Dict[Text, Any]]
) -> List[Dict[Text, Any]]:
"""Given a list of ML predictions create a UI choice list."""

sorted_intents = sorted(
predictions, key=lambda k: (-k["confidence"], k[INTENT_NAME_KEY])
)
Expand Down Expand Up @@ -923,7 +921,7 @@ def _write_domain_to_file(

messages = _collect_messages(events)
actions = _collect_actions(events)
responses = NEW_RESPONSES # type: Dict[Text, List[Dict[Text, Any]]]
responses = NEW_RESPONSES

# TODO for now there is no way to distinguish between action and form
collected_actions = list(
Expand Down
4 changes: 2 additions & 2 deletions rasa/core/training/story_conflict.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, sliced_states: List[State]) -> None:

self._sliced_states = sliced_states
# A list of actions that all follow from the same state.
self._conflicting_actions = defaultdict(
self._conflicting_actions: defaultdict[Text, List[Text]] = defaultdict(
list
) # {"action": ["story_1", ...], ...}

Expand Down Expand Up @@ -196,7 +196,7 @@ def _find_conflicting_states(
"""
# Create a 'state -> list of actions' dict, where the state is
# represented by its hash
state_action_mapping = defaultdict(list)
state_action_mapping: defaultdict[int, List[int]] = defaultdict(list)

for element in _sliced_states_iterator(trackers, domain, max_history, tokenizer):
hashed_state = element.sliced_states_hash
Expand Down
2 changes: 1 addition & 1 deletion rasa/core/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def create_action_fingerprints(

# take into account only featurized slots
featurized_slots = {slot.name for slot in domain.slots if slot.has_features()}
action_fingerprints = defaultdict(dict)
action_fingerprints: defaultdict[Text, Dict[Text, List[Text]]] = defaultdict(dict)
for action_name, events_after_action in events_after_actions.items():
slots = list(
set(
Expand Down
2 changes: 1 addition & 1 deletion rasa/engine/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def as_dict(self) -> Dict[Text, Any]:
Returns:
The graph schema in a format which can be dumped as JSON or other formats.
"""
serializable_graph_schema = {"nodes": {}}
serializable_graph_schema: Dict[Text, Dict[Text, Any]] = {"nodes": {}}
for node_name, node in self.nodes.items():
serializable = dataclasses.asdict(node)

Expand Down
6 changes: 3 additions & 3 deletions rasa/engine/storage/local_model_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
from typing import Text, ContextManager, Tuple, Union
from typing import Text, Generator, Tuple, Union

import rasa.utils.common
import rasa.shared.utils.io
Expand Down Expand Up @@ -112,7 +112,7 @@ def _load_metadata(directory: Path) -> ModelMetadata:
return ModelMetadata.from_dict(serialized_metadata)

@contextmanager
def write_to(self, resource: Resource) -> ContextManager[Path]:
def write_to(self, resource: Resource) -> Generator[Path, None, None]:
"""Persists data for a resource (see parent class for full docstring)."""
logger.debug(f"Resource '{resource.name}' was requested for writing.")
directory = self._directory_for_resource(resource)
Expand All @@ -128,7 +128,7 @@ def _directory_for_resource(self, resource: Resource) -> Path:
return self._storage_path / resource.name

@contextmanager
def read_from(self, resource: Resource) -> ContextManager[Path]:
def read_from(self, resource: Resource) -> Generator[Path, None, None]:
"""Provides the data of a `Resource` (see parent class for full docstring)."""
logger.debug(f"Resource '{resource.name}' was requested for reading.")
directory = self._directory_for_resource(resource)
Expand Down
6 changes: 3 additions & 3 deletions rasa/engine/storage/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Tuple, Union, Text, ContextManager, Dict, Any, Optional
from typing import Tuple, Union, Text, Generator, Dict, Any, Optional
from packaging import version

from rasa.constants import MINIMUM_COMPATIBLE_VERSION
Expand Down Expand Up @@ -74,7 +74,7 @@ def metadata_from_archive(

@contextmanager
@abc.abstractmethod
def write_to(self, resource: Resource) -> ContextManager[Path]:
def write_to(self, resource: Resource) -> Generator[Path, None, None]:
"""Persists data for a given resource.

This `Resource` can then be accessed in dependent graph nodes via
Expand All @@ -90,7 +90,7 @@ def write_to(self, resource: Resource) -> ContextManager[Path]:

@contextmanager
@abc.abstractmethod
def read_from(self, resource: Resource) -> ContextManager[Path]:
def read_from(self, resource: Resource) -> Generator[Path, None, None]:
"""Provides the data of a persisted `Resource`.

Args:
Expand Down
2 changes: 1 addition & 1 deletion rasa/model_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ async def compare_nlu_models(
bases = [os.path.basename(nlu_config) for nlu_config in configs]
model_names = [os.path.splitext(base)[0] for base in bases]

f1_score_results = {
f1_score_results: Dict[Text, List[List[float]]] = {
model_name: [[] for _ in range(runs)] for model_name in model_names
}

Expand Down
Loading