Skip to content

Commit e14a11f

Browse files
Use reflect.Value.Pointer() to compare pointers
Fixes #1076
1 parent 2fc4e39 commit e14a11f

File tree

2 files changed

+79
-17
lines changed

2 files changed

+79
-17
lines changed

assert/assertions.go

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -491,12 +491,12 @@ func validateEqualArgs(expected, actual interface{}) error {
491491
return nil
492492
}
493493

494-
// Same asserts that two pointers reference the same object.
494+
// Same asserts that two arguments reference the same object.
495495
//
496-
// assert.Same(t, ptr1, ptr2)
496+
// assert.Same(t, arg1, arg2)
497497
//
498-
// Both arguments must be pointer variables. Pointer variable sameness is
499-
// determined based on the equality of both type and value.
498+
// Both arguments can be pointers, channels, functions, maps, slices or strings.
499+
// Argument sameness is determined based on the equality of both type and value.
500500
func Same(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool {
501501
if h, ok := t.(tHelper); ok {
502502
h.Helper()
@@ -511,12 +511,12 @@ func Same(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) b
511511
return true
512512
}
513513

514-
// NotSame asserts that two pointers do not reference the same object.
514+
// NotSame asserts that two arguments do not reference the same object.
515515
//
516-
// assert.NotSame(t, ptr1, ptr2)
516+
// assert.NotSame(t, arg1, arg2)
517517
//
518-
// Both arguments must be pointer variables. Pointer variable sameness is
519-
// determined based on the equality of both type and value.
518+
// Both arguments can be pointers, channels, functions, maps, slices or strings.
519+
// Argument sameness is determined based on the equality of both type and value.
520520
func NotSame(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool {
521521
if h, ok := t.(tHelper); ok {
522522
h.Helper()
@@ -534,17 +534,13 @@ func NotSame(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}
534534
// they point to the same object
535535
func samePointers(first, second interface{}) bool {
536536
firstPtr, secondPtr := reflect.ValueOf(first), reflect.ValueOf(second)
537-
if firstPtr.Kind() != reflect.Ptr || secondPtr.Kind() != reflect.Ptr {
538-
return false
539-
}
540537

541-
firstType, secondType := reflect.TypeOf(first), reflect.TypeOf(second)
542-
if firstType != secondType {
538+
switch firstPtr.Kind() {
539+
case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.Slice, reflect.String:
540+
return firstPtr.Kind() == secondPtr.Kind() && firstPtr.Pointer() == secondPtr.Pointer()
541+
default:
543542
return false
544543
}
545-
546-
// compare pointer addresses
547-
return first == second
548544
}
549545

550546
// formatUnequalValues takes two values of arbitrary types and returns string

assert/assertions_test.go

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,12 @@ func TestNotSame(t *testing.T) {
599599

600600
func Test_samePointers(t *testing.T) {
601601
p := ptr(2)
602+
c1, c2 := make(chan int), make(chan int)
603+
f1, f2 := func() {}, func() {}
604+
m1, m2 := map[int]int{1: 2}, map[int]int{1: 2}
605+
p1, p2 := ptr(3), ptr(3)
606+
s1, s2 := []int{4, 5}, []int{4, 5}
607+
str1, str2 := "6", "6_"[:1] // ensure strings use different backing arrays
602608

603609
type args struct {
604610
first interface{}
@@ -634,6 +640,66 @@ func Test_samePointers(t *testing.T) {
634640
args: args{first: [2]int{1, 2}, second: []int{1, 2}},
635641
assertion: False,
636642
},
643+
{
644+
name: "chan(1) == chan(1)",
645+
args: args{first: c1, second: c1},
646+
assertion: True,
647+
},
648+
{
649+
name: "func(1) == func(1)",
650+
args: args{first: f1, second: f1},
651+
assertion: True,
652+
},
653+
{
654+
name: "map(1) == map(1)",
655+
args: args{first: m1, second: m1},
656+
assertion: True,
657+
},
658+
{
659+
name: "ptr(1) == ptr(1)",
660+
args: args{first: p1, second: p1},
661+
assertion: True,
662+
},
663+
{
664+
name: "slice(1) == slice(1)",
665+
args: args{first: s1, second: s1},
666+
assertion: True,
667+
},
668+
{
669+
name: "string(1) == string(1)",
670+
args: args{first: str1, second: str1},
671+
assertion: True,
672+
},
673+
{
674+
name: "chan(1) != chan(2)",
675+
args: args{first: c1, second: c2},
676+
assertion: False,
677+
},
678+
{
679+
name: "func(1) != func(2)",
680+
args: args{first: f1, second: f2},
681+
assertion: False,
682+
},
683+
{
684+
name: "map(1) != map(2)",
685+
args: args{first: m1, second: m2},
686+
assertion: False,
687+
},
688+
{
689+
name: "ptr(1) != ptr(2)",
690+
args: args{first: p1, second: p2},
691+
assertion: False,
692+
},
693+
{
694+
name: "slice(1) != slice(2)",
695+
args: args{first: s1, second: s2},
696+
assertion: False,
697+
},
698+
{
699+
name: "string(1) != string(2)",
700+
args: args{first: str1, second: str2},
701+
assertion: False,
702+
},
637703
}
638704
for _, tt := range tests {
639705
t.Run(tt.name, func(t *testing.T) {
@@ -2505,7 +2571,7 @@ Diff:
25052571
@@ -1,2 +1,2 @@
25062572
-(time.Time) 2020-09-24 00:00:00 +0000 UTC
25072573
+(time.Time) 2020-09-25 00:00:00 +0000 UTC
2508-
2574+
25092575
`
25102576

25112577
actual = diff(

0 commit comments

Comments
 (0)