Skip to content

Commit

Permalink
[MINOR] MM Specializations
Browse files Browse the repository at this point in the history
This commit adds specializations for matrix multiplication with the following:

1. dense-sparse with sparse output
2. ultra sparse out dense dense in.
3. sparse out on sparse vector right side in.

Furthermore, I modified the call stack to branch to the native mm
inside LibMatrixMult, to allow easy native support for CLA by calling
LibMatrixMult, instead of having to go through a MatrixBlock.

Closes #2212
  • Loading branch information
Baunsgaard committed Feb 4, 2025
1 parent e022eaf commit fd1ba7c
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 24 deletions.
143 changes: 136 additions & 7 deletions src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,13 @@ public static MatrixBlock matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock
* @return ret Matrix Block
*/
public static MatrixBlock matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k) {
if(NativeHelper.isNativeLibraryLoaded())
return LibMatrixNative.matrixMult(m1, m2, ret, k);
else
return matrixMult(m1, m2, ret, false, k);
}

public static MatrixBlock matrixMultNonNative(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k) {
return matrixMult(m1, m2, ret, false, k);
}

Expand Down Expand Up @@ -256,7 +263,7 @@ private static void singleThreadedMatrixMult(MatrixBlock m1, MatrixBlock m2, Mat
// core matrix mult computation
if(ultraSparse && !fixedRet)
matrixMultUltraSparse(m1, m2, ret, m1Perm, 0, ru2);
else if(!m1.sparse && !m2.sparse)
else if(!m1.sparse && !m2.sparse && !ret.sparse)
matrixMultDenseDense(m1, m2, ret, tm2, pm2, 0, ru2, 0, m2.clen);
else if(m1.sparse && m2.sparse)
matrixMultSparseSparse(m1, m2, ret, pm2, sparse, 0, ru2);
Expand Down Expand Up @@ -1257,6 +1264,100 @@ public static void matrixMultDenseDenseMM(DenseBlock a, DenseBlock b, DenseBlock
}

private static void matrixMultDenseSparse(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, boolean pm2, int rl, int ru) {
if(ret.isInSparseFormat()){
if(!m1.sparse && !m2.sparse)
matrixMultDenseDenseOutSparse(m1,m2,ret, pm2, rl, ru);
else
matrixMultDenseSparseOutSparse(m1, m2, ret, pm2, rl, ru);
}
else
matrixMultDenseSparseOutDense(m1, m2, ret, pm2, rl, ru);
}


private static void matrixMultDenseDenseOutSparse(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, boolean pm2,
int rl, int ru) {
final DenseBlock a = m1.getDenseBlock();
final DenseBlock b = m2.getDenseBlock();
final SparseBlock c = ret.getSparseBlock();
final int m = m1.rlen; // rows left
final int cd = m1.clen; // common dim
final int n = m2.clen;

final int rl1 = pm2 ? 0 : rl;
final int ru1 = pm2 ? m : ru;
final int rl2 = pm2 ? rl : 0;
final int ru2 = pm2 ? ru : cd;

final int blocksizeK = 32;
final int blocksizeI = 32;

for(int bi = rl1; bi < ru1; bi += blocksizeI) {
for(int bk = rl2, bimin = Math.min(ru1, bi + blocksizeI); bk < ru2; bk += blocksizeK) {
final int bkmin = Math.min(ru2, bk + blocksizeK);
// core sub block matrix multiplication
for(int i = bi; i < bimin; i++) { // rows left
final double[] avals = a.values(i);
final int aix = a.pos(i);
for(int k = bk; k < bkmin; k++) { // common dimension
final double aval = avals[aix + k];
if(aval != 0) {
final double[] bvals = b.values(k);
final int bpos = b.pos(k);
for(int j = 0; j < n; j++) {
final double bv = bvals[bpos + j];
c.add(i, j, aval * bv);
}
}
}
}
}
}
}


private static void matrixMultDenseSparseOutSparse(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, boolean pm2,
int rl, int ru) {
final DenseBlock a = m1.getDenseBlock();
final SparseBlock b = m2.getSparseBlock();
final SparseBlock c = ret.getSparseBlock();
final int m = m1.rlen; // rows left
final int cd = m1.clen; // common dim

final int rl1 = pm2 ? 0 : rl;
final int ru1 = pm2 ? m : ru;
final int rl2 = pm2 ? rl : 0;
final int ru2 = pm2 ? ru : cd;

final int blocksizeK = 32;
final int blocksizeI = 32;

for(int bi = rl1; bi < ru1; bi += blocksizeI) {
for(int bk = rl2, bimin = Math.min(ru1, bi + blocksizeI); bk < ru2; bk += blocksizeK) {
final int bkmin = Math.min(ru2, bk + blocksizeK);
// core sub block matrix multiplication
for(int i = bi; i < bimin; i++) { // rows left
final double[] avals = a.values(i);
final int aix = a.pos(i);
for(int k = bk; k < bkmin; k++) { // common dimension
final double aval = avals[aix + k];
if(aval == 0 || b.isEmpty(k))
continue;
final int[] bIdx = b.indexes(k);
final double[] bVals = b.values(k);
final int bPos = b.pos(k);
final int bEnd = bPos + b.size(k);
for(int j = bPos; j < bEnd ; j++){
c.add(i, bIdx[j], aval * bVals[j]);
}
}
}
}
}
}

private static void matrixMultDenseSparseOutDense(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, boolean pm2, int rl,
int ru) {
DenseBlock a = m1.getDenseBlock();
DenseBlock c = ret.getDenseBlock();
int m = m1.rlen;
Expand Down Expand Up @@ -1907,8 +2008,10 @@ private static void matrixMultUltraSparseRight(MatrixBlock m1, MatrixBlock m2, M
if(ret.isInSparseFormat()){
if(m1.isInSparseFormat())
matrixMultUltraSparseRightSparseMCSRLeftSparseOut(m1, m2, ret, rl, ru);
else
else if (m2.isInSparseFormat())
matrixMultUltraSparseRightDenseLeftSparseOut(m1, m2, ret, rl, ru);
else
matrixMultUltraSparseDenseInput(m1, m2, ret, rl, ru);
}
else if(ret.getDenseBlock().isContiguous())
matrixMultUltraSparseRightDenseOut(m1, m2, ret, rl, ru);
Expand Down Expand Up @@ -1990,6 +2093,30 @@ private static void matrixMultUltraSparseRightDenseLeftSparseOut(MatrixBlock m1,
}
}

private static void matrixMultUltraSparseDenseInput(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int rl, int ru){
final int cd = m1.clen;
final int rc = m2.clen;
final DenseBlock a = m1.denseBlock;
final DenseBlock b = m2.denseBlock;
final SparseBlockMCSR c = (SparseBlockMCSR) ret.sparseBlock;

for(int i = rl; i < ru; i++) {
// it is known that the left matrix is most likely containing many zeros.
final double[] av = a.values(i);
final int pos = a.pos(i);
for(int k = 0; k < cd; k++) {
final double v = av[pos + k];
if(v != 0) {
final double[] bv = b.values(k);
final int posb = b.pos(k);
for(int j = 0; j < rc; j++) {
c.add(i,j, bv[posb + j] * v);
}
}
}
}
}

private static void mmDenseMatrixSparseRow(int bpos, int blen, int[] bixs, double[] bvals, int k, int i,
DenseBlock a, SparseBlockMCSR c) {
final double[] aval = a.values(i);
Expand Down Expand Up @@ -4419,6 +4546,8 @@ public static boolean isUltraSparseMatrixMult(MatrixBlock m1, MatrixBlock m2, bo
}

public static boolean isSparseOutputMatrixMult(MatrixBlock m1, MatrixBlock m2) {
if(m2.rlen == 1 && m2.nonZeros < m2.clen / 4) // vector right ... that is sparse.
return true;
//output is a matrix (not vector), very likely sparse, and output rows fit into L1 cache
if( !(m1.sparse && m2.sparse && m1.rlen > 1 && m2.clen > 1) )
return false;
Expand Down Expand Up @@ -4551,7 +4680,7 @@ private static class MatrixMultTask implements Callable<Object>
private final boolean _pm2r; //par over m2 rows
private final boolean _pm2c; //par over m2 rows
private final boolean _m1Perm; //sparse permutation
private final boolean _sparse; //sparse output
// private final boolean _sparse; //sparse output
private final int _rl;
private final int _ru;
private final ConcurrentHashMap<double[], double[]> _cache;
Expand All @@ -4565,7 +4694,7 @@ protected MatrixMultTask( MatrixBlock m1, MatrixBlock m2, MatrixBlock ret,
_pm2r = pm2r;
_pm2c = pm2c;
_m1Perm = m1Perm;
_sparse = sparse;
// _sparse = sparse;
_rl = rl;
_ru = ru;
_cache = cache;
Expand Down Expand Up @@ -4594,14 +4723,14 @@ public Object call() {
//compute block matrix multiplication
if( _ret.sparse ) //ultra-sparse
matrixMultUltraSparse(_m1, _m2, _ret, _m1Perm, rl, ru);
else if(!_m1.sparse && !_m2.sparse)
else if(!_m1.sparse && !_m2.sparse && !_ret.sparse){
if(_m1.denseBlock instanceof DenseBlockFP64DEDUP && _m2.denseBlock.isContiguous(0,_m1.clen) && cl == 0 && cu == _m2.clen)
matrixMultDenseDenseMMDedup((DenseBlockFP64DEDUP) _m1.denseBlock, _m2.denseBlock, (DenseBlockFP64DEDUP) _ret.denseBlock, _m2.clen, _m1.clen, rl, ru, _cache);
else
matrixMultDenseDense(_m1, _m2, _ret, _tm2, _pm2r, rl, ru, cl, cu);

}
else if(_m1.sparse && _m2.sparse)
matrixMultSparseSparse(_m1, _m2, _ret, _pm2r, _sparse, rl, ru);
matrixMultSparseSparse(_m1, _m2, _ret, _pm2r, _ret.sparse, rl, ru);
else if(_m1.sparse)
matrixMultSparseDense(_m1, _m2, _ret, _pm2r, rl, ru);
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ public static MatrixBlock matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock
else
LOG.warn("Was valid for native MM but native lib was not loaded");

return LibMatrixMult.matrixMult(m1, m2, ret, k);
return LibMatrixMult.matrixMultNonNative(m1, m2, ret, k);
}

public static void tsmm(MatrixBlock m1, MatrixBlock ret, boolean leftTrans, int k) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4994,10 +4994,7 @@ public final MatrixBlock aggregateBinaryOperations(MatrixBlock m1, MatrixBlock m
public MatrixBlock aggregateBinaryOperations(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, AggregateBinaryOperator op) {
checkAggregateBinaryOperations(m1, m2, op);
final int k = op.getNumThreads();
if(NativeHelper.isNativeLibraryLoaded())
return LibMatrixNative.matrixMult(m1, m2, ret, k);
else
return LibMatrixMult.matrixMult(m1, m2, ret, k);
return LibMatrixMult.matrixMult(m1, m2, ret, k);
}

protected void checkAggregateBinaryOperations(MatrixBlock m1, MatrixBlock m2, AggregateBinaryOperator op) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,13 @@ public class MatrixMultiplyTest {
// parallelization degree
private final int k;

public MatrixMultiplyTest(int i, int j, int k, double s, double s2, int p) {
public MatrixMultiplyTest(int i, int j, int k, double s, double s2, int p, boolean self) {
try {
this.left = TestUtils.ceil(TestUtils.generateTestMatrixBlock(i, j, -10, 10, i == 1 && j == 1 ? 1 : s, 13));
this.right = TestUtils.ceil(TestUtils.generateTestMatrixBlock(j, k, -10, 10, k == 1 && k == 1 ? 1 : s2, 14));
if(self)
this.right = left;
else
this.right = TestUtils.ceil(TestUtils.generateTestMatrixBlock(j, k, -10, 10, k == 1 && k == 1 ? 1 : s2, 14));

this.exp = multiply(left, right, 1);
this.k = p;
Expand Down Expand Up @@ -83,23 +86,33 @@ public static Collection<Object[]> data() {
for(int i = 0; i < is.length; i++) {
for(int j = 0; j < js.length; j++) {
for(int k = 0; k < ks.length; k++) {
tests.add(new Object[] {is[i], js[j], ks[k], sparsities[s], sparsities[s2], par[p]});
tests.add(new Object[] {is[i], js[j], ks[k], sparsities[s], sparsities[s2], par[p], false});
}
}
}
}
}
}

tests.add(new Object[]{1000, 100, 1000, 0.3, 0.0001, 6});
tests.add(new Object[]{1000, 100, 1000, 0.01, 0.3, 6});
tests.add(new Object[]{1000, 100, 1000, 0.3, 0.0005, 6});
tests.add(new Object[]{1000, 100, 1000, 0.005, 0.3, 6});

tests.add(new Object[]{1000, 100, 1000, 0.6, 0.0001, 6});
tests.add(new Object[]{1000, 100, 1000, 0.01, 0.6, 6});
tests.add(new Object[]{1000, 100, 1000, 0.6, 0.0005, 6});
tests.add(new Object[]{1000, 100, 1000, 0.005, 0.6, 6});
tests.add(new Object[]{1000, 100, 1000, 0.3, 0.0001, 6, false});
tests.add(new Object[]{1000, 100, 1000, 0.01, 0.3, 6, false});
tests.add(new Object[]{1000, 100, 1000, 0.3, 0.0005, 6, false});
tests.add(new Object[]{1000, 100, 1000, 0.005, 0.3, 6, false});

tests.add(new Object[]{1000, 100, 1000, 0.6, 0.0001, 6, false});
tests.add(new Object[]{1000, 100, 1000, 0.01, 0.6, 6, false});
tests.add(new Object[]{1000, 100, 1000, 0.6, 0.0005, 6, false});
tests.add(new Object[]{1000, 100, 1000, 0.005, 0.6, 6, false});

// 0.00004 ultra sparse turn point
tests.add(new Object[]{100, 100, 10000, 0.5, 0.00003, 6, false});
tests.add(new Object[]{10000, 100, 100, 0.00003, 0.6, 6, false});


tests.add(new Object[]{3, 10, 100000, 1.0, 0.00003, 6, false});
tests.add(new Object[]{100000, 10, 3, 0.00003, 1.0, 6, false});

tests.add(new Object[]{1000, 1000, 1000, 0.005, 0.6, 6, true});

}
catch(Exception e) {
Expand Down

0 comments on commit fd1ba7c

Please sign in to comment.