Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions native/com_wolfssl_WolfSSL.c
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ jmethodID g_bufferPositionMethodId = NULL;
jmethodID g_bufferLimitMethodId = NULL;
jmethodID g_bufferHasArrayMethodId = NULL;
jmethodID g_bufferArrayMethodId = NULL;
jmethodID g_bufferArrayOffsetMethodId = NULL;
jmethodID g_bufferSetPositionMethodId = NULL;
jmethodID g_verifyCallbackMethodId = NULL;

Expand Down Expand Up @@ -185,6 +186,12 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void* reserved)
return JNI_ERR;
}

g_bufferArrayOffsetMethodId = (*env)->GetMethodID(env, byteBufferClass,
"arrayOffset", "()I");
if (g_bufferArrayOffsetMethodId == NULL) {
return JNI_ERR;
}

g_bufferSetPositionMethodId = (*env)->GetMethodID(env, byteBufferClass,
"position", "(I)Ljava/nio/Buffer;");
if (g_bufferSetPositionMethodId == NULL) {
Expand Down Expand Up @@ -236,6 +243,7 @@ JNIEXPORT void JNICALL JNI_OnUnload(JavaVM* vm, void* reserved)
g_bufferLimitMethodId = NULL;
g_bufferHasArrayMethodId = NULL;
g_bufferArrayMethodId = NULL;
g_bufferArrayOffsetMethodId = NULL;
g_bufferSetPositionMethodId = NULL;
g_verifyCallbackMethodId = NULL;
}
Expand Down
28 changes: 24 additions & 4 deletions native/com_wolfssl_WolfSSLSession.c
Original file line number Diff line number Diff line change
Expand Up @@ -1303,6 +1303,7 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write__JLjava_nio_ByteBuf
int ret = BAD_FUNC_ARG;
int maxInputSz;
int inSz = length;
int arrayOffset = 0;
byte* data = NULL;
WOLFSSL* ssl = (WOLFSSL*)(uintptr_t)sslPtr;
jbyteArray bufArr = NULL;
Expand Down Expand Up @@ -1333,6 +1334,15 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write__JLjava_nio_ByteBuf
return SSL_FAILURE;
}

/* Honor arrayOffset() for sliced/duplicated array-backed
* ByteBuffers, where logical position 0 maps to backing
* array index arrayOffset() */
arrayOffset = (int)(*jenv)->CallIntMethod(jenv, buf,
g_bufferArrayOffsetMethodId);
if ((*jenv)->ExceptionCheck(jenv)) {
return SSL_FAILURE;
}

/* Get array elements */
data = (byte *)(*jenv)->GetByteArrayElements(jenv, bufArr, NULL);
if (data == NULL) {
Expand All @@ -1356,8 +1366,8 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_write__JLjava_nio_ByteBuf
}
}

ret = SSLWriteNonblockingWithSelectPoll(ssl, data + position,
(int)inSz, (int)timeout);
ret = SSLWriteNonblockingWithSelectPoll(ssl,
data + arrayOffset + position, (int)inSz, (int)timeout);

/* release memory if using array mode */
if (hasArray) {
Expand Down Expand Up @@ -1530,6 +1540,7 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read__JLjava_nio_ByteBuff
int size = 0;
int maxOutputSz;
int outSz = length;
int arrayOffset = 0;
byte* data = NULL;
WOLFSSL* ssl = (WOLFSSL*)(uintptr_t)sslPtr;
jbyteArray bufArr = NULL;
Expand Down Expand Up @@ -1560,6 +1571,15 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read__JLjava_nio_ByteBuff
return SSL_FAILURE;
}

/* Honor arrayOffset() for sliced/duplicated array-backed
* ByteBuffers, where logical position 0 maps to backing
* array index arrayOffset() */
arrayOffset = (int)(*jenv)->CallIntMethod(jenv, buf,
g_bufferArrayOffsetMethodId);
if ((*jenv)->ExceptionCheck(jenv)) {
return SSL_FAILURE;
}

