Skip to content

Commit

Permalink
Merge pull request #66 from zitadel/redirect-loop
Browse files Browse the repository at this point in the history
fix: handle redirect loop
  • Loading branch information
livio-a authored Dec 7, 2023
2 parents 45a75c4 + 0afe35b commit bcc610f
Show file tree
Hide file tree
Showing 9 changed files with 54 additions and 64 deletions.
4 changes: 2 additions & 2 deletions pkg/provider/logout.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque
return nil
},
func() {
http.Error(w, fmt.Errorf("failed to parse form: %w", err).Error(), http.StatusInternalServerError)
response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to parse form: %w", err).Error(), p.timeFormat))
},
)

Expand All @@ -60,7 +60,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque
return nil
},
func() {
response.sendBackLogoutResponse(w, response.makeUnsupportedlLogoutResponse(fmt.Errorf("failed to decode request: %w", err).Error(), p.timeFormat))
response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to decode request: %w", err).Error(), p.timeFormat))
},
)

Expand Down
28 changes: 8 additions & 20 deletions pkg/provider/logout_response.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
package provider

import (
"bufio"
"bytes"
"encoding/base64"
"encoding/xml"
"html/template"
"net/http"
"time"

"github.com/zitadel/saml/pkg/provider/xml"
"github.com/zitadel/saml/pkg/provider/xml/saml"
"github.com/zitadel/saml/pkg/provider/xml/samlp"
)
Expand All @@ -31,33 +29,23 @@ type LogoutResponseForm struct {
}

