123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144 |
- package cors
- import (
- "net/http"
- "strings"
- )
- type CORS struct {
- checkOrigin bool
- checkMethods bool
- origin string
- methods string
- originsList map[string]struct{}
- methodsList map[string]struct{}
- }
- func NewCORSMiddleware(c Config) (*CORS, error) {
- cors := CORS{}
- if len(c.origin) > 0 {
- cors.checkOrigin = true
- cors.origin = c.origin
- origins := strings.Split(cors.origin, ",")
- cors.originsList = make(map[string]struct{}, len(origins))
- for i := range origins {
- cors.originsList[strings.ToLower(origins[i])] = struct{}{}
- }
- }
- if len(c.methods) > 0 {
- cors.checkMethods = true
- cors.methods = c.methods
- methods := strings.Split(cors.methods, ",")
- cors.methodsList = make(map[string]struct{}, len(methods))
- for i := range methods {
- cors.methodsList[strings.ToUpper(methods[i])] = struct{}{}
- }
- }
- return &cors, nil
- }
- func (c CORS) Handle(next http.Handler) http.Handler {
- if !c.checkOrigin && !c.checkMethods {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- next.ServeHTTP(w, r)
- })
- }
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" {
- c.handleOptions(w, r)
- w.WriteHeader(http.StatusOK)
- } else {
- c.handleRequest(w, r)
- next.ServeHTTP(w, r)
- }
- })
- }
- func (c CORS) handleOptions(w http.ResponseWriter, r *http.Request) {
- responseHeaders := w.Header()
- origin := r.Header.Get("Origin")
- if r.Method != http.MethodOptions {
- return
- }
- responseHeaders.Add("Vary", "Origin")
- responseHeaders.Add("Vary", "Access-Control-Request-Method")
- responseHeaders.Add("Vary", "Access-Control-Request-Headers")
- if origin == "" {
- return
- }
- if !c.isOriginAllowed(origin) {
- return
- }
- if c.checkOrigin {
- responseHeaders.Set("Access-Control-Allow-Origin", c.origin)
- }
- if c.checkMethods {
- responseHeaders.Set("Access-Control-Allow-Methods", c.methods)
- }
- }
- func (c CORS) handleRequest(w http.ResponseWriter, r *http.Request) {
- responseHeaders := w.Header()
- origin := r.Header.Get("Origin")
- responseHeaders.Add("Vary", "Origin")
- if origin == "" {
- return
- }
- if !c.isOriginAllowed(origin) {
- w.WriteHeader(http.StatusBadRequest)
- return
- }
- if !c.isMethodAllowed(r.Method) {
- w.WriteHeader(http.StatusMethodNotAllowed)
- return
- }
- if c.checkOrigin {
- responseHeaders.Set("Access-Control-Allow-Origin", c.origin)
- }
- if c.checkMethods {
- responseHeaders.Set("Access-Control-Allow-Methods", c.methods)
- }
- }
- func (c CORS) isOriginAllowed(origin string) bool {
- if !c.checkOrigin {
- return true
- }
- if _, ok := c.originsList[strings.ToLower(origin)]; ok {
- return true
- }
- return false
- }
- func (c CORS) isMethodAllowed(method string) bool {
- if !c.checkMethods {
- return true
- }
- if _, ok := c.methodsList[strings.ToUpper(method)]; ok {
- return true
- }
- return false
- }
|