limiter.go 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. package limiter
  2. import (
  3. "strings"
  4. "github.com/gin-gonic/gin"
  5. "github.com/runningwater/gohub/pkg/config"
  6. "github.com/runningwater/gohub/pkg/logger"
  7. "github.com/runningwater/gohub/pkg/redis"
  8. "github.com/ulule/limiter/v3"
  9. sredis "github.com/ulule/limiter/v3/drivers/store/redis"
  10. )
  11. // GetKeyIP 获取 Limiter 的 Key, IP
  12. func GetKeyIP(c *gin.Context) string {
  13. return c.ClientIP()
  14. }
  15. // GetKeyRouteWithIP Limitor 的 Key,路由+IP,针对单个路由做限流
  16. func GetKeyRouteWithIP(c *gin.Context) string {
  17. return routeToKeyString(c.FullPath()) + c.ClientIP()
  18. }
  19. // CheckRate 检测请求是否超额
  20. func CheckRate(c *gin.Context, key, formatted string) (limiter.Context, error) {
  21. var context limiter.Context
  22. // 1. 获取限流器
  23. rate, err := limiter.NewRateFromFormatted(formatted)
  24. if err != nil {
  25. logger.LogIf(err)
  26. return context, err
  27. }
  28. store, err := sredis.NewStoreWithOptions(
  29. redis.Redis.Client,
  30. limiter.StoreOptions{
  31. // 为 limiter 设置前缀
  32. Prefix: config.GetString("app.name") + ":limiter",
  33. },
  34. )
  35. if err != nil {
  36. logger.LogIf(err)
  37. return context, err
  38. }
  39. // 使用上面的初始化的 limiter.Rate 对象和存储对象
  40. limiterObj := limiter.New(store, rate)
  41. if c.GetBool("limiter-once") {
  42. // peek 方法获取当前状态, 不增加访问次数
  43. return limiterObj.Peek(c, key)
  44. } else {
  45. // 确保多个路由组里调用 LimitIP 进行限流时,只增加一次访问次数。
  46. c.Set("limiter-once", true)
  47. // 获取 limiter 的上下文, 包含当前的访问次数和剩余次数
  48. return limiterObj.Get(c, key)
  49. }
  50. }
  51. // routeToKeyString 辅助方法, 将 URL 中的 / 格式为 -
  52. func routeToKeyString(routeName string) string {
  53. routeName = strings.ReplaceAll(routeName, "/", "-")
  54. routeName = strings.ReplaceAll(routeName, ":", "_")
  55. return routeName
  56. }