Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions internal/ghmcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,13 @@ func createGitHubClients(cfg MCPServerConfig, apiHost apiHost) (*githubClients,

// Construct GraphQL client
// We use NewEnterpriseClient unconditionally since we already parsed the API host
// Layer transports: DefaultTransport -> bearerAuthTransport -> GraphQLFeaturesTransport
gqlHTTPClient := &http.Client{
Transport: &bearerAuthTransport{
transport: http.DefaultTransport,
token: cfg.Token,
Transport: &github.GraphQLFeaturesTransport{
Transport: &bearerAuthTransport{
transport: http.DefaultTransport,
token: cfg.Token,
},
},
}
gqlClient := githubv4.NewEnterpriseClient(apiHost.graphqlURL.String(), gqlHTTPClient)
Expand Down Expand Up @@ -622,12 +625,6 @@ type bearerAuthTransport struct {
func (t *bearerAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req = req.Clone(req.Context())
req.Header.Set("Authorization", "Bearer "+t.token)

// Check for GraphQL-Features in context and add header if present
if features := github.GetGraphQLFeatures(req.Context()); len(features) > 0 {
req.Header.Set("GraphQL-Features", strings.Join(features, ", "))
}

return t.transport.RoundTrip(req)
}

Expand Down
13 changes: 10 additions & 3 deletions pkg/github/issues.go
Original file line number Diff line number Diff line change
Expand Up @@ -1914,15 +1914,22 @@ func AssignCodingAgentPrompt(t translations.TranslationHelperFunc) inventory.Ser
)
}

// graphQLFeaturesKey is a context key for GraphQL feature flags
// graphQLFeaturesKey is a context key for GraphQL feature flags.
// These flags enable preview or experimental GitHub API features that are not yet GA.
type graphQLFeaturesKey struct{}

// withGraphQLFeatures adds GraphQL feature flags to the context
// withGraphQLFeatures adds GraphQL feature flags to the context.
// The flags are read by GraphQLFeaturesTransport and sent as the GraphQL-Features header.
// This is used internally by tool handlers that require experimental GitHub API features.
func withGraphQLFeatures(ctx context.Context, features ...string) context.Context {
return context.WithValue(ctx, graphQLFeaturesKey{}, features)
}

// GetGraphQLFeatures retrieves GraphQL feature flags from the context
// GetGraphQLFeatures retrieves GraphQL feature flags from the context.
// This function is exported to allow custom HTTP transports (e.g., in remote servers)
// to read feature flags and add them as the "GraphQL-Features" header.
//
// For most use cases, use GraphQLFeaturesTransport instead of calling this directly.
func GetGraphQLFeatures(ctx context.Context) []string {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot should we move all this GetGraphQLFeatures etc. to the transport file? It makes more sense there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! Moved graphQLFeaturesKey, withGraphQLFeatures, and GetGraphQLFeatures to transport.go where they logically belong with GraphQLFeaturesTransport. Commit: 7333546

if features, ok := ctx.Value(graphQLFeaturesKey{}).([]string); ok {
return features
Expand Down
61 changes: 61 additions & 0 deletions pkg/github/transport.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package github

import (
"net/http"
"strings"
)

// GraphQLFeaturesTransport is an http.RoundTripper that adds GraphQL-Features
// header based on context values set by withGraphQLFeatures.
//
// This transport should be used in the HTTP client chain for githubv4.Client
// to ensure GraphQL feature flags are properly sent to the GitHub API.
// Without this transport, certain GitHub API features (like Copilot assignment)
// that require feature flags will fail with schema validation errors.
//
// Example usage for local server (layering with auth):
//
// httpClient := &http.Client{
// Transport: &github.GraphQLFeaturesTransport{
// Transport: &authTransport{
// Transport: http.DefaultTransport,
// token: "ghp_...",
// },
// },
// }
// gqlClient := githubv4.NewClient(httpClient)
//
// Example usage for remote server (simple case):
//
// httpClient := &http.Client{
// Transport: &github.GraphQLFeaturesTransport{
// Transport: http.DefaultTransport,
// },
// }
// gqlClient := githubv4.NewClient(httpClient)
//
// The transport reads feature flags from request context using GetGraphQLFeatures.
// Feature flags are added to context by the tool handler via withGraphQLFeatures.
type GraphQLFeaturesTransport struct {
// Transport is the underlying http.RoundTripper. If nil, http.DefaultTransport is used.
Transport http.RoundTripper
}

// RoundTrip implements http.RoundTripper.
// It adds the GraphQL-Features header if features are present in the request context.
func (t *GraphQLFeaturesTransport) RoundTrip(req *http.Request) (*http.Response, error) {
transport := t.Transport
if transport == nil {
transport = http.DefaultTransport
}

// Clone request to avoid modifying the original
req = req.Clone(req.Context())

// Check for GraphQL-Features in context and add header if present
if features := GetGraphQLFeatures(req.Context()); len(features) > 0 {
req.Header.Set("GraphQL-Features", strings.Join(features, ", "))
}

return transport.RoundTrip(req)
}
142 changes: 142 additions & 0 deletions pkg/github/transport_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
package github

import (
"context"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestGraphQLFeaturesTransport(t *testing.T) {
tests := []struct {
name string
features []string
expectHeader bool
expectedHeaderVal string
}{
{
name: "adds single feature to header",
features: []string{"issues_copilot_assignment_api_support"},
expectHeader: true,
expectedHeaderVal: "issues_copilot_assignment_api_support",
},
{
name: "adds multiple features to header",
features: []string{"feature1", "feature2", "feature3"},
expectHeader: true,
expectedHeaderVal: "feature1, feature2, feature3",
},
{
name: "no header when no features in context",
features: nil,
expectHeader: false,
},
{
name: "no header when empty features slice",
features: []string{},
expectHeader: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create a test server that captures the request
var capturedReq *http.Request
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedReq = r
w.WriteHeader(http.StatusOK)
}))
defer server.Close()

// Create HTTP client with GraphQLFeaturesTransport
client := &http.Client{
Transport: &GraphQLFeaturesTransport{
Transport: http.DefaultTransport,
},
}

// Create request with or without features in context
ctx := context.Background()
if tt.features != nil {
ctx = withGraphQLFeatures(ctx, tt.features...)
}

req, err := http.NewRequestWithContext(ctx, "GET", server.URL, nil)
require.NoError(t, err)

// Make request
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

// Verify header
if tt.expectHeader {
assert.Equal(t, tt.expectedHeaderVal, capturedReq.Header.Get("GraphQL-Features"))
} else {
assert.Empty(t, capturedReq.Header.Get("GraphQL-Features"))
}
})
}
}

