-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ADBench LSTM test #152
ADBench LSTM test #152
Conversation
5fdce84
to
e667ba4
Compare
This collection of code contains a C++ wrapper for |
@toelli-msft It appears you have a simple LSTM implementation. I added a sequence LSTM implementation. As I am not sure what the purpose here is, lstm2.py may not be what you wanted, but having it in there will help discussion. I think we should test the sort of LSTM I have pushed as a minimum to make any speed claims. |
gates = np.concatenate((inp, hidden, inp, hidden), 0) * weight + bias | ||
hidden_size = hidden.shape[0] | ||
|
||
forget = sigmoid(gates[0:hidden_size]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe assert something here about size/shape of hidden
vs size of inp
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm trying to change the original source code as little as possible. See the new top-level comment.
ypred, new_state = lstm_predict(main_params, extra_params, all_states[t], _input) | ||
all_states.append(new_state) | ||
ynorm = ypred - np.log(sum(np.exp(ypred), 2)) | ||
ygold = sequence[t + 1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
super-nit: is that yg_old or y_gold?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The latter, but: I'm trying to change the original source code as little as possible. See the new top-level comment.
@@ -0,0 +1,40 @@ | |||
# There's a lot of duplication between this and | |||
# build_and_test_mnistcnn.sh, but we will follow the Rule of Three |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add similar comment to build_and_test_mnistcnn.sh? (I think the Wikipedia reference is probably unnecessary/OTT but it is humorous to include it ;) )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, good idea
src/python/ksc/adbench_lstm/lstm2.py
Outdated
|
||
""" | ||
Ther are many formulations of LSTMs. This code follows the formulation from | ||
https://cs224d.stanford.edu/lecture_notes/LectureNotes4.pdf with some simplifications |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe a comment what this is for? Do I understand correctly that this is not what you are testing against, it's for eyeball comparison only or something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alan - here's the reference for what this is for.
src/python/ksc/adbench_lstm/test.py
Outdated
from ksc.adbench_lstm.lstm import ( | ||
lstm_model, lstm_predict, lstm_objective, sigmoid) | ||
|
||
ten = np.ndarray |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this ten
used anywhere? I don't see it. (And d
?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, it's not used. (d
is used in several places though)
src/python/ksc/adbench_lstm/test.py
Outdated
import random | ||
import numpy as np | ||
|
||
from ksc.adbench_lstm.lstm import ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Personal preference but I would probably import ksc.adbench_lstm.lstm as k
and then I'm testing a.lstm_model
against k.lstm_model
etc. Up to you...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like that idea, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, a few nits/suggestions.
in terms of a Python reference implementation
311b34e
to
0813ccc
Compare
Thanks Alan and Pashmina. For those who are wondering why this is not a standard LSTM implementation, we have to copy exactly whatever ADBench does so that we are comparing like-for-like. ADBench explicitly titles its graph with "D-LSTM" (Diagonal LSTM) to make it clear that it is not a standard one. We definitely want ADBench to have one, it is an open ADBench issue to implement a standard one, but as yet it is not one. Hopefully the new comment on top of Pashmina, thanks for the reference implementation of a standard LSTM. Perhaps ADBench can use it as a reference. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the context, Tom.
0813ccc
to
c20e98c
Compare
No description provided.