You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I implemented the regression colab example (or at least the first VBLLMLP example) in JAX. I wrote the equiv. of the distributions.py by subclassing numpyro.distributions and implemented the Regression and VBLLMLP classes in flax. The model is training but the uncertainty bands are a bit of a mess.
Are there any plans to implement in JAX? Would be keen to maybe help out a little if there was. Would be keen to find the errors in my colab somehow too...
But I have the same problem when performing the classcification task. the UQ is a bit of a mess. Which parameter do you think is the most important?
Best,
This is my setting, where 7232 is the num of total samples
self.output = vbll.DiscClassification(64, 2, 1.0 / 7232, parameterization='diagonal', prior_scale= 1.0)
Hi, great paper!
I implemented the regression colab example (or at least the first VBLLMLP example) in JAX. I wrote the equiv. of the
distributions.py
by subclassingnumpyro.distributions
and implemented theRegression
andVBLLMLP
classes in flax. The model is training but the uncertainty bands are a bit of a mess.Are there any plans to implement in JAX? Would be keen to maybe help out a little if there was. Would be keen to find the errors in my colab somehow too...
Here is the colab: https://colab.research.google.com/drive/1Rh895u0jP9xEpK7eMOz9JHUX_2CluyLO?usp=sharing
Thanks,
Conor
The text was updated successfully, but these errors were encountered: