-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
优化:fast_inference_分支中TTS实例每次调用不同的weights时重复加载底膜,是否有必要? #1092
Comments
这个想法是合理的,但是如果后续增加并发功能的话, #1097 中的做法可能就不太合理了,因为这可能导致多个TTS线程竞争同一个底模。如果在并发时设置load_base=True,那么和没有改没啥区别。 所以另一个合理的办法是:不修改TTS类,在api_v3.py中复用TTS instance pool已加载的TTS类,并根据当前task的需要,调用TTS.init_vits_weights()和TTS.init_t2s_weights(),重新加载T2S和VITS的weights。这样一来,底模就不必再重新加载了。 |
好的,等我有空了按照你的办法实现一下。:) |
今天准备实现一下,发现2个问题。 前提:目前的实现使用的lru cache,就是最近最常使用的tts instance,使用的是config的hashcode来作为这个cache的key,意味着config不同就会new一个新的tts instance,相同则使用cache。 def __hash__(self):
return hash(self.configs_path)
不知道我理解的对不对? @ChasonJiang |
第一点没问题;第二点,如果一个task所需的weight本身不在cache中,那么它就是要被加载的,无论用你的方法还是我的方法。 总之,如果要兼容并发,lru cache应该是不能用了,需要自己实现一个TTS instance pool。(还要防止并发时,两个使用相同config的task,竞争同一个instance。 另外,要完全使用同一个底模也是可以的,不过需要把使用了底模的代码逻辑从TTS类中剥离出来,然后加锁等等一系列的操作,不推荐就是了。 |
@ChasonJiang 好的,在不修改原有TTS类的情况下,自己实现一个tts instance pool应该相对容易一些,下面是一个实现的想法(待测试),加了锁防止并发竞争同一个instance,重复创建新实例等线程不安全行为,同时也保证了底模只加载了一次(有了复用已有实例的逻辑)。 缺点就是复用已有实例的同时,下一个被复用的config来了之后就是新实例了。。。 class TTSInstancePool:
def __init__(self, max_size):
self.max_size = max_size
self.lock = threading.Lock()
self.pool = []
self.current_index = 0
def get_tts_instance(self, config):
config_hash = hash(config.configs_path)
with self.lock:
# 如果相同config则返回已存在实例
for tts_instance in self.pool:
if tts_instance.config_hash == config_hash:
return tts_instance
# 如果池满,则new一个新实例
if len(self.pool) < self.max_size:
tts_instance = TTS(config)
self.pool.append(tts_instance)
return tts_instance
else:
# 复用已有实例,重新设置weights,这里简单的轮转选一个实例,实际场景中如何选?
tts_instance = self.pool[self.current_index]
self.current_index = (self.current_index + 1) % self.max_size
tts_instance.init_vits_weights(config)
tts_instance.init_t2s_weights(config)
return tts_instance
def clear_pool(self):
with self.lock:
self.pool.clear()
self.current_index = 0 |
@KevinZhang19870314 我是这么实现的 import asyncio
import math
from time import time, perf_counter
import traceback
from typing import Dict, List, Union
import threading
from TTS_infer_pack.TTS import TTS, TTS_Config
class TTS_Wrapper(TTS):
heat:float = 0
usage_counter:int = 0
usage_time:float = 0.0
first_used_time:float = 0.0
def __init__(self, configs: Union[dict, str, TTS_Config]):
super(TTS_Wrapper, self).__init__(configs)
self.first_used_time = perf_counter()
def __hash__(self) -> int:
return hash(self.first_used_time)
def run(self, *args, **kwargs):
self.usage_counter+=1
t0=perf_counter()
for result in super(TTS_Wrapper, self).run(*args, **kwargs):
yield result
t1=perf_counter()
self.usage_time += t1-t0
idle_time = self.usage_time-self.first_used_time
self.heat = self.usage_counter/idle_time
def reset_heat(self):
self.heat:int = 0
self.usage_count:int = 0
self.usage_time:float = 0.0
self.first_used_time:float = perf_counter()
class TTSInstancePool:
def __init__(self, max_size):
self.max_size:int = max_size
self.semaphore:threading.Semaphore = threading.Semaphore(max_size)
self.pool_lock:threading.Lock = threading.Lock()
self.pool:Dict[int, TTS_Wrapper] = dict()
self.current_index:int = 0
self.size:int = 0
async def acquire(self, configs:TTS_Config):
self.semaphore.acquire()
try:
with self.pool_lock:
## 查询最匹配的实例
indexed_key = None
rank = []
for key, tts_instance in self.pool.items():
if tts_instance.configs.vits_weights_path == configs.vits_weights_path \
and tts_instance.configs.t2s_weights_path == configs.t2s_weights_path:
indexed_key = key
rank.append((tts_instance.heat, key))
rank.sort(key=lambda x: x[0])
matched_key = None if len(rank)==0 else rank[0][1]
# 如果已有实例匹配,则直接复用
if indexed_key is not None:
tts_instance = self._reuse_instance(indexed_key, configs)
return tts_instance
# 如果pool未满,则创建一个新实例
if self.size < self.max_size:
tts_instance = TTS_Wrapper(configs)
self.size+=1
return tts_instance
else:
# 否则用最合适的实例进行复用
tts_instance = self._reuse_instance(matched_key, configs)
return tts_instance
except Exception as e:
self.semaphore.release()
traceback.print_exc()
raise e
async def release(self, tts_instance:TTS_Wrapper):
assert tts_instance is not None
with self.pool_lock:
key = hash(tts_instance)
if key in self.pool.keys():
return
self.pool[key]=tts_instance
self.semaphore.release()
async def clear_pool(self):
for i in range(self.max_size):
self.semaphore.acquire()
with self.pool_lock:
self.pool.clear()
# for i in range(self.max_size):
self.semaphore.release(self.max_size)
def _reuse_instance(self, instance_key:int, configs:TTS_Config)->TTS_Wrapper:
'''
复用已有实例
args:
instance_key: int, 已有实例的Key
config: TTS_Config
return:
TTS_Wrapper: 返回复用的TTS实例
'''
# 复用已有实例
tts_instance = self.pool.pop(instance_key, None)
if tts_instance is None:
raise ValueError("Instance not found")
tts_instance.configs.device= configs.device
if tts_instance.configs.vits_weights_path != configs.vits_weights_path \
or tts_instance.configs.t2s_weights_path != configs.t2s_weights_path:
tts_instance.reset_heat()
if tts_instance.configs.vits_weights_path != configs.vits_weights_path:
tts_instance.init_vits_weights(configs.vits_weights_path)
tts_instance.configs.vits_weights_path = configs.vits_weights_path
if tts_instance.configs.t2s_weights_path != configs.t2s_weights_path:
tts_instance.init_t2s_weights(configs.t2s_weights_path)
tts_instance.configs.t2s_weights_path = configs.t2s_weights_path
tts_instance.set_device(configs.device)
return tts_instance |
好的,添加了 通过 |
@KevinZhang19870314 好的,代码有些小改动你再看一下 |
每次api调用时候,都会实例化一个tts实例,其中
_init_models
方法中除了加载weights相关的模型,还重新加载了底膜,在当前的实现中属于重新加载了,如果有切换weights的需求,每次就重复加载了底膜。基于此:是否可以优化一下TTS实例,在api初始化的时候加载一次底膜,后续就不用加载了?
The text was updated successfully, but these errors were encountered: