Skip to content

Commit

Permalink
[v16] tsh: Deduplicate the list of request IDs (#47121)
Browse files Browse the repository at this point in the history
* tsh: Deduplicate the list of request IDs

It's possible to specify the same request multiple times with
tsh request create. The duplicates eventually get resolved before
we generate a certificate, but they do exist in the access request
resource. This can cause the size of the resource to exceed the
limits of a gRPC message and break listing.

* Limit access requests to a maximum of 300 resources

* Only run deduplication when the request is being created
  • Loading branch information
zmb3 authored Oct 4, 2024
1 parent eb27333 commit 18a73ad
Show file tree
Hide file tree
Showing 4 changed files with 234 additions and 162 deletions.
20 changes: 18 additions & 2 deletions lib/services/access_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ import (

const (
maxAccessRequestReasonSize = 4096
maxResourcesPerRequest = 300

// A day is sometimes 23 hours, sometimes 25 hours, usually 24 hours.
day = 24 * time.Hour
Expand All @@ -67,14 +68,17 @@ func ValidateAccessRequest(ar types.AccessRequest) error {

_, err := uuid.Parse(ar.GetName())
if err != nil {
return trace.BadParameter("invalid access request id %q", ar.GetName())
return trace.BadParameter("invalid access request ID %q", ar.GetName())
}
if len(ar.GetRequestReason()) > maxAccessRequestReasonSize {
return trace.BadParameter("access request reason is too long, max %v bytes", maxAccessRequestReasonSize)
}
if len(ar.GetResolveReason()) > maxAccessRequestReasonSize {
return trace.BadParameter("access request resolve reason is too long, max %v bytes", maxAccessRequestReasonSize)
}
if l := len(ar.GetRequestedResourceIDs()); l > maxResourcesPerRequest {
return trace.BadParameter("access request contains too many resources (%v), max %v", l, maxResourcesPerRequest)
}
return nil
}

Expand Down Expand Up @@ -1110,7 +1114,6 @@ func (m *RequestValidator) Validate(ctx context.Context, req types.AccessRequest
// need to be expanded into a list consisting of all existing roles
// that the user does not hold and is allowed to request.
if r := req.GetRoles(); len(r) == 1 && r[0] == types.Wildcard {

if !req.GetState().IsPending() {
// expansion is only permitted in pending requests. once resolved,
// a request's role list must be immutable.
Expand Down Expand Up @@ -1152,6 +1155,19 @@ func (m *RequestValidator) Validate(ctx context.Context, req types.AccessRequest
}

if m.opts.expandVars {
// deduplicate requested resource IDs
var deduplicated []types.ResourceID
seen := make(map[string]struct{})
for _, resource := range req.GetRequestedResourceIDs() {
id := types.ResourceIDToString(resource)
if _, isDuplicate := seen[id]; isDuplicate {
continue
}
seen[id] = struct{}{}
deduplicated = append(deduplicated, resource)
}
req.SetRequestedResourceIDs(deduplicated)

// determine the roles which should be requested for a resource access
// request, and write them to the request
if err := m.setRolesForResourceRequest(ctx, req); err != nil {
Expand Down
71 changes: 65 additions & 6 deletions lib/services/access_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -792,12 +792,7 @@ func TestMaxLength(t *testing.T) {
req, err := types.NewAccessRequest("some-id", "dave", "dictator", "never")
require.NoError(t, err)

var s []byte
for i := 0; i <= maxAccessRequestReasonSize; i++ {
s = append(s, 'a')
}

req.SetRequestReason(string(s))
req.SetRequestReason(strings.Repeat("a", maxAccessRequestReasonSize))
require.Error(t, ValidateAccessRequest(req))
}

Expand Down Expand Up @@ -2155,6 +2150,70 @@ func (mcg mockClusterGetter) GetRemoteCluster(ctx context.Context, clusterName s
return nil, trace.NotFound("remote cluster %q was not found", clusterName)
}

func TestValidateDuplicateRequestedResources(t *testing.T) {
g := &mockGetter{
roles: make(map[string]types.Role),
userStates: make(map[string]*userloginstate.UserLoginState),
users: make(map[string]types.User),
nodes: make(map[string]types.Server),
kubeServers: make(map[string]types.KubeServer),
dbServers: make(map[string]types.DatabaseServer),
appServers: make(map[string]types.AppServer),
desktops: make(map[string]types.WindowsDesktop),
clusterName: "someCluster",
}

for i := 1; i < 3; i++ {
node, err := types.NewServerWithLabels(
fmt.Sprintf("resource%d", i),
types.KindNode,
types.ServerSpecV2{},
map[string]string{"foo": "bar"},
)
require.NoError(t, err)
g.nodes[node.GetName()] = node
}

searchAsRole, err := types.NewRole("searchAs", types.RoleSpecV6{
Allow: types.RoleConditions{
Logins: []string{"root"},
NodeLabels: types.Labels{"*": []string{"*"}},
},
})
require.NoError(t, err)
g.roles[searchAsRole.GetName()] = searchAsRole

testRole, err := types.NewRole("testRole", types.RoleSpecV6{
Allow: types.RoleConditions{
Request: &types.AccessRequestConditions{
SearchAsRoles: []string{searchAsRole.GetName()},
},
},
})
require.NoError(t, err)
g.roles[testRole.GetName()] = testRole

user := g.user(t, testRole.GetName())

clock := clockwork.NewFakeClock()
identity := tlsca.Identity{
Expires: clock.Now().UTC().Add(8 * time.Hour),
}

req, err := types.NewAccessRequestWithResources("name", user, nil, /* roles */
[]types.ResourceID{
{ClusterName: "someCluster", Kind: "node", Name: "resource1"},
{ClusterName: "someCluster", Kind: "node", Name: "resource1"}, // a duplicate
{ClusterName: "someCluster", Kind: "node", Name: "resource2"}, // not a duplicate
})
require.NoError(t, err)

require.NoError(t, ValidateAccessRequestForUser(context.Background(), clock, g, req, identity, ExpandVars(true)))
require.Len(t, req.GetRequestedResourceIDs(), 2)
require.Equal(t, "/someCluster/node/resource1", types.ResourceIDToString(req.GetRequestedResourceIDs()[0]))
require.Equal(t, "/someCluster/node/resource2", types.ResourceIDToString(req.GetRequestedResourceIDs()[1]))
}

func TestValidateAccessRequestClusterNames(t *testing.T) {
for _, tc := range []struct {
name string
Expand Down
1 change: 1 addition & 0 deletions tool/tsh/common/tsh.go
Original file line number Diff line number Diff line change
Expand Up @@ -2546,6 +2546,7 @@ func createAccessRequest(cf *CLIConf) (types.AccessRequest, error) {
if err != nil {
return nil, trace.Wrap(err)
}

req, err := services.NewAccessRequestWithResources(cf.Username, roles, requestedResourceIDs)
if err != nil {
return nil, trace.Wrap(err)
Expand Down
Loading

0 comments on commit 18a73ad

Please sign in to comment.