Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .github/workflows/bench.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@ on:
push:
branches:
- main
pull_request:
branches:
- main

permissions:
contents: write
Expand Down
18 changes: 9 additions & 9 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ exclude = ["arrow-udf-duckdb-example"]

[workspace.dependencies]
anyhow = "1"
arrow-arith = "54"
arrow-array = "54"
arrow-buffer = "54"
arrow-cast = "54"
arrow-schema = "54"
arrow-select = "54"
arrow-ipc = "54"
arrow-data = "54"
arrow-flight = "54"
arrow-arith = "58.1.0"
arrow-array = "58.1.0"
arrow-buffer = "58.1.0"
arrow-cast = "58.1.0"
arrow-schema = "58.1.0"
arrow-select = "58.1.0"
arrow-ipc = "58.1.0"
arrow-data = "58.1.0"
arrow-flight = "58.1.0"
expect-test = "1"
serde_json = "1"
tokio = "1"
Expand Down
9 changes: 9 additions & 0 deletions arrow-udf-duckdb-example/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@ all: configure debug
include extension-ci-tools/makefiles/c_api_extensions/base.Makefile
include extension-ci-tools/makefiles/c_api_extensions/rust.Makefile

TEST_RUNNER=$(PYTHON_VENV_BIN) $(PROJ_DIR)run_sqllogictest.py
TEST_RUNNER_BASE=$(TEST_RUNNER) --duckdb-root-dir $(PROJ_DIR) --test-dir test/sql $(EXTRA_EXTENSIONS_PARAM)
TEST_RUNNER_DEBUG=$(TEST_RUNNER_BASE) --build-dir build/debug
TEST_RUNNER_RELEASE=$(TEST_RUNNER_BASE) --build-dir build/release

configure: venv platform extension_version
configure: install_test_runner

debug: build_extension_library_debug build_extension_with_metadata_debug
release: build_extension_library_release build_extension_with_metadata_release
Expand All @@ -21,5 +27,8 @@ test: test_debug
test_debug: test_extension_debug
test_release: test_extension_release

install_test_runner: venv
$(PYTHON_VENV_BIN) -m pip install pytest

clean: clean_build clean_rust
clean_all: clean_configure clean
18 changes: 18 additions & 0 deletions arrow-udf-duckdb-example/run_sqllogictest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#!/usr/bin/env python3

import sys
from pathlib import Path

import pytest
import sqllogic.test_sqllogic as sqllogic_runner


def main() -> int:
runner_path = Path(sqllogic_runner.__file__).resolve()
return pytest.main(
["--noconftest", "-p", "sqllogic.conftest", str(runner_path), *sys.argv[1:]],
)


if __name__ == "__main__":
raise SystemExit(main())
6 changes: 3 additions & 3 deletions arrow-udf-runtime/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ wasmtime = { version = "27", optional = true }
genawaiter2 = { version = "0.100.1", optional = true }
tempfile = { version = "3", optional = true }

pyo3 = { version = "0.24.1", optional = true, features = ["auto-initialize"] }
pyo3 = { version = "0.28", optional = true, features = ["auto-initialize"] }

atomic-time = { version = "0.1", optional = true }
rquickjs = { version = "0.6", features = [
Expand All @@ -56,11 +56,11 @@ reqwest = { version = "0.12", features = ["json"], optional = true }
serde_json = { version = "1", optional = true }

arrow-flight = { workspace = true, optional = true }
tonic = { version = "0.12", optional = true }
tonic = { version = "0.14", optional = true }
tracing = { version = "0.1", optional = true }

[build-dependencies]
pyo3-build-config = { version = "0.24", features = ["resolve-config"] }
pyo3-build-config = { version = "0.28", features = ["resolve-config"] }

[dev-dependencies]
arrow-cast = { workspace = true, features = ["prettyprint"] }
Expand Down
7 changes: 3 additions & 4 deletions arrow-udf-runtime/src/javascript/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -568,18 +568,17 @@ impl Runtime {
CallMode::ReturnNullOnNullInput => {
// This is a bit tricky. We build input arrays without nulls, call user_fn on them,
// and then add back null results to form the final result.
let n_cols = input.num_columns();
let n_rows = input.num_rows();

// 1. Build a bitmap of which rows have nulls
let mut bitmap = Vec::with_capacity(n_rows);
for i in 0..n_rows {
let has_null = (0..n_cols).any(|j| js_columns[j][i].is_null());
for row_idx in 0..n_rows {
let has_null = js_columns.iter().any(|column| column[row_idx].is_null());
bitmap.push(!has_null);
}

// 2. Build new inputs with only the rows that don't have nulls
let mut filtered_columns = Vec::with_capacity(n_cols);
let mut filtered_columns = Vec::with_capacity(js_columns.len());
for js_values in js_columns {
let filtered_js_values: Vec<_> = js_values
.into_iter()
Expand Down
4 changes: 2 additions & 2 deletions arrow-udf-runtime/src/python/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl Interpreter {
// XXX: import the `decimal` module in the interpreter before calling anything else
// otherwise it will cause `SIGABRT: pointer being freed was not allocated`
// when importing decimal in the second sub-interpreter.
Python::with_gil(|py| {
Python::attach(|py| {
py.import("decimal").unwrap();
});
});
Expand All @@ -59,7 +59,7 @@ impl Interpreter {
where
F: for<'py> FnOnce(Python<'py>) -> Result<R, PyError>,
{
Python::with_gil(f)
Python::attach(f)
}

/// Run Python code in the sub-interpreter.
Expand Down
17 changes: 15 additions & 2 deletions arrow-udf-runtime/src/python/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@ use arrow_array::builder::{ArrayBuilder, Int32Builder, StringBuilder};
use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch};
use arrow_schema::{DataType, Field, FieldRef, Schema, SchemaRef};
use pyo3::types::{PyAnyMethods, PyIterator, PyModule, PyTuple};
use pyo3::{Py, PyObject};
use pyo3::{Py, PyAny};
use std::collections::HashMap;
use std::ffi::CString;
use std::fmt::Debug;
use std::sync::Arc;

type PyObject = Py<PyAny>;

// #[cfg(Py_3_12)]
mod interpreter;
mod pyarrow;
Expand Down Expand Up @@ -99,9 +101,17 @@ struct Aggregate {

/// A builder for `Runtime`.
#[derive(Default, Debug)]
pub struct Builder {}
pub struct Builder {
safe_codes: Option<String>,
}

impl Builder {
/// Run initialization code before user-defined functions are registered.
pub fn safe_codes(mut self, code: String) -> Self {
self.safe_codes = Some(code);
self
}

/// Build the `Runtime`.
pub fn build(self) -> Result<Runtime> {
let interpreter = Interpreter::new()?;
Expand All @@ -117,6 +127,9 @@ class Struct:
pass
"#,
)?;
if let Some(code) = self.safe_codes {
interpreter.run(&code)?;
}
Ok(Runtime {
interpreter,
functions: HashMap::new(),
Expand Down
6 changes: 4 additions & 2 deletions arrow-udf-runtime/src/python/pyarrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ use pyo3::{
ffi::c_str,
prelude::PyDictMethods,
types::{PyAnyMethods, PyDict},
IntoPyObject, PyObject, PyResult, Python,
IntoPyObject, Py, PyAny, PyResult, Python,
};
use std::{borrow::Cow, ffi::CString, sync::Arc};

type PyObject = Py<PyAny>;

macro_rules! get_pyobject {
($array_type: ty, $py:expr, $array:expr, $i:expr) => {{
let array = $array.as_any().downcast_ref::<$array_type>().unwrap();
Expand Down Expand Up @@ -403,7 +405,7 @@ impl Converter {
for val in values {
if !val.is_none(py) {
let py_any = val.bind(py);
let dict = py_any.downcast::<PyDict>()?;
let dict = py_any.cast::<PyDict>()?;
flatten_keys.reserve(dict.len());
flatten_values.reserve(dict.len());
for key in dict.keys() {
Expand Down
9 changes: 8 additions & 1 deletion arrow-udf-runtime/tests/wasm_build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,14 @@ fn gcd(mut a: i32, mut b: i32) -> i32 {
#[test]
fn test_build_error() {
let err = build("??", "").unwrap_err();
assert!(err.to_string().contains("invalid key"));
let err = err.to_string();
assert!(err.contains("failed to build wasm"));
assert!(err.contains("Cargo.toml"));
assert!(
err.contains("invalid key")
|| err.contains("key with no value")
|| err.contains("expected `=`")
);
}

fn test_build_offline() {
Expand Down
5 changes: 4 additions & 1 deletion arrow-udf/arrow-udf-macros/src/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,10 @@ fn transform_input(input: &Ident, ty: &str) -> TokenStream2 {
if ty == "decimal" {
return quote! { #input.parse::<rust_decimal::Decimal>().expect("invalid decimal") };
} else if ty == "date32" {
return quote! { arrow_array::types::Date32Type::to_naive_date(#input) };
return quote! {
arrow_array::types::Date32Type::to_naive_date_opt(#input)
.expect("invalid date32 value")
};
} else if ty == "time64" {
return quote! { arrow_array::temporal_conversions::as_time::<arrow_array::types::Time64MicrosecondType>(#input).expect("invalid time") };
} else if ty == "timestamp" {
Expand Down
Loading