diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java index b657b278830..7fa82184b77 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java @@ -279,6 +279,12 @@ public long getNumberNonZerosWithReference(int[] counts, double[] reference, int @Override public boolean containsValueWithReference(double pattern, double[] reference) { + if(Double.isNaN(pattern)){ + for(int i = 0 ; i < reference.length; i++) + if(Double.isNaN(reference[i])) + return true; + return containsValue(pattern); + } return getMBDict().containsValueWithReference(pattern, reference); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java index ce563033c02..b59d7696ba7 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java @@ -726,6 +726,8 @@ public boolean containsValue(double pattern) { @Override public boolean containsValueWithReference(double pattern, double[] reference) { + if(Double.isNaN(pattern)) + return super.containsValueWithReference(pattern, reference); final int nCol = reference.length; for(int i = 0; i < _values.length; i++) if(_values[i] + reference[i % nCol] == pattern) @@ -913,46 +915,7 @@ public IDictionary replaceWithReference(double pattern, double replace, double[] final int nCol = reference.length; final int nRow = _values.length / nCol; if(Util.eq(pattern, Double.NaN)) { - Set colsWithNan = null; - for(int i = 0; i < reference.length; i++) { - if(Util.eq(reference[i], Double.NaN)) { - if(colsWithNan == null) - colsWithNan = new HashSet<>(); - colsWithNan.add(i); - reference[i] = replace; - } - } - - if(colsWithNan != null) { - final double[] retV = new double[_values.length]; - for(int i = 0; i < nRow; i++) { - final int off = i * reference.length; - for(int j = 0; j < nCol; j++) { - final int cell = off + j; - if(colsWithNan.contains(j)) - retV[cell] = 0; - else if(Util.eq(_values[cell], Double.NaN)) - retV[cell] = replace - reference[j]; - else - retV[cell] = _values[cell]; - } - } - return create(retV); - } - else { - final double[] retV = new double[_values.length]; - for(int i = 0; i < nRow; i++) { - final int off = i * reference.length; - for(int j = 0; j < nCol; j++) { - final int cell = off + j; - if(Util.eq(_values[cell], Double.NaN)) - retV[cell] = replace - reference[j]; - else - retV[cell] = _values[cell] ; - } - } - return create(retV); - } + return replaceWithReferenceNaN(replace, reference, nCol, nRow); } else { final double[] retV = new double[_values.length]; @@ -969,6 +932,62 @@ else if(Util.eq(_values[cell], Double.NaN)) } } + private IDictionary replaceWithReferenceNaN(double replace, double[] reference, final int nCol, final int nRow) { + final Set colsWithNan = getColsWithNan(replace, reference); + final double[] retV; + if(colsWithNan != null) { + if(colsWithNan.size() == nCol && replace == 0) + return null; + retV = new double[_values.length]; + replaceWithReferenceNanDenseWithNanCols(replace, reference, nRow, nCol, colsWithNan, _values, retV); + } + else { + retV = new double[_values.length]; + replaceWithReferenceNanDenseWithoutNanCols(replace, reference, nRow, nCol, retV, _values); + } + return create(retV); + } + + protected static Set getColsWithNan(double replace, double[] reference) { + Set colsWithNan = null; + for(int i = 0; i < reference.length; i++) { + if(Util.eq(reference[i], Double.NaN)) { + if(colsWithNan == null) + colsWithNan = new HashSet<>(); + colsWithNan.add(i); + reference[i] = replace; + } + } + return colsWithNan; + } + + protected static void replaceWithReferenceNanDenseWithoutNanCols(final double replace, final double[] reference, + final int nRow, final int nCol, final double[] retV, final double[] values) { + int off = 0; + for(int i = 0; i < nRow; i++) { + for(int j = 0; j < nCol; j++) { + final double v = values[off]; + retV[off++] = Util.eq(Double.NaN, v) ? replace - reference[j] : v; + } + } + } + + protected static void replaceWithReferenceNanDenseWithNanCols(final double replace, final double[] reference, + final int nRow, final int nCol, Set colsWithNan, final double[] values, final double[] retV) { + int off = 0; + for(int i = 0; i < nRow; i++) { + for(int j = 0; j < nCol; j++) { + final double v = values[off]; + if(colsWithNan.contains(j)) + retV[off++] = 0; + else if(Util.eq(v, Double.NaN)) + retV[off++] = replace - reference[j]; + else + retV[off++] = v; + } + } + } + @Override public void product(double[] ret, int[] counts, int nCol) { if(ret[0] == 0) @@ -1024,17 +1043,22 @@ public void productWithReference(double[] ret, int[] counts, double[] reference, if(ret[0] == 0) return; final MathContext cont = MathContext.DECIMAL128; - final int len = counts.length; + final int nRow = counts.length; final int nCol = reference.length; + BigDecimal tmp = BigDecimal.ONE; int off = 0; - for(int i = 0; i < len; i++) { + for(int i = 0; i < nRow; i++) { for(int j = 0; j < nCol; j++) { final double v = _values[off++] + reference[j]; if(v == 0) { ret[0] = 0; return; } + else if(!Double.isFinite(v)) { + ret[0] = v; + return; + } tmp = tmp.multiply(new BigDecimal(v).pow(counts[i], cont), cont); } } @@ -1044,6 +1068,7 @@ public void productWithReference(double[] ret, int[] counts, double[] reference, ret[0] = 0; else if(!Double.isInfinite(ret[0])) ret[0] = new BigDecimal(ret[0]).multiply(tmp, MathContext.DECIMAL128).doubleValue(); + } @Override @@ -1192,7 +1217,7 @@ public void TSMMToUpperTriangleSparseScaling(SparseBlock left, IColIndex rowsLef public boolean equals(IDictionary o) { if(o instanceof Dictionary) return Arrays.equals(_values, ((Dictionary) o)._values); - else if (o != null) + else if(o != null) return o.equals(this); return false; } @@ -1219,7 +1244,7 @@ public IDictionary reorder(int[] reorder) { return ret; } - @Override + @Override protected IDictionary rightMMPreAggSparseSelectedCols(int numVals, SparseBlock b, IColIndex thisCols, IColIndex aggregateColumns) { @@ -1264,7 +1289,7 @@ private void sparseAddSelected(int sPos, int sEnd, int aggColSize, IColIndex agg retIdx = 0; } - @Override + @Override protected IDictionary rightMMPreAggSparseAllColsRight(int numVals, SparseBlock b, IColIndex thisCols, int nColRight) { final int thisColsSize = thisCols.size(); @@ -1291,7 +1316,7 @@ protected IDictionary rightMMPreAggSparseAllColsRight(int numVals, SparseBlock b return Dictionary.create(ret); } - private void SparseAdd(int sPos, int sEnd, double[] ret, int offOut, int[] sIdx, double[] sVals, double v) { + private void SparseAdd(int sPos, int sEnd, double[] ret, int offOut, int[] sIdx, double[] sVals, double v) { if(v != 0) { for(int k = sPos; k < sEnd; k++) { // cols right with value ret[offOut + sIdx[k]] += v * sVals[k]; @@ -1299,7 +1324,6 @@ private void SparseAdd(int sPos, int sEnd, double[] ret, int offOut, int[] sIdx } } - @Override public IDictionary append(double[] row) { double[] retV = new double[_values.length + row.length]; @@ -1308,5 +1332,4 @@ public IDictionary append(double[] row) { return new Dictionary(retV); } - } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java index a4baecb674e..7fbfbdc17fd 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java @@ -25,7 +25,6 @@ import java.math.BigDecimal; import java.math.MathContext; import java.util.Arrays; -import java.util.HashSet; import java.util.Set; import org.apache.commons.lang3.NotImplementedException; @@ -36,6 +35,7 @@ import org.apache.sysds.runtime.compress.colgroup.indexes.SingleIndex; import org.apache.sysds.runtime.compress.colgroup.indexes.TwoIndex; import org.apache.sysds.runtime.compress.utils.Util; +import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.DenseBlockFP64; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockCSR; @@ -1495,6 +1495,8 @@ public boolean containsValue(double pattern) { @Override public boolean containsValueWithReference(double pattern, double[] reference) { + if(Double.isNaN(pattern)) + return super.containsValueWithReference(pattern, reference); if(_data.isInSparseFormat()) { final SparseBlock sb = _data.getSparseBlock(); for(int i = 0; i < _data.getNumRows(); i++) { @@ -2059,9 +2061,8 @@ public IDictionary replace(double pattern, double replace, int nCol) { @Override public IDictionary replaceWithReference(double pattern, double replace, double[] reference) { - if(Util.eq(pattern, Double.NaN)) { + if(Util.eq(pattern, Double.NaN)) return replaceWithReferenceNan(replace, reference); - } final int nRow = _data.getNumRows(); final int nCol = _data.getNumColumns(); @@ -2108,27 +2109,19 @@ public IDictionary replaceWithReference(double pattern, double replace, double[] } private IDictionary replaceWithReferenceNan(double replace, double[] reference) { - + final Set colsWithNan = Dictionary.getColsWithNan(replace, reference); final int nRow = _data.getNumRows(); final int nCol = _data.getNumColumns(); + if(colsWithNan != null && colsWithNan.size() == nCol && replace == 0) + return null; + final MatrixBlock ret = new MatrixBlock(nRow, nCol, false); ret.allocateDenseBlock(); - - Set colsWithNan = null; - for(int i = 0; i < reference.length; i++) { - if(Util.eq(reference[i], Double.NaN)) { - if(colsWithNan == null) - colsWithNan = new HashSet<>(); - colsWithNan.add(i); - reference[i] = replace; - } - } + final double[] retV = ret.getDenseBlockValues(); if(colsWithNan == null) { - - final double[] retV = ret.getDenseBlockValues(); - int off = 0; if(_data.isInSparseFormat()) { + final DenseBlock db = ret.getDenseBlock(); final SparseBlock sb = _data.getSparseBlock(); for(int i = 0; i < nRow; i++) { if(sb.isEmpty(i)) @@ -2137,30 +2130,22 @@ private IDictionary replaceWithReferenceNan(double replace, double[] reference) final int apos = sb.pos(i); final int alen = sb.size(i) + apos; final double[] avals = sb.values(i); + final int[] aix = sb.indexes(i); int j = 0; + int off = db.pos(i); for(int k = apos; k < alen; k++) { final double v = avals[k]; - retV[off++] = Util.eq(Double.NaN, v) ? replace - reference[j] : v; + retV[off + aix[k]] = Util.eq(Double.NaN, v) ? replace - reference[j] : v; } } } else { final double[] values = _data.getDenseBlockValues(); - for(int i = 0; i < nRow; i++) { - for(int j = 0; j < nCol; j++) { - final double v = values[off]; - retV[off++] = Util.eq(Double.NaN, v) ? replace - reference[j] : v; - } - } + Dictionary.replaceWithReferenceNanDenseWithoutNanCols(replace, reference, nRow, nCol, retV, values); } - ret.recomputeNonZeros(); - ret.examSparsity(); - return MatrixBlockDictionary.create(ret); } else { - - final double[] retV = ret.getDenseBlockValues(); if(_data.isInSparseFormat()) { final SparseBlock sb = _data.getSparseBlock(); for(int i = 0; i < nRow; i++) { @@ -2170,10 +2155,10 @@ private IDictionary replaceWithReferenceNan(double replace, double[] reference) final int apos = sb.pos(i); final int alen = sb.size(i) + apos; final double[] avals = sb.values(i); - final int[] aidx = sb.indexes(i); + final int[] aix = sb.indexes(i); for(int k = apos; k < alen; k++) { - final int c = aidx[k]; - final int outIdx = off + aidx[k]; + final int c = aix[k]; + final int outIdx = off + aix[k]; final double v = avals[k]; if(colsWithNan.contains(c)) retV[outIdx] = 0; @@ -2185,27 +2170,16 @@ else if(Util.eq(v, Double.NaN)) } } else { - int off = 0; final double[] values = _data.getDenseBlockValues(); - for(int i = 0; i < nRow; i++) { - for(int j = 0; j < nCol; j++) { - final double v = values[off]; - if(colsWithNan.contains(j)) - retV[off++] = 0; - else if(Util.eq(v, Double.NaN)) - retV[off++] = replace - reference[j]; - else - retV[off++] = v; - } - } + Dictionary.replaceWithReferenceNanDenseWithNanCols(replace, reference, nRow, nCol, colsWithNan, values, + retV); } - - ret.recomputeNonZeros(); - ret.examSparsity(); - return MatrixBlockDictionary.create(ret); } + ret.recomputeNonZeros(); + ret.examSparsity(); + return MatrixBlockDictionary.create(ret); } @Override @@ -2277,6 +2251,7 @@ public void productWithReference(double[] ret, int[] counts, double[] reference, } else values = _data.getDenseBlockValues(); + BigDecimal tmp = BigDecimal.ONE; int off = 0; for(int i = 0; i < nRow; i++) { @@ -2286,6 +2261,10 @@ public void productWithReference(double[] ret, int[] counts, double[] reference, ret[0] = 0; return; } + else if(!Double.isFinite(v)) { + ret[0] = v; + return; + } tmp = tmp.multiply(new BigDecimal(v).pow(counts[i], cont), cont); } } @@ -2294,7 +2273,8 @@ public void productWithReference(double[] ret, int[] counts, double[] reference, if(Math.abs(tmp.doubleValue()) == 0) ret[0] = 0; else if(!Double.isInfinite(ret[0])) - ret[0] = new BigDecimal(ret[0]).multiply(tmp, MathContext.DECIMAL128).doubleValue(); + ret[0] = new BigDecimal(ret[0]).multiply(tmp, cont).doubleValue(); + } @Override diff --git a/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java b/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java index 3803fe23afd..3b470e4f45c 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java +++ b/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java @@ -106,6 +106,7 @@ public static Collection data() { addSparse(tests, -10, 10, 10, 100, 0.1, 321); addSparse(tests, -10, 10, 2, 100, 0.04, 321); + addSparseWithNan(tests, 1, 10, 100, 100, 0.1, 321); tests.add(new Object[] {IdentityDictionary.create(2), Dictionary.create(new double[] {1, 0, 0, 1}), 2, 2}); tests.add(new Object[] {IdentityDictionary.create(2, true), // @@ -320,6 +321,20 @@ private static void addSparse(List tests, double min, double max, int tests.add(new Object[] {MatrixBlockDictionary.create(mb), Dictionary.create(dbv), rows, cols}); } + private static void addSparseWithNan(List tests, double min, double max, int rows, int cols, + double sparsity, int seed) { + + MatrixBlock mb = TestUtils.generateTestMatrixBlock(rows, cols, min, max, sparsity, seed); + + mb = TestUtils.floor(mb); + mb = mb.replaceOperations(null, min, Double.NaN); + MatrixBlock mb2 = new MatrixBlock(); + mb2.copy(mb); + mb2.sparseToDense(); + double[] dbv = mb2.getDenseBlockValues(); + tests.add(new Object[] {MatrixBlockDictionary.create(mb), Dictionary.create(dbv), rows, cols}); + } + @Test public void sum() { int[] counts = getCounts(nRow, 1324); @@ -422,19 +437,26 @@ public void productWithDoctoredReference2() { } public void productWithReference(double retV, double[] reference) { - // Shared - final int[] counts = getCounts(nRow, 1324); + try { - // A - final double[] aRet = new double[] {retV}; - a.productWithReference(aRet, counts, reference, nCol); + // Shared + final int[] counts = getCounts(nRow, 1324); - // B - final double[] bRet = new double[] {retV}; - b.productWithReference(bRet, counts, reference, nCol); + // A + final double[] aRet = new double[] {retV}; + a.productWithReference(aRet, counts, reference, nCol); - TestUtils.compareMatricesBitAvgDistance(// - aRet, bRet, 10, 10, "Not Equivalent values from product"); + // B + final double[] bRet = new double[] {retV}; + b.productWithReference(bRet, counts, reference, nCol); + + TestUtils.compareMatricesBitAvgDistance(// + aRet, bRet, 10, 10, "Not Equivalent values from product"); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } } @Test @@ -489,6 +511,119 @@ public void replaceWitReference() { assertNotEquals(before, aRep.getValue(r, c, nCol), 0.00001); } + @Test + public void replaceNaN() { + IDictionary ar = a.replace(Double.NaN, 0, nCol); + IDictionary br = b.replace(Double.NaN, 0, nCol); + compare(ar, br, nCol); + } + + @Test + public void replaceNaNWithRef() { + double[] ref1 = new double[nCol]; + IDictionary ar = a.replaceWithReference(Double.NaN, 1, ref1); + double[] ref2 = new double[nCol]; + IDictionary br = b.replaceWithReference(Double.NaN, 1, ref2); + compare(ar, br, nCol); + } + + @Test + public void replaceNaNWithRef12() { + double[] ref1 = new double[nCol]; + Arrays.fill(ref1, 1.2); + IDictionary ar = a.replaceWithReference(Double.NaN, 1, ref1); + double[] ref2 = new double[nCol]; + Arrays.fill(ref2, 1.2); + IDictionary br = b.replaceWithReference(Double.NaN, 1, ref2); + compare(ar, br, nCol); + } + + @Test + public void replaceNaNWithRefNaN() { + double[] ref1 = new double[nCol]; + ref1[0] = Double.NaN; + IDictionary ar = a.replaceWithReference(Double.NaN, 1, ref1); + double[] ref2 = new double[nCol]; + ref2[0] = Double.NaN; + IDictionary br = b.replaceWithReference(Double.NaN, 1, ref2); + compare(ar, br, nCol); + } + + @Test + public void replaceNaNWithRefNaN12() { + try { + + double[] ref1 = new double[nCol]; + Arrays.fill(ref1, 1.2); + ref1[0] = Double.NaN; + IDictionary ar = a.replaceWithReference(Double.NaN, 1, ref1); + double[] ref2 = new double[nCol]; + Arrays.fill(ref2, 1.2); + ref2[0] = Double.NaN; + IDictionary br = b.replaceWithReference(Double.NaN, 1, ref2); + compare(ar, br, nCol); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void replaceNaNWithRefNaN12ColGT2() { + try { + if(nCol > 2) { + double[] ref1 = new double[nCol]; + Arrays.fill(ref1, 1.2); + ref1[1] = Double.NaN; + IDictionary ar = a.replaceWithReference(Double.NaN, 1, ref1); + double[] ref2 = new double[nCol]; + Arrays.fill(ref2, 1.2); + ref2[1] = Double.NaN; + IDictionary br = b.replaceWithReference(Double.NaN, 1, ref2); + compare(ar, br, nCol); + } + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void replaceNaNWithRefNaNAllRefNaN() { + try { + double[] ref1 = new double[nCol]; + Arrays.fill(ref1, Double.NaN); + IDictionary ar = a.replaceWithReference(Double.NaN, 1, ref1); + double[] ref2 = new double[nCol]; + Arrays.fill(ref2, Double.NaN); + IDictionary br = b.replaceWithReference(Double.NaN, 1, ref2); + compare(ar, br, nCol); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void replaceNaNWithRefNaNAllRefNaNToZero() { + try { + double[] ref1 = new double[nCol]; + Arrays.fill(ref1, Double.NaN); + IDictionary ar = a.replaceWithReference(Double.NaN, 0, ref1); + double[] ref2 = new double[nCol]; + Arrays.fill(ref2, Double.NaN); + IDictionary br = b.replaceWithReference(Double.NaN, 0, ref2); + compare(ar, br, nCol); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + @Test public void rexpandCols() { if(nCol == 1) { @@ -659,7 +794,7 @@ public void equalsElOp() { public void opRightMinus() { BinaryOperator op = new BinaryOperator(Minus.getMinusFnObject()); double[] vals = TestUtils.generateTestVector(nCol, -1, 1, 1.0, 132L); - opRight(op, vals, ColIndexFactory.create(0, nCol)); + binOp(op, vals, ColIndexFactory.create(0, nCol)); } @Test @@ -673,7 +808,7 @@ public void opRightMinusNoCol() { public void opRightMinusZero() { BinaryOperator op = new BinaryOperator(Minus.getMinusFnObject()); double[] vals = new double[nCol]; - opRight(op, vals, ColIndexFactory.create(0, nCol)); + binOp(op, vals, ColIndexFactory.create(0, nCol)); } @Test @@ -681,22 +816,37 @@ public void opRightDivOne() { BinaryOperator op = new BinaryOperator(Divide.getDivideFnObject()); double[] vals = new double[nCol]; Arrays.fill(vals, 1); - opRight(op, vals, ColIndexFactory.create(0, nCol)); + binOp(op, vals, ColIndexFactory.create(0, nCol)); } @Test public void opRightDiv() { BinaryOperator op = new BinaryOperator(Divide.getDivideFnObject()); double[] vals = TestUtils.generateTestVector(nCol, -1, 1, 1.0, 232L); - opRight(op, vals, ColIndexFactory.create(0, nCol)); + binOp(op, vals, ColIndexFactory.create(0, nCol)); } - private void opRight(BinaryOperator op, double[] vals, IColIndex cols) { + private void binOp(BinaryOperator op, double[] vals, IColIndex cols) { try { IDictionary aa = a.binOpRight(op, vals, cols); IDictionary bb = b.binOpRight(op, vals, cols); compare(aa, bb, nRow, nCol); + + double[] ref = TestUtils.generateTestVector(nCol, 0, 10, 1.0, 33); + double[] newRef = TestUtils.generateTestVector(nCol, 0, 10, 1.0, 321); + aa = a.binOpRightWithReference(op, vals, cols, ref, newRef); + bb = b.binOpRightWithReference(op, vals, cols, ref, newRef); + compare(aa, bb, nRow, nCol); + + aa = a.binOpLeftWithReference(op, vals, cols, ref, newRef); + bb = b.binOpLeftWithReference(op, vals, cols, ref, newRef); + compare(aa, bb, nRow, nCol); + + aa = a.binOpLeft(op, vals, cols); + bb = b.binOpLeft(op, vals, cols); + compare(aa, bb, nRow, nCol); + } catch(Exception e) { e.printStackTrace(); @@ -932,8 +1082,17 @@ public void containsValueWithReference(double value, double[] reference) { } private static void compare(IDictionary a, IDictionary b, int nCol) { - assertEquals(a.getNumberOfValues(nCol), b.getNumberOfValues(nCol)); - compare(a, b, a.getNumberOfValues(nCol), nCol); + try { + + if(a == null && b == null) { + return; // all good. + } + assertEquals(a.getNumberOfValues(nCol), b.getNumberOfValues(nCol)); + compare(a, b, a.getNumberOfValues(nCol), nCol); + } + catch(NullPointerException e) { + fail("both outputs are not null: " + a + " vs " + b); + } } protected static void compare(IDictionary a, IDictionary b, int nRow, int nCol) { @@ -1217,6 +1376,19 @@ public void sumAllRowsToDoubleSqWithReference() { TestUtils.compareMatrices(aa, bb, 0.001); } + @Test + public void sumAllColsSqWithReference() { + double[] def = TestUtils.generateTestVector(nCol, 1, 10, 1.0, 3215213); + final int[] counts = getCounts(nRow, 1324); + + double[] aa = new double[nCol]; + double[] bb = new double[nCol]; + + a.colSumSqWithReference(aa, counts, ColIndexFactory.create(nCol), def); + b.colSumSqWithReference(bb, counts, ColIndexFactory.create(nCol), def); + TestUtils.compareMatrices(aa, bb, 0.001); + } + @Test public void aggColsMin() { IColIndex cols = ColIndexFactory.create(2, nCol + 2); @@ -1447,7 +1619,21 @@ public void getNNzCounts() { } long annz = a.getNumberNonZeros(counts, nCol); long bnnz = b.getNumberNonZeros(counts, nCol); + + long annzR = a.getNumberNonZerosWithReference(counts, new double[nCol], nRow); + long bnnzR = a.getNumberNonZerosWithReference(counts, new double[nCol], nRow); assertEquals(annz, bnnz); + assertEquals(annzR, bnnz); + assertEquals(annzR, bnnzR); + } + + @Test + public void getNNzCountsWithRef() { + int counts[] = getCounts(nRow, 231); + double[] ref = TestUtils.generateTestVector(nCol, -1, -1, 0.5, 23); + long annzR = a.getNumberNonZerosWithReference(counts, ref, nRow); + long bnnzR = a.getNumberNonZerosWithReference(counts, ref, nRow); + assertEquals(annzR, bnnzR); } @Test