cors.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. package cors
  2. import (
  3. "net/http"
  4. "strings"
  5. )
  6. type CORS struct {
  7. checkOrigin bool
  8. checkMethods bool
  9. origin string
  10. methods string
  11. originsList map[string]struct{}
  12. methodsList map[string]struct{}
  13. }
  14. func NewCORSMiddleware(c Config) (*CORS, error) {
  15. cors := CORS{}
  16. if len(c.origin) > 0 {
  17. cors.checkOrigin = true
  18. cors.origin = c.origin
  19. origins := strings.Split(cors.origin, ",")
  20. cors.originsList = make(map[string]struct{}, len(origins))
  21. for i := range origins {
  22. cors.originsList[strings.ToLower(origins[i])] = struct{}{}
  23. }
  24. }
  25. if len(c.methods) > 0 {
  26. cors.checkMethods = true
  27. cors.methods = c.methods
  28. methods := strings.Split(cors.methods, ",")
  29. cors.methodsList = make(map[string]struct{}, len(methods))
  30. for i := range methods {
  31. cors.methodsList[strings.ToUpper(methods[i])] = struct{}{}
  32. }
  33. }
  34. return &cors, nil
  35. }
  36. func (c CORS) Handle(next http.Handler) http.Handler {
  37. if !c.checkOrigin && !c.checkMethods {
  38. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  39. next.ServeHTTP(w, r)
  40. })
  41. }
  42. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  43. if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" {
  44. c.handleOptions(w, r)
  45. w.WriteHeader(http.StatusOK)
  46. } else {
  47. c.handleRequest(w, r)
  48. next.ServeHTTP(w, r)
  49. }
  50. })
  51. }
  52. func (c CORS) handleOptions(w http.ResponseWriter, r *http.Request) {
  53. responseHeaders := w.Header()
  54. origin := r.Header.Get("Origin")
  55. if r.Method != http.MethodOptions {
  56. return
  57. }
  58. responseHeaders.Add("Vary", "Origin")
  59. responseHeaders.Add("Vary", "Access-Control-Request-Method")
  60. responseHeaders.Add("Vary", "Access-Control-Request-Headers")
  61. if origin == "" {
  62. return
  63. }
  64. if !c.isOriginAllowed(origin) {
  65. return
  66. }
  67. if c.checkOrigin {
  68. responseHeaders.Set("Access-Control-Allow-Origin", c.origin)
  69. }
  70. if c.checkMethods {
  71. responseHeaders.Set("Access-Control-Allow-Methods", c.methods)
  72. }
  73. }
  74. func (c CORS) handleRequest(w http.ResponseWriter, r *http.Request) {
  75. responseHeaders := w.Header()
  76. origin := r.Header.Get("Origin")
  77. responseHeaders.Add("Vary", "Origin")
  78. if origin == "" {
  79. return
  80. }
  81. if !c.isOriginAllowed(origin) {
  82. w.WriteHeader(http.StatusBadRequest)
  83. return
  84. }
  85. if !c.isMethodAllowed(r.Method) {
  86. w.WriteHeader(http.StatusMethodNotAllowed)
  87. return
  88. }
  89. if c.checkOrigin {
  90. responseHeaders.Set("Access-Control-Allow-Origin", c.origin)
  91. }
  92. if c.checkMethods {
  93. responseHeaders.Set("Access-Control-Allow-Methods", c.methods)
  94. }
  95. }
  96. func (c CORS) isOriginAllowed(origin string) bool {
  97. if !c.checkOrigin {
  98. return true
  99. }
  100. if _, ok := c.originsList[strings.ToLower(origin)]; ok {
  101. return true
  102. }
  103. return false
  104. }
  105. func (c CORS) isMethodAllowed(method string) bool {
  106. if !c.checkMethods {
  107. return true
  108. }
  109. if _, ok := c.methodsList[strings.ToUpper(method)]; ok {
  110. return true
  111. }
  112. return false
  113. }