Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
OlaRonning committed Dec 2, 2024
1 parent 8f89738 commit 5aa59fb
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
9 changes: 5 additions & 4 deletions pyro/distributions/sine_bivariate_von_mises.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ class SineBivariateVonMises(TorchDistribution):
\frac{\rho^2}{\kappa_1\kappa_2} \rightarrow 1
because the distribution becomes increasingly bimodal. To avoid inefficient sampling use the
`weighted_correlation` parameter with a skew away from one (e.g.,
`TransformedDistribution(Beta(5,5), AffineTransform(loc=-1, scale=2))`). The `weighted_correlation`
because the distribution becomes increasingly bimodal. To avoid inefficient sampling use the
`weighted_correlation` parameter with a skew away from one (e.g.,
`TransformedDistribution(Beta(5,5), AffineTransform(loc=-1, scale=2))`). The `weighted_correlation`
should be in [-1,1].
.. note:: The correlation and weighted_correlation params are mutually exclusive.
Expand Down Expand Up @@ -141,7 +141,8 @@ def norm_const(self):
- m * torch.log(4 * torch.prod(conc, dim=-1))
)
num_I1terms = torch.maximum(
torch.tensor(501), torch.max(self.phi_concentration) + torch.max(self.psi_concentration)
torch.tensor(501),
torch.max(self.phi_concentration) + torch.max(self.psi_concentration),
).int()

fs += log_I1(m.max(), conc, num_I1terms).sum(-1)
Expand Down
5 changes: 3 additions & 2 deletions tests/distributions/test_sine_bivariate_von_mises.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,15 @@ def guide(data):

assert_equal(expected[k].squeeze(), actual.squeeze(), 9e-2)


@pytest.mark.parametrize("conc", [1.0, 10.0, 1000.0, 10000.0])
def test_sine_bivariate_von_mises_norm(conc):
dist = SineBivariateVonMises(0, 0, conc, conc, 0.0)
num_samples = 500
x = torch.linspace(-torch.pi, torch.pi, num_samples)
y = torch.linspace(-torch.pi, torch.pi, num_samples)
mesh = torch.stack(torch.meshgrid(x, y, indexing='ij'), axis=-1)
mesh = torch.stack(torch.meshgrid(x, y, indexing="ij"), axis=-1)
integral_torus = (
torch.exp(dist.log_prob(mesh)) * (2 * torch.pi) ** 2 / num_samples**2
).sum()
assert torch.allclose(integral_torus, torch.tensor(1.0), rtol=1e-2)
assert torch.allclose(integral_torus, torch.tensor(1.0), rtol=1e-2)

0 comments on commit 5aa59fb

Please sign in to comment.