Skip to content

Commit

Permalink
multi-speaker aligner
Browse files Browse the repository at this point in the history
  • Loading branch information
keonlee9420 committed Sep 25, 2021
1 parent b0ad2e1 commit edfdb87
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 12 deletions.
14 changes: 6 additions & 8 deletions model/CompTransTTS.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,15 @@ def forward(
else None
)

output, text_embeds = self.encoder(texts, src_masks)
texts, text_embeds = self.encoder(texts, src_masks)

speaker_embeds = None
if self.speaker_emb is not None:
if self.embedder_type == "none":
output = output + self.speaker_emb(speakers).unsqueeze(1).expand(
-1, max_src_len, -1
)
speaker_embeds = self.speaker_emb(speakers) # [B, H]
else:
assert spker_embeds is not None, "Speaker embedding should not be None"
output = output + self.speaker_emb(spker_embeds).unsqueeze(1).expand(
-1, max_src_len, -1
)
speaker_embeds = self.speaker_emb(spker_embeds) # [B, H]

(
output,
Expand All @@ -110,7 +107,8 @@ def forward(
mel_masks,
attn_outs,
) = self.variance_adaptor(
output,
speaker_embeds,
texts,
text_embeds,
src_lens,
src_masks,
Expand Down
29 changes: 25 additions & 4 deletions model/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
pad_1D,
pad,
)
from .transformers.blocks import ConvNorm
from .transformers.blocks import LinearNorm, ConvNorm


@jit(nopython=True)
Expand Down Expand Up @@ -154,6 +154,7 @@ def __init__(self, preprocess_config, model_config, train_config):
n_att_channels=preprocess_config["preprocessing"]["mel"]["n_mel_channels"],
n_text_channels=model_config["transformer"]["encoder_hidden"],
temperature=model_config["duration_modeling"]["aligner_temperature"],
multi_speaker=model_config["multi_speaker"],
)

pitch_level_tag, energy_level_tag, self.pitch_feature_level, self.energy_feature_level = \
Expand Down Expand Up @@ -285,7 +286,8 @@ def get_energy_embedding(self, x, target, mask, control):

def forward(
self,
x,
speaker_embedding,
text,
text_embedding,
src_len,
src_mask,
Expand All @@ -302,6 +304,11 @@ def forward(
d_control=1.0,
step=None,
):
x = text
if speaker_embedding is not None:
x = x + speaker_embedding.unsqueeze(1).expand(
-1, text.shape[1], -1
)

log_duration_prediction = self.duration_predictor(x, src_mask)
duration_rounded = torch.clamp(
Expand All @@ -318,6 +325,7 @@ def forward(
text_embedding.transpose(1, 2),
src_mask.unsqueeze(-1),
attn_prior.transpose(1, 2),
speaker_embedding.unsqueeze(1),
)
attn_hard = self.binarize_attention_parallel(attn_soft, src_len, mel_len)
attn_hard_dur = attn_hard.sum(2)[:, 0, :]
Expand Down Expand Up @@ -381,7 +389,8 @@ def __init__(self,
n_mel_channels,
n_att_channels,
n_text_channels,
temperature):
temperature,
multi_speaker):
super().__init__()
self.temperature = temperature
self.softmax = torch.nn.Softmax(dim=3)
Expand Down Expand Up @@ -428,17 +437,29 @@ def __init__(self,
),
)

def forward(self, queries, keys, mask=None, attn_prior=None):
if multi_speaker:
self.key_spk_proj = LinearNorm(n_text_channels, n_text_channels)
self.query_spk_proj = LinearNorm(n_text_channels, n_mel_channels)

def forward(self, queries, keys, mask=None, attn_prior=None, speaker_embed=None):
"""Forward pass of the aligner encoder.
Args:
queries (torch.tensor): B x C x T1 tensor (probably going to be mel data).
keys (torch.tensor): B x C2 x T2 tensor (text data).
mask (torch.tensor): uint8 binary mask for variable length entries (should be in the T2 domain).
attn_prior (torch.tensor): prior for attention matrix.
speaker_embed (torch.tensor): B x 1 x C tnesor of speaker embedding for multi-speaker scheme.
Output:
attn (torch.tensor): B x 1 x T1 x T2 attention mask. Final dim T2 should sum to 1.
attn_logprob (torch.tensor): B x 1 x T1 x T2 log-prob attention mask.
"""
if speaker_embed is not None:
keys = keys + self.key_spk_proj(speaker_embed.expand(
-1, keys.shape[-1], -1
)).transpose(1, 2)
queries = queries + self.query_spk_proj(speaker_embed.expand(
-1, queries.shape[-1], -1
)).transpose(1, 2)
keys_enc = self.key_proj(keys) # B x n_attn_dims x T2
queries_enc = self.query_proj(queries)

Expand Down

0 comments on commit edfdb87

Please sign in to comment.