138 lines
4.2 KiB
Python
138 lines
4.2 KiB
Python
from datetime import datetime, timedelta
|
|
|
|
from authlib.flask.oauth2 import AuthorizationServer, ResourceProtector
|
|
from authlib.oauth2.rfc6749 import grants
|
|
from authlib.oidc.core import grants as oidgrants
|
|
from authlib.oauth2.rfc6750 import BearerTokenValidator
|
|
from authlib.oauth2.rfc7009 import RevocationEndpoint
|
|
from werkzeug.security import gen_salt
|
|
|
|
from .models import db
|
|
from .models import Client, Grant, Token
|
|
from .ldap import LDAPUser, check_credentials
|
|
|
|
|
|
class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
|
|
def create_authorization_code(self, client, user, request):
|
|
code = gen_salt(48)
|
|
expires = datetime.utcnow() + timedelta(seconds=100)
|
|
item = Grant(
|
|
code=code,
|
|
client_id=client.client_id,
|
|
redirect_uri=request.redirect_uri,
|
|
_scopes=request.scope,
|
|
user=user.username,
|
|
expires=expires,
|
|
)
|
|
db.session.add(item)
|
|
db.session.commit()
|
|
return code
|
|
|
|
def parse_authorization_code(self, code, client):
|
|
item = Grant.query.filter_by(
|
|
code=code, client_id=client.client_id).first()
|
|
if item and not item.is_expired():
|
|
return item
|
|
|
|
def delete_authorization_code(self, authorization_code):
|
|
db.session.delete(authorization_code)
|
|
db.session.commit()
|
|
|
|
def authenticate_user(self, authorization_code):
|
|
return LDAPUser.by_login(authorization_code.user)
|
|
|
|
|
|
class PasswordGrant(grants.ResourceOwnerPasswordCredentialsGrant):
|
|
def authenticate_user(self, username, password):
|
|
if not check_credentials(username, password):
|
|
return None
|
|
return LDAPUser.by_login(username)
|
|
|
|
|
|
class RefreshTokenGrant(grants.RefreshTokenGrant):
|
|
def authenticate_refresh_token(self, refresh_token):
|
|
token = Token.query.filter_by(refresh_token=refresh_token).first()
|
|
if token and not token.is_expired():
|
|
return token
|
|
|
|
def authenticate_user(self, credential):
|
|
return LDAPUser.by_login(credentials.user_id)
|
|
|
|
|
|
def query_client(client_id):
|
|
return Client.query.filter_by(client_id=client_id).first()
|
|
|
|
|
|
def save_token(token, request):
|
|
user_id = None
|
|
if request.user:
|
|
user_id = request.user.username
|
|
|
|
client = request.client
|
|
|
|
# HACK: convert to model field names
|
|
t = dict(
|
|
token_type=token['token_type'],
|
|
access_token=token.get('access_token'),
|
|
refresh_token=token.get('refresh_token'),
|
|
expires=datetime.utcnow() + timedelta(seconds=token['expires_in']),
|
|
_scopes=token['scope'],
|
|
)
|
|
|
|
item = Token(client_id=client.client_id, user=user_id, **t)
|
|
db.session.add(item)
|
|
db.session.commit()
|
|
|
|
|
|
authorization = AuthorizationServer(
|
|
query_client=query_client,
|
|
save_token=save_token,
|
|
)
|
|
require_oauth = ResourceProtector()
|
|
|
|
|
|
class _BearerTokenValidator(BearerTokenValidator):
|
|
def authenticate_token(self, token_string):
|
|
return Token.query.filter_by(access_token=token_string).first()
|
|
|
|
def request_invalid(self, request):
|
|
return False
|
|
|
|
def token_revoked(self,token):
|
|
return False
|
|
|
|
|
|
class _RevocationEndpoint(RevocationEndpoint):
|
|
def query_token(self, token, token_type_hint, client):
|
|
q = Token.query.filter_by(client_id=client.client_id)
|
|
|
|
if token_type_hint == 'access_token':
|
|
return q.filter_by(access_token=token).first()
|
|
elif token_type_hint == 'refresh_token':
|
|
return q.filter_by(refresh_token=token).first()
|
|
|
|
token = q.filter_by(access_token=token).first()
|
|
if token is None:
|
|
return token
|
|
|
|
return q.filter_by(refresh_token=token).first()
|
|
|
|
def revoke_token(self, token):
|
|
token.delete()
|
|
|
|
|
|
def config_oauth(app):
|
|
authorization.init_app(app)
|
|
|
|
# support all grants
|
|
authorization.register_grant(grants.ImplicitGrant)
|
|
authorization.register_grant(grants.ClientCredentialsGrant)
|
|
authorization.register_grant(AuthorizationCodeGrant)
|
|
authorization.register_grant(PasswordGrant)
|
|
authorization.register_grant(RefreshTokenGrant)
|
|
|
|
# support revocation
|
|
authorization.register_endpoint(_RevocationEndpoint)
|
|
|
|
# protect resource
|
|
require_oauth.register_token_validator(_BearerTokenValidator())
|