diff --git a/go/pki/grpc.go b/go/pki/grpc.go index 313f4a93..44099c00 100644 --- a/go/pki/grpc.go +++ b/go/pki/grpc.go @@ -235,7 +235,15 @@ func WithServerHSPKI() []grpc.ServerOption { return []grpc.ServerOption{creds, interceptor} } -func WithClientHSPKI() grpc.DialOption { +type ClientHSPKIOption func(c *tls.Config) + +func OverrideServerName(name string) ClientHSPKIOption { + return func(c *tls.Config) { + c.ServerName = name + } +} + +func WithClientHSPKI(opts ...ClientHSPKIOption) grpc.DialOption { if !flag.Parsed() { glog.Exitf("WithServerHSPKI called before flag.Parse!") } @@ -258,9 +266,15 @@ func WithClientHSPKI() grpc.DialOption { glog.Exitf("WithClientHSPKI: cannot load service certificate/key: %v", err) } - creds := credentials.NewTLS(&tls.Config{ + config := &tls.Config{ Certificates: []tls.Certificate{clientCert}, RootCAs: certPool, - }) + } + + for _, opt := range opts { + opt(config) + } + + creds := credentials.NewTLS(config) return grpc.WithTransportCredentials(creds) } diff --git a/go/pki/locate.go b/go/pki/locate.go index e48e013b..3b4ca294 100644 --- a/go/pki/locate.go +++ b/go/pki/locate.go @@ -1,6 +1,8 @@ package pki import ( + "crypto/tls" + "crypto/x509" "fmt" "io/ioutil" "os" @@ -20,6 +22,24 @@ func DeveloperCredentialsLocation() (string, error) { return fmt.Sprintf("%s/hspki", cfgDir), nil } +// DeveloperCredentialsPrincipal returns the principal/DN for which the local +// developer credentials are provisioned. +func DeveloperCredentialsPrincipal() (string, error) { + creds, err := loadDeveloperCredentials() + if err != nil { + return "", fmt.Errorf("when loading developer credentials: %w", err) + } + pair, err := tls.X509KeyPair(creds.cert, creds.key) + if err != nil { + return "", fmt.Errorf("when loading developer client cert: %w", err) + } + cert, err := x509.ParseCertificate(pair.Certificate[0]) + if err != nil { + return "", fmt.Errorf("when parsing developer client cert: %w", err) + } + return cert.Subject.CommonName, nil +} + type creds struct { ca []byte cert []byte