forked from gorgonia/tensor
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapi_utils.go
125 lines (116 loc) · 1.99 KB
/
api_utils.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
package tensor
import (
"log"
"math"
"math/rand"
"reflect"
"sort"
"github.com/chewxy/math32"
)
// SortIndex is similar to numpy's argsort
// TODO: tidy this up
func SortIndex(in interface{}) (out []int) {
switch list := in.(type) {
case []int:
orig := make([]int, len(list))
out = make([]int, len(list))
copy(orig, list)
sort.Ints(list)
for i, s := range list {
for j, o := range orig {
if o == s {
out[i] = j
break
}
}
}
case []float64:
orig := make([]float64, len(list))
out = make([]int, len(list))
copy(orig, list)
sort.Float64s(list)
for i, s := range list {
for j, o := range orig {
if o == s {
out[i] = j
break
}
}
}
case sort.Interface:
sort.Sort(list)
log.Printf("TODO: SortIndex for sort.Interface not yet done.")
}
return
}
// SampleIndex samples a slice or a Tensor.
// TODO: tidy this up.
func SampleIndex(in interface{}) int {
// var l int
switch list := in.(type) {
case []int:
var sum, i int
// l = len(list)
r := rand.Int()
for {
sum += list[i]
if sum > r && i > 0 {
return i
}
i++
}
case []float64:
var sum float64
var i int
// l = len(list)
r := rand.Float64()
for {
sum += list[i]
if sum > r && i > 0 {
return i
}
i++
}
case *Dense:
var i int
switch list.t.Kind() {
case reflect.Float64:
var sum float64
r := rand.Float64()
data := list.Float64s()
// l = len(data)
for {
datum := data[i]
if math.IsNaN(datum) || math.IsInf(datum, 0) {
return i
}
sum += datum
if sum > r && i > 0 {
return i
}
i++
}
case reflect.Float32:
var sum float32
r := rand.Float32()
data := list.Float32s()
// l = len(data)
for {
datum := data[i]
if math32.IsNaN(datum) || math32.IsInf(datum, 0) {
return i
}
sum += datum
if sum > r && i > 0 {
return i
}
i++
}
default:
panic("not yet implemented")
}
default:
panic("Not yet implemented")
}
return -1
}