diff --git a/decoder.go b/decoder.go index 883b2f8..9a1304e 100644 --- a/decoder.go +++ b/decoder.go @@ -1,14 +1,12 @@ package render import ( - "encoding/json" "encoding/xml" "errors" + "github.com/ajg/form" "io" "io/ioutil" "net/http" - - "github.com/ajg/form" ) // Decode is a package-level variable set to our default Decoder. We do this @@ -41,7 +39,7 @@ func DefaultDecoder(r *http.Request, v interface{}) error { // DecodeJSON decodes a given reader into an interface using the json decoder. func DecodeJSON(r io.Reader, v interface{}) error { defer io.Copy(ioutil.Discard, r) //nolint:errcheck - return json.NewDecoder(r).Decode(v) + return jsonMarshaller.Decode(r, v) } // DecodeXML decodes a given reader into an interface using the xml decoder. diff --git a/marshaller.go b/marshaller.go new file mode 100644 index 0000000..14d57a4 --- /dev/null +++ b/marshaller.go @@ -0,0 +1,46 @@ +package render + +import ( + "encoding/json" + "io" +) + +var jsonMarshaller Marshaller = jsonDefaultMarshaller{} + +type Marshaller interface { + Marshal(v interface{}) ([]byte, error) + Unmarshall(data []byte, v interface{}) error + NewEncoder(w io.Writer) Encoder + Decode(r io.Reader, v interface{}) error +} + +type Encoder interface { + SetEscapeHTML(on bool) + Encode(v interface{}) error +} + +type jsonDefaultMarshaller struct{} + +func (j jsonDefaultMarshaller) NewEncoder(w io.Writer) Encoder { + return json.NewEncoder(w) +} + +func (j jsonDefaultMarshaller) Decode(r io.Reader, v interface{}) error { + return json.NewDecoder(r).Decode(v) +} + +func (j jsonDefaultMarshaller) Marshal(v interface{}) ([]byte, error) { + return json.Marshal(v) +} + +func (j jsonDefaultMarshaller) Unmarshall(data []byte, v interface{}) error { + return json.Unmarshal(data, v) +} + +func SetJsonMarshaller(m Marshaller) { + if m == nil { + return + } + + jsonMarshaller = m +} diff --git a/responder.go b/responder.go index 66d6bbf..10e4fc0 100644 --- a/responder.go +++ b/responder.go @@ -3,7 +3,6 @@ package render import ( "bytes" "context" - "encoding/json" "encoding/xml" "fmt" "net/http" @@ -92,7 +91,7 @@ func HTML(w http.ResponseWriter, r *http.Request, v string) { // Content-Type as application/json. func JSON(w http.ResponseWriter, r *http.Request, v interface{}) { buf := &bytes.Buffer{} - enc := json.NewEncoder(buf) + enc := jsonMarshaller.NewEncoder(buf) enc.SetEscapeHTML(true) if err := enc.Encode(v); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -182,7 +181,7 @@ func channelEventStream(w http.ResponseWriter, r *http.Request, v interface{}) { } } - bytes, err := json.Marshal(v) + bytes, err := jsonMarshaller.Marshal(v) if err != nil { w.Write([]byte(fmt.Sprintf("event: error\ndata: {\"error\":\"%v\"}\n\n", err))) //nolint:errcheck if f, ok := w.(http.Flusher); ok {