diff --git a/backend/docs/docs.go b/backend/docs/docs.go index 8eb132b..53d784f 100644 --- a/backend/docs/docs.go +++ b/backend/docs/docs.go @@ -602,6 +602,54 @@ const docTemplate = `{ } } }, + "/api/v1/trip-invites/{code}/join": { + "post": { + "description": "Adds the authenticated user to a trip using an invite code", + "produces": [ + "application/json" + ], + "tags": [ + "memberships" + ], + "summary": "Join trip by invite code", + "operationId": "joinTripByInvite", + "parameters": [ + { + "type": "string", + "description": "Invite code", + "name": "code", + "in": "path", + "required": true + } + ], + "responses": { + "201": { + "description": "Created", + "schema": { + "$ref": "#/definitions/models.Membership" + } + }, + "400": { + "description": "Invalid or expired invite code", + "schema": { + "$ref": "#/definitions/errs.APIError" + } + }, + "401": { + "description": "Unauthorized", + "schema": { + "$ref": "#/definitions/errs.APIError" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/errs.APIError" + } + } + } + } + }, "/api/v1/trips": { "get": { "description": "Retrieves trips with cursor-based pagination. Use limit and cursor query params.", @@ -876,6 +924,78 @@ const docTemplate = `{ } } }, + "/api/v1/trips/{tripID}/invites": { + "post": { + "description": "Creates a shareable invite for the trip. Caller must be a trip member.", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "trips" + ], + "summary": "Create a trip invite", + "operationId": "createTripInvite", + "parameters": [ + { + "type": "string", + "description": "Trip ID", + "name": "tripID", + "in": "path", + "required": true + }, + { + "description": "Optional expires_at; default 7 days", + "name": "request", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/models.CreateTripInviteRequest" + } + } + ], + "responses": { + "201": { + "description": "Created", + "schema": { + "$ref": "#/definitions/models.TripInviteAPIResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/errs.APIError" + } + }, + "401": { + "description": "Unauthorized", + "schema": { + "$ref": "#/definitions/errs.APIError" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/errs.APIError" + } + }, + "422": { + "description": "Unprocessable Entity", + "schema": { + "$ref": "#/definitions/errs.APIError" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/errs.APIError" + } + } + } + } + }, "/api/v1/trips/{tripID}/memberships": { "get": { "description": "Retrieves all members of a trip", @@ -1783,6 +1903,14 @@ const docTemplate = `{ } } }, + "models.CreateTripInviteRequest": { + "type": "object", + "properties": { + "expires_at": { + "type": "string" + } + } + }, "models.CreateTripRequest": { "type": "object", "required": [ @@ -2171,6 +2299,35 @@ const docTemplate = `{ } } }, + "models.TripInviteAPIResponse": { + "type": "object", + "properties": { + "code": { + "type": "string" + }, + "created_at": { + "type": "string" + }, + "created_by": { + "type": "string" + }, + "expires_at": { + "type": "string" + }, + "id": { + "type": "string" + }, + "is_revoked": { + "type": "boolean" + }, + "join_url": { + "type": "string" + }, + "trip_id": { + "type": "string" + } + } + }, "models.UpdateCommentRequest": { "type": "object", "required": [ diff --git a/backend/docs/swagger.json b/backend/docs/swagger.json index 8b641cd..2ad7a68 100644 --- a/backend/docs/swagger.json +++ b/backend/docs/swagger.json @@ -596,6 +596,54 @@ } } }, + "/api/v1/trip-invites/{code}/join": { + "post": { + "description": "Adds the authenticated user to a trip using an invite code", + "produces": [ + "application/json" + ], + "tags": [ + "memberships" + ], + "summary": "Join trip by invite code", + "operationId": "joinTripByInvite", + "parameters": [ + { + "type": "string", + "description": "Invite code", + "name": "code", + "in": "path", + "required": true + } + ], + "responses": { + "201": { + "description": "Created", + "schema": { + "$ref": "#/definitions/models.Membership" + } + }, + "400": { + "description": "Invalid or expired invite code", + "schema": { + "$ref": "#/definitions/errs.APIError" + } + }, + "401": { + "description": "Unauthorized", + "schema": { + "$ref": "#/definitions/errs.APIError" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/errs.APIError" + } + } + } + } + }, "/api/v1/trips": { "get": { "description": "Retrieves trips with cursor-based pagination. Use limit and cursor query params.", @@ -870,6 +918,78 @@ } } }, + "/api/v1/trips/{tripID}/invites": { + "post": { + "description": "Creates a shareable invite for the trip. Caller must be a trip member.", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "trips" + ], + "summary": "Create a trip invite", + "operationId": "createTripInvite", + "parameters": [ + { + "type": "string", + "description": "Trip ID", + "name": "tripID", + "in": "path", + "required": true + }, + { + "description": "Optional expires_at; default 7 days", + "name": "request", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/models.CreateTripInviteRequest" + } + } + ], + "responses": { + "201": { + "description": "Created", + "schema": { + "$ref": "#/definitions/models.TripInviteAPIResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/errs.APIError" + } + }, + "401": { + "description": "Unauthorized", + "schema": { + "$ref": "#/definitions/errs.APIError" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/errs.APIError" + } + }, + "422": { + "description": "Unprocessable Entity", + "schema": { + "$ref": "#/definitions/errs.APIError" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/errs.APIError" + } + } + } + } + }, "/api/v1/trips/{tripID}/memberships": { "get": { "description": "Retrieves all members of a trip", @@ -1777,6 +1897,14 @@ } } }, + "models.CreateTripInviteRequest": { + "type": "object", + "properties": { + "expires_at": { + "type": "string" + } + } + }, "models.CreateTripRequest": { "type": "object", "required": [ @@ -2165,6 +2293,35 @@ } } }, + "models.TripInviteAPIResponse": { + "type": "object", + "properties": { + "code": { + "type": "string" + }, + "created_at": { + "type": "string" + }, + "created_by": { + "type": "string" + }, + "expires_at": { + "type": "string" + }, + "id": { + "type": "string" + }, + "is_revoked": { + "type": "boolean" + }, + "join_url": { + "type": "string" + }, + "trip_id": { + "type": "string" + } + } + }, "models.UpdateCommentRequest": { "type": "object", "required": [ diff --git a/backend/docs/swagger.yaml b/backend/docs/swagger.yaml index 6603038..0101907 100644 --- a/backend/docs/swagger.yaml +++ b/backend/docs/swagger.yaml @@ -110,6 +110,11 @@ definitions: - trip_id - user_id type: object + models.CreateTripInviteRequest: + properties: + expires_at: + type: string + type: object models.CreateTripRequest: properties: budget_max: @@ -374,6 +379,25 @@ definitions: next_cursor: type: string type: object + models.TripInviteAPIResponse: + properties: + code: + type: string + created_at: + type: string + created_by: + type: string + expires_at: + type: string + id: + type: string + is_revoked: + type: boolean + join_url: + type: string + trip_id: + type: string + type: object models.UpdateCommentRequest: properties: content: @@ -878,6 +902,38 @@ paths: summary: Send bulk notifications tags: - notifications + /api/v1/trip-invites/{code}/join: + post: + description: Adds the authenticated user to a trip using an invite code + operationId: joinTripByInvite + parameters: + - description: Invite code + in: path + name: code + required: true + type: string + produces: + - application/json + responses: + "201": + description: Created + schema: + $ref: '#/definitions/models.Membership' + "400": + description: Invalid or expired invite code + schema: + $ref: '#/definitions/errs.APIError' + "401": + description: Unauthorized + schema: + $ref: '#/definitions/errs.APIError' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/errs.APIError' + summary: Join trip by invite code + tags: + - memberships /api/v1/trips: get: description: Retrieves trips with cursor-based pagination. Use limit and cursor @@ -1120,6 +1176,55 @@ paths: summary: Get comments tags: - comments + /api/v1/trips/{tripID}/invites: + post: + consumes: + - application/json + description: Creates a shareable invite for the trip. Caller must be a trip + member. + operationId: createTripInvite + parameters: + - description: Trip ID + in: path + name: tripID + required: true + type: string + - description: Optional expires_at; default 7 days + in: body + name: request + required: true + schema: + $ref: '#/definitions/models.CreateTripInviteRequest' + produces: + - application/json + responses: + "201": + description: Created + schema: + $ref: '#/definitions/models.TripInviteAPIResponse' + "400": + description: Bad Request + schema: + $ref: '#/definitions/errs.APIError' + "401": + description: Unauthorized + schema: + $ref: '#/definitions/errs.APIError' + "404": + description: Not Found + schema: + $ref: '#/definitions/errs.APIError' + "422": + description: Unprocessable Entity + schema: + $ref: '#/definitions/errs.APIError' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/errs.APIError' + summary: Create a trip invite + tags: + - trips /api/v1/trips/{tripID}/memberships: get: description: Retrieves all members of a trip diff --git a/backend/internal/controllers/membership.go b/backend/internal/controllers/membership.go index 0dd5e57..72f1338 100644 --- a/backend/internal/controllers/membership.go +++ b/backend/internal/controllers/membership.go @@ -25,6 +25,46 @@ func NewMembershipController(membershipService services.MembershipServiceInterfa } } +// @Summary Join trip by invite code +// @Description Adds the authenticated user to a trip using an invite code +// @Tags memberships +// @Produce json +// @Param code path string true "Invite code" +// @Success 201 {object} models.Membership +// @Failure 400 {object} errs.APIError "Invalid or expired invite code" +// @Failure 401 {object} errs.APIError +// @Failure 500 {object} errs.APIError +// @Router /api/v1/trip-invites/{code}/join [post] +// @ID joinTripByInvite +func (ctrl *MembershipController) JoinTripByInvite(c *fiber.Ctx) error { + userIDValue := c.Locals("userID") + if userIDValue == nil { + return errs.Unauthorized() + } + + userIDStr, ok := userIDValue.(string) + if !ok { + return errs.Unauthorized() + } + + userID, err := validators.ValidateID(userIDStr) + if err != nil { + return errs.Unauthorized() + } + + code := c.Params("code") + if code == "" { + return errs.BadRequest(errors.New("invite code is required")) + } + + membership, err := ctrl.membershipService.JoinTripByInviteCode(c.Context(), userID, code) + if err != nil { + return err + } + + return c.Status(http.StatusCreated).JSON(membership) +} + // @Summary Add member to trip // @Description Adds a user as a member of a trip // @Tags memberships diff --git a/backend/internal/controllers/trips.go b/backend/internal/controllers/trips.go index 046d67a..3fabfcb 100644 --- a/backend/internal/controllers/trips.go +++ b/backend/internal/controllers/trips.go @@ -180,6 +180,57 @@ func (ctrl *TripController) UpdateTrip(c *fiber.Ctx) error { return c.Status(http.StatusOK).JSON(trip) } +// @Summary Create a trip invite +// @Description Creates a shareable invite for the trip. Caller must be a trip member. +// @Tags trips +// @Accept json +// @Produce json +// @Param tripID path string true "Trip ID" +// @Param request body models.CreateTripInviteRequest true "Optional expires_at; default 7 days" +// @Success 201 {object} models.TripInviteAPIResponse +// @Failure 400 {object} errs.APIError +// @Failure 401 {object} errs.APIError +// @Failure 404 {object} errs.APIError +// @Failure 422 {object} errs.APIError +// @Failure 500 {object} errs.APIError +// @Router /api/v1/trips/{tripID}/invites [post] +// @ID createTripInvite +func (ctrl *TripController) CreateTripInvite(c *fiber.Ctx) error { + userIDValue := c.Locals("userID") + if userIDValue == nil { + return errs.Unauthorized() + } + + userIDStr, ok := userIDValue.(string) + if !ok { + return errs.Unauthorized() + } + + userID, err := validators.ValidateID(userIDStr) + if err != nil { + return errs.Unauthorized() + } + + tripID, err := validators.ValidateID(c.Params("tripID")) + if err != nil { + return errs.InvalidUUID() + } + + var req models.CreateTripInviteRequest + _ = c.BodyParser(&req) // optional body; empty or {} uses default expiry + + if err := validators.Validate(ctrl.validator, req); err != nil { + return err + } + + invite, err := ctrl.tripService.CreateTripInvite(c.Context(), tripID, userID, req) + if err != nil { + return err + } + + return c.Status(http.StatusCreated).JSON(invite) +} + // @Summary Delete a trip // @Description Deletes a trip by ID // @Tags trips diff --git a/backend/internal/migrations/20260205035655_create_trip_invite.sql b/backend/internal/migrations/20260205035655_create_trip_invite.sql new file mode 100644 index 0000000..bbb7dbc --- /dev/null +++ b/backend/internal/migrations/20260205035655_create_trip_invite.sql @@ -0,0 +1,24 @@ +-- +goose Up +-- +goose StatementBegin +CREATE TABLE trip_invites ( + id UUID PRIMARY KEY, + trip_id UUID NOT NULL REFERENCES trips(id) ON DELETE CASCADE, + created_by UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + code TEXT NOT NULL UNIQUE, + expires_at TIMESTAMP WITH TIME ZONE NOT NULL, + is_revoked BOOLEAN NOT NULL DEFAULT false, + created_at TIMESTAMP WITH TIME ZONE DEFAULT now() +); + +CREATE INDEX idx_trip_invites_trip_id ON trip_invites(trip_id); +CREATE INDEX idx_trip_invites_code ON trip_invites(code); +CREATE INDEX idx_trip_invites_created_by ON trip_invites(created_by); +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +DROP INDEX IF EXISTS idx_trip_invites_created_by; +DROP INDEX IF EXISTS idx_trip_invites_code; +DROP INDEX IF EXISTS idx_trip_invites_trip_id; +DROP TABLE IF EXISTS trip_invites; +-- +goose StatementEnd diff --git a/backend/internal/models/trip_invites.go b/backend/internal/models/trip_invites.go new file mode 100644 index 0000000..f713ce9 --- /dev/null +++ b/backend/internal/models/trip_invites.go @@ -0,0 +1,36 @@ +package models + +import ( + "time" + + "github.com/google/uuid" +) + +// TripInvite represents a shareable invite for a trip. +type TripInvite struct { + ID uuid.UUID `bun:"id,pk,type:uuid" json:"id"` + TripID uuid.UUID `bun:"trip_id,type:uuid,notnull" json:"trip_id"` + CreatedBy uuid.UUID `bun:"created_by,type:uuid,notnull" json:"created_by"` + Code string `bun:"code,notnull" json:"code"` + ExpiresAt time.Time `bun:"expires_at,nullzero,notnull" json:"expires_at"` + IsRevoked bool `bun:"is_revoked,notnull" json:"is_revoked"` + CreatedAt time.Time `bun:"created_at,nullzero" json:"created_at"` +} + +// CreateTripInviteRequest is the request body for creating a trip invite. +// If ExpiresAt is nil, a default (e.g. 7 days) is applied in the service. +type CreateTripInviteRequest struct { + ExpiresAt *time.Time `json:"expires_at" validate:"omitempty"` +} + +// TripInviteAPIResponse is the API response for a trip invite. +type TripInviteAPIResponse struct { + ID uuid.UUID `json:"id"` + TripID uuid.UUID `json:"trip_id"` + CreatedBy uuid.UUID `json:"created_by"` + Code string `json:"code"` + ExpiresAt time.Time `json:"expires_at"` + IsRevoked bool `json:"is_revoked"` + CreatedAt time.Time `json:"created_at"` + JoinURL *string `json:"join_url,omitempty"` +} diff --git a/backend/internal/repository/repository.go b/backend/internal/repository/repository.go index 81ded4c..e1a525c 100644 --- a/backend/internal/repository/repository.go +++ b/backend/internal/repository/repository.go @@ -10,13 +10,14 @@ import ( ) type Repository struct { - User UserRepository - Health HealthRepository - Image ImageRepository - Comment CommentRepository - Membership MembershipRepository - Trip TripRepository - db *bun.DB + User UserRepository + Health HealthRepository + Image ImageRepository + Comment CommentRepository + Membership MembershipRepository + Trip TripRepository + TripInvite TripInviteRepository + db *bun.DB } func NewRepository(db *bun.DB) *Repository { @@ -27,6 +28,7 @@ func NewRepository(db *bun.DB) *Repository { Comment: &commentRepository{db: db}, Trip: &tripRepository{db: db}, Membership: &membershipRepository{db: db}, + TripInvite: newTripInviteRepository(db), db: db, } } @@ -58,6 +60,12 @@ type TripRepository interface { Delete(ctx context.Context, id uuid.UUID) error } +type TripInviteRepository interface { + Create(ctx context.Context, invite *models.TripInvite) (*models.TripInvite, error) + FindByID(ctx context.Context, id uuid.UUID) (*models.TripInvite, error) + FindByCode(ctx context.Context, code string) (*models.TripInvite, error) +} + type MembershipRepository interface { Create(ctx context.Context, membership *models.Membership) (*models.Membership, error) Find(ctx context.Context, userID, tripID uuid.UUID) (*models.MembershipDatabaseResponse, error) diff --git a/backend/internal/repository/trip_invite.go b/backend/internal/repository/trip_invite.go new file mode 100644 index 0000000..6d3e274 --- /dev/null +++ b/backend/internal/repository/trip_invite.go @@ -0,0 +1,66 @@ +package repository + +import ( + "context" + "database/sql" + "errors" + "toggo/internal/errs" + "toggo/internal/models" + + "github.com/google/uuid" + "github.com/uptrace/bun" +) + +var _ TripInviteRepository = (*tripInviteRepository)(nil) + +type tripInviteRepository struct { + db *bun.DB +} + +func newTripInviteRepository(db *bun.DB) TripInviteRepository { + return &tripInviteRepository{db: db} +} + +// Create inserts a new trip invite. +func (r *tripInviteRepository) Create(ctx context.Context, invite *models.TripInvite) (*models.TripInvite, error) { + _, err := r.db.NewInsert(). + Model(invite). + Returning("*"). + Exec(ctx) + if err != nil { + return nil, err + } + return invite, nil +} + +// FindByID returns a trip invite by id. +func (r *tripInviteRepository) FindByID(ctx context.Context, id uuid.UUID) (*models.TripInvite, error) { + invite := &models.TripInvite{} + err := r.db.NewSelect(). + Model(invite). + Where("id = ?", id). + Scan(ctx) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, errs.ErrNotFound + } + return nil, err + } + return invite, nil +} + +// FindByCode returns a trip invite by its code. +func (r *tripInviteRepository) FindByCode(ctx context.Context, code string) (*models.TripInvite, error) { + invite := &models.TripInvite{} + err := r.db.NewSelect(). + Model(invite). + Where("code = ?", code). + Scan(ctx) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, errs.ErrNotFound + } + return nil, err + } + return invite, nil +} diff --git a/backend/internal/server/routers/membership.go b/backend/internal/server/routers/membership.go index cb26c89..bc35554 100644 --- a/backend/internal/server/routers/membership.go +++ b/backend/internal/server/routers/membership.go @@ -17,6 +17,9 @@ func MembershipRoutes(apiGroup fiber.Router, routeParams types.RouteParams) fibe membershipGroup := apiGroup.Group("/memberships") membershipGroup.Post("", membershipController.AddMember) + // /api/v1/trip-invites/:code/join + apiGroup.Post("/trip-invites/:code/join", membershipController.JoinTripByInvite) + // /api/v1/trips/:tripID/memberships tripMembershipGroup := apiGroup.Group("/trips/:tripID/memberships") tripMembershipGroup.Use(middlewares.TripMemberRequired(routeParams.ServiceParams.Repository)) diff --git a/backend/internal/server/routers/trips.go b/backend/internal/server/routers/trips.go index 747cb00..0cf62ef 100644 --- a/backend/internal/server/routers/trips.go +++ b/backend/internal/server/routers/trips.go @@ -28,6 +28,7 @@ func TripRoutes(apiGroup fiber.Router, routeParams types.RouteParams) fiber.Rout tripIDGroup.Get("", tripController.GetTrip) tripIDGroup.Patch("", tripController.UpdateTrip) tripIDGroup.Delete("", tripController.DeleteTrip) + tripIDGroup.Post("/invites", tripController.CreateTripInvite) return tripGroup } diff --git a/backend/internal/services/membership.go b/backend/internal/services/membership.go index 63cceb2..6a55a4a 100644 --- a/backend/internal/services/membership.go +++ b/backend/internal/services/membership.go @@ -14,6 +14,7 @@ import ( type MembershipServiceInterface interface { AddMember(ctx context.Context, req models.CreateMembershipRequest) (*models.Membership, error) + JoinTripByInviteCode(ctx context.Context, userID uuid.UUID, code string) (*models.Membership, error) GetMembership(ctx context.Context, tripID, userID uuid.UUID) (*models.MembershipAPIResponse, error) GetTripMembers(ctx context.Context, tripID uuid.UUID, limit int, cursorToken string) (*models.MembershipCursorPageResult, error) GetUserTrips(ctx context.Context, userID uuid.UUID) ([]*models.Membership, error) @@ -84,6 +85,78 @@ func (s *MembershipService) AddMember(ctx context.Context, req models.CreateMemb return s.Membership.Create(ctx, membership) } +// JoinTripByInviteCode adds the authenticated user to a trip using an invite code. +// - If the code is invalid -> error +// - If the invite is expired or revoked -> error +// - If the user is already a member -> returns existing membership (no error) +func (s *MembershipService) JoinTripByInviteCode(ctx context.Context, userID uuid.UUID, code string) (*models.Membership, error) { + invite, err := s.TripInvite.FindByCode(ctx, code) + if err != nil { + if errors.Is(err, errs.ErrNotFound) { + return nil, errs.BadRequest(errors.New("invalid invite code")) + } + return nil, err + } + + now := time.Now().UTC() + if invite.IsRevoked || invite.ExpiresAt.Before(now) { + return nil, errs.BadRequest(errors.New("invite link has expired")) + } + + // If already a member, return existing membership. + existingMembership, err := s.Membership.Find(ctx, userID, invite.TripID) + if err == nil { + return &models.Membership{ + UserID: existingMembership.UserID, + TripID: existingMembership.TripID, + IsAdmin: existingMembership.IsAdmin, + BudgetMin: existingMembership.BudgetMin, + BudgetMax: existingMembership.BudgetMax, + Availability: existingMembership.Availability, + CreatedAt: existingMembership.CreatedAt, + UpdatedAt: existingMembership.UpdatedAt, + }, nil + } + if !errors.Is(err, errs.ErrNotFound) { + return nil, err + } + + // Not a member yet; create a basic membership. + membership := &models.Membership{ + UserID: userID, + TripID: invite.TripID, + IsAdmin: false, + BudgetMin: 0, + BudgetMax: 0, + CreatedAt: now, + UpdatedAt: now, + } + + created, err := s.Membership.Create(ctx, membership) + if err != nil { + // If there was a race and the membership already exists, treat as success. + if errors.Is(err, errs.ErrDuplicate) { + existingMembership, findErr := s.Membership.Find(ctx, userID, invite.TripID) + if findErr != nil { + return nil, findErr + } + return &models.Membership{ + UserID: existingMembership.UserID, + TripID: existingMembership.TripID, + IsAdmin: existingMembership.IsAdmin, + BudgetMin: existingMembership.BudgetMin, + BudgetMax: existingMembership.BudgetMax, + Availability: existingMembership.Availability, + CreatedAt: existingMembership.CreatedAt, + UpdatedAt: existingMembership.UpdatedAt, + }, nil + } + return nil, err + } + + return created, nil +} + func (s *MembershipService) GetMembership(ctx context.Context, tripID, userID uuid.UUID) (*models.MembershipAPIResponse, error) { membership, err := s.Membership.Find(ctx, userID, tripID) if err != nil { diff --git a/backend/internal/services/trips.go b/backend/internal/services/trips.go index e1da75a..52fbf36 100644 --- a/backend/internal/services/trips.go +++ b/backend/internal/services/trips.go @@ -2,8 +2,13 @@ package services import ( "context" + "crypto/rand" + "encoding/hex" "errors" "log" + "os" + "strings" + "time" "toggo/internal/errs" "toggo/internal/models" "toggo/internal/realtime" @@ -20,6 +25,7 @@ type TripServiceInterface interface { GetTripsWithCursor(ctx context.Context, userID uuid.UUID, limit int, cursorToken string) (*models.TripCursorPageResult, error) UpdateTrip(ctx context.Context, tripID uuid.UUID, req models.UpdateTripRequest) (*models.Trip, error) DeleteTrip(ctx context.Context, userID, tripID uuid.UUID) error + CreateTripInvite(ctx context.Context, tripID uuid.UUID, createdBy uuid.UUID, req models.CreateTripInviteRequest) (*models.TripInviteAPIResponse, error) } var _ TripServiceInterface = (*TripService)(nil) @@ -252,3 +258,72 @@ func (s *TripService) toAPIResponse(ctx context.Context, tripData *models.TripDa UpdatedAt: tripData.UpdatedAt, }, nil } + +const defaultInviteExpiry = 7 * 24 * time.Hour + +// generateInviteCode returns a URL-safe hex string (e.g. 12 chars). +func generateInviteCode() (string, error) { + b := make([]byte, 6) + if _, err := rand.Read(b); err != nil { + return "", err + } + return hex.EncodeToString(b), nil +} + +func (s *TripService) CreateTripInvite(ctx context.Context, tripID uuid.UUID, createdBy uuid.UUID, req models.CreateTripInviteRequest) (*models.TripInviteAPIResponse, error) { + expiresAt := time.Now().UTC().Add(defaultInviteExpiry) + if req.ExpiresAt != nil { + expiresAt = *req.ExpiresAt + if expiresAt.Before(time.Now().UTC()) { + return nil, errs.BadRequest(errors.New("expires_at must be in the future")) + } + } + + code, err := generateInviteCode() + if err != nil { + return nil, err + } + + invite := &models.TripInvite{ + ID: uuid.New(), + TripID: tripID, + CreatedBy: createdBy, + Code: code, + ExpiresAt: expiresAt, + IsRevoked: false, + } + + created, err := s.TripInvite.Create(ctx, invite) + if err != nil { + if errors.Is(err, errs.ErrDuplicate) { + code, err = generateInviteCode() + if err != nil { + return nil, err + } + invite.Code = code + created, err = s.TripInvite.Create(ctx, invite) + } + if err != nil { + return nil, err + } + } + + var joinURL *string + baseURL := os.Getenv("APP_PUBLIC_URL") + if baseURL != "" { + trimmed := strings.TrimRight(baseURL, "/") + u := trimmed + "/invites/" + created.Code + joinURL = &u + } + + return &models.TripInviteAPIResponse{ + ID: created.ID, + TripID: created.TripID, + CreatedBy: created.CreatedBy, + Code: created.Code, + ExpiresAt: created.ExpiresAt, + IsRevoked: created.IsRevoked, + CreatedAt: created.CreatedAt, + JoinURL: joinURL, + }, nil +} diff --git a/backend/internal/tests/invites_test.go b/backend/internal/tests/invites_test.go new file mode 100644 index 0000000..9b3b607 --- /dev/null +++ b/backend/internal/tests/invites_test.go @@ -0,0 +1,181 @@ +package tests + +import ( + "context" + "fmt" + "net/http" + "testing" + "time" + "toggo/internal/models" + testkit "toggo/internal/tests/testkit/builders" + "toggo/internal/tests/testkit/fakes" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +func TestTripInvitesJoinWorkflow(t *testing.T) { + t.Run("create invite and join trip", func(t *testing.T) { + app := fakes.GetSharedTestApp() + + owner := createUser(t, app) + trip := createTrip(t, app, owner) + + inviteResp := testkit.New(t). + Request(testkit.Request{ + App: app, + Route: fmt.Sprintf("/api/v1/trips/%s/invites", trip), + Method: testkit.POST, + UserID: &owner, + }). + AssertStatus(http.StatusCreated). + GetBody() + + code, ok := inviteResp["code"].(string) + require.True(t, ok, "expected code to be a string") + require.NotEmpty(t, code) + + joiner := createUser(t, app) + + testkit.New(t). + Request(testkit.Request{ + App: app, + Route: fmt.Sprintf("/api/v1/trip-invites/%s/join", code), + Method: testkit.POST, + UserID: &joiner, + }). + AssertStatus(http.StatusCreated). + AssertField("user_id", joiner). + AssertField("trip_id", trip). + AssertField("is_admin", false). + AssertField("budget_min", float64(0)). + AssertField("budget_max", float64(0)) + + // Verify the joiner can now access their membership (TripMemberRequired should pass) + testkit.New(t). + Request(testkit.Request{ + App: app, + Route: fmt.Sprintf("/api/v1/trips/%s/memberships/%s", trip, joiner), + Method: testkit.GET, + UserID: &joiner, + }). + AssertStatus(http.StatusOK). + AssertField("user_id", joiner). + AssertField("trip_id", trip) + + // Joining again should be idempotent (no error) + testkit.New(t). + Request(testkit.Request{ + App: app, + Route: fmt.Sprintf("/api/v1/trip-invites/%s/join", code), + Method: testkit.POST, + UserID: &joiner, + }). + AssertStatus(http.StatusCreated). + AssertField("user_id", joiner). + AssertField("trip_id", trip) + }) + + t.Run("invalid code returns 400", func(t *testing.T) { + app := fakes.GetSharedTestApp() + + user := createUser(t, app) + + testkit.New(t). + Request(testkit.Request{ + App: app, + Route: "/api/v1/trip-invites/does-not-exist/join", + Method: testkit.POST, + UserID: &user, + }). + AssertStatus(http.StatusBadRequest). + AssertField("message", "invalid invite code") + }) + + t.Run("expired invite returns 400", func(t *testing.T) { + app := fakes.GetSharedTestApp() + db := fakes.GetSharedDB() + + owner := createUser(t, app) + trip := createTrip(t, app, owner) + + code := "expired-" + uuid.NewString() + expired := time.Now().UTC().Add(-1 * time.Hour) + + invite := &models.TripInvite{ + ID: uuid.New(), + TripID: uuid.MustParse(trip), + CreatedBy: uuid.MustParse(owner), + Code: code, + ExpiresAt: expired, + IsRevoked: false, + CreatedAt: time.Now().UTC(), + } + + _, err := db.NewInsert().Model(invite).Exec(context.Background()) + require.NoError(t, err) + + user := createUser(t, app) + + testkit.New(t). + Request(testkit.Request{ + App: app, + Route: fmt.Sprintf("/api/v1/trip-invites/%s/join", code), + Method: testkit.POST, + UserID: &user, + }). + AssertStatus(http.StatusBadRequest). + AssertField("message", "invite link has expired") + }) + + t.Run("revoked invite returns 400", func(t *testing.T) { + app := fakes.GetSharedTestApp() + db := fakes.GetSharedDB() + + owner := createUser(t, app) + trip := createTrip(t, app, owner) + + code := "revoked-" + uuid.NewString() + future := time.Now().UTC().Add(24 * time.Hour) + + invite := &models.TripInvite{ + ID: uuid.New(), + TripID: uuid.MustParse(trip), + CreatedBy: uuid.MustParse(owner), + Code: code, + ExpiresAt: future, + IsRevoked: true, + CreatedAt: time.Now().UTC(), + } + + _, err := db.NewInsert().Model(invite).Exec(context.Background()) + require.NoError(t, err) + + user := createUser(t, app) + + testkit.New(t). + Request(testkit.Request{ + App: app, + Route: fmt.Sprintf("/api/v1/trip-invites/%s/join", code), + Method: testkit.POST, + UserID: &user, + }). + AssertStatus(http.StatusBadRequest). + AssertField("message", "invite link has expired") + }) + + t.Run("unauthenticated join returns 401", func(t *testing.T) { + app := fakes.GetSharedTestApp() + auth := false + + testkit.New(t). + Request(testkit.Request{ + App: app, + Route: "/api/v1/trip-invites/whatever/join", + Method: testkit.POST, + Auth: &auth, + }). + AssertStatus(http.StatusUnauthorized) + }) +} + diff --git a/backend/internal/tests/mocks/mock_S3PresignClient.go b/backend/internal/tests/mocks/mock_S3PresignClient.go index 6310cc4..711c8fc 100644 --- a/backend/internal/tests/mocks/mock_S3PresignClient.go +++ b/backend/internal/tests/mocks/mock_S3PresignClient.go @@ -7,7 +7,7 @@ package mocks import ( "context" - v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + "github.com/aws/aws-sdk-go-v2/aws/signer/v4" "github.com/aws/aws-sdk-go-v2/service/s3" mock "github.com/stretchr/testify/mock" )