diff --git a/vortex-array/src/expr/stats/mod.rs b/vortex-array/src/expr/stats/mod.rs index 53f50bbb9e0..746a34fe920 100644 --- a/vortex-array/src/expr/stats/mod.rs +++ b/vortex-array/src/expr/stats/mod.rs @@ -175,12 +175,9 @@ impl Stat { Self::NullCount => DType::Primitive(PType::U64, NonNullable), Self::UncompressedSizeInBytes => DType::Primitive(PType::U64, NonNullable), Self::NaNCount => { - return aggregate_fn::fns::nan_count::NanCount - .return_dtype(&EmptyOptions, data_type); - } - Self::Sum => { - return aggregate_fn::fns::sum::Sum.return_dtype(&EmptyOptions, data_type); + aggregate_fn::fns::nan_count::NanCount.return_dtype(&EmptyOptions, data_type)? } + Self::Sum => aggregate_fn::fns::sum::Sum.return_dtype(&EmptyOptions, data_type)?, }) } diff --git a/vortex-datafusion/src/persistent/format.rs b/vortex-datafusion/src/persistent/format.rs index d5c7fe913f1..796a9c88a8a 100644 --- a/vortex-datafusion/src/persistent/format.rs +++ b/vortex-datafusion/src/persistent/format.rs @@ -6,6 +6,7 @@ use std::fmt::Debug; use std::fmt::Formatter; use std::sync::Arc; +use arrow_schema::DataType; use arrow_schema::Schema; use arrow_schema::SchemaRef; use async_trait::async_trait; @@ -568,13 +569,39 @@ impl FileFormat for VortexFormat { .transpose() }); + let sum = (!matches!(field.data_type(), DataType::Boolean)) + .then(|| { + stats_set.get(Stat::Sum).and_then(|pstat_val| { + { + Stat::Sum.dtype(stats_dtype).map(|stat_dtype| { + pstat_val + .map(|stat_val| { + Scalar::try_new(stat_dtype, Some(stat_val)) + .vortex_expect( + "`Stat::Sum` somehow had an incompatible `DType`", + ) + .cast(&DType::from_arrow(field.as_ref())) + .vortex_expect( + "Unable to cast to target type that DataFusion wants", + ) + .try_to_df() + .ok() + }) + .transpose() + }) + } + .flatten() + }) + }) + .flatten(); + let null_count = stats_set.get_as::(Stat::NullCount, &PType::U64.into()); column_statistics.push(ColumnStatistics { null_count: null_count.to_df(), min_value: min.to_df(), max_value: max.to_df(), - sum_value: Precision::Absent, + sum_value: sum.to_df(), distinct_count: stats_set .get_as::(Stat::IsConstant, &DType::Bool(Nullability::NonNullable)) .and_then(|is_constant| is_constant.as_exact().map(|_| Precision::Exact(1)))