Skip to content

Commit

Permalink
Merge pull request #75 from CHIMEFRB/chimefrb-updates
Browse files Browse the repository at this point in the history
Chimefrb updates
  • Loading branch information
emmanuelfonseca authored Oct 2, 2023
2 parents 436ba34 + 50e1b1e commit e778cf1
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 62 deletions.
5 changes: 5 additions & 0 deletions fitburst/analysis/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def compute_hessian(self, data: float, parameter_list: list) -> float:
least_squares() for exact calculation of the Jacobian in terms of derivatives.
"""

print("INFO: computing hessian matrix with best-fit parameters.")

# load all parameter values into a dictionary.
parameter_dict = self.model.get_parameters_dict()

Expand Down Expand Up @@ -463,6 +465,9 @@ def _compute_fit_statistics(self, spectrum_observed: float, fit_result: object)
self.covariance_approx = covariance_approx
self.covariance = covariance
self.covariance_labels = par_labels
self.hessian_approx = hessian_approx
self.hessian = hessian

self.fit_statistics["bestfit_uncertainties"] = self.load_fit_parameters_list(
uncertainties)
self.fit_statistics["bestfit_covariance"] = None # return the full matrix at some point?
Expand Down
12 changes: 8 additions & 4 deletions fitburst/analysis/peak_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,17 @@ def find_peak(self, distance: int = 5):
self.peaks_greater_rms=self.peak_mean_intensities[index_peaks_greater_rms]
self.times_peaks_greater_rms=self.peak_times[index_peaks_greater_rms]



#Now find the width of each burst
for each_peak in index_peaks_greater_rms:
for each_peak in index_peaks_greater_rms[0]:

# These are the indices of times just below and just above the peaks.
below_peak_index,above_peak_index=each_peak-2,each_peak+2
below_peak_index, above_peak_index = each_peak - 1, each_peak + 1

if below_peak_index < 0:
below_peak_index = 0

if above_peak_index >= len(self.peak_times):
above_peak_index = each_peak

# Now find the times for above and below peak
below_peak_time=self.peak_times[below_peak_index]*1000
Expand Down
78 changes: 47 additions & 31 deletions fitburst/pipelines/fitburst_example_chimefrb.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
#! /user/bin/env python

### import and configure logger to only report warnings or worse for non-fitburst packages.
import datetime
import logging
right_now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
logging.basicConfig(filename=f"fitburst_run_{right_now}.log", level=logging.DEBUG)
logging.getLogger('cfod').setLevel(logging.WARNING)
logging.getLogger('chime_frb_api').setLevel(logging.WARNING)
logging.getLogger('matplotlib').setLevel(logging.WARNING)
logging.getLogger('numpy').setLevel(logging.WARNING)
logging.getLogger('urllib3').setLevel(logging.WARNING)
log = logging.getLogger("fitburst")

from fitburst.analysis.peak_finder import FindPeak
from fitburst.analysis.fitter import LSFitter
import fitburst.backend.chimefrb as chimefrb
Expand All @@ -19,18 +31,11 @@
matplotlib.use("Agg")
import matplotlib.pyplot as plt

### import and configure logger.
import datetime
import logging
right_now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
logging.basicConfig(filename=f"fitburst_run_{right_now}.log", level=logging.DEBUG)
log = logging.getLogger("fitburst")

### import and configure argparse.
import argparse

parser = argparse.ArgumentParser(description=
"A Python3 script that uses fitburst API to read, preprocess, window, and fit CHIME/FRB data " +
"A Python3 script that uses fitburst API to read, pre-process, window, and fit CHIME/FRB data " +
"against a model of the dynamic spectrum."
)

Expand Down Expand Up @@ -229,16 +234,18 @@
for current_fit_parameter in parameters_to_fit:
if current_fit_parameter in parameters_to_fix:
parameters_to_fix.remove(current_fit_parameter)
log.info(f"the parameter '{current_fit_parameter}' is now a fit parameter")

# loop over all CHIME/FRB events supplied at command line.
for current_event_id in eventIDs:
log.info(f"now preparing to fit spectrum for {current_event_id}")

# grab initial parameters to check if pipeline-specific parameters exist.
try:
data = chimefrb.DataReader(current_event_id, beam_id=beam)

except:
log.error(f"ERROR: {current_event_id} fails, moving on to next event...")
log.error(f"ERROR: {current_event_id} fails at DB-parsing stage, moving on to next event...")
continue

initial_parameters = data.get_parameters(pipeline=pipeline)
Expand Down Expand Up @@ -285,7 +292,7 @@
log.warning(f"window size not found in file '{latest_solution_file}'")


log.info("window size adjusted to +/- {0:.1f} ms".format(window * 1e3))
log.info(f"window size for {current_event_id} adjusted to +/- {0:.1f} ms, from input JSON data".format(window * 1e3))

else:
pass
Expand Down Expand Up @@ -333,6 +340,10 @@
variance_weight=variance_weight
)

# if desired, downsample data prior to extraction.
data.downsample(factor_freq_downsample, factor_time_upsample)
log.info(f"downsampled raw data by factors of (ds_freq, ds_time) = ({factor_freq_downsample}, {factor_time_downsample})")

# if the number of RFI-flagged channels is "too large", skip this event altogether.
num_bad_freq = data.num_freq - np.sum(data.good_freq)

Expand All @@ -343,11 +354,11 @@
continue

# now compute dedisperse matrix for data, given initial DM, and grab windowed data.
print("INFO: computing dedispersion-index matrix")
print("INFO: dedispersing over freq range ({0:.3f}, {1:.3f}) MHz".format(
np.min(data.freqs), np.max(data.freqs)
)
)
freq_min = min(data.freqs)
freq_max = max(data.freqs)

log.info(f"computing dedispersion-index matrix for {current_event_id}")
log.info(f"dedispersing data for {current_event_id} over freq range ({freq_min}, {freq_max}) MHz")
params = initial_parameters.copy()#data.burst_parameters["fitburst"]["round_2"]
data.dedisperse(
params["dm"][0],
Expand All @@ -357,12 +368,15 @@

# before doing anything, check if window size doesn't extend beyond data set.
# if it does, adjust down by an appropriate amount.
print("INFO: window size set to +/- {0:.1f} ms".format(window * 1e3))
window_max = data.times[-1] - initial_parameters["arrival_time"][-1]
window_max = data.times[-1] - np.mean(initial_parameters["arrival_time"])

if window > window_max:
window = window_max - 0.005
print("INFO: window size adjusted to +/- {0:.1f} ms".format(window * 1e3))
log.warning(f"window size for {current_event_id} adjusted to +/- {window * 1e3} ms")

if window_max < 0.:
log.error(f"{current_event_id} has a negative widnow size, initial guess for TOA is too far off...")
continue

data_windowed, times_windowed = data.window_data(np.mean(params["arrival_time"]), window=window)

Expand All @@ -374,25 +388,23 @@
data.good_freq[current_chan] = False
weird_chan += 1

log.warning(f"WARNING: there are {weird_chan} weird channels")
plt.pcolormesh(rt.manipulate.downsample_2d(data_windowed * data.good_freq[:, None], 64, 1))
plt.savefig("test.png")
#sys.exit()
if weird_chan > 0:
log.warning(f"WARNING: there are {weird_chan} weird channels")

#plt.pcolormesh(rt.manipulate.downsample_2d(data_windowed * data.good_freq[:, None], 64, 1))
#plt.savefig("test.png")

# before defining model, adjust model parameters with peak-finding algorithm.
if peakfind_rms is not None:
print("INFO: running FindPeak to isolate burst components...")
log.info(f"running FindPeak on {current_event_id} to isolate burst components...")
peaks = FindPeak(data_windowed, times_windowed, data.freqs, rms=peakfind_rms)
peaks.find_peak(distance=peakfind_dist)
initial_parameters = peaks.get_parameters_dict(initial_parameters)

#print(initial_parameters)
#sys.exit()

# now create initial model.
# since CHIME/FRB data are in msgpack format, define a few things
# so that this version of fitburst works similar to the original version on site.
print("INFO: initializing model")
log.info(f"initializing spectrum model for {current_event_id}")
num_components = len(initial_parameters["amplitude"])
initial_parameters["dm"] = [0.] * num_components

Expand Down Expand Up @@ -434,20 +446,25 @@
if not no_fit:

for current_iteration in range(num_iterations):
print(f"INFO: fitting model, loop #{current_iteration + 1}")
log.info(f"fitting model for {current_event_id}, loop #{current_iteration + 1}")
fitter = LSFitter(data_windowed, model, good_freq=data.good_freq, weighted_fit=True)
fitter.fix_parameter(parameters_to_fix)
start = time.time()
fitter.fit(exact_jacobian=True)

# before executing the fitting loop, overload model class with best-fit parameters.
if fitter.results.success:
stop = time.time()
log.info(f"LSFitter.fit() took {stop - start} seconds to run.")
model.update_parameters(fitter.fit_statistics["bestfit_parameters"])
bestfit_model = model.compute_model(data=data_windowed) * data.good_freq[:, None]
bestfit_params = model.get_parameters_dict()
bestfit_params["dm"] = [params["dm"][0] + x for x in bestfit_params["dm"] * model.num_components]
bestfit_residuals = data_windowed - bestfit_model
fit_is_successful = True
fit_statistics = fitter.fit_statistics
plt.pcolormesh(bestfit_model)
plt.savefig("test2.png")

# TODO: for now, stash covariance data for offline comparison; remove at some point.
np.savez(
Expand All @@ -464,11 +481,10 @@
### now compute best-fit model of spectrum and plot.
if fit_is_successful or no_fit:

# create summary plot.
# create summary plot using original data.
data_grouped = ut.plotting.compute_downsampled_data(
times_windowed, data.freqs, data_windowed, data.good_freq,
spectrum_model = bestfit_model, factor_freq = factor_freq_downsample,
factor_time = factor_time_downsample
spectrum_model = bestfit_model, factor_freq = int(64 / factor_freq_downsample), factor_time = 1
)

ut.plotting.plot_summary_triptych(
Expand Down
29 changes: 24 additions & 5 deletions fitburst/pipelines/fitburst_example_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,14 @@
help="If set, the run preprocessing return to normalize data and mask bad frequencies."
)

parser.add_argument(
"--ref_freq",
action="store",
default=None,
type=float,
help="If set, then replace reference frequency with command-line value."
)

parser.add_argument(
"--remove_smearing",
action="store_true",
Expand Down Expand Up @@ -295,6 +303,7 @@
peakfind_rms = args.peakfind_rms
peakfind_dist = args.peakfind_dist
preprocess_data = args.preprocess_data
ref_freq = args.ref_freq
remove_dispersion_smearing = args.remove_dispersion_smearing
use_outfile_substring = args.use_outfile_substring
scattering_timescale = args.scattering_timescale
Expand Down Expand Up @@ -339,6 +348,8 @@

# load spectrum data into memory and pre-process, and load in parameter data..
data.load_data()
print(f"INFO: there are {data.num_freq} frequencies and {data.num_time} time samples.")

data.good_freq = np.sum(data.data_weights, axis=1) != 0.
data.good_freq = np.sum(data.data_full, axis=1) != 0.

Expand Down Expand Up @@ -429,6 +440,10 @@
if width is not None:
current_parameters["burst_width"] = width

# now replace ref_freq value, if desired.
if ref_freq is not None:
current_parameters["ref_freq"] = [ref_freq] * num_components

# print parameter info if desired.
if verbose:
print(f"INFO: initial guess for {len(current_parameters['dm'])}-component model:")
Expand Down Expand Up @@ -482,9 +497,9 @@

# now set up fitter and execute least-squares fitting
for current_iteration in range(num_iterations):
fitter = LSFitter(data_windowed, model, data.good_freq)
fitter = LSFitter(data_windowed, model, data.good_freq, weighted_fit=True)
fitter.fix_parameter(parameters_to_fix)
fitter.fit()
fitter.fit(exact_jacobian=True)

print(fitter.results)

Expand All @@ -493,9 +508,13 @@
bestfit_params = fitter.fit_statistics["bestfit_parameters"]
model.update_parameters(bestfit_params)
current_params = model.get_parameters_dict()
current_params["dm"] = [x for x in bestfit_params["dm"] * num_components]
current_params["scattering_timescale"] = [x for x in
bestfit_params["scattering_timescale"] * num_components]

if not any([x == "dm" for x in parameters_to_fix]):
current_params["dm"] = [x for x in bestfit_params["dm"] * num_components]

if "scattering_timescale" not in parameters_to_fix:
current_params["scattering_timescale"] = [x for x in
bestfit_params["scattering_timescale"] * num_components]

# if this is the last iteration, create best-fit model and plot windowed data.
if current_iteration == (num_iterations - 1):
Expand Down
Loading

0 comments on commit e778cf1

Please sign in to comment.