diff --git a/README.md b/README.md index 21637cd..b22ddb0 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,8 @@ Query interfaces: Data sources: - [x] Parquet via `registerParquet` / `readParquet`, with `ParquetReadOptions` -- [ ] CSV, JSON, Avro +- [x] CSV via `registerCsv` / `readCsv`, with `CsvReadOptions` +- [ ] JSON, Avro - [ ] Custom catalog and table providers Results: diff --git a/native/src/csv.rs b/native/src/csv.rs new file mode 100644 index 0000000..a1b6fdd --- /dev/null +++ b/native/src/csv.rs @@ -0,0 +1,204 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::str::FromStr; + +use datafusion::arrow::datatypes::Schema; +use datafusion::arrow::ipc::reader::StreamReader; +use datafusion::datasource::file_format::file_compression_type::FileCompressionType; +use datafusion::error::DataFusionError; +use datafusion::prelude::{CsvReadOptions, SessionContext}; +use jni::objects::{JByteArray, JClass, JString}; +use jni::sys::{jboolean, jbyte, jlong}; +use jni::JNIEnv; + +use crate::errors::{try_unwrap_or_throw, JniResult}; +use crate::runtime; + +#[allow(clippy::too_many_arguments)] +fn with_csv_options( + env: &mut JNIEnv, + has_header: jboolean, + delimiter: jbyte, + quote: jbyte, + terminator_set: jboolean, + terminator_value: jbyte, + escape_set: jboolean, + escape_value: jbyte, + comment_set: jboolean, + comment_value: jbyte, + newlines_in_values_set: jboolean, + newlines_in_values_value: jboolean, + schema_infer_max_records: jlong, + file_extension: JString, + file_compression_type: JString, + schema_ipc_bytes: JByteArray, + f: impl FnOnce(CsvReadOptions) -> JniResult, +) -> JniResult { + let file_ext: String = env.get_string(&file_extension)?.into(); + let compression: String = env.get_string(&file_compression_type)?.into(); + let compression = FileCompressionType::from_str(&compression)?; + + let schema: Option = if !schema_ipc_bytes.is_null() { + let bytes: Vec = env.convert_byte_array(&schema_ipc_bytes)?; + let reader = StreamReader::try_new(std::io::Cursor::new(bytes), None)?; + Some((*reader.schema()).clone()) + } else { + None + }; + + let mut opts = CsvReadOptions::new() + .has_header(has_header != 0) + .delimiter(delimiter as u8) + .quote(quote as u8) + .file_extension(&file_ext) + .file_compression_type(compression); + + if terminator_set != 0 { + opts = opts.terminator(Some(terminator_value as u8)); + } + if escape_set != 0 { + opts = opts.escape(escape_value as u8); + } + if comment_set != 0 { + opts = opts.comment(comment_value as u8); + } + if newlines_in_values_set != 0 { + opts = opts.newlines_in_values(newlines_in_values_value != 0); + } + if schema_infer_max_records >= 0 { + opts = opts.schema_infer_max_records(schema_infer_max_records as usize); + } + if let Some(ref s) = schema { + opts = opts.schema(s); + } + + f(opts) +} + +#[allow(clippy::too_many_arguments)] +#[no_mangle] +pub extern "system" fn Java_org_apache_datafusion_SessionContext_registerCsvWithOptions<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, + name: JString<'local>, + path: JString<'local>, + has_header: jboolean, + delimiter: jbyte, + quote: jbyte, + terminator_set: jboolean, + terminator_value: jbyte, + escape_set: jboolean, + escape_value: jbyte, + comment_set: jboolean, + comment_value: jbyte, + newlines_in_values_set: jboolean, + newlines_in_values_value: jboolean, + schema_infer_max_records: jlong, + file_extension: JString<'local>, + file_compression_type: JString<'local>, + schema_ipc_bytes: JByteArray<'local>, +) { + try_unwrap_or_throw(&mut env, (), |env| -> JniResult<()> { + if handle == 0 { + return Err("SessionContext handle is null".into()); + } + let ctx = unsafe { &*(handle as *const SessionContext) }; + let name: String = env.get_string(&name)?.into(); + let path: String = env.get_string(&path)?.into(); + with_csv_options( + env, + has_header, + delimiter, + quote, + terminator_set, + terminator_value, + escape_set, + escape_value, + comment_set, + comment_value, + newlines_in_values_set, + newlines_in_values_value, + schema_infer_max_records, + file_extension, + file_compression_type, + schema_ipc_bytes, + |opts| { + runtime().block_on(async { + ctx.register_csv(&name, &path, opts).await?; + Ok::<(), DataFusionError>(()) + })?; + Ok(()) + }, + ) + }) +} + +#[allow(clippy::too_many_arguments)] +#[no_mangle] +pub extern "system" fn Java_org_apache_datafusion_SessionContext_readCsvWithOptions<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, + path: JString<'local>, + has_header: jboolean, + delimiter: jbyte, + quote: jbyte, + terminator_set: jboolean, + terminator_value: jbyte, + escape_set: jboolean, + escape_value: jbyte, + comment_set: jboolean, + comment_value: jbyte, + newlines_in_values_set: jboolean, + newlines_in_values_value: jboolean, + schema_infer_max_records: jlong, + file_extension: JString<'local>, + file_compression_type: JString<'local>, + schema_ipc_bytes: JByteArray<'local>, +) -> jlong { + try_unwrap_or_throw(&mut env, 0, |env| -> JniResult { + if handle == 0 { + return Err("SessionContext handle is null".into()); + } + let ctx = unsafe { &*(handle as *const SessionContext) }; + let path: String = env.get_string(&path)?.into(); + with_csv_options( + env, + has_header, + delimiter, + quote, + terminator_set, + terminator_value, + escape_set, + escape_value, + comment_set, + comment_value, + newlines_in_values_set, + newlines_in_values_value, + schema_infer_max_records, + file_extension, + file_compression_type, + schema_ipc_bytes, + |opts| { + let df = runtime().block_on(ctx.read_csv(path, opts))?; + Ok(Box::into_raw(Box::new(df)) as jlong) + }, + ) + }) +} diff --git a/native/src/lib.rs b/native/src/lib.rs index 463d075..fc3a983 100644 --- a/native/src/lib.rs +++ b/native/src/lib.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +mod csv; mod errors; mod proto; diff --git a/src/main/java/org/apache/datafusion/CsvReadOptions.java b/src/main/java/org/apache/datafusion/CsvReadOptions.java new file mode 100644 index 0000000..4b35b95 --- /dev/null +++ b/src/main/java/org/apache/datafusion/CsvReadOptions.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion; + +import org.apache.arrow.vector.types.pojo.Schema; + +/** + * Configuration knobs for CSV sources passed to {@link SessionContext#registerCsv(String, String, + * CsvReadOptions)} and {@link SessionContext#readCsv(String, CsvReadOptions)}. + * + *

