Skip to content

Commit

Permalink
Slight restructures, changed the default device parameter.
Browse files Browse the repository at this point in the history
Added tests for the deprecated interfaces and some error cases.
  • Loading branch information
thomgrand committed Mar 24, 2024
1 parent 3181b4c commit eabee6c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
5 changes: 2 additions & 3 deletions fimpy/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"""Function responsible for creating the correct Fast Iterative Method solver
"""
def create_fim_solver(points : available_arr_t, elems : available_arr_t, metrics : available_arr_t =None,
precision=np.float32, device='gpu', use_active_list=True) -> FIMBase:
precision=np.float32, device='cpu', use_active_list=True) -> FIMBase:
"""Creates a Fast Iterative Method solver for solving the anisotropic eikonal equation
.. math::
Expand Down Expand Up @@ -52,8 +52,7 @@ def create_fim_solver(points : available_arr_t, elems : available_arr_t, metrics
FIMBase
Returns a Fast Iterative Method solver
"""
if not cupy_available:
device='cpu'
assert not device == 'gpu' or cupy_available, "Requested GPU which is not available"

if device == 'cpu':
return (FIMNPAL(points, elems, metrics, precision) if use_active_list else FIMNP(points, elems, metrics, precision))
Expand Down
21 changes: 17 additions & 4 deletions tests/test_fim_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytest
#from .
from fimpy.solver import create_fim_solver
from fimpy.solver import FIMPY #Deprecated interface
import numpy as np
import os
import scipy.io as sio
Expand All @@ -26,15 +27,19 @@ def test_init(self, dims, init_D, precision, use_active_list, device):
if device == 'gpu' and not cupy_enabled:
pytest.skip(reason='Cupy could not be imported. GPU tests unavailable')

points = np.tile(np.linspace(0, 1, num=4)[:(dims+1)][:, np.newaxis], [1, dims])
elems = np.arange(points.shape[0])[np.newaxis]
points, elems = self.dummy_mesh(dims)
D = None
if init_D:
D = np.eye(dims)[np.newaxis]

fim_solver = create_fim_solver(points, elems, D, device='cpu', precision=precision, use_active_list=use_active_list)
fim_solver = create_fim_solver(points, elems, D, device=device, precision=precision, use_active_list=use_active_list)
return fim_solver

def dummy_mesh(self, dims):
points = np.tile(np.linspace(0, 1, num=4)[:(dims+1)][:, np.newaxis], [1, dims])
elems = np.arange(points.shape[0])[np.newaxis]
return points, elems

@pytest.mark.parametrize('precision', [np.float32, np.float64])
def test_error_init(self, precision, device='cpu'):
points = np.array([0.])
Expand Down Expand Up @@ -82,7 +87,15 @@ def test_error_init(self, precision, device='cpu'):
def test_error_init_gpu(self, precision):
self.test_error_init(precision, 'gpu')


def test_error_init_wrong_device(self):
with pytest.raises(AssertionError):
self.test_init(3, True, np.float32, use_active_list=False, device='undefined_device')

def test_error_deprecated(self):
solver2 = FIMPY.create_fim_solver(*self.dummy_mesh(2))
solver = create_fim_solver(*self.dummy_mesh(2))
assert pickle.dumps(solver) == pickle.dumps(solver2) #Serialized objects should be exactly the same

@pytest.mark.parametrize('init_D', [True, False])
@pytest.mark.parametrize('dims', [1, 2, 3])
@pytest.mark.parametrize('precision', [np.float32, np.float64])
Expand Down

0 comments on commit eabee6c

Please sign in to comment.