context.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. package framework
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "errors"
  7. "io/ioutil"
  8. "net/http"
  9. "strconv"
  10. "sync"
  11. )
  12. // 自定义 Context
  13. type Context struct {
  14. request *http.Request
  15. responseWriter http.ResponseWriter
  16. // 是否超时标记位
  17. hasTimeout bool
  18. // 写保护机制
  19. writeMux *sync.Mutex
  20. }
  21. func NewContext(r *http.Request, w http.ResponseWriter) *Context {
  22. return &Context{
  23. request: r,
  24. responseWriter: w,
  25. writeMux: &sync.Mutex{},
  26. }
  27. }
  28. // #region base function base 封装基本的函数功能
  29. func (ctx *Context) WriterMux() *sync.Mutex {
  30. return ctx.writeMux
  31. }
  32. func (ctx *Context) GetRequest() *http.Request {
  33. return ctx.request
  34. }
  35. func (ctx *Context) GetResponse() http.ResponseWriter {
  36. return ctx.responseWriter
  37. }
  38. func (ctx *Context) SetHasTimeout() {
  39. ctx.hasTimeout = true
  40. }
  41. func (ctx *Context) HasTimeout() bool {
  42. return ctx.hasTimeout
  43. }
  44. // #endregion
  45. func (ctx *Context) BaseContext() context.Context {
  46. return ctx.request.Context()
  47. }
  48. // #region implement context.Context
  49. // context 实现标准 Context 接口
  50. func (ctx *Context) Done() <-chan struct{} {
  51. return ctx.BaseContext().Done()
  52. }
  53. func (ctx *Context) Err() error {
  54. return ctx.BaseContext().Err()
  55. }
  56. func (ctx *Context) Value(key interface{}) interface{} {
  57. return ctx.BaseContext().Value(key)
  58. }
  59. // #endregion
  60. // #region query url
  61. // request 封装了 http.Request 的对外接口
  62. func (ctx *Context) QueryInt(key string, def int) int {
  63. params := ctx.QueryAll()
  64. if vals, ok := params[key]; ok {
  65. len := len(vals)
  66. if len > 0 {
  67. intVal, err := strconv.Atoi(vals[len-1])
  68. if err != nil {
  69. return def
  70. }
  71. return intVal
  72. }
  73. }
  74. return def
  75. }
  76. func (ctx *Context) QueryString(key string, def string) string {
  77. params := ctx.QueryAll()
  78. if vals, ok := params[key]; ok {
  79. len := len(vals)
  80. if len > 0 {
  81. return vals[len-1]
  82. }
  83. }
  84. return def
  85. }
  86. func (ctx *Context) QueryArray(key string, def []string) []string {
  87. params := ctx.QueryAll()
  88. if vals, ok := params[key]; ok {
  89. return vals
  90. }
  91. return def
  92. }
  93. func (ctx *Context) QueryAll() map[string][]string {
  94. if ctx.request != nil {
  95. return map[string][]string( ctx.request.URL.Query())
  96. }
  97. return map[string][]string{}
  98. }
  99. // #endregion
  100. // #region form post
  101. func (ctx *Context) FormInt(key string, def int) int {
  102. params := ctx.FormAll()
  103. if vals, ok := params[key]; ok {
  104. len := len(vals)
  105. if len > 0 {
  106. intVal, err := strconv.Atoi(vals[len-1])
  107. if err != nil {
  108. return def
  109. }
  110. return intVal
  111. }
  112. }
  113. return def
  114. }
  115. func (ctx *Context) FormString(key string, def string) string {
  116. params := ctx.FormAll()
  117. if vals, ok := params[key]; ok {
  118. len := len(vals)
  119. if len > 0 {
  120. return vals[len-1]
  121. }
  122. }
  123. return def
  124. }
  125. func (ctx *Context) FormArray(key string, def []string) []string {
  126. params := ctx.FormAll()
  127. if vals, ok := params[key]; ok {
  128. return vals
  129. }
  130. return def
  131. }
  132. func (ctx *Context) FormAll() map[string][]string {
  133. if ctx.request != nil {
  134. return map[string][]string(ctx.request.PostForm)
  135. }
  136. return map[string][]string{}
  137. }
  138. // #endregion
  139. // #region application/json post
  140. //response 封装了 http.ResponseWriter 对外接口
  141. func (ctx *Context) BindJson(obj interface{}) error {
  142. if ctx.request != nil {
  143. body, err := ioutil.ReadAll(ctx.request.Body)
  144. if err != nil {
  145. return err
  146. }
  147. ctx.request.Body = ioutil.NopCloser(bytes.NewBuffer(body))
  148. err = json.Unmarshal(body, obj)
  149. if err != nil {
  150. return err
  151. }
  152. } else {
  153. return errors.New("ctx.request empty")
  154. }
  155. return nil
  156. }
  157. func (ctx *Context) Json(status int, obj interface{}) error {
  158. if ctx.HasTimeout() {
  159. return nil
  160. }
  161. ctx.responseWriter.Header().Set("Content-Type", "application/json")
  162. ctx.responseWriter.WriteHeader(status)
  163. byt, err := json.Marshal(obj)
  164. if err != nil {
  165. ctx.responseWriter.WriteHeader(500)
  166. return err
  167. }
  168. ctx.responseWriter.Write(byt)
  169. return nil
  170. }
  171. func (ctx *Context) HTML(status int, obj interface{}, template string) error {
  172. return nil
  173. }
  174. func (ctx *Context) Text(status int , obj string ) error {
  175. return nil
  176. }
  177. // #endregion