Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions native/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,87 @@ pub extern "system" fn Java_org_apache_datafusion_DataFrame_filterRows<'local>(
})
}

#[no_mangle]
pub extern "system" fn Java_org_apache_datafusion_DataFrame_limitRows<'local>(
mut env: JNIEnv<'local>,
_class: JClass<'local>,
handle: jlong,
skip: jint,
fetch: jint,
) -> jlong {
try_unwrap_or_throw(&mut env, 0, |_env| -> JniResult<jlong> {
if handle == 0 {
return Err("DataFrame handle is null".into());
}
let df = unsafe { &*(handle as *const DataFrame) }.clone();
let new_df = df.limit(skip as usize, Some(fetch as usize))?;
Ok(Box::into_raw(Box::new(new_df)) as jlong)
})
}

#[no_mangle]
pub extern "system" fn Java_org_apache_datafusion_DataFrame_distinctRows<'local>(
mut env: JNIEnv<'local>,
_class: JClass<'local>,
handle: jlong,
) -> jlong {
try_unwrap_or_throw(&mut env, 0, |_env| -> JniResult<jlong> {
if handle == 0 {
return Err("DataFrame handle is null".into());
}
let df = unsafe { &*(handle as *const DataFrame) }.clone();
let new_df = df.distinct()?;
Ok(Box::into_raw(Box::new(new_df)) as jlong)
})
}

#[no_mangle]
pub extern "system" fn Java_org_apache_datafusion_DataFrame_dropColumns<'local>(
mut env: JNIEnv<'local>,
_class: JClass<'local>,
handle: jlong,
column_names: JObjectArray<'local>,
) -> jlong {
try_unwrap_or_throw(&mut env, 0, |env| -> JniResult<jlong> {
if handle == 0 {
return Err("DataFrame handle is null".into());
}
let df = unsafe { &*(handle as *const DataFrame) }.clone();

let len = env.get_array_length(&column_names)?;
let mut owned: Vec<String> = Vec::with_capacity(len as usize);
for i in 0..len {
let elem = env.get_object_array_element(&column_names, i)?;
let jstr: JString = elem.into();
owned.push(env.get_string(&jstr)?.into());
}
let refs: Vec<&str> = owned.iter().map(String::as_str).collect();

let new_df = df.drop_columns(&refs)?;
Ok(Box::into_raw(Box::new(new_df)) as jlong)
})
}

#[no_mangle]
pub extern "system" fn Java_org_apache_datafusion_DataFrame_renameColumn<'local>(
mut env: JNIEnv<'local>,
_class: JClass<'local>,
handle: jlong,
old_name: JString<'local>,
new_name: JString<'local>,
) -> jlong {
try_unwrap_or_throw(&mut env, 0, |env| -> JniResult<jlong> {
if handle == 0 {
return Err("DataFrame handle is null".into());
}
let df = unsafe { &*(handle as *const DataFrame) }.clone();
let old: String = env.get_string(&old_name)?.into();
let new: String = env.get_string(&new_name)?.into();
let new_df = df.with_column_renamed(&old, &new)?;
Ok(Box::into_raw(Box::new(new_df)) as jlong)
})
}

#[no_mangle]
pub extern "system" fn Java_org_apache_datafusion_DataFrame_writeParquetWithOptions<'local>(
mut env: JNIEnv<'local>,
Expand Down
63 changes: 63 additions & 0 deletions src/main/java/org/apache/datafusion/DataFrame.java
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,61 @@ public DataFrame filter(String predicate) {
return new DataFrame(filterRows(nativeHandle, predicate));
}

/**
* Take the first {@code fetch} rows. Equivalent to {@link #limit(int, int)} with {@code skip =
* 0}. The receiver remains usable and must still be closed independently.
*/
public DataFrame limit(int fetch) {
return limit(0, fetch);
}

/**
* Skip {@code skip} rows, then take the next {@code fetch} rows. Both arguments must be
* non-negative. The receiver remains usable and must still be closed independently.
*/
public DataFrame limit(int skip, int fetch) {
if (skip < 0) {
throw new IllegalArgumentException("skip must be non-negative, was " + skip);
}
if (fetch < 0) {
throw new IllegalArgumentException("fetch must be non-negative, was " + fetch);
}
if (nativeHandle == 0) {
throw new IllegalStateException("DataFrame is closed or already collected");
}
return new DataFrame(limitRows(nativeHandle, skip, fetch));
}

