From 800b0f73c63fe4c1e21e1f7b12e3e234bd2f23da Mon Sep 17 00:00:00 2001 From: Jessica Priebe Date: Fri, 26 Jun 2026 14:56:37 +0200 Subject: [PATCH] add reshape --- .../sysds/runtime/data/DenseBlockFP64.java | 20 + .../instructions/OOCInstructionParser.java | 3 +- .../instructions/ooc/CachingStream.java | 1 + .../instructions/ooc/OOCInstruction.java | 2 +- .../instructions/ooc/ReorgOOCInstruction.java | 42 +- .../ooc/ReshapeOOCInstruction.java | 434 ++++++++++++++++++ .../sysds/test/functions/ooc/ReshapeTest.java | 160 +++++++ .../functions/ooc/MatrixReshapeColWise.dml | 26 ++ .../functions/ooc/MatrixReshapeRowWise.dml | 26 ++ 9 files changed, 673 insertions(+), 41 deletions(-) create mode 100644 src/main/java/org/apache/sysds/runtime/instructions/ooc/ReshapeOOCInstruction.java create mode 100644 src/test/java/org/apache/sysds/test/functions/ooc/ReshapeTest.java create mode 100644 src/test/scripts/functions/ooc/MatrixReshapeColWise.dml create mode 100644 src/test/scripts/functions/ooc/MatrixReshapeRowWise.dml diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java index 94909444198..0a734261b9f 100644 --- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java +++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java @@ -181,6 +181,26 @@ public DenseBlock set(int rl, int ru, int ol, int ou, DenseBlock db) { return this; } + public DenseBlock setPartialRow(DenseBlock row, int rIdx, int srcOffset, int destOffset, int length) { + if(destOffset + length > _odims[0]) + throw new RuntimeException( + "Partial row assignment exceeds row length: " + (destOffset + length) + " > " + _odims[0]); + System.arraycopy(row.valuesAt(0), srcOffset, _data, this.pos(rIdx, destOffset), length); + return this; + } + + public DenseBlock setPartialCol(DenseBlock col, int cIdx, int srcOffset, int destOffset, int length) { + if(destOffset + length > _rlen) + throw new RuntimeException( + "Partial column assignment exceeds column length: " + (destOffset + length) + " > " + _rlen); + int destPos = this.pos(destOffset, cIdx); + double[] src = col.valuesAt(0); + for(int i = 0; i < length; i++) { + _data[destPos + i * _odims[0]] = src[srcOffset + i]; + } + return this; + } + @Override public DenseBlock set(int r, double[] v) { System.arraycopy(v, 0, _data, pos(r), _odims[0]); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java index ae41639687b..98a454283e2 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java @@ -43,6 +43,7 @@ import org.apache.sysds.runtime.instructions.ooc.ReorgOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.AppendOOCInstruction; +import org.apache.sysds.runtime.instructions.ooc.ReshapeOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.QuaternaryOOCInstruction; public class OOCInstructionParser extends InstructionParser { @@ -97,7 +98,7 @@ else if(parts.length == 4) case Reorg: return ReorgOOCInstruction.parseInstruction(str); case Reshape: - return ReorgOOCInstruction.parseInstruction(str); + return ReshapeOOCInstruction.parseInstruction(str); case Tee: return TeeOOCInstruction.parseInstruction(str); case CentralMoment: diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java index 56c265fe5e6..38929dcafcc 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java @@ -99,6 +99,7 @@ public CachingStream(OOCStream source, long streamId) { // Capture a short context to help identify origin OOCWatchdog.registerOpen(_watchdogId, toString(), getCtxMsg(), this); } + activateIndexing(); _downstreamRelays = null; source.setSubscriber(tmp -> { try(tmp) { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java index 859bca42dfe..80d71231646 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java @@ -80,7 +80,7 @@ public abstract class OOCInstruction extends Instruction { public enum OOCType { Reblock, Tee, Binary, Ternary, Unary, AggregateUnary, AggregateBinary, AggregateTernary, MAPMM, MMTSJ, - MAPMMCHAIN, Reorg, CM, Ctable, MatrixIndexing, ParameterizedBuiltin, Rand, Append, Quaternary + MAPMMCHAIN, Reorg, CM, Ctable, MatrixIndexing, ParameterizedBuiltin, Rand, Append, Quaternary, Reshape } protected final OOCInstruction.OOCType _ooctype; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReorgOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReorgOOCInstruction.java index cf77d559727..273d33341ab 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReorgOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReorgOOCInstruction.java @@ -41,31 +41,17 @@ public class ReorgOOCInstruction extends ComputationOOCInstruction { private final CPOperand _col; private final CPOperand _desc; private final CPOperand _ixret; - // reshape-specific attributes - private final CPOperand _opRows; - private final CPOperand _opCols; - //private final CPOperand _opDims; - private final CPOperand _opByRow; protected ReorgOOCInstruction(ReorgOperator op, CPOperand in1, CPOperand out, String opcode, String istr) { - this(op, in1, out, null, null, null, null, null, null, null, opcode, istr); - } - - private ReorgOOCInstruction(Operator op, CPOperand in, CPOperand out, CPOperand opRows, CPOperand opCols, - CPOperand opDims, CPOperand opByRow, String opcode, String istr) { - this(op, in, out, null, null, null, opRows, opCols, opDims, opByRow, opcode, istr); + this(op, in1, out, null, null, null, opcode, istr); } private ReorgOOCInstruction(Operator op, CPOperand in, CPOperand out, CPOperand col, CPOperand desc, CPOperand ixret, - CPOperand opRows, CPOperand opCols, CPOperand opDims, CPOperand opByRow, String opcode, String istr) { + String opcode, String istr) { super(OOCType.Reorg, op, in, out, opcode, istr); _col = col; _desc = desc; _ixret = ixret; - _opRows = opRows; - _opCols = opCols; - //_opDims = opDims; - _opByRow = opByRow; } public static ReorgOOCInstruction parseInstruction(String str) { @@ -92,35 +78,13 @@ else if(opcode.equalsIgnoreCase(Opcodes.SORT.toString())) { CPOperand ixret = new CPOperand(parts[4]); int k = Integer.parseInt(parts[6]); return new ReorgOOCInstruction(new ReorgOperator(new SortIndex(1, false, false), k), - in, out, col, desc, ixret, null, null, null, null, opcode, str); - } - else if(opcode.equalsIgnoreCase(Opcodes.RESHAPE.toString())) { - InstructionUtils.checkNumFields(parts, 6); - in.split(parts[1]); - CPOperand rows = new CPOperand(parts[2]); - CPOperand cols = new CPOperand(parts[3]); - CPOperand dims = new CPOperand(parts[4]); - CPOperand byRow = new CPOperand(parts[5]); - out.split(parts[6]); - return new ReorgOOCInstruction(new Operator(true), in, out, rows, cols, dims, byRow, opcode, str); + in, out, col, desc, ixret, opcode, str); } else throw new NotImplementedException(); } public void processInstruction( ExecutionContext ec ) { - if(getOpcode().equalsIgnoreCase(Opcodes.RESHAPE.toString())) { - // TODO Make reshape truly out-of-core - int rows = (int) ec.getScalarInput(_opRows).getLongValue(); - int cols = (int) ec.getScalarInput(_opCols).getLongValue(); - boolean byRow = ec.getScalarInput(_opByRow).getBooleanValue(); - MatrixBlock in = ec.getMatrixInput(input1.getName()); - MatrixBlock out = in.reshape(rows, cols, byRow); - ec.releaseMatrixInput(input1.getName()); - ec.setMatrixOutput(output.getName(), out); - return; - } - // Create thread and process the transpose/sort operation MatrixObject min = ec.getMatrixObject(input1); ReorgOperator r_op = (ReorgOperator) _optr; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReshapeOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReshapeOOCInstruction.java new file mode 100644 index 00000000000..4fdca081a77 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReshapeOOCInstruction.java @@ -0,0 +1,434 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.ooc; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.data.DenseBlockFP64; +import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.MatrixIndexes; +import org.apache.sysds.runtime.matrix.operators.Operator; + +import java.util.ArrayList; +import java.util.concurrent.CompletableFuture; + +public class ReshapeOOCInstruction extends ComputationOOCInstruction { + private final CPOperand _opRows; + private final CPOperand _opCols; + // private final CPOperand _opDims; + private final CPOperand _opByRow; + + private ReshapeOOCInstruction(Operator op, CPOperand in, CPOperand out, CPOperand rows, CPOperand cols, + CPOperand dims, CPOperand byRow, String opcode, String istr) { + super(OOCType.Reshape, op, in, out, opcode, istr); + _opRows = rows; + _opCols = cols; + // _opDims = dims; + _opByRow = byRow; + } + + public static ReshapeOOCInstruction parseInstruction(String str) { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + InstructionUtils.checkNumFields(parts, 6); + String opcode = parts[0]; + + if(!opcode.equalsIgnoreCase(Opcodes.RESHAPE.toString())) + throw new DMLRuntimeException("Unknown opcode while parsing ReshapeInstruction: " + str); + + CPOperand in = new CPOperand(parts[1]); + CPOperand rows = new CPOperand(parts[2]); + CPOperand cols = new CPOperand(parts[3]); + CPOperand dims = new CPOperand(parts[4]); + CPOperand byRow = new CPOperand(parts[5]); + CPOperand out = new CPOperand(parts[6]); + + return new ReshapeOOCInstruction(new Operator(true), in, out, rows, cols, dims, byRow, opcode, str); + } + + public void processInstruction(ExecutionContext ec) { + long rows = ec.getScalarInput(_opRows).getLongValue(); + long cols = ec.getScalarInput(_opCols).getLongValue(); + boolean byRow = ec.getScalarInput(_opByRow).getBooleanValue(); + + OOCStream qOut = createWritableStream(); + ec.getMatrixObject(output).setStreamHandle(qOut); + + MatrixObject in = ec.getMatrixObject(input1); + OOCStream qIn = in.getStreamHandle(); + int blen = in.getBlocksize(); + long rlen = in.getNumRows(); + long clen = in.getNumColumns(); + + if(rlen * clen != rows * cols) + throw new DMLRuntimeException("Reshape matrix requires consistent numbers of input/output cells (" + rlen + + ":" + clen + ", " + rows + ":" + cols + ")."); + + if(rlen == rows) { + mapOOC(qIn, qOut, tmp -> tmp); + return; + } + + if(clen <= blen && rlen <= blen && cols <= blen && rows <= blen) { + mapOOC(qIn, qOut, tmp -> { + MatrixBlock res = ((MatrixBlock) tmp.getValue()).reshape((int) rows, (int) cols, byRow); + return new IndexedMatrixValue(tmp.getIndexes(), res); + }); + return; + } + + int numBlocksPerRowIn = (int) Math.ceil((double) clen / blen); + int numBlocksPerColIn = (int) Math.ceil((double) rlen / blen); + int numBlocksPerRowOut = (int) Math.ceil((double) cols / blen); + int numBlocksPerColOut = (int) Math.ceil((double) rows / blen); + + if(byRow) { + OOCStream singleRowBlocks = new SubscribableTaskQueue<>(); + // split blocks into single rows and adapt index + CompletableFuture f = expandOOC(qIn, singleRowBlocks, tmp -> { + ArrayList out = new ArrayList<>(); + MatrixBlock blk = (MatrixBlock) tmp.getValue(); + for(int i = 0; i < blk.getNumRows(); i++) { + MatrixBlock slice = blk.slice(i, i); + long r = tmp.getIndexes().getRowIndex(); + long c = tmp.getIndexes().getColumnIndex(); + r = (r - 1) * blen + i + 1; + MatrixIndexes idx = new MatrixIndexes(r, c); + out.add(new IndexedMatrixValue(idx, slice)); + } + return out; + }); + + if(clen % blen == 0 && cols % blen == 0) { + // singleRowBlocks do not need to be split + if(rows == 1) { + // result is one single row + mapOOC(singleRowBlocks.getReadStream(), qOut, tmp -> { + long r = tmp.getIndexes().getRowIndex(); + long c = tmp.getIndexes().getColumnIndex(); + // adapt index to new position in row + return new IndexedMatrixValue(new MatrixIndexes(1, (r - 1) * numBlocksPerRowIn + c), tmp.getValue()); + }); + } + else { + f.join(); + reshapeFullColBlocks(rows, cols, blen, numBlocksPerRowIn, numBlocksPerRowOut, numBlocksPerColOut, singleRowBlocks, qOut); + } + } + else { + f.join(); + reshapePartialColBlocks(rlen, clen, rows, cols, blen, numBlocksPerRowIn, numBlocksPerRowOut, numBlocksPerColOut, singleRowBlocks, qOut); + } + } + else { + OOCStream singleColBlocks = new SubscribableTaskQueue<>(); + // split blocks into single cols and adapt index + CompletableFuture f = expandOOC(qIn, singleColBlocks, tmp -> { + ArrayList out = new ArrayList<>(); + MatrixBlock blk = (MatrixBlock) tmp.getValue(); + for(int i = 0; i < blk.getNumColumns(); i++) { + MatrixBlock slice = blk.slice(0, blk.getNumRows() - 1, i, i); + long r = tmp.getIndexes().getRowIndex(); + long c = tmp.getIndexes().getColumnIndex(); + c = (c - 1) * blen + i + 1; + MatrixIndexes idx = new MatrixIndexes(r, c); + out.add(new IndexedMatrixValue(idx, slice)); + } + return out; + }); + + if(rlen % blen == 0 && rows % blen == 0) { + // cols do not need to be split + if(cols == 1) { + // result is one single col + mapOOC(singleColBlocks.getReadStream(), qOut, tmp -> { + long r = tmp.getIndexes().getRowIndex(); + long c = tmp.getIndexes().getColumnIndex(); + // adapt index to new position in col + return new IndexedMatrixValue(new MatrixIndexes((c - 1) * numBlocksPerColIn + r, 1), tmp.getValue()); + }); + } + else { + f.join(); + reshapeFullRowBlocks(rows, cols, blen, numBlocksPerRowOut, numBlocksPerColIn, numBlocksPerColOut, singleColBlocks, qOut); + } + } + else { + f.join(); + reshapePartialRowBlocks(rlen, clen, rows, cols, blen, numBlocksPerRowOut, numBlocksPerColIn, numBlocksPerColOut, singleColBlocks, qOut); + } + } + } + + private void reshapeFullColBlocks(long rows, long cols, int blen, int numBlocksPerRowIn, int numBlocksPerRowOut, + int numBlocksPerColOut, OOCStream singleRowBlocks, OOCStream qOut) { + // use cache for accessing input rows by index + CachingStream singleRowBlockCache = new CachingStream(singleRowBlocks); + singleRowBlockCache.incrSubscriberCount(1); + singleRowBlockCache.scheduleDeletion(); + + // totalRowIdx corresponds to index of row block when all aligned in one row + // br * numBlocksPerRowOut * blen + b + r * numBlocksPerRowOut; + // with numBlocksPerRowOut * blen = cols + long totalIdx = -cols - 1 - numBlocksPerRowOut; + + // iterate through rows of output blocks + for(int br = 0; br < numBlocksPerColOut; br++) { + totalIdx += cols; + long tmp = totalIdx; + // for each block in row + for(int b = 0; b < numBlocksPerRowOut; b++) { + totalIdx += 1; + int localRows = (br == numBlocksPerColOut - 1 && rows % blen != 0) ? (int) rows % blen : blen; + MatrixBlock res = new MatrixBlock(localRows, blen, false); + long tmp2 = totalIdx; + // for each row in block + for(int r = 0; r < blen && r < localRows; r++) { + totalIdx += numBlocksPerRowOut; + // calc col idx for input + long colBlockIn = totalIdx % numBlocksPerRowIn + 1; + // calc row idx for input + long rowBlockIn = totalIdx / numBlocksPerRowIn + 1; + + try(OOCStream.QueueCallback cb = singleRowBlockCache + .findCached(new MatrixIndexes(rowBlockIn, colBlockIn))) { + MatrixBlock blk = (MatrixBlock) cb.get().getValue(); + res.setRow(r, blk.getDenseBlockValues()); + } + } + totalIdx = tmp2; + qOut.enqueue(new IndexedMatrixValue(new MatrixIndexes(br + 1, b + 1), res)); + } + totalIdx = tmp; + } + qOut.closeInput(); + } + + private void reshapePartialColBlocks(long rlen, long clen, long rows, long cols, int blen, int numBlocksPerRowIn, + int numBlocksPerRowOut, int numBlocksPerColOut, OOCStream singleRowBlocks, OOCStream qOut) { + // use cache for accessing input rows by index + CachingStream singleRowBlockCache = new CachingStream(singleRowBlocks); + singleRowBlockCache.incrSubscriberCount(1); + singleRowBlockCache.scheduleDeletion(); + + int br = 0; + int bc = 0; + int r = 0; + + // allocate row of output blocks + MatrixBlock[] outputBlockRow = allocateSliceBlocks(br, rows, cols, blen, numBlocksPerRowOut, numBlocksPerColOut,true); + + int offsetOut = 0; + int localColsOut = (cols > blen) ? blen : (int) cols; + // iterate through input rows and add to row of output blocks + for(int i = 1; i <= rlen; i++) { + for(int j = 1; j <= numBlocksPerRowIn; j++) { + try(OOCStream.QueueCallback qcb = singleRowBlockCache.findCached(new MatrixIndexes(i, j))) { + MatrixBlock blk = (MatrixBlock) qcb.get().getValue(); + + int offsetIn = 0; + int localColsIn = (j == numBlocksPerRowIn && clen % blen != 0) ? (int) clen % blen : blen; + while(offsetIn < localColsIn) { + // until input row fully processed + int remIn = localColsIn - offsetIn; + int remOut = localColsOut - offsetOut; + if(remIn < remOut) { + // next input + setOutputEntries(blk, outputBlockRow[bc], r, offsetIn, offsetOut, remIn, true); + offsetIn += remIn; + offsetOut += remIn; + continue; + } + else if(remIn == remOut) { + // next input and next row + setOutputEntries(blk, outputBlockRow[bc], r, offsetIn, offsetOut, remIn, true); + offsetIn += remIn; + } + else { + // next row + setOutputEntries(blk, outputBlockRow[bc], r, offsetIn, offsetOut, remOut, true); + offsetIn += remOut; + } + bc++; + offsetOut = 0; + if(bc == numBlocksPerRowOut) { + // next row + r++; + if(r == outputBlockRow[0].getNumRows()) { + // enqueue filled output blocks and allocate new ones + for(int b = 0; b < outputBlockRow.length; b++) + qOut.enqueue(new IndexedMatrixValue(new MatrixIndexes(br + 1, b + 1), outputBlockRow[b])); + br++; + // allocate new block row + outputBlockRow = allocateSliceBlocks(br, rows, cols, blen, numBlocksPerRowOut, numBlocksPerColOut, true); + r = 0; + } + bc = 0; + } + localColsOut = (bc == numBlocksPerRowOut - 1 && cols % blen != 0) ? (int) cols % blen : blen; + } + } + } + } + qOut.closeInput(); + } + + private void reshapeFullRowBlocks(long rows, long cols, int blen, int numBlocksPerRowOut, int numBlocksPerColIn, + int numBlocksPerColOut, OOCStream singleColBlocks, OOCStream qOut) { + // use cache for accessing input cols by index + CachingStream singleColBlockCache = new CachingStream(singleColBlocks); + singleColBlockCache.incrSubscriberCount(1); + singleColBlockCache.scheduleDeletion(); + + // totalColIdx corresponds to index of col block when all aligned in one col + // bc * numBlocksPerColOut * blen + b + c * numBlocksPerColOut; + // with numBlocksPerColOut * blen = rows + long totalIdx = -rows - 1 - numBlocksPerColOut; + + // iterate through cols of output blocks + for(int bc = 0; bc < numBlocksPerRowOut; bc++) { + totalIdx += rows; + long tmp = totalIdx; + // for each block in col + for(int b = 0; b < numBlocksPerColOut; b++) { + totalIdx += 1; + int localCols = (bc == numBlocksPerRowOut - 1 && cols % blen != 0) ? (int) cols % blen : blen; + MatrixBlock res = new MatrixBlock(blen, localCols, false); + res.allocateDenseBlock(); + long tmp2 = totalIdx; + // for each col in block + for(int c = 0; c < blen && c < localCols; c++) { + totalIdx += numBlocksPerColOut; + // calc col idx for input + long colBlockIn = totalIdx / numBlocksPerColIn + 1; + // calc row idx for input + long rowBlockIn = totalIdx % numBlocksPerColIn + 1; + + try(OOCStream.QueueCallback cb = singleColBlockCache + .findCached(new MatrixIndexes(rowBlockIn, colBlockIn))) { + MatrixBlock blk = (MatrixBlock) cb.get().getValue(); + res.getDenseBlock().set(0, blen, c, c + 1, blk.getDenseBlock()); + } + } + totalIdx = tmp2; + qOut.enqueue(new IndexedMatrixValue(new MatrixIndexes(b + 1, bc + 1), res)); + } + totalIdx = tmp; + } + qOut.closeInput(); + } + + private void reshapePartialRowBlocks(long rlen, long clen, long rows, long cols, int blen, int numBlocksPerRowOut, + int numBlocksPerColIn, int numBlocksPerColOut, OOCStream singleColBlocks, OOCStream qOut) { + // use cache for accessing input cols by index + CachingStream singleRowBlockCache = new CachingStream(singleColBlocks); + singleRowBlockCache.incrSubscriberCount(1); + singleRowBlockCache.scheduleDeletion(); + + int br = 0; + int bc = 0; + int c = 0; + + // allocate col of output blocks + MatrixBlock[] outputBlockCol = allocateSliceBlocks(bc, rows, cols, blen, numBlocksPerRowOut, numBlocksPerColOut, false); + + int offsetOut = 0; + int localRowsOut = (rows > blen) ? blen : (int) rows; + // iterate through input cols and add to col of output blocks + for(int j = 1; j <= clen; j++) { + for(int i = 1; i <= numBlocksPerColIn; i++) { + try(OOCStream.QueueCallback qcb = singleRowBlockCache.findCached(new MatrixIndexes(i, j))) { + MatrixBlock blk = (MatrixBlock) qcb.get().getValue(); + + int offsetIn = 0; + int localRowsIn = (i == numBlocksPerColIn && rlen % blen != 0) ? (int) rlen % blen : blen; + while(offsetIn < localRowsIn) { + // until input col fully processed + int remIn = localRowsIn - offsetIn; + int remOut = localRowsOut - offsetOut; + if(remIn < remOut) { + // next input + setOutputEntries(blk, outputBlockCol[br], c, offsetIn, offsetOut, remIn, false); + offsetIn += remIn; + offsetOut += remIn; + continue; + } + else if(remIn == remOut) { + // next input and next col + setOutputEntries(blk, outputBlockCol[br], c, offsetIn, offsetOut, remIn, false); + offsetIn += remIn; + } + else { + // next col + setOutputEntries(blk, outputBlockCol[br], c, offsetIn, offsetOut, remOut, false); + offsetIn += remOut; + } + br++; + offsetOut = 0; + if(br == numBlocksPerColOut) { + // next col + c++; + if(c == outputBlockCol[0].getNumColumns()) { + // enqueue filled output blocks and allocate new ones + for(int b = 0; b < outputBlockCol.length; b++) + qOut.enqueue(new IndexedMatrixValue(new MatrixIndexes(b + 1, bc + 1), outputBlockCol[b])); + bc++; + // allocate new block col + outputBlockCol = allocateSliceBlocks(bc, rows, cols, blen, numBlocksPerRowOut, numBlocksPerColOut, false); + c = 0; + } + br = 0; + } + localRowsOut = (br == numBlocksPerColOut - 1 && rows % blen != 0) ? (int) rows % blen : blen; + } + } + } + } + qOut.closeInput(); + } + + private MatrixBlock[] allocateSliceBlocks(int idx, long rows, long cols, int blen, int numBlocksPerRow, int numBlocksPerCol, boolean isBlockRowSlice) { + int num = isBlockRowSlice ? numBlocksPerRow : numBlocksPerCol; + MatrixBlock[] res = new MatrixBlock[num]; + + // full inner blocks, adjust for outer blocks + int localRows = ((!isBlockRowSlice || idx == numBlocksPerCol - 1) && rows % blen != 0) ? (int) rows % blen : blen; + int localCols = ((isBlockRowSlice || idx == numBlocksPerRow - 1) && cols % blen != 0) ? (int) cols % blen : blen; + + for(int k = 0; k < num - 1; k++) { + res[k] = isBlockRowSlice ? new MatrixBlock(localRows, blen, false) : new MatrixBlock(blen, localCols, false); + res[k].allocateDenseBlock(); + } + res[num - 1] = new MatrixBlock(localRows, localCols, false); + res[num - 1].allocateDenseBlock(); + return res; + } + + private void setOutputEntries(MatrixBlock src, MatrixBlock dest, int idx, int srcOffset, int destOffset, int length, boolean rowWise) { + if(rowWise) + ((DenseBlockFP64) dest.getDenseBlock()).setPartialRow(src.getDenseBlock(), idx, srcOffset, destOffset, length); + else + ((DenseBlockFP64) dest.getDenseBlock()).setPartialCol(src.getDenseBlock(), idx, srcOffset, destOffset, length); + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/ReshapeTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/ReshapeTest.java new file mode 100644 index 00000000000..770c5b7c5bf --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/ReshapeTest.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.common.Types; +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.instructions.Instruction; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; + +@RunWith(Parameterized.class) +@net.jcip.annotations.NotThreadSafe +public class ReshapeTest extends AutomatedTestBase { + private final static String TEST_NAME1 = "MatrixReshapeRowWise"; + private final static String TEST_NAME2 = "MatrixReshapeColWise"; + private final static String TEST_DIR = "functions/ooc/"; + private static final String TEST_CLASS_DIR = TEST_DIR + ReshapeTest.class.getSimpleName() + "/"; + private static final String INPUT_NAME = "X"; + private static final String OUTPUT_NAME = "Y"; + private static final double eps = 1e-8; + private static final int blen = 1000; + + private final int rlen; + private final int clen; + private final int rows; + private final int cols; + private final boolean rowWise; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1)); + addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2)); + } + + public ReshapeTest(int rlen, int clen, int rows, int cols, boolean rowWise) { + this.rlen = rlen; + this.clen = clen; + this.rows = rows; + this.cols = cols; + this.rowWise = rowWise; + } + + @Parameterized.Parameters(name = "{0}x{1} {2}x{3} rowWise {4}") + public static Iterable getParams() { + + int[][][] dims = { + {{1000, 1000}, {1, 1000000}}, // single row/col + {{3000, 4000}, {1500, 8000}}, // partialBlocks + {{2400, 1400}, {800, 4200}} // fullBlocks + }; + + ArrayList params = new ArrayList<>(); + + for(int[][] d : dims) { + params.add(new Object[] {d[0][0], d[0][1], d[1][0], d[1][1], true}); + params.add(new Object[] {d[1][0], d[1][1], d[0][0], d[0][1], true}); + + params.add(new Object[] {d[0][1], d[0][0], d[1][1], d[1][0], false}); + params.add(new Object[] {d[1][1], d[1][0], d[0][1], d[0][0], false}); + } + + for(boolean rowWise : new boolean[] {true, false}) { + // single block + params.add(new Object[] {400, 300, 300, 400, rowWise}); + // non matching dims + params.add(new Object[] {1400, 1000, 5000, 1, rowWise}); + // no change + params.add(new Object[] {300, 400, 300, 400, rowWise}); + } + + return params; + } + + @Test + public void runTestMatrixReshapeOOC() { + ExecMode platformOld = setExecMode(ExecMode.SINGLE_NODE); + + try { + String TEST_NAME = (rowWise) ? TEST_NAME1 : TEST_NAME2; + getAndLoadTestConfiguration(TEST_NAME); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + + double[][] X = getRandomMatrix(rlen, clen, 0, 1, 1, 7); + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY); + writer.writeMatrixToHDFS(DataConverter.convertToMatrixBlock(X), input(INPUT_NAME), rlen, clen, 1000, rlen * clen); + HDFSTool.writeMetaDataFile(input(INPUT_NAME + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(rlen, clen, blen, rlen * clen), Types.FileFormat.BINARY); + + programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", input(INPUT_NAME), String.valueOf(rlen), + String.valueOf(clen), String.valueOf(rows), String.valueOf(cols), output(OUTPUT_NAME)}; + + if(rlen * clen != rows * cols) { + runTest(true, true, DMLRuntimeException.class, -1); + return; + } + + runTest(true, false, null, -1); + if(rlen != rows) + Assert.assertTrue("OOC wasn't used for reshape", + heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.RESHAPE)); + else + Assert.assertTrue("OOC RBLK wasn't used for unchanged dimensions", + heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.RBLK)); + + // rerun without ooc flag + programArgs = new String[] {"-explain", "-stats", "-args", input(INPUT_NAME), String.valueOf(rlen), + String.valueOf(clen), String.valueOf(rows), String.valueOf(cols), output(OUTPUT_NAME + "_target")}; + runTest(true, false, null, -1); + + // compare results + MatrixBlock actual = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME), + Types.FileFormat.BINARY, rows, cols, blen); + MatrixBlock expected = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME + "_target"), + Types.FileFormat.BINARY, rows, cols, blen); + + TestUtils.compareMatrices(expected, actual, eps); + } + catch(Exception ex) { + Assert.fail(ex.getMessage()); + } + finally { + rtplatform = platformOld; + } + } +} diff --git a/src/test/scripts/functions/ooc/MatrixReshapeColWise.dml b/src/test/scripts/functions/ooc/MatrixReshapeColWise.dml new file mode 100644 index 00000000000..3af5463b65f --- /dev/null +++ b/src/test/scripts/functions/ooc/MatrixReshapeColWise.dml @@ -0,0 +1,26 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +X = read($1, rows=$2, cols=$3); + +Y = matrix(X, rows=$4, cols=$5, byrow=FALSE); +write(Y, $6, format="binary"); diff --git a/src/test/scripts/functions/ooc/MatrixReshapeRowWise.dml b/src/test/scripts/functions/ooc/MatrixReshapeRowWise.dml new file mode 100644 index 00000000000..edaaf258d90 --- /dev/null +++ b/src/test/scripts/functions/ooc/MatrixReshapeRowWise.dml @@ -0,0 +1,26 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +X = read($1, rows=$2, cols=$3); + +Y = matrix(X, rows=$4, cols=$5, byrow=TRUE); +write(Y, $6, format="binary");