forked from hswaw/hscloud
Move pki.go into code.hackerspace.pl/q3k/hspki
parent
0ca40feb95
commit
f9d85cf585
4
grpc.go
4
grpc.go
|
@ -9,6 +9,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"code.hackerspace.pl/q3k/hspki"
|
||||||
"github.com/golang/glog"
|
"github.com/golang/glog"
|
||||||
"github.com/q3k/statusz"
|
"github.com/q3k/statusz"
|
||||||
"golang.org/x/net/trace"
|
"golang.org/x/net/trace"
|
||||||
|
@ -25,7 +26,6 @@ type serverOpts struct {
|
||||||
tlsCAPath string
|
tlsCAPath string
|
||||||
tlsCertificatePath string
|
tlsCertificatePath string
|
||||||
tlsKeyPath string
|
tlsKeyPath string
|
||||||
pkiRealm string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type server struct {
|
type server struct {
|
||||||
|
@ -110,7 +110,7 @@ func (s *server) setupDebugHTTP(mux http.Handler) error {
|
||||||
func (s *server) serveForever() {
|
func (s *server) serveForever() {
|
||||||
grpc.EnableTracing = true
|
grpc.EnableTracing = true
|
||||||
|
|
||||||
if err := s.setupGRPC(grpc.UnaryInterceptor(s.unaryInterceptor)); err != nil {
|
if err := s.setupGRPC(hspki.WithServerHSPKI()); err != nil {
|
||||||
glog.Exitf("Could not setup GRPC server: %v", err)
|
glog.Exitf("Could not setup GRPC server: %v", err)
|
||||||
}
|
}
|
||||||
pb.RegisterAristaProxyServer(s.grpc.server, s)
|
pb.RegisterAristaProxyServer(s.grpc.server, s)
|
||||||
|
|
3
main.go
3
main.go
|
@ -15,7 +15,6 @@ var (
|
||||||
flagCAPath string
|
flagCAPath string
|
||||||
flagCertificatePath string
|
flagCertificatePath string
|
||||||
flagKeyPath string
|
flagKeyPath string
|
||||||
flagPKIRealm string
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type aristaClient struct {
|
type aristaClient struct {
|
||||||
|
@ -47,7 +46,6 @@ func main() {
|
||||||
flag.StringVar(&flagCAPath, "tls_ca_path", "pki/ca.pem", "Path to PKI CA certificate")
|
flag.StringVar(&flagCAPath, "tls_ca_path", "pki/ca.pem", "Path to PKI CA certificate")
|
||||||
flag.StringVar(&flagCertificatePath, "tls_certificate_path", "pki/service.pem", "Path to PKI service certificate")
|
flag.StringVar(&flagCertificatePath, "tls_certificate_path", "pki/service.pem", "Path to PKI service certificate")
|
||||||
flag.StringVar(&flagKeyPath, "tls_key_path", "pki/service-key.pem", "Path to PKI service private key")
|
flag.StringVar(&flagKeyPath, "tls_key_path", "pki/service-key.pem", "Path to PKI service private key")
|
||||||
flag.StringVar(&flagPKIRealm, "pki_realm", "svc.cluster.local", "PKI realm")
|
|
||||||
flag.Set("logtostderr", "true")
|
flag.Set("logtostderr", "true")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
|
@ -61,7 +59,6 @@ func main() {
|
||||||
tlsCAPath: flagCAPath,
|
tlsCAPath: flagCAPath,
|
||||||
tlsCertificatePath: flagCertificatePath,
|
tlsCertificatePath: flagCertificatePath,
|
||||||
tlsKeyPath: flagKeyPath,
|
tlsKeyPath: flagKeyPath,
|
||||||
pkiRealm: flagPKIRealm,
|
|
||||||
}
|
}
|
||||||
server, err := newServer(opts, arista)
|
server, err := newServer(opts, arista)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
88
pki.go
88
pki.go
|
@ -1,88 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"golang.org/x/net/trace"
|
|
||||||
"google.golang.org/grpc"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/credentials"
|
|
||||||
"google.golang.org/grpc/peer"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
)
|
|
||||||
|
|
||||||
type clientPKIInfo struct {
|
|
||||||
realm string
|
|
||||||
principal string
|
|
||||||
job string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *clientPKIInfo) String() string {
|
|
||||||
return fmt.Sprintf("job=%q, principal=%q, realm=%q", c.job, c.principal, c.realm)
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseClientName(realm, name string) (*clientPKIInfo, error) {
|
|
||||||
if !strings.HasSuffix(name, "."+realm) {
|
|
||||||
return nil, fmt.Errorf("invalid realm")
|
|
||||||
}
|
|
||||||
service := strings.TrimSuffix(name, "."+realm)
|
|
||||||
parts := strings.Split(service, ".")
|
|
||||||
if len(parts) != 2 {
|
|
||||||
return nil, fmt.Errorf("invalid service")
|
|
||||||
}
|
|
||||||
return &clientPKIInfo{
|
|
||||||
realm: realm,
|
|
||||||
principal: parts[1],
|
|
||||||
job: parts[0],
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
ctxKeyPKIInfo = "hscloud-pki-info"
|
|
||||||
)
|
|
||||||
|
|
||||||
func withPKIInfo(ctx context.Context, c *clientPKIInfo) context.Context {
|
|
||||||
tr, ok := trace.FromContext(ctx)
|
|
||||||
if ok {
|
|
||||||
tr.LazyPrintf("PKI Peer: %s", c.String())
|
|
||||||
}
|
|
||||||
return context.WithValue(ctx, ctxKeyPKIInfo, c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *server) unaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
|
||||||
peer, ok := peer.FromContext(ctx)
|
|
||||||
if !ok {
|
|
||||||
s.trace(ctx, "Could not establish identity of peer.")
|
|
||||||
return nil, status.Error(codes.InvalidArgument, "no peer info")
|
|
||||||
}
|
|
||||||
|
|
||||||
authInfo, ok := peer.AuthInfo.(credentials.TLSInfo)
|
|
||||||
if !ok {
|
|
||||||
s.trace(ctx, "Could not establish TLS identity of peer.")
|
|
||||||
return nil, status.Error(codes.InvalidArgument, "no TLS peer info")
|
|
||||||
}
|
|
||||||
|
|
||||||
chains := authInfo.State.VerifiedChains
|
|
||||||
if len(chains) != 1 {
|
|
||||||
s.trace(ctx, "No trusted chain found.")
|
|
||||||
return nil, status.Error(codes.InvalidArgument, "invalid TLS certificate")
|
|
||||||
}
|
|
||||||
chain := chains[0]
|
|
||||||
|
|
||||||
certDNs := make([]string, len(chain))
|
|
||||||
for i, cert := range chain {
|
|
||||||
certDNs[i] = cert.Subject.String()
|
|
||||||
}
|
|
||||||
s.trace(ctx, "TLS chain: %s", strings.Join(certDNs, ", "))
|
|
||||||
|
|
||||||
clientInfo, err := parseClientName(s.opts.pkiRealm, chain[0].Subject.CommonName)
|
|
||||||
if err != nil {
|
|
||||||
s.trace(ctx, "Could not parse certificate DN: %v", err)
|
|
||||||
return nil, status.Error(codes.InvalidArgument, "invalid TLS CommonName")
|
|
||||||
}
|
|
||||||
ctx = withPKIInfo(ctx, clientInfo)
|
|
||||||
|
|
||||||
return handler(ctx, req)
|
|
||||||
}
|
|
Loading…
Reference in New Issue