From c30fb5c00809419d3c6de3f6260a3705a14ceb9a Mon Sep 17 00:00:00 2001 From: John Readey Date: Fri, 25 Oct 2024 09:55:39 -0500 Subject: [PATCH] add support for field assignments (#227) * add supprot for field assignments * fix flake8 error * remove debug log statement --- h5pyd/_hl/dataset.py | 127 ++++++++++++++++++++++++-------- test/hl/test_complex_numbers.py | 11 ++- test/hl/test_dataset.py | 12 --- test/hl/test_dataset_multi.py | 1 - test/hl/test_datatype.py | 4 +- test/hl/test_folder.py | 2 - 6 files changed, 105 insertions(+), 52 deletions(-) diff --git a/h5pyd/_hl/dataset.py b/h5pyd/_hl/dataset.py index e1d0c88..7958fa6 100644 --- a/h5pyd/_hl/dataset.py +++ b/h5pyd/_hl/dataset.py @@ -297,9 +297,9 @@ def make_new_dset( return dset_id -class AstypeWrapper(object): - """Wrapper to convert data on reading from a dataset.""" - +class AstypeWrapper: + """Wrapper to convert data on reading from a dataset. + """ def __init__(self, dset, dtype): self._dset = dset self._dtype = numpy.dtype(dtype) @@ -307,26 +307,25 @@ def __init__(self, dset, dtype): def __getitem__(self, args): return self._dset.__getitem__(args, new_dtype=self._dtype) - def __enter__(self): - # pylint: disable=protected-access - print( - "Using astype() as a context manager is deprecated. " - "Slice the returned object instead, like: ds.astype(np.int32)[:10]" - ) - self._dset._local.astype = self._dtype - return self - - def __exit__(self, *args): - # pylint: disable=protected-access - self._dset._local.astype = None - def __len__(self): - """Get the length of the underlying dataset + """ Get the length of the underlying dataset >>> length = len(dataset.astype('f8')) """ return len(self._dset) + def __array__(self, dtype=None, copy=True): + if copy is False: + raise ValueError( + f"AstypeWrapper.__array__ received {copy=} " + f"but memory allocation cannot be avoided on read" + ) + + data = self[:] + if dtype is not None: + return data.astype(dtype, copy=False) + return data + class AsStrWrapper: """Wrapper to decode strings on reading the dataset""" @@ -361,6 +360,43 @@ def __len__(self): return len(self._dset) +class FieldsWrapper: + """Wrapper to extract named fields from a dataset with a struct dtype""" + extract_field = None + + def __init__(self, dset, prior_dtype, names): + self._dset = dset + if isinstance(names, str): + self.extract_field = names + names = [names] + self.read_dtype = readtime_dtype(prior_dtype, names) + + def __array__(self, dtype=None, copy=True): + if copy is False: + raise ValueError( + f"FieldsWrapper.__array__ received {copy=} " + f"but memory allocation cannot be avoided on read" + ) + data = self[:] + if dtype is not None: + return data.astype(dtype, copy=False) + else: + return data + + def __getitem__(self, args): + data = self._dset.__getitem__(args, new_dtype=self.read_dtype) + if self.extract_field is not None: + data = data[self.extract_field] + return data + + def __len__(self): + """ Get the length of the underlying dataset + + >>> length = len(dataset.fields(['x', 'y'])) + """ + return len(self._dset) + + class ChunkIterator(object): """ Class to iterate through list of chunks of a given dataset @@ -486,6 +522,19 @@ def asstr(self, encoding=None, errors="strict"): return AsStrWrapper(self, encoding, errors=errors) + def fields(self, names, *, _prior_dtype=None): + """Get a wrapper to read a subset of fields from a compound data type: + + >>> 2d_coords = dataset.fields(['x', 'y'])[:] + + If names is a string, a single field is extracted, and the resulting + arrays will have that dtype. Otherwise, it should be an iterable, + and the read data will have a compound dtype. + """ + if _prior_dtype is None: + _prior_dtype = self.dtype + return FieldsWrapper(self, _prior_dtype, names) + @property def dims(self): from .dims import DimensionManager @@ -890,7 +939,7 @@ def __getitem__(self, args, new_dtype=None): * Boolean "mask" array indexing """ if new_dtype is not None: - self.log.warning("new_dtype is not supported") + self.log.debug(f"getitem.new_dtype: {new_dtype}") args = args if isinstance(args, tuple) else (args,) self.log.debug("dataset.__getitem__") for arg in args: @@ -906,8 +955,20 @@ def __getitem__(self, args, new_dtype=None): # Sort field indices from the rest of the args. names = tuple(x for x in args if isinstance(x, str)) - args = tuple(x for x in args if not isinstance(x, str)) + if names: + self.log.debug(f"names: {names}") + # Read a subset of the fields in this structured dtype + if len(names) == 1: + names = names[0] # Read with simpler dtype of this field + args = tuple(x for x in args if not isinstance(x, str)) + return self.fields(names, _prior_dtype=new_dtype)[args] + + if new_dtype is None: + new_dtype = self.dtype + else: + self.log.debug(f"new_dtype: {new_dtype}") + """ new_dtype = getattr(self._local, "astype", None) if new_dtype is not None: new_dtype = readtime_dtype(new_dtype, names) @@ -916,6 +977,7 @@ def __getitem__(self, args, new_dtype=None): # discards the array information at the top level. new_dtype = readtime_dtype(self.dtype, names) self.log.debug(f"new_dtype: {new_dtype}") + """ if new_dtype.kind == "S" and check_dtype(ref=self.dtype): new_dtype = special_dtype(ref=Reference) @@ -1015,14 +1077,14 @@ def __getitem__(self, args, new_dtype=None): self.log.debug(f"dataset shape: {self._shape}") self.log.debug(f"mshape: {mshape}") - self.log.debug(f"single_element: {single_element}") + # Perfom the actual read rsp = None req = "/datasets/" + self.id.uuid + "/value" params = {} - if len(names) > 0: - params["fields"] = ":".join(names) + if mtype.names != self.dtype.names: + params["fields"] = ":".join(mtype.names) if self.id._http_conn.mode == "r" and self.id._http_conn.cache_on: # enables lambda to be used on server @@ -1152,7 +1214,6 @@ def __getitem__(self, args, new_dtype=None): # got binary response # TBD - check expected number of bytes self.log.info(f"binary response, {len(rsp)} bytes") - # arr1d = numpy.frombuffer(rsp, dtype=mtype) arr1d = bytesToArray(rsp, mtype, page_mshape) page_arr = numpy.reshape(arr1d, page_mshape) else: @@ -1328,7 +1389,7 @@ def __setitem__(self, args, val): # get the val dtype if we're passed a numpy array try: - msg = f"val dtype: {val.dtype}, shape: {val.shape} metadata: {val.dtype.metadata}" + msg = f"val dtype: {val.dtype}, shape: {val.shape} kind: {val.dtype.kind} metadata: {val.dtype.metadata}" self.log.debug(msg) if numpy.prod(val.shape) == 0: self.log.info("no elements in numpy array, skipping write") @@ -1360,6 +1421,7 @@ def __setitem__(self, args, val): # For h5pyd, do extra check and convert type on client side for efficiency vlen_base_class = check_dtype(vlen=self.dtype) if vlen_base_class is not None and vlen_base_class not in (bytes, str): + self.log.debug(f"asarray to base_class: {vlen_base_class}") try: # Attempt to directly convert the input array of vlen data to its base class val = numpy.asarray(val, dtype=vlen_base_class) @@ -1417,6 +1479,7 @@ def __setitem__(self, args, val): # TBD: Do we need something like the following in the above if condition: # (self.dtype.str != val.dtype.str) # for cases where the val is a numpy array but different type than self? + if len(names) == 1 and self.dtype.fields is not None: # Single field selected for write, from a non-array source if not names[0] in self.dtype.fields: @@ -1427,9 +1490,12 @@ def __setitem__(self, args, val): dtype = self.dtype cast_compound = False - val = numpy.asarray(val, dtype=dtype, order="C") + self.log.debug(f"asarray dtype: {dtype}, cast_compound: {cast_compound}") + val = numpy.asarray(val, dtype=dtype.base, order="C") if cast_compound: - val = val.astype(numpy.dtype([(names[0], dtype)])) + # val = val.astype(numpy.dtype([(names[0], dtype)])) + val = val.view(numpy.dtype([(names[0], dtype)])) + val = val.reshape(val.shape[:len(val.shape) - len(dtype.shape)]) elif isinstance(val, numpy.ndarray): # convert array if needed @@ -1447,17 +1513,16 @@ def __setitem__(self, args, val): # Check for array dtype compatibility and convert mshape = None - """ - # TBD.. + self.log.debug(f"self.dtype.subdtype: {self.dtype.subdtype}") if self.dtype.subdtype is not None: shp = self.dtype.subdtype[1] # type shape valshp = val.shape[-len(shp):] if valshp != shp: # Last dimension has to match raise TypeError(f"When writing to array types,\ last N dimensions have to match (got {valshp}, but should be {shp})") - mtype = h5t.py_create(numpy.dtype((val.dtype, shp))) - mshape = val.shape[0:len(val.shape)-len(shp)] - """ + mtype = numpy.dtype((val.dtype, shp)) + self.log.debug(f"mtype for subdtype: {mtype}") + mshape = val.shape[0:len(val.shape) - len(shp)] # Check for field selection if len(names) != 0: diff --git a/test/hl/test_complex_numbers.py b/test/hl/test_complex_numbers.py index c81bede..49f6cb7 100644 --- a/test/hl/test_complex_numbers.py +++ b/test/hl/test_complex_numbers.py @@ -43,9 +43,14 @@ def test_complex_dset(self): val = dset[0] self.assertEqual(val.shape, ()) - self.assertEqual(val.dtype.kind, 'c') - self.assertEqual(val.real, 1.0) - self.assertEqual(val.imag, 0.) + if config.get('use_h5py'): + self.assertEqual(val.dtype.kind, 'c') + self.assertEqual(val.real, 1.0) + self.assertEqual(val.imag, 0.) + else: + self.assertEqual(val.dtype.kind, 'V') + self.assertEqual(val['r'], 1.0) + self.assertEqual(val['i'], 0.) def test_complex_attr(self): """Read and wrtie complex numbers in attributes""" diff --git a/test/hl/test_dataset.py b/test/hl/test_dataset.py index 9c9c31b..7a4739b 100644 --- a/test/hl/test_dataset.py +++ b/test/hl/test_dataset.py @@ -1396,13 +1396,7 @@ def test_rt(self): self.assertTrue(np.all(outdata == testdata)) self.assertEqual(outdata.dtype, testdata.dtype) - @ut.expectedFailure def test_assign(self): - # Expected failure on HSDS; skip with h5py - if config.get('use_h5py'): - self.assertTrue(False) - - # TBD: field assignment not working dt = np.dtype([('weight', (np.float64, 3)), ('endpoint_type', np.uint8), ]) @@ -1419,13 +1413,7 @@ def test_assign(self): self.assertTrue(np.all(outdata == testdata)) self.assertEqual(outdata.dtype, testdata.dtype) - @ut.expectedFailure def test_fields(self): - # Expected failure on HSDS; skip with h5py - if config.get('use_h5py'): - self.assertTrue(False) - - # TBD: field assignment not working dt = np.dtype([ ('x', np.float64), ('y', np.float64), diff --git a/test/hl/test_dataset_multi.py b/test/hl/test_dataset_multi.py index a505f15..92e9400 100644 --- a/test/hl/test_dataset_multi.py +++ b/test/hl/test_dataset_multi.py @@ -36,7 +36,6 @@ def test_multi_read_scalar_dataspaces(self): """ filename = self.getFileName("multi_read_scalar_dataspaces") print("filename:", filename) - print(f"numpy version: {np.version.version}") f = h5py.File(filename, 'w') shape = () count = 3 diff --git a/test/hl/test_datatype.py b/test/hl/test_datatype.py index b5dceac..2fae645 100644 --- a/test/hl/test_datatype.py +++ b/test/hl/test_datatype.py @@ -122,8 +122,7 @@ def test_read(self): np.testing.assert_array_equal(outdata, testdata[key]) self.assertEqual(outdata.dtype, testdata[key].dtype) - """ - TBD + @ut.expectedFailure def test_nested_compound_vlen(self): dt_inner = np.dtype([('a', h5py.vlen_dtype(np.int32)), ('b', h5py.vlen_dtype(np.int32))]) @@ -147,7 +146,6 @@ def test_nested_compound_vlen(self): # Specifying check_alignment=False because vlen fields have 8 bytes of padding # because the vlen datatype in hdf5 occupies 16 bytes self.assertArrayEqual(out, data, check_alignment=False) - """ if __name__ == '__main__': diff --git a/test/hl/test_folder.py b/test/hl/test_folder.py index c4643ed..f9e7c3b 100644 --- a/test/hl/test_folder.py +++ b/test/hl/test_folder.py @@ -29,8 +29,6 @@ def test_list(self): # Folders not supported for h5py return - # loglevel = logging.DEBUG - # logging.basicConfig( format='%(asctime)s %(message)s', level=loglevel) test_domain = self.getFileName("folder_test") filepath = self.getPathFromDomain(test_domain)