Skip to content

Commit

Permalink
🎨 Improv localhost address validation
Browse files Browse the repository at this point in the history
  • Loading branch information
Zuoqiu-Yingyi committed Nov 11, 2023
1 parent 4d218cd commit 640149d
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 44 deletions.
72 changes: 28 additions & 44 deletions kernel/model/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func LoginAuth(c *gin.Context) {

if err := session.Save(c); nil != err {
logging.LogErrorf("save session failed: " + err.Error())
c.Status(500)
c.Status(http.StatusInternalServerError)
return
}
return
Expand All @@ -109,7 +109,7 @@ func LoginAuth(c *gin.Context) {
workspaceSession.Captcha = gulu.Rand.String(7)
if err := session.Save(c); nil != err {
logging.LogErrorf("save session failed: " + err.Error())
c.Status(500)
c.Status(http.StatusInternalServerError)
return
}
}
Expand All @@ -123,7 +123,7 @@ func GetCaptcha(c *gin.Context) {
})
if nil != err {
logging.LogErrorf("generates captcha failed: " + err.Error())
c.Status(500)
c.Status(http.StatusInternalServerError)
return
}

Expand All @@ -132,16 +132,16 @@ func GetCaptcha(c *gin.Context) {
workspaceSession.Captcha = img.Text
if err = session.Save(c); nil != err {
logging.LogErrorf("save session failed: " + err.Error())
c.Status(500)
c.Status(http.StatusInternalServerError)
return
}

if err = img.WriteImage(c.Writer); nil != err {
logging.LogErrorf("writes captcha image failed: " + err.Error())
c.Status(500)
c.Status(http.StatusInternalServerError)
return
}
c.Status(200)
c.Status(http.StatusOK)
}

func CheckReadonly(c *gin.Context) {
Expand All @@ -150,46 +150,29 @@ func CheckReadonly(c *gin.Context) {
result.Code = -1
result.Msg = Conf.Language(34)
result.Data = map[string]interface{}{"closeTimeout": 5000}
c.JSON(200, result)
c.JSON(http.StatusOK, result)
c.Abort()
return
}
}

