Skip to content

Commit f6ef84b

Browse files
authored
Merge pull request json-iterator#172 from olegshaldybin/more-stdlib-compat
Improve stdlib compatibility
2 parents 15d4ad9 + 1c6f5fc commit f6ef84b

File tree

3 files changed

+210
-57
lines changed

3 files changed

+210
-57
lines changed

feature_reflect_native.go

Lines changed: 57 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,9 @@ type intCodec struct {
3232
}
3333

3434
func (codec *intCodec) Decode(ptr unsafe.Pointer, iter *Iterator) {
35-
if iter.ReadNil() {
36-
*((*int)(ptr)) = 0
37-
return
35+
if !iter.ReadNil() {
36+
*((*int)(ptr)) = iter.ReadInt()
3837
}
39-
*((*int)(ptr)) = iter.ReadInt()
4038
}
4139

4240
func (codec *intCodec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -55,11 +53,9 @@ type uintptrCodec struct {
5553
}
5654

5755
func (codec *uintptrCodec) Decode(ptr unsafe.Pointer, iter *Iterator) {
58-
if iter.ReadNil() {
59-
*((*uintptr)(ptr)) = 0
60-
return
56+
if !iter.ReadNil() {
57+
*((*uintptr)(ptr)) = uintptr(iter.ReadUint64())
6158
}
62-
*((*uintptr)(ptr)) = uintptr(iter.ReadUint64())
6359
}
6460

6561
func (codec *uintptrCodec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -78,11 +74,9 @@ type int8Codec struct {
7874
}
7975

8076
func (codec *int8Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
81-
if iter.ReadNil() {
82-
*((*uint8)(ptr)) = 0
83-
return
77+
if !iter.ReadNil() {
78+
*((*int8)(ptr)) = iter.ReadInt8()
8479
}
85-
*((*int8)(ptr)) = iter.ReadInt8()
8680
}
8781

8882
func (codec *int8Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -101,11 +95,9 @@ type int16Codec struct {
10195
}
10296

10397
func (codec *int16Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
104-
if iter.ReadNil() {
105-
*((*int16)(ptr)) = 0
106-
return
98+
if !iter.ReadNil() {
99+
*((*int16)(ptr)) = iter.ReadInt16()
107100
}
108-
*((*int16)(ptr)) = iter.ReadInt16()
109101
}
110102

111103
func (codec *int16Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -124,11 +116,9 @@ type int32Codec struct {
124116
}
125117

126118
func (codec *int32Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
127-
if iter.ReadNil() {
128-
*((*int32)(ptr)) = 0
129-
return
119+
if !iter.ReadNil() {
120+
*((*int32)(ptr)) = iter.ReadInt32()
130121
}
131-
*((*int32)(ptr)) = iter.ReadInt32()
132122
}
133123

134124
func (codec *int32Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -147,11 +137,9 @@ type int64Codec struct {
147137
}
148138

149139
func (codec *int64Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
150-
if iter.ReadNil() {
151-
*((*int64)(ptr)) = 0
152-
return
140+
if !iter.ReadNil() {
141+
*((*int64)(ptr)) = iter.ReadInt64()
153142
}
154-
*((*int64)(ptr)) = iter.ReadInt64()
155143
}
156144

157145
func (codec *int64Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -170,11 +158,10 @@ type uintCodec struct {
170158
}
171159

172160
func (codec *uintCodec) Decode(ptr unsafe.Pointer, iter *Iterator) {
173-
if iter.ReadNil() {
174-
*((*uint)(ptr)) = 0
161+
if !iter.ReadNil() {
162+
*((*uint)(ptr)) = iter.ReadUint()
175163
return
176164
}
177-
*((*uint)(ptr)) = iter.ReadUint()
178165
}
179166

180167
func (codec *uintCodec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -193,11 +180,9 @@ type uint8Codec struct {
193180
}
194181

195182
func (codec *uint8Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
196-
if iter.ReadNil() {
197-
*((*uint8)(ptr)) = 0
198-
return
183+
if !iter.ReadNil() {
184+
*((*uint8)(ptr)) = iter.ReadUint8()
199185
}
200-
*((*uint8)(ptr)) = iter.ReadUint8()
201186
}
202187

203188
func (codec *uint8Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -216,11 +201,9 @@ type uint16Codec struct {
216201
}
217202

218203
func (codec *uint16Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
219-
if iter.ReadNil() {
220-
*((*uint16)(ptr)) = 0
221-
return
204+
if !iter.ReadNil() {
205+
*((*uint16)(ptr)) = iter.ReadUint16()
222206
}
223-
*((*uint16)(ptr)) = iter.ReadUint16()
224207
}
225208

226209
func (codec *uint16Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -239,11 +222,9 @@ type uint32Codec struct {
239222
}
240223

241224
func (codec *uint32Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
242-
if iter.ReadNil() {
243-
*((*uint32)(ptr)) = 0
244-
return
225+
if !iter.ReadNil() {
226+
*((*uint32)(ptr)) = iter.ReadUint32()
245227
}
246-
*((*uint32)(ptr)) = iter.ReadUint32()
247228
}
248229

249230
func (codec *uint32Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -262,11 +243,9 @@ type uint64Codec struct {
262243
}
263244

264245
func (codec *uint64Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
265-
if iter.ReadNil() {
266-
*((*uint64)(ptr)) = 0
267-
return
246+
if !iter.ReadNil() {
247+
*((*uint64)(ptr)) = iter.ReadUint64()
268248
}
269-
*((*uint64)(ptr)) = iter.ReadUint64()
270249
}
271250

272251
func (codec *uint64Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -285,11 +264,9 @@ type float32Codec struct {
285264
}
286265

287266
func (codec *float32Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
288-
if iter.ReadNil() {
289-
*((*float32)(ptr)) = 0
290-
return
267+
if !iter.ReadNil() {
268+
*((*float32)(ptr)) = iter.ReadFloat32()
291269
}
292-
*((*float32)(ptr)) = iter.ReadFloat32()
293270
}
294271

295272
func (codec *float32Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -308,11 +285,9 @@ type float64Codec struct {
308285
}
309286

310287
func (codec *float64Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
311-
if iter.ReadNil() {
312-
*((*float64)(ptr)) = 0
313-
return
288+
if !iter.ReadNil() {
289+
*((*float64)(ptr)) = iter.ReadFloat64()
314290
}
315-
*((*float64)(ptr)) = iter.ReadFloat64()
316291
}
317292

318293
func (codec *float64Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
@@ -352,13 +327,39 @@ type emptyInterfaceCodec struct {
352327
}
353328

354329
func (codec *emptyInterfaceCodec) Decode(ptr unsafe.Pointer, iter *Iterator) {
355-
if iter.ReadNil() {
356-
*((*interface{})(ptr)) = nil
330+
existing := *((*interface{})(ptr))
331+
332+
// Checking for both typed and untyped nil pointers.
333+
if existing != nil &&
334+
reflect.TypeOf(existing).Kind() == reflect.Ptr &&
335+
!reflect.ValueOf(existing).IsNil() {
336+
337+
var ptrToExisting interface{}
338+
for {
339+
elem := reflect.ValueOf(existing).Elem()
340+
if elem.Kind() != reflect.Ptr || elem.IsNil() {
341+
break
342+
}
343+
ptrToExisting = existing
344+
existing = elem.Interface()
345+
}
346+
347+
if iter.ReadNil() {
348+
if ptrToExisting != nil {
349+
nilPtr := reflect.Zero(reflect.TypeOf(ptrToExisting).Elem())
350+
reflect.ValueOf(ptrToExisting).Elem().Set(nilPtr)
351+
} else {
352+
*((*interface{})(ptr)) = nil
353+
}
354+
} else {
355+
iter.ReadVal(existing)
356+
}
357+
357358
return
358359
}
359-
existing := *((*interface{})(ptr))
360-
if existing != nil && reflect.TypeOf(existing).Kind() == reflect.Ptr {
361-
iter.ReadVal(existing)
360+
361+
if iter.ReadNil() {
362+
*((*interface{})(ptr)) = nil
362363
} else {
363364
*((*interface{})(ptr)) = iter.Read()
364365
}

jsoniter_interface_test.go

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,3 +437,124 @@ func Test_marshal_nil_nonempty_interface(t *testing.T) {
437437
should.NoError(err)
438438
should.Equal(nil, obj.Field)
439439
}
440+
441+
func Test_overwrite_interface_ptr_value_with_nil(t *testing.T) {
442+
type Wrapper struct {
443+
Payload interface{} `json:"payload,omitempty"`
444+
}
445+
type Payload struct {
446+
Value int `json:"val,omitempty"`
447+
}
448+
449+
should := require.New(t)
450+
451+
payload := &Payload{}
452+
wrapper := &Wrapper{
453+
Payload: &payload,
454+
}
455+
456+
err := json.Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper)
457+
should.Equal(nil, err)
458+
should.Equal(&payload, wrapper.Payload)
459+
should.Equal(42, (*(wrapper.Payload.(**Payload))).Value)
460+
461+
err = json.Unmarshal([]byte(`{"payload": null}`), &wrapper)
462+
should.Equal(nil, err)
463+
should.Equal(&payload, wrapper.Payload)
464+
should.Equal((*Payload)(nil), payload)
465+
466+
payload = &Payload{}
467+
wrapper = &Wrapper{
468+
Payload: &payload,
469+
}
470+
471+
err = Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper)
472+
should.Equal(nil, err)
473+
should.Equal(&payload, wrapper.Payload)
474+
should.Equal(42, (*(wrapper.Payload.(**Payload))).Value)
475+
476+
err = Unmarshal([]byte(`{"payload": null}`), &wrapper)
477+
should.Equal(nil, err)
478+
should.Equal(&payload, wrapper.Payload)
479+
should.Equal((*Payload)(nil), payload)
480+
}
481+
482+
func Test_overwrite_interface_value_with_nil(t *testing.T) {
483+
type Wrapper struct {
484+
Payload interface{} `json:"payload,omitempty"`
485+
}
486+
type Payload struct {
487+
Value int `json:"val,omitempty"`
488+
}
489+
490+
should := require.New(t)
491+
492+
payload := &Payload{}
493+
wrapper := &Wrapper{
494+
Payload: payload,
495+
}
496+
497+
err := json.Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper)
498+
should.Equal(nil, err)
499+
should.Equal(42, (*(wrapper.Payload.(*Payload))).Value)
500+
501+
err = json.Unmarshal([]byte(`{"payload": null}`), &wrapper)
502+
should.Equal(nil, err)
503+
should.Equal(nil, wrapper.Payload)
504+
should.Equal(42, payload.Value)
505+
506+
payload = &Payload{}
507+
wrapper = &Wrapper{
508+
Payload: payload,
509+
}
510+
511+
err = Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper)
512+
should.Equal(nil, err)
513+
should.Equal(42, (*(wrapper.Payload.(*Payload))).Value)
514+
515+
err = Unmarshal([]byte(`{"payload": null}`), &wrapper)
516+
should.Equal(nil, err)
517+
should.Equal(nil, wrapper.Payload)
518+
should.Equal(42, payload.Value)
519+
}
520+
521+
func Test_unmarshal_into_nil(t *testing.T) {
522+
type Payload struct {
523+
Value int `json:"val,omitempty"`
524+
}
525+
type Wrapper struct {
526+
Payload interface{} `json:"payload,omitempty"`
527+
}
528+
529+
should := require.New(t)
530+
531+
var payload *Payload
532+
wrapper := &Wrapper{
533+
Payload: payload,
534+
}
535+
536+
err := json.Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper)
537+
should.Nil(err)
538+
should.NotNil(wrapper.Payload)
539+
should.Nil(payload)
540+
541+
err = json.Unmarshal([]byte(`{"payload": null}`), &wrapper)
542+
should.Nil(err)
543+
should.Nil(wrapper.Payload)
544+
should.Nil(payload)
545+
546+
payload = nil
547+
wrapper = &Wrapper{
548+
Payload: payload,
549+
}
550+
551+
err = Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper)
552+
should.Nil(err)
553+
should.NotNil(wrapper.Payload)
554+
should.Nil(payload)
555+
556+
err = Unmarshal([]byte(`{"payload": null}`), &wrapper)
557+
should.Nil(err)
558+
should.Nil(wrapper.Payload)
559+
should.Nil(payload)
560+
}

jsoniter_null_test.go

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@ package jsoniter
33
import (
44
"bytes"
55
"encoding/json"
6-
"github.com/stretchr/testify/require"
76
"io"
87
"testing"
8+
9+
"github.com/stretchr/testify/require"
910
)
1011

1112
func Test_read_null(t *testing.T) {
@@ -135,3 +136,33 @@ func Test_encode_nil_array(t *testing.T) {
135136
should.Nil(err)
136137
should.Equal("null", string(output))
137138
}
139+
140+
func Test_decode_nil_num(t *testing.T) {
141+
type TestData struct {
142+
Field int `json:"field"`
143+
}
144+
should := require.New(t)
145+
146+
data1 := []byte(`{"field": 42}`)
147+
data2 := []byte(`{"field": null}`)
148+
149+
// Checking stdlib behavior as well
150+
obj2 := TestData{}
151+
err := json.Unmarshal(data1, &obj2)
152+
should.Equal(nil, err)
153+
should.Equal(42, obj2.Field)
154+
155+
err = json.Unmarshal(data2, &obj2)
156+
should.Equal(nil, err)
157+
should.Equal(42, obj2.Field)
158+
159+
obj := TestData{}
160+
161+
err = Unmarshal(data1, &obj)
162+
should.Equal(nil, err)
163+
should.Equal(42, obj.Field)
164+
165+
err = Unmarshal(data2, &obj)
166+
should.Equal(nil, err)
167+
should.Equal(42, obj.Field)
168+
}

0 commit comments

Comments
 (0)