sso/website/oauth2.py

139 lines
4.2 KiB
Python

from datetime import datetime, timedelta
from authlib.integrations.flask_oauth2 import AuthorizationServer, ResourceProtector
from authlib.oauth2.rfc6749 import grants
from authlib.oidc.core import grants as oidgrants
from authlib.oauth2.rfc6750 import BearerTokenValidator
from authlib.oauth2.rfc7009 import RevocationEndpoint
from werkzeug.security import gen_salt
from .models import db
from .models import Client, Grant, Token
from .ldap import LDAPUser, check_credentials
class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
def create_authorization_code(self, client, user, request):
code = gen_salt(48)
expires = datetime.utcnow() + timedelta(seconds=100)
item = Grant(
code=code,
client_id=client.client_id,
redirect_uri=request.redirect_uri,
_scopes=request.scope,
user=user.username,
expires=expires,
)
db.session.add(item)
db.session.commit()
return code
def parse_authorization_code(self, code, client):
item = Grant.query.filter_by(
code=code, client_id=client.client_id).first()
if item and not item.is_expired():
return item
def delete_authorization_code(self, authorization_code):
db.session.delete(authorization_code)
db.session.commit()
def authenticate_user(self, authorization_code):
return LDAPUser.by_login(authorization_code.user)
class PasswordGrant(grants.ResourceOwnerPasswordCredentialsGrant):
def authenticate_user(self, username, password):
if not check_credentials(username, password):
return None
return LDAPUser.by_login(username)
class RefreshTokenGrant(grants.RefreshTokenGrant):
def authenticate_refresh_token(self, refresh_token):
token = Token.query.filter_by(refresh_token=refresh_token).first()
if token and not token.is_expired():
return token
def authenticate_user(self, credential):
return LDAPUser.by_login(credentials.user_id)
def query_client(client_id):
return Client.query.filter_by(client_id=client_id).first()
def save_token(token, request):
user_id = None
if request.user:
user_id = request.user.username
client = request.client
# HACK: convert to model field names
t = dict(
token_type=token['token_type'],
access_token=token.get('access_token'),
refresh_token=token.get('refresh_token'),
expires=datetime.utcnow() + timedelta(seconds=token['expires_in']),
_scopes=token['scope'],
)
item = Token(client_id=client.client_id, user=user_id, **t)
db.session.add(item)
db.session.commit()
authorization = AuthorizationServer(
query_client=query_client,
save_token=save_token,
)
require_oauth = ResourceProtector()
class _BearerTokenValidator(BearerTokenValidator):
def authenticate_token(self, token_string):
return Token.query.filter_by(access_token=token_string).first()
def request_invalid(self, request):
return False
def token_revoked(self,token):
return False
class _RevocationEndpoint(RevocationEndpoint):
def query_token(self, token, token_type_hint, client):
q = Token.query.filter_by(client_id=client.client_id)
if token_type_hint == 'access_token':
return q.filter_by(access_token=token).first()
elif token_type_hint == 'refresh_token':
return q.filter_by(refresh_token=token).first()
token = q.filter_by(access_token=token).first()
if token is None:
return token
return q.filter_by(refresh_token=token).first()
def revoke_token(self, token):
token.delete()
def config_oauth(app):
authorization.init_app(app)
# support all grants
authorization.register_grant(grants.ImplicitGrant)
authorization.register_grant(grants.ClientCredentialsGrant)
authorization.register_grant(AuthorizationCodeGrant)
authorization.register_grant(PasswordGrant)
authorization.register_grant(RefreshTokenGrant)
# support revocation
authorization.register_endpoint(_RevocationEndpoint)
# protect resource
require_oauth.register_token_validator(_BearerTokenValidator())