Skip to content

Commit 68ce834

Browse files
committed
Improve stdlib compatibility
1. Null values for primitive types no longer clear the original value in the destination object. 2. Dereference multiple levels of pointers in the destination interface{} type before unmarshaling into it. This is needed to match stdlib behavior when working with nested interface{} fields. If the destination object is a pointer to interface{} then the incoming nil value should nil out the destination object but keep the reference to that nil value on its parent object. However if the destination object is an interface{} value it should set the reference to nil but keep the original object intact. 3. Correctly handle typed nil decode destinations.
1 parent 9277257 commit 68ce834

File tree

3 files changed

+210
-57
lines changed

3 files changed

+210
-57
lines changed

feature_reflect_native.go

Lines changed: 57 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,9 @@ type intCodec struct {
3232
}
3333

3434
func (codec *intCodec) Decode(ptr unsafe.Pointer, iter *Iterator) {
35-
if iter.ReadNil() {
36-
*((*int)(ptr)) = 0
37-
return
35+
if !iter.ReadNil() {
36+
*((*int)(ptr)) = iter.ReadInt()
3837
}
39-
*((*int)(ptr)) = iter.ReadInt()
4038
}
4139

4240
func (codec *intCodec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -55,11 +53,9 @@ type uintptrCodec struct {
5553
}
5654

5755
func (codec *uintptrCodec) Decode(ptr unsafe.Pointer, iter *Iterator) {
58-
if iter.ReadNil() {
59-
*((*uintptr)(ptr)) = 0
60-
return
56+
if !iter.ReadNil() {
57+
*((*uintptr)(ptr)) = uintptr(iter.ReadUint64())
6158
}
62-
*((*uintptr)(ptr)) = uintptr(iter.ReadUint64())
6359
}
6460

6561
func (codec *uintptrCodec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -78,11 +74,9 @@ type int8Codec struct {
7874
}
7975

8076
func (codec *int8Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
81-
if iter.ReadNil() {
82-
*((*uint8)(ptr)) = 0
83-
return
77+
if !iter.ReadNil() {
78+
*((*int8)(ptr)) = iter.ReadInt8()
8479
}
85-
*((*int8)(ptr)) = iter.ReadInt8()
8680
}
8781

8882
func (codec *int8Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -101,11 +95,9 @@ type int16Codec struct {
10195
}
10296

10397
func (codec *int16Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
104-
if iter.ReadNil() {
105-
*((*int16)(ptr)) = 0
106-
return
98+
if !iter.ReadNil() {
99+
*((*int16)(ptr)) = iter.ReadInt16()
107100
}
108-
*((*int16)(ptr)) = iter.ReadInt16()
109101
}
110102

111103
func (codec *int16Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -124,11 +116,9 @@ type int32Codec struct {
124116
}
125117

126118
func (codec *int32Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
127-
if iter.ReadNil() {
128-
*((*int32)(ptr)) = 0
129-
return
119+
if !iter.ReadNil() {
120+
*((*int32)(ptr)) = iter.ReadInt32()
130121
}
131-
*((*int32)(ptr)) = iter.ReadInt32()
132122
}
133123

134124
func (codec *int32Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -147,11 +137,9 @@ type int64Codec struct {
147137
}
148138

149139
func (codec *int64Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
150-
if iter.ReadNil() {
151-
*((*int64)(ptr)) = 0
152-
return
140+
if !iter.ReadNil() {
141+
*((*int64)(ptr)) = iter.ReadInt64()
153142
}
154-
*((*int64)(ptr)) = iter.ReadInt64()
155143
}
156144

157145
func (codec *int64Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -170,11 +158,10 @@ type uintCodec struct {
170158
}
171159

172160
func (codec *uintCodec) Decode(ptr unsafe.Pointer, iter *Iterator) {
173-
if iter.ReadNil() {
174-
*((*uint)(ptr)) = 0
161+
if !iter.ReadNil() {
162+
*((*uint)(ptr)) = iter.ReadUint()
175163
return
176164
}
177-
*((*uint)(ptr)) = iter.ReadUint()
178165
}
179166

180167
func (codec *uintCodec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -193,11 +180,9 @@ type uint8Codec struct {
193180
}
194181

195182
func (codec *uint8Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
196-
if iter.ReadNil() {
197-
*((*uint8)(ptr)) = 0
198-
return
183+
if !iter.ReadNil() {
184+
*((*uint8)(ptr)) = iter.ReadUint8()
199185
}
200-
*((*uint8)(ptr)) = iter.ReadUint8()
201186
}
202187

203188
func (codec *uint8Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -216,11 +201,9 @@ type uint16Codec struct {
216201
}
217202

218203
func (codec *uint16Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
219-
if iter.ReadNil() {
220-
*((*uint16)(ptr)) = 0
221-
return
204+
if !iter.ReadNil() {
205+
*((*uint16)(ptr)) = iter.ReadUint16()
222206
}
223-
*((*uint16)(ptr)) = iter.ReadUint16()
224207
}
225208

226209
func (codec *uint16Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -239,11 +222,9 @@ type uint32Codec struct {
239222
}
240223

241224
func (codec *uint32Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
242-
if iter.ReadNil() {
243-
*((*uint32)(ptr)) = 0
244-
return
225+
if !iter.ReadNil() {
226+
*((*uint32)(ptr)) = iter.ReadUint32()
245227
}
246-
*((*uint32)(ptr)) = iter.ReadUint32()
247228
}
248229

249230
func (codec *uint32Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -262,11 +243,9 @@ type uint64Codec struct {
262243
}
263244

264245
func (codec *uint64Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
265-
if iter.ReadNil() {
266-
*((*uint64)(ptr)) = 0
267-
return
246+
if !iter.ReadNil() {
247+
*((*uint64)(ptr)) = iter.ReadUint64()
268248
}
269-
*((*uint64)(ptr)) = iter.ReadUint64()
270249
}
271250

272251
func (codec *uint64Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -285,11 +264,9 @@ type float32Codec struct {
285264
}
286265

287266
func (codec *float32Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
288-
if iter.ReadNil() {
289-
*((*float32)(ptr)) = 0
290-
return
267+
if !iter.ReadNil() {
268+
*((*float32)(ptr)) = iter.ReadFloat32()
291269
}
292-
*((*float32)(ptr)) = iter.ReadFloat32()
293270
}
294271

295272
func (codec *float32Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -308,11 +285,9 @@ type float64Codec struct {
308285
}
309286

310287
func (codec *float64Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
311-
if iter.ReadNil() {
312-
*((*float64)(ptr)) = 0
313-
return
288+
if !iter.ReadNil() {
289+
*((*float64)(ptr)) = iter.ReadFloat64()
314290
}
315-
*((*float64)(ptr)) = iter.ReadFloat64()
316291
}
317292

318293
func (codec *float64Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -352,13 +327,39 @@ type emptyInterfaceCodec struct {
352327
}
353328

354329
func (codec *emptyInterfaceCodec) Decode(ptr unsafe.Pointer, iter *Iterator) {
355-
if iter.ReadNil() {
356-
*((*interface{})(ptr)) = nil
330+
existing := *((*interface{})(ptr))
331+
332+
// Checking for both typed and untyped nil pointers.
333+
if existing != nil &&
334+
reflect.TypeOf(existing).Kind() == reflect.Ptr &&
335+
!reflect.ValueOf(existing).IsNil() {
336+
337+
var ptrToExisting interface{}
338+
for {
339+
elem := reflect.ValueOf(existing).Elem()
340+
if elem.Kind() != reflect.Ptr || elem.IsNil() {
341+
break
342+
}
343+
ptrToExisting = existing
344+
existing = elem.Interface()
345+
}
346+
347+
if iter.ReadNil() {
348+
if ptrToExisting != nil {
349+
nilPtr := reflect.Zero(reflect.TypeOf(ptrToExisting).Elem())
350+
reflect.ValueOf(ptrToExisting).Elem().Set(nilPtr)
351+
} else {
352+
*((*interface{})(ptr)) = nil
353+
}
354+
} else {
355+
iter.ReadVal(existing)
356+
}
357+
357358
return
358359
}
359-
existing := *((*interface{})(ptr))
360-
if existing != nil && reflect.TypeOf(existing).Kind() == reflect.Ptr {
361-
iter.ReadVal(existing)
360+
361+
if iter.ReadNil() {
362+
*((*interface{})(ptr)) = nil
362363
} else {
363364
*((*interface{})(ptr)) = iter.Read()
364365
}

jsoniter_interface_test.go

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,3 +370,124 @@ func Test_omitempty_nil_interface(t *testing.T) {
370370
should.Equal(nil, err)
371371
should.Equal(string(js), str)
372372
}
373+
374+
func Test_overwrite_interface_ptr_value_with_nil(t *testing.T) {
375+
type Wrapper struct {
376+
Payload interface{} `json:"payload,omitempty"`
377+
}
378+
type Payload struct {
379+
Value int `json:"val,omitempty"`
380+
}
381+
382+
should := require.New(t)
383+
384+
payload := &Payload{}
385+
wrapper := &Wrapper{
386+
Payload: &payload,
387+
}
388+
389+
err := json.Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper)
390+
should.Equal(nil, err)
391+
should.Equal(&payload, wrapper.Payload)
392+
should.Equal(42, (*(wrapper.Payload.(**Payload))).Value)
393+
394+
err = json.Unmarshal([]byte(`{"payload": null}`), &wrapper)
395+
should.Equal(nil, err)
396+
should.Equal(&payload, wrapper.Payload)
397+
should.Equal((*Payload)(nil), payload)
398+
399+
payload = &Payload{}
400+
wrapper = &Wrapper{
401+
Payload: &payload,
402+
}
403+
404+
err = Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper)
405+
should.Equal(nil, err)
406+
should.Equal(&payload, wrapper.Payload)
407+
should.Equal(42, (*(wrapper.Payload.(**Payload))).Value)
408+
409+
err = Unmarshal([]byte(`{"payload": null}`), &wrapper)
410+
should.Equal(nil, err)
411+
should.Equal(&payload, wrapper.Payload)
412+
should.Equal((*Payload)(nil), payload)
413+
}
414+
415+
func Test_overwrite_interface_value_with_nil(t *testing.T) {
416+
type Wrapper struct {
417+
Payload interface{} `json:"payload,omitempty"`
418+
}
419+
type Payload struct {
420+
Value int `json:"val,omitempty"`
421+
}
422+
423+
should := require.New(t)
424+
425+
payload := &Payload{}
426+
wrapper := &Wrapper{
427+
Payload: payload,
428+
}
429+
430+
err := json.Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper)
431+
should.Equal(nil, err)
432+
should.Equal(42, (*(wrapper.Payload.(*Payload))).Value)
433+
434+
err = json.Unmarshal([]byte(`{"payload": null}`), &wrapper)
435+
should.Equal(nil, err)
436+
should.Equal(nil, wrapper.Payload)
437+
should.Equal(42, payload.Value)
438+
439+
payload = &Payload{}
440+
wrapper = &Wrapper{
441+
Payload: payload,
442+
}
443+
444+
err = Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper)
445+
should.Equal(nil, err)
446+
should.Equal(42, (*(wrapper.Payload.(*Payload))).Value)
447+
448+
err = Unmarshal([]byte(`{"payload": null}`), &wrapper)
449+
should.Equal(nil, err)
450+
should.Equal(nil, wrapper.Payload)
451+
should.Equal(42, payload.Value)
452+
}
453+
454+
func Test_unmarshal_into_nil(t *testing.T) {
455+
type Payload struct {
456+
Value int `json:"val,omitempty"`
457+
}
458+
type Wrapper struct {
459+
Payload interface{} `json:"payload,omitempty"`
460+
}
461+
462+
should := require.New(t)
463+
464+
var payload *Payload
465+
wrapper := &Wrapper{
466+
Payload: payload,
467+
}
468+
469+
err := json.Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper)
470+
should.Nil(err)
471+
should.NotNil(wrapper.Payload)
472+
should.Nil(payload)
473+
474+
err = json.Unmarshal([]byte(`{"payload": null}`), &wrapper)
475+
should.Nil(err)
476+
should.Nil(wrapper.Payload)
477+
should.Nil(payload)
478+
479+
payload = nil
480+
wrapper = &Wrapper{
481+
Payload: payload,
482+
}
483+
484+
err = Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper)
485+
should.Nil(err)
486+
should.NotNil(wrapper.Payload)
487+
should.Nil(payload)
488+
489+
err = Unmarshal([]byte(`{"payload": null}`), &wrapper)
490+
should.Nil(err)
491+
should.Nil(wrapper.Payload)
492+
should.Nil(payload)
493+
}

jsoniter_null_test.go

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@ package jsoniter
33
import (
44
"bytes"
55
"encoding/json"
6-
"github.com/stretchr/testify/require"
76
"io"
87
"testing"
8+
9+
"github.com/stretchr/testify/require"
910
)
1011

1112
func Test_read_null(t *testing.T) {
@@ -135,3 +136,33 @@ func Test_encode_nil_array(t *testing.T) {
135136
should.Nil(err)
136137
should.Equal("null", string(output))
137138
}
139+
140+
func Test_decode_nil_num(t *testing.T) {
141+
type TestData struct {
142+
Field int `json:"field"`
143+
}
144+
should := require.New(t)
145+
146+
data1 := []byte(`{"field": 42}`)
147+
data2 := []byte(`{"field": null}`)
148+
149+
// Checking stdlib behavior as well
150+
obj2 := TestData{}
151+
err := json.Unmarshal(data1, &obj2)
152+
should.Equal(nil, err)
153+
should.Equal(42, obj2.Field)
154+
155+
err = json.Unmarshal(data2, &obj2)
156+
should.Equal(nil, err)
157+
should.Equal(42, obj2.Field)
158+
159+
obj := TestData{}
160+
161+
err = Unmarshal(data1, &obj)
162+
should.Equal(nil, err)
163+
should.Equal(42, obj.Field)
164+
165+
err = Unmarshal(data2, &obj)
166+
should.Equal(nil, err)
167+
should.Equal(42, obj.Field)
168+
}

0 commit comments

Comments
 (0)