diff --git a/native/src/lib.rs b/native/src/lib.rs index 947d47f..cc58b7d 100644 --- a/native/src/lib.rs +++ b/native/src/lib.rs @@ -19,14 +19,15 @@ mod errors; use std::sync::{Arc, OnceLock}; -use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::datatypes::{Schema, SchemaRef}; use datafusion::arrow::ffi_stream::FFI_ArrowArrayStream; +use datafusion::arrow::ipc::reader::StreamReader; use datafusion::arrow::record_batch::RecordBatchIterator; use datafusion::dataframe::DataFrame; use datafusion::error::DataFusionError; use datafusion::prelude::{ParquetReadOptions, SessionContext}; -use jni::objects::{JClass, JString}; -use jni::sys::jlong; +use jni::objects::{JByteArray, JClass, JString}; +use jni::sys::{jboolean, jlong}; use jni::JNIEnv; use tokio::runtime::Runtime; @@ -129,13 +130,61 @@ pub extern "system" fn Java_org_apache_datafusion_SessionContext_closeSessionCon }) } +#[allow(clippy::too_many_arguments)] +fn with_parquet_options( + env: &mut JNIEnv, + file_extension: JString, + parquet_pruning_set: jboolean, + parquet_pruning_value: jboolean, + skip_metadata_set: jboolean, + skip_metadata_value: jboolean, + metadata_size_hint: jlong, + schema_ipc_bytes: JByteArray, + f: impl FnOnce(ParquetReadOptions) -> JniResult, +) -> JniResult { + let file_ext: String = env.get_string(&file_extension)?.into(); + + let schema: Option = if !schema_ipc_bytes.is_null() { + let bytes: Vec = env.convert_byte_array(&schema_ipc_bytes)?; + let reader = StreamReader::try_new(std::io::Cursor::new(bytes), None)?; + Some((*reader.schema()).clone()) + } else { + None + }; + + let mut opts = ParquetReadOptions::default().file_extension(&file_ext); + if parquet_pruning_set != 0 { + opts = opts.parquet_pruning(parquet_pruning_value != 0); + } + if skip_metadata_set != 0 { + opts = opts.skip_metadata(skip_metadata_value != 0); + } + if metadata_size_hint >= 0 { + opts = opts.metadata_size_hint(Some(metadata_size_hint as usize)); + } + if let Some(ref s) = schema { + opts = opts.schema(s); + } + + f(opts) +} + #[no_mangle] -pub extern "system" fn Java_org_apache_datafusion_SessionContext_registerParquet<'local>( +pub extern "system" fn Java_org_apache_datafusion_SessionContext_registerParquetWithOptions< + 'local, +>( mut env: JNIEnv<'local>, _class: JClass<'local>, handle: jlong, name: JString<'local>, path: JString<'local>, + file_extension: JString<'local>, + parquet_pruning_set: jboolean, + parquet_pruning_value: jboolean, + skip_metadata_set: jboolean, + skip_metadata_value: jboolean, + metadata_size_hint: jlong, + schema_ipc_bytes: JByteArray<'local>, ) { try_unwrap_or_throw(&mut env, (), |env| -> JniResult<()> { if handle == 0 { @@ -144,11 +193,59 @@ pub extern "system" fn Java_org_apache_datafusion_SessionContext_registerParquet let ctx = unsafe { &*(handle as *const SessionContext) }; let name: String = env.get_string(&name)?.into(); let path: String = env.get_string(&path)?.into(); - runtime().block_on(async { - ctx.register_parquet(&name, &path, ParquetReadOptions::default()) - .await?; - Ok::<(), DataFusionError>(()) - })?; - Ok(()) + with_parquet_options( + env, + file_extension, + parquet_pruning_set, + parquet_pruning_value, + skip_metadata_set, + skip_metadata_value, + metadata_size_hint, + schema_ipc_bytes, + |opts| { + runtime().block_on(async { + ctx.register_parquet(&name, &path, opts).await?; + Ok::<(), DataFusionError>(()) + })?; + Ok(()) + }, + ) + }) +} + +#[no_mangle] +pub extern "system" fn Java_org_apache_datafusion_SessionContext_readParquetWithOptions<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, + path: JString<'local>, + file_extension: JString<'local>, + parquet_pruning_set: jboolean, + parquet_pruning_value: jboolean, + skip_metadata_set: jboolean, + skip_metadata_value: jboolean, + metadata_size_hint: jlong, + schema_ipc_bytes: JByteArray<'local>, +) -> jlong { + try_unwrap_or_throw(&mut env, 0, |env| -> JniResult { + if handle == 0 { + return Err("SessionContext handle is null".into()); + } + let ctx = unsafe { &*(handle as *const SessionContext) }; + let path: String = env.get_string(&path)?.into(); + with_parquet_options( + env, + file_extension, + parquet_pruning_set, + parquet_pruning_value, + skip_metadata_set, + skip_metadata_value, + metadata_size_hint, + schema_ipc_bytes, + |opts| { + let df = runtime().block_on(ctx.read_parquet(path, opts))?; + Ok(Box::into_raw(Box::new(df)) as jlong) + }, + ) }) } diff --git a/src/main/java/org/apache/datafusion/ParquetReadOptions.java b/src/main/java/org/apache/datafusion/ParquetReadOptions.java new file mode 100644 index 0000000..b29d762 --- /dev/null +++ b/src/main/java/org/apache/datafusion/ParquetReadOptions.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion; + +import org.apache.arrow.vector.types.pojo.Schema; + +/** + * Configuration knobs for parquet sources passed to {@link SessionContext#registerParquet(String, + * String, ParquetReadOptions)} and {@link SessionContext#readParquet(String, ParquetReadOptions)}. + * + *

