Skip to content

Commit 99c57fd

Browse files
williamfisetclaude
andcommitted
Refactor ChineseRemainderTheorem: reuse PrimeFactorization, add tests
Remove duplicated primeFactorization, pollardRho, isPrime, and gcf methods in favor of the shared PrimeFactorization class. Simplify prime-power grouping loop and clean up Javadoc. Add 17 JUnit 5 tests covering eliminateCoefficient, crt, reduce, egcd, and full reduce+crt integration. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 92284a1 commit 99c57fd

3 files changed

Lines changed: 236 additions & 114 deletions

File tree

Lines changed: 58 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,14 @@
11
/**
2-
* Use the chinese remainder theorem to solve a set of congruence equations.
2+
* Solve a set of congruence equations using the Chinese Remainder Theorem.
33
*
4-
* <p>The first method (eliminateCoefficient) is used to reduce an equation of the form cx≡a(mod
5-
* m)cx≡a(mod m) to the form x≡a_new(mod m_new)x≡anew(mod m_new), which gets rids of the
6-
* coefficient. A value of null is returned if the coefficient cannot be eliminated.
4+
* <p>{@link #eliminateCoefficient(long, long, long)} reduces cx ≡ a (mod m) to x ≡ a' (mod m'),
5+
* removing the coefficient. Returns null if unsolvable.
76
*
8-
* <p>The second method (reduce) is used to reduce a set of equations so that the moduli become
9-
* pairwise co-prime (which means that we can apply the Chinese Remainder Theorem). The input and
10-
* output are of the form x≡a_0(mod m_0),...,x≡a_n−1(mod m_n−1)x≡a_0(mod m_0),...,x≡a_n−1(mod
11-
* m_n−1). Note that the number of equations may change during this process. A value of null is
12-
* returned if the set of equations cannot be reduced to co-prime moduli.
7+
* <p>{@link #reduce(long[], long[])} rewrites a system so that moduli become pairwise coprime,
8+
* enabling CRT. The number of equations may change. Returns null if the system is inconsistent.
139
*
14-
* <p>The third method (crt) is the actual Chinese Remainder Theorem. It assumes that all pairs of
15-
* moduli are co-prime to one another. This solves a set of equations of the form x≡a_0(mod
16-
* m_0),...,x≡v_n−1(mod m_n−1)x≡a_0(mod m_0),...,x≡v_n−1(mod m_n−1). It's output is of the form
17-
* x≡a_new(mod m_new)x≡a_new(mod m_new).
10+
* <p>{@link #crt(long[], long[])} is the core CRT step. It assumes all moduli are pairwise coprime
11+
* and returns x ≡ a' (mod M) where M is the product of all moduli.
1812
*
1913
* @author Micah Stairs
2014
*/
@@ -24,12 +18,16 @@
2418

