cors.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. package cors
  2. import (
  3. "net/http"
  4. "strings"
  5. )
  6. type cors struct {
  7. checkOrigin bool
  8. checkMethods bool
  9. origin string
  10. methods string
  11. originsList map[string]struct{}
  12. methodsList map[string]struct{}
  13. }
  14. func Handle(config Config, next http.Handler) http.Handler {
  15. instance := new(config)
  16. if !instance.checkOrigin && !instance.checkMethods {
  17. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  18. next.ServeHTTP(w, r)
  19. })
  20. }
  21. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  22. if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" {
  23. instance.handleOptions(w, r)
  24. w.WriteHeader(http.StatusOK)
  25. } else {
  26. instance.handleRequest(w, r)
  27. next.ServeHTTP(w, r)
  28. }
  29. })
  30. }
  31. func new(config Config) cors {
  32. c := cors{}
  33. if len(config.origin) > 0 {
  34. c.checkOrigin = true
  35. c.origin = config.origin
  36. origins := strings.Split(config.origin, ",")
  37. c.originsList = make(map[string]struct{}, len(origins))
  38. for i := range origins {
  39. c.originsList[strings.ToLower(strings.TrimSpace(origins[i]))] = struct{}{}
  40. }
  41. }
  42. if len(config.methods) > 0 {
  43. c.checkMethods = true
  44. c.methods = config.methods
  45. methods := strings.Split(config.methods, ",")
  46. c.methodsList = make(map[string]struct{}, len(methods))
  47. for i := range methods {
  48. c.methodsList[strings.ToUpper(strings.TrimSpace(methods[i]))] = struct{}{}
  49. }
  50. }
  51. return c
  52. }
  53. func (c cors) handleOptions(w http.ResponseWriter, r *http.Request) {
  54. responseHeaders := w.Header()
  55. origin := r.Header.Get("Origin")
  56. if r.Method != http.MethodOptions {
  57. return
  58. }
  59. responseHeaders.Add("Vary", "Origin")
  60. responseHeaders.Add("Vary", "Access-Control-Request-Method")
  61. responseHeaders.Add("Vary", "Access-Control-Request-Headers")
  62. if origin == "" {
  63. return
  64. }
  65. if !c.isOriginAllowed(origin) {
  66. return
  67. }
  68. if c.checkOrigin {
  69. responseHeaders.Set("Access-Control-Allow-Origin", c.origin)
  70. }
  71. if c.checkMethods {
  72. responseHeaders.Set("Access-Control-Allow-Methods", c.methods)
  73. }
  74. }
  75. func (c cors) handleRequest(w http.ResponseWriter, r *http.Request) {
  76. responseHeaders := w.Header()
  77. origin := r.Header.Get("Origin")
  78. responseHeaders.Add("Vary", "Origin")
  79. if origin == "" {
  80. return
  81. }
  82. if !c.isOriginAllowed(origin) {
  83. w.WriteHeader(http.StatusBadRequest)
  84. return
  85. }
  86. if !c.isMethodAllowed(r.Method) {
  87. w.WriteHeader(http.StatusMethodNotAllowed)
  88. return
  89. }
  90. if c.checkOrigin {
  91. responseHeaders.Set("Access-Control-Allow-Origin", c.origin)
  92. }
  93. if c.checkMethods {
  94. responseHeaders.Set("Access-Control-Allow-Methods", c.methods)
  95. }
  96. }
  97. func (c cors) isOriginAllowed(origin string) bool {
  98. if !c.checkOrigin {
  99. return true
  100. }
  101. if _, ok := c.originsList[strings.ToLower(origin)]; ok {
  102. return true
  103. }
  104. return false
  105. }
  106. func (c cors) isMethodAllowed(method string) bool {
  107. if !c.checkMethods {
  108. return true
  109. }
  110. if _, ok := c.methodsList[strings.ToUpper(method)]; ok {
  111. return true
  112. }
  113. return false
  114. }