Skip to content

Commit 8e0709a

Browse files
committed
Address cconlon PR #334 review feedback
1 parent 655ea82 commit 8e0709a

8 files changed

Lines changed: 233 additions & 414 deletions

File tree

src/java/com/wolfssl/provider/jsse/WolfSSLEngine.java

Lines changed: 84 additions & 230 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ public class WolfSSLEngine extends SSLEngine {
137137

138138
/* Scratch buffer for ssl.read() plaintext. Reused across unwrap() calls
139139
* and expanded only when a larger output window requires it. */
140-
private byte[] recvAppDataBuf = new byte[16 * 1024];
140+
private byte[] recvAppDataBuf = new byte[WolfSSL.MAX_RECORD_SIZE];
141141

142142
/* TLS record header is: type(1) + version(2) + length(2) */
143143
private static final int TLS_RECORD_HEADER_LEN = 5;
@@ -146,7 +146,7 @@ public class WolfSSLEngine extends SSLEngine {
146146

147147
/* Default size of internalIOSendBuf, 16k to match TLS record size.
148148
* TODO - add upper bound on I/O send buf resize allocations. */
149-
private static final int INTERNAL_IOSEND_BUF_SZ = 16 * 1024;
149+
private static final int INTERNAL_IOSEND_BUF_SZ = WolfSSL.MAX_RECORD_SIZE;
150150
/* static buffer used to hold encrypted data to be sent, allocated inside
151151
* internalSendCb() and expanded only if needed. Synchronize on toSendLock
152152
* when accessing this buffer. */
@@ -374,8 +374,7 @@ private List<SNIServerName> parseRequestedServerNamesFromNetData() {
374374
}
375375

376376
if (ret == WolfSSL.NOT_COMPILED_IN) {
377-
return parseRequestedServerNamesFromTlsRecord(
378-
ByteBuffer.wrap(clientHello));
377+
return null;
379378
}
380379
} catch (IllegalArgumentException | WolfSSLException e) {
381380
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
@@ -386,184 +385,6 @@ private List<SNIServerName> parseRequestedServerNamesFromNetData() {
386385
return null;
387386
}
388387

389-
private static List<SNIServerName> parseRequestedServerNamesFromTlsRecord(
390-
ByteBuffer in) {
391-
392-
int recType;
393-
int recLen;
394-
int hsType;
395-
int hsLen;
396-
ByteBuffer recBody;
397-
ByteBuffer hello;
398-
399-
if (in.remaining() < 5) {
400-
return null;
401-
}
402-
403-
recType = getU8(in);
404-
skipU16(in);
405-
recLen = getU16(in);
406-
if (recType != 22 || recLen > in.remaining()) {
407-
return null;
408-
}
409-
410-
recBody = sliceBytes(in, recLen);
411-
if (recBody.remaining() < 4) {
412-
return null;
413-
}
414-
415-
hsType = getU8(recBody);
416-
hsLen = getU24(recBody);
417-
if (hsType != 1 || hsLen > recBody.remaining()) {
418-
return null;
419-
}
420-
421-
hello = sliceBytes(recBody, hsLen);
422-
return parseSniExtensionFromClientHello(hello);
423-
}
424-
425-
private static List<SNIServerName> parseSniExtensionFromClientHello(
426-
ByteBuffer in) {
427-
428-
int sidLen;
429-
int cipherLen;
430-
int compLen;
431-
int extLen;
432-
List<SNIServerName> names;
433-
434-
if (in.remaining() < 34) {
435-
return null;
436-
}
437-
438-
skipU16(in);
439-
skipBytes(in, 32);
440-
441-
sidLen = getU8(in);
442-
skipBytes(in, sidLen);
443-
444-
cipherLen = getU16(in);
445-
skipBytes(in, cipherLen);
446-
447-
compLen = getU8(in);
448-
skipBytes(in, compLen);
449-
450-
if (in.remaining() < 2) {
451-
return null;
452-
}
453-
454-
extLen = getU16(in);
455-
if (extLen > in.remaining()) {
456-
return null;
457-
}
458-
459-
names = parseServerNameExtensionList(sliceBytes(in, extLen));
460-
if (names == null || names.isEmpty()) {
461-
return null;
462-
}
463-
464-
return names;
465-
}
466-
467-
private static List<SNIServerName> parseServerNameExtensionList(
468-
ByteBuffer extBuf) {
469-
470-
while (extBuf.remaining() >= 4) {
471-
int extType = getU16(extBuf);
472-
int extLen = getU16(extBuf);
473-
ByteBuffer extData = sliceBytes(extBuf, extLen);
474-
475-
if (extType == 0) {
476-
return parseServerNameList(extData);
477-
}
478-
}
479-
480-
return null;
481-
}
482-
483-
private static List<SNIServerName> parseServerNameList(ByteBuffer in) {
484-
int listLen;
485-
List<SNIServerName> names = new ArrayList<SNIServerName>(1);
486-
487-
if (in.remaining() < 2) {
488-
return null;
489-
}
490-
491-
listLen = getU16(in);
492-
if (listLen > in.remaining()) {
493-
return null;
494-
}
495-
496-
in = sliceBytes(in, listLen);
497-
while (in.remaining() >= 3) {
498-
int nameType = getU8(in);
499-
int nameLen = getU16(in);
500-
byte[] name = getBytes(in, nameLen);
501-
502-
if (nameType == WolfSSL.WOLFSSL_SNI_HOST_NAME) {
503-
names.add(new SNIHostName(name));
504-
}
505-
}
506-
507-
return names;
508-
}
509-
510-
private static int getU8(ByteBuffer in) {
511-
if (in.remaining() < 1) {
512-
throw new IllegalArgumentException("short TLS field");
513-
}
514-
return in.get() & 0xFF;
515-
}
516-
517-
private static int getU16(ByteBuffer in) {
518-
if (in.remaining() < 2) {
519-
throw new IllegalArgumentException("short TLS field");
520-
}
521-
return ((in.get() & 0xFF) << 8) | (in.get() & 0xFF);
522-
}
523-
524-
private static int getU24(ByteBuffer in) {
525-
if (in.remaining() < 3) {
526-
throw new IllegalArgumentException("short TLS field");
527-
}
528-
return ((in.get() & 0xFF) << 16) |
529-
((in.get() & 0xFF) << 8) |
530-
(in.get() & 0xFF);
531-
}
532-
533-
private static void skipU16(ByteBuffer in) {
534-
skipBytes(in, 2);
535-
}
536-
537-
private static void skipBytes(ByteBuffer in, int len) {
538-
if (len < 0 || len > in.remaining()) {
539-
throw new IllegalArgumentException("invalid TLS length");
540-
}
541-
in.position(in.position() + len);
542-
}
543-
544-
private static ByteBuffer sliceBytes(ByteBuffer in, int len) {
545-
ByteBuffer out;
546-
547-
if (len < 0 || len > in.remaining()) {
548-
throw new IllegalArgumentException("invalid TLS length");
549-
}
550-
551-
out = in.slice();
552-
out.limit(len);
553-
in.position(in.position() + len);
554-
return out;
555-
}
556-
557-
private static byte[] getBytes(ByteBuffer in, int len) {
558-
if (len < 0 || len > in.remaining()) {
559-
throw new IllegalArgumentException("invalid TLS length");
560-
}
561-
562-
byte[] out = new byte[len];
563-
in.get(out);
564-
return out;
565-
}
566-
567388
private void clearPendingAppData() {
568389
if (this.pendingAppData != null) {
569390
Arrays.fill(this.pendingAppData, (byte)0);
@@ -1280,22 +1101,41 @@ private synchronized int RecvAppData(ByteBuffer[] out, int ofst, int length)
12801101
/* Calculate maximum output size across ByteBuffer arrays */
12811102
maxOutSz = getTotalOutputSize(out, ofst, length);
12821103

1283-
/* Read into intermediate buffer to detect
1284-
* BUFFER_OVERFLOW before writing to output */
1285-
int readSz = 16384; /* default max TLS record plaintext */
1286-
if (readSz < maxOutSz) {
1287-
readSz = maxOutSz;
1104+
/* Fast path: single output buffer large enough to hold a full
1105+
* TLS record. Read directly into ByteBuffer to avoid extra
1106+
* allocation and copy (matches original optimization). */
1107+
boolean directRead = (length == 1 && out[ofst] != null &&
1108+
out[ofst].remaining() >= WolfSSL.MAX_RECORD_SIZE);
1109+
1110+
byte[] tmp = null;
1111+
if (!directRead) {
1112+
/* Intermediate buffer path: needed for BUFFER_OVERFLOW
1113+
* detection and multi-buffer scatter writes */
1114+
int readSz = WolfSSL.MAX_RECORD_SIZE;
1115+
if (readSz < maxOutSz) {
1116+
readSz = maxOutSz;
1117+
}
1118+
tmp = getRecvAppDataBuf(readSz);
12881119
}
1289-
byte[] tmp = getRecvAppDataBuf(readSz);
12901120

12911121
synchronized (ioLock) {
12921122
try {
1293-
ret = this.ssl.read(tmp, readSz);
1123+
if (directRead) {
1124+
ret = this.ssl.read(out[ofst], maxOutSz, 0);
1125+
}
1126+
else {
1127+
ret = this.ssl.read(tmp, tmp.length);
1128+
}
12941129
if ((ret < 0) &&
12951130
(ssl.getError(ret) == WolfSSL.APP_DATA_READY)) {
12961131
/* If DTLS, we may need to call SSL_read() again
12971132
* right away again if app data was received */
1298-
ret = this.ssl.read(tmp, readSz);
1133+
if (directRead) {
1134+
ret = this.ssl.read(out[ofst], maxOutSz, 0);
1135+
}
1136+
else {
1137+
ret = this.ssl.read(tmp, tmp.length);
1138+
}
12991139
}
13001140
} catch (SocketTimeoutException | SocketException e) {
13011141
throw new SSLException(e);
@@ -1376,37 +1216,43 @@ private synchronized int RecvAppData(ByteBuffer[] out, int ofst, int length)
13761216
}
13771217
}
13781218
else {
1379-
if (ret > maxOutSz) {
1380-
/* Output too small, stash for next unwrap().
1381-
* Caller returns BUFFER_OVERFLOW. */
1382-
this.pendingAppData = new byte[ret];
1383-
System.arraycopy(tmp, 0, this.pendingAppData, 0, ret);
1384-
this.pendingAppDataLen = ret;
1385-
this.pendingNetConsumed = 0;
1386-
return 0; /* 0 bytes written to output */
1387-
}
1388-
1389-
/* Copy from intermediate buffer to output buffers */
1390-
for (i = 0; i < ret;) {
1391-
if (idx + ofst >= length) {
1392-
/* no more output buffers left */
1393-
break;
1219+
if (directRead) {
1220+
/* Data already in output buffer, just record count */
1221+
totalRead = ret;
1222+
}
1223+
else {
1224+
if (ret > maxOutSz) {
1225+
/* Output too small, stash for next unwrap().
1226+
* Caller returns BUFFER_OVERFLOW. */
1227+
this.pendingAppData = new byte[ret];
1228+
System.arraycopy(tmp, 0,
1229+
this.pendingAppData, 0, ret);
1230+
this.pendingAppDataLen = ret;
1231+
this.pendingNetConsumed = 0;
1232+
return 0; /* 0 bytes written to output */
13941233
}
13951234

1396-
bufSpace = out[idx + ofst].remaining();
1397-
if (bufSpace == 0) {
1398-
/* no more space in current out buffer, advance */
1399-
idx++;
1400-
continue;
1401-
}
1235+
/* Copy from intermediate buffer to output bufs */
1236+
for (i = 0; i < ret;) {
1237+
if (idx + ofst >= length) {
1238+
break;
1239+
}
14021240

1403-
sz = (bufSpace >= (ret - i)) ? (ret - i) : bufSpace;
1404-
out[idx + ofst].put(tmp, i, sz);
1405-
i += sz;
1406-
totalRead += sz;
1241+
bufSpace = out[idx + ofst].remaining();
1242+
if (bufSpace == 0) {
1243+
idx++;
1244+
continue;
1245+
}
1246+
1247+
sz = (bufSpace >= (ret - i)) ?
1248+
(ret - i) : bufSpace;
1249+
out[idx + ofst].put(tmp, i, sz);
1250+
i += sz;
1251+
totalRead += sz;
14071252

1408-
if ((ret - i) > 0) {
1409-
idx++; /* go to next output buffer */
1253+
if ((ret - i) > 0) {
1254+
idx++;
1255+
}
14101256
}
14111257
}
14121258
}
@@ -1416,6 +1262,7 @@ private synchronized int RecvAppData(ByteBuffer[] out, int ofst, int length)
14161262

14171263
private byte[] getRecvAppDataBuf(int minSz) {
14181264
if (this.recvAppDataBuf.length < minSz) {
1265+
Arrays.fill(this.recvAppDataBuf, (byte)0);
14191266
this.recvAppDataBuf = new byte[minSz];
14201267
}
14211268

@@ -1592,27 +1439,33 @@ else if (hs == SSLEngineResult.HandshakeStatus.NEED_WRAP &&
15921439
ret = DoHandshake(false);
15931440
}
15941441
else {
1595-
/* TLS-only fast path: peek at the record header and
1596-
* return BUFFER_UNDERFLOW before a JNI read when the
1597-
* full record is not present. DTLS still relies on the
1598-
* native WANT_READ path as the correctness fallback. */
1442+
/* TLS-only: peek at the record header and
1443+
* return BUFFER_UNDERFLOW before calling into
1444+
* JNI when the full record is not yet present.
1445+
* Without this, native wolfSSL consumes partial
1446+
* record bytes via the I/O callback, violating
1447+
* the JSSE contract that BUFFER_UNDERFLOW must
1448+
* report bytesConsumed() == 0. DTLS still
1449+
* relies on the native WANT_READ path. */
15991450
boolean bufferUnderflow = false;
1600-
if (inRemaining > 0 && (this.ssl.dtls() == 0)) {
1451+
if (inRemaining > 0 &&
1452+
(this.ssl.dtls() == 0)) {
16011453
synchronized (netDataLock) {
16021454
int pos = in.position();
1603-
if (inRemaining < TLS_RECORD_HEADER_LEN) {
1604-
/* Not enough for TLS record header */
1455+
if (inRemaining <
1456+
TLS_RECORD_HEADER_LEN) {
16051457
bufferUnderflow = true;
16061458
} else {
1607-
/* Peek at record length from header
1608-
* bytes 3-4 (big-endian) */
16091459
int recLen =
1610-
((in.get(pos + TLS_RECORD_LEN_HI_OFF)
1460+
((in.get(pos +
1461+
TLS_RECORD_LEN_HI_OFF)
16111462
& 0xFF) << 8)
1612-
| (in.get(pos + TLS_RECORD_LEN_LO_OFF)
1463+
| (in.get(pos +
1464+
TLS_RECORD_LEN_LO_OFF)
16131465
& 0xFF);
16141466
if (inRemaining <
1615-
TLS_RECORD_HEADER_LEN + recLen) {
1467+
TLS_RECORD_HEADER_LEN +
1468+
recLen) {
16161469
bufferUnderflow = true;
16171470
}
16181471
}
@@ -1656,7 +1509,8 @@ else if (hs == SSLEngineResult.HandshakeStatus.NEED_WRAP &&
16561509
}
16571510
}
16581511
else if (bufferUnderflow) {
1659-
status = SSLEngineResult.Status.BUFFER_UNDERFLOW;
1512+
status =
1513+
SSLEngineResult.Status.BUFFER_UNDERFLOW;
16601514
}
16611515
/* If we have input data, make sure output buffer
16621516
* length is greater than zero, otherwise ask app to

0 commit comments

Comments
 (0)