Skip to content

Commit

Permalink
Merge pull request #99 from TomWright/cors
Browse files Browse the repository at this point in the history
Add --allow-all-origins arg
  • Loading branch information
TomWright authored Sep 29, 2022
2 parents 33f80ef + 849a3a2 commit f37d02c
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 8 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -74,5 +74,5 @@ RUN mkdir -p ./out
RUN chmod 0777 ./in
RUN chmod 0777 ./out

CMD ["./app", "--mermaid=./node_modules/.bin/mmdc", "--in=./in", "--out=./out", "--puppeteer=./puppeteer-config.json"]
CMD ["./app", "--mermaid=./node_modules/.bin/mmdc", "--in=./in", "--out=./out", "--puppeteer=./puppeteer-config.json", "--allow-all-origins=true"]

2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ While this currently serves the diagrams via HTTP, it could easily be manipulate

Run the container:
```
docker run -d --name mermaid-server -p 80:80 tomwright/mermaid-server:latest
docker run -d --name mermaid-server -p 80:80 tomwright/mermaid-server:latest --allow-all-origins=true
```

### Manually as a go command
Expand Down
3 changes: 2 additions & 1 deletion cmd/app/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ func main() {
in := flag.String("in", "", "Directory to store input files.")
out := flag.String("out", "", "Directory to store output files.")
puppeteer := flag.String("puppeteer", "", "Full path to optional puppeteer config.")
allowAllOrigins := flag.Bool("allow-all-origins", false, "True to allow all request origins")
flag.Parse()

if *mermaid == "" {
Expand All @@ -36,7 +37,7 @@ func main() {
cache := internal.NewDiagramCache()
generator := internal.NewGenerator(cache, *mermaid, *in, *out, *puppeteer)

httpRunner := internal.NewHTTPRunner(generator)
httpRunner := internal.NewHTTPRunner(generator, *allowAllOrigins)
cleanupRunner := internal.NewCleanupRunner(generator)

g.Run(httpRunner)
Expand Down
26 changes: 21 additions & 5 deletions internal/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@ import (
)

// NewHTTPRunner returns a grace runner that runs a HTTP server.
func NewHTTPRunner(generator Generator) grace.Runner {
func NewHTTPRunner(generator Generator, allowAllOrigins bool) grace.Runner {
httpHandler := generateHTTPHandler(generator)

if allowAllOrigins {
httpHandler = allowAllOriginsMiddleware(httpHandler)
}

r := http.NewServeMux()
r.Handle("/generate", http.HandlerFunc(httpHandler))
r.Handle("/generate", httpHandler)

return &gracehttpserverrunner.HTTPServerRunner{
Server: &http.Server{
Expand All @@ -29,6 +33,18 @@ func NewHTTPRunner(generator Generator) grace.Runner {
}
}

// allowAllOriginsMiddleware sets appropriate CORS headers to allow requests from any origin.
func allowAllOriginsMiddleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
if origin == "" {
origin = "*"
}
w.Header().Set("Access-Control-Allow-Origin", origin)
h.ServeHTTP(w, r)
})
}

func writeJSON(rw http.ResponseWriter, value interface{}, status int) {
bytes, err := json.Marshal(value)
if err != nil {
Expand Down Expand Up @@ -105,8 +121,8 @@ func getDiagramFromPOST(r *http.Request, imgType string) (*Diagram, error) {
const URLParamImageType = "type"

// generateHTTPHandler returns a HTTP handler used to generate a diagram.
func generateHTTPHandler(generator Generator) func(rw http.ResponseWriter, r *http.Request) {
return func(rw http.ResponseWriter, r *http.Request) {
func generateHTTPHandler(generator Generator) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
var diagram *Diagram

imgType := r.URL.Query().Get(URLParamImageType)
Expand Down Expand Up @@ -155,5 +171,5 @@ func generateHTTPHandler(generator Generator) func(rw http.ResponseWriter, r *ht
if err := writeImage(rw, diagramBytes, http.StatusOK, imgType); err != nil {
writeErr(rw, fmt.Errorf("could not write diagram: %w", err), http.StatusInternalServerError)
}
}
})
}

0 comments on commit f37d02c

Please sign in to comment.