Skip to content

Commit f6345f0

Browse files
committed
fix(metrics): use cnp.int32_t* for portable buffer pointer types
1 parent b3f498e commit f6345f0

3 files changed

Lines changed: 185 additions & 11 deletions

File tree

werpy/metrics.pyx

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,177 @@ cpdef object metrics_fast(object reference, object hypothesis):
268268
if isinstance(reference, (list, np.ndarray)) and isinstance(hypothesis, (list, np.ndarray)):
269269
return _metrics_batch_fast(list(reference), list(hypothesis))
270270
return calculations_fast(reference, hypothesis)
271+
272+
273+
@cython.boundscheck(False)
274+
@cython.wraparound(False)
275+
cpdef cnp.ndarray calculations_wer_only(object reference, object hypothesis):
276+
"""
277+
WER-only fast path - 2-row DP (O(n) memory), no backtrace.
278+
Returns only [wer, ld, m] without error counts or word tracking.
279+
280+
This is the fastest path for pure WER calculation, using space-optimized
281+
Wagner-Fischer algorithm with rolling 2-row buffer instead of full matrix.
282+
283+
Returns (3,) float64 array: [wer, ld, m]
284+
"""
285+
cdef list reference_word = reference.split()
286+
cdef list hypothesis_word = hypothesis.split()
287+
288+
cdef Py_ssize_t m = len(reference_word)
289+
cdef Py_ssize_t n = len(hypothesis_word)
290+
291+
cdef Py_ssize_t i, j
292+
cdef int cost, del_cost, ins_cost, sub_cost, best, ld
293+
cdef double wer
294+
295+
cdef cnp.ndarray prev_arr = np.empty(n + 1, dtype=np.int32)
296+
cdef cnp.ndarray curr_arr = np.empty(n + 1, dtype=np.int32)
297+
298+
cdef int[:] prev = prev_arr
299+
cdef int[:] curr = curr_arr
300+
301+
for j in range(n + 1):
302+
prev[j] = <int>j
303+
304+
for i in range(1, m + 1):
305+
curr[0] = <int>i
306+
for j in range(1, n + 1):
307+
cost = 0 if reference_word[i - 1] == hypothesis_word[j - 1] else 1
308+
309+
del_cost = prev[j] + 1
310+
ins_cost = curr[j - 1] + 1
311+
sub_cost = prev[j - 1] + cost
312+
313+
best = del_cost
314+
if ins_cost < best:
315+
best = ins_cost
316+
if sub_cost < best:
317+
best = sub_cost
318+
319+
curr[j] = best
320+
321+
prev, curr = curr, prev
322+
323+
ld = prev[n]
324+
wer = (<double>ld) / m if m > 0 else 0.0
325+
326+
return np.array([wer, <double>ld, <double>m], dtype=np.float64)
327+
328+
329+
@cython.boundscheck(False)
330+
@cython.wraparound(False)
331+
cdef inline void _calculations_wer_only_reuse_ptr(
332+
object reference,
333+
object hypothesis,
334+
cnp.int32_t* prev,
335+
cnp.int32_t* curr,
336+
double* out3,
337+
) except *:
338+
"""
339+
Internal WER-only DP using caller-provided buffers and pointer swap (no copying).
340+
Writes: out3[0]=wer, out3[1]=ld, out3[2]=m
341+
342+
This implementation uses true pointer swapping instead of copying values,
343+
eliminating O(n) copy overhead per outer iteration.
344+
"""
345+
cdef list reference_word = reference.split()
346+
cdef list hypothesis_word = hypothesis.split()
347+
348+
cdef Py_ssize_t m = len(reference_word)
349+
cdef Py_ssize_t n = len(hypothesis_word)
350+
351+
cdef Py_ssize_t i, j
352+
cdef int cost, del_cost, ins_cost, sub_cost, best, ld
353+
cdef cnp.int32_t* tmp
354+
355+
# Initialize base row: prev[j] = j for j=0..n
356+
for j in range(n + 1):
357+
prev[j] = j
358+
359+
for i in range(1, m + 1):
360+
curr[0] = i
361+
for j in range(1, n + 1):
362+
cost = 0 if reference_word[i - 1] == hypothesis_word[j - 1] else 1
363+
364+
del_cost = prev[j] + 1
365+
ins_cost = curr[j - 1] + 1
366+
sub_cost = prev[j - 1] + cost
367+
368+
best = del_cost
369+
if ins_cost < best:
370+
best = ins_cost
371+
if sub_cost < best:
372+
best = sub_cost
373+
374+
curr[j] = best
375+
376+
# Swap prev and curr pointers (zero-cost operation)
377+
tmp = prev
378+
prev = curr
379+
curr = tmp
380+
381+
ld = prev[n]
382+
out3[0] = (<double>ld) / m if m > 0 else 0.0
383+
out3[1] = <double>ld
384+
out3[2] = <double>m
385+
386+
387+
@cython.boundscheck(False)
388+
@cython.wraparound(False)
389+
cdef cnp.ndarray _metrics_batch_wer_only(list references, list hypotheses):
390+
"""
391+
Fast batch processing for WER-only calculations with buffer reuse and pointer swapping.
392+
393+
Eliminates repeated buffer allocations by reusing prev/curr arrays across all pairs
394+
in the batch, sized to the maximum hypothesis length. Uses true pointer swapping
395+
instead of value copying for optimal performance.
396+
397+
Returns (n, 3) float64 array where each row contains:
398+
[wer, ld, m]
399+
"""
400+
cdef Py_ssize_t n_pairs = len(references)
401+
cdef Py_ssize_t idx
402+
403+
cdef cnp.ndarray out = np.empty((n_pairs, 3), dtype=np.float64)
404+
405+
# Find max hypothesis token length to size buffers once
406+
cdef Py_ssize_t max_n = 0
407+
cdef Py_ssize_t this_n
408+
cdef object h
409+
cdef list h_words
410+
for idx in range(n_pairs):
411+
h = hypotheses[idx]
412+
h_words = h.split()
413+
this_n = len(h_words)
414+
if this_n > max_n:
415+
max_n = this_n
416+
417+
# Allocate reusable DP buffers once for the entire batch
418+
cdef cnp.ndarray prev_arr = np.empty(max_n + 1, dtype=np.int32)
419+
cdef cnp.ndarray curr_arr = np.empty(max_n + 1, dtype=np.int32)
420+
421+
# Get raw pointers for zero-cost swapping
422+
cdef cnp.int32_t* prev = <cnp.int32_t*>cnp.PyArray_DATA(prev_arr)
423+
cdef cnp.int32_t* curr = <cnp.int32_t*>cnp.PyArray_DATA(curr_arr)
424+
425+
# Process each pair using shared buffers, writing directly to output rows
426+
cdef double* out_row
427+
for idx in range(n_pairs):
428+
out_row = <double*>cnp.PyArray_DATA(out) + (idx * 3)
429+
_calculations_wer_only_reuse_ptr(references[idx], hypotheses[idx], prev, curr, out_row)
430+
431+
return out
432+
433+
434+
cpdef object metrics_wer_only(object reference, object hypothesis):
435+
"""
436+
WER-only metrics entry point (fastest path).
437+
438+
Returns:
439+
- strings: (3,) float64 array [wer, ld, m]
440+
- sequences: (n, 3) float64 array, one row per pair
441+
"""
442+
if isinstance(reference, (list, np.ndarray)) and isinstance(hypothesis, (list, np.ndarray)):
443+
return _metrics_batch_wer_only(list(reference), list(hypothesis))
444+
return calculations_wer_only(reference, hypothesis)

