Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYSTEMDS-3824] Decompressing Transpose #2204

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions scripts/builtin/kmeans.dml
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ m_kmeans = function(Matrix[Double] X, Integer k = 10, Integer runs = 10, Integer
P = D <= minD;
# If some records belong to multiple centroids, share them equally
P = P / rowSums (P);
# P = table(seq(1,num_records), rowIndexMin(D), num_records, num_centroids)
# Compute the column normalization factor for P
P_denom = colSums (P);
# Compute new centroids as weighted averages over the records
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/org/apache/sysds/hops/TernaryOp.java
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,8 @@ public boolean isSequenceRewriteApplicable( boolean left )

try
{
// TODO: to rewrite is not currently not triggered if outdim are given --> getInput().size()>=3
// currently disabled due performance decrease
if( getInput().size()==2 || (getInput().size()==3 && getInput().get(2).getDataType()==DataType.SCALAR) )
{
Hop input1 = getInput().get(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import org.apache.sysds.runtime.compress.lib.CLALibMatrixMult;
import org.apache.sysds.runtime.compress.lib.CLALibMerge;
import org.apache.sysds.runtime.compress.lib.CLALibReplace;
import org.apache.sysds.runtime.compress.lib.CLALibReorg;
import org.apache.sysds.runtime.compress.lib.CLALibReshape;
import org.apache.sysds.runtime.compress.lib.CLALibRexpand;
import org.apache.sysds.runtime.compress.lib.CLALibScalar;
Expand Down Expand Up @@ -633,21 +634,7 @@ public MatrixBlock replaceOperations(MatrixValue result, double pattern, double

@Override
public MatrixBlock reorgOperations(ReorgOperator op, MatrixValue ret, int startRow, int startColumn, int length) {
if(op.fn instanceof SwapIndex && this.getNumColumns() == 1) {
MatrixBlock tmp = decompress(op.getNumThreads());
long nz = tmp.setNonZeros(tmp.getNonZeros());
tmp = new MatrixBlock(tmp.getNumColumns(), tmp.getNumRows(), tmp.getDenseBlockValues());
tmp.setNonZeros(nz);
return tmp;
}
else {
// Allow transpose to be compressed output. In general we need to have a transposed flag on
// the compressed matrix. https://issues.apache.org/jira/browse/SYSTEMDS-3025
String message = op.getClass().getSimpleName() + " -- " + op.fn.getClass().getSimpleName();
MatrixBlock tmp = getUncompressed(message, op.getNumThreads());
return tmp.reorgOperations(op, ret, startRow, startColumn, length);
}

return CLALibReorg.reorg(this, op, (MatrixBlock) ret, startRow, startColumn, length);
}

public boolean isOverlapping() {
Expand Down Expand Up @@ -1311,7 +1298,7 @@ public void allocateAndResetSparseBlock(boolean clearNNZ, SparseBlock.Type stype

@Override
public MatrixBlock transpose(int k) {
return getUncompressed().transpose(k);
return CLALibReorg.reorg(this, new ReorgOperator(SwapIndex.getSwapIndexFnObject(), k), null, 0, 0, 0);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,21 @@ protected void decompressToSparseBlockDenseDictionary(SparseBlock ret, int rl, i

@Override
protected void decompressToDenseBlockTransposedSparseDictionary(DenseBlock db, int rl, int ru, SparseBlock sb) {
throw new NotImplementedException();
for(int i = rl; i < ru; i++) {
final int vr = _data.getIndex(i);
if(sb.isEmpty(vr))
continue;
final int apos = sb.pos(vr);
final int alen = sb.size(vr) + apos;
final int[] aix = sb.indexes(vr);
final double[] aval = sb.values(vr);
for(int j = apos; j < alen; j++) {
final int rowOut = _colIndexes.get(aix[j]);
final double[] c = db.values(rowOut);
final int off = db.pos(rowOut);
c[off + i] += aval[j];
}
}
}

@Override
Expand Down
158 changes: 158 additions & 0 deletions src/main/java/org/apache/sysds/runtime/compress/lib/CLALibReorg.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.SparseBlockMCSR;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
import org.apache.sysds.runtime.util.CommonThreadPool;

public class CLALibReorg {

protected static final Log LOG = LogFactory.getLog(CLALibReorg.class.getName());

public static boolean warned = false;

public static MatrixBlock reorg(CompressedMatrixBlock cmb, ReorgOperator op, MatrixBlock ret, int startRow,
int startColumn, int length) {
// SwapIndex is transpose
if(op.fn instanceof SwapIndex && cmb.getNumColumns() == 1) {
MatrixBlock tmp = cmb.decompress(op.getNumThreads());
long nz = tmp.setNonZeros(tmp.getNonZeros());
if(tmp.isInSparseFormat())
return LibMatrixReorg.transpose(tmp); // edge case...
else
tmp = new MatrixBlock(tmp.getNumColumns(), tmp.getNumRows(), tmp.getDenseBlockValues());
tmp.setNonZeros(nz);
return tmp;
}
else if(op.fn instanceof SwapIndex) {
MatrixBlock tmp = cmb.getCachedDecompressed();
if(tmp != null)
return tmp.reorgOperations(op, ret, startRow, startColumn, length);
// Allow transpose to be compressed output. In general we need to have a transposed flag on
// the compressed matrix. https://issues.apache.org/jira/browse/SYSTEMDS-3025
return transpose(cmb, ret, op.getNumThreads());
}
else {
String message = !warned ? op.getClass().getSimpleName() + " -- " + op.fn.getClass().getSimpleName() : null;
MatrixBlock tmp = cmb.getUncompressed(message, op.getNumThreads());
warned = true;
return tmp.reorgOperations(op, ret, startRow, startColumn, length);
}
}

private static MatrixBlock transpose(CompressedMatrixBlock cmb, MatrixBlock ret, int k) {

final long nnz = cmb.getNonZeros();
final int nRow = cmb.getNumRows();
final int nCol = cmb.getNumColumns();
final boolean sparseOut = MatrixBlock.evalSparseFormatInMemory(nCol,nRow, nnz);
if(sparseOut)
return transposeSparse(cmb, ret, k, nRow, nCol, nnz);
else
return transposeDense(cmb, ret, k, nRow, nCol, nnz);
}

private static MatrixBlock transposeSparse(CompressedMatrixBlock cmb, MatrixBlock ret, int k, int nRow, int nCol,
long nnz) {
if(ret == null)
ret = new MatrixBlock(nCol, nRow, true, nnz);
else
ret.reset(nCol, nRow, true, nnz);

ret.allocateAndResetSparseBlock(true, SparseBlock.Type.MCSR);

final int nColOut = ret.getNumColumns();

if(k > 1 && cmb.getColGroups().size() > 1)
decompressToTransposedSparseParallel((SparseBlockMCSR) ret.getSparseBlock(), cmb.getColGroups(), nColOut, k);
else
decompressToTransposedSparseSingleThread((SparseBlockMCSR) ret.getSparseBlock(), cmb.getColGroups(), nColOut);

return ret;
}

private static MatrixBlock transposeDense(CompressedMatrixBlock cmb, MatrixBlock ret, int k, int nRow, int nCol,
long nnz) {
if(ret == null)
ret = new MatrixBlock(nCol, nRow, false, nnz);
else
ret.reset(nCol, nRow, false, nnz);

// TODO: parallelize
ret.allocateDenseBlock();

decompressToTransposedDense(ret.getDenseBlock(), cmb.getColGroups(), nRow, 0, nRow);
return ret;
}

private static void decompressToTransposedDense(DenseBlock ret, List<AColGroup> groups, int rlen, int rl, int ru) {
for(int i = 0; i < groups.size(); i++) {
AColGroup g = groups.get(i);
g.decompressToDenseBlockTransposed(ret, rl, ru);
}
}

private static void decompressToTransposedSparseSingleThread(SparseBlockMCSR ret, List<AColGroup> groups,
int nColOut) {
for(int i = 0; i < groups.size(); i++) {
AColGroup g = groups.get(i);
g.decompressToSparseBlockTransposed(ret, nColOut);
}
}

private static void decompressToTransposedSparseParallel(SparseBlockMCSR ret, List<AColGroup> groups, int nColOut,
int k) {
final ExecutorService pool = CommonThreadPool.get(k);
try {
final List<Future<?>> tasks = new ArrayList<>(groups.size());

for(int i = 0; i < groups.size(); i++) {
final AColGroup g = groups.get(i);
tasks.add(pool.submit(() -> g.decompressToSparseBlockTransposed(ret, nColOut)));
}

for(Future<?> f : tasks)
f.get();

}
catch(Exception e) {
throw new DMLCompressionException("Failed to parallel decompress transpose sparse", e);
}
finally {
pool.shutdown();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.Pair;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.UtilFunctions;

Expand Down Expand Up @@ -71,19 +72,23 @@ public static MatrixBlock rexpand(int seqHeight, MatrixBlock A, int nColOut, int

try {
final int[] map = new int[seqHeight];
int maxCol = constructInitialMapping(map, A, k);
Pair<Integer, Integer> meta = constructInitialMapping(map, A, k, nColOut);
int maxCol = meta.getKey();
int nZeros = meta.getValue();
boolean containsNull = maxCol < 0;
maxCol = Math.abs(maxCol);

boolean cutOff = false;
if(nColOut == -1)
nColOut = maxCol;
else if(nColOut < maxCol)
throw new DMLRuntimeException("invalid nColOut, requested: " + nColOut + " but have to be : " + maxCol);
cutOff = true;

final int nNulls = containsNull ? correctNulls(map, nColOut) : 0;
if(containsNull)
correctNulls(map, nColOut);
if(nColOut == 0) // edge case of empty zero dimension block.
return new MatrixBlock(seqHeight, 0, 0.0);
return createCompressedReturn(map, nColOut, seqHeight, nNulls, containsNull, k);
return createCompressedReturn(map, nColOut, seqHeight, nZeros, containsNull || cutOff, k);
}
catch(Exception e) {
throw new RuntimeException("Failed table seq operator", e);
Expand Down Expand Up @@ -139,7 +144,7 @@ private static int correctNulls(int[] map, int nColOut) {
return nNulls;
}

private static int constructInitialMapping(int[] map, MatrixBlock A, int k) {
private static Pair<Integer,Integer> constructInitialMapping(int[] map, MatrixBlock A, int k, int maxOutCol) {
if(A.isEmpty() || A.isInSparseFormat())
throw new DMLRuntimeException("not supported empty or sparse construction of seq table");
final MatrixBlock Ac;
Expand All @@ -155,20 +160,23 @@ private static int constructInitialMapping(int[] map, MatrixBlock A, int k) {
try {

int blkz = Math.max((map.length / k), 1000);
List<Future<Integer>> tasks = new ArrayList<>();
List<Future<Pair<Integer,Integer>>> tasks = new ArrayList<>();
for(int i = 0; i < map.length; i += blkz) {
final int start = i;
final int end = Math.min(i + blkz, map.length);
tasks.add(pool.submit(() -> partialMapping(map, Ac, start, end)));
tasks.add(pool.submit(() -> partialMapping(map, Ac, start, end, maxOutCol)));
}

int maxCol = 0;
for(Future<Integer> f : tasks) {
int tmp = f.get();
if(Math.abs(tmp) > Math.abs(maxCol))
maxCol = tmp;
int zeros = 0;
for(Future<Pair<Integer,Integer>> f : tasks) {
int tmpMaxCol = f.get().getKey();
int tmpZeros = f.get().getValue();
if(Math.abs(tmpMaxCol) > Math.abs(maxCol))
maxCol = tmpMaxCol;
zeros += tmpZeros;
}
return maxCol;
return new Pair<Integer,Integer>(maxCol, zeros);
}
catch(Exception e) {
throw new DMLRuntimeException(e);
Expand All @@ -179,33 +187,32 @@ private static int constructInitialMapping(int[] map, MatrixBlock A, int k) {

}

private static int partialMapping(int[] map, MatrixBlock A, int start, int end) {
private static Pair<Integer, Integer> partialMapping(int[] map, MatrixBlock A, int start, int end, int maxOutCol) {

int maxCol = 0;
boolean containsNull = false;

int zeros = 0;
final double[] aVals = A.getDenseBlockValues();

for(int i = start; i < end; i++) {
final double v2 = aVals[i];
if(Double.isNaN(v2)) {
map[i] = -1; // assign temporarily to -1
containsNull = true;
}
else {
// safe casts to long for consistent behavior with indexing
int col = UtilFunctions.toInt(v2);
if(col <= 0)
throw new DMLRuntimeException(
final int colUnsafe = UtilFunctions.toInt(v2);
if(!Double.isNaN(v2) && colUnsafe < 0)
throw new DMLRuntimeException(
"Erroneous input while computing the contingency table (value <= zero): " + v2);
// Boolean to int conversion to avoid branch
final int invalid = Double.isNaN(v2) || (maxOutCol != -1 && colUnsafe > maxOutCol) ? 1 : 0;
// if invalid -> maxOutCol else -> colUnsafe - 1
final int colSafe = maxOutCol*invalid + (colUnsafe - 1)*(1 - invalid);
zeros += invalid;
maxCol = Math.max(colUnsafe, maxCol);
map[i] = colSafe;
}

map[i] = col - 1;
// maintain max seen col
maxCol = Math.max(col, maxCol);
}
if (maxOutCol == -1 && zeros > 0){
maxCol *= -1;
}

return containsNull ? maxCol * -1 : maxCol;
return new Pair<Integer, Integer>(maxCol, zeros);
}

public static boolean compressedTableSeq() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ private void processSimpleCompressInstruction(ExecutionContext ec) {
else if(ec.isMatrixObject(input1.getName()))
processMatrixBlockCompression(ec, ec.getMatrixInput(input1.getName()), _numThreads, root);
else {
throw new NotImplementedException("Not supported other types of input for compression than frame and matrix");
LOG.warn("Compression on Scalar should not happen");
ScalarObject Scalar = ec.getScalarInput(input1);
ec.setScalarOutput(output.getName(),Scalar);
}
}

Expand Down
Loading
Loading