runningwater 4 gadi atpakaļ
vecāks
revīzija
5bea1225b8
7 mainītis faili ar 394 papildinājumiem un 10 dzēšanām
  1. 67 8
      framework/core.go
  2. 44 0
      framework/group.go
  3. 168 0
      framework/trie.go
  4. 84 0
      framework/trie_test.go
  5. 15 2
      route.go
  6. 8 0
      subject_controller.go
  7. 8 0
      user_controller.go

+ 67 - 8
framework/core.go

@@ -3,33 +3,92 @@ package framework
 import (
     "log"
     "net/http"
+    "strings"
 )
 
 // 核心框架
 type Core struct {
-    router map[string]ControllerHandler
+    router map[string]*Tree
 }
 
 // 初始化框架核心结构
 func NewCore() *Core {
-    return &Core{router: map[string]ControllerHandler{}}
+
+    // 初始化路由
+    router := map[string]*Tree{}
+    router["GET"] = NewTree()
+    router["POST"] = NewTree()
+    router["PUT"] = NewTree()
+    router["DELETE"] = NewTree()
+
+    return &Core{router: router}
 }
 
+// 对应 Method = GET
 func (c *Core) Get(url string, handler ControllerHandler) {
-    c.router[url] = handler
+   if err := c.router["GET"].AddRouter(url, handler); err != nil {
+       log.Fatal("add router error: ", err)
+   }
+}
+
+// 对应 Method = POST
+func (c *Core) Post(url string, handler ControllerHandler) {
+   if err := c.router["POST"].AddRouter(url, handler); err != nil {
+       log.Fatal("add router error: ", err)
+   }
+}
+
+// 对应 Method = PUT
+func (c *Core) Put(url string, handler ControllerHandler) {
+    if err := c.router["PUT"].AddRouter(url, handler); err != nil {
+        log.Fatal("add router error: ", err)
+    }
+}
+
+// 对应 Method = DELETE
+func (c *Core) Delete(url string, handler ControllerHandler) {
+    if err := c.router["DELETE"].AddRouter(url, handler); err != nil {
+        log.Fatal("add router error: ", err)
+    }
+}
+
+// 匹配路由,如果没有匹配到,返回 nil
+func (c *Core) FindRouteByRequest(request *http.Request) ControllerHandler {
+    // uri 和 method 全部转换为大写, 保证大小写不敏感
+    uri := request.URL.Path
+    method := request.Method
+    upperMethod := strings.ToUpper(method)
+
+    // 查找第一层 map
+    if methodHandlers, ok := c.router[upperMethod]; ok {
+        // 查找第二层 map
+        return methodHandlers.FindHandler(uri)
+    }
+    return nil
+}
+
+// 初始化 Group
+func (c *Core) Group(prefix string) IGroup {
+    return NewGroup(c, prefix)
 }
 
 // 框架核心结构实现 handler 接口
 func (c *Core) ServeHTTP(response http.ResponseWriter, request *http.Request) {
     log.Println("core.ServeHTTP")
-    ctx := NewContext(request,response)
+    // 封装自定义 context
+    ctx := NewContext(request, response)
 
-    // 一个简单的路由选择器,这里直接写死为测试路由 foo
-    router:= c.router["foo"]
+    // 寻找路由
+    router := c.FindRouteByRequest(request)
     if router == nil {
+        ctx.Json(404, "not found")
         return
     }
-    log.Println("core.router")
 
-    router(ctx)
+    // 调用跌幅函数,如果返回 err 代表存在内部错误,返回 500 状态码
+    if err := router(ctx); err != nil {
+        ctx.Json(500, "inner error")
+        return
+    }
+    log.Println("core.router")
 }

+ 44 - 0
framework/group.go

