forked from hswaw/hscloud
112 lines
3.9 KiB
Python
112 lines
3.9 KiB
Python
from at.dhcp import DhcpdUpdater, DhcpLease
|
|
from pathlib import Path
|
|
import yaml
|
|
import grpc
|
|
import json
|
|
import re
|
|
import subprocess
|
|
import logging
|
|
from concurrent import futures
|
|
from datetime import datetime, timezone
|
|
|
|
from .tracker_pb2 import DhcpClient, DhcpClients, HwAddrResponse
|
|
from .tracker_pb2_grpc import DhcpTrackerServicer, add_DhcpTrackerServicer_to_server
|
|
|
|
import argparse
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--verbose", help="output more info", action="store_true")
|
|
parser.add_argument("config", type=Path, help="input file")
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
def lease_to_client(lease: DhcpLease) -> DhcpClient:
|
|
return DhcpClient(
|
|
hw_address = bytes.fromhex(lease.hwaddr.replace(':', '')),
|
|
last_seen = datetime.utcfromtimestamp(lease.atime).replace(
|
|
tzinfo=timezone.utc).isoformat(),
|
|
client_hostname = lease.name,
|
|
ip_address = lease.ip
|
|
)
|
|
|
|
class DhcpTrackerServicer(DhcpTrackerServicer):
|
|
def __init__(self, tracker: DhcpdUpdater, *args, **kwargs):
|
|
self._tracker = tracker
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def _authorize(self, context):
|
|
auth = context.auth_context()
|
|
ctype = auth.get('transport_security_type', 'local')
|
|
print(ctype)
|
|
if ctype == [b'ssl']:
|
|
if b'at.hackerspace.pl' not in context.peer_identities():
|
|
context.abort(
|
|
grpc.StatusCode.PERMISSION_DENIED,
|
|
(
|
|
"Only at.hackespace.pl is allowed to access raw "
|
|
"clients addresses"
|
|
)
|
|
)
|
|
elif ctype == 'local':
|
|
# connection from local unix socket is trusted by default
|
|
pass
|
|
else:
|
|
context.abort(
|
|
grpc.StatusCode.PERMISSION_DENIED,
|
|
f"Unknown transport type: {ctype}"
|
|
)
|
|
|
|
def GetClients(self, request, context):
|
|
self._authorize(context)
|
|
|
|
clients = [
|
|
lease_to_client(c) for c in self._tracker.get_active_devices().values()]
|
|
return DhcpClients(clients = clients)
|
|
|
|
def GetHwAddr(self, request, context):
|
|
self._authorize(context)
|
|
ip_address = str(request.ip_address)
|
|
if not re.fullmatch('[0-9a-fA-F:.]*', ip_address):
|
|
raise ValueError(f'Invalid ip address: {ip_address!r}')
|
|
logging.info(f'running ip neigh on {ip_address}')
|
|
r = subprocess.run(['ip', '-json', 'neigh', 'show', ip_address], check=True, capture_output=True)
|
|
neighs = json.loads(r.stdout)
|
|
if neighs:
|
|
return HwAddrResponse(hw_address=bytes.fromhex(neighs[0]['lladdr'].replace(':', '')))
|
|
return HwAddrResponse(hw_address=None)
|
|
|
|
def server():
|
|
args = parser.parse_args()
|
|
|
|
config = yaml.safe_load(args.config.read_text())
|
|
tracker = DhcpdUpdater(config['LEASE_FILE'], config['TIMEOUT'])
|
|
tracker.start()
|
|
|
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
|
add_DhcpTrackerServicer_to_server(DhcpTrackerServicer(tracker), server)
|
|
|
|
|
|
tls_address = config.get("GRPC_TLS_ADDRESS", None)
|
|
if tls_address:
|
|
cert_dir = Path(config.get('GRPC_TLS_CERT_DIR', 'cert'))
|
|
ca_cert = Path(config.get('GRPC_TLS_CA_CERT', 'ca.pem')).read_bytes()
|
|
|
|
server_credentials = grpc.ssl_server_credentials(
|
|
private_key_certificate_chain_pairs = ((
|
|
cert_dir.joinpath('key.pem').read_bytes(),
|
|
cert_dir.joinpath('cert.pem').read_bytes()
|
|
),),
|
|
root_certificates = ca_cert,
|
|
require_client_auth = True
|
|
)
|
|
|
|
server.add_secure_port(config.get('GRPC_TLS_ADDRESS', '[::]:2847'), server_credentials)
|
|
|
|
unix_socket = config.get('GRPC_UNIX_SOCKET', False)
|
|
if unix_socket:
|
|
server.add_insecure_port(f'unix://{unix_socket}')
|
|
|
|
if tls_address or unix_socket:
|
|
print('starting grpc server ...')
|
|
server.start()
|
|
server.wait_for_termination()
|