diff --git a/poetry.lock b/poetry.lock index d10518cb0f2b..8e3d4801c4a7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3053,7 +3053,6 @@ version = ">=3.7.4" [[package]] category = "main" description = "Backport of pathlib-compatible object wrapper for zip files" -marker = "sys_platform != \"win32\" and python_version < \"3.8\" or python_version < \"3.8\" or python_version < \"3.7\" and python_version != \"3.4\"" name = "zipp" optional = false python-versions = ">=3.6" @@ -3071,7 +3070,7 @@ spacy = ["spacy"] transformers = ["transformers"] [metadata] -content-hash = "01760c0b388a7deeceb6127efa0dc9f8c3a495e7046e2776e3073d1919e66867" +content-hash = "7c826413e4ee6a6c55df8f25b5b41217b75f9235c9305892c1fb112beabcbe35" python-versions = ">=3.6,<3.9" [metadata.files] diff --git a/pyproject.toml b/pyproject.toml index f40d42e13482..a917d70317ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -150,6 +150,7 @@ toml = "^0.10.0" pep440-version-utils = "^0.3.0" pydoc-markdown = "3.3.0.post1" mypy = "^0.782" +typing-extensions = "^3.7.4" [tool.poetry.extras] spacy = [ "spacy",] diff --git a/rasa/cli/train.py b/rasa/cli/train.py index 0ffff340f074..d5bb21726216 100644 --- a/rasa/cli/train.py +++ b/rasa/cli/train.py @@ -1,9 +1,12 @@ import argparse +import asyncio import os from typing import List, Optional, Text, Dict -import rasa.cli.arguments.train as train_arguments +import rasa.train +import rasa.cli.arguments.train as train_arguments from rasa.cli.utils import get_validated_path, missing_config_keys, print_error +from rasa.core.train import do_compare_training from rasa.constants import ( DEFAULT_CONFIG_PATH, DEFAULT_DATA_PATH, @@ -52,8 +55,6 @@ def add_subparser( def train(args: argparse.Namespace) -> Optional[Text]: - import rasa - domain = get_validated_path( args.domain, "domain", DEFAULT_DOMAIN_PATH, none_is_valid=True ) @@ -65,7 +66,7 @@ def train(args: argparse.Namespace) -> Optional[Text]: for f in args.data ] - return rasa.train( + return rasa.train.train( domain=domain, config=config, training_files=training_files, @@ -81,9 +82,6 @@ def train(args: argparse.Namespace) -> Optional[Text]: def train_core( args: argparse.Namespace, train_path: Optional[Text] = None ) -> Optional[Text]: - from rasa.train import train_core - import asyncio - loop = asyncio.get_event_loop() output = train_path or args.out @@ -103,7 +101,7 @@ def train_core( config = _get_valid_config(args.config, CONFIG_MANDATORY_KEYS_CORE) - return train_core( + return rasa.train.train_core( domain=args.domain, config=config, stories=story_file, @@ -113,8 +111,6 @@ def train_core( additional_arguments=additional_arguments, ) else: - from rasa.core.train import do_compare_training - loop.run_until_complete( do_compare_training(args, story_file, additional_arguments) ) @@ -123,8 +119,6 @@ def train_core( def train_nlu( args: argparse.Namespace, train_path: Optional[Text] = None ) -> Optional[Text]: - from rasa.train import train_nlu - output = train_path or args.out config = _get_valid_config(args.config, CONFIG_MANDATORY_KEYS_NLU) @@ -132,7 +126,7 @@ def train_nlu( args.nlu, "nlu", DEFAULT_DATA_PATH, none_is_valid=True ) - return train_nlu( + return rasa.train.train_nlu( config=config, nlu_data=nlu_data, output=output, diff --git a/rasa/server.py b/rasa/server.py index ffc4d9e7b0b1..3117ebd59954 100644 --- a/rasa/server.py +++ b/rasa/server.py @@ -9,7 +9,7 @@ from functools import reduce, wraps from inspect import isawaitable from pathlib import Path -from typing import Any, Callable, List, Optional, Text, Union, Dict +from typing import Any, Callable, List, Optional, Text, Union, Dict, cast from rasa.core.training.story_writer.yaml_story_writer import YAMLStoryWriter from rasa.nlu.training_data.formats import RasaYAMLReader @@ -52,8 +52,16 @@ if typing.TYPE_CHECKING: from ssl import SSLContext + from typing_extensions import Protocol from rasa.core.processor import MessageProcessor + class SanicView(Protocol): + def __call__( + self, request: Request, *args: Any, **kwargs: Any + ) -> response.BaseHTTPResponse: + ... + + logger = logging.getLogger(__name__) JSON_CONTENT_TYPE = "application/json" @@ -123,7 +131,7 @@ def decorated(*args, **kwargs): def requires_auth(app: Sanic, token: Optional[Text] = None) -> Callable[[Any], Any]: """Wraps a request handler with token authentication.""" - def decorator(f: Callable[[Any, Any], Any]) -> Callable[[Any, Any], Any]: + def decorator(f: "SanicView") -> "SanicView": def conversation_id_from_args(args: Any, kwargs: Any) -> Optional[Text]: argnames = common_utils.arguments_of(f) @@ -153,7 +161,9 @@ def sufficient_scope(request, *args: Any, **kwargs: Any) -> Optional[bool]: return False @wraps(f) - async def decorated(request: Request, *args: Any, **kwargs: Any) -> Any: + async def decorated( + request: Request, *args: Any, **kwargs: Any + ) -> response.BaseHTTPResponse: provided = request.args.get("token", None) @@ -227,7 +237,7 @@ async def get_tracker( _validate_tracker(tracker, conversation_id) # `_validate_tracker` ensures we can't return `None` so `Optional` is not needed - return tracker # pytype: disable=bad-return-type + return cast(DialogueStateTracker, tracker) def _validate_tracker( @@ -631,7 +641,7 @@ async def execute_action(request: Request, conversation_id: Text): tracker = await get_tracker(app.agent.create_processor(), conversation_id) state = tracker.current_state(verbosity) - response_body = {"tracker": state} + response_body: Dict[Text, Any] = {"tracker": state} if isinstance(output_channel, CollectingOutputChannel): response_body["messages"] = output_channel.messages @@ -685,7 +695,7 @@ async def trigger_intent(request: Request, conversation_id: Text) -> HTTPRespons state = tracker.current_state(verbosity) - response_body = {"tracker": state} + response_body: Dict[Text, Any] = {"tracker": state} if isinstance(output_channel, CollectingOutputChannel): response_body["messages"] = output_channel.messages @@ -870,6 +880,11 @@ async def evaluate_intents(request: Request) -> HTTPResponse: model_directory = eval_agent.model_directory _, nlu_model = model.get_model_subdirectories(model_directory) + if nlu_model is None: + raise ErrorResponse( + 500, "TestingError", f"Missing NLU model directory.", + ) + try: evaluation = run_evaluation(data_path, nlu_model, disable_plotting=True) return response.json(evaluation) diff --git a/rasa/test.py b/rasa/test.py index bdeee1ff796c..cac9fefb8ead 100644 --- a/rasa/test.py +++ b/rasa/test.py @@ -147,7 +147,7 @@ def test_core( "to train a NLU model first, e.g. using `rasa train`." ) - from rasa.core.test import test + from rasa.core.test import test as core_test kwargs = utils.minimal_kwargs(additional_arguments, test, ["stories", "agent"]) @@ -157,11 +157,11 @@ def test_core( def _test_core( stories: Optional[Text], agent: "Agent", output_directory: Text, **kwargs: Any ) -> None: - from rasa.core.test import test + from rasa.core.test import test as core_test loop = asyncio.get_event_loop() loop.run_until_complete( - test(stories, agent, out_directory=output_directory, **kwargs) + core_test(stories, agent, out_directory=output_directory, **kwargs) ) diff --git a/setup.cfg b/setup.cfg index 6b79289b6d9e..81ebaaab50fc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -398,9 +398,6 @@ ignore_errors = True [mypy-rasa.run] ignore_errors = True -[mypy-rasa.server] -ignore_errors = True - [mypy-rasa.test] ignore_errors = True