Skip to content

Commit

Permalink
Merge pull request #1200 from OceanParcels/summedvectorfield_eval_imp…
Browse files Browse the repository at this point in the history
…lementation

Implementing fieldset.UV.eval for SummedVectorFields too
  • Loading branch information
erikvansebille authored Aug 29, 2022
2 parents 85a45b0 + ed36fdd commit 7f66104
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 7 deletions.
44 changes: 39 additions & 5 deletions parcels/compilation/codegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,17 +118,29 @@ def __init__(self, fields, args, var):


class SummedVectorFieldNode(IntrinsicNode):
def __getattr__(self, attr):
if attr == "eval":
return SummedVectorFieldEvalCallNode(self)

def __getitem__(self, attr):
return SummedVectorFieldEvalNode(self.obj, attr)


class SummedVectorFieldEvalCallNode(IntrinsicNode):
def __init__(self, field):
self.field = field
self.obj = field.obj
self.ccode = ""


class SummedVectorFieldEvalNode(IntrinsicNode):
def __init__(self, fields, args, var, var2, var3):
def __init__(self, fields, args, var, var2, var3, convert=True):
self.fields = fields
self.args = args
self.var = var # the variable in which the interpolated field is written
self.var2 = var2 # second variable for UV interpolation
self.var3 = var3 # third variable for UVW interpolation
self.convert = convert # whether to convert the result (like field.applyConversion)


class NestedFieldNode(IntrinsicNode):
Expand Down Expand Up @@ -450,6 +462,28 @@ def visit_Call(self, node):
return ast.Tuple([ast.Name(id=tmp1), ast.Name(id=tmp2), ast.Name(id=tmp3)], ast.Load())
else:
return ast.Tuple([ast.Name(id=tmp1), ast.Name(id=tmp2)], ast.Load())

elif isinstance(node.func, SummedVectorFieldEvalCallNode):
# get a temporary value to assign result to
tmp = [self.get_tmp() for _ in range(len(node.func.obj))]
tmp2 = [self.get_tmp() for _ in range(len(node.func.obj))]
tmp3 = [self.get_tmp() if list.__getitem__(node.func.obj, 0).vector_type == '3D' else None for _ in range(len(node.func.obj))]
# whether to convert
convert = True
if "applyConversion" in node.keywords:
k = node.keywords["applyConversion"]
if isinstance(k, ast.NameConstant):
convert = k.value

# convert args to Index(Tuple(*args))
args = ast.Index(value=ast.Tuple(node.args, ast.Load()))

self.stmt_stack += [SummedVectorFieldEvalNode(node.func.field, args, tmp, tmp2, tmp3, convert)]
if all(tmp3):
return ast.Tuple([ast.Name(id='+'.join(tmp)), ast.Name(id='+'.join(tmp2)), ast.Name(id='+'.join(tmp3))], ast.Load())
else:
return ast.Tuple([ast.Name(id='+'.join(tmp)), ast.Name(id='+'.join(tmp2))], ast.Load())

return node


