Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix and enable mypy return check #9109

Merged
merged 47 commits into from
Jul 15, 2021
Merged
Show file tree
Hide file tree
Changes from 45 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
1ded5ef
fix and enable mypy return check
joejuzl Jul 12, 2021
897122a
Merge remote-tracking branch 'origin/main' into 8242-mypy_return_issues
joejuzl Jul 12, 2021
b21bbd2
lint
joejuzl Jul 12, 2021
176c04e
test fixes
joejuzl Jul 12, 2021
cc43663
lint
joejuzl Jul 12, 2021
299fa7e
sanic import fix
joejuzl Jul 13, 2021
2fc4f3f
Merge remote-tracking branch 'origin/main' into 8242-mypy_return_issues
joejuzl Jul 13, 2021
29f5250
use normal OrderedDict
joejuzl Jul 13, 2021
f85c84d
ci fixes
joejuzl Jul 13, 2021
5fe31f9
enable return check again
joejuzl Jul 13, 2021
f462eb6
fix return types in `Message` by declaring new variables
wochinge Jul 6, 2021
3b0f98a
fix return types in `model_data_utils` by declaring new variables
wochinge Jul 6, 2021
1e9ca83
make type explicit to fix `return` type error
wochinge Jul 7, 2021
2161830
fix `return` value issues in `models` module
wochinge Jul 7, 2021
f921848
fix wrong type annotation
wochinge Jul 7, 2021
ab9b695
specify correct return type
wochinge Jul 7, 2021
1067054
fix typing issues in `rasa_yaml`
wochinge Jul 7, 2021
f25070e
add missing return statements
wochinge Jul 12, 2021
6f83b30
return empty training data instead of implicit `None`
wochinge Jul 12, 2021
e11b4be
cover default case
wochinge Jul 12, 2021
4545654
fix context manager related issues
wochinge Jul 12, 2021
eb52955
use correct ContextManager type annotation
wochinge Jul 13, 2021
a4229ed
fix docstring errors by making fns protected
wochinge Jul 13, 2021
787cda9
add missing docstrings
wochinge Jul 13, 2021
8e1ed49
fix remaining `index` errors
wochinge Jul 13, 2021
543ff31
fixes
joejuzl Jul 13, 2021
f72b9ed
Merge pull request #9054 from RasaHQ/8242-mypy_return_issues-tobias
wochinge Jul 13, 2021
e61dd7c
fix context manager type error
wochinge Jul 13, 2021
359641a
add cast
wochinge Jul 13, 2021
de278fa
remove unused method
wochinge Jul 13, 2021
3e6c32e
fix component return types
wochinge Jul 13, 2021
da09370
add default to avoid `None` case
wochinge Jul 14, 2021
50c78d1
use correct return type
wochinge Jul 14, 2021
613bac9
add type check
wochinge Jul 14, 2021
7ff1b0a
make code more verbose to help mypy
wochinge Jul 14, 2021
db04ffe
add metaclass workaround
wochinge Jul 14, 2021
a7f9922
black format
wochinge Jul 14, 2021
ca83b0f
black format
wochinge Jul 14, 2021
0ee9595
ignore Flask error
wochinge Jul 14, 2021
76ee0dd
explicitly add `mypy-extensions` to pyproject
wochinge Jul 14, 2021
15a71b3
remove blank line after docstring
wochinge Jul 14, 2021
bce9c2d
also enable 'return' code
wochinge Jul 14, 2021
ea2dda4
make type explicit
wochinge Jul 14, 2021
190981c
add default
wochinge Jul 14, 2021
845d4d6
fix sklearn policy load
joejuzl Jul 14, 2021
dc1a173
remove `Optional`
wochinge Jul 15, 2021
ca403a3
Merge branch 'main' into 8242-mypy_return_issues
wochinge Jul 15, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .github/scripts/mr_generate_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,13 @@ def read_results(file):
with open(file) as json_file:
data = json.load(json_file)

keys = ["accuracy", "weighted avg", "macro avg", "micro avg", "conversation_accuracy"]
keys = [
"accuracy",
"weighted avg",
"macro avg",
"micro avg",
"conversation_accuracy",
]
result = {key: data[key] for key in keys if key in data}

