sso/website/oauth2.py

139 lines
4.2 KiB
Python

from datetime import datetime, timedelta
from authlib.flask.oauth2 import AuthorizationServer, ResourceProtector
from authlib.oauth2.rfc6749 import grants
from authlib.oidc.core import grants as oidgrants
from authlib.oauth2.rfc6750 import BearerTokenValidator
from authlib.oauth2.rfc7009 import RevocationEndpoint
from werkzeug.security import gen_salt
from .models import db
from .models import Client, Grant, Token
from .ldap import LDAPUser, check_credentials
class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
def create_authorization_code(self, client, user, request):
code = gen_salt(48)
expires = datetime.utcnow() + timedelta(seconds=100)
item = Grant(
code=code,
client_id=client.client_id,
redirect_uri=request.redirect_uri,
_scopes=request.scope,
user=user.username,
expires=expires,
)
db.session.add(item)
db.session.commit()
return code
def parse_authorization_code(self, code, client):
item = Grant.query.filter_by(
code=code, client_id=client.client_id).first()
if item and not item.is_expired():
return item
def delete_authorization_code(self, authorization_code):
db.session.delete(authorization_code)
db.session.commit()
def authenticate_user(self, authorization_code):
return LDAPUser.by_login(authorization_code.user)
class PasswordGrant(grants.ResourceOwnerPasswordCredentialsGrant):
def authenticate_user(self, username, password):
if not check_credentials(username, password):
return None
return LDAPUser.by_login(username)
class RefreshTokenGrant(grants.RefreshTokenGrant):
def authenticate_refresh_token(self, refresh_token):
token = Token.query.filter_by(refresh_token=refresh_token).first()
if token and not token.is_expired():
return token
def authenticate_user(self, credential):
return LDAPUser.by_login(credentials.user_id)
def query_client(client_id):
return Client.query.filter_by(client_id=client_id).first()
def save_token(token, request):
user_id = None
if request.user:
user_id = request.user.username
client = request.client
# HACK: convert to model field names
t = dict(
token_type=token['token_type'],
access_token=token.get('access_token'),
refresh_token=token.get('refresh_token'),
expires=datetime.utcnow() + timedelta(seconds=token['expires_in']),
_scopes=token['scope'],
)
item = Token(client_id=client.client_id, user=user_id, **t)
db.session.add(item)
db.session.commit()
authorization = AuthorizationServer(
query_client=query_client,
save_token=save_token,
)
require_oauth = ResourceProtector()
class _BearerTokenValidator(BearerTokenValidator):
def authenticate_token(self, token_string):
return Token.query.filter_by(access_token=token_string).first()
def request_invalid(self, request):
return False
def token_revoked(self,token):
return False
class _RevocationEndpoint(RevocationEndpoint):
def query_token(self, token, token_type_hint, client):
q = Token.query.filter_by(client_id=client.client_id)
if token_type_hint == 'access_token':
return q.filter_by(access_token=token).first()
elif token_type_hint == 'refresh_token':
return q.filter_by(refresh_token=token).first()
token = q.filter_by(access_token=token).first()
if token is None:
return token
return q.filter_by(refresh_token=token).first()
def revoke_token(self, token):
token.delete()
def config_oauth(app):
authorization.init_app(app)
# support all grants
authorization.register_grant(grants.ImplicitGrant)
authorization.register_grant(grants.ClientCredentialsGrant)
authorization.register_grant(AuthorizationCodeGrant)
authorization.register_grant(PasswordGrant)
authorization.register_grant(RefreshTokenGrant)
# support revocation
authorization.register_endpoint(_RevocationEndpoint)
# protect resource
require_oauth.register_token_validator(_BearerTokenValidator())