Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: OAuth login with ORCID #154

Merged
merged 11 commits into from
Nov 24, 2021
6 changes: 4 additions & 2 deletions afidsvalidator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from config import *

from afidsvalidator.views import validator
from afidsvalidator.model import db
from afidsvalidator.model import db, login_manager
from afidsvalidator.orcid import orcid_blueprint


class ConfigException(Exception):
Expand Down Expand Up @@ -41,7 +42,6 @@ def __init__(self, message):

def create_app():
app = Flask(__name__)
app.register_blueprint(validator)

app.config.from_object(config_settings)

Expand All @@ -57,7 +57,9 @@ def create_app():
app.config["SQLALCHEMY_DATABASE_URI"] = heroku_uri
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
db.init_app(app)
login_manager.init_app(app)
app.register_blueprint(validator)
app.register_blueprint(orcid_blueprint)

return app

Expand Down
29 changes: 28 additions & 1 deletion afidsvalidator/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from pkg_resources import parse_version
from sqlalchemy.orm import composite
from flask_sqlalchemy import SQLAlchemy
from flask_login import LoginManager, UserMixin
from flask_dance.consumer.storage.sqla import OAuthConsumerMixin


EXPECTED_LABELS = [str(x + 1) for x in range(32)]
Expand Down Expand Up @@ -49,6 +51,29 @@
EXPECTED_MAP = dict(zip(EXPECTED_LABELS, EXPECTED_DESCS))

db = SQLAlchemy()
login_manager = LoginManager()


@login_manager.user_loader
def load_user(user_id):
return User.query.get(user_id)


class User(UserMixin, db.Model):
id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.String, nullable=True)
oauths = db.relationship("OAuth", backref="user", lazy=True)
human_fiducial_sets = db.relationship(
"HumanFiducialSet", backref="user", lazy=True
)

def __repr__(self):
return f"<email={self.email}>"


class OAuth(OAuthConsumerMixin, db.Model):
user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False)
provider_user_id = db.Column(db.String(20), nullable=False)


class FiducialPosition(object):
Expand Down Expand Up @@ -130,7 +155,9 @@ class HumanFiducialSet(FiducialSet, db.Model):
__name__ = "HumanFiducialSet"

id = db.Column(db.Integer, primary_key=True)
user_id = db.Column(db.String)
afids_user_id = db.Column(
db.Integer, db.ForeignKey("user.id"), nullable=True
)
date = db.Column(db.Date)
template = db.Column(db.String)