func CheckAuth(c *gin.Context) {
//logging.LogInfof("check auth for [%s]", c.Request.RequestURI)
localhost := util.IsLocalHost(c.Request.RemoteAddr)

// 未设置访问授权码
if "" == Conf.AccessAuthCode {
if origin := c.GetHeader("Origin"); "" != origin {
// Authenticate requests with the Origin header other than 127.0.0.1 https://github.com/siyuan-note/siyuan/issues/9180
u, parseErr := url.Parse(origin)
if nil != parseErr {
logging.LogWarnf("parse origin [%s] failed: %s", origin, parseErr)
c.JSON(401, map[string]interface{}{"code": -1, "msg": "Auth failed: parse req header [Origin] failed"})
c.Abort()
return

}

if "chrome-extension" == strings.ToLower(u.Scheme) {
c.Next()
return
}

if !strings.HasPrefix(u.Host, util.LocalHost) && !strings.HasPrefix(u.Host, "[::1]") {
c.JSON(401, map[string]interface{}{"code": -1, "msg": "Auth failed: for security reasons, please set [Access authorization code] when using non-127.0.0.1 access\n\n为安全起见,使用非 127.0.0.1 访问时请设置 [访问授权码]"})
c.Abort()
return
}
}

if !strings.HasPrefix(c.Request.RemoteAddr, util.LocalHost) && !strings.HasPrefix(c.Request.RemoteAddr, "[::1]") {
// Authenticate requests of assets other than 127.0.0.1 https://github.com/siyuan-note/siyuan/issues/9388
if strings.HasPrefix(c.Request.RequestURI, "/assets/") {
c.JSON(401, map[string]interface{}{"code": -1, "msg": "Auth failed: for security reasons, please set [Access authorization code] when using non-127.0.0.1 access\n\n为安全起见,使用非 127.0.0.1 访问时请设置 [访问授权码]"})
c.Abort()
return
}
// Authenticate requests with the Origin header other than 127.0.0.1 https://github.com/siyuan-note/siyuan/issues/9180
host := c.GetHeader("Host")
origin := c.GetHeader("Origin")
forwardedHost := c.GetHeader("X-Forwarded-Host")
if !localhost ||
("" != host && !util.IsLocalHost(host)) ||
("" != origin && !util.IsLocalOrigin(origin)) ||
("" != forwardedHost && !util.IsLocalHost(forwardedHost)) {
c.JSON(http.StatusUnauthorized, map[string]interface{}{"code": -1, "msg": "Auth failed: for security reasons, please set [Access authorization code] when using non-127.0.0.1 access\n\n为安全起见,使用非 127.0.0.1 访问时请设置 [访问授权码]"})
c.Abort()
return
}

c.Next()
Expand All @@ -206,7 +189,7 @@ func CheckAuth(c *gin.Context) {
}

// 放过来自本机的某些请求
if strings.HasPrefix(c.Request.RemoteAddr, util.LocalHost) || strings.HasPrefix(c.Request.RemoteAddr, "[::1]") {
if localhost {
if strings.HasPrefix(c.Request.RequestURI, "/assets/") {
c.Next()
return
Expand Down Expand Up @@ -234,7 +217,7 @@ func CheckAuth(c *gin.Context) {
return
}

c.JSON(401, map[string]interface{}{"code": -1, "msg": "Auth failed"})
c.JSON(http.StatusUnauthorized, map[string]interface{}{"code": -1, "msg": "Auth failed"})
c.Abort()
return
}
Expand All @@ -247,7 +230,7 @@ func CheckAuth(c *gin.Context) {
return
}

c.JSON(401, map[string]interface{}{"code": -1, "msg": "Auth failed"})
c.JSON(http.StatusUnauthorized, map[string]interface{}{"code": -1, "msg": "Auth failed"})
c.Abort()
return
}
Expand All @@ -261,7 +244,7 @@ func CheckAuth(c *gin.Context) {
userAgentHeader := c.GetHeader("User-Agent")
if strings.HasPrefix(userAgentHeader, "SiYuan/") || strings.HasPrefix(userAgentHeader, "Mozilla/") {
if "GET" != c.Request.Method {
c.JSON(401, map[string]interface{}{"code": -1, "msg": Conf.Language(156)})
c.JSON(http.StatusUnauthorized, map[string]interface{}{"code": -1, "msg": Conf.Language(156)})
c.Abort()
return
}
Expand All @@ -271,12 +254,13 @@ func CheckAuth(c *gin.Context) {
queryParams.Set("to", c.Request.URL.String())
location.RawQuery = queryParams.Encode()
location.Path = "/check-auth"
c.Redirect(302, location.String())

c.Redirect(http.StatusFound, location.String())
c.Abort()
return
}

c.JSON(401, map[string]interface{}{"code": -1, "msg": "Auth failed"})
c.JSON(http.StatusUnauthorized, map[string]interface{}{"code": -1, "msg": "Auth failed"})
c.Abort()
return
}
Expand Down Expand Up @@ -316,7 +300,7 @@ func Timing(c *gin.Context) {
func Recover(c *gin.Context) {
defer func() {
logging.Recover()
c.Status(500)
c.Status(http.StatusInternalServerError)
}()

c.Next()
Expand Down
53 changes: 53 additions & 0 deletions kernel/util/net.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package util

import (
"net"
"net/http"
"net/url"
"strings"
Expand All @@ -31,6 +32,58 @@ import (
"github.com/siyuan-note/logging"
)

func ValidOptionalPort(port string) bool {
if port == "" {
return true
}
if port[0] != ':' {
return false
}
for _, b := range port[1:] {
if b < '0' || b > '9' {
return false
}
}
return true
}

func SplitHost(host string) (hostname, port string) {
hostname = host

colon := strings.LastIndexByte(hostname, ':')
if colon != -1 && ValidOptionalPort(hostname[colon:]) {
hostname, port = hostname[:colon], hostname[colon+1:]
}

if strings.HasPrefix(hostname, "[") && strings.HasSuffix(hostname, "]") {
hostname = hostname[1 : len(hostname)-1]
}

return
}

func IsLocalHostname(hostname string) bool {
if "localhost" == hostname {
return true
}
if ip := net.ParseIP(hostname); nil != ip {
return ip.IsLoopback()
}
return false
}

func IsLocalHost(host string) bool {
hostname, _ := SplitHost(host)
return IsLocalHostname(hostname)
}

func IsLocalOrigin(origin string) bool {
if url, err := url.Parse(origin); nil == err {
return IsLocalHostname(url.Hostname())
}
return false
}

func IsOnline(checkURL string, skipTlsVerify bool) bool {
_, err := url.Parse(checkURL)
if nil != err {
Expand Down

0 comments on commit 640149d

Please sign in to comment.