-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathspeed_comparison_librispeech_full.py
More file actions
80 lines (69 loc) · 2.75 KB
/
speed_comparison_librispeech_full.py
File metadata and controls
80 lines (69 loc) · 2.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from datasets import load_dataset
import werpy
import werx
import jiwer
# from torchmetrics.text import WordErrorRate as TorchWER # TODO: Uncomment when supports Python 3.14
import pywer
import evaluate
# import universal_edit_distance as ued # TODO: Uncomment when supports Python 3.14
import timeit
# Load the consolidated CSV from the Hugging Face Hub
dataset = load_dataset(
"analyticsinmotion/librispeech-eval",
data_files="all_splits.csv",
split="train"
)
# Specify which splits and model/version to evaluate
splits = ["test-clean", "test-other"]
model_name = "whisper-base"
model_version = "v20240930"
for split in splits:
print(f"\n{'='*70}")
print(f"Benchmarking: {split}")
print(f"{'='*70}\n")
# Filter references and hypotheses for the chosen split/model/version
filtered = dataset.filter(
lambda x: x["split"] == split and
x["model_name"] == model_name and
x["model_version"] == model_version
)
filtered = list(filtered)
references = [werpy.normalize(row["reference"]) for row in filtered]
hypotheses = [werpy.normalize(row["hypothesis"]) for row in filtered]
print(f"Loaded {len(references):,} utterances\n")
# --- WER tools ---
wer_metric = evaluate.load("wer")
tools = {
"WERX": werx.wer,
"WERPY": werpy.wer,
"JIWER": jiwer.wer,
# "TORCHMETRICS": lambda r, h: TorchWER()(r, h).item(), # TODO: Uncomment when supports Python 3.14
"PYWER": lambda r, h: pywer.wer(r, h) / 100.0, # pywer returns percent
"EVALUATE": lambda r, h: wer_metric.compute(predictions=h, references=r),
# "UED": lambda r, h: ued.word_error_rate(r, h), # TODO: Uncomment when supports Python 3.14
}
# --- Run + time each tool using timeit ---
results = []
n_repeats = 10 # Number of repeats for timeit
for name, func in tools.items():
def stmt():
return func(references, hypotheses)
total_time = timeit.timeit(stmt, number=n_repeats)
avg_time = total_time / n_repeats
wer = func(references, hypotheses)
results.append((name, wer, avg_time))
# --- Normalize by fastest average time ---
min_time = min(r[2] for r in results)
normalized_results = [
(name, wer, t, t / min_time) for name, wer, t in results
]
# --- Print CLI-friendly table ---
print("\n Word Error Rate Benchmark:\n")
print(f"{'Tool':<15} {'WER':<8} {'WER (%)':<10} {'Time (s)':<12} {'Norm Time':<18}")
print("-" * 70)
for name, wer, t, norm in normalized_results:
if name == "WERX":
norm_str = "1.00× (baseline)"
else:
norm_str = f"{norm:.2f}× slower"
print(f"{name:<15} {wer:.4f} {wer*100:6.2f}% {t:.6f} {norm_str:<18}")