Mirrors a subset of DataFusion's {@code CsvReadOptions}. All setters return {@code this} for + * fluent chaining. Defaults match the Rust struct: {@code hasHeader = true}, {@code delimiter = + * ','}, {@code quote = '"'}, {@code fileExtension = ".csv"}, {@code fileCompressionType = + * UNCOMPRESSED}, all other fields {@code null} (meaning the DataFusion default is used, or the + * schema is inferred from the file). + */ +public final class CsvReadOptions { + + /** Compression of the file. Names match DataFusion's {@code FileCompressionType} variants. */ + public enum FileCompressionType { + UNCOMPRESSED, + GZIP, + BZIP2, + XZ, + ZSTD + } + + private boolean hasHeader = true; + private byte delimiter = (byte) ','; + private byte quote = (byte) '"'; + private Byte terminator; + private Byte escape; + private Byte comment; + private Boolean newlinesInValues; + private Long schemaInferMaxRecords; + private String fileExtension = ".csv"; + private FileCompressionType fileCompressionType = FileCompressionType.UNCOMPRESSED; + private Schema schema; + + public CsvReadOptions hasHeader(boolean v) { + this.hasHeader = v; + return this; + } + + public CsvReadOptions delimiter(byte b) { + this.delimiter = b; + return this; + } + + public CsvReadOptions quote(byte b) { + this.quote = b; + return this; + } + + public CsvReadOptions terminator(byte b) { + this.terminator = b; + return this; + } + + public CsvReadOptions escape(byte b) { + this.escape = b; + return this; + } + + public CsvReadOptions comment(byte b) { + this.comment = b; + return this; + } + + public CsvReadOptions newlinesInValues(boolean v) { + this.newlinesInValues = v; + return this; + } + + public CsvReadOptions schemaInferMaxRecords(long n) { + this.schemaInferMaxRecords = n; + return this; + } + + public CsvReadOptions fileExtension(String ext) { + this.fileExtension = ext; + return this; + } + + public CsvReadOptions fileCompressionType(FileCompressionType t) { + this.fileCompressionType = t; + return this; + } + + public CsvReadOptions schema(Schema schema) { + this.schema = schema; + return this; + } + + boolean hasHeader() { + return hasHeader; + } + + byte delimiter() { + return delimiter; + } + + byte quote() { + return quote; + } + + Byte terminator() { + return terminator; + } + + Byte escape() { + return escape; + } + + Byte comment() { + return comment; + } + + Boolean newlinesInValues() { + return newlinesInValues; + } + + Long schemaInferMaxRecords() { + return schemaInferMaxRecords; + } + + String fileExtension() { + return fileExtension; + } + + FileCompressionType fileCompressionType() { + return fileCompressionType; + } + + Schema schema() { + return schema; + } +} diff --git a/src/main/java/org/apache/datafusion/SessionContext.java b/src/main/java/org/apache/datafusion/SessionContext.java index 823ee13..1aec343 100644 --- a/src/main/java/org/apache/datafusion/SessionContext.java +++ b/src/main/java/org/apache/datafusion/SessionContext.java @@ -103,6 +103,77 @@ public Schema tableSchema(String tableName) { } } + public void registerCsv(String name, String path) { + registerCsv(name, path, new CsvReadOptions()); + } + + /** + * Register a CSV file (or directory of CSV files) as a table with the supplied {@link + * CsvReadOptions}. + * + * @throws RuntimeException if registration fails (path not found, schema inference error, etc.). + */ + public void registerCsv(String name, String path, CsvReadOptions options) { + if (nativeHandle == 0) { + throw new IllegalStateException("SessionContext is closed"); + } + registerCsvWithOptions( + nativeHandle, + name, + path, + options.hasHeader(), + options.delimiter(), + options.quote(), + options.terminator() != null, + options.terminator() != null ? options.terminator() : 0, + options.escape() != null, + options.escape() != null ? options.escape() : 0, + options.comment() != null, + options.comment() != null ? options.comment() : 0, + options.newlinesInValues() != null, + options.newlinesInValues() != null && options.newlinesInValues(), + options.schemaInferMaxRecords() != null ? options.schemaInferMaxRecords() : -1L, + options.fileExtension(), + options.fileCompressionType().name(), + options.schema() != null ? serializeSchemaIpc(options.schema()) : null); + } + + /** Read a CSV file as a {@link DataFrame} without registering it. */ + public DataFrame readCsv(String path) { + return readCsv(path, new CsvReadOptions()); + } + + /** + * Read a CSV file as a {@link DataFrame} with the supplied {@link CsvReadOptions}. + * + * @throws RuntimeException if the read fails. + */ + public DataFrame readCsv(String path, CsvReadOptions options) { + if (nativeHandle == 0) { + throw new IllegalStateException("SessionContext is closed"); + } + long dfHandle = + readCsvWithOptions( + nativeHandle, + path, + options.hasHeader(), + options.delimiter(), + options.quote(), + options.terminator() != null, + options.terminator() != null ? options.terminator() : 0, + options.escape() != null, + options.escape() != null ? options.escape() : 0, + options.comment() != null, + options.comment() != null ? options.comment() : 0, + options.newlinesInValues() != null, + options.newlinesInValues() != null && options.newlinesInValues(), + options.schemaInferMaxRecords() != null ? options.schemaInferMaxRecords() : -1L, + options.fileExtension(), + options.fileCompressionType().name(), + options.schema() != null ? serializeSchemaIpc(options.schema()) : null); + return new DataFrame(dfHandle); + } + public void registerParquet(String name, String path) { registerParquet(name, path, new ParquetReadOptions()); } @@ -209,5 +280,44 @@ private static native long readParquetWithOptions( long metadataSizeHint, byte[] schemaIpcBytes); + private static native void registerCsvWithOptions( + long handle, + String name, + String path, + boolean hasHeader, + byte delimiter, + byte quote, + boolean terminatorSet, + byte terminatorValue, + boolean escapeSet, + byte escapeValue, + boolean commentSet, + byte commentValue, + boolean newlinesInValuesSet, + boolean newlinesInValuesValue, + long schemaInferMaxRecords, + String fileExtension, + String fileCompressionType, + byte[] schemaIpcBytes); + + private static native long readCsvWithOptions( + long handle, + String path, + boolean hasHeader, + byte delimiter, + byte quote, + boolean terminatorSet, + byte terminatorValue, + boolean escapeSet, + byte escapeValue, + boolean commentSet, + byte commentValue, + boolean newlinesInValuesSet, + boolean newlinesInValuesValue, + long schemaInferMaxRecords, + String fileExtension, + String fileCompressionType, + byte[] schemaIpcBytes); + private static native void closeSessionContext(long handle); } diff --git a/src/test/java/org/apache/datafusion/CsvReadOptionsTest.java b/src/test/java/org/apache/datafusion/CsvReadOptionsTest.java new file mode 100644 index 0000000..6f49403 --- /dev/null +++ b/src/test/java/org/apache/datafusion/CsvReadOptionsTest.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.List; + +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.jupiter.api.Test; + +class CsvReadOptionsTest { + + @Test + void defaultsMatchDataFusion() { + CsvReadOptions opts = new CsvReadOptions(); + assertTrue(opts.hasHeader()); + assertEquals((byte) ',', opts.delimiter()); + assertEquals((byte) '"', opts.quote()); + assertNull(opts.terminator()); + assertNull(opts.escape()); + assertNull(opts.comment()); + assertNull(opts.newlinesInValues()); + assertNull(opts.schemaInferMaxRecords()); + assertEquals(".csv", opts.fileExtension()); + assertEquals(CsvReadOptions.FileCompressionType.UNCOMPRESSED, opts.fileCompressionType()); + assertNull(opts.schema()); + } + + @Test + void fluentSettersChainAndMutate() { + Schema schema = + new Schema(List.of(new Field("x", FieldType.nullable(new ArrowType.Int(32, true)), null))); + + CsvReadOptions opts = + new CsvReadOptions() + .hasHeader(false) + .delimiter((byte) '|') + .quote((byte) '\'') + .terminator((byte) '\n') + .escape((byte) '\\') + .comment((byte) '#') + .newlinesInValues(true) + .schemaInferMaxRecords(10L) + .fileExtension(".tsv") + .fileCompressionType(CsvReadOptions.FileCompressionType.GZIP) + .schema(schema); + + assertEquals(false, opts.hasHeader()); + assertEquals((byte) '|', opts.delimiter()); + assertEquals((byte) '\'', opts.quote()); + assertEquals(Byte.valueOf((byte) '\n'), opts.terminator()); + assertEquals(Byte.valueOf((byte) '\\'), opts.escape()); + assertEquals(Byte.valueOf((byte) '#'), opts.comment()); + assertEquals(Boolean.TRUE, opts.newlinesInValues()); + assertEquals(Long.valueOf(10L), opts.schemaInferMaxRecords()); + assertEquals(".tsv", opts.fileExtension()); + assertEquals(CsvReadOptions.FileCompressionType.GZIP, opts.fileCompressionType()); + assertSame(schema, opts.schema()); + } +} diff --git a/src/test/java/org/apache/datafusion/SessionContextCsvTest.java b/src/test/java/org/apache/datafusion/SessionContextCsvTest.java new file mode 100644 index 0000000..47862d4 --- /dev/null +++ b/src/test/java/org/apache/datafusion/SessionContextCsvTest.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +class SessionContextCsvTest { + + private static Path writeCsv(Path dir, String name, String contents) throws IOException { + Path file = dir.resolve(name); + Files.writeString(file, contents); + return file; + } + + @Test + void registerCsvWithHeaderInfersSchemaAndCounts(@TempDir Path tempDir) throws Exception { + Path csv = writeCsv(tempDir, "people.csv", "id,name\n1,alice\n2,bob\n3,carol\n"); + + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext()) { + ctx.registerCsv("people", csv.toAbsolutePath().toString()); + + try (DataFrame df = ctx.sql("SELECT COUNT(*) FROM people"); + ArrowReader reader = df.collect(allocator)) { + assertTrue(reader.loadNextBatch()); + BigIntVector count = (BigIntVector) reader.getVectorSchemaRoot().getVector(0); + assertEquals(3L, count.get(0)); + } + + try (DataFrame df = ctx.sql("SELECT name FROM people WHERE id = 2"); + ArrowReader reader = df.collect(allocator)) { + assertTrue(reader.loadNextBatch()); + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + assertEquals(1, root.getRowCount()); + VarCharVector names = (VarCharVector) root.getVector(0); + assertEquals("bob", new String(names.get(0))); + } + } + } + + @Test + void readCsvWithExplicitSchemaAndNoHeader(@TempDir Path tempDir) throws Exception { + Path csv = writeCsv(tempDir, "headerless.csv", "1|alice\n2|bob\n"); + + Schema schema = + new Schema( + List.of( + new Field("id", FieldType.nullable(new ArrowType.Int(64, true)), null), + new Field("name", FieldType.nullable(new ArrowType.Utf8()), null))); + + CsvReadOptions opts = + new CsvReadOptions().hasHeader(false).delimiter((byte) '|').schema(schema); + + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext(); + DataFrame df = ctx.readCsv(csv.toAbsolutePath().toString(), opts); + ArrowReader reader = df.collect(allocator)) { + assertTrue(reader.loadNextBatch()); + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + assertEquals(2, root.getRowCount()); + assertEquals("id", root.getSchema().getFields().get(0).getName()); + assertEquals("name", root.getSchema().getFields().get(1).getName()); + } + } + + @Test + void registerCsvWithCustomExtension(@TempDir Path tempDir) throws Exception { + Path csv = writeCsv(tempDir, "data.tsv", "x\ty\n10\t20\n30\t40\n"); + + try (BufferAllocator allocator = new RootAllocator(); + SessionContext ctx = new SessionContext()) { + ctx.registerCsv( + "t", + csv.toAbsolutePath().toString(), + new CsvReadOptions().delimiter((byte) '\t').fileExtension(".tsv")); + + try (DataFrame df = ctx.sql("SELECT SUM(x) + SUM(y) FROM t"); + ArrowReader reader = df.collect(allocator)) { + assertTrue(reader.loadNextBatch()); + BigIntVector v = (BigIntVector) reader.getVectorSchemaRoot().getVector(0); + assertEquals(100L, v.get(0)); + } + } + } +}