From c59aedbb883920b13b341f811d365c48dfa9e73b Mon Sep 17 00:00:00 2001 From: DJ NUO Date: Wed, 3 Jul 2024 21:12:04 +0000 Subject: [PATCH] multithreading acceleration --- helpers.py | 72 ++++++++++++++++++++++++++++++++---------------------- 1 file changed, 43 insertions(+), 29 deletions(-) diff --git a/helpers.py b/helpers.py index 0646cf2..ed6e762 100644 --- a/helpers.py +++ b/helpers.py @@ -1,5 +1,6 @@ import datetime import glob +import multiprocessing import time import uuid import soundfile as sf @@ -36,6 +37,26 @@ def get_run_guid(): return str(uuid.uuid4()) +def process_file_pair(args): + index, original_file, calculated_file, stem_name, run_guid, run_datetime = args + sdr = _sdr(original_track=original_file, + calculated_track=calculated_file) + + # persisting individual SDRs for audit + result_audit = { + "sdr": sdr, + "stem": stem_name, + "run_guid": run_guid, + "original_track": original_file, + "calculated_track": calculated_file, + "run_datetime": run_datetime + } + + persist_result(result=result_audit, csv_file='results_audit.csv') + + return index, sdr + + def sdr_folder(glob_original, glob_calculated, stem_name, run_datetime=None, run_guid=None, persist_result=persist_result, title=None, description=None): """ Files are processed in alphabetical ascending order @@ -68,38 +89,31 @@ def sdr_folder(glob_original, glob_calculated, stem_name, run_datetime=None, run if len(files_original) == 0: return - # progress bar just for eye candy - progress_bar = tqdm(total=len(files_original)) + # Prepare the arguments for the pool + args_list = [(i, files_original[i], files_calculated[i], stem_name, + run_guid, run_datetime) for i in range(len(files_original))] - # calculating SDR + # multi-core processing all_sdr = [] - for i, original_file in enumerate(files_original): - calculated_file = files_calculated[i] - sdr = _sdr(original_track=original_file, - calculated_track=calculated_file) - all_sdr.append(sdr) - - # persisting individual SDRs for audit - result_audit = { - "sdr": sdr, - "stem": stem_name, - "run_guid": run_guid, - "original_track": original_file, - "calculated_track": calculated_file, - "run_datetime": run_datetime - } - - persist_result(result=result_audit, csv_file='results_audit.csv') - progress_bar.update(i - progress_bar.n) + num_cores = multiprocessing.cpu_count() + p = multiprocessing.Pool(processes=int(num_cores/2)) + with tqdm(total=len(args_list)) as pbar: + track_iter = p.imap(process_file_pair, args_list) + for index, sdr in track_iter: + all_sdr.append(sdr) + pbar.update() + p.close() all_sdr = np.array(all_sdr).mean() + result = { + "sdr": all_sdr, + "stem": stem_name, + "run_guid": run_guid, + "run_datetime": run_datetime, + "processing_time_sec": "{:.2f}".format(time.time() - start_time), + "title": title, + "description": description + } # persisting total SDR for audit - result_audit['sdr'] = all_sdr - del result_audit['original_track'] - del result_audit['calculated_track'] - result_audit['processing_time_sec'] = "{:.2f}".format( - time.time() - start_time) - result_audit['title'] = title - result_audit['description'] = description - persist_result(result=result_audit, csv_file='results.csv') + persist_result(result=result, csv_file='results.csv')