func TestGraphQLFeaturesTransport_NilTransport(t *testing.T) {
// Test that nil Transport falls back to http.DefaultTransport
var capturedReq *http.Request
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedReq = r
w.WriteHeader(http.StatusOK)
}))
defer server.Close()

client := &http.Client{
Transport: &GraphQLFeaturesTransport{
Transport: nil, // Explicitly nil
},
}

ctx := withGraphQLFeatures(context.Background(), "test_feature")
req, err := http.NewRequestWithContext(ctx, "GET", server.URL, nil)
require.NoError(t, err)

resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

assert.Equal(t, "test_feature", capturedReq.Header.Get("GraphQL-Features"))
}

func TestGraphQLFeaturesTransport_PreservesOtherHeaders(t *testing.T) {
// Test that the transport doesn't interfere with other headers
var capturedReq *http.Request
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedReq = r
w.WriteHeader(http.StatusOK)
}))
defer server.Close()

client := &http.Client{
Transport: &GraphQLFeaturesTransport{
Transport: http.DefaultTransport,
},
}

ctx := withGraphQLFeatures(context.Background(), "feature1")
req, err := http.NewRequestWithContext(ctx, "GET", server.URL, nil)
require.NoError(t, err)

// Add custom headers
req.Header.Set("Authorization", "Bearer test-token")
req.Header.Set("User-Agent", "test-agent")

resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

// Verify all headers are preserved
assert.Equal(t, "feature1", capturedReq.Header.Get("GraphQL-Features"))
assert.Equal(t, "Bearer test-token", capturedReq.Header.Get("Authorization"))
assert.Equal(t, "test-agent", capturedReq.Header.Get("User-Agent"))
}