Skip to content
52 changes: 52 additions & 0 deletions src/main/java/org/runimo/runimo/auth/domain/RefreshToken.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package org.runimo.runimo.auth.domain;

import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.GeneratedValue;
import jakarta.persistence.GenerationType;
import jakarta.persistence.Id;
import jakarta.persistence.Table;
import lombok.AccessLevel;
import lombok.Builder;
import lombok.Getter;
import lombok.NoArgsConstructor;
import org.runimo.runimo.auth.exceptions.UserJwtException;
import org.runimo.runimo.common.CreateUpdateAuditEntity;
import org.runimo.runimo.user.enums.UserHttpResponseCode;
import org.springframework.context.annotation.Profile;

@Profile({"prod", "dev"})
@Table(name = "user_refresh_token")
@Entity
@Getter
@NoArgsConstructor(access = AccessLevel.PROTECTED)
public class RefreshToken extends CreateUpdateAuditEntity {

@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
private Long id;
@Column(name = "user_id", nullable = false, unique = true)
private Long userId;
@Column(name = "refresh_token", nullable = false)
private String refreshToken;

@Builder
private RefreshToken(Long userId, String refreshToken) {
this.userId = userId;
this.refreshToken = refreshToken;
}

public static RefreshToken of(Long userId, String refreshToken) {
return RefreshToken.builder()
.userId(userId)
.refreshToken(refreshToken)
.build();
}

public void update(String refreshToken) {
if (refreshToken == null || refreshToken.isEmpty()) {
throw UserJwtException.of(UserHttpResponseCode.TOKEN_REFRESH_FAIL);
}
this.refreshToken = refreshToken;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package org.runimo.runimo.auth.repository;

import java.time.LocalDateTime;
import java.time.temporal.ChronoUnit;
import java.util.Optional;
import lombok.RequiredArgsConstructor;
import org.runimo.runimo.auth.domain.RefreshToken;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Repository;

@Profile({"prod", "dev"})
@Repository
@RequiredArgsConstructor
public class DatabaseTokenRepository implements JwtTokenRepository {

private final RefreshTokenJpaRepository refreshTokenJpaRepository;
@Value("${jwt.refresh.expiration}")
private Long refreshTokenExpiryMillis;

@Override
public Optional<String> findRefreshTokenByUserId(final Long userId) {
return refreshTokenJpaRepository.findByUserId(userId)
.map(RefreshToken::getRefreshToken);
}

@Override
public void saveRefreshTokenWithUserId(final Long userId, final String refreshToken) {
LocalDateTime REPLACE_CUTOFF_TIME = LocalDateTime.now()
.minus(refreshTokenExpiryMillis, ChronoUnit.MILLIS);

RefreshToken updatedRefreshToken = refreshTokenJpaRepository.findByUserIdAfterCutoffTime(userId,
REPLACE_CUTOFF_TIME)
.map(existingToken -> {
existingToken.update(refreshToken);
return existingToken;
})
.orElseGet(() ->
RefreshToken.of(userId, refreshToken)
);

refreshTokenJpaRepository.save(updatedRefreshToken);
}


}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package org.runimo.runimo.auth.repository;

import java.time.Duration;
import java.util.Optional;
import lombok.RequiredArgsConstructor;
import org.runimo.runimo.common.cache.InMemoryCache;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Repository;

@Profile({"test", "local"})
@Repository
@RequiredArgsConstructor
public class InMemoryTokenRepository implements JwtTokenRepository {

private final InMemoryCache<Long, String> refreshTokenCache;
@Value("${jwt.refresh.expiration}")
private Long refreshTokenExpiry;


@Override
public Optional<String> findRefreshTokenByUserId(Long userId) {
return refreshTokenCache.get(userId);
}

@Override
public void saveRefreshTokenWithUserId(Long userId, String refreshToken) {
refreshTokenCache.put(userId, refreshToken, Duration.ofMillis(refreshTokenExpiry));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package org.runimo.runimo.auth.repository;

import java.util.Optional;

public interface JwtTokenRepository {

Optional<String> findRefreshTokenByUserId(Long userId);

void saveRefreshTokenWithUserId(Long userId, String refreshToken);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package org.runimo.runimo.auth.repository;

import java.time.LocalDateTime;
import java.util.Optional;
import org.runimo.runimo.auth.domain.RefreshToken;
import org.springframework.context.annotation.Profile;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.Query;

@Profile({"prod", "dev"})
public interface RefreshTokenJpaRepository extends JpaRepository<RefreshToken, String> {

Optional<RefreshToken> findByUserId(Long id);

@Query("select distinct r from RefreshToken r " +
"where r.userId = :userId and r.updatedAt > :cutOffDateTime")
Optional<RefreshToken> findByUserIdAfterCutoffTime(Long userId, LocalDateTime cutOffDateTime);
}
Original file line number Diff line number Diff line change
@@ -1,49 +1,50 @@
package org.runimo.runimo.auth.service;

import java.time.Duration;
import lombok.RequiredArgsConstructor;
import org.runimo.runimo.auth.exceptions.UserJwtException;
import org.runimo.runimo.auth.jwt.JwtResolver;
import org.runimo.runimo.auth.jwt.JwtTokenFactory;
import org.runimo.runimo.auth.repository.JwtTokenRepository;
import org.runimo.runimo.auth.service.dto.TokenPair;
import org.runimo.runimo.common.cache.InMemoryCache;
import org.runimo.runimo.user.domain.User;
import org.runimo.runimo.user.enums.UserHttpResponseCode;
import org.springframework.beans.factory.annotation.Value;
import org.runimo.runimo.user.service.UserFinder;
import org.springframework.stereotype.Service;

@Service
@RequiredArgsConstructor
public class TokenRefreshService {

private final JwtResolver jwtResolver;
private final InMemoryCache<String, String> refreshTokenCache;
private final JwtTokenFactory jwtTokenFactory;
@Value("${jwt.refresh.expiration}")
private Long refreshTokenExpiry;

public void putRefreshToken(String userId, String refreshToken) {
refreshTokenCache.put(userId, refreshToken, Duration.ofMillis(refreshTokenExpiry));
private final JwtResolver jwtResolver;
private final JwtTokenRepository jwtTokenRepository;
private final JwtTokenFactory jwtTokenFactory;
private final UserFinder userFinder;

public void putRefreshToken(String userPublicId, String refreshToken) {
User user = userFinder.findUserByPublicId(userPublicId)
.orElseThrow(() -> UserJwtException.of(UserHttpResponseCode.TOKEN_REFRESH_FAIL));
jwtTokenRepository.saveRefreshTokenWithUserId(user.getId(), refreshToken);
}

public TokenPair refreshAccessToken(String refreshToken) {
String userPublicId;
try {
jwtResolver.verifyJwtToken(refreshToken);
userPublicId = jwtResolver.getUserIdFromJwtToken(refreshToken);
} catch (Exception e) {
throw UserJwtException.of(UserHttpResponseCode.TOKEN_REFRESH_FAIL);
}

public TokenPair refreshAccessToken(String refreshToken) {
String userId;
try {
jwtResolver.verifyJwtToken(refreshToken);
userId = jwtResolver.getUserIdFromJwtToken(refreshToken);
} catch (Exception e) {
throw UserJwtException.of(UserHttpResponseCode.TOKEN_REFRESH_FAIL);
}

String storedToken = refreshTokenCache.get(userId).orElse(null);
if (storedToken == null || !storedToken.equals(refreshToken)) {
throw new IllegalArgumentException("Refresh token mismatch");
}

String newAccessToken = jwtTokenFactory.generateAccessToken(userId);
String newRefreshToken = jwtTokenFactory.generateRefreshToken(userId);

// 갱신한 리프레시 토큰 저장 (기존 토큰 갱신)
refreshTokenCache.put(userId, newRefreshToken, Duration.ofMillis(refreshTokenExpiry));
return new TokenPair(newAccessToken, newRefreshToken);
User user = userFinder.findUserByPublicId(userPublicId)
.orElseThrow(() -> UserJwtException.of(UserHttpResponseCode.TOKEN_REFRESH_FAIL));

String storedToken = jwtTokenRepository.findRefreshTokenByUserId(user.getId())
.orElseThrow(() -> UserJwtException.of(UserHttpResponseCode.REFRESH_EXPIRED));
if (!storedToken.equals(refreshToken)) {
throw UserJwtException.of(UserHttpResponseCode.TOKEN_INVALID);
}

String newAccessToken = jwtTokenFactory.generateAccessToken(userPublicId);
return new TokenPair(newAccessToken, refreshToken);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package org.runimo.runimo.common;

import jakarta.persistence.MappedSuperclass;
import java.time.LocalDateTime;
import lombok.Getter;
import lombok.NoArgsConstructor;
import org.hibernate.annotations.CreationTimestamp;
import org.hibernate.annotations.UpdateTimestamp;

@Getter
@NoArgsConstructor
@MappedSuperclass
public abstract class CreateUpdateAuditEntity {

@CreationTimestamp
protected LocalDateTime createdAt;

@UpdateTimestamp
protected LocalDateTime updatedAt;

}
60 changes: 30 additions & 30 deletions src/main/java/org/runimo/runimo/config/CacheConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,38 +13,38 @@
@Configuration
public class CacheConfig {

@Value("${cache.cleanup.interval:300}")
private int cleanupIntervalSeconds;
@Value("${cache.cleanup.interval:300}")
private int cleanupIntervalSeconds;

@Value("${cache.cleanup.thread-pool-size:1}")
private int cleanupThreadPoolSize;
@Value("${cache.cleanup.thread-pool-size:1}")
private int cleanupThreadPoolSize;

@Bean
public TaskScheduler cacheCleanupScheduler() {
ThreadPoolTaskScheduler scheduler = new ThreadPoolTaskScheduler();
scheduler.setPoolSize(cleanupThreadPoolSize);
scheduler.setThreadNamePrefix("cache-cleanup-");
scheduler.setDaemon(true);
scheduler.setWaitForTasksToCompleteOnShutdown(true);
scheduler.setAwaitTerminationSeconds(10);
return scheduler;
}
@Bean
public TaskScheduler cacheCleanupScheduler() {
ThreadPoolTaskScheduler scheduler = new ThreadPoolTaskScheduler();
scheduler.setPoolSize(cleanupThreadPoolSize);
scheduler.setThreadNamePrefix("cache-cleanup-");
scheduler.setDaemon(true);
scheduler.setWaitForTasksToCompleteOnShutdown(true);
scheduler.setAwaitTerminationSeconds(10);
return scheduler;
}

@Bean
public InMemoryCache<String, String> refreshTokenCache(TaskScheduler cacheCleanupScheduler) {
return new SpringInMemoryCache<>(
cacheCleanupScheduler,
Duration.ofSeconds(cleanupIntervalSeconds)
);
}
@Bean
public InMemoryCache<Long, String> refreshTokenCache(TaskScheduler cacheCleanupScheduler) {
return new SpringInMemoryCache<>(
cacheCleanupScheduler,
Duration.ofSeconds(cleanupIntervalSeconds)
);
}

@Bean
public InMemoryCache<String, UserPrincipal> userPrincipalCache(
TaskScheduler cacheCleanupScheduler
) {
return new SpringInMemoryCache<>(
cacheCleanupScheduler,
Duration.ofSeconds(cleanupIntervalSeconds)
);
}
@Bean
public InMemoryCache<String, UserPrincipal> userPrincipalCache(
TaskScheduler cacheCleanupScheduler
) {
return new SpringInMemoryCache<>(
cacheCleanupScheduler,
Duration.ofSeconds(cleanupIntervalSeconds)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ public enum UserHttpResponseCode implements CustomResponseCode {
SIGNIN_FAIL_ALREADY_EXIST(HttpStatus.CONFLICT, "로그인 실패 - 이미 존재하는 사용자", "로그인 실패 - 이미 존재하는 사용자"),
JWT_TOKEN_BROKEN(HttpStatus.BAD_REQUEST, "JWT 토큰이 손상되었습니다", "JWT 토큰이 손상되었습니다"),
TOKEN_REFRESH_FAIL(HttpStatus.FORBIDDEN, "토큰 재발급 실패", "Refresh 토큰이 유효하지 않습니다."),
TOKEN_INVALID(HttpStatus.UNAUTHORIZED, "인증 실패", "JWT 토큰 인증 실패");
TOKEN_INVALID(HttpStatus.UNAUTHORIZED, "인증 실패", "JWT 토큰 인증 실패"),
REFRESH_EXPIRED(HttpStatus.FORBIDDEN, "리프레시 토큰 만료", "리프레시 토큰 만료");

private final HttpStatus code;
private final String clientMessage;
Expand Down
10 changes: 10 additions & 0 deletions src/main/resources/sql/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ DROP TABLE IF EXISTS runimo;
DROP TABLE IF EXISTS item_activity;
DROP TABLE IF EXISTS runimo_definition;
DROP TABLE IF EXISTS item;
DROP TABLE IF EXISTS user_refresh_token;
DROP TABLE IF EXISTS users;
DROP TABLE IF EXISTS user_love_point;
DROP TABLE IF EXISTS incubating_egg;
Expand Down Expand Up @@ -197,6 +198,15 @@ CREATE TABLE `runimo`
`deleted_at` TIMESTAMP NULL
);

CREATE TABLE `user_refresh_token`
(
`id` BIGINT PRIMARY KEY AUTO_INCREMENT,
`user_id` BIGINT NOT NULL UNIQUE ,
`refresh_token` TEXT NOT NULL,
`created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
`updated_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
);

ALTER TABLE `user_token`
ADD FOREIGN KEY (`user_id`) REFERENCES `users` (`id`);

Expand Down
Loading
Loading