Skip to content

Commit 6d92a19

Browse files
committed
More ML-KEM cast fixes
1 parent ee914d9 commit 6d92a19

1 file changed

Lines changed: 61 additions & 71 deletions

File tree

wolfcrypt/src/wc_mlkem_poly.c

Lines changed: 61 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1816,7 +1816,7 @@ static void mlkem_keygen_c(sword16* s, sword16* t, sword16* e, const sword16* a,
18161816
/* For each polynomial in the vectors.
18171817
* Step 17, Step 18: Calculate public from A_hat, s_hat and e_hat. */
18181818
for (i = 0; i < k; ++i) {
1819-
unsigned int j;
1819+
int j;
18201820

18211821
/* Multiply a by private into public polynomial.
18221822
* Step 18: ... A_hat o s_hat ... */
@@ -1825,8 +1825,8 @@ static void mlkem_keygen_c(sword16* s, sword16* t, sword16* e, const sword16* a,
18251825
/* Convert public polynomial to Montgomery form.
18261826
* Step 18: ... MontRed(A_hat o s_hat) ... */
18271827
for (j = 0; j < MLKEM_N; ++j) {
1828-
sword32 n = t[(unsigned int)i * MLKEM_N + j] * (sword32)MLKEM_F;
1829-
t[(unsigned int)i * MLKEM_N + j] = MLKEM_MONT_RED(n);
1828+
sword32 n = t[i * MLKEM_N + j] * (sword32)MLKEM_F;
1829+
t[i * MLKEM_N + j] = MLKEM_MONT_RED(n);
18301830
}
18311831
/* Transform error values polynomial.
18321832
* Step 17: e_hat = NTT(e) */
@@ -1835,9 +1835,8 @@ static void mlkem_keygen_c(sword16* s, sword16* t, sword16* e, const sword16* a,
18351835
/* Add errors to public key and reduce.
18361836
* Step 18: t_hat = BarrettRed(MontRed(A_hat o s_hat) + e_hat) */
18371837
for (j = 0; j < MLKEM_N; ++j) {
1838-
sword16 n = (sword16)(t[(unsigned int)i * MLKEM_N + j] +
1839-
e[(unsigned int)i * MLKEM_N + j]);
1840-
t[(unsigned int)i * MLKEM_N + j] = MLKEM_BARRETT_RED(n);
1838+
sword16 n = (sword16)(t[i * MLKEM_N + j] + e[i * MLKEM_N + j]);
1839+
t[i * MLKEM_N + j] = MLKEM_BARRETT_RED(n);
18411840
}
18421841
#else
18431842
/* Add errors to public key and reduce.
@@ -1919,7 +1918,7 @@ int mlkem_keygen_seeds(sword16* s, sword16* t, MLKEM_PRF_T* prf,
19191918
/* For each polynomial in the vectors.
19201919
* Step 17, Step 18: Calculate public from A_hat, s_hat and e_hat. */
19211920
for (i = 0; i < k; ++i) {
1922-
unsigned int j;
1921+
int j;
19231922

19241923
/* Generate a vector of matrix A.
19251924
* Steps 4-6: generate A[i] */
@@ -1934,8 +1933,8 @@ int mlkem_keygen_seeds(sword16* s, sword16* t, MLKEM_PRF_T* prf,
19341933
/* Convert public polynomial to Montgomery form.
19351934
* Step 18: ... MontRed(A_hat o s_hat) ... */
19361935
for (j = 0; j < MLKEM_N; ++j) {
1937-
sword32 n = t[(unsigned int)i * MLKEM_N + j] * (sword32)MLKEM_F;
1938-
t[(unsigned int)i * MLKEM_N + j] = MLKEM_MONT_RED(n);
1936+
sword32 n = t[i * MLKEM_N + j] * (sword32)MLKEM_F;
1937+
t[i * MLKEM_N + j] = MLKEM_MONT_RED(n);
19391938
}
19401939

19411940
/* Generate noise using PRF.
@@ -1951,8 +1950,8 @@ int mlkem_keygen_seeds(sword16* s, sword16* t, MLKEM_PRF_T* prf,
19511950
/* Add errors to public key and reduce.
19521951
* Step 18: t_hat = BarrettRed(MontRed(A_hat o s_hat) + e_hat) */
19531952
for (j = 0; j < MLKEM_N; ++j) {
1954-
sword16 n = (sword16)(t[(unsigned int)i * MLKEM_N + j] + e[j]);
1955-
t[(unsigned int)i * MLKEM_N + j] = MLKEM_BARRETT_RED(n);
1953+
sword16 n = (sword16)(t[i * MLKEM_N + j] + e[j]);
1954+
t[i * MLKEM_N + j] = MLKEM_BARRETT_RED(n);
19561955
}
19571956
#else
19581957
/* Add errors to public key and reduce.
@@ -1996,7 +1995,7 @@ static void mlkem_encapsulate_c(const sword16* pub, sword16* u, sword16* v,
19961995

19971996
/* For each polynomial in the vectors. */
19981997
for (i = 0; i < k; ++i) {
1999-
unsigned int j;
1998+
int j;
20001999

20012000
/* Multiply at by y into u polynomial. */
20022001
mlkem_pointwise_acc_mont(u + i * MLKEM_N, a + i * k * MLKEM_N, y,
@@ -2006,36 +2005,35 @@ static void mlkem_encapsulate_c(const sword16* pub, sword16* u, sword16* v,
20062005
/* Add errors to u and reduce. */
20072006
#if defined(WOLFSSL_MLKEM_SMALL) || defined(WOLFSSL_MLKEM_NO_LARGE_CODE)
20082007
for (j = 0; j < MLKEM_N; ++j) {
2009-
sword16 t = (sword16)(u[(unsigned int)i * MLKEM_N + j] +
2010-
e1[(unsigned int)i * MLKEM_N + j]);
2011-
u[(unsigned int)i * MLKEM_N + j] = MLKEM_BARRETT_RED(t);
2008+
sword16 t = (sword16)(u[i * MLKEM_N + j] + e1[i * MLKEM_N + j]);
2009+
u[i * MLKEM_N + j] = MLKEM_BARRETT_RED(t);
20122010
}
20132011
#else
20142012
for (j = 0; j < MLKEM_N; j += 8) {
2015-
sword16 t0 = (sword16)(u[(unsigned int)i * MLKEM_N + j + 0] +
2016-
e1[(unsigned int)i * MLKEM_N + j + 0]);
2017-
sword16 t1 = (sword16)(u[(unsigned int)i * MLKEM_N + j + 1] +
2018-
e1[(unsigned int)i * MLKEM_N + j + 1]);
2019-
sword16 t2 = (sword16)(u[(unsigned int)i * MLKEM_N + j + 2] +
2020-
e1[(unsigned int)i * MLKEM_N + j + 2]);
2021-
sword16 t3 = (sword16)(u[(unsigned int)i * MLKEM_N + j + 3] +
2022-
e1[(unsigned int)i * MLKEM_N + j + 3]);
2023-
sword16 t4 = (sword16)(u[(unsigned int)i * MLKEM_N + j + 4] +
2024-
e1[(unsigned int)i * MLKEM_N + j + 4]);
2025-
sword16 t5 = (sword16)(u[(unsigned int)i * MLKEM_N + j + 5] +
2026-
e1[(unsigned int)i * MLKEM_N + j + 5]);
2027-
sword16 t6 = (sword16)(u[(unsigned int)i * MLKEM_N + j + 6] +
2028-
e1[(unsigned int)i * MLKEM_N + j + 6]);
2029-
sword16 t7 = (sword16)(u[(unsigned int)i * MLKEM_N + j + 7] +
2030-
e1[(unsigned int)i * MLKEM_N + j + 7]);
2031-
u[(unsigned int)i * MLKEM_N + j + 0] = MLKEM_BARRETT_RED(t0);
2032-
u[(unsigned int)i * MLKEM_N + j + 1] = MLKEM_BARRETT_RED(t1);
2033-
u[(unsigned int)i * MLKEM_N + j + 2] = MLKEM_BARRETT_RED(t2);
2034-
u[(unsigned int)i * MLKEM_N + j + 3] = MLKEM_BARRETT_RED(t3);
2035-
u[(unsigned int)i * MLKEM_N + j + 4] = MLKEM_BARRETT_RED(t4);
2036-
u[(unsigned int)i * MLKEM_N + j + 5] = MLKEM_BARRETT_RED(t5);
2037-
u[(unsigned int)i * MLKEM_N + j + 6] = MLKEM_BARRETT_RED(t6);
2038-
u[(unsigned int)i * MLKEM_N + j + 7] = MLKEM_BARRETT_RED(t7);
2013+
sword16 t0 = (sword16)(u[i * MLKEM_N + j + 0] +
2014+
e1[i * MLKEM_N + j + 0]);
2015+
sword16 t1 = (sword16)(u[i * MLKEM_N + j + 1] +
2016+
e1[i * MLKEM_N + j + 1]);
2017+
sword16 t2 = (sword16)(u[i * MLKEM_N + j + 2] +
2018+
e1[i * MLKEM_N + j + 2]);
2019+
sword16 t3 = (sword16)(u[i * MLKEM_N + j + 3] +
2020+
e1[i * MLKEM_N + j + 3]);
2021+
sword16 t4 = (sword16)(u[i * MLKEM_N + j + 4] +
2022+
e1[i * MLKEM_N + j + 4]);
2023+
sword16 t5 = (sword16)(u[i * MLKEM_N + j + 5] +
2024+
e1[i * MLKEM_N + j + 5]);
2025+
sword16 t6 = (sword16)(u[i * MLKEM_N + j + 6] +
2026+
e1[i * MLKEM_N + j + 6]);
2027+
sword16 t7 = (sword16)(u[i * MLKEM_N + j + 7] +
2028+
e1[i * MLKEM_N + j + 7]);
2029+
u[i * MLKEM_N + j + 0] = MLKEM_BARRETT_RED(t0);
2030+
u[i * MLKEM_N + j + 1] = MLKEM_BARRETT_RED(t1);
2031+
u[i * MLKEM_N + j + 2] = MLKEM_BARRETT_RED(t2);
2032+
u[i * MLKEM_N + j + 3] = MLKEM_BARRETT_RED(t3);
2033+
u[i * MLKEM_N + j + 4] = MLKEM_BARRETT_RED(t4);
2034+
u[i * MLKEM_N + j + 5] = MLKEM_BARRETT_RED(t5);
2035+
u[i * MLKEM_N + j + 6] = MLKEM_BARRETT_RED(t6);
2036+
u[i * MLKEM_N + j + 7] = MLKEM_BARRETT_RED(t7);
20392037
}
20402038
#endif
20412039
}
@@ -2112,7 +2110,7 @@ int mlkem_encapsulate_seeds(const sword16* pub, MLKEM_PRF_T* prf, sword16* u,
21122110

21132111
/* For each polynomial in the vectors. */
21142112
for (i = 0; i < k; ++i) {
2115-
unsigned int j;
2113+
int j;
21162114

21172115
/* Generate a vector of matrix A. */
21182116
ret = mlkem_gen_matrix_i(prf, a, k, seed, i, 1);
@@ -2133,35 +2131,27 @@ int mlkem_encapsulate_seeds(const sword16* pub, MLKEM_PRF_T* prf, sword16* u,
21332131
/* Add errors to u and reduce. */
21342132
#if defined(WOLFSSL_MLKEM_SMALL) || defined(WOLFSSL_MLKEM_NO_LARGE_CODE)
21352133
for (j = 0; j < MLKEM_N; ++j) {
2136-
sword16 t = (sword16)(u[(unsigned int)i * MLKEM_N + j] + e1[j]);
2137-
u[(unsigned int)i * MLKEM_N + j] = MLKEM_BARRETT_RED(t);
2134+
sword16 t = (sword16)(u[i * MLKEM_N + j] + e1[j]);
2135+
u[i * MLKEM_N + j] = MLKEM_BARRETT_RED(t);
21382136
}
21392137
#else
21402138
for (j = 0; j < MLKEM_N; j += 8) {
2141-
sword16 t0 = (sword16)(u[(unsigned int)i * MLKEM_N + j + 0] +
2142-
e1[j + 0]);
2143-
sword16 t1 = (sword16)(u[(unsigned int)i * MLKEM_N + j + 1] +
2144-
e1[j + 1]);
2145-
sword16 t2 = (sword16)(u[(unsigned int)i * MLKEM_N + j + 2] +
2146-
e1[j + 2]);
2147-
sword16 t3 = (sword16)(u[(unsigned int)i * MLKEM_N + j + 3] +
2148-
e1[j + 3]);
2149-
sword16 t4 = (sword16)(u[(unsigned int)i * MLKEM_N + j + 4] +
2150-
e1[j + 4]);
2151-
sword16 t5 = (sword16)(u[(unsigned int)i * MLKEM_N + j + 5] +
2152-
e1[j + 5]);
2153-
sword16 t6 = (sword16)(u[(unsigned int)i * MLKEM_N + j + 6] +
2154-
e1[j + 6]);
2155-
sword16 t7 = (sword16)(u[(unsigned int)i * MLKEM_N + j + 7] +
2156-
e1[j + 7]);
2157-
u[(unsigned int)i * MLKEM_N + j + 0] = MLKEM_BARRETT_RED(t0);
2158-
u[(unsigned int)i * MLKEM_N + j + 1] = MLKEM_BARRETT_RED(t1);
2159-
u[(unsigned int)i * MLKEM_N + j + 2] = MLKEM_BARRETT_RED(t2);
2160-
u[(unsigned int)i * MLKEM_N + j + 3] = MLKEM_BARRETT_RED(t3);
2161-
u[(unsigned int)i * MLKEM_N + j + 4] = MLKEM_BARRETT_RED(t4);
2162-
u[(unsigned int)i * MLKEM_N + j + 5] = MLKEM_BARRETT_RED(t5);
2163-
u[(unsigned int)i * MLKEM_N + j + 6] = MLKEM_BARRETT_RED(t6);
2164-
u[(unsigned int)i * MLKEM_N + j + 7] = MLKEM_BARRETT_RED(t7);
2139+
sword16 t0 = (sword16)(u[i * MLKEM_N + j + 0] + e1[j + 0]);
2140+
sword16 t1 = (sword16)(u[i * MLKEM_N + j + 1] + e1[j + 1]);
2141+
sword16 t2 = (sword16)(u[i * MLKEM_N + j + 2] + e1[j + 2]);
2142+
sword16 t3 = (sword16)(u[i * MLKEM_N + j + 3] + e1[j + 3]);
2143+
sword16 t4 = (sword16)(u[i * MLKEM_N + j + 4] + e1[j + 4]);
2144+
sword16 t5 = (sword16)(u[i * MLKEM_N + j + 5] + e1[j + 5]);
2145+
sword16 t6 = (sword16)(u[i * MLKEM_N + j + 6] + e1[j + 6]);
2146+
sword16 t7 = (sword16)(u[i * MLKEM_N + j + 7] + e1[j + 7]);
2147+
u[i * MLKEM_N + j + 0] = MLKEM_BARRETT_RED(t0);
2148+
u[i * MLKEM_N + j + 1] = MLKEM_BARRETT_RED(t1);
2149+
u[i * MLKEM_N + j + 2] = MLKEM_BARRETT_RED(t2);
2150+
u[i * MLKEM_N + j + 3] = MLKEM_BARRETT_RED(t3);
2151+
u[i * MLKEM_N + j + 4] = MLKEM_BARRETT_RED(t4);
2152+
u[i * MLKEM_N + j + 5] = MLKEM_BARRETT_RED(t5);
2153+
u[i * MLKEM_N + j + 6] = MLKEM_BARRETT_RED(t6);
2154+
u[i * MLKEM_N + j + 7] = MLKEM_BARRETT_RED(t7);
21652155
}
21662156
#endif
21672157
}
@@ -2444,11 +2434,11 @@ static int mlkem_gen_matrix_k3_avx2(sword16* a, byte* seed, int transposed)
24442434
for (k = 0; k < 2; k++) {
24452435
for (i = 0; i < 4; i++) {
24462436
if (!transposed) {
2447-
state[4*4 + i] = (word64)(0x1f0000 + (((k*4+i)/3) << 8) +
2437+
state[4*4 + i] = (word32)(0x1f0000 + (((k*4+i)/3) << 8) +
24482438
((k*4+i)%3));
24492439
}
24502440
else {
2451-
state[4*4 + i] = (word64)(0x1f0000 + (((k*4+i)%3) << 8) +
2441+
state[4*4 + i] = (word32)(0x1f0000 + (((k*4+i)%3) << 8) +
24522442
((k*4+i)/3));
24532443

24542444
}
@@ -2600,10 +2590,10 @@ static int mlkem_gen_matrix_k4_avx2(sword16* a, byte* seed, int transposed)
26002590
for (k = 0; k < 4; k++) {
26012591
for (i = 0; i < 4; i++) {
26022592
if (!transposed) {
2603-
state[4*4 + i] = (word64)(0x1f0000 + (k << 8) + i);
2593+
state[4*4 + i] = (word32)(0x1f0000 + (k << 8) + i);
26042594
}
26052595
else {
2606-
state[4*4 + i] = (word64)(0x1f0000 + (i << 8) + k);
2596+
state[4*4 + i] = (word32)(0x1f0000 + (i << 8) + k);
26072597
}
26082598
}
26092599

@@ -4116,7 +4106,7 @@ static void mlkem_get_noise_x4_eta2_avx2(byte* rand, byte* seed, byte o)
41164106
word64 state[25 * 4];
41174107

41184108
for (i = 0; i < 4; i++) {
4119-
state[4*4 + i] = (word64)(0x1f00 + i + o);
4109+
state[4*4 + i] = (word32)(0x1f00 + i + o);
41204110
}
41214111

41224112
sha3_256_blocksx4_seed_avx2(state, seed);

0 commit comments

Comments
 (0)