Set APIVersion on the client, even when Ping fails

Refactor to support testing
Also add tests

Signed-off-by: Daniel Nephin <dnephin@docker.com>
master
Daniel Nephin 2017-09-20 16:13:03 -04:00
parent 10e292dbab
commit e828efa4ab
2 changed files with 177 additions and 26 deletions

View File

@ -17,6 +17,7 @@ import (
"github.com/docker/docker/client"
"github.com/docker/go-connections/sockets"
"github.com/docker/go-connections/tlsconfig"
"github.com/docker/notary"
"github.com/docker/notary/passphrase"
"github.com/pkg/errors"
"github.com/spf13/cobra"
@ -111,44 +112,53 @@ func (cli *DockerCli) Initialize(opts *cliflags.ClientOptions) error {
var err error
cli.client, err = NewAPIClientFromFlags(opts.Common, cli.configFile)
if tlsconfig.IsErrEncryptedKey(err) {
var (
passwd string
giveup bool
)
passRetriever := passphrase.PromptRetrieverWithInOut(cli.In(), cli.Out(), nil)
for attempts := 0; tlsconfig.IsErrEncryptedKey(err); attempts++ {
// some code and comments borrowed from notary/trustmanager/keystore.go
passwd, giveup, err = passRetriever("private", "encrypted TLS private", false, attempts)
// Check if the passphrase retriever got an error or if it is telling us to give up
if giveup || err != nil {
return errors.Wrap(err, "private key is encrypted, but could not get passphrase")
}
opts.Common.TLSOptions.Passphrase = passwd
cli.client, err = NewAPIClientFromFlags(opts.Common, cli.configFile)
newClient := func(password string) (client.APIClient, error) {
opts.Common.TLSOptions.Passphrase = password
return NewAPIClientFromFlags(opts.Common, cli.configFile)
}
cli.client, err = getClientWithPassword(passRetriever, newClient)
}
if err != nil {
return err
}
cli.initializeFromClient()
return nil
}
func (cli *DockerCli) initializeFromClient() {
cli.defaultVersion = cli.client.ClientVersion()
if ping, err := cli.client.Ping(context.Background()); err == nil {
cli.server = ServerInfo{
HasExperimental: ping.Experimental,
OSType: ping.OSType,
}
cli.client.NegotiateAPIVersionPing(ping)
} else {
ping, err := cli.client.Ping(context.Background())
if err != nil {
// Default to true if we fail to connect to daemon
cli.server = ServerInfo{HasExperimental: true}
if ping.APIVersion != "" {
cli.client.NegotiateAPIVersionPing(ping)
}
return
}
return nil
cli.server = ServerInfo{
HasExperimental: ping.Experimental,
OSType: ping.OSType,
}
cli.client.NegotiateAPIVersionPing(ping)
}
func getClientWithPassword(passRetriever notary.PassRetriever, newClient func(password string) (client.APIClient, error)) (client.APIClient, error) {
for attempts := 0; ; attempts++ {
passwd, giveup, err := passRetriever("private", "encrypted TLS private", false, attempts)
if giveup || err != nil {
return nil, errors.Wrap(err, "private key is encrypted, but could not get passphrase")
}
apiclient, err := newClient(passwd)
if !tlsconfig.IsErrEncryptedKey(err) {
return apiclient, err
}
}
}
// ServerInfo stores details about the supported features and platform of the

View File

@ -4,12 +4,18 @@ import (
"os"
"testing"
"crypto/x509"
"github.com/docker/cli/cli/config/configfile"
"github.com/docker/cli/cli/flags"
"github.com/docker/cli/internal/test/testutil"
"github.com/docker/docker/api"
"github.com/docker/docker/api/types"
"github.com/docker/docker/client"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/context"
)
func TestNewAPIClientFromFlags(t *testing.T) {
@ -43,7 +49,7 @@ func TestNewAPIClientFromFlagsWithAPIVersionFromEnv(t *testing.T) {
assert.Equal(t, customVersion, apiclient.ClientVersion())
}
// TODO: move to gotestyourself
// TODO: use gotestyourself/env.Patch
func patchEnvVariable(t *testing.T, key, value string) func() {
oldValue, ok := os.LookupEnv(key)
require.NoError(t, os.Setenv(key, value))
@ -55,3 +61,138 @@ func patchEnvVariable(t *testing.T, key, value string) func() {
require.NoError(t, os.Setenv(key, oldValue))
}
}
type fakeClient struct {
client.Client
pingFunc func() (types.Ping, error)
version string
negotiated bool
}
func (c *fakeClient) Ping(_ context.Context) (types.Ping, error) {
return c.pingFunc()
}
func (c *fakeClient) ClientVersion() string {
return c.version
}
func (c *fakeClient) NegotiateAPIVersionPing(types.Ping) {
c.negotiated = true
}
func TestInitializeFromClient(t *testing.T) {
defaultVersion := "v1.55"
var testcases = []struct {
doc string
pingFunc func() (types.Ping, error)
expectedServer ServerInfo
negotiated bool
}{
{
doc: "successful ping",
pingFunc: func() (types.Ping, error) {
return types.Ping{Experimental: true, OSType: "linux", APIVersion: "v1.30"}, nil
},
expectedServer: ServerInfo{HasExperimental: true, OSType: "linux"},
negotiated: true,
},
{
doc: "failed ping, no API version",
pingFunc: func() (types.Ping, error) {
return types.Ping{}, errors.New("failed")
},
expectedServer: ServerInfo{HasExperimental: true},
},
{
doc: "failed ping, with API version",
pingFunc: func() (types.Ping, error) {
return types.Ping{APIVersion: "v1.33"}, errors.New("failed")
},
expectedServer: ServerInfo{HasExperimental: true},
negotiated: true,
},
}
for _, testcase := range testcases {
t.Run(testcase.doc, func(t *testing.T) {
apiclient := &fakeClient{
pingFunc: testcase.pingFunc,
version: defaultVersion,
}
cli := &DockerCli{client: apiclient}
cli.initializeFromClient()
assert.Equal(t, defaultVersion, cli.defaultVersion)
assert.Equal(t, testcase.expectedServer, cli.server)
assert.Equal(t, testcase.negotiated, apiclient.negotiated)
})
}
}
func TestGetClientWithPassword(t *testing.T) {
expected := "password"
var testcases = []struct {
doc string
password string
retrieverErr error
retrieverGiveup bool
newClientErr error
expectedErr string
}{
{
doc: "successful connect",
password: expected,
},
{
doc: "password retriever exhausted",
retrieverGiveup: true,
retrieverErr: errors.New("failed"),
expectedErr: "private key is encrypted, but could not get passphrase",
},
{
doc: "password retriever error",
retrieverErr: errors.New("failed"),
expectedErr: "failed",
},
{
doc: "newClient error",
newClientErr: errors.New("failed to connect"),
expectedErr: "failed to connect",
},
}
for _, testcase := range testcases {
t.Run(testcase.doc, func(t *testing.T) {
passRetriever := func(_, _ string, _ bool, attempts int) (passphrase string, giveup bool, err error) {
// Always return an invalid pass first to test iteration
switch attempts {
case 0:
return "something else", false, nil
default:
return testcase.password, testcase.retrieverGiveup, testcase.retrieverErr
}
}
newClient := func(currentPassword string) (client.APIClient, error) {
if testcase.newClientErr != nil {
return nil, testcase.newClientErr
}
if currentPassword == expected {
return &client.Client{}, nil
}
return &client.Client{}, x509.IncorrectPasswordError
}
_, err := getClientWithPassword(passRetriever, newClient)
if testcase.expectedErr != "" {
testutil.ErrorContains(t, err, testcase.expectedErr)
return
}
assert.NoError(t, err)
})
}
}