mirror of
https://github.com/malarinv/jasper-asr.git
synced 2026-03-09 19:02:35 +00:00
1. integrated data generator using google tts
2. added training script
This commit is contained in:
1
jasper/data_utils/__init__.py
Normal file
1
jasper/data_utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
65
jasper/data_utils/generator.py
Normal file
65
jasper/data_utils/generator.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# import io
|
||||
# import sys
|
||||
# import json
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from .utils import random_pnr_generator, manifest_str
|
||||
from .tts.googletts import GoogleTTS
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def generate_asr_data(output_dir, count):
|
||||
google_voices = GoogleTTS.voice_list()
|
||||
gtts = GoogleTTS()
|
||||
wav_dir = output_dir / Path("pnr_data")
|
||||
wav_dir.mkdir(parents=True, exist_ok=True)
|
||||
asr_manifest = output_dir / Path("pnr_data").with_suffix(".json")
|
||||
with asr_manifest.open("w") as mf:
|
||||
for pnr_code in tqdm(random_pnr_generator(count)):
|
||||
tts_code = (
|
||||
f'<speak><say-as interpret-as="verbatim">{pnr_code}</say-as></speak>'
|
||||
)
|
||||
param = random.choice(google_voices)
|
||||
param["sample_rate"] = 24000
|
||||
param["num_channels"] = 1
|
||||
wav_data = gtts.text_to_speech(text=tts_code, params=param)
|
||||
audio_dur = len(wav_data[44:]) / (2 * 24000)
|
||||
pnr_af = wav_dir / Path(pnr_code).with_suffix(".wav")
|
||||
pnr_af.write_bytes(wav_data)
|
||||
rel_pnr_path = pnr_af.relative_to(output_dir)
|
||||
manifest = manifest_str(str(rel_pnr_path), audio_dur, pnr_code)
|
||||
mf.write(manifest)
|
||||
|
||||
|
||||
def arg_parser():
|
||||
prog = Path(__file__).stem
|
||||
parser = argparse.ArgumentParser(
|
||||
prog=prog, description=f"generates asr training data"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=Path,
|
||||
default=Path("./train/asr_data"),
|
||||
help="directory to output asr data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--count", type=int, default=3, help="number of datapoints to generate"
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = arg_parser()
|
||||
args = parser.parse_args()
|
||||
generate_asr_data(**vars(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
30
jasper/data_utils/parallel.py
Normal file
30
jasper/data_utils/parallel.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import concurrent.futures
|
||||
import urllib.request
|
||||
|
||||
URLS = [
|
||||
"http://www.foxnews.com/",
|
||||
"http://www.cnn.com/",
|
||||
"http://europe.wsj.com/",
|
||||
"http://www.bbc.co.uk/",
|
||||
"http://some-made-up-domain.com/",
|
||||
]
|
||||
|
||||
|
||||
# Retrieve a single page and report the URL and contents
|
||||
def load_url(url, timeout):
|
||||
with urllib.request.urlopen(url, timeout=timeout) as conn:
|
||||
return conn.read()
|
||||
|
||||
|
||||
# We can use a with statement to ensure threads are cleaned up promptly
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
|
||||
# Start the load operations and mark each future with its URL
|
||||
future_to_url = {executor.submit(load_url, url, 60): url for url in URLS}
|
||||
for future in concurrent.futures.as_completed(future_to_url):
|
||||
url = future_to_url[future]
|
||||
try:
|
||||
data = future.result()
|
||||
except Exception as exc:
|
||||
print("%r generated an exception: %s" % (url, exc))
|
||||
else:
|
||||
print("%r page is %d bytes" % (url, len(data)))
|
||||
95
jasper/data_utils/process.py
Normal file
95
jasper/data_utils/process.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from sklearn.model_selection import train_test_split
|
||||
from num2words import num2words
|
||||
|
||||
|
||||
def separate_space_convert_digit_setpath():
|
||||
with Path("/home/malar/work/asr-data-utils/asr_data/pnr_data.json").open("r") as pf:
|
||||
pnr_jsonl = pf.readlines()
|
||||
|
||||
pnr_data = [json.loads(i) for i in pnr_jsonl]
|
||||
|
||||
new_pnr_data = []
|
||||
for i in pnr_data:
|
||||
letters = " ".join(list(i["text"]))
|
||||
num_tokens = [num2words(c) if "0" <= c <= "9" else c for c in letters]
|
||||
i["text"] = ("".join(num_tokens)).lower()
|
||||
i["audio_filepath"] = i["audio_filepath"].replace(
|
||||
"pnr_data/", "/dataset/asr_data/pnr_data/wav/"
|
||||
)
|
||||
new_pnr_data.append(i)
|
||||
|
||||
new_pnr_jsonl = [json.dumps(i) for i in new_pnr_data]
|
||||
|
||||
with Path("/dataset/asr_data/pnr_data/pnr_data.json").open("w") as pf:
|
||||
new_pnr_data = "\n".join(new_pnr_jsonl) # + "\n"
|
||||
pf.write(new_pnr_data)
|
||||
|
||||
|
||||
separate_space_convert_digit_setpath()
|
||||
|
||||
|
||||
def split_data():
|
||||
with Path("/dataset/asr_data/pnr_data/pnr_data.json").open("r") as pf:
|
||||
pnr_jsonl = pf.readlines()
|
||||
train_pnr, test_pnr = train_test_split(pnr_jsonl, test_size=0.1)
|
||||
with Path("/dataset/asr_data/pnr_data/train_manifest.json").open("w") as pf:
|
||||
pnr_data = "".join(train_pnr)
|
||||
pf.write(pnr_data)
|
||||
with Path("/dataset/asr_data/pnr_data/test_manifest.json").open("w") as pf:
|
||||
pnr_data = "".join(test_pnr)
|
||||
pf.write(pnr_data)
|
||||
|
||||
|
||||
split_data()
|
||||
|
||||
|
||||
def augment_an4():
|
||||
an4_train = Path("/dataset/asr_data/an4/train_manifest.json").read_bytes()
|
||||
an4_test = Path("/dataset/asr_data/an4/test_manifest.json").read_bytes()
|
||||
pnr_train = Path("/dataset/asr_data/pnr_data/train_manifest.json").read_bytes()
|
||||
pnr_test = Path("/dataset/asr_data/pnr_data/test_manifest.json").read_bytes()
|
||||
|
||||
with Path("/dataset/asr_data/an4_pnr/train_manifest.json").open("wb") as pf:
|
||||
pf.write(an4_train + pnr_train)
|
||||
with Path("/dataset/asr_data/an4_pnr/test_manifest.json").open("wb") as pf:
|
||||
pf.write(an4_test + pnr_test)
|
||||
|
||||
|
||||
augment_an4()
|
||||
|
||||
|
||||
def validate_data(data_file):
|
||||
with Path(data_file).open("r") as pf:
|
||||
pnr_jsonl = pf.readlines()
|
||||
for (i, s) in enumerate(pnr_jsonl):
|
||||
try:
|
||||
json.loads(s)
|
||||
except BaseException as e:
|
||||
print(f"failed on {i}")
|
||||
|
||||
|
||||
validate_data("/dataset/asr_data/an4_pnr/test_manifest.json")
|
||||
validate_data("/dataset/asr_data/an4_pnr/train_manifest.json")
|
||||
|
||||
|
||||
# def convert_digits(data_file="/dataset/asr_data/an4_pnr/test_manifest.json"):
|
||||
# with Path(data_file).open("r") as pf:
|
||||
# pnr_jsonl = pf.readlines()
|
||||
#
|
||||
# pnr_data = [json.loads(i) for i in pnr_jsonl]
|
||||
# new_pnr_data = []
|
||||
# for i in pnr_data:
|
||||
# num_tokens = [num2words(c) for c in i["text"] if "0" <= c <= "9"]
|
||||
# i["text"] = "".join(num_tokens)
|
||||
# new_pnr_data.append(i)
|
||||
#
|
||||
# new_pnr_jsonl = [json.dumps(i) for i in new_pnr_data]
|
||||
#
|
||||
# with Path(data_file).open("w") as pf:
|
||||
# new_pnr_data = "\n".join(new_pnr_jsonl) # + "\n"
|
||||
# pf.write(new_pnr_data)
|
||||
#
|
||||
#
|
||||
# convert_digits(data_file="/dataset/asr_data/an4_pnr/train_manifest.json")
|
||||
0
jasper/data_utils/tts/__init__.py
Normal file
0
jasper/data_utils/tts/__init__.py
Normal file
52
jasper/data_utils/tts/googletts.py
Normal file
52
jasper/data_utils/tts/googletts.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from logging import getLogger
|
||||
from google.cloud import texttospeech
|
||||
|
||||
LOGGER = getLogger("googletts")
|
||||
|
||||
|
||||
class GoogleTTS(object):
|
||||
def __init__(self):
|
||||
self.client = texttospeech.TextToSpeechClient()
|
||||
|
||||
def text_to_speech(self, text: str, params: dict) -> bytes:
|
||||
tts_input = texttospeech.types.SynthesisInput(ssml=text)
|
||||
voice = texttospeech.types.VoiceSelectionParams(
|
||||
language_code=params["language"], name=params["name"]
|
||||
)
|
||||
audio_config = texttospeech.types.AudioConfig(
|
||||
audio_encoding=texttospeech.enums.AudioEncoding.LINEAR16,
|
||||
sample_rate_hertz=params["sample_rate"],
|
||||
)
|
||||
response = self.client.synthesize_speech(tts_input, voice, audio_config)
|
||||
audio_content = response.audio_content
|
||||
return audio_content
|
||||
|
||||
@classmethod
|
||||
def voice_list(cls):
|
||||
"""Lists the available voices."""
|
||||
|
||||
client = cls().client
|
||||
|
||||
# Performs the list voices request
|
||||
voices = client.list_voices()
|
||||
results = []
|
||||
for voice in voices.voices:
|
||||
supported_eng_langs = [
|
||||
lang for lang in voice.language_codes if lang[:2] == "en"
|
||||
]
|
||||
if len(supported_eng_langs) > 0:
|
||||
lang = ",".join(supported_eng_langs)
|
||||
else:
|
||||
continue
|
||||
|
||||
ssml_gender = texttospeech.enums.SsmlVoiceGender(voice.ssml_gender)
|
||||
results.append(
|
||||
{
|
||||
"name": voice.name,
|
||||
"language": lang,
|
||||
"gender": ssml_gender.name,
|
||||
"engine": "wavenet" if "Wav" in voice.name else "standard",
|
||||
"sample_rate": voice.natural_sample_rate_hertz,
|
||||
}
|
||||
)
|
||||
return results
|
||||
26
jasper/data_utils/tts/ttsclient.py
Normal file
26
jasper/data_utils/tts/ttsclient.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
TTSClient Abstract Class
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class TTSClient(ABC):
|
||||
"""
|
||||
Base class for TTS
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def text_to_speech(self, text: str, num_channels: int, sample_rate: int,
|
||||
audio_encoding) -> bytes:
|
||||
"""
|
||||
convert text to bytes
|
||||
|
||||
Arguments:
|
||||
text {[type]} -- text to convert
|
||||
channel {[type]} -- output audio bytes channel setting
|
||||
width {[type]} -- width of audio bytes
|
||||
rate {[type]} -- rare for audio bytes
|
||||
|
||||
Returns:
|
||||
[type] -- [description]
|
||||
"""
|
||||
47
jasper/data_utils/utils.py
Normal file
47
jasper/data_utils/utils.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import numpy as np
|
||||
import wave
|
||||
import io
|
||||
import json
|
||||
|
||||
|
||||
def manifest_str(path, dur, text):
|
||||
return (
|
||||
json.dumps({"audio_filepath": path, "duration": round(dur, 1), "text": text})
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
|
||||
def wav_bytes(audio_bytes, frame_rate=24000):
|
||||
wf_b = io.BytesIO()
|
||||
with wave.open(wf_b, mode="w") as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setframerate(frame_rate)
|
||||
wf.setsampwidth(2)
|
||||
wf.writeframesraw(audio_bytes)
|
||||
return wf_b.getvalue()
|
||||
|
||||
|
||||
def random_pnr_generator(count=10000):
|
||||
LENGTH = 3
|
||||
|
||||
# alphabet = list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
||||
alphabet = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
||||
numeric = list("0123456789")
|
||||
np_alphabet = np.array(alphabet, dtype="|S1")
|
||||
np_numeric = np.array(numeric, dtype="|S1")
|
||||
np_alpha_codes = np.random.choice(np_alphabet, [count, LENGTH])
|
||||
np_num_codes = np.random.choice(np_numeric, [count, LENGTH])
|
||||
np_code_seed = np.concatenate((np_alpha_codes, np_num_codes), axis=1).T
|
||||
np.random.shuffle(np_code_seed)
|
||||
np_codes = np_code_seed.T
|
||||
codes = [(b"".join(np_codes[i])).decode("utf-8") for i in range(len(np_codes))]
|
||||
return codes
|
||||
|
||||
|
||||
def main():
|
||||
for c in random_pnr_generator():
|
||||
print(c)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user