Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upadate of transit spectra calculator #57

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
12 changes: 6 additions & 6 deletions gcm_toolkit/tests/test_interface.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in an ideal world, we would add these hardcoded values in the test to the test configuration in test_gcmtools_common.py

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... but I am also fine keeping it like this.

Original file line number Diff line number Diff line change
Expand Up @@ -127,20 +127,20 @@ def test_prt_interface(petitradtrans_testdata, all_raw_testdata):

# test transit calculation
interface.chem_from_poorman("T", co_ratio=0.55, feh_ratio=0.0)
wave, spectra = interface.calc_transit_spectrum(mmw=2.33)
assert sum(spectra) == 45237213620.18512
wave, spectra, _ = interface.calc_transit_spectrum(mmw=2.33)
assert np.abs(sum(spectra) - 45398356903.221634) < 1000000

# test transit calculation
interface.chem_from_poorman("T", co_ratio=0.55, feh_ratio=0.0)
wave, spectra = interface.calc_transit_spectrum(mmw=2.33, clouds=True)
assert sum(spectra) == 45331292591.13728
wave, spectra, _ = interface.calc_transit_spectrum(mmw=2.33, clouds=True)
assert np.abs(sum(spectra) - 45491359621.84237) < 1000000

# test transit calculation
interface.chem_from_poorman("T", co_ratio=0.55, feh_ratio=0.0)
wave, spectra = interface.calc_transit_spectrum(
wave, spectra, _ = interface.calc_transit_spectrum(
mmw=2.33, clouds=True, use_bruggemann=True
)
assert sum(spectra) == 45237213620.53514
assert np.abs(sum(spectra) - 45398350666.008026) < 1000000

# Test if Pa works
interface.dsi.attrs["p_unit"] = "Pa"
Expand Down
151 changes: 119 additions & 32 deletions gcm_toolkit/utils/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,44 +238,60 @@ def _set_data_common(self, time, tag=None, regrid_lowres=False,
dsi = self.tools.get_one_model(tag).sel(time=time)

if terminator_avg:
# avarage over longitudinal opening angle. Note, this step also corrects
# for smaller areas at polar regions
ds_m = dsi.where((dsi[c['lon']] > -90 - lon_resolution/2) *
(dsi[c['lon']] < -90 + lon_resolution/2)).mean(c['lon'])
ds_e = dsi.where((dsi[c['lon']] > 90 - lon_resolution/2) *
(dsi[c['lon']] < 90 + lon_resolution/2)).mean(c['lon'])
# To perform a proper avareging over the terminator region, the data is
# sampled and regridded.

# set latitude step size
lat_step = 180 / lat_points
lat_x = np.linspace(-90+lat_step/2, 90-lat_step/2, lat_points)
# set latitude step size and latitude centers
d_lat = 180 / lat_points
c_lat = np.linspace(-90+d_lat/2, 90-d_lat/2, lat_points)

# read out values from x array to normal array input
coords = np.meshgrid(dsi[c['lon']], dsi[c['lat']])
coords = np.asarray([coords[0].flatten(), coords[1].flatten()])

# read out data that needs to be avaraged
data = []
names = []
data_nr = 0

for key in dsi.keys():
if key in ['T', 'ClAb', 'ClDs', 'ClDr'] or 'ClVf' in key:
names.append(key)
for h, zco in enumerate(dsi[c['Z']]):
data.append(dsi[key].sel(Z=zco).values.flatten())
data_nr += 1

# convert data to array
data = np.asarray(data)

# generate new dataset
ds_transit = xr.Dataset(
data_vars={},
coords={
'lat': ([c["lat"]], lat_x),
'lat': ([c["lat"]], c_lat),
'lon': ([c["lon"]], [-90, 90]),
'Z_l': ([c["Z_l"]], dsi[c['Z_l']].values),
'Z': ([c["Z"]], dsi[c['Z']].values)
},
attrs=dsi.attrs
)

# avarage in latitude space
for key in ds_e.keys():
if key in ['T', 'ClAb', 'ClDs', 'ClDr'] or 'ClVf' in key:
# empty array to fill with data
tmp = np.zeros((lat_points, 2, len(dsi[c["Z"]].values)))
for lat in range(lat_points):
tmp[lat, 0, :] = ds_m.where((ds_m[c['lat']] > lat*lat_step - 90) *
(ds_m[c['lat']] < (lat+1) * lat_step - 90)
).mean(c['lat'])[key].values
tmp[lat, 1, :] = ds_e.where((ds_e[c['lat']] > lat*lat_step - 90) *
(ds_e[c['lat']] < (lat+1)*lat_step - 90)
).mean(c['lat'])[key].values

# saving results
ds_transit[key] = ((c['lat'], c['lon'], c['Z_l']), tmp)
# loop over both limbs seperatly
tmp = np.zeros((lat_points, 2, data_nr))
for i, limb in enumerate([-90., 90.]):
# loop over each latitude point
for j, lp in enumerate(c_lat):
tmp[j, i, :] = self._terminator_slice_interpolator(
coords, data, lp - d_lat/2, lp + d_lat/2,
lon_resolution, limb)

# reshape output data and add it to the new dataset
index = 0
for key in names:
fill = np.zeros((lat_points, 2, len(dsi[c["Z"]].values)))
for h, _ in enumerate(dsi[c['Z']]):
fill[:, :, h] = tmp[:, :, index]
index += 1
ds_transit[key] = ((c['lat'], c['lon'], c['Z']), fill)

# add mark that data is ready for tranist callcuations
ds_transit.attrs['transit'] = True
Expand Down Expand Up @@ -324,6 +340,80 @@ def chem_from_poorman(self, temp_key="T", co_ratio=0.55, feh_ratio=0.0):
temp_key=temp_key, co_ratio=co_ratio, feh_ratio=feh_ratio
)

def _terminator_slice_interpolator(self, coord, data, lat_min, lat_max, opening_angle, terminator):
"""
Change to a terminator focused coordinate system for unbiased interpolation

Parameters
----------
coord : np.ndarray(2, n) [phi, theta]
coordinates of the data points known
data : np.ndarray(m, n)
m different data values for each coordinate m
lat_min : float [degree]
minimum latitude
lat_max : float [degree]
maximum latitude
opening_angle : float [degree]
Equatorial opening angle at the terminator region
terminator: float [degree]
select which terminator, use 90 for evening and -90 for morning
"""

import scipy.interpolate as ip

# convert to rad
cd = coord*np.pi/180
lmin = lat_min*np.pi/180
lmax = lat_max*np.pi/180
oa = opening_angle*np.pi/180

# check input
if oa > np.pi/6:
print('[WARN] The terminator slice interpolater assumes small opening angle. ' +
'Large opening angle lead to an under representation of data points ' +
'further away from the terminator.')
if np.abs(lmin) > np.pi/2:
raise ValueError('[EROR] lat_min needs to be within pi/2 < lat_min < pi/2')
if np.abs(lmax) > np.pi/2:
raise ValueError('[EROR] lat_max needs to be within pi/2 < lat_min < pi/2')
if lmax < lmin:
raise ValueError('[EROR] lat_max needs to be larger than lat_min')
if oa < 0:
raise ValueError('[EROR] opening_angle must be larger than 0')

# calculate theta and phi in terminator coordinate system
theta_t = np.arcsin(np.cos(cd[1])*np.cos(cd[0]))
phi_t = np.arcsin(np.sin(cd[1])/np.cos(theta_t))

# only points within the opening angle
mask = np.abs(theta_t) < oa

# only points for the given latitude range
mask *= lmin < phi_t
mask *= lmax > phi_t

# since lat_min and lat_max are directly transformed, the terminator is
# ambigouse. Here decide which terminator to pick.
mask *= terminator/cd[0] > 0

# select data and coordinates for given input
d_res = data[:, mask].T
d_gri = np.asarray([theta_t[mask], phi_t[mask]]).T

# define uniform grid to take the avarege of the physical values
grid_y, grid_x = np.meshgrid(np.linspace(lmin, lmax, 100),
np.linspace(-oa, oa, 100))

# interpolate data to new grid
out = np.zeros((len(data[:, 0]),))
for d in range(len(d_res[0])):
grid = ip.griddata(d_gri, d_res[:, d], (grid_x, grid_y), method='cubic')
out[d] = np.nanmean(grid)

# return avarage value
return out


class PrtInterface(Interface):
"""
Expand Down Expand Up @@ -673,8 +763,6 @@ def calc_transit_spectrum(
else:
wrt.write_status("INFO", "LLL mixing used")



# check if clouds are wished
do_clouds = False
if clouds is not None:
Expand Down Expand Up @@ -709,7 +797,7 @@ def calc_transit_spectrum(

# calcualte wavelengths in micron
wavelengths = 29979245800.0/self.prt.freq/1e-4

# calcualte spectra either for each limb seperatly or together
if not asymmetric:
spectra = (np.asarray(spectra_list))**2/len(np.asarray(spectra_list))
Expand All @@ -721,12 +809,11 @@ def calc_transit_spectrum(
spectra[0, :] += (np.asarray(spec))**2/len(np.asarray(spectra_list))*2
else:
spectra[1, :] += (np.asarray(spec))**2/len(np.asarray(spectra_list))*2

spectra = np.sqrt(spectra)

spectra = np.sqrt(spectra)

# return the final spectra
return wavelengths, spectra
return wavelengths, spectra, output_list

def _get_1_transit_spectra(self, lat, lon, mass_frac, gravity, mmw, rplanet,
pressure_0, do_clouds, clouds, use_bruggemann):
Expand Down