Skip to content

Commit

Permalink
derivatives: fixes for scattering fitting
Browse files Browse the repository at this point in the history
  • Loading branch information
kmjc committed Jul 16, 2024
1 parent 57a86a8 commit b8c550b
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions fitburst/routines/derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,11 @@ def deriv_model_wrt_scattering_index(parameters: dict, model: float, component:
num_freq, num_time, num_component = model.timediff_per_component.shape
deriv_mod_int = np.zeros((num_freq, num_time, num_component), dtype=float)

# KC add: define parameters and objects needed for mixed-derivative calculation.
burst_width = parameters["burst_width"][component]
spectral_index = parameters["spectral_index"][component]
spectral_running = parameters["spectral_running"][component]

# now loop over each model component and compute contribution to derivative.
for current_component in range(num_component):
current_amplitude = model.amplitude_per_component[:, :, current_component]
Expand All @@ -744,8 +749,8 @@ def deriv_model_wrt_scattering_index(parameters: dict, model: float, component:
scattering_timescale = parameters["scattering_timescale"][0]
current_timediff = model.timediff_per_component[:, :, current_component]
scat_times_freq = scattering_timescale * freq_ratio ** scattering_index
spectrum = 10 ** amplitude * freq_ratio ** (spectral_index + spectral_running * log_freq)
spectrum *= freq_ratio ** (-scattering_index)
# spectrum = 10 ** current_amplitude * freq_ratio ** (spectral_index + spectral_running * log_freq)
# spectrum *= freq_ratio ** (-scattering_index)

# define argument of error and scattering timescales over frequency.
spectrum = current_amplitude * freq_ratio[:, None] ** (-scattering_index)
Expand Down Expand Up @@ -1649,7 +1654,7 @@ def deriv2_model_wrt_burst_width_scattering_index(parameters: dict, model: float
# now loop over each frequency and compute mixed-derivative array per channel.
for freq in range(current_model.shape[0]):
freq_ratio = model.freqs[freq] / ref_freq
log_freq = np.log_freq_ratio
log_freq = np.log(freq_ratio)
sc_time_freq = sc_time * freq_ratio ** sc_index
spectrum = 10 ** amplitude * freq_ratio ** (spectral_index + spectral_running * np.log(freq_ratio))
spectrum *= freq_ratio ** (-sc_index)
Expand Down

0 comments on commit b8c550b

Please sign in to comment.