Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Replace pyjwt with authlib in org.matrix.login.jwt #13011

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/13011.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Replaced usage of PyJWT with methods from Authlib in `org.matrix.login.jwt`. Contributed by Hannes Lerchl.
35 changes: 23 additions & 12 deletions docs/jwt.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,27 +37,27 @@ As with other login types, there are additional fields (e.g. `device_id` and
## Preparing Synapse

The JSON Web Token integration in Synapse uses the
[`PyJWT`](https://pypi.org/project/pyjwt/) library, which must be installed
[`Authlib`](https://docs.authlib.org/en/latest/index.html) library, which must be installed
as follows:

* The relevant libraries are included in the Docker images and Debian packages
provided by `matrix.org` so no further action is needed.
* The relevant libraries are included in the Docker images and Debian packages
provided by `matrix.org` so no further action is needed.

* If you installed Synapse into a virtualenv, run `/path/to/env/bin/pip
install synapse[pyjwt]` to install the necessary dependencies.
* If you installed Synapse into a virtualenv, run `/path/to/env/bin/pip
install synapse[jwt]` to install the necessary dependencies.

* For other installation mechanisms, see the documentation provided by the
maintainer.
* For other installation mechanisms, see the documentation provided by the
maintainer.

To enable the JSON web token integration, you should then add an `jwt_config` section
To enable the JSON web token integration, you should then add a `jwt_config` section
to your configuration file (or uncomment the `enabled: true` line in the
existing section). See [sample_config.yaml](./sample_config.yaml) for some
sample settings.

## How to test JWT as a developer

Although JSON Web Tokens are typically generated from an external server, the
examples below use [PyJWT](https://pyjwt.readthedocs.io/en/latest/) directly.
example below uses a locally generated JWT.

1. Configure Synapse with JWT logins, note that this example uses a pre-shared
secret and an algorithm of HS256:
Expand All @@ -70,10 +70,21 @@ examples below use [PyJWT](https://pyjwt.readthedocs.io/en/latest/) directly.
```
2. Generate a JSON web token:

```bash
$ pyjwt --key=my-secret-token --alg=HS256 encode sub=test-user
eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJ0ZXN0LXVzZXIifQ.Ag71GT8v01UO3w80aqRPTeuVPBIBZkYhNTJJ-_-zQIc
You can use the following short Python snippet to generate a JWT
protected by an HMAC.
Take care that the `secret` and the algorithm given in the `header` match
the entries from `jwt_config` above.

```python
from authlib.jose import jwt

header = {"alg": "HS256"}
payload = {"sub": "user1", "aud": ["audience"]}
secret = "my-secret-token"
result = jwt.encode(header, payload, secret)
print(result.decode("ascii"))
```

3. Query for the login types and ensure `org.matrix.login.jwt` is there:

```bash
Expand Down
6 changes: 4 additions & 2 deletions docs/usage/configuration/config_documentation.md
Original file line number Diff line number Diff line change
Expand Up @@ -2946,8 +2946,10 @@ Additional sub-options for this setting include:
tokens. Defaults to false.
* `secret`: This is either the private shared secret or the public key used to
decode the contents of the JSON web token. Required if `enabled` is set to true.
* `algorithm`: The algorithm used to sign the JSON web token. Supported algorithms are listed at
https://pyjwt.readthedocs.io/en/latest/algorithms.html Required if `enabled` is set to true.
* `algorithm`: The algorithm used to sign (or HMAC) the JSON web token.
Supported algorithms are listed
[here (section JWS)](https://docs.authlib.org/en/latest/specs/rfc7518.html).
Required if `enabled` is set to true.
* `subject_claim`: Name of the claim containing a unique identifier for the user.
Optional, defaults to `sub`.
* `issuer`: The issuer to validate the "iss" claim against. Optional. If provided the
Expand Down
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 2 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ lxml = { version = ">=4.2.0", optional = true }
sentry-sdk = { version = ">=0.7.2", optional = true }
opentracing = { version = ">=2.2.0", optional = true }
jaeger-client = { version = ">=4.0.0", optional = true }
pyjwt = { version = ">=1.6.4", optional = true }
txredisapi = { version = ">=1.4.7", optional = true }
hiredis = { version = "*", optional = true }
Pympler = { version = "*", optional = true }
Expand All @@ -196,7 +195,7 @@ systemd = ["systemd-python"]
url_preview = ["lxml"]
sentry = ["sentry-sdk"]
opentracing = ["jaeger-client", "opentracing"]
jwt = ["pyjwt"]
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
jwt = ["authlib"]
# hiredis is not a *strict* dependency, but it makes things much faster.
# (if it is not installed, we fall back to slow code.)
redis = ["txredisapi", "hiredis"]
Expand All @@ -222,16 +221,14 @@ all = [
"psycopg2", "psycopg2cffi", "psycopg2cffi-compat",
# saml2
"pysaml2",
# oidc
# oidc and jwt
"authlib",
# url_preview
"lxml",
# sentry
"sentry-sdk",
# opentracing
"jaeger-client", "opentracing",
# jwt
"pyjwt",
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
# redis
"txredisapi", "hiredis",
# cache_memory
Expand Down
10 changes: 5 additions & 5 deletions synapse/config/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@

from ._base import Config, ConfigError

MISSING_JWT = """Missing jwt library. This is required for jwt login.
MISSING_AUTHLIB = """Missing authlib library. This is required for jwt login.

Install by running:
pip install pyjwt
pip install synapse[jwt]
"""


Expand All @@ -43,11 +43,11 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.jwt_audiences = jwt_config.get("audiences")

try:
import jwt
from authlib.jose import JsonWebToken

jwt # To stop unused lint.
JsonWebToken # To stop unused lint.
except ImportError:
raise ConfigError(MISSING_JWT)
raise ConfigError(MISSING_AUTHLIB)
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
else:
self.jwt_enabled = False
self.jwt_secret = None
Expand Down
46 changes: 38 additions & 8 deletions synapse/rest/client/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,25 +420,55 @@ async def _do_jwt_login(
403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN
)

import jwt
from authlib.jose import JsonWebToken, JWTClaims
from authlib.jose.errors import BadSignatureError, InvalidClaimError, JoseError

jwt = JsonWebToken([self.jwt_algorithm])
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
claim_options = {}
if self.jwt_issuer is not None:
claim_options["iss"] = {"value": self.jwt_issuer, "essential": True}
if self.jwt_audiences is not None:
claim_options["aud"] = {"values": self.jwt_audiences, "essential": True}
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved

try:
payload = jwt.decode(
claims = jwt.decode(
token,
self.jwt_secret,
algorithms=[self.jwt_algorithm],
issuer=self.jwt_issuer,
audience=self.jwt_audiences,
key=self.jwt_secret,
claims_cls=JWTClaims,
claims_options=claim_options,
)
except BadSignatureError:
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
# We handle this case separately to provide a better error message
raise LoginError(
403,
"JWT validation failed: Signature verification failed",
errcode=Codes.FORBIDDEN,
)
except jwt.PyJWTError as e:
except JoseError as e:
# A JWT error occurred, return some info back to the client.
raise LoginError(
403,
"JWT validation failed: %s" % (str(e),),
errcode=Codes.FORBIDDEN,
)

user = payload.get(self.jwt_subject_claim, None)
try:
claims.validate(leeway=120) # allows 2 min of clock skew

# Enforce the old behavior which is rolled out in productive
# servers: if the JWT contains an 'aud' claim but none is
# configured, the login attempt will fail
if claims.get("aud") is not None:
if self.jwt_audiences is None or len(self.jwt_audiences) == 0:
raise InvalidClaimError("aud")
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
except JoseError as e:
raise LoginError(
403,
"JWT validation failed: %s" % (str(e),),
errcode=Codes.FORBIDDEN,
)
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved

user = claims.get(self.jwt_subject_claim, None)
if user is None:
raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)

Expand Down
44 changes: 23 additions & 21 deletions tests/rest/client/test_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import json
import time
import urllib.parse
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional
from unittest.mock import Mock
from urllib.parse import urlencode

Expand All @@ -41,7 +41,7 @@
from tests.unittest import HomeserverTestCase, override_config, skip_unless

try:
import jwt
from authlib.jose import jwk, jwt

HAS_JWT = True
except ImportError:
Expand Down Expand Up @@ -841,7 +841,7 @@ def test_deactivated_user(self) -> None:
self.assertIn(b"SSO account deactivated", channel.result["body"])


@skip_unless(HAS_JWT, "requires jwt")
@skip_unless(HAS_JWT, "requires authlib")
class JWTTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
Expand All @@ -866,11 +866,9 @@ def default_config(self) -> Dict[str, Any]:
return config

def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
result: Union[str, bytes] = jwt.encode(payload, secret, self.jwt_algorithm)
if isinstance(result, bytes):
return result.decode("ascii")
return result
header = {"alg": self.jwt_algorithm}
result: bytes = jwt.encode(header, payload, secret)
return result.decode("ascii")

def jwt_login(self, *args: Any) -> FakeChannel:
params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
Expand Down Expand Up @@ -902,7 +900,8 @@ def test_login_jwt_expired(self) -> None:
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"], "JWT validation failed: Signature has expired"
channel.json_body["error"],
"JWT validation failed: expired_token: The token is expired",
)

def test_login_jwt_not_before(self) -> None:
Expand All @@ -912,7 +911,7 @@ def test_login_jwt_not_before(self) -> None:
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
"JWT validation failed: The token is not yet valid (nbf)",
"JWT validation failed: invalid_token: The token is not valid yet",
)

def test_login_no_sub(self) -> None:
Expand All @@ -934,7 +933,8 @@ def test_login_iss(self) -> None:
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"], "JWT validation failed: Invalid issuer"
channel.json_body["error"],
'JWT validation failed: invalid_claim: Invalid claim "iss"',
)

# Not providing an issuer.
Expand All @@ -943,7 +943,7 @@ def test_login_iss(self) -> None:
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
'JWT validation failed: Token is missing the "iss" claim',
'JWT validation failed: missing_claim: Missing "iss" claim',
)

def test_login_iss_no_config(self) -> None:
Expand All @@ -965,7 +965,8 @@ def test_login_aud(self) -> None:
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"], "JWT validation failed: Invalid audience"
channel.json_body["error"],
'JWT validation failed: invalid_claim: Invalid claim "aud"',
)

# Not providing an audience.
Expand All @@ -974,7 +975,7 @@ def test_login_aud(self) -> None:
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
'JWT validation failed: Token is missing the "aud" claim',
'JWT validation failed: missing_claim: Missing "aud" claim',
)

def test_login_aud_no_config(self) -> None:
Expand All @@ -983,7 +984,8 @@ def test_login_aud_no_config(self) -> None:
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"], "JWT validation failed: Invalid audience"
channel.json_body["error"],
'JWT validation failed: invalid_claim: Invalid claim "aud"',
)

def test_login_default_sub(self) -> None:
Expand All @@ -1010,7 +1012,7 @@ def test_login_no_token(self) -> None:
# The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use
# RSS256, with a public key configured in synapse as "jwt_secret", and tokens
# signed by the private key.
@skip_unless(HAS_JWT, "requires jwt")
@skip_unless(HAS_JWT, "requires authlib")
class JWTPubKeyTestCase(unittest.HomeserverTestCase):
servlets = [
login.register_servlets,
Expand Down Expand Up @@ -1071,11 +1073,11 @@ def default_config(self) -> Dict[str, Any]:
return config

def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
result: Union[bytes, str] = jwt.encode(payload, secret, "RS256")
if isinstance(result, bytes):
return result.decode("ascii")
return result
header = {"alg": "RS256"}
if secret.startswith("-----BEGIN RSA PRIVATE KEY-----"):
secret = jwk.dumps(secret, kty="RSA")
result: bytes = jwt.encode(header, payload, secret)
return result.decode("ascii")

def jwt_login(self, *args: Any) -> FakeChannel:
params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
Expand Down