Skip to content

Commit

Permalink
more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Baunsgaard committed Jan 26, 2025
1 parent fc442b9 commit 5124ca5
Show file tree
Hide file tree
Showing 4 changed files with 309 additions and 114 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,12 @@ public long getNumberNonZerosWithReference(int[] counts, double[] reference, int

@Override
public boolean containsValueWithReference(double pattern, double[] reference) {
if(Double.isNaN(pattern)){
for(int i = 0 ; i < reference.length; i++)
if(Double.isNaN(reference[i]))
return true;
return containsValue(pattern);
}
return getMBDict().containsValueWithReference(pattern, reference);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,8 @@ public boolean containsValue(double pattern) {

@Override
public boolean containsValueWithReference(double pattern, double[] reference) {
if(Double.isNaN(pattern))
return super.containsValueWithReference(pattern, reference);
final int nCol = reference.length;
for(int i = 0; i < _values.length; i++)
if(_values[i] + reference[i % nCol] == pattern)
Expand Down Expand Up @@ -913,46 +915,7 @@ public IDictionary replaceWithReference(double pattern, double replace, double[]
final int nCol = reference.length;
final int nRow = _values.length / nCol;
if(Util.eq(pattern, Double.NaN)) {
Set<Integer> colsWithNan = null;
for(int i = 0; i < reference.length; i++) {
if(Util.eq(reference[i], Double.NaN)) {
if(colsWithNan == null)
colsWithNan = new HashSet<>();
colsWithNan.add(i);
reference[i] = replace;
}
}

if(colsWithNan != null) {
final double[] retV = new double[_values.length];
for(int i = 0; i < nRow; i++) {
final int off = i * reference.length;
for(int j = 0; j < nCol; j++) {
final int cell = off + j;
if(colsWithNan.contains(j))
retV[cell] = 0;
else if(Util.eq(_values[cell], Double.NaN))
retV[cell] = replace - reference[j];
else
retV[cell] = _values[cell];
}
}
return create(retV);
}
else {
final double[] retV = new double[_values.length];
for(int i = 0; i < nRow; i++) {
final int off = i * reference.length;
for(int j = 0; j < nCol; j++) {
final int cell = off + j;
if(Util.eq(_values[cell], Double.NaN))
retV[cell] = replace - reference[j];
else
retV[cell] = _values[cell] ;
}
}
return create(retV);
}
return replaceWithReferenceNaN(replace, reference, nCol, nRow);
}
else {
final double[] retV = new double[_values.length];
Expand All @@ -969,6 +932,62 @@ else if(Util.eq(_values[cell], Double.NaN))
}
}

private IDictionary replaceWithReferenceNaN(double replace, double[] reference, final int nCol, final int nRow) {
final Set<Integer> colsWithNan = getColsWithNan(replace, reference);
final double[] retV;
if(colsWithNan != null) {
if(colsWithNan.size() == nCol && replace == 0)
return null;
retV = new double[_values.length];
replaceWithReferenceNanDenseWithNanCols(replace, reference, nRow, nCol, colsWithNan, _values, retV);
}
else {
retV = new double[_values.length];
replaceWithReferenceNanDenseWithoutNanCols(replace, reference, nRow, nCol, retV, _values);
}
return create(retV);
}

protected static Set<Integer> getColsWithNan(double replace, double[] reference) {
Set<Integer> colsWithNan = null;
for(int i = 0; i < reference.length; i++) {
if(Util.eq(reference[i], Double.NaN)) {
if(colsWithNan == null)
colsWithNan = new HashSet<>();
colsWithNan.add(i);
reference[i] = replace;
}
}
return colsWithNan;
}

protected static void replaceWithReferenceNanDenseWithoutNanCols(final double replace, final double[] reference,
final int nRow, final int nCol, final double[] retV, final double[] values) {
int off = 0;
for(int i = 0; i < nRow; i++) {
for(int j = 0; j < nCol; j++) {
final double v = values[off];
retV[off++] = Util.eq(Double.NaN, v) ? replace - reference[j] : v;
}
}
}

protected static void replaceWithReferenceNanDenseWithNanCols(final double replace, final double[] reference,
final int nRow, final int nCol, Set<Integer> colsWithNan, final double[] values, final double[] retV) {
int off = 0;
for(int i = 0; i < nRow; i++) {
for(int j = 0; j < nCol; j++) {
final double v = values[off];
if(colsWithNan.contains(j))
retV[off++] = 0;
else if(Util.eq(v, Double.NaN))
retV[off++] = replace - reference[j];
else
retV[off++] = v;
}
}
}

@Override
public void product(double[] ret, int[] counts, int nCol) {
if(ret[0] == 0)
Expand Down Expand Up @@ -1024,17 +1043,22 @@ public void productWithReference(double[] ret, int[] counts, double[] reference,
if(ret[0] == 0)
return;
final MathContext cont = MathContext.DECIMAL128;
final int len = counts.length;
final int nRow = counts.length;
final int nCol = reference.length;

BigDecimal tmp = BigDecimal.ONE;
int off = 0;
for(int i = 0; i < len; i++) {
for(int i = 0; i < nRow; i++) {
for(int j = 0; j < nCol; j++) {
final double v = _values[off++] + reference[j];
if(v == 0) {
ret[0] = 0;
return;
}
else if(!Double.isFinite(v)) {
ret[0] = v;
return;
}
tmp = tmp.multiply(new BigDecimal(v).pow(counts[i], cont), cont);
}
}
Expand All @@ -1044,6 +1068,7 @@ public void productWithReference(double[] ret, int[] counts, double[] reference,
ret[0] = 0;
else if(!Double.isInfinite(ret[0]))
ret[0] = new BigDecimal(ret[0]).multiply(tmp, MathContext.DECIMAL128).doubleValue();

}

@Override
Expand Down Expand Up @@ -1192,7 +1217,7 @@ public void TSMMToUpperTriangleSparseScaling(SparseBlock left, IColIndex rowsLef
public boolean equals(IDictionary o) {
if(o instanceof Dictionary)
return Arrays.equals(_values, ((Dictionary) o)._values);
else if (o != null)
else if(o != null)
return o.equals(this);
return false;
}
Expand All @@ -1219,7 +1244,7 @@ public IDictionary reorder(int[] reorder) {
return ret;
}

@Override
@Override
protected IDictionary rightMMPreAggSparseSelectedCols(int numVals, SparseBlock b, IColIndex thisCols,
IColIndex aggregateColumns) {

Expand Down Expand Up @@ -1264,7 +1289,7 @@ private void sparseAddSelected(int sPos, int sEnd, int aggColSize, IColIndex agg
retIdx = 0;
}

@Override
@Override
protected IDictionary rightMMPreAggSparseAllColsRight(int numVals, SparseBlock b, IColIndex thisCols,
int nColRight) {
final int thisColsSize = thisCols.size();
Expand All @@ -1291,15 +1316,14 @@ protected IDictionary rightMMPreAggSparseAllColsRight(int numVals, SparseBlock b
return Dictionary.create(ret);
}

private void SparseAdd(int sPos, int sEnd, double[] ret, int offOut, int[] sIdx, double[] sVals, double v) {
private void SparseAdd(int sPos, int sEnd, double[] ret, int offOut, int[] sIdx, double[] sVals, double v) {
if(v != 0) {
for(int k = sPos; k < sEnd; k++) { // cols right with value
ret[offOut + sIdx[k]] += v * sVals[k];
}
}
}


@Override
public IDictionary append(double[] row) {
double[] retV = new double[_values.length + row.length];
Expand All @@ -1308,5 +1332,4 @@ public IDictionary append(double[] row) {
return new Dictionary(retV);
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import java.math.BigDecimal;
import java.math.MathContext;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;

import org.apache.commons.lang3.NotImplementedException;
Expand All @@ -36,6 +35,7 @@
import org.apache.sysds.runtime.compress.colgroup.indexes.SingleIndex;
import org.apache.sysds.runtime.compress.colgroup.indexes.TwoIndex;
import org.apache.sysds.runtime.compress.utils.Util;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.DenseBlockFP64;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.SparseBlockCSR;
Expand Down Expand Up @@ -1495,6 +1495,8 @@ public boolean containsValue(double pattern) {

@Override
public boolean containsValueWithReference(double pattern, double[] reference) {
if(Double.isNaN(pattern))
return super.containsValueWithReference(pattern, reference);
if(_data.isInSparseFormat()) {
final SparseBlock sb = _data.getSparseBlock();
for(int i = 0; i < _data.getNumRows(); i++) {
Expand Down Expand Up @@ -2059,9 +2061,8 @@ public IDictionary replace(double pattern, double replace, int nCol) {

@Override
public IDictionary replaceWithReference(double pattern, double replace, double[] reference) {
if(Util.eq(pattern, Double.NaN)) {
if(Util.eq(pattern, Double.NaN))
return replaceWithReferenceNan(replace, reference);
}

final int nRow = _data.getNumRows();
final int nCol = _data.getNumColumns();
Expand Down Expand Up @@ -2108,27 +2109,19 @@ public IDictionary replaceWithReference(double pattern, double replace, double[]
}

private IDictionary replaceWithReferenceNan(double replace, double[] reference) {

final Set<Integer> colsWithNan = Dictionary.getColsWithNan(replace, reference);
final int nRow = _data.getNumRows();
final int nCol = _data.getNumColumns();
if(colsWithNan != null && colsWithNan.size() == nCol && replace == 0)
return null;

final MatrixBlock ret = new MatrixBlock(nRow, nCol, false);
ret.allocateDenseBlock();

Set<Integer> colsWithNan = null;
for(int i = 0; i < reference.length; i++) {
if(Util.eq(reference[i], Double.NaN)) {
if(colsWithNan == null)
colsWithNan = new HashSet<>();
colsWithNan.add(i);
reference[i] = replace;
}
}
final double[] retV = ret.getDenseBlockValues();

if(colsWithNan == null) {

final double[] retV = ret.getDenseBlockValues();
int off = 0;
if(_data.isInSparseFormat()) {
final DenseBlock db = ret.getDenseBlock();
final SparseBlock sb = _data.getSparseBlock();
for(int i = 0; i < nRow; i++) {
if(sb.isEmpty(i))
Expand All @@ -2137,30 +2130,22 @@ private IDictionary replaceWithReferenceNan(double replace, double[] reference)
final int apos = sb.pos(i);
final int alen = sb.size(i) + apos;
final double[] avals = sb.values(i);
final int[] aix = sb.indexes(i);
int j = 0;
int off = db.pos(i);
for(int k = apos; k < alen; k++) {
final double v = avals[k];
retV[off++] = Util.eq(Double.NaN, v) ? replace - reference[j] : v;
retV[off + aix[k]] = Util.eq(Double.NaN, v) ? replace - reference[j] : v;
}
}
}
else {
final double[] values = _data.getDenseBlockValues();
for(int i = 0; i < nRow; i++) {
for(int j = 0; j < nCol; j++) {
final double v = values[off];
retV[off++] = Util.eq(Double.NaN, v) ? replace - reference[j] : v;
}
}
Dictionary.replaceWithReferenceNanDenseWithoutNanCols(replace, reference, nRow, nCol, retV, values);
}

ret.recomputeNonZeros();
ret.examSparsity();
return MatrixBlockDictionary.create(ret);
}
else {

final double[] retV = ret.getDenseBlockValues();
if(_data.isInSparseFormat()) {
final SparseBlock sb = _data.getSparseBlock();
for(int i = 0; i < nRow; i++) {
Expand All @@ -2170,10 +2155,10 @@ private IDictionary replaceWithReferenceNan(double replace, double[] reference)
final int apos = sb.pos(i);
final int alen = sb.size(i) + apos;
final double[] avals = sb.values(i);
final int[] aidx = sb.indexes(i);
final int[] aix = sb.indexes(i);
for(int k = apos; k < alen; k++) {
final int c = aidx[k];
final int outIdx = off + aidx[k];
final int c = aix[k];
final int outIdx = off + aix[k];
final double v = avals[k];
if(colsWithNan.contains(c))
retV[outIdx] = 0;
Expand All @@ -2185,27 +2170,16 @@ else if(Util.eq(v, Double.NaN))
}
}
else {
int off = 0;
final double[] values = _data.getDenseBlockValues();
for(int i = 0; i < nRow; i++) {
for(int j = 0; j < nCol; j++) {
final double v = values[off];

if(colsWithNan.contains(j))
retV[off++] = 0;
else if(Util.eq(v, Double.NaN))
retV[off++] = replace - reference[j];
else
retV[off++] = v;
}
}
Dictionary.replaceWithReferenceNanDenseWithNanCols(replace, reference, nRow, nCol, colsWithNan, values,
retV);
}

ret.recomputeNonZeros();
ret.examSparsity();
return MatrixBlockDictionary.create(ret);
}

ret.recomputeNonZeros();
ret.examSparsity();
return MatrixBlockDictionary.create(ret);
}

@Override
Expand Down Expand Up @@ -2277,6 +2251,7 @@ public void productWithReference(double[] ret, int[] counts, double[] reference,
}
else
values = _data.getDenseBlockValues();

BigDecimal tmp = BigDecimal.ONE;
int off = 0;
for(int i = 0; i < nRow; i++) {
Expand All @@ -2286,6 +2261,10 @@ public void productWithReference(double[] ret, int[] counts, double[] reference,
ret[0] = 0;
return;
}
else if(!Double.isFinite(v)) {
ret[0] = v;
return;
}
tmp = tmp.multiply(new BigDecimal(v).pow(counts[i], cont), cont);
}
}
Expand All @@ -2294,7 +2273,8 @@ public void productWithReference(double[] ret, int[] counts, double[] reference,
if(Math.abs(tmp.doubleValue()) == 0)
ret[0] = 0;
else if(!Double.isInfinite(ret[0]))
ret[0] = new BigDecimal(ret[0]).multiply(tmp, MathContext.DECIMAL128).doubleValue();
ret[0] = new BigDecimal(ret[0]).multiply(tmp, cont).doubleValue();

}

@Override
Expand Down
Loading

0 comments on commit 5124ca5

Please sign in to comment.