return result
Expand Down
8 changes: 7 additions & 1 deletion .github/scripts/mr_publish_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,13 @@ def read_results(file):
with open(file) as json_file:
data = json.load(json_file)

keys = ["accuracy", "weighted avg", "macro avg", "micro avg", "conversation_accuracy"]
keys = [
"accuracy",
"weighted avg",
"macro avg",
"micro avg",
"conversation_accuracy",
]
result = {key: data[key] for key in keys if key in data}

return result
Expand Down
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ types-requests = "^2.25.0"
types-setuptools = "^57.0.0"
memory-profiler = "^0.58.0"
psutil = "^5.8.0"
mypy-extensions = "^0.4.3"

[tool.poetry.extras]
spacy = [ "spacy",]
Expand Down
1 change: 1 addition & 0 deletions rasa/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def run_core_training(
rasa.utils.common.run_in_loop(
do_compare_training(args, story_file, additional_arguments)
)
return None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not very sure what do_compare_training does, but it also returns None, is there any reason why it shouldn't be return rasa.utils.common.run_in_loop(...)?



def run_nlu_training(
Expand Down
5 changes: 3 additions & 2 deletions rasa/cli/x.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import asyncio
import importlib.util
import logging
from multiprocessing import get_context, Process
from multiprocessing.process import BaseProcess
from multiprocessing import get_context
import os
import signal
import sys
Expand Down Expand Up @@ -198,7 +199,7 @@ def _is_correct_event_broker(event_broker: EndpointConfig) -> bool:

def start_rasa_for_local_rasa_x(
args: argparse.Namespace, rasa_x_token: Text
) -> Process:
) -> BaseProcess:
"""Starts the Rasa X API with Rasa as a background process."""
credentials_path, endpoints_path = _get_credentials_and_endpoints_paths(args)
endpoints = AvailableEndpoints.read_endpoints(endpoints_path)
Expand Down
4 changes: 2 additions & 2 deletions rasa/core/actions/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ async def run(
domain: "Domain",
) -> List[Event]:
"""Runs action. Please see parent class for the full docstring."""
_events = [SessionStarted(metadata=self.metadata)]
_events: List[Event] = [SessionStarted(metadata=self.metadata)]

if domain.session_config.carry_over_slots:
_events.extend(self._slot_set_events_from_tracker(tracker))
Expand Down Expand Up @@ -690,7 +690,7 @@ async def run(

events_json = response.get("events", [])
responses = response.get("responses", [])
bot_messages = await self._utter_responses(
bot_messages: List[Event] = await self._utter_responses(
ancalita marked this conversation as resolved.
Show resolved Hide resolved
responses, output_channel, nlg, tracker
)

Expand Down
10 changes: 5 additions & 5 deletions rasa/core/actions/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ async def validate_slots(
domain: Domain,
output_channel: OutputChannel,
nlg: NaturalLanguageGenerator,
) -> List[Event]:
) -> List[Union[SlotSet, Event]]:
ancalita marked this conversation as resolved.
Show resolved Hide resolved
"""Validate the extracted slots.

If a custom action is available for validating the slots, we call it to validate
Expand All @@ -445,7 +445,7 @@ async def validate_slots(
for the validated slots.
"""
logger.debug(f"Validating extracted slots: {slot_candidates}")
events = [
events: List[Union[SlotSet, Event]] = [
SlotSet(slot_name, value) for slot_name, value in slot_candidates.items()
]

Expand Down Expand Up @@ -506,7 +506,7 @@ async def validate(
domain: Domain,
output_channel: OutputChannel,
nlg: NaturalLanguageGenerator,
) -> List[Event]:
) -> List[Union[SlotSet, Event]]:
"""Extract and validate value of requested slot.

If nothing was extracted reject execution of the form action.
Expand Down Expand Up @@ -560,9 +560,9 @@ async def request_next_slot(
output_channel: OutputChannel,
nlg: NaturalLanguageGenerator,
events_so_far: List[Event],
) -> List[Event]:
) -> List[Union[SlotSet, Event]]:
"""Request the next slot and response if needed, else return `None`."""
request_slot_events = []
request_slot_events: List[Union[SlotSet, Event]] = []

if await self.is_done(output_channel, nlg, tracker, domain, events_so_far):
# The custom action for slot validation decided to stop the form early
Expand Down
5 changes: 3 additions & 2 deletions rasa/core/actions/two_stage_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ async def deactivate(
return await self._give_up(output_channel, nlg, tracker, domain)

# revert fallback events
return [UserUtteranceReverted()] + _message_clarification(tracker)
reverted_event: List[Event] = [UserUtteranceReverted()]
return reverted_event + _message_clarification(tracker)

async def _give_up(
self,
Expand Down Expand Up @@ -137,7 +138,7 @@ def _two_fallbacks_in_a_row(tracker: DialogueStateTracker) -> bool:

def _last_n_intent_names(
tracker: DialogueStateTracker, number_of_last_intent_names: int
) -> List[Text]:
) -> List[Optional[Text]]:
intent_names = []
for i in range(number_of_last_intent_names):
message = tracker.get_last_event_for(
Expand Down
19 changes: 14 additions & 5 deletions rasa/core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,16 @@
import shutil
import tempfile
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Text,
Tuple,
Union,
)
import uuid

import aiohttp
Expand Down Expand Up @@ -52,6 +61,8 @@
from rasa.utils.endpoints import EndpointConfig
import rasa.utils.io

from rasa.shared.core.generator import TrackerWithCachedStates

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -679,7 +690,7 @@ def _are_all_featurizers_using_a_max_history(self) -> bool:
"""Check if all featurizers are MaxHistoryTrackerFeaturizer."""

def has_max_history_featurizer(policy: Policy) -> bool:
return (
return bool(
policy.featurizer
and hasattr(policy.featurizer, "max_history")
and policy.featurizer.max_history is not None
Expand All @@ -700,9 +711,8 @@ async def load_data(
use_story_concatenation: bool = True,
debug_plots: bool = False,
exclusion_percentage: Optional[int] = None,
) -> List[DialogueStateTracker]:
) -> List["TrackerWithCachedStates"]:
"""Load training data from a resource."""

max_history = self._max_history()

if unique_last_num_states is None:
Expand Down Expand Up @@ -769,7 +779,6 @@ def _clear_model_directory(model_path: Text) -> None:
Only removes files if the directory seems to contain a previously
persisted model. Otherwise does nothing to avoid deleting
`/` by accident."""

if not os.path.exists(model_path):
return

Expand Down
1 change: 1 addition & 0 deletions rasa/core/channels/botframework.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ async def _get_headers(self) -> Optional[Dict[Text, Any]]:
return BotFramework.headers
else:
logger.error("Could not get BotFramework token")
return None
else:
return BotFramework.headers

Expand Down
2 changes: 2 additions & 0 deletions rasa/core/channels/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ def decode_bearer_token(
except Exception:
logger.exception("Failed to decode bearer token.")

return None


class OutputChannel:
"""Output channel base class.
Expand Down
17 changes: 10 additions & 7 deletions rasa/core/channels/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ def print_buttons(
rasa.shared.utils.cli.print_color(
cli_utils.button_to_string(button, idx), color=color
)
return None


def print_bot_output(
def _print_bot_output(
message: Dict[Text, Any],
is_latest_message: bool = False,
color: Text = rasa.shared.utils.io.bcolors.OKBLUE,
Expand Down Expand Up @@ -90,17 +91,19 @@ def print_bot_output(
json.dumps(message.get("custom"), indent=2), color=color
)

return None

def get_user_input(previous_response: Optional[Dict[str, Any]]) -> Optional[Text]:

def _get_user_input(previous_response: Optional[Dict[str, Any]]) -> Optional[Text]:
button_response = None
if previous_response is not None:
button_response = print_bot_output(previous_response, is_latest_message=True)
button_response = _print_bot_output(previous_response, is_latest_message=True)

if button_response is not None:
response = cli_utils.payload_from_button_question(button_response)
if response == cli_utils.FREE_TEXT_INPUT_PROMPT:
# Re-prompt user with a free text input
response = get_user_input({})
response = _get_user_input({})
else:
response = questionary.text(
"",
Expand Down Expand Up @@ -169,7 +172,7 @@ async def record_messages(
previous_response = None
await asyncio.sleep(0.5) # Wait for server to start
while not utils.is_limit_reached(num_messages, max_message_limit):
text = get_user_input(previous_response)
text = _get_user_input(previous_response)

if text == exit_text or text is None:
break
Expand All @@ -181,7 +184,7 @@ async def record_messages(
previous_response = None
async for response in bot_responses:
if previous_response is not None:
print_bot_output(previous_response)
_print_bot_output(previous_response)
previous_response = response
else:
bot_responses = await send_message_receive_block(
Expand All @@ -190,7 +193,7 @@ async def record_messages(
previous_response = None
for response in bot_responses:
if previous_response is not None:
print_bot_output(previous_response)
_print_bot_output(previous_response)
previous_response = response

num_messages += 1
Expand Down
5 changes: 3 additions & 2 deletions rasa/core/channels/hangouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _text_button_card(text: Text, buttons: List) -> Union[Dict, None]:
logger.error(
"Buttons must be a list of dicts with 'title' and 'payload' as keys"
)
return
return None

hangouts_buttons.append(
{
Expand Down Expand Up @@ -246,10 +246,11 @@ def _extract_message(self, req: Request) -> Text:

@staticmethod
def _extract_room(req: Request) -> Union[Text, None]:

if req.json["space"]["type"] == "ROOM":
return req.json["space"]["displayName"]

return None

def _extract_input_channel(self) -> Text:
return self.name()

Expand Down
3 changes: 2 additions & 1 deletion rasa/core/channels/slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,6 @@ def _is_interactive_message(payload: Dict) -> bool:
@staticmethod
def _get_interactive_response(action: Dict) -> Optional[Text]:
"""Parse the payload for the response value."""

if action["type"] == "button":
return action.get("value")
elif action["type"] == "select":
Expand All @@ -328,6 +327,8 @@ def _get_interactive_response(action: Dict) -> Optional[Text]:
elif action["type"] == "datepicker":
return action.get("selected_date")

return None

async def process_message(
self,
request: Request,
Expand Down
5 changes: 4 additions & 1 deletion rasa/core/channels/socketio.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,13 @@ def get_output_channel(self) -> Optional["OutputChannel"]:
"Please use a different channel for external events in these "
"scenarios."
)
return
return None
return SocketIOOutput(self.sio, self.bot_message_evt)

def blueprint(
self, on_new_message: Callable[[UserMessage], Awaitable[Any]]
) -> Blueprint:
"""Defines a Sanic blueprint."""
# Workaround so that socketio works with requests from other origins.
# https://github.com/miguelgrinberg/python-socketio/issues/205#issuecomment-493769183
sio = AsyncServer(async_mode="sanic", cors_allowed_origins=[])
Expand Down Expand Up @@ -203,10 +204,12 @@ async def connect(

if jwt_payload:
logger.debug(f"User {sid} connected to socketIO endpoint.")
return True
else:
return False
else:
logger.debug(f"User {sid} connected to socketIO endpoint.")
return True
wochinge marked this conversation as resolved.
Show resolved Hide resolved

@sio.on("disconnect", namespace=self.namespace)
async def disconnect(sid: Text) -> None:
Expand Down
4 changes: 3 additions & 1 deletion rasa/core/featurizers/tracker_featurizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,9 @@ def training_states_labels_and_entities(
domain: Domain,
omit_unset_slots: bool = False,
ignore_action_unlikely_intent: bool = False,
) -> Tuple[List[List[State]], List[List[Text]], List[List[Dict[Text, Any]]]]:
) -> Tuple[
List[List[State]], List[List[Optional[Text]]], List[List[Dict[Text, Any]]]
]:
"""Transforms trackers to states, action labels, and entity data.

Args:
Expand Down
3 changes: 3 additions & 0 deletions rasa/core/lock_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,10 @@ def get_lock(self, conversation_id: Text) -> Optional[TicketLock]:
if serialised_lock:
return TicketLock.from_dict(json.loads(serialised_lock))

return None

def delete_lock(self, conversation_id: Text) -> None:
"""Deletes lock for conversation ID."""
deletion_successful = self.red.delete(self.key_prefix + conversation_id)
self._log_deletion(conversation_id, deletion_successful)

Expand Down
2 changes: 1 addition & 1 deletion rasa/core/nlg/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ async def generate(
method="post", json=body, timeout=DEFAULT_REQUEST_TIMEOUT
)

if self.validate_response(response):
if isinstance(response, dict) and self.validate_response(response):
return response
else:
raise RasaException("NLG web endpoint returned an invalid response.")
Expand Down
Loading