*: blacken

This commit is contained in:
q3k 2024-07-07 21:38:17 +02:00
parent f4c6620007
commit 1f51fbea85
17 changed files with 468 additions and 293 deletions

View file

@ -3,7 +3,5 @@ import wtforms
from typing import Protocol from typing import Protocol
class FlaskForm(wtforms.Form): class FlaskForm(wtforms.Form):
def validate_on_submit(self) -> bool: def validate_on_submit(self) -> bool: ...
... def is_submitted(self) -> bool: ...
def is_submitted(self) -> bool:
...

View file

@ -1,2 +1 @@
def changePassword(principal: str, current: str, new: str) -> bool: def changePassword(principal: str, current: str, new: str) -> bool: ...
...

View file

@ -1,21 +1,19 @@
from enum import Enum from enum import Enum
from typing import Protocol, List, Tuple, Dict, Optional from typing import Protocol, List, Tuple, Dict, Optional
class Scope(Enum): class Scope(Enum):
BASE = 'base' BASE = "base"
ONELEVEL = 'onelevel' ONELEVEL = "onelevel"
SUBTREE = 'subtree' SUBTREE = "subtree"
SCOPE_BASE = Scope.BASE SCOPE_BASE = Scope.BASE
SCOPE_ONELEVEL = Scope.ONELEVEL SCOPE_ONELEVEL = Scope.ONELEVEL
SCOPE_SUBTREE = Scope.SUBTREE SCOPE_SUBTREE = Scope.SUBTREE
class Mod(Enum): class Mod(Enum):
ADD = 'add' ADD = "add"
DELETE = 'delete' DELETE = "delete"
REPLACE = 'replace' REPLACE = "replace"
MOD_ADD = Mod.ADD MOD_ADD = Mod.ADD
MOD_DELETE = Mod.DELETE MOD_DELETE = Mod.DELETE
@ -23,21 +21,20 @@ MOD_REPLACE = Mod.REPLACE
ModEntry = Tuple[Mod, str, bytes | List[bytes]] ModEntry = Tuple[Mod, str, bytes | List[bytes]]
class LDAPError(Exception): class LDAPError(Exception): ...
... class NO_SUCH_OBJECT(LDAPError): ...
class NO_SUCH_OBJECT(LDAPError):
...
class ldapobject(Protocol): class ldapobject(Protocol):
def start_tls_s(self) -> None: def start_tls_s(self) -> None: ...
... def simple_bind_s(self, dn: str, password: str) -> None: ...
def simple_bind_s(self, dn: str, password: str) -> None: def search_s(
... self,
def search_s(self, base: str, scope: Scope, filterstr: str = '(objectClass=*)', attrlist: Optional[List[str]] = None, attrsonly: int = 0) -> List[Tuple[str, Dict[str, List[bytes]]]]: base: str,
... scope: Scope,
def modify_s(self, dn: str, modlist: List[ModEntry]) -> None: filterstr: str = "(objectClass=*)",
... attrlist: Optional[List[str]] = None,
attrsonly: int = 0,
) -> List[Tuple[str, Dict[str, List[bytes]]]]: ...
def modify_s(self, dn: str, modlist: List[ModEntry]) -> None: ...
def initialize(url: str) -> ldapobject: def initialize(url: str) -> ldapobject: ...
...

View file

@ -17,32 +17,36 @@ class App(flask.Flask):
super().__init__(name) super().__init__(name)
if hasattr(config, "secret_key"): if hasattr(config, "secret_key"):
self.secret_key = config.secret_key self.secret_key = config.secret_key
if hasattr(config, "debug"): if hasattr(config, "debug"):
self.debug = config.debug self.debug = config.debug
def get_dn(self) -> Optional[str]: def get_dn(self) -> Optional[str]:
return flask.session.get('dn') return flask.session.get("dn")
def get_connection(self, dn: Optional[str] = None) -> Optional[ldap.ldapobject]: def get_connection(self, dn: Optional[str] = None) -> Optional[ldap.ldapobject]:
dn = dn or self.get_dn() dn = dn or self.get_dn()
if dn is None: if dn is None:
return None return None
return self.connections[dn] return self.connections[dn]
def get_admin_connection(self) -> ldap.ldapobject: def get_admin_connection(self) -> ldap.ldapobject:
conn = self.connections[config.ldap_admin_dn] conn = self.connections[config.ldap_admin_dn]
if not conn: if not conn:
conn = self.connections.bind(config.ldap_admin_dn, config.ldap_admin_password) conn = self.connections.bind(
config.ldap_admin_dn, config.ldap_admin_password
)
return conn return conn
def get_profile(self, dn: Optional[str] = None) -> Optional[context.Profile]: def get_profile(self, dn: Optional[str] = None) -> Optional[context.Profile]:
dn = dn or self.get_dn() dn = dn or self.get_dn()
if dn is None: if dn is None:
return None return None
return self.profiles.get(dn) return self.profiles.get(dn)
def refresh_profile(self, conn: ldap.ldapobject, dn: Optional[str] = None) -> Optional[context.Profile]: def refresh_profile(
self, conn: ldap.ldapobject, dn: Optional[str] = None
) -> Optional[context.Profile]:
dn = dn or self.get_dn() dn = dn or self.get_dn()
if dn is None: if dn is None:
return None return None
@ -65,12 +69,12 @@ def inject_hackerspace_name() -> Dict[str, Any]:
return dict(hackerspace_name=config.hackerspace_name) return dict(hackerspace_name=config.hackerspace_name)
@app.template_filter('first') @app.template_filter("first")
def ldap_first(v: str) -> str: def ldap_first(v: str) -> str:
return v and v[0] return v and v[0]
@app.template_filter('readable') @app.template_filter("readable")
def readable_tf(n: str) -> str: def readable_tf(n: str) -> str:
return config.readable_names.get(n, n) return config.readable_names.get(n, n)
@ -81,15 +85,19 @@ def start() -> None:
from webapp import views from webapp import views
from webapp import auth, admin, avatar, vcard, passwd from webapp import auth, admin, avatar, vcard, passwd
for module in (auth, admin, avatar, vcard, passwd): for module in (auth, admin, avatar, vcard, passwd):
app.register_blueprint(module.bp) app.register_blueprint(module.bp)
app.connections = pools.LDAPConnectionPool(config.ldap_url, timeout=300.0) app.connections = pools.LDAPConnectionPool(config.ldap_url, timeout=300.0)
def drop_profile(dn: str) -> None: def drop_profile(dn: str) -> None:
if dn != config.ldap_admin_dn: if dn != config.ldap_admin_dn:
del app.profiles[dn] del app.profiles[dn]
app.connections.register_callback('drop', drop_profile)
app.connections.register_callback("drop", drop_profile)
app.connections.start() app.connections.start()
app.profiles = {} app.profiles = {}
start() start()

View file

@ -11,88 +11,107 @@ from webapp import app, context, config, ldaputils, email, vcard
from typing import Callable, ParamSpec, List, Tuple, Optional, Dict, Protocol from typing import Callable, ParamSpec, List, Tuple, Optional, Dict, Protocol
bp = flask.Blueprint('admin', __name__) bp = flask.Blueprint("admin", __name__)
Entry = Tuple[str, List[Tuple[str, str]]] Entry = Tuple[str, List[Tuple[str, str]]]
P = ParamSpec('P') P = ParamSpec("P")
def admin_required_impl(f: Callable[P, werkzeug.Response]) -> Callable[P, werkzeug.Response]:
def admin_required_impl(
f: Callable[P, werkzeug.Response]
) -> Callable[P, werkzeug.Response]:
@functools.wraps(f) @functools.wraps(f)
def func(*a: P.args, **kw: P.kwargs) -> werkzeug.Response: def func(*a: P.args, **kw: P.kwargs) -> werkzeug.Response:
# TODO: Actually check for admin perms # TODO: Actually check for admin perms
if not flask.session['is_admin']: if not flask.session["is_admin"]:
flask.abort(403) flask.abort(403)
return f(*a, **kw) return f(*a, **kw)
return func return func
def admin_required(f: Callable[P, werkzeug.Response]) -> Callable[P, werkzeug.Response]: def admin_required(f: Callable[P, werkzeug.Response]) -> Callable[P, werkzeug.Response]:
return webapp.auth.login_required(admin_required_impl(f)) return webapp.auth.login_required(admin_required_impl(f))
def _get_user_list(conn: ldap.ldapobject, query: str = '&') -> List[Tuple[str, str]]:
def _get_user_list(conn: ldap.ldapobject, query: str = "&") -> List[Tuple[str, str]]:
""" """
Returns List[Tuple[username, full name]] for query. Returns List[Tuple[username, full name]] for query.
""" """
all_users = [] all_users = []
results = conn.search_s(config.ldap_people, ldap.SCOPE_SUBTREE, f'(&(uid=*)(cn=*){ldaputils.wrap(query)})', attrlist=['uid', 'cn']) results = conn.search_s(
config.ldap_people,
ldap.SCOPE_SUBTREE,
f"(&(uid=*)(cn=*){ldaputils.wrap(query)})",
attrlist=["uid", "cn"],
)
for user, attrs in results: for user, attrs in results:
user_uid = attrs['uid'][0].decode() user_uid = attrs["uid"][0].decode()
user_cn = attrs['cn'][0].decode() user_cn = attrs["cn"][0].decode()
all_users.append((user_uid, user_cn)) all_users.append((user_uid, user_cn))
all_users.sort(key=lambda user: user[0].lower()) all_users.sort(key=lambda user: user[0].lower())
return all_users return all_users
def _get_groupped_user_list(conn: ldap.ldapobject) -> List[Tuple[str, List[Tuple[str, str]]]]:
def _get_groupped_user_list(
conn: ldap.ldapobject,
) -> List[Tuple[str, List[Tuple[str, str]]]]:
""" """
Returns all users (uid, full name), groupped by active groups. Returns all users (uid, full name), groupped by active groups.
""" """
groupped_users = [ groupped_users = [
(group.capitalize(), _get_user_list(conn, f'memberOf={ldaputils.group_dn(group)}')) (
group.capitalize(),
_get_user_list(conn, f"memberOf={ldaputils.group_dn(group)}"),
)
for group in config.ldap_active_groups for group in config.ldap_active_groups
] ]
inactive_filter = ldaputils._not( inactive_filter = ldaputils._not(ldaputils.member_of_any(config.ldap_active_groups))
ldaputils.member_of_any(config.ldap_active_groups)
)
groupped_users.append( groupped_users.append(("Inactive users", _get_user_list(conn, inactive_filter)))
('Inactive users', _get_user_list(conn, inactive_filter))
)
return groupped_users return groupped_users
def _get_groups_of(conn: ldap.ldapobject, dn: str) -> List[str]: def _get_groups_of(conn: ldap.ldapobject, dn: str) -> List[str]:
filter =f'(&(objectClass=groupOfUniqueNames)(uniqueMember={dn}))' filter = f"(&(objectClass=groupOfUniqueNames)(uniqueMember={dn}))"
groups = [ groups = [
attrs['cn'][0].decode() attrs["cn"][0].decode()
for group_dn, attrs in for group_dn, attrs in conn.search_s(
conn.search_s(config.ldap_base, ldap.SCOPE_SUBTREE, filter) config.ldap_base, ldap.SCOPE_SUBTREE, filter
)
] ]
return groups return groups
def _is_user_protected(conn: ldap.ldapobject, groups: List[str]) -> bool: def _is_user_protected(conn: ldap.ldapobject, groups: List[str]) -> bool:
return any(group in config.ldap_protected_groups for group in groups) return any(group in config.ldap_protected_groups for group in groups)
@bp.route('/admin/')
@bp.route("/admin/")
@admin_required @admin_required
def admin_view() -> werkzeug.Response: def admin_view() -> werkzeug.Response:
return flask.Response(flask.render_template('admin/index.html')) return flask.Response(flask.render_template("admin/index.html"))
@bp.route('/admin/users/')
@bp.route("/admin/users/")
@admin_required @admin_required
def admin_users_view() -> werkzeug.Response: def admin_users_view() -> werkzeug.Response:
conn = app.get_connection() conn = app.get_connection()
assert conn is not None assert conn is not None
groups = _get_groupped_user_list(conn) groups = _get_groupped_user_list(conn)
return flask.Response(flask.render_template('admin/users.html', groups=groups)) return flask.Response(flask.render_template("admin/users.html", groups=groups))
@bp.route('/admin/users/<username>')
@bp.route("/admin/users/<username>")
@admin_required @admin_required
def admin_user_view(username: str) -> werkzeug.Response: def admin_user_view(username: str) -> werkzeug.Response:
ldaputils.validate_name(username) ldaputils.validate_name(username)
@ -105,7 +124,11 @@ def admin_user_view(username: str) -> werkzeug.Response:
groups = _get_groups_of(conn, dn) groups = _get_groups_of(conn, dn)
is_protected = _is_user_protected(conn, groups) is_protected = _is_user_protected(conn, groups)
return flask.Response(flask.render_template('admin/user.html', profile=profile, groups=groups, is_protected=is_protected)) return flask.Response(
flask.render_template(
"admin/user.html", profile=profile, groups=groups, is_protected=is_protected
)
)
class AdminMixin: class AdminMixin:
@ -121,32 +144,42 @@ class AdminMixin:
class AdminOperationAdd(AdminMixin, vcard.OperationAdd): class AdminOperationAdd(AdminMixin, vcard.OperationAdd):
pass pass
class AdminOperationModify(AdminMixin, vcard.OperationModify): class AdminOperationModify(AdminMixin, vcard.OperationModify):
pass pass
class AdminOperationDelete(AdminMixin, vcard.OperationDelete): class AdminOperationDelete(AdminMixin, vcard.OperationDelete):
pass pass
@bp.route('/admin/users/<username>/add_mifareidhash', methods=["GET", "POST"]) @bp.route("/admin/users/<username>/add_mifareidhash", methods=["GET", "POST"])
@admin_required @admin_required
def admin_user_view_add_mifareidhash(username: str) -> werkzeug.Response: def admin_user_view_add_mifareidhash(username: str) -> werkzeug.Response:
dn = ldaputils.user_dn(username) dn = ldaputils.user_dn(username)
op = AdminOperationAdd(dn, "mifareidhash") op = AdminOperationAdd(dn, "mifareidhash")
redirect = f'/admin/users/{username}' redirect = f"/admin/users/{username}"
return op.perform(success_redirect=redirect, fatal_redirect=redirect, action=f'/admin/users/{username}/add_mifareidhash') return op.perform(
success_redirect=redirect,
fatal_redirect=redirect,
action=f"/admin/users/{username}/add_mifareidhash",
)
@bp.route('/admin/users/<username>/del_mifareidhash/<uid>', methods=["GET", "POST"]) @bp.route("/admin/users/<username>/del_mifareidhash/<uid>", methods=["GET", "POST"])
@admin_required @admin_required
def admin_user_view_del_mifareidhash(username: str, uid: str) -> werkzeug.Response: def admin_user_view_del_mifareidhash(username: str, uid: str) -> werkzeug.Response:
dn = ldaputils.user_dn(username) dn = ldaputils.user_dn(username)
op = AdminOperationDelete(dn, uid) op = AdminOperationDelete(dn, uid)
redirect = f'/admin/users/{username}' redirect = f"/admin/users/{username}"
return op.perform(success_redirect=redirect, fatal_redirect=redirect, action=f'/admin/users/{username}/del_mifareidhash/{uid}') return op.perform(
success_redirect=redirect,
fatal_redirect=redirect,
action=f"/admin/users/{username}/del_mifareidhash/{uid}",
)
@bp.route('/admin/groups/') @bp.route("/admin/groups/")
@admin_required @admin_required
def admin_groups_view() -> werkzeug.Response: def admin_groups_view() -> werkzeug.Response:
conn = app.get_connection() conn = app.get_connection()
@ -159,16 +192,20 @@ def admin_groups_view() -> werkzeug.Response:
all_uids = set([uid for uid, cn in all_users]) all_uids = set([uid for uid, cn in all_users])
groups = [ groups = [
attrs['cn'][0].decode() attrs["cn"][0].decode()
for group_dn, attrs in for group_dn, attrs in conn.search_s(
conn.search_s(config.ldap_base, ldap.SCOPE_SUBTREE, 'objectClass=groupOfUniqueNames') config.ldap_base, ldap.SCOPE_SUBTREE, "objectClass=groupOfUniqueNames"
)
] ]
filter_groups = filter((lambda cn: cn not in all_uids), groups) filter_groups = filter((lambda cn: cn not in all_uids), groups)
return flask.Response(flask.render_template('admin/groups.html', groups=filter_groups)) return flask.Response(
flask.render_template("admin/groups.html", groups=filter_groups)
)
@bp.route('/admin/groups/<name>')
@bp.route("/admin/groups/<name>")
@admin_required @admin_required
def admin_group_view(name: str) -> werkzeug.Response: def admin_group_view(name: str) -> werkzeug.Response:
ldaputils.validate_name(name) ldaputils.validate_name(name)
@ -177,8 +214,12 @@ def admin_group_view(name: str) -> werkzeug.Response:
assert conn is not None assert conn is not None
group = context.LDAPEntry(conn, dn) group = context.LDAPEntry(conn, dn)
members = _get_user_list(conn, f'memberOf={ldaputils.group_dn(name)}') members = _get_user_list(conn, f"memberOf={ldaputils.group_dn(name)}")
is_protected = name in config.ldap_protected_groups is_protected = name in config.ldap_protected_groups
return flask.Response(flask.render_template('admin/group.html', group=group, members=members, is_protected=is_protected)) return flask.Response(
flask.render_template(
"admin/group.html", group=group, members=members, is_protected=is_protected
)
)

