From 2087fde0399a4887628af9dde17c96bd95a9e7fe Mon Sep 17 00:00:00 2001 From: ggmarshall Date: Wed, 9 Oct 2024 15:21:55 +0200 Subject: [PATCH] first version aoe sf joint fitter --- src/pygama/pargen/AoE_cal.py | 244 +++++++++++++++++++++++++++-------- 1 file changed, 191 insertions(+), 53 deletions(-) diff --git a/src/pygama/pargen/AoE_cal.py b/src/pygama/pargen/AoE_cal.py index 49fa4d38c..af66db883 100644 --- a/src/pygama/pargen/AoE_cal.py +++ b/src/pygama/pargen/AoE_cal.py @@ -4,6 +4,7 @@ from __future__ import annotations +import copy import logging import re from datetime import datetime @@ -607,8 +608,92 @@ def get_peak_label(peak: float) -> str: return "Tl FEP @" +def pass_pdf_hpge( + x, + x_lo, + x_hi, + n_sig1, + n_sig2, + mu, + sigma, + htail, + tau, + n_bkg1, + n_bkg2, + hstep1, + hstep2, + hstep3, +): + return hpge_peak.pdf_ext( + x, x_lo, x_hi, n_sig1, mu, sigma, htail, tau, n_bkg1, hstep1 + ) + + +def fail_pdf_hpge( + x, + x_lo, + x_hi, + n_sig1, + n_sig2, + mu, + sigma, + htail, + tau, + n_bkg1, + n_bkg2, + hstep1, + hstep2, + hstep3, +): + return hpge_peak.pdf_ext( + x, x_lo, x_hi, n_sig2, mu, sigma, htail, tau, n_bkg2, hstep2 + ) + + +def tot_pdf_hpge( + x, + x_lo, + x_hi, + n_sig1, + n_sig2, + mu, + sigma, + htail, + tau, + n_bkg1, + n_bkg2, + hstep1, + hstep2, + hstep3, +): + return hpge_peak.pdf_ext( + x, x_lo, x_hi, n_sig1 + n_sig2, mu, sigma, htail, tau, n_bkg1 + n_bkg2, hstep3 + ) + + +def pass_pdf_gos( + x, x_lo, x_hi, n_sig1, n_sig2, mu, sigma, n_bkg1, n_bkg2, hstep1, hstep2, hstep3 +): + return gauss_on_step.pdf_ext(x, x_lo, x_hi, n_sig1, mu, sigma, n_bkg1, hstep1) + + +def fail_pdf_gos( + x, x_lo, x_hi, n_sig1, n_sig2, mu, sigma, n_bkg1, n_bkg2, hstep1, hstep2, hstep3 +): + return gauss_on_step.pdf_ext(x, x_lo, x_hi, n_sig2, mu, sigma, n_bkg2, hstep2) + + +def tot_pdf_gos( + x, x_lo, x_hi, n_sig1, n_sig2, mu, sigma, n_bkg1, n_bkg2, hstep1, hstep2, hstep3 +): + return gauss_on_step.pdf_ext( + x, x_lo, x_hi, n_sig1 + n_sig2, mu, sigma, n_bkg1 + n_bkg2, hstep3 + ) + + def update_guess(func, parguess, energies): - if func == gauss_on_step: + if func == gauss_on_step or func == hpge_peak: + total_events = len(energies) parguess["n_sig"] = len( energies[ @@ -616,15 +701,16 @@ def update_guess(func, parguess, energies): & (energies < parguess["mu"] + 2 * parguess["sigma"]) ] ) - parguess["n_bkg"] = total_events - parguess["n_sig"] - return parguess - - if func == hpge_peak: - total_events = len(energies) - parguess["n_sig"] = len( + parguess["n_sig"] -= len( energies[ - (energies > parguess["mu"] - 2 * parguess["sigma"]) - & (energies < parguess["mu"] + 2 * parguess["sigma"]) + (energies > parguess["x_lo"]) + & (energies < parguess["x_lo"] + 2 * parguess["sigma"]) + ] + ) + parguess["n_sig"] -= len( + energies[ + (energies > parguess["x_hi"] - 2 * parguess["sigma"]) + & (energies < parguess["x_hi"]) ] ) parguess["n_bkg"] = total_events - parguess["n_sig"] @@ -643,11 +729,11 @@ def get_survival_fraction( eres_pars, fit_range=None, high_cut=None, - guess_pars_cut=None, - guess_pars_surv=None, + pars=None, dt_mask=None, mode="greater", func=hpge_peak, + fix_step=False, display=0, ): if dt_mask is None: @@ -672,8 +758,8 @@ def get_survival_fraction( else: raise ValueError("mode not recognised") - if guess_pars_cut is None or guess_pars_surv is None: - (pars, errs, cov, _, func, _, _, _) = pgc.unbinned_staged_energy_fit( + if pars is None: + (pars, _, _, _, func, _, _, _) = pgc.unbinned_staged_energy_fit( energy, func, guess_func=energy_guess, @@ -682,40 +768,74 @@ def get_survival_fraction( fit_range=fit_range, ) - guess_pars_cut = pars - guess_pars_surv = pars + guess_pars_cut = copy.deepcopy(pars) + guess_pars_surv = copy.deepcopy(pars) + # add update guess here for n_sig and n_bkg guess_pars_cut = update_guess(func, guess_pars_cut, energy[(~nan_idxs) & (~idxs)]) - (cut_pars, cut_errs, cut_cov, _, _, _, _, _) = pgc.unbinned_staged_energy_fit( - energy[(~nan_idxs) & (~idxs)], - func, - guess=guess_pars_cut, - guess_func=energy_guess, - bounds_func=get_bounds, - fixed_func=fix_all_but_nevents, - guess_kwargs={"peak": peak, "eres": eres_pars}, - lock_guess=True, - allow_tail_drop=False, - fit_range=fit_range, - ) - guess_pars_surv = update_guess(func, guess_pars_cut, energy[(~nan_idxs) & (idxs)]) - (surv_pars, surv_errs, surv_cov, _, _, _, _, _) = pgc.unbinned_staged_energy_fit( - energy[(~nan_idxs) & (idxs)], - func, - guess=guess_pars_surv, - guess_func=energy_guess, - bounds_func=get_bounds, - fixed_func=fix_all_but_nevents, - guess_kwargs={"peak": peak, "eres": eres_pars}, - lock_guess=True, - allow_tail_drop=False, - fit_range=fit_range, - ) + guess_pars_surv = update_guess(func, guess_pars_surv, energy[(~nan_idxs) & (idxs)]) + + parguess = { + "x_lo": pars["x_lo"], + "x_hi": pars["x_hi"], + "mu": pars["mu"], + "sigma": pars["sigma"], + "n_sig1": guess_pars_surv["n_sig"], + "n_bkg1": guess_pars_surv["n_bkg"], + "n_sig2": guess_pars_cut["n_sig"], + "n_bkg2": guess_pars_cut["n_bkg"], + "hstep1": pars["hstep"], + "hstep2": pars["hstep"], + "hstep3": pars["hstep"], + } + + bounds = { + "n_sig1": (0, pars["n_sig"] + pars["n_bkg"]), + "n_sig2": (0, pars["n_sig"] + pars["n_bkg"]), + "n_bkg1": (0, pars["n_bkg"] + pars["n_sig"]), + "n_bkg2": (0, pars["n_bkg"] + pars["n_sig"]), + "hstep1": (-1, 1), + "hstep2": (-1, 1), + "hstep3": (-1, 1), + } + + if func == hpge_peak: + parguess.update({"htail": pars["htail"], "tau": pars["tau"]}) + + if func == hpge_peak: + lh = ( + cost.ExtendedUnbinnedNLL(energy[(~nan_idxs) & (idxs)], pass_pdf_hpge) + + cost.ExtendedUnbinnedNLL(energy[(~nan_idxs) & (~idxs)], fail_pdf_hpge) + + cost.ExtendedUnbinnedNLL(energy[(~nan_idxs)], tot_pdf_hpge) + ) + elif func == gauss_on_step: + lh = ( + cost.ExtendedUnbinnedNLL(energy[(~nan_idxs) & (idxs)], pass_pdf_gos) + + cost.ExtendedUnbinnedNLL(energy[(~nan_idxs) & (~idxs)], fail_pdf_gos) + + cost.ExtendedUnbinnedNLL(energy[(~nan_idxs)], tot_pdf_gos) + ) - ct_n = cut_pars["n_sig"] - ct_err = cut_errs["n_sig"] - surv_n = surv_pars["n_sig"] - surv_err = surv_errs["n_sig"] + else: + raise ValueError("Unknown func") + + m = Minuit(lh, **parguess) + fixed = ["x_lo", "x_hi", "mu", "sigma"] + if func == hpge_peak: + fixed += ["tau", "htail"] + if fix_step is True: + fixed += ["hstep1", "hstep2", "hstep3"] + + m.fixed[fixed] = True + for arg, val in bounds.items(): + m.limits[arg] = val + + m.simplex().migrad() + m.hesse() + + ct_n = m.values["n_sig2"] + ct_err = m.errors["n_sig2"] + surv_n = m.values["n_sig1"] + surv_err = m.errors["n_sig1"] pc_n = ct_n + surv_n @@ -723,7 +843,29 @@ def get_survival_fraction( err = 100 * sf * (1 - sf) * np.sqrt((ct_err / ct_n) ** 2 + (surv_err / surv_n) ** 2) sf *= 100 - return sf, err, cut_pars, surv_pars + if display > 1: + fig, (ax1, ax2, ax3) = plt.subplots(1, 3) + bins = np.arange(1552, 1612, 1) + ax1.hist(energy[(~nan_idxs) & (idxs)], bins=bins, histtype="step") + + ax2.hist(energy[(~nan_idxs) & (~idxs)], bins=bins, histtype="step") + + ax3.hist(energy[(~nan_idxs)], bins=bins, histtype="step") + + if func == hpge_peak: + ax1.plot(bins, pass_pdf_hpge(bins, **m.values.to_dict())[1]) + ax2.plot(bins, fail_pdf_hpge(bins, **m.values.to_dict())[1]) + + ax3.plot(bins, tot_pdf_hpge(bins, **m.values.to_dict())[1]) + elif func == gauss_on_step: + ax1.plot(bins, pass_pdf_gos(bins, **m.values.to_dict())[1]) + ax2.plot(bins, fail_pdf_gos(bins, **m.values.to_dict())[1]) + + ax3.plot(bins, tot_pdf_gos(bins, **m.values.to_dict())[1]) + + plt.show() + + return sf, err, m.values, m.errors def get_sf_sweep( @@ -754,7 +896,7 @@ def get_sf_sweep( cut_vals = np.linspace(cut_range[0], cut_range[1], n_samples) out_df = pd.DataFrame() - (pars, _, _, _, func, _, _, _) = pgc.unbinned_staged_energy_fit( + (pars, errs, _, _, func, _, _, _) = pgc.unbinned_staged_energy_fit( energy, hpge_peak, guess_func=energy_guess, @@ -762,8 +904,6 @@ def get_sf_sweep( guess_kwargs={"peak": peak, "eres": eres_pars}, fit_range=fit_range, ) - guess_pars_cut = pars - guess_pars_surv = pars for cut_val in cut_vals: try: @@ -776,8 +916,7 @@ def get_sf_sweep( fit_range=fit_range, dt_mask=dt_mask, mode=mode, - guess_pars_cut=guess_pars_cut, - guess_pars_surv=guess_pars_surv, + pars=pars, func=func, ) out_df = pd.concat( @@ -790,7 +929,7 @@ def get_sf_sweep( raise (e) out_df.set_index("cut_val", inplace=True) if final_cut_value is not None: - sf, sf_err, cut_pars, surv_pars = get_survival_fraction( + sf, sf_err, _, _ = get_survival_fraction( energy, cut_param, final_cut_value, @@ -799,8 +938,7 @@ def get_sf_sweep( fit_range=fit_range, dt_mask=dt_mask, mode=mode, - guess_pars_cut=guess_pars_cut, - guess_pars_surv=guess_pars_surv, + pars=pars, func=func, ) else: