Skip to content

Commit

Permalink
fix tests to just use jax
Browse files Browse the repository at this point in the history
  • Loading branch information
yallup committed Jul 5, 2024
1 parent a72fe3c commit 83a2515
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
Empty file removed tests/___init__.py
Empty file.
16 changes: 9 additions & 7 deletions tests/test_class.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import numpy as np
import jax.numpy as jnp
import pytest
from jax import nn
from jax import nn, random

from clax import Classifier, ClassifierSamples

rng = random.PRNGKey(0)

# @pytest.mark.parametrize("n_classes", [2, 10])
# class TestClassifier:
# @pytest.fixture
Expand All @@ -19,18 +21,18 @@
@pytest.mark.parametrize("n_classes", [1, 2, 10])
def test_classifier(n_classes):
classifier = Classifier(n_classes)
data_x = np.random.rand(100, 10)
data_y = np.random.randint(0, n_classes, 100)
data_x = random.uniform(rng, (100, 10))
data_y = random.randint(rng, (100,), 0, n_classes)
classifier.fit(data_x, data_y)
y = classifier.predict(data_x)
assert y.shape == (100, n_classes)
assert np.isclose(nn.softmax(y).sum(axis=-1), 1).all()
assert jnp.isclose(nn.softmax(y).sum(axis=-1), 1).all()


def test_conditional_classifier():
classifier = ClassifierSamples()
data_x = np.random.rand(100, 10)
data_y = np.random.rand(100, 10)
data_x = random.normal(rng, (100, 10))
data_y = random.normal(rng, (100, 10))
classifier.fit(data_x, data_y)
y = classifier.predict(data_x)
assert y.shape == (100, 1)

0 comments on commit 83a2515

Please sign in to comment.