Skip to content

Commit

Permalink
[SYSTEMDS-2748] CLA TSMM new version multithreaded
Browse files Browse the repository at this point in the history
This commit make the TSMM parallel again, while parallelizing the
underlying preAggregate instruction have also been optimized reducing
the number of instructions according to Perf from:
Infini-Mnist:
4 Trillion -> 2.7 Trillion ( 430 sec single thread -> 60 sec Multi)
Census:
0.3 Trillion -> 0.3 Trillion ( 22 sec -> 20 sec)

Next step is to optimize the cost function subject to the new performance
characteristics of compressed matrix multiplication.

Additionally:

- Fix bug in materialize sort when number of elements was 256 (byte)
- Fix CPAggregateBinaryInstruction to not release matrix block twice

Closes #1543
  • Loading branch information
Baunsgaard committed Feb 11, 2022
1 parent 2439f75 commit 2fffa15
Show file tree
Hide file tree
Showing 40 changed files with 2,109 additions and 601 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.apache.commons.logging.LogFactory;
import org.apache.commons.math3.random.Well1024a;
import org.apache.sysds.common.Types.CorrectionLocationType;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.lops.MMTSJ.MMTSJType;
import org.apache.sysds.lops.MapMultChain.ChainType;
import org.apache.sysds.runtime.DMLRuntimeException;
Expand Down Expand Up @@ -507,9 +508,11 @@ public MatrixBlock transposeSelfMatrixMultOperations(MatrixBlock out, MMTSJType

@Override
public MatrixBlock replaceOperations(MatrixValue result, double pattern, double replacement) {
if(isOverlapping())
if(isOverlapping()){

return getUncompressed("replaceOperations " + pattern + " -> " + replacement).replaceOperations(result,
pattern, replacement);
}
else {
CompressedMatrixBlock ret = new CompressedMatrixBlock(getNumRows(), getNumColumns());
final List<AColGroup> prev = getColGroups();
Expand Down Expand Up @@ -934,9 +937,12 @@ public MatrixBlock getUncompressed() {
MatrixBlock d_compressed = getCachedDecompressed();
if(d_compressed != null)
return d_compressed;
if(isEmpty())
else if(isEmpty())
return new MatrixBlock(getNumRows(), getNumColumns(), true);
return this.decompress(InfrastructureAnalyzer.getLocalParallelism());
else if(ConfigurationManager.isParallelMatrixOperations())
return this.decompress(InfrastructureAnalyzer.getLocalParallelism());
else
return this.decompress(1);
}

public MatrixBlock getUncompressed(String operation) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,20 @@ protected static List<CompressedSizeInfoColGroup> join(List<CompressedSizeInfoCo
private static List<CompressedSizeInfoColGroup> coCodeBruteForce(List<CompressedSizeInfoColGroup> inputColumns,
ICostEstimate cEst, Memorizer mem, int k) {

List<ColIndexes> workSet = new ArrayList<>(inputColumns.size());

final List<ColIndexes> workSet = new ArrayList<>(inputColumns.size());
final boolean workloadCost = cEst instanceof ComputationCostEstimator;

// assume that we can at max reduce 90 % of cost if joined
// assume that we can max reduce 65% of compute cost if joined
final double costFilterThreshold = (workloadCost ? 0.65 : 0.9);

for(int i = 0; i < inputColumns.size(); i++)
workSet.add(new ColIndexes(inputColumns.get(i).getColumns()));

parallelFirstJoin(workSet, mem, cEst, k);
if(k > 1)
parallelFirstJoin(workSet, mem, cEst, costFilterThreshold, k);

// process merging iterations until no more change
// Process merging iterations until no more change
while(workSet.size() > 1) {
double changeInCost = 0;
CompressedSizeInfoColGroup tmp = null;
Expand All @@ -87,22 +91,20 @@ private static List<CompressedSizeInfoColGroup> coCodeBruteForce(List<Compressed
// it still does not improve compression.
// In the case of workload we relax the requirement for the filter.
// if(-Math.min(costC1, costC2) > changeInCost)
if(-Math.min(costC1, costC2) * (workloadCost ? 0.7 : 1) > changeInCost)
if(-Math.min(costC1, costC2) * costFilterThreshold > changeInCost)
continue;

// Join the two column groups.
// and Memorize the new join.
final CompressedSizeInfoColGroup c1c2Inf = mem.getOrCreate(c1, c2);
final double costC1C2 = cEst.getCostOfColumnGroup(c1c2Inf);

final double newSizeChangeIfSelected = costC1C2 - costC1 - costC2;
final double newCostIfJoined = costC1C2 - costC1 - costC2;

// Select the best join of either the currently selected
// or keep the old one.
if((tmp == null && newSizeChangeIfSelected < changeInCost) ||
tmp != null && (newSizeChangeIfSelected < changeInCost || newSizeChangeIfSelected == changeInCost &&
c1c2Inf.getColumns().length < tmp.getColumns().length)) {
changeInCost = newSizeChangeIfSelected;
if((tmp == null && newCostIfJoined < changeInCost) || tmp != null && (newCostIfJoined < changeInCost ||
newCostIfJoined == changeInCost && c1c2Inf.getColumns().length < tmp.getColumns().length)) {
changeInCost = newCostIfJoined;
tmp = c1c2Inf;
selected1 = c1;
selected2 = c2;
Expand All @@ -119,35 +121,27 @@ private static List<CompressedSizeInfoColGroup> coCodeBruteForce(List<Compressed
else
break;
}

if(LOG.isDebugEnabled())
LOG.debug("Memorizer stats:" + mem.stats());
mem.resetStats();

List<CompressedSizeInfoColGroup> ret = new ArrayList<>(workSet.size());

for(ColIndexes w : workSet)
ret.add(mem.get(w));

return ret;
}

protected static void parallelFirstJoin(List<ColIndexes> workSet, Memorizer mem, ICostEstimate cEst, int k) {
protected static void parallelFirstJoin(List<ColIndexes> workSet, Memorizer mem, ICostEstimate cEst,
double costFilterThreshold, int k) {
try {
ExecutorService pool = CommonThreadPool.get(k);
List<JoinTask> tasks = new ArrayList<>();
for(int i = 0; i < workSet.size(); i++)
for(int j = i + 1; j < workSet.size(); j++) {
final ColIndexes c1 = workSet.get(i);
final ColIndexes c2 = workSet.get(j);

final int csi1 = mem.get(c1).getNumVals();
final int csi2 = mem.get(c2).getNumVals();

if(csi1 * csi2 > 10000)
continue;

final ExecutorService pool = CommonThreadPool.get(k);
final List<JoinTask> tasks = new ArrayList<>();
final int size = workSet.size();
for(int i = 0; i < size; i++)
for(int j = i + 1; j < size; j++)
tasks.add(new JoinTask(workSet.get(i), workSet.get(j), mem));
}

for(Future<Object> t : pool.invokeAll(tasks))
t.get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,18 +123,19 @@ public final void leftMultByMatrix(MatrixBlock matrix, MatrixBlock result, int r
public final Dictionary preAggregateThatIndexStructure(APreAgg that) {
int outputLength = that._colIndexes.length * this.getNumValues();
Dictionary ret = new Dictionary(new double[outputLength]);
String cThis = this.getClass().getSimpleName();
String cThat = that.getClass().getSimpleName();

if(that instanceof ColGroupDDC)
preAggregateThatDDCStructure((ColGroupDDC) that, ret);
else if(that instanceof ColGroupSDCSingleZeros)
preAggregateThatSDCSingleZerosStructure((ColGroupSDCSingleZeros) that, ret);
else if(that instanceof ColGroupSDCZeros)
preAggregateThatSDCZerosStructure((ColGroupSDCZeros) that, ret);
else
else {
final String cThis = this.getClass().getSimpleName();
final String cThat = that.getClass().getSimpleName();
throw new NotImplementedException(
"Not supported pre aggregate using index structure of :" + cThat + " in " + cThis);
}
return ret;
}

Expand All @@ -150,19 +151,10 @@ else if(that instanceof ColGroupSDCZeros)

protected abstract boolean sameIndexStructure(AColGroupCompressed that);

public int getPreAggregateSize(){
public int getPreAggregateSize() {
return getNumValues();
}


private final ADictionary preAggLeft(APreAgg lhs) {
return lhs.preAggregateThatIndexStructure(this);
}

private final ADictionary preAggRight(APreAgg lhs) {
return this.preAggregateThatIndexStructure(lhs);
}

private void tsmmAColGroupValue(APreAgg lg, MatrixBlock result) {
final int[] rightIdx = this._colIndexes;
final int[] leftIdx = lg._colIndexes;
Expand All @@ -179,10 +171,12 @@ private void tsmmAColGroupValue(APreAgg lg, MatrixBlock result) {
boolean left = !shouldPreAggregateLeft(lg);
if(left) {
l = lg._dict.getValues();
r = preAggLeft(lg).getValues();
// leftAgg
r = lg.preAggregateThatIndexStructure(this).getValues();
}
else {
l = preAggRight(lg).getValues();
// rightAgg
l = this.preAggregateThatIndexStructure(lg).getValues();
r = _dict.getValues();
}
MMDenseToUpperTriangle(l, r, leftIdx, rightIdx, result);
Expand All @@ -203,10 +197,10 @@ private void leftMultByColGroupValue(APreAgg lhs, MatrixBlock result) {
MMDictsWithScaling(lDict, rDict, leftIdx, rightIdx, result, c);
}
else {
if(shouldPreAggregateLeft(lhs))
MMDicts(lDict, preAggLeft(lhs), leftIdx, rightIdx, result);
else
MMDicts(preAggRight(lhs), rDict, leftIdx, rightIdx, result);
if(shouldPreAggregateLeft(lhs)) // left preAgg
MMDicts(lDict, lhs.preAggregateThatIndexStructure(this), leftIdx, rightIdx, result);
else // right preAgg
MMDicts(this.preAggregateThatIndexStructure(lhs), rDict, leftIdx, rightIdx, result);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,32 +180,23 @@ private void preAggregateSparse(SparseBlock sb, MatrixBlock preAgg, int rl, int

@Override
public void preAggregateThatDDCStructure(ColGroupDDC that, Dictionary ret) {
_data.preAggregateDDC(that._data, that._dict, ret, that._colIndexes.length);
_data.preAggregateDDC_DDC(that._data, that._dict, ret, that._colIndexes.length);
}

@Override
public void preAggregateThatSDCZerosStructure(ColGroupSDCZeros that, Dictionary ret) {
final AIterator itThat = that._indexes.getIterator();
final int nCol = that._colIndexes.length;
final int finalOff = that._indexes.getOffsetToLast();
while(true) {
final int to = _data.getIndex(itThat.value());
final int fr = that._data.getIndex(itThat.getDataIndex());
that._dict.addToEntry(ret, fr, to, nCol);
if(itThat.value() == finalOff)
break;
itThat.next();
}
_data.preAggregateDDC_SDCZ(that._data, that._dict, that._indexes, ret, that._colIndexes.length);
}

@Override
public void preAggregateThatSDCSingleZerosStructure(ColGroupSDCSingleZeros that, Dictionary ret) {
final AIterator itThat = that._indexes.getIterator();
final int nCol = that._colIndexes.length;
final int finalOff = that._indexes.getOffsetToLast();
final double[] v = ret.getValues();
while(true) {
final int to = _data.getIndex(itThat.value());
that._dict.addToEntry(ret, 0, to, nCol);
that._dict.addToEntry(v, 0, to, nCol);
if(itThat.value() == finalOff)
break;
itThat.next();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ private static AColGroup directCompressDDCColGroup(int[] colIndexes, MatrixBlock
final int fill = data.getUpperBoundValue();
data.fill(fill);

DblArrayCountHashMap map = new DblArrayCountHashMap(cg.getNumVals());
DblArrayCountHashMap map = new DblArrayCountHashMap(cg.getNumVals(), colIndexes.length);
boolean extra;
if(rlen < CompressionSettings.PAR_DDC_THRESHOLD || k == 1)
extra = readToMapDDC(colIndexes, raw, map, cs, data, 0, rlen, fill);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -422,10 +422,11 @@ public void preAggregateThatDDCStructure(ColGroupDDC that, Dictionary ret) {
final AIterator itThis = _indexes.getIterator();
final int nCol = that._colIndexes.length;
final int finalOffThis = _indexes.getOffsetToLast();
final double[] v = ret.getValues();

while(true) {
final int fr = that._data.getIndex(itThis.value());
that._dict.addToEntry(ret, fr, 0, nCol);
that._dict.addToEntry(v, fr, 0, nCol);
if(itThis.value() >= finalOffThis)
break;
else
Expand All @@ -440,10 +441,11 @@ public void preAggregateThatSDCZerosStructure(ColGroupSDCZeros that, Dictionary
final int nCol = that._colIndexes.length;
final int finalOffThis = _indexes.getOffsetToLast();
final int finalOffThat = that._indexes.getOffsetToLast();
final double[] v = ret.getValues();

while(true) {
if(itThat.value() == itThis.value()) {
that._dict.addToEntry(ret, that._data.getIndex(itThat.getDataIndex()), 0, nCol);
that._dict.addToEntry(v, that._data.getIndex(itThat.getDataIndex()), 0, nCol);
if(itThat.value() >= finalOffThat)
break;
else
Expand Down Expand Up @@ -475,10 +477,11 @@ public void preAggregateThatSDCSingleZerosStructure(ColGroupSDCSingleZeros that,
final AIterator itThis = _indexes.getIterator();
final int finalOffThis = _indexes.getOffsetToLast();
final int finalOffThat = that._indexes.getOffsetToLast();
final double[] v = ret.getValues();

while(true) {
if(itThat.value() == itThis.value()) {
that._dict.addToEntry(ret, 0, 0, nCol);
that._dict.addToEntry(v, 0, 0, nCol);
if(itThat.value() >= finalOffThat)
break;
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -514,57 +514,12 @@ public boolean sameIndexStructure(AColGroupCompressed that) {

@Override
public void preAggregateThatDDCStructure(ColGroupDDC that, Dictionary ret) {
final AIterator itThis = _indexes.getIterator();
final int nCol = that._colIndexes.length;

final int finalOffThis = _indexes.getOffsetToLast();
while(true) {
final int fr = that._data.getIndex(itThis.value());
final int to = _data.getIndex(itThis.getDataIndex());
that._dict.addToEntry(ret, fr, to, nCol);
if(itThis.value() >= finalOffThis)
break;
else
itThis.next();
}
_data.preAggregateSDCZ_DDC(that._data, that._dict, _indexes, ret, that._colIndexes.length);
}

@Override
public void preAggregateThatSDCZerosStructure(ColGroupSDCZeros that, Dictionary ret) {
final AIterator itThat = that._indexes.getIterator();
final AIterator itThis = _indexes.getIterator();

final int finalOffThis = _indexes.getOffsetToLast();
final int finalOffThat = that._indexes.getOffsetToLast();

final int nCol = that._colIndexes.length;
while(true) {
if(itThat.value() == itThis.value()) {
final int fr = that._data.getIndex(itThat.getDataIndex());
final int to = _data.getIndex(itThis.getDataIndex());
that._dict.addToEntry(ret, fr, to, nCol);
if(itThat.value() >= finalOffThat)
break;
else
itThat.next();
if(itThis.value() >= finalOffThis)
break;
else
itThis.next();
}
else if(itThat.value() < itThis.value()) {
if(itThat.value() >= finalOffThat)
break;
else
itThat.next();
}
else {
if(itThis.value() >= finalOffThis)
break;
else
itThis.next();
}
}
_data.preAggregateSDCZ_SDCZ(that._data, that._dict, that._indexes, _indexes, ret, that._colIndexes.length);
}

@Override
Expand All @@ -575,11 +530,12 @@ public void preAggregateThatSDCSingleZerosStructure(ColGroupSDCSingleZeros that,

final int finalOffThis = _indexes.getOffsetToLast();
final int finalOffThat = that._indexes.getOffsetToLast();
final double[] v = ret.getValues();

while(true) {
if(itThat.value() == itThis.value()) {
final int to = _data.getIndex(itThis.getDataIndex());
that._dict.addToEntry(ret, 0, to, nCol);
that._dict.addToEntry(v, 0, to, nCol);
if(itThat.value() >= finalOffThat)
break;
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -457,24 +457,18 @@ public abstract ADictionary binOpRightWithReference(BinaryOperator op, double[]
*/
public abstract long getNumberNonZerosWithReference(int[] counts, double[] reference, int nRows);

/**
* Single column version of copy add to entry.
*
* @param d target dictionary to add to
* @param fr Take from this index
* @param to put into index in d.
*/
public abstract void addToEntry(Dictionary d, int fr, int to);

/**
* Copies and adds the dictionary entry from this dictionary to the d dictionary
*
* @param d the target dictionary
* @param v the target dictionary (dense double array)
* @param fr the from index
* @param to the to index
* @param nCol the number of columns
*/
public abstract void addToEntry(Dictionary d, int fr, int to, int nCol);
public abstract void addToEntry(double[] v, int fr, int to, int nCol);

public abstract void addToEntryVectorized(double[] v, int f1, int f2, int f3, int f4, int f5, int f6, int f7, int f8, int t1,
int t2, int t3, int t4, int t5, int t6, int t7, int t8, int nCol);

/**
* Allocate a new dictionary where the tuple given is subtracted from all tuples in the previous dictionary.
Expand Down
Loading

0 comments on commit 2fffa15

Please sign in to comment.