Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 5 additions & 13 deletions vortex-duckdb/cpp/include/duckdb_vx/table_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,32 +114,24 @@ typedef struct {

duckdb_vx_data (*init_global)(const duckdb_vx_tfunc_init_input *input, duckdb_vx_error *error_out);

duckdb_vx_data (*init_local)(const duckdb_vx_tfunc_init_input *input,
void *init_global_data,
duckdb_vx_error *error_out);
duckdb_vx_data (*init_local)(void *init_global_data);

void (*function)(duckdb_client_context ctx,
const void *bind_data,
void *init_global_data,
void (*function)(void *init_global_data,
void *init_local_data,
duckdb_data_chunk data_chunk_out,
duckdb_vx_error *error_out);

bool (*statistics)(duckdb_client_context context,
const void *bind_data,
size_t column_index,
duckdb_column_statistics *stats_out);
bool (*statistics)(const void *bind_data, size_t column_index, duckdb_column_statistics *stats_out);

void (*cardinality)(void *bind_data, duckdb_vx_node_statistics *node_stats_out);

bool (*pushdown_complex_filter)(void *bind_data, duckdb_vx_expr expr, duckdb_vx_error *error_out);

void (*to_string)(void *bind_data, duckdb_vx_string_map map);

double (*table_scan_progress)(duckdb_client_context ctx, void *bind_data, void *global_state);
double (*table_scan_progress)(void *global_state);

void (*get_partition_data)(const void *bind_data,
void *init_global_data,
void (*get_partition_data)(void *init_global_data,
void *init_local_data,
duckdb_vx_partition_data *partition_data_out);
} duckdb_vx_tfunc_vtab_t;
Expand Down
53 changes: 15 additions & 38 deletions vortex-duckdb/cpp/table_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,12 @@ struct CTableLocalData final : LocalTableFunctionState {
unique_ptr<CData> ffi_data;
};

double c_table_scan_progress(ClientContext &context,
const FunctionData *bind_data,
const GlobalTableFunctionState *global_state) {
double table_scan_progress(ClientContext &,
const FunctionData *bind_data,
const GlobalTableFunctionState *global_state) {
auto &bind = bind_data->Cast<CTableBindData>();
duckdb_client_context c_ctx = reinterpret_cast<duckdb_client_context>(&context);
void *const c_bind_data = bind.ffi_data->DataPtr();
void *const c_global_state = global_state->Cast<CTableGlobalData>().ffi_data->DataPtr();
return bind.info.vtab.table_scan_progress(c_ctx, c_bind_data, c_global_state);
return bind.info.vtab.table_scan_progress(c_global_state);
}

static Value &UnwrapValue(duckdb_value value) {
Expand Down Expand Up @@ -140,18 +138,16 @@ unique_ptr<BaseStatistics> base_stats(duckdb_column_statistics &stats, LogicalTy
return out.ToUnique();
}

unique_ptr<BaseStatistics>
c_statistics(ClientContext &context, const FunctionData *bind_data, column_t column_index) {
unique_ptr<BaseStatistics> statistics(ClientContext &, const FunctionData *bind_data, column_t column_index) {
if (IsVirtualColumn(column_index)) {
return {};
}

const auto &bind = bind_data->Cast<CTableBindData>();
void *const ffi_bind = bind.ffi_data->DataPtr();

duckdb_client_context c_ctx = reinterpret_cast<duckdb_client_context>(&context);
duckdb_column_statistics statistics = {};
if (!bind.info.vtab.statistics(c_ctx, ffi_bind, column_index, &statistics)) {
if (!bind.info.vtab.statistics(ffi_bind, column_index, &statistics)) {
return {};
}

Expand Down Expand Up @@ -243,43 +239,25 @@ unique_ptr<GlobalTableFunctionState> c_init_global(ClientContext &context, Table
return make_uniq<CTableGlobalData>(std::move(cdata));
}

unique_ptr<LocalTableFunctionState> c_init_local(ExecutionContext &context,
TableFunctionInitInput &input,
GlobalTableFunctionState *global_state) {
unique_ptr<LocalTableFunctionState>
init_local(ExecutionContext &, TableFunctionInitInput &input, GlobalTableFunctionState *global_state) {
const auto &bind = input.bind_data->Cast<CTableBindData>();
void *const ffi_global = global_state->Cast<CTableGlobalData>().ffi_data->DataPtr();

duckdb_vx_tfunc_init_input ffi_input = {
.bind_data = bind.ffi_data->DataPtr(),
.column_ids = input.column_ids.data(),
.column_ids_count = input.column_ids.size(),
.projection_ids = input.projection_ids.data(),
.projection_ids_count = input.projection_ids.size(),
.filters = reinterpret_cast<duckdb_vx_table_filter_set>(input.filters.get()),
.client_context = reinterpret_cast<duckdb_client_context>(&context),
};

duckdb_vx_error error_out = nullptr;
duckdb_vx_data ffi_local_data = bind.info.vtab.init_local(&ffi_input, ffi_global, &error_out);
if (error_out) {
throw BinderException(IntoErrString(error_out));
}

duckdb_vx_data ffi_local_data = bind.info.vtab.init_local(ffi_global);
auto cdata = unique_ptr<CData>(reinterpret_cast<CData *>(ffi_local_data));
return make_uniq<CTableLocalData>(std::move(cdata));
}

void c_function(ClientContext &context, TableFunctionInput &input, DataChunk &output) {
void function(ClientContext &, TableFunctionInput &input, DataChunk &output) {
const auto &bind = input.bind_data->Cast<CTableBindData>();

duckdb_client_context ffi_ctx = reinterpret_cast<duckdb_client_context>(&context);
void *const ffi_bind = bind.ffi_data->DataPtr();
void *const ffi_global = input.global_state->Cast<CTableGlobalData>().ffi_data->DataPtr();
void *const ffi_local = input.local_state->Cast<CTableLocalData>().ffi_data->DataPtr();

duckdb_data_chunk chunk = reinterpret_cast<duckdb_data_chunk>(&output);
duckdb_vx_error error_out = nullptr;
bind.info.vtab.function(ffi_ctx, ffi_bind, ffi_global, ffi_local, chunk, &error_out);
bind.info.vtab.function(ffi_global, ffi_local, chunk, &error_out);
if (error_out) {
throw InvalidInputException(IntoErrString(error_out));
}
Expand Down Expand Up @@ -366,11 +344,10 @@ TablePartitionInfo get_partition_info(ClientContext &, TableFunctionPartitionInp
*/
OperatorPartitionData get_partition_data(ClientContext &, TableFunctionGetPartitionInput &input) {
auto &bind = input.bind_data->Cast<CTableBindData>();
void *const ffi_bind = bind.ffi_data->DataPtr();
void *const ffi_global = input.global_state->Cast<CTableGlobalData>().ffi_data->DataPtr();
void *const ffi_local = input.local_state->Cast<CTableLocalData>().ffi_data->DataPtr();
duckdb_vx_partition_data partition_data;
bind.info.vtab.get_partition_data(ffi_bind, ffi_global, ffi_local, &partition_data);
bind.info.vtab.get_partition_data(ffi_global, ffi_local, &partition_data);

OperatorPartitionData out(partition_data.partition_index);

Expand Down Expand Up @@ -410,7 +387,7 @@ extern "C" duckdb_state duckdb_vx_tfunc_register(duckdb_database ffi_db, const d

const DatabaseWrapper &wrapper = *reinterpret_cast<DatabaseWrapper *>(ffi_db);
DatabaseInstance &db = *wrapper.database->instance;
TableFunction tf(vtab->name, {}, c_function, c_bind, c_init_global, c_init_local);
TableFunction tf(vtab->name, {}, function, c_bind, c_init_global, init_local);

tf.projection_pushdown = true;
tf.filter_pushdown = true;
Expand All @@ -422,8 +399,8 @@ extern "C" duckdb_state duckdb_vx_tfunc_register(duckdb_database ffi_db, const d
tf.get_partition_info = get_partition_info;
tf.get_partition_data = get_partition_data;
tf.to_string = c_to_string;
tf.table_scan_progress = c_table_scan_progress;
tf.statistics = c_statistics;
tf.table_scan_progress = table_scan_progress;
tf.statistics = statistics;

tf.late_materialization = true;
// Columns that uniquely identify a row for deferred re-fetch in a multi
Expand Down
32 changes: 5 additions & 27 deletions vortex-duckdb/src/datasource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,6 @@ impl ColumnStatisticsAggregate {
}
}

// ---------------------------------------------------------------------------
// Blanket TableFunction implementation for any DataSourceTableFunction
// ---------------------------------------------------------------------------

impl<T: DataSourceTableFunction> TableFunction for T {
type BindData = DataSourceBindData;
type GlobalState = DataSourceGlobal;
Expand Down Expand Up @@ -398,10 +394,7 @@ impl<T: DataSourceTableFunction> TableFunction for T {
})
}

fn init_local(
_init: &TableInitInput<Self>,
global: &Self::GlobalState,
) -> VortexResult<Self::LocalState> {
fn init_local(global: &Self::GlobalState) -> Self::LocalState {
unsafe {
use custom_labels::sys;

Expand All @@ -417,17 +410,15 @@ impl<T: DataSourceTableFunction> TableFunction for T {
CURRENT_LABELSET.set(key, value);
}

Ok(DataSourceLocal {
DataSourceLocal {
iterator: global.iterator.clone(),
exporter: None,
partition_index: 0,
file_index: 0,
})
}
}

fn scan(
_client_context: &ClientContextRef,
_bind_data: &Self::BindData,
local_state: &mut Self::LocalState,
global_state: &Self::GlobalState,
chunk: &mut DataChunkRef,
Expand Down Expand Up @@ -501,11 +492,7 @@ impl<T: DataSourceTableFunction> TableFunction for T {
Ok(())
}

fn table_scan_progress(
_client_context: &ClientContextRef,
_bind_data: &Self::BindData,
global_state: &Self::GlobalState,
) -> f64 {
fn table_scan_progress(global_state: &Self::GlobalState) -> f64 {
progress(&global_state.bytes_read, &global_state.bytes_total)
}

Expand All @@ -532,11 +519,7 @@ impl<T: DataSourceTableFunction> TableFunction for T {

/// Get column-wise statistics. Available only if we're reading a single
/// file.
fn statistics(
_client_context: &ClientContextRef,
bind_data: &Self::BindData,
column_index: usize,
) -> Option<ColumnStatistics> {
fn statistics(bind_data: &Self::BindData, column_index: usize) -> Option<ColumnStatistics> {
let children = bind_data.data_source.children();
// Otherwise we'd have to open all files eagerly which is a performance
// regression. Duckdb's Parquet reader only gets metadata for multiple
Expand Down Expand Up @@ -566,7 +549,6 @@ impl<T: DataSourceTableFunction> TableFunction for T {
}

fn partition_data(
_bind_data: &Self::BindData,
global_init_data: &Self::GlobalState,
local_init_data: &mut Self::LocalState,
) -> PartitionData {
Expand All @@ -586,10 +568,6 @@ impl<T: DataSourceTableFunction> TableFunction for T {
}
}

// ---------------------------------------------------------------------------
// Helper functions
// ---------------------------------------------------------------------------

/// Extracts DuckDB column names and logical types from a Vortex struct DType.
fn extract_schema_from_dtype(dtype: &DType) -> VortexResult<Vec<DuckdbField>> {
let struct_dtype = dtype
Expand Down
17 changes: 2 additions & 15 deletions vortex-duckdb/src/duckdb/table_function/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,26 +40,13 @@ pub(crate) unsafe extern "C-unwind" fn init_global_callback<T: TableFunction>(

/// Native callback for the local initialization of a table function.
pub(crate) unsafe extern "C-unwind" fn init_local_callback<T: TableFunction>(
init_input: *const cpp::duckdb_vx_tfunc_init_input,
global_init_data: *mut c_void,
error_out: *mut cpp::duckdb_vx_error,
) -> cpp::duckdb_vx_data {
let init_input = TableInitInput::new(
unsafe { init_input.as_ref() }.vortex_expect("init_input null pointer"),
);

let global_init_data = unsafe { global_init_data.cast::<T::GlobalState>().as_ref() }
.vortex_expect("global_init_data null pointer");

match T::init_local(&init_input, global_init_data) {
Ok(init_data) => Data::from(Box::new(init_data)).as_ptr(),
Err(e) => {
// Set the error in the error output.
let msg = e.to_string();
unsafe { error_out.write(cpp::duckdb_vx_error_create(msg.as_ptr().cast(), msg.len())) };
ptr::null_mut::<cpp::duckdb_vx_data_>().cast()
}
}
let init_data = T::init_local(global_init_data);
Data::from(Box::new(init_data)).as_ptr()
}

/// A typed wrapper for the input to a table function's initialization.
Expand Down
Loading
Loading