Skip to content

Commit

Permalink
Merge pull request #938 from OceanParcels/bugfix_add_vectorfield
Browse files Browse the repository at this point in the history
Fixing bug when FieldSet.add_vector_field() called without adding individual Fields
  • Loading branch information
erikvansebille authored Oct 20, 2020
2 parents b1773bd + 70e0cc5 commit 9a5c0ee
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 4 deletions.
1 change: 1 addition & 0 deletions parcels/examples/tutorial_NestedFields.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@
"metadata": {},
"outputs": [],
"source": [
"fieldset = FieldSet(U, V) # Need to redefine fieldset because FieldSets need to be constructed before ParticleSets\n",
"F1 = Field('F1', np.ones((U1.grid.ydim, U1.grid.xdim), dtype=np.float32), grid=U1.grid)\n",
"F2 = Field('F2', 2*np.ones((U2.grid.ydim, U2.grid.xdim), dtype=np.float32), grid=U2.grid)\n",
"F = NestedField('F', [F1, F2])\n",
Expand Down
7 changes: 7 additions & 0 deletions parcels/fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class FieldSet(object):
"""
def __init__(self, U, V, fields=None):
self.gridset = GridSet()
self.completed = False
if U:
self.add_field(U, 'U')
self.time_origin = self.U.grid.time_origin if isinstance(self.U, Field) else self.U[0].grid.time_origin
Expand Down Expand Up @@ -140,6 +141,8 @@ def add_field(self, field, name=None):
* `Unit converters <https://nbviewer.jupyter.org/github/OceanParcels/parcels/blob/master/parcels/examples/tutorial_unitconverters.ipynb>`_
"""
if self.completed:
raise RuntimeError("FieldSet has already been completed. Are you trying to add a Field after you've created the ParticleSet?")
name = field.name if name is None else name
if hasattr(self, name): # check if Field with same name already exists when adding new Field
raise RuntimeError("FieldSet already has a Field with name '%s'" % name)
Expand Down Expand Up @@ -229,6 +232,9 @@ def add_vector_field(self, vfield):
:param vfield: :class:`parcels.field.VectorField` object to be added
"""
setattr(self, vfield.name, vfield)
for v in vfield.__dict__.values():
if isinstance(v, Field) and (v not in self.get_fields()):
self.add_field(v)
vfield.fieldset = self
if isinstance(vfield, NestedField):
for f in vfield:
Expand Down Expand Up @@ -314,6 +320,7 @@ def check_velocityfields(U, V, W):
if not f.grid.defer_load:
depth_data = f.grid.depth_field.data
f.grid.depth = depth_data if isinstance(depth_data, np.ndarray) else np.array(depth_data)
self.completed = True

@classmethod
def parse_wildcards(cls, paths, filenames, var):
Expand Down
50 changes: 50 additions & 0 deletions tests/test_fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,26 @@ def test_add_duplicate_field(dupobject):
assert error_thrown


@pytest.mark.parametrize('fieldtype', ['normal', 'vector'])
def test_add_field_after_pset(fieldtype):
data, dimensions = generate_fieldset(100, 100)
fieldset = FieldSet.from_data(data, dimensions)
pset = ParticleSet(fieldset, ScipyParticle, lon=0, lat=0) # noqa ; to trigger fieldset.check_complete
field1 = Field('field1', fieldset.U.data, lon=fieldset.U.lon, lat=fieldset.U.lat)
field2 = Field('field2', fieldset.U.data, lon=fieldset.U.lon, lat=fieldset.U.lat)
vfield = VectorField('vfield', field1, field2)
error_thrown = False
try:
if fieldtype == 'normal':
fieldset.add_field(field1)
elif fieldtype == 'vector':
fieldset.add_vector_field(vfield)
except RuntimeError:
error_thrown = True

assert error_thrown


def test_fieldset_samegrids_from_file(tmpdir, filename='test_subsets'):
""" Test for subsetting fieldset from file using indices dict. """
data, dimensions = generate_fieldset(100, 100)
Expand Down Expand Up @@ -463,6 +483,36 @@ def test_vector_fields(mode, swapUV):
assert abs(pset.lat[0] - .5) < 1e-9


@pytest.mark.parametrize('mode', ['scipy', 'jit'])
def test_add_second_vector_field(mode):
lon = np.linspace(0., 10., 12, dtype=np.float32)
lat = np.linspace(0., 10., 10, dtype=np.float32)
U = np.ones((10, 12), dtype=np.float32)
V = np.zeros((10, 12), dtype=np.float32)
data = {'U': U, 'V': V}
dimensions = {'U': {'lat': lat, 'lon': lon},
'V': {'lat': lat, 'lon': lon}}
fieldset = FieldSet.from_data(data, dimensions, mesh='flat')

data2 = {'U2': U, 'V2': V}
dimensions2 = {'lon': [ln + 0.1 for ln in lon], 'lat': [lt - 0.1 for lt in lat]}
fieldset2 = FieldSet.from_data(data2, dimensions2, mesh='flat')

UV2 = VectorField('UV2', fieldset2.U2, fieldset2.V2)
fieldset.add_vector_field(UV2)

def SampleUV2(particle, fieldset, time):
u, v = fieldset.UV2[time, particle.depth, particle.lat, particle.lon]
particle.lon += u * particle.dt
particle.lat += v * particle.dt

pset = ParticleSet(fieldset, pclass=ptype[mode], lon=0.5, lat=0.5)
pset.execute(AdvectionRK4+pset.Kernel(SampleUV2), dt=1, runtime=1)

assert abs(pset.lon[0] - 2.5) < 1e-9
assert abs(pset.lat[0] - .5) < 1e-9


@pytest.mark.parametrize('mode', ['scipy', 'jit'])
@pytest.mark.parametrize('time_periodic', [4*86400.0, False])
@pytest.mark.parametrize('field_chunksize', [False, 'auto', (1, 32, 32)])
Expand Down
7 changes: 3 additions & 4 deletions tests/test_fieldset_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,17 +603,16 @@ def test_multiple_grid_addlater_error():
lat=np.linspace(0., 1., ydim, dtype=np.float32))
fieldset = FieldSet(U, V)

pset = ParticleSet(fieldset, pclass=pclass('jit'), lon=[0.8], lat=[0.9])
pset = ParticleSet(fieldset, pclass=pclass('jit'), lon=[0.8], lat=[0.9]) # noqa ; to trigger fieldset.check_complete

P = Field('P', np.zeros((ydim*10, xdim*10), dtype=np.float32),
lon=np.linspace(0., 1., xdim*10, dtype=np.float32),
lat=np.linspace(0., 1., ydim*10, dtype=np.float32))
fieldset.add_field(P)

fail = False
try:
pset.execute(AdvectionRK4, runtime=10, dt=1)
except:
fieldset.add_field(P)
except RuntimeError:
fail = True
assert fail

Expand Down

0 comments on commit 9a5c0ee

Please sign in to comment.