trie.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. package framework
  2. import (
  3. "errors"
  4. "strings"
  5. )
  6. type Tree struct {
  7. root *node // 根节点
  8. }
  9. func NewTree() *Tree {
  10. return &Tree{root: newNode()}
  11. }
  12. // 增加路由节点
  13. /*
  14. /book/list
  15. /book/:id (冲突)
  16. /book/:id/name
  17. /book/:student/age
  18. /:user/name
  19. /:user/name/:age(冲突)
  20. */
  21. func (tree *Tree) AddRouter(uri string, handler ControllerHandler) error {
  22. n := tree.root
  23. // 确认路由是否冲突
  24. if n.matchNode(uri) != nil {
  25. return errors.New("route exists: " + uri)
  26. }
  27. segments := strings.Split(uri, "/")
  28. for index, segment := range segments {
  29. if !isWildSegment(segment) {
  30. segment = strings.ToUpper(segment)
  31. }
  32. isLast := index == len(segments)-1
  33. var objNode *node
  34. childNodes := n.filterChildNodes(segment)
  35. // 如果有匹配的子节点
  36. if len(childNodes) > 0 {
  37. // 如果有segment 相同的子节点,则选择这个子节点
  38. for _, cNode := range childNodes {
  39. if cNode.segment == segment {
  40. objNode = cNode
  41. break
  42. }
  43. }
  44. }
  45. if objNode == nil {
  46. // 创建一个当前 node 的节点
  47. cnode := newNode()
  48. cnode.segment = segment
  49. if isLast {
  50. cnode.isLast = true
  51. cnode.handler = handler
  52. }
  53. n.childes = append(n.childes, cnode)
  54. objNode = cnode
  55. }
  56. n = objNode
  57. }
  58. return nil
  59. }
  60. func (tree *Tree) FindHandler(uri string) ControllerHandler {
  61. matchNode := tree.root.matchNode(uri)
  62. if matchNode == nil {
  63. return nil
  64. }
  65. return matchNode.handler
  66. }
  67. // =====================================================================================================================
  68. // 代表节点
  69. type node struct {
  70. isLast bool // 代表这个节点是否可以成为最终的路由规则。该节点是否能成为一个独立的uri, 是否自身就是一个终极节点
  71. segment string // uri 中的字符串,代表这个节点表示的路由中某个段的字符串
  72. handler ControllerHandler // 代表这个节点中包含的控制器,用于最终加载调用
  73. childes []*node // 代表这个节点下的子节点
  74. }
  75. // ---------------------------------------------------------------------------------------------------------------------
  76. func newNode() *node {
  77. return &node{
  78. isLast: false,
  79. segment: "",
  80. handler: nil,
  81. childes: nil,
  82. }
  83. }
  84. // 判断路由是否在节点的所有子节点树中存在了
  85. func (n *node) matchNode(uri string) *node {
  86. // 使用分隔符将uri 切割为两部分
  87. segments := strings.SplitN(uri, "/", 2)
  88. // 第一个部分用于匹配下一层子节点
  89. segment := segments[0]
  90. if !isWildSegment(segment) {
  91. segment = strings.ToUpper(segment)
  92. }
  93. // 匹配符合的下一层子节点
  94. cNodes := n.filterChildNodes(segment)
  95. if cNodes == nil || len(cNodes) == 0 {
  96. // 如果当前子节点没有一个符合,那么说明这个 uri 一定是之前不存在,直接返回 nil
  97. return nil
  98. }
  99. // 如果只有一个segment, 则是最后一个标记
  100. if len(segments) == 1 {
  101. // 如果segment 已经是最后一个节点,判断这些 cnode 是否有 isLast 标志
  102. for _, tn := range cNodes {
  103. if tn.isLast {
  104. return tn
  105. }
  106. }
  107. // 都不是最后一个节点
  108. return nil
  109. }
  110. // 如果有 2 个 segment, 递归每个子节点继续进行查找
  111. for _, tn := range cNodes {
  112. tnMatch := tn.matchNode(segments[1])
  113. if tnMatch != nil {
  114. return tnMatch
  115. }
  116. }
  117. return nil
  118. }
  119. // 过滤下一层满足 segment 规则的子节点
  120. func (n *node) filterChildNodes(segment string) []*node {
  121. if len(n.childes) == 0 {
  122. return nil
  123. }
  124. // 如果 segment 是通配符,则所有下一层子节点都满足需求
  125. if isWildSegment(segment) {
  126. return n.childes
  127. }
  128. nodes := make([]*node, 0, len(n.childes))
  129. // 过滤所有的下一层子节点
  130. for _, cNode := range n.childes {
  131. if isWildSegment(cNode.segment) {
  132. // 如果下一层子节点有通配符,则满足需求
  133. nodes = append(nodes, cNode)
  134. } else if cNode.segment == segment {
  135. // 如果下一层子节点没有通配符,但文本完全匹配,则满足需求
  136. nodes = append(nodes, cNode)
  137. }
  138. }
  139. return nodes
  140. }
  141. // 判断一个 segment 是否是通用 segment, 即以 :开头
  142. func isWildSegment(segment string) bool {
  143. return strings.HasPrefix(segment, ":")
  144. }