Explorar el Código

feat: 限流中间件

runningwater hace 7 meses
padre
commit
0159c29cb7
Se han modificado 6 ficheros con 168 adiciones y 6 borrados
  1. 83 0
      app/http/middlewares/limit.go
  2. 2 0
      go.mod
  3. 4 0
      go.sum
  4. 1 1
      gohub.http
  5. 67 0
      pkg/limiter/limiter.go
  6. 11 5
      routes/api.go

+ 83 - 0
app/http/middlewares/limit.go

@@ -0,0 +1,83 @@
+package middlewares
+
+import (
+	"net/http"
+
+	"github.com/gin-gonic/gin"
+	"github.com/runningwater/gohub/pkg/app"
+	"github.com/runningwater/gohub/pkg/limiter"
+	"github.com/runningwater/gohub/pkg/logger"
+	"github.com/runningwater/gohub/pkg/response"
+	"github.com/spf13/cast"
+)
+
+// LimitIP 全局限流中间件,针对 IP 进行限流
+// limit 为格式化字符串,如 "5-S" ,示例:
+//
+// * 5 reqs/second: "5-S"
+// * 10 reqs/minute: "10-M"
+// * 1000 reqs/hour: "1000-H"
+// * 2000 reqs/day: "2000-D"
+func LimitIP(limit string) gin.HandlerFunc {
+	if app.IsTesting() {
+		limit = "1000000-H"
+	}
+
+	return func(c *gin.Context) {
+		// 针对 IP 限流
+		key := limiter.GetKeyIP(c)
+		if ok := limitHandler(c, key, limit); !ok {
+			return
+		}
+		c.Next()
+	}
+}
+
+// LimitPerRoute 限流中间件,用在单独的路由中
+func LimitPerRoute(limit string) gin.HandlerFunc {
+	if app.IsTesting() {
+		limit = "1000000-H"
+	}
+	return func(c *gin.Context) {
+
+		// 针对单个路由,增加访问次数
+		c.Set("limiter-once", false)
+
+		// 针对 IP + 路由进行限流
+		key := limiter.GetKeyRouteWithIP(c)
+		if ok := limitHandler(c, key, limit); !ok {
+			return
+		}
+		c.Next()
+	}
+}
+
+func limitHandler(c *gin.Context, key string, limit string) bool {
+
+	// 获取超额的情况
+	rate, err := limiter.CheckRate(c, key, limit)
+	if err != nil {
+		logger.LogIf(err)
+		response.Abort500(c)
+		return false
+	}
+
+	// ---- 设置标头信息-----
+	// X-RateLimit-Limit :10000 最大访问次数
+	// X-RateLimit-Remaining :9993 剩余的访问次数
+	// X-RateLimit-Reset :1513784506 到该时间点,访问次数会重置为 X-RateLimit-Limit
+	c.Header("X-RateLimit-Limit", cast.ToString(rate.Limit))
+	c.Header("X-RateLimit-Remaining", cast.ToString(rate.Remaining))
+	c.Header("X-RateLimit-Reset", cast.ToString(rate.Reset))
+
+	// 超额
+	if rate.Reached {
+		// 提示用户超额了
+		c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
+			"message": "接口请求太频繁",
+		})
+		return false
+	}
+
+	return true
+}

+ 2 - 0
go.mod

