From 255c6f9b23f9b8270048941ae1d8a73c1b5df49a Mon Sep 17 00:00:00 2001 From: Rudi Giesler Date: Wed, 12 Aug 2020 13:02:16 +0200 Subject: [PATCH 1/2] Include school data in persisted slots --- dbe/actions/actions.py | 14 ++++++++++++- dbe/tests/test_actions.py | 42 ++++++++++++++++++++++++++++++++------- 2 files changed, 48 insertions(+), 8 deletions(-) diff --git a/dbe/actions/actions.py b/dbe/actions/actions.py index 340f42a..9313d60 100644 --- a/dbe/actions/actions.py +++ b/dbe/actions/actions.py @@ -1,12 +1,14 @@ from typing import Any, Dict, List, Optional, Text, Union from rasa_sdk import Tracker +from rasa_sdk.events import SlotSet from rasa_sdk.executor import CollectingDispatcher from whoosh.index import open_dir from whoosh.qparser import MultifieldParser from whoosh.query import FuzzyTerm, Term -from base.actions.actions import ActionExit, ActionSessionStart +from base.actions.actions import ActionExit +from base.actions.actions import ActionSessionStart as BaseActionSessionStart from base.actions.actions import HealthCheckForm as BaseHealthCheckForm from base.actions.actions import HealthCheckProfileForm as BaseHealthCheckProfileForm from base.actions.actions import HealthCheckTermsForm @@ -101,10 +103,20 @@ def get_eventstore_data(self, tracker: Tracker, risk: Text) -> Dict[Text, Any]: return data +class ActionSessionStart(BaseActionSessionStart): + def get_carry_over_slots(self, tracker: Tracker) -> List[Dict[Text, Any]]: + actions = super().get_carry_over_slots(tracker) + carry_over_slots = ("school", "school_confirm", "school_emis") + for slot in carry_over_slots: + actions.append(SlotSet(slot, tracker.get_slot(slot))) + return actions + + __all__ = [ "HealthCheckTermsForm", "HealthCheckProfileForm", "HealthCheckForm", "ActionSessionStart", "ActionExit", + "ActionSessionStart", ] diff --git a/dbe/tests/test_actions.py b/dbe/tests/test_actions.py index 80d82ca..bb8387a 100644 --- a/dbe/tests/test_actions.py +++ b/dbe/tests/test_actions.py @@ -1,9 +1,14 @@ from unittest import TestCase from rasa_sdk import Tracker +from rasa_sdk.events import SlotSet from rasa_sdk.executor import CollectingDispatcher -from dbe.actions.actions import HealthCheckForm, HealthCheckProfileForm +from dbe.actions.actions import ( + ActionSessionStart, + HealthCheckForm, + HealthCheckProfileForm, +) class HealthCheckProfileFormTests(TestCase): @@ -51,9 +56,7 @@ def test_validate_school_confirm_no(self): Try again getting the name of the school """ form = HealthCheckProfileForm() - tracker = Tracker( - "27820001001", {}, {}, [], False, None, {}, "action_listen" - ) + tracker = Tracker("27820001001", {}, {}, [], False, None, {}, "action_listen") dispatcher = CollectingDispatcher() response = form.validate_school_confirm("no", dispatcher, tracker, {}) self.assertEqual(response, {"school": None, "school_confirm": None}) @@ -63,9 +66,7 @@ def test_validate_school_confirm_yes(self): Confirms the name of the school """ form = HealthCheckProfileForm() - tracker = Tracker( - "27820001001", {}, {}, [], False, None, {}, "action_listen" - ) + tracker = Tracker("27820001001", {}, {}, [], False, None, {}, "action_listen") dispatcher = CollectingDispatcher() response = form.validate_school_confirm("yes", dispatcher, tracker, {}) self.assertEqual(response, {"school_confirm": "yes"}) @@ -148,3 +149,30 @@ def test_eventstore_data(self): }, }, ) + + +class ActionSessionStartTests(TestCase): + def test_school_details_copied(self): + """ + Should copy over the school details to the new session + """ + action = ActionSessionStart() + events = action.get_carry_over_slots( + Tracker( + "27820001001", + { + "school": "BERGVLIET HIGH SCHOOL", + "school_emis": "105310201", + "school_confirm": "yes", + }, + {}, + [], + False, + None, + {}, + "action_listen", + ) + ) + self.assertIn(SlotSet("school", "BERGVLIET HIGH SCHOOL"), events) + self.assertIn(SlotSet("school_emis", "105310201"), events) + self.assertIn(SlotSet("school_confirm", "yes"), events) From e051412d0be47cba08d927c7cd451840b8c93641 Mon Sep 17 00:00:00 2001 From: Rudi Giesler Date: Wed, 12 Aug 2020 13:07:12 +0200 Subject: [PATCH 2/2] Fix CI to check not fix --- test.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test.sh b/test.sh index 58e9ce2..3d908ea 100755 --- a/test.sh +++ b/test.sh @@ -12,8 +12,8 @@ if (( $# != 2 )); then exit 1 fi -black . -isort -rc . +black --check . +isort -rc -c . mypy "$1" flake8 . py.test