Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Advanced indexing #25

Merged
merged 1 commit into from
Jan 8, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions TensorLib/Broadcast.lean
Original file line number Diff line number Diff line change
@@ -5,7 +5,6 @@ Authors: Jean-Baptiste Tristan, Paul Govereau, Sean McLaughlin
-/

import Aesop
import Batteries.Data.List
import TensorLib.Common

/-!
@@ -39,6 +38,8 @@ Rule 2
A: (3, 2, 7)
B: (3, 2, 7)
Theorem to prove: If we can broadcast s1 to s2, then given an array with shape s1, then s1.reshape s2 succeeds
-/

namespace TensorLib
@@ -81,7 +82,7 @@ private def matchPairs (b : Broadcast) : Option Shape :=
else if x == 1 then some y
else if y == 1 then some x
else none
let dims := (b.left.val.zip b.right.val).traverse f
let dims := (b.left.val.zip b.right.val).mapM f
dims.map Shape.mk

--! Returns the shape resulting from broadcast the arguments
@@ -103,4 +104,23 @@ def canBroadcast (b : Broadcast) : Bool := (broadcast b).isSome
broadcast b2 == broadcast b1 &&
broadcast b2 == .some (Shape.mk [1, 2, 3])

def broadcastList (shapes : List Shape) : Option Shape := Id.run do
match shapes with
| [] => none
| shape :: shapes =>
let mut shape := shape
for s in shapes do
let b := Broadcast.mk shape s
match b.broadcast with
| .none => return .none
| .some s =>
shape := s
return shape

#guard
let x1 := Shape.mk [1, 2, 3]
let x2 := Shape.mk [2, 3]
let x3 := Shape.mk []
broadcastList [x1, x2, x3] == .some x1

end Broadcast
67 changes: 60 additions & 7 deletions TensorLib/Common.lean
Original file line number Diff line number Diff line change
@@ -5,7 +5,6 @@ Authors: Jean-Baptiste Tristan, Paul Govereau, Sean McLaughlin
-/

import Std.Tactic.BVDecide

namespace TensorLib

--! The error monad for TensorLib
@@ -27,6 +26,7 @@ def natDivCeil (num denom : Nat) : Nat := (num + denom - 1) / denom

def natProd (shape : List Nat) : Nat := shape.foldl (fun x y => x * y) 1


-- We generally have large tensors, so don't show them by default
instance ByteArrayRepr : Repr ByteArray where
reprPrec x _ :=
@@ -45,11 +45,41 @@ inductive ByteOrder where
| bigEndian
deriving BEq, Repr, Inhabited

namespace ByteOrder

@[simp]
def ByteOrder.isMultiByte (x : ByteOrder) : Bool := match x with
def isMultiByte (x : ByteOrder) : Bool := match x with
| .oneByte => false
| .littleEndian | .bigEndian => true

def bytesToInt (order : ByteOrder) (bytes : ByteArray) : Int := Id.run do
let mut n : Nat := 0
let nbytes := bytes.size
let signByte := match order with
| .littleEndian => bytes.get! (nbytes - 1)
| .bigEndian | oneByte => bytes.get! 0
let negative := 128 <= signByte
for i in [0:nbytes] do
let v : UInt8 := bytes.get! i
let v := if negative then UInt8.complement v else v
let p := match order with
| .oneByte => 0 -- nbytes = 1
| .littleEndian => i
| .bigEndian => nbytes - 1 - i
n := n + Pow.pow 2 (8 * p) * v.toNat
return if 128 <= signByte then -(n + 1) else n

#guard bytesToInt .littleEndian (ByteArray.mk #[1, 1]) == 257
#guard bytesToInt .bigEndian (ByteArray.mk #[1, 1]) == 257
#guard bytesToInt .littleEndian (ByteArray.mk #[0, 1]) == 256
#guard bytesToInt .bigEndian (ByteArray.mk #[0, 1]) == 1
#guard bytesToInt .littleEndian (ByteArray.mk #[0xFF, 0xFF]) == -1
#guard bytesToInt .bigEndian (ByteArray.mk #[0xFF, 0xFF]) == -1
#guard bytesToInt .bigEndian (ByteArray.mk #[0x80, 0]) == -32768
#guard bytesToInt .littleEndian (ByteArray.mk #[0x80, 0]) == 0x80