/**
* Deduplicate rows across all columns. The receiver remains usable and must still be closed
* independently.
*/
public DataFrame distinct() {
if (nativeHandle == 0) {
throw new IllegalStateException("DataFrame is closed or already collected");
}
return new DataFrame(distinctRows(nativeHandle));
}

/**
* Drop the named columns. The inverse of {@link #select(String...)}. The receiver remains usable
* and must still be closed independently.
*/
public DataFrame dropColumns(String... columnNames) {
if (nativeHandle == 0) {
throw new IllegalStateException("DataFrame is closed or already collected");
}
return new DataFrame(dropColumns(nativeHandle, columnNames));
}

/** Rename a column. The receiver remains usable and must still be closed independently. */
public DataFrame withColumnRenamed(String oldName, String newName) {
if (nativeHandle == 0) {
throw new IllegalStateException("DataFrame is closed or already collected");
}
return new DataFrame(renameColumn(nativeHandle, oldName, newName));
}

/**
* Materialize this DataFrame as Parquet at {@code path}. The path is treated as a directory
* unless overridden via {@link ParquetWriteOptions#singleFileOutput(boolean)}. The receiver
Expand Down Expand Up @@ -168,6 +223,14 @@ public void close() {

private static native long filterRows(long handle, String predicate);

private static native long limitRows(long handle, int skip, int fetch);

private static native long distinctRows(long handle);

private static native long dropColumns(long handle, String[] columnNames);

private static native long renameColumn(long handle, String oldName, String newName);

private static native void writeParquetWithOptions(
long handle,
String path,
Expand Down
160 changes: 160 additions & 0 deletions src/test/java/org/apache/datafusion/DataFrameTransformationsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ void methodsThrowAfterClose() {
df.close();
assertThrows(IllegalStateException.class, () -> df.select("x"));
assertThrows(IllegalStateException.class, () -> df.filter("x > 0"));
assertThrows(IllegalStateException.class, () -> df.limit(1));
assertThrows(IllegalStateException.class, () -> df.limit(0, 1));
assertThrows(IllegalStateException.class, df::distinct);
assertThrows(IllegalStateException.class, () -> df.dropColumns("x"));
assertThrows(IllegalStateException.class, () -> df.withColumnRenamed("x", "y"));
assertThrows(IllegalStateException.class, df::count);
assertThrows(IllegalStateException.class, df::show);
assertThrows(IllegalStateException.class, () -> df.show(5));
Expand All @@ -144,6 +149,11 @@ void methodsThrowAfterCollect() throws Exception {
}
assertThrows(IllegalStateException.class, () -> df.select("x"));
assertThrows(IllegalStateException.class, () -> df.filter("x > 0"));
assertThrows(IllegalStateException.class, () -> df.limit(1));
assertThrows(IllegalStateException.class, () -> df.limit(0, 1));
assertThrows(IllegalStateException.class, df::distinct);
assertThrows(IllegalStateException.class, () -> df.dropColumns("x"));
assertThrows(IllegalStateException.class, () -> df.withColumnRenamed("x", "y"));
assertThrows(IllegalStateException.class, df::count);
assertThrows(IllegalStateException.class, df::show);
assertThrows(IllegalStateException.class, () -> df.show(5));
Expand Down Expand Up @@ -193,4 +203,154 @@ void lineitemFilterCountAgainstSqlBaseline() throws Exception {
assertEquals(viaSql, viaDataFrame);
}
}

@Test
void limitTakesFirstNRows() {
try (SessionContext ctx = new SessionContext();
DataFrame source = ctx.sql("SELECT * FROM (VALUES (1), (2), (3), (4), (5)) AS t(x)");
DataFrame limited = source.limit(2)) {
assertEquals(2L, limited.count());
}
}

@Test
void limitWithSkipDropsLeadingRows() {
try (SessionContext ctx = new SessionContext();
DataFrame source = ctx.sql("SELECT * FROM (VALUES (1), (2), (3), (4), (5)) AS t(x)");
DataFrame limited = source.limit(2, 2)) {
assertEquals(2L, limited.count());
}
}

@Test
void limitIsNonDestructive() {
try (SessionContext ctx = new SessionContext();
DataFrame source = ctx.sql("SELECT * FROM (VALUES (1), (2), (3)) AS t(x)")) {
try (DataFrame limited = source.limit(1)) {
assertEquals(1L, limited.count());
}
assertEquals(3L, source.count());
}
}

@Test
void limitRejectsNegativeArgs() {
try (SessionContext ctx = new SessionContext();
DataFrame df = ctx.sql("SELECT 1 AS x")) {
assertThrows(IllegalArgumentException.class, () -> df.limit(-1));
assertThrows(IllegalArgumentException.class, () -> df.limit(-1, 0));
assertThrows(IllegalArgumentException.class, () -> df.limit(0, -1));
}
}

@Test
void distinctRemovesDuplicates() {
try (SessionContext ctx = new SessionContext();
DataFrame source =
ctx.sql("SELECT * FROM (VALUES (1), (1), (2), (2), (3)) AS t(x)");
DataFrame deduped = source.distinct()) {
assertEquals(3L, deduped.count());
}
}

@Test
void distinctIsNonDestructive() {
try (SessionContext ctx = new SessionContext();
DataFrame source = ctx.sql("SELECT * FROM (VALUES (1), (1), (2)) AS t(x)")) {
try (DataFrame deduped = source.distinct()) {
assertEquals(2L, deduped.count());
}
assertEquals(3L, source.count());
}
}

@Test
void dropColumnsRemovesNamedColumns() throws Exception {
try (BufferAllocator allocator = new RootAllocator();
SessionContext ctx = new SessionContext();
DataFrame source = ctx.sql("SELECT 1 AS a, 2 AS b, 3 AS c");
DataFrame dropped = source.dropColumns("b");
ArrowReader reader = dropped.collect(allocator)) {
assertTrue(reader.loadNextBatch());
VectorSchemaRoot root = reader.getVectorSchemaRoot();
assertArrayEquals(
new String[] {"a", "c"},
root.getSchema().getFields().stream().map(f -> f.getName()).toArray(String[]::new));
}
}

@Test
void dropColumnsIsNonDestructive() {
try (SessionContext ctx = new SessionContext();
DataFrame source = ctx.sql("SELECT 1 AS a, 2 AS b")) {
try (DataFrame dropped = source.dropColumns("a")) {
assertEquals(1L, dropped.count());
}
assertEquals(1L, source.count());
}
}

@Test
void dropColumnsSilentlyIgnoresUnknownNames() throws Exception {
try (BufferAllocator allocator = new RootAllocator();
SessionContext ctx = new SessionContext();
DataFrame df = ctx.sql("SELECT 1 AS x");
DataFrame dropped = df.dropColumns("not_a_column");
ArrowReader reader = dropped.collect(allocator)) {
assertTrue(reader.loadNextBatch());
VectorSchemaRoot root = reader.getVectorSchemaRoot();
assertArrayEquals(
new String[] {"x"},
root.getSchema().getFields().stream().map(f -> f.getName()).toArray(String[]::new));
}
}

@Test
void withColumnRenamedChangesColumnName() throws Exception {
try (BufferAllocator allocator = new RootAllocator();
SessionContext ctx = new SessionContext();
DataFrame source = ctx.sql("SELECT 1 AS a, 2 AS b");
DataFrame renamed = source.withColumnRenamed("a", "alpha");
ArrowReader reader = renamed.collect(allocator)) {
assertTrue(reader.loadNextBatch());
VectorSchemaRoot root = reader.getVectorSchemaRoot();
assertArrayEquals(
new String[] {"alpha", "b"},
root.getSchema().getFields().stream().map(f -> f.getName()).toArray(String[]::new));
}
}

@Test
void withColumnRenamedIsNonDestructive() throws Exception {
try (BufferAllocator allocator = new RootAllocator();
SessionContext ctx = new SessionContext();
DataFrame source = ctx.sql("SELECT 1 AS a, 2 AS b")) {
try (DataFrame renamed = source.withColumnRenamed("a", "alpha")) {
assertEquals(1L, renamed.count());
}
try (DataFrame again = source.select("a");
ArrowReader reader = again.collect(allocator)) {
assertTrue(reader.loadNextBatch());
VectorSchemaRoot root = reader.getVectorSchemaRoot();
assertArrayEquals(
new String[] {"a"},
root.getSchema().getFields().stream().map(f -> f.getName()).toArray(String[]::new));
}
}
}

@Test
void withColumnRenamedUnknownColumnIsNoOp() throws Exception {
try (BufferAllocator allocator = new RootAllocator();
SessionContext ctx = new SessionContext();
DataFrame df = ctx.sql("SELECT 1 AS x");
DataFrame renamed = df.withColumnRenamed("not_a_column", "y");
ArrowReader reader = renamed.collect(allocator)) {
assertTrue(reader.loadNextBatch());
VectorSchemaRoot root = reader.getVectorSchemaRoot();
assertArrayEquals(
new String[] {"x"},
root.getSchema().getFields().stream().map(f -> f.getName()).toArray(String[]::new));
}
}
}
Loading