/* Get array elements */
data = (byte *)(*jenv)->GetByteArrayElements(jenv, bufArr, NULL);
if (data == NULL) {
Expand All @@ -1583,8 +1603,8 @@ JNIEXPORT jint JNICALL Java_com_wolfssl_WolfSSLSession_read__JLjava_nio_ByteBuff
}
}

size = SSLReadNonblockingWithSelectPoll(ssl, data + position,
outSz, (int)timeout);
size = SSLReadNonblockingWithSelectPoll(ssl,
data + arrayOffset + position, outSz, (int)timeout);

/* Release array elements if using array-backed buffer.
* Note: DirectByteBuffer doesn't need releasing data */
Expand Down
1 change: 1 addition & 0 deletions native/com_wolfssl_globals.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ extern jmethodID g_bufferPositionMethodId; /* ByteBuffer.position() */
extern jmethodID g_bufferLimitMethodId; /* ByteBuffer.limit() */
extern jmethodID g_bufferHasArrayMethodId; /* ByteBuffer.hasArray() */
extern jmethodID g_bufferArrayMethodId; /* ByteBuffer.array() */
extern jmethodID g_bufferArrayOffsetMethodId; /* ByteBuffer.arrayOffset() */
extern jmethodID g_bufferSetPositionMethodId; /* ByteBuffer.position(int) */
extern jmethodID g_verifyCallbackMethodId; /* WolfSSLVerifyCallback.verifyCallback */

Expand Down
18 changes: 9 additions & 9 deletions src/java/com/wolfssl/provider/jsse/WolfSSLEngine.java
Original file line number Diff line number Diff line change
Expand Up @@ -742,8 +742,8 @@ private synchronized int SendAppData(ByteBuffer[] in, int ofst, int len)
/* Get total input data size, store input array positions */
for (i = ofst; i < ofst + len; i++) {
totalIn += in[i].remaining();
pos[i] = in[i].position();
limit[i] = in[i].limit();
pos[i - ofst] = in[i].position();
limit[i - ofst] = in[i].limit();
}

/* Allocate static buffer for application data, clear before use */
Expand All @@ -765,7 +765,7 @@ private synchronized int SendAppData(ByteBuffer[] in, int ofst, int len)
in[i].limit(in[i].position() + bufChunk); /* set limit */
this.staticAppDataBuf.put(in[i]); /* get data */
inputLeft -= bufChunk;
in[i].limit(limit[i]); /* reset limit */
in[i].limit(limit[i - ofst]); /* reset limit */

