diff --git a/migrations/versions/dd58bc95a904_add_client_owner_id.py b/migrations/versions/dd58bc95a904_add_client_owner_id.py index 4382435..74df6ec 100644 --- a/migrations/versions/dd58bc95a904_add_client_owner_id.py +++ b/migrations/versions/dd58bc95a904_add_client_owner_id.py @@ -10,19 +10,21 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. -revision = 'dd58bc95a904' -down_revision = '5d43eb9bfe78' +revision = "dd58bc95a904" +down_revision = "5d43eb9bfe78" branch_labels = None depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.add_column('oauth2_client', sa.Column('owner_id', sa.String(length=40), nullable=True)) + op.add_column( + "oauth2_client", sa.Column("owner_id", sa.String(length=40), nullable=True) + ) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.drop_column('oauth2_client', 'owner_id') + op.drop_column("oauth2_client", "owner_id") # ### end Alembic commands ### diff --git a/sso/__init__.py b/sso/__init__.py index 78ba4fb..4206b9c 100644 --- a/sso/__init__.py +++ b/sso/__init__.py @@ -33,7 +33,7 @@ def create_app(): app.config.get("PROXYFIX_NUM_PROXIES"), ) - if app.config.get('LOGGING_LEVEL'): - logging.basicConfig(level=app.config['LOGGING_LEVEL']) + if app.config.get("LOGGING_LEVEL"): + logging.basicConfig(level=app.config["LOGGING_LEVEL"]) return app diff --git a/sso/directory.py b/sso/directory.py index bd0d276..92ceb91 100644 --- a/sso/directory.py +++ b/sso/directory.py @@ -14,7 +14,7 @@ def connect_to_ldap(): def check_credentials(username, password): - if app.config.get("TESTING") == True: + if app.config.get("TESTING"): return True conn = ldap.initialize(app.config["LDAP_URL"]) @@ -31,7 +31,7 @@ class LDAPUserProxy(object): self.is_authenticated = True self.is_anonymous = False - if app.config.get("TESTING") == True: + if app.config.get("TESTING"): self.gecos = "Testing User" self.mifare_hashes = [] self.phone = "123456789" @@ -44,8 +44,10 @@ class LDAPUserProxy(object): ldap.SCOPE_SUBTREE, app.config["LDAP_UID_FILTER"] % self.username, ) + if len(res) != 1: raise Exception("No such username.") + dn, data = res[0] self.username = data.get("uid", [b""])[0].decode() or None diff --git a/sso/forms.py b/sso/forms.py index 1c320f0..499d13a 100644 --- a/sso/forms.py +++ b/sso/forms.py @@ -45,7 +45,10 @@ class ClientForm(FlaskForm): client_name = StringField("Client name", validators=[DataRequired()]) client_uri = StringField("Client URI", validators=[DataRequired(), URL()]) redirect_uris = FieldList( - StringField("Redirect URI", validators=[DataRequired(), URL(require_tld=False)]), min_entries=1 + StringField( + "Redirect URI", validators=[DataRequired(), URL(require_tld=False)] + ), + min_entries=1, ) grant_types = MultiCheckboxField( "Grant types", @@ -62,9 +65,13 @@ class ClientForm(FlaskForm): token_endpoint_auth_method = RadioField( "Token endpoint authentication method", - choices=[("client_secret_basic", "Basic"), ("client_secret_post", "POST"), ("client_secret_get", "Query args (DEPRECATED)")], + choices=[ + ("client_secret_basic", "Basic"), + ("client_secret_post", "POST"), + ("client_secret_get", "Query args (DEPRECATED)"), + ], validators=[DataRequired()], - default='client_secret_post', + default="client_secret_post", ) scope = MultiCheckboxField( diff --git a/sso/models.py b/sso/models.py index f62c957..9c9e8fa 100644 --- a/sso/models.py +++ b/sso/models.py @@ -33,7 +33,9 @@ class Client(db.Model, OAuth2ClientMixin): def revoke_tokens(self): """Revoke all active access/refresh tokens and authorization codes""" Token.query.filter(Token.client_id == self.client_id).delete() - AuthorizationCode.query.filter(AuthorizationCode.client_id == self.client_id).delete() + AuthorizationCode.query.filter( + AuthorizationCode.client_id == self.client_id + ).delete() class AuthorizationCode(db.Model, OAuth2AuthorizationCodeMixin): diff --git a/sso/oauth2.py b/sso/oauth2.py index cfad0c8..f2660ab 100644 --- a/sso/oauth2.py +++ b/sso/oauth2.py @@ -125,24 +125,20 @@ def _validate_client(query_client, client_id, state=None, status_code=400): return client + def authenticate_client_secret_get(query_client, request): """Authenticates clients providing their secret via query args (either via GET or POST) request""" data = request.args - client_id = data.get('client_id') - client_secret = data.get('client_secret') + client_id = data.get("client_id") + client_secret = data.get("client_secret") if client_id and client_secret: client = _validate_client(query_client, client_id, request.state) - if client.check_token_endpoint_auth_method('client_secret_get') \ - and client.check_client_secret(client_secret): - log.debug( - 'Authenticate %s via "client_secret_get" ' - 'success', client_id - ) + if client.check_token_endpoint_auth_method( + "client_secret_get" + ) and client.check_client_secret(client_secret): + log.debug('Authenticate %s via "client_secret_get" ' "success", client_id) return client - log.debug( - 'Authenticate %s via "client_secret_get" ' - 'failed', client_id - ) + log.debug('Authenticate %s via "client_secret_get" ' "failed", client_id) def save_token(token, request): @@ -167,9 +163,12 @@ def save_token(token, request): class CustomAuthorizationCodeGrant(AuthorizationCodeGrant): # kill me (inventory) - TOKEN_ENDPOINT_HTTP_METHODS = ['GET', 'POST'] + TOKEN_ENDPOINT_HTTP_METHODS = ["GET", "POST"] TOKEN_ENDPOINT_AUTH_METHODS = [ - 'client_secret_basic', 'client_secret_post', 'client_secret_get', 'none' + "client_secret_basic", + "client_secret_post", + "client_secret_get", + "none", ] def validate_token_request(self): @@ -178,15 +177,20 @@ class CustomAuthorizationCodeGrant(AuthorizationCodeGrant): return super(CustomAuthorizationCodeGrant, self).validate_token_request() + class CustomResourceProtector(ResourceProtector): - def validate_request(self, scope, request, scope_operator='AND'): + def validate_request(self, scope, request, scope_operator="AND"): # damn you gerrit args = dict(url_decode(urlparse.urlparse(request.uri).query)) - if args.get('access_token'): - token_string = args.get('access_token') - return self._token_validators['bearer'](token_string, scope, request, scope_operator) + if args.get("access_token"): + token_string = args.get("access_token") + return self._token_validators["bearer"]( + token_string, scope, request, scope_operator + ) - return super(CustomResourceProtector, self).validate_request(scope, request, scope_operator) + return super(CustomResourceProtector, self).validate_request( + scope, request, scope_operator + ) authorization = AuthorizationServer() @@ -196,7 +200,9 @@ require_oauth = CustomResourceProtector() def config_oauth(app): query_client = create_query_client_func(db.session, Client) authorization.init_app(app, query_client=query_client, save_token=save_token) - authorization.register_client_auth_method('client_secret_get', authenticate_client_secret_get) + authorization.register_client_auth_method( + "client_secret_get", authenticate_client_secret_get + ) # support all openid grants authorization.register_grant( diff --git a/sso/settings.py b/sso/settings.py index 8477de0..9f157db 100644 --- a/sso/settings.py +++ b/sso/settings.py @@ -39,8 +39,8 @@ LDAP_BIND_DN = env.str( ) LDAP_BIND_PASSWORD = env.str("LDAP_BIND_PASSWORD", default="insert password here") -PROXYFIX_ENABLE = env.bool('PROXYFIX_ENABLE', default=True) -PROXYFIX_NUM_PROXIES = env.int('PROXYFIX_NUM_PROXIES', default=1) +PROXYFIX_ENABLE = env.bool("PROXYFIX_ENABLE", default=True) +PROXYFIX_NUM_PROXIES = env.int("PROXYFIX_NUM_PROXIES", default=1) JWT_CONFIG = { "key": env.str("JWT_SECRET_KEY", default=SECRET_KEY), @@ -49,4 +49,4 @@ JWT_CONFIG = { "exp": env.int("JWT_EXP", default=3600), } -LOGGING_LEVEL = env.str('LOGGING_LEVEL', default=None) +LOGGING_LEVEL = env.str("LOGGING_LEVEL", default=None) diff --git a/sso/views.py b/sso/views.py index fd34a63..ce2e82f 100644 --- a/sso/views.py +++ b/sso/views.py @@ -87,7 +87,7 @@ def client_create(): db.session.add(client) db.session.commit() - flash('Client has been created.', 'success') + flash("Client has been created.", "success") return redirect(url_for(".client_edit", client_id=client.id)) return render_template("client_edit.html", form=form) @@ -105,7 +105,7 @@ def client_edit(client_id): if form.validate_on_submit(): client.set_client_metadata(form.data) db.session.commit() - flash('Client has been changed.', 'success') + flash("Client has been changed.", "success") return redirect(url_for(".client_edit", client_id=client.id)) return render_template("client_edit.html", client=client, form=form) @@ -117,12 +117,12 @@ def client_destroy(client_id): Client, Client.id == client_id, Client.owner_id == current_user.get_user_id() ) - if request.method == 'POST': + if request.method == "POST": db.session.delete(client) client.revoke_tokens() db.session.commit() - flash('Client destroyed.', 'success') - return redirect(url_for('.profile')) + flash("Client destroyed.", "success") + return redirect(url_for(".profile")) return render_template("confirm_destroy.html", client=client) @@ -133,16 +133,16 @@ def client_regenerate_secret(client_id): Client, Client.id == client_id, Client.owner_id == current_user.get_user_id() ) - if request.method == 'POST': + if request.method == "POST": print(request.form) client.client_secret = generate_token() - if request.form.get('revoke') == 'yes': + if request.form.get("revoke") == "yes": client.revoke_tokens() db.session.commit() - flash('Client secret regenerated.', 'success') - return redirect(url_for('.client_edit', client_id=client.id)) + flash("Client secret regenerated.", "success") + return redirect(url_for(".client_edit", client_id=client.id)) return render_template("confirm_regenerate.html", client=client) @@ -166,8 +166,11 @@ def authorize(): return authorization.create_authorization_response(grant_user=current_user) return render_template( - "oauthorize.html", user=current_user, grant=grant, client=grant.client, - scopes=grant.request.scope.split() + "oauthorize.html", + user=current_user, + grant=grant, + client=grant.client, + scopes=grant.request.scope.split(), ) if request.form["confirm"]: