Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/main/java/org/apache/sysds/hops/estim/MMNode.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ public MMNode(MatrixBlock in) {
_op = null;
_misc = null;
}

public MMNode(DataCharacteristics mc) {
_m1 = null;
_m2 = null;
_data = null;
_mc = mc;
_op = null;
_misc = null;
}

public MMNode(MMNode left, MMNode right, OpCode op, long[] misc) {
_m1 = left;
Expand Down Expand Up @@ -112,7 +121,7 @@ public MMNode getRight() {
}

public boolean isLeaf() {
return _data != null;
return _op == null;
}

public MatrixBlock getData() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@

package org.apache.sysds.hops.rewrite;

import org.apache.sysds.runtime.controlprogram.LocalVariableMap;

public class ProgramRewriteStatus
{
//status of applied rewrites
Expand All @@ -30,19 +28,13 @@ public class ProgramRewriteStatus

//current context
private boolean _inParforCtx = false;
private LocalVariableMap _vars = null;

public ProgramRewriteStatus() {
_rmBranches = false;
_inParforCtx = false;
_injectCheckpoints = false;
}

public ProgramRewriteStatus(LocalVariableMap vars) {
this();
_vars = vars;
}

public void setRemovedBranches(){
_rmBranches = true;
}
Expand Down Expand Up @@ -74,8 +66,4 @@ public void setInjectedCheckpoints(){
public boolean getInjectedCheckpoints(){
return _injectCheckpoints;
}

public LocalVariableMap getVariables() {
return _vars;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,11 @@
import java.util.List;

import org.apache.commons.lang3.mutable.MutableInt;
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.HopsException;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.estim.MMNode;
import org.apache.sysds.hops.estim.EstimatorMatrixHistogram;
import org.apache.sysds.hops.estim.EstimatorMatrixHistogram.MatrixHistogram;
import org.apache.sysds.hops.estim.EstimatorBasicAvg;
import org.apache.sysds.hops.estim.SparsityEstimator.OpCode;
import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;

/**
* Rule: Determine the optimal order of execution for a chain of
Expand All @@ -53,9 +47,8 @@ protected void optimizeMMChain(Hop hop, List<Hop> mmChain, List<Hop> mmOperators
double[] dimsArray = new double[mmChain.size() + 1];
boolean dimsKnown = getDimsArray( hop, mmChain, dimsArray );
MMNode[] sketchArray = new MMNode[mmChain.size() + 1];
boolean inputsAvail = getInputMatrices(hop, mmChain, sketchArray, state);

if( dimsKnown && inputsAvail ) {
boolean inputMetaAvail = getInputMatrixCharacteristics(hop, mmChain, sketchArray, state);
if(dimsKnown && inputMetaAvail) {
// Step 3: clear the links among Hops within the identified chain
clearLinksWithinChain ( hop, mmOperators );

Expand Down Expand Up @@ -92,7 +85,7 @@ private static int[][] mmChainDPSparse(double[] dimArray, MMNode[] sketchArray,
}

//compute cost-optimal chains for increasing chain sizes
EstimatorMatrixHistogram estim = new EstimatorMatrixHistogram(true);
EstimatorBasicAvg estim = new EstimatorBasicAvg();
for( int l = 2; l <= size; l++ ) { // chain length
for( int i = 0; i < size - l + 1; i++ ) {
int j = i + l - 1;
Expand All @@ -102,14 +95,14 @@ private static int[][] mmChainDPSparse(double[] dimArray, MMNode[] sketchArray,
{
//construct estimation nodes (w/ lazy propagation and memoization)
MMNode tmp = new MMNode(dpMatrixS[i][k], dpMatrixS[k+1][j], OpCode.MM);
estim.estim(tmp, false);
MatrixHistogram lhs = (MatrixHistogram) dpMatrixS[i][k].getSynopsis();
MatrixHistogram rhs = (MatrixHistogram) dpMatrixS[k+1][j].getSynopsis();

estim.estim(tmp);

//recursive cost computation
double cost = dpMatrix[i][k] + dpMatrix[k + 1][j]
+ dotProduct(lhs.getColCounts(), rhs.getRowCounts());

double cost = dpMatrix[i][k] + dpMatrix[k+1][j] +
OptimizerUtils.getSparsity(tmp.getLeft().getDataCharacteristics()) *
OptimizerUtils.getSparsity(tmp.getRight().getDataCharacteristics()) *
tmp.getLeft().getRows() * tmp.getLeft().getCols() * tmp.getRight().getCols();

//prune suboptimal
if( cost < dpMatrix[i][j] ) {
dpMatrix[i][j] = cost;
Expand All @@ -118,41 +111,29 @@ private static int[][] mmChainDPSparse(double[] dimArray, MMNode[] sketchArray,
}
}

if( LOG.isTraceEnabled() ){
LOG.trace("mmchainopt [i="+(i+1)+",j="+(j+1)+"]: costs = "+dpMatrix[i][j]+", split = "+(split[i][j]+1));
}
if(LOG.isTraceEnabled())
LOG.trace("mmchainoptsparse [i="+(i+1)+",j="+(j+1)+"]: costs = "+dpMatrix[i][j]+", split = "+(split[i][j]+1));
}
}

return split;
}

private static boolean getInputMatrices(Hop hop, List<Hop> chain, MMNode[] sketchArray, ProgramRewriteStatus state) {
boolean inputsAvail = true;
LocalVariableMap vars = state.getVariables();

for( int i=0; i<chain.size(); i++ ) {
inputsAvail &= HopRewriteUtils.isData(chain.get(0), OpOpData.TRANSIENTREAD);
if( inputsAvail )
sketchArray[i] = new MMNode(getMatrix(chain.get(i).getName(), vars));
else

private static boolean getInputMatrixCharacteristics(Hop hop, List<Hop> chain, MMNode[] sketchArray, ProgramRewriteStatus state) {
boolean inputMetaAvail = true;

for(int counter = 0; counter < chain.size(); counter++ ) {
Hop currentHop = chain.get(counter);
inputMetaAvail &= currentHop.isMatrix();
inputMetaAvail &= !currentHop.isFederated();
inputMetaAvail &= (currentHop.getDataCharacteristics().getNonZeros() != -1);
if(inputMetaAvail) {
sketchArray[counter] = new MMNode(currentHop.getDataCharacteristics());
}
else
break;
}

return inputsAvail;
}

private static MatrixBlock getMatrix(String name, LocalVariableMap vars) {
Data dat = vars.get(name);
if( !(dat instanceof MatrixObject) )
throw new HopsException("Input '"+name+"' not a matrix: "+dat.getDataType());
return ((MatrixObject)dat).acquireReadAndRelease();
}

private static double dotProduct(int[] h1cNnz, int[] h2rNnz) {
long fp = 0;
for( int j=0; j<h1cNnz.length; j++ )
fp += (long)h1cNnz[j] * h2rNnz[j];
return fp;

return inputMetaAvail;
}
}
2 changes: 2 additions & 0 deletions src/test/java/org/apache/sysds/test/AutomatedTestBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -2261,6 +2261,8 @@ public void tearDown() {
}

public boolean bufferContainsString(ByteArrayOutputStream buffer, String str) {
if(buffer == null)
return false;
return Arrays.stream(buffer.toString().split("\n")).anyMatch(x -> x.contains(str));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,33 +19,75 @@

package org.apache.sysds.test.functions.rewrite;

import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.log4j.spi.LoggingEvent;
import org.apache.sysds.common.Opcodes;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.apache.sysds.test.LoggingUtils;
import org.apache.sysds.test.LoggingUtils.TestAppender;

import org.junit.Assert;
import org.junit.runners.Parameterized;
import org.junit.Test;
import org.junit.runner.RunWith;

import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.stream.DoubleStream;
import java.util.stream.Stream;

@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
public class RewriteMatrixMultChainOptSparseTest extends AutomatedTestBase {

private static final String TEST_NAME = "RewriteMatrixMultChainOpSparse";
private static final String TEST_DIR = "functions/rewrite/";
private static final String TEST_CLASS_DIR =
TEST_DIR + RewriteMatrixMultChainOptSparseTest.class.getSimpleName() + "/";
private static final String PACKAGE = "org.apache.sysds.hops.rewrite.HopRewriteRule";
private static Level _oldLevel;

@Parameterized.Parameter(0)
public int rows;

@Parameterized.Parameter(1)
public int cols;

private static final int rows = 1000;
private static final int cols = 300;
private static final double eps = Math.pow(10, -10);
@Parameterized.Parameter(2)
public double[] sparsities;

@Parameterized.Parameter(3)
public double eps;

@Parameterized.Parameters
public static Collection<Object[]> data() {
return Arrays.asList(new Object[][] {
// {rows, cols, sparsities, eps},
{1000, 300, new double[]{0.10d, 0.10d}, Math.pow(10, -10)},
// {2, 300, new double[]{0.005, 1}, Math.pow(10, -10)},

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you changed it to parameterized, enable the other test as well

});
}

@Override
public void setUp() {
TestUtils.clearAssertionInformation();
addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"}));
_oldLevel = Logger.getLogger(PACKAGE).getLevel();
Logger.getLogger(PACKAGE).setLevel(Level.TRACE);
}

@Override
public void tearDown() {
super.tearDown();
Logger.getLogger(PACKAGE).setLevel(_oldLevel);
}

@Test
Expand Down Expand Up @@ -74,13 +116,19 @@ private void testRewriteMatrixMultChainOpSparse(boolean rewrites) {

OptimizerUtils.ALLOW_ADVANCED_MMCHAIN_REWRITES = rewrites;
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.10d, 7);
double[][] Y = getRandomMatrix(cols, 1, -1, 1, 0.10d, 3);
writeInputMatrixWithMTD("X", X, true);
writeInputMatrixWithMTD("Y", Y, true);
double[][] X = getRandomMatrix(rows, cols, -1, 1, sparsities[0], 7);
double[][] Y = getRandomMatrix(cols, 1, -1, 1, sparsities[1], 3);
long X_nnz = Stream.of(X).mapToLong(row -> DoubleStream.of(row).filter(val -> val != 0).count()).sum();
long Y_nnz = Stream.of(Y).mapToLong(row -> DoubleStream.of(row).filter(val -> val != 0).count()).sum();
writeInputMatrixWithMTD("X", X, X_nnz, true);
writeInputMatrixWithMTD("Y", Y, Y_nnz, true);


//execute tests
TestAppender appender = LoggingUtils.overwrite(); // capture log output
runTest(true, false, null, -1);
List<LoggingEvent> log_out = LoggingUtils.reinsert(appender); // revert the logger to print to stdout

runRScript(true);

//compare matrices
Expand All @@ -89,12 +137,16 @@ private void testRewriteMatrixMultChainOpSparse(boolean rewrites) {
TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");

if(rewrites) {
Assert.assertTrue(log_out.stream().anyMatch(
l -> l.getMessage().toString().contains("mmchainoptsparse")));
Comment on lines +140 to +141

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to check more of the logging string.

Assert.assertTrue(heavyHittersContainsSubString(Opcodes.MMCHAIN.toString()) ||
heavyHittersContainsSubString("sp_mapmmchain"));
}
else {
Assert.assertFalse(log_out.stream().anyMatch(
l -> l.getMessage().toString().contains("mmchainoptsparse")));
Assert.assertFalse(heavyHittersContainsSubString(Opcodes.MMCHAIN.toString()) ||
heavyHittersContainsSubString("sp_mapmmchain"));
heavyHittersContainsSubString("sp_mapmmchain"));
}
}
finally {
Expand Down
Loading