diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d0764eec..6bc2461c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,48 +1,17 @@ # See https://pre-commit.com for more information # See https://pre-commit.com/hooks.html for more hooks repos: - - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.4.3 - hooks: - - id: ruff - args: ['--fix', '--exit-non-zero-on-fix'] - # Extra args, only after removing flake8 and yesqa: '--extend-select', 'RUF100' - repo: https://github.com/asottile/pyupgrade rev: v3.15.2 hooks: - id: pyupgrade args: ['--keep-runtime-typing', '--py311-plus'] - - repo: https://github.com/asottile/yesqa - rev: v1.5.0 - hooks: - - id: yesqa - additional_dependencies: &flake8deps - - bandit - - flake8-assertive - - flake8-blind-except - - flake8-bugbear - - flake8-builtins - - flake8-comprehensions - - flake8-docstrings - - flake8-isort - - flake8-logging-format - - flake8-mutable - - flake8-plugin-utils - - flake8-print - - flake8-pytest-style - - pep8-naming - - toml - - tomli - - repo: https://github.com/PyCQA/isort - rev: 5.13.2 - hooks: - - id: isort - additional_dependencies: - - toml - - repo: https://github.com/psf/black - rev: 24.4.2 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.4.3 hooks: - - id: black + - id: ruff + args: ['--fix', '--exit-non-zero-on-fix'] + - id: ruff-format - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.6.0 hooks: @@ -87,11 +56,6 @@ repos: - types-requests - types-setuptools - typing-extensions - - repo: https://github.com/PyCQA/flake8 - rev: 7.0.0 - hooks: - - id: flake8 - additional_dependencies: *flake8deps - repo: https://github.com/PyCQA/pylint rev: v3.1.0 hooks: diff --git a/boxoffice/__init__.py b/boxoffice/__init__.py index 63d37520..371a3e9d 100644 --- a/boxoffice/__init__.py +++ b/boxoffice/__init__.py @@ -9,12 +9,12 @@ from flask_rq2 import RQ from pytz import timezone +import coaster.app from baseframe import Version, baseframe from baseframe.utils import JSONProvider from coaster.assets import WebpackManifest from flask_lastuser import Lastuser from flask_lastuser.sqlalchemy import UserManager -import coaster.app from ._version import __version__ diff --git a/boxoffice/cli.py b/boxoffice/cli.py index 627695e2..08c70b42 100644 --- a/boxoffice/cli.py +++ b/boxoffice/cli.py @@ -2,7 +2,7 @@ @app.cli.command('dbconfig') -def dbconfig(): +def dbconfig() -> None: """Show required database configuration.""" print( # noqa: T201 ''' diff --git a/boxoffice/data/__init__.py b/boxoffice/data/__init__.py index d6f78bc1..ac12f9ad 100644 --- a/boxoffice/data/__init__.py +++ b/boxoffice/data/__init__.py @@ -1,52 +1,56 @@ -from typing_extensions import TypedDict +from dataclasses import dataclass -class StateList(TypedDict): +@dataclass +class GstState: name: str - short_code: int - short_code_text: str + title: str + code: int -indian_states: list[StateList] = [ - {'short_code': 1, 'name': 'Jammu and Kashmir', 'short_code_text': 'JK'}, - {'short_code': 2, 'name': 'Himachal Pradesh', 'short_code_text': 'HP'}, - {'short_code': 3, 'name': 'Punjab', 'short_code_text': 'PB'}, - {'short_code': 4, 'name': 'Chandigarh', 'short_code_text': 'CH'}, - {'short_code': 5, 'name': 'Uttarakhand', 'short_code_text': 'UT'}, - {'short_code': 6, 'name': 'Haryana', 'short_code_text': 'HR'}, - {'short_code': 7, 'name': 'Delhi', 'short_code_text': 'DL'}, - {'short_code': 8, 'name': 'Rajasthan', 'short_code_text': 'RJ'}, - {'short_code': 9, 'name': 'Uttar Pradesh', 'short_code_text': 'UP'}, - {'short_code': 10, 'name': 'Bihar', 'short_code_text': 'BR'}, - {'short_code': 11, 'name': 'Sikkim', 'short_code_text': 'SK'}, - {'short_code': 12, 'name': 'Arunachal Pradesh', 'short_code_text': 'AR'}, - {'short_code': 13, 'name': 'Nagaland', 'short_code_text': 'NL'}, - {'short_code': 14, 'name': 'Manipur', 'short_code_text': 'MN'}, - {'short_code': 15, 'name': 'Mizoram', 'short_code_text': 'MZ'}, - {'short_code': 16, 'name': 'Tripura', 'short_code_text': 'TR'}, - {'short_code': 17, 'name': 'Meghalaya', 'short_code_text': 'ML'}, - {'short_code': 18, 'name': 'Assam', 'short_code_text': 'AS'}, - {'short_code': 19, 'name': 'West Bengal', 'short_code_text': 'WB'}, - {'short_code': 20, 'name': 'Jharkhand', 'short_code_text': 'JH'}, - {'short_code': 21, 'name': 'Odisha', 'short_code_text': 'OR'}, - {'short_code': 22, 'name': 'Chhattisgarh', 'short_code_text': 'CT'}, - {'short_code': 23, 'name': 'Madhya Pradesh', 'short_code_text': 'MP'}, - {'short_code': 24, 'name': 'Gujarat', 'short_code_text': 'GJ'}, - {'short_code': 25, 'name': 'Daman and Diu', 'short_code_text': 'DD'}, - {'short_code': 26, 'name': 'Dadra and Nagar Haveli', 'short_code_text': 'DN'}, - {'short_code': 27, 'name': 'Maharashtra', 'short_code_text': 'MH'}, - {'short_code': 28, 'name': 'Andhra Pradesh (old)', 'short_code_text': 'AP'}, - {'short_code': 29, 'name': 'Karnataka', 'short_code_text': 'KA'}, - {'short_code': 30, 'name': 'Goa', 'short_code_text': 'GA'}, - {'short_code': 31, 'name': 'Lakshadweep', 'short_code_text': 'LD'}, - {'short_code': 32, 'name': 'Kerala', 'short_code_text': 'KL'}, - {'short_code': 33, 'name': 'Tamil Nadu', 'short_code_text': 'TN'}, - {'short_code': 34, 'name': 'Puducherry', 'short_code_text': 'PY'}, - {'short_code': 35, 'name': 'Andaman and Nicobar Islands', 'short_code_text': 'AN'}, - {'short_code': 36, 'name': 'Telangana', 'short_code_text': 'TG'}, - {'short_code': 37, 'name': 'Andhra Pradesh', 'short_code_text': 'AD'}, +indian_states: list[GstState] = [ + GstState(code=1, name='JK', title='Jammu and Kashmir'), + GstState(code=2, name='HP', title='Himachal Pradesh'), + GstState(code=3, name='PB', title='Punjab'), + GstState(code=4, name='CH', title='Chandigarh'), + GstState(code=5, name='UT', title='Uttarakhand'), + GstState(code=6, name='HR', title='Haryana'), + GstState(code=7, name='DL', title='Delhi'), + GstState(code=8, name='RJ', title='Rajasthan'), + GstState(code=9, name='UP', title='Uttar Pradesh'), + GstState(code=10, name='BR', title='Bihar'), + GstState(code=11, name='SK', title='Sikkim'), + GstState(code=12, name='AR', title='Arunachal Pradesh'), + GstState(code=13, name='NL', title='Nagaland'), + GstState(code=14, name='MN', title='Manipur'), + GstState(code=15, name='MZ', title='Mizoram'), + GstState(code=16, name='TR', title='Tripura'), + GstState(code=17, name='ML', title='Meghalaya'), + GstState(code=18, name='AS', title='Assam'), + GstState(code=19, name='WB', title='West Bengal'), + GstState(code=20, name='JH', title='Jharkhand'), + GstState(code=21, name='OR', title='Odisha'), + GstState(code=22, name='CG', title='Chattisgarh'), + GstState(code=23, name='MP', title='Madhya Pradesh'), + GstState(code=24, name='GJ', title='Gujarat'), + GstState(code=25, name='DD', title='Daman and Diu (old)'), + GstState(code=26, name='DN', title='Dadra, Nagar Haveli, Daman and Diu'), + GstState(code=27, name='MH', title='Maharashtra'), + GstState(code=28, name='AP', title='Andhra Pradesh (old)'), + GstState(code=29, name='KA', title='Karnataka'), + GstState(code=30, name='GA', title='Goa'), + GstState(code=31, name='LD', title='Lakshadweep'), + GstState(code=32, name='KL', title='Kerala'), + GstState(code=33, name='TN', title='Tamil Nadu'), + GstState(code=34, name='PY', title='Puducherry'), + GstState(code=35, name='AN', title='Andaman and Nicobar Islands'), + GstState(code=36, name='TG', title='Telangana'), + GstState(code=37, name='AD', title='Andhra Pradesh'), + GstState(code=38, name='LA', title='Ladakh'), ] +indian_states.sort(key=lambda s: (s.title, s.code) if s.code < 90 else ('ZZ', s.code)) -indian_states_dict = {d["short_code_text"]: d for d in indian_states} -short_codes = [state['short_code'] for state in indian_states] +indian_states_dict = {state.name: state for state in indian_states} + +codes = [state.code for state in indian_states] diff --git a/boxoffice/extapi/razorpay.py b/boxoffice/extapi/razorpay.py index 73d479ef..c2aaa3c8 100644 --- a/boxoffice/extapi/razorpay.py +++ b/boxoffice/extapi/razorpay.py @@ -1,3 +1,7 @@ +from datetime import tzinfo +from decimal import Decimal +from typing import Any, TypedDict + import requests from baseframe import localize_timezone @@ -9,35 +13,38 @@ base_url = 'https://api.razorpay.com/v1' -def capture_payment(paymentid, amount): +class YearMonth(TypedDict): + year: int + month: int + + +def capture_payment(paymentid: str, amount: Decimal) -> requests.Response: """Attempt to capture the payment from Razorpay.""" verify_https = app.config.get('VERIFY_RAZORPAY_HTTPS', True) url = f'{base_url}/payments/{paymentid}/capture' # Razorpay requires the amount to be in paisa and of type integer - resp = requests.post( + return requests.post( url, data={'amount': int(amount * 100)}, auth=(app.config['RAZORPAY_KEY_ID'], app.config['RAZORPAY_KEY_SECRET']), verify=verify_https, timeout=30, ) - return resp -def refund_payment(paymentid, amount): +def refund_payment(paymentid: str, amount: Decimal) -> requests.Response: """Send a POST request to Razorpay to initiate a refund.""" url = f'{base_url}/payments/{paymentid}/refund' # Razorpay requires the amount to be in paisa and of type integer - resp = requests.post( + return requests.post( url, data={'amount': int(amount * 100)}, auth=(app.config['RAZORPAY_KEY_ID'], app.config['RAZORPAY_KEY_SECRET']), timeout=30, ) - return resp -def get_settlements(date_range): +def get_settlements(date_range: YearMonth) -> Any: url = f'{base_url}/settlements/recon/combined' resp = requests.get( url, @@ -48,7 +55,9 @@ def get_settlements(date_range): return resp.json() -def get_settled_transactions(date_range, tz=None): +def get_settled_transactions( + date_range: YearMonth, tz: str | tzinfo | None = None +) -> tuple[list[str], list]: if not tz: tz = app.config['TIMEZONE'] settled_transactions = get_settlements(date_range) @@ -96,6 +105,7 @@ def get_settled_transactions(date_range, tz=None): ).one_or_none() if payment: order = payment.order + assert order.paid_at is not None # noqa: S101 # nosec B101 rows.append( { 'settlement_id': settled_transaction['settlement_id'], diff --git a/boxoffice/forms/menu.py b/boxoffice/forms/menu.py index fef06201..70c68ec0 100644 --- a/boxoffice/forms/menu.py +++ b/boxoffice/forms/menu.py @@ -27,7 +27,7 @@ class MenuForm(forms.Form): __("State"), description=__("State of supply"), coerce=int, - default=indian_states_dict['KA']['short_code'], + default=indian_states_dict['KA'].code, validators=[forms.validators.DataRequired(__("Please select a state"))], ) place_supply_country_code = forms.SelectField( @@ -39,10 +39,9 @@ class MenuForm(forms.Form): def __post_init__(self) -> None: self.place_supply_state_code.choices = [(0, '')] + [ - (state['short_code'], state['name']) - for state in sorted(indian_states, key=lambda k: k['name']) + (state.code, state.title) for state in indian_states ] - self.place_supply_country_code.choices = [('', '')] + localized_country_list() + self.place_supply_country_code.choices = [('', ''), *localized_country_list()] def validate_place_supply_state_code(self, field: forms.Field) -> None: if field.data <= 0: diff --git a/boxoffice/forms/order.py b/boxoffice/forms/order.py index b31aa6d5..9571c009 100644 --- a/boxoffice/forms/order.py +++ b/boxoffice/forms/order.py @@ -1,13 +1,13 @@ from __future__ import annotations from collections.abc import Callable -from typing import Any +from typing import Any, Self from werkzeug.datastructures import ImmutableMultiDict from baseframe import __, forms -from ..data import indian_states_dict, short_codes +from ..data import codes as gst_codes, indian_states_dict __all__ = ['LineItemForm', 'BuyerForm', 'OrderSessionForm', 'InvoiceForm'] @@ -36,7 +36,7 @@ class LineItemForm(forms.Form): ) @classmethod - def process_list(cls, line_items_json: list[Any]): + def process_list(cls, line_items_json: list[Any]) -> list[Self]: """ Return a list of LineItemForm objects. @@ -88,18 +88,18 @@ class OrderSessionForm(forms.Form): utm_term = forms.StringField(__("UTM Term"), filters=[trim(250)]) utm_content = forms.StringField(__("UTM Content"), filters=[trim(250)]) utm_id = forms.StringField(__("UTM Id"), filters=[trim(250)]) - gclid = forms.StringField(__("Gclid"), filters=[trim(250)]) + gclid = forms.StringField(__("Google Click Id"), filters=[trim(250)]) referrer = forms.StringField(__("Referrer"), filters=[trim(2083)]) host = forms.StringField(__("Host"), filters=[trim(2083)]) -def validate_state_code(form, field: forms.Field) -> None: +def validate_state_code(form: InvoiceForm, field: forms.Field) -> None: # Note: state_code is only a required field if the chosen country is India if form.country_code.data == "IN" and field.data.upper() not in indian_states_dict: raise forms.validators.StopValidation(__("Please select a state")) -def validate_gstin(_form, field: forms.Field) -> None: +def validate_gstin(_form: forms.Form, field: forms.Field) -> None: """ Raise a StopValidation exception if the supplied field's data is not a valid GSTIN. @@ -115,7 +115,7 @@ def validate_gstin(_form, field: forms.Field) -> None: # 15 length, first 2 digits, valid pan, checksum if ( len(field.data) != 15 - or int(field.data[:2]) not in short_codes + or int(field.data[:2]) not in gst_codes or not field.data[2:12].isalnum() or not field.data[-1].isalnum() ): diff --git a/boxoffice/forms/ticket.py b/boxoffice/forms/ticket.py index 6af39f05..b8eb360d 100644 --- a/boxoffice/forms/ticket.py +++ b/boxoffice/forms/ticket.py @@ -81,7 +81,7 @@ class TicketForm(forms.Form): __("State"), description=__("State of supply"), coerce=int, - default=indian_states_dict['KA']['short_code'], + default=indian_states_dict['KA'].code, validators=[forms.validators.DataRequired(__("Please select a state"))], ) place_supply_country_code = forms.SelectField( @@ -93,10 +93,9 @@ class TicketForm(forms.Form): def __post_init__(self) -> None: self.place_supply_state_code.choices = [(0, '')] + [ - (state['short_code'], state['name']) - for state in sorted(indian_states, key=lambda k: k['name']) + (state.code, state.title) for state in indian_states ] - self.place_supply_country_code.choices = [('', '')] + localized_country_list() + self.place_supply_country_code.choices = [('', ''), *localized_country_list()] self.category.query = ( Category.query.join(Menu, Category.menu_id == Menu.id) .filter(Category.menu == self.edit_parent) diff --git a/boxoffice/mailclient.py b/boxoffice/mailclient.py index 45ac48ef..32bd3b65 100644 --- a/boxoffice/mailclient.py +++ b/boxoffice/mailclient.py @@ -18,14 +18,15 @@ def send_receipt_mail( order_id: UUID, subject: str | None = None, template: str = 'order_confirmation_mail.html.jinja2', -): +) -> None: """Send buyer a link to fill attendee details and get cash receipt.""" with app.test_request_context(): if subject is None: subject = _("Thank you for your order!") order = Order.query.get(order_id) if order is None: - raise ValueError(f"Unable to find Order with id={order_id!r}") + err = f"Unable to find Order with id={order_id!r}" + raise ValueError(err) msg = Message( subject=subject, recipients=[order.buyer_email], @@ -59,13 +60,14 @@ def send_participant_assignment_mail( menu_title: str, team_member: str, subject: str | None = None, -): +) -> None: with app.test_request_context(): if subject is None: subject = _("Please tell us who's coming!") order = Order.query.get(order_id) if order is None: - raise ValueError(f"Unable to find Order with id={order_id!r}") + err = f"Unable to find Order with id={order_id!r}" + raise ValueError(err) msg = Message( subject=subject, recipients=[order.buyer_email], @@ -89,13 +91,14 @@ def send_participant_assignment_mail( @rq.job('boxoffice') def send_line_item_cancellation_mail( line_item_id: UUID, refund_amount: Decimal, subject: str | None = None -): +) -> None: with app.test_request_context(): if subject is None: subject = _("Ticket Cancellation") line_item = LineItem.query.get(line_item_id) if line_item is None: - raise ValueError(f"Unable to find LineItem with id={line_item_id!r}") + err = f"Unable to find LineItem with id={line_item_id!r}" + raise ValueError(err) ticket_title = line_item.ticket.title order = line_item.order is_paid = line_item.final_amount > Decimal('0') @@ -126,11 +129,12 @@ def send_line_item_cancellation_mail( @rq.job('boxoffice') def send_order_refund_mail( order_id: UUID, refund_amount: Decimal, note_to_user: MarkdownComposite -): +) -> None: with app.test_request_context(): order = Order.query.get(order_id) if order is None: - raise ValueError(f"Unable to find Order with id={order_id!r}") + err = f"Unable to find Order with id={order_id!r}" + raise ValueError(err) subject = _("{menu_title}: Refund for receipt no. {receipt_no}").format( menu_title=order.menu.title, receipt_no=order.receipt_no, @@ -158,16 +162,16 @@ def send_order_refund_mail( @rq.job('boxoffice') -def send_ticket_assignment_mail(line_item_id: UUID): +def send_ticket_assignment_mail(line_item_id: UUID) -> None: """Send a confirmation email when ticket has been assigned.""" with app.test_request_context(): line_item = LineItem.query.get(line_item_id) if line_item is None: - raise ValueError(f"Unable to find LineItem with id={line_item_id!r}") + err = f"Unable to find LineItem with id={line_item_id!r}" + raise ValueError(err) if line_item.assignee is None: - raise ValueError( - f"LineItem.assignee is None for LineItem.id={line_item_id!r}" - ) + err = f"LineItem.assignee is None for LineItem.id={line_item_id!r}" + raise ValueError(err) order = line_item.order subject = _("{title}: Here's your ticket").format(title=order.menu.title) msg = Message( @@ -192,17 +196,18 @@ def send_ticket_assignment_mail(line_item_id: UUID): @rq.job('boxoffice') def send_ticket_reassignment_mail( line_item_id: UUID, old_assignee_id: UUID, new_assignee_id: UUID -): +) -> None: """Send notice of reassignment of ticket.""" with app.test_request_context(): line_item = LineItem.query.get(line_item_id) old_assignee = Assignee.query.get(old_assignee_id) new_assignee = Assignee.query.get(new_assignee_id) if line_item is None or old_assignee is None or new_assignee is None: - raise ValueError( + err = ( f"Unexpected None value in line_item={line_item!r}," f" old_assignee={old_assignee!r}, new_assignee={new_assignee!r}" ) + raise ValueError(err) order = line_item.order subject = _("{title}: Your ticket has been transferred to someone else").format( diff --git a/boxoffice/models/__init__.py b/boxoffice/models/__init__.py index 9b67abd8..4126f000 100644 --- a/boxoffice/models/__init__.py +++ b/boxoffice/models/__init__.py @@ -3,10 +3,10 @@ from datetime import datetime from typing import Annotated, TypeAlias +import sqlalchemy as sa from flask_sqlalchemy import SQLAlchemy from sqlalchemy.dialects import postgresql from sqlalchemy.orm import DeclarativeBase, Mapped -import sqlalchemy as sa from coaster.sqlalchemy import ( AppenderQuery, diff --git a/boxoffice/models/category.py b/boxoffice/models/category.py index 462bb8f4..fb83f271 100644 --- a/boxoffice/models/category.py +++ b/boxoffice/models/category.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, ClassVar from uuid import UUID from coaster.sqlalchemy import role_check @@ -30,10 +30,12 @@ class Category(BaseScopedNameMixin[int, User], Model): ) parent: Mapped[Menu] = sa.orm.synonym('menu') - __roles__ = {'category_owner': {'read': {'id', 'name', 'title', 'menu_id'}}} + __roles__: ClassVar = { + 'category_owner': {'read': {'id', 'name', 'title', 'menu_id'}} + } @role_check('category_owner') - def has_category_owner_role(self, actor: User | None, _anchors=()) -> bool: + def has_category_owner_role(self, actor: User | None, _anchors: Any = ()) -> bool: return ( actor is not None and self.menu.organization.userid in actor.organizations_owned_ids() diff --git a/boxoffice/models/discount_policy.py b/boxoffice/models/discount_policy.py index 8fdeb5d1..a0f82f13 100644 --- a/boxoffice/models/discount_policy.py +++ b/boxoffice/models/discount_policy.py @@ -2,12 +2,12 @@ from __future__ import annotations +import secrets +import string from collections.abc import Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, ClassVar, Self from uuid import UUID -import secrets -import string from itsdangerous import BadSignature, Signer from sqlalchemy.orm.exc import MultipleResultsFound @@ -114,7 +114,7 @@ class DiscountPolicy(BaseScopedNameMixin[UUID, User], Model): lazy='dynamic', back_populates='discount_policy' ) - __roles__ = { + __roles__: ClassVar = { 'dp_owner': { 'read': { 'id', @@ -133,7 +133,7 @@ class DiscountPolicy(BaseScopedNameMixin[UUID, User], Model): } @role_check('dp_owner') - def has_dp_owner_role(self, actor: User | None, _anchors=()) -> bool: + def has_dp_owner_role(self, actor: User | None, _anchors: Any = ()) -> bool: return ( actor is not None and self.organization.userid in actor.organizations_owned_ids() @@ -159,7 +159,8 @@ def gen_signed_code(self, identifier: str | None = None) -> str: Format: ``discount_code_base.randint.signature`` """ if not self.secret: - raise TypeError("DiscountPolicy.secret is unset") + msg = "DiscountPolicy.secret is unset" + raise TypeError(msg) if not identifier: identifier = secrets.token_urlsafe(16) signer = Signer(self.secret) @@ -167,12 +168,14 @@ def gen_signed_code(self, identifier: str | None = None) -> str: return signer.sign(key).decode('utf-8') @staticmethod - def is_signed_code_format(code) -> bool: + def is_signed_code_format(code: str) -> bool: """Check if the code is in the {x.y.z} format.""" return len(code.split('.')) == 3 if code else False @classmethod - def get_from_signed_code(cls, code, organization_id) -> DiscountPolicy | None: + def get_from_signed_code( + cls, code: str, organization_id: int + ) -> DiscountPolicy | None: """Return a discount policy given a valid signed code, None otherwise.""" if not cls.is_signed_code_format(code): return None @@ -190,7 +193,7 @@ def get_from_signed_code(cls, code, organization_id) -> DiscountPolicy | None: return None @classmethod - def make_bulk(cls, discount_code_base, **kwargs) -> DiscountPolicy: + def make_bulk(cls, discount_code_base: str | None, **kwargs) -> Self: """Return a discount policy for bulk discount coupons.""" return cls( discount_type=DiscountTypeEnum.COUPON, @@ -201,7 +204,7 @@ def make_bulk(cls, discount_code_base, **kwargs) -> DiscountPolicy: @classmethod def get_from_ticket( - cls, ticket: Ticket, qty, coupon_codes: Sequence[str] = () + cls, ticket: Ticket, qty: int, coupon_codes: Sequence[str] = () ) -> list[PolicyCoupon]: """ Return a list of (discount_policy, discount_coupon) tuples. @@ -250,13 +253,13 @@ def get_from_ticket( return policies @property - def line_items_count(self): + def line_items_count(self) -> int: return self.line_items.filter( LineItem.status == LineItemStatus.CONFIRMED ).count() @classmethod - def is_valid_access_coupon(cls, ticket: Ticket, code_list): + def is_valid_access_coupon(cls, ticket: Ticket, code_list: list[str]) -> bool: """ Check if any of code_list is a valid access code for the specified ticket. @@ -292,12 +295,17 @@ def is_valid_access_coupon(cls, ticket: Ticket, code_list): @sa.event.listens_for(DiscountPolicy, 'before_update') @sa.event.listens_for(DiscountPolicy, 'before_insert') -def validate_price_based_discount(_mapper, _connection, target: DiscountPolicy): +def validate_price_based_discount( + _mapper: Any, _connection: Any, target: DiscountPolicy +) -> None: if target.is_price_based and len(target.tickets) > 1: - raise ValueError("Price-based discounts MUST have only one associated ticket") + msg = "Price-based discounts MUST have only one associated ticket" + raise ValueError(msg) -def generate_coupon_code(size=6, chars=string.ascii_uppercase + string.digits): +def generate_coupon_code( + size: int = 6, chars: str = string.ascii_uppercase + string.digits +) -> str: return ''.join(secrets.choice(chars) for _ in range(size)) @@ -305,7 +313,7 @@ class DiscountCoupon(IdMixin[UUID], Model): __tablename__ = 'discount_coupon' __table_args__ = (sa.UniqueConstraint('discount_policy_id', 'code'),) - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: self.id = uuid1mc() super().__init__(*args, **kwargs) @@ -324,7 +332,7 @@ def __init__(self, *args, **kwargs): line_items: Mapped[list[LineItem]] = relationship(back_populates='discount_coupon') @classmethod - def is_signed_code_usable(cls, policy, code): + def is_signed_code_usable(cls, policy: DiscountPolicy, code: str) -> bool: obj = cls.query.filter( cls.discount_policy == policy, cls.code == code, @@ -334,10 +342,10 @@ def is_signed_code_usable(cls, policy, code): return False return True - def update_used_count(self): + def update_used_count(self) -> None: self.used_count = ( sa.select(sa.func.count()) - .where(LineItem.discount_coupon == self) + .where(LineItem.discount_coupon_id == self.id) .where(LineItem.status == LineItemStatus.CONFIRMED) .as_scalar() ) diff --git a/boxoffice/models/enums.py b/boxoffice/models/enums.py index 8c6bf9f7..87f8967e 100644 --- a/boxoffice/models/enums.py +++ b/boxoffice/models/enums.py @@ -3,6 +3,7 @@ from __future__ import annotations from enum import IntEnum, StrEnum +from typing import Final from baseframe import __ @@ -23,10 +24,10 @@ class DiscountTypeEnum(IntEnum): AUTOMATIC = 0 COUPON = 1 - __titles__ = {AUTOMATIC: __("Automatic"), COUPON: __("Coupon")} + __titles__: Final = {AUTOMATIC: __("Automatic"), COUPON: __("Coupon")} - def __init__(self, value: int): - self.title = self.__titles__[value] + def __init__(self, value: int) -> None: + self.title = self.__titles__[value] # pylint: disable=unsubscriptable-object class InvoiceStatus(IntEnum): diff --git a/boxoffice/models/invoice.py b/boxoffice/models/invoice.py index af827fe9..59dec3b9 100644 --- a/boxoffice/models/invoice.py +++ b/boxoffice/models/invoice.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from datetime import datetime +from typing import TYPE_CHECKING, Any, ClassVar from uuid import UUID from baseframe import _ @@ -15,12 +16,14 @@ __all__ = ['Invoice'] -def gen_invoice_no(organization, jurisdiction, invoice_dt): +def gen_invoice_no( + organization: Organization, jurisdiction: str, invoice_dt: datetime +) -> sa.ScalarSelect[int]: """Generate a sequential invoice number for the organization and financial year.""" fy_start_at, fy_end_at = get_fiscal_year(jurisdiction, invoice_dt) return ( sa.select(sa.func.coalesce(sa.func.max(Invoice.invoice_no + 1), 1)) - .where(Invoice.organization == organization) + .where(Invoice.organization_id == organization.id) .where(Invoice.invoiced_at >= fy_start_at) .where(Invoice.invoiced_at < fy_end_at) .scalar_subquery() @@ -44,7 +47,7 @@ class Invoice(UuidMixin, BaseMixin[UUID, User], Model): invoice_no: Mapped[int | None] fy_start_at: Mapped[timestamptz] fy_end_at: Mapped[timestamptz] - invoiced_at: Mapped[timestamptz | None] + invoiced_at: Mapped[timestamptz] street_address_1: Mapped[str | None] = sa.orm.mapped_column(sa.Unicode(255)) street_address_2: Mapped[str | None] = sa.orm.mapped_column(sa.Unicode(255)) city: Mapped[str | None] = sa.orm.mapped_column(sa.Unicode(255)) @@ -71,7 +74,7 @@ class Invoice(UuidMixin, BaseMixin[UUID, User], Model): ) organization: Mapped[Organization] = relationship(back_populates='invoices') - __roles__ = { + __roles__: ClassVar = { 'invoicer': { 'read': { 'status', @@ -92,29 +95,35 @@ class Invoice(UuidMixin, BaseMixin[UUID, User], Model): } @role_check('invoicer') - def has_invoicer_role(self, actor: User | None, _anchors=()) -> bool: + def has_invoicer_role(self, actor: User | None, _anchors: Any = ()) -> bool: return ( actor is not None and self.organization.userid in actor.organizations_owned_ids() ) - def __init__(self, *args, **kwargs): - organization = kwargs.get('organization') - country_code = kwargs.get('country_code') + def __init__(self, *args, **kwargs) -> None: + organization = kwargs.pop('organization', None) + country_code = kwargs.pop('country_code', None) + if kwargs.pop('invoiced_at', None): + msg = "Don't pass an invoiced_at value" + raise ValueError(msg) if not country_code: # Default to India country_code = 'IN' if not organization: - raise ValueError("Invoice MUST be initialized with an organization") + msg = "Invoice MUST be initialized with an organization" + raise ValueError(msg) + super().__init__( + *args, organization=organization, country_code=country_code, **kwargs + ) self.invoiced_at = utcnow() self.fy_start_at, self.fy_end_at = get_fiscal_year( country_code, self.invoiced_at ) self.invoice_no = gen_invoice_no(organization, country_code, self.invoiced_at) - super().__init__(*args, **kwargs) @property - def is_final(self): + def is_final(self) -> bool: return self.status == InvoiceStatus.FINAL @sa.orm.validates( @@ -135,11 +144,12 @@ def is_final(self): 'customer_order_id', 'organization_id', ) - def validate_immutable_final_invoice(self, key, val): + def validate_immutable_final_invoice(self, key: str, val: Any) -> Any: if self.status == InvoiceStatus.FINAL: - raise ValueError( - _("`{attr}` cannot be modified in a finalized invoice").format(attr=key) + msg = _("`{attr}` cannot be modified in a finalized invoice").format( + attr=key ) + raise ValueError(msg) return val diff --git a/boxoffice/models/line_item.py b/boxoffice/models/line_item.py index b90df7e3..8f4d6a6e 100644 --- a/boxoffice/models/line_item.py +++ b/boxoffice/models/line_item.py @@ -19,13 +19,13 @@ from .enums import LineItemStatus from .user import User -__all__ = ['LineItem', 'Assignee'] +__all__ = ['LineItemTuple', 'LineItem', 'Assignee'] class LineItemTuple(NamedTuple): """Duck-type for LineItem.""" - id: UUID | None # noqa: A003 + id: UUID | None ticket_id: UUID base_amount: Decimal | None discount_policy_id: UUID | None = None @@ -116,7 +116,7 @@ class LineItem(BaseMixin[UUID, User], Model): cascade='all, delete-orphan', lazy='dynamic', back_populates='line_item' ) - def permissions(self, actor, inherited=None): + def permissions(self, actor: User, inherited: set[str] | None = None) -> set[str]: perms = super().permissions(actor, inherited) if self.order.organization.userid in actor.organizations_owned_ids(): perms.add('org_admin') @@ -177,7 +177,7 @@ def calculate( return calculated_line_items - def confirm(self): + def confirm(self) -> None: self.status = LineItemStatus.CONFIRMED assignee: Mapped[Assignee | None] = relationship( @@ -198,7 +198,7 @@ def current_assignee(self) -> Assignee | None: return self.assignees.filter(Assignee.current.is_(True)).one_or_none() @property - def is_transferable(self): + def is_transferable(self) -> bool: tz = current_app.config['tz'] now = localize_timezone(utcnow(), tz) if self.assignee is None: @@ -214,27 +214,27 @@ def is_transferable(self): ) @property - def is_confirmed(self): + def is_confirmed(self) -> bool: return self.status == LineItemStatus.CONFIRMED @property - def is_cancelled(self): + def is_cancelled(self) -> bool: return self.status == LineItemStatus.CANCELLED @property - def is_free(self): + def is_free(self) -> bool: return self.final_amount == Decimal('0') - def cancel(self): + def cancel(self) -> None: """Set status and cancelled_at.""" self.status = LineItemStatus.CANCELLED self.cancelled_at = sa.func.utcnow() - def make_void(self): + def make_void(self) -> None: self.status = LineItemStatus.VOID self.cancelled_at = sa.func.utcnow() - def is_cancellable(self): + def is_cancellable(self) -> bool: tz = current_app.config['tz'] now = localize_timezone(utcnow(), tz) return self.is_confirmed and ( @@ -325,7 +325,7 @@ def sales_by_date( def calculate_weekly_sales( menu_ids: Sequence[str | UUID], user_tz: str | tzinfo, year: int -): +) -> OrderedDict[int, int]: """Calculate weekly sales for a year in the given menu_ids.""" ordered_week_sales = OrderedDict() for year_week in Week.weeks_of_year(year): @@ -363,14 +363,14 @@ def calculate_weekly_sales( return ordered_week_sales -def sales_delta(user_tz: tzinfo, ticket_ids: Sequence[str]): +def sales_delta(user_tz: tzinfo, ticket_ids: Sequence[str]) -> Decimal: """Calculate the percentage difference in sales between today and yesterday.""" today = utcnow().date() yesterday = today - timedelta(days=1) today_sales = sales_by_date(today, ticket_ids, user_tz) yesterday_sales = sales_by_date(yesterday, ticket_ids, user_tz) if not today_sales or not yesterday_sales: - return 0 + return Decimal('0') return round(Decimal('100') * (today_sales - yesterday_sales) / yesterday_sales, 2) diff --git a/boxoffice/models/line_item_discounter.py b/boxoffice/models/line_item_discounter.py index f02a399a..a9bba840 100644 --- a/boxoffice/models/line_item_discounter.py +++ b/boxoffice/models/line_item_discounter.py @@ -1,9 +1,9 @@ from __future__ import annotations +import itertools from collections.abc import Sequence from decimal import Decimal from typing import cast -import itertools from .discount_policy import DiscountCoupon, DiscountPolicy, PolicyCoupon from .line_item import LineItemTuple @@ -21,7 +21,8 @@ def get_discounted_line_items( if not line_items: return [] if len({line_item.ticket_id for line_item in line_items}) > 1: - raise ValueError("line_items must be of the same ticket_id") + msg = "line_items must be of the same ticket_id" + raise ValueError(msg) valid_discounts = self.get_valid_discounts(line_items, coupons) if len(valid_discounts) > 1: diff --git a/boxoffice/models/menu.py b/boxoffice/models/menu.py index 150ea570..4f2a571d 100644 --- a/boxoffice/models/menu.py +++ b/boxoffice/models/menu.py @@ -1,7 +1,7 @@ from __future__ import annotations from decimal import Decimal -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Any, ClassVar, cast from uuid import UUID from sqlalchemy.ext.orderinglist import ordering_list @@ -61,10 +61,10 @@ class Menu(BaseScopedNameMixin[UUID, User], Model): cascade='all, delete-orphan', lazy='dynamic', back_populates='menu' ) - __roles__ = {'ic_owner': {'read': {'id', 'name', 'title', 'description'}}} + __roles__: ClassVar = {'ic_owner': {'read': {'id', 'name', 'title', 'description'}}} @role_check('ic_owner') - def has_ic_owner_role(self, actor: User | None, _anchors=()) -> bool: + def has_ic_owner_role(self, actor: User | None, _anchors: Any = ()) -> bool: return ( actor is not None and self.organization.userid in actor.organizations_owned_ids() @@ -163,12 +163,12 @@ def fetch_all_details(self) -> HeadersAndDataTuple: db.session.execute(line_item_query).fetchall(), ) - def fetch_assignee_details(self): + def fetch_assignee_details(self) -> HeadersAndDataTuple: """ Return assignee details for all ordered tickets in the menu. Includes receipt_no, ticket title, assignee fullname, assignee email, assignee - phone and assignee details for all the ordered lineitem in a given menu as a + phone and assignee details for all the ordered line items in a given menu as a tuple of (keys, rows). """ line_item_join = ( @@ -195,7 +195,7 @@ def fetch_assignee_details(self): ) .select_from(line_item_join) .where(LineItem.status == LineItemStatus.CONFIRMED) - .where(Order.menu == self) + .where(Order.menu_id == self.id) .order_by(LineItem.ordered_at) ) return HeadersAndDataTuple( diff --git a/boxoffice/models/order.py b/boxoffice/models/order.py index 013bf4db..6aa2b8ef 100644 --- a/boxoffice/models/order.py +++ b/boxoffice/models/order.py @@ -1,11 +1,11 @@ from __future__ import annotations -from collections import namedtuple +import secrets +from dataclasses import dataclass from decimal import Decimal from functools import partial -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, NamedTuple from uuid import UUID -import secrets from sqlalchemy.ext.orderinglist import ordering_list @@ -28,17 +28,27 @@ __all__ = ['Order', 'OrderSession'] -OrderAmounts = namedtuple( - 'OrderAmounts', - ['base_amount', 'discounted_amount', 'final_amount', 'confirmed_amount'], -) + +class OrderAmounts(NamedTuple): + base_amount: Decimal + discounted_amount: Decimal + final_amount: Decimal + confirmed_amount: Decimal -def gen_receipt_no(organization): +@dataclass +class LineItemGroup: + # This dataclass uses Any types as the actual types are not imported yet + ticket: Any + count: int + total_price: Decimal + + +def gen_receipt_no(organization: Organization) -> sa.ScalarSelect[int]: """Generate a sequential invoice number for an order, given an organization.""" return ( sa.select(sa.func.coalesce(sa.func.max(Order.receipt_no + 1), 1)) - .where(Order.organization == organization) + .where(Order.organization_id == organization.id) .scalar_subquery() ) @@ -130,10 +140,7 @@ class Order(BaseMixin[UUID, User], Model): viewonly=True, ) - # These 3 properties are defined below the LineItem model - - # confirmed_line_items, initial_line_items, confirmed_and_cancelled_line_items - - def permissions(self, actor, inherited=None) -> set: + def permissions(self, actor: User, inherited: set[str] | None = None) -> set: perms = super().permissions(actor, inherited) if self.organization.userid in actor.organizations_owned_ids(): perms.add('org_admin') @@ -154,7 +161,7 @@ def invoice(self) -> None: self.invoiced_at = utcnow() self.status = OrderStatus.INVOICE - def get_amounts(self, line_item_status) -> OrderAmounts: + def get_amounts(self, line_item_status: LineItemStatus) -> OrderAmounts: """Calculate and return the order's amounts as an OrderAmounts tuple.""" base_amount = Decimal(0) discounted_amount = Decimal(0) diff --git a/boxoffice/models/payment.py b/boxoffice/models/payment.py index 74fe12ab..4fdbb36d 100644 --- a/boxoffice/models/payment.py +++ b/boxoffice/models/payment.py @@ -1,6 +1,8 @@ from __future__ import annotations from collections import OrderedDict +from collections.abc import Iterable +from datetime import tzinfo from decimal import Decimal from uuid import UUID @@ -97,7 +99,9 @@ class PaymentTransaction(BaseMixin[UUID, User], Model): pg_refundid: Mapped[str | None] = sa.orm.mapped_column(sa.Unicode(80), unique=True) -def calculate_weekly_refunds(menu_ids, user_tz, year): +def calculate_weekly_refunds( + menu_ids: Iterable[UUID], user_tz: str | tzinfo, year: int +) -> OrderedDict[int, int]: """Calculate refunds per week of the year for given menu_ids.""" ordered_week_refunds = OrderedDict() for year_week in Week.weeks_of_year(year): diff --git a/boxoffice/models/ticket.py b/boxoffice/models/ticket.py index d2d7acd3..d24e1d19 100644 --- a/boxoffice/models/ticket.py +++ b/boxoffice/models/ticket.py @@ -1,9 +1,9 @@ from __future__ import annotations -from collections.abc import Sequence -from datetime import date +from collections.abc import Iterable, Sequence +from datetime import date, datetime from decimal import Decimal -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple from uuid import UUID from sqlalchemy.ext.hybrid import hybrid_property @@ -33,6 +33,12 @@ __all__ = ['Ticket', 'Price'] +class AvailabilityData(NamedTuple): + title: str + quantity_total: int + line_item_count: int + + class Ticket(BaseScopedNameMixin[UUID, User], Model): __tablename__ = 'item' __table_args__ = (sa.UniqueConstraint('item_collection_id', 'name'),) @@ -70,7 +76,7 @@ class Ticket(BaseScopedNameMixin[UUID, User], Model): cascade='all, delete-orphan', lazy='dynamic', back_populates='ticket' ) - __roles__ = { + __roles__: ClassVar = { 'item_owner': { 'read': { 'id', @@ -86,7 +92,7 @@ class Ticket(BaseScopedNameMixin[UUID, User], Model): } @role_check('item_owner') - def has_item_owner_role(self, actor: User | None, _anchors=()) -> bool: + def has_item_owner_role(self, actor: User | None, _anchors: Any = ()) -> bool: return ( actor is not None and self.menu.organization.userid in actor.organizations_owned_ids() @@ -96,7 +102,7 @@ def current_price(self) -> Price | None: """Return the current price object for a ticket.""" return self.price_at(utcnow()) - def has_higher_price(self, current_price) -> bool: + def has_higher_price(self, current_price: Price) -> bool: """Check if ticket has a higher price than the given current price.""" return Price.query.filter( Price.end_at > current_price.end_at, @@ -104,7 +110,7 @@ def has_higher_price(self, current_price) -> bool: Price.discount_policy_id.is_(None), ).notempty() - def discounted_price(self, discount_policy) -> Price | None: + def discounted_price(self, discount_policy: DiscountPolicy) -> Price | None: """Return the discounted price for a ticket.""" return Price.query.filter( Price.ticket == self, Price.discount_policy == discount_policy @@ -115,7 +121,7 @@ def standard_prices(self) -> Query[Price]: Price.ticket == self, Price.discount_policy_id.is_(None) ).order_by(Price.start_at.desc()) - def price_at(self, timestamp) -> Price | None: + def price_at(self, timestamp: datetime) -> Price | None: """Return the price object for a ticket at a given time.""" return ( Price.query.filter( @@ -129,7 +135,7 @@ def price_at(self, timestamp) -> Price | None: ) @classmethod - def get_by_category(cls, category) -> Query[Ticket]: + def get_by_category(cls, category: Category) -> Query[Ticket]: return cls.query.filter(Ticket.category == category).order_by(cls.seq) @hybrid_property @@ -160,15 +166,15 @@ def confirmed_line_items(self) -> AppenderQuery[LineItem]: return self.line_items.filter(LineItem.status == LineItemStatus.CONFIRMED) @with_roles(call={'item_owner'}) - def sold_count(self): + def sold_count(self) -> int: return self.confirmed_line_items.filter(LineItem.final_amount > 0).count() @with_roles(call={'item_owner'}) - def free_count(self): + def free_count(self) -> int: return self.confirmed_line_items.filter(LineItem.final_amount == 0).count() @with_roles(call={'item_owner'}) - def cancelled_count(self): + def cancelled_count(self) -> int: return self.line_items.filter( LineItem.status == LineItemStatus.CANCELLED ).count() @@ -185,16 +191,19 @@ def net_sales(self) -> Decimal: ) @classmethod - def get_availability(cls, item_ids): + def get_availability(cls, item_ids: Iterable[UUID]) -> dict[str, AvailabilityData]: """ Return an availability dict. {'ticket_id': ('ticket)title', 'quantity_total', 'line_item_count')} """ items_dict = {} - item_tups = ( + rows = ( db.session.query( - cls.id, cls.title, cls.quantity_total, sa.func.count(cls.id) + cls.id.label('id'), + cls.title.label('title'), + cls.quantity_total.label('quantity_total'), + sa.func.count(cls.id).label('ticket_count'), ) .join(LineItem) .filter( @@ -204,8 +213,10 @@ def get_availability(cls, item_ids): .group_by(cls.id) .all() ) - for item_tup in item_tups: - items_dict[str(item_tup[0])] = item_tup[1:] + for row in rows: + items_dict[str(row.id)] = AvailabilityData( + row.title, row.quantity_total, row.ticket_count + ) return items_dict def demand_curve(self) -> Sequence[sa.engine.Row[tuple[Decimal, int]]]: @@ -247,7 +258,7 @@ class Price(BaseScopedNameMixin[UUID, User], Model): amount: Mapped[Decimal] = sa.orm.mapped_column(default=Decimal(0)) currency: Mapped[str] = sa.orm.mapped_column(sa.Unicode(3), default='INR') - __roles__ = { + __roles__: ClassVar = { 'price_owner': { 'read': { 'id', @@ -262,18 +273,18 @@ class Price(BaseScopedNameMixin[UUID, User], Model): } @role_check('price_owner') - def has_price_owner_role(self, actor: User | None, _anchors=()) -> bool: + def has_price_owner_role(self, actor: User | None, _anchors: Any = ()) -> bool: return ( actor is not None and self.ticket.menu.organization.userid in actor.organizations_owned_ids() ) @property - def discount_policy_title(self): + def discount_policy_title(self) -> str | None: return self.discount_policy.title if self.discount_policy else None @with_roles(call={'price_owner'}) - def tense(self): + def tense(self) -> str: now = utcnow() if self.end_at < now: return _("past") diff --git a/boxoffice/models/user.py b/boxoffice/models/user.py index 70067854..dffb05ec 100644 --- a/boxoffice/models/user.py +++ b/boxoffice/models/user.py @@ -5,7 +5,7 @@ from coaster.sqlalchemy import JsonDict from flask_lastuser.sqlalchemy import ProfileBase, UserBase2 -from . import DynamicMapped, Mapped, Model, db, relationship, sa +from . import DynamicMapped, Mapped, Model, Query, db, relationship, sa from .utils import HeadersAndDataTuple __all__ = ['User', 'Organization'] @@ -19,12 +19,12 @@ class User(UserBase2, Model): ) orders: Mapped[list[Order]] = relationship(cascade='all, delete-orphan') - def __repr__(self): + def __repr__(self) -> str: """Return a representation.""" return str(self.fullname) @property - def orgs(self): + def orgs(self) -> Query[Organization]: return Organization.query.filter( Organization.userid.in_(self.organizations_owned_ids()) ) @@ -73,13 +73,13 @@ class Organization(ProfileBase, Model): cascade='all, delete-orphan', lazy='dynamic', back_populates='organization' ) - def permissions(self, actor, inherited=None): + def permissions(self, actor: User, inherited: set[str] | None = None) -> set[str]: perms = super().permissions(actor, inherited) if self.userid in actor.organizations_owned_ids(): perms.add('org_admin') return perms - def fetch_invoices(self, filters: dict | None = None): + def fetch_invoices(self, filters: dict | None = None) -> HeadersAndDataTuple: """Return invoices for an organization as a tuple of (row_headers, rows).""" # pylint: disable=import-outside-toplevel from .invoice import Invoice @@ -146,7 +146,9 @@ def fetch_invoices(self, filters: dict | None = None): headers, db.session.execute(invoices_query).fetchall() ) - def fetch_invoice_line_items(self, filters: dict | None = None): + def fetch_invoice_line_items( + self, filters: dict | None = None + ) -> HeadersAndDataTuple: """ Return invoice line items for import into Zoho Books. diff --git a/boxoffice/models/utils.py b/boxoffice/models/utils.py index 670794b0..4117adf7 100644 --- a/boxoffice/models/utils.py +++ b/boxoffice/models/utils.py @@ -24,10 +24,7 @@ def naive_to_utc(dt: datetime, timezone: str | tzinfo | None = None) -> datetime """ tz: tzinfo if timezone: - if isinstance(timezone, str): - tz = pytz.timezone(timezone) - else: - tz = timezone + tz = pytz.timezone(timezone) if isinstance(timezone, str) else timezone elif isinstance(dt, datetime) and dt.tzinfo: tz = dt.tzinfo else: @@ -43,17 +40,14 @@ def get_fiscal_year(jurisdiction: str, dt: datetime) -> tuple[datetime, datetime Return the financial year for a given jurisdiction and timestamp. Returns start and end dates as tuple of timestamps. Recognizes April 1 as the start - date for India (jurisfiction code: 'in'), January 1 everywhere else. + date for India (jurisdiction code: 'in'), January 1 everywhere else. Example:: get_fiscal_year('IN', utcnow()) """ if jurisdiction.lower() == 'in': - if dt.month < 4: - start_year = dt.year - 1 - else: - start_year = dt.year + start_year = dt.year - 1 if dt.month < 4 else dt.year # starts on April 1 XXXX fy_start = datetime(start_year, 4, 1) # ends on April 1 XXXX + 1 diff --git a/boxoffice/views/admin.py b/boxoffice/views/admin.py index ade10c8e..fa0bf170 100644 --- a/boxoffice/views/admin.py +++ b/boxoffice/views/admin.py @@ -1,7 +1,9 @@ +from collections.abc import Mapping from datetime import date +from typing import Any -from flask import g, jsonify, request import pytz +from flask import Response, g, jsonify, request from baseframe import _ from coaster.utils import getbool @@ -14,7 +16,7 @@ from .utils import api_error, api_success, check_api_access -def jsonify_dashboard(data): +def jsonify_dashboard(data: Mapping[str, Any]) -> Response: return jsonify( orgs=[ { @@ -37,9 +39,9 @@ def index() -> ReturnRenderWith: return {'user': g.user} -def jsonify_org(data): +def jsonify_org(data: Mapping[str, Any]) -> Response: menu_list = ( - Menu.query.filter(Menu.organization == data['org']) + Menu.query.filter(Menu.organization_id == data['org'].id) .order_by(Menu.created_at.desc()) .all() ) @@ -60,7 +62,7 @@ def org(organization: Organization) -> ReturnRenderWith: @app.route('/api/1/organization//weekly_revenue', methods=['GET', 'OPTIONS']) @load_models((Organization, {'name': 'org'}, 'organization')) -def org_revenue(organization: Organization): +def org_revenue(organization: Organization) -> Response: check_api_access(organization.details.get('access_token')) if not request.args.get('year'): diff --git a/boxoffice/views/admin_category.py b/boxoffice/views/admin_category.py index 5f5984b7..563fc8ba 100644 --- a/boxoffice/views/admin_category.py +++ b/boxoffice/views/admin_category.py @@ -1,4 +1,7 @@ -from flask import jsonify, request +from collections.abc import Mapping +from typing import Any + +from flask import Response, jsonify, request from baseframe import _ from baseframe.forms import render_form @@ -10,15 +13,15 @@ from .utils import api_error, api_success -def jsonify_new_category(data_dict): +def jsonify_new_category(data_dict: Mapping[str, Any]) -> Response: menu = data_dict['menu'] category_form = CategoryForm(parent=menu) if request.method == 'GET': return jsonify( form_template=render_form( form=category_form, - title="New Ticket", - submit="Create", + title=_("New Ticket"), + submit=_("Create"), with_chrome=False, ).get_data(as_text=True) ) @@ -51,15 +54,15 @@ def admin_new_category(menu: Menu) -> ReturnRenderWith: return {'menu': menu} -def jsonify_edit_category(data_dict): +def jsonify_edit_category(data_dict: Mapping[str, Any]) -> Response: category = data_dict['category'] category_form = CategoryForm(obj=category) if request.method == 'GET': return jsonify( form_template=render_form( form=category_form, - title="Edit category", - submit="Update", + title=_("Edit category"), + submit=_("Update"), with_chrome=False, ).get_data(as_text=True) ) diff --git a/boxoffice/views/admin_discount.py b/boxoffice/views/admin_discount.py index cdb0e3f3..501ba0fe 100644 --- a/boxoffice/views/admin_discount.py +++ b/boxoffice/views/admin_discount.py @@ -1,6 +1,7 @@ -from typing import TYPE_CHECKING +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any -from flask import jsonify, request +from flask import Response, jsonify, request from baseframe import _, forms from baseframe.forms import render_form @@ -26,7 +27,7 @@ from .utils import api_error, api_success, xhr_only -def jsonify_discount_policy(discount_policy: DiscountPolicy): +def jsonify_discount_policy(discount_policy: DiscountPolicy) -> dict[str, Any]: details = dict(discount_policy.current_access()) details['price_details'] = {} if discount_policy.is_price_based: @@ -42,7 +43,7 @@ def jsonify_discount_policy(discount_policy: DiscountPolicy): return details -def jsonify_discount_policies(data_dict): +def jsonify_discount_policies(data_dict: Mapping[str, Any]) -> Response: discount_policies_list = [] for discount_policy in data_dict['discount_policies']: discount_policies_list.append(jsonify_discount_policy(discount_policy)) @@ -65,7 +66,10 @@ def jsonify_discount_policies(data_dict): @load_models((Organization, {'name': 'org'}, 'organization'), permission='org_admin') @requestargs('search', ('page', int), ('size', int)) def admin_discount_policies( - organization: Organization, search: str | None = None, page=1, size=None + organization: Organization, + search: str | None = None, + page: int = 1, + size: int | None = None, ) -> ReturnRenderWith: results_per_page = size or 20 @@ -95,7 +99,7 @@ def admin_discount_policies( @lastuser.requires_login @xhr_only @load_models((Organization, {'name': 'org'}, 'organization'), permission='org_admin') -def admin_new_discount_policy(organization: Organization): +def admin_new_discount_policy(organization: Organization) -> Response: discount_policy = DiscountPolicy(organization=organization) discount_policy_form = DiscountPolicyForm(model=DiscountPolicy) discount_policy_form.populate_obj(discount_policy) @@ -176,7 +180,7 @@ def admin_new_discount_policy(organization: Organization): (DiscountPolicy, {'id': 'discount_policy_id'}, 'discount_policy'), permission='org_admin', ) -def admin_edit_discount_policy(discount_policy: DiscountPolicy): +def admin_edit_discount_policy(discount_policy: DiscountPolicy) -> Response: discount_policy_error_msg = _( "The discount could not be updated. Please rectify the indicated issues" ) @@ -243,14 +247,14 @@ def admin_edit_discount_policy(discount_policy: DiscountPolicy): (DiscountPolicy, {'id': 'discount_policy_id'}, 'discount_policy'), permission='org_admin', ) -def admin_delete_discount_policy(discount_policy: DiscountPolicy): +def admin_delete_discount_policy(discount_policy: DiscountPolicy) -> Response: form = forms.Form() if request.method == 'GET': return jsonify( form_template=render_form( form=form, - title="Delete discount policy", - submit="Delete", + title=_("Delete discount policy"), + submit=_("Delete"), with_chrome=False, ).get_data(as_text=True) ) @@ -264,7 +268,7 @@ def admin_delete_discount_policy(discount_policy: DiscountPolicy): db.session.delete(discount_policy) db.session.commit() - return api_success(result={}, doc="Discount policy deleted.", status_code=200) + return api_success(result={}, doc=_("Discount policy deleted."), status_code=200) @app.route('/admin/discount_policy//coupons/new', methods=['POST']) @@ -274,7 +278,7 @@ def admin_delete_discount_policy(discount_policy: DiscountPolicy): (DiscountPolicy, {'id': 'discount_policy_id'}, 'discount_policy'), permission='org_admin', ) -def admin_new_coupon(discount_policy: DiscountPolicy): +def admin_new_coupon(discount_policy: DiscountPolicy) -> Response: coupon_form = DiscountCouponForm(parent=discount_policy) if not coupon_form.validate_on_submit(): @@ -309,7 +313,7 @@ def admin_new_coupon(discount_policy: DiscountPolicy): (DiscountPolicy, {'id': 'discount_policy_id'}, 'discount_policy'), permission='org_admin', ) -def admin_discount_coupons(discount_policy: DiscountPolicy): +def admin_discount_coupons(discount_policy: DiscountPolicy) -> Response: coupons_list = [ { 'code': coupon.code, diff --git a/boxoffice/views/admin_menu.py b/boxoffice/views/admin_menu.py index 28ee1a29..10730d3a 100644 --- a/boxoffice/views/admin_menu.py +++ b/boxoffice/views/admin_menu.py @@ -1,4 +1,7 @@ -from flask import g, jsonify, request +from collections.abc import Mapping +from typing import Any + +from flask import Response, g, jsonify, request from baseframe import _, localize_timezone from baseframe.forms import render_form @@ -13,7 +16,7 @@ from .utils import api_error, api_success -def jsonify_menu(menu_dict): +def jsonify_menu(menu_dict: Mapping[str, Any]) -> Response: return jsonify( account_name=menu_dict['menu'].organization.name, account_title=menu_dict['menu'].organization.title, @@ -65,14 +68,14 @@ def admin_menu(menu: Menu) -> ReturnRenderWith: } -def jsonify_new_menu(menu_dict): +def jsonify_new_menu(menu_dict: Mapping[str, Any]) -> Response: ic_form = MenuForm() if request.method == 'GET': return jsonify( form_template=render_form( form=ic_form, - title="New menu", - submit="Create", + title=_("New menu"), + submit=_("Create"), ajax=False, with_chrome=False, ).get_data(as_text=True) @@ -104,15 +107,15 @@ def admin_new_ic(organization: Organization) -> ReturnRenderWith: return {'organization': organization} -def jsonify_edit_menu(menu_dict): +def jsonify_edit_menu(menu_dict: Mapping[str, Any]) -> Response: menu = menu_dict['menu'] ic_form = MenuForm(obj=menu) if request.method == 'GET': return jsonify( form_template=render_form( form=ic_form, - title="Edit menu", - submit="Save", + title=_("Edit menu"), + submit=_("Save"), ajax=False, with_chrome=False, ).get_data(as_text=True) diff --git a/boxoffice/views/admin_order.py b/boxoffice/views/admin_order.py index 64a78757..f272d6db 100644 --- a/boxoffice/views/admin_order.py +++ b/boxoffice/views/admin_order.py @@ -1,9 +1,15 @@ -from flask import jsonify, url_for +from collections.abc import Mapping +from decimal import Decimal +from typing import Any, TypedDict +from uuid import UUID + +from flask import Response, jsonify, url_for from coaster.views import ReturnRenderWith, load_models, render_with from .. import app, lastuser from ..models import ( + Assignee, CurrencySymbol, InvoiceStatus, LineItem, @@ -16,7 +22,33 @@ from .utils import check_api_access, json_date_format, xhr_only -def format_assignee(assignee): +class AssigneeDict(TypedDict): + id: int + fullname: str + email: str + phone: str | None + details: dict[str, Any] + + +class LineItemDict(TypedDict): + title: str + seq: int + id: UUID + category: str + description: str | None + description_html: str | None + currency: str + base_amount: Decimal + discounted_amount: Decimal + final_amount: Decimal + discount_policy: str + discount_coupon: str + cancelled_at: str + assignee_details: AssigneeDict | None + cancel_ticket_url: str + + +def format_assignee(assignee: Assignee | None) -> AssigneeDict | None: if assignee: return { 'id': assignee.id, @@ -28,8 +60,8 @@ def format_assignee(assignee): return None -def format_line_items(line_items): - line_item_dicts = [] +def format_line_items(line_items: list[LineItem]) -> list[LineItemDict]: + line_item_dicts: list[LineItemDict] = [] for line_item in line_items: line_item_dicts.append( { @@ -65,7 +97,7 @@ def format_line_items(line_items): return line_item_dicts -def jsonify_admin_orders(data_dict): +def jsonify_admin_orders(data_dict: Mapping[str, Any]) -> Response: menu_id = data_dict['menu'].id order_dicts = [] for order in data_dict['orders']: @@ -113,7 +145,7 @@ def admin_orders(menu: Menu) -> ReturnRenderWith: @lastuser.requires_login @xhr_only @load_models((Order, {'id': 'order_id'}, 'order'), permission='org_admin') -def admin_order(order: Order): +def admin_order(order: Order) -> Response: line_items = LineItem.query.filter( LineItem.order == order, LineItem.status.in_( @@ -123,7 +155,7 @@ def admin_order(order: Order): return jsonify(line_items=format_line_items(line_items)) -def jsonify_order(order_dict): +def jsonify_order(order_dict: Mapping[str, Any]) -> Response: org = {"title": order_dict['org'].title, "name": order_dict['org'].name} order = { 'id': order_dict['order'].id, @@ -165,7 +197,7 @@ def admin_org_order(org: Organization, order: Order) -> ReturnRenderWith: return {'org': org, 'order': order, 'line_items': line_items} -def get_order_details(order): +def get_order_details(order: Order) -> Response: line_items_list = [ { 'title': li.ticket.title, @@ -245,7 +277,7 @@ def get_order_details(order): (Organization, {'name': 'org_name'}, 'org'), (Order, {'organization': 'org', 'receipt_no': 'receipt_no'}, 'order'), ) -def order_api(org: Organization, order: Order): +def order_api(org: Organization, order: Order) -> Response: check_api_access(org.details.get('access_token')) return get_order_details(order) @@ -257,6 +289,6 @@ def order_api(org: Organization, order: Order): (Organization, {'name': 'org_name'}, 'org'), (Order, {'organization': 'org', 'id': 'order_id'}, 'order'), ) -def order_id_api(org: Organization, order: Order): +def order_id_api(org: Organization, order: Order) -> Response: check_api_access(org.details.get('access_token')) return get_order_details(order) diff --git a/boxoffice/views/admin_report.py b/boxoffice/views/admin_report.py index ed07dbc0..d542207a 100644 --- a/boxoffice/views/admin_report.py +++ b/boxoffice/views/admin_report.py @@ -1,8 +1,9 @@ +from collections.abc import Mapping from datetime import date, datetime from typing import Any from babel.dates import format_datetime -from flask import g, jsonify, request, url_for +from flask import Response, g, jsonify, request, url_for from flask_babel import get_locale from baseframe import localize_timezone @@ -14,7 +15,7 @@ from .utils import api_error, check_api_access, csv_response -def jsonify_report(data_dict): +def jsonify_report(data_dict: Mapping[str, Any]) -> Response: return jsonify( account_name=data_dict['menu'].organization.name, account_title=data_dict['menu'].organization.title, @@ -31,7 +32,7 @@ def admin_report(menu: Menu) -> ReturnRenderWith: return {'menu': menu} -def jsonify_org_report(data_dict): +def jsonify_org_report(data_dict: Mapping[str, Any]) -> Response: return jsonify(account_title=data_dict['organization'].title) @@ -48,11 +49,11 @@ def admin_org_report(organization: Organization) -> ReturnRenderWith: @app.route('/admin/menu//tickets.csv') @lastuser.requires_login @load_models((Menu, {'id': 'menu_id'}, 'menu'), permission='org_admin') -def tickets_report(menu: Menu): +def tickets_report(menu: Menu) -> Response: headers, rows = menu.fetch_all_details() assignee_url_index = headers.index('assignee_url') - def row_handler(row): + def row_handler(row: list[Any]) -> list[Any]: # localize datetime row_list = [ ( @@ -78,14 +79,14 @@ def row_handler(row): @app.route('/admin/menu//attendees.csv') @lastuser.requires_login @load_models((Menu, {'id': 'menu_id'}, 'menu'), permission='org_admin') -def attendees_report(menu: Menu): +def attendees_report(menu: Menu) -> Response: # Generated a unique list of headers for all 'assignee_details' keys in all items in # this menu. This flattens the 'assignee_details' dict. This will need to # be updated if we add additional dicts to our csv export. attendee_details_headers = [] for ticket in menu.tickets: if ticket.assignee_details: - for detail in ticket.assignee_details.keys(): + for detail in ticket.assignee_details: attendee_detail_prefixed = 'attendee_details_' + detail # Eliminate duplicate headers across attendee_details across items. For # example, if 't-shirt' and 'hoodie' are two items with a 'size' key, @@ -101,13 +102,13 @@ def attendees_report(menu: Menu): else: attendee_details_index = -1 - def row_handler(row): + def row_handler(row: list[dict | datetime | str]) -> dict[str, Any]: # Convert row to a dict dict_row = {} for idx, item in enumerate(row): # 'assignee_details' is a dict already, so copy and include prefixes if idx == attendee_details_index and isinstance(item, dict): - for key in item.keys(): + for key in item: dict_row['attendee_details_' + key] = item[key] # Item is a datetime object, so format and add to dict elif isinstance(item, datetime): @@ -136,7 +137,7 @@ def row_handler(row): 'menu', ), ) -def orders_api(organization: Organization, menu: Menu): +def orders_api(organization: Organization, menu: Menu) -> Response: check_api_access(organization.details.get('access_token')) # Generated a unique list of headers for all 'assignee_details' keys in all items in @@ -145,7 +146,7 @@ def orders_api(organization: Organization, menu: Menu): attendee_details_headers = [] for ticket in menu.tickets: if ticket.assignee_details: - for detail in ticket.assignee_details.keys(): + for detail in ticket.assignee_details: attendee_detail_prefixed = 'attendee_details_' + detail # Eliminate duplicate headers across attendee_details across tickets. # For example, if 't-shirt' and 'hoodie' are two tickets with a 'size' @@ -160,13 +161,13 @@ def orders_api(organization: Organization, menu: Menu): else: attendee_details_index = -1 - def row_handler(row): + def row_handler(row: list[dict | datetime | str]) -> dict[str, Any]: # Convert row to a dict dict_row = {} for idx, item in enumerate(row): # 'assignee_details' is a dict already, so copy and include prefixes if idx == attendee_details_index and isinstance(item, dict): - for key in item.keys(): + for key in item: dict_row['attendee_details_' + key] = item[key] # Item is a datetime object, so format and add to dict elif isinstance(item, datetime): @@ -191,7 +192,7 @@ def row_handler(row): @load_models( (Organization, {'name': 'org_name'}, 'organization'), permission='org_admin' ) -def invoices_report(organization: Organization): +def invoices_report(organization: Organization) -> Response: today = date.today() period_type = request.args.get('type', 'all') invoice_filter: dict[str, Any] = {} @@ -242,12 +243,13 @@ def invoices_report(organization: Organization): buyer_email_index = headers.index('buyer_taxid') headers.insert(buyer_email_index, 'buyer_email') - def row_handler(row): - order = Order.query.filter(Order.id == row[order_id_index]).first() + def row_handler(row: list) -> dict[str, Any]: + order = Order.query.get(row[order_id_index]) + assert order is not None # noqa: S101 # nosec B101 row = list(row) row.insert(buyer_name_index, order.buyer_fullname) row.insert(buyer_email_index, order.buyer_email) - dict_row = dict(list(zip(headers, row))) + dict_row = dict(list(zip(headers, row, strict=False))) for enum_member in InvoiceStatus: if dict_row.get('status') == enum_member.value: dict_row['status'] = enum_member.name @@ -268,7 +270,7 @@ def row_handler(row): @load_models( (Organization, {'name': 'org_name'}, 'organization'), permission='org_admin' ) -def invoices_report_zb(organization: Organization): +def invoices_report_zb(organization: Organization) -> Response: today = date.today() period_type = request.args.get('type', 'all') invoice_filter: dict[str, Any] = {} @@ -326,8 +328,8 @@ def invoices_report_zb(organization: Organization): new_rows = [] new_rows_index: dict[str, int] = {} - for row in rows: - row = list(row) + for _row in rows: + row = list(_row) inv_date = localize_timezone(row[invoice_date_index]) row[invoice_date_index] = inv_date.strftime('%Y-%m-%d') fy_base = inv_date.year - int(inv_date.month < 4) @@ -366,8 +368,8 @@ def invoices_report_zb(organization: Organization): headers.pop() headers.append('Invoice Currency') - def row_handler(row): - order = Order.query.filter(Order.id == row[order_id_index]).first() + def row_handler(row: list[Any]) -> dict[str, Any]: + order = Order.query.filter(Order.id == row[order_id_index]).one() if row[customer_name_index] is None: row[customer_name_index] = order.buyer_fullname fullname = row[customer_name_index].split(' ') @@ -395,7 +397,7 @@ def row_handler(row): row.insert(gst_treatment_index, gst_treatment) row.pop() row.append(row[headers.index('Currency Code')]) - dict_row = dict(list(zip(headers, row))) + dict_row = dict(list(zip(headers, row, strict=False))) for enum_member in InvoiceStatus: if dict_row.get('Invoice Status on Boxoffice') == enum_member.value: dict_row['Invoice Status on Boxoffice'] = enum_member.name @@ -408,7 +410,7 @@ def row_handler(row): @app.route('/admin/o//settlements.csv') @lastuser.requires_permission('siteadmin') @load_models((Organization, {'name': 'org_name'}, 'organization')) -def settled_transactions(organization: Organization): +def settled_transactions(organization: Organization) -> Response: # noqa: ARG001 # FIXME: This report is NOT filtered by organization; it has everything! today = date.today() year = int(request.args.get('year', today.year)) diff --git a/boxoffice/views/admin_ticket.py b/boxoffice/views/admin_ticket.py index 4567a5e0..b182fb8b 100644 --- a/boxoffice/views/admin_ticket.py +++ b/boxoffice/views/admin_ticket.py @@ -1,4 +1,7 @@ -from flask import jsonify, request +from collections.abc import Mapping +from typing import Any, TypedDict + +from flask import Response, jsonify, request from baseframe import _ from baseframe.forms import render_form @@ -16,7 +19,7 @@ @xhr_only @load_models((Organization, {'name': 'org'}, 'organization'), permission='org_admin') @requestargs('search') -def tickets(organization: Organization, search: str | None = None): +def tickets(organization: Organization, search: str | None = None) -> Response: if search: filtered_tickets = ( # FIXME: Query one, join the other @@ -47,20 +50,25 @@ def tickets(organization: Organization, search: str | None = None): for ticket_tuple in filtered_tickets ] }, - doc="Filtered tickets", + doc=_("Filtered tickets"), status_code=200, ) return api_error(message=_("Missing search query"), status_code=400) -def jsonify_price(price): +def jsonify_price(price: Price) -> dict[str, Any]: price_details = dict(price.current_access()) price_details['tense'] = price.tense() return price_details -def format_demand_curve(ticket: Ticket): - result = {} +class DemandData(TypedDict): + quantity_demanded: int + demand: int + + +def format_demand_curve(ticket: Ticket) -> dict[str, DemandData]: + result: dict[str, DemandData] = {} demand_counter = 0 for amount, quantity_demanded in reversed(ticket.demand_curve()): @@ -72,7 +80,7 @@ def format_demand_curve(ticket: Ticket): return result -def format_ticket_details(ticket: Ticket): +def format_ticket_details(ticket: Ticket) -> dict[str, Any]: ticket_details = dict(ticket.current_access()) ticket_details['sold_count'] = ticket.sold_count() ticket_details['free_count'] = ticket.free_count() @@ -83,7 +91,7 @@ def format_ticket_details(ticket: Ticket): return ticket_details -def jsonify_item(data_dict): +def jsonify_item(data_dict: Mapping[str, Any]) -> Response: ticket = data_dict['ticket'] discount_policies_list = [] for policy in ticket.discount_policies: @@ -116,13 +124,16 @@ def admin_item(ticket: Ticket) -> ReturnRenderWith: return {'ticket': ticket} -def jsonify_new_ticket(data_dict): +def jsonify_new_ticket(data_dict: Mapping[str, Any]) -> Response: menu = data_dict['menu'] ticket_form = TicketForm(parent=menu) if request.method == 'GET': return jsonify( form_template=render_form( - form=ticket_form, title="New ticket", submit="Create", with_chrome=False + form=ticket_form, + title=_("New ticket"), + submit=_("Create"), + with_chrome=False, ).get_data(as_text=True) ) if ticket_form.validate_on_submit(): @@ -152,15 +163,15 @@ def admin_new_item(menu: Menu) -> ReturnRenderWith: return {'menu': menu} -def jsonify_edit_ticket(data_dict): +def jsonify_edit_ticket(data_dict: Mapping[str, Any]) -> Response: ticket = data_dict['ticket'] ticket_form = TicketForm(obj=ticket) if request.method == 'GET': return jsonify( form_template=render_form( form=ticket_form, - title="Update ticket", - submit="Update", + title=_("Update ticket"), + submit=_("Update"), with_chrome=False, ).get_data(as_text=True) ) @@ -189,13 +200,16 @@ def admin_edit_item(ticket: Ticket) -> ReturnRenderWith: return {'ticket': ticket} -def jsonify_new_price(data_dict): +def jsonify_new_price(data_dict: Mapping[str, Any]) -> Response: ticket = data_dict['ticket'] price_form = PriceForm(parent=ticket) if request.method == 'GET': return jsonify( form_template=render_form( - form=price_form, title="New price", submit="Save", with_chrome=False + form=price_form, + title=_("New price"), + submit=_("Save"), + with_chrome=False, ).get_data(as_text=True) ) if price_form.validate_on_submit(): @@ -226,13 +240,16 @@ def admin_new_price(ticket: Ticket) -> ReturnRenderWith: return {'ticket': ticket} -def jsonify_edit_price(data_dict): +def jsonify_edit_price(data_dict: Mapping[str, Any]) -> Response: price = data_dict['price'] price_form = PriceForm(obj=price) if request.method == 'GET': return jsonify( form_template=render_form( - form=price_form, title="Update price", submit="Save", with_chrome=False + form=price_form, + title=_("Update price"), + submit=_("Save"), + with_chrome=False, ).get_data(as_text=True) ) if price_form.validate_on_submit(): diff --git a/boxoffice/views/custom_exceptions.py b/boxoffice/views/custom_exceptions.py index d8c84b64..0eaaa4d8 100644 --- a/boxoffice/views/custom_exceptions.py +++ b/boxoffice/views/custom_exceptions.py @@ -1,10 +1,10 @@ -from flask import jsonify, make_response +from flask import Response, jsonify, make_response from .. import app class PaymentGatewayError(Exception): - def __init__(self, message, status_code, response_message): + def __init__(self, message: str, status_code: int, response_message: str) -> None: super().__init__() self.message = message self.status_code = status_code @@ -12,7 +12,7 @@ def __init__(self, message, status_code, response_message): @app.errorhandler(PaymentGatewayError) -def handle_api_error(error): +def handle_api_error(error: PaymentGatewayError) -> Response: app.logger.error("Boxoffice Payment Gateway Error: %s", error.message) return make_response( jsonify( diff --git a/boxoffice/views/login.py b/boxoffice/views/login.py index 1735a6f2..800ca9e8 100644 --- a/boxoffice/views/login.py +++ b/boxoffice/views/login.py @@ -1,4 +1,5 @@ from flask import flash, redirect +from flask.typing import ResponseReturnValue from markupsafe import Markup, escape from baseframe import _ @@ -11,20 +12,20 @@ @app.route('/login') @lastuser.login_handler -def login(): +def login() -> dict[str, str]: return {'scope': 'id email phone organizations'} @app.route('/logout') @lastuser.logout_handler -def logout(): +def logout() -> ResponseReturnValue: flash(_("You are now logged out"), category='success') return get_next_url() @app.route('/login/redirect') @lastuser.auth_handler -def lastuserauth(): +def lastuserauth() -> ResponseReturnValue: return redirect(get_next_url()) @@ -38,7 +39,9 @@ def lastusernotify(_user: User) -> None: @lastuser.auth_error_handler -def lastuser_error(error, error_description=None, error_uri=None): +def lastuser_error( + error: str, error_description: str | None = None, error_uri: str | None = None +) -> ResponseReturnValue: if error == 'access_denied': flash(_("You denied the request to login"), category='error') return redirect(get_next_url()) diff --git a/boxoffice/views/menu.py b/boxoffice/views/menu.py index 8050bff7..377aa5e7 100644 --- a/boxoffice/views/menu.py +++ b/boxoffice/views/menu.py @@ -1,4 +1,9 @@ -from flask import jsonify, render_template, request +from datetime import datetime +from decimal import Decimal +from typing import TypedDict +from uuid import UUID + +from flask import Response, jsonify, render_template, request from markupsafe import Markup from baseframe import localized_country_list @@ -18,7 +23,39 @@ from .utils import cors, sanitize_coupons, xhr_only -def jsonify_ticket(ticket: Ticket): +class DiscountPolicyDict(TypedDict): + id: UUID + title: str + is_automatic: bool + + +class TicketDict(TypedDict): + name: str + title: str + id: UUID + description: str | None + description_html: str | None + quantity_available: int + is_available: bool + quantity_total: int + category_id: int + menu_id: UUID + price: Decimal + price_category: str + price_valid_upto: datetime + has_higher_price: bool + discount_policies: list[DiscountPolicyDict] + + +class CategoryDict(TypedDict): + id: int + title: str + name: str + menu_id: UUID + tickets: list[TicketDict] + + +def jsonify_ticket(ticket: Ticket) -> TicketDict | None: if ticket.restricted_entry: code_list = ( sanitize_coupons(request.args.getlist('code')) @@ -60,7 +97,7 @@ def jsonify_ticket(ticket: Ticket): } -def jsonify_category(category: Category): +def jsonify_category(category: Category) -> CategoryDict | None: category_items = [] for ticket in Ticket.get_by_category(category): ticket_json = jsonify_ticket(ticket) @@ -77,15 +114,12 @@ def jsonify_category(category: Category): return None -def render_boxoffice_js(): +def render_boxoffice_js() -> str: return render_template( 'boxoffice.js.jinja2', base_url=request.url_root.rstrip('/'), razorpay_key_id=app.config['RAZORPAY_KEY_ID'], - states=[ - {'name': state['name'], 'code': state['short_code_text']} - for state in sorted(indian_states, key=lambda k: k['name']) - ], + states=[{'name': state.title, 'code': state.name} for state in indian_states], countries=[ {'name': name, 'code': code} for code, name in localized_country_list() ], @@ -94,7 +128,7 @@ def render_boxoffice_js(): @app.route('/api/1/boxoffice.js') @cors -def boxofficejs(): +def boxofficejs() -> Response: return jsonify({'script': render_boxoffice_js()}) @@ -102,7 +136,7 @@ def boxofficejs(): @xhr_only @cors @load_models((Menu, {'id': 'menu_id'}, 'menu')) -def view_menu(menu: Menu): +def view_menu(menu: Menu) -> Response: categories_json = [] for category in menu.categories: category_json = jsonify_category(category) @@ -125,7 +159,7 @@ def view_menu(menu: Menu): 'menu', ), ) -def menu_listing(organization: Organization, menu: Menu): +def menu_listing(organization: Organization, menu: Menu) -> str: show_title = getbool(request.args.get('show_title', True)) return render_template( 'item_collection_listing.html.jinja2', diff --git a/boxoffice/views/order.py b/boxoffice/views/order.py index 62827196..077270fc 100644 --- a/boxoffice/views/order.py +++ b/boxoffice/views/order.py @@ -1,8 +1,9 @@ +from collections.abc import Iterable, Mapping from decimal import Decimal -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, cast, overload +from uuid import UUID -from flask import abort, jsonify, render_template, request, url_for -from sqlalchemy.sql import func +from flask import Response, abort, jsonify, render_template, request, url_for from werkzeug.datastructures import ImmutableMultiDict from baseframe import _, localized_country_list @@ -34,6 +35,7 @@ Invoice, LineItem, LineItemStatus, + LineItemTuple, Menu, OnlinePayment, Order, @@ -58,7 +60,25 @@ ) -def jsonify_line_items(line_items): +class LineItemData(TypedDict): + is_available: bool + quantity: int + base_amount: NotRequired[Decimal] + final_amount: Decimal | None + discounted_amount: Decimal | None + discount_policy_ids: list[UUID] + quantity_available: NotRequired[int] + + +class AssigneeDict(TypedDict): + id: int + fullname: str + email: str + phone: str | None + details: dict + + +def jsonify_line_items(line_items: Iterable[LineItemTuple]) -> dict[str, LineItemData]: """ Serialize and return line items. @@ -69,15 +89,16 @@ def jsonify_line_items(line_items): 'quantity': Y, 'final_amount': Z, 'discounted_amount': Z, - 'discount_policy_ids': ['d1', 'd2'] + 'discount_policy_ids': ['d1', 'd2'], } } """ - items_json = {} + items_json: dict[str, LineItemData] = {} for line_item in line_items: ticket = Ticket.query.get_or_404(line_item.ticket_id) - if not items_json.get(str(line_item.ticket_id)): - items_json[str(line_item.ticket_id)] = { + ticket_id = str(line_item.ticket_id) + if ticket_id not in items_json: + items_json[ticket_id] = { 'is_available': ticket.is_available, 'quantity': 0, 'final_amount': Decimal(0), @@ -85,33 +106,40 @@ def jsonify_line_items(line_items): 'discount_policy_ids': [], } if line_item.base_amount is not None: - items_json[str(line_item.ticket_id)]['base_amount'] = line_item.base_amount - items_json[str(line_item.ticket_id)]['final_amount'] += ( - line_item.base_amount - line_item.discounted_amount - ) - items_json[str(line_item.ticket_id)][ - 'discounted_amount' - ] += line_item.discounted_amount + items_json[ticket_id]['base_amount'] = line_item.base_amount + items_json[ticket_id]['final_amount'] = ( + items_json[ticket_id]['final_amount'] or Decimal('0') + ) + (line_item.base_amount - (line_item.discounted_amount or Decimal('0'))) + items_json[ticket_id]['discounted_amount'] = ( + items_json[ticket_id]['discounted_amount'] or Decimal('0') + ) + (line_item.discounted_amount or Decimal('0')) + items_json[ticket_id]['quantity_available'] = ticket.quantity_available else: - items_json[str(line_item.ticket_id)]['final_amount'] = None - items_json[str(line_item.ticket_id)]['discounted_amount'] = None - items_json[str(line_item.ticket_id)]['quantity'] += 1 - items_json[str(line_item.ticket_id)][ - 'quantity_available' - ] = ticket.quantity_available + items_json[ticket_id]['final_amount'] = None + items_json[ticket_id]['discounted_amount'] = None + + items_json[ticket_id]['quantity'] += 1 if ( line_item.discount_policy_id and line_item.discount_policy_id - not in items_json[str(line_item.ticket_id)]['discount_policy_ids'] + not in items_json[ticket_id]['discount_policy_ids'] ): - items_json[str(line_item.ticket_id)]['discount_policy_ids'].append( + items_json[ticket_id]['discount_policy_ids'].append( line_item.discount_policy_id ) return items_json -def jsonify_assignee(assignee): - if assignee: +@overload +def jsonify_assignee(assignee: None) -> None: ... + + +@overload +def jsonify_assignee(assignee: Assignee) -> AssigneeDict: ... + + +def jsonify_assignee(assignee: Assignee | None) -> AssigneeDict | None: + if assignee is not None: return { 'id': assignee.id, 'fullname': assignee.fullname, @@ -122,7 +150,7 @@ def jsonify_assignee(assignee): return None -def jsonify_order(data): +def jsonify_order(data: Mapping[str, Any]) -> Response: order = data['order'] line_items = [] for line_item in order.line_items: @@ -161,7 +189,7 @@ def jsonify_order(data): @app.route('/order/kharcha', methods=['OPTIONS', 'POST']) @xhr_only @cors -def kharcha(): +def kharcha() -> Response: """ Calculate rates for an order of items and quantities. @@ -175,7 +203,7 @@ def kharcha(): "quantity": Y, "final_amount": Z, "discounted_amount": Z, - "discount_policy_ids": ["d1", "d2"] + "discount_policy_ids": ["d1", "d2"], } } """ @@ -207,7 +235,7 @@ def kharcha(): @xhr_only @cors @load_models((Menu, {'id': 'menu_id'}, 'menu')) -def create_order(menu: Menu): +def create_order(menu: Menu) -> Response: """ Create an order. @@ -218,17 +246,19 @@ def create_order(menu: Menu): and the URL to be used to register a payment against the order. """ if not request.json or not request.json.get('line_items'): - return api_error(message="Missing line items", status_code=400) + return api_error(message=_("Missing line items"), status_code=400) line_item_forms = LineItemForm.process_list(request.json.get('line_items', [])) if not line_item_forms: - return api_error(message="Invalid line items", status_code=400) + return api_error(message=_("Invalid line items"), status_code=400) # See comment in LineItemForm about CSRF buyer_form = BuyerForm( formdata=ImmutableMultiDict(request.json.get('buyer')), meta={'csrf': False} ) if not buyer_form.validate(): return api_error( - message="Invalid buyer details", status_code=400, errors=buyer_form.errors + message=_("Invalid buyer details"), + status_code=400, + errors=buyer_form.errors, ) invalid_quantity_error_msg = _( @@ -359,7 +389,7 @@ def create_order(menu: Menu): @xhr_only @cors @load_models((Order, {'id': 'order'}, 'order')) -def free(order: Order): +def free(order: Order) -> Response: """Complete a order which has a final_amount of 0.""" order_amounts = order.get_amounts(LineItemStatus.PURCHASE_ORDER) if order_amounts.final_amount == 0: @@ -386,14 +416,14 @@ def free(order: Order): status_code=201, ) - return api_error(message="Free order confirmation failed", status_code=422) + return api_error(message=_("Free order confirmation failed"), status_code=422) @app.route('/order//payment', methods=['GET', 'OPTIONS', 'POST']) @xhr_only @cors @load_models((Order, {'id': 'order'}, 'order')) -def capture_payment(order: Order): +def capture_payment(order: Order) -> Response: """ Capture a payment. @@ -407,7 +437,7 @@ def capture_payment(order: Order): if TYPE_CHECKING: assert request.json is not None # nosec B101 if not request.json.get('pg_paymentid'): - return api_error(message="Missing payment id", status_code=400) + return api_error(message=_("Missing payment id"), status_code=400) order_amounts = order.get_amounts(LineItemStatus.PURCHASE_ORDER) online_payment = OnlinePayment( @@ -473,7 +503,7 @@ def capture_payment(order: Order): @app.route('/order//receipt', methods=['GET']) @load_models((Order, {'access_token': 'access_token'}, 'order')) -def receipt(order: Order): +def receipt(order: Order) -> str: if not order.is_confirmed: abort(404) line_items = LineItem.query.filter( @@ -488,7 +518,22 @@ def receipt(order: Order): ) -def jsonify_invoice(invoice): +class InvoiceDict(TypedDict): + id: UUID + buyer_taxid: str | None + invoicee_name: str | None + invoicee_company: str | None + invoicee_email: str | None + street_address_1: str | None + street_address_2: str | None + city: str | None + postcode: str | None + country_code: str | None + state_code: str | None + state: str | None + + +def jsonify_invoice(invoice: Invoice) -> InvoiceDict: return { 'id': invoice.id, 'buyer_taxid': invoice.buyer_taxid, @@ -509,8 +554,8 @@ def jsonify_invoice(invoice): @xhr_only @cors @load_models((Order, {'access_token': 'access_token'}, 'order')) -def edit_invoice_details(order: Order): - """Update invoice with buyer's address and taxid.""" +def edit_invoice_details(order: Order) -> Response: + """Update invoice with buyer's address and tax id.""" if TYPE_CHECKING: assert request.json is not None # nosec B101 if not order.is_confirmed: @@ -544,17 +589,14 @@ def edit_invoice_details(order: Order): ) -def jsonify_invoices(data_dict): +def jsonify_invoices(data_dict: Mapping[str, Any]) -> Response: invoices_list = [] for invoice in data_dict['invoices']: invoices_list.append(jsonify_invoice(invoice)) return jsonify( invoices=invoices_list, access_token=data_dict['order'].access_token, - states=[ - {'name': state['name'], 'code': state['short_code_text']} - for state in sorted(indian_states, key=lambda k: k['name']) - ], + states=[{'name': state.title, 'code': state.name} for state in indian_states], countries=[ {'name': name, 'code': code} for code, name in localized_country_list() ], @@ -588,24 +630,41 @@ def order_ticket(order: Order) -> ReturnRenderWith: return {'order': order, 'org': order.organization} -def jsonify_orders(orders): - api_orders = [] +class TicketDict(TypedDict): + title: str + + +class LineItemDict(TypedDict): + assignee: dict[str, Any] + line_item_seq: int + line_item_status: Literal['confirmed', 'cancelled'] + ticket: TicketDict - def format_assignee(line_item): - if not line_item.current_assignee: - return {} - assignee = { - 'fullname': line_item.current_assignee.fullname, - 'email': line_item.current_assignee.email, - 'phone': line_item.current_assignee.phone, - } - for key in line_item.ticket.assignee_details: - assignee[key] = line_item.current_assignee.details.get(key) - return assignee +class OrdersDict(TypedDict): + receipt_no: int | None + line_items: list[LineItemDict] + + +def format_assignee(line_item: LineItem) -> dict[str, Any]: + if not line_item.current_assignee: + return {} + assignee = { + 'fullname': line_item.current_assignee.fullname, + 'email': line_item.current_assignee.email, + 'phone': line_item.current_assignee.phone, + } + + for key in line_item.ticket.assignee_details: + assignee[key] = line_item.current_assignee.details.get(key) + return assignee + + +def jsonify_orders(orders: Iterable[Order]) -> list[OrdersDict]: + api_orders = [] for order in orders: - order_dict = {'receipt_no': order.receipt_no, 'line_items': []} + order_dict: OrdersDict = {'receipt_no': order.receipt_no, 'line_items': []} for line_item in order.line_items: order_dict['line_items'].append( { @@ -621,7 +680,7 @@ def format_assignee(line_item): return api_orders -def get_coupon_codes_from_line_items(line_items): +def get_coupon_codes_from_line_items(line_items: Iterable[LineItem]) -> list[str]: coupon_ids = [ line_item.discount_coupon.id for line_item in line_items @@ -639,9 +698,16 @@ def get_coupon_codes_from_line_items(line_items): def regenerate_line_item( - order: Order, original_line_item: LineItem, updated_line_item_tup, line_item_seq -): + order: Order, + original_line_item: LineItem, + updated_line_item_tup: LineItemTuple, + line_item_seq: int, +) -> LineItem: """Update a line item by marking the original as void and creating a replacement.""" + assert updated_line_item_tup.base_amount is not None # noqa: S101 # nosec B101 + assert ( # noqa: S101 # nosec B101 + updated_line_item_tup.discounted_amount is not None + ) original_line_item.make_void() ticket = Ticket.query.get_or_404(updated_line_item_tup.ticket_id) if updated_line_item_tup.discount_policy_id: @@ -674,9 +740,11 @@ def regenerate_line_item( def update_order_on_line_item_cancellation( - order, pre_cancellation_line_items, cancelled_line_item -): - """Cancel the given line item and updates the order.""" + order: Order, + pre_cancellation_line_items: Iterable[LineItem], + cancelled_line_item: LineItem, +) -> Order: + """Cancel the given line item and update the order.""" active_line_items = [ pre_cancellation_line_item for pre_cancellation_line_item in pre_cancellation_line_items @@ -692,11 +760,11 @@ def update_order_on_line_item_cancellation( recalculated_line_item_tups, start=last_line_item_seq + 1 ): # Fetch the line item object - pre_cancellation_line_item = [ + pre_cancellation_line_item = next( pre_cancellation_line_item for pre_cancellation_line_item in pre_cancellation_line_items if pre_cancellation_line_item.id == line_item_tup.id - ][0] + ) # Check if the line item's amount has changed post-cancellation if line_item_tup.final_amount != pre_cancellation_line_item.final_amount: # Amount has changed, void this line item and regenerate the line item @@ -719,7 +787,7 @@ def update_order_on_line_item_cancellation( return order -def process_line_item_cancellation(line_item): +def process_line_item_cancellation(line_item: LineItem) -> Decimal: order = line_item.order # initialize refund_amount to 0 refund_amount = Decimal(0) @@ -765,7 +833,7 @@ def process_line_item_cancellation(line_item): online_payment=payment, amount=refund_amount, currency=CurrencyEnum.INR, - refunded_at=func.utcnow(), + refunded_at=sa.func.utcnow(), refund_description=_("Refund: {line_item_title}").format( line_item_title=line_item.ticket.title ), @@ -795,7 +863,7 @@ def process_line_item_cancellation(line_item): @app.route('/line_item//cancel', methods=['POST']) @lastuser.requires_login @load_models((LineItem, {'id': 'line_item_id'}, 'line_item'), permission='org_admin') -def cancel_line_item(line_item: LineItem): +def cancel_line_item(line_item: LineItem) -> Response: if not line_item.is_cancellable(): return api_error( message='This ticket is not cancellable', @@ -812,7 +880,7 @@ def cancel_line_item(line_item: LineItem): ) -def process_partial_refund_for_order(data_dict): +def process_partial_refund_for_order(data_dict: Mapping[str, Any]) -> Response: order = data_dict['order'] form = data_dict['form'] request_method = data_dict['request_method'] @@ -820,7 +888,10 @@ def process_partial_refund_for_order(data_dict): if request_method == 'GET': return jsonify( form_template=render_form( - form=form, title="Partial refund", submit="Refund", with_chrome=False + form=form, + title=_("Partial refund"), + submit=_("Refund"), + with_chrome=False, ).get_data(as_text=True) ) if form.validate_on_submit(): @@ -837,7 +908,7 @@ def process_partial_refund_for_order(data_dict): online_payment=payment, currency=CurrencyEnum.INR, pg_refundid=rp_refund['id'], - refunded_at=func.utcnow(), + refunded_at=sa.func.utcnow(), ) form.populate_obj(transaction) db.session.add(transaction) @@ -888,7 +959,7 @@ def partial_refund_order(order: Order) -> ReturnRenderWith: @app.route('/api/1/ic//orders', methods=['GET', 'OPTIONS']) @app.route('/api/1/menu//orders', methods=['GET', 'OPTIONS']) @load_models((Menu, {'id': 'menu_id'}, 'menu')) -def menu_orders(menu: Menu): +def menu_orders(menu: Menu) -> Response: organization = menu.organization # TODO: Replace with a better authentication system if not request.args.get('access_token') or request.args.get( diff --git a/boxoffice/views/participant.py b/boxoffice/views/participant.py index 3ea7bbb5..05f02ede 100644 --- a/boxoffice/views/participant.py +++ b/boxoffice/views/participant.py @@ -67,7 +67,7 @@ def assign(order: Order) -> ReturnRenderWith: ticket_assignee_details = line_item.ticket.assignee_details assignee_details = {} if ticket_assignee_details: - for key in ticket_assignee_details.keys(): + for key in ticket_assignee_details: assignee_details[key] = assignee_dict.get(key) if ( line_item.current_assignee diff --git a/boxoffice/views/utils.py b/boxoffice/views/utils.py index 6ca8e295..e0be6f4d 100644 --- a/boxoffice/views/utils.py +++ b/boxoffice/views/utils.py @@ -1,27 +1,39 @@ +import csv +from collections.abc import Callable, Sequence +from datetime import datetime from functools import wraps from io import StringIO +from typing import Any, Literal, ParamSpec, TypeVar, overload from urllib.parse import urlparse -import csv from flask import Response, abort, jsonify, make_response, request +from werkzeug.wrappers import Response as BaseResponse from baseframe import localize_timezone, request_is_xhr from .. import app +_R_co = TypeVar('_R_co', covariant=True) +_P = ParamSpec('_P') + -def sanitize_coupons(coupons): +def sanitize_coupons(coupons: Any) -> list[str]: if not isinstance(coupons, list): return [] # Remove falsy values and coerce the valid values into unicode return [str(coupon_code) for coupon_code in coupons if coupon_code] -def xhr_only(f): +def request_wants_json() -> bool: + """Request wants a JSON response.""" + return request.accept_mimetypes.best == 'application/json' + + +def xhr_only(f: Callable[_P, _R_co]) -> Callable[_P, _R_co]: """Abort if a request does not have the XMLHttpRequest header set.""" @wraps(f) - def wrapper(*args, **kwargs): + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R_co: if request.method != 'OPTIONS' and not request_is_xhr(): abort(400) return f(*args, **kwargs) @@ -29,22 +41,32 @@ def wrapper(*args, **kwargs): return wrapper -def check_api_access(api_token): +def check_api_access(api_token: Any | None) -> None: """Abort if a request does not have the correct api_token.""" if not request.args.get('api_token') or request.args.get('api_token') != api_token: abort(401) -def json_date_format(dt): - return localize_timezone(dt).isoformat() +@overload +def json_date_format(dt: None) -> None: ... + + +@overload +def json_date_format(dt: datetime) -> str: ... + + +def json_date_format(dt: datetime | None) -> str | None: + if dt is not None: + return localize_timezone(dt).isoformat() + return None @app.template_filter('longdate') -def longdate(date): +def longdate(date: datetime) -> str: return localize_timezone(date).strftime('%e %B %Y') -def basepath(url): +def basepath(url: str) -> str: """ Return the base path of a given URL. @@ -57,11 +79,12 @@ def basepath(url): """ parsed_url = urlparse(url) if not (parsed_url.scheme or parsed_url.netloc): - raise ValueError("Invalid URL") + msg = "Invalid URL" + raise ValueError(msg) return f'{parsed_url.scheme}://{parsed_url.netloc}' -def cors(f): +def cors(f: Callable[_P, BaseResponse]) -> Callable[_P, BaseResponse]: """ Add CORS headers to the decorated view function. @@ -69,27 +92,26 @@ def cors(f): of permitted domains. Eg: app.config['ALLOWED_ORIGINS'] = ['https://example.com'] """ - def add_headers(resp, origin): + def add_headers(resp: BaseResponse, origin: str) -> BaseResponse: resp.headers['Access-Control-Allow-Origin'] = origin resp.headers['Access-Control-Allow-Methods'] = 'POST, OPTIONS, GET' # echo the request's headers - resp.headers['Access-Control-Allow-Headers'] = request.headers.get( - 'Access-Control-Request-Headers' - ) + if allow_headers := request.headers.get('Access-Control-Request-Headers'): + resp.headers['Access-Control-Allow-Headers'] = allow_headers # debugging only if app.debug: resp.headers['Access-Control-Max-Age'] = '1' return resp @wraps(f) - def wrapper(*args, **kwargs): + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> BaseResponse: origin = request.headers.get('Origin') if not origin: - # Firefox doesn't send the Origin header, so read the Referer header instead - # TODO: Remove this conditional when Firefox starts adding an Origin header - referer = request.referrer - if referer: - origin = basepath(referer) + # Firefox doesn't send the Origin header in some contexts, so read the + # Referrer header instead. + # https://wiki.mozilla.org/Security/Origin#Privacy-Sensitive_Contexts + referrer = request.referrer + origin = basepath(referrer) if referrer else 'null' if ( request.method == 'POST' @@ -100,7 +122,7 @@ def wrapper(*args, **kwargs): if request.method == 'OPTIONS': # pre-flight request, check CORS headers directly - resp = app.make_default_options_response() + resp: BaseResponse = app.make_default_options_response() else: resp = f(*args, **kwargs) return add_headers(resp, origin) @@ -108,7 +130,12 @@ def wrapper(*args, **kwargs): return wrapper -def csv_response(headers, rows, row_type=None, row_handler=None): +def csv_response( + headers: list[str], + rows: Sequence[Any], + row_type: Literal['dict'] | None = None, + row_handler: Callable[[list[Any]], list[Any] | dict[str, Any]] | None = None, +) -> Response: """ Return a CSV response given a list of headers and rows of data. @@ -117,6 +144,7 @@ def csv_response(headers, rows, row_type=None, row_handler=None): Accepts an optional row_handler function that can be used to transform the row. """ + csv_writer: Any stream = StringIO() if row_type == 'dict': csv_writer = csv.DictWriter( @@ -133,27 +161,28 @@ def csv_response(headers, rows, row_type=None, row_handler=None): return Response(stream.getvalue(), mimetype='text/csv') -def api_error(message, status_code, errors=()): +def api_error( + message: str, status_code: int = 400, errors: Sequence[str] = () +) -> Response: """ Generate a HTTP response as a JSON object for a failure scenario. - :param string message: Human readable error message to be included as part of the - JSON response - :param string message: Error message to be included as part of the JSON response - :param list errors: Error messages to be included as part of the JSON response - :param int status_code: HTTP status code to be used for the response + :param message: Human readable error message to be included as part of the JSON + response + :param errors: Error messages to be included as part of the JSON response + :param status_code: HTTP status code to be used for the response """ return make_response( jsonify(status='error', errors=errors, message=message), status_code ) -def api_success(result, doc, status_code): +def api_success(result: Any, doc: str, status_code: int = 200) -> Response: """ Generate a HTTP response as a JSON object for a success scenario. - :param any result: Top-level data to be encoded as JSON - :param string doc: Documentation to be included as part of the JSON response - :param int status_code: HTTP status code to be used for the response + :param result: Top-level data to be encoded as JSON + :param doc: Documentation to be included as part of the JSON response + :param status_code: HTTP status code to be used for the response """ return make_response(jsonify(status='ok', doc=doc, result=result), status_code) diff --git a/console.py b/console.py index 2ee8eecd..45f6a726 100644 --- a/console.py +++ b/console.py @@ -1,17 +1,17 @@ """Console script.""" -from collections.abc import Iterable -from decimal import Decimal -from typing import TypeAlias, Union -from uuid import UUID import csv import datetime import logging +from collections.abc import Iterable +from decimal import Decimal +from typing import TypeAlias +from uuid import UUID +import IPython from flask.cli import load_dotenv from flask.typing import ResponseReturnValue from isoweek import Week -import IPython load_dotenv() @@ -44,7 +44,7 @@ logging.basicConfig() logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO) -Timezone: TypeAlias = Union[str, datetime.tzinfo] +Timezone: TypeAlias = str | datetime.tzinfo def sales_by_date( @@ -260,7 +260,8 @@ def partial_refund( } order = Order.query.get(order_id) if order is None: - raise ValueError("Unknown order") + msg = "Unknown order" + raise ValueError(msg) with app.test_request_context(): process_partial_refund_for_order({'order': order, 'form': form_dict}) @@ -287,7 +288,8 @@ def resend_attendee_details_email( ) -> None: menu = Menu.query.get(menu_id) if menu is None: - raise ValueError("Unknown item collection") + msg = "Unknown item collection" + raise ValueError(msg) headers, rows = menu.fetch_all_details() attendee_name_index = headers.index('attendee_fullname') order_id_index = headers.index('order_id') @@ -305,7 +307,8 @@ def resend_attendee_details_email( def order_report(org_name: str) -> None: org = Organization.query.filter_by(name=org_name).one_or_none() if org is None: - raise ValueError("Unknown organization") + msg = "Unknown organization" + raise ValueError(msg) with open('order_report.csv', 'wb', encoding='utf-8') as csvfile: order_writer = csv.writer(csvfile, quoting=csv.QUOTE_ALL) diff --git a/dev_requirements.txt b/dev_requirements.txt index 4b2adb30..7176020b 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,4 +1,6 @@ +bandit mypy +ruff types-python-dateutil types-pytz types-requests diff --git a/instance/settings-sample.py b/instance/settings-sample.py index 560fae20..0510385c 100644 --- a/instance/settings-sample.py +++ b/instance/settings-sample.py @@ -5,7 +5,7 @@ #: Database backend SQLALCHEMY_DATABASE_URI = 'sqlite:///test.db' #: Secret key -SECRET_KEY = 'make this something random' # nosec +SECRET_KEY = 'make this something random' # nosec # noqa: S105 #: Timezone for displayed datetimes TIMEZONE = 'Asia/Kolkata' #: Mail settings diff --git a/instance/testing.py b/instance/testing.py index 107eb7ec..0f920ee6 100644 --- a/instance/testing.py +++ b/instance/testing.py @@ -1,7 +1,7 @@ import os #: Database backend -SECRET_KEY = 'testkey' # nosec +SECRET_KEY = 'testkey' # noqa: S105 # nosec SQLALCHEMY_DATABASE_URI = 'postgresql+psycopg:///boxoffice_testing' # nosec SERVER_NAME = 'boxoffice.test:6500' BASE_URL = 'http://' + SERVER_NAME diff --git a/migrations/env.py b/migrations/env.py index 53cc024a..cb51c23e 100644 --- a/migrations/env.py +++ b/migrations/env.py @@ -1,7 +1,9 @@ -from logging.config import fileConfig import logging +from logging.config import fileConfig +from typing import Any from alembic import context +from alembic.operations import MigrationScript from flask import current_app from sqlalchemy import engine_from_config, pool @@ -33,8 +35,9 @@ # ... etc. -def run_migrations_offline(): - """Run migrations in 'offline' mode. +def run_migrations_offline() -> None: + """ + Run migrations in 'offline' mode. This configures the context with just a URL and not an Engine, though an Engine is acceptable @@ -52,8 +55,9 @@ def run_migrations_offline(): context.run_migrations() -def run_migrations_online(): - """Run migrations in 'online' mode. +def run_migrations_online() -> None: + """ + Run migrations in 'online' mode. In this scenario we need to create an Engine and associate a connection with the context. @@ -63,17 +67,18 @@ def run_migrations_online(): # this callback is used to prevent an auto-migration from being generated # when there are no changes to the schema # reference: http://alembic.readthedocs.org/en/latest/cookbook.html - def process_revision_directives(context, revision, directives): + def process_revision_directives( + _context: Any, _revision: Any, directives: list[MigrationScript] + ) -> None: if getattr(config.cmd_opts, 'autogenerate', False): script = directives[0] - if script.upgrade_ops.is_empty(): + if not script.upgrade_ops or script.upgrade_ops.is_empty(): directives[:] = [] logger.info('No changes in schema detected.') + use_config = config.get_section(config.config_ini_section) or {} engine = engine_from_config( - config.get_section(config.config_ini_section), - prefix='sqlalchemy.', - poolclass=pool.NullPool, + use_config, prefix='sqlalchemy.', poolclass=pool.NullPool ) connection = engine.connect() diff --git a/migrations/versions/10ac78260434_add_index_to_line_item.py b/migrations/versions/10ac78260434_add_index_to_line_item.py index 2bd81bd5..99fec7e2 100644 --- a/migrations/versions/10ac78260434_add_index_to_line_item.py +++ b/migrations/versions/10ac78260434_add_index_to_line_item.py @@ -1,4 +1,5 @@ -"""add index to line_item. +""" +add index to line_item. Revision ID: 10ac78260434 Revises: 50f34bb47dc4 @@ -13,7 +14,7 @@ down_revision = '50f34bb47dc4' -def upgrade(): +def upgrade() -> None: op.create_index( op.f('ix_line_item_customer_order_id'), 'line_item', @@ -37,7 +38,7 @@ def upgrade(): ) -def downgrade(): +def downgrade() -> None: op.drop_index(op.f('ix_line_item_customer_order_id'), table_name='line_item') op.drop_index(op.f('ix_line_item_item_id'), table_name='line_item') op.drop_index(op.f('ix_line_item_discount_policy_id'), table_name='line_item') diff --git a/migrations/versions/11bc47d6d60b_add_cancellable_until.py b/migrations/versions/11bc47d6d60b_add_cancellable_until.py index 0e7fe6d0..22e7c658 100644 --- a/migrations/versions/11bc47d6d60b_add_cancellable_until.py +++ b/migrations/versions/11bc47d6d60b_add_cancellable_until.py @@ -1,4 +1,5 @@ -"""add cancellable_until. +""" +add cancellable_until. Revision ID: 11bc47d6d60b Revises: dadc5748932 @@ -6,17 +7,17 @@ """ -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision = '11bc47d6d60b' down_revision = 'dadc5748932' -def upgrade(): +def upgrade() -> None: op.add_column('item', sa.Column('cancellable_until', sa.DateTime(), nullable=True)) -def downgrade(): +def downgrade() -> None: op.drop_column('item', 'cancellable_until') diff --git a/migrations/versions/171fcb171759_add_previous_to_line_item.py b/migrations/versions/171fcb171759_add_previous_to_line_item.py index 9932550f..d2794115 100644 --- a/migrations/versions/171fcb171759_add_previous_to_line_item.py +++ b/migrations/versions/171fcb171759_add_previous_to_line_item.py @@ -1,4 +1,5 @@ -"""add_previous_to_line_item. +""" +add_previous_to_line_item. Revision ID: 171fcb171759 Revises: 81f30d00706f @@ -6,16 +7,16 @@ """ +import sqlalchemy as sa from alembic import op from sqlalchemy.dialects import postgresql -import sqlalchemy as sa # revision identifiers, used by Alembic. revision = '171fcb171759' down_revision = '81f30d00706f' -def upgrade(): +def upgrade() -> None: op.add_column( 'line_item', sa.Column('previous_id', postgresql.UUID(), nullable=True), @@ -28,7 +29,7 @@ def upgrade(): ) -def downgrade(): +def downgrade() -> None: op.drop_constraint('line_item_id_fkey', 'line_item', type_='foreignkey') op.drop_index(op.f('ix_line_item_previous_id'), table_name='line_item') op.drop_column('line_item', 'previous_id') diff --git a/migrations/versions/18576fdffd86_add_notes_and_refund_timestamp_to_.py b/migrations/versions/18576fdffd86_add_notes_and_refund_timestamp_to_.py index 37184119..53a683a7 100644 --- a/migrations/versions/18576fdffd86_add_notes_and_refund_timestamp_to_.py +++ b/migrations/versions/18576fdffd86_add_notes_and_refund_timestamp_to_.py @@ -1,4 +1,5 @@ -"""add_notes_and_refund_timestamp_to_transaction. +""" +add_notes_and_refund_timestamp_to_transaction. Revision ID: 18576fdffd86 Revises: 4246213b032b @@ -6,15 +7,15 @@ """ -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision = '18576fdffd86' down_revision = '4246213b032b' -def upgrade(): +def upgrade() -> None: op.add_column( 'payment_transaction', sa.Column('internal_note', sa.Unicode(length=250), nullable=True), @@ -36,7 +37,7 @@ def upgrade(): ) -def downgrade(): +def downgrade() -> None: op.drop_column('payment_transaction', 'refunded_at') op.drop_column('payment_transaction', 'note_to_user_text') op.drop_column('payment_transaction', 'note_to_user_html') diff --git a/migrations/versions/1a22f5035244_create_invoice.py b/migrations/versions/1a22f5035244_create_invoice.py index f33f1510..47e2a09f 100644 --- a/migrations/versions/1a22f5035244_create_invoice.py +++ b/migrations/versions/1a22f5035244_create_invoice.py @@ -1,4 +1,5 @@ -"""create_invoice. +""" +create_invoice. Revision ID: 1a22f5035244 Revises: 36f458047cfd @@ -6,16 +7,16 @@ """ +import sqlalchemy as sa from alembic import op from sqlalchemy.dialects import postgresql -import sqlalchemy as sa # revision identifiers, used by Alembic. revision = '1a22f5035244' down_revision = '36f458047cfd' -def upgrade(): +def upgrade() -> None: op.create_table( 'invoice', sa.Column('created_at', sa.DateTime(), nullable=False), @@ -64,7 +65,7 @@ def upgrade(): ) -def downgrade(): +def downgrade() -> None: op.drop_column('item_collection', 'tax_type') op.drop_index(op.f('ix_invoice_customer_order_id'), table_name='invoice') op.drop_table('invoice') diff --git a/migrations/versions/1ea1e8070ac8_add_assignee.py b/migrations/versions/1ea1e8070ac8_add_assignee.py index 5e690f63..059cfdb4 100644 --- a/migrations/versions/1ea1e8070ac8_add_assignee.py +++ b/migrations/versions/1ea1e8070ac8_add_assignee.py @@ -1,4 +1,5 @@ -"""add assignee. +""" +add assignee. Revision ID: 1ea1e8070ac8 Revises: adb90a264e3 @@ -6,16 +7,16 @@ """ +import sqlalchemy as sa from alembic import op from sqlalchemy.dialects import postgresql -import sqlalchemy as sa # revision identifiers, used by Alembic. revision = '1ea1e8070ac8' down_revision = 'adb90a264e3' -def upgrade(): +def upgrade() -> None: op.create_table( 'assignee', sa.Column('created_at', sa.DateTime(), nullable=False), @@ -51,7 +52,7 @@ def upgrade(): op.create_index(op.f('assignee_email_key'), 'assignee', ['email'], unique=True) -def downgrade(): +def downgrade() -> None: op.drop_index(op.f('assignee_email_key'), table_name='assignee') op.drop_constraint('line_item_assignee_id_fkey', 'line_item', type_='foreignkey') op.drop_column('item', 'assignee_details') diff --git a/migrations/versions/23fc9e293ac3_add_invoice_organization.py b/migrations/versions/23fc9e293ac3_add_invoice_organization.py index 7de5bdb2..4a9757c7 100644 --- a/migrations/versions/23fc9e293ac3_add_invoice_organization.py +++ b/migrations/versions/23fc9e293ac3_add_invoice_organization.py @@ -1,4 +1,5 @@ -"""add_invoice_organization. +""" +add_invoice_organization. Revision ID: 23fc9e293ac3 Revises: 66b67130c901 @@ -6,15 +7,15 @@ """ -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision = '23fc9e293ac3' down_revision = '66b67130c901' -def upgrade(): +def upgrade() -> None: op.add_column('organization', sa.Column('invoicer_id', sa.Integer(), nullable=True)) op.create_foreign_key( 'organization_invoicer_id_id_fkey', @@ -25,7 +26,7 @@ def upgrade(): ) -def downgrade(): +def downgrade() -> None: op.drop_constraint( 'organization_invoicer_id_id_fkey', 'organization', type_='foreignkey' ) diff --git a/migrations/versions/253e7b76eb8e_modify_assignee.py b/migrations/versions/253e7b76eb8e_modify_assignee.py index 6d287488..3aba4b7d 100644 --- a/migrations/versions/253e7b76eb8e_modify_assignee.py +++ b/migrations/versions/253e7b76eb8e_modify_assignee.py @@ -1,4 +1,5 @@ -"""modify assignee. +""" +modify assignee. Revision ID: 253e7b76eb8e Revises: 1ea1e8070ac8 @@ -6,16 +7,16 @@ """ +import sqlalchemy as sa from alembic import op from sqlalchemy.dialects import postgresql -import sqlalchemy as sa # revision identifiers, used by Alembic. revision = '253e7b76eb8e' down_revision = '1ea1e8070ac8' -def upgrade(): +def upgrade() -> None: op.add_column('assignee', sa.Column('current', sa.Boolean(), nullable=True)) op.create_check_constraint('assignee_current_check', 'assignee', "current != '0'") op.add_column( @@ -39,7 +40,7 @@ def upgrade(): op.drop_column('line_item', 'assignee_id') -def downgrade(): +def downgrade() -> None: op.add_column( 'line_item', sa.Column('assignee_id', sa.INTEGER(), autoincrement=False, nullable=True), diff --git a/migrations/versions/27b5ed98d7d0_add_contact_email_to_order.py b/migrations/versions/27b5ed98d7d0_add_contact_email_to_order.py index 5a3f8aee..4b46cd81 100644 --- a/migrations/versions/27b5ed98d7d0_add_contact_email_to_order.py +++ b/migrations/versions/27b5ed98d7d0_add_contact_email_to_order.py @@ -1,4 +1,5 @@ -"""add contact_email to order. +""" +add contact_email to order. Revision ID: 27b5ed98d7d0 Revises: 4ffee334e82e @@ -6,15 +7,15 @@ """ -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision = '27b5ed98d7d0' down_revision = '4ffee334e82e' -def upgrade(): +def upgrade() -> None: op.add_column( 'organization', sa.Column('contact_email', sa.Unicode(length=254), nullable=True), @@ -24,6 +25,6 @@ def upgrade(): ) -def downgrade(): +def downgrade() -> None: op.drop_constraint('organization_contact_email_key', 'organization', type_='unique') op.drop_column('organization', 'contact_email') diff --git a/migrations/versions/32abb3608d9a_simplify_discounts.py b/migrations/versions/32abb3608d9a_simplify_discounts.py index d66cfefa..5fe9f16c 100644 --- a/migrations/versions/32abb3608d9a_simplify_discounts.py +++ b/migrations/versions/32abb3608d9a_simplify_discounts.py @@ -1,4 +1,5 @@ -"""simplify_discounts. +""" +simplify_discounts. Revision ID: 32abb3608d9a Revises: 45de268cd444 @@ -6,16 +7,16 @@ """ +import sqlalchemy as sa from alembic import op from sqlalchemy.sql import column, table -import sqlalchemy as sa # revision identifiers, used by Alembic. revision = '32abb3608d9a' down_revision = '45de268cd444' -def upgrade(): +def upgrade() -> None: discount_coupon = table( 'discount_coupon', column('quantity_available', sa.Integer()), @@ -43,7 +44,7 @@ def upgrade(): op.drop_column('discount_policy', 'item_quantity_max') -def downgrade(): +def downgrade() -> None: discount_coupon = table( 'discount_coupon', column('quantity_total', sa.Integer()), diff --git a/migrations/versions/35952a56c31b_make_is_price_based_non_nullable.py b/migrations/versions/35952a56c31b_make_is_price_based_non_nullable.py index cc460bb1..bc55f25b 100644 --- a/migrations/versions/35952a56c31b_make_is_price_based_non_nullable.py +++ b/migrations/versions/35952a56c31b_make_is_price_based_non_nullable.py @@ -1,4 +1,5 @@ -"""make is_price_based non nullable. +""" +make is_price_based non nullable. Revision ID: 35952a56c31b Revises: 32abb3608d9a @@ -6,9 +7,9 @@ """ +import sqlalchemy as sa from alembic import op from sqlalchemy.sql import column, table -import sqlalchemy as sa # revision identifiers, used by Alembic. revision = '35952a56c31b' @@ -18,7 +19,7 @@ discount_policy = table('discount_policy', column('is_price_based', sa.Boolean())) -def upgrade(): +def upgrade() -> None: op.execute( discount_policy.update() .where(discount_policy.c.is_price_based.is_(None)) @@ -29,7 +30,7 @@ def upgrade(): ) -def downgrade(): +def downgrade() -> None: op.alter_column( 'discount_policy', 'is_price_based', existing_type=sa.BOOLEAN(), nullable=True ) diff --git a/migrations/versions/36f458047cfd_update_index_on_discount_title.py b/migrations/versions/36f458047cfd_update_index_on_discount_title.py index f46b057e..05fc873a 100644 --- a/migrations/versions/36f458047cfd_update_index_on_discount_title.py +++ b/migrations/versions/36f458047cfd_update_index_on_discount_title.py @@ -1,4 +1,5 @@ -"""update_index_on_discount_policy_discount_code_base. +""" +update_index_on_discount_policy_discount_code_base. Revision ID: 36f458047cfd Revises: 3a585b8d5f8d @@ -13,7 +14,7 @@ down_revision = '3a585b8d5f8d' -def upgrade(): +def upgrade() -> None: op.drop_constraint( 'discount_policy_discount_code_base_key', 'discount_policy', type_='unique' ) @@ -24,7 +25,7 @@ def upgrade(): ) -def downgrade(): +def downgrade() -> None: op.drop_constraint( 'discount_policy_organization_id_discount_code_base_key', 'discount_policy', diff --git a/migrations/versions/3a585b8d5f8d_add_trigram_index_for_discount_policy_.py b/migrations/versions/3a585b8d5f8d_add_trigram_index_for_discount_policy_.py index da8d5831..2fa57d1b 100644 --- a/migrations/versions/3a585b8d5f8d_add_trigram_index_for_discount_policy_.py +++ b/migrations/versions/3a585b8d5f8d_add_trigram_index_for_discount_policy_.py @@ -1,4 +1,5 @@ -"""add_trigram_index_for_discount_policy_title. +""" +add_trigram_index_for_discount_policy_title. Revision ID: 3a585b8d5f8d Revises: 4246213b032b @@ -6,15 +7,15 @@ """ -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision = '3a585b8d5f8d' down_revision = '18576fdffd86' -def upgrade(): +def upgrade() -> None: op.execute( sa.DDL( ''' @@ -26,5 +27,5 @@ def upgrade(): ) -def downgrade(): +def downgrade() -> None: op.drop_index('idx_discount_policy_title_trgm', 'discount_policy') diff --git a/migrations/versions/4246213b032b_bulk_discount_coupon_usage_limit_check.py b/migrations/versions/4246213b032b_bulk_discount_coupon_usage_limit_check.py index 0d5b1ad5..ed5aa390 100644 --- a/migrations/versions/4246213b032b_bulk_discount_coupon_usage_limit_check.py +++ b/migrations/versions/4246213b032b_bulk_discount_coupon_usage_limit_check.py @@ -1,4 +1,5 @@ -"""bulk_discount_coupon_usage_limit_check. +""" +bulk_discount_coupon_usage_limit_check. Revision ID: 4246213b032b Revises: 58f4f3c4fb01 @@ -6,9 +7,9 @@ """ +import sqlalchemy as sa from alembic import op from sqlalchemy.sql import column, table -import sqlalchemy as sa # revision identifiers, used by Alembic. revision = '4246213b032b' @@ -22,7 +23,7 @@ ) -def upgrade(): +def upgrade() -> None: op.execute( discount_policy.update() .where(discount_policy.c.discount_type == 1) @@ -37,7 +38,7 @@ def upgrade(): ) -def downgrade(): +def downgrade() -> None: op.drop_constraint( 'discount_policy_bulk_coupon_usage_limit_check', 'discount_policy' ) diff --git a/migrations/versions/45de268cd444_add_discount_policy_id_to_price.py b/migrations/versions/45de268cd444_add_discount_policy_id_to_price.py index 567539e9..e15b77dc 100644 --- a/migrations/versions/45de268cd444_add_discount_policy_id_to_price.py +++ b/migrations/versions/45de268cd444_add_discount_policy_id_to_price.py @@ -1,4 +1,5 @@ -"""add discount_policy_id to price. +""" +add discount_policy_id to price. Revision ID: 45de268cd444 Revises: 4d7f840202d2 @@ -6,16 +7,16 @@ """ +import sqlalchemy as sa from alembic import op from sqlalchemy.dialects import postgresql -import sqlalchemy as sa # revision identifiers, used by Alembic. revision = '45de268cd444' down_revision = '4d7f840202d2' -def upgrade(): +def upgrade() -> None: op.add_column( 'price', sa.Column('discount_policy_id', postgresql.UUID(), nullable=True), @@ -40,7 +41,7 @@ def upgrade(): ) -def downgrade(): +def downgrade() -> None: op.drop_constraint('price_discount_policy_id_fkey', 'price', type_='foreignkey') op.drop_constraint('price_item_id_discount_policy_id_key', 'price', type_='unique') op.drop_column('price', 'discount_policy_id') diff --git a/migrations/versions/48e571c759cb_add_order_session.py b/migrations/versions/48e571c759cb_add_order_session.py index cb7ec468..7d830686 100644 --- a/migrations/versions/48e571c759cb_add_order_session.py +++ b/migrations/versions/48e571c759cb_add_order_session.py @@ -1,4 +1,5 @@ -"""add_order_session. +""" +add_order_session. Revision ID: 48e571c759cb Revises: 510d2db2b008 @@ -6,16 +7,16 @@ """ +import sqlalchemy as sa from alembic import op from sqlalchemy.dialects import postgresql -import sqlalchemy as sa # revision identifiers, used by Alembic. revision = '48e571c759cb' down_revision = '510d2db2b008' -def upgrade(): +def upgrade() -> None: op.create_table( 'order_session', sa.Column('created_at', sa.DateTime(), nullable=False), @@ -68,7 +69,7 @@ def upgrade(): ) -def downgrade(): +def downgrade() -> None: op.drop_index(op.f('ix_order_session_utm_source'), table_name='order_session') op.drop_index(op.f('ix_order_session_utm_medium'), table_name='order_session') op.drop_index(op.f('ix_order_session_utm_id'), table_name='order_session') diff --git a/migrations/versions/4d7f840202d2_make_contact_email_non_nullable.py b/migrations/versions/4d7f840202d2_make_contact_email_non_nullable.py index b57f6cf5..0312dfb2 100644 --- a/migrations/versions/4d7f840202d2_make_contact_email_non_nullable.py +++ b/migrations/versions/4d7f840202d2_make_contact_email_non_nullable.py @@ -1,4 +1,5 @@ -"""make contact_email non nullable. +""" +make contact_email non nullable. Revision ID: 4d7f840202d2 Revises: 27b5ed98d7d0 @@ -6,15 +7,15 @@ """ -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision = '4d7f840202d2' down_revision = '27b5ed98d7d0' -def upgrade(): +def upgrade() -> None: op.alter_column( 'organization', 'contact_email', @@ -23,7 +24,7 @@ def upgrade(): ) -def downgrade(): +def downgrade() -> None: op.alter_column( 'organization', 'contact_email', diff --git a/migrations/versions/4ffee334e82e_init_models.py b/migrations/versions/4ffee334e82e_init_models.py index d7e8590a..036c6eee 100644 --- a/migrations/versions/4ffee334e82e_init_models.py +++ b/migrations/versions/4ffee334e82e_init_models.py @@ -1,4 +1,5 @@ -"""init models. +""" +init models. Revision ID: 4ffee334e82e Revises: None @@ -6,15 +7,15 @@ """ +import sqlalchemy as sa from alembic import op from sqlalchemy.dialects import postgresql -import sqlalchemy as sa revision = '4ffee334e82e' down_revision: str | None = None -def upgrade(): +def upgrade() -> None: op.create_table( 'organization', sa.Column('created_at', sa.DateTime(), nullable=False), @@ -300,7 +301,7 @@ def upgrade(): ) -def downgrade(): +def downgrade() -> None: op.drop_table('price') op.drop_table('line_item') op.drop_table('item_discount_policy') diff --git a/migrations/versions/50b7f36bb7eb_add_host_to_order_session.py b/migrations/versions/50b7f36bb7eb_add_host_to_order_session.py index c39d3065..c3ef3123 100644 --- a/migrations/versions/50b7f36bb7eb_add_host_to_order_session.py +++ b/migrations/versions/50b7f36bb7eb_add_host_to_order_session.py @@ -1,4 +1,5 @@ -"""add_host_to_order_session. +""" +add_host_to_order_session. Revision ID: 50b7f36bb7eb Revises: ca40e4eda72c @@ -6,17 +7,17 @@ """ -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision = '50b7f36bb7eb' down_revision = 'ca40e4eda72c' -def upgrade(): +def upgrade() -> None: op.add_column('order_session', sa.Column('host', sa.UnicodeText(), nullable=True)) -def downgrade(): +def downgrade() -> None: op.drop_column('order_session', 'host') diff --git a/migrations/versions/50f34bb47dc4_add_usage_limit_to_discount_coupon.py b/migrations/versions/50f34bb47dc4_add_usage_limit_to_discount_coupon.py index afcb3132..394e1f53 100644 --- a/migrations/versions/50f34bb47dc4_add_usage_limit_to_discount_coupon.py +++ b/migrations/versions/50f34bb47dc4_add_usage_limit_to_discount_coupon.py @@ -1,4 +1,5 @@ -"""add usage limit to discount coupon. +""" +add usage limit to discount coupon. Revision ID: 50f34bb47dc4 Revises: 2f6a3bb460b8 @@ -6,10 +7,10 @@ """ +import sqlalchemy as sa from alembic import op from sqlalchemy.dialects import postgresql from sqlalchemy.sql import column, table -import sqlalchemy as sa # revision identifiers, used by Alembic. revision = '50f34bb47dc4' @@ -30,7 +31,7 @@ ) -def upgrade(): +def upgrade() -> None: op.add_column( 'discount_coupon', sa.Column('usage_limit', sa.Integer(), nullable=False, server_default='1'), @@ -49,7 +50,7 @@ def upgrade(): op.drop_column('discount_coupon', 'used') -def downgrade(): +def downgrade() -> None: op.add_column( 'discount_coupon', sa.Column('used', sa.Boolean(), nullable=False, server_default='0'), diff --git a/migrations/versions/510d2db2b008_add_signed_code_support.py b/migrations/versions/510d2db2b008_add_signed_code_support.py index 813b5337..852cdc57 100644 --- a/migrations/versions/510d2db2b008_add_signed_code_support.py +++ b/migrations/versions/510d2db2b008_add_signed_code_support.py @@ -1,4 +1,5 @@ -"""add signed code support. +""" +add signed code support. Revision ID: 510d2db2b008 Revises: 74770336785 @@ -6,15 +7,15 @@ """ -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision = '510d2db2b008' down_revision = '74770336785' -def upgrade(): +def upgrade() -> None: # Note: discount_code_base was added in the initial migration (4ffee334e82e), # but was removed from the model for a period of time op.create_unique_constraint( @@ -33,7 +34,7 @@ def upgrade(): ) -def downgrade(): +def downgrade() -> None: op.drop_column('discount_policy', 'secret') op.drop_constraint( 'discount_policy_discount_code_base_key', 'discount_policy', type_='unique' diff --git a/migrations/versions/58f4f3c4fb01_add_bulk_coupon_usage_limit.py b/migrations/versions/58f4f3c4fb01_add_bulk_coupon_usage_limit.py index 94c41e29..f8ca45d7 100644 --- a/migrations/versions/58f4f3c4fb01_add_bulk_coupon_usage_limit.py +++ b/migrations/versions/58f4f3c4fb01_add_bulk_coupon_usage_limit.py @@ -1,4 +1,5 @@ -"""add_bulk_coupon_usage_limit. +""" +add_bulk_coupon_usage_limit. Revision ID: 58f4f3c4fb01 Revises: 48e571c759cb @@ -6,20 +7,20 @@ """ -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision = '58f4f3c4fb01' down_revision = '48e571c759cb' -def upgrade(): +def upgrade() -> None: op.add_column( 'discount_policy', sa.Column('bulk_coupon_usage_limit', sa.Integer(), nullable=True), ) -def downgrade(): +def downgrade() -> None: op.drop_column('discount_policy', 'bulk_coupon_usage_limit') diff --git a/migrations/versions/59d274a1682f_added_transaction_ref.py b/migrations/versions/59d274a1682f_added_transaction_ref.py index 309e6874..20e76477 100644 --- a/migrations/versions/59d274a1682f_added_transaction_ref.py +++ b/migrations/versions/59d274a1682f_added_transaction_ref.py @@ -1,4 +1,5 @@ -"""added_transaction_ref. +""" +added_transaction_ref. Revision ID: 59d274a1682f Revises: 11bc47d6d60b @@ -6,20 +7,20 @@ """ -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision = '59d274a1682f' down_revision = '11bc47d6d60b' -def upgrade(): +def upgrade() -> None: op.add_column( 'payment_transaction', sa.Column('transaction_ref', sa.Unicode(length=80), nullable=True), ) -def downgrade(): +def downgrade() -> None: op.drop_column('payment_transaction', 'transaction_ref') diff --git a/migrations/versions/66b67130c901_retroactively_migrate_previous_id_for_.py b/migrations/versions/66b67130c901_retroactively_migrate_previous_id_for_.py index c4352642..f1896f5e 100644 --- a/migrations/versions/66b67130c901_retroactively_migrate_previous_id_for_.py +++ b/migrations/versions/66b67130c901_retroactively_migrate_previous_id_for_.py @@ -1,4 +1,5 @@ -"""retroactively_migrate_previous_id_for_line_item. +""" +retroactively_migrate_previous_id_for_line_item. Revision ID: 66b67130c901 Revises: 171fcb171759 @@ -7,11 +8,14 @@ """ from collections import OrderedDict +from collections.abc import Iterable +from datetime import datetime +from typing import Any, Final +import sqlalchemy as sa from alembic import op from sqlalchemy.dialects import postgresql from sqlalchemy.sql import column, table -import sqlalchemy as sa # revision identifiers, used by Alembic. revision = '66b67130c901' @@ -31,19 +35,20 @@ class ORDER_STATUS: # noqa: N801 INVOICE = 2 CANCELLED = 3 - TRANSACTION = {SALES_ORDER, INVOICE, CANCELLED} + TRANSACTION: Final = {SALES_ORDER, INVOICE, CANCELLED} -def find_nearest_timestamp(lst, timestamp): +def find_nearest_timestamp(lst: list[datetime], timestamp: datetime) -> datetime | None: if not lst: return None nearest_ts = min(lst, key=lambda ts: abs(ts - timestamp).total_seconds()) if abs(nearest_ts - timestamp).total_seconds() < 1: return nearest_ts + return None -def set_previous_keys_for_line_items(line_items): - timestamped_line_items = OrderedDict() +def set_previous_keys_for_line_items(line_items: Iterable[Any]) -> list[dict]: + timestamped_line_items: OrderedDict[datetime, list[Any]] = OrderedDict() # Assemble the `timestamped_line_items` dictionary with the timestamp at which the # line items were created as the key, and the line items that were created at that @@ -72,7 +77,7 @@ def set_previous_keys_for_line_items(line_items): # timestamp with a void status with the same item_id. Find it and set it used_line_item_ids = set() for idx, (timestamp, _line_item_dicts) in enumerate( - timestamped_line_items.items()[1:] + list(timestamped_line_items.items())[1:] ): # 0th timestamps are root line items, so they're skipped since they don't need # their `previous_id` to be updated @@ -80,7 +85,7 @@ def set_previous_keys_for_line_items(line_items): # timestamped_line_items.keys()[idx] and not # timestamped_line_items.keys()[idx-1] because the # timestamped_line_items.items() list is enumerated from index 1 - previous_li_dict = [ + previous_li_dict = next( previous_li_dict for previous_li_dict in timestamped_line_items[ list(timestamped_line_items.keys())[idx] @@ -88,7 +93,7 @@ def set_previous_keys_for_line_items(line_items): if previous_li_dict['item_id'] == li_dict['item_id'] and previous_li_dict['id'] not in used_line_item_ids and previous_li_dict['status'] == LINE_ITEM_STATUS.VOID - ][0] + ) li_dict['previous_id'] = previous_li_dict['id'] used_line_item_ids.add(previous_li_dict['id']) @@ -114,7 +119,7 @@ def set_previous_keys_for_line_items(line_items): ) -def upgrade(): +def upgrade() -> None: conn = op.get_bind() orders = conn.execute( sa.select(order_table.c.id) @@ -124,12 +129,10 @@ def upgrade(): for order_id in [order.id for order in orders]: line_items = conn.execute( sa.select( - [ - line_item_table.c.id, - line_item_table.c.item_id, - line_item_table.c.status, - line_item_table.c.created_at, - ] + line_item_table.c.id, + line_item_table.c.item_id, + line_item_table.c.status, + line_item_table.c.created_at, ) .where(line_item_table.c.customer_order_id == order_id) .order_by(sa.text('created_at')) @@ -145,7 +148,7 @@ def upgrade(): ) -def downgrade(): +def downgrade() -> None: conn = op.get_bind() orders = conn.execute( sa.select(order_table.c.id) diff --git a/migrations/versions/6c04555d7d94_add_restricted_entry_to_item.py b/migrations/versions/6c04555d7d94_add_restricted_entry_to_item.py index 84f15d09..62070846 100644 --- a/migrations/versions/6c04555d7d94_add_restricted_entry_to_item.py +++ b/migrations/versions/6c04555d7d94_add_restricted_entry_to_item.py @@ -1,4 +1,5 @@ -"""add_restricted_entry_to_item. +""" +add_restricted_entry_to_item. Revision ID: 6c04555d7d94 Revises: 829f42c03de3 @@ -6,15 +7,15 @@ """ +import sqlalchemy as sa from alembic import op from sqlalchemy.sql import column, table -import sqlalchemy as sa revision = '6c04555d7d94' down_revision = '829f42c03de3' -def upgrade(): +def upgrade() -> None: item_table = table('item', column('restricted_entry', sa.Boolean())) op.add_column('item', sa.Column('restricted_entry', sa.Boolean(), nullable=True)) @@ -24,5 +25,5 @@ def upgrade(): ) -def downgrade(): +def downgrade() -> None: op.drop_column('item', 'restricted_entry') diff --git a/migrations/versions/74770336785_add_seq_to_item.py b/migrations/versions/74770336785_add_seq_to_item.py index 9856993d..c3a4102b 100644 --- a/migrations/versions/74770336785_add_seq_to_item.py +++ b/migrations/versions/74770336785_add_seq_to_item.py @@ -1,4 +1,5 @@ -"""add seq to item. +""" +add seq to item. Revision ID: 74770336785 Revises: 59d274a1682f @@ -10,10 +11,10 @@ revision = '74770336785' down_revision = '59d274a1682f' +import sqlalchemy as sa from alembic import op from sqlalchemy.dialects import postgresql from sqlalchemy.sql import column, table -import sqlalchemy as sa item_collection = table('item_collection', column('id', postgresql.UUID())) @@ -26,7 +27,7 @@ ) -def upgrade(): +def upgrade() -> None: op.add_column('item', sa.Column('seq', sa.Integer(), nullable=True)) connection = op.get_bind() item_collections = connection.execute(sa.select(item_collection.c.id)) @@ -45,5 +46,5 @@ def upgrade(): op.alter_column('item', 'seq', existing_type=sa.Integer(), nullable=False) -def downgrade(): +def downgrade() -> None: op.drop_column('item', 'seq') diff --git a/migrations/versions/7d180b95fcbe_sync_models_and_database.py b/migrations/versions/7d180b95fcbe_sync_models_and_database.py index c3a19b81..387305ea 100644 --- a/migrations/versions/7d180b95fcbe_sync_models_and_database.py +++ b/migrations/versions/7d180b95fcbe_sync_models_and_database.py @@ -1,4 +1,5 @@ -"""Sync models and database. +""" +Sync models and database. Revision ID: 7d180b95fcbe Revises: f78ca4cad5d6 @@ -6,8 +7,8 @@ """ -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision = '7d180b95fcbe' @@ -26,7 +27,7 @@ ] -def upgrade(): +def upgrade() -> None: for table, oldname, newname in renames: op.execute( sa.DDL( @@ -71,7 +72,7 @@ def upgrade(): ) -def downgrade(): +def downgrade() -> None: op.drop_constraint('line_item_previous_id_key', 'line_item', type_='unique') op.create_index( 'ix_line_item_previous_id', 'line_item', ['previous_id'], unique=True diff --git a/migrations/versions/81f30d00706f_add_pg_refundid_to_transaction.py b/migrations/versions/81f30d00706f_add_pg_refundid_to_transaction.py index 009f86d4..20040d04 100644 --- a/migrations/versions/81f30d00706f_add_pg_refundid_to_transaction.py +++ b/migrations/versions/81f30d00706f_add_pg_refundid_to_transaction.py @@ -1,4 +1,5 @@ -"""add_pg_refundid_to_transaction. +""" +add_pg_refundid_to_transaction. Revision ID: 81f30d00706f Revises: 1a22f5035244 @@ -6,15 +7,15 @@ """ -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision = '81f30d00706f' down_revision = '1a22f5035244' -def upgrade(): +def upgrade() -> None: op.add_column( 'payment_transaction', sa.Column('pg_refundid', sa.Unicode(length=80), nullable=True), @@ -33,7 +34,7 @@ def upgrade(): ) -def downgrade(): +def downgrade() -> None: op.drop_index( op.f('ix_payment_transaction_pg_refundid'), table_name='payment_transaction' ) diff --git a/migrations/versions/829f42c03de3_add_fy_to_invoice.py b/migrations/versions/829f42c03de3_add_fy_to_invoice.py index 915231c1..d471a2a6 100644 --- a/migrations/versions/829f42c03de3_add_fy_to_invoice.py +++ b/migrations/versions/829f42c03de3_add_fy_to_invoice.py @@ -1,4 +1,5 @@ -"""add_fy_to_invoice. +""" +add_fy_to_invoice. Revision ID: 829f42c03de3 Revises: 23fc9e293ac3 @@ -8,37 +9,36 @@ from datetime import datetime +import pytz +import sqlalchemy as sa from alembic import op from sqlalchemy.dialects import postgresql from sqlalchemy.sql import column, table -import pytz -import sqlalchemy as sa # revision identifiers, used by Alembic. revision = '829f42c03de3' down_revision = '23fc9e293ac3' -def naive_to_utc(dt, timezone=None): +def naive_to_utc( + dt: datetime, timezone: str | pytz.BaseTzInfo | None = None +) -> datetime: """ Return a UTC datetime for a given naive datetime or date object. Localizes it to the given timezone and converts it into a UTC datetime """ if timezone: - if isinstance(timezone, str): - tz = pytz.timezone(timezone) - else: - tz = timezone + tz = pytz.timezone(timezone) if isinstance(timezone, str) else timezone elif isinstance(dt, datetime) and dt.tzinfo: - tz = dt.tzinfo + tz = dt.tzinfo # type: ignore[assignment] else: tz = pytz.UTC return tz.localize(dt).astimezone(tz).astimezone(pytz.UTC) -def get_fiscal_year(jurisdiction, dt): +def get_fiscal_year(jurisdiction: str, dt: datetime) -> tuple[datetime, datetime]: """ Return the financial year for a given jurisdiction and timestamp. @@ -50,10 +50,7 @@ def get_fiscal_year(jurisdiction, dt): get_fiscal_year('IN', utcnow()) """ if jurisdiction.lower() == 'in': - if dt.month < 4: - start_year = dt.year - 1 - else: - start_year = dt.year + start_year = dt.year - 1 if dt.month < 4 else dt.year # starts on April 1 XXXX fy_start = datetime(start_year, 4, 1) # ends on April 1 XXXX + 1 @@ -76,7 +73,7 @@ def get_fiscal_year(jurisdiction, dt): ) -def upgrade(): +def upgrade() -> None: conn = op.get_bind() op.add_column('invoice', sa.Column('fy_end_at', sa.DateTime(), nullable=True)) op.add_column('invoice', sa.Column('fy_start_at', sa.DateTime(), nullable=True)) @@ -106,7 +103,7 @@ def upgrade(): op.alter_column('invoice', 'fy_end_at', existing_type=sa.DateTime(), nullable=False) -def downgrade(): +def downgrade() -> None: conn = op.get_bind() op.alter_column('invoice', 'fy_end_at', existing_type=sa.DateTime(), nullable=True) op.alter_column( diff --git a/migrations/versions/adb90a264e3_updated_purchase_order_line_items.py b/migrations/versions/adb90a264e3_updated_purchase_order_line_items.py index 06e74139..5159006e 100644 --- a/migrations/versions/adb90a264e3_updated_purchase_order_line_items.py +++ b/migrations/versions/adb90a264e3_updated_purchase_order_line_items.py @@ -1,4 +1,5 @@ -"""updated purchase order line items. +""" +updated purchase order line items. Revision ID: adb90a264e3 Revises: 10ac78260434 @@ -10,10 +11,10 @@ revision = 'adb90a264e3' down_revision = '10ac78260434' +import sqlalchemy as sa from alembic import op from sqlalchemy.dialects import postgresql from sqlalchemy.sql import column, table -import sqlalchemy as sa order = table( 'customer_order', @@ -29,7 +30,7 @@ ) -def upgrade(): +def upgrade() -> None: purchase_order_query = sa.select(order.c.id).where(order.c.status == 0) op.execute( line_item.update() @@ -38,5 +39,5 @@ def upgrade(): ) -def downgrade(): +def downgrade() -> None: op.execute(line_item.update().where(line_item.c.status == 2).values({'status': 0})) diff --git a/migrations/versions/ca40e4eda72c_add_transferable_until_to_item.py b/migrations/versions/ca40e4eda72c_add_transferable_until_to_item.py index df991e36..3e93f326 100644 --- a/migrations/versions/ca40e4eda72c_add_transferable_until_to_item.py +++ b/migrations/versions/ca40e4eda72c_add_transferable_until_to_item.py @@ -1,4 +1,5 @@ -"""add transferable until to item. +""" +add transferable until to item. Revision ID: ca40e4eda72c Revises: cdb214cf1e06 @@ -6,20 +7,20 @@ """ -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision = 'ca40e4eda72c' down_revision = 'cdb214cf1e06' -def upgrade(): +def upgrade() -> None: op.add_column( 'item', sa.Column('transferable_until', sa.TIMESTAMP(timezone=True), nullable=True), ) -def downgrade(): +def downgrade() -> None: op.drop_column('item', 'transferable_until') diff --git a/migrations/versions/cdb214cf1e06_switch_to_timestamptz.py b/migrations/versions/cdb214cf1e06_switch_to_timestamptz.py index ea7d5ce5..3e67fd33 100644 --- a/migrations/versions/cdb214cf1e06_switch_to_timestamptz.py +++ b/migrations/versions/cdb214cf1e06_switch_to_timestamptz.py @@ -1,4 +1,5 @@ -"""Switch to timestamptz. +""" +Switch to timestamptz. Revision ID: cdb214cf1e06 Revises: 7d180b95fcbe @@ -6,8 +7,8 @@ """ -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision = 'cdb214cf1e06' @@ -62,7 +63,7 @@ ] -def upgrade(): +def upgrade() -> None: for table, column in migrate_table_columns: op.execute( sa.DDL( @@ -73,7 +74,7 @@ def upgrade(): ) -def downgrade(): +def downgrade() -> None: for table, column in reversed(migrate_table_columns): op.execute( sa.DDL( diff --git a/migrations/versions/dadc5748932_rm_quantity_available.py b/migrations/versions/dadc5748932_rm_quantity_available.py index e37aa354..5ad8ece3 100644 --- a/migrations/versions/dadc5748932_rm_quantity_available.py +++ b/migrations/versions/dadc5748932_rm_quantity_available.py @@ -1,4 +1,5 @@ -"""rm quantity_available. +""" +rm quantity_available. Revision ID: dadc5748932 Revises: 253e7b76eb8e @@ -6,10 +7,10 @@ """ +import sqlalchemy as sa from alembic import op from sqlalchemy.dialects import postgresql from sqlalchemy.sql import column, table -import sqlalchemy as sa # revision identifiers, used by Alembic. revision = 'dadc5748932' @@ -30,12 +31,12 @@ ) -def upgrade(): +def upgrade() -> None: op.drop_constraint('item_quantity_available_lte_quantity_total_check', 'item') op.drop_column('item', 'quantity_available') -def downgrade(): +def downgrade() -> None: op.add_column( 'item', sa.Column( @@ -46,9 +47,11 @@ def downgrade(): item.update().values( { 'quantity_available': item.c.quantity_total - - line_item.count() + - sa.select(sa.func.count()) + .select_from(line_item) .where(line_item.c.item_id == item.c.id) .where(line_item.c.status == 0) + .scalar_subquery() } ) ) diff --git a/migrations/versions/f78ca4cad5d6_add_place_of_supply.py b/migrations/versions/f78ca4cad5d6_add_place_of_supply.py index 59b3bba8..594ee288 100644 --- a/migrations/versions/f78ca4cad5d6_add_place_of_supply.py +++ b/migrations/versions/f78ca4cad5d6_add_place_of_supply.py @@ -1,4 +1,5 @@ -"""add place of supply to item collection and item. +""" +add place of supply to item collection and item. Revision ID: f78ca4cad5d6 Revises: 6c04555d7d94 @@ -6,15 +7,15 @@ """ -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision = 'f78ca4cad5d6' down_revision = '6c04555d7d94' -def upgrade(): +def upgrade() -> None: op.add_column( 'item', sa.Column('place_supply_country_code', sa.Unicode(length=2), nullable=True), @@ -34,7 +35,7 @@ def upgrade(): op.add_column('item', sa.Column('event_date', sa.Date(), nullable=True)) -def downgrade(): +def downgrade() -> None: op.drop_column('item_collection', 'place_supply_state_code') op.drop_column('item_collection', 'place_supply_country_code') op.drop_column('item', 'place_supply_state_code') diff --git a/pyproject.toml b/pyproject.toml index f60547a3..5374307a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,13 +25,16 @@ exclude = ''' ''' [tool.isort] +# Some isort functionality is replicated in ruff, which should have matching config +profile = 'black' multi_line_output = 3 include_trailing_comma = true line_length = 88 order_by_type = true use_parentheses = true -from_first = true -known_future_library = ['__future__', 'six'] +from_first = false +combine_as_imports = true +split_on_trailing_comma = false known_repo = ['boxoffice'] known_first_party = ['baseframe', 'coaster', 'flask_lastuser'] default_section = 'THIRDPARTY' @@ -164,59 +167,7 @@ skips = ['**/*_test.py', '**/test_*.py', '**/conftest.py'] # This is a slight customisation of the default rules # 1. Rule E402 (module-level import not top-level) is disabled as isort handles it # 2. Rule E501 (line too long) is left to Black; some strings are worse for wrapping - -# Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default. -lint.select = ["E", "F"] -lint.ignore = ["E402"] - -# Allow autofix for all enabled rules (when `--fix`) is provided. -lint.fixable = [ - "A", - "B", - "C", - "D", - "E", - "F", - "G", - "I", - "N", - "Q", - "S", - "T", - "W", - "ANN", - "ARG", - "BLE", - "COM", - "DJ", - "DTZ", - "EM", - "ERA", - "EXE", - "FBT", - "ICN", - "INP", - "ISC", - "NPY", - "PD", - "PGH", - "PIE", - "PL", - "PT", - "PTH", - "PYI", - "RET", - "RSE", - "RUF", - "SIM", - "SLF", - "TCH", - "TID", - "TRY", - "UP", - "YTT", -] -lint.unfixable = [] +# 3. Rule S101 is disruptive in tests, so it's left to Bandit # Exclude a variety of commonly ignored directories. exclude = [ @@ -245,15 +196,108 @@ exclude = [ # Same as Black. line-length = 88 +# Target Python 3.11 +target-version = "py311" + +[tool.ruff.format] +docstring-code-format = true +quote-style = "preserve" + +[tool.ruff.lint] # Allow unused variables when underscore-prefixed. -lint.dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" -# Target Python 3.7 -target-version = "py311" +select = [ + "A", # flake8-builtins + "ANN", # flake8-annotations + "ARG", # flake8-unused-arguments + "ASYNC", # flake8-async + "B", # flake8-bugbear + "BLE", # flake8-blind-except + "C", # pylint convention + "D", # pydocstyle + "C4", # flake8-comprehensions + "C90", # mccabe + "E", # Error + "EM", # flake8-errmsg + "EXE", # flake8-executable + "F", # pyflakes + "FA", # flake8-future-annotations + "G", # flake8-logging-format + "I", # isort + "INP", # flake8-no-pep420 + "INT", # flake8-gettext + "ISC", # flake8-implicit-str-concat + "N", # pep8-naming + "PIE", # flake8-pie + "PT", # flake8-pytest-style + "PYI", # flake8-pyi + "RET", # flake8-return + "RUF", # Ruff + "S", # flake8-bandit + "SIM", # flake8-simplify + "SLOT", # flake8-slots + "T20", # flake8-print + "TRIO", # flake8-trio + "UP", # pyupgrade + "W", # Warnings + "YTT", # flake8-2020 +] +ignore = [ + "ANN002", + "ANN003", + "ANN101", + "ANN102", + "ANN401", + "C901", + "D100", + "D101", + "D102", + "D103", + "D104", + "D105", + "D107", + "D203", + "D212", + "E402", + "E501", + "ISC001", +] + +# Allow autofix for all enabled rules (when `--fix`) is provided. +fixable = ["ALL"] +unfixable = [] + +# Allow these characters in strings +allowed-confusables = ["‘", "’"] -[tool.ruff.lint.mccabe] # Unlike Flake8, default to a complexity level of 10. -max-complexity = 10 +mccabe.max-complexity = 10 + +[tool.ruff.lint.extend-per-file-ignores] +"__init__.py" = ["E402"] # Allow non-top-level imports +"tests/**.py" = ["S101", "ANN001"] # Allow assert; don't require fixture typing +"migrations/**.py" = ["INP001"] # Not a package +"scripts/**.py" = ["INP001", "ANN"] # Not a package; typing optional +"instance/**.py" = ["INP001"] # Not a package + +[tool.ruff.lint.isort] +# These config options should match isort config above under [tool.isort] +combine-as-imports = true +split-on-trailing-comma = false +relative-imports-order = 'furthest-to-closest' +known-first-party = ['baseframe', 'coaster', 'flask_lastuser'] +section-order = [ + 'future', + 'standard-library', + 'third-party', + 'first-party', + 'repo', + 'local-folder', +] + +[tool.ruff.lint.isort.sections] +repo = ['boxoffice'] [tool.ruff.lint.flake8-pytest-style] fixture-parentheses = false diff --git a/runtestserver.py b/runtestserver.py old mode 100644 new mode 100755 index 6dd15ee5..a782919e --- a/runtestserver.py +++ b/runtestserver.py @@ -35,4 +35,4 @@ def test_page() -> str: """ -app.run('0.0.0.0', 6500, debug=True) # nosec +app.run('0.0.0.0', 6500, debug=True) # nosec # noqa: S104 diff --git a/scripts/razorpay_refund.py b/scripts/razorpay_refund.py index 181e78d5..a4475b13 100644 --- a/scripts/razorpay_refund.py +++ b/scripts/razorpay_refund.py @@ -71,7 +71,7 @@ def get_refunds(date_ranges): return refunds -def write_refunds(filename, rows): +def write_refunds(filename, rows) -> None: with open(filename, 'w', encoding='utf-8') as csvfile: fieldnames = [ 'transaction_id', diff --git a/scripts/razorpay_settlement.py b/scripts/razorpay_settlement.py index 5476d4b6..ba073aa2 100644 --- a/scripts/razorpay_settlement.py +++ b/scripts/razorpay_settlement.py @@ -1,5 +1,5 @@ -from decimal import Decimal import csv +from decimal import Decimal import requests @@ -188,7 +188,7 @@ def get_settled_orders(date_ranges=(), filenames=()): # ) # ) - except Exception as error_msg: # noqa: B902 # pylint: disable=W0718 + except Exception as error_msg: # noqa: BLE001 # pylint: disable=W0718 print(error_msg) # noqa: T201 settlement_refund_ids = [ @@ -237,7 +237,7 @@ def get_settled_orders(date_ranges=(), filenames=()): ) ) ) - except: # noqa: B001, E722 # pylint: disable=bare-except + except: # noqa: E722 # pylint: disable=bare-except # FIXME: Add correct exception cancelled_line_item = LineItem.query.filter( LineItem.order == order, @@ -263,7 +263,7 @@ def get_settled_orders(date_ranges=(), filenames=()): return settled_orders -def write_settled_orders(filename, rows): +def write_settled_orders(filename, rows) -> None: with open(filename, 'w', encoding='utf-8') as csvfile: fieldnames = [ 'settlement_id', @@ -389,7 +389,7 @@ def get_settled_order_transactions(date_ranges=(), filenames=()): # ) # ) - except Exception as error_msg: # noqa: B902 # pylint: disable=W0718 + except Exception as error_msg: # noqa: BLE001 # pylint: disable=W0718 print(error_msg) # noqa: T201 settlement_refund_ids = [ @@ -440,7 +440,7 @@ def get_settled_order_transactions(date_ranges=(), filenames=()): ) ) refunded_line_item_ids.append(cancelled_line_item.id) - except: # noqa: B001, E722 # pylint: disable=bare-except + except: # noqa: E722 # pylint: disable=bare-except # FIXME: Add correct exception cancelled_line_item = LineItem.query.filter( ~LineItem.id.in_(refunded_line_item_ids), diff --git a/scripts/razorpay_settlement_old.py b/scripts/razorpay_settlement_old.py index 7b7b5f5e..aa046d16 100644 --- a/scripts/razorpay_settlement_old.py +++ b/scripts/razorpay_settlement_old.py @@ -1,5 +1,5 @@ -from decimal import Decimal import csv +from decimal import Decimal import requests @@ -160,7 +160,7 @@ def get_settled_orders(date_ranges=(), filenames=()): # ) # ) - except Exception as error_msg: # noqa: B902 # pylint: disable=W0718 + except Exception as error_msg: # noqa: BLE001 # pylint: disable=W0718 print(error_msg) # noqa: T201 settlement_refund_ids = [ @@ -209,7 +209,7 @@ def get_settled_orders(date_ranges=(), filenames=()): ) ) ) - except: # noqa: B001, E722 # pylint: disable=bare-except + except: # noqa: E722 # pylint: disable=bare-except # FIXME: Add the correct exception cancelled_line_item = LineItem.query.filter( LineItem.order == order, @@ -235,7 +235,7 @@ def get_settled_orders(date_ranges=(), filenames=()): return settled_orders -def write_settled_orders(filename, rows): +def write_settled_orders(filename, rows) -> None: with open(filename, 'w', encoding='utf-8') as csvfile: fieldnames = [ 'settlement_id', diff --git a/scripts/settled_order.py b/scripts/settled_order.py index 17c1a7f6..5ada8f6a 100644 --- a/scripts/settled_order.py +++ b/scripts/settled_order.py @@ -1,5 +1,5 @@ -from decimal import Decimal import csv +from decimal import Decimal from boxoffice.models import ( LineItem, @@ -178,7 +178,7 @@ def format_line_item(settlement_id, payment_id, line_item, payment_status): ) ) ) - except: # noqa: B001, E722 # pylint: disable=bare-except + except: # noqa: E722 # pylint: disable=bare-except # FIXME: Add correct exception print("Multiple line items found") # noqa: T201 print(payment.pg_paymentid) # noqa: T201 @@ -206,7 +206,7 @@ def format_line_item(settlement_id, payment_id, line_item, payment_status): return settled_orders -def write_settled_orders(filename, rows): +def write_settled_orders(filename, rows) -> None: with open(filename, 'w', encoding='utf-8') as csvfile: fieldnames = [ 'settlement_id', diff --git a/scripts/settlement.py b/scripts/settlement.py index 0cafd957..3b3c76fb 100644 --- a/scripts/settlement.py +++ b/scripts/settlement.py @@ -1,7 +1,8 @@ # settlements = {id: [{line_item_id, line_item_title, base_amount, final_amount}]} -from decimal import Decimal import csv +from decimal import Decimal +from typing import Any from pytz import timezone, utc @@ -22,7 +23,7 @@ def csv_to_rows(csv_file, skip_header=True, delimiter=','): return list(reader) -def rows_to_csv(rows, filename): +def rows_to_csv(rows, filename) -> bool: with open(filename, 'wb', encoding='utf-8') as csvfile: writer = csv.writer( csvfile, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL @@ -233,7 +234,8 @@ def get_line_items(filename): return line_items -def get_orders(settlement_filename): +def get_orders(settlement_filename: str) -> None: + payment_transaction: PaymentTransaction | None settlement_dicts = [] # Load input data with open(settlement_filename, encoding='utf-8') as csvfile: @@ -242,7 +244,7 @@ def get_orders(settlement_filename): settlement_dicts.append(row) # parse through input data - payment_orders = [] + payment_orders: list[dict[str, Any]] = [] for settlement_dict in settlement_dicts: if settlement_dict['type'] == 'payment' and settlement_dict[ 'entity_id' @@ -254,6 +256,7 @@ def get_orders(settlement_filename): payment_transaction = PaymentTransaction.query.filter_by( online_payment=payment, transaction_type=TransactionTypeEnum.PAYMENT ).one() + assert payment_transaction is not None # noqa: S101 # nosec B101 payment_orders.append( { 'entity_id': settlement_dict['entity_id'], @@ -310,7 +313,7 @@ def get_orders(settlement_filename): print( # noqa: T201 "no transaction for " + settlement_dict['entity_id'] ) - except: # noqa: B001, E722 # pylint: disable=bare-except + except: # noqa: E722 # pylint: disable=bare-except # FIXME: Trap the correct exception print( # noqa: T201 "no payment found for " diff --git a/scripts/settlement1.py b/scripts/settlement1.py index 76f6da3a..ed4259e9 100644 --- a/scripts/settlement1.py +++ b/scripts/settlement1.py @@ -1,7 +1,7 @@ # settlements = {id: [{line_item_id, line_item_title, base_amount, final_amount}]} -from decimal import Decimal import csv +from decimal import Decimal from pytz import timezone, utc @@ -19,7 +19,7 @@ def csv_to_rows(csv_file, skip_header=True, delimiter=','): return list(reader) -def rows_to_csv(rows, filename): +def rows_to_csv(rows, filename) -> bool: with open(filename, 'wb', encoding='utf-8') as csvfile: writer = csv.writer( csvfile, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL @@ -54,9 +54,9 @@ def get_settlements(filename): transaction_type=TransactionTypeEnum.PAYMENT, ).one_or_none() # Get settlement - settlement_amount = [ + settlement_amount = next( tr for tr in transactions if tr[0] == trans[11] - ][0][4] + )[4] settlements[trans[11]].append( { 'settlement_amount': settlement_amount, @@ -83,9 +83,9 @@ def get_settlements(filename): online_payment=payment, transaction_type=TransactionTypeEnum.REFUND, ).all() - settlement_amount = [ + settlement_amount = next( tr for tr in transactions if tr[0] == trans[11] - ][0][4] + )[4] for rt in refund_transactions: settlements[trans[11]].append( { diff --git a/scripts/sync_razorpay_refunds.py b/scripts/sync_razorpay_refunds.py index c9c01426..8c342655 100644 --- a/scripts/sync_razorpay_refunds.py +++ b/scripts/sync_razorpay_refunds.py @@ -1,5 +1,5 @@ -from decimal import Decimal import datetime +from decimal import Decimal import requests @@ -39,7 +39,7 @@ def epoch_dt(dt): return (dt - epoch).total_seconds() -def sync_refunds(): +def sync_refunds() -> None: def calc_correspongding_rp_refund(possible_rp_refunds, refund_epoch_dt): return min( possible_rp_refunds, @@ -69,13 +69,14 @@ def calc_correspongding_rp_refund(possible_rp_refunds, refund_epoch_dt): or amount_in_paise(refund.amount) != correspongding_rp_refund['amount'] ): - raise RuntimeError(f"Oops! No refund found for {refund.id}") + msg = f"Oops! No refund found for {refund.id}" + raise RuntimeError(msg) refund.pg_refundid = correspongding_rp_refund['id'] used_pg_refundids.append(correspongding_rp_refund['id']) db.session.commit() -def remove_duplicate_payments(): +def remove_duplicate_payments() -> None: orders = Order.query.all() for order in orders: payments = OnlinePayment.query.filter(OnlinePayment.order == order).all() @@ -109,7 +110,7 @@ def get_duplicate_payments(): return dupes -def import_missing_refunds(): +def import_missing_refunds() -> None: payments = OnlinePayment.query.all() for payment in payments: rp_refunds = get_refunds(payment.pg_paymentid) diff --git a/setup.cfg b/setup.cfg index 4b638655..075df2ae 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [flake8] -ignore = I100, I201, E124, E128, E203, E402, E704, W503, D100, D101, D102, D103, D104, D105, D107, D202, S101 +ignore = I100, I201, E124, E128, E203, E402, E704, W503, D100, D101, D102, D103, D104, D105, D107, D202, S101, ANN101, ANN102, ANN401 max-line-length = 88 exclude = node_modules enable-extensions = G diff --git a/tests/conftest.py b/tests/conftest.py index 98b4c936..66474755 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,12 +1,17 @@ # pylint: disable=redefined-outer-name +from collections.abc import Generator from datetime import date from types import SimpleNamespace +from typing import Self +from wsgiref.types import WSGIEnvironment +import pytest +import sqlalchemy as sa from dateutil.relativedelta import relativedelta +from flask.testing import FlaskClient +from flask_sqlalchemy import SQLAlchemy from sqlalchemy.orm import close_all_sessions from werkzeug.test import EnvironBuilder -import pytest -import sqlalchemy as sa from coaster.utils import utcnow @@ -26,13 +31,13 @@ @pytest.fixture(scope='session') -def database(request): +def database(request: pytest.FixtureRequest) -> SQLAlchemy: """Provide a database structure.""" with app.app_context(): db.create_all() @request.addfinalizer - def drop_tables(): + def drop_tables() -> None: with app.app_context(): db.drop_all() @@ -40,7 +45,7 @@ def drop_tables(): @pytest.fixture(scope='session') -def db_connection(database): +def db_connection(database: SQLAlchemy) -> Generator[sa.Connection, None, None]: """Return a database connection.""" with app.app_context(): yield database.engine.connect() @@ -49,21 +54,25 @@ def db_connection(database): class RemoveIsRollback: """Change session.remove() to session.rollback().""" - def __init__(self, session, rollback_provider): + def __init__(self, session, rollback_provider) -> None: self.session = session self.original_remove = session.remove self.rollback_provider = rollback_provider - def __enter__(self): + def __enter__(self) -> Self: """Replace ``session.remove`` during the `with` context.""" + return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type, exc_value, traceback) -> None: """Restore ``session.remove`` after the `with` context.""" self.session.remove = self.original_remove @pytest.fixture -def db_session(database, db_connection): +def db_session( + database: SQLAlchemy, + db_connection: sa.Connection, # noqa: ARG001 +) -> Generator[sa.orm.scoped_session, None, None]: """Empty the database after each use of the fixture.""" with RemoveIsRollback(database.session, lambda: database.session.rollback): yield database.session @@ -71,7 +80,7 @@ def db_session(database, db_connection): # Iterate through all database engines and empty their tables with app.app_context(): - for bind in [None] + list(app.config.get('SQLALCHEMY_BINDS') or ()): + for bind in [None, *list(app.config.get('SQLALCHEMY_BINDS') or ())]: engine = database.engines[bind] with engine.begin() as connection: connection.execute( @@ -92,25 +101,25 @@ def db_session(database, db_connection): # Enable autouse to guard against tests that have implicit database access, or assume # app context without a fixture @pytest.fixture(autouse=True) -def client(request): +def client(request: pytest.FixtureRequest) -> Generator[FlaskClient, None, None]: """Provide a test client.""" if 'noclient' in request.keywords: # To use this, annotate a test with: # @pytest.mark.noclient - yield None + yield None # type: ignore[misc] with app.test_client() as test_client: yield test_client @pytest.fixture -def post_env(): +def post_env() -> WSGIEnvironment: builder = EnvironBuilder(method='POST') return builder.get_environ() @pytest.fixture -def all_data(db_session): - user = User(userid="U3_JesHfQ2OUmdihAXaAGQ", email="test@hasgeek.com") +def all_data(db_session) -> SimpleNamespace: + user = User(userid='U3_JesHfQ2OUmdihAXaAGQ', email='test@example.com') db_session.add(user) db_session.commit() @@ -118,9 +127,9 @@ def all_data(db_session): rootconf = Organization( title='Rootconf', - userid="U3_JesHfQ2OUmdihAXaAGQ", + userid='U3_JesHfQ2OUmdihAXaAGQ', status=0, - contact_email='test@gmail.com', + contact_email='test@example.net', details={ 'service_tax_no': 'xx', 'address': '