Mirrors a subset of DataFusion's {@code ParquetReadOptions}. All setters return {@code this} + * for fluent chaining. Defaults: {@code fileExtension = ".parquet"}; all other fields {@code null} + * (meaning the SessionConfig default is used, or the schema is inferred from the file). + */ +public final class ParquetReadOptions { + + private String fileExtension = ".parquet"; + private Boolean parquetPruning; + private Boolean skipMetadata; + private Long metadataSizeHint; + private Schema schema; + + public ParquetReadOptions fileExtension(String ext) { + this.fileExtension = ext; + return this; + } + + public ParquetReadOptions parquetPruning(boolean v) { + this.parquetPruning = v; + return this; + } + + public ParquetReadOptions skipMetadata(boolean v) { + this.skipMetadata = v; + return this; + } + + public ParquetReadOptions metadataSizeHint(long bytes) { + this.metadataSizeHint = bytes; + return this; + } + + public ParquetReadOptions schema(Schema schema) { + this.schema = schema; + return this; + } + + String fileExtension() { + return fileExtension; + } + + Boolean parquetPruning() { + return parquetPruning; + } + + Boolean skipMetadata() { + return skipMetadata; + } + + Long metadataSizeHint() { + return metadataSizeHint; + } + + Schema schema() { + return schema; + } +} diff --git a/src/main/java/org/apache/datafusion/SessionContext.java b/src/main/java/org/apache/datafusion/SessionContext.java index ac1a45e..fb79d43 100644 --- a/src/main/java/org/apache/datafusion/SessionContext.java +++ b/src/main/java/org/apache/datafusion/SessionContext.java @@ -19,6 +19,16 @@ package org.apache.datafusion; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.channels.Channels; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowStreamWriter; +import org.apache.arrow.vector.types.pojo.Schema; + /** * A DataFusion session context. * @@ -54,10 +64,70 @@ public DataFrame sql(String query) { } public void registerParquet(String name, String path) { + registerParquet(name, path, new ParquetReadOptions()); + } + + /** + * Register a parquet file as a table with the supplied {@link ParquetReadOptions}. + * + * @throws RuntimeException if registration fails (path not found, schema mismatch, etc.). + */ + public void registerParquet(String name, String path, ParquetReadOptions options) { + if (nativeHandle == 0) { + throw new IllegalStateException("SessionContext is closed"); + } + registerParquetWithOptions( + nativeHandle, + name, + path, + options.fileExtension(), + options.parquetPruning() != null, + options.parquetPruning() != null && options.parquetPruning(), + options.skipMetadata() != null, + options.skipMetadata() != null && options.skipMetadata(), + options.metadataSizeHint() != null ? options.metadataSizeHint() : -1L, + options.schema() != null ? serializeSchemaIpc(options.schema()) : null); + } + + /** Read a parquet file as a {@link DataFrame} without registering it. */ + public DataFrame readParquet(String path) { + return readParquet(path, new ParquetReadOptions()); + } + + /** + * Read a parquet file as a {@link DataFrame} with the supplied {@link ParquetReadOptions}. + * + * @throws RuntimeException if the read fails. + */ + public DataFrame readParquet(String path, ParquetReadOptions options) { if (nativeHandle == 0) { throw new IllegalStateException("SessionContext is closed"); } - registerParquet(nativeHandle, name, path); + long dfHandle = + readParquetWithOptions( + nativeHandle, + path, + options.fileExtension(), + options.parquetPruning() != null, + options.parquetPruning() != null && options.parquetPruning(), + options.skipMetadata() != null, + options.skipMetadata() != null && options.skipMetadata(), + options.metadataSizeHint() != null ? options.metadataSizeHint() : -1L, + options.schema() != null ? serializeSchemaIpc(options.schema()) : null); + return new DataFrame(dfHandle); + } + + private static byte[] serializeSchemaIpc(Schema schema) { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try (BufferAllocator allocator = new RootAllocator(); + VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); + ArrowStreamWriter writer = new ArrowStreamWriter(root, null, Channels.newChannel(baos))) { + writer.start(); + writer.end(); + } catch (IOException e) { + throw new RuntimeException("Failed to serialize Arrow schema for JNI", e); + } + return baos.toByteArray(); } @Override @@ -72,7 +142,28 @@ public void close() { private static native long createDataFrame(long handle, String sql); - private static native void registerParquet(long handle, String name, String path); + private static native void registerParquetWithOptions( + long handle, + String name, + String path, + String fileExtension, + boolean parquetPruningSet, + boolean parquetPruningValue, + boolean skipMetadataSet, + boolean skipMetadataValue, + long metadataSizeHint, + byte[] schemaIpcBytes); + + private static native long readParquetWithOptions( + long handle, + String path, + String fileExtension, + boolean parquetPruningSet, + boolean parquetPruningValue, + boolean skipMetadataSet, + boolean skipMetadataValue, + long metadataSizeHint, + byte[] schemaIpcBytes); private static native void closeSessionContext(long handle); } diff --git a/src/test/java/org/apache/datafusion/ParquetReadOptionsTest.java b/src/test/java/org/apache/datafusion/ParquetReadOptionsTest.java new file mode 100644 index 0000000..3eca311 --- /dev/null +++ b/src/test/java/org/apache/datafusion/ParquetReadOptionsTest.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.List; + +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.jupiter.api.Test; + +class ParquetReadOptionsTest { + + @Test + void defaultsMatchDataFusion() { + ParquetReadOptions opts = new ParquetReadOptions(); + assertEquals(".parquet", opts.fileExtension()); + assertNull(opts.parquetPruning()); + assertNull(opts.skipMetadata()); + assertNull(opts.metadataSizeHint()); + assertNull(opts.schema()); + } + + @Test + void fluentSettersChainAndMutate() { + Schema schema = + new Schema(List.of(new Field("x", FieldType.nullable(new ArrowType.Int(32, true)), null))); + + ParquetReadOptions opts = + new ParquetReadOptions() + .fileExtension(".parq") + .parquetPruning(true) + .skipMetadata(false) + .metadataSizeHint(1_048_576L) + .schema(schema); + + assertEquals(".parq", opts.fileExtension()); + assertEquals(Boolean.TRUE, opts.parquetPruning()); + assertEquals(Boolean.FALSE, opts.skipMetadata()); + assertEquals(Long.valueOf(1_048_576L), opts.metadataSizeHint()); + assertTrue(opts.schema() == schema); + } + + @Test + void schemaSetterRetainsReferenceIdentity() { + Schema schema = new Schema(List.of()); + ParquetReadOptions opts = new ParquetReadOptions().schema(schema); + assertSame(schema, opts.schema()); + } +} diff --git a/src/test/java/org/apache/datafusion/SessionContextParquetOptionsTest.java b/src/test/java/org/apache/datafusion/SessionContextParquetOptionsTest.java new file mode 100644 index 0000000..9b79327 --- /dev/null +++ b/src/test/java/org/apache/datafusion/SessionContextParquetOptionsTest.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardCopyOption; +import java.util.List; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +class SessionContextParquetOptionsTest { + + @Test + void readParquetWithDefaultOptionsCountsAllRows() throws Exception { + Path lineitem = Path.of("tpch-data/sf1/lineitem.parquet"); + Assumptions.assumeTrue( + Files.exists(lineitem), "TPC-H SF1 data not found; run `make tpch-data` first"); + + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext(); + DataFrame df = ctx.readParquet(lineitem.toAbsolutePath().toString()); + ArrowReader reader = df.collect(allocator)) { + long total = 0; + while (reader.loadNextBatch()) { + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + total += root.getRowCount(); + } + assertEquals(6_001_215L, total); + assertTrue(total > 0); + } + } + + @Test + void registerParquetWithOptionsRespectsCustomExtension(@TempDir Path tempDir) throws Exception { + Path lineitem = Path.of("tpch-data/sf1/lineitem.parquet"); + Assumptions.assumeTrue( + Files.exists(lineitem), "TPC-H SF1 data not found; run `make tpch-data` first"); + Path renamed = tempDir.resolve("lineitem.parq"); + Files.copy(lineitem, renamed, StandardCopyOption.REPLACE_EXISTING); + + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext()) { + ctx.registerParquet( + "t", + renamed.toAbsolutePath().toString(), + new ParquetReadOptions().fileExtension(".parq")); + try (DataFrame df = ctx.sql("SELECT COUNT(*) FROM t"); + ArrowReader reader = df.collect(allocator)) { + assertTrue(reader.loadNextBatch()); + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + assertEquals(1, root.getRowCount()); + BigIntVector count = (BigIntVector) root.getVector(0); + assertEquals(6_001_215L, count.get(0)); + } + } + } + + @Test + void readParquetWithExplicitSchemaUsesProvidedSchema() throws Exception { + Path lineitem = Path.of("tpch-data/sf1/lineitem.parquet"); + Assumptions.assumeTrue( + Files.exists(lineitem), "TPC-H SF1 data not found; run `make tpch-data` first"); + + Schema custom = + new Schema( + List.of( + new Field("l_orderkey", FieldType.notNullable(new ArrowType.Int(64, true)), null))); + + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext(); + DataFrame df = + ctx.readParquet( + lineitem.toAbsolutePath().toString(), new ParquetReadOptions().schema(custom)); + ArrowReader reader = df.collect(allocator)) { + assertTrue(reader.loadNextBatch()); + Schema observed = reader.getVectorSchemaRoot().getSchema(); + assertEquals(1, observed.getFields().size()); + assertEquals("l_orderkey", observed.getFields().get(0).getName()); + } + } + + @Test + void registerParquetWithMetadataSizeHintIsAccepted() throws Exception { + Path lineitem = Path.of("tpch-data/sf1/lineitem.parquet"); + Assumptions.assumeTrue( + Files.exists(lineitem), "TPC-H SF1 data not found; run `make tpch-data` first"); + + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext()) { + ctx.registerParquet( + "t", + lineitem.toAbsolutePath().toString(), + new ParquetReadOptions().metadataSizeHint(1L << 20)); + try (DataFrame df = ctx.sql("SELECT COUNT(*) FROM t"); + ArrowReader reader = df.collect(allocator)) { + assertTrue(reader.loadNextBatch()); + BigIntVector count = (BigIntVector) reader.getVectorSchemaRoot().getVector(0); + assertEquals(6_001_215L, count.get(0)); + } + } + } +}