Skip to content

Commit

Permalink
[WIP] It works!
Browse files Browse the repository at this point in the history
  • Loading branch information
PeterBowman committed Dec 11, 2023
1 parent 12af619 commit b884ee1
Showing 1 changed file with 65 additions and 22 deletions.
87 changes: 65 additions & 22 deletions programs/speechSynthesis/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,50 +12,53 @@

class SynthesizerFactory(ABC):
@abstractmethod
def create(self, stream):
def create(self, stream, callback):
pass

class PiperSynthesizerFactory(SynthesizerFactory):
def __init__(self, device, model, use_cuda, rf):
def __init__(self, device, model, rf):
self.device = device
self.model = model
self.rf = rf

def create(self, stream):
return PiperSynthesizer(stream, self.device, self.model, use_cuda, self.rf)
def create(self, stream, callback):
return PiperSynthesizer(stream, callback, self.device, self.model, self.rf)

class SpeechSynthesizer(roboticslab_speech.SpeechSynthesis):
def __init__(self, stream, device):
def __init__(self, stream, callback, device):
super().__init__()
self.stream = stream
self.callback = callback
device_info = sd.query_devices(device, 'input')
# soundfile expects an int, sounddevice provides a float:
self.sample_rate = int(device_info['default_samplerate'])
self.text_queue = queue.Queue(maxsize=1)

def say(self, text):
self.text_queue.put(text)
self.callback.set_generator(self._get_generator(text))
return True

@abstractmethod
def synthesize(self, text):
pass

@abstractmethod
def _get_generator(self, text):
pass

@abstractmethod
def sd_callback(self, outdata, frames, time, status):
pass

class PiperSynthesizer(SpeechSynthesizer):
def __init__(self, stream, device, model, use_cuda, rf):
super().__init__(stream, device)
def __init__(self, stream, callback, device, model, rf):
super().__init__(stream, callback, device)
self.model = model
self.rf = rf
self.voice = PiperVoice(self.model, None, use_cuda=use_cuda)
self.voice = PiperVoice.load(self.model, use_cuda=False) # TODO: cuda

def sd_callback(self, outdata, frames, time, status):
# https://stackoverflow.com/q/62521902
for raw in self.voice.synthesize_stream_raw(text):
outdata[:] = raw
def _get_generator(self, text):
return self.voice.synthesize_stream_raw(text)

def play(self):
pass
Expand Down Expand Up @@ -111,6 +114,7 @@ def int_or_str(text):
parser = argparse.ArgumentParser(description=parser.description, formatter_class=argparse.ArgumentDefaultsHelpFormatter, parents=[parser])
parser.add_argument('--backend', '-b', type=str, required=True, help='ASR backend engine')
parser.add_argument('--device', '-d', type=int_or_str, help='input device (numeric ID or substring)')
parser.add_argument('--model', type=str, help='model, e.g. follow-me')
parser.add_argument('--cuda', action='store_true', help='Use Onnx CUDA execution provider (requires onnxruntime-gpu)')
parser.add_argument('--prefix', '-p', type=str, default='/speechSynthesis', help='YARP port prefix')
parser.add_argument('--context', type=str, default='speechSynthesis', help='YARP context directory')
Expand All @@ -124,11 +128,11 @@ def int_or_str(text):
rf.setDefaultConfigFile(args.ini)

if args.backend == 'piper':
if args.dictionary is None or args.language is None:
print('Dictionary and language must be specified for Piper')
if args.model is None:
print('Model must be specified for Piper')
raise SystemExit

synthesizer_factory = PiperSynthesizerFactory(args.device, args.dictionary, args.language, rf)
synthesizer_factory = PiperSynthesizerFactory(args.device, args.model, rf)
else:
print('Backend not available, must be one of: %s' % ', '.join(BACKENDS))
raise SystemExit
Expand All @@ -143,20 +147,59 @@ def int_or_str(text):
print('Unable to open RPC port')
raise SystemExit

class Callback:
def __init__(self):
self._generator = None
self._queued_generator = None

def set_generator(self, generator):
self._generator = generator

def callback(self, outdata, frames, time, status):
# https://stackoverflow.com/a/62609827

if self._generator is not None:
try:
raw = next(self._generator)

print('len(raw): %d, len(outdata): %d' % (len(raw), len(outdata)))

if len(outdata) > len(raw):
outdata[:len(raw)] = raw
outdata[len(raw):] = b'\x00' * (len(outdata) - len(raw))
print('a')
elif len(outdata) < len(raw):
outdata[:] = raw[:len(outdata)]
self._queued_generator = self._generator
self._generator = iter([raw[len(outdata):]])
print('b')
else:
outdata[:] = raw
print('c')

return
except StopIteration:
if self._queued_generator is not None:
self._generator = self._queued_generator
self._queued_generator = None

outdata[:] = b'\x00' * len(outdata)

try:
q = queue.Queue()
cb = Callback()

with sd.RawOutputStream(blocksize=8000,
with sd.RawOutputStream(samplerate=22050,
blocksize=2048,
device=args.device,
dtype='int16',
channels=1,
callback=lambda outdata, frames, time, status: q.put(bytes(outdata))) as stream:
synthesizer = synthesizer_factory.create(stream)
callback=cb.callback) as stream:
synthesizer = synthesizer_factory.create(stream, cb)
synthesizer.yarp().attachAsServer(rpc)

while True:
text = synthesizer.text_queue.get()
synthesizer.synthesize(text)
import time
time.sleep(0.1)
except KeyboardInterrupt:
rpc.interrupt()
rpc.close()
Expand Down

0 comments on commit b884ee1

Please sign in to comment.