diff --git a/utilmeta/core/orm/fields/field.py b/utilmeta/core/orm/fields/field.py index 7f49c12..ba1c24f 100644 --- a/utilmeta/core/orm/fields/field.py +++ b/utilmeta/core/orm/fields/field.py @@ -47,6 +47,13 @@ def __init__(self, model: "ModelAdaptor" = None, **kwargs): def reconstruct(self, model: "ModelAdaptor"): return self.__class__(model, **self._kwargs) + def check_schema_cls(self, schema_cls): + if isinstance(schema_cls, type): + if self.model.qualify(schema_cls): + raise TypeError(f'You are using a model class: {schema_cls} to used as schema query class, ' + f'which is invalid, you should make a schema class using ' + f'orm.Schema[{schema_cls.__name__}]') + def get_query_schema(self): parser = None schema = None @@ -69,6 +76,8 @@ def get_query_schema(self): parser = cls_parser schema = arg break + else: + self.check_schema_cls(arg) else: if ( self.type.__origin__ @@ -83,6 +92,8 @@ def get_query_schema(self): if cls_parser: parser = cls_parser schema = arg + else: + self.check_schema_cls(arg) else: self.related_single = True @@ -93,7 +104,8 @@ def get_query_schema(self): parser = cls_parser schema = origin break - + else: + self.check_schema_cls(origin) if parser: if isinstance(parser, SchemaClassParser): if parser.model: diff --git a/utilmeta/core/response/backends/base.py b/utilmeta/core/response/backends/base.py index 751336c..7ee252e 100644 --- a/utilmeta/core/response/backends/base.py +++ b/utilmeta/core/response/backends/base.py @@ -9,6 +9,8 @@ class ResponseAdaptor(BaseAdaptor): json_decoder_cls = json.JSONDecoder def __init__(self, response): + if not self.qualify(response): + raise TypeError(f'Invalid response: {response}') self.response = response # self.request = request self._context = {} diff --git a/utilmeta/core/server/backends/base.py b/utilmeta/core/server/backends/base.py index c570494..4dafd9a 100644 --- a/utilmeta/core/server/backends/base.py +++ b/utilmeta/core/server/backends/base.py @@ -3,7 +3,7 @@ if TYPE_CHECKING: from utilmeta import UtilMeta from utilmeta.core.api import API -from utilmeta.utils import BaseAdaptor, exceptions, import_obj +from utilmeta.utils import BaseAdaptor, exceptions, import_obj, Error import re import inspect from utilmeta.core.request import Request @@ -24,6 +24,9 @@ def process_request(self, request: Request): def process_response(self, response: Response): pass + # def handle_error(self, error: Error): + # pass + class ServerAdaptor(BaseAdaptor): # __backends_route__ = 'backends' diff --git a/utilmeta/core/server/backends/starlette.py b/utilmeta/core/server/backends/starlette.py index 2635438..510656b 100644 --- a/utilmeta/core/server/backends/starlette.py +++ b/utilmeta/core/server/backends/starlette.py @@ -4,6 +4,7 @@ from starlette.requests import Request as StarletteRequest from starlette.applications import Starlette from starlette.concurrency import iterate_in_threadpool +from starlette.responses import Response as StarletteResponse from starlette.middleware.base import _StreamingResponse from .base import ServerAdaptor from utilmeta.core.response import Response @@ -11,10 +12,9 @@ from utilmeta.core.response.backends.starlette import StarletteResponseAdaptor from utilmeta.core.api import API from utilmeta.core.request import Request -from utilmeta.utils import HAS_BODY_METHODS, RequestType, exceptions +from utilmeta.utils import HAS_BODY_METHODS, RequestType, exceptions, pop, Error import contextvars from typing import Optional -from urllib.parse import urlparse _current_request = contextvars.ContextVar("_starlette.request") # _current_response = contextvars.ContextVar('_starlette.response') @@ -100,6 +100,7 @@ def setup_middlewares(self): self.app.add_middleware( BaseHTTPMiddleware, dispatch=self.get_middleware_func() # noqa ) + # self.app.add_exception_handler(Exception, handler=self.get_exception_handler) @classmethod async def get_response_body(cls, starlette_response: _StreamingResponse) -> bytes: @@ -107,10 +108,63 @@ async def get_response_body(cls, starlette_response: _StreamingResponse) -> byte starlette_response.body_iterator = iterate_in_threadpool(iter(response_body)) return b"".join(response_body) + # async def get_exception_handler(self, req, exc): + # # async def http_error_handler(_: Request, exc) -> JSONResponse: + # # return JSONResponse({"errors": [exc.detail]}, status_code=exc.status_code) + # handlers = dict(self.app.exception_handlers or {}) + # pop(handlers, Exception) + # request = _current_request.get(None) or Request(self.request_adaptor_cls(req)) + # error = Error(exc, request=request) + # for middleware in self.middlewares: + # middleware.handle_error(error) + # err_handler = None + # for cls in type(exc).__mro__: + # if cls in handlers: + # err_handler = handlers[cls] + # break + # if not err_handler: + # raise exc from exc + # from starlette.concurrency import run_in_threadpool + # try: + # if inspect.iscoroutinefunction(err_handler): + # resp = await err_handler(req, exc) + # else: + # resp = await run_in_threadpool(err_handler, req, exc) + # except Exception as handler_exc: + # handler_error = Error(handler_exc, request=request) + # for middleware in self.middlewares: + # middleware.handle_error(handler_error) + # middleware.process_response(Response(error=handler_error)) + # # there is probably not response + # raise + # + # adaptor = self.response_adaptor_cls(resp) + # if ( + # adaptor.content_length or 0 + # ) <= self.RECORD_RESPONSE_BODY_LENGTH_LTE: + # body = await self.get_response_body(resp) + # resp.body = body + # # set body + # response = Response(response=adaptor, request=request) + # response_updated = False + # for middleware in self.middlewares: + # _response = middleware.process_response(response) + # if inspect.isawaitable(_response): + # _response = await _response + # if isinstance(_response, Response): + # response = _response + # response_updated = True + # + # if not resp or response_updated: + # resp = self.response_adaptor_cls.reconstruct(response) + # + # return resp + def get_middleware_func(self): async def utilmeta_middleware(starlette_request: StarletteRequest, call_next): response = None starlette_response = None + exc = None request = Request(self.request_adaptor_cls(starlette_request)) for middleware in self.middlewares: @@ -135,11 +189,38 @@ async def utilmeta_middleware(starlette_request: StarletteRequest, call_next): # and you cannot read it after response is generated _current_request.set(request) - starlette_response: Optional[_StreamingResponse] = await call_next( - starlette_request - ) + try: + starlette_response: Optional[_StreamingResponse] = await call_next( + starlette_request + ) + except Exception as e: + handlers = dict(self.app.exception_handlers or {}) # noqa + err_handler = None + for cls in type(e).__mro__: + if cls in handlers: + err_handler = handlers[cls] + break + error = Error(e, request=request) + if err_handler: + from starlette.concurrency import run_in_threadpool + if inspect.iscoroutinefunction(err_handler): + starlette_response = await err_handler(starlette_request, e) + else: + starlette_response = await run_in_threadpool(err_handler, starlette_request, e) + adaptor = self.response_adaptor_cls(starlette_response) + if ( + adaptor.content_length or 0 + ) <= self.RECORD_RESPONSE_BODY_LENGTH_LTE: + body = await self.get_response_body(starlette_response) + starlette_response.body = body + response = Response(response=adaptor, error=error) + else: + starlette_response = StarletteResponse(status_code=500) # noqa: placeholder response + response = Response(response=response, error=error, request=request) + exc = e + _current_request.set(None) - response = request.adaptor.get_context("response") + response = response or request.adaptor.get_context("response") # response = _current_response.get(None) # _current_response.set(None) @@ -170,6 +251,9 @@ async def utilmeta_middleware(starlette_request: StarletteRequest, call_next): response = _response response_updated = True + if exc: + raise exc from exc + if not starlette_response or response_updated: starlette_response = self.response_adaptor_cls.reconstruct(response) diff --git a/utilmeta/ops/config.py b/utilmeta/ops/config.py index 3cc997c..8eadf86 100644 --- a/utilmeta/ops/config.py +++ b/utilmeta/ops/config.py @@ -205,14 +205,6 @@ def __init__( self.local_scope = list(local_scope or []) self.report_disabled = report_disabled - if base_url: - parsed = urlsplit(base_url) - if not parsed.scheme: - raise ValueError( - f"Operations base_url should be an absolute url, got {base_url}" - ) - self._base_url = self.parse_base_url(base_url) - if self.HOST not in self.trusted_hosts: self.trusted_hosts.append(self.HOST) if not isinstance(monitor, self.Monitor): @@ -239,12 +231,22 @@ def __init__( ) self.proxy = proxy - @classmethod - def parse_base_url(cls, url: str): + if base_url: + parsed = urlsplit(base_url) + if not parsed.scheme: + raise ValueError( + f"Operations base_url should be an absolute url, got {base_url}" + ) + self._base_url = self.parse_base_url(base_url) + + def parse_base_url(self, url: str): if not url: return url if "$IP" in url: - url = url.replace("$IP", get_server_ip()) + ip = get_server_ip(private_only=bool(self.proxy)) + if self.proxy: + ip = ip or get_server_ip() or '127.0.0.1' + url = url.replace("$IP", ip) return url def load_openapi(self, no_store: bool = False): @@ -688,9 +690,13 @@ def proxy_base_url(self): except ImportError: return None origin = self.proxy_origin - if not service.adaptor.backend_views_empty: - return origin - return url_join(origin, service.root_url) + route = service.root_url + if self._base_url: + parsed = urlsplit(self._base_url) + if parsed.scheme: + # is url + route = parsed.path + return url_join(origin, route) # def check_host(self): # parsed = urlsplit(self.ops_api) diff --git a/utilmeta/ops/log.py b/utilmeta/ops/log.py index cec6f6b..714222e 100644 --- a/utilmeta/ops/log.py +++ b/utilmeta/ops/log.py @@ -493,6 +493,12 @@ def process_response(self, response: Response): if len(_responses_queue) >= self.config.max_backlog: threading.Thread(target=batch_save_logs, kwargs=dict(close=True)).start() + # def handle_error(self, error: Error, response=None): + # logger: Logger = _logger.get(None) + # if not logger: + # raise error.throw() + # logger.commit_error(error) + class Logger(Property): __context__ = ContextProperty(_logger) diff --git a/utilmeta/utils/functional/sys.py b/utilmeta/utils/functional/sys.py index c75d429..77b42fd 100644 --- a/utilmeta/utils/functional/sys.py +++ b/utilmeta/utils/functional/sys.py @@ -4,6 +4,7 @@ from datetime import datetime from typing import Optional, List, Union, Tuple, Dict, Set from .. import constant +from .data import distinct_add from ipaddress import ip_address, ip_network posix_os = os.name == "posix" @@ -169,16 +170,16 @@ def get_network_ip(ifname: str): return None -def get_server_ips(max_devices: int = 3) -> Set[str]: +def get_server_ips(max_devices: int = 3) -> List[str]: ip = socket.gethostbyname(socket.gethostname()) s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - ips = set() + ips = [] for i in range(0, max_devices): if_ip = get_network_ip(f"eth{i}") if if_ip: if not if_ip.startswith("127."): - ips.add(if_ip) + distinct_add(ips, if_ip) else: break @@ -187,10 +188,10 @@ def get_server_ips(max_devices: int = 3) -> Set[str]: s.connect(("8.8.8.8", 53)) ip = str(s.getsockname()[0]) if ip: - ips.add(ip) + distinct_add(ips, ip) s.close() else: - ips.add(ip) + distinct_add(ips, ip) return ips @@ -226,7 +227,7 @@ def get_server_ip(private_only: bool = False) -> Optional[str]: except ValueError: continue - _SERVER_IP = ips.pop() if ips else constant.LOCAL_IP + _SERVER_IP = ips[0] if ips else constant.LOCAL_IP if private_only: _SERVER_PRIVATE_IP = constant.LOCAL_IP return _SERVER_PRIVATE_IP