Procházet zdrojové kódy

Add CORS middleware

Dima před 2 měsíci
rodič
revize
45e7f4e539

+ 144 - 0
http/middleware/cors/cors.go

@@ -0,0 +1,144 @@
+package cors
+
+import (
+	"net/http"
+	"strings"
+)
+
+type CORS struct {
+	checkOrigin  bool
+	checkMethods bool
+	origin       string
+	methods      string
+	originsList  map[string]struct{}
+	methodsList  map[string]struct{}
+}
+
+func NewCORSMiddleware(c Config) (*CORS, error) {
+	cors := CORS{}
+
+	if len(c.origin) > 0 {
+		cors.checkOrigin = true
+		cors.origin = c.origin
+
+		origins := strings.Split(cors.origin, ",")
+		cors.originsList = make(map[string]struct{}, len(origins))
+
+		for i := range origins {
+			cors.originsList[strings.ToLower(origins[i])] = struct{}{}
+		}
+	}
+
+	if len(c.methods) > 0 {
+		cors.checkMethods = true
+		cors.methods = c.methods
+
+		methods := strings.Split(cors.methods, ",")
+		cors.methodsList = make(map[string]struct{}, len(methods))
+
+		for i := range methods {
+			cors.methodsList[strings.ToUpper(methods[i])] = struct{}{}
+		}
+	}
+
+	return &cors, nil
+}
+
+func (c CORS) Handle(next http.Handler) http.Handler {
+	if !c.checkOrigin && !c.checkMethods {
+		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+			next.ServeHTTP(w, r)
+		})
+	}
+
+	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" {
+			c.handleOptions(w, r)
+			w.WriteHeader(http.StatusOK)
+		} else {
+			c.handleRequest(w, r)
+			next.ServeHTTP(w, r)
+		}
+	})
+}
+
+func (c CORS) handleOptions(w http.ResponseWriter, r *http.Request) {
+	responseHeaders := w.Header()
+	origin := r.Header.Get("Origin")
+
+	if r.Method != http.MethodOptions {
+		return
+	}
+
+	responseHeaders.Add("Vary", "Origin")
+	responseHeaders.Add("Vary", "Access-Control-Request-Method")
+	responseHeaders.Add("Vary", "Access-Control-Request-Headers")
+
+	if origin == "" {
+		return
+	}
+
+	if !c.isOriginAllowed(origin) {
+		return
+	}
+
+	if c.checkOrigin {
+		responseHeaders.Set("Access-Control-Allow-Origin", c.origin)
+	}
+
+	if c.checkMethods {
+		responseHeaders.Set("Access-Control-Allow-Methods", c.methods)
+	}
+}
+
+func (c CORS) handleRequest(w http.ResponseWriter, r *http.Request) {
+	responseHeaders := w.Header()
+	origin := r.Header.Get("Origin")
+	responseHeaders.Add("Vary", "Origin")
+
+	if origin == "" {
+		return
+	}
+
+	if !c.isOriginAllowed(origin) {
+		w.WriteHeader(http.StatusBadRequest)
+		return
+	}
+
+	if !c.isMethodAllowed(r.Method) {
+		w.WriteHeader(http.StatusMethodNotAllowed)
+		return
+	}
+
+	if c.checkOrigin {
+		responseHeaders.Set("Access-Control-Allow-Origin", c.origin)
+	}
+
+	if c.checkMethods {
+		responseHeaders.Set("Access-Control-Allow-Methods", c.methods)
+	}
+}
+
+func (c CORS) isOriginAllowed(origin string) bool {
+	if !c.checkOrigin {
+		return true
+	}
+
+	if _, ok := c.originsList[strings.ToLower(origin)]; ok {
+		return true
+	}
+
+	return false
+}
+
+func (c CORS) isMethodAllowed(method string) bool {
+	if !c.checkMethods {
+		return true
+	}
+
+	if _, ok := c.methodsList[strings.ToUpper(method)]; ok {
+		return true
+	}
+
+	return false
+}

+ 35 - 0
http/middleware/cors/cors_config.go

@@ -0,0 +1,35 @@
+package cors
+
+type Config struct {
+	origin  string
+	methods string
+}
+
+type ConfigOption func(*Config)
+
+type ConfigOptions []ConfigOption
+
+func (s *ConfigOptions) Add(option ConfigOption) {
+	*s = append(*s, option)
+}
+
+func NewConfig(opts ...ConfigOption) Config {
+	c := &Config{}
+	for _, opt := range opts {
+		opt(c)
+	}
+
+	return *c
+}
+
+func WithOrigin(origin string) ConfigOption {
+	return func(s *Config) {
+		s.origin = origin
+	}
+}
+
+func WithMethods(methods string) ConfigOption {
+	return func(s *Config) {
+		s.methods = methods
+	}
+}

+ 20 - 0
http/middleware/cors/readme.md

@@ -0,0 +1,20 @@
+## Usage example
+
+```
+corsMiddleware, err := cors.NewCORSMiddleware(
+    cors.NewConfig(
+        cors.WithOrigin("test.com"),
+        cors.WithMethods("GET,POST"),        
+    ),
+)
+
+if err != nil {
+    // TODO
+}
+
+srv = &http.Server{
+    Addr:    ":8080",
+    Handler: corsMiddleware.Handle(mux),
+}
+
+```