diff --git a/backend/app.py b/backend/app.py index 54272a57..93b9c27a 100644 --- a/backend/app.py +++ b/backend/app.py @@ -1,83 +1,8 @@ -from flask import Flask, request -from flask_cors import CORS from waitress import serve -from http import HTTPStatus -from backend.request import ClassificationRequest -from backend.response import ClassificationResponse -from backend.spectrogram_generator import SpectrogramGenerator -from classification.parser import get_raw_array -from classification.exceptions import ClassificationError -from classification.config.constants import Sex, ALLOWED_FILE_EXTENSIONS -from classification.model import SleepStagesClassifier -from classification.features.preprocessing import preprocess +from web import App -app = Flask(__name__) -sleep_stage_classifier = SleepStagesClassifier() - - -def allowed_file(filename): - return filename.lower().endswith(ALLOWED_FILE_EXTENSIONS) - - -@app.route("/") -def status(): - return "" - - -@app.route('/analyze-sleep', methods=['POST']) -def analyze_sleep(): - """ - Request payload example - { - "file": File(...), - "device": "CYTON", - "sex": "F", - "age": "23", - "stream_start": 1602895800000, - "bedtime": 1602898320000, - "wakeup": 1602931800000 - } - """ - if 'file' not in request.files: - return 'Missing file', HTTPStatus.BAD_REQUEST - file = request.files['file'] - - if file.filename == '': - return 'No selected file', HTTPStatus.BAD_REQUEST - - if not allowed_file(file.filename): - return 'File format not allowed', HTTPStatus.BAD_REQUEST - - form_data = request.form.to_dict() - raw_array = get_raw_array(file) - - try: - classification_request = ClassificationRequest( - age=int(form_data['age']), - sex=Sex[form_data['sex']], - stream_start=int(form_data['stream_start']), - bedtime=int(form_data['bedtime']), - wakeup=int(form_data['wakeup']), - raw_eeg=raw_array, - ) - except (KeyError, ValueError, ClassificationError): - return 'Missing or invalid request parameters', HTTPStatus.BAD_REQUEST - - preprocessed_epochs = preprocess(classification_request) - predictions = sleep_stage_classifier.predict(preprocessed_epochs, classification_request) - spectrogram_generator = SpectrogramGenerator(preprocessed_epochs) - classification_response = ClassificationResponse( - classification_request, predictions, spectrogram_generator.generate() - ) - - return classification_response.response - - -CORS(app, - resources={r'/*': {"origins": '*'}}, - allow_headers='Content-Type') -app.config['CORS_HEADERS'] = 'Content-Type' +app = App() if __name__ == '__main__': serve(app, host='0.0.0.0', port=8080) diff --git a/backend/classification/parser/constants.py b/backend/classification/parser/constants.py index aa403d2e..28191d22 100644 --- a/backend/classification/parser/constants.py +++ b/backend/classification/parser/constants.py @@ -8,6 +8,4 @@ FILE_COLUMN_OFFSET = 1 -SESSION_FILE_HEADER_NB_LINES = 4 - RETAINED_COLUMNS = tuple(range(FILE_COLUMN_OFFSET, len(EEG_CHANNELS) + 1)) diff --git a/backend/classification/parser/csv.py b/backend/classification/parser/csv.py index 512059f0..1a98bb2c 100644 --- a/backend/classification/parser/csv.py +++ b/backend/classification/parser/csv.py @@ -1,11 +1,12 @@ import pandas as pd +from io import StringIO from classification.exceptions import ClassificationError -def read_csv(file, rows_to_skip=0, columns_to_read=None): +def read_csv(file_content, rows_to_skip=0, columns_to_read=None): try: - raw_array = pd.read_csv(file, + raw_array = pd.read_csv(StringIO(file_content), skiprows=rows_to_skip, usecols=columns_to_read ).to_numpy() diff --git a/backend/classification/parser/file_type.py b/backend/classification/parser/file_type.py index 0f847a33..85472c43 100644 --- a/backend/classification/parser/file_type.py +++ b/backend/classification/parser/file_type.py @@ -12,11 +12,10 @@ def __init__(self, parser): self.parser = parser -def detect_file_type(file) -> FileType: +def detect_file_type(file_content) -> FileType: """Detects file type - file: received as an input file Returns: - FileType of the input file """ - first_line = file.readline().decode("utf-8") - return FileType.SessionFile if "EEG Data" in first_line else FileType.SDFile + return FileType.SessionFile if "EEG Data" in file_content else FileType.SDFile diff --git a/backend/classification/parser/sample_rate.py b/backend/classification/parser/sample_rate.py index b8a53cd5..30c62ad3 100644 --- a/backend/classification/parser/sample_rate.py +++ b/backend/classification/parser/sample_rate.py @@ -1,28 +1,22 @@ import re from classification.parser.file_type import FileType -from classification.parser.constants import SESSION_FILE_HEADER_NB_LINES from classification.exceptions import ClassificationError OPENBCI_CYTON_SD_DEFAULT_SAMPLE_RATE = 250 SAMPLE_RATE_STRING = "Sample Rate" -SAMPLE_RATE_REGEX = fr"^%{SAMPLE_RATE_STRING} = (\d+)" +SAMPLE_RATE_REGEX = fr"%{SAMPLE_RATE_STRING} = (\d+)" -def detect_sample_rate(file, filetype): +def detect_sample_rate(file_content, filetype): if filetype == FileType.SDFile: return OPENBCI_CYTON_SD_DEFAULT_SAMPLE_RATE - for _ in range(SESSION_FILE_HEADER_NB_LINES): - line = file.readline().decode("utf-8") - if SAMPLE_RATE_STRING not in line: - continue - - try: - sample_rate_raw = re.search(SAMPLE_RATE_REGEX, line).group(1) - return int(sample_rate_raw) - except BaseException: - raise ClassificationError('Invalid sample rate') + try: + sample_rate_raw = re.search(SAMPLE_RATE_REGEX, file_content).group(1) + return int(sample_rate_raw) + except BaseException: + raise ClassificationError('Invalid sample rate') raise ClassificationError("Couldn't find sample rate") diff --git a/backend/requirements.txt b/backend/requirements.txt index 83805be4..4987346c 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,5 +1,4 @@ -Flask==1.1.2 -Flask-Cors==1.10.3 +falcon==3.0.0a2 waitress==1.4.4 mne==0.21.0 diff --git a/backend/web/__init__.py b/backend/web/__init__.py new file mode 100644 index 00000000..2cbef41b --- /dev/null +++ b/backend/web/__init__.py @@ -0,0 +1,16 @@ +import falcon + +from web.ping import Ping +from web.analyze_sleep import AnalyzeSleep + + +def App(): + app = falcon.App(cors_enable=True) + + ping = Ping() + app.add_route('/', ping) + + analyze = AnalyzeSleep() + app.add_route('/analyze-sleep', analyze) + + return app diff --git a/backend/web/analyze_sleep.py b/backend/web/analyze_sleep.py new file mode 100644 index 00000000..8915d853 --- /dev/null +++ b/backend/web/analyze_sleep.py @@ -0,0 +1,81 @@ +import json +import falcon + +from backend.request import ClassificationRequest +from backend.response import ClassificationResponse +from backend.spectrogram_generator import SpectrogramGenerator +from classification.parser import get_raw_array +from classification.exceptions import ClassificationError +from classification.config.constants import Sex, ALLOWED_FILE_EXTENSIONS +from classification.model import SleepStagesClassifier +from classification.features.preprocessing import preprocess + +sleep_stage_classifier = SleepStagesClassifier() + + +class AnalyzeSleep: + @staticmethod + def _validate_file(file_content): + if file_content is None: + raise ClassificationError("Missing file") + + @staticmethod + def _validate_filename(filename): + if not filename.lower().endswith(ALLOWED_FILE_EXTENSIONS): + raise ClassificationError('File format not allowed') + + @staticmethod + def _parse_form(form): + form_data = {} + file_content = None + + for part in form: + if part.name == 'file': + AnalyzeSleep._validate_filename(part.filename) + file_content = part.stream.read().decode('utf-8') + else: + form_data[part.name] = part.text + + AnalyzeSleep._validate_file(file_content) + + return form_data, file_content + + def on_post(self, request, response): + """ + Request payload example + { + "file": File(...), + "device": "CYTON", + "sex": "F", + "age": "23", + "stream_start": 1602895800000, + "bedtime": 1602898320000, + "wakeup": 1602931800000 + } + """ + + try: + form_data, file = self._parse_form(request.get_media()) + raw_array = get_raw_array(file) + classification_request = ClassificationRequest( + age=int(form_data['age']), + sex=Sex[form_data['sex']], + stream_start=int(form_data['stream_start']), + bedtime=int(form_data['bedtime']), + wakeup=int(form_data['wakeup']), + raw_eeg=raw_array, + ) + except (KeyError, ValueError, ClassificationError): + response.status = falcon.HTTP_400 + response.content_type = falcon.MEDIA_TEXT + response.body = 'Missing or invalid request parameters' + return + + preprocessed_epochs = preprocess(classification_request) + predictions = sleep_stage_classifier.predict(preprocessed_epochs, classification_request) + spectrogram_generator = SpectrogramGenerator(preprocessed_epochs) + classification_response = ClassificationResponse( + classification_request, predictions, spectrogram_generator.generate() + ) + + response.body = json.dumps(classification_response.response) diff --git a/backend/web/ping.py b/backend/web/ping.py new file mode 100644 index 00000000..58db3272 --- /dev/null +++ b/backend/web/ping.py @@ -0,0 +1,7 @@ +import falcon + + +class Ping: + def on_get(self, request, response): + response.content_type = falcon.MEDIA_TEXT + response.body = '' diff --git a/web/src/views/analyze_sleep/upload_form/index.js b/web/src/views/analyze_sleep/upload_form/index.js index 8d8f868b..5bfb0df6 100644 --- a/web/src/views/analyze_sleep/upload_form/index.js +++ b/web/src/views/analyze_sleep/upload_form/index.js @@ -166,7 +166,7 @@ const UploadForm = () => { // prettier-ignore const streamStart = new Date(`${getValues('stream_start_date')} ${getValues('stream_start_time')}`); const bedTime = new Date(`${getValues('bedtime_date')} ${getValues('bedtime_time')}`); - if (streamStart > bedTime) { + if (streamStart >= bedTime) { return 'Stream start must be prior to bedtime.'; } },