2519
public class ChineseRemainderTheorem {
2620

27-
// eliminateCoefficient() takes cx≡a(mod m) and gives x≡a_new(mod m_new).
21+
/**
22+
* Reduces cx ≡ a (mod m) to x ≡ a' (mod m').
23+
*
24+
* @return {a', m'} or null if unsolvable.
25+
*/
2826
public static long[] eliminateCoefficient(long c, long a, long m) {
29-
3027
long d = egcd(c, m)[0];
3128

32-
if (a % d != 0) return null;
29+
if (a % d != 0)
30+
return null;
3331

3432
c /= d;
3533
a /= d;
@@ -42,36 +40,35 @@ public static long[] eliminateCoefficient(long c, long a, long m) {
4240
return new long[] {a, m};
4341
}
4442

45-
// reduce() takes a set of equations and reduces them to an equivalent
46-
// set with pairwise co-prime moduli (or null if not solvable).
43+
/**
44+
* Reduces a system x ≡ a[i] (mod m[i]) to an equivalent system with pairwise coprime moduli.
45+
*
46+
* @return {a[], m[]} with coprime moduli, or null if the system is inconsistent.
47+
*/
4748
public static long[][] reduce(long[] a, long[] m) {
49+
List<Long> aNew = new ArrayList<>();
50+
List<Long> mNew = new ArrayList<>();
4851

49-
List<Long> aNew = new ArrayList<Long>();
50-
List<Long> mNew = new ArrayList<Long>();
51-
52-
// Split up each equation into prime factors
52+
// Split each modulus into prime-power factors
5353
for (int i = 0; i < a.length; i++) {
54-
List<Long> factors = primeFactorization(m[i]);
54+
List<Long> factors = PrimeFactorization.primeFactorization(m[i]);
5555
Collections.sort(factors);
56-
ListIterator<Long> iterator = factors.listIterator();
57-
while (iterator.hasNext()) {
58-
long val = iterator.next();
59-
long total = val;
60-
while (iterator.hasNext()) {
61-
long nextVal = iterator.next();
62-
if (nextVal == val) {
63-
total *= val;
64-
} else {
65-
iterator.previous();
66-
break;
67-
}
56+
57+
int j = 0;
58+
while (j < factors.size()) {
59+
long p = factors.get(j);
60+
long pk = p;
61+
j++;
62+
while (j < factors.size() && factors.get(j) == p) {
63+
pk *= p;
64+
j++;
6865
}
69-
aNew.add(a[i] % total);
70-
mNew.add(total);
66+
aNew.add(a[i] % pk);
67+
mNew.add(pk);
7168
}
7269
}
7370

74-
// Throw away repeated information and look for conflicts
71+
// Remove redundant equations and detect conflicts
7572
for (int i = 0; i < aNew.size(); i++) {
7673
for (int j = i + 1; j < aNew.size(); j++) {
7774
if (mNew.get(i) % mNew.get(j) == 0 || mNew.get(j) % mNew.get(i) == 0) {
@@ -81,111 +78,58 @@ public static long[][] reduce(long[] a, long[] m) {
8178
mNew.remove(j);
8279
j--;
8380
continue;
84-
} else return null;
81+
} else
82+
return null;
8583
} else {
8684
if ((aNew.get(j) % mNew.get(i)) == aNew.get(i)) {
8785
aNew.remove(i);
8886
mNew.remove(i);
8987
i--;
9088
break;
91-
} else return null;
89+
} else
90+
return null;
9291
}
9392
}
9493
}
9594
}
9695

97-
// Put result into an array
9896
long[][] res = new long[2][aNew.size()];
9997
for (int i = 0; i < aNew.size(); i++) {
10098
res[0][i] = aNew.get(i);
10199
res[1][i] = mNew.get(i);
102100
}
103-
104101
return res;
105102
}
106103

104+
/**
105+
* Solves x ≡ a[i] (mod m[i]) assuming all moduli are pairwise coprime.
106+
*
107+
* @return {x, M} where M is the product of all moduli.
108+
*/
107109
public static long[] crt(long[] a, long[] m) {
108-
109110
long M = 1;
110-
for (int i = 0; i < m.length; i++) M *= m[i];
111-
112-
long[] inv = new long[a.length];
113-
for (int i = 0; i < inv.length; i++) inv[i] = egcd(M / m[i], m[i])[1];
111+
for (long mi : m)
112+
M *= mi;
114113

115114
long x = 0;
116115
for (int i = 0; i < m.length; i++) {
117-
x += (M / m[i]) * a[i] * inv[i]; // Overflow could occur here
116+
long Mi = M / m[i];
117+
long inv = egcd(Mi, m[i])[1];
118+
x += Mi * a[i] * inv;
118119
x = ((x % M) + M) % M;
119120
}
120121

121122
return new long[] {x, M};
122123
}
123124

124-
private static ArrayList<Long> primeFactorization(long n) {
125-
ArrayList<Long> factors = new ArrayList<Long>();
126-
if (n <= 0) throw new IllegalArgumentException();
127-
else if (n == 1) return factors;
128-
PriorityQueue<Long> divisorQueue = new PriorityQueue<Long>();
129-
divisorQueue.add(n);
130-
while (!divisorQueue.isEmpty()) {
131-
long divisor = divisorQueue.remove();
132-
if (isPrime(divisor)) {
133-
factors.add(divisor);
134-
continue;
135-
}
136-
long next_divisor = pollardRho(divisor);
137-
if (next_divisor == divisor) {
138-
divisorQueue.add(divisor);
139-
} else {
140-
divisorQueue.add(next_divisor);
141-
divisorQueue.add(divisor / next_divisor);
142-
}
143-
}
144-
return factors;
145-
}
146-
147-
private static long pollardRho(long n) {
148-
if (n % 2 == 0) return 2;
149-
// Get a number in the range [2, 10^6]
150-
long x = 2 + (long) (999999 * Math.random());
151-
long c = 2 + (long) (999999 * Math.random());
152-
long y = x;
153-
long d = 1;
154-
while (d == 1) {
155-
x = (x * x + c) % n;
156-
y = (y * y + c) % n;
157-
y = (y * y + c) % n;
158-
d = gcf(Math.abs(x - y), n);
159-
if (d == n) break;
160-
}
161-
return d;
162-
}
163-
164-
// Extended euclidean algorithm
165-
private static long[] egcd(long a, long b) {
166-
if (b == 0) return new long[] {a, 1, 0};
167-
else {
168-
long[] ret = egcd(b, a % b);
169-
long tmp = ret[1] - ret[2] * (a / b);
170-
ret[1] = ret[2];
171-
ret[2] = tmp;
172-
return ret;
173-
}
174-
}
175-
176-
private static long gcf(long a, long b) {
177-
return b == 0 ? a : gcf(b, a % b);
178-
}
179-
180-
private static boolean isPrime(long n) {
181-
if (n < 2) return false;
182-
if (n == 2 || n == 3) return true;
183-
if (n % 2 == 0 || n % 3 == 0) return false;
184-
185-
int limit = (int) Math.sqrt(n);
186-
187-
for (int i = 5; i <= limit; i += 6) if (n % i == 0 || n % (i + 2) == 0) return false;
188-
189-
return true;
125+
/** Extended Euclidean algorithm. Returns {gcd(a,b), x, y} where ax + by = gcd(a,b). */
126+
static long[] egcd(long a, long b) {
127+
if (b == 0)
128+
return new long[] {a, 1, 0};
129+
long[] ret = egcd(b, a % b);
130+
long tmp = ret[1] - ret[2] * (a / b);
131+
ret[1] = ret[2];
132+
ret[2] = tmp;
133+
return ret;
190134
}
191135
}

src/test/java/com/williamfiset/algorithms/math/BUILD

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,14 @@ java_test(
4747
runtime_deps = JUNIT5_RUNTIME_DEPS,
4848
deps = TEST_DEPS,
4949
)
50+
51+
# bazel test //src/test/java/com/williamfiset/algorithms/math:ChineseRemainderTheoremTest
52+
java_test(
53+
name = "ChineseRemainderTheoremTest",
54+
srcs = ["ChineseRemainderTheoremTest.java"],
55+
main_class = "org.junit.platform.console.ConsoleLauncher",
56+
use_testrunner = False,
57+
args = ["--select-class=com.williamfiset.algorithms.math.ChineseRemainderTheoremTest"],
58+
runtime_deps = JUNIT5_RUNTIME_DEPS,
59+
deps = TEST_DEPS,
60+
)

0 commit comments

Comments
 (0)