Expand Down
51 changes: 51 additions & 0 deletions afidsvalidator/orcid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Set up blueprint for ORCID authentication."""

from flask_dance.consumer import OAuth2ConsumerBlueprint, oauth_authorized
from flask_dance.consumer.storage.sqla import SQLAlchemyStorage
from flask_login import current_user, login_user
from sqlalchemy.orm.exc import NoResultFound

from afidsvalidator.model import db, OAuth, User

orcid_blueprint = OAuth2ConsumerBlueprint(
"orcid",
__name__,
base_url="https://api.orcid.org/v3.0",
token_url="https://orcid.org/oauth/token",
authorization_url="https://orcid.org/oauth/authorize",
storage=SQLAlchemyStorage(
OAuth, db.session, user=current_user, user_required=True
),
scope="openid",
)
orcid_blueprint.from_config["client_id"] = "ORCID_OAUTH_CLIENT_ID"
orcid_blueprint.from_config["client_secret"] = "ORCID_OAUTH_CLIENT_SECRET"


@oauth_authorized.connect_via(orcid_blueprint)
def orcid_logged_in(blueprint, token):
"""Create/login user on successful ORCID login."""
if not token:
return False

orcid_id = token["orcid"]

try:
oauth = OAuth.query.filter_by(
provider=orcid_blueprint.name, provider_user_id=orcid_id
).one()
except NoResultFound:
oauth = OAuth(
provider=blueprint.name, provider_user_id=orcid_id, token=token
)

if oauth.user:
login_user(oauth.user)
else:
user = User(name=token["name"])
oauth.user = user
db.session.add_all([user, oauth])
db.session.commit()
login_user(user)

return False
6 changes: 6 additions & 0 deletions afidsvalidator/templates/base.html
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,15 @@
<li class="nav-item mx-4">
<a class="nav-link" href="contact.html">Contact</a>
</li>
{% if current_user.is_authenticated %}
<li class="nav-item mx-4">
<a class="nav-link" href="logout.html">Logout</a>
</li>
{% else %}
<li class="nav-item mx-4">
<a class="nav-link" href="login.html">Login</a>
</li>
{% endif %}
</ul>
</div>
</nav>
Expand Down
2 changes: 1 addition & 1 deletion afidsvalidator/templates/login.html
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<div class="container">
<div class="row">
<div class="col text-center">
<img class="img-fluid" src="{{ url_for('static', filename='images/under-construction-gif-11.gif') }}" alt="Under construction">
<a class="btn btn-light" role="button" href="{{ url_for("orcid.login") }}">Login with ORCID</a>
</div>
</div>
</div>
Expand Down
28 changes: 24 additions & 4 deletions afidsvalidator/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@
import os
from datetime import datetime, timezone

from flask import render_template, request, jsonify, Blueprint, current_app
from flask import (
render_template,
request,
jsonify,
Blueprint,
current_app,
redirect,
)
from flask_login import logout_user, current_user
import numpy as np
import wtforms as wtf

Expand Down Expand Up @@ -48,21 +56,28 @@ def allowed_file(filename):
@validator.route("/")
def index():
"""Render the static index page."""
return render_template("index.html")
return render_template("index.html", current_user=current_user)


# Contact
@validator.route("/contact.html")
def contact():
"""Render the static contact page."""
return render_template("contact.html")
return render_template("contact.html", current_user=current_user)


# Login
@validator.route("/login.html")
def login():
"""Render the static login page."""
return render_template("login.html")
return render_template("login.html", current_user=current_user)


@validator.route("/logout.html")
def logout():
"""Log out user and render the index."""
logout_user()
return redirect("/")


# Validator
Expand Down Expand Up @@ -131,6 +146,7 @@ def validate():
index=[],
labels=labels,
distances=distances,
current_user=current_user,
)

if user_afids.validate():
Expand All @@ -154,6 +170,7 @@ def validate():
index=[],
labels=labels,
distances=distances,
current_user=current_user,
)

result = f"{result}<br>{fid_template} selected"
Expand All @@ -164,6 +181,8 @@ def validate():
template_afids = csv_to_afids(template_file.read())

if request.form.get("db_checkbox"):
if current_user.is_authenticated:
user_afids.afids_user_id = current_user.id
db.session.add(user_afids)
db.session.commit()
print("Fiducial set added")
Expand Down Expand Up @@ -201,6 +220,7 @@ def validate():
timestamp=timestamp,
scatter_html=generate_3d_scatter(template_afids, user_afids),
histogram_html=generate_histogram(template_afids, user_afids),
current_user=current_user,
)


Expand Down
3 changes: 3 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ class Config(object):
if SQLALCHEMY_DATABASE_URI.startswith("postgres://"):
SQLALCHEMY_DATABASE_URI.replace("postgres://", "postgresql://", 1)

ORCID_OAUTH_CLIENT_ID = os.environ.get("ORCID_OAUTH_CLIENT_ID")
ORCID_OAUTH_CLIENT_SECRET = os.environ.get("ORCID_OAUTH_CLIENT_SECRET")


class ProductionConfig(Config):
"""Config used in production"""
Expand Down
15 changes: 5 additions & 10 deletions manage.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
"""Flask database management script."""

import os
from flask_script import Manager
from flask_migrate import Migrate, MigrateCommand
from flask_migrate import Migrate

from afidsvalidator import create_app
from afidsvalidator.model import db

# Set up app
app = create_app()

# Set up db
manager = Manager(app)
manager.add_command("db", MigrateCommand)
migrate = Migrate(render_as_batch=True, compare_type=True)

if __name__ == "__main__":
manager.run()
# Set up app
app = create_app()
migrate.init_app(app, db)
38 changes: 38 additions & 0 deletions migrations/versions/56d89145adbb_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""empty message

Revision ID: 56d89145adbb
Revises: 7b4e00130929
Create Date: 2021-10-13 16:52:58.961413

"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "56d89145adbb"
down_revision = "7b4e00130929"
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("user", schema=None) as batch_op:
batch_op.add_column(sa.Column("name", sa.String(), nullable=True))
batch_op.drop_column("email")

# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("user", schema=None) as batch_op:
batch_op.add_column(
sa.Column(
"email", sa.VARCHAR(), autoincrement=False, nullable=True
)
)
batch_op.drop_column("name")

# ### end Alembic commands ###
34 changes: 34 additions & 0 deletions migrations/versions/7b4e00130929_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""empty message

Revision ID: 7b4e00130929
Revises: a0928ce2eee6
Create Date: 2021-10-13 16:49:26.122781

"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "7b4e00130929"
down_revision = "a0928ce2eee6"
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("flask_dance_oauth", schema=None) as batch_op:
batch_op.add_column(
sa.Column("provider_user_id", sa.String(length=20), nullable=False)
)

# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("flask_dance_oauth", schema=None) as batch_op:
batch_op.drop_column("provider_user_id")

# ### end Alembic commands ###
Loading