werpy/wer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import numpy as np
1414
from .errorhandler import error_handler
15-
from .metrics import metrics_fast
15+
from .metrics import metrics_wer_only
1616

1717

1818
def wer(reference, hypothesis) -> float | np.float64 | None:
@@ -57,17 +57,17 @@ def wer(reference, hypothesis) -> float | np.float64 | None:
5757
"""
5858
try:
5959
error_handler(reference, hypothesis)
60-
result = metrics_fast(reference, hypothesis)
60+
result = metrics_wer_only(reference, hypothesis)
6161
except (ValueError, AttributeError, ZeroDivisionError) as err:
6262
print(f"{type(err).__name__}: {str(err)}")
6363
return None
6464

65-
# Batch: (n, 6) float64
65+
# Batch: (n, 3) float64, columns [wer, ld, m]
6666
if isinstance(result, np.ndarray) and result.ndim == 2:
67-
den = np.sum(result[:, 2])
68-
return float(np.sum(result[:, 1]) / den) if den else 0.0
67+
den = np.sum(result[:, 2]) # m column
68+
return float(np.sum(result[:, 1]) / den) if den else 0.0 # ld column
6969

70-
# Single: (6,) float64, WER is at index 0
70+
# Single: (3,) float64, WER is at index 0
7171
if isinstance(result, np.ndarray) and getattr(result, "ndim", 0) == 0:
7272
result = result.item()
7373

werpy/wers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import numpy as np
1313
from .errorhandler import error_handler
14-
from .metrics import metrics_fast
14+
from .metrics import metrics_wer_only
1515

1616

1717
def wers(reference, hypothesis):
@@ -50,16 +50,16 @@ def wers(reference, hypothesis):
5050
"""
5151
try:
5252
error_handler(reference, hypothesis)
53-
result = metrics_fast(reference, hypothesis)
53+
result = metrics_wer_only(reference, hypothesis)
5454
except (ValueError, AttributeError, ZeroDivisionError) as err:
5555
print(f"{type(err).__name__}: {str(err)}")
5656
return None
5757

58-
# Batch: (n, 6) float64
58+
# Batch: (n, 3) float64, columns [wer, ld, m]
5959
if isinstance(result, np.ndarray) and result.ndim == 2:
60-
return result[:, 0].tolist()
60+
return result[:, 0].tolist() # Return wer column
6161

62-
# Single: (6,) float64, WER is at index 0
62+
# Single: (3,) float64, WER is at index 0
6363
if isinstance(result, np.ndarray) and getattr(result, "ndim", 0) == 0:
6464
result = result.item()
6565

0 commit comments

Comments
 (0)