185 lines
No EOL
6.2 KiB
Python
185 lines
No EOL
6.2 KiB
Python
import os
|
|
import time
|
|
import json
|
|
import logging
|
|
import threading
|
|
#from requests_threads import AsyncSession
|
|
from hashlib import sha1
|
|
import asyncio
|
|
from tornado.httpclient import AsyncHTTPClient, HTTPRequest
|
|
|
|
mainLogger = logging.getLogger("hugvey")
|
|
logger = mainLogger.getChild("voice")
|
|
|
|
class VoiceStorage(object):
|
|
"""
|
|
Store & keep voices that are not part of the story json
|
|
"""
|
|
def __init__(self, cache_dir, languageConfig):
|
|
self.cache_dir = cache_dir
|
|
if not os.path.exists(self.cache_dir):
|
|
raise Exception(f"Cache dir does not exists: {self.cache_dir}")
|
|
# self.request_session = AsyncSession(n=5)
|
|
self.pendingRequests = {}
|
|
self.languages = languageConfig
|
|
self.fetchers = {}
|
|
|
|
for lang in self.languages:
|
|
cls = VoiceFetcher.getClass(self.languages[lang]['type'])
|
|
self.fetchers[lang] = cls(self.languages[lang])
|
|
|
|
def getId(self, lang_code, text):
|
|
"""
|
|
Get a unique id based on text and the voice token.
|
|
|
|
So changing the voice or text triggers a re-download.
|
|
"""
|
|
return sha1((f"{lang_code}:{self.languages[lang_code]['token']}:{text}").encode()).hexdigest()
|
|
|
|
def getFilename(self, lang_code, text, isVariable=False):
|
|
subdir = 'static' if not isVariable else 'variable'
|
|
id = self.getId(lang_code, text)
|
|
prefix = id[:2]
|
|
storageDir = os.path.join(self.cache_dir, lang_code, subdir, prefix)
|
|
fn = os.path.join(storageDir, f"{id}.wav")
|
|
return fn
|
|
|
|
async def requestFile(self, lang_code, text, isVariable=False) -> str:
|
|
id = self.getId(lang_code, text)
|
|
fn = self.getFilename(lang_code, text, isVariable)
|
|
|
|
if os.path.exists(fn):
|
|
return fn
|
|
|
|
if id in self.pendingRequests and not self.pendingRequests[id].is_set():
|
|
#: :type self.pendingRequests[id] asyncio.Event
|
|
await self.pendingRequests[id].wait()
|
|
if os.path.exists(fn):
|
|
return fn
|
|
return None
|
|
|
|
dirname = os.path.dirname(fn)
|
|
if not os.path.exists(dirname):
|
|
logger.debug(f"create directory for file: {dirname}")
|
|
os.makedirs(dirname, exist_ok=True)
|
|
|
|
self.pendingRequests[id] = asyncio.Event()
|
|
|
|
|
|
try:
|
|
contents = await self.fetchers[lang_code].requestVoiceFile(text)
|
|
except Exception as e:
|
|
logger.exception(e)
|
|
self.pendingRequests[id].set()
|
|
return None
|
|
|
|
with open(fn, "wb") as f:
|
|
logger.debug(f"Write file for {lang_code}: {text}")
|
|
f.write(contents)
|
|
self.pendingRequests[id].set()
|
|
# print(type(fn), fn)
|
|
|
|
return fn
|
|
|
|
class VoiceFetcher():
|
|
def __init__(self, config):
|
|
self.config = config
|
|
|
|
async def requestVoiceFile(self, text):
|
|
pass
|
|
|
|
@classmethod
|
|
def getClass(cls, type):
|
|
if type == "lyrebird":
|
|
return LyrebirdVoiceFetcher
|
|
if type == "ms":
|
|
return MSVoiceFetcher
|
|
raise Exception(f"Unknown voice type: {type}")
|
|
|
|
class LyrebirdVoiceFetcher(VoiceFetcher):
|
|
async def requestVoiceFile(self, text):
|
|
http_client = AsyncHTTPClient()
|
|
request = HTTPRequest(
|
|
method="POST",
|
|
url="https://avatar.lyrebird.ai/api/v0/generate",
|
|
body=json.dumps({"text": text}),
|
|
headers={"authorization": f"Bearer {self.config['token']}"}
|
|
)
|
|
try:
|
|
response = await http_client.fetch(request)
|
|
except Exception as e:
|
|
logger.critical(f"Exception when getting Lyrebird voice file: POST {request.url} body: {request.body}")
|
|
http_client.close()
|
|
raise e
|
|
|
|
if response.code != 200:
|
|
raise Exception(f"No proper response! {response.code}")
|
|
|
|
return response.body
|
|
|
|
class MSVoiceFetcher(VoiceFetcher):
|
|
def __init__(self, config):
|
|
self.config = config
|
|
self.timer = 0
|
|
self.access_token = None
|
|
|
|
async def getToken(self):
|
|
now = time.time()
|
|
if now - self.timer > 8 * 60: # token expires after 10 min. Use 8 to be sure
|
|
headers = {
|
|
'Ocp-Apim-Subscription-Key': self.config['token']
|
|
}
|
|
http_client = AsyncHTTPClient()
|
|
request = HTTPRequest(
|
|
method="POST",
|
|
url=self.config['token_url'],
|
|
headers=headers,
|
|
allow_nonstandard_methods=True
|
|
)
|
|
logger.debug(f"{request.method} {request.url} {request.headers}")
|
|
try:
|
|
response = await http_client.fetch(request)
|
|
except Exception as e:
|
|
http_client.close()
|
|
raise e
|
|
self.access_token = response.body.decode()
|
|
self.timer = time.time()
|
|
http_client.close()
|
|
|
|
return self.access_token
|
|
|
|
async def requestVoiceFile(self, text):
|
|
|
|
headers = {
|
|
'Authorization': 'Bearer ' + await self.getToken(),
|
|
'Content-Type': 'application/ssml+xml',
|
|
'X-Microsoft-OutputFormat': 'riff-24khz-16bit-mono-pcm',
|
|
# 'User-Agent': 'YOUR_RESOURCE_NAME'
|
|
}
|
|
body = f"""<speak version='1.0' xml:lang='{self.config['ms_lang']}'><voice xml:lang='{self.config['ms_lang']}' xml:gender='{self.config['ms_gender']}'
|
|
name='{self.config['ms_name']}'>
|
|
{text}
|
|
</voice></speak>"""
|
|
logger.debug(body)
|
|
http_client = AsyncHTTPClient()
|
|
request = HTTPRequest(
|
|
method="POST",
|
|
url=self.config['voice_url'],
|
|
headers=headers,
|
|
body=body
|
|
)
|
|
try:
|
|
response = await http_client.fetch(request)
|
|
except Exception as e:
|
|
logger.critical(f"Exception when getting Microsoft voice file: POST {request.url} body: {request.body}")
|
|
http_client.close()
|
|
raise e
|
|
|
|
http_client.close()
|
|
|
|
if response.code != 200:
|
|
raise Exception(f"No proper response! {response.code}")
|
|
|
|
return response.body
|
|
|
|
|