From 4dd0b5c026b68ba0850495d454a22d04074d178e Mon Sep 17 00:00:00 2001 From: Mike Gouline <1960272+gouline@users.noreply.github.com> Date: Sun, 8 Sep 2024 22:48:14 +1000 Subject: [PATCH] Refactor Slack auth into separate package --- internal/pkg/format/format.go | 27 --- internal/pkg/format/format_test.go | 36 ---- internal/pkg/server/api.go | 207 ++++++-------------- internal/pkg/server/api_test.go | 147 ++++++++++++++ internal/pkg/server/auth.go | 94 --------- internal/pkg/server/pages.go | 37 +--- internal/pkg/server/server.go | 32 +-- internal/pkg/slack/context.go | 167 ++++++++++++++++ internal/pkg/slack/context_test.go | 37 ++++ internal/pkg/slack/handlers.go | 47 +++++ internal/pkg/slack/slack.go | 87 ++++++++ internal/pkg/slack/slack_test.go | 45 +++++ internal/pkg/templates/examples/README.md | 3 + internal/pkg/templates/examples/about.html | 3 + internal/pkg/templates/examples/home.html | 3 + internal/pkg/templates/examples/layout.html | 12 ++ internal/pkg/templates/templates.go | 14 +- internal/pkg/templates/templates_test.go | 49 +++++ main.go | 18 +- templates/index.html | 2 +- templates/layout.html | 8 +- 21 files changed, 706 insertions(+), 369 deletions(-) delete mode 100644 internal/pkg/format/format.go delete mode 100644 internal/pkg/format/format_test.go create mode 100644 internal/pkg/server/api_test.go delete mode 100644 internal/pkg/server/auth.go create mode 100644 internal/pkg/slack/context.go create mode 100644 internal/pkg/slack/context_test.go create mode 100644 internal/pkg/slack/handlers.go create mode 100644 internal/pkg/slack/slack.go create mode 100644 internal/pkg/slack/slack_test.go create mode 100644 internal/pkg/templates/examples/README.md create mode 100644 internal/pkg/templates/examples/about.html create mode 100644 internal/pkg/templates/examples/home.html create mode 100644 internal/pkg/templates/examples/layout.html create mode 100644 internal/pkg/templates/templates_test.go diff --git a/internal/pkg/format/format.go b/internal/pkg/format/format.go deleted file mode 100644 index 1bc03ca..0000000 --- a/internal/pkg/format/format.go +++ /dev/null @@ -1,27 +0,0 @@ -package format - -import ( - "crypto/sha1" - "fmt" - "regexp" - - "github.com/labstack/echo/v4" -) - -// NewAllSymbolsRegexp returns a compiled regular expression with -// all symbols on the keyboard available for filtering. -func NewAllSymbolsRegexp() *regexp.Regexp { - return regexp.MustCompile("[!-/:-@[-`{-~]+") -} - -// RelativeURI returns a URI relative to request host. -func RelativeURI(c echo.Context, path string) string { - return fmt.Sprintf("%s://%s%s", c.Scheme(), c.Request().Host, path) -} - -// HashToken hashes raw auth token with SHA-1. -func HashToken(token string) string { - h := sha1.New() - h.Write([]byte(token)) - return fmt.Sprintf("%x", h.Sum(nil)) -} diff --git a/internal/pkg/format/format_test.go b/internal/pkg/format/format_test.go deleted file mode 100644 index b42db7d..0000000 --- a/internal/pkg/format/format_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package format - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/labstack/echo/v4" -) - -func TestNewAllSymbolsRegexp(t *testing.T) { - reg := NewAllSymbolsRegexp() - - if reg.ReplaceAllString("test s!t@r#i$n^g [1,2)", "") != "test string 12" { - t.Errorf("Unexpected sanitization of \"test string 12\"") - } - - if reg.ReplaceAllString("тестовая с!т@р#о$к^а (9.0]", "") != "тестовая строка 90" { - t.Errorf("Unexpected sanitization of \"тестовая строка 90\"") - } - - if reg.ReplaceAllString("测!试@字#符$串%5^6", "") != "测试字符串56" { - t.Errorf("Unexpected sanitization of \"测试字符串56\"") - } -} - -func TestRelativeURI(t *testing.T) { - e := echo.New() - req := httptest.NewRequest(http.MethodPost, "https://test.example.com/", nil) - req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - if actual := RelativeURI(c, "/context/path"); actual != "https://test.example.com/context/path" { - t.Errorf("Unexpected relative URI: %s", actual) - } -} diff --git a/internal/pkg/server/api.go b/internal/pkg/server/api.go index a33ae04..856d747 100644 --- a/internal/pkg/server/api.go +++ b/internal/pkg/server/api.go @@ -1,192 +1,102 @@ package server import ( - "fmt" "net/http" + "regexp" "strings" - "time" - "github.com/gouline/blaster/internal/pkg/format" - "github.com/gouline/blaster/internal/pkg/scache" + "github.com/gouline/blaster/internal/pkg/slack" "github.com/labstack/echo/v4" - "github.com/slack-go/slack" ) -var suggestCache = scache.New(5*time.Minute, 10*time.Minute) - // APISuggest handles /api/suggest. func (s *Server) handleAPISuggest(c echo.Context) error { - token := s.authorizedToken(c) - if token == "" { + slackCtx := s.slack.Context(c) + if !slackCtx.Authorized { return c.String(http.StatusUnauthorized, "no token") } - cacheResponse := <-buildSuggestCache(token) - if cacheResponse.Error != nil { - return c.String(http.StatusInternalServerError, cacheResponse.Error.Error()) + destinations, err := slackCtx.Destinations() + if err != nil { + return c.String(http.StatusInternalServerError, err.Error()) } - allSuggestions := cacheResponse.Value.([]suggestion) - - // Filter out all symbols from term - symbolReg := format.NewAllSymbolsRegexp() - term := strings.ToLower(" " + c.QueryParam("term")) - term = symbolReg.ReplaceAllString(term, "") - - // Filter users by term - suggestions := []suggestion{} - for _, suggestion := range allSuggestions { - if strings.Contains(suggestion.Search, term) { - suggestions = append(suggestions, suggestion) - if len(suggestions) == 10 { - break - } - } - } + suggestions := suggestDestinations(c.QueryParam("term"), destinations) return c.JSON(http.StatusOK, suggestions) } // handleAPISend handles /api/send. func (s *Server) handleAPISend(c echo.Context) error { - token := s.authorizedToken(c) - if token == "" { - return c.String(http.StatusUnauthorized, "no token") + slackCtx := s.slack.Context(c) + if !slackCtx.Authorized { + return c.NoContent(http.StatusUnauthorized) } - client := slack.New(token) - - // Bind JSON request var request sendRequest - err := c.Bind(&request) - if err != nil { + if err := c.Bind(&request); err != nil { return c.String(http.StatusBadRequest, err.Error()) } - // Open/get channel by user ID - channel, _, _, err := client.OpenConversation(&slack.OpenConversationParameters{ - Users: []string{ - request.User, - }, - }) - if err != nil { - return c.String(http.StatusInternalServerError, err.Error()) - } - - // Post message to opened channel - _, _, err = client.PostMessage( - channel.ID, - slack.MsgOptionText(request.Message, false), - slack.MsgOptionAsUser(request.AsUser), - ) - if err != nil { + if err := slackCtx.SendMessage(request.User, request.Message, request.AsUser); err != nil { return c.String(http.StatusInternalServerError, err.Error()) } return c.JSON(http.StatusOK, struct{}{}) } -func buildSuggestCache(token string) <-chan scache.Response { - return suggestCache.ResponseChan(format.HashToken(token), func(key string) (interface{}, error) { - client := slack.New(token) - - symbolReg := format.NewAllSymbolsRegexp() - - var suggestions []suggestion - - userLookup := map[string]suggestion{} - - // Get all users - users, err := client.GetUsers() - if err != nil { - return nil, err - } - - suggestions = []suggestion{} - - for _, user := range users { - if user.Deleted || user.IsBot { - continue - } - - realName := user.Profile.RealName - displayName := user.Profile.DisplayName - - // Format label based on availability - label := realName - if displayName != "" { - label += " (" + displayName + ")" +// suggestDestinations filters destionations into suggestions by search term. +func suggestDestinations(term string, destinations []*slack.Destination) []*suggestion { + term = " " + sanitizeSearchTerm(term) + + suggestions := []*suggestion{} + for _, dest := range destinations { + searchable := " " + sanitizeSearchTerm(strings.ToLower(dest.Name)+" "+strings.ToLower(dest.DisplayName)) + + if strings.Contains(searchable, term) { + children := []*suggestion{} + for _, child := range dest.Children { + children = append(children, &suggestion{ + Type: sanitizeCSV(child.Type), + Label: sanitizeCSV(suggestionLabel(child.Name, child.DisplayName)), + Value: sanitizeCSV(child.ID), + }) } - // Filter out all symbols from search string - search := fmt.Sprintf(" %s %s", strings.ToLower(realName), strings.ToLower(displayName)) - search = symbolReg.ReplaceAllString(search, "") - - // Sanitize labels and values - sanitize := func(s string) string { - return strings.Replace(s, ",", "", -1) - } + suggestions = append(suggestions, &suggestion{ + Type: sanitizeCSV(dest.Type), + Label: sanitizeCSV(suggestionLabel(dest.Name, dest.DisplayName)), + Value: sanitizeCSV(dest.ID), + Children: children, + }) - s := suggestion{ - Type: "user", - Label: sanitize(label), - Value: sanitize(user.ID), - Search: search, + if len(suggestions) == 10 { + break } - - suggestions = append(suggestions, s) - - userLookup[user.ID] = s - } - - usergroups, err := client.GetUserGroups(slack.GetUserGroupsOptionIncludeUsers(true)) - if err != nil { - return nil, err } + } - for _, usergroup := range usergroups { - if !usergroup.IsUserGroup { - continue - } - - children := []suggestion{} - - for _, userID := range usergroup.Users { - user, found := userLookup[userID] - if !found { - continue - } - - children = append(children, user) - } - - name := usergroup.Name - handle := usergroup.Handle - label := name + " (" + handle + ")" - - // Filter out all symbols from search string - search := fmt.Sprintf(" %s %s", strings.ToLower(name), strings.ToLower(handle)) - search = symbolReg.ReplaceAllString(search, "") + return suggestions +} - suggestions = append(suggestions, suggestion{ - Type: "usergroup", - Label: label, - Value: "null", - Search: search, - Children: children, - }) - } +// sanitizeSearchTerm removes lowercases search term and removes skippable characters. +func sanitizeSearchTerm(s string) string { + symbolReg := regexp.MustCompile("[!-/:-@[-`{-~]+") + return symbolReg.ReplaceAllString(strings.ToLower(s), "") +} - return suggestions, nil - }) +// sanitizeCSV removes commas for comma-separated values. +func sanitizeCSV(s string) string { + return strings.Replace(s, ",", "", -1) } -type suggestion struct { - Type string `json:"type"` - Label string `json:"label"` - Value string `json:"value"` - Children []suggestion `json:"children,omitempty"` - Search string `json:"-"` +// suggestionLabel formats label with a mandatory name and an optional display name. +func suggestionLabel(name, displayName string) string { + label := name + if displayName != "" { + label += " (" + displayName + ")" + } + return label } type sendRequest struct { @@ -194,3 +104,10 @@ type sendRequest struct { Message string `json:"message"` AsUser bool `json:"as_user"` } + +type suggestion struct { + Type string `json:"type"` + Label string `json:"label"` + Value string `json:"value"` + Children []*suggestion `json:"children,omitempty"` +} diff --git a/internal/pkg/server/api_test.go b/internal/pkg/server/api_test.go new file mode 100644 index 0000000..94d5ff0 --- /dev/null +++ b/internal/pkg/server/api_test.go @@ -0,0 +1,147 @@ +package server + +import ( + "slices" + "testing" + + "github.com/gouline/blaster/internal/pkg/slack" +) + +func Test_sanitizeSearchTerm(t *testing.T) { + for _, test := range []struct { + s string + expected string + }{ + { + s: "test s!t@r#i$n^g [1,2)", + expected: "test string 12", + }, + { + s: "тестовая с!т@р#о$к^а (9.0]", + expected: "тестовая строка 90", + }, + { + s: "测!试@字#符$串%5^6", + expected: "测试字符串56", + }, + } { + actual := sanitizeSearchTerm(test.s) + if actual != test.expected { + t.Errorf("for %s: got %s, expected %s", test.s, actual, test.expected) + } + } +} + +func Test_sanitizeCSV(t *testing.T) { + for _, test := range []struct { + s string + expected string + }{ + { + s: "", + expected: "", + }, + { + s: ",", + expected: "", + }, + { + s: ",,", + expected: "", + }, + { + s: "something, else", + expected: "something else", + }, + } { + actual := sanitizeCSV(test.s) + if actual != test.expected { + t.Errorf("for %s: got %s, expected %s", test.s, actual, test.expected) + } + } +} + +func Test_suggestionLabel(t *testing.T) { + for _, test := range []struct { + name string + displayName string + expected string + }{ + { + name: "", + displayName: "", + expected: "", + }, + { + name: "", + displayName: "Mike", + expected: " (Mike)", + }, + { + name: "mg", + displayName: "Mike", + expected: "mg (Mike)", + }, + { + name: "mg", + displayName: "", + expected: "mg", + }, + } { + actual := suggestionLabel(test.name, test.displayName) + if actual != test.expected { + t.Errorf("for %s, %s: got %s, expected %s", test.name, test.displayName, actual, test.expected) + } + } +} + +func Test_suggestDestinations(t *testing.T) { + destinations := []*slack.Destination{ + { + Type: "user", + Name: "mg", + DisplayName: "Mike", + ID: "mg", + }, + { + Type: "user", + Name: "mark", + DisplayName: "Mark", + ID: "mark", + }, + { + Type: "user", + Name: "mk", + DisplayName: "Mark Knopfler", + ID: "mk", + }, + } + + for _, test := range []struct { + term string + expectedIDs []string + }{ + { + term: "mi", + expectedIDs: []string{"mg"}, + }, + { + term: "mark", + expectedIDs: []string{"mark", "mk"}, + }, + { + term: "kno", + expectedIDs: []string{"mk"}, + }, + } { + actuals := suggestDestinations(test.term, destinations) + actualValues := []string{} + for _, actual := range actuals { + actualValues = append(actualValues, actual.Value) + } + + if !slices.Equal(actualValues, test.expectedIDs) { + t.Errorf("for %s: got %s, expected %s", test.term, actualValues, test.expectedIDs) + } + } +} diff --git a/internal/pkg/server/auth.go b/internal/pkg/server/auth.go deleted file mode 100644 index 7a6d217..0000000 --- a/internal/pkg/server/auth.go +++ /dev/null @@ -1,94 +0,0 @@ -package server - -import ( - "log" - "net/http" - "net/url" - "os" - "strings" - - "github.com/gouline/blaster/internal/pkg/format" - "github.com/labstack/echo/v4" - "github.com/slack-go/slack" -) - -const slackBaseURL = "https://slack.com" - -var ( - slackClientID = os.Getenv("SLACK_CLIENT_ID") - slackClientSecret = os.Getenv("SLACK_CLIENT_SECRET") - - slackAPIScopes = []string{ - "team:read", - "users:read", - "usergroups:read", - "im:write", - "chat:write:bot", - "chat:write:user", - } -) - -// handleAuthInitiate handles /auth/initiate. -func (s *Server) handleAuthInitiate(c echo.Context) error { - redirectURI, err := authorizeURI(format.RelativeURI(c, "/auth/complete")) - if err != nil { - log.Fatal(err) - } - return c.Redirect(http.StatusFound, redirectURI) -} - -// handleAuthComplete handles /auth/complete. -func (s *Server) handleAuthComplete(c echo.Context) error { - code := c.QueryParam("code") - - if code != "" { - response, err := slack.GetOAuthResponse(http.DefaultClient, slackClientID, slackClientSecret, code, format.RelativeURI(c, "/auth/complete")) - if err != nil { - return c.String(http.StatusUnauthorized, err.Error()) - } - - s.setAuthorizedToken(c, response.AccessToken) - } - - return c.Redirect(http.StatusFound, format.RelativeURI(c, "/")) -} - -// handleAuthLogout handles /auth/logout. -func (s *Server) handleAuthLogout(c echo.Context) error { - s.setAuthorizedToken(c, "") - - return c.Redirect(http.StatusFound, format.RelativeURI(c, "/")) -} - -func authorizeURI(redirectURI string) (string, error) { - redirectURL, err := url.Parse(slackBaseURL + "/oauth/authorize") - if err != nil { - return "", err - } - q := redirectURL.Query() - q.Set("client_id", slackClientID) - q.Set("scope", strings.Join(slackAPIScopes, ",")) - q.Set("redirect_uri", redirectURI) - redirectURL.RawQuery = q.Encode() - - return redirectURL.String(), nil -} - -func (s *Server) setAuthorizedToken(c echo.Context, token string) { - c.SetCookie(&http.Cookie{ - Name: cookiePrefix + "slacktoken", - Value: url.QueryEscape(token), - MaxAge: 86400, - Path: "/", - Secure: !s.config.Debug, - HttpOnly: true, - }) -} - -func (s *Server) authorizedToken(c echo.Context) string { - tokenCookie, err := c.Cookie(cookiePrefix + "slacktoken") - if err != nil { - return "" - } - return tokenCookie.Value -} diff --git a/internal/pkg/server/pages.go b/internal/pkg/server/pages.go index 027a1b4..9c30a6e 100644 --- a/internal/pkg/server/pages.go +++ b/internal/pkg/server/pages.go @@ -2,16 +2,10 @@ package server import ( "net/http" - "time" - "github.com/gouline/blaster/internal/pkg/format" - "github.com/gouline/blaster/internal/pkg/scache" "github.com/labstack/echo/v4" - "github.com/slack-go/slack" ) -var teamCache = scache.New(12*time.Hour, 12*time.Hour) - // handleIndex handles /. func (s *Server) handleIndex(c echo.Context) error { return c.Render(http.StatusOK, "index.html", s.baseData(c, map[string]interface{}{ @@ -28,35 +22,6 @@ func (s *Server) handleNotFound(c echo.Context) error { } func (s *Server) baseData(c echo.Context, data map[string]interface{}) map[string]interface{} { - authorized := false - teamName := "" - - if token := s.authorizedToken(c); token != "" { - authorized = true - - cacheResponse := <-teamCache.ResponseChan(format.HashToken(token), func(key string) (interface{}, error) { - client := slack.New(token) - - teamInfo, err := client.GetTeamInfo() - if err != nil { - return nil, err - } - - return teamInfo.Name, err - }) - if cacheResponse.Error == nil { - teamName = cacheResponse.Value.(string) - } - - // Build other caches - go func() { - <-buildSuggestCache(token) - }() - } - - data["debugging"] = s.config.Debug - data["authorized"] = authorized - data["teamName"] = teamName - + data["slack"] = s.slack.Context(c) return data } diff --git a/internal/pkg/server/server.go b/internal/pkg/server/server.go index b0e5f13..aad224e 100644 --- a/internal/pkg/server/server.go +++ b/internal/pkg/server/server.go @@ -3,57 +3,61 @@ package server import ( "fmt" + "github.com/gouline/blaster/internal/pkg/slack" "github.com/gouline/blaster/internal/pkg/templates" "github.com/labstack/echo/v4" ) const ( - appName = "Blaster" - cookiePrefix = "blaster_" + appName = "Blaster" ) type Config struct { Debug bool - Host string - Port string - + Host string + Port string CertFile string KeyFile string StaticRoot string TemplatesRoot string + + SlackClientID string + SlackClientSecret string } type Server struct { config Config echo *echo.Echo + slack *slack.Slack } -func NewServer(config Config) (*Server, error) { +func New(config Config) (*Server, error) { s := &Server{ config: config, echo: echo.New(), + slack: slack.New(config.SlackClientID, config.SlackClientSecret), } s.echo.Debug = config.Debug + // Slack auth + s.echo.Use(s.slack.Middleware) + s.echo.GET("/login", s.slack.HandleLogin) + s.echo.GET("/logout", s.slack.HandleLogout) + s.echo.Static("/static", config.StaticRoot) var err error - s.echo.Renderer, err = templates.NewTemplates(config.TemplatesRoot, "layout.html") + s.echo.Renderer, err = templates.New(config.TemplatesRoot, "layout.html") if err != nil { return nil, err } - s.echo.RouteNotFound("/*", s.handleNotFound) - + // Pages s.echo.GET("/", s.handleIndex) - - authGroup := s.echo.Group("/auth") - authGroup.GET("/initiate", s.handleAuthInitiate) - authGroup.GET("/complete", s.handleAuthComplete) - authGroup.GET("/logout", s.handleAuthLogout) + s.echo.RouteNotFound("/*", s.handleNotFound) // API apiGroup := s.echo.Group("/api") diff --git a/internal/pkg/slack/context.go b/internal/pkg/slack/context.go new file mode 100644 index 0000000..3656ed2 --- /dev/null +++ b/internal/pkg/slack/context.go @@ -0,0 +1,167 @@ +package slack + +import ( + "crypto/sha1" + "fmt" + "time" + + "github.com/gouline/blaster/internal/pkg/scache" + "github.com/labstack/echo/v4" + "github.com/slack-go/slack" +) + +var ( + teamCache = scache.New(12*time.Hour, 12*time.Hour) + destinationCache = scache.New(5*time.Minute, 10*time.Minute) +) + +// Context contains authenticated session information. +type Context struct { + Authorized bool + TeamName string + token string +} + +// Context retrieves current authentication context. +func (s *Slack) Context(c echo.Context) *Context { + ctx := &Context{} + + if token := s.token(c); token != "" { + ctx.Authorized = true + + cacheResponse := <-teamCache.ResponseChan(hashToken(token), func(key string) (interface{}, error) { + client := slack.New(token) + + teamInfo, err := client.GetTeamInfo() + if err != nil { + return nil, err + } + + return teamInfo.Name, err + }) + if cacheResponse.Error == nil { + ctx.TeamName = cacheResponse.Value.(string) + } + + // Build other caches + go func() { + <-buildDestinationCache(token) + }() + + ctx.token = token + } + + return ctx +} + +// SendMessage sends text message to a user by ID. +// Depending on asUser, message will be sent as your authenticated user or as the app's bot. +func (ctx *Context) SendMessage(user, message string, asUser bool) error { + client := slack.New(ctx.token) + + // Open/get channel by user ID + channel, _, _, err := client.OpenConversation(&slack.OpenConversationParameters{ + Users: []string{user}, + }) + if err != nil { + return err + } + + // Post message to opened channel + _, _, err = client.PostMessage( + channel.ID, + slack.MsgOptionText(message, false), + slack.MsgOptionAsUser(asUser), + ) + return err +} + +// Destinations retrieves a list of users and user groups that you can send messages to. +func (ctx *Context) Destinations() ([]*Destination, error) { + cacheResponse := <-buildDestinationCache(ctx.token) + if cacheResponse.Error != nil { + return []*Destination{}, cacheResponse.Error + } + return cacheResponse.Value.([]*Destination), nil +} + +func buildDestinationCache(token string) <-chan scache.Response { + return destinationCache.ResponseChan(hashToken(token), func(key string) (interface{}, error) { + client := slack.New(token) + + var destinations []*Destination + + userLookup := map[string]*Destination{} + + // Get all users + users, err := client.GetUsers() + if err != nil { + return nil, err + } + + destinations = []*Destination{} + + for _, user := range users { + if user.Deleted || user.IsBot { + continue + } + + d := &Destination{ + Type: "user", + Name: user.Profile.RealName, + DisplayName: user.Profile.DisplayName, + ID: user.ID, + } + + destinations = append(destinations, d) + userLookup[user.ID] = d + } + + usergroups, err := client.GetUserGroups(slack.GetUserGroupsOptionIncludeUsers(true)) + if err != nil { + return nil, err + } + + for _, usergroup := range usergroups { + if !usergroup.IsUserGroup { + continue + } + + children := []*Destination{} + + for _, userID := range usergroup.Users { + user, found := userLookup[userID] + if !found { + continue + } + + children = append(children, user) + } + + destinations = append(destinations, &Destination{ + Type: "usergroup", + Name: usergroup.Name, + DisplayName: usergroup.Handle, + Children: children, + }) + } + + return destinations, nil + }) +} + +// hashToken hashes raw auth token with SHA-1. +func hashToken(token string) string { + h := sha1.New() + h.Write([]byte(token)) + return fmt.Sprintf("%x", h.Sum(nil)) +} + +// Destionation represents user or user group. +type Destination struct { + Type string + Name string + DisplayName string + ID string + Children []*Destination +} diff --git a/internal/pkg/slack/context_test.go b/internal/pkg/slack/context_test.go new file mode 100644 index 0000000..d2c8efb --- /dev/null +++ b/internal/pkg/slack/context_test.go @@ -0,0 +1,37 @@ +package slack + +import "testing" + +func Test_hashToken(t *testing.T) { + for _, test := range []struct { + a string + b string + expected bool + }{ + { + a: "", + b: "", + expected: true, + }, + { + a: "392n784y9238", + b: "392n784y9238", + expected: true, + }, + { + a: "392n784y9238", + b: "392n784y923", + expected: false, + }, + { + a: "392n784y9238", + b: "", + expected: false, + }, + } { + actual := hashToken(test.a) == hashToken(test.b) + if actual != test.expected { + t.Errorf("for %s, %s: got %t, expected %t", test.a, test.b, actual, test.expected) + } + } +} diff --git a/internal/pkg/slack/handlers.go b/internal/pkg/slack/handlers.go new file mode 100644 index 0000000..487a9db --- /dev/null +++ b/internal/pkg/slack/handlers.go @@ -0,0 +1,47 @@ +package slack + +import ( + "net/http" + + "github.com/labstack/echo/v4" + "github.com/slack-go/slack" +) + +// Middleware detects 'code' query parameter and completes authentication. +func (s *Slack) Middleware(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if code := c.QueryParam("code"); code != "" { + redirectURI := redirectURI(c, c.Request().RequestURI) + response, err := slack.GetOAuthResponse(http.DefaultClient, s.clientID, s.clientSecret, code, redirectURI) + if err != nil { + return c.String(http.StatusUnauthorized, err.Error()) + } + + s.setToken(c, response.AccessToken) + return c.Redirect(http.StatusSeeOther, redirectURI) + } + + if err := next(c); err != nil { + c.Error(err) + } + return nil + } +} + +// HandleLogin initiates Slack authorization. +func (s *Slack) HandleLogin(c echo.Context) error { + redirectURI := redirectURI(c, c.Request().Referer()) + redirectURI, err := s.authorizeURI(redirectURI) + if err != nil { + return c.String(http.StatusInternalServerError, err.Error()) + } + return c.Redirect(http.StatusFound, redirectURI) +} + +// HandleLogout clears Slack credentials. +func (s *Slack) HandleLogout(c echo.Context) error { + s.setToken(c, "") + + redirectURI := redirectURI(c, c.Request().Referer()) + return c.Redirect(http.StatusFound, redirectURI) +} diff --git a/internal/pkg/slack/slack.go b/internal/pkg/slack/slack.go new file mode 100644 index 0000000..d872c9f --- /dev/null +++ b/internal/pkg/slack/slack.go @@ -0,0 +1,87 @@ +package slack + +import ( + "net/http" + "net/url" + "strings" + + "github.com/labstack/echo/v4" +) + +const ( + baseURL = "https://slack.com" + cookiePrefix = "slack_" +) + +var ( + scopes = []string{ + "team:read", + "users:read", + "usergroups:read", + "im:write", + "chat:write:bot", + "chat:write:user", + } +) + +type Slack struct { + clientID string + clientSecret string +} + +func New(clientID, clientSecret string) *Slack { + return &Slack{ + clientID: clientID, + clientSecret: clientSecret, + } +} + +func (s *Slack) authorizeURI(redirectURI string) (string, error) { + redirectURL, err := url.Parse(baseURL + "/oauth/authorize") + if err != nil { + return "", err + } + q := redirectURL.Query() + q.Set("client_id", s.clientID) + q.Set("scope", strings.Join(scopes, ",")) + q.Set("redirect_uri", redirectURI) + redirectURL.RawQuery = q.Encode() + + return redirectURL.String(), nil +} + +// token fetches authorized token from HTTP cookie. +func (s *Slack) token(c echo.Context) string { + cookie, err := c.Cookie(cookiePrefix + "token") + if err != nil { + return "" + } + return cookie.Value +} + +// setToken sets authorized token to HTTP cookie. +func (s *Slack) setToken(c echo.Context, token string) { + c.SetCookie(&http.Cookie{ + Name: cookiePrefix + "token", + Value: url.QueryEscape(token), + MaxAge: 86400, + Path: "/", + Secure: true, + HttpOnly: true, + }) +} + +// redirectURI creates a stable URI for redirects. +// Removes query parameters and trailing slashes. +func redirectURI(c echo.Context, uri string) string { + url, _ := url.Parse(uri) + url.RawQuery = "" + if url.Scheme == "" { + url.Scheme = c.Scheme() + } + if url.Host == "" { + url.Host = c.Request().Host + } + url.Path, _ = strings.CutSuffix(url.Path, "/") + return url.String() +} diff --git a/internal/pkg/slack/slack_test.go b/internal/pkg/slack/slack_test.go new file mode 100644 index 0000000..7a85274 --- /dev/null +++ b/internal/pkg/slack/slack_test.go @@ -0,0 +1,45 @@ +package slack + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/labstack/echo/v4" +) + +func Test_redirectURI(t *testing.T) { + e := echo.New() + + for _, test := range []struct { + target string + relative string + expected string + }{ + { + target: "https://example.com/", + relative: "/context/path?test=1", + expected: "https://example.com/context/path", + }, + { + target: "https://example.com/", + relative: "/", + expected: "https://example.com", + }, + { + target: "https://example.com/", + relative: "", + expected: "https://example.com", + }, + } { + req := httptest.NewRequest(http.MethodGet, test.target, nil) + req.Header.Set(echo.HeaderContentType, echo.MIMETextPlain) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + actual := redirectURI(c, test.relative) + if actual != test.expected { + t.Errorf("for %s, %s: got %s, expected %s", test.target, test.relative, actual, test.expected) + } + } +} diff --git a/internal/pkg/templates/examples/README.md b/internal/pkg/templates/examples/README.md new file mode 100644 index 0000000..8809d0e --- /dev/null +++ b/internal/pkg/templates/examples/README.md @@ -0,0 +1,3 @@ +# Examples + +Example template structure. diff --git a/internal/pkg/templates/examples/about.html b/internal/pkg/templates/examples/about.html new file mode 100644 index 0000000..0c42451 --- /dev/null +++ b/internal/pkg/templates/examples/about.html @@ -0,0 +1,3 @@ +{{define "content"}} +About +{{end}} diff --git a/internal/pkg/templates/examples/home.html b/internal/pkg/templates/examples/home.html new file mode 100644 index 0000000..c45262e --- /dev/null +++ b/internal/pkg/templates/examples/home.html @@ -0,0 +1,3 @@ +{{define "content"}} +Home +{{end}} diff --git a/internal/pkg/templates/examples/layout.html b/internal/pkg/templates/examples/layout.html new file mode 100644 index 0000000..4722cb6 --- /dev/null +++ b/internal/pkg/templates/examples/layout.html @@ -0,0 +1,12 @@ + + + + + {{.title}} + + + + {{template "content" .}} + + + diff --git a/internal/pkg/templates/templates.go b/internal/pkg/templates/templates.go index ad09d6b..90b3aa8 100644 --- a/internal/pkg/templates/templates.go +++ b/internal/pkg/templates/templates.go @@ -5,6 +5,7 @@ import ( "html/template" "io" "io/fs" + "os" "path/filepath" "github.com/labstack/echo/v4" @@ -16,21 +17,26 @@ type Templates struct { // NewRenderer creates new renderer and parses templates directory recursively // Relative path including extension is used as template name. -func NewTemplates(root string, layout string) (*Templates, error) { +func New(root string, layout string) (*Templates, error) { t := &Templates{ templates: map[string]*template.Template{}, } - basePath := root + "/" + layout + if _, err := os.Stat(root); os.IsNotExist(err) { + return t, err + } + + layoutPath := root + "/" + layout + layoutExt := filepath.Ext(layoutPath) err := filepath.WalkDir(root, func(path string, d fs.DirEntry, err error) error { name := d.Name() - if d.IsDir() || name == layout { + if d.IsDir() || name == layout || filepath.Ext(path) != layoutExt { return nil } - t.templates[name] = template.Must(template.ParseFiles(basePath, path)) + t.templates[name] = template.Must(template.ParseFiles(layoutPath, path)) return nil }) diff --git a/internal/pkg/templates/templates_test.go b/internal/pkg/templates/templates_test.go new file mode 100644 index 0000000..5158a2a --- /dev/null +++ b/internal/pkg/templates/templates_test.go @@ -0,0 +1,49 @@ +package templates + +import ( + "net/http/httptest" + "testing" +) + +func Test_New(t *testing.T) { + templates, err := New("examples", "layout.html") + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + for _, test := range []struct { + name string + expected bool + }{ + { + name: "README.md", + expected: false, + }, + { + name: "about.html", + expected: true, + }, + { + name: "home.html", + expected: true, + }, + } { + tmpl, ok := templates.templates[test.name] + if ok != test.expected { + if test.expected { + t.Errorf("template %s expected but not found", test.name) + } else { + t.Errorf("template %s not expected but found", test.name) + } + } + if !ok { + continue + } + + rec := httptest.NewRecorder() + err := tmpl.Execute(rec, map[string]interface{}{}) + if err != nil { + t.Errorf("template %s error: %s", test.name, err) + } + } +} diff --git a/main.go b/main.go index 545a8dd..32f4a1e 100644 --- a/main.go +++ b/main.go @@ -8,14 +8,16 @@ import ( ) func main() { - s, err := server.NewServer(server.Config{ - Debug: os.Getenv("DEBUG") == "1", - Host: os.Getenv("HOST"), - Port: os.Getenv("PORT"), - CertFile: os.Getenv("CERT_FILE"), - KeyFile: os.Getenv("KEY_FILE"), - StaticRoot: "static", - TemplatesRoot: "templates", + s, err := server.New(server.Config{ + Debug: os.Getenv("DEBUG") == "1", + Host: os.Getenv("HOST"), + Port: os.Getenv("PORT"), + CertFile: os.Getenv("CERT_FILE"), + KeyFile: os.Getenv("KEY_FILE"), + StaticRoot: "static", + TemplatesRoot: "templates", + SlackClientID: os.Getenv("SLACK_CLIENT_ID"), + SlackClientSecret: os.Getenv("SLACK_CLIENT_SECRET"), }) if err != nil { panic(fmt.Sprintf("Failed to create server: %s", err)) diff --git a/templates/index.html b/templates/index.html index 10523a9..40cb9ab 100644 --- a/templates/index.html +++ b/templates/index.html @@ -10,7 +10,7 @@
- {{if .authorized}} + {{if .slack.Authorized}}
diff --git a/templates/layout.html b/templates/layout.html index cfa0394..c3654e8 100644 --- a/templates/layout.html +++ b/templates/layout.html @@ -32,19 +32,19 @@