-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ADBench LSTM test #152
ADBench LSTM test #152
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# A reference implementation of D-LSTM, so that we can test our | ||
# Knossos implementation. | ||
# | ||
# Please note that this is an LSTM with diagonal weight matrices, not | ||
# a full LSTM. Its purpose is to calculate the exact same thing that | ||
# ADBench's D-LSTM implementation calculates, not to implement a full | ||
# LSTM. If and when ADBench changes its implementation to a full | ||
# LSTM, we will too. | ||
# | ||
# This code is an as-close-as-possible port of ADBench's Tensorflow | ||
# implementation to numpy. The only reason it is changed at all is | ||
# that I want to avoid picking up a Tensorflow dependency. See | ||
# | ||
# https://github.com/awf/ADBench/blob/e5f72ab5dcb453b1bb72dd000e5add6b90502ec4/src/python/modules/Tensorflow/TensorflowLSTM.py | ||
|
||
import numpy as np | ||
|
||
def sigmoid(x): | ||
return 1.0 / (1.0 + np.exp(-x)) | ||
|
||
def lstm_model(weight, bias, hidden, cell, inp): | ||
gates = np.concatenate((inp, hidden, inp, hidden), 0) * weight + bias | ||
hidden_size = hidden.shape[0] | ||
|
||
forget = sigmoid(gates[0:hidden_size]) | ||
ingate = sigmoid(gates[hidden_size:2*hidden_size]) | ||
outgate = sigmoid(gates[2*hidden_size:3*hidden_size]) | ||
change = np.tanh(gates[3*hidden_size:]) | ||
|
||
cell = cell * forget + ingate * change | ||
hidden = outgate * np.tanh(cell) | ||
|
||
return (hidden, cell) | ||
|
||
def lstm_predict(w, w2, s, x): | ||
s2 = s.copy() | ||
# NOTE not sure if this should be element-wise or matrix multiplication | ||
x = x * w2[0] | ||
for i in range(0, len(s), 2): | ||
(s2[i], s2[i + 1]) = lstm_model(w[i], w[i + 1], s[i], s[i + 1], x) | ||
x = s2[i] | ||
return (x * w2[1] + w2[2], s2) | ||
|
||
def lstm_objective(main_params, extra_params, state, sequence, _range=None): | ||
if _range is None: | ||
_range = range(0, len(sequence) - 1) | ||
|
||
total = 0.0 | ||
count = 0 | ||
_input = sequence[_range[0]] | ||
all_states = [state] | ||
for t in _range: | ||
ypred, new_state = lstm_predict(main_params, extra_params, all_states[t], _input) | ||
all_states.append(new_state) | ||
ynorm = ypred - np.log(sum(np.exp(ypred), 2)) | ||
ygold = sequence[t + 1] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. super-nit: is that yg_old or y_gold? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The latter, but: I'm trying to change the original source code as little as possible. See the new top-level comment. |
||
total += sum(ygold * ynorm) | ||
count += ygold.shape[0] | ||
_input = ygold | ||
return -total / count |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
# Remarkably, adbench-lstm.ks agrees exactly with our reference | ||
# implementation, so we set the almost equal check to check to a very | ||
# large number of decimal places. | ||
|
||
import adbench_lstm as a | ||
import ksc.adbench_lstm.lstm as k | ||
import random | ||
import numpy as np | ||
|
||
d = a.vec_double | ||
|
||
def r(): | ||
return random.random() * 2 - 1 | ||
|
||
def rv(n): | ||
return [r() for _ in range(n)] | ||
|
||
def rvv(n, m): | ||
return [rv(m) for _ in range(n)] | ||
|
||
def concat(l): | ||
return sum(l, []) | ||
|
||
# The ks::vec __iter__ method that is automatically generated by | ||
# pybind11 is one that keeps going off the end of the vec and never | ||
# stops. Until I get time to dig into how to make it generate a | ||
# better one, here's a handy utility function. | ||
def to_list(x): | ||
return [x[i] for i in range(len(x))] | ||
|
||
def main(): | ||
assert_equal_model() | ||
assert_equal_predict_and_objective() | ||
print("The assertions didn't throw any errors, so " | ||
"everything must be good!") | ||
|
||
def assert_equal_model(): | ||
h = 2 | ||
|
||
w1 = rv(h) | ||
w2 = rv(h) | ||
w3 = rv(h) | ||
w4 = rv(h) | ||
|
||
b1 = rv(h) | ||
b2 = rv(h) | ||
b3 = rv(h) | ||
b4 = rv(h) | ||
|
||
hidden = rv(h) | ||
cell = rv(h) | ||
input_ = rv(h) | ||
|
||
weight = concat([w1, w2, w3, w4]) | ||
bias = concat([b1, b2, b3, b4]) | ||
|
||
(ao0, ao1) = a.lstm_model(d(w1), | ||
d(b1), | ||
d(w2), | ||
d(b2), | ||
d(w3), | ||
d(b3), | ||
d(w4), | ||
d(b4), | ||
d(hidden), | ||
d(cell), | ||
d(input_)) | ||
|
||
ao0l = to_list(ao0) | ||
ao1l = to_list(ao1) | ||
|
||
nd_weight = np.array(weight) | ||
|
||
(mo0, mo1) = k.lstm_model(nd_weight, | ||
np.array(bias), | ||
np.array(hidden), | ||
np.array(cell), | ||
np.array(input_)) | ||
|
||
print(mo0) | ||
print(ao0l) | ||
print(mo1) | ||
print(ao1l) | ||
|
||
np.testing.assert_almost_equal(ao0l, mo0, decimal=12, err_msg="Model 1") | ||
np.testing.assert_almost_equal(ao1l, mo1, decimal=12, err_msg="Model 2") | ||
|
||
def assert_equal_predict_and_objective(): | ||
l = 2 | ||
h = 10 | ||
|
||
w1 = rvv(l, h) | ||
w2 = rvv(l, h) | ||
w3 = rvv(l, h) | ||
w4 = rvv(l, h) | ||
|
||
b1 = rvv(l, h) | ||
b2 = rvv(l, h) | ||
b3 = rvv(l, h) | ||
b4 = rvv(l, h) | ||
|
||
hidden = rvv(l, h) | ||
cell = rvv(l, h) | ||
|
||
input_ = rv(h) | ||
|
||
input_weight = rv(h) | ||
output_weight = rv(h) | ||
output_bias = rv(h) | ||
|
||
tww = np.array(concat([concat([w1i, w2i, w3i, w4i]), | ||
concat([b1i, b2i, b3i, b4i])] | ||
for (w1i, w2i, w3i, w4i, b1i, b2i, b3i, b4i) | ||
in zip(w1, w2, w3, w4, b1, b2, b3, b4))) | ||
|
||
ts = np.array(concat(([hiddeni, celli] | ||
for (hiddeni, celli) | ||
in zip(hidden, cell)))) | ||
|
||
tww2 = np.array([input_weight, output_weight, output_bias]) | ||
|
||
tinput_ = np.array(input_) | ||
|
||
print(tww.shape) | ||
print(tww2.shape) | ||
print(ts.shape) | ||
print(tinput_.shape) | ||
|
||
(tp0, tp1) = k.lstm_predict(tww, tww2, ts, tinput_) | ||
|
||
|
||
tp0l = tp0.tolist() | ||
tp1l = tp1.tolist() | ||
|
||
wf_etc = [tuple(d(i) for i in tu) | ||
for tu in zip(w1, b1, w2, b2, w3, b3, w4, b4, hidden, cell)] | ||
|
||
(v, vtvv) = a.lstm_predict(a.vec_tuple_vec10(wf_etc), | ||
d(input_weight), | ||
d(output_weight), | ||
d(output_bias), | ||
d(input_)) | ||
|
||
vl = to_list(v) | ||
vtvvl = concat([to_list(v1), to_list(v2)] for (v1, v2) in to_list(vtvv)) | ||
|
||
to = k.lstm_objective(tww, tww2, ts, [tinput_, tinput_]) | ||
tol = to.tolist() | ||
|
||
print(tol) | ||
|
||
aol = a.lstm_objective(a.vec_tuple_vec10(wf_etc), | ||
d(input_weight), | ||
d(output_weight), | ||
d(output_bias), | ||
a.vec_tuple_vec2([(d(input_), d(input_))])) | ||
|
||
print(tp0l) | ||
print(vl) | ||
print(tp1l) | ||
print(vtvvl) | ||
print(tol) | ||
print(aol) | ||
|
||
np.testing.assert_almost_equal(tp0l, vl, decimal=12, err_msg="Predict 1") | ||
np.testing.assert_almost_equal(tp1l, vtvvl, decimal=12, err_msg="Predict 2") | ||
np.testing.assert_almost_equal(tol, aol, decimal=12, err_msg="Objective") | ||
|
||
if __name__ == '__main__': main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# There's a lot of duplication between this and | ||
# build_and_test_mnistcnn.sh, but we will follow the Rule of Three | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add similar comment to build_and_test_mnistcnn.sh? (I think the Wikipedia reference is probably unnecessary/OTT but it is humorous to include it ;) ) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, good idea |
||
# | ||
# https://en.wikipedia.org/wiki/Rule_of_three_(computer_programming) | ||
|
||
set -e | ||
|
||
KNOSSOS=$1 | ||
PYBIND11=$2 | ||
|
||
RUNTIME=$KNOSSOS/src/runtime | ||
OBJ=$KNOSSOS/obj/test/ksc | ||
PYBIND11_INCLUDE=$PYBIND11/include | ||
|
||
PYTHON3_CONFIG_EXTENSION_SUFFIX=$(python3-config --extension-suffix) | ||
|
||
MODULE_NAME=adbench_lstm | ||
MODULE_FILE="$OBJ/$MODULE_NAME$PYTHON3_CONFIG_EXTENSION_SUFFIX" | ||
|
||
echo Compiling... | ||
|
||
g++-7 -fmax-errors=5 \ | ||
-fdiagnostics-color=always \ | ||
-Wall \ | ||
-Wno-unused \ | ||
-Wno-maybe-uninitialized \ | ||
-I$RUNTIME \ | ||
-I$OBJ \ | ||
-I$PYBIND11_INCLUDE \ | ||
$(PYTHONPATH=$PYBIND11 python3 -m pybind11 --includes) \ | ||
-O3 \ | ||
-std=c++17 \ | ||
-shared \ | ||
-fPIC \ | ||
-o $MODULE_FILE \ | ||
-DMNISTCNNCPP_MODULE_NAME=$MODULE_NAME \ | ||
$KNOSSOS/test/ksc/adbench-lstmpy.cpp | ||
|
||
KSCPY=$KNOSSOS/src/python | ||
PYTHONPATH=$OBJ:$KSCPY python3 -m ksc.adbench_lstm.test |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
/* There's a lot of duplication between this and mnistcnnpy.cpp, but | ||
* we will follow the Rule of Three | ||
* | ||
* https://en.wikipedia.org/wiki/Rule_of_three_(computer_programming) | ||
*/ | ||
|
||
#include <pybind11/pybind11.h> | ||
#include <pybind11/stl.h> | ||
#include <pybind11/operators.h> | ||
|
||
namespace py = pybind11; | ||
|
||
#include "adbench-lstm.cpp" | ||
|
||
int ks::main() { return 0; }; | ||
|
||
template<typename T> | ||
void declare_vec(py::module &m, std::string typestr) { | ||
using Class = ks::vec<T>; | ||
std::string pyclass_name = std::string("vec_") + typestr; | ||
py::class_<Class>(m, pyclass_name.c_str()) | ||
.def(py::init<>()) | ||
.def(py::init<std::vector<T> const&>()) | ||
.def("is_zero", &Class::is_zero) | ||
.def("__getitem__", [](const ks::vec<T> &a, const int &b) { | ||
return a[b]; | ||
}) | ||
.def("__len__", [](const ks::vec<T> &a) { return a.size(); }); | ||
} | ||
|
||
// In the future it might make more sense to move the vec type | ||
// definitions to a general Knossos CPP types Python module. | ||
// | ||
// I don't know how to make a single Python type that works for vecs | ||
// of many different sorts of contents. It seems like it must be | ||
// possible because Python tuples map to std::tuples regardless of | ||
// their contents. I'll look into it later. For now I'll just have a | ||
// bunch of verbose replication. | ||
PYBIND11_MODULE(MNISTCNNCPP_MODULE_NAME, m) { | ||
declare_vec<double>(m, std::string("double")); | ||
declare_vec<std::tuple<ks::vec<double>, ks::vec<double>, ks::vec<double>, ks::vec<double>, ks::vec<double>, ks::vec<double>, ks::vec<double>, ks::vec<double>, ks::vec<double>, ks::vec<double>>>(m, std::string("tuple_vec10")); | ||
declare_vec<std::tuple<ks::vec<double>, ks::vec<double>>>(m, std::string("tuple_vec2")); | ||
declare_vec<ks::vec<double> >(m, std::string("vec_double")); | ||
declare_vec<ks::vec<ks::vec<double> > >(m, std::string("vec_vec_double")); | ||
declare_vec<ks::vec<ks::vec<ks::vec<double> > > >(m, std::string("vec_vec_vec_double")); | ||
declare_vec<ks::vec<ks::vec<ks::vec<ks::vec<double> > > > >(m, std::string("vec_vec_vec_vec_double")); | ||
m.def("sigmoid", &ks::sigmoid); | ||
m.def("logsumexp", &ks::logsumexp); | ||
m.def("lstm_model", &ks::lstm_model); | ||
m.def("lstm_predict", &ks::lstm_predict); | ||
m.def("lstm_objective", &ks::lstm_objective); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe assert something here about size/shape of
hidden
vs size ofinp
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm trying to change the original source code as little as possible. See the new top-level comment.