Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature request] Add Recipe for all 3 Training stages - XTTS V2 #3704

Open
smallsudarshan opened this issue Apr 24, 2024 · 2 comments
Open
Labels
feature request feature requests for making TTS better.

Comments

@smallsudarshan
Copy link

馃殌 Feature Description
Hey, we saw that there is no training code for fine-tuning all parts of XTTS V2. We would like to contribute if it adds value.

The aim can be to make it work very reliably on a particular accent [Indian for eg.], in a particular language[English], in a particular speaking style with very little variability. We tried simply fine-tuning and it seems like it learns the accent somewhat and the speaking style, but is not super robust and mispronounces quite a lot.

Solution

  1. First fine-tune the DVAE model using code from the Dall-E repo
  2. Then get fine-tune the GPT-2 model - maybe this can be the recipe provided
  3. Finally fine-tune end to end with the Hi-Fi GAN

We are not sure if the perceiver needs any fine-tuning.

If licenses permit, we will also share the data.

Does this make sense?

@smallsudarshan smallsudarshan added the feature request feature requests for making TTS better. label Apr 24, 2024
@smallsudarshan
Copy link
Author

smallsudarshan commented Apr 27, 2024

Ok so here you go. I picked the code for training from this repo.

train.py:

import torch
import wandb
from models.dvae import DiscreteVAE
from utils.arch_utils import TorchMelSpectrogram
from torch.utils.data import DataLoader
from utils.dvae_dataset import DVAEDataset
from torch.optim import Adam
from torch.nn.utils import clip_grad_norm_

import pdb
from TTS.tts.datasets import load_tts_samples
from TTS.config.shared_configs import BaseDatasetConfig

dvae_checkpoint = '/home/ubuntu/test_tts/SimpleTTS/xtts/run/training/XTTS_v2.0_original_model_files/dvae.pth'
mel_norm_file = '/home/ubuntu/test_tts/SimpleTTS/xtts/run/training/XTTS_v2.0_original_model_files/mel_stats.pth'

config_dataset = BaseDatasetConfig(
    formatter="ljspeech",
    dataset_name="ljspeech",
    path="/home/ubuntu/test_tts/sapien-formatted-english-22050",
    meta_file_train="/home/ubuntu/test_tts/sapien-formatted-english-22050/metadata_norm.txt",
    language="en",
)

# Add here the configs of the datasets
DATASETS_CONFIG_LIST = [config_dataset]
GRAD_CLIP_NORM = 0.5
LEARNING_RATE = 5e-05

dvae = DiscreteVAE(
            channels=80,
            normalization=None,
            positional_dims=1,
            num_tokens=1024,
            codebook_dim=512,
            hidden_dim=512,
            num_resnet_blocks=3,
            kernel_size=3,
            num_layers=2,
            use_transposed_convs=False,
        )

dvae.load_state_dict(torch.load(dvae_checkpoint), strict=False)
dvae.cuda()
opt = Adam(dvae.parameters(), lr = LEARNING_RATE)
torch_mel_spectrogram_dvae = TorchMelSpectrogram(
            mel_norm_file=mel_norm_file, sampling_rate=22050
        ).cuda()

train_samples, eval_samples = load_tts_samples(
        DATASETS_CONFIG_LIST,
        eval_split=True,
        eval_split_max_size=256,
        eval_split_size=0.01,
    )

eval_dataset = DVAEDataset(eval_samples, 22050, True)
train_dataset = DVAEDataset(train_samples, 22050, False)
epochs = 20
eval_data_loader = DataLoader(
                    eval_dataset,
                    batch_size=3,
                    shuffle=False,
                    drop_last=False,
                    collate_fn=eval_dataset.collate_fn,
                    num_workers=0,
                    pin_memory=False,
                )

train_data_loader = DataLoader(
                    train_dataset,
                    batch_size=3,
                    shuffle=False,
                    drop_last=False,
                    collate_fn=train_dataset.collate_fn,
                    num_workers=4,
                    pin_memory=False,
                )

torch.set_grad_enabled(True)
dvae.train()

wandb.init(project = 'train_dvae')
wandb.watch(dvae)

def to_cuda(x: torch.Tensor) -> torch.Tensor:
    if x is None:
        return None
    if torch.is_tensor(x):
        x = x.contiguous()
        if torch.cuda.is_available():
            x = x.cuda(non_blocking=True)
    return x

@torch.no_grad()
def format_batch(batch):
    if isinstance(batch, dict):
        for k, v in batch.items():
            batch[k] = to_cuda(v)
    elif isinstance(batch, list):
        batch = [to_cuda(v) for v in batch]

    try:
        batch['mel'] = torch_mel_spectrogram_dvae(batch['wav'])
        # if the mel spectogram is not divisible by 4 then input.shape != output.shape 
        # for dvae
        remainder = batch['mel'].shape[-1] % 4
        if remainder:
            batch['mel'] = batch['mel'][:, :, :-remainder]
    except NotImplementedError:
        pass
    return batch

for i in range(epochs):
    for cur_step, batch in enumerate(train_data_loader):

        opt.zero_grad()
        batch = format_batch(batch)
        recon_loss, commitment_loss, out = dvae(batch['mel'])
        total_loss = recon_loss + commitment_loss
        total_loss.backward()
        clip_grad_norm_(dvae.parameters(), GRAD_CLIP_NORM)
        opt.step()

        log = {'epoch': i,
               'cur_step': cur_step,
               'loss': total_loss.item(),
               'recon_loss': recon_loss.item(),
               'commit_loss': commitment_loss.item()}
        print(f"epoch: {i}", print(f"step: {cur_step}"), f'loss - {total_loss.item()}', f'recon_loss - {recon_loss.item()}', f'commit_loss - {commitment_loss.item()}')
        wandb.log(log)
        torch.cuda.empty_cache()
