forked from AIcrowd/music-demixing-challenge-starter-kit
-
Notifications
You must be signed in to change notification settings - Fork 1
/
test.py
65 lines (51 loc) · 2.5 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#!/usr/bin/env python
# This file is the entrypoint for your submission.
# You can modify this file to include your code or directly call your functions/modules from here.
import shutil
import soundfile as sf
from evaluator.music_demixing import MusicDemixingPredictor
class CopyPredictor(MusicDemixingPredictor):
"""
PARTICIPANT_TODO:
You can do any preprocessing required for your codebase here like loading up models into memory, etc.
"""
def prediction_setup(self):
# Load your model here.
# self.separator = torch.hub.load('sigsep/open-unmix-pytorch', 'umxhq')
pass
"""
PARTICIPANT_TODO:
During the evaluation all music files will be provided one by one, along with destination path
for saving separated audios.
NOTE: In case you want to load your model, please do so in `predict_setup` function.
"""
def prediction(self, mixture_file_path, bass_file_path, drums_file_path, other_file_path, vocals_file_path):
print("Mixture file is present at following location: %s" % mixture_file_path)
# Write your prediction code here:
# [...]
# estimates = separator(audio)
# Save the wav files at assigned locations.
shutil.copyfile(mixture_file_path, bass_file_path)
shutil.copyfile(mixture_file_path, drums_file_path)
shutil.copyfile(mixture_file_path, other_file_path)
shutil.copyfile(mixture_file_path, vocals_file_path)
print("%s: prediction completed." % mixture_file_path)
class ScaledMixturePredictor(MusicDemixingPredictor):
"""Lower baseline of using `1/4 * mixture` as prediction for bass, drums, other and vocals."""
def prediction_setup(self):
"""Initialize predictor."""
pass
def prediction(self, mixture_file_path, bass_file_path, drums_file_path, other_file_path, vocals_file_path):
"""Perform prediction."""
print("Mixture file is present at following location: %s" % mixture_file_path)
x, rate = sf.read(mixture_file_path) # mixture is stereo with sample rate of 44.1kHz
n_sources = 4
sf.write(bass_file_path, 1/n_sources * x, rate)
sf.write(drums_file_path, 1/n_sources * x, rate)
sf.write(other_file_path, 1/n_sources * x, rate)
sf.write(vocals_file_path, 1/n_sources * x, rate)
print("%s: prediction completed." % mixture_file_path)
if __name__ == "__main__":
submission = ScaledMixturePredictor()
submission.run()
print("Successfully generated predictions!")