Skip to content

Commit

Permalink
Use numpy.ctypeslib
Browse files Browse the repository at this point in the history
To avoid (silent) errors when reading lists / arrays in the parameter
provider, explicitly use numpy.ctypeslib
  • Loading branch information
schmoelder committed Dec 8, 2024
1 parent 0546225 commit b7fdb98
Showing 1 changed file with 29 additions and 13 deletions.
42 changes: 29 additions & 13 deletions cadet/cadet_dll_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ def param_provider_get_double(
except TypeError:
float_val = float(o[0])

val[0] = ctypes.c_double(float_val)
log_print(f"GET scalar [double] {n}: {float(val[0])}")
np_val = np.ctypeslib.as_array(val, shape=(1,))
np_val[0] = float_val
log_print(f"GET scalar [double] {n}: {np_val[0]}")
return 0

return -1
Expand Down Expand Up @@ -85,8 +86,9 @@ def param_provider_get_int(
except TypeError:
int_val = int(o[0])

val[0] = ctypes.c_int(int_val)
log_print(f"GET scalar [int] {n}: {int(val[0])}")
np_val = np.ctypeslib.as_array(val, shape=(1,))
np_val[0] = int_val
log_print(f"GET scalar [int] {n}: {np_val[0]}")
return 0

return -1
Expand Down Expand Up @@ -124,8 +126,9 @@ def param_provider_get_bool(
except TypeError:
int_val = int(o[0])

val[0] = ctypes.c_uint8(int_val)
log_print(f"GET scalar [bool] {n}: {bool(val[0])}")
np_val = np.ctypeslib.as_array(val, shape=(1,))
np_val[0] = int_val
log_print(f"GET scalar [bool] {n}: {bool(np_val[0])}")
return 0

return -1
Expand Down Expand Up @@ -208,8 +211,9 @@ def param_provider_get_double_array(
if n in c:
o = c[n]
if isinstance(o, list):
o = np.ascontiguousarray(o)
if not isinstance(o, np.ndarray) or o.dtype != np.double or not o.flags.c_contiguous:
o = np.ascontiguousarray(o, dtype=np.float64)

if not isinstance(o, np.ndarray) or o.dtype != np.float64 or not o.flags.c_contiguous:
return -1

n_elem[0] = ctypes.c_int(o.size)
Expand Down Expand Up @@ -251,8 +255,9 @@ def param_provider_get_int_array(
if n in c:
o = c[n]
if isinstance(o, list):
o = np.ascontiguousarray(o)
if not isinstance(o, np.ndarray) or o.dtype != int or not o.flags.c_contiguous:
o = np.ascontiguousarray(o, dtype=np.int32)

if not isinstance(o, np.ndarray) or o.dtype != np.int32 or not o.flags.c_contiguous:
return -1

n_elem[0] = ctypes.c_int(o.size)
Expand Down Expand Up @@ -339,9 +344,20 @@ def param_provider_get_int_array_item(
o = c[n]

try:
int_val = int(o)
except TypeError:
int_val = int(o[index])
# If it's a list, convert it to a NumPy array
if isinstance(o, list):
o = np.ascontiguousarray(o, dtype=np.int32)

# Validate it's a NumPy array with appropriate dtype
if isinstance(o, np.ndarray) and o.dtype == np.int32:
int_val = o[index] # Retrieve the value by index
else:
# Handle scalar values
int_val = int(o)

except (TypeError, IndexError) as e:
log_print(f"ERROR retrieving array item for {n} at index {index}: {e}")
return -1

val[0] = ctypes.c_int(int_val)
log_print(f"GET array [int] ({index}) {n}: {val[0]}")
Expand Down

0 comments on commit b7fdb98

Please sign in to comment.