Skip to content
This repository has been archived by the owner on Mar 3, 2020. It is now read-only.

Commit

Permalink
Merge pull request #22 from edx/jsa/implicit2
Browse files Browse the repository at this point in the history
Refactor implicit flow.
  • Loading branch information
Jim Abramson committed Dec 1, 2015
2 parents 87f872c + da1cb1c commit 0382e3f
Show file tree
Hide file tree
Showing 6 changed files with 282 additions and 204 deletions.
2 changes: 1 addition & 1 deletion provider/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.3.0'
__version__ = '0.4.0'
1 change: 1 addition & 0 deletions provider/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
WRITE = 1 << 2
READ_WRITE = READ | WRITE

# NOTE that DEFAULT_SCOPES[0] (i.e. READ / 'read') is the default OAuth2 scope, per section 3.3 of rfc6749.
DEFAULT_SCOPES = (
(READ, 'read'),
(WRITE, 'write'),
Expand Down
122 changes: 95 additions & 27 deletions provider/oauth2/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import urlparse

import ddt
from django.conf import settings
from django.contrib.auth.models import User
from django.core.urlresolvers import reverse
Expand Down Expand Up @@ -64,6 +65,7 @@ def _login_and_authorize(self, url_func=None):
self.assertTrue(self.redirect_url() in response['Location'])


@ddt.ddt
class AuthorizationTest(BaseOAuth2TestCase):
fixtures = ['test_oauth2']

Expand All @@ -89,17 +91,40 @@ def test_authorization_requires_login(self):

self.assertTrue(self.auth_url2() in response['Location'])

def test_authorization_requires_client_id(self):
@ddt.data(
('read', 'read'),
('write', 'write'),
('read+write', 'read write read+write'),
)
@ddt.unpack
def test_implicit_flow(self, requested_scope, expected_scope):
"""
End-to-end test of the implicit flow (happy path).
"""
self.login()
response = self.client.get(self.auth_url())
self.client.get(self.auth_url(), data=self.get_auth_params(response_type='token', scope=requested_scope))
response = self.client.post(self.auth_url2(), {'authorize': True})
fragment = urlparse.urlparse(response['Location']).fragment
auth_response_data = {k: v[0] for k, v in urlparse.parse_qs(fragment).items()}
self.assertEqual(auth_response_data['scope'], expected_scope)
self.assertEqual(auth_response_data['access_token'], AccessToken.objects.all()[0].token)
self.assertEqual(auth_response_data['token_type'], 'Bearer')
self.assertEqual(int(auth_response_data['expires_in']), constants.EXPIRE_DELTA.days * 60 * 60 * 24 - 1)
self.assertNotIn('refresh_token', response)

@ddt.data('code', 'token')
def test_authorization_requires_client_id(self, response_type):
self.login()
self.client.get(self.auth_url(), data={'response_type': response_type})
response = self.client.get(self.auth_url2())

self.assertEqual(400, response.status_code)
self.assertTrue("An unauthorized client tried to access your resources." in response.content)

def test_authorization_rejects_invalid_client_id(self):
@ddt.data('code', 'token')
def test_authorization_rejects_invalid_client_id(self, response_type):
self.login()
response = self.client.get(self.auth_url(), data={"client_id": 123})
response = self.client.get(self.auth_url(), data={"client_id": 123, 'response_type': response_type})
response = self.client.get(self.auth_url2())

self.assertEqual(400, response.status_code)
Expand All @@ -113,22 +138,19 @@ def test_authorization_requires_response_type(self):
self.assertEqual(400, response.status_code)
self.assertTrue(escape(u"No 'response_type' supplied.") in response.content)

def test_authorization_requires_supported_response_type(self):
@ddt.data('code', 'token', 'unsupported')
def test_authorization_requires_supported_response_type(self, response_type):
self.login()
response = self.client.get(
self.auth_url(), self.get_auth_params(response_type="unsupported"))
self.auth_url(), self.get_auth_params(response_type=response_type))
response = self.client.get(self.auth_url2())

self.assertEqual(400, response.status_code)
self.assertTrue(escape(u"'unsupported' is not a supported response type.") in response.content)
if response_type == 'unsupported':
self.assertEqual(400, response.status_code)
self.assertTrue(escape(u"'unsupported' is not a supported response type.") in response.content)

response = self.client.get(self.auth_url(), data=self.get_auth_params())
response = self.client.get(self.auth_url2())
self.assertEqual(200, response.status_code, response.content)

response = self.client.get(self.auth_url(), data=self.get_auth_params(response_type="token"))
response = self.client.get(self.auth_url2())
self.assertEqual(200, response.status_code)
else:
self.assertEqual(200, response.status_code)

def test_token_authorization_redirects_to_correct_uri(self):
self.login()
Expand Down Expand Up @@ -212,48 +234,83 @@ def test_token_authorization_cancellation(self):

self.assertEqual(AccessToken.objects.count(), 0)

def test_authorization_requires_a_valid_redirect_uri(self):
@ddt.data('code', 'token')
def test_authorization_requires_a_valid_redirect_uri(self, response_type):
self.login()

response = self.client.get(self.auth_url(),
data=self.get_auth_params(redirect_uri=self.get_client().redirect_uri + '-invalid'))
self.client.get(
self.auth_url(),
data=self.get_auth_params(
response_type=response_type, redirect_uri=self.get_client().redirect_uri + '-invalid'
)
)
response = self.client.get(self.auth_url2())

self.assertEqual(400, response.status_code)
self.assertTrue(escape(u"The requested redirect didn't match the client settings.") in response.content)

response = self.client.get(self.auth_url(),
data=self.get_auth_params(redirect_uri=self.get_client().redirect_uri))
self.client.get(self.auth_url(), data=self.get_auth_params(
response_type=response_type, redirect_uri=self.get_client().redirect_uri))
response = self.client.get(self.auth_url2())

self.assertEqual(200, response.status_code)

def test_authorization_requires_a_valid_scope(self):
@ddt.data('code', 'token')
def test_authorization_requires_a_valid_scope(self, response_type):
self.login()

response = self.client.get(self.auth_url(), data=self.get_auth_params(scope="invalid"))
self.client.get(self.auth_url(), data=self.get_auth_params(response_type=response_type, scope="invalid"))
response = self.client.get(self.auth_url2())

self.assertEqual(400, response.status_code)
self.assertTrue(escape(u"'invalid' is not a valid scope.") in response.content,
'Expected `{0}` in {1}'.format(escape(u"'invalid' is not a valid scope."), response.content))

response = self.client.get(self.auth_url(), data=self.get_auth_params(scope=constants.SCOPES[0][1]))
self.client.get(
self.auth_url(),
data=self.get_auth_params(response_type=response_type, scope=constants.SCOPES[0][1])
)
response = self.client.get(self.auth_url2())
self.assertEqual(200, response.status_code)

def test_authorization_is_not_granted(self):
@ddt.data('code', 'token')
def test_authorization_sets_default_scope(self, response_type):

self.login()
self.client.get(self.auth_url(), data=self.get_auth_params(response_type=response_type))
response = self.client.post(self.auth_url2(), {'authorize': True})

if response_type == 'code':
# authorization code flow
response = self.client.get(self.redirect_url())
query = urlparse.urlparse(response['Location']).query
code = urlparse.parse_qs(query)['code'][0]
response = self.client.post(self.access_token_url(), {
'grant_type': 'authorization_code',
'client_id': self.get_client().client_id,
'client_secret': self.get_client().client_secret,
'code': code})
scope_str = json.loads(response.content).get('scope')
else:
# implicit flow
fragment = urlparse.urlparse(response['Location']).fragment
scope_str = urlparse.parse_qs(fragment)['scope'][0]

self.assertEqual(scope_str, constants.SCOPES[0][1])

@ddt.data('code', 'token')
def test_authorization_is_not_granted(self, response_type):
self.login()

response = self.client.get(self.auth_url(), data=self.get_auth_params(response_type="code"))
response = self.client.get(self.auth_url2())
self.client.get(self.auth_url(), data=self.get_auth_params(response_type=response_type))
self.client.get(self.auth_url2())

response = self.client.post(self.auth_url2(), {'authorize': False, 'scope': constants.SCOPES[0][1]})
self.assertEqual(302, response.status_code, response.content)
self.assertTrue(self.get_client().redirect_uri in response['Location'],
'{0} not in {1}'.format(self.redirect_url(), response['Location']))
self.assertTrue('error=access_denied' in response['Location'])
self.assertFalse('code' in response['Location'])
self.assertFalse(response_type in response['Location'])

