diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..b0c1e83 --- /dev/null +++ b/.flake8 @@ -0,0 +1,5 @@ +[flake8] +max-line-length = 120 +ignore = E402, E265, F403, W503, W504, E731 +exclude = .github, .git, venv*, docs, build +per-file-ignores = **/__init__.py:F401 diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 0000000..173c1e3 --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,12 @@ +# These are supported funding model platforms + +github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] +patreon: # Replace with a single Patreon username +open_collective: pyronear +ko_fi: # Replace with a single Ko-fi username +tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel +community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry +liberapay: # Replace with a single Liberapay username +issuehunt: # Replace with a single IssueHunt username +otechie: # Replace with a single Otechie username +custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] diff --git a/.github/validate_headers.py b/.github/validate_headers.py new file mode 100644 index 0000000..6104874 --- /dev/null +++ b/.github/validate_headers.py @@ -0,0 +1,57 @@ +from datetime import datetime +from pathlib import Path + +shebang = ["#!usr/bin/python\n"] +blank_line = "\n" + +# Possible years +starting_year = 2022 +current_year = datetime.now().year + +year_options = [f"{current_year}"] + [f"{year}-{current_year}" for year in range(starting_year, current_year)] +copyright_notices = [[f"# Copyright (C) {year_str}, Pyronear.\n"] for year_str in year_options] +license_notice = [ + "# This program is licensed under the Apache License version 2.\n", + "# See LICENSE or go to for full license details.\n", +] + +# Define all header options +HEADERS = [ + shebang + [blank_line] + copyright_notice + [blank_line] + license_notice for copyright_notice in copyright_notices +] + [copyright_notice + [blank_line] + license_notice for copyright_notice in copyright_notices] + + +IGNORED_FILES = ["version.py", "__init__.py"] +FOLDERS = ["src/app"] + + +def main(): + + invalid_files = [] + + # For every python file in the repository + for folder in FOLDERS: + for source_path in Path(__file__).parent.parent.joinpath(folder).rglob("**/*.py"): + if source_path.name not in IGNORED_FILES: + # Parse header + header_length = max(len(option) for option in HEADERS) + current_header = [] + with open(source_path) as f: + for idx, line in enumerate(f): + current_header.append(line) + if idx == header_length - 1: + break + # Validate it + if not any( + "".join(current_header[: min(len(option), len(current_header))]) == "".join(option) + for option in HEADERS + ): + invalid_files.append(source_path) + + if len(invalid_files) > 0: + invalid_str = "\n- " + "\n- ".join(map(str, invalid_files)) + raise AssertionError(f"Invalid header in the following files:{invalid_str}") + + +if __name__ == "__main__": + main() diff --git a/.github/workflows/builds.yml b/.github/workflows/builds.yml new file mode 100644 index 0000000..357eb41 --- /dev/null +++ b/.github/workflows/builds.yml @@ -0,0 +1,34 @@ +name: builds + +on: + push: + branches: main + pull_request: + branches: main + +jobs: + install: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest] + python: [3.7, 3.8] + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Cache python modules + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-deps-${{ matrix.python }}-${{ hashFiles('src/app/requirements.txt') }}-${{ hashFiles('**/*.py') }} + restore-keys: | + ${{ runner.os }}-deps-${{ matrix.python }}-${{ hashFiles('src/app/requirements.txt') }}- + - name: Install project + run: | + python -m pip install --upgrade pip + pip install -r src/app/requirements.txt diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 0000000..5a9a5a4 --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,75 @@ +name: api + +on: + push: + branches: main + pull_request: + branches: main + +jobs: + docker-ready: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + steps: + - uses: actions/checkout@v2 + - name: Build & run docker + env: + QARNOT_TOKEN: ${{ secrets.QARNOT_TOKEN }} + BUCKET_NAME: ${{ secrets.BUCKET_NAME }} + BUCKET_MEDIA_FOLDER: ${{ secrets.BUCKET_MEDIA_FOLDER }} + BUCKET_ANNOT_FOLDER: ${{ secrets.BUCKET_ANNOT_FOLDER }} + run: docker-compose up -d --build + - name: Docker sanity check + run: sleep 10 && nc -vz localhost 8080 + - name: Ping server + run: curl http://localhost:8080/docs + + pytest: + runs-on: ubuntu-latest + needs: docker-ready + steps: + - uses: actions/checkout@v2 + - name: Build & run docker + env: + QARNOT_TOKEN: ${{ secrets.QARNOT_TOKEN }} + BUCKET_NAME: ${{ secrets.BUCKET_NAME }} + BUCKET_MEDIA_FOLDER: ${{ secrets.BUCKET_MEDIA_FOLDER }} + BUCKET_ANNOT_FOLDER: ${{ secrets.BUCKET_ANNOT_FOLDER }} + run: docker-compose up -d --build + - name: Install dependencies in docker + run: | + docker-compose exec -T pyrostorage python -m pip install --upgrade pip + docker-compose exec -T pyrostorage pip install -r requirements-dev.txt + - name: Run docker test + run: | + docker-compose exec -T pyrostorage coverage --version + docker-compose exec -T pyrostorage coverage run -m pytest tests/ + docker-compose exec -T pyrostorage coverage xml + docker cp pyro-storage_pyrostorage_1:/app/coverage.xml . + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v1 + with: + file: ./coverage.xml + flags: unittests + fail_ci_if_error: true + + headers: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python: [3.7] + steps: + - uses: actions/checkout@v2 + with: + persist-credentials: false + - name: Set up Python + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Run unittests + run: python .github/validate_headers.py diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml new file mode 100644 index 0000000..20fa01a --- /dev/null +++ b/.github/workflows/style.yml @@ -0,0 +1,75 @@ +name: style + +on: + push: + branches: main + pull_request: + branches: main + +jobs: + flake8: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python: [3.7] + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Run flake8 + run: | + pip install flake8 + flake8 --version + flake8 ./ + + isort: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python: [3.7] + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Run isort + run: | + pip install isort + isort --version + isort . + if [ -n "$(git status --porcelain --untracked-files=no)" ]; then exit 1; else echo "All clear"; fi + + mypy: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python: [3.7] + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Cache python modules + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('src/app/requirements.txt') }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r src/app/requirements.txt --upgrade + pip install mypy + - name: Run mypy + run: | + mypy --version + mypy diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..070bafd --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,128 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, religion, or sexual identity +and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the + overall community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or + advances of any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email + address, without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +contact@pyronear.org. +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series +of actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or +permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within +the community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.0, available at +https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. + +Community Impact Guidelines were inspired by [Mozilla's code of conduct +enforcement ladder](https://github.com/mozilla/diversity). + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see the FAQ at +https://www.contributor-covenant.org/faq. Translations are available at +https://www.contributor-covenant.org/translations. diff --git a/README.md b/README.md index 7ccbd19..1318d2e 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,61 @@ -# pyro-storage -Management of data & annotations for wildfire detection +# Data curation API + +[![License: Apache 2.0](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](LICENSE) ![Build Status](https://github.com/pyronear/pyro-storage/workflows/api/badge.svg) [![codecov](https://codecov.io/gh/pyronear/pyro-storage/branch/main/graph/badge.svg)](https://codecov.io/gh/pyronear/pyro-storage) + +The building blocks of our data curation API. + + + +## Getting started + +### Prerequisites + +- Python 3.7 (or more recent) +- [pip](https://pip.pypa.io/en/stable/) + +### Installation + +You can clone and install the project dependencies as follows: + +```shell +git clone https://github.com/pyronear/pyro-storage.git +``` + +## Usage + +If you wish to deploy this project on a server hosted remotely, you might want to be using [Docker](https://www.docker.com/) containers. Beforehand, you will need to set a few environment variables either manually or by writing an `.env` file in the root directory of this project, like in the example below: + +``` +QARNOT_TOKEN=my_very_secret_token +BUCKET_NAME=my_storage_bucket_name +BUCKET_MEDIA_FOLDER=my/media/subfolder +BUCKET_ANNOTATIONS_FOLDER=my/annotations/subfolder + +``` + +Those values will allow your API server to connect to our cloud service provider [Qarnot Computing](https://qarnot.com/), which is mandatory for your local server to be fully operational. +Then you can run the API containers using this command: + +```shell +docker-compose up -d --build +``` + +Once completed, you will notice that you have a docker container running on the port you selected, which can process requests just like any django server. + + + +## Documentation + +The full project documentation is available [here](http://pyro-storage.herokuapp.com/redoc) for detailed specifications. The documentation was built with [ReDoc](https://redocly.github.io/redoc/). + + + +## Contributing + +Please refer to `CONTRIBUTING` if you wish to contribute to this project. + + + +## License + +Distributed under the Apache 2.0 License. See `LICENSE` for more information. \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..ff5cc03 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,50 @@ +version: '3.7' + +services: + pyrostorage: + build: src + command: uvicorn app.main:app --reload --workers 1 --host 0.0.0.0 --port 8080 + volumes: + - ./src/:/app/ + ports: + - 8080:8080 + environment: + - DATABASE_URL=postgresql://dummy_pg_user:dummy_pg_pwd@db/dummy_pg_db + - TEST_DATABASE_URL=postgresql://dummy_pg_tuser:dummy_pg_tpwd@db_test/dummy_pg_tdb + - SUPERUSER_LOGIN=dummy_login + - SUPERUSER_PWD=dummy_pwd + - QARNOT_TOKEN=${QARNOT_TOKEN} + - BUCKET_NAME=${BUCKET_NAME} + - BUCKET_MEDIA_FOLDER=${BUCKET_MEDIA_FOLDER} + - BUCKET_ANNOT_FOLDER=${BUCKET_ANNOT_FOLDER} + depends_on: + - db + db: + image: postgres:12.1-alpine + volumes: + - postgres_data:/var/lib/postgresql/data/ + ports: + - 5432:5432 + environment: + - POSTGRES_USER=dummy_pg_user + - POSTGRES_PASSWORD=dummy_pg_pwd + - POSTGRES_DB=dummy_pg_db + nginx: + build: nginx + ports: + - 80:80 + - 443:443 + depends_on: + - pyrostorage + db_test: + image: postgres:12.1-alpine + volumes: + - postgres_data_test:/var/lib/postgresql/data_test/ + environment: + - POSTGRES_USER=dummy_pg_tuser + - POSTGRES_PASSWORD=dummy_pg_tpwd + - POSTGRES_DB=dummy_pg_tdb + +volumes: + postgres_data: + postgres_data_test: diff --git a/nginx/Dockerfile b/nginx/Dockerfile new file mode 100644 index 0000000..17b5f68 --- /dev/null +++ b/nginx/Dockerfile @@ -0,0 +1,3 @@ +FROM nginx:latest + +COPY nginx.conf /etc/nginx/nginx.conf diff --git a/nginx/nginx.conf b/nginx/nginx.conf new file mode 100644 index 0000000..95705e9 --- /dev/null +++ b/nginx/nginx.conf @@ -0,0 +1,55 @@ + +worker_processes 1; + +events { + worker_connections 1024; # increase if you have lots of clients + accept_mutex off; # set to 'on' if nginx worker_processes > 1 +} + +http { + include mime.types; + # fallback in case we can't determine a type + default_type application/octet-stream; + access_log /var/log/nginx/access.log combined; + sendfile on; + + upstream app_server { + # fail_timeout=0 means we always retry an upstream even if it failed + # to return a good HTTP response + + # for a TCP configuration + server pyroapi:8080 fail_timeout=0; + } + + server { + # if no Host match, close the connection to prevent host spoofing + listen 80 default_server; + return 444; + } + + server { + # use 'listen 80 deferred;' for Linux + # use 'listen 80 accept_filter=httpready;' for FreeBSD + client_max_body_size 4G; + + # set the correct host(s) for your site + server_name storage.pyronear.org; + + keepalive_timeout 5; + + location / { + # checks for static file, if not found proxy to app + try_files $uri @proxy_to_app; + } + + location @proxy_to_app { + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Forwarded-Proto $scheme; + proxy_set_header Host $http_host; + # we don't want nginx trying to do something clever with + # redirects, we set the Host: header above already. + proxy_redirect off; + proxy_pass http://app_server; + } + } +} \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..718f98e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,38 @@ +[tool.mypy] +mypy_path = "src/" +files = "src/app" +show_error_codes = true +pretty = true +warn_unused_ignores = true +warn_redundant_casts = true +no_implicit_optional = true +check_untyped_defs = true +implicit_reexport = false + +[[tool.mypy.overrides]] +module = [ + "sqlalchemy.*", + "qarnot.*", + "jose.*", + "passlib.*", + "app.*", + "requests.*", +] +ignore_missing_imports = true + +[tool.isort] +line_length = 120 +src_paths = ["src/", "client/"] +skip_glob = "**/__init__.py" +known_third_party = ["fastapi"] + +[tool.pydocstyle] +select = "D300,D301,D417" +match = ".*\\.py" + +[tool.coverage.run] +source = ["src/app"] + +[tool.black] +line-length = 120 +target-version = ['py38'] diff --git a/src/.coveragerc b/src/.coveragerc new file mode 100644 index 0000000..cfa196e --- /dev/null +++ b/src/.coveragerc @@ -0,0 +1,2 @@ +[run] +source = app \ No newline at end of file diff --git a/src/Dockerfile b/src/Dockerfile new file mode 100644 index 0000000..2c37309 --- /dev/null +++ b/src/Dockerfile @@ -0,0 +1,20 @@ +FROM tiangolo/uvicorn-gunicorn-fastapi:python3.8-alpine3.10 + +WORKDIR /app + +# set environment variables +ENV PYTHONDONTWRITEBYTECODE 1 +ENV PYTHONUNBUFFERED 1 +ENV PYTHONPATH "${PYTHONPATH}:/app" + +# copy requirements file +COPY app/requirements.txt /app/requirements.txt + +# install dependencies +RUN set -eux \ + && apk add --no-cache --virtual .build-deps build-base postgresql-dev gcc libffi-dev libressl-dev musl-dev \ + && pip install -r /app/requirements.txt \ + && rm -rf /root/.cache/pip + +# copy project +COPY . /app diff --git a/src/app/api/crud/__init__.py b/src/app/api/crud/__init__.py new file mode 100644 index 0000000..1e1c5d2 --- /dev/null +++ b/src/app/api/crud/__init__.py @@ -0,0 +1,3 @@ +from .base import * +from . import accesses +from . import authorizations diff --git a/src/app/api/crud/accesses.py b/src/app/api/crud/accesses.py new file mode 100644 index 0000000..e106efc --- /dev/null +++ b/src/app/api/crud/accesses.py @@ -0,0 +1,46 @@ +# Copyright (C) 2022, Pyronear. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from fastapi import HTTPException, status +from sqlalchemy import Table + +from app.api import security +from app.api.crud import base +from app.api.schemas import AccessCreation, AccessRead, Cred, CredHash, Login + + +async def check_login_existence(table: Table, login: str): + """Check that the login does not already exist, raises a 400 exception if do so.""" + if await base.fetch_one(table, {"login": login}) is not None: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=f"An entry with login='{login}' already exists.", + ) + + +async def update_login(accesses: Table, login: str, access_id: int): + """Update access login assuming access_id exists and new login does not exist.""" + return await base.update_entry(accesses, Login(login=login), access_id) + + +async def post_access(accesses: Table, login: str, password: str, scope: str) -> AccessRead: + """Insert an access entry in the accesses table, call within a transaction to reuse returned access id.""" + await check_login_existence(accesses, login) + + # Hash the password + pwd = await security.hash_password(password) + + access = AccessCreation(login=login, hashed_password=pwd, scope=scope) + entry = await base.create_entry(accesses, access) + + return AccessRead(**entry) + + +async def update_access_pwd(accesses: Table, payload: Cred, access_id: int) -> None: + """Update the access password using provided access_id.""" + # Update the access entry with the hashed password + updated_payload = CredHash(hashed_password=await security.hash_password(payload.password)) + + await base.update_entry(accesses, updated_payload, access_id) # update & check if access_id exists diff --git a/src/app/api/crud/authorizations.py b/src/app/api/crud/authorizations.py new file mode 100644 index 0000000..5987df7 --- /dev/null +++ b/src/app/api/crud/authorizations.py @@ -0,0 +1,21 @@ +# Copyright (C) 2022, Pyronear. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from fastapi import HTTPException, status + +from app.api import crud +from app.api.schemas import AccessType +from app.db import accesses + + +async def is_admin_access(access_id: int) -> bool: + access = await crud.base.get_entry(accesses, access_id) + return access["scope"] == AccessType.admin + + +async def check_access_read(access_id: int) -> bool: + if not (await is_admin_access(access_id)): + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="This access can't read resources") + return True diff --git a/src/app/api/crud/base.py b/src/app/api/crud/base.py new file mode 100644 index 0000000..65d1c99 --- /dev/null +++ b/src/app/api/crud/base.py @@ -0,0 +1,114 @@ +# Copyright (C) 2022, Pyronear. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import Any, Dict, List, Mapping, Optional + +from fastapi import HTTPException, Path, status +from pydantic import BaseModel +from sqlalchemy import Table + +from app.db import database + +__all__ = [ + "post", + "get", + "fetch_all", + "fetch_one", + "put", + "delete", + "create_entry", + "get_entry", + "update_entry", + "delete_entry", +] + + +async def post(payload: BaseModel, table: Table) -> int: + query = table.insert().values(**payload.dict()) + return await database.execute(query=query) + + +async def get(entry_id: int, table: Table) -> Mapping[str, Any]: + query = table.select().where(entry_id == table.c.id) + return await database.fetch_one(query=query) + + +async def fetch_all( + table: Table, + query_filters: Optional[Dict[str, Any]] = None, + exclusions: Optional[Dict[str, Any]] = None, + limit: int = 50, +) -> List[Mapping[str, Any]]: + query = table.select().order_by(table.c.id.desc()) + if isinstance(query_filters, dict): + for key, value in query_filters.items(): + query = query.where(getattr(table.c, key) == value) + + if isinstance(exclusions, dict): + for key, value in exclusions.items(): + query = query.where(getattr(table.c, key) != value) + return (await database.fetch_all(query=query.limit(limit)))[::-1] + + +async def fetch_one(table: Table, query_filters: Dict[str, Any]) -> Mapping[str, Any]: + query = table.select() + for query_filter_key, query_filter_value in query_filters.items(): + query = query.where(getattr(table.c, query_filter_key) == query_filter_value) + return await database.fetch_one(query=query) + + +async def put(entry_id: int, payload: Dict, table: Table) -> int: + query = table.update().where(entry_id == table.c.id).values(**payload).returning(table.c.id) + return await database.execute(query=query) + + +async def delete(entry_id: int, table: Table) -> None: + query = table.delete().where(entry_id == table.c.id) + await database.execute(query=query) + + +async def create_entry(table: Table, payload: BaseModel) -> Dict[str, Any]: + entry_id = await post(payload, table) + return {**payload.dict(), "id": entry_id} + + +async def get_entry(table: Table, entry_id: int = Path(..., gt=0)) -> Dict[str, Any]: + entry = await get(entry_id, table) + if entry is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=f"Table {table.name} has no entry with id={entry_id}" + ) + + return dict(entry) + + +async def update_entry( + table: Table, payload: BaseModel, entry_id: int = Path(..., gt=0), only_specified: bool = True +) -> Dict[str, Any]: + payload_dict = payload.dict() + + if only_specified: + # Dont update columns for null fields + payload_dict = {k: v for k, v in payload_dict.items() if v is not None} + + _id = await put(entry_id, payload_dict, table) + + if not isinstance(_id, int): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=f"Table {table.name} has no entry with id={entry_id}" + ) + + if only_specified: + # Retrieve complete record values + return dict(await get(entry_id, table)) + else: + return {**payload.dict(), "id": entry_id} + + +async def delete_entry(table: Table, entry_id: int = Path(..., gt=0)) -> Dict[str, Any]: + entry = await get_entry(table, entry_id) + await delete(entry_id, table) + + return entry diff --git a/src/app/api/deps.py b/src/app/api/deps.py new file mode 100644 index 0000000..b2009e3 --- /dev/null +++ b/src/app/api/deps.py @@ -0,0 +1,68 @@ +# Copyright (C) 2022, Pyronear. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from fastapi import Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer, SecurityScopes +from jose import JWTError, jwt +from pydantic import ValidationError + +import app.config as cfg +from app.api import crud +from app.api.schemas import AccessRead, AccessType, TokenPayload +from app.db import accesses + +# Scope definition +oauth2_scheme = OAuth2PasswordBearer( + tokenUrl="login/access-token", + scopes={ + AccessType.user: "Read information about the current user.", + AccessType.admin: "Admin rights on all routes.", + }, +) + + +async def get_current_access(security_scopes: SecurityScopes, token: str = Depends(oauth2_scheme)) -> AccessRead: + """Dependency to use as fastapi.security.Security with scopes. + + >>> @app.get("/users/me") + >>> async def read_users_me(current_user: User = Security(get_current_access, scopes=["me"])): + >>> return current_user + """ + + if security_scopes.scopes: + authenticate_value = f'Bearer scope="{security_scopes.scope_str}"' + else: + authenticate_value = "Bearer" + + try: + payload = jwt.decode(token, cfg.SECRET_KEY, algorithms=[cfg.JWT_ENCODING_ALGORITHM]) + except JWTError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Token has expired.", + headers={"WWW-Authenticate": authenticate_value}, + ) + + try: + access_id = int(payload["sub"]) + token_scopes = payload.get("scopes", []) + token_data = TokenPayload(access_id=access_id, scopes=token_scopes) + except (KeyError, ValidationError): + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="Invalid token payload.", + headers={"WWW-Authenticate": authenticate_value}, + ) + + entry = await crud.get_entry(table=accesses, entry_id=int(access_id)) + + if set(token_data.scopes).isdisjoint(security_scopes.scopes): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Your access scope is not compatible with this operation.", + headers={"WWW-Authenticate": authenticate_value}, + ) + + return AccessRead(**entry) diff --git a/src/app/api/external.py b/src/app/api/external.py new file mode 100644 index 0000000..534fac0 --- /dev/null +++ b/src/app/api/external.py @@ -0,0 +1,22 @@ +# Copyright (C) 2022, Pyronear. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import Optional + +import requests +from pydantic import BaseModel + + +def post_request(url: str, payload: Optional[BaseModel] = None) -> requests.Response: + """Performs a POST request to a given URL + + Args: + url: URL to send the POST request to + payload: payload to be sent + Returns: + HTTP response + """ + kwargs = {} if payload is None else {"json": payload} + return requests.post(url, headers={"Content-Type": "application/json"}, **kwargs) diff --git a/src/app/api/routes/accesses.py b/src/app/api/routes/accesses.py new file mode 100644 index 0000000..83a451d --- /dev/null +++ b/src/app/api/routes/accesses.py @@ -0,0 +1,63 @@ +# Copyright (C) 2022, Pyronear. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import List + +from fastapi import APIRouter, Path, Security, status + +from app.api import crud +from app.api.deps import get_current_access +from app.api.schemas import AccessAuth, AccessRead, AccessType, Cred +from app.db import accesses + +router = APIRouter() + + +@router.post("/", response_model=AccessRead, status_code=status.HTTP_201_CREATED, summary="Create an access") +async def create_access(payload: AccessAuth, _=Security(get_current_access, scopes=[AccessType.admin])): + """ + Creates an annotation related to specific media, based on media_id as argument + + Below, click on "Schema" for more detailed information about arguments + or "Example Value" to get a concrete idea of arguments + """ + return await crud.accesses.post_access(accesses, **payload.dict()) + + +@router.get("/{access_id}/", response_model=AccessRead, summary="Get information about a specific access") +async def get_access(access_id: int = Path(..., gt=0), _=Security(get_current_access, scopes=[AccessType.admin])): + """ + Based on a access_id, retrieves information about the specified access + """ + entry = await crud.get_entry(accesses, access_id) + return AccessRead(**entry) + + +@router.get("/", response_model=List[AccessRead], summary="Get the list of all accesses") +async def fetch_accesses(_=Security(get_current_access, scopes=[AccessType.admin])): + """ + Retrieves the list of all accesses and their information + """ + entries = await crud.fetch_all(accesses) + return [AccessRead(**entry) for entry in entries] + + +@router.put("/{access_id}/", response_model=None, summary="Update information about a specific access") +async def update_access_pwd( + payload: Cred, access_id: int = Path(..., gt=0), _=Security(get_current_access, scopes=[AccessType.admin]) +): + """ + Based on a access_id, updates information about the specified access + """ + await crud.accesses.update_access_pwd(accesses, payload, access_id) + + +@router.delete("/{access_id}/", response_model=AccessRead, summary="Delete a specific access") +async def delete_access(access_id: int = Path(..., gt=0), _=Security(get_current_access, scopes=[AccessType.admin])): + """ + Based on a access_id, deletes the specified access + """ + entry = await crud.delete_entry(accesses, access_id) + return AccessRead(**entry) diff --git a/src/app/api/routes/annotations.py b/src/app/api/routes/annotations.py new file mode 100644 index 0000000..ce5c3bd --- /dev/null +++ b/src/app/api/routes/annotations.py @@ -0,0 +1,156 @@ +# Copyright (C) 2022, Pyronear. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import Any, Dict, List + +from fastapi import APIRouter, BackgroundTasks, File, HTTPException, Path, Security, UploadFile, status + +from app.api import crud +from app.api.crud.authorizations import check_access_read, is_admin_access +from app.api.deps import get_current_access +from app.api.schemas import AccessType, AnnotationCreation, AnnotationIn, AnnotationOut, AnnotationUrl +from app.api.security import hash_content_file +from app.db import annotations +from app.services import annotations_bucket, resolve_bucket_key + +router = APIRouter() + + +async def check_annotation_registration(annotation_id: int) -> Dict[str, Any]: + """Checks whether the media is registered in the DB""" + return await crud.get_entry(annotations, annotation_id) + + +@router.post( + "/", + response_model=AnnotationOut, + status_code=status.HTTP_201_CREATED, + summary="Create an annotation related to a specific media", +) +async def create_annotation( + payload: AnnotationIn, _=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]) +): + """ + Creates an annotation related to specific media, based on media_id as argument + + Below, click on "Schema" for more detailed information about arguments + or "Example Value" to get a concrete idea of arguments + """ + return await crud.create_entry(annotations, payload) + + +@router.get("/{annotation_id}/", response_model=AnnotationOut, summary="Get information about a specific annotation") +async def get_annotation( + annotation_id: int = Path(..., gt=0), + requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), +): + """ + Based on a annotation_id, retrieves information about the specified media + """ + await check_access_read(requester.id) + return await crud.get_entry(annotations, annotation_id) + + +@router.get("/", response_model=List[AnnotationOut], summary="Get the list of all annotations") +async def fetch_annotations( + requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), +): + """ + Retrieves the list of all annotations and their information + """ + if await is_admin_access(requester.id): + return await crud.fetch_all(annotations) + else: + return [] + + +@router.put("/{annotation_id}/", response_model=AnnotationOut, summary="Update information about a specific annotation") +async def update_annotation( + payload: AnnotationIn, + annotation_id: int = Path(..., gt=0), + _=Security(get_current_access, scopes=[AccessType.admin]), +): + """ + Based on a annotation_id, updates information about the specified annotation + """ + return await crud.update_entry(annotations, payload, annotation_id) + + +@router.delete("/{annotation_id}/", response_model=AnnotationOut, summary="Delete a specific annotation") +async def delete_annotation( + annotation_id: int = Path(..., gt=0), _=Security(get_current_access, scopes=[AccessType.admin]) +): + """ + Based on a annotation_id, deletes the specified annotation + """ + return await crud.delete_entry(annotations, annotation_id) + + +@router.post("/{annotation_id}/upload", response_model=AnnotationOut, status_code=200) +async def upload_annotation( + background_tasks: BackgroundTasks, + annotation_id: int = Path(..., gt=0), + file: UploadFile = File(...), +): + """ + Upload a annotation (image or video) linked to an existing annotation object in the DB + """ + + # Check in DB + entry = await check_annotation_registration(annotation_id) + + # Concatenate the first 32 chars (to avoid system interactions issues) of SHA256 hash with file extension + file_hash = hash_content_file(file.file.read()) + file_name = f"{file_hash[:32]}.{file.filename.rpartition('.')[-1]}" + # Reset byte position of the file (cf. https://fastapi.tiangolo.com/tutorial/request-files/#uploadfile) + await file.seek(0) + # If files are in a subfolder of the bucket, prepend the folder path + bucket_key = resolve_bucket_key(file_name, annotations_bucket.folder) + + # Upload if bucket_key is different (otherwise the content is the exact same) + if isinstance(entry["bucket_key"], str) and entry["bucket_key"] == bucket_key: + return await crud.get_entry(annotations, annotation_id) + else: + # Failed upload + if not await annotations_bucket.upload_file(bucket_key=bucket_key, file_binary=file.file): + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed upload") + # Data integrity check + uploaded_file = await annotations_bucket.get_file(bucket_key=bucket_key) + # Failed download + if uploaded_file is None: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="The data integrity check failed (unable to download media from bucket)", + ) + # Remove temp local file + background_tasks.add_task(annotations_bucket.flush_tmp_file, uploaded_file) + # Check the hash + with open(uploaded_file, "rb") as f: + upload_hash = hash_content_file(f.read()) + if upload_hash != file_hash: + # Delete corrupted file + await annotations_bucket.delete_file(bucket_key) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Data was corrupted during upload" + ) + + entry_dict = dict(**entry) + entry_dict["bucket_key"] = bucket_key + return await crud.update_entry(annotations, AnnotationCreation(**entry_dict), annotation_id) + + +@router.get("/{annotation_id}/url", response_model=AnnotationUrl, status_code=200) +async def get_annotation_url( + annotation_id: int = Path(..., gt=0), + requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), +): + """Resolve the temporary media image URL""" + await check_access_read(requester.id) + + # Check in DB + annotation_instance = await check_annotation_registration(annotation_id) + # Check in bucket + temp_public_url = await annotations_bucket.get_public_url(annotation_instance["bucket_key"]) + return AnnotationUrl(url=temp_public_url) diff --git a/src/app/api/routes/login.py b/src/app/api/routes/login.py new file mode 100644 index 0000000..8c4188c --- /dev/null +++ b/src/app/api/routes/login.py @@ -0,0 +1,39 @@ +# Copyright (C) 2022, Pyronear. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from datetime import timedelta + +from fastapi import APIRouter, Depends, HTTPException, status +from fastapi.security import OAuth2PasswordRequestForm + +from app import config as cfg +from app.api import crud, security +from app.api.schemas import Token +from app.db import accesses + +router = APIRouter() + + +@router.post("/access-token", response_model=Token) +async def create_access_token(form_data: OAuth2PasswordRequestForm = Depends()): + """ + This API follows the OAuth 2.0 specification + + If the credentials are valid, creates a new access token + + By default, the token expires after 1 hour + """ + + # Verify credentials + entry = await crud.fetch_one(accesses, {"login": form_data.username}) + if entry is None or not await security.verify_password(form_data.password, entry["hashed_password"]): + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials.") + # create access token using user user_id/user_scopes + token_data = {"sub": str(entry["id"]), "scopes": entry["scope"].split()} + token = await security.create_access_token( + token_data, expires_delta=timedelta(minutes=cfg.ACCESS_TOKEN_EXPIRE_MINUTES) + ) + + return {"access_token": token, "token_type": "bearer"} diff --git a/src/app/api/routes/media.py b/src/app/api/routes/media.py new file mode 100644 index 0000000..004978d --- /dev/null +++ b/src/app/api/routes/media.py @@ -0,0 +1,148 @@ +# Copyright (C) 2022, Pyronear. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import Any, Dict, List + +from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, Path, Security, UploadFile, status + +from app.api import crud +from app.api.crud.authorizations import check_access_read, is_admin_access +from app.api.deps import get_current_access +from app.api.schemas import AccessType, MediaCreation, MediaIn, MediaOut, MediaUrl +from app.api.security import hash_content_file +from app.db import get_session, media +from app.services import media_bucket, resolve_bucket_key + +router = APIRouter() + + +async def check_media_registration(media_id: int) -> Dict[str, Any]: + """Checks whether the media is registered in the DB""" + return await crud.get_entry(media, media_id) + + +@router.post( + "/", + response_model=MediaOut, + status_code=status.HTTP_201_CREATED, + summary="Create a media related to a specific device", +) +async def create_media(payload: MediaIn, _=Security(get_current_access, scopes=[AccessType.admin])): + """ + Creates a media related to specific device, based on device_id as argument + + Below, click on "Schema" for more detailed information about arguments + or "Example Value" to get a concrete idea of arguments + """ + return await crud.create_entry(media, payload) + + +@router.get("/{media_id}/", response_model=MediaOut, summary="Get information about a specific media") +async def get_media( + media_id: int = Path(..., gt=0), requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]) +): + """ + Based on a media_id, retrieves information about the specified media + """ + await check_access_read(requester.id) + + return await crud.get_entry(media, media_id) + + +@router.get("/", response_model=List[MediaOut], summary="Get the list of all media") +async def fetch_media( + requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), session=Depends(get_session) +): + """ + Retrieves the list of all media and their information + """ + if await is_admin_access(requester.id): + return await crud.fetch_all(media) + return [] + + +@router.put("/{media_id}/", response_model=MediaOut, summary="Update information about a specific media") +async def update_media( + payload: MediaIn, media_id: int = Path(..., gt=0), _=Security(get_current_access, scopes=[AccessType.admin]) +): + """ + Based on a media_id, updates information about the specified media + """ + return await crud.update_entry(media, payload, media_id) + + +@router.delete("/{media_id}/", response_model=MediaOut, summary="Delete a specific media") +async def delete_media(media_id: int = Path(..., gt=0), _=Security(get_current_access, scopes=[AccessType.admin])): + """ + Based on a media_id, deletes the specified media + """ + return await crud.delete_entry(media, media_id) + + +@router.post("/{media_id}/upload", response_model=MediaOut, status_code=200) +async def upload_media( + background_tasks: BackgroundTasks, + media_id: int = Path(..., gt=0), + file: UploadFile = File(...), +): + """ + Upload a media (image or video) linked to an existing media object in the DB + """ + + # Check in DB + entry = await check_media_registration(media_id) + + # Concatenate the first 32 chars (to avoid system interactions issues) of SHA256 hash with file extension + file_hash = hash_content_file(file.file.read()) + file_name = f"{file_hash[:32]}.{file.filename.rpartition('.')[-1]}" + # Reset byte position of the file (cf. https://fastapi.tiangolo.com/tutorial/request-files/#uploadfile) + await file.seek(0) + # If files are in a subfolder of the bucket, prepend the folder path + bucket_key = resolve_bucket_key(file_name, media_bucket.folder) + + # Upload if bucket_key is different (otherwise the content is the exact same) + if isinstance(entry["bucket_key"], str) and entry["bucket_key"] == bucket_key: + return await crud.get_entry(media, media_id) + else: + # Failed upload + if not await media_bucket.upload_file(bucket_key=bucket_key, file_binary=file.file): + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed upload") + # Data integrity check + uploaded_file = await media_bucket.get_file(bucket_key=bucket_key) + # Failed download + if uploaded_file is None: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="The data integrity check failed (unable to download media from bucket)", + ) + # Remove temp local file + background_tasks.add_task(media_bucket.flush_tmp_file, uploaded_file) + # Check the hash + with open(uploaded_file, "rb") as f: + upload_hash = hash_content_file(f.read()) + if upload_hash != file_hash: + # Delete corrupted file + await media_bucket.delete_file(bucket_key) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Data was corrupted during upload" + ) + + entry_dict = dict(**entry) + entry_dict["bucket_key"] = bucket_key + return await crud.update_entry(media, MediaCreation(**entry_dict), media_id) + + +@router.get("/{media_id}/url", response_model=MediaUrl, status_code=200) +async def get_media_url( + media_id: int = Path(..., gt=0), requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]) +): + """Resolve the temporary media image URL""" + await check_access_read(requester.id) + + # Check in DB + media_instance = await check_media_registration(media_id) + # Check in bucket + temp_public_url = await media_bucket.get_public_url(media_instance["bucket_key"]) + return MediaUrl(url=temp_public_url) diff --git a/src/app/api/schemas.py b/src/app/api/schemas.py new file mode 100644 index 0000000..f5e5fe9 --- /dev/null +++ b/src/app/api/schemas.py @@ -0,0 +1,103 @@ +# Copyright (C) 2022, Pyronear. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from datetime import datetime +from typing import List, Optional + +from pydantic import BaseModel, Field, validator + +from app.db.models import AccessType, MediaType + + +# Template classes +class _CreatedAt(BaseModel): + created_at: Optional[datetime] = None + + @staticmethod + @validator("created_at", pre=True, always=True) + def default_ts_created(v): + return v or datetime.utcnow() + + +class _Id(BaseModel): + id: int = Field(..., gt=0) + + +# Accesses +class Login(BaseModel): + login: str = Field(..., min_length=3, max_length=50, example="JohnDoe") + + +class Cred(BaseModel): + password: str = Field(..., min_length=3, example="PickARobustOne") + + +class CredHash(BaseModel): + hashed_password: str + + +class AccessBase(Login): + scope: AccessType = AccessType.user + + +class AccessAuth(AccessBase, Cred): + pass + + +class AccessCreation(AccessBase, CredHash): + pass + + +class AccessRead(AccessBase, _Id): + pass + + +# Token +class Token(BaseModel): + access_token: str = Field(..., example="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.423fgFGTfttrvU6D1k7vF92hH5vaJHCGFYd8E") + token_type: str = Field(..., example="bearer") + + +class TokenPayload(BaseModel): + user_id: Optional[str] = None # token sub + scopes: List[AccessType] = [] + + +# Media +class BaseMedia(BaseModel): + type: MediaType = MediaType.image + + +class MediaIn(BaseMedia): + pass + + +class MediaCreation(MediaIn): + bucket_key: str = Field(...) + + +class MediaOut(MediaIn, _CreatedAt, _Id): + pass + + +class MediaUrl(BaseModel): + url: str + + +# Annotation +class AnnotationIn(BaseModel): + media_id: int = Field(..., gt=0) + + +class AnnotationCreation(AnnotationIn): + bucket_key: str = Field(...) + + +class AnnotationOut(AnnotationIn, _CreatedAt, _Id): + pass + + +class AnnotationUrl(BaseModel): + url: str diff --git a/src/app/api/security.py b/src/app/api/security.py new file mode 100644 index 0000000..79280bb --- /dev/null +++ b/src/app/api/security.py @@ -0,0 +1,40 @@ +# Copyright (C) 2022, Pyronear. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import hashlib +from datetime import datetime, timedelta +from typing import Any, Dict, Optional + +from jose import jwt +from passlib.context import CryptContext + +from app import config as cfg + +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + + +async def create_unlimited_access_token(content: Dict[str, Any]) -> str: + # Used for devices + return await create_access_token(content, timedelta(minutes=cfg.ACCESS_TOKEN_UNLIMITED_MINUTES)) + + +async def create_access_token(content: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str: + """Encode content dict using security algorithm, setting expiration.""" + if expires_delta is None: + expires_delta = timedelta(minutes=cfg.ACCESS_TOKEN_EXPIRE_MINUTES) + expire = datetime.utcnow() + expires_delta + return jwt.encode({**content, "exp": expire}, cfg.SECRET_KEY, algorithm=cfg.JWT_ENCODING_ALGORITHM) + + +async def verify_password(plain_password: str, hashed_password: str) -> bool: + return pwd_context.verify(plain_password, hashed_password) + + +async def hash_password(password: str) -> str: + return pwd_context.hash(password) + + +def hash_content_file(content: bytes) -> str: + return hashlib.sha256(content).hexdigest() diff --git a/src/app/config.py b/src/app/config.py new file mode 100644 index 0000000..6159641 --- /dev/null +++ b/src/app/config.py @@ -0,0 +1,53 @@ +# Copyright (C) 2022, Pyronear. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import os +import secrets +from typing import Optional + +PROJECT_NAME: str = "Pyronear - Storage API" +PROJECT_DESCRIPTION: str = "API for wildfire data curation" +API_BASE: str = "storage/" +VERSION: str = "0.1.0.dev0" +DEBUG: bool = os.environ.get("DEBUG", "") != "False" +DATABASE_URL: str = os.getenv("DATABASE_URL", "") +# Fix for SqlAlchemy 1.4+ +if DATABASE_URL.startswith("postgres://"): + DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://", 1) + +TEST_DATABASE_URL: str = os.getenv("TEST_DATABASE_URL", "") +LOGO_URL: str = "https://pyronear.org/img/logo_letters.png" + +SECRET_KEY: str = secrets.token_urlsafe(32) +if DEBUG: + # To keep the same Auth at every app loading in debug mode and not having to redo the auth. + debug_secret_key = "000000000000000000000000000000000000" + SECRET_KEY = debug_secret_key + +ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 +ACCESS_TOKEN_UNLIMITED_MINUTES = 60 * 24 * 365 * 10 +JWT_ENCODING_ALGORITHM = "HS256" + +SUPERUSER_LOGIN: str = os.getenv("SUPERUSER_LOGIN", "") +SUPERUSER_PWD: str = os.getenv("SUPERUSER_PWD", "") + +if SUPERUSER_LOGIN is None or SUPERUSER_PWD is None: + raise ValueError( + "Missing Credentials. Please set 'SUPERUSER_LOGIN' and 'SUPERUSER_PWD' in your environment variables" + ) + +QARNOT_TOKEN: str = os.getenv("QARNOT_TOKEN", "") +BUCKET_NAME: str = os.getenv("BUCKET_NAME", "") +BUCKET_MEDIA_FOLDER: Optional[str] = os.getenv("BUCKET_MEDIA_FOLDER") +BUCKET_ANNOT_FOLDER: Optional[str] = os.getenv("BUCKET_ANNOT_FOLDER") +DUMMY_BUCKET_FILE = ( + "https://ec.europa.eu/jrc/sites/jrcsh/files/styles/normal-responsive/" + + "public/growing-risk-future-wildfires_adobestock_199370851.jpeg" +) + + +# Sentry +SENTRY_DSN: Optional[str] = os.getenv("SENTRY_DSN") +SERVER_NAME: Optional[str] = os.getenv("SERVER_NAME") diff --git a/src/app/db/__init__.py b/src/app/db/__init__.py new file mode 100644 index 0000000..24f106b --- /dev/null +++ b/src/app/db/__init__.py @@ -0,0 +1,13 @@ +from .tables import * +from .session import engine, database, Base, SessionLocal +from .init_db import init_db +from .models import AccessType, MediaType + + +# Dependency +def get_session(): + db = SessionLocal() + try: + yield db + finally: + db.close() diff --git a/src/app/db/init_db.py b/src/app/db/init_db.py new file mode 100644 index 0000000..9e2c44c --- /dev/null +++ b/src/app/db/init_db.py @@ -0,0 +1,26 @@ +# Copyright (C) 2022, Pyronear. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from app import config as cfg +from app.api import crud +from app.api.schemas import AccessCreation, AccessType +from app.api.security import hash_password +from app.db import accesses + + +async def init_db(): + + login = cfg.SUPERUSER_LOGIN + + # check if access login does not already exist + entry = await crud.fetch_one(accesses, {"login": login}) + if entry is None: + + hashed_password = await hash_password(cfg.SUPERUSER_PWD) + + access = AccessCreation(login=login, hashed_password=hashed_password, scope=AccessType.admin) + await crud.create_entry(accesses, access) + + return None diff --git a/src/app/db/models.py b/src/app/db/models.py new file mode 100644 index 0000000..66b7c8c --- /dev/null +++ b/src/app/db/models.py @@ -0,0 +1,60 @@ +# Copyright (C) 2022, Pyronear. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import enum + +from sqlalchemy import Column, DateTime, Enum, ForeignKey, Integer, String +from sqlalchemy.orm import relationship +from sqlalchemy.sql import func + +from .session import Base + + +class AccessType(str, enum.Enum): + user: str = "user" + admin: str = "admin" + + +class Accesses(Base): + __tablename__ = "accesses" + + id = Column(Integer, primary_key=True) + login = Column(String(50), unique=True, index=True) # index for fast lookup + hashed_password = Column(String(70), nullable=False) + scope = Column(Enum(AccessType), default=AccessType.user, nullable=False) + + def __repr__(self): + return f"" + + +class MediaType(str, enum.Enum): + image: str = "image" + video: str = "video" + + +class Media(Base): + __tablename__ = "media" + + id = Column(Integer, primary_key=True) + bucket_key = Column(String(100), nullable=True) + type = Column(Enum(MediaType), default=MediaType.image) + created_at = Column(DateTime, default=func.now()) + + def __repr__(self): + return f"" + + +class Annotations(Base): + __tablename__ = "annotations" + + id = Column(Integer, primary_key=True) + media_id = Column(Integer, ForeignKey("media.id")) + bucket_key = Column(String(100), nullable=True) + created_at = Column(DateTime, default=func.now()) + + media = relationship("Media", uselist=False, back_populates="annotations") + + def __repr__(self): + return f"" diff --git a/src/app/db/session.py b/src/app/db/session.py new file mode 100644 index 0000000..b204e32 --- /dev/null +++ b/src/app/db/session.py @@ -0,0 +1,17 @@ +# Copyright (C) 2022, Pyronear. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from databases import Database +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker + +from app import config as cfg + +engine = create_engine(cfg.DATABASE_URL) +database = Database(cfg.DATABASE_URL) +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +Base = declarative_base() diff --git a/src/app/db/tables.py b/src/app/db/tables.py new file mode 100644 index 0000000..6293dcf --- /dev/null +++ b/src/app/db/tables.py @@ -0,0 +1,16 @@ +# Copyright (C) 2022, Pyronear. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + + +from .models import Accesses, Annotations, Media +from .session import Base + +__all__ = ["metadata", "accesses", "media", "annotations"] + +accesses = Accesses.__table__ +media = Media.__table__ +annotations = Annotations.__table__ + +metadata = Base.metadata diff --git a/src/app/main.py b/src/app/main.py new file mode 100644 index 0000000..2fbe5f0 --- /dev/null +++ b/src/app/main.py @@ -0,0 +1,85 @@ +# Copyright (C) 2022, Pyronear. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import logging +import time + +import sentry_sdk +from fastapi import FastAPI, Request +from fastapi.openapi.utils import get_openapi +from sentry_sdk.integrations.asgi import SentryAsgiMiddleware + +from app import config as cfg +from app.api.routes import accesses, annotations, login, media +from app.db import database, engine, init_db, metadata + +logger = logging.getLogger("uvicorn.error") + +metadata.create_all(bind=engine) + +# Sentry +if isinstance(cfg.SENTRY_DSN, str): + sentry_sdk.init( + cfg.SENTRY_DSN, + release=cfg.VERSION, + server_name=cfg.SERVER_NAME, + environment="production" if isinstance(cfg.SERVER_NAME, str) else None, + traces_sample_rate=1.0, + ) + logger.info(f"Sentry middleware enabled on server {cfg.SERVER_NAME}") + + +app = FastAPI(title=cfg.PROJECT_NAME, description=cfg.PROJECT_DESCRIPTION, debug=cfg.DEBUG, version=cfg.VERSION) + + +# Database connection +@app.on_event("startup") +async def startup(): + await database.connect() + await init_db() + + +@app.on_event("shutdown") +async def shutdown(): + await database.disconnect() + + +# Routing +app.include_router(login.router, prefix="/login", tags=["login"]) +app.include_router(media.router, prefix="/media", tags=["media"]) +app.include_router(annotations.router, prefix="/annotations", tags=["annotations"]) +app.include_router(accesses.router, prefix="/accesses", tags=["accesses"]) + + +# Middleware +@app.middleware("http") +async def add_process_time_header(request: Request, call_next): + start_time = time.time() + response = await call_next(request) + process_time = time.time() - start_time + response.headers["X-Process-Time"] = str(process_time) + return response + + +if isinstance(cfg.SENTRY_DSN, str): + app.add_middleware(SentryAsgiMiddleware) + + +# Docs +def custom_openapi(): + if app.openapi_schema: + return app.openapi_schema + openapi_schema = get_openapi( + title=cfg.PROJECT_NAME, + version=cfg.VERSION, + description=cfg.PROJECT_DESCRIPTION, + routes=app.routes, + ) + openapi_schema["info"]["x-logo"] = {"url": cfg.LOGO_URL} + app.openapi_schema = openapi_schema + return app.openapi_schema + + +app.openapi = custom_openapi # type: ignore[assignment] diff --git a/src/app/requirements.txt b/src/app/requirements.txt new file mode 100644 index 0000000..9553228 --- /dev/null +++ b/src/app/requirements.txt @@ -0,0 +1,11 @@ +fastapi>=0.61.1 +uvicorn>=0.11.1 +databases[postgresql]>=0.2.6,<=0.4.0 +SQLAlchemy>=1.3.12 +python-jose>=3.2.0 +passlib[bcrypt]>=1.7.4 +python-multipart==0.0.5 +aiofiles==0.6.0 +requests>=2.22.0 +sentry-sdk>=1.5.12 +qarnot>=2.2.1 diff --git a/src/app/services/__init__.py b/src/app/services/__init__.py new file mode 100644 index 0000000..dabc41a --- /dev/null +++ b/src/app/services/__init__.py @@ -0,0 +1,2 @@ +from .services import * +from .utils import * diff --git a/src/app/services/bucket/__init__.py b/src/app/services/bucket/__init__.py new file mode 100644 index 0000000..0c55cff --- /dev/null +++ b/src/app/services/bucket/__init__.py @@ -0,0 +1 @@ +from .qarnot import * diff --git a/src/app/services/bucket/qarnot.py b/src/app/services/bucket/qarnot.py new file mode 100644 index 0000000..9dace9e --- /dev/null +++ b/src/app/services/bucket/qarnot.py @@ -0,0 +1,98 @@ +# Copyright (C) 2022, Pyronear. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import logging +import os +from typing import List, Optional + +from fastapi import HTTPException +from qarnot.bucket import Bucket +from qarnot.connection import Connection + +from app import config as cfg + +__all__ = ["QarnotBucket"] + + +logger = logging.getLogger("uvicorn.warning") + + +class QarnotBucket: + """Storage bucket manipulation object on Qarnot computing""" + + _bucket: Optional[Bucket] = None + + def __init__(self, bucket_name: str, folder: Optional[str] = None) -> None: + self.bucket_name = bucket_name + self.folder = folder + self._connect_to_bucket() + + def _connect_to_bucket(self) -> None: + """Connect to the CSP bucket""" + self._conn = Connection(client_token=cfg.QARNOT_TOKEN) + self._bucket = Bucket(self._conn, self.bucket_name) + + @property + def bucket(self) -> Bucket: + if self._bucket is None: + self._connect_to_bucket() + return self._bucket + + async def get_file(self, bucket_key: str) -> Optional[str]: + """Download a file locally and returns the local temp path""" + try: + return self.bucket.get_file(bucket_key) + except Exception as e: + logger.warning(e) + return None + + async def check_file_existence(self, bucket_key: str) -> bool: + """Check whether a file exists on the bucket""" + try: + # Use boto3 head_object method using the Qarnot private connection attribute + # cf. https://github.com/qarnot/qarnot-sdk-python/blob/master/qarnot/connection.py#L188 + head_object = self._conn._s3client.head_object(Bucket=self.bucket_name, Key=bucket_key) + return head_object["ResponseMetadata"]["HTTPStatusCode"] == 200 + except Exception as e: + logger.warning(e) + return False + + async def get_public_url(self, bucket_key: str, url_expiration: int = 3600) -> str: + """Generate a temporary public URL for a bucket file""" + if not await self.check_file_existence(bucket_key): + raise HTTPException(status_code=404, detail="File cannot be found on the bucket storage") + + # Point to the bucket file + file_params = {"Bucket": self.bucket_name, "Key": bucket_key} + # Generate a public URL for it using boto3 presign URL generation + return self._conn._s3client.generate_presigned_url("get_object", Params=file_params, ExpiresIn=url_expiration) + + async def upload_file(self, bucket_key: str, file_binary: bytes) -> bool: + """Upload a file to bucket and return whether the upload succeeded""" + try: + self.bucket.add_file(file_binary, bucket_key) + except Exception as e: + logger.warning(e) + return False + return True + + async def fetch_bucket_filenames(self) -> List[str]: + """List all bucket files""" + + if isinstance(self.folder, str): + obj_summary = self.bucket.directory(self.folder) + else: + obj_summary = self.bucket.list_files() + + return [file.key for file in list(obj_summary)] + + async def flush_tmp_file(self, filename: str) -> None: + """Remove temporary file""" + if os.path.exists(filename): + os.remove(filename) + + async def delete_file(self, bucket_key: str) -> None: + """Remove bucket file and return whether the deletion succeeded""" + self.bucket.delete_file(bucket_key) diff --git a/src/app/services/services.py b/src/app/services/services.py new file mode 100644 index 0000000..8cc3a03 --- /dev/null +++ b/src/app/services/services.py @@ -0,0 +1,13 @@ +# Copyright (C) 2022, Pyronear. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from app import config as cfg +from app.services.bucket import QarnotBucket + +__all__ = ["media_bucket", "annotations_bucket"] + + +media_bucket = QarnotBucket(cfg.BUCKET_NAME, cfg.BUCKET_MEDIA_FOLDER) +annotations_bucket = QarnotBucket(cfg.BUCKET_NAME, cfg.BUCKET_ANNOT_FOLDER) diff --git a/src/app/services/utils.py b/src/app/services/utils.py new file mode 100644 index 0000000..9e58d54 --- /dev/null +++ b/src/app/services/utils.py @@ -0,0 +1,13 @@ +# Copyright (C) 2022, Pyronear. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import Optional + +__all__ = ["resolve_bucket_key"] + + +def resolve_bucket_key(file_name: str, bucket_folder: Optional[str] = None) -> str: + """Prepend file name with bucket subfolder""" + return f"{bucket_folder}/{file_name}" if isinstance(bucket_folder, str) else file_name diff --git a/src/requirements-dev.txt b/src/requirements-dev.txt new file mode 100644 index 0000000..5362534 --- /dev/null +++ b/src/requirements-dev.txt @@ -0,0 +1,6 @@ +pytest>=5.3.2 +pytest-asyncio>=0.14.0 +asyncpg>=0.20.0 +coverage>=4.5.4 +aiosqlite>=0.16.0 +httpx>=0.16.1,<0.19.0 diff --git a/src/tests/conftest.py b/src/tests/conftest.py new file mode 100644 index 0000000..cc07fbc --- /dev/null +++ b/src/tests/conftest.py @@ -0,0 +1,51 @@ +# Copyright (C) 2021, Pyronear contributors. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import pytest +from httpx import AsyncClient + +from app.api.security import create_unlimited_access_token +from app.main import app +from tests.db_utils import database as test_database +from tests.db_utils import reset_test_db + + +async def mock_hash_password(password): + return f"hashed_{password}" + + +async def mock_verify_password(plain_password, hashed_password): + return hashed_password == f"hashed_{plain_password}" + + +async def get_token(access_id, scopes): + + token_data = {"sub": str(access_id), "scopes": scopes} + token = await create_unlimited_access_token(token_data) + + return {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + + +def pytest_configure(): + # api.security patching + pytest.mock_hash_password = mock_hash_password + pytest.mock_verify_password = mock_verify_password + pytest.get_token = get_token + + +@pytest.fixture(scope="function") +async def test_app_asyncio(): + async with AsyncClient(app=app, base_url="http://test") as ac: + yield ac # testing happens here + + +@pytest.fixture(scope="function") +async def test_db(): + try: + await test_database.connect() + yield test_database + finally: + await reset_test_db() + await test_database.disconnect() diff --git a/src/tests/crud/test_authorizations.py b/src/tests/crud/test_authorizations.py new file mode 100644 index 0000000..f35db9d --- /dev/null +++ b/src/tests/crud/test_authorizations.py @@ -0,0 +1,68 @@ +# Copyright (C) 2021, Pyronear contributors. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import pytest +from fastapi import HTTPException + +from app import db +from app.api import crud +from tests.db_utils import fill_table +from tests.utils import update_only_datetime + +ACCESS_TABLE = [ + {"id": 1, "login": "first_login", "hashed_password": "hashed_pwd", "scope": "user"}, + {"id": 2, "login": "second_login", "hashed_password": "hashed_pwd", "scope": "admin"}, +] + +MEDIA_TABLE = [ + {"id": 1, "type": "image", "created_at": "2020-10-13T08:18:45.447773"}, + {"id": 2, "type": "video", "created_at": "2020-10-13T09:18:45.447773"}, +] + + +ANNOTATIONS_TABLE = [ + {"id": 1, "media_id": 1, "created_at": "2020-10-13T08:18:45.447773"}, +] + + +MEDIA_TABLE_FOR_DB = list(map(update_only_datetime, MEDIA_TABLE)) +ANNOTATIONS_TABLE_FOR_DB = list(map(update_only_datetime, ANNOTATIONS_TABLE)) + + +@pytest.fixture(scope="function") +async def init_test_db(monkeypatch, test_db): + monkeypatch.setattr(crud.base, "database", test_db) + await fill_table(test_db, db.accesses, ACCESS_TABLE) + await fill_table(test_db, db.media, MEDIA_TABLE_FOR_DB) + await fill_table(test_db, db.annotations, ANNOTATIONS_TABLE_FOR_DB) + + +@pytest.mark.parametrize( + "access_id, expected_result", + [ + [1, False], + [2, True], + ], +) +@pytest.mark.asyncio +async def test_admin_access(test_app_asyncio, init_test_db, access_id, expected_result): + admin_access_result = await crud.authorizations.is_admin_access(access_id) + assert admin_access_result == expected_result + + +@pytest.mark.parametrize( + "access_id, should_raise", + [ + [1, True], + [2, False], # Because Admin + ], +) +@pytest.mark.asyncio +async def test_check_access_read(test_app_asyncio, init_test_db, access_id, should_raise): + if should_raise: + with pytest.raises(HTTPException): + await crud.authorizations.check_access_read(access_id) + else: + await crud.authorizations.check_access_read(access_id) diff --git a/src/tests/db_utils.py b/src/tests/db_utils.py new file mode 100644 index 0000000..2604a20 --- /dev/null +++ b/src/tests/db_utils.py @@ -0,0 +1,50 @@ +# Copyright (C) 2021, Pyronear contributors. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import contextlib +from typing import Any, Dict, List, Mapping + +from databases import Database +from sqlalchemy import Table, create_engine +from sqlalchemy.orm import sessionmaker + +import app.config as cfg +from app.db import metadata + +SQLALCHEMY_DATABASE_URL = cfg.TEST_DATABASE_URL +engine = create_engine(SQLALCHEMY_DATABASE_URL) +metadata.create_all(bind=engine) +database = Database(SQLALCHEMY_DATABASE_URL) +TestSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + + +async def reset_test_db(): + """ + Delete all rows from the database but keeps the schemas + """ + + with contextlib.closing(engine.connect()) as con: + trans = con.begin() + for table in reversed(metadata.sorted_tables): + con.execute(table.delete()) + con.execute(f"ALTER SEQUENCE {table}_id_seq RESTART WITH 1") + trans.commit() + + +async def fill_table(test_db: Database, table: Table, entries: List[Dict[str, Any]], remove_ids: bool = True) -> None: + """ + Directly insert data into the DB table. Set remove_ids to True by default as the id sequence pointer + are not incremented if the "id" field is included + """ + if remove_ids: + entries = [{k: v for k, v in x.items() if k != "id"} for x in entries] + + query = table.insert().values(entries) + await test_db.execute(query=query) + + +async def get_entry(test_db: Database, table: Table, entry_id: int) -> Mapping[str, Any]: + query = table.select().where(entry_id == table.c.id) + return await test_db.fetch_one(query=query) diff --git a/src/tests/routes/test_accesses.py b/src/tests/routes/test_accesses.py new file mode 100644 index 0000000..c88dcb4 --- /dev/null +++ b/src/tests/routes/test_accesses.py @@ -0,0 +1,200 @@ +# Copyright (C) 2021, Pyronear contributors. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import json + +import pytest + +from app import db +from app.api import crud, security +from tests.db_utils import fill_table, get_entry +from tests.utils import update_only_datetime + +ACCESS_TABLE = [ + {"id": 1, "login": "first_login", "hashed_password": "hashed_pwd", "scope": "user"}, + {"id": 2, "login": "second_login", "hashed_password": "hashed_pwd", "scope": "admin"}, +] + + +ACCESS_TABLE_FOR_DB = list(map(update_only_datetime, ACCESS_TABLE)) + + +@pytest.fixture(scope="function") +async def init_test_db(monkeypatch, test_db): + monkeypatch.setattr(crud.base, "database", test_db) + await fill_table(test_db, db.accesses, ACCESS_TABLE) + monkeypatch.setattr(security, "hash_password", pytest.mock_hash_password) + + +@pytest.mark.parametrize( + "access_idx, access_id, status_code, status_details", + [ + [None, 1, 401, "Not authenticated"], + [0, 1, 403, "Your access scope is not compatible with this operation."], + [0, 2, 403, "Your access scope is not compatible with this operation."], + [0, 3, 403, "Your access scope is not compatible with this operation."], + [1, 1, 200, None], + [1, 2, 200, None], + [1, 999, 404, "Table accesses has no entry with id=999"], + [1, 0, 422, None], + ], +) +@pytest.mark.asyncio +async def test_get_access(init_test_db, test_app_asyncio, access_idx, access_id, status_code, status_details): + + # Create a custom access token + auth = None + if isinstance(access_idx, int): + auth = await pytest.get_token(ACCESS_TABLE[access_idx]["id"], ACCESS_TABLE[access_idx]["scope"].split()) + + response = await test_app_asyncio.get(f"/accesses/{access_id}", headers=auth) + assert response.status_code == status_code + + if isinstance(status_details, str): + assert response.json()["detail"] == status_details + + if response.status_code == 200: + access = None + for _access in ACCESS_TABLE: + if _access["id"] == access_id: + access = _access + break + assert response.json() == {k: v for k, v in access.items() if k != "hashed_password"} + + +@pytest.mark.parametrize( + "access_idx, status_code, status_details", + [ + [None, 401, "Not authenticated"], + [0, 403, "Your access scope is not compatible with this operation."], + [1, 200, None], + ], +) +@pytest.mark.asyncio +async def test_fetch_accesses(init_test_db, test_app_asyncio, access_idx, status_code, status_details): + + # Create a custom access token + auth = None + if isinstance(access_idx, int): + auth = await pytest.get_token(ACCESS_TABLE[access_idx]["id"], ACCESS_TABLE[access_idx]["scope"].split()) + + response = await test_app_asyncio.get("/accesses/", headers=auth) + assert response.status_code == status_code + + if isinstance(status_details, str): + assert response.json()["detail"] == status_details + if response.status_code == 200: + assert response.json() == [{k: v for k, v in entry.items() if k != "hashed_password"} for entry in ACCESS_TABLE] + + +@pytest.mark.parametrize( + "access_idx, payload, status_code, status_details", + [ + [None, {"login": "dummy_login", "scope": "admin", "password": "my_pwd"}, 401, "Not authenticated"], + # non-admin can't create access + [ + 0, + {"login": "dummy_login", "scope": "user", "password": "my_pwd"}, + 403, + "Your access scope is not compatible with this operation.", + ], + [ + 0, + {"login": "dummy_login", "scope": "admin", "password": "my_pwd"}, + 403, + "Your access scope is not compatible with this operation.", + ], + [1, {"login": "dummy_login", "scope": "user", "password": "my_pwd"}, 201, None], + [1, {"login": "dummy_login", "scope": "admin", "password": "my_pwd"}, 201, None], + [1, {"login": 1, "scope": "admin", "password": "my_pwd"}, 422, None], + [1, {"login": "dummy_login", "scope": 1, "password": "my_pwd"}, 422, None], + ], +) +@pytest.mark.asyncio +async def test_create_access(test_app_asyncio, init_test_db, test_db, access_idx, payload, status_code, status_details): + + # Create a custom access token + auth = None + if isinstance(access_idx, int): + auth = await pytest.get_token(ACCESS_TABLE[access_idx]["id"], ACCESS_TABLE[access_idx]["scope"].split()) + + response = await test_app_asyncio.post("/accesses/", data=json.dumps(payload), headers=auth) + assert response.status_code == status_code + if isinstance(status_details, str): + assert response.json()["detail"] == status_details + + if response.status_code // 100 == 2: + json_response = response.json() + test_response = {"id": len(ACCESS_TABLE) + 1, "login": payload["login"], "scope": payload["scope"]} + assert json_response == test_response + + new_annotation = await get_entry(test_db, db.accesses, json_response["id"]) + new_annotation = dict(**new_annotation) + + +@pytest.mark.parametrize( + "access_idx, payload, access_id, status_code, status_details", + [ + [None, {}, 1, 401, "Not authenticated"], + [0, {"password": "my_pwd"}, 1, 403, "Your access scope is not compatible with this operation."], + [1, {"password": "my_pwd"}, 1, 200, None], + [1, {}, 1, 422, None], + [1, {"password": 1}, 1, 422, None], + [1, {"password": "my_pwd"}, 999, 404, "Table accesses has no entry with id=999"], + [1, {"password": "my_pwd"}, 2, 200, None], + ], +) +@pytest.mark.asyncio +async def test_update_access_pwd( + test_app_asyncio, init_test_db, test_db, access_idx, payload, access_id, status_code, status_details +): + + # Create a custom access token + auth = None + if isinstance(access_idx, int): + auth = await pytest.get_token(ACCESS_TABLE[access_idx]["id"], ACCESS_TABLE[access_idx]["scope"].split()) + + response = await test_app_asyncio.put(f"/accesses/{access_id}/", data=json.dumps(payload), headers=auth) + assert response.status_code == status_code + if isinstance(status_details, str): + assert response.json()["detail"] == status_details + + if response.status_code // 100 == 2: + updated_access = await get_entry(test_db, db.accesses, access_id) + updated_access = dict(**updated_access) + for k, v in updated_access.items(): + if k == "hashed_password": + assert v == f"hashed_{payload['password']}" + else: + assert v == payload.get(k, ACCESS_TABLE_FOR_DB[access_id - 1][k]) + + +@pytest.mark.parametrize( + "access_idx, access_id, status_code, status_details", + [ + [None, 1, 401, "Not authenticated"], + [0, 1, 403, "Your access scope is not compatible with this operation."], + [1, 1, 200, None], + [1, 999, 404, "Table accesses has no entry with id=999"], + [1, 0, 422, None], + ], +) +@pytest.mark.asyncio +async def test_delete_access(test_app_asyncio, init_test_db, access_idx, access_id, status_code, status_details): + + # Create a custom access token + auth = None + if isinstance(access_idx, int): + auth = await pytest.get_token(ACCESS_TABLE[access_idx]["id"], ACCESS_TABLE[access_idx]["scope"].split()) + + response = await test_app_asyncio.delete(f"/accesses/{access_id}/", headers=auth) + assert response.status_code == status_code + if isinstance(status_details, str): + assert response.json()["detail"] == status_details + + if response.status_code // 100 == 2: + assert response.json() == {k: v for k, v in ACCESS_TABLE[access_id - 1].items() if k != "hashed_password"} + remaining_annotation = await test_app_asyncio.get("/accesses/", headers=auth) + assert all(entry["id"] != access_id for entry in remaining_annotation.json()) diff --git a/src/tests/routes/test_annotations.py b/src/tests/routes/test_annotations.py new file mode 100644 index 0000000..1ce3858 --- /dev/null +++ b/src/tests/routes/test_annotations.py @@ -0,0 +1,264 @@ +# Copyright (C) 2022, Pyronear contributors. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import json +import os +import tempfile +from datetime import datetime + +import pytest + +from app import db +from app.api import crud +from app.services import annotations_bucket +from tests.db_utils import TestSessionLocal, fill_table, get_entry +from tests.utils import update_only_datetime + +ACCESS_TABLE = [ + {"id": 1, "login": "first_login", "hashed_password": "hashed_pwd", "scope": "user"}, + {"id": 2, "login": "second_login", "hashed_password": "hashed_pwd", "scope": "admin"}, +] + +MEDIA_TABLE = [ + {"id": 1, "type": "image", "created_at": "2020-10-13T08:18:45.447773"}, + {"id": 2, "type": "video", "created_at": "2020-10-13T09:18:45.447773"}, +] + +ANNOTATIONS_TABLE = [ + {"id": 1, "media_id": 1, "bucket_key": "dummy_key", "created_at": "2020-10-13T08:18:45.447773"}, +] + + +MEDIA_TABLE_FOR_DB = list(map(update_only_datetime, MEDIA_TABLE)) +ANNOTATIONS_TABLE_FOR_DB = list(map(update_only_datetime, ANNOTATIONS_TABLE)) + + +@pytest.fixture(scope="function") +async def init_test_db(monkeypatch, test_db): + monkeypatch.setattr(crud.base, "database", test_db) + monkeypatch.setattr(db, "SessionLocal", TestSessionLocal) + await fill_table(test_db, db.accesses, ACCESS_TABLE) + await fill_table(test_db, db.media, MEDIA_TABLE_FOR_DB) + await fill_table(test_db, db.annotations, ANNOTATIONS_TABLE_FOR_DB) + + +@pytest.mark.parametrize( + "access_idx, annotation_id, status_code, status_details", + [ + [None, 1, 401, "Not authenticated"], + [0, 1, 403, "This access can't read resources"], + [1, 1, 200, None], + [1, 999, 404, "Table annotations has no entry with id=999"], + [1, 0, 422, None], + ], +) +@pytest.mark.asyncio +async def test_get_annotation(test_app_asyncio, init_test_db, access_idx, annotation_id, status_code, status_details): + + # Create a custom access token + auth = None + if isinstance(access_idx, int): + auth = await pytest.get_token(ACCESS_TABLE[access_idx]["id"], ACCESS_TABLE[access_idx]["scope"].split()) + + response = await test_app_asyncio.get(f"/annotations/{annotation_id}", headers=auth) + assert response.status_code == status_code + if isinstance(status_details, str): + assert response.json()["detail"] == status_details + + if response.status_code // 100 == 2: + assert response.json() == {k: v for k, v in ANNOTATIONS_TABLE[annotation_id - 1].items() if k != "bucket_key"} + + +@pytest.mark.parametrize( + "access_idx, status_code, status_details, expected_results", + [ + [None, 401, "Not authenticated", None], + [0, 200, None, []], + [1, 200, None, [{k: v for k, v in elt.items() if k != "bucket_key"} for elt in ANNOTATIONS_TABLE]], + ], +) +@pytest.mark.asyncio +async def test_fetch_annotations( + test_app_asyncio, init_test_db, access_idx, status_code, status_details, expected_results +): + + # Create a custom access token + auth = None + if isinstance(access_idx, int): + auth = await pytest.get_token(ACCESS_TABLE[access_idx]["id"], ACCESS_TABLE[access_idx]["scope"].split()) + + response = await test_app_asyncio.get("/annotations/", headers=auth) + assert response.status_code == status_code + if isinstance(status_details, str): + assert response.json()["detail"] == status_details + + if response.status_code // 100 == 2: + assert response.json() == expected_results + + +@pytest.mark.parametrize( + "access_idx, payload, status_code, status_details", + [ + [None, {"media_id": 1}, 401, "Not authenticated"], + [0, {"media_id": 1}, 201, None], + [1, {"media_id": 1}, 201, None], + [1, {"media_id": "alpha"}, 422, None], + [1, {}, 422, None], + ], +) +@pytest.mark.asyncio +async def test_create_annotation( + test_app_asyncio, init_test_db, test_db, access_idx, payload, status_code, status_details +): + + # Create a custom access token + auth = None + if isinstance(access_idx, int): + auth = await pytest.get_token(ACCESS_TABLE[access_idx]["id"], ACCESS_TABLE[access_idx]["scope"].split()) + + utc_dt = datetime.utcnow() + response = await test_app_asyncio.post("/annotations/", data=json.dumps(payload), headers=auth) + assert response.status_code == status_code + if isinstance(status_details, str): + assert response.json()["detail"] == status_details + + if response.status_code // 100 == 2: + json_response = response.json() + test_response = {"id": len(ANNOTATIONS_TABLE) + 1, **payload} + assert {k: v for k, v in json_response.items() if k != "created_at"} == test_response + + new_annotation = await get_entry(test_db, db.annotations, json_response["id"]) + new_annotation = dict(**new_annotation) + + # Timestamp consistency + assert new_annotation["created_at"] > utc_dt and new_annotation["created_at"] < datetime.utcnow() + + +@pytest.mark.parametrize( + "access_idx, payload, annotation_id, status_code, status_details", + [ + [None, {"media_id": 1}, 1, 401, "Not authenticated"], + [0, {"media_id": 1}, 1, 403, "Your access scope is not compatible with this operation."], + [1, {"media_id": 1}, 1, 200, None], + [1, {}, 1, 422, None], + [1, {"media_id": "alpha"}, 1, 422, None], + [1, {"media_id": 1}, 999, 404, "Table annotations has no entry with id=999"], + [1, {"media_id": 1}, 0, 422, None], + ], +) +@pytest.mark.asyncio +async def test_update_annotation( + test_app_asyncio, init_test_db, test_db, access_idx, payload, annotation_id, status_code, status_details +): + + # Create a custom access token + auth = None + if isinstance(access_idx, int): + auth = await pytest.get_token(ACCESS_TABLE[access_idx]["id"], ACCESS_TABLE[access_idx]["scope"].split()) + + response = await test_app_asyncio.put(f"/annotations/{annotation_id}/", data=json.dumps(payload), headers=auth) + assert response.status_code == status_code + if isinstance(status_details, str): + assert response.json()["detail"] == status_details + + if response.status_code // 100 == 2: + updated_annotation = await get_entry(test_db, db.annotations, annotation_id) + updated_annotation = dict(**updated_annotation) + for k, v in updated_annotation.items(): + if k != "bucket_key": + assert v == payload.get(k, ANNOTATIONS_TABLE_FOR_DB[annotation_id - 1][k]) + + +@pytest.mark.parametrize( + "access_idx, annotation_id, status_code, status_details", + [ + [None, 1, 401, "Not authenticated"], + [0, 1, 403, "Your access scope is not compatible with this operation."], + [1, 1, 200, None], + [1, 999, 404, "Table annotations has no entry with id=999"], + [1, 0, 422, None], + ], +) +@pytest.mark.asyncio +async def test_delete_annotation( + test_app_asyncio, init_test_db, access_idx, annotation_id, status_code, status_details +): + + # Create a custom access token + auth = None + if isinstance(access_idx, int): + auth = await pytest.get_token(ACCESS_TABLE[access_idx]["id"], ACCESS_TABLE[access_idx]["scope"].split()) + + response = await test_app_asyncio.delete(f"/annotations/{annotation_id}/", headers=auth) + assert response.status_code == status_code + if isinstance(status_details, str): + assert response.json()["detail"] == status_details + + if response.status_code // 100 == 2: + assert response.json() == {k: v for k, v in ANNOTATIONS_TABLE[annotation_id - 1].items() if k != "bucket_key"} + remaining_annotation = await test_app_asyncio.get("/annotations/", headers=auth) + assert all(entry["id"] != annotation_id for entry in remaining_annotation.json()) + + +@pytest.mark.asyncio +async def test_upload_annotation(test_app_asyncio, init_test_db, test_db, monkeypatch): + + admin_idx = 1 + # Create a custom access token + admin_auth = await pytest.get_token(ACCESS_TABLE[admin_idx]["id"], ACCESS_TABLE[admin_idx]["scope"].split()) + + # 1 - Create a annotation that will have an upload + payload = {"media_id": 2} + new_annotation_id = len(ANNOTATIONS_TABLE_FOR_DB) + 1 + response = await test_app_asyncio.post("/annotations/", data=json.dumps(payload), headers=admin_auth) + assert response.status_code == 201 + + # 2 - Upload something + async def mock_upload_file(bucket_key, file_binary): + return True + + monkeypatch.setattr(annotations_bucket, "upload_file", mock_upload_file) + + # Download and save a temporary file + local_tmp_path = os.path.join(tempfile.gettempdir(), "my_temp_annotation.json") + data = {"label": "fire"} + with open(local_tmp_path, "w") as f: + json.dump(data, f) + + async def mock_get_file(bucket_key): + return local_tmp_path + + monkeypatch.setattr(annotations_bucket, "get_file", mock_get_file) + + async def mock_delete_file(filename): + return True + + monkeypatch.setattr(annotations_bucket, "delete_file", mock_delete_file) + + # Switch content-type from JSON to multipart + del admin_auth["Content-Type"] + + with open(local_tmp_path, "r") as content: + response = await test_app_asyncio.post( + f"/annotations/{new_annotation_id}/upload", files=dict(file=content), headers=admin_auth + ) + + assert response.status_code == 200, print(response.json()["detail"]) + response_json = response.json() + updated_annotation = await get_entry(test_db, db.annotations, response_json["id"]) + updated_annotation = dict(**updated_annotation) + response_json.pop("created_at") + assert {k: v for k, v in updated_annotation.items() if k not in ("created_at", "bucket_key")} == response_json + assert updated_annotation["bucket_key"] is not None + + # 2b - Upload failing + async def failing_upload(bucket_key, file_binary): + return False + + monkeypatch.setattr(annotations_bucket, "upload_file", failing_upload) + response = await test_app_asyncio.post( + f"/annotations/{new_annotation_id}/upload", files=dict(file="bar"), headers=admin_auth + ) + assert response.status_code == 500 diff --git a/src/tests/routes/test_login.py b/src/tests/routes/test_login.py new file mode 100644 index 0000000..ab1c12e --- /dev/null +++ b/src/tests/routes/test_login.py @@ -0,0 +1,42 @@ +# Copyright (C) 2021, Pyronear contributors. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import pytest + +from app import db +from app.api import crud, security +from tests.db_utils import fill_table + +ACCESS_TABLE = [ + {"id": 1, "login": "first_login", "hashed_password": "hashed_first_pwd", "scope": "user"}, + {"id": 2, "login": "second_login", "hashed_password": "hashed_second_pwd", "scope": "user"}, +] + + +@pytest.fixture(scope="function") +async def init_test_db(monkeypatch, test_db): + monkeypatch.setattr(security, "verify_password", pytest.mock_verify_password) + monkeypatch.setattr(crud.base, "database", test_db) + await fill_table(test_db, db.accesses, ACCESS_TABLE) + + +@pytest.mark.parametrize( + "payload, status_code, status_detail", + [ + [{"username": "foo"}, 422, None], + [{"password": "foo"}, 422, None], + [{"username": "unknown", "password": "foo"}, 401, "Invalid credentials."], # unknown username + [{"username": "first", "password": "second"}, 401, "Invalid credentials."], # wrong pwd + [{"username": "first_login", "password": "first_pwd"}, 200, None], # valid + ], +) +@pytest.mark.asyncio +async def test_create_access_token(test_app_asyncio, init_test_db, payload, status_code, status_detail): + + response = await test_app_asyncio.post("/login/access-token", data=payload) + + assert response.status_code == status_code, print(payload) + if isinstance(status_detail, str): + assert response.json()["detail"] == status_detail diff --git a/src/tests/routes/test_media.py b/src/tests/routes/test_media.py new file mode 100644 index 0000000..7f8ebf6 --- /dev/null +++ b/src/tests/routes/test_media.py @@ -0,0 +1,249 @@ +# Copyright (C) 2021, Pyronear contributors. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import json +import os +import tempfile +from datetime import datetime + +import pytest +import requests + +from app import db +from app.api import crud +from app.services import media_bucket +from tests.db_utils import TestSessionLocal, fill_table, get_entry +from tests.utils import update_only_datetime + +ACCESS_TABLE = [ + {"id": 1, "login": "first_login", "hashed_password": "hashed_pwd", "scope": "user"}, + {"id": 2, "login": "second_login", "hashed_password": "hashed_pwd", "scope": "admin"}, +] + +MEDIA_TABLE = [ + {"id": 1, "type": "image", "created_at": "2020-10-13T08:18:45.447773"}, + {"id": 2, "type": "video", "created_at": "2020-10-13T09:18:45.447773"}, +] + + +MEDIA_TABLE_FOR_DB = list(map(update_only_datetime, MEDIA_TABLE)) + + +@pytest.fixture(scope="function") +async def init_test_db(monkeypatch, test_db): + monkeypatch.setattr(crud.base, "database", test_db) + monkeypatch.setattr(db, "SessionLocal", TestSessionLocal) + await fill_table(test_db, db.accesses, ACCESS_TABLE) + await fill_table(test_db, db.media, MEDIA_TABLE_FOR_DB) + + +@pytest.mark.parametrize( + "access_idx, media_id, status_code, status_details", + [ + [None, 1, 401, "Not authenticated"], + [0, 1, 403, "This access can't read resources"], + [1, 1, 200, None], + [1, 999, 404, "Table media has no entry with id=999"], + [1, 0, 422, None], + ], +) +@pytest.mark.asyncio +async def test_get_media(test_app_asyncio, init_test_db, access_idx, media_id, status_code, status_details): + + # Create a custom access token + auth = None + if isinstance(access_idx, int): + auth = await pytest.get_token(ACCESS_TABLE[access_idx]["id"], ACCESS_TABLE[access_idx]["scope"].split()) + + response = await test_app_asyncio.get(f"/media/{media_id}", headers=auth) + assert response.status_code == status_code + if isinstance(status_details, str): + assert response.json()["detail"] == status_details + + if response.status_code // 100 == 2: + assert response.json() == MEDIA_TABLE[media_id - 1] + + +@pytest.mark.parametrize( + "access_idx, status_code, status_details, expected_results", + [ + [None, 401, "Not authenticated", None], + [0, 200, None, []], + [1, 200, None, MEDIA_TABLE], + ], +) +@pytest.mark.asyncio +async def test_fetch_media(test_app_asyncio, init_test_db, access_idx, status_code, status_details, expected_results): + + # Create a custom access token + auth = None + if isinstance(access_idx, int): + auth = await pytest.get_token(ACCESS_TABLE[access_idx]["id"], ACCESS_TABLE[access_idx]["scope"].split()) + + response = await test_app_asyncio.get("/media/", headers=auth) + assert response.status_code == status_code + if isinstance(status_details, str): + assert response.json()["detail"] == status_details + + if response.status_code // 100 == 2: + assert response.json() == expected_results + + +@pytest.mark.parametrize( + "access_idx, payload, status_code, status_details", + [ + [None, {}, 401, "Not authenticated"], + [0, {"type": "video"}, 403, "Your access scope is not compatible with this operation."], + [1, {}, 201, None], + [1, {"type": "audio"}, 422, None], + ], +) +@pytest.mark.asyncio +async def test_create_media(test_app_asyncio, init_test_db, test_db, access_idx, payload, status_code, status_details): + + # Create a custom access token + auth = None + if isinstance(access_idx, int): + auth = await pytest.get_token(ACCESS_TABLE[access_idx]["id"], ACCESS_TABLE[access_idx]["scope"].split()) + + utc_dt = datetime.utcnow() + response = await test_app_asyncio.post("/media/", data=json.dumps(payload), headers=auth) + assert response.status_code == status_code + if isinstance(status_details, str): + assert response.json()["detail"] == status_details + + if response.status_code // 100 == 2: + json_response = response.json() + test_response = {"id": len(MEDIA_TABLE) + 1, **payload, "type": "image"} + assert {k: v for k, v in json_response.items() if k != "created_at"} == test_response + + new_media = await get_entry(test_db, db.media, json_response["id"]) + new_media = dict(**new_media) + + # Timestamp consistency + assert new_media["created_at"] > utc_dt and new_media["created_at"] < datetime.utcnow() + + +@pytest.mark.parametrize( + "access_idx, payload, media_id, status_code, status_details", + [ + [None, {}, 1, 401, "Not authenticated"], + [0, {"type": "video"}, 1, 403, "Your access scope is not compatible with this operation."], + [1, {"type": "video"}, 1, 200, None], + [1, {"type": "audio"}, 1, 422, None], + [1, {"type": "image"}, 999, 404, "Table media has no entry with id=999"], + [1, {"type": "audio"}, 1, 422, None], + [1, {"type": "image"}, 0, 422, None], + ], +) +@pytest.mark.asyncio +async def test_update_media( + test_app_asyncio, init_test_db, test_db, access_idx, payload, media_id, status_code, status_details +): + + # Create a custom access token + auth = None + if isinstance(access_idx, int): + auth = await pytest.get_token(ACCESS_TABLE[access_idx]["id"], ACCESS_TABLE[access_idx]["scope"].split()) + + response = await test_app_asyncio.put(f"/media/{media_id}/", data=json.dumps(payload), headers=auth) + assert response.status_code == status_code + if isinstance(status_details, str): + assert response.json()["detail"] == status_details + + if response.status_code // 100 == 2: + updated_media = await get_entry(test_db, db.media, media_id) + updated_media = dict(**updated_media) + for k, v in updated_media.items(): + if k != "bucket_key": + assert v == payload.get(k, MEDIA_TABLE_FOR_DB[media_id - 1][k]) + + +@pytest.mark.parametrize( + "access_idx, media_id, status_code, status_details", + [ + [None, 1, 401, "Not authenticated"], + [0, 1, 403, "Your access scope is not compatible with this operation."], + [1, 1, 200, None], + [1, 999, 404, "Table media has no entry with id=999"], + [1, 0, 422, None], + ], +) +@pytest.mark.asyncio +async def test_delete_media(test_app_asyncio, init_test_db, access_idx, media_id, status_code, status_details): + + # Create a custom access token + auth = None + if isinstance(access_idx, int): + auth = await pytest.get_token(ACCESS_TABLE[access_idx]["id"], ACCESS_TABLE[access_idx]["scope"].split()) + + response = await test_app_asyncio.delete(f"/media/{media_id}/", headers=auth) + assert response.status_code == status_code + if isinstance(status_details, str): + assert response.json()["detail"] == status_details + + if response.status_code // 100 == 2: + assert response.json() == MEDIA_TABLE[media_id - 1] + remaining_media = await test_app_asyncio.get("/media/", headers=auth) + assert all(entry["id"] != media_id for entry in remaining_media.json()) + + +@pytest.mark.asyncio +async def test_upload_media(test_app_asyncio, init_test_db, test_db, monkeypatch): + + admin_idx = 1 + # Create a custom access token + admin_auth = await pytest.get_token(ACCESS_TABLE[admin_idx]["id"], ACCESS_TABLE[admin_idx]["scope"].split()) + + # 1 - Create a media that will have an upload + payload = {} + new_media_id = len(MEDIA_TABLE_FOR_DB) + 1 + response = await test_app_asyncio.post("/media/", data=json.dumps(payload), headers=admin_auth) + assert response.status_code == 201 + + # 2 - Upload something + async def mock_upload_file(bucket_key, file_binary): + return True + + monkeypatch.setattr(media_bucket, "upload_file", mock_upload_file) + + # Download and save a temporary file + local_tmp_path = os.path.join(tempfile.gettempdir(), "my_temp_image.jpg") + img_content = requests.get("https://pyronear.org/img/logo_letters.png").content + with open(local_tmp_path, "wb") as f: + f.write(img_content) + + async def mock_get_file(bucket_key): + return local_tmp_path + + monkeypatch.setattr(media_bucket, "get_file", mock_get_file) + + async def mock_delete_file(filename): + return True + + monkeypatch.setattr(media_bucket, "delete_file", mock_delete_file) + + # Switch content-type from JSON to multipart + del admin_auth["Content-Type"] + + response = await test_app_asyncio.post( + f"/media/{new_media_id}/upload", files=dict(file=img_content), headers=admin_auth + ) + + assert response.status_code == 200, print(response.json()["detail"]) + response_json = response.json() + updated_media = await get_entry(test_db, db.media, response_json["id"]) + updated_media = dict(**updated_media) + response_json.pop("created_at") + assert {k: v for k, v in updated_media.items() if k not in ("created_at", "bucket_key")} == response_json + assert updated_media["bucket_key"] is not None + + # 2b - Upload failing + async def failing_upload(bucket_key, file_binary): + return False + + monkeypatch.setattr(media_bucket, "upload_file", failing_upload) + response = await test_app_asyncio.post(f"/media/{new_media_id}/upload", files=dict(file="bar"), headers=admin_auth) + assert response.status_code == 500 diff --git a/src/tests/test_deps.py b/src/tests/test_deps.py new file mode 100644 index 0000000..bcae259 --- /dev/null +++ b/src/tests/test_deps.py @@ -0,0 +1,54 @@ +# Copyright (C) 2022, Pyronear contributors. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import pytest +from fastapi import HTTPException +from fastapi.security import SecurityScopes + +from app import db +from app.api import crud, deps, security +from app.api.schemas import AccessRead +from tests.db_utils import fill_table + +ACCESS_TABLE = [ + {"id": 1, "login": "first_user", "hashed_password": "first_pwd_hashed", "scope": "user"}, + {"id": 2, "login": "connected_user", "hashed_password": "first_pwd_hashed", "scope": "user"}, + {"id": 3, "login": "first_device", "hashed_password": "first_pwd_hashed", "scope": "admin"}, + {"id": 4, "login": "second_device", "hashed_password": "second_pwd_hashed", "scope": "admin"}, +] + + +@pytest.fixture(scope="function") +async def init_test_db(monkeypatch, test_db): + monkeypatch.setattr(crud.base, "database", test_db) + await fill_table(test_db, db.accesses, ACCESS_TABLE) + + +@pytest.mark.parametrize( + "token_data, scope, expected_access, exception", + [ + [ACCESS_TABLE[0], "user", 0, False], + ["my_false_token", "admin", None, True], # Decoding failure + [{"id": 100, "scope": "admin"}, "admin", None, True], # Unable to find access in table + [ACCESS_TABLE[3], "admin", 3, False], # Correct + ], +) +@pytest.mark.asyncio +async def test_get_current_access(init_test_db, token_data, scope, expected_access, exception): + + # Create a token for the access we'll want to retrieve + if isinstance(token_data, str): + token = token_data + else: + _data = {"sub": str(token_data["id"]), "scopes": token_data["scope"].split()} + token = await security.create_access_token(_data) + # Check that we retrieve the correct access + if exception: + with pytest.raises(HTTPException): + access = await deps.get_current_access(SecurityScopes([scope]), token=token) + else: + access = await deps.get_current_access(SecurityScopes([scope]), token=token) + if isinstance(expected_access, int): + assert access.dict() == AccessRead(**ACCESS_TABLE[expected_access]).dict() diff --git a/src/tests/test_external.py b/src/tests/test_external.py new file mode 100644 index 0000000..48c5a05 --- /dev/null +++ b/src/tests/test_external.py @@ -0,0 +1,11 @@ +# Copyright (C) 2021, Pyronear contributors. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from app.api.external import post_request + + +def test_post_request(): + response = post_request("https://httpbin.org/post") + assert response.status_code == 200 diff --git a/src/tests/test_security.py b/src/tests/test_security.py new file mode 100644 index 0000000..c916f0d --- /dev/null +++ b/src/tests/test_security.py @@ -0,0 +1,71 @@ +# Copyright (C) 2021, Pyronear contributors. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from datetime import datetime, timedelta + +import pytest +import requests +from jose import jwt + +from app import config as cfg +from app.api import security + + +@pytest.mark.asyncio +async def test_hash_password(): + + pwd1 = "my_password" + hash_pwd1 = await security.hash_password(pwd1) + + assert hash_pwd1 != pwd1 + assert hash_pwd1 != await security.hash_password(pwd1 + "bis") + # Check that it's non deterministic + assert hash_pwd1 != await security.hash_password(pwd1) + + +@pytest.mark.asyncio +async def test_verify_password(): + + pwd1 = "my_password" + hash_pwd1 = await security.hash_password(pwd1) + + assert await security.verify_password(pwd1, hash_pwd1) + assert not await security.verify_password("another_try", hash_pwd1) + + +def test_hash_content_file(): + + # Download a small file + file_url1 = "https://github.com/pyronear/pyro-api/releases/download/v0.1.1/pyronear_logo.png" + file_url2 = "https://github.com/pyronear/pyro-api/releases/download/v0.1.1/pyronear_logo_mini.png" + + # Hash it + hash1 = security.hash_content_file(requests.get(file_url1).content) + hash2 = security.hash_content_file(requests.get(file_url2).content) + + # Check data integrity + assert security.hash_content_file(requests.get(file_url1).content) == hash1 + assert hash1 != hash2 + + +@pytest.mark.parametrize( + "content, expiration, expected_delta", + [ + [{"data": "my_data"}, 60, 60], + [{"data": "my_data"}, None, cfg.ACCESS_TOKEN_EXPIRE_MINUTES], + ], +) +@pytest.mark.asyncio +async def test_create_access_token(content, expiration, expected_delta): + + delta = timedelta(minutes=expiration) if isinstance(expiration, int) else None + payload = await security.create_access_token(content, expires_delta=delta) + after = datetime.utcnow() + assert isinstance(payload, str) + decoded_data = jwt.decode(payload, cfg.SECRET_KEY) + # Verify data integrity + assert all(v == decoded_data[k] for k, v in content.items()) + # Check expiration + assert datetime.utcfromtimestamp(decoded_data["exp"]) - timedelta(minutes=expected_delta) < after diff --git a/src/tests/test_services.py b/src/tests/test_services.py new file mode 100644 index 0000000..437d009 --- /dev/null +++ b/src/tests/test_services.py @@ -0,0 +1,23 @@ +# Copyright (C) 2021, Pyronear contributors. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from app.services import annotations_bucket, media_bucket, resolve_bucket_key +from app.services.bucket import QarnotBucket + + +def test_resolve_bucket_key(monkeypatch): + file_name = "myfile.jpg" + bucket_subfolder = "my/bucket/subfolder" + + # Same if the bucket folder is specified + assert resolve_bucket_key(file_name, bucket_subfolder) == f"{bucket_subfolder}/{file_name}" + + # Check that it returns the same thing when bucket folder is not set + assert resolve_bucket_key(file_name) == file_name + + +def test_bucket_service(): + assert isinstance(media_bucket, QarnotBucket) + assert isinstance(annotations_bucket, QarnotBucket) diff --git a/src/tests/utils.py b/src/tests/utils.py new file mode 100644 index 0000000..3c6c0ec --- /dev/null +++ b/src/tests/utils.py @@ -0,0 +1,28 @@ +# Copyright (C) 2021, Pyronear contributors. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from copy import deepcopy +from datetime import datetime + +DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%f" + + +def update_only_datetime(entity_as_dict): + to_return = deepcopy(entity_as_dict) + if isinstance(to_return.get("created_at"), str): + to_return["created_at"] = parse_time(to_return["created_at"]) + if isinstance(to_return.get("start_ts"), str): + to_return["start_ts"] = parse_time(to_return["start_ts"]) + if isinstance(to_return.get("end_ts"), str): + to_return["end_ts"] = parse_time(to_return["end_ts"]) + return to_return + + +def parse_time(d): + return datetime.strptime(d, DATETIME_FORMAT) + + +def ts_to_string(ts): + return datetime.strftime(ts, DATETIME_FORMAT)