Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

优化:fast_inference_分支中TTS实例每次调用不同的weights时重复加载底膜,是否有必要? #1092

Open
KevinZhang19870314 opened this issue May 15, 2024 · 8 comments

Comments

@KevinZhang19870314
Copy link

每次api调用时候,都会实例化一个tts实例,其中_init_models方法中除了加载weights相关的模型,还重新加载了底膜,在当前的实现中属于重新加载了,如果有切换weights的需求,每次就重复加载了底膜。

基于此:是否可以优化一下TTS实例,在api初始化的时候加载一次底膜,后续就不用加载了?

@KevinZhang19870314 KevinZhang19870314 changed the title 优化:fast_inference_分支中TTS实例每次调用不同的weights时重复加载底膜,是否没有必要? 优化:fast_inference_分支中TTS实例每次调用不同的weights时重复加载底膜,是否有必要? May 15, 2024
@ChasonJiang
Copy link

是否可以优化一下TTS实例,在api初始化的时候加载一次底膜,后续就不用加载了?

这个想法是合理的,但是如果后续增加并发功能的话, #1097 中的做法可能就不太合理了,因为这可能导致多个TTS线程竞争同一个底模。如果在并发时设置load_base=True,那么和没有改没啥区别。

所以另一个合理的办法是:不修改TTS类,在api_v3.py中复用TTS instance pool已加载的TTS类,并根据当前task的需要,调用TTS.init_vits_weights()和TTS.init_t2s_weights(),重新加载T2S和VITS的weights。这样一来,底模就不必再重新加载了。

@KevinZhang19870314
Copy link
Author

是否可以优化一下TTS实例,在api初始化的时候加载一次底膜,后续就不用加载了?

这个想法是合理的,但是如果后续增加并发功能的话, #1097 中的做法可能就不太合理了,因为这可能导致多个TTS线程竞争同一个底模。如果在并发时设置load_base=True,那么和没有改没啥区别。

所以另一个合理的办法是:不修改TTS类,在api_v3.py中复用TTS instance pool已加载的TTS类,并根据当前task的需要,调用TTS.init_vits_weights()和TTS.init_t2s_weights(),重新加载T2S和VITS的weights。这样一来,底模就不必再重新加载了。

好的,等我有空了按照你的办法实现一下。:)

@KevinZhang19870314
Copy link
Author

是否可以优化一下TTS实例,在api初始化的时候加载一次底膜,后续就不用加载了?

这个想法是合理的,但是如果后续增加并发功能的话, #1097 中的做法可能就不太合理了,因为这可能导致多个TTS线程竞争同一个底模。如果在并发时设置load_base=True,那么和没有改没啥区别。

所以另一个合理的办法是:不修改TTS类,在api_v3.py中复用TTS instance pool已加载的TTS类,并根据当前task的需要,调用TTS.init_vits_weights()和TTS.init_t2s_weights(),重新加载T2S和VITS的weights。这样一来,底模就不必再重新加载了。

今天准备实现一下,发现2个问题。

前提:目前的实现使用的lru cache,就是最近最常使用的tts instance,使用的是config的hashcode来作为这个cache的key,意味着config不同就会new一个新的tts instance,相同则使用cache。

    def __hash__(self):
        return hash(self.configs_path)
  1. 如果某次推理遇到的是不在cache中的config的话,这样还是要重新加载重新new 一个tts instance。

  2. 即使使用了cache,如果调用TTS.init_vits_weights()和TTS.init_t2s_weights()了之后,那么意味着下次调用上一个weights的话,还是得调用TTS.init_vits_weights()和TTS.init_t2s_weights()一次,这样就达不到lru cache的效果了。

不知道我理解的对不对? @ChasonJiang

@ChasonJiang
Copy link

  1. 如果某次推理遇到的是不在cache中的config的话,这样还是要重新加载重新new 一个tts instance。
  2. 即使使用了cache,如果调用TTS.init_vits_weights()和TTS.init_t2s_weights()了之后,那么意味着下次调用上一个weights的话,还是得调用TTS.init_vits_weights()和TTS.init_t2s_weights()一次,这样就达不到lru cache的效果了。

第一点没问题;第二点,如果一个task所需的weight本身不在cache中,那么它就是要被加载的,无论用你的方法还是我的方法。
我说的“在api_v3.py中复用TTS instance pool已加载的TTS类”的意思是:pool已经达到了max length,且task所需的weight本身不在pool中,因此需要在pool中按照某种算法挑一个出来复用。

总之,如果要兼容并发,lru cache应该是不能用了,需要自己实现一个TTS instance pool。(还要防止并发时,两个使用相同config的task,竞争同一个instance。

另外,要完全使用同一个底模也是可以的,不过需要把使用了底模的代码逻辑从TTS类中剥离出来,然后加锁等等一系列的操作,不推荐就是了。

@KevinZhang19870314
Copy link
Author

KevinZhang19870314 commented May 22, 2024

@ChasonJiang 好的,在不修改原有TTS类的情况下,自己实现一个tts instance pool应该相对容易一些,下面是一个实现的想法(待测试),加了锁防止并发竞争同一个instance,重复创建新实例等线程不安全行为,同时也保证了底模只加载了一次(有了复用已有实例的逻辑)。

缺点就是复用已有实例的同时,下一个被复用的config来了之后就是新实例了。。。

class TTSInstancePool:
    def __init__(self, max_size):
        self.max_size = max_size
        self.lock = threading.Lock()
        self.pool = []
        self.current_index = 0

    def get_tts_instance(self, config):
        config_hash = hash(config.configs_path)

        with self.lock:
            # 如果相同config则返回已存在实例
            for tts_instance in self.pool:
                if tts_instance.config_hash == config_hash:
                    return tts_instance

            # 如果池满,则new一个新实例
            if len(self.pool) < self.max_size:
                tts_instance = TTS(config)
                self.pool.append(tts_instance)
                return tts_instance
            else:
                # 复用已有实例,重新设置weights,这里简单的轮转选一个实例,实际场景中如何选?
                tts_instance = self.pool[self.current_index]
                self.current_index = (self.current_index + 1) % self.max_size
                tts_instance.init_vits_weights(config)
                tts_instance.init_t2s_weights(config)
                return tts_instance

    def clear_pool(self):
        with self.lock:
            self.pool.clear()
            self.current_index = 0

@ChasonJiang
Copy link

ChasonJiang commented May 23, 2024

@KevinZhang19870314 我是这么实现的

import asyncio
import math
from time import time, perf_counter
import traceback
from typing import Dict, List, Union
import threading
from TTS_infer_pack.TTS import TTS, TTS_Config

class TTS_Wrapper(TTS):
    heat:float = 0
    usage_counter:int = 0
    usage_time:float = 0.0
    first_used_time:float = 0.0
    def __init__(self, configs: Union[dict, str, TTS_Config]):
        super(TTS_Wrapper, self).__init__(configs)
        self.first_used_time = perf_counter()

    def __hash__(self) -> int:
        return hash(self.first_used_time)
    
    def run(self, *args, **kwargs):
        self.usage_counter+=1
        t0=perf_counter()
        for result in super(TTS_Wrapper, self).run(*args, **kwargs):
            yield result
        t1=perf_counter()
        self.usage_time += t1-t0
        idle_time = self.usage_time-self.first_used_time
        self.heat = self.usage_counter/idle_time

        
    def reset_heat(self):
        self.heat:int = 0
        self.usage_count:int = 0
        self.usage_time:float = 0.0
        self.first_used_time:float = perf_counter()

class TTSInstancePool:
    def __init__(self, max_size):
        self.max_size:int = max_size
        self.semaphore:threading.Semaphore = threading.Semaphore(max_size)
        self.pool_lock:threading.Lock = threading.Lock()
        self.pool:Dict[int, TTS_Wrapper] = dict()
        self.current_index:int = 0
        self.size:int = 0

    async def acquire(self, configs:TTS_Config):

        self.semaphore.acquire()
        try:
            with self.pool_lock:
                ## 查询最匹配的实例
                indexed_key = None
                rank = []
                for key, tts_instance in self.pool.items():
                    if tts_instance.configs.vits_weights_path == configs.vits_weights_path \
                        and tts_instance.configs.t2s_weights_path == configs.t2s_weights_path:
                        indexed_key = key
                    rank.append((tts_instance.heat, key))
                rank.sort(key=lambda x: x[0])
                matched_key = None if len(rank)==0 else rank[0][1]

                # 如果已有实例匹配,则直接复用
                if indexed_key is not None:
                    tts_instance = self._reuse_instance(indexed_key, configs)
                    return tts_instance

                # 如果pool未满,则创建一个新实例
                if self.size < self.max_size:
                    tts_instance = TTS_Wrapper(configs)
                    self.size+=1
                    return tts_instance
                else:
                    # 否则用最合适的实例进行复用
                    tts_instance = self._reuse_instance(matched_key, configs)
                    return tts_instance
        except Exception as e:
            self.semaphore.release()
            traceback.print_exc()
            raise e

            
    async def release(self, tts_instance:TTS_Wrapper):
        assert tts_instance is not None
        with self.pool_lock:
            key = hash(tts_instance)
            if key in self.pool.keys():
                return
            self.pool[key]=tts_instance
        self.semaphore.release()

    async def clear_pool(self):
        for i in range(self.max_size):
            self.semaphore.acquire()
        with self.pool_lock:
            self.pool.clear()
        # for i in range(self.max_size):
        self.semaphore.release(self.max_size)
    
    def _reuse_instance(self, instance_key:int, configs:TTS_Config)->TTS_Wrapper:
        '''
        复用已有实例
        args:
            instance_key: int, 已有实例的Key
            config: TTS_Config
        return:
            TTS_Wrapper: 返回复用的TTS实例
        '''


        # 复用已有实例
        tts_instance = self.pool.pop(instance_key, None)
        if tts_instance is None:
            raise ValueError("Instance not found")
        
        tts_instance.configs.device= configs.device
        if tts_instance.configs.vits_weights_path != configs.vits_weights_path \
            or tts_instance.configs.t2s_weights_path != configs.t2s_weights_path:
            tts_instance.reset_heat()

        if tts_instance.configs.vits_weights_path != configs.vits_weights_path:
            tts_instance.init_vits_weights(configs.vits_weights_path)
            tts_instance.configs.vits_weights_path = configs.vits_weights_path

        if tts_instance.configs.t2s_weights_path != configs.t2s_weights_path:
            tts_instance.init_t2s_weights(configs.t2s_weights_path)
            tts_instance.configs.t2s_weights_path = configs.t2s_weights_path

        tts_instance.set_device(configs.device)
        return tts_instance

@KevinZhang19870314
Copy link
Author

KevinZhang19870314 commented May 24, 2024

好的,添加了threading.Semaphore对max_size控制的确好一些,等我有空了提个PR给你,到时候review一下? @ChasonJiang

通过TTS_Wrapper的热度来判断使用tts实例的想法没想到,蛮不错的!

@ChasonJiang
Copy link

@KevinZhang19870314 好的,代码有些小改动你再看一下

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants