runningwater 4 лет назад
Родитель
Сommit
49daaebd9c

+ 24 - 1
framework/context.go

@@ -15,18 +15,37 @@ import (
 type Context struct {
     request        *http.Request
     responseWriter http.ResponseWriter
+    ctx            context.Context
+
     // 是否超时标记位
     hasTimeout bool
     // 写保护机制
     writeMux *sync.Mutex
+
+    // 当前请求的 handler 链条
+    handlers []ControllerHandler
+    index    int //当前请求调用到调用链的哪个节点
 }
 
 func NewContext(r *http.Request, w http.ResponseWriter) *Context {
     return &Context{
         request:        r,
         responseWriter: w,
+        ctx:            r.Context(),
         writeMux:       &sync.Mutex{},
+        index:          -1,
+    }
+}
+
+// 核心函数, 调用 context 的下一个函数
+func (ctx *Context) Next() error {
+    ctx.index++
+    if ctx.index < len(ctx.handlers) {
+        if err := ctx.handlers[ctx.index](ctx); err != nil {
+            return err
+        }
     }
+    return nil
 }
 
 // #region base function  base 封装基本的函数功能
@@ -46,6 +65,10 @@ func (ctx *Context) SetHasTimeout() {
     ctx.hasTimeout = true
 }
 
+func (ctx *Context) setHandlers(handlers []ControllerHandler) {
+    ctx.handlers = handlers
+}
+
 func (ctx *Context) HasTimeout() bool {
     return ctx.hasTimeout
 }
@@ -200,7 +223,7 @@ func (ctx *Context) HTML(status int, obj interface{}, template string) error {
     return nil
 }
 
-func (ctx *Context) Text(status int , obj string ) error {
+func (ctx *Context) Text(status int, obj string) error {
     return nil
 }
 

+ 2 - 1
framework/controller.go

@@ -1,3 +1,4 @@
 package framework
 
-type ControllerHandler func(c *Context) error
+// 中间件(控制器)
+type ControllerHandler func(c *Context) error

+ 34 - 24
framework/core.go

@@ -8,7 +8,8 @@ import (
 
 // 核心框架
 type Core struct {
-    router map[string]*Tree
+    router      map[string]*Tree
+    middlewares []ControllerHandler
 }
 
 // 初始化框架核心结构
@@ -24,36 +25,49 @@ func NewCore() *Core {
     return &Core{router: router}
 }
 
+func (c *Core) Use(middlewares ...ControllerHandler) {
+    c.middlewares = append(c.middlewares, middlewares...)
+}
+
 // 对应 Method = GET
-func (c *Core) Get(url string, handler ControllerHandler) {
-   if err := c.router["GET"].AddRouter(url, handler); err != nil {
-       log.Fatal("add router error: ", err)
-   }
+func (c *Core) Get(url string, handlers ...ControllerHandler) {
+    allHandlers := append(c.middlewares, handlers...)
+    if err := c.router["GET"].AddRouter(url, allHandlers); err != nil {
+        log.Fatal("add router error: ", err)
+    }
 }
 
 // 对应 Method = POST
-func (c *Core) Post(url string, handler ControllerHandler) {
-   if err := c.router["POST"].AddRouter(url, handler); err != nil {
-       log.Fatal("add router error: ", err)
-   }
+func (c *Core) Post(url string, handlers ...ControllerHandler) {
+    allHandlers := append(c.middlewares, handlers...)
+    if err := c.router["POST"].AddRouter(url, allHandlers); err != nil {
+        log.Fatal("add router error: ", err)
+    }
 }
 
 // 对应 Method = PUT
-func (c *Core) Put(url string, handler ControllerHandler) {
-    if err := c.router["PUT"].AddRouter(url, handler); err != nil {
+func (c *Core) Put(url string, handlers ...ControllerHandler) {
+    allHandlers := append(c.middlewares, handlers...)
+    if err := c.router["PUT"].AddRouter(url, allHandlers); err != nil {
         log.Fatal("add router error: ", err)
     }
 }
 
 // 对应 Method = DELETE
-func (c *Core) Delete(url string, handler ControllerHandler) {
-    if err := c.router["DELETE"].AddRouter(url, handler); err != nil {
+func (c *Core) Delete(url string, handlers ...ControllerHandler) {
+    allHandlers := append(c.middlewares, handlers...)
+    if err := c.router["DELETE"].AddRouter(url, allHandlers); err != nil {
         log.Fatal("add router error: ", err)
     }
 }
 
+// 初始化 Group
+func (c *Core) Group(prefix string) IGroup {
+    return NewGroup(c, prefix)
+}
+
 // 匹配路由,如果没有匹配到,返回 nil
-func (c *Core) FindRouteByRequest(request *http.Request) ControllerHandler {
+func (c *Core) FindRouteByRequest(request *http.Request) []ControllerHandler {
     // uri 和 method 全部转换为大写, 保证大小写不敏感
     uri := request.URL.Path
     method := request.Method
@@ -67,28 +81,24 @@ func (c *Core) FindRouteByRequest(request *http.Request) ControllerHandler {
     return nil
 }
 
-// 初始化 Group
-func (c *Core) Group(prefix string) IGroup {
-    return NewGroup(c, prefix)
-}
-
 // 框架核心结构实现 handler 接口
 func (c *Core) ServeHTTP(response http.ResponseWriter, request *http.Request) {
-    log.Println("core.ServeHTTP")
     // 封装自定义 context
     ctx := NewContext(request, response)
 
     // 寻找路由
-    router := c.FindRouteByRequest(request)
-    if router == nil {
+    handlers := c.FindRouteByRequest(request)
+    if handlers == nil {
         ctx.Json(404, "not found")
         return
     }
 
+    // 设置 handlers 字段
+    ctx.setHandlers(handlers)
+
     // 调用跌幅函数,如果返回 err 代表存在内部错误,返回 500 状态码
-    if err := router(ctx); err != nil {
+    if err := ctx.Next(); err != nil {
         ctx.Json(500, "inner error")
         return
     }
-    log.Println("core.router")
 }

+ 55 - 18
framework/group.go

@@ -2,43 +2,80 @@ package framework
 
 // IGroup 代表前缀分组
 type IGroup interface {
-    Get(string, ControllerHandler)
-    Post(string, ControllerHandler)
-    Put(string, ControllerHandler)
-    Delete(string, ControllerHandler)
+    Get(string, ...ControllerHandler)
+    Post(string, ...ControllerHandler)
+    Put(string, ...ControllerHandler)
+    Delete(string, ...ControllerHandler)
+
+    //嵌套 group
+    Group(string) IGroup
+
+    // 嵌套中间件
+    Use(middlewares ...ControllerHandler)
 }
 
 // ==========================================================
 
 type Group struct {
     core   *Core
+    parent *Group // 指向上一个 group
     prefix string
+
+    middlewares []ControllerHandler // 存放中间件
 }
 
 // 初始化 Group
 func NewGroup(core *Core, prefix string) *Group {
     return &Group{
-        core:   core,
-        prefix: prefix,
+        core:        core,
+        prefix:      prefix,
+        middlewares: []ControllerHandler{},
     }
 }
 
-func (g *Group) Get(uri string, handler ControllerHandler) {
-    uri = g.prefix + uri
-    g.core.Get(uri, handler)
+func (g *Group) Get(uri string, handlers ...ControllerHandler) {
+    uri = g.getAbsolutePrefix() + uri
+    allHandlers := append(g.getMiddlewares(), handlers...)
+    g.core.Get(uri, allHandlers...)
+}
+
+func (g *Group) Post(uri string, handlers ...ControllerHandler) {
+    uri = g.getAbsolutePrefix() + uri
+    g.core.Post(uri, handlers...)
+}
+
+func (g *Group) Put(uri string, handlers ...ControllerHandler) {
+    uri = g.getAbsolutePrefix() + uri
+    g.core.Put(uri, handlers...)
 }
 
-func (g *Group) Post(uri string, handler ControllerHandler) {
-    uri = g.prefix + uri
-    g.core.Post(uri, handler)
+func (g *Group) Delete(uri string, handlers ...ControllerHandler) {
+    uri = g.getAbsolutePrefix() + uri
+    g.core.Delete(uri, handlers...)
 }
 
-func (g *Group) Put(uri string, handler ControllerHandler) {
-    uri = g.prefix + uri
-    g.core.Put(uri, handler)
+func (g *Group) Group(uri string) IGroup {
+    cgroup := NewGroup(g.core, uri)
+    cgroup.parent = g
+    return cgroup
 }
 
-func (g *Group) Delete(uri string, handler ControllerHandler) {
-    uri = g.prefix + uri
-    g.core.Delete(uri, handler)
+// 注册中间件
+func (g *Group) Use(middlewares ...ControllerHandler) {
+    g.middlewares = append(g.middlewares, middlewares...)
+}
+
+// 获取当前group的绝对路径
+func (g *Group) getAbsolutePrefix() string {
+    if g.parent == nil {
+        return g.prefix
+    }
+    return g.parent.getAbsolutePrefix() + g.prefix
+}
+
+func (g *Group) getMiddlewares() []ControllerHandler {
+    if g.parent == nil {
+        return g.middlewares
+    }
+    return append(g.parent.getMiddlewares(), g.middlewares...)
 }

+ 20 - 0
framework/middleware/Cost.go

@@ -0,0 +1,20 @@
+package middleware
+
+import (
+    "coredemo/framework"
+    "log"
+    "time"
+)
+
+func Cost() framework.ControllerHandler  {
+    return func(c *framework.Context) error {
+        // 记录开始时间
+        start := time.Now()
+        c.Next()
+        end := time.Now()
+        cost := end.Sub(start)
+        log.Printf("api uri: %v, cost: %v sendond(s)", c.GetRequest().RequestURI, cost.Seconds())
+
+        return nil
+    }
+}

+ 18 - 0
framework/middleware/recovery.go

@@ -0,0 +1,18 @@
+package middleware
+
+import "coredemo/framework"
+
+func Recovery() framework.ControllerHandler {
+    return func(c *framework.Context) error {
+        // 核心在增加这个recover机制,捕获c.Next()出现的panic
+        defer func() {
+            if err := recover(); err != nil {
+                c.Json(500, err)
+            }
+        }()
+
+        c.Next()
+
+        return nil
+    }
+}

+ 39 - 0
framework/middleware/test.go

@@ -0,0 +1,39 @@
+package middleware
+
+import (
+    "coredemo/framework"
+    "fmt"
+)
+
+// 测试中间件
+
+func Test1() framework.ControllerHandler {
+    return func(c *framework.Context) error {
+        fmt.Println("middleware pre test1")
+        c.Next()
+        fmt.Println("middleware post test1")
+        return nil
+    }
+
+}
+
+func Test2() framework.ControllerHandler {
+    // 使用函数回调
+    return func(c *framework.Context) error {
+        fmt.Println("middleware pre test2")
+        c.Next()
+        // 调用Next往下调用,会自增contxt.index
+        fmt.Println("middleware post test2")
+        return nil
+    }
+}
+
+func Test3() framework.ControllerHandler {
+    // 使用函数回调
+    return func(c *framework.Context) error {
+        fmt.Println("middleware pre test3")
+        c.Next()
+        fmt.Println("middleware post test3")
+        return nil
+    }
+}

+ 49 - 0
framework/timeout.go

@@ -0,0 +1,49 @@
+package framework
+
+import (
+    "context"
+    "fmt"
+    "log"
+    "time"
+)
+
+func TimeOutHandler(fun ControllerHandler, d time.Duration) ControllerHandler {
+    // 使用函数回调
+    return func(c *Context) error {
+        finish := make(chan struct{}, 1)
+        panicChan := make(chan interface{}, 1)
+
+        // 执行业务逻辑前预操作: 初始化超时 context
+        durationCtx, cancel := context.WithTimeout(c.BaseContext(), d)
+        defer cancel()
+
+        c.request.WithContext(durationCtx)
+
+        go func() {
+            defer func() {
+                if p := recover(); p != nil {
+                    panicChan <- p
+                }
+            }()
+
+            // 执行具体业务逻辑
+            fun(c)
+
+            finish <- struct{}{}
+        }()
+
+        // 执行业务逻辑后操作
+        select {
+        case p := <-panicChan:
+            log.Println(p)
+            c.responseWriter.WriteHeader(500)
+        case <-finish:
+            fmt.Println("finish")
+        case <-durationCtx.Done():
+            c.SetHasTimeout()
+            c.responseWriter.Write([]byte("time out"))
+        }
+
+        return nil
+    }
+}

+ 10 - 10
framework/trie.go

@@ -22,7 +22,7 @@ func NewTree() *Tree {
 /:user/name
 /:user/name/:age(冲突)
 */
-func (tree *Tree) AddRouter(uri string, handler ControllerHandler) error {
+func (tree *Tree) AddRouter(uri string, handlers []ControllerHandler) error {
     n := tree.root
     // 确认路由是否冲突
     if n.matchNode(uri) != nil {
@@ -56,7 +56,7 @@ func (tree *Tree) AddRouter(uri string, handler ControllerHandler) error {
             cnode.segment = segment
             if isLast {
                 cnode.isLast = true
-                cnode.handler = handler
+                cnode.handlers = handlers
             }
             n.childes = append(n.childes, cnode)
             objNode = cnode
@@ -68,22 +68,23 @@ func (tree *Tree) AddRouter(uri string, handler ControllerHandler) error {
     return nil
 }
 
-func (tree *Tree) FindHandler(uri string) ControllerHandler {
+// 匹配 uri
+func (tree *Tree) FindHandler(uri string) []ControllerHandler {
     matchNode := tree.root.matchNode(uri)
     if matchNode == nil {
         return nil
     }
-    return matchNode.handler
+    return matchNode.handlers
 }
 
 // =====================================================================================================================
 
 // 代表节点
 type node struct {
-    isLast  bool              // 代表这个节点是否可以成为最终的路由规则。该节点是否能成为一个独立的uri, 是否自身就是一个终极节点
-    segment string            // uri 中的字符串,代表这个节点表示的路由中某个段的字符串
-    handler ControllerHandler // 代表这个节点中包含的控制器,用于最终加载调用
-    childes []*node           // 代表这个节点下的子节点
+    isLast   bool                // 代表这个节点是否可以成为最终的路由规则。该节点是否能成为一个独立的uri, 是否自身就是一个终极节点
+    segment  string              // uri 中的字符串,代表这个节点表示的路由中某个段的字符串
+    handlers []ControllerHandler // 代表这个节点中包含的控制器,用于最终加载调用 中间件+控制器
+    childes  []*node             // 代表这个节点下的子节点
 }
 
 // ---------------------------------------------------------------------------------------------------------------------
@@ -92,8 +93,7 @@ func newNode() *node {
     return &node{
         isLast:  false,
         segment: "",
-        handler: nil,
-        childes: nil,
+        childes: []*node{},
     }
 }
 

+ 0 - 84
framework/trie_test.go

@@ -1,84 +0,0 @@
-package framework
-
-import "testing"
-
-func Test_filterChildNodes(t *testing.T) {
-    root := &node{
-        isLast:  false,
-        segment: "",
-        handler: func(*Context) error { return nil },
-        childes: []*node{
-            {
-                isLast:  true,
-                segment: "FOO",
-                handler: func(*Context) error { return nil },
-                childes: nil,
-            },
-            {
-                isLast:  false,
-                segment: ":id",
-                handler: nil,
-                childes: nil,
-            },
-        },
-    }
-
-    {
-        nodes := root.filterChildNodes("FOO")
-        if len(nodes) != 2 {
-            t.Error("foo error")
-        }
-    }
-
-    {
-        nodes := root.filterChildNodes(":foo")
-        if len(nodes) != 2 {
-            t.Error(":foo error")
-        }
-    }
-
-}
-
-func Test_matchNode(t *testing.T) {
-    root := &node{
-        isLast:  false,
-        segment: "",
-        handler: func(*Context) error { return nil },
-        childes: []*node{
-            {
-                isLast:  true,
-                segment: "FOO",
-                handler: nil,
-                childes: []*node{
-                    {
-                        isLast:  true,
-                        segment: "BAR",
-                        handler: func(*Context) error { panic("not implemented") },
-                        childes: []*node{},
-                    },
-                },
-            },
-            {
-                isLast:  true,
-                segment: ":id",
-                handler: nil,
-                childes: nil,
-            },
-        },
-    }
-
-    {
-        node := root.matchNode("foo/bar")
-        if node == nil {
-            t.Error("match normal node error")
-        }
-    }
-
-    {
-        node := root.matchNode("test")
-        if node == nil {
-            t.Error("match test")
-        }
-    }
-
-}

+ 5 - 4
route.go

@@ -2,6 +2,7 @@ package main
 
 import (
     "coredemo/framework"
+    "coredemo/framework/middleware"
 )
 
 /**
@@ -9,14 +10,14 @@ import (
 */
 
 func registerRouter(core *framework.Core) {
+    core.Use(middleware.Recovery())
+    core.Use(middleware.Cost())
 
-    core.Get("/user/login", UserLoginController)
+    core.Get("/user/login", middleware.Test1(), UserLoginController, middleware.Test2())
 
     subjectApi := core.Group("/subject")
     {
-        subjectApi.Delete("/:id", SubjectController)
-        subjectApi.Put("/:id", SubjectController)
-        subjectApi.Get("/:id", SubjectController)
+        subjectApi.Get("/:id", middleware.Test3(), SubjectController)
         subjectApi.Get("/list/all", SubjectController)
     }
 }

+ 5 - 1
user_controller.go

@@ -1,8 +1,12 @@
 package main
 
-import "coredemo/framework"
+import (
+    "coredemo/framework"
+    "fmt"
+)
 
 func UserLoginController(c *framework.Context) error {
+    fmt.Println("ok, UserLoginController")
     c.Json(200, "ok, UserLoginController")
     return nil
 }