from datetime import datetime, timedelta from authlib.integrations.flask_oauth2 import AuthorizationServer, ResourceProtector from authlib.oauth2.rfc6749 import grants from authlib.oidc.core import grants as oidgrants from authlib.oauth2.rfc6750 import BearerTokenValidator from authlib.oauth2.rfc7009 import RevocationEndpoint from werkzeug.security import gen_salt from .models import db from .models import Client, Grant, Token from .ldap import LDAPUser, check_credentials class AuthorizationCodeGrant(grants.AuthorizationCodeGrant): def create_authorization_code(self, client, user, request): code = gen_salt(48) expires = datetime.utcnow() + timedelta(seconds=100) item = Grant( code=code, client_id=client.client_id, redirect_uri=request.redirect_uri, _scopes=request.scope, user=user.username, expires=expires, ) db.session.add(item) db.session.commit() return code def parse_authorization_code(self, code, client): item = Grant.query.filter_by( code=code, client_id=client.client_id).first() if item and not item.is_expired(): return item def delete_authorization_code(self, authorization_code): db.session.delete(authorization_code) db.session.commit() def authenticate_user(self, authorization_code): return LDAPUser.by_login(authorization_code.user) class PasswordGrant(grants.ResourceOwnerPasswordCredentialsGrant): def authenticate_user(self, username, password): if not check_credentials(username, password): return None return LDAPUser.by_login(username) class RefreshTokenGrant(grants.RefreshTokenGrant): def authenticate_refresh_token(self, refresh_token): token = Token.query.filter_by(refresh_token=refresh_token).first() if token and not token.is_expired(): return token def authenticate_user(self, credential): return LDAPUser.by_login(credentials.user_id) def query_client(client_id): return Client.query.filter_by(client_id=client_id).first() def save_token(token, request): user_id = None if request.user: user_id = request.user.username client = request.client # HACK: convert to model field names t = dict( token_type=token['token_type'], access_token=token.get('access_token'), refresh_token=token.get('refresh_token'), expires=datetime.utcnow() + timedelta(seconds=token['expires_in']), _scopes=token['scope'], ) item = Token(client_id=client.client_id, user=user_id, **t) db.session.add(item) db.session.commit() authorization = AuthorizationServer( query_client=query_client, save_token=save_token, ) require_oauth = ResourceProtector() class _BearerTokenValidator(BearerTokenValidator): def authenticate_token(self, token_string): return Token.query.filter_by(access_token=token_string).first() def request_invalid(self, request): return False def token_revoked(self,token): return False class _RevocationEndpoint(RevocationEndpoint): def query_token(self, token, token_type_hint, client): q = Token.query.filter_by(client_id=client.client_id) if token_type_hint == 'access_token': return q.filter_by(access_token=token).first() elif token_type_hint == 'refresh_token': return q.filter_by(refresh_token=token).first() token = q.filter_by(access_token=token).first() if token is None: return token return q.filter_by(refresh_token=token).first() def revoke_token(self, token): token.delete() def config_oauth(app): authorization.init_app(app) # support all grants authorization.register_grant(grants.ImplicitGrant) authorization.register_grant(grants.ClientCredentialsGrant) authorization.register_grant(AuthorizationCodeGrant) authorization.register_grant(PasswordGrant) authorization.register_grant(RefreshTokenGrant) # support revocation authorization.register_endpoint(_RevocationEndpoint) # protect resource require_oauth.register_token_validator(_BearerTokenValidator())