end ByteOrder

/-!
The strides are how many bytes you need to skip to get to the next element in that
"row". For example, in an array of 8-byte data with shape 2, 3, the strides are (24, 8).
@@ -144,6 +174,9 @@ deriving BEq, Repr, Inhabited

namespace Shape

instance : ToString Shape where
toString := reprStr

def empty : Shape := Shape.mk []

--! The number of elements in a tensor. All that's needed is the shape for this calculation.
@@ -155,6 +188,10 @@ def ndim (shape : Shape) : Nat := shape.val.length

def map (shape : Shape) (f : List Nat -> List Nat) : Shape := Shape.mk (f shape.val)

def dimIndexInRange (shape : Shape) (dimIndex : DimIndex) :=
shape.ndim == dimIndex.length &&
(shape.val.zip dimIndex).all fun (n, i) => i < n

/-!
Strides can be computed from the shape by figuring out how many elements you
need to jump over to get to the next spot and mulitplying by the bytes in each
@@ -205,6 +242,7 @@ def positionToDimIndex (strides : Strides) (n : Position) : DimIndex :=
let (_, idx) := strides.foldl foldFn (n, [])
idx.reverse

-- TODO: Return `Err Offset` for when the strides and index have different lengths?
def dimIndexToOffset (strides : Strides) (index : DimIndex) : Offset := dot strides (index.map Int.ofNat)

#guard positionToDimIndex [3, 1] 4 == [1, 1]
@@ -223,6 +261,21 @@ def allDimIndices (shape : Shape) : List DimIndex := Id.run do
#guard allDimIndices (Shape.mk [5]) == [[0], [1], [2], [3], [4]]
#guard allDimIndices (Shape.mk [3, 2]) == [[0, 0], [0, 1], [1, 0], [1, 1], [2, 0], [2, 1]]

-- NumPy supports negative indices, which simply wrap around. E.g. `x[.., -1, ..] = x[.., n-1, ..]` where `n` is the
-- dimension in question. It only supports `-n` to `n`.
def intIndexToDimIndex (shape : Shape) (index : List Int) : Err DimIndex := do
if shape.ndim != index.length then .error "intsToDimIndex length mismatch" else
let conv (dim : Nat) (ind : Int) : Err Nat :=
if 0 <= ind then
if ind < dim then .ok ind.toNat
else .error "index out of bounds"
else if ind < -dim then .error "index out of bounds"
else .ok (dim + ind).toNat
(shape.val.zip index).mapM (fun (dim, ind) => conv dim ind)

#guard intIndexToDimIndex (Shape.mk [1, 2, 3]) [0, -1, -1] == (.ok [0, 1, 2])
#guard intIndexToDimIndex (Shape.mk [1, 2, 3]) [0, 1, -2] == (.ok [0, 1, 1])

end Shape

