forked from gorgonia/tensor
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdense_test.go
116 lines (101 loc) · 3.01 KB
/
dense_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
package tensor
import (
"math/rand"
"testing"
"testing/quick"
"time"
"unsafe"
"github.com/stretchr/testify/assert"
)
func TestDense_ShallowClone(t *testing.T) {
T := New(Of(Float64), WithBacking([]float64{1, 2, 3, 4}))
T2 := T.ShallowClone()
T2.slice(0, 2)
T2.Float64s()[0] = 1000
assert.Equal(t, T.Data().([]float64)[0:2], T2.Data())
assert.Equal(t, T.Engine(), T2.Engine())
assert.Equal(t, T.oe, T2.oe)
assert.Equal(t, T.flag, T2.flag)
}
func TestDense_Clone(t *testing.T) {
assert := assert.New(t)
cloneChk := func(q *Dense) bool {
a := q.Clone().(*Dense)
if !q.Shape().Eq(a.Shape()) {
t.Errorf("Shape Difference: %v %v", q.Shape(), a.Shape())
return false
}
if len(q.Strides()) != len(a.Strides()) {
t.Errorf("Stride Difference: %v %v", q.Strides(), a.Strides())
return false
}
for i, s := range q.Strides() {
if a.Strides()[i] != s {
t.Errorf("Stride Difference: %v %v", q.Strides(), a.Strides())
return false
}
}
if q.o != a.o {
t.Errorf("Data Order difference : %v %v", q.o, a.o)
return false
}
if q.Δ != a.Δ {
t.Errorf("Triangle Difference: %v %v", q.Δ, a.Δ)
return false
}
if q.flag != a.flag {
t.Errorf("Flag difference : %v %v", q.flag, a.flag)
return false
}
if q.e != a.e {
t.Errorf("Engine difference; %T %T", q.e, a.e)
return false
}
if q.oe != a.oe {
t.Errorf("Optimized Engine difference; %T %T", q.oe, a.oe)
return false
}
if len(q.transposeWith) != len(a.transposeWith) {
t.Errorf("TransposeWith difference: %v %v", q.transposeWith, a.transposeWith)
return false
}
assert.Equal(q.mask, a.mask, "mask difference")
assert.Equal(q.maskIsSoft, a.maskIsSoft, "mask is soft ")
return true
}
r := rand.New(rand.NewSource(time.Now().UnixNano()))
if err := quick.Check(cloneChk, &quick.Config{Rand: r}); err != nil {
t.Error(err)
}
}
func TestDenseMasked(t *testing.T) {
T := New(Of(Float64), WithShape(3, 2))
T.ResetMask()
assert.Equal(t, []bool{false, false, false, false, false, false}, T.mask)
}
func TestFromScalar(t *testing.T) {
T := New(FromScalar(3.14))
data := T.Float64s()
assert.Equal(t, []float64{3.14}, data)
}
func TestFromMemory(t *testing.T) {
// dummy memory - this could be an externally malloc'd memory, or a mmap'ed file.
// but here we're just gonna let Go manage memory.
s := make([]float64, 100)
ptr := uintptr(unsafe.Pointer(&s[0]))
size := uintptr(100 * 8)
T := New(Of(Float32), WithShape(50, 4), FromMemory(ptr, size))
if len(T.Float32s()) != 200 {
t.Error("expected 200 Float32s")
}
assert.Equal(t, make([]float32, 200), T.Data())
assert.True(t, T.IsManuallyManaged(), "Unamanged %v |%v | q: %v", ManuallyManaged, T.flag, (T.flag>>ManuallyManaged)&MemoryFlag(1))
fail := func() { New(FromMemory(ptr, size), Of(Float32)) }
assert.Panics(t, fail, "Expected bad New() call to panic")
}
func Test_recycledDense(t *testing.T) {
T := recycledDense(Float64, ScalarShape())
assert.Equal(t, float64(0), T.Data())
assert.Equal(t, StdEng{}, T.e)
assert.Equal(t, StdEng{}, T.oe)
}