1
0
Pārlūkot izejas kodu

Update CORS middleware

Dmitriy Gnatenko 4 mēneši atpakaļ
vecāks
revīzija
75d5d8e157
2 mainītis faili ar 46 papildinājumiem un 48 dzēšanām
  1. 40 38
      http/middleware/cors/cors.go
  2. 6 10
      http/middleware/cors/readme.md

+ 40 - 38
http/middleware/cors/cors.go

@@ -5,7 +5,7 @@ import (
 	"strings"
 )
 
-type CORS struct {
+type cors struct {
 	checkOrigin  bool
 	checkMethods bool
 	origin       string
@@ -14,38 +14,10 @@ type CORS struct {
 	methodsList  map[string]struct{}
 }
 
-func NewCORSMiddleware(c Config) (*CORS, error) {
-	cors := CORS{}
+func Handle(config Config, next http.Handler) http.Handler {
+	instance := new(config)
 
-	if len(c.origin) > 0 {
-		cors.checkOrigin = true
-		cors.origin = c.origin
-
-		origins := strings.Split(cors.origin, ",")
-		cors.originsList = make(map[string]struct{}, len(origins))
-
-		for i := range origins {
-			cors.originsList[strings.ToLower(origins[i])] = struct{}{}
-		}
-	}
-
-	if len(c.methods) > 0 {
-		cors.checkMethods = true
-		cors.methods = c.methods
-
-		methods := strings.Split(cors.methods, ",")
-		cors.methodsList = make(map[string]struct{}, len(methods))
-
-		for i := range methods {
-			cors.methodsList[strings.ToUpper(methods[i])] = struct{}{}
-		}
-	}
-
-	return &cors, nil
-}
-
-func (c CORS) Handle(next http.Handler) http.Handler {
-	if !c.checkOrigin && !c.checkMethods {
+	if !instance.checkOrigin && !instance.checkMethods {
 		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 			next.ServeHTTP(w, r)
 		})
@@ -53,16 +25,46 @@ func (c CORS) Handle(next http.Handler) http.Handler {
 
 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" {
-			c.handleOptions(w, r)
+			instance.handleOptions(w, r)
 			w.WriteHeader(http.StatusOK)
 		} else {
-			c.handleRequest(w, r)
+			instance.handleRequest(w, r)
 			next.ServeHTTP(w, r)
 		}
 	})
 }
 
-func (c CORS) handleOptions(w http.ResponseWriter, r *http.Request) {
+func new(config Config) cors {
+	c := cors{}
+
+	if len(config.origin) > 0 {
+		c.checkOrigin = true
+		c.origin = config.origin
+
+		origins := strings.Split(config.origin, ",")
+		c.originsList = make(map[string]struct{}, len(origins))
+
+		for i := range origins {
+			c.originsList[strings.ToLower(strings.TrimSpace(origins[i]))] = struct{}{}
+		}
+	}
+
+	if len(config.methods) > 0 {
+		c.checkMethods = true
+		c.methods = config.methods
+
+		methods := strings.Split(config.methods, ",")
+		c.methodsList = make(map[string]struct{}, len(methods))
+
+		for i := range methods {
+			c.methodsList[strings.ToUpper(strings.TrimSpace(methods[i]))] = struct{}{}
+		}
+	}
+
+	return c
+}
+
+func (c cors) handleOptions(w http.ResponseWriter, r *http.Request) {
 	responseHeaders := w.Header()
 	origin := r.Header.Get("Origin")
 
@@ -91,7 +93,7 @@ func (c CORS) handleOptions(w http.ResponseWriter, r *http.Request) {
 	}
 }
 
-func (c CORS) handleRequest(w http.ResponseWriter, r *http.Request) {
+func (c cors) handleRequest(w http.ResponseWriter, r *http.Request) {
 	responseHeaders := w.Header()
 	origin := r.Header.Get("Origin")
 	responseHeaders.Add("Vary", "Origin")
@@ -119,7 +121,7 @@ func (c CORS) handleRequest(w http.ResponseWriter, r *http.Request) {
 	}
 }
 
-func (c CORS) isOriginAllowed(origin string) bool {
+func (c cors) isOriginAllowed(origin string) bool {
 	if !c.checkOrigin {
 		return true
 	}
@@ -131,7 +133,7 @@ func (c CORS) isOriginAllowed(origin string) bool {
 	return false
 }
 
-func (c CORS) isMethodAllowed(method string) bool {
+func (c cors) isMethodAllowed(method string) bool {
 	if !c.checkMethods {
 		return true
 	}

+ 6 - 10
http/middleware/cors/readme.md

@@ -1,20 +1,16 @@
 ## Usage example
 
 ```
-corsMiddleware, err := cors.NewCORSMiddleware(
-    cors.NewConfig(
-        cors.WithOrigin("test.com"),
-        cors.WithMethods("GET,POST"),        
-    ),
-)
+mux := http.NewServeMux()
 
-if err != nil {
-    // TODO
-}
+config := cors.NewConfig(
+    cors.WithOrigin("test.com"),
+    cors.WithMethods("GET,POST"),        
+)
 
 srv = &http.Server{
     Addr:    ":8080",
-    Handler: corsMiddleware.Handle(mux),
+    Handler: cors.Handle(config, mux),
 }
 
 ```