package framework import ( "bytes" "context" "encoding/json" "errors" "io/ioutil" "net/http" "strconv" "sync" ) // 自定义 Context type Context struct { request *http.Request responseWriter http.ResponseWriter // 是否超时标记位 hasTimeout bool // 写保护机制 writeMux *sync.Mutex } func NewContext(r *http.Request, w http.ResponseWriter) *Context { return &Context{ request: r, responseWriter: w, writeMux: &sync.Mutex{}, } } // #region base function base 封装基本的函数功能 func (ctx *Context) WriterMux() *sync.Mutex { return ctx.writeMux } func (ctx *Context) GetRequest() *http.Request { return ctx.request } func (ctx *Context) GetResponse() http.ResponseWriter { return ctx.responseWriter } func (ctx *Context) SetHasTimeout() { ctx.hasTimeout = true } func (ctx *Context) HasTimeout() bool { return ctx.hasTimeout } // #endregion func (ctx *Context) BaseContext() context.Context { return ctx.request.Context() } // #region implement context.Context // context 实现标准 Context 接口 func (ctx *Context) Done() <-chan struct{} { return ctx.BaseContext().Done() } func (ctx *Context) Err() error { return ctx.BaseContext().Err() } func (ctx *Context) Value(key interface{}) interface{} { return ctx.BaseContext().Value(key) } // #endregion // #region query url // request 封装了 http.Request 的对外接口 func (ctx *Context) QueryInt(key string, def int) int { params := ctx.QueryAll() if vals, ok := params[key]; ok { len := len(vals) if len > 0 { intVal, err := strconv.Atoi(vals[len-1]) if err != nil { return def } return intVal } } return def } func (ctx *Context) QueryString(key string, def string) string { params := ctx.QueryAll() if vals, ok := params[key]; ok { len := len(vals) if len > 0 { return vals[len-1] } } return def } func (ctx *Context) QueryArray(key string, def []string) []string { params := ctx.QueryAll() if vals, ok := params[key]; ok { return vals } return def } func (ctx *Context) QueryAll() map[string][]string { if ctx.request != nil { return map[string][]string( ctx.request.URL.Query()) } return map[string][]string{} } // #endregion // #region form post func (ctx *Context) FormInt(key string, def int) int { params := ctx.FormAll() if vals, ok := params[key]; ok { len := len(vals) if len > 0 { intVal, err := strconv.Atoi(vals[len-1]) if err != nil { return def } return intVal } } return def } func (ctx *Context) FormString(key string, def string) string { params := ctx.FormAll() if vals, ok := params[key]; ok { len := len(vals) if len > 0 { return vals[len-1] } } return def } func (ctx *Context) FormArray(key string, def []string) []string { params := ctx.FormAll() if vals, ok := params[key]; ok { return vals } return def } func (ctx *Context) FormAll() map[string][]string { if ctx.request != nil { return map[string][]string(ctx.request.PostForm) } return map[string][]string{} } // #endregion // #region application/json post //response 封装了 http.ResponseWriter 对外接口 func (ctx *Context) BindJson(obj interface{}) error { if ctx.request != nil { body, err := ioutil.ReadAll(ctx.request.Body) if err != nil { return err } ctx.request.Body = ioutil.NopCloser(bytes.NewBuffer(body)) err = json.Unmarshal(body, obj) if err != nil { return err } } else { return errors.New("ctx.request empty") } return nil } func (ctx *Context) Json(status int, obj interface{}) error { if ctx.HasTimeout() { return nil } ctx.responseWriter.Header().Set("Content-Type", "application/json") ctx.responseWriter.WriteHeader(status) byt, err := json.Marshal(obj) if err != nil { ctx.responseWriter.WriteHeader(500) return err } ctx.responseWriter.Write(byt) return nil } func (ctx *Context) HTML(status int, obj interface{}, template string) error { return nil } func (ctx *Context) Text(status int , obj string ) error { return nil } // #endregion