XYZ

Bangalore - 560034

' @@ -129,9 +138,9 @@ def all_data(db_session): 'pan': 'abc', 'website': 'https://www.test.com', 'refund_policy': '

We offer full refund.

', - 'support_email': 'test@boxoffice.com', + 'support_email': 'test@example.org', 'ticket_faq': '

To cancel your ticket, please mail test@boxoffice.com with your receipt' + 'href="mailto:test@example.org">test@example.org with your receipt' ' number.

', }, ) diff --git a/tests/integration/test_items.py b/tests/integration/test_items.py index afa359bb..24a0e6f2 100644 --- a/tests/integration/test_items.py +++ b/tests/integration/test_items.py @@ -1,5 +1,9 @@ -from datetime import timedelta import json +from datetime import timedelta +from typing import Any + +import pytest +from werkzeug.test import TestResponse from coaster.utils import utcnow @@ -7,7 +11,7 @@ from boxoffice.models import Menu, Order, OrderStatus, Ticket -def ajax_post(client, url, data): +def ajax_post(client, url: str, data: Any) -> TestResponse: return client.post( url, data=json.dumps(data), @@ -19,8 +23,9 @@ def ajax_post(client, url, data): ) -def test_assign(db_session, client, all_data) -> None: - ticket = Ticket.query.filter_by(name="conference-ticket").one() +@pytest.mark.usefixtures('all_data') +def test_assign(db_session, client) -> None: + ticket = Ticket.query.filter_by(name='conference-ticket').one() data = { 'line_items': [{'ticket_id': str(ticket.id), 'quantity': 2}], 'buyer': { @@ -42,7 +47,7 @@ def test_assign(db_session, client, all_data) -> None: li_one = order.line_items[0] li_two = order.line_items[1] - # li_one has no assingee set yet, so it should be possible to set one + # li_one has no assignee set yet, so it should be possible to set one assert li_one.assignee is None assert li_one.current_assignee is None data = { @@ -154,7 +159,7 @@ def test_assign(db_session, client, all_data) -> None: assert json.loads(resp.data)['status'] == 'ok' # but ticket transfer has a hard deadline in the - # absense of 'transferable_until' - `event_date`. + # absence of 'transferable_until' - `event_date`. # so, if `transferable_until` is not set and `event_date` is in the past, # ticket transfer should fail. assert li_two.assignee is not None diff --git a/tests/test_kharcha.py b/tests/test_kharcha.py index 5351c0ed..11abb31e 100644 --- a/tests/test_kharcha.py +++ b/tests/test_kharcha.py @@ -1,7 +1,8 @@ -from typing import cast import decimal import json +from typing import cast +import pytest from flask import url_for from coaster.utils import make_name @@ -10,7 +11,8 @@ from boxoffice.models import DiscountCoupon, DiscountPolicy, Ticket -def test_undiscounted_kharcha(client, all_data) -> None: +@pytest.mark.usefixtures('all_data') +def test_undiscounted_kharcha(client) -> None: first_item = Ticket.query.filter_by(name='conference-ticket').one() undiscounted_quantity = 2 kharcha_req = { @@ -50,7 +52,8 @@ def test_undiscounted_kharcha(client, all_data) -> None: ) -def test_expired_ticket_kharcha(client, all_data) -> None: +@pytest.mark.usefixtures('all_data') +def test_expired_ticket_kharcha(client) -> None: expired_ticket = Ticket.query.filter_by(name='expired-ticket').one() quantity = 2 kharcha_req = { @@ -75,7 +78,8 @@ def test_expired_ticket_kharcha(client, all_data) -> None: ) -def test_expired_discounted_item_kharcha(client, all_data) -> None: +@pytest.mark.usefixtures('all_data') +def test_expired_discounted_item_kharcha(client) -> None: expired_ticket = Ticket.query.filter_by(name='expired-ticket').one() quantity = 2 coupon = DiscountCoupon.query.filter_by(code='couponex').one() @@ -101,7 +105,8 @@ def test_expired_discounted_item_kharcha(client, all_data) -> None: ) -def test_discounted_bulk_kharcha(client, all_data) -> None: +@pytest.mark.usefixtures('all_data') +def test_discounted_bulk_kharcha(client) -> None: first_item = Ticket.query.filter_by(name='conference-ticket').one() discounted_quantity = 10 kharcha_req = { @@ -150,7 +155,8 @@ def test_discounted_bulk_kharcha(client, all_data) -> None: assert expected_policy_id in policy_ids -def test_discounted_coupon_kharcha(client, all_data) -> None: +@pytest.mark.usefixtures('all_data') +def test_discounted_coupon_kharcha(client) -> None: first_item = Ticket.query.filter_by(name='conference-ticket').one() coupon: DiscountCoupon = DiscountCoupon.query.filter_by(code='coupon1').one() discounted_quantity = 1 @@ -199,7 +205,8 @@ def test_discounted_coupon_kharcha(client, all_data) -> None: assert expected_policy_id in policy_ids -def test_signed_discounted_coupon_kharcha(client, all_data) -> None: +@pytest.mark.usefixtures('all_data') +def test_signed_discounted_coupon_kharcha(client) -> None: first_item = Ticket.query.filter_by(name='conference-ticket').one() signed_policy = DiscountPolicy.query.filter_by(name='signed').one() code = signed_policy.gen_signed_code() @@ -247,7 +254,8 @@ def test_signed_discounted_coupon_kharcha(client, all_data) -> None: assert expected_policy_id in policy_ids -def test_unlimited_coupon_kharcha(client, all_data) -> None: +@pytest.mark.usefixtures('all_data') +def test_unlimited_coupon_kharcha(client) -> None: first_item = Ticket.query.filter_by(name='conference-ticket').one() coupon_code = 'unlimited' discounted_quantity = 5 @@ -297,7 +305,8 @@ def test_unlimited_coupon_kharcha(client, all_data) -> None: assert str(expected_policy_id) in policy_ids -def test_coupon_limit(client, all_data) -> None: +@pytest.mark.usefixtures('all_data') +def test_coupon_limit(client) -> None: first_item = Ticket.query.filter_by(name='conference-ticket').one() coupon = DiscountCoupon.query.filter_by(code='coupon1').one() discounted_quantity = 2 @@ -344,7 +353,8 @@ def test_coupon_limit(client, all_data) -> None: assert expected_policy_id in policy_ids -def test_discounted_price_kharcha(client, all_data) -> None: +@pytest.mark.usefixtures('all_data') +def test_discounted_price_kharcha(client) -> None: first_item = Ticket.query.filter_by(name='conference-ticket').one() coupon = DiscountCoupon.query.filter_by(code='forever').one() discounted_quantity = 1 @@ -366,9 +376,9 @@ def test_discounted_price_kharcha(client, all_data) -> None: assert resp.status_code == 200 resp_json = json.loads(resp.get_data()) - discounted_price = [ + discounted_price = next( price for price in first_item.prices if price.name == 'forever-early-geek' - ][0] + ) assert resp_json.get('line_items')[str(first_item.id)].get('final_amount') == int( discounted_price.amount ) @@ -385,7 +395,8 @@ def test_discounted_price_kharcha(client, all_data) -> None: assert expected_policy_id in policy_ids -def test_discount_policy_without_price_kharcha(client, all_data) -> None: +@pytest.mark.usefixtures('all_data') +def test_discount_policy_without_price_kharcha(client) -> None: first_item = Ticket.query.filter_by(name='conference-ticket').one() coupon = DiscountCoupon.query.filter_by(code='noprice').one() discounted_quantity = 1 @@ -414,7 +425,8 @@ def test_discount_policy_without_price_kharcha(client, all_data) -> None: ) == decimal.Decimal(0) -def test_zero_discounted_price_kharcha(client, all_data) -> None: +@pytest.mark.usefixtures('all_data') +def test_zero_discounted_price_kharcha(client) -> None: first_item = Ticket.query.filter_by(name='conference-ticket').one() coupon = DiscountCoupon.query.filter_by(code='zerodi').one() discounted_quantity = 1 @@ -435,9 +447,9 @@ def test_zero_discounted_price_kharcha(client, all_data) -> None: ) assert resp.status_code == 200 resp_json = json.loads(resp.get_data()) - discounted_price = [ + discounted_price = next( price for price in first_item.prices if price.name == 'zero-discount' - ][0] + ) assert resp_json.get('line_items')[str(first_item.id)].get('final_amount') != int( discounted_price.amount ) @@ -454,7 +466,8 @@ def test_zero_discounted_price_kharcha(client, all_data) -> None: assert expected_policy_id not in policy_ids -def test_discounted_complex_kharcha(client, all_data) -> None: +@pytest.mark.usefixtures('all_data') +def test_discounted_complex_kharcha(client) -> None: first_item = Ticket.query.filter_by(name='conference-ticket').one() discounted_quantity = 9 coupon2 = DiscountCoupon.query.filter_by(code='coupon2').one() diff --git a/tests/test_menu.py b/tests/test_menu.py index 6d6eb650..a454dffd 100644 --- a/tests/test_menu.py +++ b/tests/test_menu.py @@ -2,6 +2,7 @@ import json import pytest +from werkzeug.test import TestResponse from boxoffice import app from boxoffice.models import Menu @@ -56,7 +57,7 @@ @pytest.fixture -def resp(client, all_data): +def resp(client, all_data) -> TestResponse: # noqa: ARG001 menu = Menu.query.one() return client.get( f'/menu/{menu.id}', diff --git a/tests/test_order.py b/tests/test_order.py index ced92606..5d60159c 100644 --- a/tests/test_order.py +++ b/tests/test_order.py @@ -1,10 +1,12 @@ -from unittest.mock import patch -from uuid import UUID import datetime import decimal import json +from typing import Any +from unittest.mock import patch +from uuid import UUID import pytest +from werkzeug.test import TestResponse from coaster.utils import buid @@ -32,15 +34,16 @@ class MockResponse: - def __init__(self, response_data, status_code=200): + def __init__(self, response_data: Any, status_code=200) -> None: self.response_data = response_data self.status_code = status_code - def json(self): + def json(self) -> Any: return self.response_data -def test_basic(client, all_data) -> None: +@pytest.mark.usefixtures('all_data') +def test_basic(client) -> None: ticket = Ticket.query.filter_by(name='conference-ticket').one() data = { 'line_items': [{'ticket_id': str(ticket.id), 'quantity': 2}], @@ -69,7 +72,8 @@ def test_basic(client, all_data) -> None: assert resp_data['final_amount'] == 7000 -def test_basic_with_utm_headers(client, all_data) -> None: +@pytest.mark.usefixtures('all_data') +def test_basic_with_utm_headers(client) -> None: ticket = Ticket.query.filter_by(name='conference-ticket').one() utm_campaign = 'campaign' utm_medium = 'medium' @@ -119,7 +123,8 @@ def test_basic_with_utm_headers(client, all_data) -> None: assert order_session.gclid == gclid -def test_order_with_invalid_quantity(client, all_data) -> None: +@pytest.mark.usefixtures('all_data') +def test_order_with_invalid_quantity(client) -> None: ticket = Ticket.query.filter_by(name='conference-ticket').one() data = { 'line_items': [{'ticket_id': str(ticket.id), 'quantity': 1001}], @@ -142,7 +147,8 @@ def test_order_with_invalid_quantity(client, all_data) -> None: assert resp.status_code == 400 -def test_simple_discounted_item(client, all_data) -> None: +@pytest.mark.usefixtures('all_data') +def test_simple_discounted_item(client) -> None: discounted_item = Ticket.query.filter_by(name='t-shirt').one() data = { 'line_items': [{'ticket_id': str(discounted_item.id), 'quantity': 5}], @@ -167,7 +173,8 @@ def test_simple_discounted_item(client, all_data) -> None: assert resp_data['final_amount'] == 2375 -def test_expired_ticket_order(client, all_data) -> None: +@pytest.mark.usefixtures('all_data') +def test_expired_ticket_order(client) -> None: expired_ticket = Ticket.query.filter_by(name='expired-ticket').one() quantity = 2 data = { @@ -191,7 +198,8 @@ def test_expired_ticket_order(client, all_data) -> None: assert resp.status_code == 400 -def test_signed_discounted_coupon_order(client, all_data) -> None: +@pytest.mark.usefixtures('all_data') +def test_signed_discounted_coupon_order(client) -> None: first_item = Ticket.query.filter_by(name='conference-ticket').one() signed_policy = DiscountPolicy.query.filter_by(name='signed').one() signed_code = signed_policy.gen_signed_code() @@ -232,7 +240,8 @@ def test_signed_discounted_coupon_order(client, all_data) -> None: assert line_item.discount_coupon.code == signed_code -def test_complex_discounted_item(client, all_data) -> None: +@pytest.mark.usefixtures('all_data') +def test_complex_discounted_item(client) -> None: discounted_item1 = Ticket.query.filter_by(name='t-shirt').one() discounted_item2 = Ticket.query.filter_by(name='conference-ticket').one() data = { @@ -262,7 +271,8 @@ def test_complex_discounted_item(client, all_data) -> None: assert resp_data['final_amount'] == 33875 -def test_discounted_complex_order(client, all_data) -> None: +@pytest.mark.usefixtures('all_data') +def test_discounted_complex_order(client) -> None: conf = Ticket.query.filter_by(name='conference-ticket').one() tshirt = Ticket.query.filter_by(name='t-shirt').one() conf_current_price = conf.current_price() @@ -324,7 +334,7 @@ def test_discounted_complex_order(client, all_data) -> None: ) -def make_free_order(client): +def make_free_order(client) -> TestResponse: ticket = Ticket.query.filter_by(name='conference-ticket').one() data = { 'line_items': [{'ticket_id': str(ticket.id), 'quantity': 1}], @@ -336,7 +346,7 @@ def make_free_order(client): 'discount_coupons': ['coupon2'], } menu = Menu.query.one() - resp = client.post( + return client.post( f'/menu/{menu.id}/order', data=json.dumps(data), content_type='application/json', @@ -345,10 +355,10 @@ def make_free_order(client): ('Origin', app.config['BASE_URL']), ], ) - return resp -def test_free_order(client, all_data) -> None: +@pytest.mark.usefixtures('all_data') +def test_free_order(client) -> None: resp = make_free_order(client) assert resp.status_code == 201 resp_json = json.loads(resp.data)['result'] @@ -374,7 +384,8 @@ def test_free_order(client, all_data) -> None: ) -def test_cancel_line_item_in_order(db_session, client, all_data, post_env) -> None: +@pytest.mark.usefixtures('all_data') +def test_cancel_line_item_in_order(db_session, client, post_env) -> None: original_quantity = 2 order_item = Ticket.query.filter_by(name='t-shirt').one() current_price = order_item.current_price() @@ -458,7 +469,8 @@ def test_cancel_line_item_in_order(db_session, client, all_data, post_env) -> No assert refund_transaction1.amount == expected_refund_amount -def test_cancel_line_item_in_bulk_order(db_session, client, all_data, post_env) -> None: +@pytest.mark.usefixtures('all_data') +def test_cancel_line_item_in_bulk_order(db_session, client, post_env) -> None: original_quantity = 5 discounted_item = Ticket.query.filter_by(name='t-shirt').one() current_price = discounted_item.current_price() @@ -543,7 +555,8 @@ def test_cancel_line_item_in_bulk_order(db_session, client, all_data, post_env) assert refund_transaction2.amount == second_line_item.final_amount # test failed cancellation - third_line_item = order.confirmed_line_items[0] + third_line_item = order.confirmed_line_items.first() + assert third_line_item is not None with patch('boxoffice.extapi.razorpay.refund_payment') as mock: mock.return_value = MockResponse( response_data={ @@ -584,7 +597,8 @@ def test_cancel_line_item_in_bulk_order(db_session, client, all_data, post_env) with app.request_context(post_env): process_partial_refund_for_order(partial_refund_args) - third_line_item = order.confirmed_line_items[0] + third_line_item = order.confirmed_line_items.first() + assert third_line_item is not None pre_cancellation_transactions_count = order.refund_transactions.count() cancelled_refund_amount = process_line_item_cancellation(third_line_item) assert cancelled_refund_amount == decimal.Decimal(0) @@ -601,7 +615,8 @@ def test_cancel_line_item_in_bulk_order(db_session, client, all_data, post_env) assert free_order.transactions.count() == 0 -def test_partial_refund_in_order(db_session, client, all_data, post_env) -> None: +@pytest.mark.usefixtures('all_data') +def test_partial_refund_in_order(db_session, client, post_env) -> None: original_quantity = 5 discounted_item = Ticket.query.filter_by(name='t-shirt').one() current_price = discounted_item.current_price()