Skip to content

Commit 8ae87a6

Browse files
authored
Better coverage (#320)
* added test for invalid merge_threshold * parameterized the basic greedy_solver tests such that both native and python version is tested.
1 parent 496f4bb commit 8ae87a6

2 files changed

Lines changed: 24 additions & 8 deletions

File tree

anonlink/solving/_multiparty_solving.pyx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ def probabilistic_greedy_solve_native(
6969
raise ValueError('inconsistent shape of index arrays')
7070
if len(dset_is_arrs) != 2:
7171
raise NotImplementedError('only binary solving is supported')
72+
if merge_threshold < 0 or merge_threshold > 1:
73+
raise ValueError('merge_threshold must be between 0 and 1')
7274

7375
dset_is0_arr, dset_is1_arr = dset_is_arrs
7476
rec_is0_arr, rec_is1_arr = rec_is_arrs

tests/test_solving.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,11 @@ def _compare_matching(result, truth):
3636
assert result_set == truth_set
3737

3838

39-
def test_greedy_twoparty():
39+
@pytest.mark.parametrize("greedy_solve", [greedy_solve_native, greedy_solve_python])
40+
def test_greedy_twoparty(greedy_solve):
4041
candidates = [(.8, ((0, 0), (1, 0)))]
4142
result = greedy_solve(_zip_candidates(candidates))
42-
_compare_matching(result, [{(0,0), (1,0)}])
43+
_compare_matching(result, [{(0, 0), (1, 0)}])
4344

4445
candidates = [(.8, ((0, 0), (1, 0))),
4546
(.7, ((0, 1), (1, 0)))]
@@ -55,11 +56,12 @@ def test_greedy_twoparty():
5556
(.7, ((0, 1), (1, 0))),
5657
(.6, ((0, 1), (1, 1)))]
5758
result = greedy_solve(_zip_candidates(candidates))
58-
_compare_matching(result, [{(0,0), (1,0)},
59-
{(0,1), (1,1)}])
59+
_compare_matching(result, [{(0, 0), (1, 0)},
60+
{(0, 1), (1, 1)}])
6061

6162

62-
def test_greedy_threeparty():
63+
@pytest.mark.parametrize("greedy_solve", [greedy_solve_native, greedy_solve_python])
64+
def test_greedy_threeparty(greedy_solve):
6365
candidates = [(.9, ((1, 0), (2, 0))),
6466
(.8, ((0, 0), (1, 1))),
6567
(.8, ((0, 0), (2, 1))),
@@ -99,7 +101,8 @@ def test_greedy_threeparty():
99101
_compare_matching(result, [{(0,0), (1,0), (2,0), (3,0), (4,0)}])
100102

101103

102-
def test_greedy_fourparty():
104+
@pytest.mark.parametrize("greedy_solve", [greedy_solve_native, greedy_solve_python])
105+
def test_greedy_fourparty(greedy_solve):
103106
candidates = [(.9, ((0, 0), (1, 0))),
104107
(.9, ((2, 0), (3, 0))),
105108
(.7, ((0, 0), (2, 0))),
@@ -110,7 +113,8 @@ def test_greedy_fourparty():
110113
_compare_matching(result, [{(0,0), (1,0), (2,0), (3,0)}])
111114

112115

113-
def test_inconsistent_dataset_number():
116+
@pytest.mark.parametrize("greedy_solve", [greedy_solve_native, greedy_solve_python])
117+
def test_inconsistent_dataset_number(greedy_solve):
114118
candidates = (
115119
array('d', [.5]),
116120
(array('I', [3]), array('I', [4])),
@@ -119,6 +123,15 @@ def test_inconsistent_dataset_number():
119123
greedy_solve(candidates)
120124

121125

126+
@pytest.mark.parametrize("prob_greedy_solve", [probabilistic_greedy_solve_native, probabilistic_greedy_solve_python])
127+
def test_wrong_merge_threshold(prob_greedy_solve):
128+
candidates = [(.8, ((0, 0), (1, 0)))]
129+
with pytest.raises(ValueError):
130+
prob_greedy_solve(_zip_candidates(candidates), merge_threshold=-0.1)
131+
with pytest.raises(ValueError):
132+
prob_greedy_solve(_zip_candidates(candidates), merge_threshold=1.01)
133+
134+
122135
@pytest.mark.parametrize('datasets_n', [0, 1, 3, 5])
123136
def test_unsupported_shape(datasets_n):
124137
candidates = (
@@ -129,7 +142,8 @@ def test_unsupported_shape(datasets_n):
129142
greedy_solve(candidates)
130143

131144

132-
def test_inconsistent_entry_number():
145+
@pytest.mark.parametrize("greedy_solve", [greedy_solve_native, greedy_solve_python])
146+
def test_inconsistent_entry_number(greedy_solve):
133147
candidates = (
134148
array('d', [.5, .3]),
135149
(array('I', [3]), array('I', [4])),

0 commit comments

Comments
 (0)