Skip to content

Commit

Permalink
Add tests and fix some functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
Iainmon committed Aug 31, 2024
1 parent 3dea84b commit 13d644e
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 31 deletions.
8 changes: 6 additions & 2 deletions lib/NDArray.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -549,11 +549,15 @@ record ndarray : serializable {
}


proc type ndarray.arange(to: int,type eltType = real(64),shape: ?rank*int): ndarray(rank,eltType) {

proc type ndarray.arange(type eltType = real(32),shape: ?rank*int): ndarray(rank,eltType) {
const dom = util.domainFromShape((...shape));
const A: [dom] eltType = foreach (_,x) in zip(dom,0..<to) do x:eltType;
const A: [dom] eltType = foreach (i,_) in dom.everyZip() do i : eltType;
return new ndarray(A);
}
proc type ndarray.arange(shape: int...?rank): ndarray(rank,real(32)) do
return ndarray.arange(eltType=real(32), shape);



operator =(ref lhs: ndarray(?rank,?eltType), const rhs: ndarray(rank,eltType)) {
Expand Down
4 changes: 1 addition & 3 deletions lib/Network.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ class Module {
if var p = m : borrowed Parameter(eltType)? {
const paramName = name[(moduleName.size + 1)..];
const paramPath = modelPath + paramName + ".chdata";
writeln("Loading ",paramName," from ", paramPath);
if debug then writeln("Loading ",paramName," from ", paramPath);
var loaded = Tensor.load(paramPath) : eltType;
p!.data = loaded;
}
Expand Down Expand Up @@ -666,10 +666,8 @@ class Dropout : Module(?) {


proc chain(m: borrowed Module(?), modNames: string...?n, input: Tensor(?eltType)) {
writeln("layer 0");
var output = m.mod(modNames(0))(input);
for param i in 1..<n {
writeln("layer ", i);
output = m.mod(modNames(i))(output);
}
return output;
Expand Down
14 changes: 14 additions & 0 deletions lib/SimpleDomain.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,20 @@ inline proc computeSize(shape: ?rank*int): int {
return s;
}

inline proc computeAtIndex(const idx: int): int do
return idx;

inline proc computeAtIndex(const shape: ?rank*int, const idx: rank*int): int {
if rank == 1 then
return idx;
const strides = computeStrides(shape);
// if rank > 1 do
var i: int;
for param j in 0..<rank do
i += idx(j) * strides(j);
return i;
}



record rect : serializable {
Expand Down
2 changes: 2 additions & 0 deletions lib/StaticTensor.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use Autograd;
import Utilities as util;
use Utilities.Standard;

type tensor = staticTensor(?);

record staticTensor : serializable {
param rank: int;
type eltType = real(64);
Expand Down
12 changes: 10 additions & 2 deletions lib/Utilities.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ module Utilities {
// import ChapelDomain;
import Math;

import SimpleDomain;


param debugPrint = false;

Expand Down Expand Up @@ -76,6 +78,12 @@ module Utilities {
}
}

inline proc product(tup: int...?rank) do
return SimpleDomain.computeSize(tup);

inline proc linearIdx(shape: ?rank*int, idx: rank*int) do
return SimpleDomain.computeAtIndex(shape,idx);

inline proc normalizeArray(arr: []) {
const arrDom = arr.domain;
const normalDomain = normalizeDomain(arrDom);
Expand Down Expand Up @@ -608,7 +616,7 @@ module Utilities {
}
}
}

/*
inline iter _domain.everyZip(param tag: iterKind) where tag == iterKind.leader {
const shape = this.fastShape;
if CHPL_LOCALE_MODEL != "gpu" {
Expand Down Expand Up @@ -712,7 +720,7 @@ module Utilities {
}
}
}
}
}*/

inline proc _domain.indexAt(n: int) where rank == 1 {
return n;
Expand Down
12 changes: 12 additions & 0 deletions src/ChAI.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,18 @@ module ChAI {
public import Network;
public import Tensor;

public import NDArray;
public import StaticTensor;
public import DynamicTensor;
public import Utilities.Types;
public import Utilities as util;
public import Utilities.Standard;

public import Remote;

// Expose common types:
public use NDArray;

proc main() {
writeln(Tensor.Tensor.arange(3,5));

Expand Down
15 changes: 15 additions & 0 deletions test/importTest.chpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use UnitTest;

use ChAI;

config const testParam: bool = true;

proc myTest(test: borrowed Test) throws {

writeln(ndarray.arange(1,2,3));
test.assertTrue(true);


}

UnitTest.main();
30 changes: 6 additions & 24 deletions test/loadFromSpec.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -9,39 +9,21 @@ use UnitTest;
proc myTest2(test: borrowed Test) throws {

// Construct the model from specification.
var model: owned Module(real) = modelFromSpecFile("scripts/models/cnn/specification.json");

// Print the model's structure.
writeln(model.signature);

// Load the weights into the model.
model.loadPyTorchDump("scripts/models/cnn/");
var model: owned Module(real(32)) = Network.loadModel(specFile="scripts/models/cnn/specification.json",
weightsFolder="scripts/models/cnn/",
dtype=real(32));

// Load an array of images.
const numImages = 10;
var images = forall i in 0..<numImages do Tensor.load("data/datasets/mnist/image_idx_" + i:string + ".chdata");
var images = forall i in 0..<numImages do Tensor.load("examples/data/datasets/mnist/image_idx_" + i:string + ".chdata") : real(32);

// Create array of output results.
var preds: [0..<numImages] int;

const numTimes = 1;
var time: real;
for i in 0..<numTimes {
var st = new Time.stopwatch();

st.start();
forall (img,pred) in zip(images, preds) {
pred = model(img).argmax();
}
st.stop();

const tm = st.elapsed();
writeln("Time: ", tm, " seconds.");
time += tm;
forall (img,pred) in zip(images, preds) {
pred = model(img).argmax();
}

time /= numTimes;

test.assertTrue(preds[0] == 7);
test.assertTrue(preds[1] == 2);

Expand Down

0 comments on commit 13d644e

Please sign in to comment.