/-
@@ -278,7 +331,7 @@ Note: I tried writing this as a `do/for` loop and in this case the recursive
one seems nicer. We are walking over two lists simultaneously, which is easy
here but with a for loop is either quadratic or awkward.
-/
def next (iter : DimsIter) : List Nat × DimsIter :=
def next (iter : DimsIter) : DimIndex × DimsIter :=
-- Invariant: `acc` is a list of 0s, so doesn't need to be reversed
let rec loop (acc ms ns : List Nat) : List Nat :=
match ms, ns with
@@ -293,7 +346,7 @@ def next (iter : DimsIter) : List Nat × DimsIter :=
let curr' := loop [] iter.dims iter.curr
(iter.curr.reverse, { iter with curr := curr' })

instance [Monad m] : ForIn m DimsIter (List Nat) where
instance [Monad m] : ForIn m DimsIter DimIndex where
forIn {α} [Monad m] (iter : DimsIter) (x : α) (f : List Nat -> α -> m (ForInStep α)) : m α := do
let mut iter := iter
let mut res := x
@@ -305,7 +358,7 @@ instance [Monad m] : ForIn m DimsIter (List Nat) where
| .done k => return k
return res

private def toList (iter : DimsIter) : List (List Nat) := Id.run do
private def toList (iter : DimsIter) : List DimIndex := Id.run do
let mut res := []
for xs in iter do
res := xs :: res
@@ -321,7 +374,7 @@ private def toList (iter : DimsIter) : List (List Nat) := Id.run do
#guard (DimsIter.make $ Shape.mk [1, 1, 2]).toList == [[0, 0, 0], [0, 0, 1]]
#guard (DimsIter.make $ Shape.mk [3, 2]).toList == [[0, 0], [0, 1], [1, 0], [1, 1], [2, 0], [2, 1]]

private def testBreak (iter : DimsIter) : List (List Nat) := Id.run do
private def testBreak (iter : DimsIter) : List DimIndex := Id.run do
let mut res := []
for xs in iter do
res := xs :: res
@@ -330,7 +383,7 @@ private def testBreak (iter : DimsIter) : List (List Nat) := Id.run do

#guard (DimsIter.make $ Shape.mk [3, 2]).testBreak == [[0, 0]]

private def testReturn (iter : DimsIter) : List (List Nat) := Id.run do
private def testReturn (iter : DimsIter) : List DimIndex := Id.run do
let mut res := []
let mut i := 0
for xs in iter do
14 changes: 14 additions & 0 deletions TensorLib/Dtype.lean
Original file line number Diff line number Diff line change
@@ -34,6 +34,16 @@ def isMultiByte (x : Name) : Bool := match x with
| bool | int8 | uint8 => false
| _ => true

def isInt (x : Name) : Bool := match x with
| int8 | int16 | int32 | int64 => true
| _ => false

def isUint (x : Name) : Bool := match x with
| uint8 | uint16 | uint32 | uint64 => true
| _ => false

def isIntLike (x : Name) : Bool := x.isInt || x.isUint

--! Number of bytes used by each element of the given dtype
def itemsize (x : Name) : Nat := match x with
| float64 | int64 | uint64 => 8
@@ -57,6 +67,10 @@ def itemsize (t : Dtype) := t.name.itemsize

def sizedStrides (dtype : Dtype) (s : Shape) : Strides := List.map (fun x => x * dtype.itemsize) s.unitStrides

def isInt (dtype : Dtype) : Bool := dtype.name.isInt
def isUint (dtype : Dtype) : Bool := dtype.name.isUint
def isIntLike (dtype : Dtype) : Bool := dtype.isInt || dtype.isUint

def int8 : Dtype := Dtype.mk Dtype.Name.int8 ByteOrder.littleEndian
def uint8 : Dtype := Dtype.mk Dtype.Name.uint8 ByteOrder.littleEndian
def uint64 : Dtype := Dtype.mk Dtype.Name.uint64 ByteOrder.littleEndian
141 changes: 117 additions & 24 deletions TensorLib/Index.lean
Original file line number Diff line number Diff line change
@@ -4,12 +4,19 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Jean-Baptiste Tristan, Paul Govereau, Sean McLaughlin
-/

import TensorLib.Broadcast
import TensorLib.Common
import TensorLib.Tensor
import TensorLib.Slice
import TensorLib.Npy
import TensorLib.Slice
import TensorLib.Tensor

/-
There are several types of indexing in NumPy.
https://numpy.org/doc/stable/user/basics.indexing.html
We handle basic indexing and some types of advanced indexing.
Theorems to prove (taken from NumPy docs):
1. Basic slicing with more than one non-: entry in the slicing tuple,
@@ -18,7 +25,7 @@ Theorems to prove (taken from NumPy docs):
Thus, x[ind1, ..., ind2,:] acts like x[ind1][..., ind2, :] under basic slicing.
2. Advanced indices always are broadcast and iterated as one:
result[i_1, ..., i_M] == x[ind_1[i_1, ..., i_M], ind_2[i_1, ..., i_M],
result[i_1, ..., i_M] == x[ind_1[i_1, ..., i_M], ind_2[i_1, ..., i_M],
..., ind_N[i_1, ..., i_M]]
3. ...TODO...
-/
@@ -251,6 +258,8 @@ private def testReturn (iter : BasicIter) : List (List Nat) := Id.run do
let slice := Slice.Iter.make Slice.all 5
testReturn (get! $ BasicIter.make shape [.slice slice, .slice slice]) == [[0, 0], [0, 1], [0, 2]]

end BasicIter

def applyWithCopy (index : NumpyBasic) (arr : Tensor) : Err Tensor := do
let itemsize := arr.itemsize
let oldShape := arr.shape
@@ -301,6 +310,94 @@ def apply (index : NumpyBasic) (arr : Tensor) : Err (Tensor × Bool) := do
}
return (res, false)

/-
For advanced indexing, the all-multidimensional-array case is relatively easy;
broadcast all arguments to the same shape, then select the elements of the original
array one by one. For example
# x = np.arange(6).reshape(2, 3)
# x
array([[0, 1, 2],
[3, 4, 5]])
# i0 = np.array([1, 0])[:, None]
# ii = np.array([1, 2, 0])[None, :]
# x[i0, i1]
array([[4, 5, 3],
[1, 2, 0]])
To obtain the later result, we simply walk through the [2, 3]-shaped indices
[[x[1, 1], x[1, 2], x[1, 0]],
[x[0, 1], x[0, 2], x[0, 0]],
This also works when the dims of the index is smaller than the dims
of the array. Each x[i, j] is just an array instead of a scalar. We do not currently
implement that, but if we need it it will be clear what to do; we just copy the (contiguous)
bytes of the sub-array.
Mixing basic and advanced indexing is complex: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
While I can follow the individual examples, the general case is some work.
As a simple example of mixing, they give `x[..., ind, :]` where
`x.shape` is `(10, 20, 30)` and `ind` is a `(2, 5, 2)`-shaped indexing array.
The result has shape `(10, 2, 5, 2, 30)` and `result[..., i, j, k, :] = x[..., ind[i, j, k], :]`.
While this example is understandable, things get more complex, and I've not yet seen
examples we currently want to support that uses them. Therefore, we do not currently implement
mixed basic/advanced indexing.
-/

namespace Advanced

def apply (indexTensors : List Tensor) (arr : Tensor) : Err Tensor := do
if indexTensors.any fun arr => !arr.isIntLike then .error "Index arrays must have an int-like type" else
if indexTensors.length != arr.ndim then .error "advanced indexing length mismatch"
-- Reshape all the input tesnsors
match Broadcast.broadcastList (indexTensors.map fun arr => arr.shape) with
| none => .error "input shapes must be broadcastable"
| some outShape =>
let mut reshapedIndexTensors := []
for indexTensor in indexTensors do
let indexTensor <- indexTensor.reshape outShape -- should never fail since broadcastList succeeds
reshapedIndexTensors := indexTensor :: reshapedIndexTensors
reshapedIndexTensors := reshapedIndexTensors.reverse
let mut res := Tensor.zeros arr.dtype outShape
-- Now we will iterate over the output shape, computing the values one-by-one from the input array
for outDimIndex in DimsIter.make outShape do
-- Get the index for each dimension in the original array from the corresponding value of the index tensors
let mut inIntIndex : List Int := []
for indexTensor in reshapedIndexTensors do
let v <- indexTensor.intAtDimIndex outDimIndex
inIntIndex := v :: inIntIndex
let inDimIndex <- arr.shape.intIndexToDimIndex inIntIndex.reverse
let bytes <- arr.byteArrayAtDimIndex inDimIndex
res <- res.setByteArrayAtDimIndex outDimIndex bytes
return res

/-
0 1 2
3 4 5
6 7 8
9 10 11
12 12 14
15 16 17
18 19 20
21 22 23
24 25 26
-/
#guard
let ind0 := (Tensor.Element.ofList Tensor.Element.Int8Native [1, 2, 0, 0]).reshape! (Shape.mk [2, 2])
let ind1 := (Tensor.Element.ofList Tensor.Element.Int8Native [2, -2, 0, 1]).reshape! (Shape.mk [2, 2])
let ind2 := (Tensor.Element.ofList Tensor.Element.Int8Native [1, 1, -1, -1]).reshape! (Shape.mk [2, 2])
let typ := BV16
let arr := (Tensor.Element.arange typ 27).reshape! (Shape.mk [3, 3, 3])
let res := get! $ apply [ind0, ind1, ind2] arr
let tree := get! $ res.toTree typ
tree == Tensor.Format.Tree.node [.root [16, 22], .root [2, 5]]

end Advanced

section Test

#guard
let tp := BV8
let tensor := Tensor.Element.arange tp 10
@@ -335,35 +432,33 @@ def apply (index : NumpyBasic) (arr : Tensor) : Err (Tensor × Bool) := do
let tree' := Tensor.Format.Tree.root [19]
!copied && tree == tree'

section Test

-- Testing
private def numpyBasicToList (dims : List Nat) (basic : NumpyBasic) : Option (List (List Nat)) := do
let shape := Shape.mk dims
let (basic, _) <- (toBasic basic shape).toOption
let iter <- (make shape basic).toOption
let iter <- (BasicIter.make shape basic).toOption
iter.toList

#guard numpyBasicToList [] [] == some [[]]
#guard numpyBasicToList [1] [.int 0] == some [[0]]
#guard numpyBasicToList [2] [.int 0] == some [[0]]
#guard numpyBasicToList [2] [.int 1] == some [[1]]
#guard (numpyBasicToList [2] [.int 2]) == none
#guard (numpyBasicToList [2] [.int (-1)]) == some [[1]]
#guard (numpyBasicToList [2] [.int (-3)]) == none
#guard (numpyBasicToList [4] [.slice Slice.all]) == some [[0], [1], [2], [3]]
#guard (numpyBasicToList [4] [.slice $ Slice.build! .none .none (.some 2)]) == some [[0], [2]]
#guard (numpyBasicToList [4] [.slice $ Slice.build! (.some (-1)) .none (.some (-2))]) == some [[3], [1]]
#guard (numpyBasicToList [2, 2] [.int 5]) == none
#guard (numpyBasicToList [2, 2] [.int 0]) == some [[0, 0], [0, 1]]
#guard (numpyBasicToList [2, 2] [.int 0, .int 0]) == some [[0, 0]]
#guard (numpyBasicToList [2, 2] [.int 0, .int 1]) == some [[0, 1]]
#guard (numpyBasicToList [2, 2] [.int 0, .int 2]) == none
#guard (numpyBasicToList [3, 3] [.slice Slice.all, .int 2]) == some [[0, 2], [1, 2], [2, 2]]
#guard (numpyBasicToList [3, 3] [.int 2, .slice Slice.all]) == some [[2, 0], [2, 1], [2, 2]]
#guard (numpyBasicToList [2, 2] [.slice Slice.all, .slice Slice.all]) == some [[0, 0], [0, 1], [1, 0], [1, 1]]
#guard (numpyBasicToList [2, 2] [.slice (Slice.build! .none .none (.some (-1))), .slice Slice.all]) == some [[1, 0], [1, 1], [0, 0], [0, 1]]
#guard (numpyBasicToList [4, 2] [.slice (Slice.build! .none .none (.some (-2))), .slice Slice.all]) == some [[3, 0], [3, 1], [1, 0], [1, 1]]
#guard numpyBasicToList [2] [.int 2] == none
#guard numpyBasicToList [2] [.int (-1)] == some [[1]]
#guard numpyBasicToList [2] [.int (-3)] == none
#guard numpyBasicToList [4] [.slice Slice.all] == some [[0], [1], [2], [3]]
#guard numpyBasicToList [4] [.slice $ Slice.build! .none .none (.some 2)] == some [[0], [2]]
#guard numpyBasicToList [4] [.slice $ Slice.build! (.some (-1)) .none (.some (-2))] == some [[3], [1]]
#guard numpyBasicToList [2, 2] [.int 5] == none
#guard numpyBasicToList [2, 2] [.int 0] == some [[0, 0], [0, 1]]
#guard numpyBasicToList [2, 2] [.int 0, .int 0] == some [[0, 0]]
#guard numpyBasicToList [2, 2] [.int 0, .int 1] == some [[0, 1]]
#guard numpyBasicToList [2, 2] [.int 0, .int 2] == none
#guard numpyBasicToList [3, 3] [.slice Slice.all, .int 2] == some [[0, 2], [1, 2], [2, 2]]
#guard numpyBasicToList [3, 3] [.int 2, .slice Slice.all] == some [[2, 0], [2, 1], [2, 2]]
#guard numpyBasicToList [2, 2] [.slice Slice.all, .slice Slice.all] == some [[0, 0], [0, 1], [1, 0], [1, 1]]
#guard numpyBasicToList [2, 2] [.slice (Slice.build! .none .none (.some (-1))), .slice Slice.all] == some [[1, 0], [1, 1], [0, 0], [0, 1]]
#guard numpyBasicToList [4, 2] [.slice (Slice.build! .none .none (.some (-2))), .slice Slice.all] == some [[3, 0], [3, 1], [1, 0], [1, 1]]

-- Commented for easier debugging. Remove some day
-- #eval do
@@ -380,9 +475,7 @@ private def numpyBasicToList (dims : List Nat) (basic : NumpyBasic) : Option (Li
-- -- let (ns8, iter8) <- iter7.next
-- -- let (ns9, iter9) <- iter8.next
-- return (basic, iter0, ns0, iter1, ns1, iter2, ns2, iter3) -- , ns4, iter4) -- , ns5, iter5, ns6, iter6, ns7, iter7, ns8, iter8, ns9, iter9)

end Test

end BasicIter
end Index
end TensorLib
57 changes: 52 additions & 5 deletions TensorLib/Tensor.lean
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Jean-Baptiste Tristan, Paul Govereau, Sean McLaughlin
-/

import Batteries.Data.List
import Batteries.Data.List -- for `toChunks`
import TensorLib.Common
import TensorLib.Dtype
import TensorLib.Npy
@@ -108,6 +108,8 @@ def ones (dtype : Dtype) (shape : Shape) : Tensor := Id.run do
data := data.push byte
{ dtype := dtype, shape := shape, data := data }

def byteOrder (arr : Tensor) : ByteOrder := arr.dtype.order

--! number of dimensions
def ndim (x : Tensor) : Nat := x.shape.ndim

@@ -126,11 +128,35 @@ def dimIndexToOffset (x : Tensor) (i : DimIndex) : Int :=

--! Get the starting byte corresponding to a DimIndex
def dimIndexToPosition (x : Tensor) (i : DimIndex) : Nat :=
(x.dimIndexToOffset i).toNat
(x.startIndex + (x.dimIndexToOffset i)).toNat

--! number of bytes representing the entire tensor
def nbytes (x : Tensor) : Nat := x.itemsize * x.size

def isIntLike (x : Tensor) : Bool := x.dtype.isIntLike

def dimIndexInRange (arr : Tensor) (dimIndex : DimIndex) : Bool := arr.shape.dimIndexInRange dimIndex

def byteArrayAtDimIndex (arr : Tensor) (dimIndex : DimIndex) : Err ByteArray := do
if !arr.dimIndexInRange dimIndex then .error "index is incompatible with tensor shape" else
let posn := arr.dimIndexToPosition dimIndex
.ok $ arr.data.extract posn (posn + arr.itemsize)

def setByteArrayAtDimIndex (arr : Tensor) (dimIndex : DimIndex) (bytes : ByteArray) : Err Tensor := do
if !arr.dimIndexInRange dimIndex then .error "index is incompatible with tensor shape" else
if arr.itemsize != bytes.size then .error "byte size mismatch" else
let posn := arr.dimIndexToPosition dimIndex
.ok $ { arr with data := bytes.copySlice 0 arr.data posn bytes.size }

/-!
Return the integer at the dimIndex. This is useful, for example, in advanced indexing
where we must have an int/uint Tensor as an argument.
-/
def intAtDimIndex (arr : Tensor) (dimIndex : DimIndex) : Err Int := do
if !arr.isIntLike then .error "natAt expects an int tensor" else
let bytes <- byteArrayAtDimIndex arr dimIndex
.ok $ arr.byteOrder.bytesToInt bytes

/-!
Copy a Tensor's data to new, contiguous storage.
@@ -267,6 +293,16 @@ def setPosition [typ : Element a] (x : Tensor) (n : Nat) (v : a): Err Tensor :=
let posn := n * itemsize
.ok { x with data := bytes.copySlice 0 x.data posn itemsize true }

def ofList (typ : Element a) (xs : List a) : Tensor := Id.run do
let arr := Tensor.zeros typ.dtype (Shape.mk [xs.length])
let mut data := arr.data
let mut posn := 0
for x in xs do
let v := typ.toByteArray x
data := v.copySlice 0 data posn typ.itemsize
posn := posn + arr.itemsize
{ arr with data := data }

-- Since the DimIndex is independent of the dtype size, we need to recompute the strides
-- TODO: Would be better to not recompute this over and over. We should find a place to store
-- the 1-based default strides
@@ -284,8 +320,7 @@ def setDimIndex [Element a] (x : Tensor) (index : DimIndex) (v : a): Err Tensor

-- TODO: remove `Err` by proving all indices are within range
def toList (a : Type) [Tensor.Element a] (x : Tensor) : Err (List a) :=
let traverseFn ind : Err a := getDimIndex x ind
x.shape.allDimIndices.traverse traverseFn
x.shape.allDimIndices.mapM (getDimIndex x)

def toList! (a : Type) [Tensor.Element a] (x : Tensor) : List a := match toList a x with
| .error _ => []
@@ -298,6 +333,15 @@ instance BV8Native : Element BV8 where
toByteArray (x : BV8) : ByteArray := x.toByteArray
fromByteArray arr startIndex := ByteArray.toBV8 arr startIndex

instance Int8Native : Element Int8 where
dtype := Dtype.mk .int8 .oneByte
itemsize := 1
ofNat n := n.toInt8
toByteArray (x : Int8) : ByteArray := [x.toUInt8].toByteArray
fromByteArray arr startIndex := (ByteArray.toBV8 arr startIndex).map fun b => Int8.mk b.toUInt8

#guard Int8Native.fromByteArray (Int8Native.toByteArray (-5)) 0 == .ok (-5)

instance BV16Little : Element BV16 where
dtype := Dtype.mk .uint16 .littleEndian
itemsize := 2
@@ -344,7 +388,7 @@ private def toTree {a : Type} (x : List a) (strides : Strides) : Err (Tree a) :=
| [_] => .error "not a unit stride"
| stride :: strides => do
let chunks := x.toChunks stride.toNat
let res <- chunks.traverse (fun x => toTree x strides)
let res <- chunks.mapM (fun x => toTree x strides)
return .node res

private def toTree! {a : Type} (x : List a) (strides : Strides) : Tree a := match toTree x strides with
@@ -457,6 +501,9 @@ private def arr1 := Element.arange BV8 12
#guard (ones (Dtype.float64) $ Shape.mk [2, 2]).nbytes == 2 * 2 * 8
#guard (ones (Dtype.float64) $ Shape.mk [2, 2]).data.toList.count 1 == 2 * 2

#guard get! ((Element.ofList Element.BV8Native [1, 2, 3]).toTree BV8) == Format.Tree.root [1, 2, 3]
#guard get! (((Element.ofList Element.BV8Native [0, 1, 2, 3, 4, 5]).reshape! (Shape.mk [2, 3])).toTree BV8) == .node [.root [0, 1, 2], .root [3, 4, 5]]

end Test

end Tensor