if (inputLeft == 0) {
break; /* reached data size needed, stop reading */
Expand All @@ -786,7 +786,7 @@ private synchronized int SendAppData(ByteBuffer[] in, int ofst, int len)
if (ret <= 0) {
/* error, reset in[] positions for next call */
for (i = ofst; i < ofst + len; i++) {
in[i].position(pos[i]);
in[i].position(pos[i - ofst]);
}
}

Expand Down Expand Up @@ -828,7 +828,7 @@ public synchronized SSLEngineResult wrap(ByteBuffer[] in, int ofst, int len,
throw new IndexOutOfBoundsException();
}

for (i = ofst; i < len; ++i) {
for (i = ofst; i < ofst + len; ++i) {
if (in[i] == null) {
throw new SSLException("SSLEngine.wrap() bad arguments");
}
Expand All @@ -850,7 +850,7 @@ public synchronized SSLEngineResult wrap(ByteBuffer[] in, int ofst, int len,
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
() -> "setUseClientMode: " +
this.engineHelper.getUseClientMode());
for (i = 0; i < len; i++) {
for (i = ofst; i < ofst + len; i++) {
final int idx = i;
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
() -> "ByteBuffer in["+idx+"].remaining(): " +
Expand Down Expand Up @@ -1392,7 +1392,7 @@ public synchronized SSLEngineResult unwrap(ByteBuffer in, ByteBuffer[] out,
throw new IndexOutOfBoundsException();
}

for (i = ofst; i < length; ++i) {
for (i = ofst; i < ofst + length; ++i) {
if (out[i] == null) {
throw new IllegalArgumentException(
"SSLEngine.unwrap() bad arguments");
Expand Down Expand Up @@ -1434,7 +1434,7 @@ public synchronized SSLEngineResult unwrap(ByteBuffer in, ByteBuffer[] out,
() -> "in.position(): " + in.position());
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
() -> "in.limit(): " + in.limit());
for (i = 0; i < length; i++) {
for (i = ofst; i < ofst + length; i++) {
final int idx = i;
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
() -> "out["+idx+"].remaining(): " +
Expand Down Expand Up @@ -2093,7 +2093,7 @@ public synchronized void closeOutbound() {

/* If handshake has not started yet, close inBound as well */
if (needInit) {
inBoundOpen = true;
inBoundOpen = false;
}

/* Update status based on internal state. Some calling applications
Expand Down
72 changes: 72 additions & 0 deletions src/test/com/wolfssl/provider/jsse/test/WolfSSLEngineTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.atomic.AtomicIntegerArray;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.fail;
Expand Down Expand Up @@ -3347,5 +3348,76 @@ public void testWrapPartialDrainOffsetUpdate()
fail("drained output does not match injected queue");
}
}

/* Regression: closeOutbound() before handshake must also close
* inbound, otherwise isInboundDone() never returns true. */
@Test
public void testCloseOutboundBeforeHandshake() throws Exception {
this.ctx = tf.createSSLContext("TLS", engineProvider);
SSLEngine e = this.ctx.createSSLEngine();
e.setUseClientMode(true);
e.closeOutbound();
assertTrue(e.isOutboundDone());
assertTrue(e.isInboundDone());
}

/* Regression for wrap(ByteBuffer[], ofst, len, out) when ofst > 0:
* pos[]/limit[] OOB and null-check loop bound. */
@Test
public void testWrapWithBufferArrayOffset() throws Exception {
this.ctx = tf.createSSLContext("TLS", engineProvider);
SSLEngine server = this.ctx.createSSLEngine();
SSLEngine client = this.ctx.createSSLEngine("wolfSSL test", 11111);
server.setUseClientMode(false);
client.setUseClientMode(true);
server.beginHandshake();
client.beginHandshake();
assertEquals(0, tf.testConnection(server, client, null, null, "x"));

byte[] payload = "real-payload".getBytes();
ByteBuffer[] in = {ByteBuffer.wrap("DECOY".getBytes()),
ByteBuffer.wrap(payload)};
ByteBuffer net = ByteBuffer.allocateDirect(
client.getSession().getPacketBufferSize());

SSLEngineResult r = client.wrap(in, 1, 1, net);
assertEquals(SSLEngineResult.Status.OK, r.getStatus());
assertEquals(0, in[0].position());
assertEquals(payload.length, in[1].position());

net.flip();
ByteBuffer plain = ByteBuffer.allocate(
server.getSession().getApplicationBufferSize());
assertEquals(SSLEngineResult.Status.OK,
server.unwrap(net, plain).getStatus());
plain.flip();
byte[] got = new byte[plain.remaining()];
plain.get(got);
assertArrayEquals(payload, got);
}

/* Direct regression: wrap() null-check must reach in[ofst+len-1]. */
@Test(expected = SSLException.class)
public void testWrapRejectsNullAtOffset() throws Exception {
this.ctx = tf.createSSLContext("TLS", engineProvider);
SSLEngine c = this.ctx.createSSLEngine("wolfSSL test", 11111);
c.setUseClientMode(true);
ByteBuffer[] in = {ByteBuffer.wrap("x".getBytes()), null};
c.wrap(in, 1, 1, ByteBuffer.allocateDirect(
c.getSession().getPacketBufferSize()));
}

/* Direct regression: unwrap() readOnly-check must reach
* out[ofst+length-1]. */
@Test(expected = java.nio.ReadOnlyBufferException.class)
public void testUnwrapRejectsReadOnlyAtOffset() throws Exception {
this.ctx = tf.createSSLContext("TLS", engineProvider);
SSLEngine s = this.ctx.createSSLEngine();
s.setUseClientMode(false);
ByteBuffer[] out = {ByteBuffer.allocate(64),
ByteBuffer.allocate(64).asReadOnlyBuffer()};
s.unwrap(ByteBuffer.allocateDirect(
s.getSession().getPacketBufferSize()), out, 1, 1);
}
}

102 changes: 102 additions & 0 deletions src/test/com/wolfssl/test/WolfSSLSessionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4210,5 +4210,107 @@ public void test_WolfSSLSession_dtlsCidDataExchangeAfterHandshake()
}
}
}

/* Regression: read(ByteBuffer) must honor arrayOffset() so a
* sliced array-backed buffer reads into backing[arrayOffset+pos),
* not backing[pos). */
@Test
public void test_WolfSSLSession_readSlicedByteBuffer() throws Exception {
final ServerSocket srvSocket = new ServerSocket(0);
final WolfSSLContext srvCtx = createAndSetupWolfSSLContext(
srvCert, srvKey, WolfSSL.SSL_FILETYPE_PEM, cliCert,
WolfSSL.SSLv23_ServerMethod());
WolfSSLContext cliCtx = createAndSetupWolfSSLContext(
cliCert, cliKey, WolfSSL.SSL_FILETYPE_PEM, caCert,
WolfSSL.SSLv23_ClientMethod());
final byte[] payload = "sliced-buf-payload".getBytes();

ExecutorService es = Executors.newSingleThreadExecutor();
Future<Void> srv = es.submit(() -> {
try (Socket s = srvSocket.accept()) {
WolfSSLSession ss = new WolfSSLSession(srvCtx);
ss.setFd(s);
int r;
int e;
do {
r = ss.accept();
e = ss.getError(r);
} while (r != WolfSSL.SSL_SUCCESS &&
(e == WolfSSL.SSL_ERROR_WANT_READ ||
e == WolfSSL.SSL_ERROR_WANT_WRITE));
ss.write(payload, payload.length, 0);
ss.shutdownSSL();
ss.freeSSL();
}
return null;
});

Socket cliSock = null;
WolfSSLSession cliSes = null;
try {
cliSock = new Socket(InetAddress.getLoopbackAddress(),
srvSocket.getLocalPort());
cliSes = new WolfSSLSession(cliCtx);
cliSes.setFd(cliSock);
int r;
int e;
do {
r = cliSes.connect();
e = cliSes.getError(r);
} while (r != WolfSSL.SSL_SUCCESS &&
(e == WolfSSL.SSL_ERROR_WANT_READ ||
e == WolfSSL.SSL_ERROR_WANT_WRITE));

int prefix = 64;
ByteBuffer parent = ByteBuffer.allocate(256);
byte[] backing = parent.array();
byte sentinel = (byte) 0xA5;
Arrays.fill(backing, sentinel);
parent.position(prefix);
ByteBuffer slice = parent.slice();
assertEquals(prefix, slice.arrayOffset());

int total = 0;
while (total < payload.length) {
int n = cliSes.read(slice, payload.length - total, 5000);
if (n > 0) {
total += n;
continue;
}
int err = cliSes.getError(n);
if (err == WolfSSL.SSL_ERROR_WANT_READ ||
err == WolfSSL.SSL_ERROR_WANT_WRITE) {
continue;
}
fail("cliSes.read() failed: ret=" + n + " err=" + err +
" total=" + total + "/" + payload.length);
}

for (int i = 0; i < prefix; i++) {
assertEquals("backing[" + i + "] corrupted",
sentinel, backing[i]);
}
assertArrayEquals(payload, Arrays.copyOfRange(backing,
prefix, prefix + payload.length));
assertEquals(payload.length, slice.position());

cliSes.shutdownSSL();
} finally {
try {
srv.get(10, TimeUnit.SECONDS);
} finally {
es.shutdownNow();
if (cliSes != null) {
cliSes.freeSSL();
}
if (cliSock != null) {
cliSock.close();
}
srvSocket.close();
cliCtx.free();
srvCtx.free();
}
}
}
}

Loading