@@ -0,0 +1,44 @@
+package framework
+
+// IGroup 代表前缀分组
+type IGroup interface {
+    Get(string, ControllerHandler)
+    Post(string, ControllerHandler)
+    Put(string, ControllerHandler)
+    Delete(string, ControllerHandler)
+}
+
+// ==========================================================
+
+type Group struct {
+    core   *Core
+    prefix string
+}
+
+// 初始化 Group
+func NewGroup(core *Core, prefix string) *Group {
+    return &Group{
+        core:   core,
+        prefix: prefix,
+    }
+}
+
+func (g *Group) Get(uri string, handler ControllerHandler) {
+    uri = g.prefix + uri
+    g.core.Get(uri, handler)
+}
+
+func (g *Group) Post(uri string, handler ControllerHandler) {
+    uri = g.prefix + uri
+    g.core.Post(uri, handler)
+}
+
+func (g *Group) Put(uri string, handler ControllerHandler) {
+    uri = g.prefix + uri
+    g.core.Put(uri, handler)
+}
+
+func (g *Group) Delete(uri string, handler ControllerHandler) {
+    uri = g.prefix + uri
+    g.core.Delete(uri, handler)
+}

+ 168 - 0
framework/trie.go

@@ -0,0 +1,168 @@
+package framework
+
+import (
+    "errors"
+    "strings"
+)
+
+type Tree struct {
+    root *node // 根节点
+}
+
+func NewTree() *Tree {
+    return &Tree{root: newNode()}
+}
+
+// 增加路由节点
+/*
+/book/list
+/book/:id (冲突)
+/book/:id/name
+/book/:student/age
+/:user/name
+/:user/name/:age(冲突)
+*/
+func (tree *Tree) AddRouter(uri string, handler ControllerHandler) error {
+    n := tree.root
+    // 确认路由是否冲突
+    if n.matchNode(uri) != nil {
+        return errors.New("route exists: " + uri)
+    }
+
+    segments := strings.Split(uri, "/")
+    for index, segment := range segments {
+        if !isWildSegment(segment) {
+            segment = strings.ToUpper(segment)
+        }
+        isLast := index == len(segments)-1
+
+        var objNode *node
+
+        childNodes := n.filterChildNodes(segment)
+        // 如果有匹配的子节点
+        if len(childNodes) > 0 {
+            // 如果有segment 相同的子节点,则选择这个子节点
+            for _, cNode := range childNodes {
+                if cNode.segment == segment {
+                    objNode = cNode
+                    break
+                }
+            }
+        }
+
+        if objNode == nil {
+            // 创建一个当前 node 的节点
+            cnode := newNode()
+            cnode.segment = segment
+            if isLast {
+                cnode.isLast = true
+                cnode.handler = handler
+            }
+            n.childes = append(n.childes, cnode)
+            objNode = cnode
+        }
+
+        n = objNode
+    }
+
+    return nil
+}
+
+func (tree *Tree) FindHandler(uri string) ControllerHandler {
+    matchNode := tree.root.matchNode(uri)
+    if matchNode == nil {
+        return nil
+    }
+    return matchNode.handler
+}
+
+// =====================================================================================================================
+
+// 代表节点
+type node struct {
+    isLast  bool              // 代表这个节点是否可以成为最终的路由规则。该节点是否能成为一个独立的uri, 是否自身就是一个终极节点
+    segment string            // uri 中的字符串,代表这个节点表示的路由中某个段的字符串
+    handler ControllerHandler // 代表这个节点中包含的控制器,用于最终加载调用
+    childes []*node           // 代表这个节点下的子节点
+}
+
+// ---------------------------------------------------------------------------------------------------------------------
+
+func newNode() *node {
+    return &node{
+        isLast:  false,
+        segment: "",
+        handler: nil,
+        childes: nil,
+    }
+}
+
+// 判断路由是否在节点的所有子节点树中存在了
+func (n *node) matchNode(uri string) *node {
+    // 使用分隔符将uri 切割为两部分
+    segments := strings.SplitN(uri, "/", 2)
+    // 第一个部分用于匹配下一层子节点
+    segment := segments[0]
+    if !isWildSegment(segment) {
+        segment = strings.ToUpper(segment)
+    }
+
+    // 匹配符合的下一层子节点
+    cNodes := n.filterChildNodes(segment)
+    if cNodes == nil || len(cNodes) == 0 {
+        // 如果当前子节点没有一个符合,那么说明这个 uri 一定是之前不存在,直接返回 nil
+        return nil
+    }
+
+    // 如果只有一个segment, 则是最后一个标记
+    if len(segments) == 1 {
+        // 如果segment 已经是最后一个节点,判断这些 cnode 是否有 isLast 标志
+        for _, tn := range cNodes {
+            if tn.isLast {
+                return tn
+            }
+        }
+        // 都不是最后一个节点
+        return nil
+    }
+
+    // 如果有 2 个 segment, 递归每个子节点继续进行查找
+    for _, tn := range cNodes {
+        tnMatch := tn.matchNode(segments[1])
+        if tnMatch != nil {
+            return tnMatch
+        }
+    }
+
+    return nil
+}
+
+// 过滤下一层满足 segment 规则的子节点
+func (n *node) filterChildNodes(segment string) []*node {
+    if len(n.childes) == 0 {
+        return nil
+    }
+
+    // 如果 segment 是通配符,则所有下一层子节点都满足需求
+    if isWildSegment(segment) {
+        return n.childes
+    }
+
+    nodes := make([]*node, 0, len(n.childes))
+    // 过滤所有的下一层子节点
+    for _, cNode := range n.childes {
+        if isWildSegment(cNode.segment) {
+            // 如果下一层子节点有通配符,则满足需求
+            nodes = append(nodes, cNode)
+        } else if cNode.segment == segment {
+            // 如果下一层子节点没有通配符,但文本完全匹配,则满足需求
+            nodes = append(nodes, cNode)
+        }
+    }
+    return nodes
+}
+
+// 判断一个 segment 是否是通用 segment, 即以 :开头
+func isWildSegment(segment string) bool {
+    return strings.HasPrefix(segment, ":")
+}

