1. fixed nty-num type spellcheck issue
2. added tests for the same 3. remove [infer] optional subsumes [eval]tegra
parent
af51fe95cb
commit
4bca2097e1
8
setup.py
8
setup.py
|
|
@ -6,6 +6,7 @@ requirements = [
|
||||||
# "nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit",
|
# "nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit",
|
||||||
# "fairseq @ git+https://github.com/pytorch/fairseq.git@94a1b924f3adec25c8c508ac112410d02b400d1e#egg=fairseq",
|
# "fairseq @ git+https://github.com/pytorch/fairseq.git@94a1b924f3adec25c8c508ac112410d02b400d1e#egg=fairseq",
|
||||||
# "google-cloud-texttospeech~=1.0.1",
|
# "google-cloud-texttospeech~=1.0.1",
|
||||||
|
"six~=1.16.0",
|
||||||
"tqdm~=4.49.0",
|
"tqdm~=4.49.0",
|
||||||
# "pydub~=0.24.0",
|
# "pydub~=0.24.0",
|
||||||
# "scikit_learn~=0.22.1",
|
# "scikit_learn~=0.22.1",
|
||||||
|
|
@ -58,14 +59,15 @@ extra_requirements = {
|
||||||
"torchvision~=0.8.2",
|
"torchvision~=0.8.2",
|
||||||
"torchaudio~=0.7.2",
|
"torchaudio~=0.7.2",
|
||||||
],
|
],
|
||||||
"eval": [
|
"infer": [
|
||||||
"jiwer~=2.2.0",
|
"jiwer~=2.2.0",
|
||||||
"pydub~=0.24.0",
|
"pydub~=0.24.0",
|
||||||
"tritonclient[grpc]~=2.9.0",
|
"tritonclient[grpc]~=2.9.0",
|
||||||
"pyspellchecker~=0.6.2",
|
"pyspellchecker~=0.6.2",
|
||||||
"num2words~=0.5.10",
|
"num2words~=0.5.10",
|
||||||
|
"pydub~=0.24.0",
|
||||||
],
|
],
|
||||||
"infer": [
|
"infer_min": [
|
||||||
"pyspellchecker~=0.6.2",
|
"pyspellchecker~=0.6.2",
|
||||||
"num2words~=0.5.10",
|
"num2words~=0.5.10",
|
||||||
],
|
],
|
||||||
|
|
@ -85,7 +87,7 @@ extra_requirements = {
|
||||||
"train": ["torchaudio~=0.6.0", "torch-stft~=0.1.4"],
|
"train": ["torchaudio~=0.6.0", "torch-stft~=0.1.4"],
|
||||||
}
|
}
|
||||||
extra_requirements["deploy"] = (
|
extra_requirements["deploy"] = (
|
||||||
extra_requirements["models"] + extra_requirements["infer"]
|
extra_requirements["models"] + extra_requirements["infer_min"]
|
||||||
)
|
)
|
||||||
extra_requirements["all"] = list(
|
extra_requirements["all"] = list(
|
||||||
{d for r in extra_requirements.values() for d in r}
|
{d for r in extra_requirements.values() for d in r}
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,7 @@ from .regentity import ( # noqa
|
||||||
default_num_only_rules,
|
default_num_only_rules,
|
||||||
default_alnum_rules,
|
default_alnum_rules,
|
||||||
entity_replacer_keeper,
|
entity_replacer_keeper,
|
||||||
|
vocab_corrector_gen,
|
||||||
)
|
)
|
||||||
|
|
||||||
boto3 = lazy_module("boto3")
|
boto3 = lazy_module("boto3")
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,8 @@ from .lazy_import import lazy_callable, lazy_module
|
||||||
|
|
||||||
num2words = lazy_callable("num2words.num2words")
|
num2words = lazy_callable("num2words.num2words")
|
||||||
spellchecker = lazy_module("spellchecker")
|
spellchecker = lazy_module("spellchecker")
|
||||||
|
editdistance = lazy_module("editdistance")
|
||||||
|
|
||||||
# from num2words import num2words
|
# from num2words import num2words
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -363,15 +365,34 @@ def infer_num_replacer(num_range=100, condense=True):
|
||||||
return final_replacer
|
return final_replacer
|
||||||
|
|
||||||
|
|
||||||
def vocab_corrector_gen(vocab):
|
def vocab_corrector_gen(vocab, distance=1, method="spell"):
|
||||||
spell = spellchecker.SpellChecker(distance=1)
|
spell = spellchecker.SpellChecker(distance=distance)
|
||||||
words_to_remove = set(spell.word_frequency.words()) - set(vocab)
|
words_to_remove = set(spell.word_frequency.words()) - set(vocab)
|
||||||
spell.word_frequency.remove_words(words_to_remove)
|
spell.word_frequency.remove_words(words_to_remove)
|
||||||
|
spell.word_frequency.load_words(vocab)
|
||||||
|
|
||||||
def corrector(inp):
|
if method == "spell":
|
||||||
return " ".join(
|
|
||||||
[spell.correction(tok) for tok in spell.split_words(inp)]
|
def corrector(inp):
|
||||||
)
|
# return " ".join(
|
||||||
|
# [spell.correction(tok) for tok in spell.split_words(inp)]
|
||||||
|
# )
|
||||||
|
return " ".join(
|
||||||
|
[spell.correction(tok) for tok in inp.split()]
|
||||||
|
)
|
||||||
|
|
||||||
|
elif method == "edit":
|
||||||
|
# editdistance.eval("banana", "bahama")
|
||||||
|
|
||||||
|
def corrector(inp):
|
||||||
|
match_dists = sorted(
|
||||||
|
[(v, editdistance.eval(inp, v)) for v in vocab],
|
||||||
|
key=lambda x: x[1],
|
||||||
|
)
|
||||||
|
return match_dists[0]
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unsupported method:{method}")
|
||||||
|
|
||||||
return corrector
|
return corrector
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,3 +15,12 @@ def test_infer_num():
|
||||||
repl("SIX NINE FSIX EIGHT IGSIX SIX NINE NINE THRE ZERO TWO SEVEN ONE")
|
repl("SIX NINE FSIX EIGHT IGSIX SIX NINE NINE THRE ZERO TWO SEVEN ONE")
|
||||||
== "6968669930271"
|
== "6968669930271"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
repl("FORTY-TWO SEVEN SIXTY-FOUR SEVEN THREE FIVE U OH FOUR SIX")
|
||||||
|
== "42764735046"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
repl("FORTY-TWO SEVEN SIXTY-FOUR SEVEN THREE FIVE U OH FOUR SIX")
|
||||||
|
== "42764735046"
|
||||||
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue