Skip to content

Commit

Permalink
multithreading acceleration
Browse files Browse the repository at this point in the history
  • Loading branch information
dj-nuo committed Jul 3, 2024
1 parent 14c5ab6 commit c59aedb
Showing 1 changed file with 43 additions and 29 deletions.
72 changes: 43 additions & 29 deletions helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import glob
import multiprocessing
import time
import uuid
import soundfile as sf
Expand Down Expand Up @@ -36,6 +37,26 @@ def get_run_guid():
return str(uuid.uuid4())


def process_file_pair(args):
index, original_file, calculated_file, stem_name, run_guid, run_datetime = args
sdr = _sdr(original_track=original_file,
calculated_track=calculated_file)

# persisting individual SDRs for audit
result_audit = {
"sdr": sdr,
"stem": stem_name,
"run_guid": run_guid,
"original_track": original_file,
"calculated_track": calculated_file,
"run_datetime": run_datetime
}

persist_result(result=result_audit, csv_file='results_audit.csv')

return index, sdr


def sdr_folder(glob_original, glob_calculated, stem_name, run_datetime=None, run_guid=None, persist_result=persist_result, title=None, description=None):
"""
Files are processed in alphabetical ascending order
Expand Down Expand Up @@ -68,38 +89,31 @@ def sdr_folder(glob_original, glob_calculated, stem_name, run_datetime=None, run
if len(files_original) == 0:
return

# progress bar just for eye candy
progress_bar = tqdm(total=len(files_original))
# Prepare the arguments for the pool
args_list = [(i, files_original[i], files_calculated[i], stem_name,
run_guid, run_datetime) for i in range(len(files_original))]

# calculating SDR
# multi-core processing
all_sdr = []
for i, original_file in enumerate(files_original):
calculated_file = files_calculated[i]
sdr = _sdr(original_track=original_file,
calculated_track=calculated_file)
all_sdr.append(sdr)

# persisting individual SDRs for audit
result_audit = {
"sdr": sdr,
"stem": stem_name,
"run_guid": run_guid,
"original_track": original_file,
"calculated_track": calculated_file,
"run_datetime": run_datetime
}

persist_result(result=result_audit, csv_file='results_audit.csv')
progress_bar.update(i - progress_bar.n)
num_cores = multiprocessing.cpu_count()
p = multiprocessing.Pool(processes=int(num_cores/2))
with tqdm(total=len(args_list)) as pbar:
track_iter = p.imap(process_file_pair, args_list)
for index, sdr in track_iter:
all_sdr.append(sdr)
pbar.update()
p.close()

all_sdr = np.array(all_sdr).mean()

result = {
"sdr": all_sdr,
"stem": stem_name,
"run_guid": run_guid,
"run_datetime": run_datetime,
"processing_time_sec": "{:.2f}".format(time.time() - start_time),
"title": title,
"description": description
}
# persisting total SDR for audit
result_audit['sdr'] = all_sdr
del result_audit['original_track']
del result_audit['calculated_track']
result_audit['processing_time_sec'] = "{:.2f}".format(
time.time() - start_time)
result_audit['title'] = title
result_audit['description'] = description
persist_result(result=result_audit, csv_file='results.csv')
persist_result(result=result, csv_file='results.csv')

0 comments on commit c59aedb

Please sign in to comment.