Skip to content

Commit

Permalink
Merge pull request #1553 from veg/develop
Browse files Browse the repository at this point in the history
CMake fixes and a few more speed tweaks.
  • Loading branch information
spond authored Jan 3, 2023
2 parents 5658b3f + 8431caa commit d3bda6c
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 197 deletions.
6 changes: 3 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ macro(PCL_CHECK_FOR_SSE4)
set(CMAKE_REQUIRED_FLAGS)

#if(CMAKE_COMPILER_IS_GNUCC OR CMAKE_COMPILER_IS_GNUCXX OR CMAKE_COMPILER_IS_CLANG)
set(CMAKE_REQUIRED_FLAGS "-msse4")
set(CMAKE_REQUIRED_FLAGS "-msse4.1")
#endif(CMAKE_COMPILER_IS_GNUCC OR CMAKE_COMPILER_IS_GNUCXX OR CMAKE_COMPILER_IS_CLANG)

check_cxx_source_runs("
Expand Down Expand Up @@ -230,7 +230,7 @@ if (CMAKE_CXX_COMPILER_ID MATCHES "Clang")
PCL_CHECK_FOR_SSE4()
if(${HAVE_SSE4_EXTENSIONS})
add_definitions (-D_SLKP_USE_SSE_INTRINSICS)
set(DEFAULT_COMPILE_FLAGS "${DEFAULT_COMPILE_FLAGS} -msse3 ")
set(DEFAULT_COMPILE_FLAGS "${DEFAULT_COMPILE_FLAGS} -msse4.1 ")
endif(${HAVE_SSE4_EXTENSIONS})
endif(NOSSE4)
else(NOAVX)
Expand All @@ -247,7 +247,7 @@ if (CMAKE_CXX_COMPILER_ID MATCHES "Clang")
PCL_CHECK_FOR_SSE4()
if(${HAVE_SSE4_EXTENSIONS})
add_definitions (-D_SLKP_USE_SSE_INTRINSICS)
set(DEFAULT_COMPILE_FLAGS "${DEFAULT_COMPILE_FLAGS} -msse3 ")
set(DEFAULT_COMPILE_FLAGS "${DEFAULT_COMPILE_FLAGS} -msse4.1 ")
endif(${HAVE_SSE4_EXTENSIONS})
endif (${HAVE_AVX_EXTENSIONS})
endif(NOAVX)
Expand Down
316 changes: 122 additions & 194 deletions src/core/matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2740,14 +2740,9 @@ HBLObjectRef _Matrix::Evaluate (bool replace)
ci = k - ri * vDim;

result.Store (ri,ci,formValue->Value());
//result[HashBack(i)] = formValue->Value();
//if (ri < 0 || ri >= hDim) {
// abort();
//}
if (ci == ri) diag_skip[ri] = true;
} else {
result[k] = 0;
//result[HashBack(i)] = 0.;
}
}
}
Expand Down Expand Up @@ -2775,29 +2770,6 @@ HBLObjectRef _Matrix::Evaluate (bool replace)
result.Store (i,i,diag_storage[i]);
}
}

/*for (long i = 0; i<hDim; i++) {
long k = Hash(i,i);
if ((k>=0)&&theFormulas[k]->IsEmpty()) {
hyFloat *st = &result[k];
*st=0;
for (long j = 0; j<vDim; j++) {
if (j==i) {
continue;
}
*st-=result(i,j);
}
} else if (k<0) {
hyFloat *st = &result[i*vDim+i];
*st=0;
for (long j = 0; j<vDim; j++) {
if (j==i) {
continue;
}
*st-=result(i,j);
}
}
}*/
}
} else {
for (long i = 0; i<lDim; i++) {
Expand Down Expand Up @@ -4096,180 +4068,136 @@ void _Matrix::Multiply (_Matrix& storage, _Matrix const& secondArg) const

} else if (theIndex && !secondArg.theIndex) { // sparse multiplied by non-sparse
if (storageType == 1 && secondArg.storageType ==1) { // both numeric

if ( vDim == hDim && secondArg.vDim==secondArg.hDim) { // both square and same dimension
/*
break out a special case for universal code
If the sparse LHS matrix has a non-zero entry (i,k), it will contribute to
cells in the i-th row of the result matrix (via products (i,k)*(k,j))
This will, however, have poor memory locality fo acccessing k-th row of the second matrix
over and over again.
*/

if (vDim == 61L) {
if (compressedIndex) {
#ifdef _SLKP_USE_ARM_NEON
hyFloat * _hprestrict_ res = storage.theData;
long currentXIndex = 0L;
for (long i = 0; i < 61; i++) {
long up = compressedIndex[i];

auto handle_chunk3 = [&](int o1, int o2, int o3) -> void {
float64x2x4_t R1 = vld4q_f64 (res + o1),
R2 = vld4q_f64 (res + o2),
R3 = vld4q_f64 (res + o3);

for (long cxi = currentXIndex; cxi < up; cxi++) {
long currentXColumn = compressedIndex[cxi + 61];
hyFloat * secArg = secondArg.theData + currentXColumn*61;
float64x2_t value_op = vdupq_n_f64 (theData[cxi]);

float64x2x4_t C1 = vld4q_f64 (secArg + o1),
C2 = vld4q_f64 (secArg + o2),
C3 = vld4q_f64 (secArg + o3);

R1.val[0] = vfmaq_f64 (R1.val[0], value_op, C1.val[0]);
R1.val[1] = vfmaq_f64 (R1.val[1], value_op, C1.val[1]);
R1.val[2] = vfmaq_f64 (R1.val[2], value_op, C1.val[2]);
R1.val[3] = vfmaq_f64 (R1.val[3], value_op, C1.val[3]);

R2.val[0] = vfmaq_f64 (R2.val[0], value_op, C2.val[0]);
R2.val[1] = vfmaq_f64 (R2.val[1], value_op, C2.val[1]);
R2.val[2] = vfmaq_f64 (R2.val[2], value_op, C2.val[2]);
R2.val[3] = vfmaq_f64 (R2.val[3], value_op, C2.val[3]);

R3.val[0] = vfmaq_f64 (R3.val[0], value_op, C3.val[0]);
R3.val[1] = vfmaq_f64 (R3.val[1], value_op, C3.val[1]);
R3.val[2] = vfmaq_f64 (R3.val[2], value_op, C3.val[2]);
R3.val[3] = vfmaq_f64 (R3.val[3], value_op, C3.val[3]);
if (compressedIndex) {


}
vst4q_f64 (res + o1, R1);
vst4q_f64 (res + o2, R2);
vst4q_f64 (res + o3, R3);
};


if (currentXIndex < up) {
//printf ("%d %d %d\n", i, currentXIndex, up);

handle_chunk3 (0,8,16);
handle_chunk3 (24,32,40);

float64x2x4_t R1 = vld4q_f64 (res + 48);
float64x2x2_t R2 = vld2q_f64 (res + 56);
double r60 = res[60];

for (long cxi = currentXIndex; cxi < up; cxi++) {
long currentXColumn = compressedIndex[cxi + 61];
// go into the second matrix and look up all the non-zero entries in the currentXColumn row

hyFloat value = theData[cxi];
hyFloat * secArg = secondArg.theData + currentXColumn*61;

// go into the second matrix and look up all the non-zero entries in the currentXColumn row
float64x2_t value_op = vdupq_n_f64 (value);

float64x2x4_t C1 = vld4q_f64 (secArg+48);
float64x2x2_t C2 = vld2q_f64 (secArg+56);

R1.val[0] = vfmaq_f64 (R1.val[0], value_op, C1.val[0]);
R1.val[1] = vfmaq_f64 (R1.val[1], value_op, C1.val[1]);
R1.val[2] = vfmaq_f64 (R1.val[2], value_op, C1.val[2]);
R1.val[3] = vfmaq_f64 (R1.val[3], value_op, C1.val[3]);
R2.val[0] = vfmaq_f64 (R2.val[0], value_op, C2.val[0]);
R2.val[1] = vfmaq_f64 (R2.val[1], value_op, C2.val[1]);

r60 += value * secArg[60];
}

vst4q_f64 (res + 48, R1);
vst2q_f64 (res + 56, R2);
res[60] = r60;
}

res += 61;
currentXIndex = up;
}

#ifdef _SLKP_USE_ARM_NEON


hyFloat * _hprestrict_ res = storage.theData;
long currentXIndex = 0L;
for (long i = 0; i < 61; i++) {
long up = compressedIndex[i];

if (currentXIndex < up) {
float64x2x2_t R[15]; // store 60 elements of this row

#pragma unroll 3
for (int k = 0; k < 15; k++) {
R[k] = vld2q_f64 (res + (k<<2));
}

hyFloat r60 = res[60]; // and the 61st element

for (long cxi = currentXIndex; cxi < up; cxi++) {
long currentXColumn = compressedIndex[cxi + 61];
hyFloat * secArg = secondArg.theData + currentXColumn*61;

hyFloat value = theData[cxi];
float64x2_t value_op = vdupq_n_f64 (value);

for (int k = 0; k < 5; k++) {
int k12 = k*12,
k3 = k*3;

float64x2x2_t C1 = vld2q_f64 (secArg + k12),
C2 = vld2q_f64 (secArg + k12 + 4),
C3 = vld2q_f64 (secArg + k12 + 8);

R[k3].val[0] = vfmaq_f64 (R[k3].val[0], value_op, C1.val[0]);
R[k3].val[1] = vfmaq_f64 (R[k3].val[1], value_op, C1.val[1]);

R[k3+1].val[0] = vfmaq_f64 (R[k3+1].val[0], value_op, C2.val[0]);
R[k3+1].val[1] = vfmaq_f64 (R[k3+1].val[1], value_op, C2.val[1]);

R[k3+2].val[0] = vfmaq_f64 (R[k3+2].val[0], value_op, C3.val[0]);
R[k3+2].val[1] = vfmaq_f64 (R[k3+2].val[1], value_op, C3.val[1]);

}
r60 += value * secArg[60];



}

#pragma unroll 3
for (int k = 0; k < 15; k++) {
vst2q_f64 (res + (k<<2), R[k]);
}

res[60] = r60;
}
res += 61;
currentXIndex = up;
}


#elif defined _SLKP_USE_AVX_INTRINSICS
hyFloat * _hprestrict_ res = storage.theData;
long currentXIndex = 0L;
for (long i = 0; i < 61; i++) {
long up = compressedIndex[i];

auto handle_chunk4 = [&](int o1, int o2) -> void {
__m256d R1 = _mm256_loadu_pd (res + o1),
R2 = _mm256_loadu_pd (res + o1 + 4),
R3 = _mm256_loadu_pd (res + o2),
R4 = _mm256_loadu_pd (res + o2 + 4);


for (long cxi = currentXIndex; cxi < up; cxi++) {
long currentXColumn = compressedIndex[cxi + 61];
hyFloat * secArg = secondArg.theData + currentXColumn*61;
__m256d value_op = _mm256_broadcast_sd (theData + cxi);

__m256d C1 = _mm256_loadu_pd (secArg + o1),
C2 = _mm256_loadu_pd (secArg + o1 + 4),
C3 = _mm256_loadu_pd (secArg + o2),
C4 = _mm256_loadu_pd (secArg + o2 + 4);

R1 = _hy_matrix_handle_axv_mfma (R1, value_op,C1);
R2 = _hy_matrix_handle_axv_mfma (R2, value_op,C2);
R3 = _hy_matrix_handle_axv_mfma (R3, value_op,C3);
R4 = _hy_matrix_handle_axv_mfma (R4, value_op,C4);

}
_mm256_storeu_pd (res + o1, R1);
_mm256_storeu_pd (res + o1 + 4, R2);
_mm256_storeu_pd (res + o2, R3);
_mm256_storeu_pd (res + o2 + 4, R4);
};
auto handle_chunk3 = [&](int o1) -> void {
__m256d R1 = _mm256_loadu_pd (res + o1),
R2 = _mm256_loadu_pd (res + o1 + 4),
R3 = _mm256_loadu_pd (res + o1 + 8);


for (long cxi = currentXIndex; cxi < up; cxi++) {
long currentXColumn = compressedIndex[cxi + 61];
hyFloat * secArg = secondArg.theData + currentXColumn*61;
__m256d value_op = _mm256_broadcast_sd (theData + cxi);

__m256d C1 = _mm256_loadu_pd (secArg + o1),
C2 = _mm256_loadu_pd (secArg + o1 + 4),
C3 = _mm256_loadu_pd (secArg + o1 + 8);

R1 = _hy_matrix_handle_axv_mfma (R1, value_op,C1);
R2 = _hy_matrix_handle_axv_mfma (R2, value_op,C2);
R3 = _hy_matrix_handle_axv_mfma (R3, value_op,C3);

}
_mm256_storeu_pd (res + o1, R1);
_mm256_storeu_pd (res + o1 + 4, R2);
_mm256_storeu_pd (res + o1 + 8, R3);
};

if (currentXIndex < up) {
//printf ("%d %d %d\n", i, currentXIndex, up);

handle_chunk4 (0,8);
handle_chunk4 (16,24);
handle_chunk4 (32,40);
handle_chunk3 (48);

double r60 = res[60];

for (long cxi = currentXIndex; cxi < up; cxi++) {
long currentXColumn = compressedIndex[cxi + 61];
// go into the second matrix and look up all the non-zero entries in the currentXColumn row

hyFloat * secArg = secondArg.theData + currentXColumn*61;

r60 += theData[cxi] * secArg[60];
}

res[60] = r60;
}

res += 61;
currentXIndex = up;
}
hyFloat * _hprestrict_ res = storage.theData;
long currentXIndex = 0L;
for (long i = 0; i < 61; i++) {
long up = compressedIndex[i];

if (currentXIndex < up) {
__m256d R[15]; // store 60 elements of this row

#pragma unroll 3
for (int k = 0; k < 15; k++) {
R[k] = _mm256_loadu_pd (res + (k<<2));
}

hyFloat r60 = res[60]; // and the 61st element

for (long cxi = currentXIndex; cxi < up; cxi++) {
long currentXColumn = compressedIndex[cxi + 61];
hyFloat * secArg = secondArg.theData + currentXColumn*61;

hyFloat value = theData[cxi];
__m256d value_op = _mm256_set1_pd (value);

for (int k = 0; k < 3; k++) {
int k12 = k*20,
k3 = k*5;


R[k3] = _hy_matrix_handle_axv_mfma (R[k3], value_op, _mm256_loadu_pd (secArg + k12));
R[k3+1] = _hy_matrix_handle_axv_mfma (R[k3+1], value_op, _mm256_loadu_pd (secArg + k12 + 4));
R[k3+2] = _hy_matrix_handle_axv_mfma (R[k3+2], value_op, _mm256_loadu_pd (secArg + k12 + 8));
R[k3+3] = _hy_matrix_handle_axv_mfma (R[k3+3], value_op, _mm256_loadu_pd (secArg + k12 + 12));
R[k3+4] = _hy_matrix_handle_axv_mfma (R[k3+4], value_op, _mm256_loadu_pd (secArg + k12 + 16));

}
r60 += value * secArg[60];



}

#pragma unroll 3
for (int k = 0; k < 15; k++) {
_mm256_storeu_pd (res + (k<<2), R[k]);
}

res[60] = r60;
}
res += 61;
currentXIndex = up;
}
#else
long currentXIndex = 0L;
hyFloat * _hprestrict_ res = storage.theData;
Expand Down

0 comments on commit d3bda6c

Please sign in to comment.