From 4bca2097e1c6044a8ceccb0222b142e32adef659 Mon Sep 17 00:00:00 2001 From: Malar Date: Tue, 8 Jun 2021 17:45:09 +0530 Subject: [PATCH] 1. fixed nty-num type spellcheck issue 2. added tests for the same 3. remove [infer] optional subsumes [eval] --- setup.py | 8 ++++--- src/plume/utils/__init__.py | 1 + src/plume/utils/regentity.py | 33 +++++++++++++++++++++++------ tests/plume/utils/test_regentity.py | 9 ++++++++ 4 files changed, 42 insertions(+), 9 deletions(-) diff --git a/setup.py b/setup.py index 92925e8..ec7b585 100644 --- a/setup.py +++ b/setup.py @@ -6,6 +6,7 @@ requirements = [ # "nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit", # "fairseq @ git+https://github.com/pytorch/fairseq.git@94a1b924f3adec25c8c508ac112410d02b400d1e#egg=fairseq", # "google-cloud-texttospeech~=1.0.1", + "six~=1.16.0", "tqdm~=4.49.0", # "pydub~=0.24.0", # "scikit_learn~=0.22.1", @@ -58,14 +59,15 @@ extra_requirements = { "torchvision~=0.8.2", "torchaudio~=0.7.2", ], - "eval": [ + "infer": [ "jiwer~=2.2.0", "pydub~=0.24.0", "tritonclient[grpc]~=2.9.0", "pyspellchecker~=0.6.2", "num2words~=0.5.10", + "pydub~=0.24.0", ], - "infer": [ + "infer_min": [ "pyspellchecker~=0.6.2", "num2words~=0.5.10", ], @@ -85,7 +87,7 @@ extra_requirements = { "train": ["torchaudio~=0.6.0", "torch-stft~=0.1.4"], } extra_requirements["deploy"] = ( - extra_requirements["models"] + extra_requirements["infer"] + extra_requirements["models"] + extra_requirements["infer_min"] ) extra_requirements["all"] = list( {d for r in extra_requirements.values() for d in r} diff --git a/src/plume/utils/__init__.py b/src/plume/utils/__init__.py index a3e3646..9bfdb64 100644 --- a/src/plume/utils/__init__.py +++ b/src/plume/utils/__init__.py @@ -49,6 +49,7 @@ from .regentity import ( # noqa default_num_only_rules, default_alnum_rules, entity_replacer_keeper, + vocab_corrector_gen, ) boto3 = lazy_module("boto3") diff --git a/src/plume/utils/regentity.py b/src/plume/utils/regentity.py index 029cee6..6191e15 100644 --- a/src/plume/utils/regentity.py +++ b/src/plume/utils/regentity.py @@ -4,6 +4,8 @@ from .lazy_import import lazy_callable, lazy_module num2words = lazy_callable("num2words.num2words") spellchecker = lazy_module("spellchecker") +editdistance = lazy_module("editdistance") + # from num2words import num2words @@ -363,15 +365,34 @@ def infer_num_replacer(num_range=100, condense=True): return final_replacer -def vocab_corrector_gen(vocab): - spell = spellchecker.SpellChecker(distance=1) +def vocab_corrector_gen(vocab, distance=1, method="spell"): + spell = spellchecker.SpellChecker(distance=distance) words_to_remove = set(spell.word_frequency.words()) - set(vocab) spell.word_frequency.remove_words(words_to_remove) + spell.word_frequency.load_words(vocab) - def corrector(inp): - return " ".join( - [spell.correction(tok) for tok in spell.split_words(inp)] - ) + if method == "spell": + + def corrector(inp): + # return " ".join( + # [spell.correction(tok) for tok in spell.split_words(inp)] + # ) + return " ".join( + [spell.correction(tok) for tok in inp.split()] + ) + + elif method == "edit": + # editdistance.eval("banana", "bahama") + + def corrector(inp): + match_dists = sorted( + [(v, editdistance.eval(inp, v)) for v in vocab], + key=lambda x: x[1], + ) + return match_dists[0] + + else: + raise ValueError(f"unsupported method:{method}") return corrector diff --git a/tests/plume/utils/test_regentity.py b/tests/plume/utils/test_regentity.py index e17a0dd..dcd4adc 100644 --- a/tests/plume/utils/test_regentity.py +++ b/tests/plume/utils/test_regentity.py @@ -15,3 +15,12 @@ def test_infer_num(): repl("SIX NINE FSIX EIGHT IGSIX SIX NINE NINE THRE ZERO TWO SEVEN ONE") == "6968669930271" ) + + assert ( + repl("FORTY-TWO SEVEN SIXTY-FOUR SEVEN THREE FIVE U OH FOUR SIX") + == "42764735046" + ) + assert ( + repl("FORTY-TWO SEVEN SIXTY-FOUR SEVEN THREE FIVE U OH FOUR SIX") + == "42764735046" + )