Skip to content

Commit

Permalink
DatView: Fix zero() (#727)
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck authored Sep 4, 2024
1 parent 5f18075 commit b84b770
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 18 deletions.
33 changes: 15 additions & 18 deletions pyop2/types/dat.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,7 @@ def __init__(self, dat, index):
if not (0 <= i < d):
raise ex.IndexValueError("Can't create DatView with index %s for Dat with shape %s" % (index, dat.dim))
self.index = index
self._idx = (slice(None), *index)
self._parent = dat
# Point at underlying data
super(DatView, self).__init__(dat.dataset,
Expand Down Expand Up @@ -720,41 +721,37 @@ def halo_valid(self):
def halo_valid(self, value):
self._parent.halo_valid = value

@property
def dat_version(self):
return self._parent.dat_version

@property
def _data(self):
return self._parent._data[self._idx]

@property
def data(self):
full = self._parent.data
idx = (slice(None), *self.index)
return full[idx]
return self._parent.data[self._idx]

@property
def data_ro(self):
full = self._parent.data_ro
idx = (slice(None), *self.index)
return full[idx]
return self._parent.data_ro[self._idx]

@property
def data_wo(self):
full = self._parent.data_wo
idx = (slice(None), *self.index)
return full[idx]
return self._parent.data_wo[self._idx]

@property
def data_with_halos(self):
full = self._parent.data_with_halos
idx = (slice(None), *self.index)
return full[idx]
return self._parent.data_with_halos[self._idx]

@property
def data_ro_with_halos(self):
full = self._parent.data_ro_with_halos
idx = (slice(None), *self.index)
return full[idx]
return self._parent.data_ro_with_halos[self._idx]

@property
def data_wo_with_halos(self):
full = self._parent.data_wo_with_halos
idx = (slice(None), *self.index)
return full[idx]
return self._parent.data_wo_with_halos[self._idx]


class Dat(AbstractDat, VecAccessMixin):
Expand Down
64 changes: 64 additions & 0 deletions test/unit/test_dats.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ def mdat(d1):
return op2.MixedDat([d1, d1])


@pytest.fixture(scope='module')
def s2(s):
return op2.DataSet(s, 2)


@pytest.fixture
def vdat(s2):
return op2.Dat(s2, np.zeros(2 * nelems), dtype=np.float64)


class TestDat:

"""
Expand Down Expand Up @@ -254,6 +264,60 @@ def test_accessing_data_with_halos_increments_dat_version(self, d1):
assert d1.dat_version == 1


class TestDatView():

def test_dat_view_assign(self, vdat):
vdat.data[:, 0] = 3
vdat.data[:, 1] = 4
comp = op2.DatView(vdat, 1)
comp.data[:] = 7
assert not vdat.halo_valid
assert not comp.halo_valid

expected = np.zeros_like(vdat.data)
expected[:, 0] = 3
expected[:, 1] = 7
assert all(comp.data == expected[:, 1])
assert all(vdat.data[:, 0] == expected[:, 0])
assert all(vdat.data[:, 1] == expected[:, 1])

def test_dat_view_zero(self, vdat):
vdat.data[:, 0] = 3
vdat.data[:, 1] = 4
comp = op2.DatView(vdat, 1)
comp.zero()
assert vdat.halo_valid
assert comp.halo_valid

expected = np.zeros_like(vdat.data)
expected[:, 0] = 3
expected[:, 1] = 0
assert all(comp.data == expected[:, 1])
assert all(vdat.data[:, 0] == expected[:, 0])
assert all(vdat.data[:, 1] == expected[:, 1])

def test_dat_view_halo_valid(self, vdat):
"""Check halo validity for DatView"""
comp = op2.DatView(vdat, 1)
assert vdat.halo_valid
assert comp.halo_valid
assert vdat.dat_version == 0
assert comp.dat_version == 0

comp.data_ro_with_halos
assert vdat.halo_valid
assert comp.halo_valid
assert vdat.dat_version == 0
assert comp.dat_version == 0

# accessing comp.data_with_halos should mark the parent halo as dirty
comp.data_with_halos
assert not vdat.halo_valid
assert not comp.halo_valid
assert vdat.dat_version == 1
assert comp.dat_version == 1


if __name__ == '__main__':
import os
pytest.main(os.path.abspath(__file__))

0 comments on commit b84b770

Please sign in to comment.