Skip to content

Commit

Permalink
Fix reshape function and optimize its performance.
Browse files Browse the repository at this point in the history
  • Loading branch information
Iainmon committed Mar 4, 2025
1 parent d69adec commit 147917d
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 13 deletions.
2 changes: 1 addition & 1 deletion lib/Env.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ param developmentAndTesting = !releaseChAI;


// Minimum needed rank of dynamicTensor for *any* build
config param minRankNeeded = 6;
config param minRankNeeded = 4;


// Maximum needed rank of dynamicTensor for a release build
Expand Down
35 changes: 25 additions & 10 deletions lib/NDArray.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -149,24 +149,39 @@ record ndarray : serializable {
_domain = dom;
}

proc reshape(const dom: ?t): ndarray(rank,eltType)
proc reshape(const dom: ?t): ndarray(dom.rank,eltType)
where isDomainType(t)
&& dom.rank == rank {
/* && dom.rank == rank */ {
var arr = new ndarray(eltType,dom);
const arrDom = arr.domain;
const selfDom = _domain;

const inter = selfDom[arrDom];
arr.data[inter] = data[inter];

const arrDom = arr.domain;
const selfDom = this.domain;

ref arrData = arr.data;
const ref selfData = this.data;

const arrShape = arrDom.shape;
const selfShape = selfDom.shape;
const selfShapeDivs = util.shapeDivisors((...selfShape));

const zero: eltType = 0;

forall (i,idx) in arrDom.everyZip() {
const selfIdx = util.indexAtHelperMultiples(i,(...selfShapeDivs));
const a = if util.shapeContains(selfShape,selfIdx)
then selfData[selfIdx]
else zero;
arrData[idx] = a;
}
return arr;
}

/*
proc reshape(const dom: ?t): ndarray(dom.rank,eltType)
where isDomainType(t)
&& dom.rank != rank {
var arr: ndarray(dom.rank,eltType) = new ndarray(eltType,dom);

compilerError("Testing. Don't call me.");
const selfDom = this.domain;
const newDom = arr.domain;
const ref selfData = this.data;
Expand Down Expand Up @@ -200,7 +215,7 @@ record ndarray : serializable {
// return arr;
}

*/

// This can optimized such that it doesn't use two heavy utility functions...
proc reshape(newShape: int ...?newRank): ndarray(newRank,eltType) {
Expand Down
28 changes: 28 additions & 0 deletions lib/Utilities.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,26 @@ module Utilities {
return idxs;
}

inline proc shapeDivisors(shape: int...?rank): rank*int {
var prod = 1;
var divs: rank * int;
for param j in 0..<rank {
param i = rank - j - 1;
divs(i) = prod;
prod *= shape(i);
}
return divs;
}

inline proc shapeProduct(shape: int...?rank): int {
if rank == 1 {
return 1;
} else {
const divs = shapeDivisors((...shape));
return divs(rank - 1) * shape(rank - 1);
}
}

inline proc indexAtHelperProd(n: int, prod: int, shape: int ...?rank): rank * int where rank > 1 {
var idx: rank * int;
var order = n;
Expand Down Expand Up @@ -304,6 +324,14 @@ module Utilities {
return arr;
}

inline proc shapeContains(shape: ?rank*int, idx: ?idxRank*int): bool
where rank == idxRank {
var contained = true;
for param i in 0..<rank do
contained &= idx(i) < shape(i);
return contained;
}

module Standard {
private use ChplConfig;

Expand Down
19 changes: 19 additions & 0 deletions test/correspondence/movement/reshape/reshape.chpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
use Tensor;

var a = dynamicTensor.arange(2,3);

Testing.numericPrint(a.reshape(3,2));

Testing.numericPrint(a.reshape(6));

Testing.numericPrint(a.reshape(1,2,3));

Testing.numericPrint(a.reshape(1,1,6));

Testing.numericPrint(a.reshape(6,1,1));

Testing.numericPrint(a.reshape(2,1,1,3));

Testing.numericPrint(a.reshape(3,1,1,2));

Testing.numericPrint(a.reshape(1,3,2,1));
23 changes: 23 additions & 0 deletions test/correspondence/movement/reshape/reshape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch

def test(imports):
print = imports['print_fn']

a = torch.arange(6,dtype=torch.float32).reshape(2,3)

print(a.reshape(3,2))

print(a.reshape(6))

print(a.reshape(1,2,3))

print(a.reshape(1,1,6))

print(a.reshape(6,1,1))

print(a.reshape(2,1,1,3))

print(a.reshape(3,1,1,2))

print(a.reshape(1,3,2,1))

29 changes: 27 additions & 2 deletions test/tiny/dynamic_shape.chpl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use Tensor;

use List;
import Utilities as util;

/*
var a = dynamicTensor.arange(2,3);
Expand All @@ -22,7 +24,7 @@ writeln(b.shape());
writeln(b.shape().toList());
*/

/*
var c = dynamicTensor.arange(4) + 1;
writeln(c);
Expand All @@ -36,4 +38,27 @@ writeln(d.squeeze());
writeln(d.squeeze(1));
writeln(d.reshape(new dynamicShape((2,2))));
writeln(d.reshape(dShape=new dynamicShape((4,))));
writeln(d.reshape(dShape=new dynamicShape((4,))));
writeln(d.reshape(2,2));
*/

/*
var a = dynamicTensor.arange(2,3);
writeln(a);
writeln(a.forceRank(2).array);
writeln(a.reshape(3,2).forceRank(2).array);
writeln(a.reshape(3,2).forceRank(2).array);
*/

var a = ndarray.arange(2,3);
writeln(a);
writeln(a.reshape(3,2));
writeln(a.reshape(3,2));


writeln(a.reshape(6));
writeln(a.reshape(1,1,6));
39 changes: 39 additions & 0 deletions test/tiny/tinygradNB.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,45 @@
"print(t.unsqueeze(1).unsqueeze(1).squeeze(1).numpy().shape)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0 1 2]\n",
" [3 4 5]]\n"
]
}
],
"source": [
"a = Tensor.arange(6).reshape(2,3)\n",
"print(a.numpy())"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0 1]\n",
" [2 3]\n",
" [4 5]]\n"
]
}
],
"source": [
"b = a.reshape(3,2)\n",
"print(b.numpy())"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down

0 comments on commit 147917d

Please sign in to comment.