diff --git a/lib/pickling.py b/lib/pickling.py new file mode 100644 index 0000000..2927422 --- /dev/null +++ b/lib/pickling.py @@ -0,0 +1,18 @@ +import bz2 +import pickle +import _pickle as cPickle +import sys + + +# Pickle a file and then compress it into a file with extension +def compress_pickle(filename, data): + sys.stderr.write(f"Saving pickle {filename}...\n") + with bz2.BZ2File(filename, 'w') as f: + cPickle.dump(data, f) + +# Load any compressed pickle file +def decompress_pickle(filename): + sys.stderr.write(f"Loading pickle {filename}...\n") + data = bz2.BZ2File(filename, 'rb') + data = cPickle.load(data) + return data diff --git a/main.py b/main.py index cd18915..6a89bf0 100644 --- a/main.py +++ b/main.py @@ -8,7 +8,10 @@ import matplotlib.pyplot as plt from scipy import signal +from tqdm import tqdm +from lib.pickling import * +from os.path import exists def load_wav(fpath): """Loads a .wav file and returns the data and sample rate. @@ -22,7 +25,7 @@ def load_wav(fpath): fs = wav_f.getframerate() clip_len_s = len(data) / fs - print(f"Loaded .wav file, n_samples={len(data)} len_s={clip_len_s}") + print(f"Loaded {fpath} file, n_samples={len(data)} len_s={clip_len_s}") return (data, fs) @@ -97,7 +100,7 @@ def quadratic_interpolation(data, max_idx, bin_size): bin_size = f[1] - f[0] max_freqs = [] - for spectrum in np.abs(np.transpose(Zxx)): + for spectrum in tqdm(np.abs(np.transpose(Zxx))): max_amp = np.amax(spectrum) max_freq_idx = np.where(spectrum == max_amp)[0][0] @@ -139,9 +142,7 @@ def sorted_pmccs(target, references): :returns: list of tuples of (reference index, PMCC), sorted desc by PMCC """ pmccs = [pmcc(target, r) for r in references] - sorted_pmccs = [(idx, v) for idx, v in sorted(enumerate(pmccs), key=lambda item: -item[1])] - - return sorted_pmccs + return list(sorted(enumerate(pmccs), key=lambda item: -item[1])) def search(target_enf, reference_enf): @@ -152,10 +153,9 @@ def search(target_enf, reference_enf): :returns: list of tuples of (reference index, PMCC), sorted desc by PMCC """ n_steps = len(reference_enf) - len(target_enf) - reference_enfs = (reference_enf[step:step+len(target_enf)] for step in range(n_steps)) + reference_enfs = (reference_enf[step:step+len(target_enf)] for step in tqdm(range(n_steps))) - coeffs = sorted_pmccs(target_enf, reference_enfs) - return coeffs + return sorted_pmccs(target_enf, reference_enfs) def gb_reference_data(year, month, day=None): @@ -214,21 +214,35 @@ def plot_series_ax(ax, series, label=None): ax.plot(t, series, label=label) +def wav_to_enf(filename, nominal_freq, freq_band_size, harmonic_n=1): + """ + This code attempts to load a pickle file that contains the enf samples. + If the pickle file does not exist, the code loads the corresponding wav file, + processes it, and saves the enf samples in a new pickle file. + This approach prevents unnecessary loading and computing. + """ + pklfilename = f".{filename}.pkl" + if exists(pklfilename): + return decompress_pickle(pklfilename) + ref_data, ref_fs = load_wav(filename) + enf = enf_series(ref_data, ref_fs, nominal_freq, freq_band_size, harmonic_n) + compress_pickle(pklfilename, enf) + return enf + if __name__ == "__main__": nominal_freq = 50 freq_band_size = 0.2 - # !!!: make sure to run ./bin/download-example-files first - ref_data, ref_fs = load_wav("./001_ref.wav") + refwav = '001_ref.wav' + wav = "001.wav" - ref_enf_output = enf_series(ref_data, ref_fs, nominal_freq, freq_band_size) + # !!!: make sure to run ./bin/download-example-files first + ref_enf_output = wav_to_enf(refwav, nominal_freq, freq_band_size, harmonic_n=1) ref_enf = ref_enf_output['enf'] # !!!: make sure to run ./bin/download-example-files first - data, fs = load_wav("./001.wav") - harmonic_n = 1 - enf_output = enf_series(data, fs, nominal_freq, freq_band_size, harmonic_n=2) + enf_output = wav_to_enf(wav, nominal_freq, freq_band_size, harmonic_n=2) target_enf = enf_output['enf'] stft = enf_output['stft'] @@ -237,7 +251,7 @@ def plot_series_ax(ax, series, label=None): Zxx = stft['Zxx'] pmccs = search(target_enf, ref_enf) - print(pmccs[0:100]) + print(pmccs[:100]) predicted_ts = pmccs[0][0] print(f"Best predicted timestamp is {predicted_ts}") # True value provided by creator of example file