from datetime import datetime, timedelta from authlib.flask.oauth2 import AuthorizationServer, ResourceProtector from authlib.oauth2.rfc6749 import grants from authlib.oidc.core import grants as oidgrants from authlib.oauth2.rfc6750 import BearerTokenValidator from authlib.oauth2.rfc7009 import RevocationEndpoint from werkzeug.security import gen_salt from .models import db from .models import Client, Grant, Token from .ldap import LDAPUser, check_credentials class AuthorizationCodeGrant(grants.AuthorizationCodeGrant): def create_authorization_code(self, client, user, request): code = gen_salt(48) expires = datetime.utcnow() + timedelta(seconds=100) item = Grant( code=code, client_id=client.client_id, redirect_uri=request.redirect_uri, _scopes=request.scope, user=user.username, expires=expires, ) db.session.add(item) db.session.commit() return code def parse_authorization_code(self, code, client): item = Grant.query.filter_by( code=code, client_id=client.client_id).first() if item and not item.is_expired(): return item def delete_authorization_code(self, authorization_code): db.session.delete(authorization_code) db.session.commit() def authenticate_user(self, authorization_code): return LDAPUser.by_login(authorization_code.user) class PasswordGrant(grants.ResourceOwnerPasswordCredentialsGrant): def authenticate_user(self, username, password): if not check_credentials(username, password): return None return LDAPUser.by_login(username) class RefreshTokenGrant(grants.RefreshTokenGrant): def authenticate_refresh_token(self, refresh_token): token = Token.query.filter_by(refresh_token=refresh_token).first() if token and not token.is_expired(): return token def authenticate_user(self, credential): return LDAPUser.by_login(credentials.user_id) def query_client(client_id): return Client.query.filter_by(client_id=client_id).first() def save_token(token, request): user_id = None if request.user: user_id = request.user.username client = request.client # HACK: convert to model field names t = dict( token_type=token['token_type'], access_token=token.get('access_token'), refresh_token=token.get('refresh_token'), expires=datetime.utcnow() + timedelta(seconds=token['expires_in']), _scopes=token['scope'], ) item = Token(client_id=client.client_id, user=user_id, **t) db.session.add(item) db.session.commit() authorization = AuthorizationServer( query_client=query_client, save_token=save_token, ) require_oauth = ResourceProtector() class _BearerTokenValidator(BearerTokenValidator): def authenticate_token(self, token_string): return Token.query.filter_by(access_token=token_string).first() def request_invalid(self, request): return False def token_revoked(self,token): return False class _RevocationEndpoint(RevocationEndpoint): def query_token(self, token, token_type_hint, client): q = Token.query.filter_by(client_id=client.client_id) if token_type_hint == 'access_token': return q.filter_by(access_token=token).first() elif token_type_hint == 'refresh_token': return q.filter_by(refresh_token=token).first() token = q.filter_by(access_token=token).first() if token is None: return token return q.filter_by(refresh_token=token).first() def revoke_token(self, token): token.delete() def config_oauth(app): authorization.init_app(app) # support all grants authorization.register_grant(grants.ImplicitGrant) authorization.register_grant(grants.ClientCredentialsGrant) authorization.register_grant(AuthorizationCodeGrant) authorization.register_grant(PasswordGrant) authorization.register_grant(RefreshTokenGrant) # support revocation authorization.register_endpoint(_RevocationEndpoint) # protect resource require_oauth.register_token_validator(_BearerTokenValidator())