func (r *LogoutResponse) sendBackLogoutResponse(w http.ResponseWriter, resp *samlp.LogoutResponseType) {
var xmlbuff bytes.Buffer

memWriter := bufio.NewWriter(&xmlbuff)
_, err := memWriter.Write([]byte(xml.Header))
respData, err := xml.Marshal(resp)
if err != nil {
r.ErrorFunc(err)
return
}

encoder := xml.NewEncoder(memWriter)
err = encoder.Encode(resp)
if err != nil {
r.ErrorFunc(err)
if r.LogoutURL == "" {
if err := xml.Write(w, respData); err != nil {
r.ErrorFunc(err)
return
}
return
}

err = memWriter.Flush()
if err != nil {
r.ErrorFunc(err)
return
}

samlMessage := base64.StdEncoding.EncodeToString(xmlbuff.Bytes())

data := LogoutResponseForm{
RelayState: r.RelayState,
SAMLResponse: samlMessage,
SAMLResponse: base64.StdEncoding.EncodeToString(respData),
LogoutURL: r.LogoutURL,
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/provider/redirect.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,12 @@ func createRedirectSignature(
idp *IdentityProvider,
response *Response,
) error {
respStr, err := xml.Marshal(samlResponse)
resp, err := xml.Marshal(samlResponse)
if err != nil {
return err
}

respData, err := xml.DeflateAndBase64([]byte(respStr))
respData, err := xml.DeflateAndBase64(resp)
if err != nil {
return err
}
Expand Down
49 changes: 25 additions & 24 deletions pkg/provider/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,37 @@ type Response struct {
}

func (r *Response) doResponse(request *http.Request, w http.ResponseWriter, response string) {

}

type AuthResponseForm struct {
RelayState string
SAMLResponse string
AssertionConsumerServiceURL string
}

func (r *Response) sendBackResponse(
req *http.Request,
w http.ResponseWriter,
resp *samlp.ResponseType,
) {
respData, err := xml.Marshal(resp)
if err != nil {
r.ErrorFunc(err)
return
}

if r.AcsUrl == "" {
if err := xml.Write(w, []byte(response)); err != nil {
if err := xml.Write(w, respData); err != nil {
r.ErrorFunc(err)
return
}
return
}

switch r.ProtocolBinding {
case PostBinding:
respData := base64.StdEncoding.EncodeToString([]byte(response))
respData := base64.StdEncoding.EncodeToString(respData)

data := AuthResponseForm{
r.RelayState,
Expand All @@ -63,39 +84,19 @@ func (r *Response) doResponse(request *http.Request, w http.ResponseWriter, resp
return
}
case RedirectBinding:
respData, err := xml.DeflateAndBase64([]byte(response))
respData, err := xml.DeflateAndBase64(respData)
if err != nil {
r.ErrorFunc(err)
return
}

http.Redirect(w, request, fmt.Sprintf("%s?%s", r.AcsUrl, buildRedirectQuery(string(respData), r.RelayState, r.SigAlg, r.Signature)), http.StatusFound)
http.Redirect(w, req, fmt.Sprintf("%s?%s", r.AcsUrl, buildRedirectQuery(string(respData), r.RelayState, r.SigAlg, r.Signature)), http.StatusFound)
return
default:
//TODO: no binding
}
}

type AuthResponseForm struct {
RelayState string
SAMLResponse string
AssertionConsumerServiceURL string
}

func (r *Response) sendBackResponse(
req *http.Request,
w http.ResponseWriter,
resp *samlp.ResponseType,
) {
respStr, err := xml.Marshal(resp)
if err != nil {
r.ErrorFunc(err)
return
}

r.doResponse(req, w, respStr)
}

func (r *Response) makeUnsupportedBindingResponse(
message string,
timeFormat string,
Expand Down
4 changes: 2 additions & 2 deletions pkg/provider/signature/signature_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ func TestSignature_CreatePost(t *testing.T) {
}
resp.Signature = sig

respStr, err := saml_xml.Marshal(resp)
respData, err := saml_xml.Marshal(resp)
if err != nil {
if (err != nil) != tt.res.err {
t.Errorf("Create() marshall response for signing")
Expand All @@ -521,7 +521,7 @@ func TestSignature_CreatePost(t *testing.T) {
}

doc := etree.NewDocument()
if err := doc.ReadFromBytes([]byte(respStr)); err != nil {
if err := doc.ReadFromBytes(respData); err != nil {
if (err != nil) != tt.res.err {
t.Errorf("Cert() failed to read response")
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/provider/sso.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (p *IdentityProvider) ssoHandleFunc(w http.ResponseWriter, r *http.Request)
return nil
},
func() {
http.Error(w, fmt.Errorf("failed to parse form: %w", err).Error(), http.StatusInternalServerError)
response.sendBackResponse(r, w, response.makeDeniedResponse(fmt.Errorf("failed to parse form").Error(), p.timeFormat))
},
)

Expand Down
4 changes: 2 additions & 2 deletions pkg/provider/sso_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -559,8 +559,8 @@ func TestSSO_ssoHandleFunc(t *testing.T) {
},
},
res{
code: 500,
state: "",
code: 200,
state: StatusCodeRequestDenied,
err: false,
}},
{
Expand Down
10 changes: 5 additions & 5 deletions pkg/provider/xml/xml.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,27 @@ const (
EncodingDeflate = "urn:oasis:names:tc:SAML:2.0:bindings:URL-Encoding:DEFLATE"
)

func Marshal(data interface{}) (string, error) {
func Marshal(data interface{}) ([]byte, error) {
var xmlbuff bytes.Buffer

memWriter := bufio.NewWriter(&xmlbuff)
_, err := memWriter.Write([]byte(xml.Header))
if err != nil {
return "", err
return nil, err
}

encoder := xml.NewEncoder(memWriter)
err = encoder.Encode(data)
if err != nil {
return "", err
return nil, err
}

err = memWriter.Flush()
if err != nil {
return "", err
return nil, err
}

return xmlbuff.String(), nil
return xmlbuff.Bytes(), nil
}

func DeflateAndBase64(data []byte) ([]byte, error) {
Expand Down
13 changes: 7 additions & 6 deletions pkg/provider/xml/xml_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package xml_test

import (
"reflect"
"testing"

"github.com/zitadel/saml/pkg/provider/xml"
Expand All @@ -12,7 +13,7 @@ type XML struct {

func Test_XmlMarshal(t *testing.T) {
type res struct {
metadata string
metadata []byte
err bool
}

Expand All @@ -25,23 +26,23 @@ func Test_XmlMarshal(t *testing.T) {
name: "xml struct",
arg: "<test></test>",
res: res{
metadata: "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<XML><test></test></XML>",
metadata: []byte("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<XML><test></test></XML>"),
err: false,
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
xmlStruct := XML{InnerXml: tt.arg}

xmlStr, err := xml.Marshal(xmlStruct)
xmlData, err := xml.Marshal(xmlStruct)
if (err != nil) != tt.res.err {
t.Errorf("Marshal() error: %v", err)
return
}
if xmlStr != tt.res.metadata {
t.Errorf("Marshal() error expected: %v, got %v", tt.res.metadata, xmlStr)

if !reflect.DeepEqual(tt.res.metadata, xmlData) {
t.Errorf("Marshal() error expected: %v, got %v", tt.res.metadata, xmlData)
return
}
})
Expand Down

0 comments on commit bcc610f

Please sign in to comment.