diff --git a/README.md b/README.md index b9c60d5..4161cb3 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,9 @@ **Text-to-text alignment algorithm for speech recognition error analysis.** ErrorAlign helps you dig deeper into your speech recognition projects by accurately aligning each word in a reference transcript with the model-generated transcript. Unlike traditional methods, such as Levenshtein-based alignment, it is not restricted to simple one-to-one alignment, but can map a single reference word to multiple words or subwords in the model output. This enables quick and reliable identification of error patterns in rare words, names, or domain-specific terms that matter most for your application. -→ **Update [2025-12-10]:** As of version `0.1.0b5`, `error-align` will include a word-level pass to efficiently identify unambiguous matches, along with C++ extensions to accelerate beam search and backtrace construction. The combined speedup is ~15× over the pure-Python implementation ⚡ +→ **Update [2026-06-22]:** As of version `0.1.0b10`, the word-level pass defaults to a faster `rapidfuzz`-based method that anchors matches from a single optimal Levenshtein alignment. On longer examples (e.g., Earnings-21), the speedup is expected to be around 30×. The original graph-based pass is still available via `error_align(ref, hyp, word_level_method="unambiguous")`. + +→ **Update [2025-12-10]:** As of version `0.1.0b5`, `error-align` will include a word-level pass to efficiently identify unambiguous matches, along with C++ extensions to accelerate beam search and backtrace construction. The combined speedup is ~15× over the pure-Python implementation. [//]: <> (https://raw.githubusercontent.com/corticph/error-align/refs/heads/main/.github/assets/logo_gpt.svg) diff --git a/poetry.lock b/poetry.lock index f8618ed..70956c7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -8036,6 +8036,13 @@ optional = false python-versions = ">=3.8" groups = ["main", "dev"] files = [ + {file = "PyYAML-6.0.3-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:c2514fceb77bc5e7a2f7adfaa1feb2fb311607c9cb518dbc378688ec73d8292f"}, + {file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9c57bb8c96f6d1808c030b1687b9b5fb476abaa47f0db9c0101f5e9f394e97f4"}, + {file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:efd7b85f94a6f21e4932043973a7ba2613b059c4a000551892ac9f1d11f5baf3"}, + {file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:22ba7cfcad58ef3ecddc7ed1db3409af68d023b7f940da23c6c2a1890976eda6"}, + {file = "PyYAML-6.0.3-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:6344df0d5755a2c9a276d4473ae6b90647e216ab4757f8426893b5dd2ac3f369"}, + {file = "PyYAML-6.0.3-cp38-cp38-win32.whl", hash = "sha256:3ff07ec89bae51176c0549bc4c63aa6202991da2d9a6129d7aef7f1407d3f295"}, + {file = "PyYAML-6.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:5cf4e27da7e3fbed4d6c3d8e797387aaad68102272f8f9752883bc32d61cb87b"}, {file = "pyyaml-6.0.3-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:214ed4befebe12df36bcc8bc2b64b396ca31be9304b8f59e25c11cf94a4c033b"}, {file = "pyyaml-6.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:02ea2dfa234451bbb8772601d7b8e426c2bfa197136796224e50e35a78777956"}, {file = "pyyaml-6.0.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b30236e45cf30d2b8e7b3e85881719e98507abed1011bf463a8fa23e9c3e98a8"}, @@ -8131,10 +8138,9 @@ decord = ["decord"] name = "rapidfuzz" version = "3.14.3" description = "rapid fuzzy string matching" -optional = true +optional = false python-versions = ">=3.10" groups = ["main"] -markers = "python_version == \"3.12\" and extra == \"evaluation\"" files = [ {file = "rapidfuzz-3.14.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b9fcd4d751a4fffa17aed1dde41647923c72c74af02459ad1222e3b0022da3a1"}, {file = "rapidfuzz-3.14.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4ad73afb688b36864a8d9b7344a9cf6da186c471e5790cbf541a635ee0f457f2"}, @@ -11405,9 +11411,9 @@ test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more_it type = ["pytest-mypy"] [extras] -evaluation = ["backoff", "click", "datasets", "gitpython", "librosa", "lilcom", "matplotlib", "nemo-toolkit", "num2words", "numpy", "pandas", "pyphen", "rapidfuzz", "scipy", "soundfile", "textgrid", "torch", "torchcodec", "torchvision", "transformers"] +evaluation = ["backoff", "click", "datasets", "gitpython", "librosa", "lilcom", "matplotlib", "nemo-toolkit", "num2words", "numpy", "pandas", "pyphen", "scipy", "soundfile", "textgrid", "torch", "torchcodec", "torchvision", "transformers"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.15" -content-hash = "35955ef088b01426ffedf84338ac8f82aae808b5960152978fa15e0dec16727f" +content-hash = "068773cea111445347373377f2f45f8527b6ce69612f0332aec8b922b5036f81" diff --git a/pyproject.toml b/pyproject.toml index f49f747..5193622 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ build-backend = "scikit_build_core.build" [project] name = "error-align" -version = "0.1.0b9" +version = "0.1.0b10" description = "Text-to-text alignment algorithm for speech recognition error analysis." readme = "README.md" requires-python = ">=3.10,<3.15" @@ -36,6 +36,7 @@ dependencies = [ "tqdm>=4.67.1", "unidecode>=1.4.0", "regex>=2025.9.18", + "rapidfuzz>=3.13.0", ] [project.urls] @@ -43,7 +44,6 @@ Homepage = "https://github.com/corticph/error-align" [project.optional-dependencies] evaluation = [ - "rapidfuzz>=3.13.0; python_version == '3.12'", "num2words>=0.5.14; python_version == '3.12'", "datasets>=3.3.2; python_version == '3.12'", "soundfile>=0.13.1; python_version == '3.12'", diff --git a/src/error_align/error_align.py b/src/error_align/error_align.py index e843fee..7635a21 100644 --- a/src/error_align/error_align.py +++ b/src/error_align/error_align.py @@ -1,3 +1,5 @@ +from rapidfuzz.distance import Levenshtein + from error_align.backtrace_graph import BacktraceGraph from error_align.core import compute_levenshtein_distance_matrix, error_align_beam_search from error_align.graph_metadata import GraphMetadata, SubgraphMetadata @@ -19,6 +21,7 @@ def error_align( normalizer: callable = basic_normalizer, beam_size: int = 100, word_level_pass: bool = True, + word_level_method: str = "rapidfuzz", ): """Run error alignment between reference and hypothesis texts. @@ -28,7 +31,10 @@ def error_align( tokenizer (callable): A function to tokenize the sequences. Must be regex-based and return Match objects. normalizer (callable): A function to normalize the tokens. Defaults to basic_normalizer. beam_size (int): The beam size for beam search alignment. - word_level_pass (bool): Whether to perform a word-level alignment pass to identify unambiguous matches. + word_level_pass (bool): Whether to perform a word-level alignment pass to anchor matches before beam search. + word_level_method (str): Which word-level pass to use when ``word_level_pass`` is True. ``"rapidfuzz"`` + (default) takes the matches from a single optimal Levenshtein alignment via rapidfuzz. ``"unambiguous"`` + builds the full backtrace graph and only anchors matches common to all optimal paths. """ graph_metadata = prepare_graph_metadata( @@ -42,8 +48,12 @@ def error_align( return align_identical_inputs(graph_metadata) elif not word_level_pass: return align_beam_search(graph_metadata, beam_size=beam_size) - else: + elif word_level_method == "rapidfuzz": + return align_with_rapidfuzz_word_level_pass(graph_metadata, beam_size=beam_size) + elif word_level_method == "unambiguous": return align_with_word_level_pass(graph_metadata, beam_size=beam_size) + else: + raise ValueError(f"Unknown word_level_method: {word_level_method!r}") def prepare_graph_metadata( @@ -127,10 +137,44 @@ def align_with_word_level_pass( ) backtrace_graph = BacktraceGraph(backtrace_matrix) match_indices = backtrace_graph.get_unambiguous_node_matches() + return align_from_match_indices(graph_metadata, beam_size, match_indices) + + +def align_with_rapidfuzz_word_level_pass( + graph_metadata: GraphMetadata, + beam_size: int, +) -> list[Alignment]: + """Perform a word-level alignment pass using matches from a single optimal Levenshtein alignment.""" + match_indices = get_rapidfuzz_match_indices(graph_metadata.ref_norm, graph_metadata.hyp_norm) + return align_from_match_indices(graph_metadata, beam_size, match_indices) + + +def get_rapidfuzz_match_indices(ref_norm: list[str], hyp_norm: list[str]) -> list[tuple[int, int]]: + """Infer word-level match indices from a rapidfuzz Levenshtein alignment. + + rapidfuzz only emits the non-match operations (insert/delete/replace), so matches are recovered as the + complement of the edited token indices. Returns ``(hyp_idx, ref_idx)`` tuples to match the convention used + by ``BacktraceGraph.get_unambiguous_node_matches`` and consumed by ``align_from_match_indices``. + """ + edit_ops = Levenshtein.editops(ref_norm, hyp_norm).as_list() # (op, ref_idx, hyp_idx) + ref_edit_idxs = {op[1] for op in edit_ops if op[0] != "insert"} + hyp_edit_idxs = {op[2] for op in edit_ops if op[0] != "delete"} + ref_match = [i for i in range(len(ref_norm)) if i not in ref_edit_idxs] + hyp_match = [i for i in range(len(hyp_norm)) if i not in hyp_edit_idxs] + # Matches are monotonic in both axes; zip the complements and swap to (hyp, ref). + return [(h, r) for r, h in zip(ref_match, hyp_match, strict=True)] + + +def align_from_match_indices( + graph_metadata: GraphMetadata, + beam_size: int, + match_indices: list[tuple[int, int]], +) -> list[Alignment]: + """Extract alignments from word-level match anchors, beam-searching the ambiguous spans between them.""" # NOTE: We always add an artificial terminal match node to simplify subspan extraction. match_indices = match_indices + [(len(graph_metadata.hyp_norm), len(graph_metadata.ref_norm))] - # Iterate over the unambiguous matches to extract subspans (i.e., the span of words between two matches). + # Iterate over the matches to extract subspans (i.e., the span of words between two matches). hyp_start, ref_start = (0, 0) alignments = [] end_index = len(match_indices) - 1 @@ -174,7 +218,7 @@ def align_with_word_level_pass( hyp_index=hyp_end, ) ) - ref_start, hyp_start = (ref_end + 1, hyp_end + 1) + hyp_start, ref_start = (hyp_end + 1, ref_end + 1) return alignments diff --git a/tests/test_default.py b/tests/test_default.py index d33952c..f314716 100644 --- a/tests/test_default.py +++ b/tests/test_default.py @@ -1,3 +1,4 @@ +import pytest from error_align._cpp_beam_search import error_align_beam_search as cpp_error_align_beam_search from typeguard import suppress_type_checks @@ -6,7 +7,7 @@ from error_align.beam_search import _cpp_path_to_py_path from error_align.beam_search import error_align_beam_search as python_error_align_beam_search from error_align.edit_distance import compute_error_align_distance_matrix, compute_levenshtein_distance_matrix -from error_align.error_align import prepare_graph_metadata +from error_align.error_align import get_rapidfuzz_match_indices, prepare_graph_metadata from error_align.graph_metadata import SubgraphMetadata from error_align.utils import ( Alignment, @@ -19,13 +20,14 @@ ) -def test_error_align() -> None: +@pytest.mark.parametrize("word_level_method", ["rapidfuzz", "unambiguous"]) +def test_error_align(word_level_method: str) -> None: """Test error alignment for an example including all substitution types.""" ref = "This is a substitution test deleted." hyp = "Inserted this is a contribution test." - alignments = error_align(ref, hyp) + alignments = error_align(ref, hyp, word_level_method=word_level_method) expected_ops = [ OpType.INSERT, # Inserted OpType.MATCH, # This @@ -40,6 +42,30 @@ def test_error_align() -> None: assert alignment.op_type == op +@pytest.mark.parametrize( + ("ref_norm", "hyp_norm", "expected"), + [ + # Replace: equal lengths, only the differing token is excluded. + (["this", "is", "a", "test"], ["this", "is", "a", "pest"], [(0, 0), (1, 1), (2, 2)]), + # Insertion: hyp longer; "b" stays matched at (hyp=2, ref=1). + (["a", "b"], ["a", "x", "b"], [(0, 0), (2, 1)]), + # Deletion: ref longer; "c" stays matched at (hyp=1, ref=2). + (["a", "b", "c"], ["a", "c"], [(0, 0), (1, 2)]), + ], +) +def test_get_rapidfuzz_match_indices(ref_norm: list[str], hyp_norm: list[str], expected: list[tuple[int, int]]) -> None: + """Match indices are inferred as the complement of rapidfuzz edit ops, in (hyp_idx, ref_idx) order.""" + + assert get_rapidfuzz_match_indices(ref_norm, hyp_norm) == expected + + +def test_error_align_unknown_word_level_method() -> None: + """An unknown word-level method raises a clear error.""" + + with pytest.raises(ValueError, match="Unknown word_level_method"): + error_align("a b c", "a x c", word_level_method="nonsense") + + def test_beam_search_cpp_vs_python() -> None: """Test that the C++ and Python beam search implementations produce the same results."""