diff --git a/pkg/portal/middleware/aad.go b/pkg/portal/middleware/aad.go index a8fecc1df84..294f7645f5e 100644 --- a/pkg/portal/middleware/aad.go +++ b/pkg/portal/middleware/aad.go @@ -29,10 +29,11 @@ import ( const ( SessionName = "session" // Expiration time in unix format - SessionKeyExpires = "expires" - sessionKeyState = "state" - SessionKeyUsername = "user_name" - SessionKeyGroups = "groups" + SessionKeyExpires = "expires" + sessionKeyState = "state" + sessionKeyRedirectUri = "redirect_uri" + SessionKeyUsername = "user_name" + SessionKeyGroups = "groups" ) // AAD is responsible for ensuring that we have a valid login session with AAD. @@ -175,7 +176,11 @@ func (a *aad) CheckAuthentication(h http.Handler) http.Handler { ctx := r.Context() if ctx.Value(ContextKeyUsername) == nil { if r.URL != nil { - http.Redirect(w, r, "/api/login", http.StatusTemporaryRedirect) + redirect := "/api/login" + if r.URL.Path != "" { + redirect += "?" + sessionKeyRedirectUri + "=" + r.URL.Path + } + http.Redirect(w, r, redirect, http.StatusTemporaryRedirect) return } http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) @@ -223,6 +228,10 @@ func (a *aad) redirect(w http.ResponseWriter, r *http.Request) { sessionKeyState: state, } + if r.URL.Query().Has(sessionKeyRedirectUri) { + session.Values[sessionKeyRedirectUri] = r.URL.Query().Get(sessionKeyRedirectUri) + } + err = session.Save(r, w) if err != nil { a.internalServerError(w, err) @@ -308,13 +317,19 @@ func (a *aad) callback(w http.ResponseWriter, r *http.Request) { session.Values[SessionKeyGroups] = groupsIntersect session.Values[SessionKeyExpires] = a.now().Add(a.sessionTimeout).Unix() + redirectUri := "/" + if v, ok := session.Values[sessionKeyRedirectUri]; ok { + redirectUri = v.(string) + delete(session.Values, sessionKeyRedirectUri) + } + err = session.Save(r, w) if err != nil { a.internalServerError(w, err) return } - http.Redirect(w, r, "/", http.StatusTemporaryRedirect) + http.Redirect(w, r, redirectUri, http.StatusTemporaryRedirect) } // clientAssertion adds a JWT client assertion according to diff --git a/pkg/portal/middleware/aad_test.go b/pkg/portal/middleware/aad_test.go index 1aaa3fa8213..daf91436949 100644 --- a/pkg/portal/middleware/aad_test.go +++ b/pkg/portal/middleware/aad_test.go @@ -201,6 +201,7 @@ func TestCheckAuthentication(t *testing.T) { name string request func(*aad) (*http.Request, error) wantStatusCode int + wantRedirectTo string wantAuthenticated bool }{ { @@ -220,6 +221,7 @@ func TestCheckAuthentication(t *testing.T) { return http.NewRequestWithContext(ctx, http.MethodGet, "/api/info", nil) }, wantStatusCode: http.StatusTemporaryRedirect, + wantRedirectTo: "/api/login?redirect_uri=/api/info", }, { name: "not authenticated", @@ -228,6 +230,7 @@ func TestCheckAuthentication(t *testing.T) { return http.NewRequestWithContext(ctx, http.MethodGet, "/callback", nil) }, wantStatusCode: http.StatusTemporaryRedirect, + wantRedirectTo: "/api/login?redirect_uri=/callback", }, { name: "invalid cookie", @@ -275,6 +278,14 @@ func TestCheckAuthentication(t *testing.T) { t.Error(w.Code, tt.wantStatusCode) } + if tt.wantRedirectTo != "" { + redirectLocation := w.Result().Header["Location"] + + if redirectLocation == nil || len(redirectLocation) != 1 || redirectLocation[0] != tt.wantRedirectTo { + t.Error(redirectLocation, tt.wantRedirectTo) + } + } + if authenticated != tt.wantAuthenticated { t.Fatal(authenticated) } diff --git a/portal/v2/src/Request.tsx b/portal/v2/src/Request.tsx index 7ea07fbf4f7..c123104203f 100644 --- a/portal/v2/src/Request.tsx +++ b/portal/v2/src/Request.tsx @@ -4,7 +4,11 @@ import { convertTimeToHours } from "./ClusterDetailListComponents/Statistics/Gra const OnError = (err: AxiosResponse): AxiosResponse | null => { if (err.status === 403) { - document.location.href = "/api/login" + var href = "/api/login" + if (document.location.pathname !== "/") { + href += "?redirect_uri=" + document.location.pathname + } + document.location.href = href return null } else { return err