diff --git a/scripts/builtin/kmeans.dml b/scripts/builtin/kmeans.dml index 7fdd320a164..8b76040f641 100644 --- a/scripts/builtin/kmeans.dml +++ b/scripts/builtin/kmeans.dml @@ -148,6 +148,7 @@ m_kmeans = function(Matrix[Double] X, Integer k = 10, Integer runs = 10, Integer P = D <= minD; # If some records belong to multiple centroids, share them equally P = P / rowSums (P); + # P = table(seq(1,num_records), rowIndexMin(D), num_records, num_centroids) # Compute the column normalization factor for P P_denom = colSums (P); # Compute new centroids as weighted averages over the records diff --git a/src/main/java/org/apache/sysds/hops/TernaryOp.java b/src/main/java/org/apache/sysds/hops/TernaryOp.java index 0334dbbb2f7..dabaeb79aab 100644 --- a/src/main/java/org/apache/sysds/hops/TernaryOp.java +++ b/src/main/java/org/apache/sysds/hops/TernaryOp.java @@ -651,6 +651,8 @@ public boolean isSequenceRewriteApplicable( boolean left ) try { + // TODO: to rewrite is not currently not triggered if outdim are given --> getInput().size()>=3 + // currently disabled due performance decrease if( getInput().size()==2 || (getInput().size()==3 && getInput().get(2).getDataType()==DataType.SCALAR) ) { Hop input1 = getInput().get(0); diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java index c78d651ff00..48637595741 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java @@ -59,6 +59,7 @@ import org.apache.sysds.runtime.compress.lib.CLALibMatrixMult; import org.apache.sysds.runtime.compress.lib.CLALibMerge; import org.apache.sysds.runtime.compress.lib.CLALibReplace; +import org.apache.sysds.runtime.compress.lib.CLALibReorg; import org.apache.sysds.runtime.compress.lib.CLALibReshape; import org.apache.sysds.runtime.compress.lib.CLALibRexpand; import org.apache.sysds.runtime.compress.lib.CLALibScalar; @@ -633,21 +634,7 @@ public MatrixBlock replaceOperations(MatrixValue result, double pattern, double @Override public MatrixBlock reorgOperations(ReorgOperator op, MatrixValue ret, int startRow, int startColumn, int length) { - if(op.fn instanceof SwapIndex && this.getNumColumns() == 1) { - MatrixBlock tmp = decompress(op.getNumThreads()); - long nz = tmp.setNonZeros(tmp.getNonZeros()); - tmp = new MatrixBlock(tmp.getNumColumns(), tmp.getNumRows(), tmp.getDenseBlockValues()); - tmp.setNonZeros(nz); - return tmp; - } - else { - // Allow transpose to be compressed output. In general we need to have a transposed flag on - // the compressed matrix. https://issues.apache.org/jira/browse/SYSTEMDS-3025 - String message = op.getClass().getSimpleName() + " -- " + op.fn.getClass().getSimpleName(); - MatrixBlock tmp = getUncompressed(message, op.getNumThreads()); - return tmp.reorgOperations(op, ret, startRow, startColumn, length); - } - + return CLALibReorg.reorg(this, op, (MatrixBlock) ret, startRow, startColumn, length); } public boolean isOverlapping() { @@ -1311,7 +1298,7 @@ public void allocateAndResetSparseBlock(boolean clearNNZ, SparseBlock.Type stype @Override public MatrixBlock transpose(int k) { - return getUncompressed().transpose(k); + return CLALibReorg.reorg(this, new ReorgOperator(SwapIndex.getSwapIndexFnObject(), k), null, 0, 0, 0); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java index c1b9c65f229..e55a24e56f5 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java @@ -251,7 +251,21 @@ protected void decompressToSparseBlockDenseDictionary(SparseBlock ret, int rl, i @Override protected void decompressToDenseBlockTransposedSparseDictionary(DenseBlock db, int rl, int ru, SparseBlock sb) { - throw new NotImplementedException(); + for(int i = rl; i < ru; i++) { + final int vr = _data.getIndex(i); + if(sb.isEmpty(vr)) + continue; + final int apos = sb.pos(vr); + final int alen = sb.size(vr) + apos; + final int[] aix = sb.indexes(vr); + final double[] aval = sb.values(vr); + for(int j = apos; j < alen; j++) { + final int rowOut = _colIndexes.get(aix[j]); + final double[] c = db.values(rowOut); + final int off = db.pos(rowOut); + c[off + i] += aval[j]; + } + } } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibReorg.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibReorg.java new file mode 100644 index 00000000000..d587d26c3cb --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibReorg.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.compress.lib; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.compress.DMLCompressionException; +import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.data.DenseBlock; +import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.data.SparseBlockMCSR; +import org.apache.sysds.runtime.functionobjects.SwapIndex; +import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.ReorgOperator; +import org.apache.sysds.runtime.util.CommonThreadPool; + +public class CLALibReorg { + + protected static final Log LOG = LogFactory.getLog(CLALibReorg.class.getName()); + + public static boolean warned = false; + + public static MatrixBlock reorg(CompressedMatrixBlock cmb, ReorgOperator op, MatrixBlock ret, int startRow, + int startColumn, int length) { + // SwapIndex is transpose + if(op.fn instanceof SwapIndex && cmb.getNumColumns() == 1) { + MatrixBlock tmp = cmb.decompress(op.getNumThreads()); + long nz = tmp.setNonZeros(tmp.getNonZeros()); + if(tmp.isInSparseFormat()) + return LibMatrixReorg.transpose(tmp); // edge case... + else + tmp = new MatrixBlock(tmp.getNumColumns(), tmp.getNumRows(), tmp.getDenseBlockValues()); + tmp.setNonZeros(nz); + return tmp; + } + else if(op.fn instanceof SwapIndex) { + MatrixBlock tmp = cmb.getCachedDecompressed(); + if(tmp != null) + return tmp.reorgOperations(op, ret, startRow, startColumn, length); + // Allow transpose to be compressed output. In general we need to have a transposed flag on + // the compressed matrix. https://issues.apache.org/jira/browse/SYSTEMDS-3025 + return transpose(cmb, ret, op.getNumThreads()); + } + else { + String message = !warned ? op.getClass().getSimpleName() + " -- " + op.fn.getClass().getSimpleName() : null; + MatrixBlock tmp = cmb.getUncompressed(message, op.getNumThreads()); + warned = true; + return tmp.reorgOperations(op, ret, startRow, startColumn, length); + } + } + + private static MatrixBlock transpose(CompressedMatrixBlock cmb, MatrixBlock ret, int k) { + + final long nnz = cmb.getNonZeros(); + final int nRow = cmb.getNumRows(); + final int nCol = cmb.getNumColumns(); + final boolean sparseOut = MatrixBlock.evalSparseFormatInMemory(nCol,nRow, nnz); + if(sparseOut) + return transposeSparse(cmb, ret, k, nRow, nCol, nnz); + else + return transposeDense(cmb, ret, k, nRow, nCol, nnz); + } + + private static MatrixBlock transposeSparse(CompressedMatrixBlock cmb, MatrixBlock ret, int k, int nRow, int nCol, + long nnz) { + if(ret == null) + ret = new MatrixBlock(nCol, nRow, true, nnz); + else + ret.reset(nCol, nRow, true, nnz); + + ret.allocateAndResetSparseBlock(true, SparseBlock.Type.MCSR); + + final int nColOut = ret.getNumColumns(); + + if(k > 1 && cmb.getColGroups().size() > 1) + decompressToTransposedSparseParallel((SparseBlockMCSR) ret.getSparseBlock(), cmb.getColGroups(), nColOut, k); + else + decompressToTransposedSparseSingleThread((SparseBlockMCSR) ret.getSparseBlock(), cmb.getColGroups(), nColOut); + + return ret; + } + + private static MatrixBlock transposeDense(CompressedMatrixBlock cmb, MatrixBlock ret, int k, int nRow, int nCol, + long nnz) { + if(ret == null) + ret = new MatrixBlock(nCol, nRow, false, nnz); + else + ret.reset(nCol, nRow, false, nnz); + + // TODO: parallelize + ret.allocateDenseBlock(); + + decompressToTransposedDense(ret.getDenseBlock(), cmb.getColGroups(), nRow, 0, nRow); + return ret; + } + + private static void decompressToTransposedDense(DenseBlock ret, List groups, int rlen, int rl, int ru) { + for(int i = 0; i < groups.size(); i++) { + AColGroup g = groups.get(i); + g.decompressToDenseBlockTransposed(ret, rl, ru); + } + } + + private static void decompressToTransposedSparseSingleThread(SparseBlockMCSR ret, List groups, + int nColOut) { + for(int i = 0; i < groups.size(); i++) { + AColGroup g = groups.get(i); + g.decompressToSparseBlockTransposed(ret, nColOut); + } + } + + private static void decompressToTransposedSparseParallel(SparseBlockMCSR ret, List groups, int nColOut, + int k) { + final ExecutorService pool = CommonThreadPool.get(k); + try { + final List> tasks = new ArrayList<>(groups.size()); + + for(int i = 0; i < groups.size(); i++) { + final AColGroup g = groups.get(i); + tasks.add(pool.submit(() -> g.decompressToSparseBlockTransposed(ret, nColOut))); + } + + for(Future f : tasks) + f.get(); + + } + catch(Exception e) { + throw new DMLCompressionException("Failed to parallel decompress transpose sparse", e); + } + finally { + pool.shutdown(); + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRexpand.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRexpand.java index 34f22441112..5be508febde 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRexpand.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRexpand.java @@ -39,6 +39,7 @@ import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory; import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.Pair; import org.apache.sysds.runtime.util.CommonThreadPool; import org.apache.sysds.runtime.util.UtilFunctions; @@ -71,19 +72,23 @@ public static MatrixBlock rexpand(int seqHeight, MatrixBlock A, int nColOut, int try { final int[] map = new int[seqHeight]; - int maxCol = constructInitialMapping(map, A, k); + Pair meta = constructInitialMapping(map, A, k, nColOut); + int maxCol = meta.getKey(); + int nZeros = meta.getValue(); boolean containsNull = maxCol < 0; maxCol = Math.abs(maxCol); + boolean cutOff = false; if(nColOut == -1) nColOut = maxCol; else if(nColOut < maxCol) - throw new DMLRuntimeException("invalid nColOut, requested: " + nColOut + " but have to be : " + maxCol); + cutOff = true; - final int nNulls = containsNull ? correctNulls(map, nColOut) : 0; + if(containsNull) + correctNulls(map, nColOut); if(nColOut == 0) // edge case of empty zero dimension block. return new MatrixBlock(seqHeight, 0, 0.0); - return createCompressedReturn(map, nColOut, seqHeight, nNulls, containsNull, k); + return createCompressedReturn(map, nColOut, seqHeight, nZeros, containsNull || cutOff, k); } catch(Exception e) { throw new RuntimeException("Failed table seq operator", e); @@ -139,7 +144,7 @@ private static int correctNulls(int[] map, int nColOut) { return nNulls; } - private static int constructInitialMapping(int[] map, MatrixBlock A, int k) { + private static Pair constructInitialMapping(int[] map, MatrixBlock A, int k, int maxOutCol) { if(A.isEmpty() || A.isInSparseFormat()) throw new DMLRuntimeException("not supported empty or sparse construction of seq table"); final MatrixBlock Ac; @@ -155,20 +160,23 @@ private static int constructInitialMapping(int[] map, MatrixBlock A, int k) { try { int blkz = Math.max((map.length / k), 1000); - List> tasks = new ArrayList<>(); + List>> tasks = new ArrayList<>(); for(int i = 0; i < map.length; i += blkz) { final int start = i; final int end = Math.min(i + blkz, map.length); - tasks.add(pool.submit(() -> partialMapping(map, Ac, start, end))); + tasks.add(pool.submit(() -> partialMapping(map, Ac, start, end, maxOutCol))); } int maxCol = 0; - for(Future f : tasks) { - int tmp = f.get(); - if(Math.abs(tmp) > Math.abs(maxCol)) - maxCol = tmp; + int zeros = 0; + for(Future> f : tasks) { + int tmpMaxCol = f.get().getKey(); + int tmpZeros = f.get().getValue(); + if(Math.abs(tmpMaxCol) > Math.abs(maxCol)) + maxCol = tmpMaxCol; + zeros += tmpZeros; } - return maxCol; + return new Pair(maxCol, zeros); } catch(Exception e) { throw new DMLRuntimeException(e); @@ -179,33 +187,32 @@ private static int constructInitialMapping(int[] map, MatrixBlock A, int k) { } - private static int partialMapping(int[] map, MatrixBlock A, int start, int end) { + private static Pair partialMapping(int[] map, MatrixBlock A, int start, int end, int maxOutCol) { int maxCol = 0; - boolean containsNull = false; - + int zeros = 0; final double[] aVals = A.getDenseBlockValues(); for(int i = start; i < end; i++) { final double v2 = aVals[i]; - if(Double.isNaN(v2)) { - map[i] = -1; // assign temporarily to -1 - containsNull = true; - } - else { - // safe casts to long for consistent behavior with indexing - int col = UtilFunctions.toInt(v2); - if(col <= 0) - throw new DMLRuntimeException( + final int colUnsafe = UtilFunctions.toInt(v2); + if(!Double.isNaN(v2) && colUnsafe < 0) + throw new DMLRuntimeException( "Erroneous input while computing the contingency table (value <= zero): " + v2); + // Boolean to int conversion to avoid branch + final int invalid = Double.isNaN(v2) || (maxOutCol != -1 && colUnsafe > maxOutCol) ? 1 : 0; + // if invalid -> maxOutCol else -> colUnsafe - 1 + final int colSafe = maxOutCol*invalid + (colUnsafe - 1)*(1 - invalid); + zeros += invalid; + maxCol = Math.max(colUnsafe, maxCol); + map[i] = colSafe; + } - map[i] = col - 1; - // maintain max seen col - maxCol = Math.max(col, maxCol); - } + if (maxOutCol == -1 && zeros > 0){ + maxCol *= -1; } - return containsNull ? maxCol * -1 : maxCol; + return new Pair(maxCol, zeros); } public static boolean compressedTableSeq() { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java index 7d0d9f78704..4216385b722 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java @@ -138,7 +138,9 @@ private void processSimpleCompressInstruction(ExecutionContext ec) { else if(ec.isMatrixObject(input1.getName())) processMatrixBlockCompression(ec, ec.getMatrixInput(input1.getName()), _numThreads, root); else { - throw new NotImplementedException("Not supported other types of input for compression than frame and matrix"); + LOG.warn("Compression on Scalar should not happen"); + ScalarObject Scalar = ec.getScalarInput(input1); + ec.setScalarOutput(output.getName(),Scalar); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/CtableCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/CtableCPInstruction.java index 4f508cd5b8d..52c596c33bf 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CtableCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CtableCPInstruction.java @@ -110,13 +110,17 @@ public void processInstruction(ExecutionContext ec) { boolean outputDimsKnown = (outputDim1 != -1 && outputDim2 != -1); if ( outputDimsKnown ) { - int inputRows = matBlock1.getNumRows(); - int inputCols = matBlock1.getNumColumns(); - boolean sparse = MatrixBlock.evalSparseFormatInMemory(outputDim1, outputDim2, inputRows*inputCols); - //only create result block if dense; it is important not to aggregate on sparse result - //blocks because it would implicitly turn the O(N) algorithm into O(N log N). - if( !sparse ) - resultBlock = new MatrixBlock((int)outputDim1, (int)outputDim2, false); + if(_isExpand){ + resultBlock = new MatrixBlock((int)outputDim1, (int)outputDim2, true); + } else { + int inputRows = matBlock1.getNumRows(); + int inputCols = matBlock1.getNumColumns(); + boolean sparse = MatrixBlock.evalSparseFormatInMemory(outputDim1, outputDim2, inputRows*inputCols); + //only create result block if dense; it is important not to aggregate on sparse result + //blocks because it would implicitly turn the O(N) algorithm into O(N log N). + if( !sparse ) + resultBlock = new MatrixBlock((int)outputDim1, (int)outputDim2, false); + } } switch(ctableOp) { @@ -140,7 +144,8 @@ public void processInstruction(ExecutionContext ec) { } matBlock2 = ec.getMatrixInput(input2.getName()); cst1 = ec.getScalarInput(input3).getDoubleValue(); - resultBlock = LibMatrixReorg.fusedSeqRexpand(matBlock2.getNumRows(), matBlock2, cst1, resultBlock, true, _k); + resultBlock = LibMatrixReorg.fusedSeqRexpand(matBlock2.getNumRows(), matBlock2, cst1, resultBlock, + !outputDimsKnown, _k); break; case CTABLE_TRANSFORM_HISTOGRAM: //(VECTOR) // F=ctable(A,1) or F = ctable(A,1,1) diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java index 009be82de3a..29c2ecdaf2b 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java @@ -1044,11 +1044,13 @@ public static MatrixBlock fusedSeqRexpand(int seqHeight, MatrixBlock A, double w } - private static MatrixBlock fusedSeqRexpandSparse(int seqHeight, MatrixBlock A, double w, MatrixBlock ret, boolean updateClen) { + private static MatrixBlock fusedSeqRexpandSparse(int seqHeight, MatrixBlock A, double w, MatrixBlock ret, + boolean updateClen) { if(ret == null) { ret = new MatrixBlock(); updateClen = true; } + int outCols = updateClen ? -1 : ret.getNumColumns(); final int rlen = seqHeight; // prepare allocation of CSR sparse block final int[] rowPointers = new int[rlen + 1]; @@ -1060,14 +1062,14 @@ private static MatrixBlock fusedSeqRexpandSparse(int seqHeight, MatrixBlock A, d ret.sparse = true; ret.denseBlock = null; // construct sparse CSR block from filled arrays - SparseBlockCSR csr = new SparseBlockCSR(rowPointers, indexes, values, rlen); + SparseBlockCSR csr = new SparseBlockCSR(rowPointers, indexes, values, seqHeight); ret.sparseBlock = csr; - int blkz = Math.min(1024, rlen); + int blkz = Math.min(1024, seqHeight); int maxcol = 0; boolean containsNull = false; - for(int i = 0; i < rlen; i += blkz) { + for(int i = 0; i < seqHeight; i += blkz) { // blocked execution for earlier JIT compilation - int t = fusedSeqRexpandSparseBlock(csr, A, w, i, Math.min(i + blkz, rlen)); + int t = fusedSeqRexpandSparseBlock(csr, A, w, i, Math.min(i + blkz, seqHeight), updateClen,outCols); if(t < 0) { t = Math.abs(t); containsNull = true; @@ -1078,14 +1080,15 @@ private static MatrixBlock fusedSeqRexpandSparse(int seqHeight, MatrixBlock A, d if(containsNull) csr.compact(); - rowPointers[rlen] = rlen; + rowPointers[seqHeight] = seqHeight; ret.setNonZeros(ret.sparseBlock.size()); if(updateClen) - ret.setNumColumns(maxcol); + ret.setNumColumns(outCols == -1 ? maxcol : (int) outCols); return ret; } - private static int fusedSeqRexpandSparseBlock(final SparseBlockCSR csr, final MatrixBlock A, final double w, int rl, int ru) { + private static int fusedSeqRexpandSparseBlock(final SparseBlockCSR csr, final MatrixBlock A, final double w, int rl, + int ru, boolean updateClen,int maxOutCol) { // prepare allocation of CSR sparse block final int[] rowPointers = csr.rowPointers(); @@ -1096,11 +1099,9 @@ private static int fusedSeqRexpandSparseBlock(final SparseBlockCSR csr, final Ma int maxCol = 0; for(int i = rl; i < ru; i++) { - int c = rexpandSingleRow(i, A.get(i, 0), w, indexes, values); - if(c < 0) - containsNull = true; - else - maxCol = Math.max(c, maxCol); + int c = rexpandSingleRow(i, A.get(i, 0), w, indexes, values, updateClen, maxOutCol); + containsNull |= c < 0; + maxCol = Math.max(c, maxCol); rowPointers[i] = i; } @@ -1114,23 +1115,22 @@ private static void updateClenRexpand(MatrixBlock ret, int maxCol, boolean updat ret.clen = maxCol; } - public static int rexpandSingleRow(int row, double v2, double w, int[] retIx, double[] retVals) { - // If any of the values are NaN (i.e., missing) then - // we skip this tuple, proceed to the next tuple - if(Double.isNaN(v2)) - return -1; + public static int rexpandSingleRow(int row, double v2, double w, int[] retIx, double[] retVals, + boolean updateClen, int maxOutCol) { - // safe casts to long for consistent behavior with indexing - int col = UtilFunctions.toInt(v2); - if(col <= 0) - throw new DMLRuntimeException("Erroneous input while computing the contingency table (value <= zero): " + v2); + final int colUnsafe = UtilFunctions.toInt(v2); // colUnsafe = 0 for Nan + int isNan = (Double.isNaN(v2) ? 1 : 0); // avoid branching by boolean to int conversion + int col = colUnsafe - isNan; // col = -1 for Nan - // set weight as value (expand is guaranteed to address different cells) - retIx[row] = col - 1; - retVals[row] = w; + // use branch prediction for rare case + if(!Double.isNaN(v2) && colUnsafe <= 0) + throw new DMLRuntimeException("Erroneous input while computing the contingency table (value <= zero): " + v2); - // maintain max seen col - return col; + // avoid branching again by boolean to int conversion + int valid = !Double.isNaN(v2) && (updateClen || col <= maxOutCol) ? 1 : 0; + retIx[row] = (col - 1)*valid; // use valid as switch + retVals[row] = w*valid; + return valid*col + valid - 1; // -1 if invalid else col } /** diff --git a/src/test/java/org/apache/sysds/test/component/compress/lib/SeqTableTest.java b/src/test/java/org/apache/sysds/test/component/compress/lib/SeqTableTest.java index 8f666aee487..a43a53e1501 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/lib/SeqTableTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/lib/SeqTableTest.java @@ -25,6 +25,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.compress.lib.CLALibRexpand; +import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.test.TestUtils; import org.junit.Test; @@ -40,7 +41,7 @@ public class SeqTableTest { @Test(expected = RuntimeException.class) public void test_notSameDim() throws Exception { MatrixBlock c = new MatrixBlock(20, 1, 0.0); - CLALibRexpand.rexpand(10, c); + LibMatrixReorg.fusedSeqRexpand(10, c, 1.0); } @Test(expected = RuntimeException.class) @@ -52,7 +53,7 @@ public void test_toLow() throws Exception { @Test(expected = RuntimeException.class) public void test_toManyColumn() throws Exception { MatrixBlock c = new MatrixBlock(10, 2, -1.0); - CLALibRexpand.rexpand(10, c); + LibMatrixReorg.fusedSeqRexpand(10, c, 1.0); } @Test diff --git a/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressBase.java b/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressBase.java index 3d61b4942c7..1ddfc09258b 100644 --- a/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressBase.java +++ b/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressBase.java @@ -32,6 +32,8 @@ import org.apache.sysds.utils.Statistics; import org.junit.Assert; +import java.io.ByteArrayOutputStream; + public abstract class CompressBase extends AutomatedTestBase { // private static final Log LOG = LogFactory.getLog(CompressBase.class.getName()); @@ -66,7 +68,8 @@ public void compressTest(int rows, int cols, double sparsity, ExecType instType, fullDMLScriptName = SCRIPT_DIR + "/functions/compress/compress_" + name + ".dml"; programArgs = new String[] {"-stats", "100", "-nvargs", "A=" + input("A")}; - String out = runTest(null).toString(); + ByteArrayOutputStream tmp = runTest(null); + String out = tmp != null ? runTest(null).toString() : ""; int decompressCount = DMLCompressionStatistics.getDecompressionCount(); long compressionCount = (instType == ExecType.SPARK) ? Statistics @@ -74,7 +77,8 @@ public void compressTest(int rows, int cols, double sparsity, ExecType instType, DMLCompressionStatistics.reset(); Assert.assertEquals(out + "\ncompression count wrong : ", compressionCount, compressionCountsExpected); - Assert.assertTrue(out + "\nDecompression count wrong : ", + Assert.assertTrue(out + "\nDecompression count wrong : " + decompressCount + + (decompressionCountExpected >= 0 ? " [expected: " + decompressionCountExpected+ "]" : ""), decompressionCountExpected >= 0 ? decompressionCountExpected == decompressCount : decompressCount > 1); } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java index 2f8ce8ec7fd..02bfb960d5b 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java @@ -123,7 +123,7 @@ public void federatedKmeans(Types.ExecMode execMode, boolean singleWorker) { // Run actual dml script with federated matrix fullDMLScriptName = HOME + TEST_NAME + ".dml"; - programArgs = new String[] {"-stats", "-nvargs", "in_X1=" + TestUtils.federatedAddress(port1, input("X1")), + programArgs = new String[] {"-stats","20", "-nvargs", "in_X1=" + TestUtils.federatedAddress(port1, input("X1")), "in_X2=" + TestUtils.federatedAddress(port2, input("X2")), "rows=" + rows, "cols=" + cols, "single=" + String.valueOf(singleWorker).toUpperCase(), "runs=" + String.valueOf(runs), "out=" + output("Z")};