Skip to content

Commit

Permalink
Merge pull request #197 from croessner/features
Browse files Browse the repository at this point in the history
Fix: Refactor `AuthState` variable naming for clarity.
  • Loading branch information
croessner authored Jan 2, 2025
2 parents 7b2a34d + 850fe38 commit 4baccf4
Showing 1 changed file with 46 additions and 46 deletions.
92 changes: 46 additions & 46 deletions server/core/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,7 @@ func (a *AuthState) GetUniqueUserID() string {
return ""
}

if webAuthnUserID, okay := a.Attributes[*a.UniqueUserIDField]; okay {
if webAuthnUserID, okay := a.Attributes[a.GetUniqueUserIDField()]; okay {
if value, assertOk := webAuthnUserID[definitions.LDAPSingleValue].(string); assertOk {
return value
}
Expand All @@ -943,7 +943,7 @@ func (a *AuthState) GetDisplayName() string {
return ""
}

if account, okay := a.Attributes[*a.DisplayNameField]; okay {
if account, okay := a.Attributes[a.GetDisplayNameField()]; okay {
if value, assertOk := account[definitions.SliceWithOneElement].(string); assertOk {
return value
}
Expand Down Expand Up @@ -986,12 +986,12 @@ func (a *AuthState) AuthOK(ctx *gin.Context) {
// It sets the "Auth-Status" header to "OK" and the "X-Nauthilus-Session" header to the GUID of the AuthState.
// If the AuthState's Service is not definitions.ServBasic, and the HaveAccountField flag is true,
// it retrieves the account from the AuthState and sets the "Auth-User" header
func setCommonHeaders(ctx *gin.Context, a *AuthState) {
func setCommonHeaders(ctx *gin.Context, auth *AuthState) {
ctx.Header("Auth-Status", "OK")
ctx.Header("X-Nauthilus-Session", *a.GUID)
ctx.Header("X-Nauthilus-Session", *auth.GUID)

if a.Service != definitions.ServBasic {
if account, found := a.GetAccountOk(); found {
if auth.Service != definitions.ServBasic {
if account, found := auth.GetAccountOk(); found {
ctx.Header("Auth-User", account)
}
}
Expand All @@ -1012,18 +1012,18 @@ func setCommonHeaders(ctx *gin.Context, a *AuthState) {
// If the Protocol is definitions.ProtoSMTP, it sets the "Auth-Server" header to the SMTPBackendAddress and the "Auth-Port" header to the SMTPBackendPort.
// If the Protocol is definitions.ProtoIMAP, it sets the "Auth-Server" header to the IMAPBackendAddress and the "Auth-Port" header to the IMAPBackendPort.
// If the Protocol is definitions.ProtoPOP3, it sets the "Auth-Server" header to the POP3BackendAddress and the "Auth-Port" header to the POP3BackendPort.
func setNginxHeaders(ctx *gin.Context, a *AuthState) {
func setNginxHeaders(ctx *gin.Context, auth *AuthState) {
if config.LoadableConfig.HasFeature(definitions.FeatureBackendServersMonitoring) {
if BackendServers.GetTotalServers() == 0 {
ctx.Header("Auth-Status", "Internal failure")
} else {
if a.UsedBackendIP != "" && a.UsedBackendPort > 0 {
ctx.Header("Auth-Server", a.UsedBackendIP)
ctx.Header("Auth-Port", fmt.Sprintf("%d", a.UsedBackendPort))
if auth.UsedBackendIP != "" && auth.UsedBackendPort > 0 {
ctx.Header("Auth-Server", auth.UsedBackendIP)
ctx.Header("Auth-Port", fmt.Sprintf("%d", auth.UsedBackendPort))
}
}
} else {
switch a.Protocol.Get() {
switch auth.Protocol.Get() {
case definitions.ProtoSMTP:
ctx.Header("Auth-Server", config.EnvConfig.SMTPBackendAddress)
ctx.Header("Auth-Port", fmt.Sprintf("%d", config.EnvConfig.SMTPBackendPort))
Expand Down Expand Up @@ -1057,9 +1057,9 @@ func setNginxHeaders(ctx *gin.Context, a *AuthState) {
// Resulting headers in ctx:
// - X-Nauthilus-Attribute1: "Value1"
// - X-Nauthilus-Attribute2: "Value2_1,Value2_2"
func setHeaderHeaders(ctx *gin.Context, a *AuthState) {
if a.Attributes != nil && len(a.Attributes) > 0 {
for name, value := range a.Attributes {
func setHeaderHeaders(ctx *gin.Context, auth *AuthState) {
if auth.Attributes != nil && len(auth.Attributes) > 0 {
for name, value := range auth.Attributes {
handleAttributeValue(ctx, name, value)
}
}
Expand Down Expand Up @@ -1115,22 +1115,22 @@ func formatValues(values []any) []string {
}

// sendAuthResponse sends a JSON response with the appropriate headers and content based on the AuthState.
func sendAuthResponse(ctx *gin.Context, a *AuthState) {
ctx.JSON(a.StatusCodeOK, &backend.PositivePasswordCache{
AccountField: a.AccountField,
TOTPSecretField: a.TOTPSecretField,
Backend: a.SourcePassDBBackend,
Attributes: a.Attributes,
func sendAuthResponse(ctx *gin.Context, auth *AuthState) {
ctx.JSON(auth.StatusCodeOK, &backend.PositivePasswordCache{
AccountField: auth.AccountField,
TOTPSecretField: auth.TOTPSecretField,
Backend: auth.SourcePassDBBackend,
Attributes: auth.Attributes,
})
}

// handleLogging logs information about the authentication request if the verbosity level is greater than LogLevelWarn.
// It uses the log.Logger to log the information.
// The logged information includes the result of the a.LogLineTemplate() function, which returns either "ok" or an empty string depending on the value of a.NoAuth,
// and the path of the request URL obtained from ctx.Request.URL.Path.
func handleLogging(ctx *gin.Context, a *AuthState) {
level.Info(log.Logger).Log(a.LogLineTemplate(func() string {
if !a.NoAuth {
func handleLogging(ctx *gin.Context, auth *AuthState) {
level.Info(log.Logger).Log(auth.LogLineTemplate(func() string {
if !auth.NoAuth {
return "ok"
}

Expand Down Expand Up @@ -1366,10 +1366,10 @@ func (a *AuthState) verifyPassword(passDBs []*PassDBMap) (*PassDBResult, error)
// logDebugModule(a, passDB, passDBResult)
//
// This function uses the util.DebugModule function from the package to log the debug information.
func logDebugModule(a *AuthState, passDB *PassDBMap, passDBResult *PassDBResult) {
func logDebugModule(auth *AuthState, passDB *PassDBMap, passDBResult *PassDBResult) {
util.DebugModule(
definitions.DbgAuth,
definitions.LogKeyGUID, a.GUID,
definitions.LogKeyGUID, auth.GUID,
"passdb", passDB.backend.String(),
"result", fmt.Sprintf("%v", passDBResult))
}
Expand All @@ -1379,23 +1379,23 @@ func logDebugModule(a *AuthState, passDB *PassDBMap, passDBResult *PassDBResult)
// If all password databases have been processed and there are configuration errors, it calls the checkAllBackends function.
// If the error is not a configuration error, it logs the error using the Logger.
// It returns the error unchanged.
func handleBackendErrors(passDBIndex int, passDBs []*PassDBMap, passDB *PassDBMap, err error, a *AuthState, configErrors map[definitions.Backend]error) error {
func handleBackendErrors(passDBIndex int, passDBs []*PassDBMap, passDB *PassDBMap, err error, auth *AuthState, configErrors map[definitions.Backend]error) error {
if stderrors.Is(err, errors.ErrLDAPConfig) || stderrors.Is(err, errors.ErrLuaConfig) {
configErrors[passDB.backend] = err

// After all password databases were running, check if SQL, LDAP and Lua backends have configuration errors.
if passDBIndex == len(passDBs)-1 {
err = checkAllBackends(configErrors, a)
err = checkAllBackends(configErrors, auth)
}
} else {
level.Error(log.Logger).Log(definitions.LogKeyGUID, a.GUID, "passdb", passDB.backend.String(), definitions.LogKeyMsg, err)
level.Error(log.Logger).Log(definitions.LogKeyGUID, auth.GUID, "passdb", passDB.backend.String(), definitions.LogKeyMsg, err)
}

return err
}

// After all password databases were running, check if SQL, LDAP and Lua backends have configuration errors.
func checkAllBackends(configErrors map[definitions.Backend]error, a *AuthState) (err error) {
func checkAllBackends(configErrors map[definitions.Backend]error, auth *AuthState) (err error) {
var allConfigErrors = true

for _, err = range configErrors {
Expand All @@ -1409,7 +1409,7 @@ func checkAllBackends(configErrors map[definitions.Backend]error, a *AuthState)
// If all (real) Database backends failed, we must return with a temporary failure
if allConfigErrors {
err = errors.ErrAllBackendConfigError
level.Error(log.Logger).Log(definitions.LogKeyGUID, a.GUID, "passdb", "all", definitions.LogKeyMsg, err)
level.Error(log.Logger).Log(definitions.LogKeyGUID, auth.GUID, "passdb", "all", definitions.LogKeyMsg, err)
}

return err
Expand All @@ -1422,20 +1422,20 @@ func checkAllBackends(configErrors map[definitions.Backend]error, a *AuthState)
// Next, it calls the updateAuthentication function to update the fields of a based on the values in passDBResult.
// If the UserFound field of passDBResult is true, it sets the UserFound field of a to true.
// Finally, it returns the updated passDBResult and nil error.
func processPassDBResult(passDBResult *PassDBResult, a *AuthState, passDB *PassDBMap) error {
func processPassDBResult(passDBResult *PassDBResult, auth *AuthState, passDB *PassDBMap) error {
if passDBResult == nil {
return errors.ErrNoPassDBResult
}

util.DebugModule(
definitions.DbgAuth,
definitions.LogKeyGUID, a.GUID,
definitions.LogKeyGUID, auth.GUID,
"passdb", passDB.backend.String(),
definitions.LogKeyUsername, a.Username,
definitions.LogKeyUsername, auth.Username,
"passdb_result", fmt.Sprintf("%+v", *passDBResult),
)

updateAuthentication(a, passDBResult, passDB)
updateAuthentication(auth, passDBResult, passDB)

return nil
}
Expand All @@ -1444,32 +1444,32 @@ func processPassDBResult(passDBResult *PassDBResult, a *AuthState, passDB *PassD
// It checks if each field in passDBResult is not nil and if it is not nil, it updates the corresponding field in the AuthState struct.
// It also updates the SourcePassDBBackend and UsedPassDBBackend fields of the AuthState struct with the values from passDBResult.Backend and passDB.backend respectively.
// It returns the updated PassDBResult struct.
func updateAuthentication(a *AuthState, passDBResult *PassDBResult, passDB *PassDBMap) {
func updateAuthentication(auth *AuthState, passDBResult *PassDBResult, passDB *PassDBMap) {
if passDBResult.UserFound {
a.UserFound = true
auth.UserFound = true

a.SourcePassDBBackend = passDBResult.Backend
a.UsedPassDBBackend = passDB.backend
auth.SourcePassDBBackend = passDBResult.Backend
auth.UsedPassDBBackend = passDB.backend
}

if passDBResult.AccountField != nil {
a.AccountField = passDBResult.AccountField
auth.AccountField = passDBResult.AccountField
}

if passDBResult.TOTPSecretField != nil {
a.TOTPSecretField = passDBResult.TOTPSecretField
auth.TOTPSecretField = passDBResult.TOTPSecretField
}

if passDBResult.UniqueUserIDField != nil {
a.UniqueUserIDField = passDBResult.UniqueUserIDField
auth.UniqueUserIDField = passDBResult.UniqueUserIDField
}

if passDBResult.DisplayNameField != nil {
a.DisplayNameField = passDBResult.DisplayNameField
auth.DisplayNameField = passDBResult.DisplayNameField
}

if passDBResult.Attributes != nil && len(passDBResult.Attributes) > 0 {
a.Attributes = passDBResult.Attributes
auth.Attributes = passDBResult.Attributes
}
}

Expand Down Expand Up @@ -2659,10 +2659,10 @@ func (a *AuthState) processClaim(claimName string, claimValue string, claims map
}

// Custom logic to apply string claims
func applyClaim(claimKey string, attributeKey string, a *AuthState, claims map[string]any, claimHandlers []ClaimHandler) {
func applyClaim(claimKey string, attributeKey string, auth *AuthState, claims map[string]any, claimHandlers []ClaimHandler) {
var success bool

if attributeValue, found := a.Attributes[attributeKey]; found {
if attributeValue, found := auth.Attributes[attributeKey]; found {
for _, handler := range claimHandlers {
if t := reflect.TypeOf(attributeValue).Kind(); t == handler.Type {
success = handler.ApplyFunc(attributeValue, claims, claimKey)
Expand All @@ -2675,7 +2675,7 @@ func applyClaim(claimKey string, attributeKey string, a *AuthState, claims map[s

if !success {
level.Warn(log.Logger).Log(
definitions.LogKeyGUID, a.GUID,
definitions.LogKeyGUID, auth.GUID,
definitions.LogKeyMsg, fmt.Sprintf("Claim '%s' malformed or not returned from Database", claimKey),
)
}
Expand Down

0 comments on commit 4baccf4

Please sign in to comment.