def test_authorization_is_granted(self):
self.login()
Expand All @@ -278,6 +335,17 @@ def test_preserving_the_state_variable(self):
self.assertTrue('code' in response['Location'])
self.assertTrue('state=abc' in response['Location'])

def test_preserving_the_state_variable_implicit(self):
self.login()

self.client.get(self.auth_url(), data=self.get_auth_params(response_type='token', state='abc'))
self.client.get(self.auth_url2())
response = self.client.post(self.auth_url2(), {'authorize': True, 'scope': constants.SCOPES[0][1]})
self.assertEqual(302, response.status_code)
self.assertFalse('error' in response['Location'])
self.assertTrue('access_token=' in response['Location'])
self.assertTrue('state=abc' in response['Location'])

def test_redirect_requires_valid_data(self):
self.login()
response = self.client.get(self.redirect_url())
Expand Down
88 changes: 46 additions & 42 deletions provider/oauth2/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,53 @@
from provider.oauth2.forms import PasswordGrantForm, RefreshTokenGrantForm
from provider.oauth2.models import Client, RefreshToken, AccessToken
from provider.utils import now
from provider.views import AccessToken as AccessTokenView, OAuthError
from provider.views import AccessToken as AccessTokenView, OAuthError, AccessTokenMixin
from provider.views import Capture, Authorize, Redirect


class OAuth2AccessTokenMixin(AccessTokenMixin):

def get_access_token(self, request, user, scope, client):
try:
# Attempt to fetch an existing access token.
at = AccessToken.objects.get(user=user, client=client,
scope=scope, expires__gt=now())
except AccessToken.DoesNotExist:
# None found... make a new one!
at = self.create_access_token(request, user, scope, client)
self.create_refresh_token(request, user, scope, at, client)
return at

def create_access_token(self, request, user, scope, client):
return AccessToken.objects.create(
user=user,
client=client,
scope=scope
)

def create_refresh_token(self, request, user, scope, access_token, client):
return RefreshToken.objects.create(
user=user,
access_token=access_token,
client=client
)

def invalidate_refresh_token(self, rt):
if constants.DELETE_EXPIRED:
rt.delete()
else:
rt.expired = True
rt.save()

def invalidate_access_token(self, at):
if constants.DELETE_EXPIRED:
at.delete()
else:
at.expires = now() - timedelta(milliseconds=1)
at.save()



class Capture(Capture):
"""
Implementation of :class:`provider.views.Capture`.
Expand All @@ -26,7 +69,7 @@ def get_redirect_url(self, request):
return reverse('oauth2:authorize')


class Authorize(Authorize):
class Authorize(Authorize, OAuth2AccessTokenMixin):
"""
Implementation of :class:`provider.views.Authorize`.
"""
Expand Down Expand Up @@ -67,7 +110,7 @@ class Redirect(Redirect):
pass


class AccessTokenView(AccessTokenView):
class AccessTokenView(AccessTokenView, OAuth2AccessTokenMixin):
"""
Implementation of :class:`provider.views.AccessToken`.
Expand Down Expand Up @@ -100,52 +143,13 @@ def get_password_grant(self, request, data, client):
raise OAuthError(form.errors)
return form.cleaned_data

def get_access_token(self, request, user, scope, client):
try:
# Attempt to fetch an existing access token.
at = AccessToken.objects.get(user=user, client=client,
scope=scope, expires__gt=now())
except AccessToken.DoesNotExist:
# None found... make a new one!
at = self.create_access_token(request, user, scope, client)
self.create_refresh_token(request, user, scope, at, client)
return at

def create_access_token(self, request, user, scope, client):
return AccessToken.objects.create(
user=user,
client=client,
scope=scope
)

def create_refresh_token(self, request, user, scope, access_token, client):
return RefreshToken.objects.create(
user=user,
access_token=access_token,
client=client
)

def invalidate_grant(self, grant):
if constants.DELETE_EXPIRED:
grant.delete()
else:
grant.expires = now() - timedelta(days=1)
grant.save()

def invalidate_refresh_token(self, rt):
if constants.DELETE_EXPIRED:
rt.delete()
else:
rt.expired = True
rt.save()

def invalidate_access_token(self, at):
if constants.DELETE_EXPIRED:
at.delete()
else:
at.expires = now() - timedelta(days=1)
at.save()


class AccessTokenDetailView(View):
"""
Expand Down
Loading

0 comments on commit 0382e3f

Please sign in to comment.