Skip to content

Commit

Permalink
Fix layer loading
Browse files Browse the repository at this point in the history
  • Loading branch information
jwilles committed Aug 26, 2024
1 parent 9e51684 commit e79282e
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "vbll"
version = "0.2.4"
version = "0.2.5"
description = ""
authors = ["John Willes <[email protected]>"]
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

setup(
name="vbll",
version="0.2.4",
version="0.2.5",
packages=find_packages(),
install_requires=["torch"],
)
4 changes: 2 additions & 2 deletions vbll/layers/regression.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import torch
from dataclasses import dataclass
from vbll.utils.distributions import Normal, DenseNormal, LowRankNormal, get_parameterization
from vbll.utils.distributions import Normal, DenseNormal, LowRankNormal, DenseNormalPrec, get_parameterization
from collections.abc import Callable
import torch.nn as nn

Expand Down Expand Up @@ -103,7 +103,7 @@ def W(self):
cov_diag = torch.exp(self.W_logdiag)
if self.W_dist == Normal:
cov = self.W_dist(self.W_mean, cov_diag)
elif (self.W_dist == DenseNormal) or (self.W_dist == DenseNormalPrecision):
elif (self.W_dist == DenseNormal) or (self.W_dist == DenseNormalPrec):
tril = torch.tril(self.W_offdiag, diagonal=-1) + torch.diag_embed(cov_diag)
cov = self.W_dist(self.W_mean, tril)
elif self.W_dist == LowRankNormal:
Expand Down

0 comments on commit e79282e

Please sign in to comment.