Skip to content

Commit

Permalink
Speeding up matmul()
Browse files Browse the repository at this point in the history
Still infinitely slower than the numpy version, but better than before
at least.
  • Loading branch information
iamsrp-deshaw committed Sep 30, 2024
1 parent 7cee6f0 commit f59d191
Show file tree
Hide file tree
Showing 3 changed files with 524 additions and 65 deletions.
262 changes: 233 additions & 29 deletions java/src/main/java/com/deshaw/hypercube/CubeMath.java
Original file line number Diff line number Diff line change
Expand Up @@ -11307,6 +11307,9 @@ else if (a.getNDim() == 2 && b.getNDim() == 2) {
final Dimension.Accessor<?>[] bSlice =
new Dimension.Accessor<?>[] { null, bDims[1].at(0) };
if (a.slice(aSlice).matches(b.slice(bSlice))) {
// A regular matrix multiply where we compute the dot product of
// the row and column to get their intersection coordinate's
// value.
final long[] ai = new long[2];
final long[] bi = new long[2];
final long[] ri = new long[2];
Expand All @@ -11315,14 +11318,62 @@ else if (a.getNDim() == 2 && b.getNDim() == 2) {
final IntegerHypercube da = (IntegerHypercube)a;
final IntegerHypercube db = (IntegerHypercube)b;
final IntegerHypercube dr = (IntegerHypercube)r;
for (long i=0; i < aDims[0].length(); i++) {
ai[0] = ri[0] = i;
for (long j=0; j < bDims[1].length(); j++) {
bi[1] = ri[1] = j;

// We will copy out the column from 'b' for faster access,
// if it's small enough to fit into an array. 2^30 doubles
// is 16GB for one column which is totally possible for a
// non-square matrix but, we hope, most matrices will not be
// quite that big. We could use an array of arrays to handle
// that case but this is slower for the general case.
final int[] bcol = (bDims[0].length() < 1<<30)
? new int[(int)bDims[0].length()]
: null;

// Where we start striding, see below.
ai[1] = bi[0] = 0;

// The stride through the flattened data, to walk a column
// in 'b'. We know that the format of the data is C-style in
// flattened form, so moving one row length in distance will
// step down one column index.
final long bs = bDims[1].length();

// Flipped the ordering of 'i' and 'j' since it's more cache
// efficient to copy out the column data (once) and then to
// stride through the rows each time.
for (long j=0; j < bDims[1].length(); j++) {
bi[1] = ri[1] = j;
if (bcol != null) {
long bo = db.toOffset(bi);
for (int i=0; i < bcol.length; i++, bo += bs) {
bcol[i] = db.getAt(bo);
}
}
for (long i=0; i < aDims[0].length(); i++) {
// We will stride through the two cubes pulling out
// the values for the sum directly, since we know
// their shape. This is much faster than going via
// the coordinate-based lookup. The stride in 'a' is
// 1 since it's walking along a row; in b it's the
// the row length, since it's walking along a column.
// Both will be the same number of steps so we only
// need to know when to stop talking in 'a'.
ai[0] = ri[0] = i;
long ao = da.toOffset(ai);
int sum = 0;
for (long k=0; k < bDims[0].length(); k++) {
ai[1] = bi[0] = k;
sum += da.get(ai) * db.get(bi);
if (bcol == null) {
final long ae = ao + aDims[1].length();
for (long bo = db.toOffset(bi);
ao < ae; ao++,
bo += bs)
{
sum += da.getAt(ao) * db.getAt(bo);
}
}
else {
for (int bo=0 ; bo < bcol.length; ao++, bo++) {
sum += da.getAt(ao) * bcol[bo];
}
}
dr.set(sum, ri);
}
Expand Down Expand Up @@ -13420,6 +13471,9 @@ else if (a.getNDim() == 2 && b.getNDim() == 2) {
final Dimension.Accessor<?>[] bSlice =
new Dimension.Accessor<?>[] { null, bDims[1].at(0) };
if (a.slice(aSlice).matches(b.slice(bSlice))) {
// A regular matrix multiply where we compute the dot product of
// the row and column to get their intersection coordinate's
// value.
final long[] ai = new long[2];
final long[] bi = new long[2];
final long[] ri = new long[2];
Expand All @@ -13428,14 +13482,62 @@ else if (a.getNDim() == 2 && b.getNDim() == 2) {
final LongHypercube da = (LongHypercube)a;
final LongHypercube db = (LongHypercube)b;
final LongHypercube dr = (LongHypercube)r;
for (long i=0; i < aDims[0].length(); i++) {
ai[0] = ri[0] = i;
for (long j=0; j < bDims[1].length(); j++) {
bi[1] = ri[1] = j;

// We will copy out the column from 'b' for faster access,
// if it's small enough to fit into an array. 2^30 doubles
// is 16GB for one column which is totally possible for a
// non-square matrix but, we hope, most matrices will not be
// quite that big. We could use an array of arrays to handle
// that case but this is slower for the general case.
final long[] bcol = (bDims[0].length() < 1<<30)
? new long[(int)bDims[0].length()]
: null;

// Where we start striding, see below.
ai[1] = bi[0] = 0;

// The stride through the flattened data, to walk a column
// in 'b'. We know that the format of the data is C-style in
// flattened form, so moving one row length in distance will
// step down one column index.
final long bs = bDims[1].length();

// Flipped the ordering of 'i' and 'j' since it's more cache
// efficient to copy out the column data (once) and then to
// stride through the rows each time.
for (long j=0; j < bDims[1].length(); j++) {
bi[1] = ri[1] = j;
if (bcol != null) {
long bo = db.toOffset(bi);
for (int i=0; i < bcol.length; i++, bo += bs) {
bcol[i] = db.getAt(bo);
}
}
for (long i=0; i < aDims[0].length(); i++) {
// We will stride through the two cubes pulling out
// the values for the sum directly, since we know
// their shape. This is much faster than going via
// the coordinate-based lookup. The stride in 'a' is
// 1 since it's walking along a row; in b it's the
// the row length, since it's walking along a column.
// Both will be the same number of steps so we only
// need to know when to stop talking in 'a'.
ai[0] = ri[0] = i;
long ao = da.toOffset(ai);
long sum = 0;
for (long k=0; k < bDims[0].length(); k++) {
ai[1] = bi[0] = k;
sum += da.get(ai) * db.get(bi);
if (bcol == null) {
final long ae = ao + aDims[1].length();
for (long bo = db.toOffset(bi);
ao < ae; ao++,
bo += bs)
{
sum += da.getAt(ao) * db.getAt(bo);
}
}
else {
for (int bo=0 ; bo < bcol.length; ao++, bo++) {
sum += da.getAt(ao) * bcol[bo];
}
}
dr.set(sum, ri);
}
Expand Down Expand Up @@ -15533,6 +15635,9 @@ else if (a.getNDim() == 2 && b.getNDim() == 2) {
final Dimension.Accessor<?>[] bSlice =
new Dimension.Accessor<?>[] { null, bDims[1].at(0) };
if (a.slice(aSlice).matches(b.slice(bSlice))) {
// A regular matrix multiply where we compute the dot product of
// the row and column to get their intersection coordinate's
// value.
final long[] ai = new long[2];
final long[] bi = new long[2];
final long[] ri = new long[2];
Expand All @@ -15541,14 +15646,62 @@ else if (a.getNDim() == 2 && b.getNDim() == 2) {
final FloatHypercube da = (FloatHypercube)a;
final FloatHypercube db = (FloatHypercube)b;
final FloatHypercube dr = (FloatHypercube)r;
for (long i=0; i < aDims[0].length(); i++) {
ai[0] = ri[0] = i;
for (long j=0; j < bDims[1].length(); j++) {
bi[1] = ri[1] = j;

// We will copy out the column from 'b' for faster access,
// if it's small enough to fit into an array. 2^30 doubles
// is 16GB for one column which is totally possible for a
// non-square matrix but, we hope, most matrices will not be
// quite that big. We could use an array of arrays to handle
// that case but this is slower for the general case.
final float[] bcol = (bDims[0].length() < 1<<30)
? new float[(int)bDims[0].length()]
: null;

// Where we start striding, see below.
ai[1] = bi[0] = 0;

// The stride through the flattened data, to walk a column
// in 'b'. We know that the format of the data is C-style in
// flattened form, so moving one row length in distance will
// step down one column index.
final long bs = bDims[1].length();

// Flipped the ordering of 'i' and 'j' since it's more cache
// efficient to copy out the column data (once) and then to
// stride through the rows each time.
for (long j=0; j < bDims[1].length(); j++) {
bi[1] = ri[1] = j;
if (bcol != null) {
long bo = db.toOffset(bi);
for (int i=0; i < bcol.length; i++, bo += bs) {
bcol[i] = db.getAt(bo);
}
}
for (long i=0; i < aDims[0].length(); i++) {
// We will stride through the two cubes pulling out
// the values for the sum directly, since we know
// their shape. This is much faster than going via
// the coordinate-based lookup. The stride in 'a' is
// 1 since it's walking along a row; in b it's the
// the row length, since it's walking along a column.
// Both will be the same number of steps so we only
// need to know when to stop talking in 'a'.
ai[0] = ri[0] = i;
long ao = da.toOffset(ai);
float sum = 0;
for (long k=0; k < bDims[0].length(); k++) {
ai[1] = bi[0] = k;
sum += da.get(ai) * db.get(bi);
if (bcol == null) {
final long ae = ao + aDims[1].length();
for (long bo = db.toOffset(bi);
ao < ae; ao++,
bo += bs)
{
sum += da.getAt(ao) * db.getAt(bo);
}
}
else {
for (int bo=0 ; bo < bcol.length; ao++, bo++) {
sum += da.getAt(ao) * bcol[bo];
}
}
dr.set(sum, ri);
}
Expand Down Expand Up @@ -17662,6 +17815,9 @@ else if (a.getNDim() == 2 && b.getNDim() == 2) {
final Dimension.Accessor<?>[] bSlice =
new Dimension.Accessor<?>[] { null, bDims[1].at(0) };
if (a.slice(aSlice).matches(b.slice(bSlice))) {
// A regular matrix multiply where we compute the dot product of
// the row and column to get their intersection coordinate's
// value.
final long[] ai = new long[2];
final long[] bi = new long[2];
final long[] ri = new long[2];
Expand All @@ -17670,14 +17826,62 @@ else if (a.getNDim() == 2 && b.getNDim() == 2) {
final DoubleHypercube da = (DoubleHypercube)a;
final DoubleHypercube db = (DoubleHypercube)b;
final DoubleHypercube dr = (DoubleHypercube)r;
for (long i=0; i < aDims[0].length(); i++) {
ai[0] = ri[0] = i;
for (long j=0; j < bDims[1].length(); j++) {
bi[1] = ri[1] = j;

// We will copy out the column from 'b' for faster access,
// if it's small enough to fit into an array. 2^30 doubles
// is 16GB for one column which is totally possible for a
// non-square matrix but, we hope, most matrices will not be
// quite that big. We could use an array of arrays to handle
// that case but this is slower for the general case.
final double[] bcol = (bDims[0].length() < 1<<30)
? new double[(int)bDims[0].length()]
: null;

// Where we start striding, see below.
ai[1] = bi[0] = 0;

// The stride through the flattened data, to walk a column
// in 'b'. We know that the format of the data is C-style in
// flattened form, so moving one row length in distance will
// step down one column index.
final long bs = bDims[1].length();

// Flipped the ordering of 'i' and 'j' since it's more cache
// efficient to copy out the column data (once) and then to
// stride through the rows each time.
for (long j=0; j < bDims[1].length(); j++) {
bi[1] = ri[1] = j;
if (bcol != null) {
long bo = db.toOffset(bi);
for (int i=0; i < bcol.length; i++, bo += bs) {
bcol[i] = db.getAt(bo);
}
}
for (long i=0; i < aDims[0].length(); i++) {
// We will stride through the two cubes pulling out
// the values for the sum directly, since we know
// their shape. This is much faster than going via
// the coordinate-based lookup. The stride in 'a' is
// 1 since it's walking along a row; in b it's the
// the row length, since it's walking along a column.
// Both will be the same number of steps so we only
// need to know when to stop talking in 'a'.
ai[0] = ri[0] = i;
long ao = da.toOffset(ai);
double sum = 0;
for (long k=0; k < bDims[0].length(); k++) {
ai[1] = bi[0] = k;
sum += da.get(ai) * db.get(bi);
if (bcol == null) {
final long ae = ao + aDims[1].length();
for (long bo = db.toOffset(bi);
ao < ae; ao++,
bo += bs)
{
sum += da.getAt(ao) * db.getAt(bo);
}
}
else {
for (int bo=0 ; bo < bcol.length; ao++, bo++) {
sum += da.getAt(ao) * bcol[bo];
}
}
dr.set(sum, ri);
}
Expand Down Expand Up @@ -19352,4 +19556,4 @@ private static Hypercube<Double> doubleExtract(
}
}

// [[[end]]] (checksum: e95077b77768bd9403a780ff90ce9e1d)
// [[[end]]] (checksum: 39d964579ec2256d0b844d8f015ac9d6)
Loading

0 comments on commit f59d191

Please sign in to comment.