1. fixed nty-num type spellcheck issue

2. added tests for the same
3. remove [infer] optional subsumes [eval]
tegra
Malar 2021-06-08 17:45:09 +05:30
parent af51fe95cb
commit 4bca2097e1
4 changed files with 42 additions and 9 deletions

View File

@ -6,6 +6,7 @@ requirements = [
# "nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit",
# "fairseq @ git+https://github.com/pytorch/fairseq.git@94a1b924f3adec25c8c508ac112410d02b400d1e#egg=fairseq",
# "google-cloud-texttospeech~=1.0.1",
"six~=1.16.0",
"tqdm~=4.49.0",
# "pydub~=0.24.0",
# "scikit_learn~=0.22.1",
@ -58,14 +59,15 @@ extra_requirements = {
"torchvision~=0.8.2",
"torchaudio~=0.7.2",
],
"eval": [
"infer": [
"jiwer~=2.2.0",
"pydub~=0.24.0",
"tritonclient[grpc]~=2.9.0",
"pyspellchecker~=0.6.2",
"num2words~=0.5.10",
"pydub~=0.24.0",
],
"infer": [
"infer_min": [
"pyspellchecker~=0.6.2",
"num2words~=0.5.10",
],
@ -85,7 +87,7 @@ extra_requirements = {
"train": ["torchaudio~=0.6.0", "torch-stft~=0.1.4"],
}
extra_requirements["deploy"] = (
extra_requirements["models"] + extra_requirements["infer"]
extra_requirements["models"] + extra_requirements["infer_min"]
)
extra_requirements["all"] = list(
{d for r in extra_requirements.values() for d in r}

View File

@ -49,6 +49,7 @@ from .regentity import ( # noqa
default_num_only_rules,
default_alnum_rules,
entity_replacer_keeper,
vocab_corrector_gen,
)
boto3 = lazy_module("boto3")

View File

@ -4,6 +4,8 @@ from .lazy_import import lazy_callable, lazy_module
num2words = lazy_callable("num2words.num2words")
spellchecker = lazy_module("spellchecker")
editdistance = lazy_module("editdistance")
# from num2words import num2words
@ -363,15 +365,34 @@ def infer_num_replacer(num_range=100, condense=True):
return final_replacer
def vocab_corrector_gen(vocab):
spell = spellchecker.SpellChecker(distance=1)
def vocab_corrector_gen(vocab, distance=1, method="spell"):
spell = spellchecker.SpellChecker(distance=distance)
words_to_remove = set(spell.word_frequency.words()) - set(vocab)
spell.word_frequency.remove_words(words_to_remove)
spell.word_frequency.load_words(vocab)
def corrector(inp):
return " ".join(
[spell.correction(tok) for tok in spell.split_words(inp)]
)
if method == "spell":
def corrector(inp):
# return " ".join(
# [spell.correction(tok) for tok in spell.split_words(inp)]
# )
return " ".join(
[spell.correction(tok) for tok in inp.split()]
)
elif method == "edit":
# editdistance.eval("banana", "bahama")
def corrector(inp):
match_dists = sorted(
[(v, editdistance.eval(inp, v)) for v in vocab],
key=lambda x: x[1],
)
return match_dists[0]
else:
raise ValueError(f"unsupported method:{method}")
return corrector

View File

@ -15,3 +15,12 @@ def test_infer_num():
repl("SIX NINE FSIX EIGHT IGSIX SIX NINE NINE THRE ZERO TWO SEVEN ONE")
== "6968669930271"
)
assert (
repl("FORTY-TWO SEVEN SIXTY-FOUR SEVEN THREE FIVE U OH FOUR SIX")
== "42764735046"
)
assert (
repl("FORTY-TWO SEVEN SIXTY-FOUR SEVEN THREE FIVE U OH FOUR SIX")
== "42764735046"
)