Expand Down Expand Up @@ -979,14 +1013,14 @@ def visit_SummedVectorFieldEvalNode(self, node):
for fld, var, var2, var3 in zip(node.fields.obj, node.var, node.var2, node.var3):
ccode_eval = fld.ccode_eval_array(var, var2, var3,
fld.U, fld.V, fld.W, *args)
if fld.U.interp_method != 'cgrid_velocity':
if node.convert and fld.U.interp_method != 'cgrid_velocity':
ccode_conv1 = fld.U.ccode_convert(*args)
ccode_conv2 = fld.V.ccode_convert(*args)
statements = [c.Statement("%s *= %s" % (var, ccode_conv1)),
c.Statement("%s *= %s" % (var2, ccode_conv2))]
else:
statements = []
if fld.vector_type == '3D':
if node.convert and fld.vector_type == '3D':
ccode_conv3 = fld.W.ccode_convert(*args)
statements.append(c.Statement("%s *= %s" % (var3, ccode_conv3)))
cstat += [c.Assign("err", ccode_eval), c.Block(statements)]
Expand Down Expand Up @@ -1129,14 +1163,14 @@ def visit_SummedVectorFieldEvalNode(self, node):
args = self._check_FieldSamplingArguments(node.args.ccode)
for fld, var, var2, var3 in zip(node.fields.obj, node.var, node.var2, node.var3):
ccode_eval = fld.ccode_eval_object(var, var2, var3, fld.U, fld.V, fld.W, *args)
if fld.U.interp_method != 'cgrid_velocity':
if node.convert and fld.U.interp_method != 'cgrid_velocity':
ccode_conv1 = fld.U.ccode_convert(*args)
ccode_conv2 = fld.V.ccode_convert(*args)
statements = [c.Statement("%s *= %s" % (var, ccode_conv1)),
c.Statement("%s *= %s" % (var2, ccode_conv2))]
else:
statements = []
if fld.vector_type == '3D':
if node.convert and fld.vector_type == '3D':
ccode_conv3 = fld.W.ccode_convert(*args)
statements.append(c.Statement("%s *= %s" % (var3, ccode_conv3)))
cstat += [c.Assign("err", ccode_eval), c.Block(statements)]
Expand Down
8 changes: 8 additions & 0 deletions parcels/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -1897,6 +1897,14 @@ def __init__(self, name, F, V=None, W=None):
self.append(VectorField(name+'_%d' % i, Fi, Vi, Wi))
self.name = name

def eval(self, time, z, y, x, particle=None, applyConversion=True):
vals = []
val = None
for iField in range(len(self)):
val = list.__getitem__(self, iField).eval(time, z, y, x, applyConversion=applyConversion)
vals.append(val)
return tuple(np.sum(vals, 0)) if isinstance(val, tuple) else np.sum(vals)

def __getitem__(self, key):
if isinstance(key, int):
return list.__getitem__(self, key)
Expand Down
8 changes: 6 additions & 2 deletions tests/test_fieldset_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,16 +814,20 @@ def test_summedfields(pset_mode, mode, with_W, k_sample_p, mesh):
fieldsetS.add_field((P1+P4)+(P2+P3), name='P')
assert np.allclose(fieldsetS.P[0, 0, 0, 0], 60)

def sample_UV_noconvert(particle, fieldset, time):
(particle.u, particle.v) = fieldset.UV.eval(time, particle.depth, particle.lat, particle.lon, applyConversion=False) # noqa

if with_W:
W1 = Field('W', 2*np.ones((zdim * gf, ydim * gf, xdim * gf), dtype=np.float32), grid=U1.grid)
W2 = Field('W', np.ones((zdim, ydim, xdim), dtype=np.float32), grid=U2.grid)
fieldsetS.add_field(W1+W2, name='W')
pset = pset_type[pset_mode]['pset'](fieldsetS, pclass=pclass(mode), lon=[0], lat=[0.9])
pset.execute(AdvectionRK4_3D+pset.Kernel(k_sample_p), runtime=2, dt=1)
pset.execute(AdvectionRK4_3D+pset.Kernel(k_sample_p)+sample_UV_noconvert, runtime=2, dt=1)
assert np.isclose(pset.depth[0], 6)
else:
pset = pset_type[pset_mode]['pset'](fieldsetS, pclass=pclass(mode), lon=[0], lat=[0.9])
pset.execute(AdvectionRK4+pset.Kernel(k_sample_p), runtime=2, dt=1)
pset.execute(AdvectionRK4+pset.Kernel(k_sample_p)+sample_UV_noconvert, runtime=2, dt=1)
assert np.isclose(pset.u[0], 0.3)
assert np.isclose(pset.p[0], 60)
assert np.isclose(pset.lon[0]*conv, 0.6, atol=1e-3)
assert np.isclose(pset.lat[0], 0.9)
Expand Down

0 comments on commit 7f66104

Please sign in to comment.