Skip to content

Commit 784072e

Browse files
committed
perf(metrics): add fast path without word tracking for speedup in wer/wers/werp/werps functions
1 parent 1abc587 commit 784072e

7 files changed

Lines changed: 159 additions & 48 deletions

File tree

werpy/metrics.pyx

Lines changed: 117 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ cpdef cnp.ndarray calculations(object reference, object hypothesis):
8484
ldm[i, j] = best
8585

8686
ld = ldm[m, n]
87-
wer = (<double>ld) / m
87+
wer = (<double>ld) / m if m > 0 else 0.0
8888

8989
insertions, deletions, substitutions = 0, 0, 0
9090
inserted_words, deleted_words, substituted_words = [], [], []
@@ -125,7 +125,7 @@ cdef cnp.ndarray _metrics_batch(list references, list hypotheses):
125125
[wer, ld, m, insertions, deletions, substitutions, inserted_words, deleted_words, substituted_words]
126126
"""
127127
cdef Py_ssize_t n = len(references)
128-
cdef Py_ssize_t idx, j
128+
cdef Py_ssize_t idx
129129

130130
# Rows output, dtype=object because cols 6-8 are lists
131131
cdef cnp.ndarray out = np.empty((n, 9), dtype=object)
@@ -138,8 +138,7 @@ cdef cnp.ndarray _metrics_batch(list references, list hypotheses):
138138
if isinstance(r, np.ndarray) and r.ndim == 0:
139139
r = r.item()
140140

141-
for j in range(9):
142-
out[idx, j] = r[j]
141+
out[idx, :] = r
143142

144143
return out
145144

@@ -155,3 +154,117 @@ cpdef object metrics(object reference, object hypothesis):
155154
if isinstance(reference, (list, np.ndarray)) and isinstance(hypothesis, (list, np.ndarray)):
156155
return _metrics_batch(list(reference), list(hypothesis))
157156
return calculations(reference, hypothesis)
157+
158+
159+
@cython.boundscheck(False)
160+
@cython.wraparound(False)
161+
cpdef cnp.ndarray calculations_fast(object reference, object hypothesis):
162+
"""
163+
Fast path for WER/LD calculations without word tracking.
164+
Returns only numeric metrics (WER, LD, m, insertions, deletions, substitutions).
165+
166+
This function is optimized for use cases that only need counts and metrics,
167+
not the actual lists of inserted/deleted/substituted words.
168+
169+
Returns (6,) float64 array: [wer, ld, m, insertions, deletions, substitutions]
170+
"""
171+
cdef list reference_word = reference.split()
172+
cdef list hypothesis_word = hypothesis.split()
173+
174+
cdef Py_ssize_t m = len(reference_word)
175+
cdef Py_ssize_t n = len(hypothesis_word)
176+
cdef Py_ssize_t i, j
177+
178+
cdef int ld, insertions, deletions, substitutions
179+
cdef double wer
180+
181+
cdef int cost, del_cost, ins_cost, sub_cost, best
182+
183+
# Allocate the (m+1) x (n+1) DP matrix without zero-initialization
184+
cdef int[:, :] ldm = np.empty((m + 1, n + 1), dtype=np.int32)
185+
186+
# Initialize first column and first row (boundary conditions)
187+
for i in range(m + 1):
188+
ldm[i, 0] = <int>i
189+
for j in range(n + 1):
190+
ldm[0, j] = <int>j
191+
192+
# Fill the Levenshtein distance matrix
193+
for i in range(1, m + 1):
194+
for j in range(1, n + 1):
195+
cost = 0 if reference_word[i - 1] == hypothesis_word[j - 1] else 1
196+
197+
del_cost = ldm[i - 1, j] + 1
198+
ins_cost = ldm[i, j - 1] + 1
199+
sub_cost = ldm[i - 1, j - 1] + cost
200+
201+
best = del_cost
202+
if ins_cost < best:
203+
best = ins_cost
204+
if sub_cost < best:
205+
best = sub_cost
206+
207+
ldm[i, j] = best
208+
209+
ld = ldm[m, n]
210+
wer = (<double>ld) / m if m > 0 else 0.0
211+
212+
# Backtrace to count errors (no word tracking)
213+
insertions, deletions, substitutions = 0, 0, 0
214+
i, j = m, n
215+
while i > 0 or j > 0:
216+
if i > 0 and j > 0 and reference_word[i - 1] == hypothesis_word[j - 1]:
217+
i -= 1
218+
j -= 1
219+
else:
220+
if i > 0 and j > 0 and ldm[i, j] == ldm[i - 1, j - 1] + 1:
221+
substitutions += 1
222+
i -= 1
223+
j -= 1
224+
elif j > 0 and ldm[i, j] == ldm[i, j - 1] + 1:
225+
insertions += 1
226+
j -= 1
227+
elif i > 0 and ldm[i, j] == ldm[i - 1, j] + 1:
228+
deletions += 1
229+
i -= 1
230+
231+
return np.array(
232+
[wer, <double>ld, <double>m,
233+
<double>insertions, <double>deletions, <double>substitutions],
234+
dtype=np.float64
235+
)
236+
237+
238+
@cython.boundscheck(False)
239+
@cython.wraparound(False)
240+
cdef cnp.ndarray _metrics_batch_fast(list references, list hypotheses):
241+
"""
242+
Fast batch processing without word tracking.
243+
244+
Returns (n, 6) float64 array where each row contains:
245+
[wer, ld, m, insertions, deletions, substitutions]
246+
"""
247+
cdef Py_ssize_t n = len(references)
248+
cdef Py_ssize_t idx
249+
250+
cdef cnp.ndarray out = np.empty((n, 6), dtype=np.float64)
251+
252+
cdef cnp.ndarray r
253+
for idx in range(n):
254+
r = calculations_fast(references[idx], hypotheses[idx])
255+
out[idx, :] = r
256+
257+
return out
258+
259+
260+
cpdef object metrics_fast(object reference, object hypothesis):
261+
"""
262+
Fast metrics entry point without word tracking.
263+
264+
Returns:
265+
- strings: (6,) float64 array [wer, ld, m, insertions, deletions, substitutions]
266+
- sequences: (n, 6) float64 array, one row per pair
267+
"""
268+
if isinstance(reference, (list, np.ndarray)) and isinstance(hypothesis, (list, np.ndarray)):
269+
return _metrics_batch_fast(list(reference), list(hypothesis))
270+
return calculations_fast(reference, hypothesis)

werpy/summary.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,11 @@ def summary(reference, hypothesis) -> pd.DataFrame | None:
5353
"""
5454
try:
5555
error_handler(reference, hypothesis)
56+
result = metrics(reference, hypothesis)
5657
except (ValueError, AttributeError, ZeroDivisionError) as err:
5758
print(f"{type(err).__name__}: {str(err)}")
5859
return None
5960

