limit.go 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. package middlewares
  2. import (
  3. "net/http"
  4. "github.com/gin-gonic/gin"
  5. "github.com/runningwater/gohub/pkg/app"
  6. "github.com/runningwater/gohub/pkg/limiter"
  7. "github.com/runningwater/gohub/pkg/logger"
  8. "github.com/runningwater/gohub/pkg/response"
  9. "github.com/spf13/cast"
  10. )
  11. // LimitIP 全局限流中间件,针对 IP 进行限流
  12. // limit 为格式化字符串,如 "5-S" ,示例:
  13. //
  14. // * 5 reqs/second: "5-S"
  15. // * 10 reqs/minute: "10-M"
  16. // * 1000 reqs/hour: "1000-H"
  17. // * 2000 reqs/day: "2000-D"
  18. func LimitIP(limit string) gin.HandlerFunc {
  19. if app.IsTesting() {
  20. limit = "1000000-H"
  21. }
  22. return func(c *gin.Context) {
  23. // 针对 IP 限流
  24. key := limiter.GetKeyIP(c)
  25. if ok := limitHandler(c, key, limit); !ok {
  26. return
  27. }
  28. c.Next()
  29. }
  30. }
  31. // LimitPerRoute 限流中间件,用在单独的路由中
  32. func LimitPerRoute(limit string) gin.HandlerFunc {
  33. if app.IsTesting() {
  34. limit = "1000000-H"
  35. }
  36. return func(c *gin.Context) {
  37. // 针对单个路由,增加访问次数
  38. c.Set("limiter-once", false)
  39. // 针对 IP + 路由进行限流
  40. key := limiter.GetKeyRouteWithIP(c)
  41. if ok := limitHandler(c, key, limit); !ok {
  42. return
  43. }
  44. c.Next()
  45. }
  46. }
  47. func limitHandler(c *gin.Context, key string, limit string) bool {
  48. // 获取超额的情况
  49. rate, err := limiter.CheckRate(c, key, limit)
  50. if err != nil {
  51. logger.LogIf(err)
  52. response.Abort500(c)
  53. return false
  54. }
  55. // ---- 设置标头信息-----
  56. // X-RateLimit-Limit :10000 最大访问次数
  57. // X-RateLimit-Remaining :9993 剩余的访问次数
  58. // X-RateLimit-Reset :1513784506 到该时间点,访问次数会重置为 X-RateLimit-Limit
  59. c.Header("X-RateLimit-Limit", cast.ToString(rate.Limit))
  60. c.Header("X-RateLimit-Remaining", cast.ToString(rate.Remaining))
  61. c.Header("X-RateLimit-Reset", cast.ToString(rate.Reset))
  62. // 超额
  63. if rate.Reached {
  64. // 提示用户超额了
  65. c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
  66. "message": "接口请求太频繁",
  67. })
  68. return false
  69. }
  70. return true
  71. }