+ 84 - 0
framework/trie_test.go

@@ -0,0 +1,84 @@
+package framework
+
+import "testing"
+
+func Test_filterChildNodes(t *testing.T) {
+    root := &node{
+        isLast:  false,
+        segment: "",
+        handler: func(*Context) error { return nil },
+        childes: []*node{
+            {
+                isLast:  true,
+                segment: "FOO",
+                handler: func(*Context) error { return nil },
+                childes: nil,
+            },
+            {
+                isLast:  false,
+                segment: ":id",
+                handler: nil,
+                childes: nil,
+            },
+        },
+    }
+
+    {
+        nodes := root.filterChildNodes("FOO")
+        if len(nodes) != 2 {
+            t.Error("foo error")
+        }
+    }
+
+    {
+        nodes := root.filterChildNodes(":foo")
+        if len(nodes) != 2 {
+            t.Error(":foo error")
+        }
+    }
+
+}
+
+func Test_matchNode(t *testing.T) {
+    root := &node{
+        isLast:  false,
+        segment: "",
+        handler: func(*Context) error { return nil },
+        childes: []*node{
+            {
+                isLast:  true,
+                segment: "FOO",
+                handler: nil,
+                childes: []*node{
+                    {
+                        isLast:  true,
+                        segment: "BAR",
+                        handler: func(*Context) error { panic("not implemented") },
+                        childes: []*node{},
+                    },
+                },
+            },
+            {
+                isLast:  true,
+                segment: ":id",
+                handler: nil,
+                childes: nil,
+            },
+        },
+    }
+
+    {
+        node := root.matchNode("foo/bar")
+        if node == nil {
+            t.Error("match normal node error")
+        }
+    }
+
+    {
+        node := root.matchNode("test")
+        if node == nil {
+            t.Error("match test")
+        }
+    }
+
+}

+ 15 - 2
route.go

@@ -4,6 +4,19 @@ import (
     "coredemo/framework"
 )
 
-func registerRouter(core *framework.Core)  {
-    core.Get("foo", FooControllerHandler)
+/**
+路由:制定匹配规则
+*/
+
+func registerRouter(core *framework.Core) {
+
+    core.Get("/user/login", UserLoginController)
+
+    subjectApi := core.Group("/subject")
+    {
+        subjectApi.Delete("/:id", SubjectController)
+        subjectApi.Put("/:id", SubjectController)
+        subjectApi.Get("/:id", SubjectController)
+        subjectApi.Get("/list/all", SubjectController)
+    }
 }

+ 8 - 0
subject_controller.go

@@ -0,0 +1,8 @@
+package main
+
+import "coredemo/framework"
+
+func SubjectController(c *framework.Context) error {
+    c.Json(200, "ok, SubjectController")
+    return nil
+}

+ 8 - 0
user_controller.go

@@ -0,0 +1,8 @@
+package main
+
+import "coredemo/framework"
+
+func UserLoginController(c *framework.Context) error {
+    c.Json(200, "ok, UserLoginController")
+    return nil
+}