diff --git a/README b/README index 940db3f..2147adb 100644 --- a/README +++ b/README @@ -43,6 +43,7 @@ The following flags are automatically registered: - `-listen_address` (default: `127.0.0.1:4200`): where to listen for gRPC requests - `-debug_address` (default: `127.0.0.1:4201`): where to listen for debug HTTP requests + - `-debug_allow_all` (default: false): whether to allow all IP address (vs. localhost) to connect to debug endpoint Since this library also includes [hspki](https://code.hackerspace.pl/q3k/hspki), you also get all the typical `-hspki_{...}` flags included. diff --git a/mirko.go b/mirko.go index 4a5f01f..90959ae 100644 --- a/mirko.go +++ b/mirko.go @@ -19,11 +19,13 @@ import ( var ( flagListenAddress string flagDebugAddress string + flagDebugAllowAll bool ) func init() { flag.StringVar(&flagListenAddress, "listen_address", "127.0.0.1:4200", "gRPC listen address") flag.StringVar(&flagDebugAddress, "debug_address", "127.0.0.1:4201", "HTTP debug/status listen address") + flag.StringVar(&flagDebugAllowAll, "debug_allow_all", false, "HTTP debug/status available to everyone") flag.Set("logtostderr", "true") } @@ -39,8 +41,28 @@ func New() *Mirko { return &Mirko{} } +func authRequest(req *http.Request) (any, sensitive bool) { + host, _, err := net.SplitHostPort(req.RemoteAddr) + if err != nil { + host = req.RemoteAddr + } + + if flagDebugAllowAll { + return true, true + } + + switch host { + case "localhost", "127.0.0.1", "::1": + return true, true + default: + return false, false + } +} + func (m *Mirko) Listen() error { grpc.EnableTracing = true + trace.AuthRequest = authRequest + grpcLis, err := net.Listen("tcp", flagListenAddress) if err != nil { return fmt.Errorf("net.Listen: %v", err) @@ -56,7 +78,14 @@ func (m *Mirko) Listen() error { m.httpMux = http.NewServeMux() // Canonical URLs - m.httpMux.HandleFunc("/debug/status", statusz.StatusHandler) + m.httpMux.HandleFunc("/debug/status", func(w http.ResponseWriter, r *http.Request) { + any, sensitive := authRequest(r) + if !any { + http.Error(w, "not allowed", http.StatusUnauthorized) + return + } + statusz.StatusHandler(w, r) + }) m.httpMux.HandleFunc("/debug/requests", trace.Traces) // -z legacy URLs