|
@@ -5,7 +5,7 @@ import (
|
|
|
"strings"
|
|
|
)
|
|
|
|
|
|
-type CORS struct {
|
|
|
+type cors struct {
|
|
|
checkOrigin bool
|
|
|
checkMethods bool
|
|
|
origin string
|
|
@@ -14,38 +14,10 @@ type CORS struct {
|
|
|
methodsList map[string]struct{}
|
|
|
}
|
|
|
|
|
|
-func NewCORSMiddleware(c Config) (*CORS, error) {
|
|
|
- cors := CORS{}
|
|
|
+func Handle(config Config, next http.Handler) http.Handler {
|
|
|
+ instance := new(config)
|
|
|
|
|
|
- if len(c.origin) > 0 {
|
|
|
- cors.checkOrigin = true
|
|
|
- cors.origin = c.origin
|
|
|
-
|
|
|
- origins := strings.Split(cors.origin, ",")
|
|
|
- cors.originsList = make(map[string]struct{}, len(origins))
|
|
|
-
|
|
|
- for i := range origins {
|
|
|
- cors.originsList[strings.ToLower(origins[i])] = struct{}{}
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- if len(c.methods) > 0 {
|
|
|
- cors.checkMethods = true
|
|
|
- cors.methods = c.methods
|
|
|
-
|
|
|
- methods := strings.Split(cors.methods, ",")
|
|
|
- cors.methodsList = make(map[string]struct{}, len(methods))
|
|
|
-
|
|
|
- for i := range methods {
|
|
|
- cors.methodsList[strings.ToUpper(methods[i])] = struct{}{}
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- return &cors, nil
|
|
|
-}
|
|
|
-
|
|
|
-func (c CORS) Handle(next http.Handler) http.Handler {
|
|
|
- if !c.checkOrigin && !c.checkMethods {
|
|
|
+ if !instance.checkOrigin && !instance.checkMethods {
|
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
next.ServeHTTP(w, r)
|
|
|
})
|
|
@@ -53,16 +25,46 @@ func (c CORS) Handle(next http.Handler) http.Handler {
|
|
|
|
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" {
|
|
|
- c.handleOptions(w, r)
|
|
|
+ instance.handleOptions(w, r)
|
|
|
w.WriteHeader(http.StatusOK)
|
|
|
} else {
|
|
|
- c.handleRequest(w, r)
|
|
|
+ instance.handleRequest(w, r)
|
|
|
next.ServeHTTP(w, r)
|
|
|
}
|
|
|
})
|
|
|
}
|
|
|
|
|
|
-func (c CORS) handleOptions(w http.ResponseWriter, r *http.Request) {
|
|
|
+func new(config Config) cors {
|
|
|
+ c := cors{}
|
|
|
+
|
|
|
+ if len(config.origin) > 0 {
|
|
|
+ c.checkOrigin = true
|
|
|
+ c.origin = config.origin
|
|
|
+
|
|
|
+ origins := strings.Split(config.origin, ",")
|
|
|
+ c.originsList = make(map[string]struct{}, len(origins))
|
|
|
+
|
|
|
+ for i := range origins {
|
|
|
+ c.originsList[strings.ToLower(strings.TrimSpace(origins[i]))] = struct{}{}
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if len(config.methods) > 0 {
|
|
|
+ c.checkMethods = true
|
|
|
+ c.methods = config.methods
|
|
|
+
|
|
|
+ methods := strings.Split(config.methods, ",")
|
|
|
+ c.methodsList = make(map[string]struct{}, len(methods))
|
|
|
+
|
|
|
+ for i := range methods {
|
|
|
+ c.methodsList[strings.ToUpper(strings.TrimSpace(methods[i]))] = struct{}{}
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return c
|
|
|
+}
|
|
|
+
|
|
|
+func (c cors) handleOptions(w http.ResponseWriter, r *http.Request) {
|
|
|
responseHeaders := w.Header()
|
|
|
origin := r.Header.Get("Origin")
|
|
|
|
|
@@ -91,7 +93,7 @@ func (c CORS) handleOptions(w http.ResponseWriter, r *http.Request) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (c CORS) handleRequest(w http.ResponseWriter, r *http.Request) {
|
|
|
+func (c cors) handleRequest(w http.ResponseWriter, r *http.Request) {
|
|
|
responseHeaders := w.Header()
|
|
|
origin := r.Header.Get("Origin")
|
|
|
responseHeaders.Add("Vary", "Origin")
|
|
@@ -119,7 +121,7 @@ func (c CORS) handleRequest(w http.ResponseWriter, r *http.Request) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (c CORS) isOriginAllowed(origin string) bool {
|
|
|
+func (c cors) isOriginAllowed(origin string) bool {
|
|
|
if !c.checkOrigin {
|
|
|
return true
|
|
|
}
|
|
@@ -131,7 +133,7 @@ func (c CORS) isOriginAllowed(origin string) bool {
|
|
|
return false
|
|
|
}
|
|
|
|
|
|
-func (c CORS) isMethodAllowed(method string) bool {
|
|
|
+func (c cors) isMethodAllowed(method string) bool {
|
|
|
if !c.checkMethods {
|
|
|
return true
|
|
|
}
|