forked from hswaw/hscloud
88 lines
2.3 KiB
Go
88 lines
2.3 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"golang.org/x/net/trace"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/credentials"
|
|
"google.golang.org/grpc/peer"
|
|
"google.golang.org/grpc/status"
|
|
)
|
|
|
|
type clientPKIInfo struct {
|
|
realm string
|
|
principal string
|
|
job string
|
|
}
|
|
|
|
func (c *clientPKIInfo) String() string {
|
|
return fmt.Sprintf("job=%q, principal=%q, realm=%q", c.job, c.principal, c.realm)
|
|
}
|
|
|
|
func parseClientName(realm, name string) (*clientPKIInfo, error) {
|
|
if !strings.HasSuffix(name, "."+realm) {
|
|
return nil, fmt.Errorf("invalid realm")
|
|
}
|
|
service := strings.TrimSuffix(name, "."+realm)
|
|
parts := strings.Split(service, ".")
|
|
if len(parts) != 2 {
|
|
return nil, fmt.Errorf("invalid service")
|
|
}
|
|
return &clientPKIInfo{
|
|
realm: realm,
|
|
principal: parts[1],
|
|
job: parts[0],
|
|
}, nil
|
|
}
|
|
|
|
const (
|
|
ctxKeyPKIInfo = "hscloud-pki-info"
|
|
)
|
|
|
|
func withPKIInfo(ctx context.Context, c *clientPKIInfo) context.Context {
|
|
tr, ok := trace.FromContext(ctx)
|
|
if ok {
|
|
tr.LazyPrintf("PKI Peer: %s", c.String())
|
|
}
|
|
return context.WithValue(ctx, ctxKeyPKIInfo, c)
|
|
}
|
|
|
|
func (s *server) unaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
|
peer, ok := peer.FromContext(ctx)
|
|
if !ok {
|
|
s.trace(ctx, "Could not establish identity of peer.")
|
|
return nil, status.Error(codes.InvalidArgument, "no peer info")
|
|
}
|
|
|
|
authInfo, ok := peer.AuthInfo.(credentials.TLSInfo)
|
|
if !ok {
|
|
s.trace(ctx, "Could not establish TLS identity of peer.")
|
|
return nil, status.Error(codes.InvalidArgument, "no TLS peer info")
|
|
}
|
|
|
|
chains := authInfo.State.VerifiedChains
|
|
if len(chains) != 1 {
|
|
s.trace(ctx, "No trusted chain found.")
|
|
return nil, status.Error(codes.InvalidArgument, "invalid TLS certificate")
|
|
}
|
|
chain := chains[0]
|
|
|
|
certDNs := make([]string, len(chain))
|
|
for i, cert := range chain {
|
|
certDNs[i] = cert.Subject.String()
|
|
}
|
|
s.trace(ctx, "TLS chain: %s", strings.Join(certDNs, ", "))
|
|
|
|
clientInfo, err := parseClientName(s.opts.pkiRealm, chain[0].Subject.CommonName)
|
|
if err != nil {
|
|
s.trace(ctx, "Could not parse certificate DN: %v", err)
|
|
return nil, status.Error(codes.InvalidArgument, "invalid TLS CommonName")
|
|
}
|
|
ctx = withPKIInfo(ctx, clientInfo)
|
|
|
|
return handler(ctx, req)
|
|
}
|