Skip to content

Commit

Permalink
fix: v1.0 version fallback checks if v1.0 is supported by client
Browse files Browse the repository at this point in the history
Signed-off-by: Pierre-Henri Symoneaux <[email protected]>
  • Loading branch information
phsym committed Mar 10, 2025
1 parent cb4fafd commit 0a5893a
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 13 deletions.
9 changes: 6 additions & 3 deletions kmipclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ func (c *Client) CloneCtx(ctx context.Context) (*Client, error) {
dialer: c.dialer,
middlewares: slices.Clone(c.middlewares),
conn: stream,
addr: c.addr,
}, nil
}

Expand Down Expand Up @@ -379,10 +380,12 @@ func (c *Client) negotiateVersion(ctx context.Context) error {
return errors.New("Unexpected batch item count")
}
bi := resp.BatchItem[0]
if bi.ResultStatus == kmip.ResultStatusOperationFailed &&
(bi.ResultReason == kmip.ResultReasonOperationNotSupported /*|| bi.ResultReason == kmip.ReasonInvalidMessage && bi.Operation == 0x00*/) {
if bi.ResultStatus == kmip.ResultStatusOperationFailed && bi.ResultReason == kmip.ResultReasonOperationNotSupported {
// If the discover opertion is not supported, then fallbacks to kmip v1.0
// TODO: Check that v1.0 is in the client's supported version list and return an error if not
// but also check that v1.0 is in the client's supported version list and return an error if not.
if !slices.Contains(c.supportedVersions, kmip.V1_0) {
return errors.New("Protocol version negotiation failed. No common version found")
}
c.version = &kmip.V1_0
return nil
}
Expand Down
83 changes: 73 additions & 10 deletions kmipclient/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,21 @@ package kmipclient_test

import (
"context"
"os"
"sync"
"testing"
"time"

"github.com/ovh/kmip-go"
"github.com/ovh/kmip-go/kmipclient"
"github.com/ovh/kmip-go/kmipserver"
"github.com/ovh/kmip-go/kmiptest"
"github.com/ovh/kmip-go/payloads"
"github.com/ovh/kmip-go/ttlv"

"github.com/stretchr/testify/require"
)

// func testClientRequest[Req, Resp kmip.OperationPayload](t *testing.T, tf func(*kmipclient.Client) *kmipclient.Executor[Req, Resp], f func(*testing.T, Req) (Resp, error)) (Resp, error) {
// mux := kmipserver.NewBatchExecutor()
// client := kmiptest.NewClientAndServer(t, mux)
// req := tf(client)
// mux.Route(req.RequestPayload().Operation(), kmipserver.HandleFunc(func(ctx context.Context, pl Req) (Resp, error) {
// return f(t, pl)
// }))
// return req.Exec()
// }

func TestRequest_ContextTimeout(t *testing.T) {
mux := kmipserver.NewBatchExecutor()
client := kmiptest.NewClientAndServer(t, mux)
Expand Down Expand Up @@ -215,3 +208,73 @@ func TestClone(t *testing.T) {
_, err = client3.Request(context.Background(), &payloads.DiscoverVersionsRequestPayload{})
require.NoError(t, err)
}

func TestVersionNegociation(t *testing.T) {
router := kmipserver.NewBatchExecutor()
router.Route(kmip.OperationDiscoverVersions, kmipserver.HandleFunc(func(ctx context.Context, pl *payloads.DiscoverVersionsRequestPayload) (*payloads.DiscoverVersionsResponsePayload, error) {
return &payloads.DiscoverVersionsResponsePayload{
ProtocolVersion: []kmip.ProtocolVersion{
kmip.V1_3, kmip.V1_2,
},
}, nil
}))
addr, ca := kmiptest.NewServer(t, router)
client, err := kmipclient.Dial(
addr,
kmipclient.WithRootCAPem([]byte(ca)),
kmipclient.WithMiddlewares(
kmipclient.DebugMiddleware(os.Stderr, ttlv.MarshalXML),
),
kmipclient.WithKmipVersions(kmip.V1_2, kmip.V1_3),
)
require.NoError(t, err)
require.NotNil(t, client)
require.EqualValues(t, client.Version(), kmip.V1_3)
}

func TestVersionNegociation_NoCommon(t *testing.T) {
router := kmipserver.NewBatchExecutor()
router.Route(kmip.OperationDiscoverVersions, kmipserver.HandleFunc(func(ctx context.Context, pl *payloads.DiscoverVersionsRequestPayload) (*payloads.DiscoverVersionsResponsePayload, error) {
return &payloads.DiscoverVersionsResponsePayload{
ProtocolVersion: []kmip.ProtocolVersion{},
}, nil
}))
addr, ca := kmiptest.NewServer(t, router)
client, err := kmipclient.Dial(
addr,
kmipclient.WithRootCAPem([]byte(ca)),
kmipclient.WithMiddlewares(
kmipclient.DebugMiddleware(os.Stderr, ttlv.MarshalXML),
),
kmipclient.WithKmipVersions(kmip.V1_1, kmip.V1_2),
)
require.Error(t, err)
require.Nil(t, client)
}

func TestVersionNegociation_v1_0_Fallback(t *testing.T) {
router := kmipserver.NewBatchExecutor()
router.Route(kmip.OperationDiscoverVersions, kmipserver.HandleFunc(func(ctx context.Context, pl *payloads.DiscoverVersionsRequestPayload) (*payloads.DiscoverVersionsResponsePayload, error) {
return nil, kmipserver.ErrOperationNotSupported
}))
client := kmiptest.NewClientAndServer(t, router)
require.EqualValues(t, client.Version(), kmip.V1_0)
}

func TestVersionNegociation_v1_0_Fallback_unsupported(t *testing.T) {
router := kmipserver.NewBatchExecutor()
router.Route(kmip.OperationDiscoverVersions, kmipserver.HandleFunc(func(ctx context.Context, pl *payloads.DiscoverVersionsRequestPayload) (*payloads.DiscoverVersionsResponsePayload, error) {
return nil, kmipserver.ErrOperationNotSupported
}))
addr, ca := kmiptest.NewServer(t, router)
client, err := kmipclient.Dial(
addr,
kmipclient.WithRootCAPem([]byte(ca)),
kmipclient.WithMiddlewares(
kmipclient.DebugMiddleware(os.Stderr, ttlv.MarshalXML),
),
kmipclient.WithKmipVersions(kmip.V1_3, kmip.V1_4),
)
require.Error(t, err)
require.Nil(t, client)
}

0 comments on commit 0a5893a

Please sign in to comment.