-
Notifications
You must be signed in to change notification settings - Fork 44
/
Copy pathmiddleware.py
514 lines (414 loc) · 19.7 KB
/
middleware.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
"""
Custom ASGI app middleware.
These middleware are based on [Starlette](https://www.starlette.io)'s `BaseHTTPMiddleware`.
See the specific Starlette [documentation page](https://www.starlette.io/middleware/) for more
information on it's middleware implementation.
"""
import json
import re
import urllib.parse
import warnings
from collections.abc import Generator, Iterable
from typing import TextIO
from starlette.datastructures import URL as StarletteURL
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import RedirectResponse, StreamingResponse
from optimade.exceptions import BadRequest, VersionNotSupported
from optimade.models import Warnings
from optimade.server.config import CONFIG
from optimade.server.routers.utils import BASE_URL_PREFIXES, get_base_url
from optimade.warnings import (
FieldValueNotRecognized,
LocalOptimadeWarning,
OptimadeWarning,
QueryParamNotUsed,
TooManyValues,
)
class EnsureQueryParamIntegrity(BaseHTTPMiddleware):
"""Ensure all query parameters are followed by an equal sign (`=`)."""
@staticmethod
def check_url(url_query: str) -> set:
"""Check parsed URL query part for parameters not followed by `=`.
URL query parameters are considered to be split by ampersand (`&`)
and semi-colon (`;`).
Parameters:
url_query: The raw urllib-parsed query part.
Raises:
BadRequest: If a query parameter does not come with a value.
Returns:
The set of individual query parameters and their values.
This is mainly for testing and not actually neeeded by the middleware,
since if the URL exhibits an invalid query part a `400 Bad Request`
response will be returned.
"""
queries_amp = set(url_query.split("&"))
queries = set()
for query in queries_amp:
queries.update(set(query.split(";")))
for query in queries:
if "=" not in query and query != "":
raise BadRequest(
detail="A query parameter without an equal sign (=) is not supported by this server"
)
return queries # Useful for testing
async def dispatch(self, request: Request, call_next):
parsed_url = urllib.parse.urlsplit(str(request.url))
if parsed_url.query:
self.check_url(parsed_url.query)
response = await call_next(request)
return response
class CheckWronglyVersionedBaseUrls(BaseHTTPMiddleware):
"""If a non-supported versioned base URL is supplied return `553 Version Not Supported`."""
@staticmethod
def check_url(url: StarletteURL):
"""Check URL path for versioned part.
Parameters:
url: A complete urllib-parsed raw URL.
Raises:
VersionNotSupported: If the URL represents an OPTIMADE versioned base URL
and the version part is not supported by the implementation.
"""
base_url = get_base_url(url)
optimade_path = f"{url.scheme}://{url.netloc}{url.path}"[len(base_url) :]
match = re.match(r"^(?P<version>/v[0-9]+(\.[0-9]+){0,2}).*", optimade_path)
if match is not None:
if match.group("version") not in BASE_URL_PREFIXES.values():
raise VersionNotSupported(
detail=(
f"The parsed versioned base URL {match.group('version')!r} from "
f"{url} is not supported by this implementation. "
f"Supported versioned base URLs are: {', '.join(BASE_URL_PREFIXES.values())}"
)
)
async def dispatch(self, request: Request, call_next):
if request.url.path:
self.check_url(request.url)
response = await call_next(request)
return response
class HandleApiHint(BaseHTTPMiddleware):
"""Handle `api_hint` query parameter."""
@staticmethod
def handle_api_hint(api_hint: list[str]) -> None | str:
"""Handle `api_hint` parameter value.
There are several scenarios that can play out, when handling the `api_hint`
query parameter:
If several `api_hint` query parameters have been used, or a "standard" JSON
list (`,`-separated value) has been supplied, a warning will be added to the
response and the `api_hint` query parameter will not be applied.
If the passed value does not comply with the rules set out in
[the specification](https://github.com/Materials-Consortia/OPTIMADE/blob/v1.0.0/optimade.rst#version-negotiation),
a warning will be added to the response and the `api_hint` query parameter
will not be applied.
If the value is part of the implementation's accepted versioned base URLs,
it will be returned as is.
If the value represents a major version that is newer than what is supported
by the implementation, a `553 Version Not Supported` response will be returned,
as is stated by [the specification](https://github.com/Materials-Consortia/OPTIMADE/blob/v1.0.0/optimade.rst#version-negotiation).
On the other hand, if the value represents a major version equal to or lower
than the implementation's supported major version, then the implementation's
supported major version will be returned and tried for the request.
Parameters:
api_hint: The urllib-parsed query parameter value for `api_hint`.
Raises:
VersionNotSupported: If the requested major version is newer than the
supported major version of the implementation.
Returns:
Either a valid `api_hint` value or `None`.
"""
# Try to split by `,` if value is provided once, but in JSON-type "list" format
_api_hint = []
for value in api_hint:
values = value.split(",")
_api_hint.extend(values)
if len(_api_hint) > 1:
warnings.warn(
TooManyValues(
detail="`api_hint` should only be supplied once, with a single value."
)
)
return None
api_hint_str: str = f"/{_api_hint[0]}"
if re.match(r"^/v[0-9]+(\.[0-9]+)?$", api_hint_str) is None:
warnings.warn(
FieldValueNotRecognized(
detail=f"{api_hint_str[1:]!r} is not recognized as a valid `api_hint` value."
)
)
return None
if api_hint_str in BASE_URL_PREFIXES.values():
return api_hint_str
major_api_hint = int(re.findall(r"/v([0-9]+)", api_hint_str)[0])
major_implementation = int(BASE_URL_PREFIXES["major"][len("/v") :])
if major_api_hint <= major_implementation:
# If less than:
# Use the current implementation in hope that it can still handle older requests
#
# If equal:
# Go to /v<MAJOR>, since this should point to the latest available
return BASE_URL_PREFIXES["major"]
# Let's not try to handle a request for a newer major version
raise VersionNotSupported(
detail=(
f"The provided `api_hint` ({api_hint_str[1:]!r}) is not supported by this implementation. "
f"Supported versions include: {', '.join(BASE_URL_PREFIXES.values())}"
)
)
@staticmethod
def is_versioned_base_url(url: str) -> bool:
"""Determine whether a request is for a versioned base URL.
First, simply check whether a `/vMAJOR(.MINOR.PATCH)` part exists in the URL.
If not, return `False`, else, remove unversioned base URL from the URL and check again.
Return `bool` of final result.
Parameters:
url: The full URL to check.
Returns:
Whether or not the full URL represents an OPTIMADE versioned base URL.
"""
if not re.findall(r"(/v[0-9]+(\.[0-9]+){0,2})", url):
return False
base_url = get_base_url(url)
return bool(re.findall(r"(/v[0-9]+(\.[0-9]+){0,2})", url[len(base_url) :]))
async def dispatch(self, request: Request, call_next):
parsed_query = urllib.parse.parse_qs(request.url.query, keep_blank_values=True)
if "api_hint" in parsed_query:
if self.is_versioned_base_url(str(request.url)):
warnings.warn(
QueryParamNotUsed(
detail=(
"`api_hint` provided with value{:s} '{:s}' for a versioned base URL. "
"In accordance with the specification, this will not be handled by "
"the implementation.".format(
"s" if len(parsed_query["api_hint"]) > 1 else "",
"', '".join(parsed_query["api_hint"]),
)
)
)
)
else:
from optimade.server.routers.utils import get_base_url
version_path = self.handle_api_hint(parsed_query["api_hint"])
if version_path:
base_url = get_base_url(request.url)
new_request = (
f"{base_url}{version_path}{str(request.url)[len(base_url) :]}"
)
url = urllib.parse.urlsplit(new_request)
q = "&".join(
[
f"{key}={value}"
for key, value in urllib.parse.parse_qsl(
url.query, keep_blank_values=True
)
if key != "api_hint"
]
)
return RedirectResponse(
request.url.replace(path=url.path, query=q),
headers=request.headers,
)
# This is the non-URL changing solution:
#
# scope = request.scope
# scope["path"] = path
# request = Request(scope=scope, receive=request.receive, send=request._send)
response = await call_next(request)
return response
class AddWarnings(BaseHTTPMiddleware):
"""
Add [`OptimadeWarning`][optimade.warnings.OptimadeWarning]s to the response.
All sub-classes of [`OptimadeWarning`][optimade.warnings.OptimadeWarning]
will also be added to the response's
[`meta.warnings`][optimade.models.optimade_json.ResponseMeta.warnings] list.
By overriding the `warnings.showwarning()` function with the
[`showwarning` method][optimade.server.middleware.AddWarnings.showwarning],
all usages of `warnings.warn()` will result in the regular printing of the
warning message to `stderr`, but also its addition to an in-memory list of
warnings.
This middleware will, after the URL request has been handled, add the list of
accumulated warnings to the JSON response under the
[`meta.warnings`][optimade.models.optimade_json.ResponseMeta.warnings] field.
To make sure the last part happens correctly and a Starlette `StreamingResponse`
is returned, as is expected from a `BaseHTTPMiddleware` sub-class, one is
instantiated with the updated `Content-Length` header, as well as making sure
the response's body content is actually streamable, by breaking it down into
chunks of the original response's chunk size.
!!! warning "Important"
It is **recommended** to add this middleware as the _last one_ to your application.
This is to ensure it is invoked _first_, updating `warnings.showwarning()` and
catching all warnings that should be added to the response.
This can be achieved by applying `AddWarnings` _after_ all
other middleware with the `.add_middleware()` method, or by
initialising the app with a middleware list in which `AddWarnings`
appears _first_. More information can be found in the docstring of
[`OPTIMADE_MIDDLEWARE`][optimade.server.middleware.OPTIMADE_MIDDLEWARE].
Attributes:
_warnings (List[Warnings]): List of [`Warnings`][optimade.models.optimade_json.Warnings]
added through usages of `warnings.warn()` via [`showwarning`][optimade.server.middleware.AddWarnings.showwarning].
"""
_warnings: list[Warnings]
def showwarning(
self,
message: Warning | str,
category: type[Warning],
filename: str,
lineno: int,
file: TextIO | None = None,
line: str | None = None,
) -> None:
"""
Hook to write a warning to a file using the built-in `warnings` lib.
In [the documentation](https://docs.python.org/3/library/warnings.html)
for the built-in `warnings` library, there are a few recommended ways of
customizing the printing of warning messages.
This method can override the `warnings.showwarning` function,
which is called as part of the `warnings` library's workflow to print
warning messages, e.g., when using `warnings.warn()`.
Originally, it prints warning messages to `stderr`.
This method will also print warning messages to `stderr` by calling
`warnings._showwarning_orig()` or `warnings._showwarnmsg_impl()`.
The first function will be called if the issued warning is not recognized
as an [`OptimadeWarning`][optimade.warnings.OptimadeWarning].
This is equivalent to "standard behaviour".
The second function will be called _after_ an
[`OptimadeWarning`][optimade.warnings.OptimadeWarning] has been handled.
An [`OptimadeWarning`][optimade.warnings.OptimadeWarning] will be
translated into an OPTIMADE Warnings JSON object in accordance with
[the specification](https://github.com/Materials-Consortia/OPTIMADE/blob/v1.0.0/optimade.rst#json-response-schema-common-fields).
This process is similar to the [Exception handlers][optimade.server.exception_handlers].
Parameters:
message: The `Warning` object to show and possibly handle.
category: `Warning` type being warned about. This amounts to `type(message)`.
filename: Name of the file, where the warning was issued.
lineno: Line number in the file, where the warning was issued.
file: A file-like object to which the warning should be written.
line: Source content of the line that issued the warning.
"""
assert isinstance(message, Warning), (
"'message' is expected to be a Warning or subclass thereof."
)
if not isinstance(message, OptimadeWarning):
# If the Warning is not an OptimadeWarning or subclass thereof,
# use the regular 'showwarning' function.
warnings._showwarning_orig(message, category, filename, lineno, file, line) # type: ignore[attr-defined]
return
if isinstance(message, LocalOptimadeWarning):
return
# Format warning
try:
title = str(message.title)
except AttributeError:
title = str(message.__class__.__name__)
try:
detail = str(message.detail)
except AttributeError:
detail = str(message)
if CONFIG.debug:
if line is None:
# All this is taken directly from the warnings library.
# See 'warnings._formatwarnmsg_impl()' for the original code.
try:
import linecache
line = linecache.getline(filename, lineno)
except Exception:
# When a warning is logged during Python shutdown, linecache
# and the import machinery don't work anymore
line = None
meta = {
"filename": filename,
"lineno": lineno,
}
if line:
meta["line"] = line.strip()
if CONFIG.debug:
new_warning = Warnings(title=title, detail=detail, meta=meta)
else:
new_warning = Warnings(title=title, detail=detail)
# Add new warning to self._warnings
self._warnings.append(new_warning.model_dump(exclude_unset=True))
# Show warning message as normal in sys.stderr
warnings._showwarnmsg_impl( # type: ignore[attr-defined]
warnings.WarningMessage(message, category, filename, lineno, file, line)
)
@staticmethod
def chunk_it_up(content: str | bytes, chunk_size: int) -> Generator:
"""Return generator for string in chunks of size `chunk_size`.
Parameters:
content: String or bytes content to separate into chunks.
chunk_size: The size of the chunks, i.e. the length of the string-chunks.
Returns:
A Python generator to be converted later to an `asyncio` generator.
"""
if chunk_size <= 0:
chunk_size = 1
return (content[i : chunk_size + i] for i in range(0, len(content), chunk_size))
async def dispatch(self, request: Request, call_next):
self._warnings = []
warnings.simplefilter(action="default", category=OptimadeWarning)
warnings.showwarning = self.showwarning
response = await call_next(request)
status = response.status_code
headers = response.headers
media_type = response.media_type
background = response.background
charset = response.charset
body = b""
chunk_size = 0
async for chunk in response.body_iterator:
chunk_size = chunk_size or len(chunk)
if not isinstance(chunk, bytes):
chunk = chunk.encode(charset)
body += chunk
body_str = body.decode(charset)
if self._warnings:
response = json.loads(body_str)
response.get("meta", {})["warnings"] = self._warnings
body_str = json.dumps(response)
if "content-length" in headers:
headers["content-length"] = str(len(body_str))
response = StreamingResponse(
content=self.chunk_it_up(body_str, chunk_size),
status_code=status,
headers=headers,
media_type=media_type,
background=background,
)
return response
OPTIMADE_MIDDLEWARE: Iterable[BaseHTTPMiddleware] = (
EnsureQueryParamIntegrity,
CheckWronglyVersionedBaseUrls,
HandleApiHint,
AddWarnings,
)
"""A tuple of all the middleware classes that implement certain required
features of the OPTIMADE specification, e.g. warnings and URL
versioning.
!!! note
The order in which middleware is added to an application matters.
As discussed in the docstring of
[`AddWarnings`][optimade.server.middleware.AddWarnings], this
middleware is the final entry to this list so that it is the first
to be applied by the server.
Any other middleware should therefore be added _before_ iterating
through this variable.
This is the opposite way around to the example in the
[Starlette documentation](https://www.starlette.io/middleware/)
which initialises the application with a pre-built middleware list
in the _reverse_ order to `OPTIMADE_MIDDLEWARE`.
To use this variable in FastAPI app code after initialisation:
```python
from fastapi import FastAPI
app = FastAPI()
for middleware in OPTIMADE_MIDDLEWARE:
app.add_middleware(middleware)
```
Alternatively, to use this variable on initialisation:
```python
from fastapi import FastAPI
from starlette.middleware import Middleware
app = FastAPI(
...,
middleware=[Middleware(m) for m in reversed(OPTIMADE_MIDDLEWARE)]
)
```
"""