package limiter import ( "strings" "github.com/gin-gonic/gin" "github.com/runningwater/gohub/pkg/config" "github.com/runningwater/gohub/pkg/logger" "github.com/runningwater/gohub/pkg/redis" "github.com/ulule/limiter/v3" sredis "github.com/ulule/limiter/v3/drivers/store/redis" ) // GetKeyIP 获取 Limiter 的 Key, IP func GetKeyIP(c *gin.Context) string { return c.ClientIP() } // GetKeyRouteWithIP Limitor 的 Key,路由+IP,针对单个路由做限流 func GetKeyRouteWithIP(c *gin.Context) string { return routeToKeyString(c.FullPath()) + c.ClientIP() } // CheckRate 检测请求是否超额 func CheckRate(c *gin.Context, key, formatted string) (limiter.Context, error) { var context limiter.Context // 1. 获取限流器 rate, err := limiter.NewRateFromFormatted(formatted) if err != nil { logger.LogIf(err) return context, err } store, err := sredis.NewStoreWithOptions( redis.Redis.Client, limiter.StoreOptions{ // 为 limiter 设置前缀 Prefix: config.GetString("app.name") + ":limiter", }, ) if err != nil { logger.LogIf(err) return context, err } // 使用上面的初始化的 limiter.Rate 对象和存储对象 limiterObj := limiter.New(store, rate) if c.GetBool("limiter-once") { // peek 方法获取当前状态, 不增加访问次数 return limiterObj.Peek(c, key) } else { // 确保多个路由组里调用 LimitIP 进行限流时,只增加一次访问次数。 c.Set("limiter-once", true) // 获取 limiter 的上下文, 包含当前的访问次数和剩余次数 return limiterObj.Get(c, key) } } // routeToKeyString 辅助方法, 将 URL 中的 / 格式为 - func routeToKeyString(routeName string) string { routeName = strings.ReplaceAll(routeName, "/", "-") routeName = strings.ReplaceAll(routeName, ":", "_") return routeName }