Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ linters:
- errname
- errorlint
- exptostd
- forcetypeassert
- govet
- staticcheck
- testifylint
Expand Down
79 changes: 76 additions & 3 deletions packages/api/internal/auth/constants.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,79 @@
package auth

const (
TeamContextKey string = "team"
UserIDContextKey string = "user_id"
import (
"errors"
"fmt"

"github.com/gin-gonic/gin"
"github.com/google/uuid"
"go.uber.org/zap"

authcache "github.com/e2b-dev/infra/packages/api/internal/cache/auth"
)

var (
ErrNotFoundInContext = errors.New("not found in context")
ErrInvalidType = errors.New("unexpected type")
)

type ginContextValueHelper[T any] struct {
contextKey string
}

func (g *ginContextValueHelper[T]) set(c *gin.Context, val T) {
c.Set(g.contextKey, val)
}

func (g *ginContextValueHelper[T]) get(c *gin.Context) (T, error) {
var t T

v := c.Value(g.contextKey)
if v == nil {
return t, ErrNotFoundInContext
}

t, ok := v.(T)
if !ok {
return t, fmt.Errorf("%w: wanted %T, got %T",
ErrInvalidType, t, v)
}

return t, nil
}

func (g *ginContextValueHelper[T]) safeGet(c *gin.Context) T {
v, err := g.get(c)
if err != nil {
zap.L().Warn("failed to "+g.contextKey, zap.Error(err))
}
return v
}

var (
teamInfoHelper = ginContextValueHelper[authcache.AuthTeamInfo]{"team"}
userIDHelper = ginContextValueHelper[uuid.UUID]{"user_id"}
)

func setTeamInfo(c *gin.Context, teamInfo authcache.AuthTeamInfo) {
teamInfoHelper.set(c, teamInfo)
}

func GetTeamInfo(c *gin.Context) (authcache.AuthTeamInfo, error) {
return teamInfoHelper.get(c)
}

func SafeGetTeamInfo(c *gin.Context) authcache.AuthTeamInfo {
return teamInfoHelper.safeGet(c)
}

func setUserID(c *gin.Context, userID uuid.UUID) {
userIDHelper.set(c, userID)
}

func GetUserID(c *gin.Context) (uuid.UUID, error) {
return userIDHelper.get(c)
}

func SafeGetUserID(c *gin.Context) uuid.UUID {
return userIDHelper.safeGet(c)
}
24 changes: 14 additions & 10 deletions packages/api/internal/auth/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ type commonAuthenticator[T any] struct {
securitySchemeName string
headerKey headerKey
validationFunction func(context.Context, string) (T, *api.APIError)
contextKey string
setContext func(ginContext *gin.Context, result T)
errorMessage string
}

Expand Down Expand Up @@ -107,8 +107,8 @@ func (a *commonAuthenticator[T]) Authenticate(ctx context.Context, input *openap
telemetry.ReportEvent(ctx, "api key validated")

// Set the property on the gin context
if a.contextKey != "" {
middleware.GetGinContext(ctx).Set(a.contextKey, result)
if a.setContext != nil {
a.setContext(middleware.GetGinContext(ctx), result)
}

return nil
Expand Down Expand Up @@ -146,7 +146,7 @@ func CreateAuthenticationFunc(
removePrefix: "",
},
validationFunction: teamValidationFunction,
contextKey: TeamContextKey,
setContext: setTeamInfo,
errorMessage: "Invalid API key, please visit https://e2b.dev/docs/api-key for more information.",
},
&commonAuthenticator[uuid.UUID]{
Expand All @@ -157,7 +157,7 @@ func CreateAuthenticationFunc(
removePrefix: "Bearer ",
},
validationFunction: userValidationFunction,
contextKey: UserIDContextKey,
setContext: setUserID,
errorMessage: "Invalid Access token, try to login again by running `e2b auth login`.",
},
&commonAuthenticator[uuid.UUID]{
Expand All @@ -168,7 +168,7 @@ func CreateAuthenticationFunc(
removePrefix: "",
},
validationFunction: supabaseTokenValidationFunction,
contextKey: UserIDContextKey,
setContext: setUserID,
errorMessage: "Invalid Supabase token.",
},
&commonAuthenticator[authcache.AuthTeamInfo]{
Expand All @@ -179,7 +179,7 @@ func CreateAuthenticationFunc(
removePrefix: "",
},
validationFunction: supabaseTeamValidationFunction,
contextKey: TeamContextKey,
setContext: setTeamInfo,
errorMessage: "Invalid Supabase token teamID.",
},
&commonAuthenticator[struct{}]{
Expand All @@ -190,15 +190,19 @@ func CreateAuthenticationFunc(
removePrefix: "",
},
validationFunction: adminValidationFunction,
contextKey: "",
setContext: nil,
errorMessage: "Invalid Access token.",
},
}

return func(ctx context.Context, input *openapi3filter.AuthenticationInput) error {
ginContext := ctx.Value(middleware.GinContextKey).(*gin.Context)
requestContext := ginContext.Request.Context()
value := ctx.Value(middleware.GinContextKey)
ginContext, ok := value.(*gin.Context)
if !ok {
return fmt.Errorf("%w: received %T", ErrInvalidType, value)
}

requestContext := ginContext.Request.Context()
_, span := tracer.Start(requestContext, "authenticate")
defer span.End()

Expand Down
5 changes: 3 additions & 2 deletions packages/api/internal/handlers/accesstoken.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/google/uuid"

"github.com/e2b-dev/infra/packages/api/internal/api"
"github.com/e2b-dev/infra/packages/api/internal/auth"
"github.com/e2b-dev/infra/packages/api/internal/utils"
"github.com/e2b-dev/infra/packages/shared/pkg/keys"
"github.com/e2b-dev/infra/packages/shared/pkg/models/accesstoken"
Expand All @@ -17,7 +18,7 @@ import (
func (a *APIStore) PostAccessTokens(c *gin.Context) {
ctx := c.Request.Context()

userID := a.GetUserID(c)
userID := auth.SafeGetUserID(c)

body, err := utils.ParseBody[api.NewAccessToken](ctx, c)
if err != nil {
Expand Down Expand Up @@ -74,7 +75,7 @@ func (a *APIStore) PostAccessTokens(c *gin.Context) {
func (a *APIStore) DeleteAccessTokensAccessTokenID(c *gin.Context, accessTokenID string) {
ctx := c.Request.Context()

userID := a.GetUserID(c)
userID := auth.SafeGetUserID(c)

accessTokenIDParsed, err := uuid.Parse(accessTokenID)
if err != nil {
Expand Down
11 changes: 6 additions & 5 deletions packages/api/internal/handlers/apikey.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"go.uber.org/zap"

"github.com/e2b-dev/infra/packages/api/internal/api"
"github.com/e2b-dev/infra/packages/api/internal/auth"
"github.com/e2b-dev/infra/packages/api/internal/team"
"github.com/e2b-dev/infra/packages/api/internal/utils"
"github.com/e2b-dev/infra/packages/db/queries"
Expand Down Expand Up @@ -39,7 +40,7 @@ func (a *APIStore) PatchApiKeysApiKeyID(c *gin.Context, apiKeyID string) {
return
}

teamID := a.GetTeamInfo(c).Team.ID
teamID := auth.SafeGetTeamInfo(c).Team.ID

now := time.Now()
_, err = a.sqlcDB.UpdateTeamApiKey(ctx, queries.UpdateTeamApiKeyParams{
Expand All @@ -64,7 +65,7 @@ func (a *APIStore) PatchApiKeysApiKeyID(c *gin.Context, apiKeyID string) {
func (a *APIStore) GetApiKeys(c *gin.Context) {
ctx := c.Request.Context()

teamID := a.GetTeamInfo(c).Team.ID
teamID := auth.SafeGetTeamInfo(c).Team.ID

apiKeysDB, err := a.db.Client.TeamAPIKey.
Query().
Expand Down Expand Up @@ -116,7 +117,7 @@ func (a *APIStore) DeleteApiKeysApiKeyID(c *gin.Context, apiKeyID string) {
return
}

teamID := a.GetTeamInfo(c).Team.ID
teamID := auth.SafeGetTeamInfo(c).Team.ID

err = a.db.Client.TeamAPIKey.DeleteOneID(apiKeyIDParsed).Where(teamapikey.TeamID(teamID)).Exec(ctx)
if models.IsNotFound(err) {
Expand All @@ -135,8 +136,8 @@ func (a *APIStore) DeleteApiKeysApiKeyID(c *gin.Context, apiKeyID string) {
func (a *APIStore) PostApiKeys(c *gin.Context) {
ctx := c.Request.Context()

userID := a.GetUserID(c)
teamID := a.GetTeamInfo(c).Team.ID
userID := auth.SafeGetUserID(c)
teamID := auth.SafeGetTeamInfo(c).Team.ID

body, err := utils.ParseBody[api.NewTeamAPIKey](ctx, c)
if err != nil {
Expand Down
16 changes: 7 additions & 9 deletions packages/api/internal/handlers/auth.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
package handlers

import (
"errors"
"fmt"

"github.com/gin-gonic/gin"
"github.com/google/uuid"

"github.com/e2b-dev/infra/packages/api/internal/auth"
authcache "github.com/e2b-dev/infra/packages/api/internal/cache/auth"
"github.com/e2b-dev/infra/packages/db/queries"
)

func (a *APIStore) GetUserID(c *gin.Context) uuid.UUID {
return c.Value(auth.UserIDContextKey).(uuid.UUID)
}
var ErrNoUserIDInContext = errors.New("no user id in context")

func (a *APIStore) GetUserAndTeams(c *gin.Context) (*uuid.UUID, []queries.GetTeamsWithUsersTeamsWithTierRow, error) {
userID := a.GetUserID(c)
userID, err := auth.GetUserID(c)
if err != nil {
return nil, nil, ErrNoUserIDInContext
}

ctx := c.Request.Context()

teams, err := a.sqlcDB.GetTeamsWithUsersTeamsWithTier(ctx, userID)
Expand All @@ -26,7 +28,3 @@ func (a *APIStore) GetUserAndTeams(c *gin.Context) (*uuid.UUID, []queries.GetTea

return &userID, teams, err
}

func (a *APIStore) GetTeamInfo(c *gin.Context) authcache.AuthTeamInfo {
return c.Value(auth.TeamContextKey).(authcache.AuthTeamInfo)
}
5 changes: 2 additions & 3 deletions packages/api/internal/handlers/sandbox_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (

"github.com/e2b-dev/infra/packages/api/internal/api"
"github.com/e2b-dev/infra/packages/api/internal/auth"
authcache "github.com/e2b-dev/infra/packages/api/internal/cache/auth"
"github.com/e2b-dev/infra/packages/api/internal/cache/instance"
"github.com/e2b-dev/infra/packages/api/internal/middleware/otel/metrics"
"github.com/e2b-dev/infra/packages/api/internal/utils"
Expand Down Expand Up @@ -42,8 +41,8 @@ var mostUsedTemplates = map[string]struct{}{
func (a *APIStore) PostSandboxes(c *gin.Context) {
ctx := c.Request.Context()

// Get team from context, use TeamContextKey
teamInfo := c.Value(auth.TeamContextKey).(authcache.AuthTeamInfo)
// Get team from context
teamInfo := auth.SafeGetTeamInfo(c)

c.Set("teamID", teamInfo.Team.ID.String())

Expand Down
4 changes: 1 addition & 3 deletions packages/api/internal/handlers/sandbox_get.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (

"github.com/e2b-dev/infra/packages/api/internal/api"
"github.com/e2b-dev/infra/packages/api/internal/auth"
authcache "github.com/e2b-dev/infra/packages/api/internal/cache/auth"
"github.com/e2b-dev/infra/packages/api/internal/cache/instance"
"github.com/e2b-dev/infra/packages/api/internal/utils"
"github.com/e2b-dev/infra/packages/db/queries"
Expand All @@ -21,8 +20,7 @@ import (
func (a *APIStore) GetSandboxesSandboxID(c *gin.Context, id string) {
ctx := c.Request.Context()

teamInfo := c.Value(auth.TeamContextKey).(authcache.AuthTeamInfo)
team := teamInfo.Team
team := auth.SafeGetTeamInfo(c).Team

telemetry.ReportEvent(ctx, "get sandbox")

Expand Down
3 changes: 1 addition & 2 deletions packages/api/internal/handlers/sandbox_kill.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"go.opentelemetry.io/otel/attribute"

"github.com/e2b-dev/infra/packages/api/internal/auth"
authcache "github.com/e2b-dev/infra/packages/api/internal/cache/auth"
"github.com/e2b-dev/infra/packages/api/internal/cache/instance"
template_manager "github.com/e2b-dev/infra/packages/api/internal/template-manager"
"github.com/e2b-dev/infra/packages/api/internal/utils"
Expand Down Expand Up @@ -72,7 +71,7 @@ func (a *APIStore) DeleteSandboxesSandboxID(
ctx := c.Request.Context()
sandboxID = utils.ShortID(sandboxID)

team := c.Value(auth.TeamContextKey).(authcache.AuthTeamInfo).Team
team := auth.SafeGetTeamInfo(c).Team
teamID := team.ID

telemetry.SetAttributes(ctx,
Expand Down
9 changes: 4 additions & 5 deletions packages/api/internal/handlers/sandbox_logs.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (

"github.com/e2b-dev/infra/packages/api/internal/api"
"github.com/e2b-dev/infra/packages/api/internal/auth"
authcache "github.com/e2b-dev/infra/packages/api/internal/cache/auth"
"github.com/e2b-dev/infra/packages/api/internal/utils"
apiedge "github.com/e2b-dev/infra/packages/shared/pkg/http/edge"
"github.com/e2b-dev/infra/packages/shared/pkg/telemetry"
Expand All @@ -21,17 +20,17 @@ func (a *APIStore) GetSandboxesSandboxIDLogs(c *gin.Context, sandboxID string, p
ctx := c.Request.Context()
sandboxID = utils.ShortID(sandboxID)

team := c.Value(auth.TeamContextKey).(authcache.AuthTeamInfo).Team
team := auth.SafeGetTeamInfo(c).Team

telemetry.SetAttributes(ctx,
attribute.String("instance.id", sandboxID),
telemetry.WithTeamID(team.ID.String()),
)

/// Sandboxes living in a cluster
sbxLogs, err := a.getClusterSandboxLogs(ctx, sandboxID, team.ID.String(), utils.WithClusterFallback(team.ClusterID), params.Limit, params.Start)
if err != nil {
a.sendAPIStoreError(c, int(err.Code), err.Message)
sbxLogs, apiErr := a.getClusterSandboxLogs(ctx, sandboxID, team.ID.String(), utils.WithClusterFallback(team.ClusterID), params.Limit, params.Start)
if apiErr != nil {
a.sendAPIStoreError(c, int(apiErr.Code), apiErr.Message)
return
}

Expand Down
3 changes: 1 addition & 2 deletions packages/api/internal/handlers/sandbox_metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (

"github.com/e2b-dev/infra/packages/api/internal/api"
"github.com/e2b-dev/infra/packages/api/internal/auth"
authcache "github.com/e2b-dev/infra/packages/api/internal/cache/auth"
"github.com/e2b-dev/infra/packages/api/internal/utils"
clickhouse "github.com/e2b-dev/infra/packages/clickhouse/pkg"
featureflags "github.com/e2b-dev/infra/packages/shared/pkg/feature-flags"
Expand All @@ -25,7 +24,7 @@ func (a *APIStore) GetSandboxesSandboxIDMetrics(c *gin.Context, sandboxID string
defer span.End()
sandboxID = utils.ShortID(sandboxID)

team := c.Value(auth.TeamContextKey).(authcache.AuthTeamInfo).Team
team := auth.SafeGetTeamInfo(c).Team

metricsReadFlag, err := a.featureFlags.BoolFlag(ctx, featureflags.MetricsReadFlagName,
featureflags.SandboxContext(sandboxID))
Expand Down
4 changes: 2 additions & 2 deletions packages/api/internal/handlers/sandbox_pause.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"go.uber.org/zap"

"github.com/e2b-dev/infra/packages/api/internal/api"
"github.com/e2b-dev/infra/packages/api/internal/auth"
"github.com/e2b-dev/infra/packages/api/internal/cache/instance"
"github.com/e2b-dev/infra/packages/api/internal/utils"
"github.com/e2b-dev/infra/packages/db/queries"
Expand All @@ -22,8 +23,7 @@ func (a *APIStore) PostSandboxesSandboxIDPause(c *gin.Context, sandboxID api.San
ctx := c.Request.Context()
// Get team from context, use TeamContextKey

teamID := a.GetTeamInfo(c).Team.ID

teamID := auth.SafeGetTeamInfo(c).Team.ID
sandboxID = utils.ShortID(sandboxID)

span := trace.SpanFromContext(ctx)
Expand Down
Loading
Loading