60-
result = metrics(reference, hypothesis)
61-
6261
# Batch rows (n, 9)
6362
if isinstance(result, np.ndarray) and result.ndim == 2:
6463
word_error_rate_breakdown = result.tolist()

werpy/summaryp.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,11 @@ def summaryp(
6767
"""
6868
try:
6969
error_handler(reference, hypothesis)
70+
result = metrics(reference, hypothesis)
7071
except (ValueError, AttributeError, ZeroDivisionError) as err:
7172
print(f"{type(err).__name__}: {str(err)}")
7273
return None
7374

74-
result = metrics(reference, hypothesis)
75-
7675
# Batch rows (n, 9)
7776
if isinstance(result, np.ndarray) and result.ndim == 2:
7877
word_error_rate_breakdown = result.tolist()

werpy/wer.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import numpy as np
1414
from .errorhandler import error_handler
15-
from .metrics import metrics
15+
from .metrics import metrics_fast
1616

1717

1818
def wer(reference, hypothesis) -> float | np.float64 | None:
@@ -57,20 +57,18 @@ def wer(reference, hypothesis) -> float | np.float64 | None:
5757
"""
5858
try:
5959
error_handler(reference, hypothesis)
60+
result = metrics_fast(reference, hypothesis)
6061
except (ValueError, AttributeError, ZeroDivisionError) as err:
6162
print(f"{type(err).__name__}: {str(err)}")
6263
return None
6364

64-
result = metrics(reference, hypothesis)
65-
66-
# Batch rows (n, 9)
65+
# Batch: (n, 6) float64
6766
if isinstance(result, np.ndarray) and result.ndim == 2:
68-
ld_total = float(np.sum(result[:, 1]))
69-
m_total = float(np.sum(result[:, 2]))
70-
return ld_total / m_total
67+
den = np.sum(result[:, 2])
68+
return float(np.sum(result[:, 1]) / den) if den else 0.0
7169

72-
# Single row
70+
# Single: (6,) float64, WER is at index 0
7371
if isinstance(result, np.ndarray) and getattr(result, "ndim", 0) == 0:
7472
result = result.item()
7573

76-
return float(result[1]) / float(result[2])
74+
return float(result[0])

werpy/werp.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import numpy as np
1313
from .errorhandler import error_handler
14-
from .metrics import metrics
14+
from .metrics import metrics_fast
1515

1616

1717
def werp(
@@ -75,27 +75,28 @@ def werp(
7575
"""
7676
try:
7777
error_handler(reference, hypothesis)
78+
result = metrics_fast(reference, hypothesis)
7879
except (ValueError, AttributeError, ZeroDivisionError) as err:
7980
print(f"{type(err).__name__}: {str(err)}")
8081
return None
8182

82-
result = metrics(reference, hypothesis)
83-
84-
# Batch rows (n, 9)
83+
# Batch: (n, 6) float64
8584
if isinstance(result, np.ndarray) and result.ndim == 2:
86-
weighted_insertions = np.sum(result[:, 3]) * insertions_weight
87-
weighted_deletions = np.sum(result[:, 4]) * deletions_weight
88-
weighted_substitutions = np.sum(result[:, 5]) * substitutions_weight
89-
m = np.sum(result[:, 2])
90-
else:
91-
# Single row
92-
if isinstance(result, np.ndarray) and getattr(result, "ndim", 0) == 0:
93-
result = result.item()
94-
weighted_insertions = result[3] * insertions_weight
95-
weighted_deletions = result[4] * deletions_weight
96-
weighted_substitutions = result[5] * substitutions_weight
97-
m = result[2]
98-
85+
weighted_insertions = result[:, 3] * insertions_weight
86+
weighted_deletions = result[:, 4] * deletions_weight
87+
weighted_substitutions = result[:, 5] * substitutions_weight
88+
m = result[:, 2]
89+
weighted_errors = weighted_insertions + weighted_deletions + weighted_substitutions
90+
den = np.sum(m)
91+
return float(np.sum(weighted_errors) / den) if den else 0.0
92+
93+
# Single: (6,) float64
94+
if isinstance(result, np.ndarray) and getattr(result, "ndim", 0) == 0:
95+
result = result.item()
96+
97+
weighted_insertions = result[3] * insertions_weight
98+
weighted_deletions = result[4] * deletions_weight
99+
weighted_substitutions = result[5] * substitutions_weight
100+
m = result[2]
99101
weighted_errors = weighted_insertions + weighted_deletions + weighted_substitutions
100-
werp_result = float(weighted_errors / m) if m else 0.0
101-
return werp_result
102+
return float(weighted_errors / m) if m else 0.0

werpy/werps.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import numpy as np
1313
from .errorhandler import error_handler
14-
from .metrics import metrics
14+
from .metrics import metrics_fast
1515

1616

1717
def werps(
@@ -69,22 +69,24 @@ def werps(
6969
"""
7070
try:
7171
error_handler(reference, hypothesis)
72+
result = metrics_fast(reference, hypothesis)
7273
except (ValueError, AttributeError, ZeroDivisionError) as err:
7374
print(f"{type(err).__name__}: {str(err)}")
7475
return None
7576

76-
result = metrics(reference, hypothesis)
77-
78-
# Batch rows (n, 9)
77+
# Batch: (n, 6) float64
7978
if isinstance(result, np.ndarray) and result.ndim == 2:
8079
weighted_insertions = result[:, 3] * insertions_weight
8180
weighted_deletions = result[:, 4] * deletions_weight
8281
weighted_substitutions = result[:, 5] * substitutions_weight
8382
m = result[:, 2]
8483
weighted_errors = weighted_insertions + weighted_deletions + weighted_substitutions
85-
return (weighted_errors / m).tolist()
84+
out = np.zeros_like(weighted_errors, dtype=np.float64)
85+
mask = m != 0
86+
out[mask] = weighted_errors[mask] / m[mask]
87+
return out.tolist()
8688

87-
# Single row
89+
# Single: (6,) float64
8890
if isinstance(result, np.ndarray) and getattr(result, "ndim", 0) == 0:
8991
result = result.item()
9092

werpy/wers.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import numpy as np
1313
from .errorhandler import error_handler
14-
from .metrics import metrics
14+
from .metrics import metrics_fast
1515

1616

1717
def wers(reference, hypothesis):
@@ -50,17 +50,16 @@ def wers(reference, hypothesis):
5050
"""
5151
try:
5252
error_handler(reference, hypothesis)
53+
result = metrics_fast(reference, hypothesis)
5354
except (ValueError, AttributeError, ZeroDivisionError) as err:
5455
print(f"{type(err).__name__}: {str(err)}")
5556
return None
5657

57-
result = metrics(reference, hypothesis)
58-
59-
# Batch rows (n, 9)
58+
# Batch: (n, 6) float64
6059
if isinstance(result, np.ndarray) and result.ndim == 2:
61-
return [float(x) for x in result[:, 0].tolist()]
60+
return result[:, 0].tolist()
6261

63-
# Single row
62+
# Single: (6,) float64, WER is at index 0
6463
if isinstance(result, np.ndarray) and getattr(result, "ndim", 0) == 0:
6564
result = result.item()
6665

0 commit comments

Comments
 (0)