diff --git a/src/main/java/org/apache/sysds/hops/estim/MMNode.java b/src/main/java/org/apache/sysds/hops/estim/MMNode.java index 89c706fd87e..7483dba0223 100644 --- a/src/main/java/org/apache/sysds/hops/estim/MMNode.java +++ b/src/main/java/org/apache/sysds/hops/estim/MMNode.java @@ -47,6 +47,15 @@ public MMNode(MatrixBlock in) { _op = null; _misc = null; } + + public MMNode(DataCharacteristics mc) { + _m1 = null; + _m2 = null; + _data = null; + _mc = mc; + _op = null; + _misc = null; + } public MMNode(MMNode left, MMNode right, OpCode op, long[] misc) { _m1 = left; @@ -112,7 +121,7 @@ public MMNode getRight() { } public boolean isLeaf() { - return _data != null; + return _op == null; } public MatrixBlock getData() { diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriteStatus.java b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriteStatus.java index 0c86aab59db..089d65e509e 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriteStatus.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriteStatus.java @@ -19,8 +19,6 @@ package org.apache.sysds.hops.rewrite; -import org.apache.sysds.runtime.controlprogram.LocalVariableMap; - public class ProgramRewriteStatus { //status of applied rewrites @@ -30,7 +28,6 @@ public class ProgramRewriteStatus //current context private boolean _inParforCtx = false; - private LocalVariableMap _vars = null; public ProgramRewriteStatus() { _rmBranches = false; @@ -38,11 +35,6 @@ public ProgramRewriteStatus() { _injectCheckpoints = false; } - public ProgramRewriteStatus(LocalVariableMap vars) { - this(); - _vars = vars; - } - public void setRemovedBranches(){ _rmBranches = true; } @@ -74,8 +66,4 @@ public void setInjectedCheckpoints(){ public boolean getInjectedCheckpoints(){ return _injectCheckpoints; } - - public LocalVariableMap getVariables() { - return _vars; - } } diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimizationSparse.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimizationSparse.java index 48b457f759d..2ae679d3fb5 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimizationSparse.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimizationSparse.java @@ -23,17 +23,11 @@ import java.util.List; import org.apache.commons.lang3.mutable.MutableInt; -import org.apache.sysds.common.Types.OpOpData; import org.apache.sysds.hops.Hop; -import org.apache.sysds.hops.HopsException; +import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.hops.estim.MMNode; -import org.apache.sysds.hops.estim.EstimatorMatrixHistogram; -import org.apache.sysds.hops.estim.EstimatorMatrixHistogram.MatrixHistogram; +import org.apache.sysds.hops.estim.EstimatorBasicAvg; import org.apache.sysds.hops.estim.SparsityEstimator.OpCode; -import org.apache.sysds.runtime.controlprogram.LocalVariableMap; -import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; -import org.apache.sysds.runtime.instructions.cp.Data; -import org.apache.sysds.runtime.matrix.data.MatrixBlock; /** * Rule: Determine the optimal order of execution for a chain of @@ -53,9 +47,8 @@ protected void optimizeMMChain(Hop hop, List mmChain, List mmOperators double[] dimsArray = new double[mmChain.size() + 1]; boolean dimsKnown = getDimsArray( hop, mmChain, dimsArray ); MMNode[] sketchArray = new MMNode[mmChain.size() + 1]; - boolean inputsAvail = getInputMatrices(hop, mmChain, sketchArray, state); - - if( dimsKnown && inputsAvail ) { + boolean inputMetaAvail = getInputMatrixCharacteristics(hop, mmChain, sketchArray, state); + if(dimsKnown && inputMetaAvail) { // Step 3: clear the links among Hops within the identified chain clearLinksWithinChain ( hop, mmOperators ); @@ -92,7 +85,7 @@ private static int[][] mmChainDPSparse(double[] dimArray, MMNode[] sketchArray, } //compute cost-optimal chains for increasing chain sizes - EstimatorMatrixHistogram estim = new EstimatorMatrixHistogram(true); + EstimatorBasicAvg estim = new EstimatorBasicAvg(); for( int l = 2; l <= size; l++ ) { // chain length for( int i = 0; i < size - l + 1; i++ ) { int j = i + l - 1; @@ -102,14 +95,14 @@ private static int[][] mmChainDPSparse(double[] dimArray, MMNode[] sketchArray, { //construct estimation nodes (w/ lazy propagation and memoization) MMNode tmp = new MMNode(dpMatrixS[i][k], dpMatrixS[k+1][j], OpCode.MM); - estim.estim(tmp, false); - MatrixHistogram lhs = (MatrixHistogram) dpMatrixS[i][k].getSynopsis(); - MatrixHistogram rhs = (MatrixHistogram) dpMatrixS[k+1][j].getSynopsis(); - + estim.estim(tmp); + //recursive cost computation - double cost = dpMatrix[i][k] + dpMatrix[k + 1][j] - + dotProduct(lhs.getColCounts(), rhs.getRowCounts()); - + double cost = dpMatrix[i][k] + dpMatrix[k+1][j] + + OptimizerUtils.getSparsity(tmp.getLeft().getDataCharacteristics()) * + OptimizerUtils.getSparsity(tmp.getRight().getDataCharacteristics()) * + tmp.getLeft().getRows() * tmp.getLeft().getCols() * tmp.getRight().getCols(); + //prune suboptimal if( cost < dpMatrix[i][j] ) { dpMatrix[i][j] = cost; @@ -118,41 +111,29 @@ private static int[][] mmChainDPSparse(double[] dimArray, MMNode[] sketchArray, } } - if( LOG.isTraceEnabled() ){ - LOG.trace("mmchainopt [i="+(i+1)+",j="+(j+1)+"]: costs = "+dpMatrix[i][j]+", split = "+(split[i][j]+1)); - } + if(LOG.isTraceEnabled()) + LOG.trace("mmchainoptsparse [i="+(i+1)+",j="+(j+1)+"]: costs = "+dpMatrix[i][j]+", split = "+(split[i][j]+1)); } } return split; } - - private static boolean getInputMatrices(Hop hop, List chain, MMNode[] sketchArray, ProgramRewriteStatus state) { - boolean inputsAvail = true; - LocalVariableMap vars = state.getVariables(); - - for( int i=0; i chain, MMNode[] sketchArray, ProgramRewriteStatus state) { + boolean inputMetaAvail = true; + + for(int counter = 0; counter < chain.size(); counter++ ) { + Hop currentHop = chain.get(counter); + inputMetaAvail &= currentHop.isMatrix(); + inputMetaAvail &= !currentHop.isFederated(); + inputMetaAvail &= (currentHop.getDataCharacteristics().getNonZeros() != -1); + if(inputMetaAvail) { + sketchArray[counter] = new MMNode(currentHop.getDataCharacteristics()); + } + else break; } - - return inputsAvail; - } - - private static MatrixBlock getMatrix(String name, LocalVariableMap vars) { - Data dat = vars.get(name); - if( !(dat instanceof MatrixObject) ) - throw new HopsException("Input '"+name+"' not a matrix: "+dat.getDataType()); - return ((MatrixObject)dat).acquireReadAndRelease(); - } - - private static double dotProduct(int[] h1cNnz, int[] h2rNnz) { - long fp = 0; - for( int j=0; j x.contains(str)); } diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteMatrixMultChainOptSparseTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteMatrixMultChainOptSparseTest.java index 6f323b6aa99..f32cf52186c 100644 --- a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteMatrixMultChainOptSparseTest.java +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteMatrixMultChainOptSparseTest.java @@ -19,6 +19,9 @@ package org.apache.sysds.test.functions.rewrite; +import org.apache.log4j.Level; +import org.apache.log4j.Logger; +import org.apache.log4j.spi.LoggingEvent; import org.apache.sysds.common.Opcodes; import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.hops.recompile.Recompiler; @@ -26,26 +29,65 @@ import org.apache.sysds.test.AutomatedTestBase; import org.apache.sysds.test.TestConfiguration; import org.apache.sysds.test.TestUtils; +import org.apache.sysds.test.LoggingUtils; +import org.apache.sysds.test.LoggingUtils.TestAppender; + import org.junit.Assert; +import org.junit.runners.Parameterized; import org.junit.Test; +import org.junit.runner.RunWith; +import java.util.Arrays; +import java.util.Collection; import java.util.HashMap; +import java.util.List; +import java.util.stream.DoubleStream; +import java.util.stream.Stream; +@RunWith(value = Parameterized.class) +@net.jcip.annotations.NotThreadSafe public class RewriteMatrixMultChainOptSparseTest extends AutomatedTestBase { private static final String TEST_NAME = "RewriteMatrixMultChainOpSparse"; private static final String TEST_DIR = "functions/rewrite/"; private static final String TEST_CLASS_DIR = TEST_DIR + RewriteMatrixMultChainOptSparseTest.class.getSimpleName() + "/"; + private static final String PACKAGE = "org.apache.sysds.hops.rewrite.HopRewriteRule"; + private static Level _oldLevel; + + @Parameterized.Parameter(0) + public int rows; + + @Parameterized.Parameter(1) + public int cols; - private static final int rows = 1000; - private static final int cols = 300; - private static final double eps = Math.pow(10, -10); + @Parameterized.Parameter(2) + public double[] sparsities; + + @Parameterized.Parameter(3) + public double eps; + + @Parameterized.Parameters + public static Collection data() { + return Arrays.asList(new Object[][] { + // {rows, cols, sparsities, eps}, + {1000, 300, new double[]{0.10d, 0.10d}, Math.pow(10, -10)}, + // {2, 300, new double[]{0.005, 1}, Math.pow(10, -10)}, + }); + } @Override public void setUp() { TestUtils.clearAssertionInformation(); addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"})); + _oldLevel = Logger.getLogger(PACKAGE).getLevel(); + Logger.getLogger(PACKAGE).setLevel(Level.TRACE); + } + + @Override + public void tearDown() { + super.tearDown(); + Logger.getLogger(PACKAGE).setLevel(_oldLevel); } @Test @@ -74,13 +116,19 @@ private void testRewriteMatrixMultChainOpSparse(boolean rewrites) { OptimizerUtils.ALLOW_ADVANCED_MMCHAIN_REWRITES = rewrites; OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites; - double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.10d, 7); - double[][] Y = getRandomMatrix(cols, 1, -1, 1, 0.10d, 3); - writeInputMatrixWithMTD("X", X, true); - writeInputMatrixWithMTD("Y", Y, true); + double[][] X = getRandomMatrix(rows, cols, -1, 1, sparsities[0], 7); + double[][] Y = getRandomMatrix(cols, 1, -1, 1, sparsities[1], 3); + long X_nnz = Stream.of(X).mapToLong(row -> DoubleStream.of(row).filter(val -> val != 0).count()).sum(); + long Y_nnz = Stream.of(Y).mapToLong(row -> DoubleStream.of(row).filter(val -> val != 0).count()).sum(); + writeInputMatrixWithMTD("X", X, X_nnz, true); + writeInputMatrixWithMTD("Y", Y, Y_nnz, true); + //execute tests + TestAppender appender = LoggingUtils.overwrite(); // capture log output runTest(true, false, null, -1); + List log_out = LoggingUtils.reinsert(appender); // revert the logger to print to stdout + runRScript(true); //compare matrices @@ -89,12 +137,16 @@ private void testRewriteMatrixMultChainOpSparse(boolean rewrites) { TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); if(rewrites) { + Assert.assertTrue(log_out.stream().anyMatch( + l -> l.getMessage().toString().contains("mmchainoptsparse"))); Assert.assertTrue(heavyHittersContainsSubString(Opcodes.MMCHAIN.toString()) || heavyHittersContainsSubString("sp_mapmmchain")); } else { + Assert.assertFalse(log_out.stream().anyMatch( + l -> l.getMessage().toString().contains("mmchainoptsparse"))); Assert.assertFalse(heavyHittersContainsSubString(Opcodes.MMCHAIN.toString()) || - heavyHittersContainsSubString("sp_mapmmchain")); + heavyHittersContainsSubString("sp_mapmmchain")); } } finally {