]> git.feebdaed.xyz Git - 0xmirror/go.git/commitdiff
cmd/compile: more generated equality function tests
authorkhr@golang.org <khr@golang.org>
Mon, 1 Dec 2025 18:32:23 +0000 (10:32 -0800)
committerGopher Robot <gobot@golang.org>
Tue, 2 Dec 2025 20:32:35 +0000 (12:32 -0800)
Change-Id: I05cd103ea8d8bbee0ad907ff3e594de273d49e86
Reviewed-on: https://go-review.googlesource.com/c/go/+/725600
Reviewed-by: Keith Randall <khr@google.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Auto-Submit: Keith Randall <khr@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Cuong Manh Le <cuong.manhle.vn@gmail.com>
src/cmd/compile/internal/test/eq_test.go [new file with mode: 0644]

diff --git a/src/cmd/compile/internal/test/eq_test.go b/src/cmd/compile/internal/test/eq_test.go
new file mode 100644 (file)
index 0000000..336ec01
--- /dev/null
@@ -0,0 +1,298 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Tests of generated equality functions.
+
+package test
+
+import (
+       "reflect"
+       "testing"
+       "unsafe"
+)
+
+//go:noinline
+func checkEq(t *testing.T, x, y any) {
+       // Make sure we don't inline the equality test.
+       if x != y {
+               t.Errorf("%#v != %#v, wanted equal", x, y)
+       }
+}
+
+//go:noinline
+func checkNe(t *testing.T, x, y any) {
+       // Make sure we don't inline the equality test.
+       if x == y {
+               t.Errorf("%#v == %#v, wanted not equal", x, y)
+       }
+}
+
+//go:noinline
+func checkPanic(t *testing.T, x, y any) {
+       defer func() {
+               if recover() == nil {
+                       t.Errorf("%#v == %#v didn't panic", x, y)
+               }
+       }()
+       _ = x == y
+}
+
+type fooComparable struct {
+       x int
+}
+
+func (f fooComparable) foo() {
+}
+
+type fooIncomparable struct {
+       b func()
+}
+
+func (i fooIncomparable) foo() {
+}
+
+type eqResult int
+
+const (
+       eq eqResult = iota
+       ne
+       panic_
+)
+
+func (x eqResult) String() string {
+       return []string{eq: "eq", ne: "ne", panic_: "panic"}[x]
+}
+
+// testEq returns eq if x==y, ne if x!=y, or panic_ if the comparison panics.
+func testEq(x, y any) (r eqResult) {
+       defer func() {
+               if e := recover(); e != nil {
+                       r = panic_
+               }
+       }()
+       r = ne
+       if x == y {
+               r = eq
+       }
+       return
+}
+
+// testCompare make two instances of struct type typ, then
+// assigns its len(vals) fields one value from each slice in vals.
+// Then it checks the results against a "manual" comparison field
+// by field.
+func testCompare(t *testing.T, typ reflect.Type, vals [][]any) {
+       if len(vals) != typ.NumField() {
+               t.Fatalf("bad test, have %d fields in the list, but %d fields in the type", len(vals), typ.NumField())
+       }
+
+       x := reflect.New(typ).Elem()
+       y := reflect.New(typ).Elem()
+       ps := powerSet(vals)    // all possible settings of fields of the test type.
+       for _, xf := range ps { // Pick fields for x
+               for _, yf := range ps { // Pick fields for y
+                       // Make x and y from their chosen fields.
+                       for i, f := range xf {
+                               x.Field(i).Set(reflect.ValueOf(f))
+                       }
+                       for i, f := range yf {
+                               y.Field(i).Set(reflect.ValueOf(f))
+                       }
+                       // Compute what we want the result to be.
+                       want := eq
+                       for i := range len(vals) {
+                               if c := testEq(xf[i], yf[i]); c != eq {
+                                       want = c
+                                       break
+                               }
+                       }
+                       // Compute actual result using generated equality function.
+                       got := testEq(x.Interface(), y.Interface())
+                       if got != want {
+                               t.Errorf("%#v == %#v, got %s want %s\n", x, y, got, want)
+                       }
+               }
+       }
+}
+
+// powerset returns all possible sequences of choosing one
+// element from each entry in s.
+// For instance, if s = {{1,2}, {a,b}}, then
+// it returns {{1,a},{1,b},{2,a},{2,b}}.
+func powerSet(s [][]any) [][]any {
+       if len(s) == 0 {
+               return [][]any{{}}
+       }
+       p := powerSet(s[:len(s)-1]) // powerset from first len(s)-1 entries
+       var r [][]any
+       for _, head := range p {
+               // add one more entry.
+               for _, v := range s[len(s)-1] {
+                       x := make([]any, 0, len(s))
+                       x = append(x, head...)
+                       x = append(x, v)
+                       r = append(r, x)
+               }
+       }
+       return r
+}
+
+func TestCompareKinds1(t *testing.T) {
+       type S struct {
+               X0 int8
+               X1 int16
+               X2 int32
+               X3 int64
+               X4 float32
+               X5 float64
+       }
+       testCompare(t, reflect.TypeOf(S{}), [][]any{
+               {int8(0), int8(1)},
+               {int16(0), int16(1), int16(1 << 14)},
+               {int32(0), int32(1), int32(1 << 30)},
+               {int64(0), int64(1), int64(1 << 62)},
+               {float32(0), float32(1.0)},
+               {0.0, 1.0},
+       })
+}
+func TestCompareKinds2(t *testing.T) {
+       type S struct {
+               X0 uint8
+               X1 uint16
+               X2 uint32
+               X3 uint64
+               X4 uintptr
+               X5 bool
+       }
+       testCompare(t, reflect.TypeOf(S{}), [][]any{
+               {uint8(0), uint8(1)},
+               {uint16(0), uint16(1), uint16(1 << 15)},
+               {uint32(0), uint32(1), uint32(1 << 31)},
+               {uint64(0), uint64(1), uint64(1 << 63)},
+               {uintptr(0), uintptr(1)},
+               {false, true},
+       })
+}
+func TestCompareKinds3(t *testing.T) {
+       type S struct {
+               X0 complex64
+               X1 complex128
+               X2 *byte
+               X3 chan int
+               X4 unsafe.Pointer
+       }
+       testCompare(t, reflect.TypeOf(S{}), [][]any{
+               {complex64(1 + 1i), complex64(1 + 2i), complex64(2 + 1i)},
+               {complex128(1 + 1i), complex128(1 + 2i), complex128(2 + 1i)},
+               {new(byte), new(byte)},
+               {make(chan int), make(chan int)},
+               {unsafe.Pointer(new(byte)), unsafe.Pointer(new(byte))},
+       })
+}
+
+func TestCompareOrdering(t *testing.T) {
+       type S struct {
+               A string
+               E any
+               B string
+       }
+
+       testCompare(t, reflect.TypeOf(S{}), [][]any{
+               {"a", "b", "cc"},
+               {3, []byte{0}, []byte{1}},
+               {"a", "b", "cc"},
+       })
+}
+func TestCompareInterfaces(t *testing.T) {
+       type S struct {
+               A any
+               B fooer
+       }
+       testCompare(t, reflect.TypeOf(S{}), [][]any{
+               {3, []byte{0}},
+               {fooComparable{x: 3}, fooIncomparable{b: nil}},
+       })
+}
+
+func TestCompareSkip(t *testing.T) {
+       type S struct {
+               A int8
+               B int16
+       }
+       type S2 struct {
+               A       int8
+               padding int8
+               B       int16
+       }
+       x := S{A: 1, B: 3}
+       y := S{A: 1, B: 3}
+       (*S2)(unsafe.Pointer(&x)).padding = 88
+       (*S2)(unsafe.Pointer(&y)).padding = 99
+
+       want := eq
+       if got := testEq(x, y); got != want {
+               t.Errorf("%#v == %#v, got %s want %s", x, y, got, want)
+       }
+}
+
+func TestCompareMemequal(t *testing.T) {
+       type S struct {
+               s1 string
+               d  [100]byte
+               s2 string
+       }
+       var x, y S
+
+       checkEq(t, x, y)
+       y.d[0] = 1
+       checkNe(t, x, y)
+       y.d[0] = 0
+       y.d[99] = 1
+       checkNe(t, x, y)
+}
+
+func TestComparePanic(t *testing.T) {
+       type S struct {
+               X0 string
+               X1 any
+               X2 string
+               X3 fooer
+               X4 string
+       }
+       testCompare(t, reflect.TypeOf(S{}), [][]any{
+               {"a", "b", "cc"}, // length equal, as well as length unequal
+               {3, []byte{1}},   // comparable and incomparable
+               {"a", "b", "cc"}, // length equal, as well as length unequal
+               {fooComparable{x: 3}, fooIncomparable{b: nil}}, // comparable and incomparable
+               {"a", "b", "cc"}, // length equal, as well as length unequal
+       })
+}
+
+func TestCompareArray(t *testing.T) {
+       type S struct {
+               X0 string
+               X1 [100]string
+               X2 string
+       }
+       x := S{X0: "a", X2: "b"}
+       y := x
+       checkEq(t, x, y)
+       x.X0 = "c"
+       checkNe(t, x, y)
+       x.X0 = "a"
+       x.X2 = "c"
+       checkNe(t, x, y)
+       x.X2 = "b"
+       checkEq(t, x, y)
+
+       for i := 0; i < 100; i++ {
+               x.X1[i] = "d"
+               checkNe(t, x, y)
+               y.X1[i] = "e"
+               checkNe(t, x, y)
+               x.X1[i] = ""
+               y.X1[i] = ""
+               checkEq(t, x, y)
+       }
+}