From 6405b3d8bc7ee244e8b51b5fcdc6c424154521cb Mon Sep 17 00:00:00 2001 From: Anton Chernyatevich Date: Sat, 14 Nov 2020 21:34:18 +0200 Subject: [PATCH] Fix enum classes and their inheritance --- arrayfire/__init__.py | 2 +- arrayfire/array.py | 6 +- arrayfire/device.py | 5 +- arrayfire/library.py | 492 ++++++++++++++++++++++-------------------- arrayfire/opencl.py | 29 +-- arrayfire/util.py | 6 +- 6 files changed, 275 insertions(+), 265 deletions(-) diff --git a/arrayfire/__init__.py b/arrayfire/__init__.py index 595d9615c..4ef4f0cfe 100644 --- a/arrayfire/__init__.py +++ b/arrayfire/__init__.py @@ -348,6 +348,7 @@ from .library import get_size_of # noqa : E401 from .library import safe_call # noqa : E401 from .library import set_backend # noqa : E401 +from .library import to_str # noqa : E401 # ============================================================================= # Machine Learning (ML) module # ============================================================================= @@ -445,7 +446,6 @@ from .util import number_dtype # noqa : E401 from .util import to_c_type # noqa : E401 from .util import to_dtype # noqa : E401 -from .util import to_str # noqa : E401 from .util import to_typecode # noqa : E401 try: diff --git a/arrayfire/array.py b/arrayfire/array.py index acbf14230..444060bd0 100644 --- a/arrayfire/array.py +++ b/arrayfire/array.py @@ -20,9 +20,9 @@ from .library import backend, safe_call from .library import ( Dtype, Source, c_bool_t, c_char_ptr_t, c_dim_t, c_double_t, c_int_t, c_longlong_t, c_pointer, c_size_t, c_uint_t, - c_ulonglong_t, c_void_ptr_t) + c_ulonglong_t, c_void_ptr_t, to_str) from .util import ( - _is_number, dim4, dim4_to_tuple, implicit_dtype, to_c_type, to_dtype, to_str, to_typecode, to_typename) + _is_number, dim4, dim4_to_tuple, implicit_dtype, to_c_type, to_dtype, to_typecode, to_typename) _is_running_in_py_charm = "PYCHARM_HOSTED" in os.environ @@ -1490,5 +1490,3 @@ def read_array(filename, index=None, key=None): elif key is not None: safe_call(backend.get().af_read_array_key(c_pointer(out.arr), filename.encode('utf-8'), key.encode('utf-8'))) return out - - diff --git a/arrayfire/device.py b/arrayfire/device.py index e12f74c41..262294c49 100644 --- a/arrayfire/device.py +++ b/arrayfire/device.py @@ -12,8 +12,8 @@ """ from .array import Array -from .library import backend, safe_call, c_bool_t, c_char_t, c_dim_t, c_int_t, c_pointer, c_size_t, c_void_ptr_t -from .util import to_str +from .library import ( + backend, c_bool_t, c_char_t, c_dim_t, c_int_t, c_pointer, c_size_t, c_void_ptr_t, safe_call, to_str) def init(): @@ -506,4 +506,3 @@ def free_pinned(ptr): """ cptr = c_void_ptr_t(ptr) safe_call(backend.get().af_free_pinned(cptr)) - diff --git a/arrayfire/library.py b/arrayfire/library.py index f4031cc5b..7d71e1c02 100644 --- a/arrayfire/library.py +++ b/arrayfire/library.py @@ -11,31 +11,34 @@ Module containing enums and other constants. """ -import platform import ctypes as ct import os +import platform import traceback +from enum import Enum -c_float_t = ct.c_float -c_double_t = ct.c_double -c_int_t = ct.c_int -c_uint_t = ct.c_uint -c_longlong_t = ct.c_longlong +c_float_t = ct.c_float +c_double_t = ct.c_double +c_int_t = ct.c_int +c_uint_t = ct.c_uint +c_longlong_t = ct.c_longlong c_ulonglong_t = ct.c_ulonglong -c_char_t = ct.c_char -c_bool_t = ct.c_bool -c_uchar_t = ct.c_ubyte -c_short_t = ct.c_short -c_ushort_t = ct.c_ushort -c_pointer = ct.pointer -c_void_ptr_t = ct.c_void_p -c_char_ptr_t = ct.c_char_p -c_size_t = ct.c_size_t -c_cast = ct.cast +c_char_t = ct.c_char +c_bool_t = ct.c_bool +c_uchar_t = ct.c_ubyte +c_short_t = ct.c_short +c_ushort_t = ct.c_ushort +c_pointer = ct.pointer +c_void_ptr_t = ct.c_void_p +c_char_ptr_t = ct.c_char_p +c_size_t = ct.c_size_t +c_cast = ct.cast + class af_cfloat_t(ct.Structure): _fields_ = [("real", ct.c_float), ("imag", ct.c_float)] + class af_cdouble_t(ct.Structure): _fields_ = [("real", ct.c_double), ("imag", ct.c_double)] @@ -56,146 +59,137 @@ class af_cdouble_t(ct.Structure): platform.machine()[0:3] == 'arm')): c_dim_t = c_int_t -try: - from enum import Enum as _Enum - def _Enum_Type(v): - return v -except ImportError: - class _MetaEnum(type): - def __init__(cls, name, bases, attrs): - for attrname, attrvalue in attrs.iteritems(): - if name != '_Enum' and isinstance(attrvalue, _Enum_Type): - attrvalue.__class__ = cls - attrs[attrname] = attrvalue - - class _Enum(object): - __metaclass__ = _MetaEnum - class _Enum_Type(object): - def __init__(self, v): - self.value = v - -class ERR(_Enum): +class ERR(Enum): """ Error values. For internal use only. """ - NONE = _Enum_Type(0) + NONE = 0 - #100-199 Errors in environment - NO_MEM = _Enum_Type(101) - DRIVER = _Enum_Type(102) - RUNTIME = _Enum_Type(103) + # 100-199 Errors in environment + NO_MEM = 101 + DRIVER = 102 + RUNTIME = 103 # 200-299 Errors in input parameters - INVALID_ARRAY = _Enum_Type(201) - ARG = _Enum_Type(202) - SIZE = _Enum_Type(203) - TYPE = _Enum_Type(204) - DIFF_TYPE = _Enum_Type(205) - BATCH = _Enum_Type(207) - DEVICE = _Enum_Type(208) + INVALID_ARRAY = 201 + ARG = 202 + SIZE = 203 + TYPE = 204 + DIFF_TYPE = 205 + BATCH = 207 + DEVICE = 208 # 300-399 Errors for missing software features - NOT_SUPPORTED = _Enum_Type(301) - NOT_CONFIGURED = _Enum_Type(302) - NONFREE = _Enum_Type(303) + NOT_SUPPORTED = 301 + NOT_CONFIGURED = 302 + NONFREE = 303 # 400-499 Errors for missing hardware features - NO_DBL = _Enum_Type(401) - NO_GFX = _Enum_Type(402) - NO_HALF = _Enum_Type(403) + NO_DBL = 401 + NO_GFX = 402 + NO_HALF = 403 # 500-599 Errors specific to the heterogeneous API - LOAD_LIB = _Enum_Type(501) - LOAD_SYM = _Enum_Type(502) - ARR_BKND_MISMATCH = _Enum_Type(503) + LOAD_LIB = 501 + LOAD_SYM = 502 + ARR_BKND_MISMATCH = 503 # 900-999 Errors from upstream libraries and runtimes - INTERNAL = _Enum_Type(998) - UNKNOWN = _Enum_Type(999) + INTERNAL = 998 + UNKNOWN = 999 + -class Dtype(_Enum): +class Dtype(Enum): """ Error values. For internal use only. """ - f32 = _Enum_Type(0) - c32 = _Enum_Type(1) - f64 = _Enum_Type(2) - c64 = _Enum_Type(3) - b8 = _Enum_Type(4) - s32 = _Enum_Type(5) - u32 = _Enum_Type(6) - u8 = _Enum_Type(7) - s64 = _Enum_Type(8) - u64 = _Enum_Type(9) - s16 = _Enum_Type(10) - u16 = _Enum_Type(11) - f16 = _Enum_Type(12) - -class Source(_Enum): + f32 = 0 + c32 = 1 + f64 = 2 + c64 = 3 + b8 = 4 + s32 = 5 + u32 = 6 + u8 = 7 + s64 = 8 + u64 = 9 + s16 = 10 + u16 = 11 + f16 = 12 + + +class Source(Enum): """ Source of the pointer """ - device = _Enum_Type(0) - host = _Enum_Type(1) + device = 0 + host = 1 + -class INTERP(_Enum): +class INTERP(Enum): """ Interpolation method """ - NEAREST = _Enum_Type(0) - LINEAR = _Enum_Type(1) - BILINEAR = _Enum_Type(2) - CUBIC = _Enum_Type(3) - LOWER = _Enum_Type(4) - LINEAR_COSINE = _Enum_Type(5) - BILINEAR_COSINE = _Enum_Type(6) - BICUBIC = _Enum_Type(7) - CUBIC_SPLINE = _Enum_Type(8) - BICUBIC_SPLINE = _Enum_Type(9) + NEAREST = 0 + LINEAR = 1 + BILINEAR = 2 + CUBIC = 3 + LOWER = 4 + LINEAR_COSINE = 5 + BILINEAR_COSINE = 6 + BICUBIC = 7 + CUBIC_SPLINE = 8 + BICUBIC_SPLINE = 9 -class PAD(_Enum): + +class PAD(Enum): """ Edge padding types """ - ZERO = _Enum_Type(0) - SYM = _Enum_Type(1) - CLAMP_TO_EDGE = _Enum_Type(2) - PERIODIC = _Enum_Type(3) + ZERO = 0 + SYM = 1 + CLAMP_TO_EDGE = 2 + PERIODIC = 3 + -class CONNECTIVITY(_Enum): +class CONNECTIVITY(Enum): """ Neighborhood connectivity """ - FOUR = _Enum_Type(4) - EIGHT = _Enum_Type(8) + FOUR = 4 + EIGHT = 8 + -class CONV_MODE(_Enum): +class CONV_MODE(Enum): """ Convolution mode """ - DEFAULT = _Enum_Type(0) - EXPAND = _Enum_Type(1) + DEFAULT = 0 + EXPAND = 1 -class CONV_DOMAIN(_Enum): + +class CONV_DOMAIN(Enum): """ Convolution domain """ - AUTO = _Enum_Type(0) - SPATIAL = _Enum_Type(1) - FREQ = _Enum_Type(2) + AUTO = 0 + SPATIAL = 1 + FREQ = 2 + -class CONV_GRADIENT(_Enum): +class CONV_GRADIENT(Enum): """ Convolution gradient type """ - DEFAULT = _Enum_Type(0) - FILTER = _Enum_Type(1) - DATA = _Enum_Type(2) - BIAS = _Enum_Type(3) + DEFAULT = 0 + FILTER = 1 + DATA = 2 + BIAS = 3 -class MATCH(_Enum): + +class MATCH(Enum): """ Match type """ @@ -203,67 +197,69 @@ class MATCH(_Enum): """ Sum of absolute differences """ - SAD = _Enum_Type(0) + SAD = 0 """ Zero mean SAD """ - ZSAD = _Enum_Type(1) + ZSAD = 1 """ Locally scaled SAD """ - LSAD = _Enum_Type(2) + LSAD = 2 """ Sum of squared differences """ - SSD = _Enum_Type(3) + SSD = 3 """ Zero mean SSD """ - ZSSD = _Enum_Type(4) + ZSSD = 4 """ Locally scaled SSD """ - LSSD = _Enum_Type(5) + LSSD = 5 """ Normalized cross correlation """ - NCC = _Enum_Type(6) + NCC = 6 """ Zero mean NCC """ - ZNCC = _Enum_Type(7) + ZNCC = 7 """ Sum of hamming distances """ - SHD = _Enum_Type(8) + SHD = 8 -class YCC_STD(_Enum): +class YCC_STD(Enum): """ YCC Standard formats """ - BT_601 = _Enum_Type(601) - BT_709 = _Enum_Type(709) - BT_2020 = _Enum_Type(2020) + BT_601 = 601 + BT_709 = 709 + BT_2020 = 2020 + -class CSPACE(_Enum): +class CSPACE(Enum): """ Colorspace formats """ - GRAY = _Enum_Type(0) - RGB = _Enum_Type(1) - HSV = _Enum_Type(2) - YCbCr= _Enum_Type(3) + GRAY = 0 + RGB = 1 + HSV = 2 + YCbCr = 3 -class MATPROP(_Enum): + +class MATPROP(Enum): """ Matrix properties """ @@ -271,225 +267,241 @@ class MATPROP(_Enum): """ None, general. """ - NONE = _Enum_Type(0) + NONE = 0 """ Transposed. """ - TRANS = _Enum_Type(1) + TRANS = 1 """ Conjugate transposed. """ - CTRANS = _Enum_Type(2) + CTRANS = 2 """ Upper triangular matrix. """ - UPPER = _Enum_Type(32) + UPPER = 32 """ Lower triangular matrix. """ - LOWER = _Enum_Type(64) + LOWER = 64 """ Treat diagonal as units. """ - DIAG_UNIT = _Enum_Type(128) + DIAG_UNIT = 128 """ Symmetric matrix. """ - SYM = _Enum_Type(512) + SYM = 512 """ Positive definite matrix. """ - POSDEF = _Enum_Type(1024) + POSDEF = 1024 """ Orthogonal matrix. """ - ORTHOG = _Enum_Type(2048) + ORTHOG = 2048 """ Tri diagonal matrix. """ - TRI_DIAG = _Enum_Type(4096) + TRI_DIAG = 4096 """ Block diagonal matrix. """ - BLOCK_DIAG = _Enum_Type(8192) + BLOCK_DIAG = 8192 + -class NORM(_Enum): +class NORM(Enum): """ Norm types """ - VECTOR_1 = _Enum_Type(0) - VECTOR_INF = _Enum_Type(1) - VECTOR_2 = _Enum_Type(2) - VECTOR_P = _Enum_Type(3) - MATRIX_1 = _Enum_Type(4) - MATRIX_INF = _Enum_Type(5) - MATRIX_2 = _Enum_Type(6) - MATRIX_L_PQ = _Enum_Type(7) - EUCLID = VECTOR_2 + VECTOR_1 = 0 + VECTOR_INF = 1 + VECTOR_2 = 2 + VECTOR_P = 3 + MATRIX_1 = 4 + MATRIX_INF = 5 + MATRIX_2 = 6 + MATRIX_L_PQ = 7 + EUCLID = VECTOR_2 -class COLORMAP(_Enum): + +class COLORMAP(Enum): """ Colormaps """ - DEFAULT = _Enum_Type(0) - SPECTRUM = _Enum_Type(1) - COLORS = _Enum_Type(2) - RED = _Enum_Type(3) - MOOD = _Enum_Type(4) - HEAT = _Enum_Type(5) - BLUE = _Enum_Type(6) + DEFAULT = 0 + SPECTRUM = 1 + COLORS = 2 + RED = 3 + MOOD = 4 + HEAT = 5 + BLUE = 6 + -class IMAGE_FORMAT(_Enum): +class IMAGE_FORMAT(Enum): """ Image Formats """ - BMP = _Enum_Type(0) - ICO = _Enum_Type(1) - JPEG = _Enum_Type(2) - JNG = _Enum_Type(3) - PNG = _Enum_Type(13) - PPM = _Enum_Type(14) - PPMRAW = _Enum_Type(15) - TIFF = _Enum_Type(18) - PSD = _Enum_Type(20) - HDR = _Enum_Type(26) - EXR = _Enum_Type(29) - JP2 = _Enum_Type(31) - RAW = _Enum_Type(34) - -class HOMOGRAPHY(_Enum): + BMP = 0 + ICO = 1 + JPEG = 2 + JNG = 3 + PNG = 13 + PPM = 14 + PPMRAW = 15 + TIFF = 18 + PSD = 20 + HDR = 26 + EXR = 29 + JP2 = 31 + RAW = 34 + + +class HOMOGRAPHY(Enum): """ Homography Types """ - RANSAC = _Enum_Type(0) - LMEDS = _Enum_Type(1) + RANSAC = 0 + LMEDS = 1 + -class BACKEND(_Enum): +class BACKEND(Enum): """ Backend libraries """ - DEFAULT = _Enum_Type(0) - CPU = _Enum_Type(1) - CUDA = _Enum_Type(2) - OPENCL = _Enum_Type(4) + DEFAULT = 0 + CPU = 1 + CUDA = 2 + OPENCL = 4 -class MARKER(_Enum): + +class MARKER(Enum): """ Markers used for different points in graphics plots """ - NONE = _Enum_Type(0) - POINT = _Enum_Type(1) - CIRCLE = _Enum_Type(2) - SQUARE = _Enum_Type(3) - TRIANGE = _Enum_Type(4) - CROSS = _Enum_Type(5) - PLUS = _Enum_Type(6) - STAR = _Enum_Type(7) + NONE = 0 + POINT = 1 + CIRCLE = 2 + SQUARE = 3 + TRIANGE = 4 + CROSS = 5 + PLUS = 6 + STAR = 7 + -class MOMENT(_Enum): +class MOMENT(Enum): """ Image Moments types """ - M00 = _Enum_Type(1) - M01 = _Enum_Type(2) - M10 = _Enum_Type(4) - M11 = _Enum_Type(8) - FIRST_ORDER = _Enum_Type(15) + M00 = 1 + M01 = 2 + M10 = 4 + M11 = 8 + FIRST_ORDER = 15 + -class BINARYOP(_Enum): +class BINARYOP(Enum): """ Binary Operators """ - ADD = _Enum_Type(0) - MUL = _Enum_Type(1) - MIN = _Enum_Type(2) - MAX = _Enum_Type(3) + ADD = 0 + MUL = 1 + MIN = 2 + MAX = 3 -class RANDOM_ENGINE(_Enum): + +class RANDOM_ENGINE(Enum): """ Random engine types """ - PHILOX_4X32_10 = _Enum_Type(100) - THREEFRY_2X32_16 = _Enum_Type(200) - MERSENNE_GP11213 = _Enum_Type(300) - PHILOX = PHILOX_4X32_10 - THREEFRY = THREEFRY_2X32_16 - DEFAULT = PHILOX + PHILOX_4X32_10 = 100 + THREEFRY_2X32_16 = 200 + MERSENNE_GP11213 = 300 + PHILOX = PHILOX_4X32_10 + THREEFRY = THREEFRY_2X32_16 + DEFAULT = PHILOX + -class STORAGE(_Enum): +class STORAGE(Enum): """ Matrix Storage types """ - DENSE = _Enum_Type(0) - CSR = _Enum_Type(1) - CSC = _Enum_Type(2) - COO = _Enum_Type(3) + DENSE = 0 + CSR = 1 + CSC = 2 + COO = 3 -class CANNY_THRESHOLD(_Enum): + +class CANNY_THRESHOLD(Enum): """ Canny Edge Threshold types """ - MANUAL = _Enum_Type(0) - AUTO_OTSU = _Enum_Type(1) + MANUAL = 0 + AUTO_OTSU = 1 + -class FLUX(_Enum): +class FLUX(Enum): """ Flux functions """ - DEFAULT = _Enum_Type(0) - QUADRATIC = _Enum_Type(1) - EXPONENTIAL = _Enum_Type(2) + DEFAULT = 0 + QUADRATIC = 1 + EXPONENTIAL = 2 + -class DIFFUSION(_Enum): +class DIFFUSION(Enum): """ Diffusion equations """ - DEFAULT = _Enum_Type(0) - GRAD = _Enum_Type(1) - MCDE = _Enum_Type(2) + DEFAULT = 0 + GRAD = 1 + MCDE = 2 -class TOPK(_Enum): + +class TOPK(Enum): """ Top-K ordering """ - DEFAULT = _Enum_Type(0) - MIN = _Enum_Type(1) - MAX = _Enum_Type(2) + DEFAULT = 0 + MIN = 1 + MAX = 2 + -class ITERATIVE_DECONV(_Enum): +class ITERATIVE_DECONV(Enum): """ Iterative deconvolution algorithm """ - DEFAULT = _Enum_Type(0) - LANDWEBER = _Enum_Type(1) - RICHARDSONLUCY = _Enum_Type(2) + DEFAULT = 0 + LANDWEBER = 1 + RICHARDSONLUCY = 2 -class INVERSE_DECONV(_Enum): + +class INVERSE_DECONV(Enum): """ Inverse deconvolution algorithm """ - DEFAULT = _Enum_Type(0) - TIKHONOV = _Enum_Type(1) + DEFAULT = 0 + TIKHONOV = 1 + -class VARIANCE(_Enum): +class VARIANCE(Enum): """ Variance bias type """ - DEFAULT = _Enum_Type(0) - SAMPLE = _Enum_Type(1) - POPULATION = _Enum_Type(2) + DEFAULT = 0 + SAMPLE = 1 + POPULATION = 2 -from .util import to_str AF_VER_MAJOR = "3" FORGE_VER_MAJOR = "1" @@ -817,3 +829,7 @@ def safe_call(af_error): err_len = c_dim_t(0) backend.get().af_get_last_error(c_pointer(err_str), c_pointer(err_len)) raise RuntimeError(to_str(err_str)) + + +def to_str(c_str): + return str(c_str.value.decode('utf-8')) diff --git a/arrayfire/opencl.py b/arrayfire/opencl.py index 5c21e5cb8..06fce9ae3 100644 --- a/arrayfire/opencl.py +++ b/arrayfire/opencl.py @@ -12,31 +12,32 @@ This module provides interoperability with other OpenCL libraries. """ +from enum import Enum -from .library import _Enum, _Enum_Type, c_int_t, c_pointer, c_void_ptr_t +from .library import c_int_t, c_pointer, c_void_ptr_t -class DEVICE_TYPE(_Enum): +class DEVICE_TYPE(Enum): """ ArrayFire wrapper for CL_DEVICE_TYPE """ - CPU = _Enum_Type(1 << 1) - GPU = _Enum_Type(1 << 2) - ACC = _Enum_Type(1 << 3) - UNKNOWN = _Enum_Type(-1) + CPU = 1 << 1 + GPU = 1 << 2 + ACC = 1 << 3 + UNKNOWN = -1 -class PLATFORM(_Enum): +class PLATFORM(Enum): """ ArrayFire enum for common platforms """ - AMD = _Enum_Type(0) - APPLE = _Enum_Type(1) - INTEL = _Enum_Type(2) - NVIDIA = _Enum_Type(3) - BEIGNET = _Enum_Type(4) - POCL = _Enum_Type(5) - UNKNOWN = _Enum_Type(-1) + AMD = 0 + APPLE = 1 + INTEL = 2 + NVIDIA = 3 + BEIGNET = 4 + POCL = 5 + UNKNOWN = -1 def get_context(retain=False): diff --git a/arrayfire/util.py b/arrayfire/util.py index bf01639a0..17431cd32 100644 --- a/arrayfire/util.py +++ b/arrayfire/util.py @@ -15,7 +15,7 @@ from .library import ( Dtype, c_char_t, c_dim_t, c_double_t, c_float_t, c_int_t, c_longlong_t, c_short_t, c_uchar_t, c_uint_t, - c_ulonglong_t, c_ushort_t) + c_ulonglong_t, c_ushort_t, to_str) def dim4(d0=1, d1=1, d2=1, d3=1): @@ -79,10 +79,6 @@ def dim4_to_tuple(dims, default=1): return tuple(out) -def to_str(c_str): - return str(c_str.value.decode('utf-8')) - - def get_version(): """ Function to get the version of arrayfire.