From 0a5893a6838858ed01392494702d611620204f72 Mon Sep 17 00:00:00 2001 From: Pierre-Henri Symoneaux Date: Mon, 10 Mar 2025 09:54:36 +0100 Subject: [PATCH] fix: v1.0 version fallback checks if v1.0 is supported by client Signed-off-by: Pierre-Henri Symoneaux --- kmipclient/client.go | 9 +++-- kmipclient/client_test.go | 83 ++++++++++++++++++++++++++++++++++----- 2 files changed, 79 insertions(+), 13 deletions(-) diff --git a/kmipclient/client.go b/kmipclient/client.go index 3a81cac..453d669 100644 --- a/kmipclient/client.go +++ b/kmipclient/client.go @@ -294,6 +294,7 @@ func (c *Client) CloneCtx(ctx context.Context) (*Client, error) { dialer: c.dialer, middlewares: slices.Clone(c.middlewares), conn: stream, + addr: c.addr, }, nil } @@ -379,10 +380,12 @@ func (c *Client) negotiateVersion(ctx context.Context) error { return errors.New("Unexpected batch item count") } bi := resp.BatchItem[0] - if bi.ResultStatus == kmip.ResultStatusOperationFailed && - (bi.ResultReason == kmip.ResultReasonOperationNotSupported /*|| bi.ResultReason == kmip.ReasonInvalidMessage && bi.Operation == 0x00*/) { + if bi.ResultStatus == kmip.ResultStatusOperationFailed && bi.ResultReason == kmip.ResultReasonOperationNotSupported { // If the discover opertion is not supported, then fallbacks to kmip v1.0 - // TODO: Check that v1.0 is in the client's supported version list and return an error if not + // but also check that v1.0 is in the client's supported version list and return an error if not. + if !slices.Contains(c.supportedVersions, kmip.V1_0) { + return errors.New("Protocol version negotiation failed. No common version found") + } c.version = &kmip.V1_0 return nil } diff --git a/kmipclient/client_test.go b/kmipclient/client_test.go index d4de1c1..981b54f 100644 --- a/kmipclient/client_test.go +++ b/kmipclient/client_test.go @@ -2,28 +2,21 @@ package kmipclient_test import ( "context" + "os" "sync" "testing" "time" "github.com/ovh/kmip-go" + "github.com/ovh/kmip-go/kmipclient" "github.com/ovh/kmip-go/kmipserver" "github.com/ovh/kmip-go/kmiptest" "github.com/ovh/kmip-go/payloads" + "github.com/ovh/kmip-go/ttlv" "github.com/stretchr/testify/require" ) -// func testClientRequest[Req, Resp kmip.OperationPayload](t *testing.T, tf func(*kmipclient.Client) *kmipclient.Executor[Req, Resp], f func(*testing.T, Req) (Resp, error)) (Resp, error) { -// mux := kmipserver.NewBatchExecutor() -// client := kmiptest.NewClientAndServer(t, mux) -// req := tf(client) -// mux.Route(req.RequestPayload().Operation(), kmipserver.HandleFunc(func(ctx context.Context, pl Req) (Resp, error) { -// return f(t, pl) -// })) -// return req.Exec() -// } - func TestRequest_ContextTimeout(t *testing.T) { mux := kmipserver.NewBatchExecutor() client := kmiptest.NewClientAndServer(t, mux) @@ -215,3 +208,73 @@ func TestClone(t *testing.T) { _, err = client3.Request(context.Background(), &payloads.DiscoverVersionsRequestPayload{}) require.NoError(t, err) } + +func TestVersionNegociation(t *testing.T) { + router := kmipserver.NewBatchExecutor() + router.Route(kmip.OperationDiscoverVersions, kmipserver.HandleFunc(func(ctx context.Context, pl *payloads.DiscoverVersionsRequestPayload) (*payloads.DiscoverVersionsResponsePayload, error) { + return &payloads.DiscoverVersionsResponsePayload{ + ProtocolVersion: []kmip.ProtocolVersion{ + kmip.V1_3, kmip.V1_2, + }, + }, nil + })) + addr, ca := kmiptest.NewServer(t, router) + client, err := kmipclient.Dial( + addr, + kmipclient.WithRootCAPem([]byte(ca)), + kmipclient.WithMiddlewares( + kmipclient.DebugMiddleware(os.Stderr, ttlv.MarshalXML), + ), + kmipclient.WithKmipVersions(kmip.V1_2, kmip.V1_3), + ) + require.NoError(t, err) + require.NotNil(t, client) + require.EqualValues(t, client.Version(), kmip.V1_3) +} + +func TestVersionNegociation_NoCommon(t *testing.T) { + router := kmipserver.NewBatchExecutor() + router.Route(kmip.OperationDiscoverVersions, kmipserver.HandleFunc(func(ctx context.Context, pl *payloads.DiscoverVersionsRequestPayload) (*payloads.DiscoverVersionsResponsePayload, error) { + return &payloads.DiscoverVersionsResponsePayload{ + ProtocolVersion: []kmip.ProtocolVersion{}, + }, nil + })) + addr, ca := kmiptest.NewServer(t, router) + client, err := kmipclient.Dial( + addr, + kmipclient.WithRootCAPem([]byte(ca)), + kmipclient.WithMiddlewares( + kmipclient.DebugMiddleware(os.Stderr, ttlv.MarshalXML), + ), + kmipclient.WithKmipVersions(kmip.V1_1, kmip.V1_2), + ) + require.Error(t, err) + require.Nil(t, client) +} + +func TestVersionNegociation_v1_0_Fallback(t *testing.T) { + router := kmipserver.NewBatchExecutor() + router.Route(kmip.OperationDiscoverVersions, kmipserver.HandleFunc(func(ctx context.Context, pl *payloads.DiscoverVersionsRequestPayload) (*payloads.DiscoverVersionsResponsePayload, error) { + return nil, kmipserver.ErrOperationNotSupported + })) + client := kmiptest.NewClientAndServer(t, router) + require.EqualValues(t, client.Version(), kmip.V1_0) +} + +func TestVersionNegociation_v1_0_Fallback_unsupported(t *testing.T) { + router := kmipserver.NewBatchExecutor() + router.Route(kmip.OperationDiscoverVersions, kmipserver.HandleFunc(func(ctx context.Context, pl *payloads.DiscoverVersionsRequestPayload) (*payloads.DiscoverVersionsResponsePayload, error) { + return nil, kmipserver.ErrOperationNotSupported + })) + addr, ca := kmiptest.NewServer(t, router) + client, err := kmipclient.Dial( + addr, + kmipclient.WithRootCAPem([]byte(ca)), + kmipclient.WithMiddlewares( + kmipclient.DebugMiddleware(os.Stderr, ttlv.MarshalXML), + ), + kmipclient.WithKmipVersions(kmip.V1_3, kmip.V1_4), + ) + require.Error(t, err) + require.Nil(t, client) +}