#     if i%10:
#         save_model(f'.dvae.pth')

# wandb.save('./dvae.pth')
# wandb.finish()

Wrote a custom DVAEDataset that is imported in the above train.py file.


import torch
import random
from utils.dataset import key_samples_by_col
from TTS.tts.models.xtts import load_audio

torch.set_num_threads(1)

class DVAEDataset(torch.utils.data.Dataset):
    def __init__(self, samples, sample_rate, is_eval):
        self.sample_rate = sample_rate
        self.is_eval = is_eval
        self.max_wav_len = 255995
        self.samples = samples
        self.training_seed = 1
        self.failed_samples = set()
        if not is_eval:
            random.seed(self.training_seed)
            # random.shuffle(self.samples)
            random.shuffle(self.samples)
            # order by language
            self.samples = key_samples_by_col(self.samples, "language")
            print(" > Sampling by language:", self.samples.keys())
        else:
            # for evaluation load and check samples that are corrupted to ensures the reproducibility
            self.check_eval_samples()

    def check_eval_samples(self):
        print(" > Filtering invalid eval samples!!")
        new_samples = []
        for sample in self.samples:
            try:
                _, wav = self.load_item(sample)
            except:
                continue
            # Basically, this audio file is nonexistent or too long to be supported by the dataset.
            if (
                wav is None
                or (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len)
            ):
                continue
            new_samples.append(sample)
        self.samples = new_samples
        print(" > Total eval samples after filtering:", len(self.samples))

    def load_item(self, sample):
        audiopath = sample["audio_file"]
        wav = load_audio(audiopath, self.sample_rate)
        if wav is None or wav.shape[-1] < (0.5 * self.sample_rate):
            # Ultra short clips are also useless (and can cause problems within some models).
            raise ValueError

        return audiopath, wav
    
    def __getitem__(self, index):
        if self.is_eval:
            sample = self.samples[index]
            sample_id = str(index)
        else:
            # select a random language
            lang = random.choice(list(self.samples.keys()))
            # select random sample
            index = random.randint(0, len(self.samples[lang]) - 1)
            sample = self.samples[lang][index]
            # a unique id for each sampel to deal with fails
            sample_id = lang + "_" + str(index)

        # ignore samples that we already know that is not valid ones
        if sample_id in self.failed_samples:
            # call get item again to get other sample
            return self[1]

        # try to load the sample, if fails added it to the failed samples list
        try:
            audiopath, wav = self.load_item(sample)
        except:
            self.failed_samples.add(sample_id)
            return self[1]

        # check if the audio and text size limits and if it out of the limits, added it failed_samples
        if (
            wav is None
            or (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len)
        ):
            # Basically, this audio file is nonexistent or too long to be supported by the dataset.
            # It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result.
            self.failed_samples.add(sample_id)
            return self[1]

        res = {
            "wav": wav,
            "wav_lengths": torch.tensor(wav.shape[-1], dtype=torch.long),
            "filenames": audiopath,
        }
        return res
    
    def __len__(self):
        if self.is_eval:
            return len(self.samples)
        return sum([len(v) for v in self.samples.values()])

    def collate_fn(self, batch):
        # convert list of dicts to dict of lists
        B = len(batch)

        batch = {k: [dic[k] for dic in batch] for k in batch[0]}

        # stack for features that already have the same shape
        batch["wav_lengths"] = torch.stack(batch["wav_lengths"])

        max_wav_len = batch["wav_lengths"].max()

        # create padding tensors
        wav_padded = torch.FloatTensor(B, 1, max_wav_len)

        # initialize tensors for zero padding
        wav_padded = wav_padded.zero_()
        for i in range(B):
            wav = batch["wav"][i]
            wav_padded[i, :, : batch["wav_lengths"][i]] = torch.FloatTensor(wav)

        batch["wav"] = wav_padded
        return batch

This trains the DVAE to encode and decode mel-spectograms.

Few things:

  1. You can see my import paths are not standard. That is because I have changed the structure of the repo a bit in my personal fork. You can follow the standard import paths as per TTS.
  2. There is a loss called DiscretizationLoss here but I am not sure where or how this is used? So I am not using it currently.
  3. For some reason, in the dvae.py on line 378, the author has added self.loss_fn(img, out, reduction="none"). I am not sure what is the purpose of doing reduction='none'. So I have summed it up in my code and just added it to calculate loss.
  4. I am not sure what recipe to use for training(grad ACC steps, LR changes etc), I am just doing basic fine-tuning for now.
  5. For a small batch, my train loss seems to be very low initially and also converge quickly:
    loss

Next step would be to fine-tune a larger dataset. @erogol @eginhard if this is in the right direction, I can convert this into a training recipe
and add to the repo.

PS: The code is a bit dirty since I have just re-used whatever was available as long as it doesn't harm my training.

@smallsudarshan
Copy link
Author

smallsudarshan commented Apr 29, 2024

I also now understand that the decoder of DVAE is not used, but instead an LM head is used on the GPT-2 to recompute the mel from the audio-codes. Need to understand this a bit better before writing the next stage training code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request feature requests for making TTS better.
Projects
None yet
Development

No branches or pull requests

1 participant