Skip to content

Commit

Permalink
df/scripts: tqdm for voicebank test
Browse files Browse the repository at this point in the history
  • Loading branch information
Rikorose committed Nov 24, 2021
1 parent f1e35eb commit ee01954
Showing 1 changed file with 47 additions and 13 deletions.
60 changes: 47 additions & 13 deletions DeepFilterNet/df/scripts/test_voicebank_demand.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import glob
import os
import tempfile
import sys

import numpy as np
import pystoi
Expand All @@ -20,6 +21,32 @@
from df.utils import as_complex, resample
from libdf import DF

try:
from tqdm import tqdm
except ImportError:

def tqdm(iterable, desc="Progress", total=None, fallback_estimate=1000):
# tqdm not available using fallback
try:
L = iterable.__len__()
except AttributeError:
e = ""
L = total or None

print("{}: {: >2d}".format(desc, 0), end="")
for k, i in enumerate(iterable):
yield i
if L is not None:
p = (k + 1) / L
e = "" if k < (L - 1) else "\n"
else:
# Use an exponentially decaying function
p = 1 - np.exp(-k / fallback_estimate)
print("\b\b\b\b {: >2d}%".format(int(100 * p)), end=e)
sys.stdout.flush()
if L is None:
print()


def main():
parser = argparse.ArgumentParser()
Expand All @@ -33,6 +60,7 @@ def main():
parser.add_argument("--pf", action="store_true")
parser.add_argument("--verbose", action="store_true")
parser.add_argument("--output-dir", "-o", type=str, default=None)
parser.add_argument("--disable-output", action="store_true", default=None)
args = parser.parse_args()
if not os.path.isdir(args.model_base_dir):
NotADirectoryError("Base directory not found at {}".format(args.model_base_dir))
Expand All @@ -59,19 +87,24 @@ def main():
assert os.path.isdir(args.dataset_dir)
noisy_dir = os.path.join(args.dataset_dir, "noisy_testset_wav")
clean_dir = os.path.join(args.dataset_dir, "clean_testset_wav")
if args.output_dir is not None:
enh_dir = args.output_dir
if args.disable_output:
enh_dir = None
else:
enh_dir = os.path.join(args.dataset_dir, "enhanced")
if args.output_dir is not None:
enh_dir = args.output_dir
else:
enh_dir = os.path.join(args.dataset_dir, "enhanced")
os.makedirs(enh_dir, exist_ok=True)
assert os.path.isdir(noisy_dir) and os.path.isdir(clean_dir)
os.makedirs(enh_dir, exist_ok=True)
enh_stoi = []
noisy_stoi = []
enh_sisdr = []
noisy_sisdr = []
enh_comp = []
noisy_comp = []
for noisyfn, cleanfn in zip(glob.iglob(noisy_dir + "/*wav"), glob.iglob(clean_dir + "/*wav")):
noisy_files = glob.glob(noisy_dir + "/*wav")
clean_files = glob.glob(clean_dir + "/*wav")
for noisyfn, cleanfn in tqdm(zip(noisy_files, clean_files), total=len(noisy_files)):
noisy, sr = torchaudio.load(noisyfn)
clean, sr = torchaudio.load(cleanfn)
if sr != p.sr:
Expand All @@ -90,14 +123,15 @@ def main():
if args.verbose:
print(cleanfn, enh_stoi[-1], enh_comp[-1], enh_sisdr[-1])
enh = torch.as_tensor(enh).to(torch.float32).view(1, -1)
save_audio(
os.path.basename(cleanfn),
enh,
p.sr,
output_dir=enh_dir,
suffix=f"{model_n}_{enh_comp[-1][0]:.3f}",
log=args.verbose,
)
if enh_dir is not None:
save_audio(
os.path.basename(cleanfn),
enh,
p.sr,
output_dir=enh_dir,
suffix=f"{model_n}_{enh_comp[-1][0]:.3f}",
log=args.verbose,
)
logger.info(f"noisy stoi: {np.mean(noisy_stoi)}")
logger.info(f"enhanced stoi: {np.mean(enh_stoi)}")
noisy_comp = np.stack(noisy_comp)
Expand Down

0 comments on commit ee01954

Please sign in to comment.