View file

@ -8,9 +8,10 @@ from webapp import app, avatar, config, ldaputils
from typing import TypeVar, Callable, ParamSpec, Dict, Any, Optional from typing import TypeVar, Callable, ParamSpec, Dict, Any, Optional
bp = flask.Blueprint('auth', __name__) bp = flask.Blueprint("auth", __name__)
P = ParamSpec("P")
P = ParamSpec('P')
def login_required(f: Callable[P, werkzeug.Response]) -> Callable[P, werkzeug.Response]: def login_required(f: Callable[P, werkzeug.Response]) -> Callable[P, werkzeug.Response]:
@functools.wraps(f) @functools.wraps(f)
@ -18,17 +19,23 @@ def login_required(f: Callable[P, werkzeug.Response]) -> Callable[P, werkzeug.Re
conn = app.get_connection() conn = app.get_connection()
if not conn: if not conn:
flask.session.clear() flask.session.clear()
flask.flash('You must log in to continue', category='warning') flask.flash("You must log in to continue", category="warning")
return flask.redirect('/login?' + urllib.parse.urlencode({'goto': flask.request.path})) return flask.redirect(
"/login?" + urllib.parse.urlencode({"goto": flask.request.path})
)
return f(*a, **kw) return f(*a, **kw)
return func return func
def req_to_ctx() -> Dict[str, Any]: def req_to_ctx() -> Dict[str, Any]:
return dict(flask.request.form.items()) return dict(flask.request.form.items())
@bp.route('/login', methods=["GET"])
@bp.route("/login", methods=["GET"])
def login_form() -> werkzeug.Response: def login_form() -> werkzeug.Response:
return flask.Response(flask.render_template('login.html', **req_to_ctx())) return flask.Response(flask.render_template("login.html", **req_to_ctx()))
def _connect_to_ldap(dn: str, password: str) -> Optional[ldap.ldapobject]: def _connect_to_ldap(dn: str, password: str) -> Optional[ldap.ldapobject]:
try: try:
@ -37,7 +44,8 @@ def _connect_to_ldap(dn: str, password: str) -> Optional[ldap.ldapobject]:
print("Could not connect to server:", error_message) print("Could not connect to server:", error_message)
return None return None
@bp.route('/login', methods=["POST"])
@bp.route("/login", methods=["POST"])
def login_action() -> werkzeug.Response: def login_action() -> werkzeug.Response:
# LDAP usernames/DNs are case-insensitive, so we normalize them just in # LDAP usernames/DNs are case-insensitive, so we normalize them just in
# case, # case,
@ -51,27 +59,34 @@ def login_action() -> werkzeug.Response:
# Now that we have logged in, we can retrieve the 'real' username (which # Now that we have logged in, we can retrieve the 'real' username (which
# might be cased differently from the login name). # might be cased differently from the login name).
res = conn.search_s(dn, ldap.SCOPE_SUBTREE) res = conn.search_s(dn, ldap.SCOPE_SUBTREE)
for (k, vs) in res[0][1].items(): for k, vs in res[0][1].items():
if k == 'uid': if k == "uid":
username = vs[0].decode() username = vs[0].decode()
# Check if user belongs to admin group # Check if user belongs to admin group
is_admin = bool(conn.search_s(dn, ldap.SCOPE_SUBTREE, ldaputils.member_of_any(config.ldap_admin_groups))) is_admin = bool(
conn.search_s(
dn,
ldap.SCOPE_SUBTREE,
ldaputils.member_of_any(config.ldap_admin_groups),
)
)
flask.session["username"] = username flask.session["username"] = username
flask.session['dn'] = dn flask.session["dn"] = dn
flask.session['is_admin'] = is_admin flask.session["is_admin"] = is_admin
app.refresh_profile(conn) app.refresh_profile(conn)
avatar.cache.reset_user(username) avatar.cache.reset_user(username)
avatar.hash_cache.reset() avatar.hash_cache.reset()
return flask.redirect(goto) return flask.redirect(goto)
else: else:
flask.flash("Invalid credentials.", category='danger') flask.flash("Invalid credentials.", category="danger")
return login_form() return login_form()
@bp.route('/logout')
@bp.route("/logout")
@login_required @login_required
def logout_action() -> werkzeug.Response: def logout_action() -> werkzeug.Response:
app.connections.unbind(flask.session['dn']) app.connections.unbind(flask.session["dn"])
flask.session.clear() flask.session.clear()
return flask.redirect('/') return flask.redirect("/")

View file

@ -21,8 +21,8 @@ from webapp import app, ldaputils, config
from typing import List from typing import List
bp = flask.Blueprint('avatar', __name__) bp = flask.Blueprint("avatar", __name__)
log = logging.getLogger('ldap-web.avatar') log = logging.getLogger("ldap-web.avatar")
# Stolen from https://stackoverflow.com/questions/43512615/reshaping-rectangular-image-to-square # Stolen from https://stackoverflow.com/questions/43512615/reshaping-rectangular-image-to-square
@ -46,14 +46,22 @@ def resize_image(image: Image.Image, length: int) -> Image.Image:
# The image is in portrait mode. Height is bigger than width. # The image is in portrait mode. Height is bigger than width.
# This makes the width fit the LENGTH in pixels while conserving the ration. # This makes the width fit the LENGTH in pixels while conserving the ration.
resized_image = image.resize((length, int(image.size[1] * (length / image.size[0])))) resized_image = image.resize(
(length, int(image.size[1] * (length / image.size[0])))
)
# Amount of pixel to lose in total on the height of the image. # Amount of pixel to lose in total on the height of the image.
required_loss = (resized_image.size[1] - length) required_loss = resized_image.size[1] - length
# Crop the height of the image so as to keep the center part. # Crop the height of the image so as to keep the center part.
resized_image = resized_image.crop( resized_image = resized_image.crop(
box=(0, required_loss // 2, length, resized_image.size[1] - required_loss // 2)) box=(
0,
required_loss // 2,
length,
resized_image.size[1] - required_loss // 2,
)
)
# We now have a length*length pixels image. # We now have a length*length pixels image.
return resized_image return resized_image
@ -61,14 +69,22 @@ def resize_image(image: Image.Image, length: int) -> Image.Image:
# This image is in landscape mode or already squared. The width is bigger than the heihgt. # This image is in landscape mode or already squared. The width is bigger than the heihgt.
# This makes the height fit the LENGTH in pixels while conserving the ration. # This makes the height fit the LENGTH in pixels while conserving the ration.
resized_image = image.resize((int(image.size[0] * (length / image.size[1])), length)) resized_image = image.resize(
(int(image.size[0] * (length / image.size[1])), length)
)
# Amount of pixel to lose in total on the width of the image. # Amount of pixel to lose in total on the width of the image.
required_loss = resized_image.size[0] - length required_loss = resized_image.size[0] - length
# Crop the width of the image so as to keep 1080 pixels of the center part. # Crop the width of the image so as to keep 1080 pixels of the center part.
resized_image = resized_image.crop( resized_image = resized_image.crop(
box=(required_loss // 2, 0, resized_image.size[0] - required_loss // 2, length)) box=(
required_loss // 2,
0,
resized_image.size[0] - required_loss // 2,
length,
)
)
# We now have a length*length pixels image. # We now have a length*length pixels image.
return resized_image return resized_image
@ -78,7 +94,7 @@ def process_upload(data: bytes) -> bytes:
img = Image.open(io.BytesIO(data)) img = Image.open(io.BytesIO(data))
img = resize_image(img, 256) img = resize_image(img, 256)
res = io.BytesIO() res = io.BytesIO()
img.save(res, 'PNG') img.save(res, "PNG")
return base64.b64encode(res.getvalue()) return base64.b64encode(res.getvalue())
@ -90,18 +106,18 @@ def default_avatar(uid: str) -> bytes:
Create little generative avatar for people who don't have a custom one Create little generative avatar for people who don't have a custom one
configured. configured.
""" """
img = Image.new('RGBA', (256, 256), (255, 255, 255, 0)) img = Image.new("RGBA", (256, 256), (255, 255, 255, 0))
draw = ImageDraw.Draw(img) draw = ImageDraw.Draw(img)
# Deterministic rng for stable output. # Deterministic rng for stable output.
rng = random.Random(uid) rng = random.Random(uid)
# Pick a nice random neon color. # Pick a nice random neon color.
n_h, n_s, n_l = rng.random(), 0.5 + rng.random()/2.0, 0.4 + rng.random()/5.0 n_h, n_s, n_l = rng.random(), 0.5 + rng.random() / 2.0, 0.4 + rng.random() / 5.0
# Use muted version for background. # Use muted version for background.
r, g, b = [int(256*i) for i in colorsys.hls_to_rgb(n_h, n_l+0.3, n_s-0.1)] r, g, b = [int(256 * i) for i in colorsys.hls_to_rgb(n_h, n_l + 0.3, n_s - 0.1)]
draw.rectangle([(0, 0), (256, 256)], fill=(r,g,b,255)) draw.rectangle([(0, 0), (256, 256)], fill=(r, g, b, 255))
# Scale logo by randomized factor. # Scale logo by randomized factor.
factor = 0.7 + 0.1 * rng.random() factor = 0.7 + 0.1 * rng.random()
@ -112,20 +128,20 @@ def default_avatar(uid: str) -> bytes:
overlay = overlay.crop(box=(0, 0, w, w)) overlay = overlay.crop(box=(0, 0, w, w))
# Give it a little nudge. # Give it a little nudge.
overlay = overlay.rotate((rng.random() - 0.5) * 100) # type: ignore overlay = overlay.rotate((rng.random() - 0.5) * 100) # type: ignore
# Colorize with full neon color. # Colorize with full neon color.
r, g, b = [int(256*i) for i in colorsys.hls_to_rgb(n_h,n_l,n_s)] r, g, b = [int(256 * i) for i in colorsys.hls_to_rgb(n_h, n_l, n_s)]
pixels = overlay.load() # type: ignore pixels = overlay.load() # type: ignore
for x in range(img.size[0]): for x in range(img.size[0]):
for y in range(img.size[1]): for y in range(img.size[1]):
alpha = pixels[x, y][3] alpha = pixels[x, y][3]
pixels[x, y] = (r, g, b, alpha) pixels[x, y] = (r, g, b, alpha)
img.alpha_composite(overlay) # type: ignore img.alpha_composite(overlay) # type: ignore
res = io.BytesIO() res = io.BytesIO()
img.save(res, 'PNG') img.save(res, "PNG")
return res.getvalue() return res.getvalue()
@ -160,10 +176,10 @@ class AvatarCacheEntry:
res = io.BytesIO() res = io.BytesIO()
img = resize_image(img, 256) img = resize_image(img, 256)
img.save(res, 'PNG') img.save(res, "PNG")
self._converted = res.getvalue() self._converted = res.getvalue()
return flask.Response(self._converted, mimetype='image/png') return flask.Response(self._converted, mimetype="image/png")
class AvatarCache: class AvatarCache:
@ -210,10 +226,10 @@ class AvatarCache:
is_user_found = len(res) == 1 is_user_found = len(res) == 1
if is_user_found: if is_user_found:
for attr, vs in res[0][1].items(): for attr, vs in res[0][1].items():
if attr == 'jpegPhoto': if attr == "jpegPhoto":
for v in vs: for v in vs:
# Temporary workaround: treat empty jpegPhoto as no avatar. # Temporary workaround: treat empty jpegPhoto as no avatar.
if v == b'': if v == b"":
avatar = None avatar = None
break break
@ -230,7 +246,7 @@ class AvatarCache:
if avatar is None: if avatar is None:
# don't generate avatars for non-users to reduce DoS potential # don't generate avatars for non-users to reduce DoS potential
# (note: capacifier already leaks existence of users, so whatever) # (note: capacifier already leaks existence of users, so whatever)
avatar = default_avatar(uid if is_user_found else 'default') avatar = default_avatar(uid if is_user_found else "default")
# Save avatar in cache. # Save avatar in cache.
entry = AvatarCacheEntry(uid, avatar) entry = AvatarCacheEntry(uid, avatar)
@ -239,12 +255,13 @@ class AvatarCache:
# And serve the entry. # And serve the entry.
return entry.serve() return entry.serve()
cache = AvatarCache() cache = AvatarCache()
def hash_for_uid(uid: str) -> str: def hash_for_uid(uid: str) -> str:
# NOTE: Gravatar documentation says to use SHA256, but everyone passes MD5 instead # NOTE: Gravatar documentation says to use SHA256, but everyone passes MD5 instead
email = f'{uid}@hackerspace.pl'.strip().lower() email = f"{uid}@hackerspace.pl".strip().lower()
hasher = hashlib.md5() hasher = hashlib.md5()
hasher.update(email.encode()) hasher.update(email.encode())
return hasher.hexdigest() return hasher.hexdigest()
@ -253,9 +270,11 @@ def hash_for_uid(uid: str) -> str:
def get_all_user_uids(conn: ldap.ldapobject) -> List[str]: def get_all_user_uids(conn: ldap.ldapobject) -> List[str]:
all_uids = [] all_uids = []
results = conn.search_s(config.ldap_people, ldap.SCOPE_SUBTREE, 'uid=*', attrlist=['uid']) results = conn.search_s(
config.ldap_people, ldap.SCOPE_SUBTREE, "uid=*", attrlist=["uid"]
)
for user, attrs in results: for user, attrs in results:
uid = attrs['uid'][0].decode() uid = attrs["uid"][0].decode()
all_uids.append(uid) all_uids.append(uid)
return all_uids return all_uids
@ -269,7 +288,7 @@ class HashCache:
def get(self, email_hash: str) -> str: def get(self, email_hash: str) -> str:
self.rebuild_if_needed() self.rebuild_if_needed()
return self.entries.get(email_hash, 'default') return self.entries.get(email_hash, "default")
def reset(self) -> None: def reset(self) -> None:
self.entries = {} self.entries = {}
@ -285,7 +304,7 @@ class HashCache:
conn = app.get_admin_connection() conn = app.get_admin_connection()
users = get_all_user_uids(conn) users = get_all_user_uids(conn)
self.deadline = time.time() + config.avatar_cache_timeout self.deadline = time.time() + config.avatar_cache_timeout
self.entries = { hash_for_uid(uid): uid for uid in users } self.entries = {hash_for_uid(uid): uid for uid in users}
hash_cache = HashCache() hash_cache = HashCache()
@ -296,12 +315,12 @@ def sanitize_email_hash(hash: str) -> str:
lowercases, removes file extension (probably) lowercases, removes file extension (probably)
""" """
hash = hash.lower() hash = hash.lower()
if hash.endswith('.png') or hash.endswith('.jpg'): if hash.endswith(".png") or hash.endswith(".jpg"):
hash = hash[:-4] hash = hash[:-4]
return hash return hash
@bp.route('/avatar/<email_hash>', methods=['GET']) @bp.route("/avatar/<email_hash>", methods=["GET"])
def gravatar_serve(email_hash: str) -> flask.Response: def gravatar_serve(email_hash: str) -> flask.Response:
""" """
Serves avatar in a Gravatar-compatible(ish) way, i.e. by email hash, not user name. Serves avatar in a Gravatar-compatible(ish) way, i.e. by email hash, not user name.
@ -310,6 +329,6 @@ def gravatar_serve(email_hash: str) -> flask.Response:
return cache.get(uid) return cache.get(uid)
@bp.route('/avatar/user/<uid>', methods=['GET']) @bp.route("/avatar/user/<uid>", methods=["GET"])
def avatar_serve(uid: str) -> flask.Response: def avatar_serve(uid: str) -> flask.Response:
return cache.get(uid) return cache.get(uid)

View file

@ -5,7 +5,7 @@ import os
from typing import Dict, Set, List, Tuple, Any, TypeVar from typing import Dict, Set, List, Tuple, Any, TypeVar
hackerspace_name: str = 'Warsaw Hackerspace' hackerspace_name: str = "Warsaw Hackerspace"
secret_key: str = secrets.token_hex(32) secret_key: str = secrets.token_hex(32)
# Kerberos configuration # Kerberos configuration
@ -13,86 +13,104 @@ kadmin_principal_map: str = "{}@HACKERSPACE.PL"
# LDAP configuration # LDAP configuration
ldap_url: str = 'ldap://ldap.hackerspace.pl' ldap_url: str = "ldap://ldap.hackerspace.pl"
ldap_base: str = 'dc=hackerspace,dc=pl' ldap_base: str = "dc=hackerspace,dc=pl"
ldap_people: str = 'ou=people,dc=hackerspace,dc=pl' ldap_people: str = "ou=people,dc=hackerspace,dc=pl"
ldap_user_dn_format: str = 'uid={},ou=people,dc=hackerspace,dc=pl' ldap_user_dn_format: str = "uid={},ou=people,dc=hackerspace,dc=pl"
ldap_group_dn_format: str = 'cn={},ou=group,dc=hackerspace,dc=pl' ldap_group_dn_format: str = "cn={},ou=group,dc=hackerspace,dc=pl"
# LDAP user groups allowed to see /admin # LDAP user groups allowed to see /admin
ldap_admin_groups: List[str] = os.getenv('LDAPWEB_ADMIN_GROUPS', 'ldap-admin,staff,zarzad').split(',') ldap_admin_groups: List[str] = os.getenv(
"LDAPWEB_ADMIN_GROUPS", "ldap-admin,staff,zarzad"
).split(",")
# LDAP user groups indicating that a user is active # LDAP user groups indicating that a user is active
ldap_active_groups: List[str] = os.getenv('LDAPWEB_ACTIVE_GROUPS', 'fatty,starving,potato').split(',') ldap_active_groups: List[str] = os.getenv(
"LDAPWEB_ACTIVE_GROUPS", "fatty,starving,potato"
).split(",")
# LDAP service user with admin privileges (for admin listings, creating new users) # LDAP service user with admin privileges (for admin listings, creating new users)
ldap_admin_dn: str = os.getenv('LDAPWEB_ADMIN_DN', 'cn=ldapweb,ou=services,dc=hackerspace,dc=pl') ldap_admin_dn: str = os.getenv(
ldap_admin_password: str = os.getenv('LDAPWEB_ADMIN_PASSWORD', 'unused') "LDAPWEB_ADMIN_DN", "cn=ldapweb,ou=services,dc=hackerspace,dc=pl"
)
ldap_admin_password: str = os.getenv("LDAPWEB_ADMIN_PASSWORD", "unused")
# Protected LDAP user groups # Protected LDAP user groups
# These groups (and their members) cannot be modified by admin UI # These groups (and their members) cannot be modified by admin UI
ldap_protected_groups: List[str] = ( ldap_protected_groups: List[str] = "staff,zarzad,ldap-admin".split(",") + os.getenv(
'staff,zarzad,ldap-admin'.split(',') + "LDAPWEB_PROTECTED_GROUPS", ""
os.getenv('LDAPWEB_PROTECTED_GROUPS', '').split(',') ).split(",")
)
# Email notification (paper trail) configuration # Email notification (paper trail) configuration
smtp_server: str = 'mail.hackerspace.pl' smtp_server: str = "mail.hackerspace.pl"
smtp_format: str = '{}@hackerspace.pl' smtp_format: str = "{}@hackerspace.pl"
smtp_user: str = os.getenv('LDAPWEB_SMTP_USER', 'ldapweb') smtp_user: str = os.getenv("LDAPWEB_SMTP_USER", "ldapweb")
smtp_password: str = os.getenv('LDAPWEB_SMTP_PASSWORD', 'unused') smtp_password: str = os.getenv("LDAPWEB_SMTP_PASSWORD", "unused")
papertrail_recipients: str = os.getenv('LDAPWEB_PAPERTRAIL_RECIPIENTS', 'zarzad@hackerspace.pl') papertrail_recipients: str = os.getenv(
"LDAPWEB_PAPERTRAIL_RECIPIENTS", "zarzad@hackerspace.pl"
)
# Avatar server # Avatar server
avatar_cache_timeout: int = int(os.getenv('LDAPWEB_AVATAR_CACHE_TIMEOUT', '1800')) avatar_cache_timeout: int = int(os.getenv("LDAPWEB_AVATAR_CACHE_TIMEOUT", "1800"))
# LDAP attribute configuration # LDAP attribute configuration
readable_names: Dict[str, str] = { readable_names: Dict[str, str] = {
'jpegphoto': 'Avatar', "jpegphoto": "Avatar",
'commonname': 'Common Name', "commonname": "Common Name",
'givenname': 'Given Name', "givenname": "Given Name",
'gecos': 'GECOS (public name)', "gecos": "GECOS (public name)",
'surname': 'Surname', "surname": "Surname",
'loginshell': 'Shell', "loginshell": "Shell",
'telephonenumber': 'Phone Number', "telephonenumber": "Phone Number",
'mobiletelephonenumber': 'Mobile Number', "mobiletelephonenumber": "Mobile Number",
'sshpublickey': 'SSH Public Key', "sshpublickey": "SSH Public Key",
'mifareidhash': 'MIFARE ID Hash', "mifareidhash": "MIFARE ID Hash",
'mail': 'Email Adress', "mail": "Email Adress",
'mailroutingaddress': 'Email Adress (external)', "mailroutingaddress": "Email Adress (external)",
} }
full_name: Dict[str, str] = { full_name: Dict[str, str] = {
'cn': 'commonname', "cn": "commonname",
'gecos': 'gecos', "gecos": "gecos",
'sn': 'surname', "sn": "surname",
'mobile': 'mobiletelephonenumber', "mobile": "mobiletelephonenumber",
'l': 'locality', "l": "locality",
} }
can_add: Set[str] = { can_add: Set[str] = {
'jpegphoto', "jpegphoto",
'telephonenumber', "telephonenumber",
'mobiletelephonenumber', "mobiletelephonenumber",
'sshpublickey', "sshpublickey",
} }
can_delete: Set[str] = can_add can_delete: Set[str] = can_add
can_modify: Set[str] = can_add | { can_modify: Set[str] = can_add | {
'jpegphoto', "jpegphoto",
'givenname', "givenname",
'surname', "surname",
'commonname', "commonname",
'gecos', "gecos",
}
can: Dict[str, Set[str]] = {
"add": can_add,
"mod": can_modify,
"del": can_delete,
"admin": {"mifareidhash"},
} }
can: Dict[str, Set[str]] = { 'add': can_add, 'mod': can_modify, 'del': can_delete, 'admin': {'mifareidhash'} }
FormField = Tuple[type[wtforms.Field], Dict[str, Any]] FormField = Tuple[type[wtforms.Field], Dict[str, Any]]
default_field: FormField = (wtforms.fields.StringField, {}) default_field: FormField = (wtforms.fields.StringField, {})
fields: Dict[str, FormField] = { fields: Dict[str, FormField] = {
'jpegphoto': (wtforms.fields.FileField, {'validators': []}), "jpegphoto": (wtforms.fields.FileField, {"validators": []}),
'mobiletelephonenumber': (wtforms.fields.StringField, {'validators': [wtforms.validators.Regexp(r'[+0-9 ]+')]}), "mobiletelephonenumber": (
'telephonenumber': (wtforms.fields.StringField, {'validators': [wtforms.validators.Regexp(r'[+0-9 ]+')]}), wtforms.fields.StringField,
{"validators": [wtforms.validators.Regexp(r"[+0-9 ]+")]},
),
"telephonenumber": (
wtforms.fields.StringField,
{"validators": [wtforms.validators.Regexp(r"[+0-9 ]+")]},
),
} }

View file

@ -10,10 +10,12 @@ from webapp import config, validation
from typing import Optional, Dict, List from typing import Optional, Dict, List
class Attr: class Attr:
""" """
A concrete attribute (with value) on a Profile. A concrete attribute (with value) on a Profile.
""" """
name: str name: str
readable_name: Optional[str] readable_name: Optional[str]
value: bytes value: bytes
@ -26,10 +28,10 @@ class Attr:
self.name = name self.name = name
self.readable_name = config.readable_names.get(name) self.readable_name = config.readable_names.get(name)
self.value = value self.value = value
self.uid = hashlib.sha1(name.encode('utf-8') + value).hexdigest() self.uid = hashlib.sha1(name.encode("utf-8") + value).hexdigest()
def __str__(self) -> str: def __str__(self) -> str:
return self.value.decode('utf-8') return self.value.decode("utf-8")
@dataclass @dataclass
@ -37,6 +39,7 @@ class LDAPEntry:
""" """
An LDAP entry, eg. a user profile or a group. An LDAP entry, eg. a user profile or a group.
""" """
# Map from uid/hash to attr. # Map from uid/hash to attr.
fields: Dict[str, Attr] fields: Dict[str, Attr]
# DN of this entry # DN of this entry
@ -44,7 +47,7 @@ class LDAPEntry:
def __init__(self, conn: ldap.ldapobject, dn: str): def __init__(self, conn: ldap.ldapobject, dn: str):
res = conn.search_s(dn, ldap.SCOPE_SUBTREE) res = conn.search_s(dn, ldap.SCOPE_SUBTREE)
assert(len(res) == 1) assert len(res) == 1
self.dn = dn self.dn = dn
self.fields = {} self.fields = {}
for attr, vs in res[0][1].items(): for attr, vs in res[0][1].items():
@ -52,7 +55,6 @@ class LDAPEntry:
a = Attr(attr, v) a = Attr(attr, v)
self.fields[a.uid] = a self.fields[a.uid] = a
def get_attr(self, attr: str) -> Optional[Attr]: def get_attr(self, attr: str) -> Optional[Attr]:
for v in self.fields.values(): for v in self.fields.values():
if v.name == attr: if v.name == attr:
@ -68,7 +70,7 @@ class LDAPEntry:
@property @property
def name(self) -> str: def name(self) -> str:
res = self.get_attr('commonname') res = self.get_attr("commonname")
assert res is not None assert res is not None
return res.value.decode() return res.value.decode()
@ -77,6 +79,7 @@ class Profile(LDAPEntry):
""" """
A user profile. A user profile.
""" """
# Map from uid/hash to attr. # Map from uid/hash to attr.
fields: Dict[str, Attr] fields: Dict[str, Attr]
# DN of this profile # DN of this profile
@ -84,6 +87,6 @@ class Profile(LDAPEntry):
@property @property
def username(self) -> str: def username(self) -> str:
res = self.get_attr('uid') res = self.get_attr("uid")
assert res is not None assert res is not None
return res.value.decode() return res.value.decode()

View file

@ -8,6 +8,7 @@ from webapp import config, context
cached_connection: Optional[smtplib.SMTP] = None cached_connection: Optional[smtplib.SMTP] = None
def test_connection_open(conn: smtplib.SMTP) -> bool: def test_connection_open(conn: smtplib.SMTP) -> bool:
try: try:
status = conn.noop()[0] status = conn.noop()[0]
@ -15,37 +16,43 @@ def test_connection_open(conn: smtplib.SMTP) -> bool:
status = -1 status = -1
return True if status == 250 else False return True if status == 250 else False
def create_connection() -> smtplib.SMTP: def create_connection() -> smtplib.SMTP:
conn = smtplib.SMTP_SSL(config.smtp_server) conn = smtplib.SMTP_SSL(config.smtp_server)
conn.login(config.smtp_user, config.smtp_password) conn.login(config.smtp_user, config.smtp_password)
return conn return conn
def get_connection() -> smtplib.SMTP: def get_connection() -> smtplib.SMTP:
global cached_connection global cached_connection
if cached_connection is not None and test_connection_open(cached_connection): if cached_connection is not None and test_connection_open(cached_connection):
return cached_connection
cached_connection = create_connection()
return cached_connection return cached_connection
cached_connection = create_connection()
return cached_connection
def send_email(conn: smtplib.SMTP, subject: str, body: str, recipient_emails: str) -> None: def send_email(
msg = EmailMessage() conn: smtplib.SMTP, subject: str, body: str, recipient_emails: str
msg.set_content(body) ) -> None:
msg['Subject'] = subject msg = EmailMessage()
msg.set_content(body)
msg["Subject"] = subject
sender_email = config.smtp_format.format(config.smtp_user) sender_email = config.smtp_format.format(config.smtp_user)
msg['From'] = f'LDAPWeb <{sender_email}>' msg["From"] = f"LDAPWeb <{sender_email}>"
msg['To'] = recipient_emails msg["To"] = recipient_emails
conn.send_message(msg)
conn.send_message(msg)
def send_papertrail(title: str, description: str) -> None: def send_papertrail(title: str, description: str) -> None:
username = flask.session.get('username') username = flask.session.get("username")
current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M") current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M")
subject = f'[LDAPWeb] {title}' subject = f"[LDAPWeb] {title}"
body = f"Changed by {username} at {current_time}:\n\n{description or title}" body = f"Changed by {username} at {current_time}:\n\n{description or title}"
recipients = config.papertrail_recipients recipients = config.papertrail_recipients
conn = get_connection() conn = get_connection()
send_email(conn, subject, body, recipients) send_email(conn, subject, body, recipients)

View file

@ -5,45 +5,55 @@ from webapp import config
from typing import List, Tuple, Any, Dict from typing import List, Tuple, Any, Dict
def is_valid_name(name: str) -> bool: def is_valid_name(name: str) -> bool:
"""`true` if `name` is a safe ldap uid/cn""" """`true` if `name` is a safe ldap uid/cn"""
return re.match(r'^[a-zA-Z_][a-zA-Z0-9-_\.]*\Z', name) is not None return re.match(r"^[a-zA-Z_][a-zA-Z0-9-_\.]*\Z", name) is not None
def validate_name(name: str) -> None: def validate_name(name: str) -> None:
"""Raises `RuntimeError` if `name` is not a safe ldap uid/cn""" """Raises `RuntimeError` if `name` is not a safe ldap uid/cn"""
if not is_valid_name(name): if not is_valid_name(name):
raise RuntimeError('Invalid name') raise RuntimeError("Invalid name")
def user_dn(uid: str) -> str: def user_dn(uid: str) -> str:
validate_name(uid) validate_name(uid)
return config.ldap_user_dn_format.format(uid) return config.ldap_user_dn_format.format(uid)
def group_dn(cn: str) -> str: def group_dn(cn: str) -> str:
validate_name(cn) validate_name(cn)
return config.ldap_group_dn_format.format(cn) return config.ldap_group_dn_format.format(cn)
def wrap(filter: str) -> str: def wrap(filter: str) -> str:
if len(filter) and filter[0] == '(' and filter[-1] == ')': if len(filter) and filter[0] == "(" and filter[-1] == ")":
return filter return filter
else: else:
return f'({filter})' return f"({filter})"
def _or(*filters: str) -> str: def _or(*filters: str) -> str:
wrapped = ''.join(wrap(f) for f in filters) wrapped = "".join(wrap(f) for f in filters)
return f'(|{wrapped})' return f"(|{wrapped})"
def _and(*filters: str) -> str: def _and(*filters: str) -> str:
wrapped = ''.join(wrap(f) for f in filters) wrapped = "".join(wrap(f) for f in filters)
return f'(&{wrapped})' return f"(&{wrapped})"
def _not(filter: str) -> str: def _not(filter: str) -> str:
wrapped = wrap(filter) wrapped = wrap(filter)
return f'(!{wrapped})' return f"(!{wrapped})"
def member_of_any(groups: List[str]) -> str: def member_of_any(groups: List[str]) -> str:
"""Returns a filter that matches users that are a member of any of the given group names""" """Returns a filter that matches users that are a member of any of the given group names"""
return _or(*(f'memberOf={group_dn(group)}' for group in groups)) return _or(*(f"memberOf={group_dn(group)}" for group in groups))
def groups_of_user(uid: str) -> str: def groups_of_user(uid: str) -> str:
"""Returns a filter that matches groups that have the given user as a member""" """Returns a filter that matches groups that have the given user as a member"""
return f'(&(objectClass=groupOfUniqueNames)(uniqueMember={user_dn(uid)}))' return f"(&(objectClass=groupOfUniqueNames)(uniqueMember={user_dn(uid)}))"

View file

@ -7,17 +7,19 @@ import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Generic, TypeVar, List, Callable, Optional from typing import Dict, Generic, TypeVar, List, Callable, Optional
log = logging.getLogger('ldap-web.lru') log = logging.getLogger("ldap-web.lru")
K = TypeVar('K') K = TypeVar("K")
V = TypeVar('V') V = TypeVar("V")
Cb = Callable[[K], None] Cb = Callable[[K], None]
@dataclass @dataclass
class _Entry(Generic[V]): class _Entry(Generic[V]):
c: V c: V
atime: float atime: float
class LRUPool(threading.Thread, Generic[K, V]): class LRUPool(threading.Thread, Generic[K, V]):
"""A key-value pool to store objects with a timeout. """A key-value pool to store objects with a timeout.
@ -75,7 +77,7 @@ class LRUPool(threading.Thread, Generic[K, V]):
self.lock.release() self.lock.release()
def _drop(self, key: K) -> Optional[V]: def _drop(self, key: K) -> Optional[V]:
for f in self.callbacks.get('drop', []): for f in self.callbacks.get("drop", []):
f(key) f(key)
res = self.pool.pop(key, None) res = self.pool.pop(key, None)
if res is None: if res is None:
@ -88,5 +90,3 @@ class LRUPool(threading.Thread, Generic[K, V]):
return self._drop(key) return self._drop(key)
finally: finally:
self.lock.release() self.lock.release()

View file

@ -8,35 +8,38 @@ import werkzeug
from webapp import app, context, config from webapp import app, context, config
from webapp.auth import login_required from webapp.auth import login_required
bp = flask.Blueprint('passwd', __name__) bp = flask.Blueprint("passwd", __name__)
@bp.route('/passwd', methods=["GET"])
@bp.route("/passwd", methods=["GET"])
@login_required @login_required
def passwd_form() -> werkzeug.Response: def passwd_form() -> werkzeug.Response:
return flask.Response(flask.render_template('passwd.html')) return flask.Response(flask.render_template("passwd.html"))
def _passwd_kadmin(current: str, new: str) -> bool: def _passwd_kadmin(current: str, new: str) -> bool:
username = flask.session.get('username') username = flask.session.get("username")
try: try:
principal_name = config.kadmin_principal_map.format(username) principal_name = config.kadmin_principal_map.format(username)
return kerberos.changePassword(principal_name, current, new) return kerberos.changePassword(principal_name, current, new)
except Exception as e: except Exception as e:
print('Kerberos error:', e) print("Kerberos error:", e)
logging.exception('kpasswd failed') logging.exception("kpasswd failed")
return False return False
@bp.route('/passwd', methods=["POST"]) @bp.route("/passwd", methods=["POST"])
@login_required @login_required
def passwd_action() -> werkzeug.Response: def passwd_action() -> werkzeug.Response:
current, new, confirm = (flask.request.form[n] for n in ('current', 'new', 'confirm')) current, new, confirm = (
flask.request.form[n] for n in ("current", "new", "confirm")
)
if new != confirm: if new != confirm:
flask.flash("New passwords don't match", category='danger') flask.flash("New passwords don't match", category="danger")
return flask.Response(flask.render_template('passwd.html')) return flask.Response(flask.render_template("passwd.html"))
if _passwd_kadmin(current, new): if _passwd_kadmin(current, new):
flask.flash('Password changed', category='info') flask.flash("Password changed", category="info")
else: else:
flask.flash('Wrong password', category='danger') flask.flash("Wrong password", category="danger")
return flask.Response(flask.render_template('passwd.html')) return flask.Response(flask.render_template("passwd.html"))

View file

@ -6,6 +6,7 @@ from typing import Optional, Any
Pool = lru.LRUPool[str, ldap.ldapobject] Pool = lru.LRUPool[str, ldap.ldapobject]
class LDAPConnectionPool(Pool): class LDAPConnectionPool(Pool):
def __init__(self, url: str, use_tls: bool = True, **kwargs: Any) -> None: def __init__(self, url: str, use_tls: bool = True, **kwargs: Any) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
@ -16,7 +17,7 @@ class LDAPConnectionPool(Pool):
self.lock.acquire() self.lock.acquire()
try: try:
conn = ldap.initialize(self.url) conn = ldap.initialize(self.url)
if(self.use_tls): if self.use_tls:
conn.start_tls_s() conn.start_tls_s()
conn.simple_bind_s(dn, password) conn.simple_bind_s(dn, password)
return self._insert(dn, conn) return self._insert(dn, conn)
@ -25,5 +26,3 @@ class LDAPConnectionPool(Pool):
def unbind(self, dn: str) -> Optional[ldap.ldapobject]: def unbind(self, dn: str) -> Optional[ldap.ldapobject]:
return self.drop(dn) return self.drop(dn)

View file

@ -4,13 +4,17 @@ import flask
from webapp import config from webapp import config
def sanitize_perms() -> None: def sanitize_perms() -> None:
config.can = { k: set(map(sanitize_ldap, v)) for k,v in config.can.items() } config.can = {k: set(map(sanitize_ldap, v)) for k, v in config.can.items()}
def sanitize_readable() -> None: def sanitize_readable() -> None:
config.readable_names = { sanitize_ldap(k): v for k, v in config.readable_names.items() } config.readable_names = {
sanitize_ldap(k): v for k, v in config.readable_names.items()
}
def sanitize_ldap(k: str) -> str: def sanitize_ldap(k: str) -> str:
k = k.lower() k = k.lower()
return (k in config.full_name and config.full_name[k]) or k return (k in config.full_name and config.full_name[k]) or k

View file

@ -1,7 +1,17 @@
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from functools import reduce from functools import reduce
from typing import Optional, cast, Dict, List, TypedDict, Literal, Union, TypeVar, Generic from typing import (
Optional,
cast,
Dict,
List,
TypedDict,
Literal,
Union,
TypeVar,
Generic,
)
import ldap import ldap
import flask import flask
@ -12,7 +22,7 @@ import werkzeug
from webapp import app, context, config, validation, avatar from webapp import app, context, config, validation, avatar
from webapp.auth import login_required from webapp.auth import login_required
bp = flask.Blueprint('vcard', __name__) bp = flask.Blueprint("vcard", __name__)
# NOTE: this code is quite hairy. Simplifying it further is very possible, but # NOTE: this code is quite hairy. Simplifying it further is very possible, but
@ -31,6 +41,7 @@ class DelForm(flask_wtf.FlaskForm):
Used as is. Used as is.
""" """
attr_data: context.Attr attr_data: context.Attr
@ -40,6 +51,7 @@ class AddModifyForm(flask_wtf.FlaskForm):
Base class subclassed by dynamically generated classes in initialize_forms. Base class subclassed by dynamically generated classes in initialize_forms.
""" """
attr_data: Optional[context.Attr] attr_data: Optional[context.Attr]
value: wtforms.Field value: wtforms.Field
@ -49,11 +61,13 @@ def initialize_forms() -> Dict[str, type[AddModifyForm]]:
Create AddModifyForm subclasses keyed by attribute name, based on config. Create AddModifyForm subclasses keyed by attribute name, based on config.
""" """
forms: Dict[str, type[AddModifyForm]] = {} forms: Dict[str, type[AddModifyForm]] = {}
for f in reduce(lambda a,b: a | b, config.can.values()): for f in reduce(lambda a, b: a | b, config.can.values()):
cls, attrs = config.fields.get(f, config.default_field) cls, attrs = config.fields.get(f, config.default_field)
class AddForm(AddModifyForm): class AddForm(AddModifyForm):
value = cls(label=config.readable_names.get(f), **attrs) value = cls(label=config.readable_names.get(f), **attrs)
AddForm.__name__ == 'Add' + f
AddForm.__name__ == "Add" + f
forms[f] = AddForm forms[f] = AddForm
return forms return forms
@ -61,7 +75,8 @@ def initialize_forms() -> Dict[str, type[AddModifyForm]]:
add_modify_forms = initialize_forms() add_modify_forms = initialize_forms()
F = TypeVar('F', bound=flask_wtf.FlaskForm) F = TypeVar("F", bound=flask_wtf.FlaskForm)
class Operation(Generic[F]): class Operation(Generic[F]):
""" """
@ -73,6 +88,7 @@ class Operation(Generic[F]):
Then further subclassed by operation kind: add, delete or modify. Then further subclassed by operation kind: add, delete or modify.
""" """
kind: str kind: str
perm_error: str perm_error: str
dn: str dn: str
@ -80,7 +96,7 @@ class Operation(Generic[F]):
@property @property
def _template_path(self) -> str: def _template_path(self) -> str:
return f'ops/{self.kind}.html' return f"ops/{self.kind}.html"
def _make_form(self) -> F: def _make_form(self) -> F:
""" """
@ -106,7 +122,12 @@ class Operation(Generic[F]):
return self.perm_error return self.perm_error
return None return None
def perform(self, success_redirect: str = '/vcard', fatal_redirect: str = '/vcard', action: str = '/vcard') -> werkzeug.Response: def perform(
self,
success_redirect: str = "/vcard",
fatal_redirect: str = "/vcard",
action: str = "/vcard",
) -> werkzeug.Response:
""" """
Primary entrypoint to operation. To be called from a view. Primary entrypoint to operation. To be called from a view.
""" """
@ -116,7 +137,7 @@ class Operation(Generic[F]):
# Check permissions per config. # Check permissions per config.
perm_err = self._allowed(self.dn) perm_err = self._allowed(self.dn)
if perm_err is not None: if perm_err is not None:
flask.flash(perm_err, 'danger') flask.flash(perm_err, "danger")
return flask.redirect(fatal_redirect) return flask.redirect(fatal_redirect)
form = self._make_form() form = self._make_form()
@ -125,8 +146,8 @@ class Operation(Generic[F]):
try: try:
self._on_form_submit(self.dn, conn, form) self._on_form_submit(self.dn, conn, form)
except ldap.LDAPError as e: except ldap.LDAPError as e:
print('LDAP error:', e) print("LDAP error:", e)
flask.flash('Could not modify profile', 'danger') flask.flash("Could not modify profile", "danger")
return flask.redirect(fatal_redirect) return flask.redirect(fatal_redirect)
profile = app.refresh_profile(conn) profile = app.refresh_profile(conn)
@ -141,10 +162,12 @@ class Operation(Generic[F]):
for field, errors in form.errors.items(): for field, errors in form.errors.items():
assert field is not None assert field is not None
for error in errors: for error in errors:
flask.flash("Error in the {} field - {}".format( flask.flash(
getattr(form, field).label.text, "Error in the {} field - {}".format(
error getattr(form, field).label.text, error
), 'danger') ),
"danger",
)
return self._render_form(form, action) return self._render_form(form, action)
@ -153,12 +176,17 @@ class OperationWithAttrName(Operation[F]):
Operations acting on an attribute name, without a corresponding existing Operations acting on an attribute name, without a corresponding existing
attribute. This is used when adding profile fields/attributes. attribute. This is used when adding profile fields/attributes.
""" """
def __init__(self, dn: str, attr_name: str) -> None: def __init__(self, dn: str, attr_name: str) -> None:
self.dn = dn self.dn = dn
self.attr_name = attr_name self.attr_name = attr_name
def _render_form(self, form: F, action: str) -> werkzeug.Response: def _render_form(self, form: F, action: str) -> werkzeug.Response:
return flask.Response(flask.render_template(self._template_path, form=form, attr_name=self.attr_name, action=action)) return flask.Response(
flask.render_template(
self._template_path, form=form, attr_name=self.attr_name, action=action
)
)
class OperationWithUid(Operation[F]): class OperationWithUid(Operation[F]):
@ -167,6 +195,7 @@ class OperationWithUid(Operation[F]):
'uid'. This is used when modifying existing fields/attributes or removing 'uid'. This is used when modifying existing fields/attributes or removing
them. them.
""" """
attr: context.Attr attr: context.Attr
def __init__(self, dn: str, uid: str): def __init__(self, dn: str, uid: str):
@ -178,23 +207,33 @@ class OperationWithUid(Operation[F]):
self.attr_name = self.attr.name self.attr_name = self.attr.name
def _render_form(self, form: F, action: str) -> werkzeug.Response: def _render_form(self, form: F, action: str) -> werkzeug.Response:
return flask.Response(flask.render_template(self._template_path, form=form, attr_name=self.attr.name, uid=self.attr.uid, action=action)) return flask.Response(
flask.render_template(
self._template_path,
form=form,
attr_name=self.attr.name,
uid=self.attr.uid,
action=action,
)
)
class OperationAdd(OperationWithAttrName[AddModifyForm]): class OperationAdd(OperationWithAttrName[AddModifyForm]):
kind = 'add' kind = "add"
perm_error = 'You cannot add this attribute!' perm_error = "You cannot add this attribute!"
def _make_form(self) -> AddModifyForm: def _make_form(self) -> AddModifyForm:
form = add_modify_forms[self.attr_name]() form = add_modify_forms[self.attr_name]()
return form return form
def _on_form_submit(self, dn: str, conn: ldap.ldapobject, form: AddModifyForm) -> None: def _on_form_submit(
self, dn: str, conn: ldap.ldapobject, form: AddModifyForm
) -> None:
# Special case for jpegphoto # Special case for jpegphoto
value = form.value.data value = form.value.data
if self.attr_name == 'jpegphoto': if self.attr_name == "jpegphoto":
value = avatar.process_upload(form.value.data.read()) value = avatar.process_upload(form.value.data.read())
print(f'Uploading avatar (size: {len(value)}) for {dn}') print(f"Uploading avatar (size: {len(value)}) for {dn}")
# jpegPhoto should always be REPLACED. # jpegPhoto should always be REPLACED.
conn.modify_s(dn, [(ldap.MOD_REPLACE, self.attr_name, value)]) conn.modify_s(dn, [(ldap.MOD_REPLACE, self.attr_name, value)])
return return
@ -204,22 +243,24 @@ class OperationAdd(OperationWithAttrName[AddModifyForm]):
class OperationModify(OperationWithUid[AddModifyForm]): class OperationModify(OperationWithUid[AddModifyForm]):
kind = 'mod' kind = "mod"
perm_error = 'You cannot modify this attribute!' perm_error = "You cannot modify this attribute!"
def _make_form(self) -> AddModifyForm: def _make_form(self) -> AddModifyForm:
form = add_modify_forms[self.attr_name](value=str(self.attr)) form = add_modify_forms[self.attr_name](value=str(self.attr))
return form return form
def _on_form_submit(self, dn: str, conn: ldap.ldapobject, form: AddModifyForm) -> None: def _on_form_submit(
self, dn: str, conn: ldap.ldapobject, form: AddModifyForm
) -> None:
# Special case for jpegphoto # Special case for jpegphoto
value = form.value.data value = form.value.data
if self.attr_name == 'jpegphoto': if self.attr_name == "jpegphoto":
value = avatar.process_upload(form.value.data.read()) value = avatar.process_upload(form.value.data.read())
print(f'Uploading avatar (size: {len(value)}) for {dn}') print(f"Uploading avatar (size: {len(value)}) for {dn}")
assert value is not None assert value is not None
if self.attr_name in ['commonname']: if self.attr_name in ["commonname"]:
# Modify directly # Modify directly
conn.modify_s(dn, [(ldap.MOD_REPLACE, self.attr_name, value.encode())]) conn.modify_s(dn, [(ldap.MOD_REPLACE, self.attr_name, value.encode())])
else: else:
@ -227,24 +268,27 @@ class OperationModify(OperationWithUid[AddModifyForm]):
conn.modify_s(dn, [(ldap.MOD_DELETE, self.attr_name, self.attr.value)]) conn.modify_s(dn, [(ldap.MOD_DELETE, self.attr_name, self.attr.value)])
conn.modify_s(dn, [(ldap.MOD_ADD, self.attr_name, value.encode())]) conn.modify_s(dn, [(ldap.MOD_ADD, self.attr_name, value.encode())])
class OperationDelete(OperationWithUid[DelForm]): class OperationDelete(OperationWithUid[DelForm]):
kind = 'del' kind = "del"
perm_error = 'You cannot delete this attribute!' perm_error = "You cannot delete this attribute!"
def _make_form(self) -> DelForm: def _make_form(self) -> DelForm:
res = DelForm() res = DelForm()
res.attr_data = self.attr res.attr_data = self.attr
return res return res
def _on_form_submit(self, dn: str, conn: ldap.ldapobject, form: flask_wtf.FlaskForm) -> None: def _on_form_submit(
if self.attr_name == 'jpegphoto': self, dn: str, conn: ldap.ldapobject, form: flask_wtf.FlaskForm
) -> None:
if self.attr_name == "jpegphoto":
# We apparently can't remove these, so just set it empty. # We apparently can't remove these, so just set it empty.
conn.modify_s(dn, [(ldap.MOD_REPLACE, self.attr_name, b'')]) conn.modify_s(dn, [(ldap.MOD_REPLACE, self.attr_name, b"")])
return return
conn.modify_s(dn, [(ldap.MOD_DELETE, self.attr_name, self.attr.value)]) conn.modify_s(dn, [(ldap.MOD_DELETE, self.attr_name, self.attr.value)])
@bp.route('/vcard', methods=['GET']) @bp.route("/vcard", methods=["GET"])
@login_required @login_required
def vcard() -> werkzeug.Response: def vcard() -> werkzeug.Response:
data: Dict[str, List[context.Attr]] = {} data: Dict[str, List[context.Attr]] = {}
@ -252,29 +296,39 @@ def vcard() -> werkzeug.Response:
assert profile is not None assert profile is not None
for v in profile.fields.values(): for v in profile.fields.values():
data.setdefault(v.name, []).append(v) data.setdefault(v.name, []).append(v)
return flask.Response(flask.render_template('vcard.html', can_add=config.can['add'], return flask.Response(
can_modify=config.can['mod'], can_delete=config.can['del'], profile=data)) flask.render_template(
"vcard.html",
can_add=config.can["add"],
can_modify=config.can["mod"],
can_delete=config.can["del"],
profile=data,
)
)
@bp.route('/vcard/add/<attr_name>', methods=['GET', 'POST'])
@bp.route("/vcard/add/<attr_name>", methods=["GET", "POST"])
@login_required @login_required
def add_attr(attr_name: str) -> werkzeug.Response: def add_attr(attr_name: str) -> werkzeug.Response:
dn = app.get_dn() dn = app.get_dn()
assert dn is not None assert dn is not None
op = OperationAdd(dn, attr_name) op = OperationAdd(dn, attr_name)
return op.perform(action=f'/vcard/add/{attr_name}') return op.perform(action=f"/vcard/add/{attr_name}")
@bp.route('/vcard/delete/<uid>', methods=['GET', 'POST'])
@bp.route("/vcard/delete/<uid>", methods=["GET", "POST"])
@login_required @login_required
def del_attr(uid: str) -> werkzeug.Response: def del_attr(uid: str) -> werkzeug.Response:
dn = app.get_dn() dn = app.get_dn()
assert dn is not None assert dn is not None
op = OperationDelete(dn, uid) op = OperationDelete(dn, uid)
return op.perform(action=f'/vcard/delete/{uid}') return op.perform(action=f"/vcard/delete/{uid}")
@bp.route('/vcard/modify/<uid>', methods=['GET', 'POST'])
@bp.route("/vcard/modify/<uid>", methods=["GET", "POST"])
@login_required @login_required
def mod_attr(uid: str) -> werkzeug.Response: def mod_attr(uid: str) -> werkzeug.Response:
dn = app.get_dn() dn = app.get_dn()
assert dn is not None assert dn is not None
op = OperationModify(dn, uid) op = OperationModify(dn, uid)
return op.perform(action=f'/vcard/modify/{uid}') return op.perform(action=f"/vcard/modify/{uid}")

View file

@ -4,8 +4,8 @@ import werkzeug
from webapp import app, context, config from webapp import app, context, config
from webapp.auth import login_required from webapp.auth import login_required
@app.route("/") @app.route("/")
@login_required @login_required
def root() -> werkzeug.Response: def root() -> werkzeug.Response:
return flask.Response(flask.render_template('root.html', **flask.session)) return flask.Response(flask.render_template("root.html", **flask.session))