diff --git a/.env b/.env index 2b61116..6217abb 100644 --- a/.env +++ b/.env @@ -1,5 +1,6 @@ SERVER_HOST=localhost SERVER_PORT=8080 + DB_NAME=doit DB_USER=islamghany DB_PASSWORD=secret @@ -7,4 +8,13 @@ DB_DISABLE_TLS=true DB_MAX_OPEN_CONNS=11 DB_MAX_IDLE_CONNS=11 DB_HOST=localhost -DB_PORT=5432 \ No newline at end of file +DB_PORT=5432 + +JWT_SECRET=dpkow94v4LtqqDvvSuzXCuID11AtnAfLuXbdG8VG3mc= +JWT_REFRESH_SECRET=Do7Tb1NdL2E9Fk1FGWbYdK2nfTPFLKIyd7wG8FUQq0Q= +JWT_ACCESS_TOKEN_EXP=900 +JWT_REFRESH_TOKEN_EXP=604800 + +REDIS_ADDR=localhost:6379 +REDIS_PASSWORD= +REDIS_DB=0 diff --git a/api/api.go b/api/api.go index acc085c..2b41e09 100644 --- a/api/api.go +++ b/api/api.go @@ -43,7 +43,10 @@ func Run(ctx context.Context, logger *logger.Logger, cfg *config.Config) error { defer dbPool.Close() // Starting the HTTP server with graceful shutdown - srv := NewServer(logger, cfg, dbPool) + srv, err := NewServer(logger, cfg, dbPool) + if err != nil { + return fmt.Errorf("failed to create server: %w", err) + } server := &http.Server{ Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port), diff --git a/api/server.go b/api/server.go index 382ed87..b321f84 100644 --- a/api/server.go +++ b/api/server.go @@ -5,13 +5,20 @@ import ( "doit/internal/config" "doit/internal/middlewares" "doit/internal/service" + "doit/internal/token" "doit/internal/web" "doit/pkg/database" "doit/pkg/logger" "net/http" ) -func NewServer(logger *logger.Logger, cfg *config.Config, dbPool *database.Pool) http.Handler { +func NewServer(logger *logger.Logger, cfg *config.Config, dbPool *database.Pool) (http.Handler, error) { + + // Helpers + tokenMaker, err := token.NewJWTToken(cfg.JWT.Secret) + if err != nil { + return nil, err + } // Middlewares errorMiddleware := middlewares.ErrorMiddleware(logger) @@ -20,10 +27,13 @@ func NewServer(logger *logger.Logger, cfg *config.Config, dbPool *database.Pool) app := web.NewApp(errorMiddleware) // Services - authService := service.NewAuthService(dbPool, logger) + userService := service.NewUserService(dbPool) + tokenService := service.NewTokenService(dbPool, tokenMaker, + cfg.JWT.AccessTokenExp, + cfg.JWT.RefreshTokenExp) // Handlers - authHandler := auth.NewHandler(logger, authService, cfg) + authHandler := auth.NewHandler(logger, userService, tokenService, cfg) // Routes auth.RegisterRoutes(app, authHandler) @@ -32,5 +42,5 @@ func NewServer(logger *logger.Logger, cfg *config.Config, dbPool *database.Pool) return web.RespondOK(w, r, map[string]string{"status": "ok"}) }) - return app + return app, nil } diff --git a/api/v1/auth/handler.go b/api/v1/auth/handler.go index 89d0979..99da835 100644 --- a/api/v1/auth/handler.go +++ b/api/v1/auth/handler.go @@ -13,12 +13,13 @@ import ( type Handler struct { log *logger.Logger - authService *service.AuthService + userService *service.UserService + tokenService *service.TokenService config *config.Config } -func NewHandler(log *logger.Logger, authService *service.AuthService, config *config.Config) *Handler { - return &Handler{log: log, authService: authService, config: config} +func NewHandler(log *logger.Logger, userService *service.UserService, tokenService *service.TokenService, config *config.Config) *Handler { + return &Handler{log: log, userService: userService, tokenService: tokenService, config: config} } func (h *Handler) Login(w http.ResponseWriter, r *http.Request) error { @@ -30,7 +31,7 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) error { return web.NewError(err, http.StatusBadRequest) } - user, err := h.authService.AuthenticateUser(ctx, model.LoginInput{ + user, err := h.userService.AuthenticateUser(ctx, model.LoginInput{ Email: input.Email, Password: input.Password, }) @@ -39,7 +40,21 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) error { return h.handleAuthError(err) } - return web.RespondOK(w, r, user) + deviceInfo := model.DeviceInfo{ + IPAddress: web.GetClientIP(r), + UserAgent: web.GetUserAgent(r), + DeviceName: web.GetDeviceName(r), + } + + tokens, err := h.tokenService.CreateTokenPair(ctx, *user, deviceInfo) + if err != nil { + return web.NewError(err, http.StatusInternalServerError) + } + response := map[string]interface{}{ + "user": user, + "tokens": tokens, + } + return web.RespondOK(w, r, response) } // handleAuthError maps service errors to appropriate HTTP responses @@ -48,8 +63,8 @@ func (h *Handler) handleAuthError( err error) error { switch { case errors.Is(err, service.ErrInvalidCredentials): return web.NewError(errors.New("invalid credentials"), http.StatusUnauthorized) - case errors.Is(err, service.ErrUserInactive): - return web.NewError(errors.New("user account is inactive"), http.StatusForbidden) + // case errors.Is(err, service.ErrUserInactive): + // return web.NewError(errors.New("user account is inactive"), http.StatusForbidden) default: return web.NewError(errors.New("internal server error"), http.StatusInternalServerError) } diff --git a/cmd/doit/main.go b/cmd/doit/main.go index c0167ed..b1fc171 100644 --- a/cmd/doit/main.go +++ b/cmd/doit/main.go @@ -17,7 +17,7 @@ var version = "v0.0.1" func main() { // Create a context for the application. ctx := context.Background() - + // Load configuration cfg, err := config.LoadConfig() if err != nil { diff --git a/go.mod b/go.mod index c8a2c14..2b64c64 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/go-playground/locales v0.14.1 github.com/go-playground/universal-translator v0.18.1 github.com/go-playground/validator/v10 v10.28.0 + github.com/golang-jwt/jwt/v5 v5.3.0 github.com/google/uuid v1.6.0 github.com/islamghany/enfl v1.0.0 github.com/jackc/pgx/v5 v5.7.2 diff --git a/go.sum b/go.sum index 059c1be..e725cdc 100644 --- a/go.sum +++ b/go.sum @@ -12,6 +12,8 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator/v10 v10.28.0 h1:Q7ibns33JjyW48gHkuFT91qX48KG0ktULL6FgHdG688= github.com/go-playground/validator/v10 v10.28.0/go.mod h1:GoI6I1SjPBh9p7ykNE/yj3fFYbyDOpwMn5KXd+m2hUU= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/islamghany/enfl v1.0.0 h1:mR77orybPguSo/WZL4Ma7D/vyioD0r/Qv/rRmU1xRfY= diff --git a/internal/config/config.go b/internal/config/config.go index 4d58ae6..5dc279d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -44,6 +44,7 @@ type RedisConfig struct { type JWTConfig struct { Secret string `env:"SECRET" flag:"jwt_secret" required:"true"` + RefreshSecret string `env:"REFRESH_SECRET" flag:"refresh_secret" required:"true"` AccessTokenExp int `env:"ACCESS_TOKEN_EXP" flag:"access_token_exp" required:"true"` RefreshTokenExp int `env:"REFRESH_TOKEN_EXP" flag:"refresh_token_exp" required:"true"` } @@ -65,7 +66,7 @@ type Config struct { Server ServerConfig `prefix:"SERVER_"` App AppConfig `prefix:"APP_"` Database DatabaseConfig `prefix:"DB_"` - // JWT JWTConfig `prefix:"JWT_"` + JWT JWTConfig `prefix:"JWT_"` // Redis RedisConfig `prefix:"REDIS_"` } diff --git a/internal/data/db/models.go b/internal/data/db/models.go index 921dd7b..bdb823b 100644 --- a/internal/data/db/models.go +++ b/internal/data/db/models.go @@ -6,6 +6,7 @@ package db import ( "database/sql/driver" + "encoding/json" "fmt" "time" @@ -101,6 +102,17 @@ func (ns NullTodoStatus) Value() (driver.Value, error) { return string(ns.TodoStatus), nil } +type RefreshToken struct { + ID uuid.UUID `db:"id" json:"id"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + TokenHash string `db:"token_hash" json:"token_hash"` + ExpiresAt pgtype.Timestamp `db:"expires_at" json:"expires_at"` + IsRevoked bool `db:"is_revoked" json:"is_revoked"` + CreatedAt pgtype.Timestamp `db:"created_at" json:"created_at"` + LastUsedAt pgtype.Timestamp `db:"last_used_at" json:"last_used_at"` + DeviceInfo json.RawMessage `db:"device_info" json:"{{.Name}}"` +} + type Todo struct { ID uuid.UUID `db:"id" json:"id"` UserID uuid.UUID `db:"user_id" json:"user_id"` @@ -127,4 +139,5 @@ type User struct { LastLoginAt pgtype.Timestamptz `db:"last_login_at" json:"last_login_at"` CreatedAt time.Time `db:"created_at" json:"created_at"` UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + TokenVersion *int32 `db:"token_version" json:"token_version"` } diff --git a/internal/data/db/querier.go b/internal/data/db/querier.go index 2698429..b3f9fe9 100644 --- a/internal/data/db/querier.go +++ b/internal/data/db/querier.go @@ -13,28 +13,39 @@ import ( type Querier interface { BulkUpdateTodoStatus(ctx context.Context, arg BulkUpdateTodoStatusParams) error BulkUpdateUsersMetadata(ctx context.Context, arg BulkUpdateUsersMetadataParams) error + CleanupExpiredTokens(ctx context.Context) error CompleteTodo(ctx context.Context, id uuid.UUID) (Todo, error) CountUserTodos(ctx context.Context, userID uuid.UUID) (int64, error) CountUserTodosByStatus(ctx context.Context, arg CountUserTodosByStatusParams) (int64, error) CountUsers(ctx context.Context) (int64, error) + CreateRefreshToken(ctx context.Context, arg CreateRefreshTokenParams) (RefreshToken, error) CreateTodo(ctx context.Context, arg CreateTodoParams) (Todo, error) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) GetOverdueTodos(ctx context.Context, limit int32) ([]GetOverdueTodosRow, error) + GetRefreshToken(ctx context.Context, tokenHash string) (RefreshToken, error) + // For security checks (detect reuse) + GetRefreshTokenIncludingRevoked(ctx context.Context, tokenHash string) (RefreshToken, error) GetTodoByID(ctx context.Context, id uuid.UUID) (Todo, error) // Use FOR UPDATE to lock the row for updates (prevents race conditions) GetTodoByIDForUpdate(ctx context.Context, id uuid.UUID) (Todo, error) GetTodoStats(ctx context.Context, userID uuid.UUID) (GetTodoStatsRow, error) GetTodosByTags(ctx context.Context, arg GetTodosByTagsParams) ([]Todo, error) + GetUserActiveRefreshTokens(ctx context.Context, userID uuid.UUID) ([]GetUserActiveRefreshTokensRow, error) GetUserByEmail(ctx context.Context, email string) (User, error) GetUserByID(ctx context.Context, id uuid.UUID) (User, error) GetUserByUsername(ctx context.Context, username string) (User, error) + GetUserTokenVersion(ctx context.Context, id uuid.UUID) (*int32, error) HardDeleteTodo(ctx context.Context, arg HardDeleteTodoParams) error HardDeleteTodos(ctx context.Context, arg HardDeleteTodosParams) error + IncrementUserTokenVersion(ctx context.Context, id uuid.UUID) (*int32, error) ListTodosByUser(ctx context.Context, arg ListTodosByUserParams) ([]Todo, error) ListTodosByUserAndStatus(ctx context.Context, arg ListTodosByUserAndStatusParams) ([]Todo, error) ListUsers(ctx context.Context, arg ListUsersParams) ([]User, error) + RevokeAllUserRefreshTokens(ctx context.Context, userID uuid.UUID) error + RevokeRefreshToken(ctx context.Context, tokenHash string) error SearchTodosByTitle(ctx context.Context, arg SearchTodosByTitleParams) ([]Todo, error) SearchUsersByEmail(ctx context.Context, arg SearchUsersByEmailParams) ([]User, error) + UpdateRefreshTokenUsage(ctx context.Context, id uuid.UUID) error UpdateTodo(ctx context.Context, arg UpdateTodoParams) (Todo, error) UpdateUser(ctx context.Context, arg UpdateUserParams) (User, error) UpdateUserLastLogin(ctx context.Context, id uuid.UUID) error diff --git a/internal/data/db/token.sql.go b/internal/data/db/token.sql.go new file mode 100644 index 0000000..f2fdbe3 --- /dev/null +++ b/internal/data/db/token.sql.go @@ -0,0 +1,208 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: token.sql + +package db + +import ( + "context" + "encoding/json" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgtype" +) + +const cleanupExpiredTokens = `-- name: CleanupExpiredTokens :exec +DELETE FROM refresh_tokens +WHERE expires_at < NOW() - INTERVAL '30 days' +` + +func (q *Queries) CleanupExpiredTokens(ctx context.Context) error { + _, err := q.db.Exec(ctx, cleanupExpiredTokens) + return err +} + +const createRefreshToken = `-- name: CreateRefreshToken :one +INSERT INTO refresh_tokens ( + id, user_id, token_hash, expires_at, device_info +) VALUES ( + $1, $2, $3, $4, $5 +) RETURNING id, user_id, token_hash, expires_at, is_revoked, created_at, last_used_at, device_info +` + +type CreateRefreshTokenParams struct { + ID uuid.UUID `db:"id" json:"id"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + TokenHash string `db:"token_hash" json:"token_hash"` + ExpiresAt pgtype.Timestamp `db:"expires_at" json:"expires_at"` + DeviceInfo json.RawMessage `db:"device_info" json:"{{.Name}}"` +} + +func (q *Queries) CreateRefreshToken(ctx context.Context, arg CreateRefreshTokenParams) (RefreshToken, error) { + row := q.db.QueryRow(ctx, createRefreshToken, + arg.ID, + arg.UserID, + arg.TokenHash, + arg.ExpiresAt, + arg.DeviceInfo, + ) + var i RefreshToken + err := row.Scan( + &i.ID, + &i.UserID, + &i.TokenHash, + &i.ExpiresAt, + &i.IsRevoked, + &i.CreatedAt, + &i.LastUsedAt, + &i.DeviceInfo, + ) + return i, err +} + +const getRefreshToken = `-- name: GetRefreshToken :one +SELECT id, user_id, token_hash, expires_at, is_revoked, created_at, last_used_at, device_info FROM refresh_tokens +WHERE token_hash = $1 + AND is_revoked = FALSE + AND expires_at > NOW() +LIMIT 1 +` + +func (q *Queries) GetRefreshToken(ctx context.Context, tokenHash string) (RefreshToken, error) { + row := q.db.QueryRow(ctx, getRefreshToken, tokenHash) + var i RefreshToken + err := row.Scan( + &i.ID, + &i.UserID, + &i.TokenHash, + &i.ExpiresAt, + &i.IsRevoked, + &i.CreatedAt, + &i.LastUsedAt, + &i.DeviceInfo, + ) + return i, err +} + +const getRefreshTokenIncludingRevoked = `-- name: GetRefreshTokenIncludingRevoked :one +SELECT id, user_id, token_hash, expires_at, is_revoked, created_at, last_used_at, device_info FROM refresh_tokens +WHERE token_hash = $1 +LIMIT 1 +` + +// For security checks (detect reuse) +func (q *Queries) GetRefreshTokenIncludingRevoked(ctx context.Context, tokenHash string) (RefreshToken, error) { + row := q.db.QueryRow(ctx, getRefreshTokenIncludingRevoked, tokenHash) + var i RefreshToken + err := row.Scan( + &i.ID, + &i.UserID, + &i.TokenHash, + &i.ExpiresAt, + &i.IsRevoked, + &i.CreatedAt, + &i.LastUsedAt, + &i.DeviceInfo, + ) + return i, err +} + +const getUserActiveRefreshTokens = `-- name: GetUserActiveRefreshTokens :many +SELECT id, created_at, last_used_at, device_info +FROM refresh_tokens +WHERE user_id = $1 + AND is_revoked = FALSE + AND expires_at > NOW() +ORDER BY last_used_at DESC +` + +type GetUserActiveRefreshTokensRow struct { + ID uuid.UUID `db:"id" json:"id"` + CreatedAt pgtype.Timestamp `db:"created_at" json:"created_at"` + LastUsedAt pgtype.Timestamp `db:"last_used_at" json:"last_used_at"` + DeviceInfo json.RawMessage `db:"device_info" json:"{{.Name}}"` +} + +func (q *Queries) GetUserActiveRefreshTokens(ctx context.Context, userID uuid.UUID) ([]GetUserActiveRefreshTokensRow, error) { + rows, err := q.db.Query(ctx, getUserActiveRefreshTokens, userID) + if err != nil { + return nil, err + } + defer rows.Close() + items := []GetUserActiveRefreshTokensRow{} + for rows.Next() { + var i GetUserActiveRefreshTokensRow + if err := rows.Scan( + &i.ID, + &i.CreatedAt, + &i.LastUsedAt, + &i.DeviceInfo, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getUserTokenVersion = `-- name: GetUserTokenVersion :one +SELECT token_version FROM users WHERE id = $1 +` + +func (q *Queries) GetUserTokenVersion(ctx context.Context, id uuid.UUID) (*int32, error) { + row := q.db.QueryRow(ctx, getUserTokenVersion, id) + var token_version *int32 + err := row.Scan(&token_version) + return token_version, err +} + +const incrementUserTokenVersion = `-- name: IncrementUserTokenVersion :one +UPDATE users +SET token_version = token_version + 1 +WHERE id = $1 +RETURNING token_version +` + +func (q *Queries) IncrementUserTokenVersion(ctx context.Context, id uuid.UUID) (*int32, error) { + row := q.db.QueryRow(ctx, incrementUserTokenVersion, id) + var token_version *int32 + err := row.Scan(&token_version) + return token_version, err +} + +const revokeAllUserRefreshTokens = `-- name: RevokeAllUserRefreshTokens :exec +UPDATE refresh_tokens +SET is_revoked = TRUE +WHERE user_id = $1 AND is_revoked = FALSE +` + +func (q *Queries) RevokeAllUserRefreshTokens(ctx context.Context, userID uuid.UUID) error { + _, err := q.db.Exec(ctx, revokeAllUserRefreshTokens, userID) + return err +} + +const revokeRefreshToken = `-- name: RevokeRefreshToken :exec +UPDATE refresh_tokens +SET is_revoked = TRUE +WHERE token_hash = $1 +` + +func (q *Queries) RevokeRefreshToken(ctx context.Context, tokenHash string) error { + _, err := q.db.Exec(ctx, revokeRefreshToken, tokenHash) + return err +} + +const updateRefreshTokenUsage = `-- name: UpdateRefreshTokenUsage :exec +UPDATE refresh_tokens +SET last_used_at = NOW() +WHERE id = $1 +` + +func (q *Queries) UpdateRefreshTokenUsage(ctx context.Context, id uuid.UUID) error { + _, err := q.db.Exec(ctx, updateRefreshTokenUsage, id) + return err +} diff --git a/internal/data/db/users.sql.go b/internal/data/db/users.sql.go index 441516e..a35be6f 100644 --- a/internal/data/db/users.sql.go +++ b/internal/data/db/users.sql.go @@ -49,7 +49,7 @@ INSERT INTO users ( metadata ) VALUES ($1, $2, $3, $4, $5) -RETURNING id, email, username, password_hash, email_verified, is_active, metadata, last_login_at, created_at, updated_at +RETURNING id, email, username, password_hash, email_verified, is_active, metadata, last_login_at, created_at, updated_at, token_version ` type CreateUserParams struct { @@ -80,12 +80,13 @@ func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, e &i.LastLoginAt, &i.CreatedAt, &i.UpdatedAt, + &i.TokenVersion, ) return i, err } const getUserByEmail = `-- name: GetUserByEmail :one -SELECT id, email, username, password_hash, email_verified, is_active, metadata, last_login_at, created_at, updated_at +SELECT id, email, username, password_hash, email_verified, is_active, metadata, last_login_at, created_at, updated_at, token_version FROM users WHERE email = $1 ` @@ -104,12 +105,13 @@ func (q *Queries) GetUserByEmail(ctx context.Context, email string) (User, error &i.LastLoginAt, &i.CreatedAt, &i.UpdatedAt, + &i.TokenVersion, ) return i, err } const getUserByID = `-- name: GetUserByID :one -SELECT id, email, username, password_hash, email_verified, is_active, metadata, last_login_at, created_at, updated_at +SELECT id, email, username, password_hash, email_verified, is_active, metadata, last_login_at, created_at, updated_at, token_version FROM users WHERE id = $1 ` @@ -128,12 +130,13 @@ func (q *Queries) GetUserByID(ctx context.Context, id uuid.UUID) (User, error) { &i.LastLoginAt, &i.CreatedAt, &i.UpdatedAt, + &i.TokenVersion, ) return i, err } const getUserByUsername = `-- name: GetUserByUsername :one -SELECT id, email, username, password_hash, email_verified, is_active, metadata, last_login_at, created_at, updated_at +SELECT id, email, username, password_hash, email_verified, is_active, metadata, last_login_at, created_at, updated_at, token_version FROM users WHERE username = $1 ` @@ -152,12 +155,13 @@ func (q *Queries) GetUserByUsername(ctx context.Context, username string) (User, &i.LastLoginAt, &i.CreatedAt, &i.UpdatedAt, + &i.TokenVersion, ) return i, err } const listUsers = `-- name: ListUsers :many -SELECT id, email, username, password_hash, email_verified, is_active, metadata, last_login_at, created_at, updated_at +SELECT id, email, username, password_hash, email_verified, is_active, metadata, last_login_at, created_at, updated_at, token_version FROM users ORDER BY created_at DESC LIMIT $1 OFFSET $2 @@ -188,6 +192,7 @@ func (q *Queries) ListUsers(ctx context.Context, arg ListUsersParams) ([]User, e &i.LastLoginAt, &i.CreatedAt, &i.UpdatedAt, + &i.TokenVersion, ); err != nil { return nil, err } @@ -200,7 +205,7 @@ func (q *Queries) ListUsers(ctx context.Context, arg ListUsersParams) ([]User, e } const searchUsersByEmail = `-- name: SearchUsersByEmail :many -SELECT id, email, username, password_hash, email_verified, is_active, metadata, last_login_at, created_at, updated_at +SELECT id, email, username, password_hash, email_verified, is_active, metadata, last_login_at, created_at, updated_at, token_version FROM users WHERE email ILIKE '%' || $1 || '%' ORDER BY created_at DESC @@ -232,6 +237,7 @@ func (q *Queries) SearchUsersByEmail(ctx context.Context, arg SearchUsersByEmail &i.LastLoginAt, &i.CreatedAt, &i.UpdatedAt, + &i.TokenVersion, ); err != nil { return nil, err } @@ -250,7 +256,7 @@ SET email = COALESCE($2, email), is_active = COALESCE($4, is_active), metadata = COALESCE($5, metadata) WHERE id = $1 -RETURNING id, email, username, password_hash, email_verified, is_active, metadata, last_login_at, created_at, updated_at +RETURNING id, email, username, password_hash, email_verified, is_active, metadata, last_login_at, created_at, updated_at, token_version ` type UpdateUserParams struct { @@ -281,6 +287,7 @@ func (q *Queries) UpdateUser(ctx context.Context, arg UpdateUserParams) (User, e &i.LastLoginAt, &i.CreatedAt, &i.UpdatedAt, + &i.TokenVersion, ) return i, err } diff --git a/internal/data/migrations/000003_create_refresh_tokens_table.down.sql b/internal/data/migrations/000003_create_refresh_tokens_table.down.sql new file mode 100644 index 0000000..35bc265 --- /dev/null +++ b/internal/data/migrations/000003_create_refresh_tokens_table.down.sql @@ -0,0 +1,11 @@ +-- Remove from users table +ALTER TABLE users DROP COLUMN token_version; +DROP INDEX idx_users_token_version; + +-- Remove refresh tokens table +DROP TABLE refresh_tokens; + +-- Remove indexes for refresh tokens +DROP INDEX idx_refresh_tokens_user; +DROP INDEX idx_refresh_tokens_lookup; +DROP INDEX idx_refresh_tokens_cleanup; \ No newline at end of file diff --git a/internal/data/migrations/000003_create_refresh_tokens_table.up.sql b/internal/data/migrations/000003_create_refresh_tokens_table.up.sql new file mode 100644 index 0000000..d027315 --- /dev/null +++ b/internal/data/migrations/000003_create_refresh_tokens_table.up.sql @@ -0,0 +1,20 @@ +-- Add to users table +ALTER TABLE users ADD COLUMN token_version INTEGER DEFAULT 1; +CREATE INDEX idx_users_token_version ON users(id, token_version); + +-- Create refresh tokens table +CREATE TABLE refresh_tokens ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + token_hash VARCHAR(64) NOT NULL UNIQUE, + expires_at TIMESTAMP NOT NULL, + is_revoked BOOLEAN DEFAULT FALSE NOT NULL, + created_at TIMESTAMP DEFAULT NOW() NOT NULL, + last_used_at TIMESTAMP DEFAULT NOW() NOT NULL, + device_info JSONB DEFAULT '{}' NOT NULL +); + +-- Create indexes for refresh tokens +CREATE INDEX idx_refresh_tokens_user ON refresh_tokens(user_id, is_revoked, expires_at); +CREATE INDEX idx_refresh_tokens_lookup ON refresh_tokens(token_hash) WHERE is_revoked = FALSE; +CREATE INDEX idx_refresh_tokens_cleanup ON refresh_tokens(expires_at) WHERE is_revoked = FALSE; \ No newline at end of file diff --git a/internal/data/queries/token.sql b/internal/data/queries/token.sql new file mode 100644 index 0000000..a9c4e54 --- /dev/null +++ b/internal/data/queries/token.sql @@ -0,0 +1,56 @@ +-- name: CreateRefreshToken :one +INSERT INTO refresh_tokens ( + id, user_id, token_hash, expires_at, device_info +) VALUES ( + $1, $2, $3, $4, $5 +) RETURNING *; + +-- name: GetRefreshToken :one +SELECT * FROM refresh_tokens +WHERE token_hash = $1 + AND is_revoked = FALSE + AND expires_at > NOW() +LIMIT 1; + +-- name: GetRefreshTokenIncludingRevoked :one +-- For security checks (detect reuse) +SELECT * FROM refresh_tokens +WHERE token_hash = $1 +LIMIT 1; + + +-- name: UpdateRefreshTokenUsage :exec +UPDATE refresh_tokens +SET last_used_at = NOW() +WHERE id = $1; + +-- name: RevokeRefreshToken :exec +UPDATE refresh_tokens +SET is_revoked = TRUE +WHERE token_hash = $1; + +-- name: RevokeAllUserRefreshTokens :exec +UPDATE refresh_tokens +SET is_revoked = TRUE +WHERE user_id = $1 AND is_revoked = FALSE; + +-- name: GetUserActiveRefreshTokens :many +SELECT id, created_at, last_used_at, device_info +FROM refresh_tokens +WHERE user_id = $1 + AND is_revoked = FALSE + AND expires_at > NOW() +ORDER BY last_used_at DESC; + +-- name: IncrementUserTokenVersion :one +UPDATE users +SET token_version = token_version + 1 +WHERE id = $1 +RETURNING token_version; + +-- name: GetUserTokenVersion :one +SELECT token_version FROM users WHERE id = $1; + +-- name: CleanupExpiredTokens :exec +DELETE FROM refresh_tokens +WHERE expires_at < NOW() - INTERVAL '30 days'; \ No newline at end of file diff --git a/internal/model/token.go b/internal/model/token.go new file mode 100644 index 0000000..38aaf3b --- /dev/null +++ b/internal/model/token.go @@ -0,0 +1,32 @@ +package model + +import ( + "time" + + "github.com/google/uuid" +) + +// TokenPair represents access and refresh tokens +type TokenPair struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` // seconds +} + +// DeviceInfo represents client device information +type DeviceInfo struct { + IPAddress string `json:"ip_address"` + UserAgent string `json:"user_agent"` + DeviceName string `json:"device_name,omitempty"` + Location string `json:"location,omitempty"` +} + +// Session represents an active user session +type Session struct { + ID uuid.UUID `json:"id"` + CreatedAt time.Time `json:"created_at"` + LastUsedAt time.Time `json:"last_used_at"` + DeviceInfo map[string]interface{} `json:"device_info"` + IsCurrent bool `json:"is_current"` +} \ No newline at end of file diff --git a/internal/model/user.go b/internal/model/user.go index ff643c2..fbe0222 100644 --- a/internal/model/user.go +++ b/internal/model/user.go @@ -17,6 +17,7 @@ type User struct { LastLoginAt *time.Time `json:"last_login_at,omitempty"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` + TokenVersion int32 `json:"token_version,omitempty"` } func (u *User) IsUserActive() bool { diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go deleted file mode 100644 index 7f2db12..0000000 --- a/internal/service/auth_service.go +++ /dev/null @@ -1,59 +0,0 @@ -package service - -import ( - "context" - "doit/internal/data/db" - "doit/internal/model" - "doit/pkg/database" - "doit/pkg/logger" - "errors" - "fmt" - - "github.com/jackc/pgx/v5" - "golang.org/x/crypto/bcrypt" -) - -type AuthService struct { - pool *database.Pool - querier db.Querier - log *logger.Logger -} - -// Sentinel errors for auth service -var ( - ErrUserNotFound = errors.New("user not found") - ErrInvalidCredentials = errors.New("invalid credentials") - ErrUserInactive = errors.New("user is inactive") -) - -func NewAuthService(pool *database.Pool, log *logger.Logger) *AuthService { - return &AuthService{pool: pool, querier: db.New(pool), log: log} -} - -func (s *AuthService) AuthenticateUser(ctx context.Context, input model.LoginInput) (*model.User, error) { - user, err := s.querier.GetUserByEmail(ctx, input.Email) - if err != nil { - if err == pgx.ErrNoRows { - return nil, ErrInvalidCredentials // Don't reveal if user exists for security - } - return nil, fmt.Errorf("failed to get user: %w", err) - } - - // Check if user is active - if !user.IsActive { - return nil, ErrUserInactive - } - - // Verify password - if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(input.Password)); err != nil { - return nil, ErrInvalidCredentials - } - - // Update last login (non-critical, don't fail auth if this fails) - if err := s.querier.UpdateUserLastLogin(ctx, user.ID); err != nil { - // Log error but don't fail authentication - s.log.Error(ctx, "warning: failed to update last login", "error", err) - } - - return toUserModel(user), nil -} diff --git a/internal/service/mocks/mock_querier.go b/internal/service/mocks/mock_querier.go index 790d6a5..566af11 100644 --- a/internal/service/mocks/mock_querier.go +++ b/internal/service/mocks/mock_querier.go @@ -70,6 +70,20 @@ func (mr *MockQuerierMockRecorder) BulkUpdateUsersMetadata(ctx, arg any) *gomock return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BulkUpdateUsersMetadata", reflect.TypeOf((*MockQuerier)(nil).BulkUpdateUsersMetadata), ctx, arg) } +// CleanupExpiredTokens mocks base method. +func (m *MockQuerier) CleanupExpiredTokens(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CleanupExpiredTokens", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// CleanupExpiredTokens indicates an expected call of CleanupExpiredTokens. +func (mr *MockQuerierMockRecorder) CleanupExpiredTokens(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupExpiredTokens", reflect.TypeOf((*MockQuerier)(nil).CleanupExpiredTokens), ctx) +} + // CompleteTodo mocks base method. func (m *MockQuerier) CompleteTodo(ctx context.Context, id uuid.UUID) (db.Todo, error) { m.ctrl.T.Helper() @@ -130,6 +144,21 @@ func (mr *MockQuerierMockRecorder) CountUsers(ctx any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountUsers", reflect.TypeOf((*MockQuerier)(nil).CountUsers), ctx) } +// CreateRefreshToken mocks base method. +func (m *MockQuerier) CreateRefreshToken(ctx context.Context, arg db.CreateRefreshTokenParams) (db.RefreshToken, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateRefreshToken", ctx, arg) + ret0, _ := ret[0].(db.RefreshToken) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateRefreshToken indicates an expected call of CreateRefreshToken. +func (mr *MockQuerierMockRecorder) CreateRefreshToken(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateRefreshToken", reflect.TypeOf((*MockQuerier)(nil).CreateRefreshToken), ctx, arg) +} + // CreateTodo mocks base method. func (m *MockQuerier) CreateTodo(ctx context.Context, arg db.CreateTodoParams) (db.Todo, error) { m.ctrl.T.Helper() @@ -175,6 +204,36 @@ func (mr *MockQuerierMockRecorder) GetOverdueTodos(ctx, limit any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOverdueTodos", reflect.TypeOf((*MockQuerier)(nil).GetOverdueTodos), ctx, limit) } +// GetRefreshToken mocks base method. +func (m *MockQuerier) GetRefreshToken(ctx context.Context, tokenHash string) (db.RefreshToken, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRefreshToken", ctx, tokenHash) + ret0, _ := ret[0].(db.RefreshToken) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetRefreshToken indicates an expected call of GetRefreshToken. +func (mr *MockQuerierMockRecorder) GetRefreshToken(ctx, tokenHash any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRefreshToken", reflect.TypeOf((*MockQuerier)(nil).GetRefreshToken), ctx, tokenHash) +} + +// GetRefreshTokenIncludingRevoked mocks base method. +func (m *MockQuerier) GetRefreshTokenIncludingRevoked(ctx context.Context, tokenHash string) (db.RefreshToken, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRefreshTokenIncludingRevoked", ctx, tokenHash) + ret0, _ := ret[0].(db.RefreshToken) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetRefreshTokenIncludingRevoked indicates an expected call of GetRefreshTokenIncludingRevoked. +func (mr *MockQuerierMockRecorder) GetRefreshTokenIncludingRevoked(ctx, tokenHash any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRefreshTokenIncludingRevoked", reflect.TypeOf((*MockQuerier)(nil).GetRefreshTokenIncludingRevoked), ctx, tokenHash) +} + // GetTodoByID mocks base method. func (m *MockQuerier) GetTodoByID(ctx context.Context, id uuid.UUID) (db.Todo, error) { m.ctrl.T.Helper() @@ -235,6 +294,21 @@ func (mr *MockQuerierMockRecorder) GetTodosByTags(ctx, arg any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTodosByTags", reflect.TypeOf((*MockQuerier)(nil).GetTodosByTags), ctx, arg) } +// GetUserActiveRefreshTokens mocks base method. +func (m *MockQuerier) GetUserActiveRefreshTokens(ctx context.Context, userID uuid.UUID) ([]db.GetUserActiveRefreshTokensRow, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUserActiveRefreshTokens", ctx, userID) + ret0, _ := ret[0].([]db.GetUserActiveRefreshTokensRow) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetUserActiveRefreshTokens indicates an expected call of GetUserActiveRefreshTokens. +func (mr *MockQuerierMockRecorder) GetUserActiveRefreshTokens(ctx, userID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserActiveRefreshTokens", reflect.TypeOf((*MockQuerier)(nil).GetUserActiveRefreshTokens), ctx, userID) +} + // GetUserByEmail mocks base method. func (m *MockQuerier) GetUserByEmail(ctx context.Context, email string) (db.User, error) { m.ctrl.T.Helper() @@ -280,6 +354,21 @@ func (mr *MockQuerierMockRecorder) GetUserByUsername(ctx, username any) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserByUsername", reflect.TypeOf((*MockQuerier)(nil).GetUserByUsername), ctx, username) } +// GetUserTokenVersion mocks base method. +func (m *MockQuerier) GetUserTokenVersion(ctx context.Context, id uuid.UUID) (*int32, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUserTokenVersion", ctx, id) + ret0, _ := ret[0].(*int32) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetUserTokenVersion indicates an expected call of GetUserTokenVersion. +func (mr *MockQuerierMockRecorder) GetUserTokenVersion(ctx, id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserTokenVersion", reflect.TypeOf((*MockQuerier)(nil).GetUserTokenVersion), ctx, id) +} + // HardDeleteTodo mocks base method. func (m *MockQuerier) HardDeleteTodo(ctx context.Context, arg db.HardDeleteTodoParams) error { m.ctrl.T.Helper() @@ -308,6 +397,21 @@ func (mr *MockQuerierMockRecorder) HardDeleteTodos(ctx, arg any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HardDeleteTodos", reflect.TypeOf((*MockQuerier)(nil).HardDeleteTodos), ctx, arg) } +// IncrementUserTokenVersion mocks base method. +func (m *MockQuerier) IncrementUserTokenVersion(ctx context.Context, id uuid.UUID) (*int32, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IncrementUserTokenVersion", ctx, id) + ret0, _ := ret[0].(*int32) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IncrementUserTokenVersion indicates an expected call of IncrementUserTokenVersion. +func (mr *MockQuerierMockRecorder) IncrementUserTokenVersion(ctx, id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrementUserTokenVersion", reflect.TypeOf((*MockQuerier)(nil).IncrementUserTokenVersion), ctx, id) +} + // ListTodosByUser mocks base method. func (m *MockQuerier) ListTodosByUser(ctx context.Context, arg db.ListTodosByUserParams) ([]db.Todo, error) { m.ctrl.T.Helper() @@ -353,6 +457,34 @@ func (mr *MockQuerierMockRecorder) ListUsers(ctx, arg any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListUsers", reflect.TypeOf((*MockQuerier)(nil).ListUsers), ctx, arg) } +// RevokeAllUserRefreshTokens mocks base method. +func (m *MockQuerier) RevokeAllUserRefreshTokens(ctx context.Context, userID uuid.UUID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RevokeAllUserRefreshTokens", ctx, userID) + ret0, _ := ret[0].(error) + return ret0 +} + +// RevokeAllUserRefreshTokens indicates an expected call of RevokeAllUserRefreshTokens. +func (mr *MockQuerierMockRecorder) RevokeAllUserRefreshTokens(ctx, userID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeAllUserRefreshTokens", reflect.TypeOf((*MockQuerier)(nil).RevokeAllUserRefreshTokens), ctx, userID) +} + +// RevokeRefreshToken mocks base method. +func (m *MockQuerier) RevokeRefreshToken(ctx context.Context, tokenHash string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RevokeRefreshToken", ctx, tokenHash) + ret0, _ := ret[0].(error) + return ret0 +} + +// RevokeRefreshToken indicates an expected call of RevokeRefreshToken. +func (mr *MockQuerierMockRecorder) RevokeRefreshToken(ctx, tokenHash any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeRefreshToken", reflect.TypeOf((*MockQuerier)(nil).RevokeRefreshToken), ctx, tokenHash) +} + // SearchTodosByTitle mocks base method. func (m *MockQuerier) SearchTodosByTitle(ctx context.Context, arg db.SearchTodosByTitleParams) ([]db.Todo, error) { m.ctrl.T.Helper() @@ -383,6 +515,20 @@ func (mr *MockQuerierMockRecorder) SearchUsersByEmail(ctx, arg any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SearchUsersByEmail", reflect.TypeOf((*MockQuerier)(nil).SearchUsersByEmail), ctx, arg) } +// UpdateRefreshTokenUsage mocks base method. +func (m *MockQuerier) UpdateRefreshTokenUsage(ctx context.Context, id uuid.UUID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateRefreshTokenUsage", ctx, id) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateRefreshTokenUsage indicates an expected call of UpdateRefreshTokenUsage. +func (mr *MockQuerierMockRecorder) UpdateRefreshTokenUsage(ctx, id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateRefreshTokenUsage", reflect.TypeOf((*MockQuerier)(nil).UpdateRefreshTokenUsage), ctx, id) +} + // UpdateTodo mocks base method. func (m *MockQuerier) UpdateTodo(ctx context.Context, arg db.UpdateTodoParams) (db.Todo, error) { m.ctrl.T.Helper() diff --git a/internal/service/token_service.go b/internal/service/token_service.go new file mode 100644 index 0000000..4ab95c9 --- /dev/null +++ b/internal/service/token_service.go @@ -0,0 +1,259 @@ +package service + +import ( + "context" + "crypto/sha256" + "doit/internal/data/db" + "doit/internal/model" + "doit/internal/token" + "doit/pkg/database" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgtype" +) + +// Sentinel errors for token service +var ( + ErrSecurityAlert = errors.New("security alert") +) + +// TokenService manages token lifecycle (creation, verification, revocation) +type TokenService struct { + tokenMaker token.TokenMaker + pool *database.Pool + querier db.Querier + accessTokenDuration int + refreshTokenDuration int +} + +func NewTokenService( + pool *database.Pool, + tokenMaker token.TokenMaker, + accessTokenDuration int, refreshTokenDuration int) *TokenService { + return &TokenService{ + tokenMaker: tokenMaker, + pool: pool, + querier: db.New(pool), + accessTokenDuration: accessTokenDuration, + refreshTokenDuration: refreshTokenDuration, + } +} + +// CreateTokenPair creates both access and refresh tokens for a user +func (s *TokenService) CreateTokenPair(ctx context.Context, user model.User, deviceInfo model.DeviceInfo) (*model.TokenPair, error) { + // 1. Create short-lived access token (stateless) + accessToken, _, err := s.tokenMaker.CreateToken(token.TokenParams{ + UserID: user.ID, + Email: user.Email, + Username: user.Username, + Version: int(user.TokenVersion), + Duration: time.Duration(s.accessTokenDuration) * time.Second, + }) + if err != nil { + return nil, fmt.Errorf("failed to create access token: %w", err) + } + + // 2. Create long-lived refresh token (stateful) + refreshTokenString, refreshPayload, err := s.tokenMaker.CreateToken(token.TokenParams{ + UserID: user.ID, + Email: user.Email, + Username: user.Username, + Version: int(user.TokenVersion), + Duration: time.Duration(s.refreshTokenDuration) * time.Second, + }) + if err != nil { + return nil, fmt.Errorf("failed to create refresh token: %w", err) + } + + // 3. Hash refresh token before storing (never store plaintext!) + tokenHash := hashToken(refreshTokenString) + + // 4. Prepare device info as JSON + deviceJSON, err := json.Marshal(deviceInfo) + if err != nil { + return nil, fmt.Errorf("marshal device info: %w", err) + } + + // 5. Store refresh token in database + _, err = s.querier.CreateRefreshToken(ctx, db.CreateRefreshTokenParams{ + ID: refreshPayload.ID, + UserID: user.ID, + TokenHash: tokenHash, + ExpiresAt: pgtype.Timestamp{Time: refreshPayload.ExpiredAt, Valid: true}, + DeviceInfo: deviceJSON, + }) + if err != nil { + return nil, fmt.Errorf("store refresh token: %w", err) + } + + return &model.TokenPair{ + AccessToken: accessToken, + RefreshToken: refreshTokenString, + TokenType: "Bearer", + ExpiresIn: s.accessTokenDuration, + }, nil +} + + +// VerifyAccessToken verifies an access token and returns the payload +// This is fast - only checks JWT signature + version (no DB lookup of token itself) +func (s *TokenService) VerifyAccessToken(ctx context.Context, tokenString string) (*token.Payload, error) { + // 1. Verify JWT signature and expiration + payload, err := s.tokenMaker.VerifyToken(tokenString) + if err != nil { + return nil, err + } + + // 2. Check token version (handles password changes, security events) + currentVersion, err := s.querier.GetUserTokenVersion(ctx, payload.UserID) + if err != nil { + return nil, fmt.Errorf("get token version: %w", err) + } + + // 3. Version mismatch = token invalidated (password changed, security event) + if payload.Version != int(*currentVersion) { + return nil, token.ErrInvalidToken + } + + return payload, nil +} + +// RefreshAccessToken exchanges a refresh token for a new access token +func (s *TokenService) RefreshAccessToken(ctx context.Context, refreshTokenString string) (*model.TokenPair, error) { +// 1. Verify refresh token JWT signature and expiration +payload, err := s.tokenMaker.VerifyToken(refreshTokenString) +if err != nil { + return nil, err +} + + + // 2. Hash the token to lookup in database + tokenHash := hashToken(refreshTokenString) + +// 3. Get refresh token from database (including revoked ones for security check) +storedToken, err := s.querier.GetRefreshTokenIncludingRevoked(ctx, tokenHash) +if err != nil { + return nil, fmt.Errorf("refresh token not found: %w", err) +} +// 4. SECURITY CHECK: Detect token reuse (revoked token being used) +if storedToken.IsRevoked { + // 🚨 CRITICAL SECURITY EVENT! + // Someone is trying to use a revoked token - possible theft/replay attack + // Log the security incident + fmt.Printf("🚨 SECURITY ALERT: Revoked token reuse detected!\n") + fmt.Printf(" User ID: %s\n", storedToken.UserID) + fmt.Printf(" Token ID: %s\n", storedToken.ID) + fmt.Printf(" Revoked At: %s\n", storedToken.LastUsedAt.Time) + fmt.Printf(" Time Since Revoke: %s\n", time.Since(storedToken.LastUsedAt.Time)) + + // Security Response: Revoke ALL user tokens immediately + err = s.RevokeAllUserTokens(ctx, storedToken.UserID) + if err != nil { + fmt.Printf("Failed to revoke all tokens: %v\n", err) + } + + // TODO: Send security alert email/notification to user + // s.notificationService.SendSecurityAlert(storedToken.UserID, ...) + + return nil, ErrSecurityAlert + } + // 5. Check if token is expired + if time.Now().After(storedToken.ExpiresAt.Time) { + return nil, token.ErrExpiredToken + } + // 6. Update last used timestamp (for session tracking) + _ = s.querier.UpdateRefreshTokenUsage(ctx, storedToken.ID) + + // 7. Get current user data (fresh from DB) + user, err := s.querier.GetUserByID(ctx, payload.UserID) + if err != nil { + return nil, fmt.Errorf("get user: %w", err) + } + + // 8. Check if user is still active (not banned/deleted) + // TODO: Add user status check if you have user.IsActive field + // if !user.IsActive { + // return nil, errors.New("user account disabled") + // } + + // 9. Create new access token with current user data and version + accessToken, _, err := s.tokenMaker.CreateToken(token.TokenParams{ + UserID: user.ID, + Email: user.Email, + Username: user.Username, + Version: int(*user.TokenVersion), + Duration: 15 * time.Minute, + }) + if err != nil { + return nil, fmt.Errorf("create access token: %w", err) + } + return &model.TokenPair{ + AccessToken: accessToken, + RefreshToken: refreshTokenString, + TokenType: "Bearer", + ExpiresIn: s.accessTokenDuration, + }, nil +} + +// Helper: Hash token using SHA256 before storage +func hashToken(token string) string { + hash := sha256.Sum256([]byte(token)) + return hex.EncodeToString(hash[:]) +} + + +// Logout revokes a specific refresh token (single device logout) +func (s *TokenService) Logout(ctx context.Context, refreshTokenString string) error { + tokenHash := hashToken(refreshTokenString) + return s.querier.RevokeRefreshToken(ctx, tokenHash) +} + + +// RevokeAllUserTokens revokes all tokens for a user (logout all devices) +// Use when: password change, security breach, admin action +func (s *TokenService) RevokeAllUserTokens(ctx context.Context, userID uuid.UUID) error { +// 1. Increment token version (invalidates ALL access tokens immediately) +_, err := s.querier.IncrementUserTokenVersion(ctx, userID) +if err != nil { + return fmt.Errorf("increment token version: %w", err) +} + +// 2. Revoke all refresh tokens (logout all devices) +err = s.querier.RevokeAllUserRefreshTokens(ctx, userID) +if err != nil { + return fmt.Errorf("revoke all refresh tokens: %w", err) +} + +return nil +} + +// GetUserSessions returns all active sessions for a user +func (s *TokenService) GetUserSessions(ctx context.Context, userID uuid.UUID, currentTokenID uuid.UUID) ([]model.Session, error) { + tokens, err := s.querier.GetUserActiveRefreshTokens(ctx, userID) + if err != nil { + return nil, fmt.Errorf("get active tokens: %w", err) + } + + sessions := make([]model.Session, len(tokens)) + for i, t := range tokens { + var deviceInfo map[string]interface{} + if len(t.DeviceInfo) > 0 { + _ = json.Unmarshal(t.DeviceInfo, &deviceInfo) + } + + sessions[i] = model.Session{ + ID: t.ID, + CreatedAt: t.CreatedAt.Time, + LastUsedAt: t.LastUsedAt.Time, + DeviceInfo: deviceInfo, + IsCurrent: t.ID == currentTokenID, + } + } + + return sessions, nil +} diff --git a/internal/service/user_service.go b/internal/service/user_service.go index a7867e7..c48de4a 100644 --- a/internal/service/user_service.go +++ b/internal/service/user_service.go @@ -18,6 +18,7 @@ import ( var ( ErrDuplicateEmail = errors.New("email already exists") ErrInvalidInput = errors.New("invalid input") + ErrInvalidCredentials = errors.New("invalid credentials") ) // UserService handles all user-related business logic @@ -128,14 +129,14 @@ func (s *UserService) AuthenticateUser(ctx context.Context, input model.LoginInp return nil, fmt.Errorf("failed to get user: %w", err) } - // Check if user is active - if !user.IsActive { - return nil, fmt.Errorf("user account is inactive") - } + // // Check if user is active + // if !user.IsActive { + // return nil, fmt.Errorf("user account is inactive") + // } // Verify password if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(input.Password)); err != nil { - return nil, fmt.Errorf("invalid credentials") + return nil, ErrInvalidCredentials } // Update last login time (non-critical, don't fail auth if this fails) @@ -281,6 +282,7 @@ func toUserModel(user db.User) *model.User { IsActive: user.IsActive, CreatedAt: user.CreatedAt, UpdatedAt: user.UpdatedAt, + TokenVersion: *user.TokenVersion, } // Handle nullable last login diff --git a/internal/token/jwt_token.go b/internal/token/jwt_token.go new file mode 100644 index 0000000..e0263b1 --- /dev/null +++ b/internal/token/jwt_token.go @@ -0,0 +1,124 @@ +package token + +import ( + "errors" + "fmt" + + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" +) + +const ( + minSecretKeySize = 32 +) + +var ( + ErrInvalidSecretKey = errors.New("secret key must be at least 32 characters") + ErrInvalidSigningMethod = errors.New("invalid token signing method") +) + +// JWTToken is a JWT implementation of TokenService. +type JWTToken struct { + secretKey string +} + +// CustomClaims represents the structured JWT claims. +type CustomClaims struct { + ID string `json:"id"` // JTI - JWT ID + UserID string `json:"user_id"` // Subject user ID + Email string `json:"email"` // User email + Username string `json:"username"` // Username + Version int `json:"version"` // Token version + jwt.RegisteredClaims +} + +// NewJWTToken creates a new JWTToken with validation. +func NewJWTToken(secretKey string) (*JWTToken, error) { + if len(secretKey) < minSecretKeySize { + return nil, ErrInvalidSecretKey + } + + return &JWTToken{ + secretKey: secretKey, + }, nil +} + +// CreateToken creates a new JWT token with structured claims. +func (t *JWTToken) CreateToken(params TokenParams) (string, *Payload, error) { + payload := NewPayload(params) + + // Create structured claims + claims := CustomClaims{ + ID: payload.ID.String(), + UserID: payload.UserID.String(), + Email: payload.Email, + Username: payload.Username, + Version: payload.Version, + RegisteredClaims: jwt.RegisteredClaims{ + ID: payload.ID.String(), // JTI + Subject: payload.UserID.String(), + IssuedAt: jwt.NewNumericDate(payload.IssuedAt), + ExpiresAt: jwt.NewNumericDate(payload.ExpiredAt), + }, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + + signedToken, err := token.SignedString([]byte(t.secretKey)) + if err != nil { + return "", nil, fmt.Errorf("failed to sign token: %w", err) + } + + return signedToken, payload, nil +} + +// VerifyToken verifies and parses a JWT token with proper validation. +func (t *JWTToken) VerifyToken(tokenString string) (*Payload, error) { + // Parse token with validation + token, err := jwt.ParseWithClaims(tokenString, &CustomClaims{}, func(token *jwt.Token) (interface{}, error) { + // Validate signing method to prevent algorithm confusion attacks + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("%w: %v", ErrInvalidSigningMethod, token.Header["alg"]) + } + return []byte(t.secretKey), nil + }) + + if err != nil { + return nil, fmt.Errorf("failed to parse token: %w", err) + } + + // Extract and validate claims + claims, ok := token.Claims.(*CustomClaims) + if !ok || !token.Valid { + return nil, ErrInvalidToken + } + + // Parse UUIDs with proper error handling + id, err := uuid.Parse(claims.ID) + if err != nil { + return nil, fmt.Errorf("invalid token ID: %w", err) + } + + userID, err := uuid.Parse(claims.UserID) + if err != nil { + return nil, fmt.Errorf("invalid user ID: %w", err) + } + + // Construct payload + payload := &Payload{ + ID: id, + UserID: userID, + Email: claims.Email, + Username: claims.Username, + Version: claims.Version, + IssuedAt: claims.IssuedAt.Time, + ExpiredAt: claims.ExpiresAt.Time, + } + + // Validate expiration + if err := payload.Valid(); err != nil { + return nil, err + } + + return payload, nil +} diff --git a/internal/token/jwt_token_test.go b/internal/token/jwt_token_test.go new file mode 100644 index 0000000..19e1d33 --- /dev/null +++ b/internal/token/jwt_token_test.go @@ -0,0 +1,452 @@ +package token + +import ( + "fmt" + "strings" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewJWTToken(t *testing.T) { + tests := []struct { + name string + secretKey string + wantErr error + }{ + { + name: "valid secret key - exactly 32 characters", + secretKey: "12345678901234567890123456789012", + wantErr: nil, + }, + { + name: "valid secret key - more than 32 characters", + secretKey: "this_is_a_very_long_secret_key_that_is_secure", + wantErr: nil, + }, + { + name: "invalid secret key - empty", + secretKey: "", + wantErr: ErrInvalidSecretKey, + }, + { + name: "invalid secret key - too short (31 characters)", + secretKey: "1234567890123456789012345678901", + wantErr: ErrInvalidSecretKey, + }, + { + name: "invalid secret key - too short (16 characters)", + secretKey: "1234567890123456", + wantErr: ErrInvalidSecretKey, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + token, err := NewJWTToken(tt.secretKey) + + if tt.wantErr != nil { + assert.ErrorIs(t, err, tt.wantErr, "error should match expected") + assert.Nil(t, token, "token should be nil on error") + } else { + assert.NoError(t, err, "should not return error") + assert.NotNil(t, token, "token should not be nil") + assert.Equal(t, tt.secretKey, token.secretKey, "secret key should be stored") + } + }) + } +} + +func TestJWTToken_CreateToken(t *testing.T) { + secretKey := "test_secret_key_that_is_32_chars_long!!" + jwtToken, err := NewJWTToken(secretKey) + require.NoError(t, err) + + tests := []struct { + name string + params TokenParams + }{ + { + name: "creates token with standard parameters", + params: TokenParams{ + UserID: uuid.New(), + Email: "test@example.com", + Username: "testuser", + Version: 1, + Duration: time.Hour, + }, + }, + { + name: "creates token with different duration", + params: TokenParams{ + UserID: uuid.New(), + Email: "admin@example.com", + Username: "admin", + Version: 2, + Duration: 24 * time.Hour, + }, + }, + { + name: "creates token with version 0", + params: TokenParams{ + UserID: uuid.New(), + Email: "user@example.com", + Username: "user", + Version: 0, + Duration: 30 * time.Minute, + }, + }, + { + name: "creates token with special characters in email", + params: TokenParams{ + UserID: uuid.New(), + Email: "test+tag@example.co.uk", + Username: "special_user", + Version: 1, + Duration: time.Hour, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tokenString, payload, err := jwtToken.CreateToken(tt.params) + + require.NoError(t, err, "should not return error") + assert.NotEmpty(t, tokenString, "token string should not be empty") + assert.NotNil(t, payload, "payload should not be nil") + + // Verify payload fields + assert.NotEqual(t, uuid.Nil, payload.ID, "payload ID should not be nil") + assert.Equal(t, tt.params.UserID, payload.UserID, "user ID should match") + assert.Equal(t, tt.params.Email, payload.Email, "email should match") + assert.Equal(t, tt.params.Username, payload.Username, "username should match") + assert.Equal(t, tt.params.Version, payload.Version, "version should match") + + // Verify token string format (JWT has 3 parts separated by dots) + parts := strings.Split(tokenString, ".") + assert.Equal(t, 3, len(parts), "JWT should have 3 parts") + }) + } +} + +func TestJWTToken_VerifyToken(t *testing.T) { + secretKey := "test_secret_key_that_is_32_chars_long!!" + jwtToken, err := NewJWTToken(secretKey) + require.NoError(t, err) + + t.Run("successfully verifies valid token", func(t *testing.T) { + params := TokenParams{ + UserID: uuid.New(), + Email: "test@example.com", + Username: "testuser", + Version: 1, + Duration: time.Hour, + } + + tokenString, originalPayload, err := jwtToken.CreateToken(params) + require.NoError(t, err) + + verifiedPayload, err := jwtToken.VerifyToken(tokenString) + require.NoError(t, err) + assert.NotNil(t, verifiedPayload, "verified payload should not be nil") + + // Compare all fields + assert.Equal(t, originalPayload.ID, verifiedPayload.ID, "ID should match") + assert.Equal(t, originalPayload.UserID, verifiedPayload.UserID, "user ID should match") + assert.Equal(t, originalPayload.Email, verifiedPayload.Email, "email should match") + assert.Equal(t, originalPayload.Username, verifiedPayload.Username, "username should match") + assert.Equal(t, originalPayload.Version, verifiedPayload.Version, "version should match") + // JWT uses Unix timestamps which only have second precision, so compare times in seconds + assert.Equal(t, originalPayload.IssuedAt.Unix(), verifiedPayload.IssuedAt.Unix(), "issued at should match") + assert.Equal(t, originalPayload.ExpiredAt.Unix(), verifiedPayload.ExpiredAt.Unix(), "expired at should match") + }) + + t.Run("fails to verify token with wrong secret", func(t *testing.T) { + params := TokenParams{ + UserID: uuid.New(), + Email: "test@example.com", + Username: "testuser", + Version: 1, + Duration: time.Hour, + } + + tokenString, _, err := jwtToken.CreateToken(params) + require.NoError(t, err) + + // Create token service with different secret (must be at least 32 chars) + wrongSecretToken, err := NewJWTToken("different_secret_key_must_be_32_characters_long!") + require.NoError(t, err) + + payload, err := wrongSecretToken.VerifyToken(tokenString) + assert.Error(t, err, "should return error") + assert.Nil(t, payload, "payload should be nil") + }) + + t.Run("fails to verify expired token", func(t *testing.T) { + params := TokenParams{ + UserID: uuid.New(), + Email: "test@example.com", + Username: "testuser", + Version: 1, + Duration: -time.Hour, // Negative duration = already expired + } + + tokenString, _, err := jwtToken.CreateToken(params) + require.NoError(t, err) + + payload, err := jwtToken.VerifyToken(tokenString) + assert.Error(t, err, "should return error") + // Error is wrapped, so check if it contains token expiration error + assert.Contains(t, err.Error(), "token is expired", "should contain expired token error") + assert.Nil(t, payload, "payload should be nil") + }) + + t.Run("fails to verify malformed token", func(t *testing.T) { + tests := []struct { + name string + token string + }{ + {"empty token", ""}, + {"random string", "this_is_not_a_jwt_token"}, + {"incomplete JWT", "header.payload"}, + {"invalid base64", "not.valid.base64!!!"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + payload, err := jwtToken.VerifyToken(tt.token) + assert.Error(t, err, "should return error") + assert.Nil(t, payload, "payload should be nil") + }) + } + }) + + t.Run("fails to verify token with invalid UUID", func(t *testing.T) { + // Create a token with invalid UUID in claims + claims := CustomClaims{ + ID: "not-a-valid-uuid", + UserID: uuid.New().String(), + Email: "test@example.com", + Username: "testuser", + Version: 1, + RegisteredClaims: jwt.RegisteredClaims{ + ID: "not-a-valid-uuid", + IssuedAt: jwt.NewNumericDate(time.Now()), + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), + }, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenString, err := token.SignedString([]byte(secretKey)) + require.NoError(t, err) + + payload, err := jwtToken.VerifyToken(tokenString) + assert.Error(t, err, "should return error") + assert.Nil(t, payload, "payload should be nil") + assert.Contains(t, err.Error(), "invalid token ID", "error should mention invalid token ID") + }) +} + +func TestJWTToken_AlgorithmValidation(t *testing.T) { + secretKey := "test_secret_key_that_is_32_chars_long!!" + jwtToken, err := NewJWTToken(secretKey) + require.NoError(t, err) + + t.Run("rejects token with different algorithm", func(t *testing.T) { + // Create a token with a different signing method (none) + claims := CustomClaims{ + ID: uuid.New().String(), + UserID: uuid.New().String(), + Email: "test@example.com", + Username: "testuser", + Version: 1, + RegisteredClaims: jwt.RegisteredClaims{ + ID: uuid.New().String(), + IssuedAt: jwt.NewNumericDate(time.Now()), + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), + }, + } + + // Try to create a token with "none" algorithm (security vulnerability) + token := jwt.NewWithClaims(jwt.SigningMethodNone, claims) + tokenString, err := token.SignedString(jwt.UnsafeAllowNoneSignatureType) + require.NoError(t, err) + + payload, err := jwtToken.VerifyToken(tokenString) + assert.Error(t, err, "should reject token with different algorithm") + assert.Nil(t, payload, "payload should be nil") + }) +} + +func TestJWTToken_CreateAndVerifyRoundTrip(t *testing.T) { + secretKey := "test_secret_key_that_is_32_chars_long!!" + jwtToken, err := NewJWTToken(secretKey) + require.NoError(t, err) + + // Test multiple tokens to ensure consistency + numTokens := 10 + tokens := make(map[string]*Payload) + + // Create tokens + for i := 0; i < numTokens; i++ { + params := TokenParams{ + UserID: uuid.New(), + Email: fmt.Sprintf("user%d@example.com", i), + Username: fmt.Sprintf("user%d", i), + Version: i, + Duration: time.Hour, + } + + tokenString, payload, err := jwtToken.CreateToken(params) + require.NoError(t, err) + tokens[tokenString] = payload + } + + // Verify all tokens + for tokenString, originalPayload := range tokens { + verifiedPayload, err := jwtToken.VerifyToken(tokenString) + require.NoError(t, err, "should verify token successfully") + + assert.Equal(t, originalPayload.ID, verifiedPayload.ID) + assert.Equal(t, originalPayload.UserID, verifiedPayload.UserID) + assert.Equal(t, originalPayload.Email, verifiedPayload.Email) + assert.Equal(t, originalPayload.Username, verifiedPayload.Username) + assert.Equal(t, originalPayload.Version, verifiedPayload.Version) + } +} + +func TestJWTToken_ConcurrentOperations(t *testing.T) { + secretKey := "test_secret_key_that_is_32_chars_long!!" + jwtToken, err := NewJWTToken(secretKey) + require.NoError(t, err) + + // Test concurrent token creation and verification + numGoroutines := 50 + done := make(chan bool, numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func(idx int) { + params := TokenParams{ + UserID: uuid.New(), + Email: fmt.Sprintf("user%d@example.com", idx), + Username: fmt.Sprintf("user%d", idx), + Version: idx, + Duration: time.Hour, + } + + // Create token + tokenString, originalPayload, err := jwtToken.CreateToken(params) + assert.NoError(t, err) + + // Verify token + verifiedPayload, err := jwtToken.VerifyToken(tokenString) + assert.NoError(t, err) + assert.Equal(t, originalPayload.ID, verifiedPayload.ID) + + done <- true + }(i) + } + + // Wait for all goroutines to complete + for i := 0; i < numGoroutines; i++ { + <-done + } +} + +func TestJWTToken_TokenExpiration(t *testing.T) { + secretKey := "test_secret_key_that_is_32_chars_long!!" + jwtToken, err := NewJWTToken(secretKey) + require.NoError(t, err) + + t.Run("token expires after duration", func(t *testing.T) { + params := TokenParams{ + UserID: uuid.New(), + Email: "test@example.com", + Username: "testuser", + Version: 1, + Duration: 2 * time.Second, // Short duration but not too short + } + + tokenString, _, err := jwtToken.CreateToken(params) + require.NoError(t, err) + + // Verify immediately - should succeed + payload, err := jwtToken.VerifyToken(tokenString) + assert.NoError(t, err) + assert.NotNil(t, payload) + + // Wait for expiration (add buffer) + time.Sleep(3 * time.Second) + + // Verify after expiration - should fail + payload, err = jwtToken.VerifyToken(tokenString) + assert.Error(t, err, "should return error after expiration") + assert.Contains(t, err.Error(), "token is expired", "error should indicate token is expired") + assert.Nil(t, payload) + }) +} + +func TestJWTToken_RegisteredClaims(t *testing.T) { + secretKey := "test_secret_key_that_is_32_chars_long!!" + jwtToken, err := NewJWTToken(secretKey) + require.NoError(t, err) + + params := TokenParams{ + UserID: uuid.New(), + Email: "test@example.com", + Username: "testuser", + Version: 1, + Duration: time.Hour, + } + + tokenString, payload, err := jwtToken.CreateToken(params) + require.NoError(t, err) + + // Parse token to check registered claims + token, err := jwt.ParseWithClaims(tokenString, &CustomClaims{}, func(token *jwt.Token) (interface{}, error) { + return []byte(secretKey), nil + }) + require.NoError(t, err) + + claims, ok := token.Claims.(*CustomClaims) + require.True(t, ok) + + // Verify registered claims + assert.Equal(t, payload.ID.String(), claims.ID, "JTI should match payload ID") + assert.Equal(t, payload.UserID.String(), claims.Subject, "Subject should match user ID") + assert.NotNil(t, claims.IssuedAt, "IssuedAt should be set") + assert.NotNil(t, claims.ExpiresAt, "ExpiresAt should be set") +} + +func TestJWTToken_EmptyFields(t *testing.T) { + secretKey := "test_secret_key_that_is_32_chars_long!!" + jwtToken, err := NewJWTToken(secretKey) + require.NoError(t, err) + + t.Run("handles empty email and username", func(t *testing.T) { + params := TokenParams{ + UserID: uuid.New(), + Email: "", + Username: "", + Version: 1, + Duration: time.Hour, + } + + tokenString, originalPayload, err := jwtToken.CreateToken(params) + require.NoError(t, err) + + verifiedPayload, err := jwtToken.VerifyToken(tokenString) + require.NoError(t, err) + + assert.Equal(t, originalPayload.Email, verifiedPayload.Email) + assert.Equal(t, originalPayload.Username, verifiedPayload.Username) + assert.Empty(t, verifiedPayload.Email) + assert.Empty(t, verifiedPayload.Username) + }) +} + diff --git a/internal/token/token.go b/internal/token/token.go new file mode 100644 index 0000000..2ba1f3a --- /dev/null +++ b/internal/token/token.go @@ -0,0 +1,65 @@ +package token + +import ( + "errors" + "time" + + "github.com/google/uuid" +) + +var ( + ErrInvalidToken = errors.New("invalid token") + ErrExpiredToken = errors.New("expired token") +) + +// TokenParams represents the parameters for creating a token. +type TokenParams struct { + UserID uuid.UUID + Email string + Username string + Version int + Duration time.Duration +} + +// TokenMaker provides methods for managing tokens. +type TokenMaker interface { + CreateToken(params TokenParams) (string, *Payload, error) + VerifyToken(token string) (*Payload, error) +} + +// Payload represents the payload of a token. +type Payload struct { + ID uuid.UUID `json:"id"` // JTI - JWT ID for token identification/revocation + UserID uuid.UUID `json:"user_id"` // Subject user ID + Email string `json:"email"` // User email + Username string `json:"username"` // Username + Version int `json:"version"` // Token version for invalidation + IssuedAt time.Time `json:"issued_at"` // Token issue time + ExpiredAt time.Time `json:"expired_at"` // Token expiration time +} + +// NewPayload creates a new Payload. +func NewPayload( + params TokenParams, +) *Payload { + p := &Payload{ + ID: uuid.New(), + UserID: params.UserID, + Email: params.Email, + Username: params.Username, + Version: params.Version, + IssuedAt: time.Now(), + ExpiredAt: time.Now().Add(params.Duration), + } + + return p +} + +// Valid checks if the token is valid or not. +func (p *Payload) Valid() error { + if time.Now().After(p.ExpiredAt) { + return ErrExpiredToken + } + + return nil +} \ No newline at end of file diff --git a/internal/token/token_test.go b/internal/token/token_test.go new file mode 100644 index 0000000..75714f7 --- /dev/null +++ b/internal/token/token_test.go @@ -0,0 +1,273 @@ +package token + +import ( + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewPayload(t *testing.T) { + tests := []struct { + name string + params TokenParams + }{ + { + name: "creates payload with valid parameters", + params: TokenParams{ + UserID: uuid.New(), + Email: "test@example.com", + Username: "testuser", + Version: 1, + Duration: time.Hour, + }, + }, + { + name: "creates payload with different duration", + params: TokenParams{ + UserID: uuid.New(), + Email: "admin@example.com", + Username: "admin", + Version: 2, + Duration: 24 * time.Hour, + }, + }, + { + name: "creates payload with zero version", + params: TokenParams{ + UserID: uuid.New(), + Email: "user@example.com", + Username: "user", + Version: 0, + Duration: 30 * time.Minute, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + beforeTime := time.Now() + payload := NewPayload(tt.params) + afterTime := time.Now() + + // Assert basic fields + assert.NotEqual(t, uuid.Nil, payload.ID, "payload ID should not be nil") + assert.Equal(t, tt.params.UserID, payload.UserID, "user ID should match") + assert.Equal(t, tt.params.Email, payload.Email, "email should match") + assert.Equal(t, tt.params.Username, payload.Username, "username should match") + assert.Equal(t, tt.params.Version, payload.Version, "version should match") + + // Assert time fields are within reasonable bounds + assert.True(t, payload.IssuedAt.After(beforeTime.Add(-time.Second)), "issued at should be after before time") + assert.True(t, payload.IssuedAt.Before(afterTime.Add(time.Second)), "issued at should be before after time") + + // Assert expiration is correct (within reasonable bounds due to time precision) + actualDuration := payload.ExpiredAt.Sub(payload.IssuedAt) + assert.InDelta(t, tt.params.Duration.Nanoseconds(), actualDuration.Nanoseconds(), float64(time.Millisecond), "expired at should be issued at + duration") + }) + } +} + +func TestPayload_Valid(t *testing.T) { + tests := []struct { + name string + setupPayload func() *Payload + wantErr error + }{ + { + name: "valid payload - expires in future", + setupPayload: func() *Payload { + return &Payload{ + ID: uuid.New(), + UserID: uuid.New(), + Email: "test@example.com", + Username: "testuser", + Version: 1, + IssuedAt: time.Now(), + ExpiredAt: time.Now().Add(time.Hour), + } + }, + wantErr: nil, + }, + { + name: "valid payload - expires in 1 second", + setupPayload: func() *Payload { + return &Payload{ + ID: uuid.New(), + UserID: uuid.New(), + Email: "test@example.com", + Username: "testuser", + Version: 1, + IssuedAt: time.Now(), + ExpiredAt: time.Now().Add(time.Second), + } + }, + wantErr: nil, + }, + { + name: "invalid payload - expired 1 hour ago", + setupPayload: func() *Payload { + return &Payload{ + ID: uuid.New(), + UserID: uuid.New(), + Email: "test@example.com", + Username: "testuser", + Version: 1, + IssuedAt: time.Now().Add(-2 * time.Hour), + ExpiredAt: time.Now().Add(-time.Hour), + } + }, + wantErr: ErrExpiredToken, + }, + { + name: "invalid payload - expired 1 second ago", + setupPayload: func() *Payload { + return &Payload{ + ID: uuid.New(), + UserID: uuid.New(), + Email: "test@example.com", + Username: "testuser", + Version: 1, + IssuedAt: time.Now().Add(-time.Minute), + ExpiredAt: time.Now().Add(-time.Second), + } + }, + wantErr: ErrExpiredToken, + }, + { + name: "invalid payload - expires exactly now (edge case)", + setupPayload: func() *Payload { + now := time.Now() + return &Payload{ + ID: uuid.New(), + UserID: uuid.New(), + Email: "test@example.com", + Username: "testuser", + Version: 1, + IssuedAt: now.Add(-time.Hour), + ExpiredAt: now, + } + }, + // This might be valid or invalid depending on timing + // The test allows for both scenarios + wantErr: nil, // or ErrExpiredToken depending on exact timing + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + payload := tt.setupPayload() + err := payload.Valid() + + if tt.wantErr != nil { + assert.ErrorIs(t, err, tt.wantErr, "error should match expected error") + } else { + // For edge cases near expiration, accept either result + if tt.name == "invalid payload - expires exactly now (edge case)" { + if err != nil { + assert.ErrorIs(t, err, ErrExpiredToken) + } + } else { + assert.NoError(t, err, "should not return error") + } + } + }) + } +} + +func TestPayload_UniqueIDs(t *testing.T) { + // Create multiple payloads and ensure IDs are unique + params := TokenParams{ + UserID: uuid.New(), + Email: "test@example.com", + Username: "testuser", + Version: 1, + Duration: time.Hour, + } + + ids := make(map[uuid.UUID]bool) + numPayloads := 100 + + for i := 0; i < numPayloads; i++ { + payload := NewPayload(params) + require.NotEqual(t, uuid.Nil, payload.ID, "payload ID should not be nil") + require.False(t, ids[payload.ID], "payload ID should be unique") + ids[payload.ID] = true + } + + assert.Equal(t, numPayloads, len(ids), "should have unique IDs for all payloads") +} + +func TestPayload_DurationCalculation(t *testing.T) { + tests := []struct { + name string + duration time.Duration + }{ + {"1 minute", time.Minute}, + {"15 minutes", 15 * time.Minute}, + {"1 hour", time.Hour}, + {"24 hours", 24 * time.Hour}, + {"7 days", 7 * 24 * time.Hour}, + {"30 days", 30 * 24 * time.Hour}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + params := TokenParams{ + UserID: uuid.New(), + Email: "test@example.com", + Username: "testuser", + Version: 1, + Duration: tt.duration, + } + + payload := NewPayload(params) + + actualDuration := payload.ExpiredAt.Sub(payload.IssuedAt) + // Allow for small timing differences (within 1 millisecond) + assert.InDelta(t, tt.duration.Nanoseconds(), actualDuration.Nanoseconds(), float64(time.Millisecond), "duration should be calculated correctly") + + // Verify token is valid immediately after creation + assert.NoError(t, payload.Valid(), "newly created payload should be valid") + }) + } +} + +func TestPayload_ZeroDuration(t *testing.T) { + params := TokenParams{ + UserID: uuid.New(), + Email: "test@example.com", + Username: "testuser", + Version: 1, + Duration: 0, // Zero duration + } + + payload := NewPayload(params) + + // With zero duration, the token expires immediately + // It may or may not be valid depending on exact timing + err := payload.Valid() + if err != nil { + assert.ErrorIs(t, err, ErrExpiredToken, "zero duration token should be expired or about to expire") + } +} + +func TestPayload_NegativeDuration(t *testing.T) { + params := TokenParams{ + UserID: uuid.New(), + Email: "test@example.com", + Username: "testuser", + Version: 1, + Duration: -time.Hour, // Negative duration + } + + payload := NewPayload(params) + + // With negative duration, the token is already expired + err := payload.Valid() + assert.ErrorIs(t, err, ErrExpiredToken, "negative duration token should be expired") + assert.True(t, payload.ExpiredAt.Before(payload.IssuedAt), "expired at should be before issued at") +} + diff --git a/internal/web/request.go b/internal/web/request.go index 67fdafa..07d3453 100644 --- a/internal/web/request.go +++ b/internal/web/request.go @@ -74,3 +74,29 @@ func GetQueryInt(r *http.Request, key string, defaultValue int) int { } return num } + +// Helper: Extract client IP address +func GetClientIP(r *http.Request) string { + // Check X-Forwarded-For header first (proxy/load balancer) + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + return xff + } + // Check X-Real-IP header + if xri := r.Header.Get("X-Real-IP"); xri != "" { + return xri + } + // Fallback to RemoteAddr + return r.RemoteAddr +} + +func GetUserAgent(r *http.Request) string { + return r.UserAgent() +} + +func GetDeviceName(r *http.Request) string { + return r.Header.Get("X-Device-Name") +} + +func GetLocation(r *http.Request) string { + return r.Header.Get("X-Location") +} \ No newline at end of file diff --git a/pkg/database/database_test.go b/pkg/database/database_test.go index d463622..7d8c1e1 100644 --- a/pkg/database/database_test.go +++ b/pkg/database/database_test.go @@ -224,3 +224,6 @@ func BenchmarkBuildDSN(b *testing.B) { _ = BuildDSN(config) } } + + +