@@ -20,6 +20,7 @@ require (
 	github.com/spf13/cobra v1.9.1
 	github.com/spf13/viper v1.20.1
 	github.com/thedevsaddam/govalidator v1.9.10
+	github.com/ulule/limiter/v3 v3.11.2
 	go.uber.org/zap v1.27.0
 	golang.org/x/crypto v0.37.0
 	gopkg.in/natefinch/lumberjack.v2 v2.2.1
@@ -62,6 +63,7 @@ require (
 	github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
 	github.com/modern-go/reflect2 v1.0.2 // indirect
 	github.com/pelletier/go-toml/v2 v2.2.3 // indirect
+	github.com/pkg/errors v0.9.1 // indirect
 	github.com/sagikazarmark/locafero v0.7.0 // indirect
 	github.com/sourcegraph/conc v0.3.0 // indirect
 	github.com/spf13/afero v1.12.0 // indirect

+ 4 - 0
go.sum

@@ -175,6 +175,8 @@ github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWb
 github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
 github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M=
 github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc=
+github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
+github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
 github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
 github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
 github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
@@ -223,6 +225,8 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
 github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
 github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
 github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
+github.com/ulule/limiter/v3 v3.11.2 h1:P4yOrxoEMJbOTfRJR2OzjL90oflzYPPmWg+dvwN2tHA=
+github.com/ulule/limiter/v3 v3.11.2/go.mod h1:QG5GnFOCV+k7lrL5Y8kgEeeflPH3+Cviqlqa8SVSQxI=
 github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
 github.com/yuin/goldmark v1.1.30/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
 github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=

+ 1 - 1
gohub.http

@@ -25,7 +25,7 @@ Content-Type: application/json
 {
   "phone": "15968875425",
   "captcha_id": "captcha_skip_test",
-  "captcha_answer": "123456"
+  "captcha_answer": "888888"
 }
 
 ### POST /verify_code/email 发送邮箱验证码

+ 67 - 0
pkg/limiter/limiter.go

@@ -0,0 +1,67 @@
+package limiter
+
+import (
+	"strings"
+
+	"github.com/gin-gonic/gin"
+	"github.com/runningwater/gohub/pkg/config"
+	"github.com/runningwater/gohub/pkg/logger"
+	"github.com/runningwater/gohub/pkg/redis"
+
+	"github.com/ulule/limiter/v3"
+	sredis "github.com/ulule/limiter/v3/drivers/store/redis"
+)
+
+// GetKeyIP 获取 Limiter 的 Key, IP
+func GetKeyIP(c *gin.Context) string {
+	return c.ClientIP()
+}
+
+// GetKeyRouteWithIP Limitor 的 Key,路由+IP,针对单个路由做限流
+func GetKeyRouteWithIP(c *gin.Context) string {
+	return routeToKeyString(c.FullPath()) + c.ClientIP()
+}
+
+// CheckRate 检测请求是否超额
+func CheckRate(c *gin.Context, key, formatted string) (limiter.Context, error) {
+
+	var context limiter.Context
+	// 1. 获取限流器
+	rate, err := limiter.NewRateFromFormatted(formatted)
+	if err != nil {
+		logger.LogIf(err)
+		return context, err
+	}
+
+	store, err := sredis.NewStoreWithOptions(
+		redis.Redis.Client,
+		limiter.StoreOptions{
+			// 为 limiter 设置前缀
+			Prefix: config.GetString("app.name") + ":limiter",
+		},
+	)
+	if err != nil {
+		logger.LogIf(err)
+		return context, err
+	}
+
+	// 使用上面的初始化的 limiter.Rate 对象和存储对象
+	limiterObj := limiter.New(store, rate)
+
+	if c.GetBool("limiter-once") {
+		// peek 方法获取当前状态, 不增加访问次数
+		return limiterObj.Peek(c, key)
+	} else {
+		// 确保多个路由组里调用 LimitIP 进行限流时,只增加一次访问次数。
+		c.Set("limiter-once", true)
+		// 获取 limiter 的上下文, 包含当前的访问次数和剩余次数
+		return limiterObj.Get(c, key)
+	}
+}
+
+// routeToKeyString 辅助方法, 将 URL 中的 / 格式为 -
+func routeToKeyString(routeName string) string {
+	routeName = strings.ReplaceAll(routeName, "/", "-")
+	routeName = strings.ReplaceAll(routeName, ":", "_")
+	return routeName
+}

+ 11 - 5
routes/api.go

@@ -5,30 +5,36 @@ import (
 	"github.com/gin-gonic/gin"
 
 	"github.com/runningwater/gohub/app/http/controllers/api/v1/auth"
+	"github.com/runningwater/gohub/app/http/middlewares"
 )
 
 // RegisterAPIRoutes 注册路由
 func RegisterAPIRoutes(router *gin.Engine) {
 	// v1 路由组,所有 v1 版本的路由都放在这里
 	v1 := router.Group("/v1")
+	// 全局限流中间件:每小时限流。这里是所有 API (根据 IP)请求加起来。
+	// 作为参考 Github API 每小时最多 60 个请求(根据 IP)。
+	// 测试时,可以调高一点。
+	v1.Use(middlewares.LimitIP("200-H"))
+
 	{
 		authGroup := v1.Group("/auth")
 		{
 			suc := new(auth.SignupController)
 			vcc := new(auth.VerifyCodeController)
 			// 注册手机号是否已存在
-			authGroup.POST("/signup/phone/exist", suc.IsPhoneExist)
+			authGroup.POST("/signup/phone/exist", middlewares.GuestJWT(), middlewares.LimitIP("60-H"), suc.IsPhoneExist)
 			// 注册邮箱是否已存在
-			authGroup.POST("/signup/email/exist", suc.IsEmailExist)
+			authGroup.POST("/signup/email/exist", middlewares.GuestJWT(), middlewares.LimitIP("60-H"), suc.IsEmailExist)
 			// 注册用户
 			authGroup.POST("/signup/using-phone", suc.SignupUsingPhone)
 			authGroup.POST("/signup/using-email", suc.SignupUsingEmail)
 			// 显示图片验证码
-			authGroup.POST("/verify_code/captcha", vcc.ShowCaptcha)
+			authGroup.POST("/verify_code/captcha", middlewares.LimitIP("50-H"), vcc.ShowCaptcha)
 			// 发送手机验证码
-			authGroup.POST("/verify_code/phone", vcc.SendUsingPhone)
+			authGroup.POST("/verify_code/phone", middlewares.LimitIP("20-H"), vcc.SendUsingPhone)
 			// 发送邮箱验证码
-			authGroup.POST("/verify_code/email", vcc.SendUsingEmail)
+			authGroup.POST("/verify_code/email", middlewares.LimitIP("20-H"), vcc.SendUsingEmail)
 
 			logc := new(auth.LoginController)
 			// 手机号登录