Skip to content

Commit

Permalink
Fix and clean up code for mean and sum functions. (#52)
Browse files Browse the repository at this point in the history
- Fix issue with sum and mean functions on tensors. 
- Add dynamic shape functionality for `dynamicTensor`
- Add `squeeze`, `reshape`, and `unsqueeze` methods to `dynamicTensor`.
  • Loading branch information
Iainmon authored Mar 3, 2025
2 parents 9cd2da2 + 2cc91fd commit 6f329bd
Show file tree
Hide file tree
Showing 8 changed files with 390 additions and 63 deletions.
18 changes: 18 additions & 0 deletions lib/Autograd.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,24 @@ record multOp : serializable {
proc spec : GradOpSpec do return new dict(("operation","Mul"));
}

// record scalarMapOp : serializable {
// param opName: string;
// enum scalarMapOpSide { ScalarMapOpLeft, ScalarMapOpRight }
// param opSide: scalarMapOpSide;

// var input: shared BaseTensorResource(?);
// var scalar: input.eltType;

// proc init(param opName: string, scalar: ?scalarType, input: shared BaseTensorResource(?))
// where isNumericType(scalarType) {
// this.opSide = ScalarMapOpLeft;
// this.input = input;
// this.scalar = scalar;
// }

// proc forward() { compilerError("TODO!"); }

// }

record reshapeOp : serializable {
param oldRank: int;
Expand Down
181 changes: 154 additions & 27 deletions lib/DynamicTensor.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ use Utilities.Standard;

use Env;

use List only list;

import LoadNumpy;

param defaultDetachedMode = true;
Expand Down Expand Up @@ -330,45 +332,50 @@ operator ==(a: dynamicTensor(?eltType),b: dynamicTensor(eltType)): bool {
}
halt("Could not determine rank in dynamicTensor == dynamicTensor.");
}
proc dynamicTensor.sum(axes: ?axesCount*int, keepDim: bool = true): dynamicTensor(eltType) {
for param rank in 1..maxRank do
if this.checkRank(rank) then
return this.forceRank(rank).sum(axes,keepDim).eraseRank();
halt("Could not determine rank in dynamicTensor.sum.");
return new dynamicTensor(eltType);
}

proc dynamicTensor.sum(keepDim: bool = true): dynamicTensor(eltType) {
inline proc dynamicTensor.reduceOpAxes(param opName: string, axes: ?axesCount*int, param keepDim: bool): dynamicTensor(eltType) {
for param rank in 1..maxRank do
if this.checkRank(rank) then
return this.forceRank(rank).sum(keepDim=true).eraseRank();
halt("Could not determine rank in dynamicTensor.sum.");
select opName {
when "sum" do
return this.forceRank(rank).sum(axes,keepDim=keepDim).eraseRank();
when "mean" do
return this.forceRank(rank).mean(axes,keepDim=keepDim).eraseRank();
}
halt("Could not determine rank in dynamicTensor." + opName + ".");
return new dynamicTensor(eltType);
}

proc dynamicTensor.sum(axes: int...?axesCount): dynamicTensor(eltType) {
inline proc dynamicTensor.reduceOpNoAxes(param opName: string, param keepDim: bool): dynamicTensor(eltType) {
for param rank in 1..maxRank do
if this.checkRank(rank) then
return this.forceRank(rank).sum((...axes)).eraseRank();
halt("Could not determine rank in dynamicTensor.sum.");
select opName {
when "sum" do
return this.forceRank(rank).sum(keepDim=keepDim).eraseRank();
when "mean" do
return this.forceRank(rank).mean(keepDim=keepDim).eraseRank();
}
halt("Could not determine rank in dynamicTensor." + opName + ".");
return new dynamicTensor(eltType);
}

proc dynamicTensor.sum(axes: int...?r): dynamicTensor(eltType) {
for param rank in 1..maxRank do
if this.checkRank(rank) then
return this.forceRank(rank).sum((...axes)).eraseRank();
halt("Could not determine rank in dynamicTensor.sum.");
return new dynamicTensor(eltType);
}
proc dynamicTensor.sum(axes: ?axesCount*int, param keepDim: bool): dynamicTensor(eltType) do
return this.reduceOpAxes("sum",axes,keepDim);

proc dynamicTensor.sum(axes: int...?r): dynamicTensor(eltType) {
for param rank in 1..maxRank do
if this.checkRank(rank) then
return this.forceRank(rank).sum((...axes)).eraseRank();
halt("Could not determine rank in dynamicTensor.sum.");
return new dynamicTensor(eltType);
}
proc dynamicTensor.sum(param keepDim: bool = true): dynamicTensor(eltType) do
return this.reduceOpNoAxes("sum",keepDim);

proc dynamicTensor.sum(axes: int...?axesCount): dynamicTensor(eltType) do
return this.sum(axes,keepDim=true);

proc dynamicTensor.mean(axes: ?axesCount*int, param keepDim: bool): dynamicTensor(eltType) do
return this.reduceOpAxes("mean",axes,keepDim);

proc dynamicTensor.mean(param keepDim: bool = true): dynamicTensor(eltType) do
return this.reduceOpNoAxes("mean",keepDim);

proc dynamicTensor.mean(axes: int...?axesCount): dynamicTensor(eltType) do
return this.mean(axes,keepDim=true);

proc dynamicTensor.relu(): dynamicTensor(eltType) {
for param rank in 1..maxRank {
Expand Down Expand Up @@ -777,6 +784,126 @@ proc dynamicTensor.degenerateFlatten(): [] eltType {
return new dynamicTensor(eltType);
}

record dynamicShape : serializable {

var size: int;
var sizes: [0..<size] int;

proc init(shape: ?rank*int) {
this.size = rank;
init this;
for param i in 0..<rank do
this.sizes[i] = shape(i);
}

proc init(dt: staticTensor(?rank,?eltType)) do
this.init(dt.shapeTuple());

proc init(sizes: [] int) do
this.init(sizes.shape);

proc init(sizes: list(int)) do
this.init(sizes.toArray());

proc checkRank(param rank: int): bool do
return rank == size;

proc toRankedShape(param rank: int): rank*int {
var shape: rank*int;
if this.checkRank(rank) {
for param i in 0..<rank do
shape(i) = this.sizes[i];
return shape;
}
halt("DynamicShape rank is not given rank " + rank : string + ".");
return shape;
}

proc toList(): list(int) do
return new list(this.sizes);

proc head: int do
return this.sizes.first;

proc tail: dynamicShape {
var sizes = this.toList();
sizes.remove(this.head);
return new dynamicShape(sizes);
}
}

proc dynamicTensor.shape(): dynamicShape {
for param rank in 1..maxRank do
if this.checkRank(rank) then
return new dynamicShape(this.forceRank(rank));

halt("Could not determine rank in dynamicTensor.shape.");
return new dynamicShape((0,));
}

proc dynamicTensor.reshape(dShape: dynamicShape): dynamicTensor(eltType) {
for param rank in 1..maxRank do
for param shapeRank in 1..maxRank do
if this.checkRank(rank) && dShape.checkRank(shapeRank) then
return this.forceRank(rank).reshape(dShape.toRankedShape(shapeRank)).eraseRank();
halt("Could not determine rank in dynamicTensor.reshape.");
return new dynamicTensor(eltType);
}

proc dynamicTensor.unsqueeze(dim: int): dynamicTensor(eltType) {
for param rank in 1..maxRank do
if this.checkRank(rank) then
return this.forceRank(rank).unsqueeze(dim).eraseRank();
halt("Could not determine rank in dynamicTensor.unsqueeze.");
return new dynamicTensor(eltType);
}


proc dynamicTensor.squeeze(): dynamicTensor(eltType) {
var dShape = this.shape();
var newSizes = new list(int);
var prod = 1;
for i in 0..<dShape.size do
if dShape.sizes[i] != 1 {
const s = dShape.sizes[i];
newSizes.pushBack(s * prod);
prod *= s;
}
var newDShape = new dynamicShape(newSizes);
return this.reshape(newDShape);
}

proc dynamicTensor.squeeze(dim: int): dynamicTensor(eltType) {
var dShape = this.shape();
var newSizes = new list(int);
var prod = 1;
for i in 0..<dShape.size do
if dShape.sizes[i] != 1 {
const s = dShape.sizes[i];
newSizes.pushBack(s * prod);
prod *= s;
} else { break; }
var newDShape = new dynamicShape(newSizes);
return this.reshape(newDShape);
}

// proc dynamicTensor.squeeze(dim: squee): dynamicTensor(eltType) {
// for param rank in 1..maxRank do
// for param shapeRank in 1..maxRank do
// if this.checkRank(rank) && dShape.checkRank(shapeRank) then
// return this.forceRank(rank).squeeze(dShape.toRankedShape(shapeRank)).eraseRank();
// halt("Could not determine rank in dynamicTensor.squeeze.");
// return new dynamicTensor(eltType);
// }

// proc dynamicTensor.squeeze(dShape: dynamicShape): dynamicTensor(eltType) {
// if dShape.size == 1 then
// return this.squeeze(dShape.head);
// else
// return (this.squeeze(dShape.tail)).squeeze(dShape.head);
// }


proc main() {

// Just some examples.
Expand Down
1 change: 0 additions & 1 deletion lib/NDArray.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,6 @@ record ndarray : serializable {
return dilated;
}


proc squeeze(param newRank: int): ndarray(newRank,eltType) where newRank < rank {
// I think this will work: (a member of the chapel team needs to review this)
// I suspect heavy performance hits will happen when running this on CUDA.
Expand Down
52 changes: 21 additions & 31 deletions lib/StaticTensor.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ proc staticTensor.shapeArray(): [] int {
return sa;
}

proc staticTensor.shapeTuple(): rank*int {
var st: rank * int;
on this.device do
st = this.array.shape;
return st;
}
proc tensorFromCtx(param rank: int, type eltType, ctx: ?ctxType): staticTensor(rank,eltType) {
var newMeta = new owned TensorResource(eltType,rank,ctx);
newMeta.forward();
Expand Down Expand Up @@ -192,18 +198,18 @@ operator ==(a: staticTensor(?rank,?eltType), b: staticTensor(rank,eltType)): boo
return flag;
}


proc staticTensor.reshape(dom: domain(?)) {
param newRank = dom.rank;
var ctx = new reshapeOp(rank,newRank,eltType,dom.shape,meta);
return tensorFromCtx(newRank,eltType,ctx);
}

proc staticTensor.reshape(newShape: int ...?newRank) {
proc staticTensor.reshape(newShape: ?newRank*int) {
var ctx = new reshapeOp(rank,newRank,eltType,newShape,meta);
return tensorFromCtx(newRank,eltType,ctx);
}

proc staticTensor.reshape(newShape: int ...?newRank) do
return this.reshape(newShape);

proc staticTensor.reshape(dom: domain(?)) do
return this.reshape(dom.shape);


proc staticTensor.relu() {
var ctx = new reluOp(meta);
return tensorFromCtx(rank,eltType,ctx);
Expand Down Expand Up @@ -354,41 +360,25 @@ proc staticTensor.slice(rngs: range...rank) {
return tensorFromCtx(rank,eltType,ctx);
}

proc staticTensor.sum(axes: ?axesCount * int, param keepDim: bool = true) {
if rank - axesCount < 0 then
compilerError("Cannot mean more axes than rank. ");
proc staticTensor.sum(axes: ?axesCount * int, param keepDim: bool) {
if rank - axesCount < 0 && !keepDim then
compilerError("Cannot sum more axes than rank. ");
var ctx = new sumOp(rank,eltType,axesCount,axes,meta,keepDim);
return tensorFromCtx(ctx.outRank,eltType,ctx);
}

proc staticTensor.sum(axes: ?axesCount * int, keepDim: bool = true) {
if rank - axesCount < 0 then
compilerError("Cannot mean more axes than rank. ");
const bools = (true,false);
for param i in 0..<bools.size do
if keepDim then
return this.sum(axes,keepDim=true);
else
return this.sum(axes,keepDim=false);
}

proc staticTensor.sum(keepDim: bool = true) {
proc staticTensor.sum(param keepDim: bool = true) {
const axes = this.array.nDimTuple();
const bools = (true,false);
for param i in 0..<bools.size do
if keepDim then
return this.sum(axes,keepDim=true);
else
return this.sum(axes,keepDim=false);
return this.sum(axes,keepDim);
}

proc staticTensor.sum(axes: int...?axesCount) {
return this.sum(axes,keepDim=true);
}


proc staticTensor.mean(axes: ?axesCount * int, param keepDim: bool = true) {
if rank - axesCount < 0 then
proc staticTensor.mean(axes: ?axesCount * int, param keepDim: bool) {
if rank - axesCount < 0 && !keepDim then
compilerError("Cannot mean more axes than rank. ");
var ctx = new meanOp(rank,eltType,axesCount,axes,meta,keepDim);
return tensorFromCtx(ctx.outRank,eltType,ctx);
Expand Down
4 changes: 2 additions & 2 deletions test/correspondence/reduction/mean/mean.chpl
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use Tensor;

var a = Tensor.arange(2,3);
var a = dynamicTensor.arange(2,3);

Testing.numericPrint(a.mean(0));

Testing.numericPrint(a.mean(1));

Testing.numericPrint(a.mean());

var b = Tensor.arange(2,3,4);
var b = dynamicTensor.arange(2,3,4);

Testing.numericPrint(b.mean(0));

Expand Down
4 changes: 2 additions & 2 deletions test/correspondence/reduction/sum/sum.chpl
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use Tensor;

var a = Tensor.arange(2,3);
var a = dynamicTensor.arange(2,3);

Testing.numericPrint(a.sum(0));

Testing.numericPrint(a.sum(1));

Testing.numericPrint(a.sum());

var b = Tensor.arange(2,3,4);
var b = dynamicTensor.arange(2,3,4);

Testing.numericPrint(b.sum(0));

Expand Down
Loading

0 comments on commit 6f329bd

Please sign in to comment.