so_vits_svc.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673
  1. # original implementation: https://github.com/svc-develop-team/so-vits-svc
  2. from __future__ import annotations
  3. import sys, logging, time, io, math, argparse, operator, numpy as np
  4. from functools import partial, reduce
  5. from pathlib import Path
  6. from typing import Tuple, Optional, Type
  7. from tinygrad import nn, dtypes, Tensor
  8. from tinygrad.helpers import getenv
  9. from tinygrad.nn.state import torch_load
  10. from examples.vits import ResidualCouplingBlock, PosteriorEncoder, Encoder, ResBlock1, ResBlock2, LRELU_SLOPE, sequence_mask, split, get_hparams_from_file, load_checkpoint, weight_norm, HParams
  11. from examples.sovits_helpers import preprocess
  12. import soundfile
  13. DEBUG = getenv("DEBUG")
  14. F0_BIN = 256
  15. F0_MAX = 1100.0
  16. F0_MIN = 50.0
  17. F0_MEL_MIN = 1127 * np.log(1 + F0_MIN / 700)
  18. F0_MEL_MAX = 1127 * np.log(1 + F0_MAX / 700)
  19. def download_if_not_present(file_path: Path, url: str):
  20. if not os.path.isfile(file_path): download_file(url, file_path)
  21. return file_path
  22. class SpeechEncoder:
  23. def __init__(self, hidden_dim, model:ContentVec): self.hidden_dim, self.model = hidden_dim, model
  24. def encode(self, ): raise NotImplementedError("implement me")
  25. @classmethod
  26. def load_from_pretrained(cls, checkpoint_path:str, checkpoint_url:str) -> ContentVec:
  27. contentvec = ContentVec.load_from_pretrained(checkpoint_path, checkpoint_url)
  28. return cls(contentvec)
  29. class ContentVec256L9(SpeechEncoder):
  30. def __init__(self, model:ContentVec): super().__init__(hidden_dim=256, model=model)
  31. def encode(self, wav: Tensor):
  32. feats = wav
  33. if len(feats.shape) == 2: # double channels
  34. feats = feats.mean(-1)
  35. assert len(feats.shape) == 1, feats.dim()
  36. feats = feats.reshape(1, -1)
  37. padding_mask = Tensor.zeros_like(feats).cast(dtypes.bool)
  38. logits = self.model.extract_features(feats.to(wav.device), padding_mask=padding_mask.to(wav.device), output_layer=9)
  39. feats = self.model.final_proj(logits[0])
  40. return feats.transpose(1,2)
  41. class ContentVec768L12(SpeechEncoder):
  42. def __init__(self, model:ContentVec): super().__init__(hidden_dim=768, model=model)
  43. def encode(self, wav: Tensor):
  44. feats = wav
  45. if len(feats.shape) == 2: # double channels
  46. feats = feats.mean(-1)
  47. assert len(feats.shape) == 1, feats.dim()
  48. feats = feats.reshape(1, -1)
  49. padding_mask = Tensor.zeros_like(feats).cast(dtypes.bool)
  50. logits = self.model.extract_features(feats.to(wav.device), padding_mask=padding_mask.to(wav.device), output_layer=12)
  51. return logits[0].transpose(1,2)
  52. # original code for contentvec: https://github.com/auspicious3000/contentvec/
  53. class ContentVec:
  54. # self.final_proj dims are hardcoded and depend on fairseq.data.dictionary Dictionary in the checkpoint. This param can't yet be loaded since there is no pickle for it. See with DEBUG=2.
  55. # This means that the ContentVec only works with the hubert weights used in all SVC models
  56. def __init__(self, cfg: HParams):
  57. self.feature_grad_mult, self.untie_final_proj = cfg.feature_grad_mult, cfg.untie_final_proj
  58. feature_enc_layers = eval(cfg.conv_feature_layers)
  59. self.embed = feature_enc_layers[-1][0]
  60. final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
  61. self.feature_extractor = ConvFeatureExtractionModel(conv_layers=feature_enc_layers, dropout=0.0, mode=cfg.extractor_mode, conv_bias=cfg.conv_bias)
  62. self.post_extract_proj = nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim else None
  63. self.encoder = TransformerEncoder(cfg)
  64. self.layer_norm = nn.LayerNorm(self.embed)
  65. self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim * 1) if self.untie_final_proj else nn.Linear(cfg.encoder_embed_dim, final_dim)
  66. self.mask_emb = Tensor.uniform(cfg.encoder_embed_dim, dtype=dtypes.float32)
  67. self.label_embs_concat = Tensor.uniform(504, final_dim, dtype=dtypes.float32)
  68. def forward_features(self, source, padding_mask):
  69. if self.feature_grad_mult > 0:
  70. features = self.feature_extractor(source, padding_mask)
  71. if self.feature_grad_mult != 1.0: pass # training: GradMultiply.forward(features, self.feature_grad_mult)
  72. else:
  73. features = self.feature_extractor(source, padding_mask)
  74. return features
  75. def forward_padding_mask(self, features, padding_mask): # replaces original forward_padding_mask for batch inference
  76. lengths_org = tilde(padding_mask.cast(dtypes.bool)).cast(dtypes.int64).sum(1) # ensure its bool for tilde
  77. lengths = (lengths_org - 400).float().div(320).floor().cast(dtypes.int64) + 1 # intermediate float to divide
  78. padding_mask = lengths_to_padding_mask(lengths)
  79. return padding_mask
  80. def extract_features(self, source: Tensor, spk_emb:Tensor=None, padding_mask=None, ret_conv=False, output_layer=None, tap=False):
  81. features = self.forward_features(source, padding_mask)
  82. if padding_mask is not None:
  83. padding_mask = self.forward_padding_mask(features, padding_mask)
  84. features = features.transpose(1, 2)
  85. features = self.layer_norm(features)
  86. if self.post_extract_proj is not None:
  87. features = self.post_extract_proj(features)
  88. x, _ = self.encoder(features, spk_emb, padding_mask=padding_mask, layer=(None if output_layer is None else output_layer - 1), tap=tap)
  89. res = features if ret_conv else x
  90. return res, padding_mask
  91. @classmethod
  92. def load_from_pretrained(cls, checkpoint_path:str, checkpoint_url:str) -> ContentVec:
  93. download_if_not_present(checkpoint_path, checkpoint_url)
  94. cfg = load_fairseq_cfg(checkpoint_path)
  95. enc = cls(cfg.model)
  96. _ = load_checkpoint_enc(checkpoint_path, enc, None)
  97. logging.debug(f"{cls.__name__}: Loaded model with cfg={cfg}")
  98. return enc
  99. class TransformerEncoder:
  100. def __init__(self, cfg: HParams):
  101. def make_conv() -> nn.Conv1d:
  102. layer = nn.Conv1d(self.embedding_dim, self.embedding_dim, kernel_size=cfg.conv_pos, padding=cfg.conv_pos // 2, groups=cfg.conv_pos_groups)
  103. std = std = math.sqrt(4 / (cfg.conv_pos * self.embedding_dim))
  104. layer.weight, layer.bias = (Tensor.normal(*layer.weight.shape, std=std)), (Tensor.zeros(*layer.bias.shape))
  105. # for training: layer.weights need to be weight_normed
  106. return layer
  107. self.dropout, self.embedding_dim, self.layer_norm_first, self.layerdrop, self.num_layers, self.num_layers_1 = cfg.dropout, cfg.encoder_embed_dim, cfg.layer_norm_first, cfg.encoder_layerdrop, cfg.encoder_layers, cfg.encoder_layers_1
  108. self.pos_conv, self.pos_conv_remove = [make_conv()], (1 if cfg.conv_pos % 2 == 0 else 0)
  109. self.layers = [
  110. TransformerEncoderLayer(self.embedding_dim, cfg.encoder_ffn_embed_dim, cfg.encoder_attention_heads, self.dropout, cfg.attention_dropout, cfg.activation_dropout, cfg.activation_fn, self.layer_norm_first, cond_layer_norm=(i >= cfg.encoder_layers))
  111. for i in range(cfg.encoder_layers + cfg.encoder_layers_1)
  112. ]
  113. self.layer_norm = nn.LayerNorm(self.embedding_dim)
  114. self.cond_layer_norm = CondLayerNorm(self.embedding_dim) if cfg.encoder_layers_1 > 0 else None
  115. # training: apply init_bert_params
  116. def __call__(self, x, spk_emb, padding_mask=None, layer=None, tap=False):
  117. x, layer_results = self.extract_features(x, spk_emb, padding_mask, layer, tap)
  118. if self.layer_norm_first and layer is None:
  119. x = self.cond_layer_norm(x, spk_emb) if (self.num_layers_1 > 0) else self.layer_norm(x)
  120. return x, layer_results
  121. def extract_features(self, x: Tensor, spk_emb: Tensor, padding_mask=None, tgt_layer=None, tap=False):
  122. if tgt_layer is not None: # and not self.training
  123. assert tgt_layer >= 0 and tgt_layer < len(self.layers)
  124. if padding_mask is not None:
  125. # x[padding_mask] = 0
  126. assert padding_mask.shape == x.shape[:len(padding_mask.shape)] # first few dims of x must match padding_mask
  127. tmp_mask = padding_mask.unsqueeze(-1).repeat((1, 1, x.shape[-1]))
  128. tmp_mask = tilde(tmp_mask.cast(dtypes.bool))
  129. x = tmp_mask.where(x, 0)
  130. x_conv = self.pos_conv[0](x.transpose(1,2))
  131. if self.pos_conv_remove > 0: x_conv = x_conv[:, :, : -self.pos_conv_remove]
  132. x_conv = x_conv.gelu().transpose(1, 2)
  133. x = (x + x_conv).transpose(0, 1) # B x T x C -> T x B x C
  134. if not self.layer_norm_first: x = self.layer_norm(x)
  135. x = x.dropout(p=self.dropout)
  136. layer_results = []
  137. r = None
  138. for i, layer in enumerate(self.layers):
  139. if i < self.num_layers: # if (not self.training or (dropout_probability > self.layerdrop)) and (i < self.num_layers):
  140. assert layer.cond_layer_norm == False
  141. x = layer(x, self_attn_padding_mask=padding_mask, need_weights=False)
  142. if tgt_layer is not None or tap:
  143. layer_results.append(x.transpose(0, 1))
  144. if i>= self.num_layers:
  145. assert layer.cond_layer_norm == True
  146. x = layer(x, emb=spk_emb, self_attn_padding_mask=padding_mask, need_weights=False)
  147. if i == tgt_layer:
  148. r = x
  149. break
  150. if r is not None:
  151. x = r
  152. x = x.transpose(0, 1) # T x B x C -> B x T x C
  153. return x, layer_results
  154. class TransformerEncoderLayer:
  155. def __init__(self, embedding_dim=768.0, ffn_embedding_dim=3072.0, num_attention_heads=8.0, dropout=0.1, attention_dropout=0.1, activation_dropout=0.1, activation_fn="relu", layer_norm_first=False, cond_layer_norm=False):
  156. def get_activation_fn(activation):
  157. if activation == "relu": return Tensor.relu
  158. if activation == "gelu": return Tensor.gelu
  159. else: raise RuntimeError(f"activation function={activation} is not forseen")
  160. self.embedding_dim, self.dropout, self.activation_dropout, self.layer_norm_first, self.num_attention_heads, self.cond_layer_norm, self.activation_fn = embedding_dim, dropout, activation_dropout, layer_norm_first, num_attention_heads, cond_layer_norm, get_activation_fn(activation_fn)
  161. self.self_attn = MultiHeadAttention(self.embedding_dim, self.num_attention_heads)
  162. self.self_attn_layer_norm = nn.LayerNorm(self.embedding_dim) if not cond_layer_norm else CondLayerNorm(self.embedding_dim)
  163. self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
  164. self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
  165. self.final_layer_norm = nn.LayerNorm(self.embedding_dim) if not cond_layer_norm else CondLayerNorm(self.embedding_dim)
  166. def __call__(self, x:Tensor, self_attn_mask:Tensor=None, self_attn_padding_mask:Tensor=None, emb:Tensor=None, need_weights=False):
  167. #self_attn_padding_mask = self_attn_padding_mask.reshape(x.shape[0], 1, 1, self_attn_padding_mask.shape[1]).expand(-1, self.num_attention_heads, -1, -1).reshape(x.shape[0] * self.num_attention_heads, 1, self_attn_padding_mask.shape[1]) if self_attn_padding_mask is not None else None
  168. assert self_attn_mask is None and self_attn_padding_mask is not None
  169. residual = x
  170. if self.layer_norm_first:
  171. x = self.self_attn_layer_norm(x) if not self.cond_layer_norm else self.self_attn_layer_norm(x, emb)
  172. x = self.self_attn(x=x, mask=self_attn_padding_mask)
  173. x = x.dropout(self.dropout)
  174. x = residual + x
  175. x = self.final_layer_norm(x) if not self.cond_layer_norm else self.final_layer_norm(x, emb)
  176. x = self.activation_fn(self.fc1(x))
  177. x = x.dropout(self.activation_dropout)
  178. x = self.fc2(x)
  179. x = x.dropout(self.dropout)
  180. x = residual + x
  181. else:
  182. x = self.self_attn(x=x, mask=self_attn_padding_mask)
  183. x = x.dropout(self.dropout)
  184. x = residual + x
  185. x = self.self_attn_layer_norm(x) if not self.cond_layer_norm else self.self_attn_layer_norm(x, emb)
  186. residual = x
  187. x = self.activation_fn(self.fc1(x))
  188. x = x.dropout(self.activation_dropout)
  189. x = self.fc2(x)
  190. x = x.dropout(self.dropout)
  191. x = residual + x
  192. x = self.final_layer_norm(x) if not self.cond_layer_norm else self.final_layer_norm(x, emb)
  193. return x
  194. class MultiHeadAttention:
  195. def __init__(self, n_state, n_head):
  196. self.n_state, self.n_head = n_state, n_head
  197. self.q_proj, self.k_proj, self.v_proj, self.out_proj = [nn.Linear(n_state, n_state) for _ in range(4)]
  198. def __call__(self, x:Tensor, xa:Optional[Tensor]=None, mask:Optional[Tensor]=None):
  199. x = x.transpose(0,1) # TxBxC -> BxTxC
  200. q, k, v = self.q_proj(x), self.k_proj(xa or x), self.v_proj(xa or x)
  201. q, k, v = [x.reshape(*q.shape[:2], self.n_head, -1) for x in (q, k, v)]
  202. wv = Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), None).transpose(1, 2).reshape(*x.shape[:2], -1)
  203. ret = self.out_proj(wv).transpose(0,1) # BxTxC -> TxBxC
  204. return ret
  205. class ConvFeatureExtractionModel:
  206. def __init__(self, conv_layers, dropout=.0, mode="default", conv_bias=False):
  207. assert mode in {"default", "group_norm_masked", "layer_norm"}
  208. def block(n_in, n_out, k, stride, is_layer_norm=False, is_group_norm=False, conv_bias=False):
  209. def make_conv():
  210. conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
  211. conv.weight = Tensor.kaiming_normal(*conv.weight.shape)
  212. return conv
  213. assert (is_layer_norm and is_group_norm) == False, "layer norm and group norm are exclusive"
  214. if is_layer_norm:
  215. return [make_conv(), partial(Tensor.dropout, p=dropout),[partial(Tensor.transpose, dim0=-2, dim1=-1), nn.LayerNorm(dim, elementwise_affine=True), partial(Tensor.transpose, dim0=-2, dim1=-1)], Tensor.gelu]
  216. elif is_group_norm and mode == "default":
  217. return [make_conv(), partial(Tensor.dropout, p=dropout), nn.GroupNorm(dim, dim, affine=True), Tensor.gelu]
  218. elif is_group_norm and mode == "group_norm_masked":
  219. return [make_conv(), partial(Tensor.dropout, p=dropout), GroupNormMasked(dim, dim, affine=True), Tensor.gelu]
  220. else:
  221. return [make_conv(), partial(Tensor.dropout, p=dropout), Tensor.gelu]
  222. in_d, self.conv_layers, self.mode = 1, [], mode
  223. for i, cl in enumerate(conv_layers):
  224. assert len(cl) == 3, "invalid conv definition: " + str(cl)
  225. (dim, k, stride) = cl
  226. if i == 0: self.cl = cl
  227. self.conv_layers.append(block(in_d, dim, k, stride, is_layer_norm=(mode == "layer_norm"), is_group_norm=((mode == "default" or mode == "group_norm_masked") and i == 0), conv_bias=conv_bias))
  228. in_d = dim
  229. def __call__(self, x:Tensor, padding_mask:Tensor):
  230. x = x.unsqueeze(1) # BxT -> BxCxT
  231. if self.mode == "group_norm_masked":
  232. if padding_mask is not None:
  233. _, k, stride = self.cl
  234. lengths_org = tilde(padding_mask.cast(dtypes.bool)).cast(dtypes.int64).sum(1) # ensure padding_mask is bool for tilde
  235. lengths = (((lengths_org - k) / stride) + 1).floor().cast(dtypes.int64)
  236. padding_mask = tilde(lengths_to_padding_mask(lengths)).cast(dtypes.int64) # lengths_to_padding_mask returns bool tensor
  237. x = self.conv_layers[0][0](x) # padding_mask is numeric
  238. x = self.conv_layers[0][1](x)
  239. x = self.conv_layers[0][2](x, padding_mask)
  240. x = self.conv_layers[0][3](x)
  241. else:
  242. x = x.sequential(self.conv_layers[0]) # default
  243. for _, conv in enumerate(self.conv_layers[1:], start=1):
  244. conv = reduce(lambda a,b: operator.iconcat(a,b if isinstance(b, list) else [b]), conv, []) # flatten
  245. x = x.sequential(conv)
  246. return x
  247. class CondLayerNorm: # https://github.com/auspicious3000/contentvec/blob/main/contentvec/modules/cond_layer_norm.py#L10
  248. def __init__(self, dim_last, eps=1e-5, dim_spk=256, elementwise_affine=True):
  249. self.dim_last, self.eps, self.dim_spk, self.elementwise_affine = dim_last, eps, dim_spk, elementwise_affine
  250. if self.elementwise_affine:
  251. self.weight_ln = nn.Linear(self.dim_spk, self.dim_last, bias=False)
  252. self.bias_ln = nn.Linear(self.dim_spk, self.dim_last, bias=False)
  253. self.weight_ln.weight, self.bias_ln.weight = (Tensor.ones(*self.weight_ln.weight.shape)), (Tensor.zeros(*self.bias_ln.weight.shape))
  254. def __call__(self, x: Tensor, spk_emb: Tensor):
  255. axis = tuple(-1-i for i in range(len(x.shape[1:])))
  256. x = x.layernorm(axis=axis, eps=self.eps)
  257. if not self.elementwise_affine: return x
  258. weights, bias = self.weight_ln(spk_emb), self.bias_ln(spk_emb)
  259. return weights * x + bias
  260. class GroupNormMasked: # https://github.com/auspicious3000/contentvec/blob/d746688a32940f4bee410ed7c87ec9cf8ff04f74/contentvec/modules/fp32_group_norm.py#L16
  261. def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
  262. self.num_groups, self.num_channels, self.eps, self.affine = num_groups, num_channels, eps, affine
  263. self.weight, self.bias = (Tensor.ones(num_channels)), (Tensor.zeros(num_channels)) if self.affine else (None, None)
  264. def __call__(self, x:Tensor, mask:Tensor):
  265. bsz, n_c, length = x.shape
  266. assert n_c % self.num_groups == 0
  267. x = x.reshape(bsz, self.num_groups, n_c // self.num_groups, length)
  268. if mask is None: mask = Tensor.ones_like(x)
  269. else: mask = mask.reshape(bsz, 1, 1, length)
  270. x = x * mask
  271. lengths = mask.sum(axis=3, keepdim=True)
  272. assert x.shape[2] == 1
  273. mean_ = x.mean(dim=3, keepdim=True)
  274. mean = mean_ * length / lengths
  275. var = (((x.std(axis=3, keepdim=True) ** 2) + mean_**2) * length / lengths - mean**2) + self.eps
  276. return x.add(-mean).div(var.sqrt()).reshape(bsz, n_c, length).mul(self.weight.reshape(1,-1,1)).add(self.bias.reshape(1,-1,1))
  277. class Synthesizer:
  278. def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels, ssl_dim, n_speakers, sampling_rate=44100, vol_embedding=False, n_flow_layer=4, **kwargs):
  279. self.spec_channels, self.inter_channels, self.hidden_channels, self.filter_channels, self.n_heads, self.n_layers, self.kernel_size, self.p_dropout, self.resblock, self.resblock_kernel_sizes, self.resblock_dilation_sizes, self.upsample_rates, self.upsample_initial_channel, self.upsample_kernel_sizes, self.segment_size, self.n_speakers, self.gin_channels, self.vol_embedding = spec_channels, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, segment_size, n_speakers, gin_channels, vol_embedding
  280. self.emb_g = nn.Embedding(n_speakers, gin_channels)
  281. if vol_embedding: self.emb_vol = nn.Linear(1, hidden_channels)
  282. self.pre = nn.Conv1d(ssl_dim, hidden_channels, kernel_size=5, padding=2)
  283. self.enc_p = TextEncoder(inter_channels, hidden_channels, kernel_size, n_layers, filter_channels=filter_channels, n_heads=n_heads, p_dropout=p_dropout)
  284. self.dec = Generator(sampling_rate, inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels)
  285. self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
  286. self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, n_flow_layer, gin_channels=gin_channels)
  287. self.emb_uv = nn.Embedding(vocab_size=2, embed_size=hidden_channels)
  288. def infer(self, c:Tensor, f0:Tensor, uv:Tensor, g:Tensor=None, noise_scale=0.35, seed=52468, vol=None) -> Tuple[Tensor, Tensor]:
  289. Tensor.manual_seed(getenv('SEED', seed))
  290. c_lengths = (Tensor.ones([c.shape[0]]) * c.shape[-1]).to(c.device)
  291. if len(g.shape) == 1: g = g.unsqueeze(0)
  292. g = self.emb_g(g).transpose(1, 2)
  293. x_mask = sequence_mask(c_lengths, c.shape[2]).unsqueeze(1).cast(c.dtype)
  294. vol = self.emb_vol(vol[:,:,None]).transpose(1,2) if vol is not None and self.vol_embedding else 0
  295. x = self.pre(c) * x_mask + self.emb_uv(uv.cast(dtypes.int64)).transpose(1, 2) + vol
  296. z_p, _, _, c_mask = self.enc_p.forward(x, x_mask, f0=self._f0_to_coarse(f0), noise_scale=noise_scale)
  297. z = self.flow.forward(z_p, c_mask, g=g, reverse=True)
  298. o = self.dec.forward(z * c_mask, g=g, f0=f0)
  299. return o,f0
  300. def _f0_to_coarse(self, f0 : Tensor):
  301. f0_mel = 1127 * (1 + f0 / 700).log()
  302. a = (F0_BIN - 2) / (F0_MEL_MAX - F0_MEL_MIN)
  303. b = F0_MEL_MIN * a - 1.
  304. f0_mel = (f0_mel > 0).where(f0_mel * a - b, f0_mel)
  305. f0_coarse = f0_mel.ceil().cast(dtype=dtypes.int64)
  306. f0_coarse = f0_coarse * (f0_coarse > 0)
  307. f0_coarse = f0_coarse + ((f0_coarse < 1) * 1)
  308. f0_coarse = f0_coarse * (f0_coarse < F0_BIN)
  309. f0_coarse = f0_coarse + ((f0_coarse >= F0_BIN) * (F0_BIN - 1))
  310. return f0_coarse
  311. @classmethod
  312. def load_from_pretrained(cls, config_path:str, config_url:str, weights_path:str, weights_url:str) -> Synthesizer:
  313. download_if_not_present(config_path, config_url)
  314. hps = get_hparams_from_file(config_path)
  315. download_if_not_present(weights_path, weights_url)
  316. net_g = cls(hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, **hps.model)
  317. _ = load_checkpoint(weights_path, net_g, None, skip_list=["f0_decoder"])
  318. logging.debug(f"{cls.__name__}:Loaded model with hps: {hps}")
  319. return net_g, hps
  320. class TextEncoder:
  321. def __init__(self, out_channels, hidden_channels, kernel_size, n_layers, gin_channels=0, filter_channels=None, n_heads=None, p_dropout=None):
  322. self.out_channels, self.hidden_channels, self.kernel_size, self.n_layers, self.gin_channels = out_channels, hidden_channels, kernel_size, n_layers, gin_channels
  323. self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
  324. self.f0_emb = nn.Embedding(256, hidden_channels) # n_vocab = 256
  325. self.enc_ = Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout)
  326. def forward(self, x, x_mask, f0=None, noise_scale=1):
  327. x = x + self.f0_emb(f0).transpose(1, 2)
  328. x = self.enc_.forward(x * x_mask, x_mask)
  329. stats = self.proj(x) * x_mask
  330. m, logs = split(stats, self.out_channels, dim=1)
  331. z = (m + randn_like(m) * logs.exp() * noise_scale) * x_mask
  332. return z, m, logs, x_mask
  333. class Upsample:
  334. def __init__(self, scale_factor):
  335. assert scale_factor % 1 == 0, "Only integer scale factor allowed."
  336. self.scale = int(scale_factor)
  337. def forward(self, x:Tensor):
  338. repeats = tuple([1] * len(x.shape) + [self.scale])
  339. new_shape = (*x.shape[:-1], x.shape[-1] * self.scale)
  340. return x.unsqueeze(-1).repeat(repeats).reshape(new_shape)
  341. class SineGen:
  342. def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voice_threshold=0, flag_for_pulse=False):
  343. self.sine_amp, self.noise_std, self.harmonic_num, self.sampling_rate, self.voiced_threshold, self.flag_for_pulse = sine_amp, noise_std, harmonic_num, samp_rate, voice_threshold, flag_for_pulse
  344. self.dim = self.harmonic_num + 1
  345. def _f02uv(self, f0): return (f0 > self.voiced_threshold).float() #generate uv signal
  346. def _f02sine(self, f0_values):
  347. def padDiff(x : Tensor): return (x.pad2d((0,0,-1,1)) - x).pad2d((0,0,0,-1))
  348. def mod(x: Tensor, n: int) -> Tensor: return x - n * x.div(n).floor() # this is what the % operator does in pytorch.
  349. rad_values = mod((f0_values / self.sampling_rate) , 1) # convert to F0 in rad
  350. rand_ini = Tensor.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device) # initial phase noise
  351. #rand_ini[:, 0] = 0
  352. m = Tensor.ones(f0_values.shape[0]).unsqueeze(1).pad2d((0,f0_values.shape[2]-1,0,0)).cast(dtypes.bool)
  353. m = tilde(m)
  354. rand_ini = m.where(rand_ini, 0)
  355. #rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
  356. tmp = rad_values[:, 0, :] + rand_ini
  357. m = Tensor.ones(tmp.shape).pad2d((0,0,0,rad_values.shape[1]-1,0)).cast(dtypes.bool)
  358. m = tilde(m)
  359. tmp = tmp.unsqueeze(1).pad2d((0,0,0,rad_values.shape[1]-1,0))
  360. rad_values = m.where(rad_values, tmp)
  361. tmp_over_one = mod(rad_values.cumsum(1), 1)
  362. tmp_over_one_idx = padDiff(tmp_over_one) < 0
  363. cumsum_shift = Tensor.zeros_like(rad_values)
  364. #cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
  365. tmp_over_one_idx = (tmp_over_one_idx * -1.0).pad2d((0,0,1,0))
  366. cumsum_shift = tmp_over_one_idx
  367. sines = ((rad_values + cumsum_shift).cumsum(1) * 2 * np.pi).sin()
  368. return sines
  369. def forward(self, f0, upp=None):
  370. fn = f0.mul(Tensor([[range(1, self.harmonic_num + 2)]], dtype=dtypes.float32).to(f0.device))
  371. sine_waves = self._f02sine(fn) * self.sine_amp #generate sine waveforms
  372. uv = self._f02uv(f0) # generate uv signal
  373. noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
  374. noise = noise_amp * randn_like(sine_waves)
  375. sine_waves = sine_waves * uv + noise
  376. return sine_waves, uv, noise
  377. class SourceHnNSF:
  378. def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshold=0):
  379. self.sine_amp, self.noise_std = sine_amp, add_noise_std
  380. self.l_sin_gen = SineGen(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshold)
  381. self.l_linear = nn.Linear(harmonic_num + 1, 1)
  382. def forward(self, x, upp=None):
  383. sine_waves, uv, _ = self.l_sin_gen.forward(x, upp)
  384. sine_merge = self.l_linear(sine_waves.cast(self.l_linear.weight.dtype)).tanh()
  385. noise = randn_like(uv) * self.sine_amp / 3
  386. return sine_merge, noise, uv
  387. # most of the hifigan in standard vits is reused here, but need to upsample and construct harmonic source from f0
  388. class Generator:
  389. def __init__(self, sampling_rate, inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels):
  390. self.sampling_rate, self.inter_channels, self.resblock, self.resblock_kernel_sizes, self.resblock_dilation_sizes, self.upsample_rates, self.upsample_initial_channel, self.upsample_kernel_sizes, self.gin_channels = sampling_rate, inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels
  391. self.num_kernels, self.num_upsamples = len(resblock_kernel_sizes), len(upsample_rates)
  392. self.conv_pre = nn.Conv1d(inter_channels, upsample_initial_channel, 7, 1, padding=3)
  393. self.f0_upsamp = Upsample(scale_factor=np.prod(upsample_rates))
  394. self.m_source = SourceHnNSF(sampling_rate, harmonic_num=8)
  395. resblock = ResBlock1 if resblock == '1' else ResBlock2
  396. self.ups, self.noise_convs, self.resblocks = [], [], []
  397. for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
  398. c_cur = upsample_initial_channel//(2**(i+1))
  399. self.ups.append(nn.ConvTranspose1d(upsample_initial_channel//(2**i), c_cur, k, u, padding=(k-u)//2))
  400. stride_f0 = int(np.prod(upsample_rates[i + 1:]))
  401. self.noise_convs.append(nn.Conv1d(1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2) if (i + 1 < len(upsample_rates)) else nn.Conv1d(1, c_cur, kernel_size=1))
  402. for i in range(len(self.ups)):
  403. ch = upsample_initial_channel // (2 ** (i + 1))
  404. for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
  405. self.resblocks.append(resblock(ch, k, d))
  406. self.conv_post = nn.Conv1d(ch, 1, 7, 1, padding=3)
  407. if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
  408. self.upp = np.prod(upsample_rates)
  409. def forward(self, x, f0, g=None):
  410. f0 = self.f0_upsamp.forward(f0[:, None]).transpose(1, 2) # bs,n,t
  411. har_source, _, _ = self.m_source.forward(f0, self.upp)
  412. har_source = har_source.transpose(1, 2)
  413. x = self.conv_pre(x)
  414. if g is not None: x = x + self.cond(g)
  415. for i in range(self.num_upsamples):
  416. x, xs = self.ups[i](x.leakyrelu(LRELU_SLOPE)), None
  417. x_source = self.noise_convs[i](har_source)
  418. x = x + x_source
  419. for j in range(self.num_kernels):
  420. if xs is None: xs = self.resblocks[i * self.num_kernels + j].forward(x)
  421. else: xs += self.resblocks[i * self.num_kernels + j].forward(x)
  422. x = xs / self.num_kernels
  423. return self.conv_post(x.leakyrelu()).tanh()
  424. # **** helpers ****
  425. def randn_like(x:Tensor) -> Tensor: return Tensor.randn(*x.shape, dtype=x.dtype).to(device=x.device)
  426. def tilde(x: Tensor) -> Tensor:
  427. if x.dtype == dtypes.bool: return (1 - x).cast(dtypes.bool)
  428. return (x + 1) * -1 # this seems to be what the ~ operator does in pytorch for non bool
  429. def lengths_to_padding_mask(lens:Tensor) -> Tensor:
  430. bsz, max_lens = lens.shape[0], lens.max().numpy().item()
  431. mask = Tensor.arange(max_lens).to(lens.device).reshape(1, max_lens)
  432. mask = mask.expand(bsz, -1) >= lens.reshape(bsz, 1).expand(-1, max_lens)
  433. return mask.cast(dtypes.bool)
  434. def repeat_expand_2d_left(content, target_len): # content : [h, t]
  435. src_len = content.shape[-1]
  436. temp = np.arange(src_len+1) * target_len / src_len
  437. current_pos, cols = 0, []
  438. for i in range(target_len):
  439. if i >= temp[current_pos+1]:
  440. current_pos += 1
  441. cols.append(content[:, current_pos])
  442. return Tensor.stack(*cols).transpose(0, 1)
  443. def load_fairseq_cfg(checkpoint_path):
  444. assert Path(checkpoint_path).is_file()
  445. state = torch_load(checkpoint_path)
  446. cfg = state["cfg"] if ("cfg" in state and state["cfg"] is not None) else None
  447. if cfg is None: raise RuntimeError(f"No cfg exist in state keys = {state.keys()}")
  448. return HParams(**cfg)
  449. def load_checkpoint_enc(checkpoint_path, model: ContentVec, optimizer=None, skip_list=[]):
  450. assert Path(checkpoint_path).is_file()
  451. start_time = time.time()
  452. checkpoint_dict = torch_load(checkpoint_path)
  453. saved_state_dict = checkpoint_dict['model']
  454. weight_g, weight_v, parent = None, None, None
  455. for key, v in saved_state_dict.items():
  456. if any(layer in key for layer in skip_list): continue
  457. try:
  458. obj, skip = model, False
  459. for k in key.split('.'):
  460. if k.isnumeric(): obj = obj[int(k)]
  461. elif isinstance(obj, dict): obj = obj[k]
  462. else:
  463. if k in ["weight_g", "weight_v"]:
  464. parent, skip = obj, True
  465. if k == "weight_g": weight_g = v
  466. else: weight_v = v
  467. if not skip:
  468. parent = obj
  469. obj = getattr(obj, k)
  470. if weight_g and weight_v:
  471. setattr(obj, "weight_g", weight_g.numpy())
  472. setattr(obj, "weight_v", weight_v.numpy())
  473. obj, v = getattr(parent, "weight"), weight_norm(weight_v, weight_g, 0)
  474. weight_g, weight_v, parent, skip = None, None, None, False
  475. if not skip and obj.shape == v.shape:
  476. if "feature_extractor" in key and (isinstance(parent, nn.GroupNorm) or isinstance(parent, nn.LayerNorm)): # cast
  477. obj.assign(v.to(obj.device).float())
  478. else:
  479. obj.assign(v.to(obj.device))
  480. elif not skip: logging.error(f"MISMATCH SHAPE IN {key}, {obj.shape} {v.shape}")
  481. except Exception as e: raise e
  482. logging.info(f"Loaded checkpoint '{checkpoint_path}' in {time.time() - start_time:.4f}s")
  483. return model, optimizer
  484. def pad_array(arr, target_length):
  485. current_length = arr.shape[0]
  486. if current_length >= target_length: return arr
  487. pad_width = target_length - current_length
  488. pad_left = pad_width // 2
  489. pad_right = pad_width - pad_left
  490. padded_arr = np.pad(arr, (pad_left, pad_right), 'constant', constant_values=(0, 0))
  491. return padded_arr
  492. def split_list_by_n(list_collection, n, pre=0):
  493. for i in range(0, len(list_collection), n):
  494. yield list_collection[i-pre if i-pre>=0 else i: i + n]
  495. def get_sid(spk2id:HParams, speaker:str) -> Tensor:
  496. speaker_id = spk2id[speaker]
  497. if not speaker_id and type(speaker) is int:
  498. if len(spk2id.__dict__) >= speaker: speaker_id = speaker
  499. if speaker_id is None: raise RuntimeError(f"speaker={speaker} not in the speaker list")
  500. return Tensor([int(speaker_id)], dtype=dtypes.int64).unsqueeze(0)
  501. def get_encoder(ssl_dim) -> Type[SpeechEncoder]:
  502. if ssl_dim == 256: return ContentVec256L9
  503. if ssl_dim == 768: return ContentVec768L12
  504. #########################################################################################
  505. # CODE: https://github.com/svc-develop-team/so-vits-svc
  506. #########################################################################################
  507. # CONTENTVEC:
  508. # CODE: https://github.com/auspicious3000/contentvec
  509. # PAPER: https://arxiv.org/abs/2204.09224
  510. #########################################################################################
  511. # INSTALLATION: dependencies are for preprocessing and loading/saving audio.
  512. # pip3 install soundfile librosa praat-parselmouth
  513. #########################################################################################
  514. # EXAMPLE USAGE:
  515. # python3 examples/so_vits_svc.py --model tf2spy --file ~/recording.wav
  516. #########################################################################################
  517. # DEMO USAGE (uses audio sample from LJ-Speech):
  518. # python3 examples/so_vits_svc.py --model saul_goodman
  519. #########################################################################################
  520. SO_VITS_SVC_PATH = Path(__file__).parents[1] / "weights/So-VITS-SVC"
  521. VITS_MODELS = { # config_path, weights_path, config_url, weights_url
  522. "saul_goodman" : (SO_VITS_SVC_PATH / "config_saul_gman.json", SO_VITS_SVC_PATH / "pretrained_saul_gman.pth", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/Saul_Goodman_80000/config.json", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/Saul_Goodman_80000/G_80000.pth"),
  523. "drake" : (SO_VITS_SVC_PATH / "config_drake.json", SO_VITS_SVC_PATH / "pretrained_drake.pth", "https://huggingface.co/jaspa/so-vits-svc/resolve/main/aubrey/config_aubrey.json", "https://huggingface.co/jaspa/so-vits-svc/resolve/main/aubrey/pretrained_aubrey.pth"),
  524. "cartman" : (SO_VITS_SVC_PATH / "config_cartman.json", SO_VITS_SVC_PATH / "pretrained_cartman.pth", "https://huggingface.co/marcoc2/so-vits-svc-4.0-models/resolve/main/EricCartman/config.json", "https://huggingface.co/marcoc2/so-vits-svc-4.0-models/resolve/main/EricCartman/G_10200.pth"),
  525. "tf2spy" : (SO_VITS_SVC_PATH / "config_tf2spy.json", SO_VITS_SVC_PATH / "pretrained_tf2spy.pth", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/TF2_spy_60k/config.json", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/TF2_spy_60k/G_60000.pth"),
  526. "tf2heavy" : (SO_VITS_SVC_PATH / "config_tf2heavy.json", SO_VITS_SVC_PATH / "pretrained_tf2heavy.pth", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/TF2_heavy_100k/config.json", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/TF2_heavy_100k/G_100000.pth"),
  527. "lady_gaga" : (SO_VITS_SVC_PATH / "config_gaga.json", SO_VITS_SVC_PATH / "pretrained_gaga.pth", "https://huggingface.co/marcoc2/so-vits-svc-4.0-models/resolve/main/LadyGaga/config.json", "https://huggingface.co/marcoc2/so-vits-svc-4.0-models/resolve/main/LadyGaga/G_14400.pth")
  528. }
  529. ENCODER_MODELS = { # weights_path, weights_url
  530. "contentvec": (SO_VITS_SVC_PATH / "contentvec_checkpoint.pt", "https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/hubert_base.pt")
  531. }
  532. ENCODER_MODEL = "contentvec"
  533. DEMO_PATH, DEMO_URL = Path(__file__).parents[1] / "temp/LJ037-0171.wav", "https://keithito.com/LJ-Speech-Dataset/LJ037-0171.wav"
  534. if __name__=="__main__":
  535. logging.basicConfig(stream=sys.stdout, level=(logging.INFO if DEBUG < 1 else logging.DEBUG))
  536. parser = argparse.ArgumentParser()
  537. parser.add_argument("-m", "--model", default=None, help=f"Specify the model to use. All supported models: {VITS_MODELS.keys()}", required=True)
  538. parser.add_argument("-f", "--file", default=DEMO_PATH, help=f"Specify the path of the input file")
  539. parser.add_argument("--out_dir", default=str(Path(__file__).parents[1] / "temp"), help="Specify the output path.")
  540. parser.add_argument("--out_path", default=None, help="Specify the full output path. Overrides the --out_dir and --name parameter.")
  541. parser.add_argument("--base_name", default="test", help="Specify the base of the output file name. Default is 'test'.")
  542. parser.add_argument("--speaker", default=None, help="If not specified, the first available speaker is chosen. Usually there is only one speaker per model.")
  543. parser.add_argument("--noise_scale", default=0.4)
  544. parser.add_argument("--tran", default=0.0, help="Pitch shift, supports positive and negative (semitone) values. Default 0.0")
  545. parser.add_argument("--pad_seconds", default=0.5)
  546. parser.add_argument("--lg_num", default=0.0)
  547. parser.add_argument("--clip_seconds", default=0.0)
  548. parser.add_argument("--slice_db", default=-40)
  549. args = parser.parse_args()
  550. vits_model = args.model
  551. encoder_location, vits_location = ENCODER_MODELS[ENCODER_MODEL], VITS_MODELS[vits_model]
  552. Tensor.no_grad, Tensor.training = True, False
  553. # Get Synthesizer and ContentVec
  554. net_g, hps = Synthesizer.load_from_pretrained(vits_location[0], vits_location[2], vits_location[1], vits_location[3])
  555. Encoder = get_encoder(hps.model.ssl_dim)
  556. encoder = Encoder.load_from_pretrained(encoder_location[0], encoder_location[1])
  557. # model config args
  558. target_sample, spk2id, hop_length, target_sample = hps.data.sampling_rate, hps.spk, hps.data.hop_length, hps.data.sampling_rate
  559. vol_embedding = hps.model.vol_embedding if hasattr(hps.data, "vol_embedding") and hps.model.vol_embedding is not None else False
  560. # args
  561. slice_db, clip_seconds, lg_num, pad_seconds, tran, noise_scale, audio_path = args.slice_db, args.clip_seconds, args.lg_num, args.pad_seconds, args.tran, args.noise_scale, args.file
  562. speaker = args.speaker if args.speaker is not None else list(hps.spk.__dict__.keys())[0]
  563. ### Loading audio and slicing ###
  564. if audio_path == DEMO_PATH: download_if_not_present(DEMO_PATH, DEMO_URL)
  565. assert Path(audio_path).is_file() and Path(audio_path).suffix == ".wav"
  566. chunks = preprocess.cut(audio_path, db_thresh=slice_db)
  567. audio_data, audio_sr = preprocess.chunks2audio(audio_path, chunks)
  568. per_size = int(clip_seconds * audio_sr)
  569. lg_size = int(lg_num * audio_sr)
  570. ### Infer per slice ###
  571. global_frame = 0
  572. audio = []
  573. for (slice_tag, data) in audio_data:
  574. print(f"\n====segment start, {round(len(data) / audio_sr, 3)}s====")
  575. length = int(np.ceil(len(data) / audio_sr * target_sample))
  576. if slice_tag:
  577. print("empty segment")
  578. _audio = np.zeros(length)
  579. audio.extend(list(pad_array(_audio, length)))
  580. global_frame += length // hop_length
  581. continue
  582. datas = [data] if per_size == 0 else split_list_by_n(data, per_size, lg_size)
  583. for k, dat in enumerate(datas):
  584. per_length = int(np.ceil(len(dat) / audio_sr * target_sample)) if clip_seconds!=0 else length
  585. pad_len = int(audio_sr * pad_seconds)
  586. dat = np.concatenate([np.zeros([pad_len]), dat, np.zeros([pad_len])])
  587. raw_path = io.BytesIO()
  588. soundfile.write(raw_path, dat, audio_sr, format="wav")
  589. raw_path.seek(0)
  590. ### Infer START ###
  591. wav, sr = preprocess.load_audiofile(raw_path)
  592. wav = preprocess.sinc_interp_resample(wav, sr, target_sample)[0]
  593. wav16k, f0, uv = preprocess.get_unit_f0(wav, tran, hop_length, target_sample)
  594. sid = get_sid(spk2id, speaker)
  595. n_frames = f0.shape[1]
  596. # ContentVec infer
  597. start = time.time()
  598. c = encoder.encode(wav16k)
  599. c = repeat_expand_2d_left(c.squeeze(0).realize(), f0.shape[1]) # interpolate speech encoding to match f0
  600. c = c.unsqueeze(0).realize()
  601. enc_time = time.time() - start
  602. # VITS infer
  603. vits_start = time.time()
  604. out_audio, f0 = net_g.infer(c, f0=f0, uv=uv, g=sid, noise_scale=noise_scale, vol=None)
  605. out_audio = out_audio[0,0].float().realize()
  606. vits_time = time.time() - vits_start
  607. infer_time = time.time() - start
  608. logging.info("total infer time:{:.2f}s, speech_enc time:{:.2f}s, vits time:{:.2f}s".format(infer_time, enc_time, vits_time))
  609. ### Infer END ###
  610. out_sr, out_frame = out_audio.shape[-1], n_frames
  611. global_frame += out_frame
  612. _audio = out_audio.numpy()
  613. pad_len = int(target_sample * pad_seconds)
  614. _audio = _audio[pad_len:-pad_len]
  615. _audio = pad_array(_audio, per_length)
  616. audio.extend(list(_audio))
  617. audio = np.array(audio)
  618. out_path = Path(args.out_path or Path(args.out_dir)/f"{args.model}{f'_spk_{speaker}'}_{args.base_name}.wav")
  619. out_path.parent.mkdir(parents=True, exist_ok=True)
  620. soundfile.write(out_path, audio, target_sample, format="flac")
  621. logging.info(f"Saved audio output to {out_path}")