forked from gorgonia/tensor
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdense_matop.go
346 lines (292 loc) · 9.02 KB
/
dense_matop.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
package tensor
import "github.com/pkg/errors"
// T performs a thunked transpose. It doesn't actually do anything, except store extra information about the post-transposed shapes and strides
// Usually this is more than enough, as BLAS will handle the rest of the transpose
func (t *Dense) T(axes ...int) (err error) {
var transform AP
if transform, axes, err = t.AP.T(axes...); err != nil {
return handleNoOp(err)
}
// is there any old transposes that need to be done first?
// this is important, because any old transposes for dim >=3 are merely permutations of the strides
if !t.old.IsZero() {
if t.IsVector() {
// the transform that was calculated was a waste of time - return it to the pool then untranspose
t.UT()
return
}
// check if the current axes are just a reverse of the previous transpose's
isReversed := true
for i, s := range t.oshape() {
if transform.Shape()[i] != s {
isReversed = false
break
}
}
// if it is reversed, well, we just restore the backed up one
if isReversed {
t.UT()
return
}
// cool beans. No funny reversals. We'd have to actually do transpose then
t.Transpose()
}
// swap out the old and the new
t.old = t.AP
t.transposeWith = axes
t.AP = transform
return nil
}
// UT is a quick way to untranspose a currently transposed *Dense
// The reason for having this is quite simply illustrated by this problem:
// T = NewTensor(WithShape(2,3,4))
// T.T(1,2,0)
//
// To untranspose that, we'd need to apply a transpose of (2,0,1).
// This means having to keep track and calculate the transposes.
// Instead, here's a helpful convenience function to instantly untranspose any previous transposes.
//
// Nothing will happen if there was no previous transpose
func (t *Dense) UT() {
if !t.old.IsZero() {
ReturnInts(t.transposeWith)
t.AP = t.old
t.old.zeroOnly()
t.transposeWith = nil
}
}
// SafeT is exactly like T(), except it returns a new *Dense. The data is also copied over, unmoved.
func (t *Dense) SafeT(axes ...int) (retVal *Dense, err error) {
var transform AP
if transform, axes, err = t.AP.T(axes...); err != nil {
if err = handleNoOp(err); err != nil {
return
}
}
retVal = recycledDense(t.t, Shape{t.len()})
copyDense(retVal, t)
retVal.e = t.e
retVal.oe = t.oe
retVal.AP = transform
t.AP.CloneTo(&retVal.old)
retVal.transposeWith = axes
return
}
// At returns the value at the given coordinate
func (t *Dense) At(coords ...int) (interface{}, error) {
if !t.IsNativelyAccessible() {
return nil, errors.Errorf(inaccessibleData, t)
}
if len(coords) != t.Dims() {
return nil, errors.Errorf(dimMismatch, t.Dims(), len(coords))
}
at, err := t.at(coords...)
if err != nil {
return nil, errors.Wrap(err, "At()")
}
return t.Get(at), nil
}
// MaskAt returns the value of the mask at a given coordinate
// returns false (valid) if not tensor is not masked
func (t *Dense) MaskAt(coords ...int) (bool, error) {
if !t.IsMasked() {
return false, nil
}
if !t.IsNativelyAccessible() {
return false, errors.Errorf(inaccessibleData, t)
}
if len(coords) != t.Dims() {
return true, errors.Errorf(dimMismatch, t.Dims(), len(coords))
}
at, err := t.maskAt(coords...)
if err != nil {
return true, errors.Wrap(err, "MaskAt()")
}
return t.mask[at], nil
}
// SetAt sets the value at the given coordinate
func (t *Dense) SetAt(v interface{}, coords ...int) error {
if !t.IsNativelyAccessible() {
return errors.Errorf(inaccessibleData, t)
}
if len(coords) != t.Dims() {
return errors.Errorf(dimMismatch, t.Dims(), len(coords))
}
at, err := t.at(coords...)
if err != nil {
return errors.Wrap(err, "SetAt()")
}
t.Set(at, v)
return nil
}
// SetMaskAtDataIndex set the value of the mask at a given index
func (t *Dense) SetMaskAtIndex(v bool, i int) error {
if !t.IsMasked() {
return nil
}
t.mask[i] = v
return nil
}
// SetMaskAt sets the mask value at the given coordinate
func (t *Dense) SetMaskAt(v bool, coords ...int) error {
if !t.IsMasked() {
return nil
}
if !t.IsNativelyAccessible() {
return errors.Errorf(inaccessibleData, t)
}
if len(coords) != t.Dims() {
return errors.Errorf(dimMismatch, t.Dims(), len(coords))
}
at, err := t.maskAt(coords...)
if err != nil {
return errors.Wrap(err, "SetAt()")
}
t.mask[at] = v
return nil
}
// CopyTo copies the underlying data to the destination *Dense. The original data is untouched.
// Note: CopyTo doesn't care about the metadata of the destination *Dense. Take for example:
// T = NewTensor(WithShape(6))
// T2 = NewTensor(WithShape(2,3))
// err = T.CopyTo(T2) // err == nil
//
// The only time that this will fail is if the underlying sizes are different
func (t *Dense) CopyTo(other *Dense) error {
if other == t {
return nil // nothing to copy to. Maybe return NoOpErr?
}
if other.Size() != t.Size() {
return errors.Errorf(sizeMismatch, t.Size(), other.Size())
}
// easy peasy lemon squeezy
if t.viewOf == 0 && other.viewOf == 0 {
copyDense(other, t)
return nil
}
// TODO: use copyDenseIter
return errors.Errorf(methodNYI, "CopyTo", "views")
}
// Slice performs slicing on the *Dense Tensor. It returns a view which shares the same underlying memory as the original *Dense.
//
// Given:
// T = NewTensor(WithShape(2,2), WithBacking(RangeFloat64(0,4)))
// V, _ := T.Slice(nil, singleSlice(1)) // T[:, 1]
//
// Any modification to the values in V, will be reflected in T as well.
//
// The method treats <nil> as equivalent to a colon slice. T.Slice(nil) is equivalent to T[:] in Numpy syntax
func (t *Dense) Slice(slices ...Slice) (retVal View, err error) {
var newAP AP
var ndStart, ndEnd int
if newAP, ndStart, ndEnd, err = t.AP.S(t.len(), slices...); err != nil {
return
}
view := borrowDense()
view.t = t.t
view.e = t.e
view.oe = t.oe
view.flag = t.flag
view.AP = newAP
view.setParentTensor(t)
t.sliceInto(ndStart, ndEnd, &view.array)
if t.IsMasked() {
view.mask = t.mask[ndStart:ndEnd]
}
return view, err
}
// SliceInto is a convenience method. It does NOT copy the values - it simply updates the AP of the view.
// The underlying data is the same.
// This method will override ALL the metadata in view.
func (t *Dense) SliceInto(view *Dense, slices ...Slice) (retVal View, err error) {
var newAP AP
var ndStart, ndEnd int
if newAP, ndStart, ndEnd, err = t.AP.S(t.len(), slices...); err != nil {
return
}
view.AP.zero()
view.array.v = nil // reset
view.t = t.t
view.e = t.e
view.oe = t.oe
view.flag = t.flag
view.AP = newAP
view.setParentTensor(t)
t.sliceInto(ndStart, ndEnd, &view.array)
if t.IsMasked() {
view.mask = t.mask[ndStart:ndEnd]
}
return view, err
}
// RollAxis rolls the axis backwards until it lies in the given position.
//
// This method was adapted from Numpy's Rollaxis. The licence for Numpy is a BSD-like licence and can be found here: https://github.com/numpy/numpy/blob/master/LICENSE.txt
//
// As a result of being adapted from Numpy, the quirks are also adapted. A good guide reducing the confusion around rollaxis can be found here: http://stackoverflow.com/questions/29891583/reason-why-numpy-rollaxis-is-so-confusing (see answer by hpaulj)
func (t *Dense) RollAxis(axis, start int, safe bool) (retVal *Dense, err error) {
dims := t.Dims()
if !(axis >= 0 && axis < dims) {
err = errors.Errorf(invalidAxis, axis, dims)
return
}
if !(start >= 0 && start <= dims) {
err = errors.Wrap(errors.Errorf(invalidAxis, axis, dims), "Start axis is wrong")
return
}
if axis < start {
start--
}
if axis == start {
retVal = t
return
}
axes := BorrowInts(dims)
defer ReturnInts(axes)
for i := 0; i < dims; i++ {
axes[i] = i
}
copy(axes[axis:], axes[axis+1:])
copy(axes[start+1:], axes[start:])
axes[start] = axis
if safe {
return t.SafeT(axes...)
}
err = t.T(axes...)
retVal = t
return
}
/* Private Methods */
// returns the new index given the old index
func (t *Dense) transposeIndex(i int, transposePat, strides []int) int {
oldCoord, err := Itol(i, t.oshape(), t.ostrides())
if err != nil {
err = errors.Wrapf(err, "transposeIndex ItoL failure. i %d original shape %v. original strides %v", i, t.oshape(), t.ostrides())
panic(err)
}
/*
coordss, _ := Permute(transposePat, oldCoord)
coords := coordss[0]
expShape := t.Shape()
index, _ := Ltoi(expShape, strides, coords...)
*/
// The above is the "conceptual" algorithm.
// Too many checks above slows things down, so the below is the "optimized" edition
var index int
for i, axis := range transposePat {
index += oldCoord[axis] * strides[i]
}
return index
}
// at returns the index at which the coordinate is referring to.
// This function encapsulates the addressing of elements in a contiguous block.
// For a 2D ndarray, ndarray.at(i,j) is
// at = ndarray.strides[0]*i + ndarray.strides[1]*j
// This is of course, extensible to any number of dimensions.
func (t *Dense) at(coords ...int) (at int, err error) {
return Ltoi(t.Shape(), t.Strides(), coords...)
}
// maskat returns the mask index at which the coordinate is referring to.
func (t *Dense) maskAt(coords ...int) (at int, err error) {
//TODO: Add check for non-masked tensor
return t.at(coords...)
}