-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
18d333e
commit 4b7e05d
Showing
5 changed files
with
220 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
package auth | ||
|
||
import "errors" | ||
|
||
var ( | ||
UserIdNotStored = errors.New("user id was not stored in user session") | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
package auth | ||
|
||
import ( | ||
"encoding/gob" | ||
"github.com/google/uuid" | ||
) | ||
|
||
func init() { | ||
gob.Register(uuid.UUID{}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
package auth | ||
|
||
import ( | ||
"github.com/gin-gonic/gin" | ||
"github.com/google/uuid" | ||
log "github.com/sirupsen/logrus" | ||
) | ||
|
||
// GetUserMiddleware returns gin handler function that reads user data from session store and saves it as context item. | ||
// Note that this middleware does not prevent not logged in users from using the app. | ||
func GetUserMiddleware() gin.HandlerFunc { | ||
return func(c *gin.Context) { | ||
userId, err := RetrieveUserFromSession(c) | ||
if err != nil { | ||
log.Debugf("Failed to retrieve user from session. Possibly just not logged in: %v", err) | ||
return | ||
} | ||
|
||
c.Set("userId", userId) | ||
} | ||
} | ||
|
||
// IsLoggedIn returns true if the user is logged in to the application. | ||
// Use GetUserId to retrieve the ID of the currently logged in user. | ||
func IsLoggedIn(c *gin.Context) bool { | ||
_, exists := c.Get("userId") | ||
return exists | ||
} | ||
|
||
// GetUserId returns the id of the user currently logged in. | ||
// It will return uuid.Nil if no user is logged in. | ||
// | ||
// Use IsLoggedIn first to check if the user is logged in. | ||
func GetUserId(c *gin.Context) uuid.UUID { | ||
value, exists := c.Get("userId") | ||
if !exists { | ||
return uuid.Nil | ||
} | ||
return value.(uuid.UUID) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
package auth | ||
|
||
import ( | ||
"fmt" | ||
"github.com/KowalskiPiotr98/ludivault/utils" | ||
"github.com/markbates/goth" | ||
"github.com/markbates/goth/providers/gitea" | ||
log "github.com/sirupsen/logrus" | ||
"strings" | ||
) | ||
|
||
var ( | ||
providers []string | ||
|
||
setups = map[string]func(callbackUrl string) (bool, goth.Provider){ | ||
"gitea": func(callbackUrl string) (bool, goth.Provider) { | ||
clientId := utils.GetOptionalConfig("SSO_GITEA_CLIENT_ID", "") | ||
clientSecret := utils.GetOptionalConfig("SSO_GITEA_CLIENT_SECRET", "") | ||
url := utils.GetOptionalConfig("SSO_GITEA_URL", "") | ||
|
||
if clientId == "" || clientSecret == "" || url == "" { | ||
return false, nil | ||
} | ||
url = strings.TrimRight(url, "/") | ||
|
||
return true, gitea.NewCustomisedURL(clientId, clientSecret, callbackUrl, fmt.Sprintf("%s/login/oauth/authorize", url), fmt.Sprintf("%s/login/oauth/access_token", url), fmt.Sprintf("%s/api/v1/user", url)) | ||
}, | ||
} | ||
) | ||
|
||
func SetupProviders(baseUrl string) { | ||
if areProvidersSet() { | ||
// prevent duplicate provider setup | ||
return | ||
} | ||
|
||
baseUrl = strings.TrimRight(baseUrl, "/") | ||
callbackUrl := fmt.Sprintf("%s/api/v1/auth/callback?provider=%%s", baseUrl) | ||
log.Debugf("Setting provider callback url to: %s", callbackUrl) | ||
|
||
enabledProviders := make([]goth.Provider, 0) | ||
|
||
for providerName, setup := range setups { | ||
ok, provider := setup(fmt.Sprintf(callbackUrl, providerName)) | ||
if ok { | ||
providers = append(providers, providerName) | ||
enabledProviders = append(enabledProviders, provider) | ||
log.Debugf("Registered login provider %s", providerName) | ||
} | ||
} | ||
|
||
goth.UseProviders(enabledProviders...) | ||
} | ||
|
||
func GetEnabledProviders() []string { | ||
return providers | ||
} | ||
|
||
func areProvidersSet() bool { | ||
return len(providers) > 0 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
package auth | ||
|
||
import ( | ||
"github.com/KowalskiPiotr98/ludivault/utils" | ||
"github.com/gin-gonic/gin" | ||
"github.com/google/uuid" | ||
"github.com/gorilla/sessions" | ||
"github.com/markbates/goth/gothic" | ||
log "github.com/sirupsen/logrus" | ||
"net/http" | ||
) | ||
|
||
//todo: storing sessions in cookie only is bad as it prevents any form of server-side tracking of sessions | ||
// but it'll have to do for now | ||
|
||
var ( | ||
UserSessionName = "ludivault-session" | ||
SetupSession = func(baseDomain string) sessions.Store { | ||
key := utils.GetRequiredConfig("SESSION_KEY") | ||
maxAge := 86400 * 30 | ||
store := sessions.NewCookieStore([]byte(key)) | ||
store.MaxAge(maxAge) | ||
store.Options.Path = "/" | ||
store.Options.HttpOnly = true | ||
store.Options.Secure = gin.Mode() != gin.DebugMode | ||
store.Options.Domain = baseDomain | ||
store.Options.SameSite = http.SameSiteLaxMode | ||
return store | ||
} | ||
|
||
authStore sessions.Store | ||
) | ||
|
||
// InitSessionStore initialises the store for user sessions. | ||
// | ||
// This function must be called before any other function from this package. | ||
func InitSessionStore(baseDomain string) { | ||
authStore = SetupSession(baseDomain) | ||
gothic.Store = authStore | ||
} | ||
|
||
// StoreUserInSession adds the user data to session store. | ||
func StoreUserInSession(c *gin.Context, userId uuid.UUID) error { | ||
ensureSessionStoreInit() | ||
|
||
session, err := authStore.Get(c.Request, UserSessionName) | ||
if err != nil { | ||
log.Warnf("Failed to get session: %v", err) | ||
return err | ||
} | ||
|
||
session.Values["userId"] = userId | ||
|
||
if err = session.Save(c.Request, c.Writer); err != nil { | ||
log.Warnf("Failed to save session: %v", err) | ||
return err | ||
} | ||
|
||
return nil | ||
} | ||
|
||
// RetrieveUserFromSession attempts to get user data from session store. | ||
func RetrieveUserFromSession(c *gin.Context) (uuid.UUID, error) { | ||
ensureSessionStoreInit() | ||
|
||
session, err := authStore.Get(c.Request, UserSessionName) | ||
if err != nil { | ||
log.Warnf("Failed to get session: %v", err) | ||
return uuid.Nil, err | ||
} | ||
|
||
id, ok := session.Values["userId"].(uuid.UUID) | ||
if !ok { | ||
return uuid.Nil, UserIdNotStored | ||
} | ||
return id, nil | ||
} | ||
|
||
func RemoveUserSession(c *gin.Context) error { | ||
ensureSessionStoreInit() | ||
|
||
session, err := authStore.Get(c.Request, UserSessionName) | ||
if err != nil { | ||
log.Warnf("Failed to get session: %v", err) | ||
return err | ||
} | ||
|
||
// remove what's left of the session | ||
session.Options.MaxAge = -1 | ||
if err := session.Save(c.Request, c.Writer); err != nil { | ||
log.Warnf("Failed to save session: %v", err) | ||
return err | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func ensureSessionStoreInit() { | ||
if authStore == nil { | ||
log.Panic("Session store is not initialized") | ||
} | ||
} |