Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some improvements #1

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions lib/pickling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import bz2
import pickle
import _pickle as cPickle
import sys


# Pickle a file and then compress it into a file with extension
def compress_pickle(filename, data):
sys.stderr.write(f"Saving pickle {filename}...\n")
with bz2.BZ2File(filename, 'w') as f:
cPickle.dump(data, f)

# Load any compressed pickle file
def decompress_pickle(filename):
sys.stderr.write(f"Loading pickle {filename}...\n")
data = bz2.BZ2File(filename, 'rb')
data = cPickle.load(data)
return data
44 changes: 29 additions & 15 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
import matplotlib.pyplot as plt

from scipy import signal
from tqdm import tqdm

from lib.pickling import *
from os.path import exists

def load_wav(fpath):
"""Loads a .wav file and returns the data and sample rate.
Expand All @@ -22,7 +25,7 @@ def load_wav(fpath):
fs = wav_f.getframerate()

clip_len_s = len(data) / fs
print(f"Loaded .wav file, n_samples={len(data)} len_s={clip_len_s}")
print(f"Loaded {fpath} file, n_samples={len(data)} len_s={clip_len_s}")

return (data, fs)

Expand Down Expand Up @@ -97,7 +100,7 @@ def quadratic_interpolation(data, max_idx, bin_size):
bin_size = f[1] - f[0]

max_freqs = []
for spectrum in np.abs(np.transpose(Zxx)):
for spectrum in tqdm(np.abs(np.transpose(Zxx))):
max_amp = np.amax(spectrum)
max_freq_idx = np.where(spectrum == max_amp)[0][0]

Expand Down Expand Up @@ -139,9 +142,7 @@ def sorted_pmccs(target, references):
:returns: list of tuples of (reference index, PMCC), sorted desc by PMCC
"""
pmccs = [pmcc(target, r) for r in references]
sorted_pmccs = [(idx, v) for idx, v in sorted(enumerate(pmccs), key=lambda item: -item[1])]

return sorted_pmccs
return list(sorted(enumerate(pmccs), key=lambda item: -item[1]))


def search(target_enf, reference_enf):
Expand All @@ -152,10 +153,9 @@ def search(target_enf, reference_enf):
:returns: list of tuples of (reference index, PMCC), sorted desc by PMCC
"""
n_steps = len(reference_enf) - len(target_enf)
reference_enfs = (reference_enf[step:step+len(target_enf)] for step in range(n_steps))
reference_enfs = (reference_enf[step:step+len(target_enf)] for step in tqdm(range(n_steps)))

coeffs = sorted_pmccs(target_enf, reference_enfs)
return coeffs
return sorted_pmccs(target_enf, reference_enfs)


def gb_reference_data(year, month, day=None):
Expand Down Expand Up @@ -214,21 +214,35 @@ def plot_series_ax(ax, series, label=None):
ax.plot(t, series, label=label)


def wav_to_enf(filename, nominal_freq, freq_band_size, harmonic_n=1):
"""
This code attempts to load a pickle file that contains the enf samples.
If the pickle file does not exist, the code loads the corresponding wav file,
processes it, and saves the enf samples in a new pickle file.
This approach prevents unnecessary loading and computing.
"""
pklfilename = f".{filename}.pkl"
if exists(pklfilename):
return decompress_pickle(pklfilename)
ref_data, ref_fs = load_wav(filename)
enf = enf_series(ref_data, ref_fs, nominal_freq, freq_band_size, harmonic_n)
compress_pickle(pklfilename, enf)
return enf

if __name__ == "__main__":
nominal_freq = 50
freq_band_size = 0.2

# !!!: make sure to run ./bin/download-example-files first
ref_data, ref_fs = load_wav("./001_ref.wav")
refwav = '001_ref.wav'
wav = "001.wav"

ref_enf_output = enf_series(ref_data, ref_fs, nominal_freq, freq_band_size)
# !!!: make sure to run ./bin/download-example-files first
ref_enf_output = wav_to_enf(refwav, nominal_freq, freq_band_size, harmonic_n=1)
ref_enf = ref_enf_output['enf']

# !!!: make sure to run ./bin/download-example-files first
data, fs = load_wav("./001.wav")

harmonic_n = 1
enf_output = enf_series(data, fs, nominal_freq, freq_band_size, harmonic_n=2)
enf_output = wav_to_enf(wav, nominal_freq, freq_band_size, harmonic_n=2)
target_enf = enf_output['enf']

stft = enf_output['stft']
Expand All @@ -237,7 +251,7 @@ def plot_series_ax(ax, series, label=None):
Zxx = stft['Zxx']

pmccs = search(target_enf, ref_enf)
print(pmccs[0:100])
print(pmccs[:100])
predicted_ts = pmccs[0][0]
print(f"Best predicted timestamp is {predicted_ts}")
# True value provided by creator of example file
Expand Down