diff --git a/.asf.yaml b/.asf.yaml index aab8c1e6df2df..b719a495bd735 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -41,6 +41,7 @@ github: - sql enabled_merge_buttons: squash: true + squash_commit_message: PR_TITLE_AND_DESC merge: false rebase: false features: @@ -50,11 +51,29 @@ github: main: required_pull_request_reviews: required_approving_review_count: 1 + # needs to be updated as part of the release process + # .asf.yaml doesn't support wildcard branch protection rules, only exact branch names + # https://github.com/apache/infrastructure-asfyaml?tab=readme-ov-file#branch-protection + # these branches protection blocks autogenerated during release process which is described in + # https://github.com/apache/datafusion/tree/main/dev/release#2-add-a-protection-to-release-candidate-branch + branch-50: + required_pull_request_reviews: + required_approving_review_count: 1 + branch-51: + required_pull_request_reviews: + required_approving_review_count: 1 + branch-52: + required_pull_request_reviews: + required_approving_review_count: 1 pull_requests: # enable updating head branches of pull requests allow_update_branch: true + allow_auto_merge: true + # auto-delete head branches after being merged + del_branch_on_merge: true # publishes the content of the `asf-site` branch to # https://datafusion.apache.org/ publish: whoami: asf-site + diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 9dd627b01abed..49aacd118e19b 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -4,10 +4,12 @@ RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ # Remove imagemagick due to https://security-tracker.debian.org/tracker/CVE-2019-10131 && apt-get purge -y imagemagick imagemagick-6-common -# Add protoc -# https://datafusion.apache.org/contributor-guide/getting_started.html#protoc-installation -RUN curl -LO https://github.com/protocolbuffers/protobuf/releases/download/v25.1/protoc-25.1-linux-x86_64.zip \ - && unzip protoc-25.1-linux-x86_64.zip -d $HOME/.local \ - && rm protoc-25.1-linux-x86_64.zip +# setup the containers WORKDIR so npm install works +# https://stackoverflow.com/questions/57534295/npm-err-tracker-idealtree-already-exists-while-creating-the-docker-image-for +WORKDIR /root -ENV PATH="$PATH:$HOME/.local/bin" \ No newline at end of file +# Add protoc, npm, prettier +# https://datafusion.apache.org/contributor-guide/development_environment.html#protoc-installation +RUN apt-get update \ + && apt-get install -y --no-install-recommends protobuf-compiler libprotobuf-dev npm nodejs\ + && rm -rf /var/lib/apt/lists/* diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index a886cbd74c23a..ac5f082113117 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -1,5 +1,6 @@ name: Bug report description: Create a report to help us improve +type: Bug labels: bug body: - type: textarea diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml index 2542b28dcae8a..955e59d74d08b 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.yml +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -1,5 +1,6 @@ name: Feature request description: Suggest an idea for this project +type: Feature labels: enhancement body: - type: textarea diff --git a/.github/actions/setup-builder/action.yaml b/.github/actions/setup-builder/action.yaml index 22d2f2187dd07..6228370c955a9 100644 --- a/.github/actions/setup-builder/action.yaml +++ b/.github/actions/setup-builder/action.yaml @@ -46,3 +46,17 @@ runs: # https://github.com/actions/checkout/issues/766 shell: bash run: git config --global --add safe.directory "$GITHUB_WORKSPACE" + - name: Remove unnecessary preinstalled software + shell: bash + run: | + echo "Disk space before cleanup:" + df -h + apt-get clean + # remove tool cache: about 8.5GB (github has host /opt/hostedtoolcache mounted as /__t) + rm -rf /__t/* || true + # remove Haskell runtime: about 6.3GB (host /usr/local/.ghcup) + rm -rf /host/usr/local/.ghcup || true + # remove Android library: about 7.8GB (host /usr/local/lib/android) + rm -rf /host/usr/local/lib/android || true + echo "Disk space after cleanup:" + df -h \ No newline at end of file diff --git a/.github/actions/setup-macos-aarch64-builder/action.yaml b/.github/actions/setup-macos-aarch64-builder/action.yaml index 288799a284b01..b62370447adea 100644 --- a/.github/actions/setup-macos-aarch64-builder/action.yaml +++ b/.github/actions/setup-macos-aarch64-builder/action.yaml @@ -44,6 +44,8 @@ runs: rustup default stable rustup component add rustfmt - name: Setup rust cache - uses: Swatinem/rust-cache@v2 + uses: Swatinem/rust-cache@f13886b937689c021905a6b90929199931d60db1 # v2.8.1 + with: + save-if: ${{ github.ref_name == 'main' }} - name: Configure rust runtime env uses: ./.github/actions/setup-rust-runtime diff --git a/.github/actions/setup-macos-builder/action.yaml b/.github/actions/setup-macos-builder/action.yaml deleted file mode 100644 index fffdab160b043..0000000000000 --- a/.github/actions/setup-macos-builder/action.yaml +++ /dev/null @@ -1,47 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -name: Prepare Rust Builder for MacOS -description: 'Prepare Rust Build Environment for MacOS' -inputs: - rust-version: - description: 'version of rust to install (e.g. stable)' - required: true - default: 'stable' -runs: - using: "composite" - steps: - - name: Install protobuf compiler - shell: bash - run: | - mkdir -p $HOME/d/protoc - cd $HOME/d/protoc - export PROTO_ZIP="protoc-29.1-osx-x86_64.zip" - curl -LO https://github.com/protocolbuffers/protobuf/releases/download/v29.1/$PROTO_ZIP - unzip $PROTO_ZIP - echo "$HOME/d/protoc/bin" >> $GITHUB_PATH - export PATH=$PATH:$HOME/d/protoc/bin - protoc --version - - name: Setup Rust toolchain - shell: bash - run: | - rustup update stable - rustup toolchain install stable - rustup default stable - rustup component add rustfmt - - name: Configure rust runtime env - uses: ./.github/actions/setup-rust-runtime diff --git a/.github/actions/setup-rust-runtime/action.yaml b/.github/actions/setup-rust-runtime/action.yaml index b6fb2c898bf2f..e0341de93b83d 100644 --- a/.github/actions/setup-rust-runtime/action.yaml +++ b/.github/actions/setup-rust-runtime/action.yaml @@ -20,10 +20,6 @@ description: 'Setup Rust Runtime Environment' runs: using: "composite" steps: - # https://github.com/apache/datafusion/issues/15535 - # disabled because neither version nor git hash works with apache github policy - #- name: Run sccache-cache - # uses: mozilla-actions/sccache-action@65101d47ea8028ed0c98a1cdea8dd9182e9b5133 # v0.0.8 - name: Configure runtime env shell: bash # do not produce debug symbols to keep memory usage down @@ -32,11 +28,6 @@ runs: # # Set debuginfo=line-tables-only as debuginfo=0 causes immensely slow build # See for more details: https://github.com/rust-lang/rust/issues/119560 - # - # readd the following to the run below once sccache-cache is re-enabled - # echo "RUSTC_WRAPPER=sccache" >> $GITHUB_ENV - # echo "SCCACHE_GHA_ENABLED=true" >> $GITHUB_ENV run: | echo "RUST_BACKTRACE=1" >> $GITHUB_ENV echo "RUSTFLAGS=-C debuginfo=line-tables-only -C incremental=false" >> $GITHUB_ENV - diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 7c2b7e3a5458c..2cd4bdfdd7923 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -20,9 +20,10 @@ updates: - package-ecosystem: cargo directory: "/" schedule: - interval: daily + interval: weekly target-branch: main labels: [auto-dependencies] + open-pull-requests-limit: 15 ignore: # major version bumps of arrow* and parquet are handled manually - dependency-name: "arrow*" @@ -44,9 +45,31 @@ updates: patterns: - "prost*" - "pbjson*" + + # Catch-all: group only minor/patch into a single PR, + # excluding deps we want always separate (and excluding arrow/parquet which have their own group) + all-other-cargo-deps: + applies-to: version-updates + patterns: + - "*" + exclude-patterns: + - "arrow*" + - "parquet" + - "object_store" + - "sqlparser" + - "prost*" + - "pbjson*" + update-types: + - "minor" + - "patch" - package-ecosystem: "github-actions" directory: "/" schedule: - interval: "daily" + interval: "weekly" open-pull-requests-limit: 10 labels: [auto-dependencies] + - package-ecosystem: "pip" + directory: "/docs" + schedule: + interval: "weekly" + labels: [auto-dependencies] diff --git a/.github/workflows/audit.yml b/.github/workflows/audit.yml index 491fa27c2a56a..281f600d6766a 100644 --- a/.github/workflows/audit.yml +++ b/.github/workflows/audit.yml @@ -23,25 +23,29 @@ concurrency: on: push: + branches: + - main paths: - "**/Cargo.toml" - "**/Cargo.lock" - branches: - - main pull_request: paths: - "**/Cargo.toml" - "**/Cargo.lock" + + merge_group: jobs: security_audit: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Install cargo-audit - run: cargo install cargo-audit + uses: taiki-e/install-action@de6bbd1333b8f331563d54a051e542c7dfef81c3 # v2.68.34 + with: + tool: cargo-audit - name: Run audit check - # Ignored until https://github.com/apache/datafusion/issues/15571 - # ignored py03 warning until arrow 55 upgrade - run: cargo audit --ignore RUSTSEC-2024-0370 --ignore RUSTSEC-2025-0020 + # Note: you can ignore specific RUSTSEC issues using the `--ignore` flag ,for example: + # run: cargo audit --ignore RUSTSEC-2026-0001 + run: cargo audit --ignore RUSTSEC-2024-0014 diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml new file mode 100644 index 0000000000000..d42c2b4aa8d39 --- /dev/null +++ b/.github/workflows/codeql.yml @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +name: "CodeQL" + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + schedule: + - cron: '16 4 * * 1' + +permissions: + contents: read + +jobs: + analyze: + name: Analyze Actions + runs-on: ubuntu-latest + permissions: + contents: read + security-events: write + packages: read + + steps: + - name: Checkout repository + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + with: + persist-credentials: false + + - name: Initialize CodeQL + uses: github/codeql-action/init@b1bff81932f5cdfc8695c7752dcee935dcd061c8 # v4 + with: + languages: actions + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@b1bff81932f5cdfc8695c7752dcee935dcd061c8 # v4 + with: + category: "/language:actions" diff --git a/.github/workflows/dependencies.yml b/.github/workflows/dependencies.yml index a577725fed4b9..3b2cc243d4967 100644 --- a/.github/workflows/dependencies.yml +++ b/.github/workflows/dependencies.yml @@ -23,6 +23,8 @@ concurrency: on: push: + branches-ignore: + - 'gh-readonly-queue/**' paths: - "**/Cargo.toml" - "**/Cargo.lock" @@ -30,6 +32,7 @@ on: paths: - "**/Cargo.toml" - "**/Cargo.lock" + merge_group: # manual trigger # https://docs.github.com/en/actions/managing-workflow-runs/manually-running-a-workflow workflow_dispatch: @@ -41,7 +44,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -53,3 +56,14 @@ jobs: run: | cd dev/depcheck cargo run + + detect-unused-dependencies: + runs-on: ubuntu-latest + container: + image: amd64/rust + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + - name: Install cargo-machete + run: cargo install cargo-machete --version ^0.9 --locked + - name: Detect unused dependencies + run: cargo machete --with-metadata diff --git a/.github/workflows/dev.yml b/.github/workflows/dev.yml index aa4bd862e09e4..2fec343650914 100644 --- a/.github/workflows/dev.yml +++ b/.github/workflows/dev.yml @@ -16,7 +16,12 @@ # under the License. name: Dev -on: [push, pull_request] +on: + push: + branches-ignore: + - 'gh-readonly-queue/**' + pull_request: + merge_group: concurrency: group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} @@ -27,25 +32,36 @@ jobs: runs-on: ubuntu-latest name: Check License Header steps: - - uses: actions/checkout@v4 - - uses: korandoru/hawkeye@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + - name: Install HawkEye + # This CI job is bound by installation time, use `--profile dev` to speed it up + run: cargo install hawkeye --version 6.2.0 --locked --profile dev + - name: Run license header check + run: ci/scripts/license_header.sh prettier: name: Use prettier to check formatting of documents runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - uses: actions/setup-node@v4 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + - uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # v6.2.0 with: node-version: "20" - name: Prettier check - run: | - # if you encounter error, rerun the command below and commit the changes - # - # ignore subproject CHANGELOG.md because they are machine generated - npx prettier@2.7.1 --write \ - '{datafusion,datafusion-cli,datafusion-examples,dev,docs}/**/*.md' \ - '!datafusion/CHANGELOG.md' \ - README.md \ - CONTRIBUTING.md - git diff --exit-code + # if you encounter error, see instructions inside the script + run: ci/scripts/doc_prettier_check.sh + + typos: + name: Spell Check with Typos + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false + # Version fixed on purpose. It uses heuristics to detect typos, so upgrading + # it may cause checks to fail more often. + # We can upgrade it manually once a while. + - name: Install typos-cli + run: cargo install typos-cli --locked --version 1.37.0 + - name: Run typos check + run: ci/scripts/typos_check.sh diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 5f1b2c1395982..63add4dacc812 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -32,32 +32,31 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout docs sources - uses: actions/checkout@v4 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Checkout asf-site branch - uses: actions/checkout@v4 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: asf-site path: asf-site - - name: Setup Python - uses: actions/setup-python@v5 - with: - python-version: "3.12" + - name: Setup uv + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 - name: Install dependencies + run: uv sync --package datafusion-docs + - name: Install dependency graph tooling run: | set -x - python3 -m venv venv - source venv/bin/activate - pip install -r docs/requirements.txt + sudo apt-get update + sudo apt-get install -y graphviz + cargo install cargo-depgraph --version ^1.6 --locked - name: Build docs run: | set -x - source venv/bin/activate cd docs - ./build.sh + uv run --package datafusion-docs ./build.sh - name: Copy & push the generated HTML run: | diff --git a/.github/workflows/docs_pr.yaml b/.github/workflows/docs_pr.yaml index 8d11cdf9d39bb..cc5b9a1e44bb5 100644 --- a/.github/workflows/docs_pr.yaml +++ b/.github/workflows/docs_pr.yaml @@ -40,24 +40,22 @@ jobs: name: Test doc build runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 - - name: Setup Python - uses: actions/setup-python@v5 - with: - python-version: "3.12" + - name: Setup uv + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 - name: Install doc dependencies + run: uv sync --package datafusion-docs + - name: Install dependency graph tooling run: | set -x - python3 -m venv venv - source venv/bin/activate - pip install -r docs/requirements.txt + sudo apt-get update + sudo apt-get install -y graphviz + cargo install cargo-depgraph --version ^1.6 --locked - name: Build docs html and check for warnings run: | set -x - source venv/bin/activate cd docs - ./build.sh # fails on errors - + uv run --package datafusion-docs ./build.sh # fails on errors diff --git a/.github/workflows/extended.yml b/.github/workflows/extended.yml index 0ccecfc44fd64..c2aa96d92edc0 100644 --- a/.github/workflows/extended.yml +++ b/.github/workflows/extended.yml @@ -32,6 +32,19 @@ on: push: branches: - main + # support extended test suite for release candidate branches, + # it is not expected to have many changes in these branches, + # so running extended tests is not a burden + - 'branch-*' + # Also run for changes to some critical areas that are most likely + # to trigger errors in extended tests + pull_request: + branches: [ '**' ] + paths: + - 'datafusion/physical*/**/*.rs' + - 'datafusion/expr*/**/*.rs' + - 'datafusion/optimizer/**/*.rs' + - 'datafusion-testing' workflow_dispatch: inputs: pr_number: @@ -53,10 +66,11 @@ jobs: # Check crate compiles and base cargo check passes linux-build-lib: name: linux build test - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=8,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} # note: do not use amd/rust container to preserve disk space steps: - - uses: actions/checkout@v4 + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: ${{ github.event.inputs.pr_head_sha }} # will be empty if triggered by push submodules: true @@ -67,7 +81,9 @@ jobs: source $HOME/.cargo/env rustup toolchain install - name: Install Protobuf Compiler - run: sudo apt-get install -y protobuf-compiler + run: | + sudo apt-get update + sudo apt-get install -y protobuf-compiler - name: Prepare cargo build run: | cargo check --profile ci --all-targets @@ -77,23 +93,27 @@ jobs: linux-test-extended: name: cargo test 'extended_tests' (amd64) needs: [linux-build-lib] - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=32,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion,spot=false', github.run_id) || 'ubuntu-latest' }} + # spot=false because the tests are long, https://runs-on.com/configuration/spot-instances/#disable-spot-pricing # note: do not use amd/rust container to preserve disk space steps: - - uses: actions/checkout@v4 + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: ${{ github.event.inputs.pr_head_sha }} # will be empty if triggered by push submodules: true fetch-depth: 1 - name: Free Disk Space (Ubuntu) - uses: jlumbroso/free-disk-space@54081f138730dfa15788a46383842cd2f914a1be + uses: jlumbroso/free-disk-space@54081f138730dfa15788a46383842cd2f914a1be # v1.3.1 - name: Install Rust run: | curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y source $HOME/.cargo/env rustup toolchain install - name: Install Protobuf Compiler - run: sudo apt-get install -y protobuf-compiler + run: | + sudo apt-get update + sudo apt-get install -y protobuf-compiler # For debugging, test binaries can be large. - name: Show available disk space run: | @@ -111,7 +131,7 @@ jobs: --lib \ --tests \ --bins \ - --features avro,json,backtrace,extended_tests,recursive_protection + --features avro,json,backtrace,extended_tests,recursive_protection,parquet_encryption - name: Verify Working Directory Clean run: git diff --exit-code - name: Cleanup @@ -120,11 +140,12 @@ jobs: # Check answers are correct when hash values collide hash-collisions: name: cargo test hash collisions (amd64) - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: ${{ github.event.inputs.pr_head_sha }} # will be empty if triggered by push submodules: true @@ -136,16 +157,18 @@ jobs: - name: Run tests run: | cd datafusion - cargo test --profile ci --exclude datafusion-examples --exclude datafusion-benchmarks --exclude datafusion-sqllogictest --workspace --lib --tests --features=force_hash_collisions,avro + cargo test --profile ci --exclude datafusion-examples --exclude datafusion-benchmarks --exclude datafusion-sqllogictest --exclude datafusion-cli --workspace --lib --tests --features=force_hash_collisions,avro cargo clean sqllogictest-sqlite: name: "Run sqllogictests with the sqlite test suite" - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=48,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion,spot=false', github.run_id) || 'ubuntu-latest' }} + # spot=false because the tests are long, https://runs-on.com/configuration/spot-instances/#disable-spot-pricing container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: ${{ github.event.inputs.pr_head_sha }} # will be empty if triggered by push submodules: true @@ -156,47 +179,4 @@ jobs: rust-version: stable - name: Run sqllogictest run: | - cargo test --features backtrace --profile release-nonlto --test sqllogictests -- --include-sqlite - cargo clean - - # If the workflow was triggered by the PR comment (through pr_comment_commands.yml action) we need to manually update check status to display in UI - update-check-status: - needs: [linux-build-lib, linux-test-extended, hash-collisions, sqllogictest-sqlite] - runs-on: ubuntu-latest - if: ${{ always() && github.event_name == 'workflow_dispatch' }} - steps: - - name: Determine workflow status - id: status - run: | - if [[ "${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') }}" == "true" ]]; then - echo "workflow_status=failure" >> $GITHUB_OUTPUT - echo "conclusion=failure" >> $GITHUB_OUTPUT - else - echo "workflow_status=completed" >> $GITHUB_OUTPUT - echo "conclusion=success" >> $GITHUB_OUTPUT - fi - - - name: Update check run - uses: actions/github-script@v7 - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - script: | - const workflowRunUrl = `https://github.com/${context.repo.owner}/${context.repo.repo}/actions/runs/${context.runId}`; - - await github.rest.checks.update({ - owner: context.repo.owner, - repo: context.repo.repo, - check_run_id: ${{ github.event.inputs.check_run_id }}, - status: 'completed', - conclusion: '${{ steps.status.outputs.conclusion }}', - output: { - title: '${{ steps.status.outputs.conclusion == 'success' && 'Extended Tests Passed' || 'Extended Tests Failed' }}', - summary: `Extended tests have completed with status: ${{ steps.status.outputs.conclusion }}.\n\n[View workflow run](${workflowRunUrl})` - }, - details_url: workflowRunUrl - }); - - - - - + cargo test --features backtrace,parquet_encryption --profile ci-optimized --test sqllogictests -- --include-sqlite \ No newline at end of file diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index 8b251552d3b2d..a575b39577477 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -39,14 +39,12 @@ jobs: contents: read pull-requests: write steps: - - uses: actions/checkout@v4 - - name: Assign GitHub labels if: | github.event_name == 'pull_request_target' && (github.event.action == 'opened' || github.event.action == 'synchronize') - uses: actions/labeler@v5.0.0 + uses: actions/labeler@634933edcd8ababfe52f92936142cc22ac488b1b # v6.0.1 with: repo-token: ${{ secrets.GITHUB_TOKEN }} configuration-path: .github/workflows/labeler/labeler-config.yml diff --git a/.github/workflows/labeler/labeler-config.yml b/.github/workflows/labeler/labeler-config.yml index e408130725215..0e492b6f3f6dc 100644 --- a/.github/workflows/labeler/labeler-config.yml +++ b/.github/workflows/labeler/labeler-config.yml @@ -58,11 +58,11 @@ execution: datasource: - changed-files: - - any-glob-to-any-file: ['datafusion/datasource/**/*', 'datafusion/datasource-avro/**/*', 'datafusion/datasource-csv/**/*', 'datafusion/datasource-json/**/*', 'datafusion/datasource-parquet/**/*'] + - any-glob-to-any-file: ['datafusion/datasource/**/*', 'datafusion/datasource-avro/**/*', 'datafusion/datasource-arrow/**/*', 'datafusion/datasource-csv/**/*', 'datafusion/datasource-json/**/*', 'datafusion/datasource-parquet/**/*'] functions: - changed-files: - - any-glob-to-any-file: ['datafusion/functions/**/*', 'datafusion/functions-aggregate/**/*', 'datafusion/functions-aggregate-common', 'datafusion/functions-nested', 'datafusion/functions-table/**/*', 'datafusion/functions-window/**/*', 'datafusion/functions-window-common/**/*'] + - any-glob-to-any-file: ['datafusion/functions/**/*', 'datafusion/functions-aggregate/**/*', 'datafusion/functions-aggregate-common/**/*', 'datafusion/functions-nested/**/*', 'datafusion/functions-table/**/*', 'datafusion/functions-window/**/*', 'datafusion/functions-window-common/**/*'] optimizer: diff --git a/.github/workflows/large_files.yml b/.github/workflows/large_files.yml index aa96d55a0d851..12b7bae76ab32 100644 --- a/.github/workflows/large_files.yml +++ b/.github/workflows/large_files.yml @@ -23,12 +23,13 @@ concurrency: on: pull_request: + merge_group: jobs: check-files: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: fetch-depth: 0 - name: Check size of new Git objects @@ -38,7 +39,16 @@ jobs: MAX_FILE_SIZE_BYTES: 1048576 shell: bash run: | - git rev-list --objects ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }} \ + if [ "${{ github.event_name }}" = "merge_group" ]; then + # For merge queue, compare against the base branch + base_sha="${{ github.event.merge_group.base_sha }}" + head_sha="${{ github.event.merge_group.head_sha }}" + else + # For pull requests + base_sha="${{ github.event.pull_request.base.sha }}" + head_sha="${{ github.event.pull_request.head.sha }}" + fi + git rev-list --objects ${base_sha}..${head_sha} \ > pull-request-objects.txt exit_code=0 while read -r id path; do diff --git a/.github/workflows/pr_comment_commands.yml b/.github/workflows/pr_comment_commands.yml deleted file mode 100644 index 6aa6caaf34d02..0000000000000 --- a/.github/workflows/pr_comment_commands.yml +++ /dev/null @@ -1,89 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -name: PR commands - -on: - issue_comment: - types: [created] - -permissions: - contents: read - pull-requests: write - actions: write - checks: write - -jobs: - # Starts the extended_tests on a PR branch when someone leaves a `Run extended tests` comment - run_extended_tests: - runs-on: ubuntu-latest - if: ${{ github.event_name == 'issue_comment' && github.event.issue.pull_request && contains(github.event.comment.body, 'Run extended tests') }} - steps: - - name: Dispatch extended tests for a PR branch with comment - uses: actions/github-script@v7 - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - script: | - // Get PR details to fetch the branch name - const { data: pullRequest } = await github.rest.pulls.get({ - owner: context.repo.owner, - repo: context.repo.repo, - pull_number: context.payload.issue.number - }); - - // Extract the branch name - const branchName = pullRequest.head.ref; - const headSha = pullRequest.head.sha; - const workflowRunsUrl = `https://github.com/${context.repo.owner}/${context.repo.repo}/actions?query=workflow%3A%22Datafusion+extended+tests%22+branch%3A${branchName}`; - - // Create a check run that links to the Actions tab so the run will be visible in GitHub UI - const check = await github.rest.checks.create({ - owner: context.repo.owner, - repo: context.repo.repo, - name: 'Extended Tests', - head_sha: headSha, - status: 'in_progress', - output: { - title: 'Extended Tests Running', - summary: `Extended tests have been triggered for this PR.\n\n[View workflow runs](${workflowRunsUrl})` - }, - details_url: workflowRunsUrl - }); - - // Dispatch the workflow with the PR branch name - await github.rest.actions.createWorkflowDispatch({ - owner: context.repo.owner, - repo: context.repo.repo, - workflow_id: 'extended.yml', - ref: 'main', - inputs: { - pr_number: context.payload.issue.number.toString(), - check_run_id: check.data.id.toString(), - pr_head_sha: headSha - } - }); - - - name: Add reaction to comment - uses: actions/github-script@v7 - with: - script: | - await github.rest.reactions.createForIssueComment({ - owner: context.repo.owner, - repo: context.repo.repo, - comment_id: context.payload.comment.id, - content: 'rocket' - }); diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 2463b04b33738..f7452ee603b1c 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +# For some actions, we use Runs-On to run them on ASF infrastructure: https://datafusion.apache.org/contributor-guide/#ci-runners + name: Rust concurrency: @@ -23,6 +25,8 @@ concurrency: on: push: + branches-ignore: + - 'gh-readonly-queue/**' paths-ignore: - "docs/**" - "**.md" @@ -34,31 +38,30 @@ on: - "**.md" - ".github/ISSUE_TEMPLATE/**" - ".github/pull_request_template.md" + merge_group: # manual trigger # https://docs.github.com/en/actions/managing-workflow-runs/manually-running-a-workflow workflow_dispatch: jobs: - # Check license header - license-header-check: - runs-on: ubuntu-latest - name: Check License Header - steps: - - uses: actions/checkout@v4 - - uses: korandoru/hawkeye@v6 - # Check crate compiles and base cargo check passes linux-build-lib: name: linux build test - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=8,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: rust-version: stable + - name: Rust Dependency Cache + uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2.9.1 + with: + shared-key: "amd-ci-check" # this job uses it's own cache becase check has a separate cache and we need it to be fast as it blocks other jobs + save-if: ${{ github.ref_name == 'main' }} - name: Prepare cargo build run: | # Adding `--locked` here to assert that the `Cargo.lock` file is up to @@ -77,7 +80,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -98,15 +101,20 @@ jobs: linux-datafusion-substrait-features: name: cargo check datafusion-substrait features needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: rust-version: stable + - name: Rust Dependency Cache + uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2.9.1 + with: + save-if: false # set in linux-test + shared-key: "amd-ci" - name: Check datafusion-substrait (default features) run: cargo check --profile ci --all-targets -p datafusion-substrait # @@ -130,11 +138,12 @@ jobs: linux-datafusion-proto-features: name: cargo check datafusion-proto features needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -161,15 +170,21 @@ jobs: linux-cargo-check-datafusion: name: cargo check datafusion features needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: rust-version: stable + - name: Rust Dependency Cache + uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2.9.1 + with: + save-if: false # set in linux-test + shared-key: "amd-ci" - name: Check datafusion (default features) run: cargo check --profile ci --all-targets -p datafusion # @@ -199,18 +214,20 @@ jobs: run: cargo check --profile ci --no-default-features -p datafusion --features=math_expressions - name: Check datafusion (parquet) run: cargo check --profile ci --no-default-features -p datafusion --features=parquet - - name: Check datafusion (pyarrow) - run: cargo check --profile ci --no-default-features -p datafusion --features=pyarrow - name: Check datafusion (regex_expressions) run: cargo check --profile ci --no-default-features -p datafusion --features=regex_expressions - name: Check datafusion (recursive_protection) run: cargo check --profile ci --no-default-features -p datafusion --features=recursive_protection - name: Check datafusion (serde) run: cargo check --profile ci --no-default-features -p datafusion --features=serde + - name: Check datafusion (sql) + run: cargo check --profile ci --no-default-features -p datafusion --features=sql - name: Check datafusion (string_expressions) run: cargo check --profile ci --no-default-features -p datafusion --features=string_expressions - name: Check datafusion (unicode_expressions) run: cargo check --profile ci --no-default-features -p datafusion --features=unicode_expressions + - name: Check parquet encryption (parquet_encryption) + run: cargo check --profile ci --no-default-features -p datafusion --features=parquet_encryption # Check datafusion-functions crate features # @@ -223,7 +240,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -254,16 +271,26 @@ jobs: linux-test: name: cargo test (amd64) needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} + container: + image: amd64/rust + volumes: + - /usr/local:/host/usr/local steps: - - uses: actions/checkout@v4 + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 - name: Setup Rust toolchain - run: rustup toolchain install stable - - name: Install Protobuf Compiler - run: sudo apt-get install -y protobuf-compiler + uses: ./.github/actions/setup-builder + with: + rust-version: stable + - name: Rust Dependency Cache + uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2.9.1 + with: + save-if: ${{ github.ref_name == 'main' }} + shared-key: "amd-ci" - name: Run tests (excluding doctests and datafusion-cli) env: RUST_BACKTRACE: 1 @@ -278,34 +305,37 @@ jobs: --lib \ --tests \ --bins \ - --features serde,avro,json,backtrace,integration-tests + --features serde,avro,json,backtrace,integration-tests,parquet_encryption - name: Verify Working Directory Clean run: git diff --exit-code + # Check no temporary directories created during test. + # `false/` folder is excuded for rust cache. + - name: Verify Working Directory Clean (No Untracked Files) + run: | + STATUS="$(git status --porcelain | sed -e '/^?? false\/$/d' -e '/^?? false$/d')" + if [ -n "$STATUS" ]; then + echo "$STATUS" + exit 1 + fi # datafusion-cli tests linux-test-datafusion-cli: name: cargo test datafusion-cli (amd64) needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} steps: - - uses: actions/checkout@v4 + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 - name: Setup Rust toolchain run: rustup toolchain install stable - - name: Setup Minio - S3-compatible storage - run: | - docker run -d --name minio-container \ - -p 9000:9000 \ - -e MINIO_ROOT_USER=TEST-DataFusionLogin -e MINIO_ROOT_PASSWORD=TEST-DataFusionPassword \ - -v $(pwd)/datafusion/core/tests/data:/source quay.io/minio/minio \ - server /data - docker exec minio-container /bin/sh -c "\ - mc ready local - mc alias set localminio http://localhost:9000 TEST-DataFusionLogin TEST-DataFusionPassword && \ - mc mb localminio/data && \ - mc cp -r /source/* localminio/data" + - name: Rust Dependency Cache + uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2.9.1 + with: + save-if: false # set in linux-test + shared-key: "amd-ci" - name: Run tests (excluding doctests) env: RUST_BACKTRACE: 1 @@ -314,22 +344,20 @@ jobs: AWS_SECRET_ACCESS_KEY: TEST-DataFusionPassword TEST_STORAGE_INTEGRATION: 1 AWS_ALLOW_HTTP: true - run: cargo test --profile ci -p datafusion-cli --lib --tests --bins + run: cargo test --features backtrace --profile ci -p datafusion-cli --lib --tests --bins - name: Verify Working Directory Clean run: git diff --exit-code - - name: Minio Output - if: ${{ !cancelled() }} - run: docker logs minio-container linux-test-example: name: cargo examples (amd64) needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -337,6 +365,11 @@ jobs: uses: ./.github/actions/setup-builder with: rust-version: stable + - name: Rust Dependency Cache + uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2.9.1 + with: + save-if: ${{ github.ref_name == 'main' }} + shared-key: "amd-ci-linux-test-example" - name: Run examples run: | # test datafusion-sql examples @@ -350,11 +383,12 @@ jobs: linux-test-doc: name: cargo test doc (amd64) needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -371,11 +405,12 @@ jobs: linux-rustdoc: name: cargo doc needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -387,7 +422,7 @@ jobs: name: build and run with wasm-pack runs-on: ubuntu-24.04 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup for wasm32 run: | rustup target add wasm32-unknown-unknown @@ -396,23 +431,27 @@ jobs: sudo apt-get update -qq sudo apt-get install -y -qq clang - name: Setup wasm-pack - run: | - cargo install wasm-pack + uses: taiki-e/install-action@de6bbd1333b8f331563d54a051e542c7dfef81c3 # v2.68.34 + with: + tool: wasm-pack - name: Run tests with headless mode working-directory: ./datafusion/wasmtest run: | - RUSTFLAGS='--cfg getrandom_backend="wasm_js"' wasm-pack test --headless --firefox - RUSTFLAGS='--cfg getrandom_backend="wasm_js"' wasm-pack test --headless --chrome --chromedriver $CHROMEWEBDRIVER/chromedriver + # debuginfo=none because CI tests weren't completing successfully after this upstream PR: + # https://github.com/wasm-bindgen/wasm-bindgen/pull/4635 + RUSTFLAGS='--cfg getrandom_backend="wasm_js" -C debuginfo=none' wasm-pack test --headless --firefox + RUSTFLAGS='--cfg getrandom_backend="wasm_js" -C debuginfo=none' wasm-pack test --headless --chrome --chromedriver $CHROMEWEBDRIVER/chromedriver # verify that the benchmark queries return the correct results verify-benchmark-results: name: verify benchmark results (amd64) needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -434,14 +473,14 @@ jobs: export RUST_MIN_STACK=20971520 export TPCH_DATA=`realpath datafusion/sqllogictest/test_files/tpch/data` cargo test plan_q --package datafusion-benchmarks --profile ci --features=ci -- --test-threads=1 - INCLUDE_TPCH=true cargo test --features backtrace --profile ci --package datafusion-sqllogictest --test sqllogictests + INCLUDE_TPCH=true cargo test --features backtrace,parquet_encryption --profile ci --package datafusion-sqllogictest --test sqllogictests - name: Verify Working Directory Clean run: git diff --exit-code sqllogictest-postgres: name: "Run sqllogictest with Postgres runner" needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust services: @@ -459,7 +498,8 @@ jobs: --health-timeout 5s --health-retries 5 steps: - - uses: actions/checkout@v4 + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -476,6 +516,29 @@ jobs: POSTGRES_HOST: postgres POSTGRES_PORT: ${{ job.services.postgres.ports[5432] }} + sqllogictest-substrait: + name: "Run sqllogictest in Substrait round-trip mode" + needs: linux-build-lib + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} + container: + image: amd64/rust + steps: + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + submodules: true + fetch-depth: 1 + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: stable + - name: Run sqllogictest + # TODO: Right now several tests are failing in Substrait round-trip mode, so this + # command cannot be run for all the .slt files. Run it for just one that works (limit.slt) + # until most of the tickets in https://github.com/apache/datafusion/issues/16248 are addressed + # and this command can be run without filters. + run: cargo test --test sqllogictests -- --substrait-round-trip limit.slt + # Temporarily commenting out the Windows flow, the reason is enormously slow running build # Waiting for new Windows 2025 github runner # Details: https://github.com/apache/datafusion/issues/13726 @@ -495,27 +558,11 @@ jobs: # export PATH=$PATH:$HOME/d/protoc/bin # cargo test --lib --tests --bins --features avro,json,backtrace - # Commenting out intel mac build as so few users would ever use it - # Details: https://github.com/apache/datafusion/issues/13846 - # macos: - # name: cargo test (macos) - # runs-on: macos-latest - # steps: - # - uses: actions/checkout@v4 - # with: - # submodules: true - # fetch-depth: 1 - # - name: Setup Rust toolchain - # uses: ./.github/actions/setup-macos-builder - # - name: Run tests (excluding doctests) - # shell: bash - # run: cargo test run --profile ci --exclude datafusion-examples --exclude datafusion-benchmarks --workspace --lib --tests --bins --features avro,json,backtrace - macos-aarch64: name: cargo test (macos-aarch64) - runs-on: macos-14 + runs-on: macos-15 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -525,37 +572,13 @@ jobs: shell: bash run: cargo test --profile ci --exclude datafusion-cli --workspace --lib --tests --bins --features avro,json,backtrace,integration-tests - test-datafusion-pyarrow: - name: cargo test pyarrow (amd64) - needs: linux-build-lib - runs-on: ubuntu-latest - container: - image: amd64/rust:bullseye # Use the bullseye tag image which comes with python3.9 - steps: - - uses: actions/checkout@v4 - with: - submodules: true - fetch-depth: 1 - - name: Install PyArrow - run: | - echo "LIBRARY_PATH=$LD_LIBRARY_PATH" >> $GITHUB_ENV - apt-get update - apt-get install python3-pip -y - python3 -m pip install pyarrow - - name: Setup Rust toolchain - uses: ./.github/actions/setup-builder - with: - rust-version: stable - - name: Run datafusion-common tests - run: cargo test --profile ci -p datafusion-common --features=pyarrow - vendor: name: Verify Vendored Code runs-on: ubuntu-latest container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -572,7 +595,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -627,11 +650,12 @@ jobs: clippy: name: clippy needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -641,6 +665,11 @@ jobs: rust-version: stable - name: Install Clippy run: rustup component add clippy + - name: Rust Dependency Cache + uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2.9.1 + with: + save-if: ${{ github.ref_name == 'main' }} + shared-key: "amd-ci-clippy" - name: Run clippy run: ci/scripts/rust_clippy.sh @@ -651,7 +680,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -668,11 +697,12 @@ jobs: config-docs-check: name: check configs.md and ***_functions.md is up-to-date needs: linux-build-lib - runs-on: ubuntu-latest + runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion', github.run_id) || 'ubuntu-latest' }} container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # v2.0.3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: submodules: true fetch-depth: 1 @@ -680,7 +710,7 @@ jobs: uses: ./.github/actions/setup-builder with: rust-version: stable - - uses: actions/setup-node@v4 + - uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # v6.2.0 with: node-version: "20" - name: Check if configs.md has been modified @@ -693,11 +723,38 @@ jobs: # If you encounter an error, run './dev/update_function_docs.sh' and commit ./dev/update_function_docs.sh git diff --exit-code - - name: Check if runtime_configs.md has been modified + +# This job ensures `datafusion-examples/README.md` stays in sync with the source code: +# 1. Generates README automatically using the Rust examples docs generator +# (parsing documentation from `examples//main.rs`) +# 2. Formats the generated Markdown using DataFusion's standard Prettier setup +# 3. Compares the result against the committed README.md and fails if out-of-date + examples-docs-check: + name: check example README is up-to-date + needs: linux-build-lib + runs-on: ubuntu-latest + container: + image: amd64/rust + + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + submodules: true + fetch-depth: 1 + + - name: Mark repository as safe for git + # Required for git commands inside container (avoids "dubious ownership" error) + run: git config --global --add safe.directory "$GITHUB_WORKSPACE" + + - name: Set up Node.js (required for prettier) + # doc_prettier_check.sh uses npx to run prettier for Markdown formatting + uses: actions/setup-node@v6 + with: + node-version: '18' + + - name: Run examples docs check script run: | - # If you encounter an error, run './dev/update_runtime_config_docs.sh' and commit - ./dev/update_runtime_config_docs.sh - git diff --exit-code + bash ci/scripts/check_examples_docs.sh # Verify MSRV for the crates which are directly used by other projects: # - datafusion @@ -710,11 +767,14 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - name: Install cargo-msrv - run: cargo install cargo-msrv + uses: taiki-e/install-action@de6bbd1333b8f331563d54a051e542c7dfef81c3 # v2.68.34 + with: + tool: cargo-msrv + - name: Check datafusion working-directory: datafusion/core run: | @@ -724,10 +784,15 @@ jobs: # `rust-version` key of `Cargo.toml`. # # To reproduce: - # 1. Install the version of Rust that is failing. Example: - # rustup install 1.80.1 - # 2. Run the command that failed with that version. Example: - # cargo +1.80.1 check -p datafusion + # 1. Install the version of Rust that is failing. + # 2. Run the command that failed with that version. + # + # Example: + # # MSRV looks like "1.80.0" and is specified in Cargo.toml. We can read the value with the following command: + # msrv="$(cargo metadata --format-version=1 | jq '.packages[] | select( .name == "datafusion" ) | .rust_version' -r)" + # echo "MSRV: ${msrv}" + # rustup install "${msrv}" + # cargo "+${msrv}" check # # To resolve, either: # 1. Change your code to use older Rust features, diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 2312526824a91..ec7f54ec24dbc 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -27,7 +27,7 @@ jobs: issues: write pull-requests: write steps: - - uses: actions/stale@v9 + - uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10.2.0 with: stale-pr-message: "Thank you for your contribution. Unfortunately, this pull request is stale because it has been open 60 days with no activity. Please remove the stale label or comment or this will be closed in 7 days." days-before-pr-stale: 60 diff --git a/.github/workflows/take.yml b/.github/workflows/take.yml index 86dc190add1d1..ffb5f728e04c1 100644 --- a/.github/workflows/take.yml +++ b/.github/workflows/take.yml @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -name: Assign the issue via a `take` comment +name: Assign/unassign the issue via `take` or `untake` comment on: issue_comment: types: created @@ -26,16 +26,30 @@ permissions: jobs: issue_assign: runs-on: ubuntu-latest - if: (!github.event.issue.pull_request) && github.event.comment.body == 'take' + if: (!github.event.issue.pull_request) && (github.event.comment.body == 'take' || github.event.comment.body == 'untake') concurrency: group: ${{ github.actor }}-issue-assign steps: - - run: | - CODE=$(curl -H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" -LI https://api.github.com/repos/${{ github.repository }}/issues/${{ github.event.issue.number }}/assignees/${{ github.event.comment.user.login }} -o /dev/null -w '%{http_code}\n' -s) - if [ "$CODE" -eq "204" ] + - name: Take or untake issue + env: + COMMENT_BODY: ${{ github.event.comment.body }} + ISSUE_NUMBER: ${{ github.event.issue.number }} + USER_LOGIN: ${{ github.event.comment.user.login }} + REPO: ${{ github.repository }} + TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + if [ "$COMMENT_BODY" == "take" ] then - echo "Assigning issue ${{ github.event.issue.number }} to ${{ github.event.comment.user.login }}" - curl -H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" -d '{"assignees": ["${{ github.event.comment.user.login }}"]}' https://api.github.com/repos/${{ github.repository }}/issues/${{ github.event.issue.number }}/assignees - else - echo "Cannot assign issue ${{ github.event.issue.number }} to ${{ github.event.comment.user.login }}" + CODE=$(curl -H "Authorization: token $TOKEN" -LI https://api.github.com/repos/$REPO/issues/$ISSUE_NUMBER/assignees/$USER_LOGIN -o /dev/null -w '%{http_code}\n' -s) + if [ "$CODE" -eq "204" ] + then + echo "Assigning issue $ISSUE_NUMBER to $USER_LOGIN" + curl -X POST -H "Authorization: token $TOKEN" -H "Content-Type: application/json" -d "{\"assignees\": [\"$USER_LOGIN\"]}" https://api.github.com/repos/$REPO/issues/$ISSUE_NUMBER/assignees + else + echo "Cannot assign issue $ISSUE_NUMBER to $USER_LOGIN" + fi + elif [ "$COMMENT_BODY" == "untake" ] + then + echo "Unassigning issue $ISSUE_NUMBER from $USER_LOGIN" + curl -X DELETE -H "Authorization: token $TOKEN" -H "Content-Type: application/json" -d "{\"assignees\": [\"$USER_LOGIN\"]}" https://api.github.com/repos/$REPO/issues/$ISSUE_NUMBER/assignees fi \ No newline at end of file diff --git a/.gitignore b/.gitignore index 4ae32925d908e..8466a72adaec8 100644 --- a/.gitignore +++ b/.gitignore @@ -31,6 +31,7 @@ docker_cache *.orig .*.swp .*.swo +*.pending-snap venv/* diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000000000..eeedbd8bc45ec --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,34 @@ +# Agent Guidelines for Apache DataFusion + +## Developer Documentation + +- [Contributor Guide](docs/source/contributor-guide/index.md) +- [Architecture Guide](docs/source/contributor-guide/architecture.md) + +## Before Committing + +Before committing any changes, you **must** run the following checks and fix any issues: + +```bash +cargo fmt --all +cargo clippy --all-targets --all-features -- -D warnings +``` + +- `cargo fmt` ensures consistent code formatting across the project. +- `cargo clippy` catches common mistakes and enforces idiomatic Rust patterns. All warnings must be resolved (treated as errors via `-D warnings`). + +Do not commit code that fails either of these checks. + +## Testing + +Run relevant tests before submitting changes: + +```bash +cargo test --all-features +``` + +For SQL logic tests: + +```bash +cargo test -p datafusion-sqllogictest +``` diff --git a/CLAUDE.md b/CLAUDE.md new file mode 120000 index 0000000000000..47dc3e3d863cf --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1 @@ +AGENTS.md \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index f918b3ae2663d..5cef3742dfd18 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "abi_stable" @@ -14,7 +14,7 @@ dependencies = [ "core_extensions", "crossbeam-channel", "generational-arena", - "libloading 0.7.4", + "libloading", "lock_api", "parking_lot", "paste", @@ -50,37 +50,11 @@ dependencies = [ "core_extensions", ] -[[package]] -name = "addr2line" -version = "0.24.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1" -dependencies = [ - "gimli", -] - [[package]] name = "adler2" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" - -[[package]] -name = "adler32" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aae1277d39aeec15cb388266ecc24b11c80469deae6067e17a1a7aa9e5c1f234" - -[[package]] -name = "ahash" -version = "0.7.8" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "891477e0c6a8957309ee5c45a6368af3ae14bb510732d2684ffa19af310920f9" -dependencies = [ - "getrandom 0.2.16", - "once_cell", - "version_check", -] +checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" [[package]] name = "ahash" @@ -90,7 +64,7 @@ checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" dependencies = [ "cfg-if", "const-random", - "getrandom 0.3.3", + "getrandom 0.3.4", "once_cell", "version_check", "zerocopy", @@ -98,9 +72,9 @@ dependencies = [ [[package]] name = "aho-corasick" -version = "1.1.3" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" dependencies = [ "memchr", ] @@ -121,16 +95,19 @@ dependencies = [ ] [[package]] -name = "allocator-api2" -version = "0.2.21" +name = "alloca" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" +checksum = "e5a7d05ea6aea7e9e64d25b9156ba2fee3fdd659e34e41063cd2fc7cd020d7f4" +dependencies = [ + "cc", +] [[package]] -name = "android-tzdata" -version = "0.1.1" +name = "allocator-api2" +version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" [[package]] name = "android_system_properties" @@ -149,12 +126,27 @@ checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" [[package]] name = "anstream" -version = "0.6.18" +version = "0.6.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43d5b281e737544384e969a5ccad3f1cdd24b48086a0fc1b2a5262a26b8f4f4a" +dependencies = [ + "anstyle", + "anstyle-parse 0.2.7", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstream" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8acc5369981196006228e28809f761875c0327210a891e941f4c683b3a99529b" +checksum = "824a212faf96e9acacdbd09febd34438f8f711fb84e09a8916013cd7815ca28d" dependencies = [ "anstyle", - "anstyle-parse", + "anstyle-parse 1.0.0", "anstyle-query", "anstyle-wincon", "colorchoice", @@ -164,74 +156,92 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.10" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" +checksum = "940b3a0ca603d1eade50a4846a2afffd5ef57a9feac2c0e2ec2e14f9ead76000" [[package]] name = "anstyle-parse" -version = "0.2.6" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-parse" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9" +checksum = "52ce7f38b242319f7cabaa6813055467063ecdc9d355bbb4ce0c68908cd8130e" dependencies = [ "utf8parse", ] [[package]] name = "anstyle-query" -version = "1.1.2" +version = "1.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79947af37f4177cfead1110013d678905c37501914fba0efea834c3fe9a8d60c" +checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" dependencies = [ - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] name = "anstyle-wincon" -version = "3.0.7" +version = "3.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca3534e77181a9cc07539ad51f2141fe32f6c3ffd4df76db8ad92346b003ae4e" +checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" dependencies = [ "anstyle", - "once_cell", - "windows-sys 0.59.0", + "once_cell_polyfill", + "windows-sys 0.61.2", ] [[package]] name = "anyhow" -version = "1.0.98" +version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" [[package]] name = "apache-avro" -version = "0.17.0" +version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1aef82843a0ec9f8b19567445ad2421ceeb1d711514384bdd3d49fe37102ee13" +checksum = "36fa98bc79671c7981272d91a8753a928ff6a1cd8e4f20a44c45bd5d313840bf" dependencies = [ "bigdecimal", - "bzip2 0.4.4", + "bon", + "bzip2", "crc32fast", "digest", - "libflate", + "liblzma", "log", + "miniz_oxide", "num-bigint", "quad-rand", - "rand 0.8.5", + "rand 0.9.2", "regex-lite", "serde", "serde_bytes", "serde_json", "snap", - "strum", - "strum_macros", - "thiserror 1.0.69", - "typed-builder", + "strum 0.27.2", + "strum_macros 0.27.2", + "thiserror", "uuid", - "xz2", "zstd", ] +[[package]] +name = "ar_archive_writer" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7eb93bbb63b9c227414f6eb3a0adfddca591a8ce1e9b60661bb08969b87e340b" +dependencies = [ + "object", +] + [[package]] name = "arrayref" version = "0.3.9" @@ -246,9 +256,9 @@ checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "arrow" -version = "55.1.0" +version = "58.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1bb018b6960c87fd9d025009820406f74e83281185a8bdcb44880d2aa5c9a87" +checksum = "602268ce9f569f282cedb9a9f6bac569b680af47b9b077d515900c03c5d190da" dependencies = [ "arrow-arith", "arrow-array", @@ -264,61 +274,64 @@ dependencies = [ "arrow-select", "arrow-string", "half", - "pyo3", - "rand 0.9.1", + "rand 0.9.2", ] [[package]] name = "arrow-arith" -version = "55.1.0" +version = "58.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44de76b51473aa888ecd6ad93ceb262fb8d40d1f1154a4df2f069b3590aa7575" +checksum = "cd53c6bf277dea91f136ae8e3a5d7041b44b5e489e244e637d00ae302051f56f" dependencies = [ "arrow-array", "arrow-buffer", "arrow-data", "arrow-schema", "chrono", - "num", + "num-traits", ] [[package]] name = "arrow-array" -version = "55.1.0" +version = "58.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29ed77e22744475a9a53d00026cf8e166fe73cf42d89c4c4ae63607ee1cfcc3f" +checksum = "e53796e07a6525edaf7dc28b540d477a934aff14af97967ad1d5550878969b9e" dependencies = [ - "ahash 0.8.12", + "ahash", "arrow-buffer", "arrow-data", "arrow-schema", "chrono", "chrono-tz", "half", - "hashbrown 0.15.3", - "num", + "hashbrown 0.16.1", + "num-complex", + "num-integer", + "num-traits", ] [[package]] name = "arrow-buffer" -version = "55.1.0" +version = "58.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0391c96eb58bf7389171d1e103112d3fc3e5625ca6b372d606f2688f1ea4cce" +checksum = "f2c1a85bb2e94ee10b76531d8bc3ce9b7b4c0d508cabfb17d477f63f2617bd20" dependencies = [ "bytes", "half", - "num", + "num-bigint", + "num-traits", ] [[package]] name = "arrow-cast" -version = "55.1.0" +version = "58.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f39e1d774ece9292697fcbe06b5584401b26bd34be1bec25c33edae65c2420ff" +checksum = "89fb245db6b0e234ed8e15b644edb8664673fefe630575e94e62cd9d489a8a26" dependencies = [ "arrow-array", "arrow-buffer", "arrow-data", + "arrow-ord", "arrow-schema", "arrow-select", "atoi", @@ -327,15 +340,15 @@ dependencies = [ "comfy-table", "half", "lexical-core", - "num", + "num-traits", "ryu", ] [[package]] name = "arrow-csv" -version = "55.1.0" +version = "58.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9055c972a07bf12c2a827debfd34f88d3b93da1941d36e1d9fee85eebe38a12a" +checksum = "d374882fb465a194462527c0c15a93aa19a554cf690a6b77a26b2a02539937a7" dependencies = [ "arrow-array", "arrow-cast", @@ -343,27 +356,27 @@ dependencies = [ "chrono", "csv", "csv-core", - "lazy_static", "regex", ] [[package]] name = "arrow-data" -version = "55.1.0" +version = "58.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf75ac27a08c7f48b88e5c923f267e980f27070147ab74615ad85b5c5f90473d" +checksum = "189d210bc4244c715fa3ed9e6e22864673cccb73d5da28c2723fb2e527329b33" dependencies = [ "arrow-buffer", "arrow-schema", "half", - "num", + "num-integer", + "num-traits", ] [[package]] name = "arrow-flight" -version = "55.1.0" +version = "58.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91efc67a4f5a438833dd76ef674745c80f6f6b9a428a3b440cbfbf74e32867e6" +checksum = "b4f5cdf00ee0003ba0768d3575d0afc47d736b29673b14c3c228fdffa9a3fb29" dependencies = [ "arrow-arith", "arrow-array", @@ -384,27 +397,30 @@ dependencies = [ "prost", "prost-types", "tonic", + "tonic-prost", ] [[package]] name = "arrow-ipc" -version = "55.1.0" +version = "58.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a222f0d93772bd058d1268f4c28ea421a603d66f7979479048c429292fac7b2e" +checksum = "7968c2e5210c41f4909b2ef76f6e05e172b99021c2def5edf3cc48fdd39d1d6c" dependencies = [ "arrow-array", "arrow-buffer", "arrow-data", "arrow-schema", + "arrow-select", "flatbuffers", "lz4_flex", + "zstd", ] [[package]] name = "arrow-json" -version = "55.1.0" +version = "58.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9085342bbca0f75e8cb70513c0807cc7351f1fbf5cb98192a67d5e3044acb033" +checksum = "92111dba5bf900f443488e01f00d8c4ddc2f47f5c50039d18120287b580baa22" dependencies = [ "arrow-array", "arrow-buffer", @@ -413,20 +429,22 @@ dependencies = [ "arrow-schema", "chrono", "half", - "indexmap 2.9.0", + "indexmap 2.13.0", + "itoa", "lexical-core", "memchr", - "num", - "serde", + "num-traits", + "ryu", + "serde_core", "serde_json", "simdutf8", ] [[package]] name = "arrow-ord" -version = "55.1.0" +version = "58.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab2f1065a5cad7b9efa9e22ce5747ce826aa3855766755d4904535123ef431e7" +checksum = "211136cb253577ee1a6665f741a13136d4e563f64f5093ffd6fb837af90b9495" dependencies = [ "arrow-array", "arrow-buffer", @@ -437,9 +455,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "55.1.0" +version = "58.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3703a0e3e92d23c3f756df73d2dc9476873f873a76ae63ef9d3de17fda83b2d8" +checksum = "8e0f20145f9f5ea3fe383e2ba7a7487bf19be36aa9dbf5dd6a1f92f657179663" dependencies = [ "arrow-array", "arrow-buffer", @@ -450,34 +468,35 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "55.1.0" +version = "58.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73a47aa0c771b5381de2b7f16998d351a6f4eb839f1e13d48353e17e873d969b" +checksum = "1b47e0ca91cc438d2c7879fe95e0bca5329fff28649e30a88c6f760b1faeddcb" dependencies = [ - "bitflags 2.9.1", + "bitflags", "serde", + "serde_core", "serde_json", ] [[package]] name = "arrow-select" -version = "55.1.0" +version = "58.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24b7b85575702b23b85272b01bc1c25a01c9b9852305e5d0078c79ba25d995d4" +checksum = "750a7d1dda177735f5e82a314485b6915c7cccdbb278262ac44090f4aba4a325" dependencies = [ - "ahash 0.8.12", + "ahash", "arrow-array", "arrow-buffer", "arrow-data", "arrow-schema", - "num", + "num-traits", ] [[package]] name = "arrow-string" -version = "55.1.0" +version = "58.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9260fddf1cdf2799ace2b4c2fc0356a9789fa7551e0953e35435536fecefebbd" +checksum = "e1eab1208bc4fe55d768cdc9b9f3d9df5a794cdb3ee2586bf89f9b30dc31ad8c" dependencies = [ "arrow-array", "arrow-buffer", @@ -485,7 +504,7 @@ dependencies = [ "arrow-schema", "arrow-select", "memchr", - "num", + "num-traits", "regex", "regex-syntax", ] @@ -503,36 +522,31 @@ dependencies = [ ] [[package]] -name = "assert_cmd" -version = "2.0.17" +name = "astral-tokio-tar" +version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2bd389a4b2970a01282ee455294913c0a43724daedcd1a24c3eb0ec1c1320b66" +checksum = "ec179a06c1769b1e42e1e2cbe74c7dcdb3d6383c838454d063eaac5bbb7ebbe5" dependencies = [ - "anstyle", - "bstr", - "doc-comment", + "filetime", + "futures-core", "libc", - "predicates", - "predicates-core", - "predicates-tree", - "wait-timeout", + "portable-atomic", + "rustc-hash", + "tokio", + "tokio-stream", + "xattr", ] [[package]] name = "async-compression" -version = "0.4.19" +version = "0.4.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06575e6a9673580f52661c92107baabffbf41e2141373441cbcdc47cb733003c" +checksum = "d0f9ee0f6e02ffd7ad5816e9464499fba7b3effd01123b515c41d1697c43dad1" dependencies = [ - "bzip2 0.5.2", - "flate2", - "futures-core", - "memchr", + "compression-codecs", + "compression-core", "pin-project-lite", "tokio", - "xz2", - "zstd", - "zstd-safe", ] [[package]] @@ -552,7 +566,7 @@ checksum = "3b43422f69d8ff38f95f1b2bb76517c91589a924d1559a0e935d7c8ce0274c11" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.117", ] [[package]] @@ -574,18 +588,18 @@ checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.117", ] [[package]] name = "async-trait" -version = "0.1.88" +version = "0.1.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.117", ] [[package]] @@ -605,15 +619,15 @@ checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" [[package]] name = "autocfg" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" [[package]] name = "aws-config" -version = "1.6.3" +version = "1.8.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02a18fd934af6ae7ca52410d4548b98eb895aab0f1ea417d168d85db1434a141" +checksum = "11493b0bad143270fb8ad284a096dd529ba91924c5409adeac856cc1bf047dbc" dependencies = [ "aws-credential-types", "aws-runtime", @@ -630,8 +644,8 @@ dependencies = [ "bytes", "fastrand", "hex", - "http 1.3.1", - "ring", + "http 1.4.0", + "sha1", "time", "tokio", "tracing", @@ -641,9 +655,9 @@ dependencies = [ [[package]] name = "aws-credential-types" -version = "1.2.3" +version = "1.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "687bc16bc431a8533fe0097c7f0182874767f920989d7260950172ae8e3c4465" +checksum = "8f20799b373a1be121fe3005fba0c2090af9411573878f224df44b42727fcaf7" dependencies = [ "aws-smithy-async", "aws-smithy-runtime-api", @@ -653,9 +667,9 @@ dependencies = [ [[package]] name = "aws-lc-rs" -version = "1.13.1" +version = "1.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93fcc8f365936c834db5514fc45aee5b1202d677e6b40e48468aaaa8183ca8c7" +checksum = "94bffc006df10ac2a68c83692d734a465f8ee6c5b384d8545a636f81d858f4bf" dependencies = [ "aws-lc-sys", "zeroize", @@ -663,11 +677,10 @@ dependencies = [ [[package]] name = "aws-lc-sys" -version = "0.29.0" +version = "0.38.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61b1d86e7705efe1be1b569bab41d4fa1e14e220b60a160f78de2db687add079" +checksum = "4321e568ed89bb5a7d291a7f37997c2c0df89809d7b6d12062c81ddb54aa782e" dependencies = [ - "bindgen", "cc", "cmake", "dunce", @@ -676,9 +689,9 @@ dependencies = [ [[package]] name = "aws-runtime" -version = "1.5.7" +version = "1.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c4063282c69991e57faab9e5cb21ae557e59f5b0fb285c196335243df8dc25c" +checksum = "5fc0651c57e384202e47153c1260b84a9936e19803d747615edf199dc3b98d17" dependencies = [ "aws-credential-types", "aws-sigv4", @@ -689,9 +702,10 @@ dependencies = [ "aws-smithy-types", "aws-types", "bytes", + "bytes-utils", "fastrand", - "http 0.2.12", - "http-body 0.4.6", + "http 1.4.0", + "http-body 1.0.1", "percent-encoding", "pin-project-lite", "tracing", @@ -700,15 +714,16 @@ dependencies = [ [[package]] name = "aws-sdk-sso" -version = "1.70.0" +version = "1.96.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83447efb7179d8e2ad2afb15ceb9c113debbc2ecdf109150e338e2e28b86190b" +checksum = "f64a6eded248c6b453966e915d32aeddb48ea63ad17932682774eb026fbef5b1" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", "aws-smithy-http", "aws-smithy-json", + "aws-smithy-observability", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", @@ -716,21 +731,23 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", + "http 1.4.0", "regex-lite", "tracing", ] [[package]] name = "aws-sdk-ssooidc" -version = "1.71.0" +version = "1.98.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5f9bfbbda5e2b9fe330de098f14558ee8b38346408efe9f2e9cee82dc1636a4" +checksum = "db96d720d3c622fcbe08bae1c4b04a72ce6257d8b0584cb5418da00ae20a344f" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", "aws-smithy-http", "aws-smithy-json", + "aws-smithy-observability", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", @@ -738,21 +755,23 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", + "http 1.4.0", "regex-lite", "tracing", ] [[package]] name = "aws-sdk-sts" -version = "1.71.0" +version = "1.100.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e17b984a66491ec08b4f4097af8911251db79296b3e4a763060b45805746264f" +checksum = "fafbdda43b93f57f699c5dfe8328db590b967b8a820a13ccdd6687355dfcc7ca" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", "aws-smithy-http", "aws-smithy-json", + "aws-smithy-observability", "aws-smithy-query", "aws-smithy-runtime", "aws-smithy-runtime-api", @@ -761,15 +780,16 @@ dependencies = [ "aws-types", "fastrand", "http 0.2.12", + "http 1.4.0", "regex-lite", "tracing", ] [[package]] name = "aws-sigv4" -version = "1.3.2" +version = "1.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3734aecf9ff79aa401a6ca099d076535ab465ff76b46440cf567c8e70b65dc13" +checksum = "b0b660013a6683ab23797778e21f1f854744fdf05f68204b4cca4c8c04b5d1f4" dependencies = [ "aws-credential-types", "aws-smithy-http", @@ -780,7 +800,7 @@ dependencies = [ "hex", "hmac", "http 0.2.12", - "http 1.3.1", + "http 1.4.0", "percent-encoding", "sha2", "time", @@ -789,9 +809,9 @@ dependencies = [ [[package]] name = "aws-smithy-async" -version = "1.2.5" +version = "1.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e190749ea56f8c42bf15dd76c65e14f8f765233e6df9b0506d9d934ebef867c" +checksum = "2ffcaf626bdda484571968400c326a244598634dc75fd451325a54ad1a59acfc" dependencies = [ "futures-util", "pin-project-lite", @@ -800,18 +820,19 @@ dependencies = [ [[package]] name = "aws-smithy-http" -version = "0.62.1" +version = "0.63.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99335bec6cdc50a346fda1437f9fefe33abf8c99060739a546a16457f2862ca9" +checksum = "ba1ab2dc1c2c3749ead27180d333c42f11be8b0e934058fb4b2258ee8dbe5231" dependencies = [ "aws-smithy-runtime-api", "aws-smithy-types", "bytes", "bytes-utils", "futures-core", - "http 0.2.12", - "http 1.3.1", - "http-body 0.4.6", + "futures-util", + "http 1.4.0", + "http-body 1.0.1", + "http-body-util", "percent-encoding", "pin-project-lite", "pin-utils", @@ -820,15 +841,15 @@ dependencies = [ [[package]] name = "aws-smithy-http-client" -version = "1.0.2" +version = "1.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e44697a9bded898dcd0b1cb997430d949b87f4f8940d91023ae9062bf218250" +checksum = "6a2f165a7feee6f263028b899d0a181987f4fa7179a6411a32a439fba7c5f769" dependencies = [ "aws-smithy-async", "aws-smithy-runtime-api", "aws-smithy-types", "h2", - "http 1.3.1", + "http 1.4.0", "hyper", "hyper-rustls", "hyper-util", @@ -837,33 +858,34 @@ dependencies = [ "rustls-native-certs", "rustls-pki-types", "tokio", - "tower 0.5.2", + "tokio-rustls", + "tower", "tracing", ] [[package]] name = "aws-smithy-json" -version = "0.61.3" +version = "0.62.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92144e45819cae7dc62af23eac5a038a58aa544432d2102609654376a900bd07" +checksum = "9648b0bb82a2eedd844052c6ad2a1a822d1f8e3adee5fbf668366717e428856a" dependencies = [ "aws-smithy-types", ] [[package]] name = "aws-smithy-observability" -version = "0.1.3" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9364d5989ac4dd918e5cc4c4bdcc61c9be17dcd2586ea7f69e348fc7c6cab393" +checksum = "a06c2315d173edbf1920da8ba3a7189695827002e4c0fc961973ab1c54abca9c" dependencies = [ "aws-smithy-runtime-api", ] [[package]] name = "aws-smithy-query" -version = "0.60.7" +version = "0.60.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2fbd61ceb3fe8a1cb7352e42689cec5335833cd9f94103a61e98f9bb61c64bb" +checksum = "1a56d79744fb3edb5d722ef79d86081e121d3b9422cb209eb03aea6aa4f21ebd" dependencies = [ "aws-smithy-types", "urlencoding", @@ -871,9 +893,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.8.3" +version = "1.10.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14302f06d1d5b7d333fd819943075b13d27c7700b414f574c3c35859bfb55d5e" +checksum = "028999056d2d2fd58a697232f9eec4a643cf73a71cf327690a7edad1d2af2110" dependencies = [ "aws-smithy-async", "aws-smithy-http", @@ -884,9 +906,10 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", - "http 1.3.1", + "http 1.4.0", "http-body 0.4.6", "http-body 1.0.1", + "http-body-util", "pin-project-lite", "pin-utils", "tokio", @@ -895,15 +918,15 @@ dependencies = [ [[package]] name = "aws-smithy-runtime-api" -version = "1.8.0" +version = "1.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1e5d9e3a80a18afa109391fb5ad09c3daf887b516c6fd805a157c6ea7994a57" +checksum = "876ab3c9c29791ba4ba02b780a3049e21ec63dabda09268b175272c3733a79e6" dependencies = [ "aws-smithy-async", "aws-smithy-types", "bytes", "http 0.2.12", - "http 1.3.1", + "http 1.4.0", "pin-project-lite", "tokio", "tracing", @@ -912,15 +935,15 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.3.1" +version = "1.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40076bd09fadbc12d5e026ae080d0930defa606856186e31d83ccc6a255eeaf3" +checksum = "9d73dbfbaa8e4bc57b9045137680b958d274823509a360abfd8e1d514d40c95c" dependencies = [ "base64-simd", "bytes", "bytes-utils", "http 0.2.12", - "http 1.3.1", + "http 1.4.0", "http-body 0.4.6", "http-body 1.0.1", "http-body-util", @@ -935,18 +958,18 @@ dependencies = [ [[package]] name = "aws-smithy-xml" -version = "0.60.9" +version = "0.60.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab0b0166827aa700d3dc519f72f8b3a91c35d0b8d042dc5d643a91e6f80648fc" +checksum = "0ce02add1aa3677d022f8adf81dcbe3046a95f17a1b1e8979c145cd21d3d22b3" dependencies = [ "xmlparser", ] [[package]] name = "aws-types" -version = "1.3.7" +version = "1.3.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a322fec39e4df22777ed3ad8ea868ac2f94cd15e1a55f6ee8d8d6305057689a" +checksum = "47c8323699dd9b3c8d5b3c13051ae9cdef58fd179957c882f8374dd8725962d9" dependencies = [ "aws-credential-types", "aws-smithy-async", @@ -958,15 +981,14 @@ dependencies = [ [[package]] name = "axum" -version = "0.7.9" +version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" +checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8" dependencies = [ - "async-trait", "axum-core", "bytes", "futures-util", - "http 1.3.1", + "http 1.4.0", "http-body 1.0.1", "http-body-util", "itoa", @@ -975,49 +997,31 @@ dependencies = [ "mime", "percent-encoding", "pin-project-lite", - "rustversion", - "serde", + "serde_core", "sync_wrapper", - "tower 0.5.2", + "tower", "tower-layer", "tower-service", ] [[package]] name = "axum-core" -version = "0.4.5" +version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199" +checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1" dependencies = [ - "async-trait", "bytes", - "futures-util", - "http 1.3.1", + "futures-core", + "http 1.4.0", "http-body 1.0.1", "http-body-util", "mime", "pin-project-lite", - "rustversion", "sync_wrapper", "tower-layer", "tower-service", ] -[[package]] -name = "backtrace" -version = "0.3.75" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6806a6321ec58106fea15becdad98371e28d92ccbc7c8f1b3b6dd724fe8f1002" -dependencies = [ - "addr2line", - "cfg-if", - "libc", - "miniz_oxide", - "object", - "rustc-demangle", - "windows-targets 0.52.6", -] - [[package]] name = "base64" version = "0.21.7" @@ -1042,9 +1046,9 @@ dependencies = [ [[package]] name = "bigdecimal" -version = "0.4.8" +version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a22f228ab7a1b23027ccc6c350b72868017af7ea8356fbdf19f8d991c690013" +checksum = "4d6867f1565b3aad85681f1015055b087fcfd840d6aeee6eee7f2da317603695" dependencies = [ "autocfg", "libm", @@ -1054,52 +1058,11 @@ dependencies = [ "serde", ] -[[package]] -name = "bindgen" -version = "0.69.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088" -dependencies = [ - "bitflags 2.9.1", - "cexpr", - "clang-sys", - "itertools 0.12.1", - "lazy_static", - "lazycell", - "log", - "prettyplease", - "proc-macro2", - "quote", - "regex", - "rustc-hash 1.1.0", - "shlex", - "syn 2.0.101", - "which", -] - -[[package]] -name = "bitflags" -version = "1.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" - [[package]] name = "bitflags" -version = "2.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967" - -[[package]] -name = "bitvec" -version = "1.0.1" +version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" -dependencies = [ - "funty", - "radium", - "tap", - "wyz", -] +checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" [[package]] name = "blake2" @@ -1112,15 +1075,16 @@ dependencies = [ [[package]] name = "blake3" -version = "1.8.2" +version = "1.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3888aaa89e4b2a40fca9848e400f6a658a5a3978de7be858e209cafa8be9a4a0" +checksum = "2468ef7d57b3fb7e16b576e8377cdbde2320c60e1491e961d11da40fc4f02a2d" dependencies = [ "arrayref", "arrayvec", "cc", "cfg-if", "constant_time_eq", + "cpufeatures", ] [[package]] @@ -1134,18 +1098,21 @@ dependencies = [ [[package]] name = "bollard" -version = "0.18.1" +version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97ccca1260af6a459d75994ad5acc1651bcabcbdbc41467cc9786519ab854c30" +checksum = "ee04c4c84f1f811b017f2fbb7dd8815c976e7ca98593de9c1e2afad0f636bff4" dependencies = [ + "async-stream", "base64 0.22.1", + "bitflags", + "bollard-buildkit-proto", "bollard-stubs", "bytes", "futures-core", "futures-util", "hex", "home", - "http 1.3.1", + "http 1.4.0", "http-body-util", "hyper", "hyper-named-pipe", @@ -1153,63 +1120,86 @@ dependencies = [ "hyper-util", "hyperlocal", "log", + "num", "pin-project-lite", + "rand 0.9.2", "rustls", "rustls-native-certs", - "rustls-pemfile", "rustls-pki-types", "serde", "serde_derive", "serde_json", - "serde_repr", "serde_urlencoded", - "thiserror 2.0.12", + "thiserror", + "time", "tokio", + "tokio-stream", "tokio-util", + "tonic", "tower-service", "url", "winapi", ] +[[package]] +name = "bollard-buildkit-proto" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85a885520bf6249ab931a764ffdb87b0ceef48e6e7d807cfdb21b751e086e1ad" +dependencies = [ + "prost", + "prost-types", + "tonic", + "tonic-prost", + "ureq", +] + [[package]] name = "bollard-stubs" -version = "1.47.1-rc.27.3.1" +version = "1.52.1-rc.29.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f179cfbddb6e77a5472703d4b30436bff32929c0aa8a9008ecf23d1d3cdd0da" +checksum = "0f0a8ca8799131c1837d1282c3f81f31e76ceb0ce426e04a7fe1ccee3287c066" dependencies = [ + "base64 0.22.1", + "bollard-buildkit-proto", + "bytes", + "prost", "serde", + "serde_json", "serde_repr", - "serde_with", + "time", ] [[package]] -name = "borsh" -version = "1.5.7" +name = "bon" +version = "3.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad8646f98db542e39fc66e68a20b2144f6a732636df7c2354e74645faaa433ce" +checksum = "f47dbe92550676ee653353c310dfb9cf6ba17ee70396e1f7cf0a2020ad49b2fe" dependencies = [ - "borsh-derive", - "cfg_aliases", + "bon-macros", + "rustversion", ] [[package]] -name = "borsh-derive" -version = "1.5.7" +name = "bon-macros" +version = "3.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdd1d3c0c2f5833f22386f252fe8ed005c7f59fdcddeef025c01b4c3b9fd9ac3" +checksum = "519bd3116aeeb42d5372c29d982d16d0170d3d4a5ed85fc7dd91642ffff3c67c" dependencies = [ - "once_cell", - "proc-macro-crate", + "darling", + "ident_case", + "prettyplease", "proc-macro2", "quote", - "syn 2.0.101", + "rustversion", + "syn 2.0.117", ] [[package]] name = "brotli" -version = "8.0.1" +version = "8.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9991eea70ea4f293524138648e41ee89b0b2b12ddef3b255effa43c8056e0e0d" +checksum = "4bd8b9603c7aa97359dbd97ecf258968c95f3adddd6db2f7e7a5bef101c84560" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -1228,42 +1218,19 @@ dependencies = [ [[package]] name = "bstr" -version = "1.12.0" +version = "1.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "234113d19d0d7d613b40e86fb654acf958910802bcceab913a4f9e7cda03b1a4" +checksum = "63044e1ae8e69f3b5a92c736ca6269b8d12fa7efe39bf34ddb06d102cf0e2cab" dependencies = [ "memchr", - "regex-automata", "serde", ] [[package]] name = "bumpalo" -version = "3.17.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf" - -[[package]] -name = "bytecheck" -version = "0.6.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23cdc57ce23ac53c931e88a43d06d070a6fd142f2617be5855eb75efc9beb1c2" -dependencies = [ - "bytecheck_derive", - "ptr_meta", - "simdutf8", -] - -[[package]] -name = "bytecheck_derive" -version = "0.6.12" +version = "3.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3db406d29fbcd95542e92559bed4d8ad92636d1ca8b3b72ede10b4bcc010e659" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] +checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" [[package]] name = "byteorder" @@ -1273,9 +1240,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.10.1" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" [[package]] name = "bytes-utils" @@ -1289,31 +1256,11 @@ dependencies = [ [[package]] name = "bzip2" -version = "0.4.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bdb116a6ef3f6c3698828873ad02c3014b3c85cadb88496095628e3ef1e347f8" -dependencies = [ - "bzip2-sys", - "libc", -] - -[[package]] -name = "bzip2" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49ecfb22d906f800d4fe833b6282cf4dc1c298f5057ca0b5445e5c209735ca47" -dependencies = [ - "bzip2-sys", -] - -[[package]] -name = "bzip2-sys" -version = "0.1.13+1.0.8" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "225bff33b2141874fe80d71e07d6eec4f85c5c216453dd96388240f96e1acc14" +checksum = "f3a53fac24f34a81bc9954b5d6cfce0c21e18ec6959f44f56e8e90e4bb7c346c" dependencies = [ - "cc", - "pkg-config", + "libbz2-rs-sys", ] [[package]] @@ -1324,29 +1271,21 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.23" +version = "1.2.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f4ac86a9e5bc1e2b3449ab9d7d3a6a405e3d1bb28d7b9be8614f55846ae3766" +checksum = "7a0dd1ca384932ff3641c8718a02769f1698e7563dc6974ffd03346116310423" dependencies = [ + "find-msvc-tools", "jobserver", "libc", "shlex", ] -[[package]] -name = "cexpr" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" -dependencies = [ - "nom", -] - [[package]] name = "cfg-if" -version = "1.0.0" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" [[package]] name = "cfg_aliases" @@ -1356,11 +1295,10 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "chrono" -version = "0.4.41" +version = "0.4.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d" +checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0" dependencies = [ - "android-tzdata", "iana-time-zone", "js-sys", "num-traits", @@ -1371,23 +1309,12 @@ dependencies = [ [[package]] name = "chrono-tz" -version = "0.10.3" +version = "0.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "efdce149c370f133a071ca8ef6ea340b7b88748ab0810097a9e2976eaa34b4f3" +checksum = "a6139a8597ed92cf816dfb33f5dd6cf0bb93a6adc938f11039f371bc5bcd26c3" dependencies = [ "chrono", - "chrono-tz-build", - "phf", -] - -[[package]] -name = "chrono-tz-build" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f10f8c9340e31fc120ff885fcdb54a0b48e474bbd77cab557f0c30a3e569402" -dependencies = [ - "parse-zoneinfo", - "phf_codegen", + "phf 0.12.1", ] [[package]] @@ -1417,33 +1344,11 @@ dependencies = [ "half", ] -[[package]] -name = "clang-sys" -version = "1.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" -dependencies = [ - "glob", - "libc", - "libloading 0.8.7", -] - -[[package]] -name = "clap" -version = "2.34.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0610544180c38b88101fecf2dd634b174a62eef6946f84dfc6a7127512b381c" -dependencies = [ - "bitflags 1.3.2", - "textwrap", - "unicode-width 0.1.14", -] - [[package]] name = "clap" -version = "4.5.39" +version = "4.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd60e63e9be68e5fb56422e397cf9baddded06dae1d2e523401542383bc72a9f" +checksum = "b193af5b67834b676abd72466a96c1024e6a6ad978a1f484bd90b85c94041351" dependencies = [ "clap_builder", "clap_derive", @@ -1451,11 +1356,11 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.39" +version = "4.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89cc6392a1f72bbeb820d71f32108f61fdaf18bc526e1d23954168a67759ef51" +checksum = "714a53001bf66416adb0e2ef5ac857140e7dc3a0c48fb28b2f10762fc4b5069f" dependencies = [ - "anstream", + "anstream 1.0.0", "anstyle", "clap_lex", "strsim", @@ -1463,56 +1368,77 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.32" +version = "4.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09176aae279615badda0765c0c0b3f6ed53f4709118af73cf4655d85d1530cd7" +checksum = "1110bd8a634a1ab8cb04345d8d878267d57c3cf1b38d91b71af6686408bbca6a" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.117", ] [[package]] name = "clap_lex" -version = "0.7.4" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" +checksum = "c8d4a3bb8b1e0c1050499d1815f5ab16d04f0959b233085fb31653fbfc9d98f9" [[package]] name = "clipboard-win" -version = "5.4.0" +version = "5.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15efe7a882b08f34e38556b14f2fb3daa98769d06c7f0c1b076dfd0d983bc892" +checksum = "bde03770d3df201d4fb868f2c9c59e66a3e4e2bd06692a0fe701e7103c7e84d4" dependencies = [ "error-code", ] [[package]] name = "cmake" -version = "0.1.54" +version = "0.1.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0" +checksum = "75443c44cd6b379beb8c5b45d85d0773baf31cce901fe7bb252f4eff3008ef7d" dependencies = [ "cc", ] [[package]] name = "colorchoice" -version = "1.0.3" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" +checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570" [[package]] name = "comfy-table" -version = "7.1.4" +version = "7.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a65ebfec4fb190b6f90e944a817d60499ee0744e582530e2c9900a22e591d9a" +checksum = "958c5d6ecf1f214b4c2bbbbf6ab9523a864bd136dcf71a7e8904799acfe1ad47" dependencies = [ "unicode-segmentation", - "unicode-width 0.2.0", + "unicode-width 0.2.2", +] + +[[package]] +name = "compression-codecs" +version = "0.4.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb7b51a7d9c967fc26773061ba86150f19c50c0d65c887cb1fbe295fd16619b7" +dependencies = [ + "bzip2", + "compression-core", + "flate2", + "liblzma", + "memchr", + "zstd", + "zstd-safe", ] +[[package]] +name = "compression-core" +version = "0.4.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75984efb6ed102a0d42db99afb6c1948f0380d1d91808d5529916e6c08b49d8d" + [[package]] name = "console" version = "0.15.11" @@ -1522,10 +1448,21 @@ dependencies = [ "encode_unicode", "libc", "once_cell", - "unicode-width 0.2.0", "windows-sys 0.59.0", ] +[[package]] +name = "console" +version = "0.16.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d64e8af5551369d19cf50138de61f1c42074ab970f74e99be916646777f8fc87" +dependencies = [ + "encode_unicode", + "libc", + "unicode-width 0.2.2", + "windows-sys 0.61.2", +] + [[package]] name = "console_error_panic_hook" version = "0.1.7" @@ -1551,28 +1488,31 @@ version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" dependencies = [ - "getrandom 0.2.16", + "getrandom 0.2.17", "once_cell", "tiny-keccak", ] [[package]] name = "const_panic" -version = "0.2.12" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2459fc9262a1aa204eb4b5764ad4f189caec88aea9634389c0a25f8be7f6265e" +checksum = "e262cdaac42494e3ae34c43969f9cdeb7da178bdb4b66fa6a1ea2edb4c8ae652" +dependencies = [ + "typewit", +] [[package]] name = "constant_time_eq" -version = "0.3.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" +checksum = "3d52eff69cd5e647efe296129160853a42795992097e8af39800e1060caeea9b" [[package]] name = "core-foundation" -version = "0.10.0" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b55271e5c8c478ad3f38ad24ef34923091e0548492a266d19b3c0b4d82574c63" +checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6" dependencies = [ "core-foundation-sys", "libc", @@ -1584,29 +1524,20 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" -[[package]] -name = "core2" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b49ba7ef1ad6107f8824dbe97de947cbaac53c44e7f9756a1fba0d37c1eec505" -dependencies = [ - "memchr", -] - [[package]] name = "core_extensions" -version = "1.5.3" +version = "1.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92c71dc07c9721607e7a16108336048ee978c3a8b129294534272e8bac96c0ee" +checksum = "42bb5e5d0269fd4f739ea6cedaf29c16d81c27a7ce7582008e90eb50dcd57003" dependencies = [ "core_extensions_proc_macros", ] [[package]] name = "core_extensions_proc_macros" -version = "1.5.3" +version = "1.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69f3b219d28b6e3b4ac87bc1fc522e0803ab22e055da177bff0068c4150c61a6" +checksum = "533d38ecd2709b7608fb8e18e4504deb99e9a72879e6aa66373a76d8dc4259ea" [[package]] name = "cpufeatures" @@ -1619,35 +1550,34 @@ dependencies = [ [[package]] name = "crc32fast" -version = "1.4.2" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" dependencies = [ "cfg-if", ] [[package]] name = "criterion" -version = "0.5.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +checksum = "950046b2aa2492f9a536f5f4f9a3de7b9e2476e575e05bd6c333371add4d98f3" dependencies = [ + "alloca", "anes", "cast", "ciborium", - "clap 4.5.39", + "clap", "criterion-plot", "futures", - "is-terminal", - "itertools 0.10.5", + "itertools 0.13.0", "num-traits", - "once_cell", "oorandom", + "page_size", "plotters", "rayon", "regex", "serde", - "serde_derive", "serde_json", "tinytemplate", "tokio", @@ -1656,12 +1586,12 @@ dependencies = [ [[package]] name = "criterion-plot" -version = "0.5.0" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +checksum = "d8d80a2f4f5b554395e47b5d8305bc3d27813bacb73493eb1001e8f76dae29ea" dependencies = [ "cast", - "itertools 0.10.5", + "itertools 0.13.0", ] [[package]] @@ -1692,6 +1622,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "crossbeam-queue" +version = "0.3.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.21" @@ -1700,15 +1639,15 @@ checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" [[package]] name = "crunchy" -version = "0.2.3" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929" +checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" [[package]] name = "crypto-common" -version = "0.1.6" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" dependencies = [ "generic-array", "typenum", @@ -1716,30 +1655,30 @@ dependencies = [ [[package]] name = "csv" -version = "1.3.1" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acdc4883a9c96732e4733212c01447ebd805833b7275a73ca3ee080fd77afdaf" +checksum = "52cd9d68cf7efc6ddfaaee42e7288d3a99d613d4b50f76ce9827ae0c6e14f938" dependencies = [ "csv-core", "itoa", "ryu", - "serde", + "serde_core", ] [[package]] name = "csv-core" -version = "0.1.12" +version = "0.1.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d02f3b0da4c6504f86e9cd789d8dbafab48c2321be74e9987593de5a894d93d" +checksum = "704a3c26996a80471189265814dbc2c257598b96b8a7feae2d31ace646bb9782" dependencies = [ "memchr", ] [[package]] name = "ctor" -version = "0.4.2" +version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4735f265ba6a1188052ca32d461028a7d1125868be18e287e756019da7607b5" +checksum = "424e0138278faeb2b401f174ad17e715c829512d74f3d1e81eb43365c2e0590e" dependencies = [ "ctor-proc-macro", "dtor", @@ -1747,15 +1686,21 @@ dependencies = [ [[package]] name = "ctor-proc-macro" -version = "0.0.5" +version = "0.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52560adf09603e58c9a7ee1fe1dcb95a16927b17c127f0ac02d6e768a0e25bc1" + +[[package]] +name = "cty" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f211af61d8efdd104f96e57adf5e426ba1bc3ed7a4ead616e15e5881fd79c4d" +checksum = "b365fabc795046672053e29c954733ec3b05e4be654ab130fe8f1f94d7051f35" [[package]] name = "darling" -version = "0.20.11" +version = "0.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" +checksum = "25ae13da2f202d56bd7f91c25fba009e7717a1e4a1cc98a76d844b65ae912e9d" dependencies = [ "darling_core", "darling_macro", @@ -1763,35 +1708,28 @@ dependencies = [ [[package]] name = "darling_core" -version = "0.20.11" +version = "0.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e" +checksum = "9865a50f7c335f53564bb694ef660825eb8610e0a53d3e11bf1b0d3df31e03b0" dependencies = [ - "fnv", "ident_case", "proc-macro2", "quote", "strsim", - "syn 2.0.101", + "syn 2.0.117", ] [[package]] name = "darling_macro" -version = "0.20.11" +version = "0.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" +checksum = "ac3984ec7bd6cfa798e62b4a642426a5be0e68f9401cfc2a01e3fa9ea2fcdb8d" dependencies = [ "darling_core", "quote", - "syn 2.0.101", + "syn 2.0.117", ] -[[package]] -name = "dary_heap" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04d2cd9c18b9f454ed67da600630b021a8a80bf33f8c95896ab33aaf1c26b728" - [[package]] name = "dashmap" version = "6.1.0" @@ -1808,14 +1746,13 @@ dependencies = [ [[package]] name = "datafusion" -version = "47.0.0" +version = "52.3.0" dependencies = [ "arrow", - "arrow-ipc", "arrow-schema", "async-trait", "bytes", - "bzip2 0.5.2", + "bzip2", "chrono", "criterion", "ctor", @@ -1825,6 +1762,7 @@ dependencies = [ "datafusion-common", "datafusion-common-runtime", "datafusion-datasource", + "datafusion-datasource-arrow", "datafusion-datasource-avro", "datafusion-datasource-csv", "datafusion-datasource-json", @@ -1842,6 +1780,7 @@ dependencies = [ "datafusion-macros", "datafusion-optimizer", "datafusion-physical-expr", + "datafusion-physical-expr-adapter", "datafusion-physical-expr-common", "datafusion-physical-optimizer", "datafusion-physical-plan", @@ -1851,16 +1790,19 @@ dependencies = [ "env_logger", "flate2", "futures", + "glob", "insta", "itertools 0.14.0", + "liblzma", "log", - "nix", + "nix 0.31.2", "object_store", "parking_lot", "parquet", - "paste", - "rand 0.9.1", + "pretty_assertions", + "rand 0.9.2", "rand_distr", + "recursive", "regex", "rstest", "serde", @@ -1872,37 +1814,39 @@ dependencies = [ "tokio", "url", "uuid", - "xz2", "zstd", ] [[package]] name = "datafusion-benchmarks" -version = "47.0.0" +version = "52.3.0" dependencies = [ "arrow", + "async-trait", + "bytes", + "clap", "datafusion", "datafusion-common", "datafusion-proto", "env_logger", "futures", + "libmimalloc-sys", "log", "mimalloc", "object_store", "parquet", - "rand 0.9.1", + "rand 0.9.2", + "regex", "serde", "serde_json", "snmalloc-rs", - "structopt", - "test-utils", "tokio", "tokio-util", ] [[package]] name = "datafusion-catalog" -version = "47.0.0" +version = "52.3.0" dependencies = [ "arrow", "async-trait", @@ -1915,7 +1859,6 @@ dependencies = [ "datafusion-physical-expr", "datafusion-physical-plan", "datafusion-session", - "datafusion-sql", "futures", "itertools 0.14.0", "log", @@ -1926,75 +1869,78 @@ dependencies = [ [[package]] name = "datafusion-catalog-listing" -version = "47.0.0" +version = "52.3.0" dependencies = [ "arrow", "async-trait", "datafusion-catalog", "datafusion-common", "datafusion-datasource", + "datafusion-datasource-parquet", "datafusion-execution", "datafusion-expr", "datafusion-physical-expr", + "datafusion-physical-expr-adapter", "datafusion-physical-expr-common", "datafusion-physical-plan", - "datafusion-session", "futures", + "itertools 0.14.0", "log", "object_store", - "tokio", ] [[package]] name = "datafusion-cli" -version = "47.0.0" +version = "52.3.0" dependencies = [ "arrow", - "assert_cmd", "async-trait", "aws-config", "aws-credential-types", - "clap 4.5.39", + "chrono", + "clap", "ctor", "datafusion", + "datafusion-common", "dirs", "env_logger", "futures", "insta", "insta-cmd", + "log", "mimalloc", "object_store", "parking_lot", "parquet", - "predicates", "regex", "rstest", "rustyline", + "testcontainers-modules", "tokio", "url", ] [[package]] name = "datafusion-common" -version = "47.0.0" +version = "52.3.0" dependencies = [ - "ahash 0.8.12", "apache-avro", "arrow", "arrow-ipc", - "base64 0.22.1", "chrono", + "criterion", + "foldhash 0.2.0", "half", - "hashbrown 0.14.5", - "indexmap 2.9.0", + "hashbrown 0.16.1", + "hex", + "indexmap 2.13.0", "insta", + "itertools 0.14.0", "libc", "log", "object_store", "parquet", - "paste", - "pyo3", - "rand 0.9.1", + "rand 0.9.2", "recursive", "sqlparser", "tokio", @@ -2003,7 +1949,7 @@ dependencies = [ [[package]] name = "datafusion-common-runtime" -version = "47.0.0" +version = "52.3.0" dependencies = [ "futures", "log", @@ -2012,78 +1958,95 @@ dependencies = [ [[package]] name = "datafusion-datasource" -version = "47.0.0" +version = "52.3.0" dependencies = [ "arrow", "async-compression", "async-trait", "bytes", - "bzip2 0.5.2", + "bzip2", "chrono", "criterion", + "crossbeam-queue", "datafusion-common", "datafusion-common-runtime", "datafusion-execution", "datafusion-expr", "datafusion-physical-expr", + "datafusion-physical-expr-adapter", "datafusion-physical-expr-common", "datafusion-physical-plan", "datafusion-session", "flate2", "futures", "glob", + "insta", "itertools 0.14.0", + "liblzma", "log", "object_store", - "parquet", - "rand 0.9.1", + "rand 0.9.2", "tempfile", "tokio", "tokio-util", "url", - "xz2", "zstd", ] [[package]] -name = "datafusion-datasource-avro" -version = "47.0.0" +name = "datafusion-datasource-arrow" +version = "52.3.0" dependencies = [ - "apache-avro", "arrow", + "arrow-ipc", "async-trait", "bytes", "chrono", - "datafusion-catalog", "datafusion-common", + "datafusion-common-runtime", "datafusion-datasource", "datafusion-execution", - "datafusion-physical-expr", + "datafusion-expr", + "datafusion-physical-expr-common", + "datafusion-physical-plan", + "datafusion-session", + "futures", + "itertools 0.14.0", + "object_store", + "tokio", +] + +[[package]] +name = "datafusion-datasource-avro" +version = "52.3.0" +dependencies = [ + "apache-avro", + "arrow", + "async-trait", + "bytes", + "datafusion-common", + "datafusion-datasource", "datafusion-physical-expr-common", "datafusion-physical-plan", "datafusion-session", "futures", "num-traits", "object_store", - "rstest", "serde_json", - "tokio", ] [[package]] name = "datafusion-datasource-csv" -version = "47.0.0" +version = "52.3.0" dependencies = [ "arrow", "async-trait", "bytes", - "datafusion-catalog", "datafusion-common", "datafusion-common-runtime", "datafusion-datasource", "datafusion-execution", "datafusion-expr", - "datafusion-physical-expr", "datafusion-physical-expr-common", "datafusion-physical-plan", "datafusion-session", @@ -2095,18 +2058,16 @@ dependencies = [ [[package]] name = "datafusion-datasource-json" -version = "47.0.0" +version = "52.3.0" dependencies = [ "arrow", "async-trait", "bytes", - "datafusion-catalog", "datafusion-common", "datafusion-common-runtime", "datafusion-datasource", "datafusion-execution", "datafusion-expr", - "datafusion-physical-expr", "datafusion-physical-expr-common", "datafusion-physical-plan", "datafusion-session", @@ -2114,27 +2075,31 @@ dependencies = [ "object_store", "serde_json", "tokio", + "tokio-stream", ] [[package]] name = "datafusion-datasource-parquet" -version = "47.0.0" +version = "52.3.0" dependencies = [ "arrow", "async-trait", "bytes", "chrono", - "datafusion-catalog", + "criterion", "datafusion-common", "datafusion-common-runtime", "datafusion-datasource", "datafusion-execution", "datafusion-expr", - "datafusion-functions-aggregate", + "datafusion-functions", + "datafusion-functions-aggregate-common", + "datafusion-functions-nested", "datafusion-physical-expr", + "datafusion-physical-expr-adapter", "datafusion-physical-expr-common", - "datafusion-physical-optimizer", "datafusion-physical-plan", + "datafusion-pruning", "datafusion-session", "futures", "itertools 0.14.0", @@ -2142,34 +2107,45 @@ dependencies = [ "object_store", "parking_lot", "parquet", - "rand 0.9.1", + "tempfile", "tokio", ] [[package]] name = "datafusion-doc" -version = "47.0.0" +version = "52.3.0" [[package]] name = "datafusion-examples" -version = "47.0.0" +version = "52.3.0" dependencies = [ "arrow", "arrow-flight", "arrow-schema", "async-trait", + "base64 0.22.1", "bytes", "dashmap", "datafusion", - "datafusion-ffi", + "datafusion-common", + "datafusion-expr", + "datafusion-physical-expr-adapter", "datafusion-proto", + "datafusion-sql", "env_logger", "futures", + "insta", "log", "mimalloc", - "nix", + "nix 0.31.2", + "nom", "object_store", "prost", + "rand 0.9.2", + "serde", + "serde_json", + "strum 0.28.0", + "strum_macros 0.28.0", "tempfile", "test-utils", "tokio", @@ -2182,28 +2158,33 @@ dependencies = [ [[package]] name = "datafusion-execution" -version = "47.0.0" +version = "52.3.0" dependencies = [ "arrow", + "arrow-buffer", + "async-trait", "chrono", "dashmap", "datafusion-common", "datafusion-expr", + "datafusion-physical-expr-common", "futures", "insta", "log", "object_store", "parking_lot", - "rand 0.9.1", + "parquet", + "rand 0.9.2", "tempfile", "url", ] [[package]] name = "datafusion-expr" -version = "47.0.0" +version = "52.3.0" dependencies = [ "arrow", + "async-trait", "chrono", "ctor", "datafusion-common", @@ -2213,9 +2194,9 @@ dependencies = [ "datafusion-functions-window-common", "datafusion-physical-expr-common", "env_logger", - "indexmap 2.9.0", + "indexmap 2.13.0", "insta", - "paste", + "itertools 0.14.0", "recursive", "serde_json", "sqlparser", @@ -2223,18 +2204,18 @@ dependencies = [ [[package]] name = "datafusion-expr-common" -version = "47.0.0" +version = "52.3.0" dependencies = [ "arrow", "datafusion-common", - "indexmap 2.9.0", + "indexmap 2.13.0", + "insta", "itertools 0.14.0", - "paste", ] [[package]] name = "datafusion-ffi" -version = "47.0.0" +version = "52.3.0" dependencies = [ "abi_stable", "arrow", @@ -2242,7 +2223,22 @@ dependencies = [ "async-ffi", "async-trait", "datafusion", + "datafusion-catalog", + "datafusion-common", + "datafusion-datasource", + "datafusion-execution", + "datafusion-expr", + "datafusion-functions", + "datafusion-functions-aggregate", + "datafusion-functions-aggregate-common", + "datafusion-functions-table", + "datafusion-functions-window", + "datafusion-physical-expr", + "datafusion-physical-expr-common", + "datafusion-physical-plan", "datafusion-proto", + "datafusion-proto-common", + "datafusion-session", "doc-comment", "futures", "log", @@ -2253,7 +2249,7 @@ dependencies = [ [[package]] name = "datafusion-functions" -version = "47.0.0" +version = "52.3.0" dependencies = [ "arrow", "arrow-buffer", @@ -2261,18 +2257,23 @@ dependencies = [ "blake2", "blake3", "chrono", + "chrono-tz", "criterion", + "ctor", "datafusion-common", "datafusion-doc", "datafusion-execution", "datafusion-expr", "datafusion-expr-common", "datafusion-macros", + "env_logger", "hex", "itertools 0.14.0", "log", "md-5", - "rand 0.9.1", + "memchr", + "num-traits", + "rand 0.9.2", "regex", "sha2", "tokio", @@ -2282,9 +2283,8 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate" -version = "47.0.0" +version = "52.3.0" dependencies = [ - "ahash 0.8.12", "arrow", "criterion", "datafusion-common", @@ -2295,28 +2295,28 @@ dependencies = [ "datafusion-macros", "datafusion-physical-expr", "datafusion-physical-expr-common", + "foldhash 0.2.0", "half", "log", - "paste", - "rand 0.9.1", + "num-traits", + "rand 0.9.2", ] [[package]] name = "datafusion-functions-aggregate-common" -version = "47.0.0" +version = "52.3.0" dependencies = [ - "ahash 0.8.12", "arrow", "criterion", "datafusion-common", "datafusion-expr-common", "datafusion-physical-expr-common", - "rand 0.9.1", + "rand 0.9.2", ] [[package]] name = "datafusion-functions-nested" -version = "47.0.0" +version = "52.3.0" dependencies = [ "arrow", "arrow-ord", @@ -2325,19 +2325,22 @@ dependencies = [ "datafusion-doc", "datafusion-execution", "datafusion-expr", + "datafusion-expr-common", "datafusion-functions", "datafusion-functions-aggregate", + "datafusion-functions-aggregate-common", "datafusion-macros", "datafusion-physical-expr-common", + "hashbrown 0.16.1", "itertools 0.14.0", + "itoa", "log", - "paste", - "rand 0.9.1", + "rand 0.9.2", ] [[package]] name = "datafusion-functions-table" -version = "47.0.0" +version = "52.3.0" dependencies = [ "arrow", "async-trait", @@ -2346,14 +2349,14 @@ dependencies = [ "datafusion-expr", "datafusion-physical-plan", "parking_lot", - "paste", ] [[package]] name = "datafusion-functions-window" -version = "47.0.0" +version = "52.3.0" dependencies = [ "arrow", + "criterion", "datafusion-common", "datafusion-doc", "datafusion-expr", @@ -2362,12 +2365,11 @@ dependencies = [ "datafusion-physical-expr", "datafusion-physical-expr-common", "log", - "paste", ] [[package]] name = "datafusion-functions-window-common" -version = "47.0.0" +version = "52.3.0" dependencies = [ "datafusion-common", "datafusion-physical-expr-common", @@ -2375,16 +2377,16 @@ dependencies = [ [[package]] name = "datafusion-macros" -version = "47.0.0" +version = "52.3.0" dependencies = [ - "datafusion-expr", + "datafusion-doc", "quote", - "syn 2.0.101", + "syn 2.0.117", ] [[package]] name = "datafusion-optimizer" -version = "47.0.0" +version = "52.3.0" dependencies = [ "arrow", "async-trait", @@ -2393,13 +2395,14 @@ dependencies = [ "ctor", "datafusion-common", "datafusion-expr", + "datafusion-expr-common", "datafusion-functions-aggregate", "datafusion-functions-window", "datafusion-functions-window-common", "datafusion-physical-expr", "datafusion-sql", "env_logger", - "indexmap 2.9.0", + "indexmap 2.13.0", "insta", "itertools 0.14.0", "log", @@ -2410,9 +2413,8 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" -version = "47.0.0" +version = "52.3.0" dependencies = [ - "ahash 0.8.12", "arrow", "criterion", "datafusion-common", @@ -2422,101 +2424,137 @@ dependencies = [ "datafusion-functions-aggregate-common", "datafusion-physical-expr-common", "half", - "hashbrown 0.14.5", - "indexmap 2.9.0", + "hashbrown 0.16.1", + "indexmap 2.13.0", "insta", "itertools 0.14.0", - "log", - "paste", - "petgraph 0.8.1", - "rand 0.9.1", + "parking_lot", + "petgraph", + "rand 0.9.2", + "recursive", "rstest", + "tokio", +] + +[[package]] +name = "datafusion-physical-expr-adapter" +version = "52.3.0" +dependencies = [ + "arrow", + "datafusion-common", + "datafusion-expr", + "datafusion-functions", + "datafusion-physical-expr", + "datafusion-physical-expr-common", + "itertools 0.14.0", ] [[package]] name = "datafusion-physical-expr-common" -version = "47.0.0" +version = "52.3.0" dependencies = [ - "ahash 0.8.12", "arrow", + "chrono", + "criterion", "datafusion-common", "datafusion-expr-common", - "hashbrown 0.14.5", + "hashbrown 0.16.1", + "indexmap 2.13.0", "itertools 0.14.0", + "parking_lot", + "rand 0.9.2", ] [[package]] name = "datafusion-physical-optimizer" -version = "47.0.0" +version = "52.3.0" dependencies = [ "arrow", "datafusion-common", "datafusion-execution", "datafusion-expr", "datafusion-expr-common", - "datafusion-functions-nested", + "datafusion-functions", + "datafusion-functions-window", "datafusion-physical-expr", "datafusion-physical-expr-common", "datafusion-physical-plan", + "datafusion-pruning", "insta", "itertools 0.14.0", - "log", "recursive", + "tokio", ] [[package]] name = "datafusion-physical-plan" -version = "47.0.0" +version = "52.3.0" dependencies = [ - "ahash 0.8.12", "arrow", "arrow-ord", "arrow-schema", "async-trait", - "chrono", "criterion", "datafusion-common", "datafusion-common-runtime", "datafusion-execution", "datafusion-expr", + "datafusion-functions", "datafusion-functions-aggregate", + "datafusion-functions-aggregate-common", "datafusion-functions-window", "datafusion-functions-window-common", "datafusion-physical-expr", "datafusion-physical-expr-common", "futures", "half", - "hashbrown 0.14.5", - "indexmap 2.9.0", + "hashbrown 0.16.1", + "indexmap 2.13.0", "insta", "itertools 0.14.0", "log", + "num-traits", "parking_lot", "pin-project-lite", - "rand 0.9.1", + "rand 0.9.2", "rstest", "rstest_reuse", - "tempfile", "tokio", ] [[package]] name = "datafusion-proto" -version = "47.0.0" +version = "52.3.0" dependencies = [ "arrow", + "async-trait", "chrono", "datafusion", + "datafusion-catalog", + "datafusion-catalog-listing", "datafusion-common", + "datafusion-datasource", + "datafusion-datasource-arrow", + "datafusion-datasource-avro", + "datafusion-datasource-csv", + "datafusion-datasource-json", + "datafusion-datasource-parquet", + "datafusion-execution", "datafusion-expr", "datafusion-functions", "datafusion-functions-aggregate", + "datafusion-functions-table", "datafusion-functions-window-common", + "datafusion-physical-expr", + "datafusion-physical-expr-common", + "datafusion-physical-plan", "datafusion-proto-common", "doc-comment", "object_store", - "pbjson", + "pbjson 0.9.0", + "pretty_assertions", "prost", + "rand 0.9.2", "serde", "serde_json", "tokio", @@ -2524,59 +2562,79 @@ dependencies = [ [[package]] name = "datafusion-proto-common" -version = "47.0.0" +version = "52.3.0" dependencies = [ "arrow", "datafusion-common", "doc-comment", - "pbjson", + "pbjson 0.9.0", "prost", "serde", - "serde_json", ] [[package]] -name = "datafusion-session" -version = "47.0.0" +name = "datafusion-pruning" +version = "52.3.0" dependencies = [ "arrow", - "async-trait", - "dashmap", "datafusion-common", - "datafusion-common-runtime", - "datafusion-execution", + "datafusion-datasource", "datafusion-expr", + "datafusion-expr-common", + "datafusion-functions-nested", "datafusion-physical-expr", + "datafusion-physical-expr-common", "datafusion-physical-plan", - "datafusion-sql", - "futures", + "insta", "itertools 0.14.0", "log", - "object_store", +] + +[[package]] +name = "datafusion-session" +version = "52.3.0" +dependencies = [ + "async-trait", + "datafusion-common", + "datafusion-execution", + "datafusion-expr", + "datafusion-physical-plan", "parking_lot", - "tokio", ] [[package]] name = "datafusion-spark" -version = "47.0.0" +version = "52.3.0" dependencies = [ "arrow", + "bigdecimal", + "chrono", + "crc32fast", + "criterion", + "datafusion", "datafusion-catalog", "datafusion-common", "datafusion-execution", "datafusion-expr", "datafusion-functions", - "datafusion-macros", + "datafusion-functions-aggregate", + "datafusion-functions-nested", "log", + "percent-encoding", + "rand 0.9.2", + "serde_json", + "sha1", + "sha2", + "url", ] [[package]] name = "datafusion-sql" -version = "47.0.0" +version = "52.3.0" dependencies = [ "arrow", "bigdecimal", + "chrono", "ctor", "datafusion-common", "datafusion-expr", @@ -2585,10 +2643,10 @@ dependencies = [ "datafusion-functions-nested", "datafusion-functions-window", "env_logger", - "indexmap 2.9.0", + "indexmap 2.13.0", "insta", + "itertools 0.14.0", "log", - "paste", "recursive", "regex", "rstest", @@ -2597,16 +2655,17 @@ dependencies = [ [[package]] name = "datafusion-sqllogictest" -version = "47.0.0" +version = "52.3.0" dependencies = [ "arrow", "async-trait", "bigdecimal", "bytes", "chrono", - "clap 4.5.39", + "clap", "datafusion", "datafusion-spark", + "datafusion-substrait", "env_logger", "futures", "half", @@ -2614,28 +2673,27 @@ dependencies = [ "itertools 0.14.0", "log", "object_store", - "postgres-protocol", "postgres-types", - "rust_decimal", + "regex", "sqllogictest", "sqlparser", "tempfile", - "testcontainers", "testcontainers-modules", - "thiserror 2.0.12", + "thiserror", "tokio", "tokio-postgres", ] [[package]] name = "datafusion-substrait" -version = "47.0.0" +version = "52.3.0" dependencies = [ "async-recursion", "async-trait", "chrono", "datafusion", "datafusion-functions-aggregate", + "half", "insta", "itertools 0.14.0", "object_store", @@ -2649,8 +2707,9 @@ dependencies = [ [[package]] name = "datafusion-wasmtest" -version = "47.0.0" +version = "52.3.0" dependencies = [ + "bytes", "chrono", "console_error_panic_hook", "datafusion", @@ -2660,8 +2719,8 @@ dependencies = [ "datafusion-optimizer", "datafusion-physical-plan", "datafusion-sql", - "getrandom 0.3.3", - "insta", + "futures", + "getrandom 0.3.4", "object_store", "tokio", "url", @@ -2671,19 +2730,19 @@ dependencies = [ [[package]] name = "deranged" -version = "0.4.0" +version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c9e6a11ca8224451684bc0d7d5a7adbf8f2fd6887261a1cfc3c0432f9d4068e" +checksum = "7cd812cc2bc1d69d4764bd80df88b4317eaef9e773c75226407d9bc0876b211c" dependencies = [ "powerfmt", - "serde", + "serde_core", ] [[package]] -name = "difflib" -version = "0.4.0" +name = "diff" +version = "0.1.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6184e33543162437515c2e2b48714794e37845ec9851711914eec9d308f6ebe8" +checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8" [[package]] name = "digest" @@ -2714,7 +2773,7 @@ dependencies = [ "libc", "option-ext", "redox_users", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -2725,14 +2784,14 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.117", ] [[package]] name = "doc-comment" -version = "0.3.3" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" +checksum = "780955b8b195a21ab8e4ac6b60dd1dbdcec1dc6c51c0617964b08c81785e12c9" [[package]] name = "docker_credential" @@ -2747,18 +2806,18 @@ dependencies = [ [[package]] name = "dtor" -version = "0.0.6" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97cbdf2ad6846025e8e25df05171abfb30e3ababa12ee0a0e44b9bbe570633a8" +checksum = "404d02eeb088a82cfd873006cb713fe411306c7d182c344905e101fb1167d301" dependencies = [ "dtor-proc-macro", ] [[package]] name = "dtor-proc-macro" -version = "0.0.5" +version = "0.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7454e41ff9012c00d53cf7f475c5e3afa3b91b7c90568495495e8d9bf47a1055" +checksum = "f678cf4a922c215c63e0de95eb1ff08a958a81d47e485cf9da1e27bf6305cfa5" [[package]] name = "dunce" @@ -2768,9 +2827,9 @@ checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" [[package]] name = "dyn-clone" -version = "1.0.19" +version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c7a8fb8a9fbf66c1f703fe16184d10ca0ee9d23be5b4436400408ba54a95005" +checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" [[package]] name = "educe" @@ -2781,7 +2840,7 @@ dependencies = [ "enum-ordinalize", "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.117", ] [[package]] @@ -2804,29 +2863,29 @@ checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d" [[package]] name = "enum-ordinalize" -version = "4.3.0" +version = "4.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fea0dcfa4e54eeb516fe454635a95753ddd39acda650ce703031c6973e315dd5" +checksum = "4a1091a7bb1f8f2c4b28f1fe2cef4980ca2d410a3d727d67ecc3178c9b0800f0" dependencies = [ "enum-ordinalize-derive", ] [[package]] name = "enum-ordinalize-derive" -version = "4.3.1" +version = "4.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d28318a75d4aead5c4db25382e8ef717932d0346600cacae6357eb5941bc5ff" +checksum = "8ca9601fb2d62598ee17836250842873a413586e5d7ed88b356e38ddbb0ec631" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.117", ] [[package]] name = "env_filter" -version = "0.1.3" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "186e05a59d4c50738528153b83b0b0194d3a29507dfec16eccd4b342903397d0" +checksum = "7a1c3cc8e57274ec99de65301228b537f1e4eedc1b8e0f9411c6caac8ae7308f" dependencies = [ "log", "regex", @@ -2834,11 +2893,11 @@ dependencies = [ [[package]] name = "env_logger" -version = "0.11.8" +version = "0.11.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13c863f0904021b108aa8b2f55046443e6b1ebde8fd4a15c399893aae4fa069f" +checksum = "b2daee4ea451f429a58296525ddf28b45a3b64f1acf6587e2067437bb11e218d" dependencies = [ - "anstream", + "anstream 0.6.21", "anstyle", "env_filter", "jiff", @@ -2853,12 +2912,12 @@ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" [[package]] name = "errno" -version = "0.3.12" +version = "0.3.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cea14ef9355e3beab063703aa9dab15afd25f0667c341310c1e5274bb1d0da18" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -2875,13 +2934,12 @@ checksum = "5692dd7b5a1978a5aeb0ce83b7655c58ca8efdcb79d21036ea249da95afec2c6" [[package]] name = "etcetera" -version = "0.10.0" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26c7b13d0780cb82722fd59f6f57f925e143427e4a75313a6c77243bf5326ae6" +checksum = "de48cc4d1c1d97a20fd819def54b890cadde72ed3ad0c614822a0a433361be96" dependencies = [ "cfg-if", - "home", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -2903,10 +2961,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ce92ff622d6dadf7349484f42c93271a0d49b7cc4d466a936405bacbe10aa78" dependencies = [ "cfg-if", - "rustix 1.0.7", + "rustix", "windows-sys 0.59.0", ] +[[package]] +name = "ferroid" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb330bbd4cb7a5b9f559427f06f98a4f853a137c8298f3bd3f8ca57663e21986" +dependencies = [ + "portable-atomic", + "rand 0.9.2", + "web-time", +] + [[package]] name = "ffi_example_table_provider" version = "0.1.0" @@ -2939,16 +3008,21 @@ dependencies = [ [[package]] name = "filetime" -version = "0.2.25" +version = "0.2.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35c0522e981e68cbfa8c3f978441a5f34b30b96e146b33cd3359176b50fe8586" +checksum = "f98844151eee8917efc50bd9e8318cb963ae8b297431495d3f758616ea5c57db" dependencies = [ "cfg-if", "libc", "libredox", - "windows-sys 0.59.0", ] +[[package]] +name = "find-msvc-tools" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" + [[package]] name = "fixedbitset" version = "0.5.7" @@ -2957,32 +3031,23 @@ checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" [[package]] name = "flatbuffers" -version = "25.2.10" +version = "25.12.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1045398c1bfd89168b5fd3f1fc11f6e70b34f6f66300c87d44d3de849463abf1" +checksum = "35f6839d7b3b98adde531effaf34f0c2badc6f4735d26fe74709d8e513a96ef3" dependencies = [ - "bitflags 2.9.1", + "bitflags", "rustc_version", ] [[package]] name = "flate2" -version = "1.1.1" +version = "1.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ced92e76e966ca2fd84c8f7aa01a4aea65b0eb6648d72f7c8f3e2764a67fece" +checksum = "843fba2746e448b37e26a819579957415c8cef339bf08564fe8b7ddbd959573c" dependencies = [ "crc32fast", - "libz-rs-sys", "miniz_oxide", -] - -[[package]] -name = "float-cmp" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b09cf3155332e944990140d967ff5eceb70df778b34f77d8075db46e4704e6d8" -dependencies = [ - "num-traits", + "zlib-rs", ] [[package]] @@ -2997,20 +3062,26 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" +[[package]] +name = "foldhash" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" + [[package]] name = "form_urlencoded" -version = "1.2.1" +version = "1.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" dependencies = [ "percent-encoding", ] [[package]] name = "fs-err" -version = "3.1.0" +version = "3.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f89bda4c2a21204059a977ed3bfe746677dfd137b83c339e702b0ac91d482aa" +checksum = "73fde052dbfc920003cfd2c8e2c6e6d4cc7c1091538c3a24226cec0665ab08c0" dependencies = [ "autocfg", ] @@ -3021,17 +3092,11 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" -[[package]] -name = "funty" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" - [[package]] name = "futures" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +checksum = "8b147ee9d1f6d097cef9ce628cd2ee62288d963e16fb287bd9286455b241382d" dependencies = [ "futures-channel", "futures-core", @@ -3044,9 +3109,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" +checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d" dependencies = [ "futures-core", "futures-sink", @@ -3054,15 +3119,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" +checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" [[package]] name = "futures-executor" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +checksum = "baf29c38818342a3b26b5b923639e7b1f4a61fc5e76102d4b1981c6dc7a7579d" dependencies = [ "futures-core", "futures-task", @@ -3071,32 +3136,32 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" +checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718" [[package]] name = "futures-macro" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.117", ] [[package]] name = "futures-sink" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" +checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893" [[package]] name = "futures-task" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" +checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" [[package]] name = "futures-timer" @@ -3106,9 +3171,9 @@ checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" [[package]] name = "futures-util" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" dependencies = [ "futures-channel", "futures-core", @@ -3118,7 +3183,6 @@ dependencies = [ "futures-task", "memchr", "pin-project-lite", - "pin-utils", "slab", ] @@ -3126,7 +3190,7 @@ dependencies = [ name = "gen" version = "0.1.0" dependencies = [ - "pbjson-build", + "pbjson-build 0.9.0", "prost-build", ] @@ -3134,7 +3198,7 @@ dependencies = [ name = "gen-common" version = "0.1.0" dependencies = [ - "pbjson-build", + "pbjson-build 0.9.0", "prost-build", ] @@ -3159,48 +3223,55 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" dependencies = [ "cfg-if", "js-sys", "libc", - "wasi 0.11.0+wasi-snapshot-preview1", + "wasi 0.11.1+wasi-snapshot-preview1", "wasm-bindgen", ] [[package]] name = "getrandom" -version = "0.3.3" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" dependencies = [ "cfg-if", "js-sys", "libc", - "r-efi", - "wasi 0.14.2+wasi-0.2.4", + "r-efi 5.3.0", + "wasip2", "wasm-bindgen", ] [[package]] -name = "gimli" -version = "0.31.1" +name = "getrandom" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" +checksum = "0de51e6874e94e7bf76d726fc5d13ba782deca734ff60d5bb2fb2607c7406555" +dependencies = [ + "cfg-if", + "libc", + "r-efi 6.0.0", + "wasip2", + "wasip3", +] [[package]] name = "glob" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" +checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" [[package]] name = "globset" -version = "0.4.16" +version = "0.4.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54a1028dfc5f5df5da8a56a73e6c153c9a9708ec57232470703592a3f18e49f5" +checksum = "52dfc19153a48bde0cbd630453615c8151bce3a5adfac7a0aebfbf0a1e1f57e3" dependencies = [ "aho-corasick", "bstr", @@ -3211,17 +3282,17 @@ dependencies = [ [[package]] name = "h2" -version = "0.4.10" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9421a676d1b147b16b82c9225157dc629087ef8ec4d5e2960f9437a90dac0a5" +checksum = "2f44da3a8150a6703ed5d34e164b875fd14c2cdab9af1252a9a1020bde2bdc54" dependencies = [ "atomic-waker", "bytes", "fnv", "futures-core", "futures-sink", - "http 1.3.1", - "indexmap 2.9.0", + "http 1.4.0", + "indexmap 2.13.0", "slab", "tokio", "tokio-util", @@ -3230,13 +3301,16 @@ dependencies = [ [[package]] name = "half" -version = "2.6.0" +version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9" +checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b" dependencies = [ "cfg-if", "crunchy", "num-traits", + "rand 0.9.2", + "rand_distr", + "zerocopy", ] [[package]] @@ -3244,38 +3318,31 @@ name = "hashbrown" version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" -dependencies = [ - "ahash 0.7.8", -] [[package]] name = "hashbrown" version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" -dependencies = [ - "ahash 0.8.12", - "allocator-api2", -] [[package]] name = "hashbrown" -version = "0.15.3" +version = "0.15.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84b26c544d002229e640969970a2e74021aadf6e2f96372b9c58eff97de08eb3" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" dependencies = [ - "allocator-api2", - "equivalent", - "foldhash", + "foldhash 0.1.5", ] [[package]] -name = "heck" -version = "0.3.3" +name = "hashbrown" +version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d621efb26863f0e9924c6ac577e8275e5e6b77455db64ffa6c65c904e9e132c" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" dependencies = [ - "unicode-segmentation", + "allocator-api2", + "equivalent", + "foldhash 0.2.0", ] [[package]] @@ -3284,12 +3351,6 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" -[[package]] -name = "hermit-abi" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f154ce46856750ed433c8649605bf7ed2de3bc35fd9d2a9f30cddd873c80cb08" - [[package]] name = "hex" version = "0.4.3" @@ -3307,11 +3368,11 @@ dependencies = [ [[package]] name = "home" -version = "0.5.11" +version = "0.5.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf" +checksum = "cc627f471c528ff0c4a49e1d5e60450c8f6461dd6d10ba9dcd3a61d3dff7728d" dependencies = [ - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -3327,12 +3388,11 @@ dependencies = [ [[package]] name = "http" -version = "1.3.1" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565" +checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" dependencies = [ "bytes", - "fnv", "itoa", ] @@ -3354,7 +3414,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http 1.3.1", + "http 1.4.0", ] [[package]] @@ -3365,7 +3425,7 @@ checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" dependencies = [ "bytes", "futures-core", - "http 1.3.1", + "http 1.4.0", "http-body 1.0.1", "pin-project-lite", ] @@ -3384,26 +3444,28 @@ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" [[package]] name = "humantime" -version = "2.2.0" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b112acc8b3adf4b107a8ec20977da0273a8c386765a3ec0229bd500a1443f9f" +checksum = "135b12329e5e3ce057a9f972339ea52bc954fe1e9358ef27f95e89716fbc5424" [[package]] name = "hyper" -version = "1.6.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80" +checksum = "2ab2d4f250c3d7b1c9fcdff1cece94ea4e2dfbec68614f7b87cb205f24ca9d11" dependencies = [ + "atomic-waker", "bytes", "futures-channel", - "futures-util", + "futures-core", "h2", - "http 1.3.1", + "http 1.4.0", "http-body 1.0.1", "httparse", "httpdate", "itoa", "pin-project-lite", + "pin-utils", "smallvec", "tokio", "want", @@ -3426,12 +3488,11 @@ dependencies = [ [[package]] name = "hyper-rustls" -version = "0.27.5" +version = "0.27.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d191583f3da1305256f22463b9bb0471acad48a4e534a5218b9963e9c1f59b2" +checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" dependencies = [ - "futures-util", - "http 1.3.1", + "http 1.4.0", "hyper", "hyper-util", "rustls", @@ -3457,17 +3518,20 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.12" +version = "0.1.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf9f1e950e0d9d1d3c47184416723cf29c0d1f93bd8cccf37e4beb6b44f31710" +checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0" dependencies = [ + "base64 0.22.1", "bytes", "futures-channel", "futures-util", - "http 1.3.1", + "http 1.4.0", "http-body 1.0.1", "hyper", + "ipnet", "libc", + "percent-encoding", "pin-project-lite", "socket2", "tokio", @@ -3492,9 +3556,9 @@ dependencies = [ [[package]] name = "iana-time-zone" -version = "0.1.63" +version = "0.1.65" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0c919e5debc312ad217002b8048a17b7d83f80703865bbfcfebb0458b0b27d8" +checksum = "e31bc9ad994ba00e440a8aa5c9ef0ec67d5cb5e5cb0cc7f8b744a35b389cc470" dependencies = [ "android_system_properties", "core-foundation-sys", @@ -3516,9 +3580,9 @@ dependencies = [ [[package]] name = "icu_collections" -version = "2.0.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "200072f5d0e3614556f94a9930d5dc3e0662a652823904c3a75dc3b0af7fee47" +checksum = "4c6b649701667bbe825c3b7e6388cb521c23d88644678e83c0c4d0a621a34b43" dependencies = [ "displaydoc", "potential_utf", @@ -3529,9 +3593,9 @@ dependencies = [ [[package]] name = "icu_locale_core" -version = "2.0.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0cde2700ccaed3872079a65fb1a78f6c0a36c91570f28755dda67bc8f7d9f00a" +checksum = "edba7861004dd3714265b4db54a3c390e880ab658fec5f7db895fae2046b5bb6" dependencies = [ "displaydoc", "litemap", @@ -3542,11 +3606,10 @@ dependencies = [ [[package]] name = "icu_normalizer" -version = "2.0.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "436880e8e18df4d7bbc06d58432329d6458cc84531f7ac5f024e93deadb37979" +checksum = "5f6c8828b67bf8908d82127b2054ea1b4427ff0230ee9141c54251934ab1b599" dependencies = [ - "displaydoc", "icu_collections", "icu_normalizer_data", "icu_properties", @@ -3557,42 +3620,38 @@ dependencies = [ [[package]] name = "icu_normalizer_data" -version = "2.0.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00210d6893afc98edb752b664b8890f0ef174c8adbb8d0be9710fa66fbbf72d3" +checksum = "7aedcccd01fc5fe81e6b489c15b247b8b0690feb23304303a9e560f37efc560a" [[package]] name = "icu_properties" -version = "2.0.1" +version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "016c619c1eeb94efb86809b015c58f479963de65bdb6253345c1a1276f22e32b" +checksum = "020bfc02fe870ec3a66d93e677ccca0562506e5872c650f893269e08615d74ec" dependencies = [ - "displaydoc", "icu_collections", "icu_locale_core", "icu_properties_data", "icu_provider", - "potential_utf", "zerotrie", "zerovec", ] [[package]] name = "icu_properties_data" -version = "2.0.1" +version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "298459143998310acd25ffe6810ed544932242d3f07083eee1084d83a71bd632" +checksum = "616c294cf8d725c6afcd8f55abc17c56464ef6211f9ed59cccffe534129c77af" [[package]] name = "icu_provider" -version = "2.0.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03c80da27b5f4187909049ee2d72f276f0d9f99a42c306bd0131ecfe04d8e5af" +checksum = "85962cf0ce02e1e0a629cc34e7ca3e373ce20dda4c4d7294bbd0bf1fdb59e614" dependencies = [ "displaydoc", "icu_locale_core", - "stable_deref_trait", - "tinystr", "writeable", "yoke", "zerofrom", @@ -3600,6 +3659,12 @@ dependencies = [ "zerovec", ] +[[package]] +name = "id-arena" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" + [[package]] name = "ident_case" version = "1.0.1" @@ -3608,9 +3673,9 @@ checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" [[package]] name = "idna" -version = "1.0.3" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" +checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" dependencies = [ "idna_adapter", "smallvec", @@ -3640,46 +3705,42 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.9.0" +version = "2.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e" +checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" dependencies = [ "equivalent", - "hashbrown 0.15.3", + "hashbrown 0.16.1", "serde", + "serde_core", ] [[package]] name = "indicatif" -version = "0.17.11" +version = "0.18.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "183b3088984b400f4cfac3620d5e076c84da5364016b4f49473de574b2586235" +checksum = "25470f23803092da7d239834776d653104d551bc4d7eacaf31e6837854b8e9eb" dependencies = [ - "console", - "number_prefix", + "console 0.16.3", "portable-atomic", - "unicode-width 0.2.0", + "unicode-width 0.2.2", + "unit-prefix", "web-time", ] -[[package]] -name = "indoc" -version = "2.0.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" - [[package]] name = "insta" -version = "1.43.1" +version = "1.46.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "154934ea70c58054b556dd430b99a98c2a7ff5309ac9891597e339b5c28f4371" +checksum = "e82db8c87c7f1ccecb34ce0c24399b8a73081427f3c7c50a5d597925356115e4" dependencies = [ - "console", + "console 0.15.11", "globset", "once_cell", "regex", "serde", "similar", + "tempfile", "walkdir", ] @@ -3702,44 +3763,25 @@ checksum = "8bb03732005da905c88227371639bf1ad885cc712789c011c31c5fb3ab3ccf02" [[package]] name = "ipnet" -version = "2.11.0" +version = "2.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" +checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2" [[package]] -name = "is-terminal" -version = "0.4.16" +name = "iri-string" +version = "0.7.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e04d7f318608d35d4b61ddd75cbdaee86b023ebe2bd5a66ee0915f0bf93095a9" +checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a" dependencies = [ - "hermit-abi", - "libc", - "windows-sys 0.59.0", + "memchr", + "serde", ] [[package]] name = "is_terminal_polyfill" -version = "1.70.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" - -[[package]] -name = "itertools" -version = "0.10.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" -dependencies = [ - "either", -] - -[[package]] -name = "itertools" -version = "0.12.1" +version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" -dependencies = [ - "either", -] +checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" [[package]] name = "itertools" @@ -3761,49 +3803,49 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.15" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" +checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" [[package]] name = "jiff" -version = "0.2.14" +version = "0.2.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a194df1107f33c79f4f93d02c80798520551949d59dfad22b6157048a88cca93" +checksum = "1a3546dc96b6d42c5f24902af9e2538e82e39ad350b0c766eb3fbf2d8f3d8359" dependencies = [ "jiff-static", "log", "portable-atomic", "portable-atomic-util", - "serde", + "serde_core", ] [[package]] name = "jiff-static" -version = "0.2.14" +version = "0.2.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c6e1db7ed32c6c71b759497fae34bf7933636f75a251b9e736555da426f6442" +checksum = "2a8c8b344124222efd714b73bb41f8b5120b27a7cc1c75593a6ff768d9d05aa4" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.117", ] [[package]] name = "jobserver" -version = "0.1.33" +version = "0.1.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38f262f097c174adebe41eb73d66ae9c06b2844fb0da69969647bbddd9b0538a" +checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" dependencies = [ - "getrandom 0.3.3", + "getrandom 0.3.4", "libc", ] [[package]] name = "js-sys" -version = "0.3.77" +version = "0.3.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cfaf33c695fc6e08064efbc1f72ec937429614f25eef83af942d0e227c3a28f" +checksum = "b49715b7073f385ba4bc528e5747d02e66cb39c6146efb66b781f131f0fb399c" dependencies = [ "once_cell", "wasm-bindgen", @@ -3816,16 +3858,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] -name = "lazycell" -version = "1.3.0" +name = "leb128fmt" +version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" [[package]] name = "lexical-core" -version = "1.0.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b765c31809609075565a70b4b71402281283aeda7ecaf4818ac14a7b2ade8958" +checksum = "7d8d125a277f807e55a77304455eb7b1cb52f2b18c143b60e766c120bd64a594" dependencies = [ "lexical-parse-float", "lexical-parse-integer", @@ -3836,84 +3878,59 @@ dependencies = [ [[package]] name = "lexical-parse-float" -version = "1.0.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de6f9cb01fb0b08060209a057c048fcbab8717b4c1ecd2eac66ebfe39a65b0f2" +checksum = "52a9f232fbd6f550bc0137dcb5f99ab674071ac2d690ac69704593cb4abbea56" dependencies = [ "lexical-parse-integer", "lexical-util", - "static_assertions", ] [[package]] name = "lexical-parse-integer" -version = "1.0.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72207aae22fc0a121ba7b6d479e42cbfea549af1479c3f3a4f12c70dd66df12e" +checksum = "9a7a039f8fb9c19c996cd7b2fcce303c1b2874fe1aca544edc85c4a5f8489b34" dependencies = [ "lexical-util", - "static_assertions", ] [[package]] name = "lexical-util" -version = "1.0.6" +version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a82e24bf537fd24c177ffbbdc6ebcc8d54732c35b50a3f28cc3f4e4c949a0b3" -dependencies = [ - "static_assertions", -] +checksum = "2604dd126bb14f13fb5d1bd6a66155079cb9fa655b37f875b3a742c705dbed17" [[package]] name = "lexical-write-float" -version = "1.0.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5afc668a27f460fb45a81a757b6bf2f43c2d7e30cb5a2dcd3abf294c78d62bd" +checksum = "50c438c87c013188d415fbabbb1dceb44249ab81664efbd31b14ae55dabb6361" dependencies = [ "lexical-util", "lexical-write-integer", - "static_assertions", ] [[package]] name = "lexical-write-integer" -version = "1.0.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "629ddff1a914a836fb245616a7888b62903aae58fa771e1d83943035efa0f978" +checksum = "409851a618475d2d5796377cad353802345cba92c867d9fbcde9cf4eac4e14df" dependencies = [ "lexical-util", - "static_assertions", ] [[package]] -name = "libc" -version = "0.2.172" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" - -[[package]] -name = "libflate" -version = "2.1.0" +name = "libbz2-rs-sys" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45d9dfdc14ea4ef0900c1cddbc8dcd553fbaacd8a4a282cf4018ae9dd04fb21e" -dependencies = [ - "adler32", - "core2", - "crc32fast", - "dary_heap", - "libflate_lz77", -] +checksum = "2c4a545a15244c7d945065b5d392b2d2d7f21526fba56ce51467b06ed445e8f7" [[package]] -name = "libflate_lz77" -version = "2.1.0" +name = "libc" +version = "0.2.183" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6e0d73b369f386f1c44abd9c570d5318f55ccde816ff4b562fa452e5182863d" -dependencies = [ - "core2", - "hashbrown 0.14.5", - "rle-decode-fast", -] +checksum = "b5b646652bf6661599e1da8901b3b9522896f01e736bad5f723fe7a3a27f899d" [[package]] name = "libloading" @@ -3926,96 +3943,92 @@ dependencies = [ ] [[package]] -name = "libloading" -version = "0.8.7" +name = "liblzma" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a793df0d7afeac54f95b471d3af7f0d4fb975699f972341a4b76988d49cdf0c" +checksum = "b6033b77c21d1f56deeae8014eb9fbe7bdf1765185a6c508b5ca82eeaed7f899" dependencies = [ - "cfg-if", - "windows-targets 0.53.0", + "liblzma-sys", +] + +[[package]] +name = "liblzma-sys" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f2db66f3268487b5033077f266da6777d057949b8f93c8ad82e441df25e6186" +dependencies = [ + "cc", + "libc", + "pkg-config", ] [[package]] name = "libm" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" +checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981" [[package]] name = "libmimalloc-sys" -version = "0.1.42" +version = "0.1.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec9d6fac27761dabcd4ee73571cdb06b7022dc99089acbe5435691edffaac0f4" +checksum = "667f4fec20f29dfc6bc7357c582d91796c169ad7e2fce709468aefeb2c099870" dependencies = [ "cc", + "cty", "libc", ] [[package]] name = "libredox" -version = "0.1.3" +version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" +checksum = "1744e39d1d6a9948f4f388969627434e31128196de472883b39f148769bfe30a" dependencies = [ - "bitflags 2.9.1", + "bitflags", "libc", - "redox_syscall 0.5.12", + "plain", + "redox_syscall 0.7.3", ] [[package]] name = "libtest-mimic" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5297962ef19edda4ce33aaa484386e0a5b3d7f2f4e037cbeee00503ef6b29d33" +checksum = "14e6ba06f0ade6e504aff834d7c34298e5155c6baca353cc6a4aaff2f9fd7f33" dependencies = [ - "anstream", + "anstream 1.0.0", "anstyle", - "clap 4.5.39", + "clap", "escape8259", ] -[[package]] -name = "libz-rs-sys" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6489ca9bd760fe9642d7644e827b0c9add07df89857b0416ee15c1cc1a3b8c5a" -dependencies = [ - "zlib-rs", -] - [[package]] name = "linux-raw-sys" -version = "0.4.15" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" +checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53" [[package]] -name = "linux-raw-sys" -version = "0.9.4" +name = "litemap" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd945864f07fe9f5371a27ad7b52a172b4b499999f1d97574c9fa68373937e12" - -[[package]] -name = "litemap" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956" +checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77" [[package]] name = "lock_api" -version = "0.4.12" +version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" dependencies = [ - "autocfg", "scopeguard", ] [[package]] name = "log" -version = "0.4.27" +version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" [[package]] name = "lru-slab" @@ -4025,29 +4038,18 @@ checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" [[package]] name = "lz4_flex" -version = "0.11.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75761162ae2b0e580d7e7c390558127e5f01b4194debd6221fd8c207fc80e3f5" -dependencies = [ - "twox-hash 1.6.3", -] - -[[package]] -name = "lzma-sys" -version = "0.1.20" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fda04ab3764e6cde78b9974eec4f779acaba7c4e84b36eca3cf77c581b85d27" +checksum = "98c23545df7ecf1b16c303910a69b079e8e251d60f7dd2cc9b4177f2afaf1746" dependencies = [ - "cc", - "libc", - "pkg-config", + "twox-hash", ] [[package]] name = "matchit" -version = "0.7.3" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" [[package]] name = "md-5" @@ -4061,24 +4063,15 @@ dependencies = [ [[package]] name = "memchr" -version = "2.7.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" - -[[package]] -name = "memoffset" -version = "0.9.1" +version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" -dependencies = [ - "autocfg", -] +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" [[package]] name = "mimalloc" -version = "0.1.46" +version = "0.1.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "995942f432bbb4822a7e9c3faa87a695185b0d09273ba85f097b54f4e458f2af" +checksum = "e1ee66a4b64c74f4ef288bcbb9192ad9c3feaad75193129ac8509af543894fd8" dependencies = [ "libmimalloc-sys", ] @@ -4091,38 +4084,33 @@ checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" [[package]] name = "minicov" -version = "0.3.7" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f27fe9f1cc3c22e1687f9446c2083c4c5fc7f0bcf1c7a86bdbded14985895b4b" +checksum = "4869b6a491569605d66d3952bcdf03df789e5b536e5f0cf7758a7f08a55ae24d" dependencies = [ "cc", "walkdir", ] -[[package]] -name = "minimal-lexical" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" - [[package]] name = "miniz_oxide" -version = "0.8.8" +version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3be647b768db090acb35d5ec5db2b0e1f1de11133ca123b9eacf5137868f892a" +checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" dependencies = [ "adler2", + "simd-adler32", ] [[package]] name = "mio" -version = "1.0.3" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" +checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc" dependencies = [ "libc", - "wasi 0.11.0+wasi-snapshot-preview1", - "windows-sys 0.52.0", + "wasi 0.11.1+wasi-snapshot-preview1", + "windows-sys 0.61.2", ] [[package]] @@ -4146,45 +4134,49 @@ version = "0.30.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "74523f3a35e05aba87a1d978330aef40f67b0304ac79c1c00b294c9830543db6" dependencies = [ - "bitflags 2.9.1", + "bitflags", "cfg-if", "cfg_aliases", "libc", ] [[package]] -name = "nom" -version = "7.1.3" +name = "nix" +version = "0.31.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +checksum = "5d6d0705320c1e6ba1d912b5e37cf18071b6c2e9b7fa8215a1e8a7651966f5d3" dependencies = [ - "memchr", - "minimal-lexical", + "bitflags", + "cfg-if", + "cfg_aliases", + "libc", ] [[package]] -name = "normalize-line-endings" -version = "0.3.0" +name = "nom" +version = "8.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61807f77802ff30975e01f4f071c8ba10c022052f98b3294119f3e615d13e5be" +checksum = "df9761775871bdef83bee530e60050f7e54b1105350d6884eb0fb4f46c2f9405" +dependencies = [ + "memchr", +] [[package]] name = "ntapi" -version = "0.4.1" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8a3895c6391c39d7fe7ebc444a87eb2991b2a0bc718fdabd071eec617fc68e4" +checksum = "c3b335231dfd352ffb0f8017f3b6027a4917f7df785ea2143d8af2adc66980ae" dependencies = [ "winapi", ] [[package]] name = "nu-ansi-term" -version = "0.46.0" +version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "overload", - "winapi", + "windows-sys 0.61.2", ] [[package]] @@ -4223,9 +4215,9 @@ dependencies = [ [[package]] name = "num-conv" -version = "0.1.0" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" +checksum = "cf97ec579c3c42f953ef76dbf8d55ac91fb219dde70e49aa4a6b7d74e9919050" [[package]] name = "num-integer" @@ -4268,45 +4260,48 @@ dependencies = [ "libm", ] -[[package]] -name = "number_prefix" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" - [[package]] name = "objc2-core-foundation" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c10c2894a6fed806ade6027bcd50662746363a9589d3ec9d9bef30a4e4bc166" +checksum = "2a180dd8642fa45cdb7dd721cd4c11b1cadd4929ce112ebd8b9f5803cc79d536" dependencies = [ - "bitflags 2.9.1", + "bitflags", ] [[package]] name = "objc2-io-kit" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71c1c64d6120e51cd86033f67176b1cb66780c2efe34dec55176f77befd93c0a" +checksum = "33fafba39597d6dc1fb709123dfa8289d39406734be322956a69f0931c73bb15" dependencies = [ "libc", "objc2-core-foundation", ] +[[package]] +name = "objc2-system-configuration" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7216bd11cbda54ccabcab84d523dc93b858ec75ecfb3a7d89513fa22464da396" +dependencies = [ + "objc2-core-foundation", +] + [[package]] name = "object" -version = "0.36.7" +version = "0.37.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62948e14d923ea95ea2c7c86c71013138b66525b86bdc08d2dcc262bdb497b87" +checksum = "ff76201f031d8863c38aa7f905eca4f53abbfa15f609db4277d44cd8938f33fe" dependencies = [ "memchr", ] [[package]] name = "object_store" -version = "0.12.1" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d94ac16b433c0ccf75326388c893d2835ab7457ea35ab8ba5d745c053ef5fa16" +checksum = "c2858065e55c148d294a9f3aae3b0fa9458edadb41a108397094566f4e3c0dfb" dependencies = [ "async-trait", "base64 0.22.1", @@ -4314,7 +4309,7 @@ dependencies = [ "chrono", "form_urlencoded", "futures", - "http 1.3.1", + "http 1.4.0", "http-body-util", "humantime", "hyper", @@ -4323,14 +4318,14 @@ dependencies = [ "parking_lot", "percent-encoding", "quick-xml", - "rand 0.9.1", + "rand 0.9.2", "reqwest", "ring", - "rustls-pemfile", + "rustls-pki-types", "serde", "serde_json", "serde_urlencoded", - "thiserror 2.0.12", + "thiserror", "tokio", "tracing", "url", @@ -4341,9 +4336,15 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.21.3" +version = "1.21.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" + +[[package]] +name = "once_cell_polyfill" +version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" [[package]] name = "oorandom" @@ -4353,9 +4354,9 @@ checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" [[package]] name = "openssl-probe" -version = "0.1.6" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" +checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" [[package]] name = "option-ext" @@ -4379,22 +4380,26 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e" [[package]] -name = "overload" -version = "0.1.1" +name = "owo-colors" +version = "4.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +checksum = "d211803b9b6b570f68772237e415a029d5a50c65d382910b879fb19d3271f94d" [[package]] -name = "owo-colors" -version = "4.2.1" +name = "page_size" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26995317201fa17f3656c36716aed4a7c81743a9634ac4c99c0eeda495db0cec" +checksum = "30d5b2194ed13191c1999ae0704b7839fb18384fa22e49b57eeaa97d79ce40da" +dependencies = [ + "libc", + "winapi", +] [[package]] name = "parking_lot" -version = "0.12.3" +version = "0.12.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" dependencies = [ "lock_api", "parking_lot_core", @@ -4402,27 +4407,26 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.10" +version = "0.9.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" dependencies = [ "cfg-if", "libc", - "redox_syscall 0.5.12", + "redox_syscall 0.5.18", "smallvec", - "windows-targets 0.52.6", + "windows-link", ] [[package]] name = "parquet" -version = "55.1.0" +version = "58.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be7b2d778f6b841d37083ebdf32e33a524acde1266b5884a8ca29bf00dfa1231" +checksum = "3f491d0ef1b510194426ee67ddc18a9b747ef3c42050c19322a2cd2e1666c29b" dependencies = [ - "ahash 0.8.12", + "ahash", "arrow-array", "arrow-buffer", - "arrow-cast", "arrow-data", "arrow-ipc", "arrow-schema", @@ -4434,18 +4438,20 @@ dependencies = [ "flate2", "futures", "half", - "hashbrown 0.15.3", + "hashbrown 0.16.1", "lz4_flex", - "num", "num-bigint", + "num-integer", + "num-traits", "object_store", "paste", + "ring", "seq-macro", "simdutf8", "snap", "thrift", "tokio", - "twox-hash 2.1.0", + "twox-hash", "zstd", ] @@ -4471,16 +4477,7 @@ dependencies = [ "regex", "regex-syntax", "structmeta", - "syn 2.0.101", -] - -[[package]] -name = "parse-zoneinfo" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f2a05b18d44e2957b88f96ba460715e295bc1d7510468a2f3d3b44535d26c24" -dependencies = [ - "regex", + "syn 2.0.117", ] [[package]] @@ -4491,36 +4488,58 @@ checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" [[package]] name = "pbjson" -version = "0.7.0" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7e6349fa080353f4a597daffd05cb81572a9c031a6d4fff7e504947496fcc68" +checksum = "898bac3fa00d0ba57a4e8289837e965baa2dee8c3749f3b11d45a64b4223d9c3" dependencies = [ - "base64 0.21.7", + "base64 0.22.1", + "serde", +] + +[[package]] +name = "pbjson" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8edd1efdd8ab23ba9cb9ace3d9987a72663d5d7c9f74fa00b51d6213645cf6c" +dependencies = [ + "base64 0.22.1", "serde", ] [[package]] name = "pbjson-build" -version = "0.7.0" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6eea3058763d6e656105d1403cb04e0a41b7bbac6362d413e7c33be0c32279c9" +checksum = "af22d08a625a2213a78dbb0ffa253318c5c79ce3133d32d296655a7bdfb02095" dependencies = [ - "heck 0.5.0", - "itertools 0.13.0", + "heck", + "itertools 0.14.0", + "prost", + "prost-types", +] + +[[package]] +name = "pbjson-build" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ed4d5c6ae95e08ac768883c8401cf0e8deb4e6e1d6a4e1fd3d2ec4f0ec63200" +dependencies = [ + "heck", + "itertools 0.14.0", "prost", "prost-types", ] [[package]] name = "pbjson-types" -version = "0.7.0" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e54e5e7bfb1652f95bc361d76f3c780d8e526b134b85417e774166ee941f0887" +checksum = "8e748e28374f10a330ee3bb9f29b828c0ac79831a32bab65015ad9b661ead526" dependencies = [ "bytes", "chrono", - "pbjson", - "pbjson-build", + "pbjson 0.8.0", + "pbjson-build 0.8.0", "prost", "prost-build", "serde", @@ -4528,95 +4547,84 @@ dependencies = [ [[package]] name = "percent-encoding" -version = "2.3.1" +version = "2.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" [[package]] name = "petgraph" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" -dependencies = [ - "fixedbitset", - "indexmap 2.9.0", -] - -[[package]] -name = "petgraph" -version = "0.8.1" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a98c6720655620a521dcc722d0ad66cd8afd5d86e34a89ef691c50b7b24de06" +checksum = "8701b58ea97060d5e5b155d383a69952a60943f0e6dfe30b04c287beb0b27455" dependencies = [ "fixedbitset", - "hashbrown 0.15.3", - "indexmap 2.9.0", + "hashbrown 0.15.5", + "indexmap 2.13.0", "serde", ] [[package]] name = "phf" -version = "0.11.3" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078" +checksum = "913273894cec178f401a31ec4b656318d95473527be05c0752cc41cdc32be8b7" dependencies = [ - "phf_shared", + "phf_shared 0.12.1", ] [[package]] -name = "phf_codegen" -version = "0.11.3" +name = "phf" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aef8048c789fa5e851558d709946d6d79a8ff88c0440c587967f8e94bfb1216a" +checksum = "c1562dc717473dbaa4c1f85a36410e03c047b2e7df7f45ee938fbef64ae7fadf" dependencies = [ - "phf_generator", - "phf_shared", + "phf_shared 0.13.1", + "serde", ] [[package]] -name = "phf_generator" -version = "0.11.3" +name = "phf_shared" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d" +checksum = "06005508882fb681fd97892ecff4b7fd0fee13ef1aa569f8695dae7ab9099981" dependencies = [ - "phf_shared", - "rand 0.8.5", + "siphasher", ] [[package]] name = "phf_shared" -version = "0.11.3" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5" +checksum = "e57fef6bc5981e38c2ce2d63bfa546861309f875b8a75f092d1d54ae2d64f266" dependencies = [ "siphasher", ] [[package]] name = "pin-project" -version = "1.1.10" +version = "1.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "677f1add503faace112b9f1373e43e9e054bfdd22ff1a63c1bc485eaec6a6a8a" +checksum = "f1749c7ed4bcaf4c3d0a3efc28538844fb29bcdd7d2b67b2be7e20ba861ff517" dependencies = [ "pin-project-internal", ] [[package]] name = "pin-project-internal" -version = "1.1.10" +version = "1.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861" +checksum = "d9b20ed30f105399776b9c883e68e536ef602a16ae6f596d2c473591d6ad64c6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.117", ] [[package]] name = "pin-project-lite" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" +checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" [[package]] name = "pin-utils" @@ -4630,6 +4638,12 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +[[package]] +name = "plain" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4596b6d070b27117e987119b4dac604f3c58cfb0b191112e24771b2faeac1a6" + [[package]] name = "plotters" version = "0.3.7" @@ -4660,36 +4674,36 @@ dependencies = [ [[package]] name = "portable-atomic" -version = "1.11.0" +version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e" +checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" [[package]] name = "portable-atomic-util" -version = "0.2.4" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +checksum = "091397be61a01d4be58e7841595bd4bfedb15f1cd54977d79b8271e94ed799a3" dependencies = [ "portable-atomic", ] [[package]] name = "postgres-derive" -version = "0.4.6" +version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69700ea4603c5ef32d447708e6a19cd3e8ac197a000842e97f527daea5e4175f" +checksum = "56df96f5394370d1b20e49de146f9e6c25aa9ae750f449c9d665eafecb3ccae6" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.117", ] [[package]] name = "postgres-protocol" -version = "0.6.8" +version = "0.6.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76ff0abab4a9b844b93ef7b81f1efc0a366062aaef2cd702c76256b5dc075c54" +checksum = "3ee9dd5fe15055d2b6806f4736aa0c9637217074e224bbec46d4041b91bb9491" dependencies = [ "base64 0.22.1", "byteorder", @@ -4698,16 +4712,16 @@ dependencies = [ "hmac", "md-5", "memchr", - "rand 0.9.1", + "rand 0.9.2", "sha2", "stringprep", ] [[package]] name = "postgres-types" -version = "0.2.9" +version = "0.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "613283563cd90e1dfc3518d548caee47e0e725455ed619881f5cf21f36de4b48" +checksum = "54b858f82211e84682fecd373f68e1ceae642d8d751a1ebd13f33de6257b3e20" dependencies = [ "bytes", "chrono", @@ -4718,9 +4732,9 @@ dependencies = [ [[package]] name = "potential_utf" -version = "0.1.2" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5a7c30837279ca13e7c867e9e40053bc68740f988cb07f7ca6df43cc734b585" +checksum = "b73949432f5e2a09657003c25bca5e19a0e9c84f8058ca374f49e0ebe605af77" dependencies = [ "zerovec", ] @@ -4741,92 +4755,48 @@ dependencies = [ ] [[package]] -name = "predicates" -version = "3.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5d19ee57562043d37e82899fade9a22ebab7be9cef5026b07fda9cdd4293573" -dependencies = [ - "anstyle", - "difflib", - "float-cmp", - "normalize-line-endings", - "predicates-core", - "regex", -] - -[[package]] -name = "predicates-core" -version = "1.0.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "727e462b119fe9c93fd0eb1429a5f7647394014cf3c04ab2c0350eeb09095ffa" - -[[package]] -name = "predicates-tree" -version = "1.0.12" +name = "pretty_assertions" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72dd2d6d381dfb73a193c7fca536518d7caee39fc8503f74e7dc0be0531b425c" +checksum = "3ae130e2f271fbc2ac3a40fb1d07180839cdbbe443c7a27e1e3c13c5cac0116d" dependencies = [ - "predicates-core", - "termtree", + "diff", + "yansi", ] [[package]] name = "prettyplease" -version = "0.2.32" +version = "0.2.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "664ec5419c51e34154eec046ebcba56312d5a2fc3b09a06da188e1ad21afadf6" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" dependencies = [ "proc-macro2", - "syn 2.0.101", + "syn 2.0.117", ] [[package]] name = "proc-macro-crate" -version = "3.3.0" +version = "3.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edce586971a4dfaa28950c6f18ed55e0406c1ab88bbce2c6f6293a7aaba73d35" +checksum = "e67ba7e9b2b56446f1d419b1d807906278ffa1a658a8a5d8a39dcb1f5a78614f" dependencies = [ "toml_edit", ] -[[package]] -name = "proc-macro-error" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" -dependencies = [ - "proc-macro-error-attr", - "proc-macro2", - "quote", - "syn 1.0.109", - "version_check", -] - -[[package]] -name = "proc-macro-error-attr" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" -dependencies = [ - "proc-macro2", - "quote", - "version_check", -] - [[package]] name = "proc-macro2" -version = "1.0.95" +version = "1.0.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" dependencies = [ "unicode-ident", ] [[package]] name = "prost" -version = "0.13.5" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5" +checksum = "d2ea70524a2f82d518bce41317d0fae74151505651af45faf1ffbd6fd33f0568" dependencies = [ "bytes", "prost-derive", @@ -4834,42 +4804,41 @@ dependencies = [ [[package]] name = "prost-build" -version = "0.13.5" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be769465445e8c1474e9c5dac2018218498557af32d9ed057325ec9a41ae81bf" +checksum = "343d3bd7056eda839b03204e68deff7d1b13aba7af2b2fd16890697274262ee7" dependencies = [ - "heck 0.5.0", + "heck", "itertools 0.14.0", "log", "multimap", - "once_cell", - "petgraph 0.7.1", + "petgraph", "prettyplease", "prost", "prost-types", "regex", - "syn 2.0.101", + "syn 2.0.117", "tempfile", ] [[package]] name = "prost-derive" -version = "0.13.5" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" +checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b" dependencies = [ "anyhow", "itertools 0.14.0", "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.117", ] [[package]] name = "prost-types" -version = "0.13.5" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52c2c1bf36ddb1a1c396b3601a3cec27c2462e45f07c386894ec3ccf5332bd16" +checksum = "8991c4cbdb8bc5b11f0b074ffe286c30e523de90fee5ba8132f1399f23cb3dd7" dependencies = [ "prost", ] @@ -4885,96 +4854,14 @@ dependencies = [ [[package]] name = "psm" -version = "0.1.26" +version = "0.1.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e944464ec8536cd1beb0bbfd96987eb5e3b72f2ecdafdc5c769a37f1fa2ae1f" +checksum = "3852766467df634d74f0b2d7819bf8dc483a0eb2e3b0f50f756f9cfe8b0d18d8" dependencies = [ + "ar_archive_writer", "cc", ] -[[package]] -name = "ptr_meta" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0738ccf7ea06b608c10564b31debd4f5bc5e197fc8bfe088f68ae5ce81e7a4f1" -dependencies = [ - "ptr_meta_derive", -] - -[[package]] -name = "ptr_meta_derive" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16b845dbfca988fa33db069c0e230574d15a3088f147a87b64c7589eb662c9ac" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - -[[package]] -name = "pyo3" -version = "0.24.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5203598f366b11a02b13aa20cab591229ff0a89fd121a308a5df751d5fc9219" -dependencies = [ - "cfg-if", - "indoc", - "libc", - "memoffset", - "once_cell", - "portable-atomic", - "pyo3-build-config", - "pyo3-ffi", - "pyo3-macros", - "unindent", -] - -[[package]] -name = "pyo3-build-config" -version = "0.24.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99636d423fa2ca130fa5acde3059308006d46f98caac629418e53f7ebb1e9999" -dependencies = [ - "once_cell", - "target-lexicon", -] - -[[package]] -name = "pyo3-ffi" -version = "0.24.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78f9cf92ba9c409279bc3305b5409d90db2d2c22392d443a87df3a1adad59e33" -dependencies = [ - "libc", - "pyo3-build-config", -] - -[[package]] -name = "pyo3-macros" -version = "0.24.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b999cb1a6ce21f9a6b147dcf1be9ffedf02e0043aec74dc390f3007047cecd9" -dependencies = [ - "proc-macro2", - "pyo3-macros-backend", - "quote", - "syn 2.0.101", -] - -[[package]] -name = "pyo3-macros-backend" -version = "0.24.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "822ece1c7e1012745607d5cf0bcb2874769f0f7cb34c4cde03b9358eb9ef911a" -dependencies = [ - "heck 0.5.0", - "proc-macro2", - "pyo3-build-config", - "quote", - "syn 2.0.101", -] - [[package]] name = "quad-rand" version = "0.2.3" @@ -4983,9 +4870,9 @@ checksum = "5a651516ddc9168ebd67b24afd085a718be02f8858fe406591b013d101ce2f40" [[package]] name = "quick-xml" -version = "0.37.5" +version = "0.38.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "331e97a1af0bf59823e6eadffe373d7b27f485be8748f71471c662c1f269b7fb" +checksum = "b66c2058c55a409d601666cffe35f04333cf1013010882cec174a7467cd4e21c" dependencies = [ "memchr", "serde", @@ -4993,19 +4880,19 @@ dependencies = [ [[package]] name = "quinn" -version = "0.11.8" +version = "0.11.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "626214629cda6781b6dc1d316ba307189c85ba657213ce642d9c77670f8202c8" +checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20" dependencies = [ "bytes", "cfg_aliases", "pin-project-lite", "quinn-proto", "quinn-udp", - "rustc-hash 2.1.1", + "rustc-hash", "rustls", "socket2", - "thiserror 2.0.12", + "thiserror", "tokio", "tracing", "web-time", @@ -5013,20 +4900,20 @@ dependencies = [ [[package]] name = "quinn-proto" -version = "0.11.12" +version = "0.11.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49df843a9161c85bb8aae55f101bc0bac8bcafd637a620d9122fd7e0b2f7422e" +checksum = "434b42fec591c96ef50e21e886936e66d3cc3f737104fdb9b737c40ffb94c098" dependencies = [ "bytes", - "getrandom 0.3.3", + "getrandom 0.3.4", "lru-slab", - "rand 0.9.1", + "rand 0.9.2", "ring", - "rustc-hash 2.1.1", + "rustc-hash", "rustls", "rustls-pki-types", "slab", - "thiserror 2.0.12", + "thiserror", "tinyvec", "tracing", "web-time", @@ -5034,38 +4921,38 @@ dependencies = [ [[package]] name = "quinn-udp" -version = "0.5.12" +version = "0.5.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee4e529991f949c5e25755532370b8af5d114acae52326361d68d47af64aa842" +checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd" dependencies = [ "cfg_aliases", "libc", "once_cell", "socket2", "tracing", - "windows-sys 0.59.0", + "windows-sys 0.60.2", ] [[package]] name = "quote" -version = "1.0.40" +version = "1.0.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" dependencies = [ "proc-macro2", ] [[package]] name = "r-efi" -version = "5.2.0" +version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74765f6d916ee2faa39bc8e68e4f3ed8949b48cccdac59983d287a7cb71ce9c5" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" [[package]] -name = "radium" -version = "0.7.0" +name = "r-efi" +version = "6.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" +checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" [[package]] name = "radix_trie" @@ -5090,12 +4977,12 @@ dependencies = [ [[package]] name = "rand" -version = "0.9.1" +version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" dependencies = [ "rand_chacha 0.9.0", - "rand_core 0.9.3", + "rand_core 0.9.5", ] [[package]] @@ -5115,7 +5002,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" dependencies = [ "ppv-lite86", - "rand_core 0.9.3", + "rand_core 0.9.5", ] [[package]] @@ -5124,16 +5011,16 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom 0.2.16", + "getrandom 0.2.17", ] [[package]] name = "rand_core" -version = "0.9.3" +version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" +checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" dependencies = [ - "getrandom 0.3.3", + "getrandom 0.3.4", ] [[package]] @@ -5143,14 +5030,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463" dependencies = [ "num-traits", - "rand 0.9.1", + "rand 0.9.2", ] [[package]] name = "rayon" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" dependencies = [ "either", "rayon-core", @@ -5158,9 +5045,9 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.12.1" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" dependencies = [ "crossbeam-deque", "crossbeam-utils", @@ -5183,43 +5070,63 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" dependencies = [ "quote", - "syn 2.0.101", + "syn 2.0.117", ] [[package]] name = "redox_syscall" -version = "0.3.5" +version = "0.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" dependencies = [ - "bitflags 1.3.2", + "bitflags", ] [[package]] name = "redox_syscall" -version = "0.5.12" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "928fca9cf2aa042393a8325b9ead81d2f0df4cb12e1e24cef072922ccd99c5af" +checksum = "6ce70a74e890531977d37e532c34d45e9055d2409ed08ddba14529471ed0be16" dependencies = [ - "bitflags 2.9.1", + "bitflags", ] [[package]] name = "redox_users" -version = "0.5.0" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd6f9d3d47bdd2ad6945c5015a226ec6155d0bcdfd8f7cd29f86b71f8de99d2b" +checksum = "a4e608c6638b9c18977b00b475ac1f28d14e84b27d8d42f70e0bf1e3dec127ac" dependencies = [ - "getrandom 0.2.16", + "getrandom 0.2.17", "libredox", - "thiserror 2.0.12", + "thiserror", +] + +[[package]] +name = "ref-cast" +version = "1.0.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f354300ae66f76f1c85c5f84693f0ce81d747e2c3f21a45fef496d89c960bf7d" +dependencies = [ + "ref-cast-impl", +] + +[[package]] +name = "ref-cast-impl" +version = "1.0.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7186006dcb21920990093f30e3dea63b7d6e977bf1256be20c3563a5db070da" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", ] [[package]] name = "regex" -version = "1.11.1" +version = "1.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" dependencies = [ "aho-corasick", "memchr", @@ -5229,9 +5136,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.9" +version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" dependencies = [ "aho-corasick", "memchr", @@ -5240,23 +5147,23 @@ dependencies = [ [[package]] name = "regex-lite" -version = "0.1.6" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a" +checksum = "cab834c73d247e67f4fae452806d17d3c7501756d98c8808d7c9c7aa7d18f973" [[package]] name = "regex-syntax" -version = "0.8.5" +version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" +checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" [[package]] name = "regress" -version = "0.10.3" +version = "0.10.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78ef7fa9ed0256d64a688a3747d0fef7a88851c18a5e1d57f115f38ec2e09366" +checksum = "2057b2325e68a893284d1538021ab90279adac1139957ca2a74426c6f118fb48" dependencies = [ - "hashbrown 0.15.3", + "hashbrown 0.16.1", "memchr", ] @@ -5266,15 +5173,6 @@ version = "1.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2" -[[package]] -name = "rend" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71fe3824f5629716b1589be05dacd749f6aa084c87e00e016714a8cdfccc997c" -dependencies = [ - "bytecheck", -] - [[package]] name = "repr_offset" version = "0.2.2" @@ -5286,32 +5184,28 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.12.15" +version = "0.12.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d19c46a6fdd48bc4dab94b6103fccc55d34c67cc0ad04653aad4ea2a07cd7bbb" +checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147" dependencies = [ "base64 0.22.1", "bytes", "futures-core", "futures-util", "h2", - "http 1.3.1", + "http 1.4.0", "http-body 1.0.1", "http-body-util", "hyper", "hyper-rustls", "hyper-util", - "ipnet", "js-sys", "log", - "mime", - "once_cell", "percent-encoding", "pin-project-lite", "quinn", "rustls", "rustls-native-certs", - "rustls-pemfile", "rustls-pki-types", "serde", "serde_json", @@ -5320,14 +5214,14 @@ dependencies = [ "tokio", "tokio-rustls", "tokio-util", - "tower 0.5.2", + "tower", + "tower-http", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", "wasm-streams", "web-sys", - "windows-registry", ] [[package]] @@ -5338,64 +5232,28 @@ checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" dependencies = [ "cc", "cfg-if", - "getrandom 0.2.16", + "getrandom 0.2.17", "libc", "untrusted", "windows-sys 0.52.0", ] -[[package]] -name = "rkyv" -version = "0.7.45" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9008cd6385b9e161d8229e1f6549dd23c3d022f132a2ea37ac3a10ac4935779b" -dependencies = [ - "bitvec", - "bytecheck", - "bytes", - "hashbrown 0.12.3", - "ptr_meta", - "rend", - "rkyv_derive", - "seahash", - "tinyvec", - "uuid", -] - -[[package]] -name = "rkyv_derive" -version = "0.7.45" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "503d1d27590a2b0a3a4ca4c94755aa2875657196ecbf401a42eff41d7de532c0" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - -[[package]] -name = "rle-decode-fast" -version = "1.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3582f63211428f83597b51b2ddb88e2a91a9d52d12831f9d08f5e624e8977422" - [[package]] name = "rstest" -version = "0.25.0" +version = "0.26.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fc39292f8613e913f7df8fa892b8944ceb47c247b78e1b1ae2f09e019be789d" +checksum = "f5a3193c063baaa2a95a33f03035c8a72b83d97a54916055ba22d35ed3839d49" dependencies = [ "futures-timer", "futures-util", "rstest_macros", - "rustc_version", ] [[package]] name = "rstest_macros" -version = "0.25.0" +version = "0.26.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f168d99749d307be9de54d23fd226628d99768225ef08f6ffb52e0182a27746" +checksum = "9c845311f0ff7951c5506121a9ad75aec44d083c31583b2ea5a30bcb0b0abba0" dependencies = [ "cfg-if", "glob", @@ -5405,7 +5263,7 @@ dependencies = [ "regex", "relative-path", "rustc_version", - "syn 2.0.101", + "syn 2.0.117", "unicode-ident", ] @@ -5417,38 +5275,9 @@ checksum = "b3a8fb4672e840a587a66fc577a5491375df51ddb88f2a2c2a792598c326fe14" dependencies = [ "quote", "rand 0.8.5", - "syn 2.0.101", -] - -[[package]] -name = "rust_decimal" -version = "1.37.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "faa7de2ba56ac291bd90c6b9bece784a52ae1411f9506544b3eae36dd2356d50" -dependencies = [ - "arrayvec", - "borsh", - "bytes", - "num-traits", - "postgres-types", - "rand 0.8.5", - "rkyv", - "serde", - "serde_json", + "syn 2.0.117", ] -[[package]] -name = "rustc-demangle" -version = "0.1.24" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" - -[[package]] -name = "rustc-hash" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" - [[package]] name = "rustc-hash" version = "2.1.1" @@ -5466,37 +5295,25 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.44" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" -dependencies = [ - "bitflags 2.9.1", - "errno", - "libc", - "linux-raw-sys 0.4.15", - "windows-sys 0.59.0", -] - -[[package]] -name = "rustix" -version = "1.0.7" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c71e83d6afe7ff64890ec6b71d6a69bb8a610ab78ce364b3352876bb4c801266" +checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190" dependencies = [ - "bitflags 2.9.1", + "bitflags", "errno", "libc", - "linux-raw-sys 0.9.4", - "windows-sys 0.59.0", + "linux-raw-sys", + "windows-sys 0.61.2", ] [[package]] name = "rustls" -version = "0.23.27" +version = "0.23.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "730944ca083c1c233a75c09f199e973ca499344a2b7ba9e755c457e86fb4a321" +checksum = "758025cb5fccfd3bc2fd74708fd4682be41d99e5dff73c377c0646c6012c73a4" dependencies = [ "aws-lc-rs", + "log", "once_cell", "ring", "rustls-pki-types", @@ -5507,9 +5324,9 @@ dependencies = [ [[package]] name = "rustls-native-certs" -version = "0.8.1" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fcff2dd52b58a8d98a70243663a0d234c4e2b79235637849d15913394a247d3" +checksum = "612460d5f7bea540c490b2b6395d8e34a953e52b491accd6c86c8164c5932a63" dependencies = [ "openssl-probe", "rustls-pki-types", @@ -5517,20 +5334,11 @@ dependencies = [ "security-framework", ] -[[package]] -name = "rustls-pemfile" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" -dependencies = [ - "rustls-pki-types", -] - [[package]] name = "rustls-pki-types" -version = "1.12.0" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "229a4a4c221013e7e1f1a043678c5cc39fe5171437c88fb47151a21e6f5b5c79" +checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" dependencies = [ "web-time", "zeroize", @@ -5538,9 +5346,9 @@ dependencies = [ [[package]] name = "rustls-webpki" -version = "0.103.3" +version = "0.103.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4a72fe2bcf7a6ac6fd7d0b9e5cb68aeb7d4c0a0271730218b3e92d43b4eb435" +checksum = "d7df23109aa6c1567d1c575b9952556388da57401e4ace1d15f79eedad0d8f53" dependencies = [ "aws-lc-rs", "ring", @@ -5550,17 +5358,17 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.20" +version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eded382c5f5f786b989652c49544c4877d9f015cc22e145a5ea8ea66c2921cd2" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" [[package]] name = "rustyline" -version = "16.0.0" +version = "17.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62fd9ca5ebc709e8535e8ef7c658eb51457987e48c98ead2be482172accc408d" +checksum = "e902948a25149d50edc1a8e0141aad50f54e22ba83ff988cf8f7c9ef07f50564" dependencies = [ - "bitflags 2.9.1", + "bitflags", "cfg-if", "clipboard-win", "fd-lock", @@ -5568,19 +5376,19 @@ dependencies = [ "libc", "log", "memchr", - "nix", + "nix 0.30.1", "radix_trie", "unicode-segmentation", - "unicode-width 0.2.0", + "unicode-width 0.2.2", "utf8parse", - "windows-sys 0.59.0", + "windows-sys 0.60.2", ] [[package]] name = "ryu" -version = "1.0.20" +version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" +checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" [[package]] name = "same-file" @@ -5593,11 +5401,11 @@ dependencies = [ [[package]] name = "schannel" -version = "0.1.27" +version = "0.1.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d" +checksum = "91c1b7e4904c873ef0710c1f407dde2e6287de2bebc1bbbf7d430bb7cbffd939" dependencies = [ - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -5612,6 +5420,30 @@ dependencies = [ "serde_json", ] +[[package]] +name = "schemars" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cd191f9397d57d581cddd31014772520aa448f65ef991055d7f61582c65165f" +dependencies = [ + "dyn-clone", + "ref-cast", + "serde", + "serde_json", +] + +[[package]] +name = "schemars" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2b42f36aa1cd011945615b92222f6bf73c599a102a300334cd7f8dbeec726cc" +dependencies = [ + "dyn-clone", + "ref-cast", + "serde", + "serde_json", +] + [[package]] name = "schemars_derive" version = "0.8.22" @@ -5621,7 +5453,7 @@ dependencies = [ "proc-macro2", "quote", "serde_derive_internals", - "syn 2.0.101", + "syn 2.0.117", ] [[package]] @@ -5630,19 +5462,13 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" -[[package]] -name = "seahash" -version = "4.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c107b6f4780854c8b126e228ea8869f4d7b71260f962fefb57b996b8959ba6b" - [[package]] name = "security-framework" -version = "3.2.0" +version = "3.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "271720403f46ca04f7ba6f55d438f8bd878d6b8ca0a1046e8228c4145bcbb316" +checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d" dependencies = [ - "bitflags 2.9.1", + "bitflags", "core-foundation", "core-foundation-sys", "libc", @@ -5651,9 +5477,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.14.0" +version = "2.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32" +checksum = "6ce2691df843ecc5d231c0b14ece2acc3efb62c0a398c7e1d875f3983ce020e3" dependencies = [ "core-foundation-sys", "libc", @@ -5661,11 +5487,12 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.26" +version = "1.0.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" dependencies = [ "serde", + "serde_core", ] [[package]] @@ -5676,31 +5503,42 @@ checksum = "1bc711410fbe7399f390ca1c3b60ad0f53f80e95c5eb935e52268a0e2cd49acc" [[package]] name = "serde" -version = "1.0.219" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" dependencies = [ + "serde_core", "serde_derive", ] [[package]] name = "serde_bytes" -version = "0.11.17" +version = "0.11.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8437fd221bde2d4ca316d61b90e337e9e702b3820b87d63caa9ba6c02bd06d96" +checksum = "a5d440709e79d88e51ac01c4b72fc6cb7314017bb7da9eeff678aa94c10e3ea8" dependencies = [ "serde", + "serde_core", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.219" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.117", ] [[package]] @@ -5711,19 +5549,21 @@ checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.117", ] [[package]] name = "serde_json" -version = "1.0.140" +version = "1.0.149" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" dependencies = [ + "indexmap 2.13.0", "itoa", "memchr", - "ryu", "serde", + "serde_core", + "zmij", ] [[package]] @@ -5734,19 +5574,19 @@ checksum = "175ee3e80ae9982737ca543e96133087cbd9a485eecc3bc4de9c1a37b47ea59c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.117", ] [[package]] name = "serde_tokenstream" -version = "0.2.2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64060d864397305347a78851c51588fd283767e7e7589829e8121d65512340f1" +checksum = "d7c49585c52c01f13c5c2ebb333f14f6885d76daa768d8a037d28017ec538c69" dependencies = [ "proc-macro2", "quote", "serde", - "syn 2.0.101", + "syn 2.0.117", ] [[package]] @@ -5763,17 +5603,18 @@ dependencies = [ [[package]] name = "serde_with" -version = "3.12.0" +version = "3.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6b6f7f2fcb69f747921f79f3926bd1e203fce4fef62c268dd3abfb6d86029aa" +checksum = "dd5414fad8e6907dbdd5bc441a50ae8d6e26151a03b1de04d89a5576de61d01f" dependencies = [ "base64 0.22.1", "chrono", "hex", "indexmap 1.9.3", - "indexmap 2.9.0", - "serde", - "serde_derive", + "indexmap 2.13.0", + "schemars 0.9.0", + "schemars 1.2.1", + "serde_core", "serde_json", "serde_with_macros", "time", @@ -5781,14 +5622,14 @@ dependencies = [ [[package]] name = "serde_with_macros" -version = "3.12.0" +version = "3.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d00caa5193a3c8362ac2b73be6b9e768aa5a4b2f721d8f4b339600c3cb51f8e" +checksum = "d3db8978e608f1fe7357e211969fd9abdcae80bac1ba7a3369bb7eb6b404eb65" dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.117", ] [[package]] @@ -5797,13 +5638,24 @@ version = "0.9.34+deprecated" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" dependencies = [ - "indexmap 2.9.0", + "indexmap 2.13.0", "itoa", "ryu", "serde", "unsafe-libyaml", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sha2" version = "0.10.9" @@ -5832,13 +5684,20 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] name = "signal-hook-registry" -version = "1.4.5" +version = "1.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9203b8055f63a2a00e2f593bb0510367fe707d7ff1e5c872de2f537b339e5410" +checksum = "c4db69cba1110affc0e9f7bcd48bbf87b3f4fc7c61fc9155afd4c469eb3d6c1b" dependencies = [ + "errno", "libc", ] +[[package]] +name = "simd-adler32" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" + [[package]] name = "simdutf8" version = "0.1.5" @@ -5853,24 +5712,21 @@ checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" [[package]] name = "siphasher" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" +checksum = "b2aa850e253778c88a04c3d7323b043aeda9d3e30d5971937c1855769763678e" [[package]] name = "slab" -version = "0.4.9" +version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" -dependencies = [ - "autocfg", -] +checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" [[package]] name = "smallvec" -version = "1.15.0" +version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8917285742e9f3e1683f0a9c4e6b57960b7314d0b08d30d1ecd426713ee2eee9" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" [[package]] name = "snap" @@ -5898,19 +5754,19 @@ dependencies = [ [[package]] name = "socket2" -version = "0.5.9" +version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f5fd57c80058a56cf5c777ab8a126398ece8e442983605d280a44ce79d0edef" +checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] name = "sqllogictest" -version = "0.28.2" +version = "0.29.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94181af64007792bd1ab6d22023fbe86c2ccc50c1031b5bac554b5d057597e7b" +checksum = "d03b2262a244037b0b510edbd25a8e6c9fb8d73ee0237fc6cc95a54c16f94a82" dependencies = [ "async-trait", "educe", @@ -5927,15 +5783,15 @@ dependencies = [ "similar", "subst", "tempfile", - "thiserror 2.0.12", + "thiserror", "tracing", ] [[package]] name = "sqlparser" -version = "0.55.0" +version = "0.61.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4521174166bac1ff04fe16ef4524c70144cd29682a45978978ca3d7f4e0be11" +checksum = "dbf5ea8d4d7c808e1af1cbabebca9a2abe603bcefc22294c5b95018d53200cb7" dependencies = [ "log", "recursive", @@ -5944,26 +5800,26 @@ dependencies = [ [[package]] name = "sqlparser_derive" -version = "0.3.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da5fc6819faabb412da764b99d3b713bb55083c11e7e0c00144d386cd6a1939c" +checksum = "a6dd45d8fc1c79299bfbb7190e42ccbbdf6a5f52e4a6ad98d92357ea965bd289" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.117", ] [[package]] name = "stable_deref_trait" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" [[package]] name = "stacker" -version = "0.1.21" +version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cddb07e32ddb770749da91081d8d0ac3a16f1a569a18b20348cd371f5dead06b" +checksum = "08d74a23609d509411d10e2176dc2a4346e3b4aea2e7b1869f19fdedbc71c013" dependencies = [ "cc", "cfg-if", @@ -5972,12 +5828,6 @@ dependencies = [ "windows-sys 0.59.0", ] -[[package]] -name = "static_assertions" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" - [[package]] name = "stringprep" version = "0.1.5" @@ -6004,7 +5854,7 @@ dependencies = [ "proc-macro2", "quote", "structmeta-derive", - "syn 2.0.101", + "syn 2.0.117", ] [[package]] @@ -6015,50 +5865,43 @@ checksum = "152a0b65a590ff6c3da95cabe2353ee04e6167c896b28e3b14478c2636c922fc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.117", ] [[package]] -name = "structopt" -version = "0.3.26" +name = "strum" +version = "0.27.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c6b5c64445ba8094a6ab0c3cd2ad323e07171012d9c98b0b15651daf1787a10" -dependencies = [ - "clap 2.34.0", - "lazy_static", - "structopt-derive", -] +checksum = "af23d6f6c1a224baef9d3f61e287d2761385a5b88fdab4eb4c6f11aeb54c4bcf" [[package]] -name = "structopt-derive" -version = "0.4.18" +name = "strum" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9628de9b8791db39ceda2b119bbe13134770b56c138ec1d3af810d045c04f9bd" + +[[package]] +name = "strum_macros" +version = "0.27.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcb5ae327f9cc13b68763b5749770cb9e048a99bd9dfdfa58d0cf05d5f64afe0" +checksum = "7695ce3845ea4b33927c055a39dc438a45b059f7c1b3d91d38d10355fb8cbca7" dependencies = [ - "heck 0.3.3", - "proc-macro-error", + "heck", "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.117", ] -[[package]] -name = "strum" -version = "0.26.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" - [[package]] name = "strum_macros" -version = "0.26.4" +version = "0.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" +checksum = "ab85eea0270ee17587ed4156089e10b9e6880ee688791d45a905f5b1ca36f664" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", - "rustversion", - "syn 2.0.101", + "syn 2.0.117", ] [[package]] @@ -6073,13 +5916,14 @@ dependencies = [ [[package]] name = "substrait" -version = "0.56.0" +version = "0.63.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13de2e20128f2a018dab1cfa30be83ae069219a65968c6f89df66ad124de2397" +checksum = "e620ff4d5c02fd6f7752931aa74b16a26af66a63022cc1ad412c77edbe0bab47" dependencies = [ - "heck 0.5.0", - "pbjson", - "pbjson-build", + "heck", + "indexmap 2.13.0", + "pbjson 0.8.0", + "pbjson-build 0.8.0", "pbjson-types", "prettyplease", "prost", @@ -6087,12 +5931,12 @@ dependencies = [ "prost-types", "protobuf-src", "regress", - "schemars", + "schemars 0.8.22", "semver", "serde", "serde_json", "serde_yaml", - "syn 2.0.101", + "syn 2.0.117", "typify", "walkdir", ] @@ -6116,9 +5960,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.101" +version = "2.0.117" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ce2b7fc941b3a24138a0a7cf8e858bfc6a992e7978a068a5c760deb0ed43caf" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" dependencies = [ "proc-macro2", "quote", @@ -6142,14 +5986,14 @@ checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.117", ] [[package]] name = "sysinfo" -version = "0.35.1" +version = "0.38.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79251336d17c72d9762b8b54be4befe38d2db56fbbc0241396d70f173c39d47a" +checksum = "92ab6a2f8bfe508deb3c6406578252e491d299cbbf3bc0529ecc3313aee4a52f" dependencies = [ "libc", "memchr", @@ -6159,37 +6003,19 @@ dependencies = [ "windows", ] -[[package]] -name = "tap" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" - -[[package]] -name = "target-lexicon" -version = "0.13.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e502f78cdbb8ba4718f566c418c52bc729126ffd16baee5baa718cf25dd5a69a" - [[package]] name = "tempfile" -version = "3.20.0" +version = "3.27.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8a64e3985349f2441a1a9ef0b853f869006c3855f2cda6862a94d26ebb9d6a1" +checksum = "32497e9a4c7b38532efcdebeef879707aa9f794296a4f0244f6f69e9bc8574bd" dependencies = [ "fastrand", - "getrandom 0.3.3", + "getrandom 0.4.2", "once_cell", - "rustix 1.0.7", - "windows-sys 0.59.0", + "rustix", + "windows-sys 0.61.2", ] -[[package]] -name = "termtree" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f50febec83f5ee1df3015341d8bd429f2d1cc62bcba7ea2076759d315084683" - [[package]] name = "test-utils" version = "0.1.0" @@ -6198,23 +6024,26 @@ dependencies = [ "chrono-tz", "datafusion-common", "env_logger", - "rand 0.9.1", + "rand 0.9.2", ] [[package]] name = "testcontainers" -version = "0.24.0" +version = "0.27.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23bb7577dca13ad86a78e8271ef5d322f37229ec83b8d98da6d996c588a1ddb1" +checksum = "c1c0624faaa317c56d6d19136580be889677259caf5c897941c6f446b4655068" dependencies = [ + "astral-tokio-tar", "async-trait", "bollard", - "bollard-stubs", "bytes", "docker_credential", "either", "etcetera", + "ferroid", "futures", + "http 1.4.0", + "itertools 0.14.0", "log", "memchr", "parse-display", @@ -6222,80 +6051,49 @@ dependencies = [ "serde", "serde_json", "serde_with", - "thiserror 2.0.12", + "thiserror", "tokio", "tokio-stream", - "tokio-tar", "tokio-util", "url", ] [[package]] name = "testcontainers-modules" -version = "0.12.1" +version = "0.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eac95cde96549fc19c6bf19ef34cc42bd56e264c1cb97e700e21555be0ecf9e2" +checksum = "e5985fde5befe4ffa77a052e035e16c2da86e8bae301baa9f9904ad3c494d357" dependencies = [ "testcontainers", ] -[[package]] -name = "textwrap" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060" -dependencies = [ - "unicode-width 0.1.14", -] - [[package]] name = "thiserror" -version = "1.0.69" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" -dependencies = [ - "thiserror-impl 1.0.69", -] - -[[package]] -name = "thiserror" -version = "2.0.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" -dependencies = [ - "thiserror-impl 2.0.12", -] - -[[package]] -name = "thiserror-impl" -version = "1.0.69" +version = "2.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.101", + "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "2.0.12" +version = "2.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.117", ] [[package]] name = "thread_local" -version = "1.1.8" +version = "1.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c" +checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" dependencies = [ "cfg-if", - "once_cell", ] [[package]] @@ -6311,30 +6109,30 @@ dependencies = [ [[package]] name = "time" -version = "0.3.41" +version = "0.3.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a7619e19bc266e0f9c5e6686659d394bc57973859340060a69221e57dbc0c40" +checksum = "743bd48c283afc0388f9b8827b976905fb217ad9e647fae3a379a9283c4def2c" dependencies = [ "deranged", "itoa", "num-conv", "powerfmt", - "serde", + "serde_core", "time-core", "time-macros", ] [[package]] name = "time-core" -version = "0.1.4" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9e9a38711f559d9e3ce1cdb06dd7c5b8ea546bc90052da6d06bb76da74bb07c" +checksum = "7694e1cfe791f8d31026952abf09c69ca6f6fa4e1a1229e18988f06a04a12dca" [[package]] name = "time-macros" -version = "0.2.22" +version = "0.2.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3526739392ec93fd8b359c8e98514cb3e8e021beb4e5f597b00a0221f8ed8a49" +checksum = "2e70e4c5a0e0a8a4823ad65dfe1a6930e4f4d756dcd9dd7939022b5e8c501215" dependencies = [ "num-conv", "time-core", @@ -6351,9 +6149,9 @@ dependencies = [ [[package]] name = "tinystr" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d4f6d1145dcb577acf783d4e601bc1d76a13337bb54e6233add580b07344c8b" +checksum = "42d3e9c45c09de15d06dd8acf5f4e0e399e85927b7f00711024eb7ae10fa4869" dependencies = [ "displaydoc", "zerovec", @@ -6371,9 +6169,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.9.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09b3661f17e86524eccd4371ab0429194e0d7c008abb45f7a7495b1719463c71" +checksum = "3e61e67053d25a4e82c844e8424039d9745781b3fc4f32b8d55ed50f5f667ef3" dependencies = [ "tinyvec_macros", ] @@ -6386,11 +6184,10 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.45.1" +version = "1.50.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75ef51a33ef1da925cea3e4eb122833cb377c61439ca401b770f54902b806779" +checksum = "27ad5e34374e03cfffefc301becb44e9dc3c17584f414349ebe29ed26661822d" dependencies = [ - "backtrace", "bytes", "libc", "mio", @@ -6399,25 +6196,25 @@ dependencies = [ "signal-hook-registry", "socket2", "tokio-macros", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] name = "tokio-macros" -version = "2.5.0" +version = "2.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" +checksum = "5c55a2eff8b69ce66c84f85e1da1c233edc36ceb85a2058d11b0d6a3c7e7569c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.117", ] [[package]] name = "tokio-postgres" -version = "0.7.13" +version = "0.7.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c95d533c83082bb6490e0189acaa0bbeef9084e60471b696ca6988cd0541fb0" +checksum = "dcea47c8f71744367793f16c2db1f11cb859d28f436bdb4ca9193eb1f787ee42" dependencies = [ "async-trait", "byteorder", @@ -6428,11 +6225,11 @@ dependencies = [ "log", "parking_lot", "percent-encoding", - "phf", + "phf 0.13.1", "pin-project-lite", "postgres-protocol", "postgres-types", - "rand 0.9.1", + "rand 0.9.2", "socket2", "tokio", "tokio-util", @@ -6441,9 +6238,9 @@ dependencies = [ [[package]] name = "tokio-rustls" -version = "0.26.2" +version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e727b36a1a0e8b74c376ac2211e40c2c8af09fb4013c60d910495810f008e9b" +checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" dependencies = [ "rustls", "tokio", @@ -6451,35 +6248,21 @@ dependencies = [ [[package]] name = "tokio-stream" -version = "0.1.17" +version = "0.1.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047" +checksum = "32da49809aab5c3bc678af03902d4ccddea2a87d028d86392a4b1560c6906c70" dependencies = [ "futures-core", "pin-project-lite", "tokio", -] - -[[package]] -name = "tokio-tar" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d5714c010ca3e5c27114c1cdeb9d14641ace49874aa5626d7149e47aedace75" -dependencies = [ - "filetime", - "futures-core", - "libc", - "redox_syscall 0.3.5", - "tokio", - "tokio-stream", - "xattr", + "tokio-util", ] [[package]] name = "tokio-util" -version = "0.7.15" +version = "0.7.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "66a539a9ad6d5d281510d5bd368c973d636c02dbf8a67300bfb6b950696ad7df" +checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" dependencies = [ "bytes", "futures-core", @@ -6490,34 +6273,46 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "0.6.9" +version = "1.0.0+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3da5db5a963e24bc68be8b17b6fa82814bb22ee8660f192bb182771d498f09a3" +checksum = "32c2555c699578a4f59f0cc68e5116c8d7cabbd45e1409b989d4be085b53f13e" +dependencies = [ + "serde_core", +] [[package]] name = "toml_edit" -version = "0.22.26" +version = "0.25.4+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "310068873db2c5b3e7659d2cc35d21855dbafa50d1ce336397c666e3cb08137e" +checksum = "7193cbd0ce53dc966037f54351dbbcf0d5a642c7f0038c382ef9e677ce8c13f2" dependencies = [ - "indexmap 2.9.0", + "indexmap 2.13.0", "toml_datetime", + "toml_parser", + "winnow", +] + +[[package]] +name = "toml_parser" +version = "1.0.9+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "702d4415e08923e7e1ef96cd5727c0dfed80b4d2fa25db9647fe5eb6f7c5a4c4" +dependencies = [ "winnow", ] [[package]] name = "tonic" -version = "0.12.3" +version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "877c5b330756d856ffcc4553ab34a5684481ade925ecc54bcd1bf02b1d0d4d52" +checksum = "fec7c61a0695dc1887c1b53952990f3ad2e3a31453e1f49f10e75424943a93ec" dependencies = [ - "async-stream", "async-trait", "axum", "base64 0.22.1", "bytes", "h2", - "http 1.3.1", + "http 1.4.0", "http-body 1.0.1", "http-body-util", "hyper", @@ -6525,29 +6320,39 @@ dependencies = [ "hyper-util", "percent-encoding", "pin-project", - "prost", "socket2", + "sync_wrapper", "tokio", "tokio-stream", - "tower 0.4.13", + "tower", "tower-layer", "tower-service", "tracing", ] [[package]] -name = "tower" -version = "0.4.13" +name = "tonic-prost" +version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" +checksum = "a55376a0bbaa4975a3f10d009ad763d8f4108f067c7c2e74f3001fb49778d309" dependencies = [ - "futures-core", - "futures-util", - "indexmap 1.9.3", - "pin-project", + "bytes", + "prost", + "tonic", +] + +[[package]] +name = "tower" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4" +dependencies = [ + "futures-core", + "futures-util", + "indexmap 2.13.0", "pin-project-lite", - "rand 0.8.5", "slab", + "sync_wrapper", "tokio", "tokio-util", "tower-layer", @@ -6556,16 +6361,19 @@ dependencies = [ ] [[package]] -name = "tower" -version = "0.5.2" +name = "tower-http" +version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" +checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" dependencies = [ - "futures-core", + "bitflags", + "bytes", "futures-util", + "http 1.4.0", + "http-body 1.0.1", + "iri-string", "pin-project-lite", - "sync_wrapper", - "tokio", + "tower", "tower-layer", "tower-service", ] @@ -6584,9 +6392,9 @@ checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" -version = "0.1.41" +version = "0.1.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" dependencies = [ "pin-project-lite", "tracing-attributes", @@ -6595,20 +6403,20 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.28" +version = "0.1.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.117", ] [[package]] name = "tracing-core" -version = "0.1.33" +version = "0.1.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" dependencies = [ "once_cell", "valuable", @@ -6627,9 +6435,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.19" +version = "0.3.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" +checksum = "cb7f578e5945fb242538965c2d0b04418d38ec25c79d160cd279bf0731c8d319" dependencies = [ "nu-ansi-term", "sharded-slab", @@ -6662,19 +6470,9 @@ checksum = "e78122066b0cb818b8afd08f7ed22f7fdbc3e90815035726f0840d0d26c0747a" [[package]] name = "twox-hash" -version = "1.6.3" +version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675" -dependencies = [ - "cfg-if", - "static_assertions", -] - -[[package]] -name = "twox-hash" -version = "2.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7b17f197b3050ba473acf9181f7b1d3b66d1cf7356c6cc57886662276e65908" +checksum = "9ea3136b675547379c4bd395ca6b938e5ad3c3d20fad76e7fe85f9e0d011419c" [[package]] name = "typed-arena" @@ -6683,36 +6481,22 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6af6ae20167a9ece4bcb41af5b80f8a1f1df981f6391189ce00fd257af04126a" [[package]] -name = "typed-builder" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a06fbd5b8de54c5f7c91f6fe4cebb949be2125d7758e630bb58b1d831dbce600" -dependencies = [ - "typed-builder-macro", -] - -[[package]] -name = "typed-builder-macro" -version = "0.19.1" +name = "typenum" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9534daa9fd3ed0bd911d462a37f172228077e7abf18c18a5f67199d959205f8" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.101", -] +checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" [[package]] -name = "typenum" -version = "1.18.0" +name = "typewit" +version = "1.14.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" +checksum = "f8c1ae7cc0fdb8b842d65d127cb981574b0d2b249b74d1c7a2986863dc134f71" [[package]] name = "typify" -version = "0.4.1" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcc5bec3cdff70fd542e579aa2e52967833e543a25fae0d14579043d2e868a50" +checksum = "e6d5bcc6f62eb1fa8aa4098f39b29f93dcb914e17158b76c50360911257aa629" dependencies = [ "typify-impl", "typify-macro", @@ -6720,38 +6504,38 @@ dependencies = [ [[package]] name = "typify-impl" -version = "0.4.1" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b52a67305054e1da6f3d99ad94875dcd0c7c49adbd17b4b64f0eefb7ae5bf8ab" +checksum = "a1eb359f7ffa4f9ebe947fa11a1b2da054564502968db5f317b7e37693cb2240" dependencies = [ - "heck 0.5.0", + "heck", "log", "proc-macro2", "quote", "regress", - "schemars", + "schemars 0.8.22", "semver", "serde", "serde_json", - "syn 2.0.101", - "thiserror 2.0.12", + "syn 2.0.117", + "thiserror", "unicode-ident", ] [[package]] name = "typify-macro" -version = "0.4.1" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ff5799be156e4f635c348c6051d165e1c59997827155133351a8c4d333d9841" +checksum = "911c32f3c8514b048c1b228361bebb5e6d73aeec01696e8cc0e82e2ffef8ab7a" dependencies = [ "proc-macro2", "quote", - "schemars", + "schemars 0.8.22", "semver", "serde", "serde_json", "serde_tokenstream", - "syn 2.0.101", + "syn 2.0.117", "typify-impl", ] @@ -6763,24 +6547,24 @@ checksum = "5c1cb5db39152898a79168971543b1cb5020dff7fe43c8dc468b0885f5e29df5" [[package]] name = "unicode-ident" -version = "1.0.18" +version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" [[package]] name = "unicode-normalization" -version = "0.1.24" +version = "0.1.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5033c97c4262335cded6d6fc3e5c18ab755e1a3dc96376350f3d8e9f009ad956" +checksum = "5fd4f6878c9cb28d874b009da9e8d183b5abc80117c40bbd187a1fde336be6e8" dependencies = [ "tinyvec", ] [[package]] name = "unicode-properties" -version = "0.1.3" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e70f2a8b45122e719eb623c01822704c4e0907e7e426a05927e1a1cfff5b75d0" +checksum = "7df058c713841ad818f1dc5d3fd88063241cc61f49f5fbea4b951e8cf5a8d71d" [[package]] name = "unicode-segmentation" @@ -6796,15 +6580,21 @@ checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" [[package]] name = "unicode-width" -version = "0.2.0" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" +checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" [[package]] -name = "unindent" -version = "0.2.4" +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + +[[package]] +name = "unit-prefix" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" +checksum = "81e544489bf3d8ef66c953931f56617f423cd4b5494be343d9b9d3dda037b9a3" [[package]] name = "unsafe-libyaml" @@ -6818,16 +6608,44 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" +[[package]] +name = "ureq" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdc97a28575b85cfedf2a7e7d3cc64b3e11bd8ac766666318003abbacc7a21fc" +dependencies = [ + "base64 0.22.1", + "log", + "percent-encoding", + "rustls", + "rustls-pki-types", + "ureq-proto", + "utf-8", +] + +[[package]] +name = "ureq-proto" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d81f9efa9df032be5934a46a068815a10a042b494b6a58cb0a1a97bb5467ed6f" +dependencies = [ + "base64 0.22.1", + "http 1.4.0", + "httparse", + "log", +] + [[package]] name = "url" -version = "2.5.4" +version = "2.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32f8b686cadd1473f4bd0117a5d28d36b1ade384ea9b5069a1c40aefed7fda60" +checksum = "ff67a8a4397373c3ef660812acab3268222035010ab8680ec4215f38ba3d0eed" dependencies = [ "form_urlencoded", "idna", "percent-encoding", "serde", + "serde_derive", ] [[package]] @@ -6836,6 +6654,12 @@ version = "2.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf8_iter" version = "1.0.4" @@ -6850,13 +6674,13 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.17.0" +version = "1.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3cf4199d1e5d15ddd86a694e4d0dffa9c323ce759fea589f00fef9d81cc1931d" +checksum = "a68d3c8f01c0cfa54a75291d83601161799e4a89a39e0929f4b0354d88757a37" dependencies = [ - "getrandom 0.3.3", + "getrandom 0.4.2", "js-sys", - "serde", + "serde_core", "wasm-bindgen", ] @@ -6878,15 +6702,6 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5c3082ca00d5a5ef149bb8b555a72ae84c9c59f7250f013ac822ac2e49b19c64" -[[package]] -name = "wait-timeout" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09ac3b126d3914f9849036f826e054cbabdc8519970b8998ddaf3b5bd3c65f11" -dependencies = [ - "libc", -] - [[package]] name = "walkdir" version = "2.5.0" @@ -6908,58 +6723,67 @@ dependencies = [ [[package]] name = "wasi" -version = "0.11.0+wasi-snapshot-preview1" +version = "0.11.1+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" [[package]] name = "wasi" -version = "0.14.2+wasi-0.2.4" +version = "0.14.7+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "883478de20367e224c0090af9cf5f9fa85bed63a95c1abf3afc5c083ebc06e8c" +dependencies = [ + "wasip2", +] + +[[package]] +name = "wasip2" +version = "1.0.2+wasi-0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasip3" +version = "0.4.0+wasi-0.3.0-rc-2026-01-06" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3" +checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" dependencies = [ - "wit-bindgen-rt", + "wit-bindgen", ] [[package]] name = "wasite" -version = "0.1.0" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" +checksum = "66fe902b4a6b8028a753d5424909b764ccf79b7a209eac9bf97e59cda9f71a42" +dependencies = [ + "wasi 0.14.7+wasi-0.2.4", +] [[package]] name = "wasm-bindgen" -version = "0.2.100" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1edc8929d7499fc4e8f0be2262a241556cfc54a0bea223790e71446f2aab1ef5" +checksum = "6532f9a5c1ece3798cb1c2cfdba640b9b3ba884f5db45973a6f442510a87d38e" dependencies = [ "cfg-if", "once_cell", "rustversion", "wasm-bindgen-macro", -] - -[[package]] -name = "wasm-bindgen-backend" -version = "0.2.100" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f0a0651a5c2bc21487bde11ee802ccaf4c51935d0d3d42a6101f98161700bc6" -dependencies = [ - "bumpalo", - "log", - "proc-macro2", - "quote", - "syn 2.0.101", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.50" +version = "0.4.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "555d470ec0bc3bb57890405e5d4322cc9ea83cebb085523ced7be4144dac1e61" +checksum = "e9c5522b3a28661442748e09d40924dfb9ca614b21c00d3fd135720e48b67db8" dependencies = [ "cfg-if", + "futures-util", "js-sys", "once_cell", "wasm-bindgen", @@ -6968,9 +6792,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.100" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fe63fc6d09ed3792bd0897b314f53de8e16568c2b3f7982f468c0bf9bd0b407" +checksum = "18a2d50fcf105fb33bb15f00e7a77b772945a2ee45dcf454961fd843e74c18e6" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -6978,48 +6802,85 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.100" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" +checksum = "03ce4caeaac547cdf713d280eda22a730824dd11e6b8c3ca9e42247b25c631e3" dependencies = [ + "bumpalo", "proc-macro2", "quote", - "syn 2.0.101", - "wasm-bindgen-backend", + "syn 2.0.117", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.100" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a05d73b933a847d6cccdda8f838a22ff101ad9bf93e33684f39c1f5f0eece3d" +checksum = "75a326b8c223ee17883a4251907455a2431acc2791c98c26279376490c378c16" dependencies = [ "unicode-ident", ] [[package]] name = "wasm-bindgen-test" -version = "0.3.50" +version = "0.3.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "66c8d5e33ca3b6d9fa3b4676d774c5778031d27a578c2b007f905acf816152c3" +checksum = "6311c867385cc7d5602463b31825d454d0837a3aba7cdb5e56d5201792a3f7fe" dependencies = [ + "async-trait", + "cast", "js-sys", + "libm", "minicov", + "nu-ansi-term", + "num-traits", + "oorandom", + "serde", + "serde_json", "wasm-bindgen", "wasm-bindgen-futures", "wasm-bindgen-test-macro", + "wasm-bindgen-test-shared", ] [[package]] name = "wasm-bindgen-test-macro" -version = "0.3.50" +version = "0.3.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17d5042cc5fa009658f9a7333ef24291b1291a25b6382dd68862a7f3b969f69b" +checksum = "67008cdde4769831958536b0f11b3bdd0380bde882be17fff9c2f34bb4549abd" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.117", +] + +[[package]] +name = "wasm-bindgen-test-shared" +version = "0.2.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfe29135b180b72b04c74aa97b2b4a2ef275161eff9a6c7955ea9eaedc7e1d4e" + +[[package]] +name = "wasm-encoder" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" +dependencies = [ + "anyhow", + "indexmap 2.13.0", + "wasm-encoder", + "wasmparser", ] [[package]] @@ -7035,11 +6896,23 @@ dependencies = [ "web-sys", ] +[[package]] +name = "wasmparser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" +dependencies = [ + "bitflags", + "hashbrown 0.15.5", + "indexmap 2.13.0", + "semver", +] + [[package]] name = "web-sys" -version = "0.3.77" +version = "0.3.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2" +checksum = "854ba17bb104abfb26ba36da9729addc7ce7f06f5c0f90f3c391f8461cca21f9" dependencies = [ "js-sys", "wasm-bindgen", @@ -7055,25 +6928,15 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "which" -version = "4.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" -dependencies = [ - "either", - "home", - "once_cell", - "rustix 0.38.44", -] - [[package]] name = "whoami" -version = "1.6.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6994d13118ab492c3c80c1f81928718159254c53c472bf9ce36f8dae4add02a7" +checksum = "d6a5b12f9df4f978d2cfdb1bd3bac52433f44393342d7ee9c25f5a1c14c0f45d" dependencies = [ - "redox_syscall 0.5.12", + "libc", + "libredox", + "objc2-system-configuration", "wasite", "web-sys", ] @@ -7096,11 +6959,11 @@ checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" [[package]] name = "winapi-util" -version = "0.1.9" +version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" +checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -7111,44 +6974,43 @@ checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] name = "windows" -version = "0.61.1" +version = "0.62.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5ee8f3d025738cb02bad7868bbb5f8a6327501e870bf51f1b455b0a2454a419" +checksum = "527fadee13e0c05939a6a05d5bd6eec6cd2e3dbd648b9f8e447c6518133d8580" dependencies = [ "windows-collections", "windows-core", "windows-future", - "windows-link", "windows-numerics", ] [[package]] name = "windows-collections" -version = "0.2.0" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3beeceb5e5cfd9eb1d76b381630e82c4241ccd0d27f1a39ed41b2760b255c5e8" +checksum = "23b2d95af1a8a14a3c7367e1ed4fc9c20e0a26e79551b1454d72583c97cc6610" dependencies = [ "windows-core", ] [[package]] name = "windows-core" -version = "0.61.2" +version = "0.62.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3" +checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" dependencies = [ "windows-implement", "windows-interface", "windows-link", "windows-result", - "windows-strings 0.4.2", + "windows-strings", ] [[package]] name = "windows-future" -version = "0.2.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc6a41e98427b19fe4b73c550f060b59fa592d7d686537eebf9385621bfbad8e" +checksum = "e1d6f90251fe18a279739e78025bd6ddc52a7e22f921070ccdc67dde84c605cb" dependencies = [ "windows-core", "windows-link", @@ -7157,96 +7019,94 @@ dependencies = [ [[package]] name = "windows-implement" -version = "0.60.0" +version = "0.60.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836" +checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.117", ] [[package]] name = "windows-interface" -version = "0.59.1" +version = "0.59.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8" +checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.117", ] [[package]] name = "windows-link" -version = "0.1.1" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76840935b766e1b0a05c0066835fb9ec80071d4c09a16f6bd5f7e655e3c14c38" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" [[package]] name = "windows-numerics" -version = "0.2.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9150af68066c4c5c07ddc0ce30421554771e528bde427614c61038bc2c92c2b1" +checksum = "6e2e40844ac143cdb44aead537bbf727de9b044e107a0f1220392177d15b0f26" dependencies = [ "windows-core", "windows-link", ] [[package]] -name = "windows-registry" -version = "0.4.0" +name = "windows-result" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4286ad90ddb45071efd1a66dfa43eb02dd0dfbae1545ad6cc3c51cf34d7e8ba3" +checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" dependencies = [ - "windows-result", - "windows-strings 0.3.1", - "windows-targets 0.53.0", + "windows-link", ] [[package]] -name = "windows-result" -version = "0.3.4" +name = "windows-strings" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" +checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" dependencies = [ "windows-link", ] [[package]] -name = "windows-strings" -version = "0.3.1" +name = "windows-sys" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87fa48cc5d406560701792be122a10132491cff9d0aeb23583cc2dcafc847319" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-link", + "windows-targets 0.52.6", ] [[package]] -name = "windows-strings" -version = "0.4.2" +name = "windows-sys" +version = "0.59.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" dependencies = [ - "windows-link", + "windows-targets 0.52.6", ] [[package]] name = "windows-sys" -version = "0.52.0" +version = "0.60.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" dependencies = [ - "windows-targets 0.52.6", + "windows-targets 0.53.5", ] [[package]] name = "windows-sys" -version = "0.59.0" +version = "0.61.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" dependencies = [ - "windows-targets 0.52.6", + "windows-link", ] [[package]] @@ -7267,25 +7127,26 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.53.0" +version = "0.53.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1e4c7e8ceaaf9cb7d7507c974735728ab453b67ef8f18febdd7c11fe59dca8b" +checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" dependencies = [ - "windows_aarch64_gnullvm 0.53.0", - "windows_aarch64_msvc 0.53.0", - "windows_i686_gnu 0.53.0", - "windows_i686_gnullvm 0.53.0", - "windows_i686_msvc 0.53.0", - "windows_x86_64_gnu 0.53.0", - "windows_x86_64_gnullvm 0.53.0", - "windows_x86_64_msvc 0.53.0", + "windows-link", + "windows_aarch64_gnullvm 0.53.1", + "windows_aarch64_msvc 0.53.1", + "windows_i686_gnu 0.53.1", + "windows_i686_gnullvm 0.53.1", + "windows_i686_msvc 0.53.1", + "windows_x86_64_gnu 0.53.1", + "windows_x86_64_gnullvm 0.53.1", + "windows_x86_64_msvc 0.53.1", ] [[package]] name = "windows-threading" -version = "0.1.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b66463ad2e0ea3bbf808b7f1d371311c80e115c0b71d60efc142cafbcfb057a6" +checksum = "3949bd5b99cafdf1c7ca86b43ca564028dfe27d66958f2470940f73d86d75b37" dependencies = [ "windows-link", ] @@ -7298,9 +7159,9 @@ checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" [[package]] name = "windows_aarch64_gnullvm" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" +checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" [[package]] name = "windows_aarch64_msvc" @@ -7310,9 +7171,9 @@ checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" [[package]] name = "windows_aarch64_msvc" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" +checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" [[package]] name = "windows_i686_gnu" @@ -7322,9 +7183,9 @@ checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" [[package]] name = "windows_i686_gnu" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3" +checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" [[package]] name = "windows_i686_gnullvm" @@ -7334,9 +7195,9 @@ checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" [[package]] name = "windows_i686_gnullvm" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" +checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" [[package]] name = "windows_i686_msvc" @@ -7346,9 +7207,9 @@ checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" [[package]] name = "windows_i686_msvc" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" +checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" [[package]] name = "windows_x86_64_gnu" @@ -7358,9 +7219,9 @@ checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" [[package]] name = "windows_x86_64_gnu" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" +checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" [[package]] name = "windows_x86_64_gnullvm" @@ -7370,9 +7231,9 @@ checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" [[package]] name = "windows_x86_64_gnullvm" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" +checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" [[package]] name = "windows_x86_64_msvc" @@ -7382,51 +7243,121 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "windows_x86_64_msvc" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" +checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" [[package]] name = "winnow" -version = "0.7.10" +version = "0.7.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c06928c8748d81b05c9be96aad92e1b6ff01833332f281e8cfca3be4b35fc9ec" +checksum = "df79d97927682d2fd8adb29682d1140b343be4ac0f08fd68b7765d9c059d3945" dependencies = [ "memchr", ] [[package]] -name = "wit-bindgen-rt" -version = "0.39.0" +name = "wit-bindgen" +version = "0.51.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" dependencies = [ - "bitflags 2.9.1", + "wit-bindgen-rust-macro", ] [[package]] -name = "writeable" -version = "0.6.1" +name = "wit-bindgen-core" +version = "0.51.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea2f10b9bb0928dfb1b42b65e1f9e36f7f54dbdf08457afefb38afcdec4fa2bb" +checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" +dependencies = [ + "anyhow", + "heck", + "wit-parser", +] [[package]] -name = "wyz" -version = "0.5.1" +name = "wit-bindgen-rust" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" +dependencies = [ + "anyhow", + "heck", + "indexmap 2.13.0", + "prettyplease", + "syn 2.0.117", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn 2.0.117", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" +dependencies = [ + "anyhow", + "bitflags", + "indexmap 2.13.0", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.244.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" +checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" dependencies = [ - "tap", + "anyhow", + "id-arena", + "indexmap 2.13.0", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", ] +[[package]] +name = "writeable" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9" + [[package]] name = "xattr" -version = "1.5.0" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d65cbf2f12c15564212d48f4e3dfb87923d25d611f2aed18f4cb23f0413d89e" +checksum = "32e45ad4206f6d2479085147f02bc2ef834ac85886624a23575ae137c8aa8156" dependencies = [ "libc", - "rustix 1.0.7", + "rustix", ] [[package]] @@ -7436,21 +7367,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "66fee0b777b0f5ac1c69bb06d361268faafa61cd4682ae064a171c16c433e9e4" [[package]] -name = "xz2" -version = "0.1.7" +name = "yansi" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "388c44dc09d76f1536602ead6d325eb532f5c122f17782bd57fb47baeeb767e2" -dependencies = [ - "lzma-sys", -] +checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" [[package]] name = "yoke" -version = "0.8.0" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f41bb01b8226ef4bfd589436a297c53d118f65921786300e427be8d487695cc" +checksum = "72d6e5c6afb84d73944e5cedb052c4680d5657337201555f9f2a16b7406d4954" dependencies = [ - "serde", "stable_deref_trait", "yoke-derive", "zerofrom", @@ -7458,34 +7385,34 @@ dependencies = [ [[package]] name = "yoke-derive" -version = "0.8.0" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38da3c9736e16c5d3c8c597a9aaa5d1fa565d0532ae05e27c24aa62fb32c0ab6" +checksum = "b659052874eb698efe5b9e8cf382204678a0086ebf46982b79d6ca3182927e5d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.117", "synstructure", ] [[package]] name = "zerocopy" -version = "0.8.25" +version = "0.8.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1702d9583232ddb9174e01bb7c15a2ab8fb1bc6f227aa1233858c351a3ba0cb" +checksum = "f2578b716f8a7a858b7f02d5bd870c14bf4ddbbcf3a4c05414ba6503640505e3" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.25" +version = "0.8.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28a6e20d751156648aa063f3800b706ee209a32c0b4d9f24be3d980b01be55ef" +checksum = "7e6cc098ea4d3bd6246687de65af3f920c430e236bee1e3bf2e441463f08a02f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.117", ] [[package]] @@ -7505,21 +7432,21 @@ checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.117", "synstructure", ] [[package]] name = "zeroize" -version = "1.8.1" +version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" [[package]] name = "zerotrie" -version = "0.2.2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36f0bbd478583f79edad978b407914f61b2972f5af6fa089686016be8f9af595" +checksum = "2a59c17a5562d507e4b54960e8569ebee33bee890c70aa3fe7b97e85a9fd7851" dependencies = [ "displaydoc", "yoke", @@ -7528,9 +7455,9 @@ dependencies = [ [[package]] name = "zerovec" -version = "0.11.2" +version = "0.11.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a05eb080e015ba39cc9e23bbe5e7fb04d5fb040350f99f34e338d5fdd294428" +checksum = "6c28719294829477f525be0186d13efa9a3c602f7ec202ca9e353d310fb9a002" dependencies = [ "yoke", "zerofrom", @@ -7539,20 +7466,26 @@ dependencies = [ [[package]] name = "zerovec-derive" -version = "0.11.1" +version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b96237efa0c878c64bd89c436f661be4e46b2f3eff1ebb976f7ef2321d2f58f" +checksum = "eadce39539ca5cb3985590102671f2567e659fca9666581ad3411d59207951f3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.101", + "syn 2.0.117", ] [[package]] name = "zlib-rs" -version = "0.5.0" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3be3d40e40a133f9c916ee3f9f4fa2d9d63435b5fbe1bfc6d9dae0aa0ada1513" + +[[package]] +name = "zmij" +version = "1.0.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "868b928d7949e09af2f6086dfc1e01936064cc7a819253bce650d4e2a2d63ba8" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" [[package]] name = "zstd" @@ -7574,9 +7507,9 @@ dependencies = [ [[package]] name = "zstd-sys" -version = "2.0.15+zstd.1.5.7" +version = "2.0.16+zstd.1.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb81183ddd97d0c74cedf1d50d85c8d08c1b8b68ee863bdee9e706eedba1a237" +checksum = "91e19ebc2adc8f83e43039e79776e3fda8ca919132d68a1fed6a5faca2683748" dependencies = [ "cc", "pkg-config", diff --git a/Cargo.toml b/Cargo.toml index 79bb2f3cc602d..08d585d3ef906 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ members = [ "datafusion/catalog", "datafusion/catalog-listing", "datafusion/datasource", + "datafusion/datasource-arrow", "datafusion/datasource-avro", "datafusion/datasource-csv", "datafusion/datasource-json", @@ -40,8 +41,10 @@ members = [ "datafusion/functions-window-common", "datafusion/optimizer", "datafusion/physical-expr", + "datafusion/physical-expr-adapter", "datafusion/physical-expr-common", "datafusion/physical-optimizer", + "datafusion/pruning", "datafusion/physical-plan", "datafusion/proto", "datafusion/proto/gen", @@ -68,15 +71,15 @@ resolver = "2" [workspace.package] authors = ["Apache DataFusion "] -edition = "2021" +edition = "2024" homepage = "https://datafusion.apache.org" license = "Apache-2.0" readme = "README.md" repository = "https://github.com/apache/datafusion" # Define Minimum Supported Rust Version (MSRV) -rust-version = "1.82.0" +rust-version = "1.88.0" # Define DataFusion version -version = "47.0.0" +version = "52.3.0" [workspace.dependencies] # We turn off default-features for some dependencies here so the workspaces which inherit them can @@ -84,100 +87,166 @@ version = "47.0.0" # for the inherited dependency but cannot do the reverse (override from true to false). # # See for more details: https://github.com/rust-lang/cargo/issues/11329 -ahash = { version = "0.8", default-features = false, features = [ - "runtime-rng", -] } -apache-avro = { version = "0.17", default-features = false } -arrow = { version = "55.1.0", features = [ +apache-avro = { version = "0.21", default-features = false } +arrow = { version = "58.0.0", features = [ "prettyprint", "chrono-tz", ] } -arrow-buffer = { version = "55.0.0", default-features = false } -arrow-flight = { version = "55.1.0", features = [ +arrow-buffer = { version = "58.0.0", default-features = false } +arrow-flight = { version = "58.0.0", features = [ "flight-sql-experimental", ] } -arrow-ipc = { version = "55.0.0", default-features = false, features = [ +arrow-ipc = { version = "58.0.0", default-features = false, features = [ "lz4", ] } -arrow-ord = { version = "55.0.0", default-features = false } -arrow-schema = { version = "55.0.0", default-features = false } -async-trait = "0.1.88" +arrow-ord = { version = "58.0.0", default-features = false } +arrow-schema = { version = "58.0.0", default-features = false } +async-trait = "0.1.89" bigdecimal = "0.4.8" -bytes = "1.10" -chrono = { version = "0.4.41", default-features = false } -criterion = "0.5.1" -ctor = "0.4.0" +bytes = "1.11" +bzip2 = "0.6.1" +chrono = { version = "0.4.44", default-features = false } +criterion = "0.8" +ctor = "0.6.3" dashmap = "6.0.1" -datafusion = { path = "datafusion/core", version = "47.0.0", default-features = false } -datafusion-catalog = { path = "datafusion/catalog", version = "47.0.0" } -datafusion-catalog-listing = { path = "datafusion/catalog-listing", version = "47.0.0" } -datafusion-common = { path = "datafusion/common", version = "47.0.0", default-features = false } -datafusion-common-runtime = { path = "datafusion/common-runtime", version = "47.0.0" } -datafusion-datasource = { path = "datafusion/datasource", version = "47.0.0", default-features = false } -datafusion-datasource-avro = { path = "datafusion/datasource-avro", version = "47.0.0", default-features = false } -datafusion-datasource-csv = { path = "datafusion/datasource-csv", version = "47.0.0", default-features = false } -datafusion-datasource-json = { path = "datafusion/datasource-json", version = "47.0.0", default-features = false } -datafusion-datasource-parquet = { path = "datafusion/datasource-parquet", version = "47.0.0", default-features = false } -datafusion-doc = { path = "datafusion/doc", version = "47.0.0" } -datafusion-execution = { path = "datafusion/execution", version = "47.0.0" } -datafusion-expr = { path = "datafusion/expr", version = "47.0.0" } -datafusion-expr-common = { path = "datafusion/expr-common", version = "47.0.0" } -datafusion-ffi = { path = "datafusion/ffi", version = "47.0.0" } -datafusion-functions = { path = "datafusion/functions", version = "47.0.0" } -datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", version = "47.0.0" } -datafusion-functions-aggregate-common = { path = "datafusion/functions-aggregate-common", version = "47.0.0" } -datafusion-functions-nested = { path = "datafusion/functions-nested", version = "47.0.0" } -datafusion-functions-table = { path = "datafusion/functions-table", version = "47.0.0" } -datafusion-functions-window = { path = "datafusion/functions-window", version = "47.0.0" } -datafusion-functions-window-common = { path = "datafusion/functions-window-common", version = "47.0.0" } -datafusion-macros = { path = "datafusion/macros", version = "47.0.0" } -datafusion-optimizer = { path = "datafusion/optimizer", version = "47.0.0", default-features = false } -datafusion-physical-expr = { path = "datafusion/physical-expr", version = "47.0.0", default-features = false } -datafusion-physical-expr-common = { path = "datafusion/physical-expr-common", version = "47.0.0", default-features = false } -datafusion-physical-optimizer = { path = "datafusion/physical-optimizer", version = "47.0.0" } -datafusion-physical-plan = { path = "datafusion/physical-plan", version = "47.0.0" } -datafusion-proto = { path = "datafusion/proto", version = "47.0.0" } -datafusion-proto-common = { path = "datafusion/proto-common", version = "47.0.0" } -datafusion-session = { path = "datafusion/session", version = "47.0.0" } -datafusion-spark = { path = "datafusion/spark", version = "47.0.0" } -datafusion-sql = { path = "datafusion/sql", version = "47.0.0" } +datafusion = { path = "datafusion/core", version = "52.3.0", default-features = false } +datafusion-catalog = { path = "datafusion/catalog", version = "52.3.0" } +datafusion-catalog-listing = { path = "datafusion/catalog-listing", version = "52.3.0" } +datafusion-common = { path = "datafusion/common", version = "52.3.0", default-features = false } +datafusion-common-runtime = { path = "datafusion/common-runtime", version = "52.3.0" } +datafusion-datasource = { path = "datafusion/datasource", version = "52.3.0", default-features = false } +datafusion-datasource-arrow = { path = "datafusion/datasource-arrow", version = "52.3.0", default-features = false } +datafusion-datasource-avro = { path = "datafusion/datasource-avro", version = "52.3.0", default-features = false } +datafusion-datasource-csv = { path = "datafusion/datasource-csv", version = "52.3.0", default-features = false } +datafusion-datasource-json = { path = "datafusion/datasource-json", version = "52.3.0", default-features = false } +datafusion-datasource-parquet = { path = "datafusion/datasource-parquet", version = "52.3.0", default-features = false } +datafusion-doc = { path = "datafusion/doc", version = "52.3.0" } +datafusion-execution = { path = "datafusion/execution", version = "52.3.0", default-features = false } +datafusion-expr = { path = "datafusion/expr", version = "52.3.0", default-features = false } +datafusion-expr-common = { path = "datafusion/expr-common", version = "52.3.0" } +datafusion-ffi = { path = "datafusion/ffi", version = "52.3.0" } +datafusion-functions = { path = "datafusion/functions", version = "52.3.0" } +datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", version = "52.3.0" } +datafusion-functions-aggregate-common = { path = "datafusion/functions-aggregate-common", version = "52.3.0" } +datafusion-functions-nested = { path = "datafusion/functions-nested", version = "52.3.0", default-features = false } +datafusion-functions-table = { path = "datafusion/functions-table", version = "52.3.0" } +datafusion-functions-window = { path = "datafusion/functions-window", version = "52.3.0" } +datafusion-functions-window-common = { path = "datafusion/functions-window-common", version = "52.3.0" } +datafusion-macros = { path = "datafusion/macros", version = "52.3.0" } +datafusion-optimizer = { path = "datafusion/optimizer", version = "52.3.0", default-features = false } +datafusion-physical-expr = { path = "datafusion/physical-expr", version = "52.3.0", default-features = false } +datafusion-physical-expr-adapter = { path = "datafusion/physical-expr-adapter", version = "52.3.0", default-features = false } +datafusion-physical-expr-common = { path = "datafusion/physical-expr-common", version = "52.3.0", default-features = false } +datafusion-physical-optimizer = { path = "datafusion/physical-optimizer", version = "52.3.0" } +datafusion-physical-plan = { path = "datafusion/physical-plan", version = "52.3.0" } +datafusion-proto = { path = "datafusion/proto", version = "52.3.0" } +datafusion-proto-common = { path = "datafusion/proto-common", version = "52.3.0" } +datafusion-pruning = { path = "datafusion/pruning", version = "52.3.0" } +datafusion-session = { path = "datafusion/session", version = "52.3.0" } +datafusion-spark = { path = "datafusion/spark", version = "52.3.0" } +datafusion-sql = { path = "datafusion/sql", version = "52.3.0" } +datafusion-substrait = { path = "datafusion/substrait", version = "52.3.0" } + doc-comment = "0.3" env_logger = "0.11" +flate2 = "1.1.9" futures = "0.3" -half = { version = "2.6.0", default-features = false } -hashbrown = { version = "0.14.5", features = ["raw"] } -indexmap = "2.9.0" +glob = "0.3.0" +half = { version = "2.7.0", default-features = false } +hashbrown = { version = "0.16.1" } +hex = { version = "0.4.3" } +indexmap = "2.13.0" +insta = { version = "1.46.3", features = ["glob", "filters"] } itertools = "0.14" +itoa = "1.0" +liblzma = { version = "0.4.6", features = ["static"] } log = "^0.4" -object_store = { version = "0.12.0", default-features = false } +memchr = "2.8.0" +num-traits = { version = "0.2" } +object_store = { version = "0.13.1", default-features = false } parking_lot = "0.12" -parquet = { version = "55.1.0", default-features = false, features = [ +parquet = { version = "58.0.0", default-features = false, features = [ "arrow", "async", "object_store", ] } -pbjson = { version = "0.7.0" } -pbjson-types = "0.7" +pbjson = { version = "0.9.0" } +pbjson-types = "0.9" # Should match arrow-flight's version of prost. -insta = { version = "1.43.1", features = ["glob", "filters"] } -prost = "0.13.1" +prost = "0.14.1" rand = "0.9" recursive = "0.1.1" -regex = "1.8" -rstest = "0.25.0" +regex = "1.12" +rstest = "0.26.1" serde_json = "1" -sqlparser = { version = "0.55.0", features = ["visitor"] } +sha2 = "^0.10.9" +sqlparser = { version = "0.61.0", default-features = false, features = ["std", "visitor"] } +strum = "0.28.0" +strum_macros = "0.28.0" tempfile = "3" -tokio = { version = "1.45", features = ["macros", "rt", "sync"] } -url = "2.5.4" +testcontainers-modules = { version = "0.15" } +tokio = { version = "1.48", features = ["macros", "rt", "sync"] } +tokio-stream = "0.1" +tokio-util = "0.7" +url = "2.5.7" +uuid = "1.21" +zstd = { version = "0.13", default-features = false } + +[workspace.lints.clippy] +# Detects large stack-allocated futures that may cause stack overflow crashes (see threshold in clippy.toml) +large_futures = "warn" +used_underscore_binding = "warn" +or_fun_call = "warn" +unnecessary_lazy_evaluations = "warn" +uninlined_format_args = "warn" +inefficient_to_string = "warn" +# https://github.com/apache/datafusion/issues/18503 +needless_pass_by_value = "warn" +# https://github.com/apache/datafusion/issues/18881 +allow_attributes = "warn" +assigning_clones = "warn" + +[workspace.lints.rust] +unexpected_cfgs = { level = "warn", check-cfg = [ + 'cfg(datafusion_coop, values("tokio", "tokio_fallback", "per_stream"))', + "cfg(tarpaulin)", + "cfg(tarpaulin_include)", +] } +unused_qualifications = "deny" +# -------------------- +# Compilation Profiles +# -------------------- +# A Cargo profile is a preset for the compiler/linker knobs that trade off: +# - Build time: how quickly code compiles and links +# - Runtime performance: how fast the resulting binaries execute +# - Binary size: how large the executables end up +# - Debuggability: how much debug information is preserved for debugging and profiling +# +# Profiles available: +# - dev: default debug build; fastest to compile, slowest to run, full debug info +# for everyday development. +# Run: cargo run +# - release: optimized build; slowest to compile, fastest to run, smallest +# binaries for public releases. +# Run: cargo run --release +# - release-nonlto: skips LTO, so it builds quicker while staying close to +# release performance. It is useful when developing performance optimizations. +# Run: cargo run --profile release-nonlto +# - profiling: inherits release optimizations but retains debug info to support +# profiling tools and flamegraphs. +# Run: cargo run --profile profiling +# - ci: derived from `dev` but disables incremental builds and strips dependency +# symbols to keep CI artifacts small and reproducible. +# Run: cargo run --profile ci +# +# If you want to optimize compilation, the `compile_profile` benchmark can be useful. +# See `benchmarks/README.md` for more details. [profile.release] codegen-units = 1 lto = true strip = true # Eliminate debug information to minimize binary size -# the release profile takes a long time to build so we can use this profile during development to save time -# cargo build --profile release-nonlto [profile.release-nonlto] codegen-units = 16 debug-assertions = false @@ -189,32 +258,27 @@ overflow-checks = false rpath = false strip = false # Retain debug info for flamegraphs +[profile.ci-optimized] +inherits = "release" +codegen-units = 16 +lto = "thin" +strip = true + [profile.ci] +debug = false inherits = "dev" incremental = false -# ci turns off debug info, etc. for dependencies to allow for smaller binaries making caching more effective +# This rule applies to every package except workspace members (dependencies +# such as `arrow` and `tokio`). It disables debug info and related features on +# dependencies so their binaries stay smaller, improving cache reuse. [profile.ci.package."*"] debug = false debug-assertions = false strip = "debuginfo" incremental = false -# release inherited profile keeping debug information and symbols -# for mem/cpu profiling [profile.profiling] inherits = "release" debug = true strip = false - -[workspace.lints.clippy] -# Detects large stack-allocated futures that may cause stack overflow crashes (see threshold in clippy.toml) -large_futures = "warn" -used_underscore_binding = "warn" -or_fun_call = "warn" -unnecessary_lazy_evaluations = "warn" -uninlined_format_args = "warn" - -[workspace.lints.rust] -unexpected_cfgs = { level = "warn", check-cfg = ["cfg(tarpaulin)"] } -unused_qualifications = "deny" diff --git a/NOTICE.txt b/NOTICE.txt index 7f3c80d606c07..0bd2d52368fea 100644 --- a/NOTICE.txt +++ b/NOTICE.txt @@ -1,5 +1,5 @@ Apache DataFusion -Copyright 2019-2025 The Apache Software Foundation +Copyright 2019-2026 The Apache Software Foundation This product includes software developed at The Apache Software Foundation (http://www.apache.org/). diff --git a/README.md b/README.md index c142d8f366b2e..630d4295bd427 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ [![Build Status][actions-badge]][actions-url] ![Commit Activity][commit-activity-badge] [![Open Issues][open-issues-badge]][open-issues-url] +[![Pending PRs][pending-pr-badge]][pending-pr-url] [![Discord chat][discord-badge]][discord-url] [![Linkedin][linkedin-badge]][linkedin-url] ![Crates.io MSRV][msrv-badge] @@ -39,6 +40,8 @@ [commit-activity-badge]: https://img.shields.io/github/commit-activity/m/apache/datafusion [open-issues-badge]: https://img.shields.io/github/issues-raw/apache/datafusion [open-issues-url]: https://github.com/apache/datafusion/issues +[pending-pr-badge]: https://img.shields.io/github/issues-search/apache/datafusion?query=is%3Apr+is%3Aopen+draft%3Afalse+review%3Arequired+status%3Asuccess&label=Pending%20PRs&logo=github +[pending-pr-url]: https://github.com/apache/datafusion/pulls?q=is%3Apr+is%3Aopen+draft%3Afalse+review%3Arequired+status%3Asuccess+sort%3Aupdated-desc [linkedin-badge]: https://img.shields.io/badge/Follow-Linkedin-blue [linkedin-url]: https://www.linkedin.com/company/apache-datafusion/ [msrv-badge]: https://img.shields.io/crates/msrv/datafusion?label=Min%20Rust%20Version @@ -55,18 +58,16 @@ DataFusion is an extensible query engine written in [Rust] that uses [Apache Arrow] as its in-memory format. This crate provides libraries and binaries for developers building fast and -feature rich database and analytic systems, customized to particular workloads. +feature-rich database and analytic systems, customized for particular workloads. See [use cases] for examples. The following related subprojects target end users: - [DataFusion Python](https://github.com/apache/datafusion-python/) offers a Python interface for SQL and DataFrame queries. -- [DataFusion Ray](https://github.com/apache/datafusion-ray/) provides a distributed version of DataFusion that scales - out on Ray clusters. - [DataFusion Comet](https://github.com/apache/datafusion-comet/) is an accelerator for Apache Spark based on DataFusion. "Out of the box," -DataFusion offers [SQL] and [`Dataframe`] APIs, excellent [performance], +DataFusion offers [SQL](https://datafusion.apache.org/user-guide/sql/index.html) and [DataFrame](https://datafusion.apache.org/user-guide/dataframe.html) APIs, excellent [performance], built-in support for CSV, Parquet, JSON, and Avro, extensive customization, and a great community. @@ -83,7 +84,7 @@ See the [Architecture] section for more details. [performance]: https://benchmark.clickhouse.com/ [architecture]: https://datafusion.apache.org/contributor-guide/architecture.html -Here are links to some important information +Here are links to important resources: - [Project Site](https://datafusion.apache.org/) - [Installation](https://datafusion.apache.org/user-guide/cli/installation.html) @@ -96,8 +97,8 @@ Here are links to some important information ## What can you do with this crate? -DataFusion is great for building projects such as domain specific query engines, new database platforms and data pipelines, query languages and more. -It lets you start quickly from a fully working engine, and then customize those features specific to your use. [Click Here](https://datafusion.apache.org/user-guide/introduction.html#known-users) to see a list known users. +DataFusion is great for building projects such as domain-specific query engines, new database platforms and data pipelines, query languages and more. +It lets you start quickly from a fully working engine, and then customize those features specific to your needs. See the [list of known users](https://datafusion.apache.org/user-guide/introduction.html#known-users). ## Contributing to DataFusion @@ -114,14 +115,15 @@ This crate has several [features] which can be specified in your `Cargo.toml`. Default features: -- `nested_expressions`: functions for working with nested type function such as `array_to_string` +- `nested_expressions`: functions for working with nested types such as `array_to_string` - `compression`: reading files compressed with `xz2`, `bzip2`, `flate2`, and `zstd` - `crypto_expressions`: cryptographic functions such as `md5` and `sha256` - `datetime_expressions`: date and time functions such as `to_timestamp` - `encoding_expressions`: `encode` and `decode` functions - `parquet`: support for reading the [Apache Parquet] format +- `sql`: support for SQL parsing and planning - `regex_expressions`: regular expression functions, such as `regexp_match` -- `unicode_expressions`: Include unicode aware functions such as `character_length` +- `unicode_expressions`: include Unicode-aware functions such as `character_length` - `unparser`: enables support to reverse LogicalPlans back into SQL - `recursive_protection`: uses [recursive](https://docs.rs/recursive/latest/recursive/) for stack overflow protection. @@ -129,11 +131,12 @@ Optional features: - `avro`: support for reading the [Apache Avro] format - `backtrace`: include backtrace information in error messages -- `pyarrow`: conversions between PyArrow and DataFusion types +- `parquet_encryption`: support for using [Parquet Modular Encryption] - `serde`: enable arrow-schema's `serde` feature [apache avro]: https://avro.apache.org/ [apache parquet]: https://parquet.apache.org/ +[parquet modular encryption]: https://parquet.apache.org/docs/file-format/data-pages/encryption/ ## DataFusion API Evolution and Deprecation Guidelines @@ -141,7 +144,7 @@ Public methods in Apache DataFusion evolve over time: while we try to maintain a stable API, we also improve the API over time. As a result, we typically deprecate methods before removing them, according to the [deprecation guidelines]. -[deprecation guidelines]: https://datafusion.apache.org/library-user-guide/api-health.html +[deprecation guidelines]: https://datafusion.apache.org/contributor-guide/api-health.html ## Dependencies and `Cargo.lock` diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index f9c198597b74c..56f7704309780 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -26,6 +26,9 @@ repository = { workspace = true } license = { workspace = true } rust-version = { workspace = true } +# Note: add additional linter rules in lib.rs. +# Rust does not support workspace + new linter rules in subcrates yet +# https://github.com/rust-lang/cargo/issues/13157 [lints] workspace = true @@ -33,25 +36,29 @@ workspace = true ci = [] default = ["mimalloc"] snmalloc = ["snmalloc-rs"] +mimalloc_extended = ["libmimalloc-sys/extended"] [dependencies] arrow = { workspace = true } +async-trait = "0.1" +bytes = { workspace = true } +clap = { version = "4.5.60", features = ["derive"] } datafusion = { workspace = true, default-features = true } datafusion-common = { workspace = true, default-features = true } env_logger = { workspace = true } futures = { workspace = true } +libmimalloc-sys = { version = "0.1", optional = true } log = { workspace = true } mimalloc = { version = "0.1", optional = true, default-features = false } object_store = { workspace = true } parquet = { workspace = true, default-features = true } rand = { workspace = true } -serde = { version = "1.0.219", features = ["derive"] } +regex.workspace = true +serde = { version = "1.0.228", features = ["derive"] } serde_json = { workspace = true } snmalloc-rs = { version = "0.3", optional = true } -structopt = { version = "0.3", default-features = false } -test-utils = { path = "../test-utils/", version = "0.1.0" } tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot"] } -tokio-util = { version = "0.7.15" } +tokio-util = { version = "0.7.17" } [dev-dependencies] datafusion-proto = { workspace = true } diff --git a/benchmarks/README.md b/benchmarks/README.md index b19b3385afc83..3aa4f4bb8640c 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -87,6 +87,38 @@ To run for specific query, for example Q21 ./bench.sh run tpch10 21 ``` +## Compile profile benchmark + +Generate the data required for the compile profile helper (TPC-H SF=1): + +```shell +./bench.sh data compile_profile +``` + +Run the benchmark across all default Cargo profiles (`dev`, `release`, `ci`, `release-nonlto`): + +```shell +./bench.sh run compile_profile +``` + +Limit the run to a single profile: + +```shell +./bench.sh run compile_profile dev +``` + +Or specify a subset of profiles: + +```shell +./bench.sh run compile_profile dev release +``` + +You can also invoke the helper directly if you need to customise arguments further: + +```shell +./benchmarks/compile_profile.py --profiles dev release --data /path/to/tpch_sf1 +``` + ## Benchmark with modified configurations ### Select join algorithm @@ -114,6 +146,19 @@ To verify that datafusion picked up your configuration, run the benchmarks with ## Comparing performance of main and a branch +For TPC-H +```shell +./benchmarks/compare_tpch.sh main mybranch +``` + +For TPC-DS. +To get data in `DATA_DIR` for TPCDS, please follow instructions in `./benchmarks/bench.sh data tcpds` +```shell +DATA_DIR=../../datafusion-benchmarks/tpcds/data/sf1/ ./benchmarks/compare_tpcds.sh main mybranch +``` + +Alternatively, you can compare manually following the example below + ```shell git checkout main @@ -195,6 +240,23 @@ Benchmark tpch_mem.json └──────────────┴──────────────┴──────────────┴───────────────┘ ``` +## Comparing performance of main and a PR + +### TPCDS + +Considering you already have TPCDS data locally + +```shell +export DATA_DIR=../../datafusion-benchmarks/tpcds/data/sf1/ +export PR_NUMBER=19464 +git fetch upstream pull/$PR_NUMBER/head:pr-$PR_NUMBER +git checkout main +git pull +./benchmarks/compare_tpcds.sh main pr-$PR_NUMBER +``` + +Note: if `gh` is installed, you can also run `gh pr checkout $PR_NUMBER` instead of `git fetch upstream pull/$PR_NUMBER/head:pr-$PR_NUMBER` + ### Running Benchmarks Manually Assuming data is in the `data` directory, the `tpch` benchmark can be run with a command like this: @@ -210,28 +272,11 @@ See the help for more details. You can enable `mimalloc` or `snmalloc` (to use either the mimalloc or snmalloc allocator) as features by passing them in as `--features`. For example: ```shell -cargo run --release --features "mimalloc" --bin tpch -- benchmark datafusion --iterations 3 --path ./data --format tbl --query 1 --batch-size 4096 -``` - -The benchmark program also supports CSV and Parquet input file formats and a utility is provided to convert from `tbl` -(generated by the `dbgen` utility) to CSV and Parquet. - -```bash -cargo run --release --bin tpch -- convert --input ./data --output /mnt/tpch-parquet --format parquet +cargo run --release --features "mimalloc" --bin dfbench tpch --iterations 3 --path ./data --format tbl --query 1 --batch-size 4096 ``` Or if you want to verify and run all the queries in the benchmark, you can just run `cargo test`. -#### Sorted Conversion - -The TPCH tables generated by the dbgen utility are sorted by their first column (their primary key for most tables, the `l_orderkey` column for the `lineitem` table.) - -To preserve this sorted order information during conversion (useful for benchmarking execution on pre-sorted data) include the `--sort` flag: - -```bash -cargo run --release --bin tpch -- convert --input ./data --output /mnt/tpch-sorted-parquet --format parquet --sort -``` - ### Comparing results between runs Any `dfbench` execution with `-o ` argument will produce a @@ -321,6 +366,72 @@ FLAGS: ... ``` +# Profiling Memory Stats for each benchmark query + +The `mem_profile` program wraps benchmark execution to measure memory usage statistics, such as peak RSS. It runs each benchmark query in a separate subprocess, capturing the child process’s stdout to print structured output. + +Subcommands supported by mem_profile are the subset of those in `dfbench`. +Currently supported benchmarks include: Clickbench, H2o, Imdb, SortTpch, Tpch, TPCDS + +Before running benchmarks, `mem_profile` automatically compiles the benchmark binary (`dfbench`) using `cargo build`. Note that the build profile used for `dfbench` is not tied to the profile used for running `mem_profile` itself. We can explicitly specify the desired build profile using the `--bench-profile` option (e.g. release-nonlto). By prebuilding the binary and running each query in a separate process, we can ensure accurate memory statistics. + +Currently, `mem_profile` only supports `mimalloc` as the memory allocator, since it relies on `mimalloc`'s API to collect memory statistics. + +Because it runs the compiled binary directly from the target directory, make sure your working directory is the top-level datafusion/ directory, where the target/ is also located. + +The benchmark subcommand (e.g., `tpch`) and all following arguments are passed directly to `dfbench`. Be sure to specify `--bench-profile` before the benchmark subcommand. + +Example: + +```shell +datafusion$ cargo run --profile release-nonlto --bin mem_profile -- --bench-profile release-nonlto tpch --path benchmarks/data/tpch_sf1 --partitions 4 --format parquet +``` + +Example Output: + +``` +Query Time (ms) Peak RSS Peak Commit Major Page Faults +---------------------------------------------------------------- +1 503.42 283.4 MB 3.0 GB 0 +2 431.09 240.7 MB 3.0 GB 0 +3 594.28 350.1 MB 3.0 GB 0 +4 468.90 462.4 MB 3.0 GB 0 +5 653.58 385.4 MB 3.0 GB 0 +6 296.79 247.3 MB 2.0 GB 0 +7 662.32 652.4 MB 3.0 GB 0 +8 702.48 396.0 MB 3.0 GB 0 +9 774.21 611.5 MB 3.0 GB 0 +10 733.62 397.9 MB 3.0 GB 0 +11 271.71 209.6 MB 3.0 GB 0 +12 512.60 212.5 MB 2.0 GB 0 +13 507.83 381.5 MB 2.0 GB 0 +14 420.89 313.5 MB 3.0 GB 0 +15 539.97 288.0 MB 2.0 GB 0 +16 370.91 229.8 MB 3.0 GB 0 +17 758.33 467.0 MB 2.0 GB 0 +18 1112.32 638.9 MB 3.0 GB 0 +19 712.72 280.9 MB 2.0 GB 0 +20 620.64 402.9 MB 2.9 GB 0 +21 971.63 388.9 MB 2.9 GB 0 +22 404.50 164.8 MB 2.0 GB 0 +``` + +## Reported Metrics + +When running benchmarks, `mem_profile` collects several memory-related statistics using the mimalloc API: + +- Peak RSS (Resident Set Size): + The maximum amount of physical memory used by the process. + This is a process-level metric collected via OS-specific mechanisms and is not mimalloc-specific. + +- Peak Commit: + The peak amount of memory committed by the allocator (i.e., total virtual memory reserved). + This is mimalloc-specific. It gives a more allocator-aware view of memory usage than RSS. + +- Major Page Faults: + The number of major page faults triggered during execution. + This metric is obtained from the operating system and is not mimalloc-specific. + # Writing a new benchmark ## Creating or downloading data outside of the benchmark @@ -379,37 +490,6 @@ Your benchmark should create and use an instance of `BenchmarkRun` defined in `b The output of `dfbench` help includes a description of each benchmark, which is reproduced here for convenience. -## Cancellation - -Test performance of cancelling queries. - -Queries in DataFusion should stop executing "quickly" after they are -cancelled (the output stream is dropped). - -The queries are executed on a synthetic dataset generated during -the benchmark execution that is an anonymized version of a -real-world data set. - -The query is an anonymized version of a real-world query, and the -test starts the query then cancels it and reports how long it takes -for the runtime to fully exit. - -Example output: - -``` -Using 7 files found on disk -Starting to load data into in-memory object store -Done loading data into in-memory object store -in main, sleeping -Starting spawned -Creating logical plan... -Creating physical plan... -Executing physical plan... -Getting results... -cancelling thread -done dropping runtime in 83.531417ms -``` - ## ClickBench The ClickBench[1] benchmarks are widely cited in the industry and @@ -510,6 +590,14 @@ See [`sort_tpch.rs`](src/sort_tpch.rs) for more details. ./bench.sh run sort_tpch ``` +### TopK TPCH + +In addition, topk_tpch is available from the bench.sh script: + +```bash +./bench.sh run topk_tpch +``` + ## IMDB Run Join Order Benchmark (JOB) on IMDB dataset. @@ -532,6 +620,34 @@ This benchmarks is derived from the [TPC-H][1] version [2]: https://github.com/databricks/tpch-dbgen.git, [2.17.1]: https://www.tpc.org/tpc_documents_current_versions/pdf/tpc-h_v2.17.1.pdf +## TPCDS + +Run the tpcds benchmark. + +For data please clone `datafusion-benchmarks` repo which contains the predefined parquet data with SF1. + +```shell +git clone https://github.com/apache/datafusion-benchmarks +``` + +Then run the benchmark with the following command: + +```shell +DATA_DIR=../../datafusion-benchmarks/tpcds/data/sf1/ ./benchmarks/bench.sh run tpcds +``` + +Alternatively benchmark the specific query + +```shell +DATA_DIR=../../datafusion-benchmarks/tpcds/data/sf1/ ./benchmarks/bench.sh run tpcds 30 +``` + +More help + +```shell +cargo run --release --bin dfbench -- tpcds --help +``` + ## External Aggregation Run the benchmark for aggregations with limited memory. @@ -672,3 +788,115 @@ For example, to run query 1 with the small data generated above: ```bash cargo run --release --bin dfbench -- h2o --join-paths ./benchmarks/data/h2o/J1_1e7_NA_0.csv,./benchmarks/data/h2o/J1_1e7_1e1_0.csv,./benchmarks/data/h2o/J1_1e7_1e4_0.csv,./benchmarks/data/h2o/J1_1e7_1e7_NA.csv --queries-path ./benchmarks/queries/h2o/window.sql --query 1 ``` + +# Micro-Benchmarks + +## Nested Loop Join + +This benchmark focuses on the performance of queries with nested loop joins, minimizing other overheads such as scanning data sources or evaluating predicates. + +Different queries are included to test nested loop joins under various workloads. + +### Example Run + +```bash +# No need to generate data: this benchmark uses table function `range()` as the data source + +./bench.sh run nlj +``` + +## Hash Join + +This benchmark focuses on the performance of queries with hash joins, minimizing other overheads such as scanning data sources or evaluating predicates. + +Several queries are included to test hash joins under various workloads. + +### Example Run + +```bash +# No need to generate data: this benchmark uses table function `range()` as the data source + +./bench.sh run hj +``` + +## Sort Merge Join + +This benchmark focuses on the performance of queries with sort merge joins joins, minimizing other overheads such as scanning data sources or evaluating predicates. + +Several queries are included to test sort merge joins under various workloads. + +### Example Run + +```bash +# No need to generate data: this benchmark uses table function `range()` as the data source + +./bench.sh run smj +``` +## Cancellation + +Test performance of cancelling queries. + +Queries in DataFusion should stop executing "quickly" after they are +cancelled (the output stream is dropped). + +The queries are executed on a synthetic dataset generated during +the benchmark execution that is an anonymized version of a +real-world data set. + +The query is an anonymized version of a real-world query, and the +test starts the query then cancels it and reports how long it takes +for the runtime to fully exit. + +Example output: + +``` +Using 7 files found on disk +Starting to load data into in-memory object store +Done loading data into in-memory object store +in main, sleeping +Starting spawned +Creating logical plan... +Creating physical plan... +Executing physical plan... +Getting results... +cancelling thread +done dropping runtime in 83.531417ms +``` + +## Sorted Data Benchmarks + +### Data Sorted ClickBench + +Benchmark for queries on pre-sorted data to test sort order optimization. +This benchmark uses a subset of the ClickBench dataset (hits.parquet, ~14GB) that has been pre-sorted by the EventTime column. The queries are designed to test DataFusion's performance when the data is already sorted as is common in timeseries workloads. + +The benchmark includes queries that: +- Scan pre-sorted data with ORDER BY clauses that match the sort order +- Test reverse scans on sorted data +- Verify the performance result + +#### Generating Sorted Data + +The sorted dataset is automatically generated from the ClickBench partitioned dataset. You can configure the memory used during the sorting process with the `DATAFUSION_MEMORY_GB` environment variable. The default memory limit is 12GB. +```bash +./bench.sh data clickbench_sorted +``` + +To create the sorted dataset, for example with 16GB of memory, run: + +```bash +DATAFUSION_MEMORY_GB=16 ./bench.sh data clickbench_sorted +``` + +This command will: +1. Download the ClickBench partitioned dataset if not present +2. Sort hits.parquet by EventTime in ascending order +3. Save the sorted file as hits_sorted.parquet + +#### Running the Benchmark + +```bash +./bench.sh run clickbench_sorted +``` + +This runs queries against the pre-sorted dataset with the `--sorted-by EventTime` flag, which informs DataFusion that the data is pre-sorted, allowing it to optimize away redundant sort operations. diff --git a/benchmarks/bench.sh b/benchmarks/bench.sh index 6f8cac2b6bfd5..0fc6ede3b3af4 100755 --- a/benchmarks/bench.sh +++ b/benchmarks/bench.sh @@ -28,6 +28,12 @@ set -e # https://stackoverflow.com/questions/59895/how-do-i-get-the-directory-where-a-bash-script-is-located-from-within-the-script SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +# Execute command and also print it, for debugging purposes +debug_run() { + set -x + "$@" + set +x +} # Set Defaults COMMAND= @@ -36,71 +42,113 @@ DATAFUSION_DIR=${DATAFUSION_DIR:-$SCRIPT_DIR/..} DATA_DIR=${DATA_DIR:-$SCRIPT_DIR/data} CARGO_COMMAND=${CARGO_COMMAND:-"cargo run --release"} PREFER_HASH_JOIN=${PREFER_HASH_JOIN:-true} -VIRTUAL_ENV=${VIRTUAL_ENV:-$SCRIPT_DIR/venv} +SIMULATE_LATENCY=${SIMULATE_LATENCY:-false} + +# Build latency arg based on SIMULATE_LATENCY setting +LATENCY_ARG="" +if [ "$SIMULATE_LATENCY" = "true" ]; then + LATENCY_ARG="--simulate-latency" +fi usage() { echo " Orchestrates running benchmarks against DataFusion checkouts Usage: -$0 data [benchmark] [query] -$0 run [benchmark] +$0 data [benchmark] +$0 run [benchmark] [query] $0 compare -$0 venv +$0 compare_detail -********** +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ Examples: -********** +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ # Create the datasets for all benchmarks in $DATA_DIR ./bench.sh data # Run the 'tpch' benchmark on the datafusion checkout in /source/datafusion DATAFUSION_DIR=/source/datafusion ./bench.sh run tpch -********** -* Commands -********** -data: Generates or downloads data needed for benchmarking -run: Runs the named benchmark -compare: Compares results from benchmark runs -venv: Creates new venv (unless already exists) and installs compare's requirements into it - -********** -* Benchmarks -********** +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +Commands +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +data: Generates or downloads data needed for benchmarking +run: Runs the named benchmark +compare: Compares fastest results from benchmark runs +compare_detail: Compares minimum, average (±stddev), and maximum results from benchmark runs + +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +Benchmarks +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# Run all of the following benchmarks all(default): Data/Run/Compare for all benchmarks + +# TPC-H Benchmarks tpch: TPCH inspired benchmark on Scale Factor (SF) 1 (~1GB), single parquet file per table, hash join +tpch_csv: TPCH inspired benchmark on Scale Factor (SF) 1 (~1GB), single csv file per table, hash join tpch_mem: TPCH inspired benchmark on Scale Factor (SF) 1 (~1GB), query from memory tpch10: TPCH inspired benchmark on Scale Factor (SF) 10 (~10GB), single parquet file per table, hash join +tpch_csv10: TPCH inspired benchmark on Scale Factor (SF) 10 (~10GB), single csv file per table, hash join tpch_mem10: TPCH inspired benchmark on Scale Factor (SF) 10 (~10GB), query from memory -cancellation: How long cancelling a query takes -parquet: Benchmark of parquet reader's filtering speed -sort: Benchmark of sorting speed -sort_tpch: Benchmark of sorting speed for end-to-end sort queries on TPCH dataset + +# TPC-DS Benchmarks +tpcds: TPCDS inspired benchmark on Scale Factor (SF) 1 (~1GB), single parquet file per table, hash join + +# Extended TPC-H Benchmarks +sort_tpch: Benchmark of sorting speed for end-to-end sort queries on TPC-H dataset (SF=1) +sort_tpch10: Benchmark of sorting speed for end-to-end sort queries on TPC-H dataset (SF=10) +topk_tpch: Benchmark of top-k (sorting with limit) queries on TPC-H dataset (SF=1) +external_aggr: External aggregation benchmark on TPC-H dataset (SF=1) + +# ClickBench Benchmarks clickbench_1: ClickBench queries against a single parquet file -clickbench_partitioned: ClickBench queries against a partitioned (100 files) parquet +clickbench_partitioned: ClickBench queries against partitioned (100 files) parquet +clickbench_pushdown: ClickBench queries against partitioned (100 files) parquet w/ filter_pushdown enabled clickbench_extended: ClickBench \"inspired\" queries against a single parquet (DataFusion specific) -external_aggr: External aggregation benchmark -h2o_small: h2oai benchmark with small dataset (1e7 rows) for groupby, default file format is csv -h2o_medium: h2oai benchmark with medium dataset (1e8 rows) for groupby, default file format is csv -h2o_big: h2oai benchmark with large dataset (1e9 rows) for groupby, default file format is csv -h2o_small_join: h2oai benchmark with small dataset (1e7 rows) for join, default file format is csv -h2o_medium_join: h2oai benchmark with medium dataset (1e8 rows) for join, default file format is csv -h2o_big_join: h2oai benchmark with large dataset (1e9 rows) for join, default file format is csv -h2o_small_window: Extended h2oai benchmark with small dataset (1e7 rows) for window, default file format is csv -h2o_medium_window: Extended h2oai benchmark with medium dataset (1e8 rows) for window, default file format is csv -h2o_big_window: Extended h2oai benchmark with large dataset (1e9 rows) for window, default file format is csv + +# Sorted Data Benchmarks (ORDER BY Optimization) +clickbench_sorted: ClickBench queries on pre-sorted data using prefer_existing_sort (tests sort elimination optimization) + +# H2O.ai Benchmarks (Group By, Join, Window) +h2o_small: h2oai benchmark with small dataset (1e7 rows) for groupby, default file format is csv +h2o_medium: h2oai benchmark with medium dataset (1e8 rows) for groupby, default file format is csv +h2o_big: h2oai benchmark with large dataset (1e9 rows) for groupby, default file format is csv +h2o_small_join: h2oai benchmark with small dataset (1e7 rows) for join, default file format is csv +h2o_medium_join: h2oai benchmark with medium dataset (1e8 rows) for join, default file format is csv +h2o_big_join: h2oai benchmark with large dataset (1e9 rows) for join, default file format is csv +h2o_small_window: Extended h2oai benchmark with small dataset (1e7 rows) for window, default file format is csv +h2o_medium_window: Extended h2oai benchmark with medium dataset (1e8 rows) for window, default file format is csv +h2o_big_window: Extended h2oai benchmark with large dataset (1e9 rows) for window, default file format is csv +h2o_small_parquet: h2oai benchmark with small dataset (1e7 rows) for groupby, file format is parquet +h2o_medium_parquet: h2oai benchmark with medium dataset (1e8 rows) for groupby, file format is parquet +h2o_big_parquet: h2oai benchmark with large dataset (1e9 rows) for groupby, file format is parquet +h2o_small_join_parquet: h2oai benchmark with small dataset (1e7 rows) for join, file format is parquet +h2o_medium_join_parquet: h2oai benchmark with medium dataset (1e8 rows) for join, file format is parquet +h2o_big_join_parquet: h2oai benchmark with large dataset (1e9 rows) for join, file format is parquet +h2o_small_window_parquet: Extended h2oai benchmark with small dataset (1e7 rows) for window, file format is parquet +h2o_medium_window_parquet: Extended h2oai benchmark with medium dataset (1e8 rows) for window, file format is parquet +h2o_big_window_parquet: Extended h2oai benchmark with large dataset (1e9 rows) for window, file format is parquet + +# Join Order Benchmark (IMDB) imdb: Join Order Benchmark (JOB) using the IMDB dataset converted to parquet -********** -* Supported Configuration (Environment Variables) -********** +# Micro-Benchmarks (specific operators and features) +cancellation: How long cancelling a query takes +nlj: Benchmark for simple nested loop joins, testing various join scenarios +hj: Benchmark for simple hash joins, testing various join scenarios +smj: Benchmark for simple sort merge joins, testing various join scenarios +compile_profile: Compile and execute TPC-H across selected Cargo profiles, reporting timing and binary size + + +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +Supported Configuration (Environment Variables) +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ DATA_DIR directory to store datasets CARGO_COMMAND command that runs the benchmark binary DATAFUSION_DIR directory to use (default $DATAFUSION_DIR) RESULTS_NAME folder where the benchmark files are stored PREFER_HASH_JOIN Prefer hash join algorithm (default true) -VENV_PATH Python venv to use for compare and venv commands (default ./venv, override by /bin/activate) +SIMULATE_LATENCY Simulate object store latency to mimic S3 (default false) DATAFUSION_* Set the given datafusion configuration " exit 1 @@ -152,8 +200,8 @@ main() { echo "***************************" case "$BENCHMARK" in all) - data_tpch "1" - data_tpch "10" + data_tpch "1" "parquet" + data_tpch "10" "parquet" data_h2o "SMALL" data_h2o "MEDIUM" data_h2o "BIG" @@ -163,20 +211,28 @@ main() { data_clickbench_1 data_clickbench_partitioned data_imdb + # nlj uses range() function, no data generation needed ;; tpch) - data_tpch "1" + data_tpch "1" "parquet" ;; tpch_mem) - # same data as for tpch - data_tpch "1" + data_tpch "1" "parquet" + ;; + tpch_csv) + data_tpch "1" "csv" ;; tpch10) - data_tpch "10" + data_tpch "10" "parquet" ;; tpch_mem10) - # same data as for tpch10 - data_tpch "10" + data_tpch "10" "parquet" + ;; + tpch_csv10) + data_tpch "10" "csv" + ;; + tpcds) + data_tpcds ;; clickbench_1) data_clickbench_1 @@ -184,6 +240,9 @@ main() { clickbench_partitioned) data_clickbench_partitioned ;; + clickbench_pushdown) + data_clickbench_partitioned # same data as clickbench_partitioned + ;; clickbench_extended) data_clickbench_1 ;; @@ -218,13 +277,66 @@ main() { h2o_big_window) data_h2o_join "BIG" "CSV" ;; + h2o_small_parquet) + data_h2o "SMALL" "PARQUET" + ;; + h2o_medium_parquet) + data_h2o "MEDIUM" "PARQUET" + ;; + h2o_big_parquet) + data_h2o "BIG" "PARQUET" + ;; + h2o_small_join_parquet) + data_h2o_join "SMALL" "PARQUET" + ;; + h2o_medium_join_parquet) + data_h2o_join "MEDIUM" "PARQUET" + ;; + h2o_big_join_parquet) + data_h2o_join "BIG" "PARQUET" + ;; + # h2o window benchmark uses the same data as the h2o join + h2o_small_window_parquet) + data_h2o_join "SMALL" "PARQUET" + ;; + h2o_medium_window_parquet) + data_h2o_join "MEDIUM" "PARQUET" + ;; + h2o_big_window_parquet) + data_h2o_join "BIG" "PARQUET" + ;; external_aggr) # same data as for tpch - data_tpch "1" + data_tpch "1" "parquet" ;; sort_tpch) # same data as for tpch - data_tpch "1" + data_tpch "1" "parquet" + ;; + sort_tpch10) + # same data as for tpch10 + data_tpch "10" "parquet" + ;; + topk_tpch) + # same data as for tpch + data_tpch "1" "parquet" + ;; + nlj) + # nlj uses range() function, no data generation needed + echo "NLJ benchmark does not require data generation" + ;; + hj) + data_tpch "10" "parquet" + ;; + smj) + # smj uses range() function, no data generation needed + echo "SMJ benchmark does not require data generation" + ;; + compile_profile) + data_tpch "1" "parquet" + ;; + clickbench_sorted) + clickbench_sorted ;; *) echo "Error: unknown benchmark '$BENCHMARK' for data generation" @@ -235,6 +347,18 @@ main() { run) # Parse positional parameters BENCHMARK=${ARG2:-"${BENCHMARK}"} + EXTRA_ARGS=("${POSITIONAL_ARGS[@]:2}") + PROFILE_ARGS=() + QUERY="" + QUERY_ARG="" + if [ "$BENCHMARK" = "compile_profile" ]; then + PROFILE_ARGS=("${EXTRA_ARGS[@]}") + else + QUERY=${EXTRA_ARGS[0]} + if [ -n "$QUERY" ]; then + QUERY_ARG="--query ${QUERY}" + fi + fi BRANCH_NAME=$(cd "${DATAFUSION_DIR}" && git rev-parse --abbrev-ref HEAD) BRANCH_NAME=${BRANCH_NAME//\//_} # mind blowing syntax to replace / with _ RESULTS_NAME=${RESULTS_NAME:-"${BRANCH_NAME}"} @@ -244,12 +368,18 @@ main() { echo "DataFusion Benchmark Script" echo "COMMAND: ${COMMAND}" echo "BENCHMARK: ${BENCHMARK}" + if [ "$BENCHMARK" = "compile_profile" ]; then + echo "PROFILES: ${PROFILE_ARGS[*]:-All}" + else + echo "QUERY: ${QUERY:-All}" + fi echo "DATAFUSION_DIR: ${DATAFUSION_DIR}" echo "BRANCH_NAME: ${BRANCH_NAME}" echo "DATA_DIR: ${DATA_DIR}" echo "RESULTS_DIR: ${RESULTS_DIR}" echo "CARGO_COMMAND: ${CARGO_COMMAND}" echo "PREFER_HASH_JOIN: ${PREFER_HASH_JOIN}" + echo "SIMULATE_LATENCY: ${SIMULATE_LATENCY}" echo "***************************" # navigate to the appropriate directory @@ -258,15 +388,16 @@ main() { mkdir -p "${DATA_DIR}" case "$BENCHMARK" in all) - run_tpch "1" + run_tpch "1" "parquet" + run_tpch "1" "csv" run_tpch_mem "1" - run_tpch "10" + run_tpch "10" "parquet" + run_tpch "10" "csv" run_tpch_mem "10" run_cancellation - run_parquet - run_sort run_clickbench_1 run_clickbench_partitioned + run_clickbench_pushdown run_clickbench_extended run_h2o "SMALL" "PARQUET" "groupby" run_h2o "MEDIUM" "PARQUET" "groupby" @@ -276,34 +407,44 @@ main() { run_h2o_join "BIG" "PARQUET" "join" run_imdb run_external_aggr + run_nlj + run_hj + run_tpcds + run_smj ;; tpch) - run_tpch "1" + run_tpch "1" "parquet" + ;; + tpch_csv) + run_tpch "1" "csv" ;; tpch_mem) run_tpch_mem "1" ;; tpch10) - run_tpch "10" + run_tpch "10" "parquet" + ;; + tpch_csv10) + run_tpch "10" "csv" ;; tpch_mem10) run_tpch_mem "10" ;; + tpcds) + run_tpcds + ;; cancellation) run_cancellation ;; - parquet) - run_parquet - ;; - sort) - run_sort - ;; clickbench_1) run_clickbench_1 ;; clickbench_partitioned) run_clickbench_partitioned ;; + clickbench_pushdown) + run_clickbench_pushdown + ;; clickbench_extended) run_clickbench_extended ;; @@ -334,14 +475,63 @@ main() { h2o_medium_window) run_h2o_window "MEDIUM" "CSV" "window" ;; - h2o_big_window) + h2o_big_window) run_h2o_window "BIG" "CSV" "window" ;; + h2o_small_parquet) + run_h2o "SMALL" "PARQUET" + ;; + h2o_medium_parquet) + run_h2o "MEDIUM" "PARQUET" + ;; + h2o_big_parquet) + run_h2o "BIG" "PARQUET" + ;; + h2o_small_join_parquet) + run_h2o_join "SMALL" "PARQUET" + ;; + h2o_medium_join_parquet) + run_h2o_join "MEDIUM" "PARQUET" + ;; + h2o_big_join_parquet) + run_h2o_join "BIG" "PARQUET" + ;; + # h2o window benchmark uses the same data as the h2o join + h2o_small_window_parquet) + run_h2o_window "SMALL" "PARQUET" + ;; + h2o_medium_window_parquet) + run_h2o_window "MEDIUM" "PARQUET" + ;; + h2o_big_window_parquet) + run_h2o_window "BIG" "PARQUET" + ;; external_aggr) run_external_aggr ;; sort_tpch) - run_sort_tpch + run_sort_tpch "1" + ;; + sort_tpch10) + run_sort_tpch "10" + ;; + topk_tpch) + run_topk_tpch + ;; + nlj) + run_nlj + ;; + hj) + run_hj + ;; + smj) + run_smj + ;; + compile_profile) + run_compile_profile "${PROFILE_ARGS[@]}" + ;; + clickbench_sorted) + run_clickbench_sorted ;; *) echo "Error: unknown benchmark '$BENCHMARK' for run" @@ -354,8 +544,8 @@ main() { compare) compare_benchmarks "$ARG2" "$ARG3" ;; - venv) - setup_venv + compare_detail) + compare_benchmarks "$ARG2" "$ARG3" "--detailed" ;; "") usage @@ -372,7 +562,7 @@ main() { # Creates TPCH data at a certain scale factor, if it doesn't already # exist # -# call like: data_tpch($scale_factor) +# call like: data_tpch($scale_factor, format) # # Creates data in $DATA_DIR/tpch_sf1 for scale factor 1 # Creates data in $DATA_DIR/tpch_sf10 for scale factor 10 @@ -383,20 +573,23 @@ data_tpch() { echo "Internal error: Scale factor not specified" exit 1 fi + FORMAT=$2 + if [ -z "$FORMAT" ] ; then + echo "Internal error: Format not specified" + exit 1 + fi TPCH_DIR="${DATA_DIR}/tpch_sf${SCALE_FACTOR}" - echo "Creating tpch dataset at Scale Factor ${SCALE_FACTOR} in ${TPCH_DIR}..." + echo "Creating tpch $FORMAT dataset at Scale Factor ${SCALE_FACTOR} in ${TPCH_DIR}..." # Ensure the target data directory exists mkdir -p "${TPCH_DIR}" - # Create 'tbl' (CSV format) data into $DATA_DIR if it does not already exist - FILE="${TPCH_DIR}/supplier.tbl" - if test -f "${FILE}"; then - echo " tbl files exist ($FILE exists)." - else - echo " creating tbl files with tpch_dbgen..." - docker run -v "${TPCH_DIR}":/data -it --rm ghcr.io/scalytics/tpch-docker:main -vf -s "${SCALE_FACTOR}" + # check if tpchgen-cli is installed + if ! command -v tpchgen-cli &> /dev/null + then + echo "tpchgen-cli could not be found, please install it via 'cargo install tpchgen-cli'" + exit 1 fi # Copy expected answers into the ./data/answers directory if it does not already exist @@ -409,16 +602,52 @@ data_tpch() { docker run -v "${TPCH_DIR}":/data -it --entrypoint /bin/bash --rm ghcr.io/scalytics/tpch-docker:main -c "cp -f /opt/tpch/2.18.0_rc2/dbgen/answers/* /data/answers/" fi - # Create 'parquet' files from tbl - FILE="${TPCH_DIR}/supplier" - if test -d "${FILE}"; then - echo " parquet files exist ($FILE exists)." - else - echo " creating parquet files using benchmark binary ..." - pushd "${SCRIPT_DIR}" > /dev/null - $CARGO_COMMAND --bin tpch -- convert --input "${TPCH_DIR}" --output "${TPCH_DIR}" --format parquet - popd > /dev/null + if [ "$FORMAT" = "parquet" ]; then + # Create 'parquet' files, one directory per file + FILE="${TPCH_DIR}/supplier" + if test -d "${FILE}"; then + echo " parquet files exist ($FILE exists)." + else + echo " creating parquet files using tpchgen-cli ..." + tpchgen-cli --scale-factor "${SCALE_FACTOR}" --format parquet --parquet-compression='ZSTD(1)' --parts=1 --output-dir "${TPCH_DIR}" + fi + return + fi + + # Create 'csv' files, one directory per file + if [ "$FORMAT" = "csv" ]; then + FILE="${TPCH_DIR}/csv/supplier" + if test -d "${FILE}"; then + echo " csv files exist ($FILE exists)." + else + echo " creating csv files using tpchgen-cli binary ..." + tpchgen-cli --scale-factor "${SCALE_FACTOR}" --format csv --parts=1 --output-dir "${TPCH_DIR}/csv" + fi + return + fi + + echo "Error: unknown format '$FORMAT' for tpch data generation, expected 'parquet' or 'csv'" + exit 1 +} + +# Downloads TPC-DS data +data_tpcds() { + TPCDS_DIR="${DATA_DIR}/tpcds_sf1" + + # Check if `web_site.parquet` exists in the TPCDS data directory to verify data presence + echo "Checking TPC-DS data directory: ${TPCDS_DIR}" + if [ ! -f "${TPCDS_DIR}/web_site.parquet" ]; then + mkdir -p "${TPCDS_DIR}" + # Download the DataFusion benchmarks repository zip if it is not already downloaded + if [ ! -f "${DATA_DIR}/datafusion-benchmarks.zip" ]; then + echo "Downloading DataFusion benchmarks repository zip to: ${DATA_DIR}/datafusion-benchmarks.zip" + wget --timeout=30 --tries=3 -O "${DATA_DIR}/datafusion-benchmarks.zip" https://github.com/apache/datafusion-benchmarks/archive/refs/heads/main.zip + fi + echo "Extracting TPC-DS parquet data to ${TPCDS_DIR}..." + unzip -o -j -d "${TPCDS_DIR}" "${DATA_DIR}/datafusion-benchmarks.zip" datafusion-benchmarks-main/tpcds/data/sf1/* + echo "TPC-DS data extracted." fi + echo "Done." } # Runs the tpch benchmark @@ -433,15 +662,12 @@ run_tpch() { RESULTS_FILE="${RESULTS_DIR}/tpch_sf${SCALE_FACTOR}.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running tpch benchmark..." - # Optional query filter to run specific query - QUERY=$([ -n "$ARG3" ] && echo "--query $ARG3" || echo "") - # debug the target command - set -x - $CARGO_COMMAND --bin tpch -- benchmark datafusion --iterations 5 --path "${TPCH_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" --format parquet -o "${RESULTS_FILE}" $QUERY - set +x + + FORMAT=$2 + debug_run $CARGO_COMMAND --bin dfbench -- tpch --iterations 5 --path "${TPCH_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" --format ${FORMAT} -o "${RESULTS_FILE}" ${QUERY_ARG} ${LATENCY_ARG} } -# Runs the tpch in memory +# Runs the tpch in memory (needs tpch parquet data) run_tpch_mem() { SCALE_FACTOR=$1 if [ -z "$SCALE_FACTOR" ] ; then @@ -453,37 +679,50 @@ run_tpch_mem() { RESULTS_FILE="${RESULTS_DIR}/tpch_mem_sf${SCALE_FACTOR}.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running tpch_mem benchmark..." - # Optional query filter to run specific query - QUERY=$([ -n "$ARG3" ] && echo "--query $ARG3" || echo "") - # debug the target command - set -x # -m means in memory - $CARGO_COMMAND --bin tpch -- benchmark datafusion --iterations 5 --path "${TPCH_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" -m --format parquet -o "${RESULTS_FILE}" $QUERY - set +x + debug_run $CARGO_COMMAND --bin dfbench -- tpch --iterations 5 --path "${TPCH_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" -m --format parquet -o "${RESULTS_FILE}" ${QUERY_ARG} ${LATENCY_ARG} } -# Runs the cancellation benchmark -run_cancellation() { - RESULTS_FILE="${RESULTS_DIR}/cancellation.json" +# Runs the tpcds benchmark +run_tpcds() { + TPCDS_DIR="${DATA_DIR}/tpcds_sf1" + + # Check if TPCDS data directory and representative file exists + if [ ! -f "${TPCDS_DIR}/web_site.parquet" ]; then + echo "" >&2 + echo "Please prepare TPC-DS data first by following instructions:" >&2 + echo " ./bench.sh data tpcds" >&2 + echo "" >&2 + exit 1 + fi + + RESULTS_FILE="${RESULTS_DIR}/tpcds_sf1.json" echo "RESULTS_FILE: ${RESULTS_FILE}" - echo "Running cancellation benchmark..." - $CARGO_COMMAND --bin dfbench -- cancellation --iterations 5 --path "${DATA_DIR}/cancellation" -o "${RESULTS_FILE}" + echo "Running tpcds benchmark..." + + debug_run $CARGO_COMMAND --bin dfbench -- tpcds --iterations 5 --path "${TPCDS_DIR}" --query_path "../datafusion/core/tests/tpc-ds" --prefer_hash_join "${PREFER_HASH_JOIN}" -o "${RESULTS_FILE}" ${QUERY_ARG} ${LATENCY_ARG} } -# Runs the parquet filter benchmark -run_parquet() { - RESULTS_FILE="${RESULTS_DIR}/parquet.json" - echo "RESULTS_FILE: ${RESULTS_FILE}" - echo "Running parquet filter benchmark..." - $CARGO_COMMAND --bin parquet -- filter --path "${DATA_DIR}" --scale-factor 1.0 --iterations 5 -o "${RESULTS_FILE}" +# Runs the compile profile benchmark helper +run_compile_profile() { + local profiles=("$@") + local runner="${SCRIPT_DIR}/compile_profile.py" + local data_path="${DATA_DIR}/tpch_sf1" + + echo "Running compile profile benchmark..." + local cmd=(uv run python3 "${runner}" --data "${data_path}") + if [ ${#profiles[@]} -gt 0 ]; then + cmd+=(--profiles "${profiles[@]}") + fi + debug_run "${cmd[@]}" } -# Runs the sort benchmark -run_sort() { - RESULTS_FILE="${RESULTS_DIR}/sort.json" +# Runs the cancellation benchmark +run_cancellation() { + RESULTS_FILE="${RESULTS_DIR}/cancellation.json" echo "RESULTS_FILE: ${RESULTS_FILE}" - echo "Running sort benchmark..." - $CARGO_COMMAND --bin parquet -- sort --path "${DATA_DIR}" --scale-factor 1.0 --iterations 5 -o "${RESULTS_FILE}" + echo "Running cancellation benchmark..." + debug_run $CARGO_COMMAND --bin dfbench -- cancellation --iterations 5 --path "${DATA_DIR}/cancellation" -o "${RESULTS_FILE}" ${LATENCY_ARG} } @@ -537,23 +776,33 @@ run_clickbench_1() { RESULTS_FILE="${RESULTS_DIR}/clickbench_1.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running clickbench (1 file) benchmark..." - $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits.parquet" --queries-path "${SCRIPT_DIR}/queries/clickbench/queries.sql" -o "${RESULTS_FILE}" + debug_run $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits.parquet" --queries-path "${SCRIPT_DIR}/queries/clickbench/queries" -o "${RESULTS_FILE}" ${QUERY_ARG} ${LATENCY_ARG} } - # Runs the clickbench benchmark with the partitioned parquet files + # Runs the clickbench benchmark with the partitioned parquet dataset (100 files) run_clickbench_partitioned() { RESULTS_FILE="${RESULTS_DIR}/clickbench_partitioned.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running clickbench (partitioned, 100 files) benchmark..." - $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits_partitioned" --queries-path "${SCRIPT_DIR}/queries/clickbench/queries.sql" -o "${RESULTS_FILE}" + debug_run $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits_partitioned" --queries-path "${SCRIPT_DIR}/queries/clickbench/queries" -o "${RESULTS_FILE}" ${QUERY_ARG} ${LATENCY_ARG} +} + + + # Runs the clickbench benchmark with the partitioned parquet files and filter_pushdown enabled +run_clickbench_pushdown() { + RESULTS_FILE="${RESULTS_DIR}/clickbench_pushdown.json" + echo "RESULTS_FILE: ${RESULTS_FILE}" + echo "Running clickbench (partitioned, 100 files) benchmark with pushdown_filters=true, reorder_filters=true..." + debug_run $CARGO_COMMAND --bin dfbench -- clickbench --pushdown --iterations 5 --path "${DATA_DIR}/hits_partitioned" --queries-path "${SCRIPT_DIR}/queries/clickbench/queries" -o "${RESULTS_FILE}" ${QUERY_ARG} ${LATENCY_ARG} } + # Runs the clickbench "extended" benchmark with a single large parquet file run_clickbench_extended() { RESULTS_FILE="${RESULTS_DIR}/clickbench_extended.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running clickbench (1 file) extended benchmark..." - $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits.parquet" --queries-path "${SCRIPT_DIR}/queries/clickbench/extended.sql" -o "${RESULTS_FILE}" + debug_run $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits.parquet" --queries-path "${SCRIPT_DIR}/queries/clickbench/extended" -o "${RESULTS_FILE}" ${QUERY_ARG} ${LATENCY_ARG} } # Downloads the csv.gz files IMDB datasets from Peter Boncz's homepage(one of the JOB paper authors) @@ -668,7 +917,7 @@ run_imdb() { RESULTS_FILE="${RESULTS_DIR}/imdb.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running imdb benchmark..." - $CARGO_COMMAND --bin imdb -- benchmark datafusion --iterations 5 --path "${IMDB_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" --format parquet -o "${RESULTS_FILE}" + debug_run $CARGO_COMMAND --bin imdb -- benchmark datafusion --iterations 5 --path "${IMDB_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" --format parquet -o "${RESULTS_FILE}" ${QUERY_ARG} ${LATENCY_ARG} } data_h2o() { @@ -676,75 +925,13 @@ data_h2o() { SIZE=${1:-"SMALL"} DATA_FORMAT=${2:-"CSV"} - # Function to compare Python versions - version_ge() { - [ "$(printf '%s\n' "$1" "$2" | sort -V | head -n1)" = "$2" ] - } - - export PYO3_USE_ABI3_FORWARD_COMPATIBILITY=1 - - # Find the highest available Python version (3.10 or higher) - REQUIRED_VERSION="3.10" - PYTHON_CMD=$(command -v python3 || true) - - if [ -n "$PYTHON_CMD" ]; then - PYTHON_VERSION=$($PYTHON_CMD -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')") - if version_ge "$PYTHON_VERSION" "$REQUIRED_VERSION"; then - echo "Found Python version $PYTHON_VERSION, which is suitable." - else - echo "Python version $PYTHON_VERSION found, but version $REQUIRED_VERSION or higher is required." - PYTHON_CMD="" - fi - fi - - # Search for suitable Python versions if the default is unsuitable - if [ -z "$PYTHON_CMD" ]; then - # Loop through all available Python3 commands on the system - for CMD in $(compgen -c | grep -E '^python3(\.[0-9]+)?$'); do - if command -v "$CMD" &> /dev/null; then - PYTHON_VERSION=$($CMD -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')") - if version_ge "$PYTHON_VERSION" "$REQUIRED_VERSION"; then - PYTHON_CMD="$CMD" - echo "Found suitable Python version: $PYTHON_VERSION ($CMD)" - break - fi - fi - done - fi - - # If no suitable Python version found, exit with an error - if [ -z "$PYTHON_CMD" ]; then - echo "Python 3.10 or higher is required. Please install it." - return 1 - fi - - echo "Using Python command: $PYTHON_CMD" - - # Install falsa and other dependencies - echo "Installing falsa..." - - # Set virtual environment directory - VIRTUAL_ENV="${PWD}/venv" - - # Create a virtual environment using the detected Python command - $PYTHON_CMD -m venv "$VIRTUAL_ENV" - - # Activate the virtual environment and install dependencies - source "$VIRTUAL_ENV/bin/activate" - - # Ensure 'falsa' is installed (avoid unnecessary reinstall) - pip install --quiet --upgrade falsa - # Create directory if it doesn't exist H2O_DIR="${DATA_DIR}/h2o" mkdir -p "${H2O_DIR}" # Generate h2o test data echo "Generating h2o test data in ${H2O_DIR} with size=${SIZE} and format=${DATA_FORMAT}" - falsa groupby --path-prefix="${H2O_DIR}" --size "${SIZE}" --data-format "${DATA_FORMAT}" - - # Deactivate virtual environment after completion - deactivate + uv run falsa groupby --path-prefix="${H2O_DIR}" --size "${SIZE}" --data-format "${DATA_FORMAT}" } data_h2o_join() { @@ -752,75 +939,13 @@ data_h2o_join() { SIZE=${1:-"SMALL"} DATA_FORMAT=${2:-"CSV"} - # Function to compare Python versions - version_ge() { - [ "$(printf '%s\n' "$1" "$2" | sort -V | head -n1)" = "$2" ] - } - - export PYO3_USE_ABI3_FORWARD_COMPATIBILITY=1 - - # Find the highest available Python version (3.10 or higher) - REQUIRED_VERSION="3.10" - PYTHON_CMD=$(command -v python3 || true) - - if [ -n "$PYTHON_CMD" ]; then - PYTHON_VERSION=$($PYTHON_CMD -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')") - if version_ge "$PYTHON_VERSION" "$REQUIRED_VERSION"; then - echo "Found Python version $PYTHON_VERSION, which is suitable." - else - echo "Python version $PYTHON_VERSION found, but version $REQUIRED_VERSION or higher is required." - PYTHON_CMD="" - fi - fi - - # Search for suitable Python versions if the default is unsuitable - if [ -z "$PYTHON_CMD" ]; then - # Loop through all available Python3 commands on the system - for CMD in $(compgen -c | grep -E '^python3(\.[0-9]+)?$'); do - if command -v "$CMD" &> /dev/null; then - PYTHON_VERSION=$($CMD -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')") - if version_ge "$PYTHON_VERSION" "$REQUIRED_VERSION"; then - PYTHON_CMD="$CMD" - echo "Found suitable Python version: $PYTHON_VERSION ($CMD)" - break - fi - fi - done - fi - - # If no suitable Python version found, exit with an error - if [ -z "$PYTHON_CMD" ]; then - echo "Python 3.10 or higher is required. Please install it." - return 1 - fi - - echo "Using Python command: $PYTHON_CMD" - - # Install falsa and other dependencies - echo "Installing falsa..." - - # Set virtual environment directory - VIRTUAL_ENV="${PWD}/venv" - - # Create a virtual environment using the detected Python command - $PYTHON_CMD -m venv "$VIRTUAL_ENV" - - # Activate the virtual environment and install dependencies - source "$VIRTUAL_ENV/bin/activate" - - # Ensure 'falsa' is installed (avoid unnecessary reinstall) - pip install --quiet --upgrade falsa - # Create directory if it doesn't exist H2O_DIR="${DATA_DIR}/h2o" mkdir -p "${H2O_DIR}" # Generate h2o test data echo "Generating h2o test data in ${H2O_DIR} with size=${SIZE} and format=${DATA_FORMAT}" - falsa join --path-prefix="${H2O_DIR}" --size "${SIZE}" --data-format "${DATA_FORMAT}" - - # Deactivate virtual environment after completion - deactivate + uv run falsa join --path-prefix="${H2O_DIR}" --size "${SIZE}" --data-format "${DATA_FORMAT}" } # Runner for h2o groupby benchmark @@ -859,11 +984,12 @@ run_h2o() { QUERY_FILE="${SCRIPT_DIR}/queries/h2o/${RUN_Type}.sql" # Run the benchmark using the dynamically constructed file path and query file - $CARGO_COMMAND --bin dfbench -- h2o \ + debug_run $CARGO_COMMAND --bin dfbench -- h2o \ --iterations 3 \ --path "${H2O_DIR}/${FILE_NAME}" \ --queries-path "${QUERY_FILE}" \ - -o "${RESULTS_FILE}" + -o "${RESULTS_FILE}" \ + ${QUERY_ARG} ${LATENCY_ARG} } # Utility function to run h2o join/window benchmark @@ -910,11 +1036,12 @@ h2o_runner() { # Set the query file name based on the RUN_Type QUERY_FILE="${SCRIPT_DIR}/queries/h2o/${RUN_Type}.sql" - $CARGO_COMMAND --bin dfbench -- h2o \ + debug_run $CARGO_COMMAND --bin dfbench -- h2o \ --iterations 3 \ --join-paths "${H2O_DIR}/${X_TABLE_FILE_NAME},${H2O_DIR}/${SMALL_TABLE_FILE_NAME},${H2O_DIR}/${MEDIUM_TABLE_FILE_NAME},${H2O_DIR}/${LARGE_TABLE_FILE_NAME}" \ --queries-path "${QUERY_FILE}" \ - -o "${RESULTS_FILE}" + -o "${RESULTS_FILE}" \ + ${QUERY_ARG} ${LATENCY_ARG} } # Runners for h2o join benchmark @@ -940,17 +1067,57 @@ run_external_aggr() { # number-of-partitions), and by default `--partitions` is set to number of # CPU cores, we set a constant number of partitions to prevent this # benchmark to fail on some machines. - $CARGO_COMMAND --bin external_aggr -- benchmark --partitions 4 --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" + debug_run $CARGO_COMMAND --bin external_aggr -- benchmark --partitions 4 --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" ${QUERY_ARG} } # Runs the sort integration benchmark run_sort_tpch() { - TPCH_DIR="${DATA_DIR}/tpch_sf1" - RESULTS_FILE="${RESULTS_DIR}/sort_tpch.json" + SCALE_FACTOR=$1 + if [ -z "$SCALE_FACTOR" ] ; then + echo "Internal error: Scale factor not specified" + exit 1 + fi + TPCH_DIR="${DATA_DIR}/tpch_sf${SCALE_FACTOR}" + RESULTS_FILE="${RESULTS_DIR}/sort_tpch${SCALE_FACTOR}.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running sort tpch benchmark..." - $CARGO_COMMAND --bin dfbench -- sort-tpch --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" + debug_run $CARGO_COMMAND --bin dfbench -- sort-tpch --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" ${QUERY_ARG} ${LATENCY_ARG} +} + +# Runs the sort tpch integration benchmark with limit 100 (topk) +run_topk_tpch() { + TPCH_DIR="${DATA_DIR}/tpch_sf1" + RESULTS_FILE="${RESULTS_DIR}/run_topk_tpch.json" + echo "RESULTS_FILE: ${RESULTS_FILE}" + echo "Running topk tpch benchmark..." + + $CARGO_COMMAND --bin dfbench -- sort-tpch --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" --limit 100 ${QUERY_ARG} ${LATENCY_ARG} +} + +# Runs the nlj benchmark +run_nlj() { + RESULTS_FILE="${RESULTS_DIR}/nlj.json" + echo "RESULTS_FILE: ${RESULTS_FILE}" + echo "Running nlj benchmark..." + debug_run $CARGO_COMMAND --bin dfbench -- nlj --iterations 5 -o "${RESULTS_FILE}" ${QUERY_ARG} ${LATENCY_ARG} +} + +# Runs the hj benchmark +run_hj() { + TPCH_DIR="${DATA_DIR}/tpch_sf10" + RESULTS_FILE="${RESULTS_DIR}/hj.json" + echo "RESULTS_FILE: ${RESULTS_FILE}" + echo "Running hj benchmark..." + debug_run $CARGO_COMMAND --bin dfbench -- hj --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" ${QUERY_ARG} ${LATENCY_ARG} +} + +# Runs the smj benchmark +run_smj() { + RESULTS_FILE="${RESULTS_DIR}/smj.json" + echo "RESULTS_FILE: ${RESULTS_FILE}" + echo "Running smj benchmark..." + debug_run $CARGO_COMMAND --bin dfbench -- smj --iterations 5 -o "${RESULTS_FILE}" ${QUERY_ARG} ${LATENCY_ARG} } @@ -958,6 +1125,8 @@ compare_benchmarks() { BASE_RESULTS_DIR="${SCRIPT_DIR}/results" BRANCH1="$1" BRANCH2="$2" + OPTS="$3" + if [ -z "$BRANCH1" ] ; then echo " not specified. Available branches:" ls -1 "${BASE_RESULTS_DIR}" @@ -978,7 +1147,7 @@ compare_benchmarks() { echo "--------------------" echo "Benchmark ${BENCH}" echo "--------------------" - PATH=$VIRTUAL_ENV/bin:$PATH python3 "${SCRIPT_DIR}"/compare.py "${RESULTS_FILE1}" "${RESULTS_FILE2}" + uv run python3 "${SCRIPT_DIR}"/compare.py $OPTS "${RESULTS_FILE1}" "${RESULTS_FILE2}" else echo "Note: Skipping ${RESULTS_FILE1} as ${RESULTS_FILE2} does not exist" fi @@ -986,10 +1155,113 @@ compare_benchmarks() { } -setup_venv() { - python3 -m venv "$VIRTUAL_ENV" - PATH=$VIRTUAL_ENV/bin:$PATH python3 -m pip install -r requirements.txt +# Creates sorted ClickBench data from hits.parquet (full dataset) +# The data is sorted by EventTime in ascending order +# Uses datafusion-cli to reduce dependencies +clickbench_sorted() { + SORTED_FILE="${DATA_DIR}/hits_sorted.parquet" + ORIGINAL_FILE="${DATA_DIR}/hits.parquet" + + # Default memory limit is 12GB, can be overridden with DATAFUSION_MEMORY_GB env var + MEMORY_LIMIT_GB=${DATAFUSION_MEMORY_GB:-12} + + echo "Creating sorted ClickBench dataset from hits.parquet..." + echo "Configuration:" + echo " Memory limit: ${MEMORY_LIMIT_GB}G" + echo " Row group size: 64K rows" + echo " Compression: uncompressed" + + if [ ! -f "${ORIGINAL_FILE}" ]; then + echo "hits.parquet not found. Running data_clickbench_1 first..." + data_clickbench_1 + fi + + if [ -f "${SORTED_FILE}" ]; then + echo "Sorted hits.parquet already exists at ${SORTED_FILE}" + return 0 + fi + + echo "Sorting hits.parquet by EventTime (this may take several minutes)..." + + pushd "${DATAFUSION_DIR}" > /dev/null + echo "Building datafusion-cli..." + cargo build --release --bin datafusion-cli + DATAFUSION_CLI="${DATAFUSION_DIR}/target/release/datafusion-cli" + popd > /dev/null + + + START_TIME=$(date +%s) + echo "Start time: $(date '+%Y-%m-%d %H:%M:%S')" + echo "Using datafusion-cli to create sorted parquet file..." + "${DATAFUSION_CLI}" << EOF +-- Memory and performance configuration +SET datafusion.runtime.memory_limit = '${MEMORY_LIMIT_GB}G'; +SET datafusion.execution.spill_compression = 'uncompressed'; +SET datafusion.execution.sort_spill_reservation_bytes = 10485760; -- 10MB +SET datafusion.execution.batch_size = 8192; +SET datafusion.execution.target_partitions = 1; + +-- Parquet output configuration +SET datafusion.execution.parquet.max_row_group_size = 65536; +SET datafusion.execution.parquet.compression = 'uncompressed'; + +-- Execute sort and write +COPY (SELECT * FROM '${ORIGINAL_FILE}' ORDER BY "EventTime") +TO '${SORTED_FILE}' +STORED AS PARQUET; +EOF + + local result=$? + + END_TIME=$(date +%s) + DURATION=$((END_TIME - START_TIME)) + echo "End time: $(date '+%Y-%m-%d %H:%M:%S')" + + if [ $result -eq 0 ]; then + echo "✓ Successfully created sorted ClickBench dataset" + + INPUT_SIZE=$(stat -f%z "${ORIGINAL_FILE}" 2>/dev/null || stat -c%s "${ORIGINAL_FILE}" 2>/dev/null) + OUTPUT_SIZE=$(stat -f%z "${SORTED_FILE}" 2>/dev/null || stat -c%s "${SORTED_FILE}" 2>/dev/null) + INPUT_MB=$((INPUT_SIZE / 1024 / 1024)) + OUTPUT_MB=$((OUTPUT_SIZE / 1024 / 1024)) + + echo " Input: ${INPUT_MB} MB" + echo " Output: ${OUTPUT_MB} MB" + + echo "" + echo "Time Statistics:" + echo " Total duration: ${DURATION} seconds ($(printf '%02d:%02d:%02d' $((DURATION/3600)) $((DURATION%3600/60)) $((DURATION%60))))" + echo " Throughput: $((INPUT_MB / DURATION)) MB/s" + + return 0 + else + echo "✗ Error: Failed to create sorted dataset" + echo "💡 Tip: Try increasing memory with: DATAFUSION_MEMORY_GB=16 ./bench.sh data clickbench_sorted" + return 1 + fi +} + +# Runs the sorted data benchmark with prefer_existing_sort configuration +run_clickbench_sorted() { + RESULTS_FILE="${RESULTS_DIR}/clickbench_sorted.json" + echo "RESULTS_FILE: ${RESULTS_FILE}" + echo "Running sorted data benchmark with prefer_existing_sort optimization..." + + # Ensure sorted data exists + clickbench_sorted + + # Run benchmark with prefer_existing_sort configuration + # This allows DataFusion to optimize away redundant sorts while maintaining parallelism + debug_run $CARGO_COMMAND --bin dfbench -- clickbench \ + --iterations 5 \ + --path "${DATA_DIR}/hits_sorted.parquet" \ + --queries-path "${SCRIPT_DIR}/queries/clickbench/queries/sorted_data" \ + --sorted-by "EventTime" \ + -c datafusion.optimizer.prefer_existing_sort=true \ + -o "${RESULTS_FILE}" \ + ${QUERY_ARG} ${LATENCY_ARG} } + # And start the process up main diff --git a/benchmarks/compare.py b/benchmarks/compare.py index 4b609c744d503..9ad1de980abe8 100755 --- a/benchmarks/compare.py +++ b/benchmarks/compare.py @@ -18,7 +18,9 @@ from __future__ import annotations +import argparse import json +import math from dataclasses import dataclass from typing import Dict, List, Any from pathlib import Path @@ -47,6 +49,7 @@ class QueryRun: query: int iterations: List[QueryResult] start_time: int + success: bool = True @classmethod def load_from(cls, data: Dict[str, Any]) -> QueryRun: @@ -54,17 +57,57 @@ def load_from(cls, data: Dict[str, Any]) -> QueryRun: query=data["query"], iterations=[QueryResult(**iteration) for iteration in data["iterations"]], start_time=data["start_time"], + success=data.get("success", True), ) @property - def execution_time(self) -> float: + def min_execution_time(self) -> float: assert len(self.iterations) >= 1 - # Use minimum execution time to account for variations / other - # things the system was doing return min(iteration.elapsed for iteration in self.iterations) + @property + def max_execution_time(self) -> float: + assert len(self.iterations) >= 1 + + return max(iteration.elapsed for iteration in self.iterations) + + + @property + def mean_execution_time(self) -> float: + assert len(self.iterations) >= 1 + + total = sum(iteration.elapsed for iteration in self.iterations) + return total / len(self.iterations) + + + @property + def stddev_execution_time(self) -> float: + assert len(self.iterations) >= 1 + + mean = self.mean_execution_time + squared_diffs = [(iteration.elapsed - mean) ** 2 for iteration in self.iterations] + variance = sum(squared_diffs) / len(self.iterations) + return math.sqrt(variance) + + def execution_time_report(self, detailed = False) -> tuple[float, str]: + if detailed: + mean_execution_time = self.mean_execution_time + return ( + mean_execution_time, + f"{self.min_execution_time:.2f} / {mean_execution_time :.2f} ±{self.stddev_execution_time:.2f} / {self.max_execution_time:.2f} ms" + ) + else: + # Use minimum execution time to account for variations / other + # things the system was doing + min_execution_time = self.min_execution_time + return ( + min_execution_time, + f"{min_execution_time :.2f} ms" + ) + + @dataclass class Context: benchmark_version: str @@ -106,35 +149,54 @@ def compare( baseline_path: Path, comparison_path: Path, noise_threshold: float, + detailed: bool, ) -> None: baseline = BenchmarkRun.load_from_file(baseline_path) comparison = BenchmarkRun.load_from_file(comparison_path) - console = Console() + console = Console(width=200) # use basename as the column names - baseline_header = baseline_path.parent.stem - comparison_header = comparison_path.parent.stem + baseline_header = baseline_path.parent.name + comparison_header = comparison_path.parent.name table = Table(show_header=True, header_style="bold magenta") - table.add_column("Query", style="dim", width=12) - table.add_column(baseline_header, justify="right", style="dim") - table.add_column(comparison_header, justify="right", style="dim") - table.add_column("Change", justify="right", style="dim") + table.add_column("Query", style="dim", no_wrap=True) + table.add_column(baseline_header, justify="right", style="dim", no_wrap=True) + table.add_column(comparison_header, justify="right", style="dim", no_wrap=True) + table.add_column("Change", justify="right", style="dim", no_wrap=True) faster_count = 0 slower_count = 0 no_change_count = 0 + failure_count = 0 total_baseline_time = 0 total_comparison_time = 0 for baseline_result, comparison_result in zip(baseline.queries, comparison.queries): assert baseline_result.query == comparison_result.query - total_baseline_time += baseline_result.execution_time - total_comparison_time += comparison_result.execution_time + base_failed = not baseline_result.success + comp_failed = not comparison_result.success + # If a query fails, its execution time is excluded from the performance comparison + if base_failed or comp_failed: + change_text = "incomparable" + failure_count += 1 + table.add_row( + f"Q{baseline_result.query}", + "FAIL" if base_failed else baseline_result.execution_time_report(detailed)[1], + "FAIL" if comp_failed else comparison_result.execution_time_report(detailed)[1], + change_text, + ) + continue + + baseline_value, baseline_text = baseline_result.execution_time_report(detailed) + comparison_value, comparison_text = comparison_result.execution_time_report(detailed) + + total_baseline_time += baseline_value + total_comparison_time += comparison_value - change = comparison_result.execution_time / baseline_result.execution_time + change = comparison_value / baseline_value if (1.0 - noise_threshold) <= change <= (1.0 + noise_threshold): change_text = "no change" @@ -148,16 +210,20 @@ def compare( table.add_row( f"Q{baseline_result.query}", - f"{baseline_result.execution_time:.2f}ms", - f"{comparison_result.execution_time:.2f}ms", + baseline_text, + comparison_text, change_text, ) console.print(table) # Calculate averages - avg_baseline_time = total_baseline_time / len(baseline.queries) - avg_comparison_time = total_comparison_time / len(comparison.queries) + avg_baseline_time = 0.0 + avg_comparison_time = 0.0 + if len(baseline.queries) - failure_count > 0: + avg_baseline_time = total_baseline_time / (len(baseline.queries) - failure_count) + if len(comparison.queries) - failure_count > 0: + avg_comparison_time = total_comparison_time / (len(comparison.queries) - failure_count) # Summary table summary_table = Table(show_header=True, header_style="bold magenta") @@ -171,6 +237,7 @@ def compare( summary_table.add_row("Queries Faster", str(faster_count)) summary_table.add_row("Queries Slower", str(slower_count)) summary_table.add_row("Queries with No Change", str(no_change_count)) + summary_table.add_row("Queries with Failure", str(failure_count)) console.print(summary_table) @@ -193,10 +260,16 @@ def main() -> None: default=0.05, help="The threshold for statistically insignificant results (+/- %5).", ) + compare_parser.add_argument( + "--detailed", + action=argparse.BooleanOptionalAction, + default=False, + help="Show detailed result comparison instead of minimum runtime.", + ) options = parser.parse_args() - compare(options.baseline_path, options.comparison_path, options.noise_threshold) + compare(options.baseline_path, options.comparison_path, options.noise_threshold, options.detailed) diff --git a/benchmarks/compare_tpcds.sh b/benchmarks/compare_tpcds.sh new file mode 100755 index 0000000000000..48331a7c7510e --- /dev/null +++ b/benchmarks/compare_tpcds.sh @@ -0,0 +1,58 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Compare TPC-DS benchmarks between two branches + +set -e + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + +usage() { + echo "Usage: $0 " + echo "" + echo "Example: $0 main dev2" + echo "" + echo "Note: TPC-DS benchmarks are not currently implemented in bench.sh" + exit 1 +} + +BRANCH1=${1:-""} +BRANCH2=${2:-""} + +if [ -z "$BRANCH1" ] || [ -z "$BRANCH2" ]; then + usage +fi + +# Store current branch +CURRENT_BRANCH=$(git rev-parse --abbrev-ref HEAD) + +echo "Comparing TPC-DS benchmarks: ${BRANCH1} vs ${BRANCH2}" + +# Run benchmark on first branch +git checkout "$BRANCH1" +./benchmarks/bench.sh run tpcds + +# Run benchmark on second branch +git checkout "$BRANCH2" +./benchmarks/bench.sh run tpcds + +# Compare results +./benchmarks/bench.sh compare "$BRANCH1" "$BRANCH2" + +# Return to original branch +git checkout "$CURRENT_BRANCH" \ No newline at end of file diff --git a/benchmarks/compare_tpch.sh b/benchmarks/compare_tpch.sh new file mode 100755 index 0000000000000..85e8da29ce41d --- /dev/null +++ b/benchmarks/compare_tpch.sh @@ -0,0 +1,56 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Compare TPC-H benchmarks between two branches + +set -e + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + +usage() { + echo "Usage: $0 " + echo "" + echo "Example: $0 main dev2" + exit 1 +} + +BRANCH1=${1:-""} +BRANCH2=${2:-""} + +if [ -z "$BRANCH1" ] || [ -z "$BRANCH2" ]; then + usage +fi + +# Store current branch +CURRENT_BRANCH=$(git rev-parse --abbrev-ref HEAD) + +echo "Comparing TPC-H benchmarks: ${BRANCH1} vs ${BRANCH2}" + +# Run benchmark on first branch +git checkout "$BRANCH1" +./benchmarks/bench.sh run tpch + +# Run benchmark on second branch +git checkout "$BRANCH2" +./benchmarks/bench.sh run tpch + +# Compare results +./benchmarks/bench.sh compare "$BRANCH1" "$BRANCH2" + +# Return to original branch +git checkout "$CURRENT_BRANCH" \ No newline at end of file diff --git a/benchmarks/compile_profile.py b/benchmarks/compile_profile.py new file mode 100644 index 0000000000000..ae51de94937bf --- /dev/null +++ b/benchmarks/compile_profile.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python3 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Compile profile benchmark runner for DataFusion. + +Builds the `tpch` benchmark binary with several Cargo profiles (e.g. `--release` or `--profile ci`), runs the full TPC-H suite against the Parquet data under `benchmarks/data/tpch_sf1`, and reports compile time, execution time, and resulting +binary size. + +See `benchmarks/README.md` for usages. +""" + +from __future__ import annotations + +import argparse +import os +import subprocess +import sys +import time +from pathlib import Path +from typing import Iterable, NamedTuple + +REPO_ROOT = Path(__file__).resolve().parents[1] +DEFAULT_DATA_DIR = REPO_ROOT / "benchmarks" / "data" / "tpch_sf1" +DEFAULT_ITERATIONS = 1 +DEFAULT_FORMAT = "parquet" +DEFAULT_PARTITIONS: int | None = None +TPCH_BINARY = "tpch.exe" if os.name == "nt" else "tpch" +PROFILE_TARGET_DIR = { + "dev": "debug", + "release": "release", + "ci": "ci", + "release-nonlto": "release-nonlto", +} + + +class ProfileResult(NamedTuple): + profile: str + compile_seconds: float + run_seconds: float + binary_bytes: int + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--profiles", + nargs="+", + default=list(PROFILE_TARGET_DIR.keys()), + help="Cargo profiles to test (default: dev release ci release-nonlto)", + ) + parser.add_argument( + "--data", + type=Path, + default=DEFAULT_DATA_DIR, + help="Path to TPCH dataset (default: benchmarks/data/tpch_sf1)", + ) + return parser.parse_args() + + +def timed_run(command: Iterable[str]) -> float: + start = time.perf_counter() + try: + subprocess.run(command, cwd=REPO_ROOT, check=True) + except subprocess.CalledProcessError as exc: + raise RuntimeError(f"command failed: {' '.join(command)}") from exc + return time.perf_counter() - start + + +def cargo_build(profile: str) -> float: + if profile == "dev": + command = ["cargo", "build", "--bin", "tpch"] + else: + command = ["cargo", "build", "--profile", profile, "--bin", "tpch"] + return timed_run(command) + + +def cargo_clean(profile: str) -> None: + command = ["cargo", "clean", "--profile", profile] + try: + subprocess.run(command, cwd=REPO_ROOT, check=True) + except subprocess.CalledProcessError as exc: + raise RuntimeError(f"failed to clean cargo artifacts for profile '{profile}'") from exc + + +def run_benchmark(profile: str, data_path: Path) -> float: + binary_dir = PROFILE_TARGET_DIR.get(profile) + if not binary_dir: + raise ValueError(f"unknown profile '{profile}'") + binary_path = REPO_ROOT / "target" / binary_dir / TPCH_BINARY + if not binary_path.exists(): + raise FileNotFoundError(f"compiled binary not found at {binary_path}") + + command = [ + str(binary_path), + "benchmark", + "datafusion", + "--iterations", + str(DEFAULT_ITERATIONS), + "--path", + str(data_path), + "--format", + DEFAULT_FORMAT, + ] + if DEFAULT_PARTITIONS is not None: + command.extend(["--partitions", str(DEFAULT_PARTITIONS)]) + env = os.environ.copy() + env.setdefault("RUST_LOG", "warn") + + start = time.perf_counter() + try: + subprocess.run(command, cwd=REPO_ROOT, env=env, check=True) + except subprocess.CalledProcessError as exc: + raise RuntimeError(f"benchmark failed for profile '{profile}'") from exc + return time.perf_counter() - start + + +def binary_size(profile: str) -> int: + binary_dir = PROFILE_TARGET_DIR[profile] + binary_path = REPO_ROOT / "target" / binary_dir / TPCH_BINARY + return binary_path.stat().st_size + + +def human_time(seconds: float) -> str: + return f"{seconds:6.2f}s" + + +def human_size(size: int) -> str: + value = float(size) + for unit in ("B", "KB", "MB", "GB", "TB"): + if value < 1024 or unit == "TB": + return f"{value:6.1f}{unit}" + value /= 1024 + return f"{value:6.1f}TB" + + +def main() -> None: + args = parse_args() + data_path = args.data.resolve() + if not data_path.exists(): + print(f"Data directory not found: {data_path}", file=sys.stderr) + sys.exit(1) + + results: list[ProfileResult] = [] + for profile in args.profiles: + print(f"\n=== Profile: {profile} ===") + print("Cleaning previous build artifacts...") + cargo_clean(profile) + compile_seconds = cargo_build(profile) + run_seconds = run_benchmark(profile, data_path) + size_bytes = binary_size(profile) + results.append(ProfileResult(profile, compile_seconds, run_seconds, size_bytes)) + + print("\nSummary") + header = f"{'Profile':<15}{'Compile':>12}{'Run':>12}{'Size':>12}" + print(header) + print("-" * len(header)) + for result in results: + print( + f"{result.profile:<15}{human_time(result.compile_seconds):>12}" + f"{human_time(result.run_seconds):>12}{human_size(result.binary_bytes):>12}" + ) + +if __name__ == "__main__": + main() diff --git a/benchmarks/pyproject.toml b/benchmarks/pyproject.toml new file mode 100644 index 0000000000000..e6a60582148ce --- /dev/null +++ b/benchmarks/pyproject.toml @@ -0,0 +1,6 @@ +[project] +name = "datafusion-benchmarks" +version = "0.1.0" +requires-python = ">=3.11" +# typing_extensions is an undeclared dependency of falsa +dependencies = ["rich", "falsa", "typing_extensions"] diff --git a/benchmarks/queries/clickbench/README.md b/benchmarks/queries/clickbench/README.md index e5acd8f348a47..877ea0e0c3192 100644 --- a/benchmarks/queries/clickbench/README.md +++ b/benchmarks/queries/clickbench/README.md @@ -6,8 +6,8 @@ ClickBench is focused on aggregation and filtering performance (though it has no ## Files: -- `queries.sql` - Actual ClickBench queries, downloaded from the [ClickBench repository] -- `extended.sql` - "Extended" DataFusion specific queries. +- `queries/*.sql` - Actual ClickBench queries, downloaded from the [ClickBench repository](https://raw.githubusercontent.com/ClickHouse/ClickBench/main/datafusion/queries.sql) and split by the `update_queries.sh` script. +- `extended/*.sql` - "Extended" DataFusion specific queries. [clickbench repository]: https://github.com/ClickHouse/ClickBench/blob/main/datafusion/queries.sql @@ -15,8 +15,8 @@ ClickBench is focused on aggregation and filtering performance (though it has no The "extended" queries are not part of the official ClickBench benchmark. Instead they are used to test other DataFusion features that are not covered by -the standard benchmark. Each description below is for the corresponding line in -`extended.sql` (line 1 is `Q0`, line 2 is `Q1`, etc.) +the standard benchmark. Each description below is for the corresponding file in +`extended` ### Q0: Data Exploration diff --git a/benchmarks/queries/clickbench/extended.sql b/benchmarks/queries/clickbench/extended.sql deleted file mode 100644 index 93c39efe4f8e3..0000000000000 --- a/benchmarks/queries/clickbench/extended.sql +++ /dev/null @@ -1,9 +0,0 @@ -SELECT COUNT(DISTINCT "SearchPhrase"), COUNT(DISTINCT "MobilePhone"), COUNT(DISTINCT "MobilePhoneModel") FROM hits; -SELECT COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTINCT "BrowserLanguage") FROM hits; -SELECT "BrowserCountry", COUNT(DISTINCT "SocialNetwork"), COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserLanguage"), COUNT(DISTINCT "SocialAction") FROM hits GROUP BY 1 ORDER BY 2 DESC LIMIT 10; -SELECT "SocialSourceNetworkID", "RegionID", COUNT(*), AVG("Age"), AVG("ParamPrice"), STDDEV("ParamPrice") as s, VAR("ParamPrice") FROM hits GROUP BY "SocialSourceNetworkID", "RegionID" HAVING s IS NOT NULL ORDER BY s DESC LIMIT 10; -SELECT "ClientIP", "WatchID", COUNT(*) c, MIN("ResponseStartTiming") tmin, MEDIAN("ResponseStartTiming") tmed, MAX("ResponseStartTiming") tmax FROM hits WHERE "JavaEnable" = 0 GROUP BY "ClientIP", "WatchID" HAVING c > 1 ORDER BY tmed DESC LIMIT 10; -SELECT "ClientIP", "WatchID", COUNT(*) c, MIN("ResponseStartTiming") tmin, APPROX_PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY "ResponseStartTiming") tp95, MAX("ResponseStartTiming") tmax FROM 'hits' WHERE "JavaEnable" = 0 GROUP BY "ClientIP", "WatchID" HAVING c > 1 ORDER BY tp95 DESC LIMIT 10; -SELECT COUNT(*) AS ShareCount FROM hits WHERE "IsMobile" = 1 AND "MobilePhoneModel" LIKE 'iPhone%' AND "SocialAction" = 'share' AND "SocialSourceNetworkID" IN (5, 12) AND "ClientTimeZone" BETWEEN -5 AND 5 AND regexp_match("Referer", '\/campaign\/(spring|summer)_promo') IS NOT NULL AND CASE WHEN split_part(split_part("URL", 'resolution=', 2), '&', 1) ~ '^\d+$' THEN split_part(split_part("URL", 'resolution=', 2), '&', 1)::INT ELSE 0 END > 1920 AND levenshtein(CAST("UTMSource" AS STRING), CAST("UTMCampaign" AS STRING)) < 3; -SELECT "WatchID", MIN("ResolutionWidth") as wmin, MAX("ResolutionWidth") as wmax, SUM("IsRefresh") as srefresh FROM hits GROUP BY "WatchID" ORDER BY "WatchID" DESC LIMIT 10; -SELECT "RegionID", "UserAgent", "OS", AVG(to_timestamp("ResponseEndTiming")-to_timestamp("ResponseStartTiming")) as avg_response_time, AVG(to_timestamp("ResponseEndTiming")-to_timestamp("ConnectTiming")) as avg_latency FROM hits GROUP BY "RegionID", "UserAgent", "OS" ORDER BY avg_latency DESC limit 10; \ No newline at end of file diff --git a/benchmarks/queries/clickbench/extended/q0.sql b/benchmarks/queries/clickbench/extended/q0.sql new file mode 100644 index 0000000000000..cb826e5f947e9 --- /dev/null +++ b/benchmarks/queries/clickbench/extended/q0.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT COUNT(DISTINCT "SearchPhrase"), COUNT(DISTINCT "MobilePhone"), COUNT(DISTINCT "MobilePhoneModel") FROM hits; diff --git a/benchmarks/queries/clickbench/extended/q1.sql b/benchmarks/queries/clickbench/extended/q1.sql new file mode 100644 index 0000000000000..7862423787d85 --- /dev/null +++ b/benchmarks/queries/clickbench/extended/q1.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTINCT "BrowserLanguage") FROM hits; diff --git a/benchmarks/queries/clickbench/extended/q2.sql b/benchmarks/queries/clickbench/extended/q2.sql new file mode 100644 index 0000000000000..de2be79885792 --- /dev/null +++ b/benchmarks/queries/clickbench/extended/q2.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "BrowserCountry", COUNT(DISTINCT "SocialNetwork"), COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserLanguage"), COUNT(DISTINCT "SocialAction") FROM hits GROUP BY 1 ORDER BY 2 DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/extended/q3.sql b/benchmarks/queries/clickbench/extended/q3.sql new file mode 100644 index 0000000000000..f52990b9843a5 --- /dev/null +++ b/benchmarks/queries/clickbench/extended/q3.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "SocialSourceNetworkID", "RegionID", COUNT(*), AVG("Age"), AVG("ParamPrice"), STDDEV("ParamPrice") as s, VAR("ParamPrice") FROM hits GROUP BY "SocialSourceNetworkID", "RegionID" HAVING s IS NOT NULL ORDER BY s DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/extended/q4.sql b/benchmarks/queries/clickbench/extended/q4.sql new file mode 100644 index 0000000000000..5865129db6425 --- /dev/null +++ b/benchmarks/queries/clickbench/extended/q4.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "ClientIP", "WatchID", COUNT(*) c, MIN("ResponseStartTiming") tmin, MEDIAN("ResponseStartTiming") tmed, MAX("ResponseStartTiming") tmax FROM hits WHERE "JavaEnable" = 0 GROUP BY "ClientIP", "WatchID" HAVING c > 1 ORDER BY tmed DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/extended/q5.sql b/benchmarks/queries/clickbench/extended/q5.sql new file mode 100644 index 0000000000000..18d3e01c82c4b --- /dev/null +++ b/benchmarks/queries/clickbench/extended/q5.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "ClientIP", "WatchID", COUNT(*) c, MIN("ResponseStartTiming") tmin, APPROX_PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY "ResponseStartTiming") tp95, MAX("ResponseStartTiming") tmax FROM 'hits' WHERE "JavaEnable" = 0 GROUP BY "ClientIP", "WatchID" HAVING c > 1 ORDER BY tp95 DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/extended/q6.sql b/benchmarks/queries/clickbench/extended/q6.sql new file mode 100644 index 0000000000000..0a6467b8898aa --- /dev/null +++ b/benchmarks/queries/clickbench/extended/q6.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT COUNT(*) AS ShareCount FROM hits WHERE "IsMobile" = 1 AND "MobilePhoneModel" LIKE 'iPhone%' AND "SocialAction" = 'share' AND "SocialSourceNetworkID" IN (5, 12) AND "ClientTimeZone" BETWEEN -5 AND 5 AND regexp_match("Referer", '\/campaign\/(spring|summer)_promo') IS NOT NULL AND CASE WHEN split_part(split_part("URL", 'resolution=', 2), '&', 1) ~ '^\d+$' THEN split_part(split_part("URL", 'resolution=', 2), '&', 1)::INT ELSE 0 END > 1920 AND levenshtein(CAST("UTMSource" AS STRING), CAST("UTMCampaign" AS STRING)) < 3; diff --git a/benchmarks/queries/clickbench/extended/q7.sql b/benchmarks/queries/clickbench/extended/q7.sql new file mode 100644 index 0000000000000..ddaff7f8804f5 --- /dev/null +++ b/benchmarks/queries/clickbench/extended/q7.sql @@ -0,0 +1 @@ +SELECT "WatchID", MIN("ResolutionWidth") as wmin, MAX("ResolutionWidth") as wmax, SUM("IsRefresh") as srefresh FROM hits GROUP BY "WatchID" ORDER BY "WatchID" DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries.sql b/benchmarks/queries/clickbench/queries.sql deleted file mode 100644 index 9a183cd6e259c..0000000000000 --- a/benchmarks/queries/clickbench/queries.sql +++ /dev/null @@ -1,43 +0,0 @@ -SELECT COUNT(*) FROM hits; -SELECT COUNT(*) FROM hits WHERE "AdvEngineID" <> 0; -SELECT SUM("AdvEngineID"), COUNT(*), AVG("ResolutionWidth") FROM hits; -SELECT AVG("UserID") FROM hits; -SELECT COUNT(DISTINCT "UserID") FROM hits; -SELECT COUNT(DISTINCT "SearchPhrase") FROM hits; -SELECT MIN("EventDate"), MAX("EventDate") FROM hits; -SELECT "AdvEngineID", COUNT(*) FROM hits WHERE "AdvEngineID" <> 0 GROUP BY "AdvEngineID" ORDER BY COUNT(*) DESC; -SELECT "RegionID", COUNT(DISTINCT "UserID") AS u FROM hits GROUP BY "RegionID" ORDER BY u DESC LIMIT 10; -SELECT "RegionID", SUM("AdvEngineID"), COUNT(*) AS c, AVG("ResolutionWidth"), COUNT(DISTINCT "UserID") FROM hits GROUP BY "RegionID" ORDER BY c DESC LIMIT 10; -SELECT "MobilePhoneModel", COUNT(DISTINCT "UserID") AS u FROM hits WHERE "MobilePhoneModel" <> '' GROUP BY "MobilePhoneModel" ORDER BY u DESC LIMIT 10; -SELECT "MobilePhone", "MobilePhoneModel", COUNT(DISTINCT "UserID") AS u FROM hits WHERE "MobilePhoneModel" <> '' GROUP BY "MobilePhone", "MobilePhoneModel" ORDER BY u DESC LIMIT 10; -SELECT "SearchPhrase", COUNT(*) AS c FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY c DESC LIMIT 10; -SELECT "SearchPhrase", COUNT(DISTINCT "UserID") AS u FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY u DESC LIMIT 10; -SELECT "SearchEngineID", "SearchPhrase", COUNT(*) AS c FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchEngineID", "SearchPhrase" ORDER BY c DESC LIMIT 10; -SELECT "UserID", COUNT(*) FROM hits GROUP BY "UserID" ORDER BY COUNT(*) DESC LIMIT 10; -SELECT "UserID", "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", "SearchPhrase" ORDER BY COUNT(*) DESC LIMIT 10; -SELECT "UserID", "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", "SearchPhrase" LIMIT 10; -SELECT "UserID", extract(minute FROM to_timestamp_seconds("EventTime")) AS m, "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", m, "SearchPhrase" ORDER BY COUNT(*) DESC LIMIT 10; -SELECT "UserID" FROM hits WHERE "UserID" = 435090932899640449; -SELECT COUNT(*) FROM hits WHERE "URL" LIKE '%google%'; -SELECT "SearchPhrase", MIN("URL"), COUNT(*) AS c FROM hits WHERE "URL" LIKE '%google%' AND "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY c DESC LIMIT 10; -SELECT "SearchPhrase", MIN("URL"), MIN("Title"), COUNT(*) AS c, COUNT(DISTINCT "UserID") FROM hits WHERE "Title" LIKE '%Google%' AND "URL" NOT LIKE '%.google.%' AND "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY c DESC LIMIT 10; -SELECT * FROM hits WHERE "URL" LIKE '%google%' ORDER BY "EventTime" LIMIT 10; -SELECT "SearchPhrase" FROM hits WHERE "SearchPhrase" <> '' ORDER BY "EventTime" LIMIT 10; -SELECT "SearchPhrase" FROM hits WHERE "SearchPhrase" <> '' ORDER BY "SearchPhrase" LIMIT 10; -SELECT "SearchPhrase" FROM hits WHERE "SearchPhrase" <> '' ORDER BY "EventTime", "SearchPhrase" LIMIT 10; -SELECT "CounterID", AVG(length("URL")) AS l, COUNT(*) AS c FROM hits WHERE "URL" <> '' GROUP BY "CounterID" HAVING COUNT(*) > 100000 ORDER BY l DESC LIMIT 25; -SELECT REGEXP_REPLACE("Referer", '^https?://(?:www\.)?([^/]+)/.*$', '\1') AS k, AVG(length("Referer")) AS l, COUNT(*) AS c, MIN("Referer") FROM hits WHERE "Referer" <> '' GROUP BY k HAVING COUNT(*) > 100000 ORDER BY l DESC LIMIT 25; -SELECT SUM("ResolutionWidth"), SUM("ResolutionWidth" + 1), SUM("ResolutionWidth" + 2), SUM("ResolutionWidth" + 3), SUM("ResolutionWidth" + 4), SUM("ResolutionWidth" + 5), SUM("ResolutionWidth" + 6), SUM("ResolutionWidth" + 7), SUM("ResolutionWidth" + 8), SUM("ResolutionWidth" + 9), SUM("ResolutionWidth" + 10), SUM("ResolutionWidth" + 11), SUM("ResolutionWidth" + 12), SUM("ResolutionWidth" + 13), SUM("ResolutionWidth" + 14), SUM("ResolutionWidth" + 15), SUM("ResolutionWidth" + 16), SUM("ResolutionWidth" + 17), SUM("ResolutionWidth" + 18), SUM("ResolutionWidth" + 19), SUM("ResolutionWidth" + 20), SUM("ResolutionWidth" + 21), SUM("ResolutionWidth" + 22), SUM("ResolutionWidth" + 23), SUM("ResolutionWidth" + 24), SUM("ResolutionWidth" + 25), SUM("ResolutionWidth" + 26), SUM("ResolutionWidth" + 27), SUM("ResolutionWidth" + 28), SUM("ResolutionWidth" + 29), SUM("ResolutionWidth" + 30), SUM("ResolutionWidth" + 31), SUM("ResolutionWidth" + 32), SUM("ResolutionWidth" + 33), SUM("ResolutionWidth" + 34), SUM("ResolutionWidth" + 35), SUM("ResolutionWidth" + 36), SUM("ResolutionWidth" + 37), SUM("ResolutionWidth" + 38), SUM("ResolutionWidth" + 39), SUM("ResolutionWidth" + 40), SUM("ResolutionWidth" + 41), SUM("ResolutionWidth" + 42), SUM("ResolutionWidth" + 43), SUM("ResolutionWidth" + 44), SUM("ResolutionWidth" + 45), SUM("ResolutionWidth" + 46), SUM("ResolutionWidth" + 47), SUM("ResolutionWidth" + 48), SUM("ResolutionWidth" + 49), SUM("ResolutionWidth" + 50), SUM("ResolutionWidth" + 51), SUM("ResolutionWidth" + 52), SUM("ResolutionWidth" + 53), SUM("ResolutionWidth" + 54), SUM("ResolutionWidth" + 55), SUM("ResolutionWidth" + 56), SUM("ResolutionWidth" + 57), SUM("ResolutionWidth" + 58), SUM("ResolutionWidth" + 59), SUM("ResolutionWidth" + 60), SUM("ResolutionWidth" + 61), SUM("ResolutionWidth" + 62), SUM("ResolutionWidth" + 63), SUM("ResolutionWidth" + 64), SUM("ResolutionWidth" + 65), SUM("ResolutionWidth" + 66), SUM("ResolutionWidth" + 67), SUM("ResolutionWidth" + 68), SUM("ResolutionWidth" + 69), SUM("ResolutionWidth" + 70), SUM("ResolutionWidth" + 71), SUM("ResolutionWidth" + 72), SUM("ResolutionWidth" + 73), SUM("ResolutionWidth" + 74), SUM("ResolutionWidth" + 75), SUM("ResolutionWidth" + 76), SUM("ResolutionWidth" + 77), SUM("ResolutionWidth" + 78), SUM("ResolutionWidth" + 79), SUM("ResolutionWidth" + 80), SUM("ResolutionWidth" + 81), SUM("ResolutionWidth" + 82), SUM("ResolutionWidth" + 83), SUM("ResolutionWidth" + 84), SUM("ResolutionWidth" + 85), SUM("ResolutionWidth" + 86), SUM("ResolutionWidth" + 87), SUM("ResolutionWidth" + 88), SUM("ResolutionWidth" + 89) FROM hits; -SELECT "SearchEngineID", "ClientIP", COUNT(*) AS c, SUM("IsRefresh"), AVG("ResolutionWidth") FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchEngineID", "ClientIP" ORDER BY c DESC LIMIT 10; -SELECT "WatchID", "ClientIP", COUNT(*) AS c, SUM("IsRefresh"), AVG("ResolutionWidth") FROM hits WHERE "SearchPhrase" <> '' GROUP BY "WatchID", "ClientIP" ORDER BY c DESC LIMIT 10; -SELECT "WatchID", "ClientIP", COUNT(*) AS c, SUM("IsRefresh"), AVG("ResolutionWidth") FROM hits GROUP BY "WatchID", "ClientIP" ORDER BY c DESC LIMIT 10; -SELECT "URL", COUNT(*) AS c FROM hits GROUP BY "URL" ORDER BY c DESC LIMIT 10; -SELECT 1, "URL", COUNT(*) AS c FROM hits GROUP BY 1, "URL" ORDER BY c DESC LIMIT 10; -SELECT "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3, COUNT(*) AS c FROM hits GROUP BY "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3 ORDER BY c DESC LIMIT 10; -SELECT "URL", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "DontCountHits" = 0 AND "IsRefresh" = 0 AND "URL" <> '' GROUP BY "URL" ORDER BY PageViews DESC LIMIT 10; -SELECT "Title", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "DontCountHits" = 0 AND "IsRefresh" = 0 AND "Title" <> '' GROUP BY "Title" ORDER BY PageViews DESC LIMIT 10; -SELECT "URL", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "IsRefresh" = 0 AND "IsLink" <> 0 AND "IsDownload" = 0 GROUP BY "URL" ORDER BY PageViews DESC LIMIT 10 OFFSET 1000; -SELECT "TraficSourceID", "SearchEngineID", "AdvEngineID", CASE WHEN ("SearchEngineID" = 0 AND "AdvEngineID" = 0) THEN "Referer" ELSE '' END AS Src, "URL" AS Dst, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "IsRefresh" = 0 GROUP BY "TraficSourceID", "SearchEngineID", "AdvEngineID", Src, Dst ORDER BY PageViews DESC LIMIT 10 OFFSET 1000; -SELECT "URLHash", "EventDate", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "IsRefresh" = 0 AND "TraficSourceID" IN (-1, 6) AND "RefererHash" = 3594120000172545465 GROUP BY "URLHash", "EventDate" ORDER BY PageViews DESC LIMIT 10 OFFSET 100; -SELECT "WindowClientWidth", "WindowClientHeight", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "IsRefresh" = 0 AND "DontCountHits" = 0 AND "URLHash" = 2868770270353813622 GROUP BY "WindowClientWidth", "WindowClientHeight" ORDER BY PageViews DESC LIMIT 10 OFFSET 10000; -SELECT DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) AS M, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-14' AND "EventDate" <= '2013-07-15' AND "IsRefresh" = 0 AND "DontCountHits" = 0 GROUP BY DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) ORDER BY DATE_TRUNC('minute', M) LIMIT 10 OFFSET 1000; diff --git a/benchmarks/queries/clickbench/queries/q0.sql b/benchmarks/queries/clickbench/queries/q0.sql new file mode 100644 index 0000000000000..35f2b32ed4863 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q0.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 + +-- set datafusion.execution.parquet.binary_as_string = true +SELECT COUNT(*) FROM hits; diff --git a/benchmarks/queries/clickbench/queries/q1.sql b/benchmarks/queries/clickbench/queries/q1.sql new file mode 100644 index 0000000000000..0bee959ec3c7d --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q1.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT COUNT(*) FROM hits WHERE "AdvEngineID" <> 0; diff --git a/benchmarks/queries/clickbench/queries/q10.sql b/benchmarks/queries/clickbench/queries/q10.sql new file mode 100644 index 0000000000000..0f9114803fecf --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q10.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "MobilePhoneModel", COUNT(DISTINCT "UserID") AS u FROM hits WHERE "MobilePhoneModel" <> '' GROUP BY "MobilePhoneModel" ORDER BY u DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q11.sql b/benchmarks/queries/clickbench/queries/q11.sql new file mode 100644 index 0000000000000..bed8bb210e130 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q11.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "MobilePhone", "MobilePhoneModel", COUNT(DISTINCT "UserID") AS u FROM hits WHERE "MobilePhoneModel" <> '' GROUP BY "MobilePhone", "MobilePhoneModel" ORDER BY u DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q12.sql b/benchmarks/queries/clickbench/queries/q12.sql new file mode 100644 index 0000000000000..8cf09c0049f3d --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q12.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "SearchPhrase", COUNT(*) AS c FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY c DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q13.sql b/benchmarks/queries/clickbench/queries/q13.sql new file mode 100644 index 0000000000000..ef6583c8d1886 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q13.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "SearchPhrase", COUNT(DISTINCT "UserID") AS u FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY u DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q14.sql b/benchmarks/queries/clickbench/queries/q14.sql new file mode 100644 index 0000000000000..dd267146edec5 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q14.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "SearchEngineID", "SearchPhrase", COUNT(*) AS c FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchEngineID", "SearchPhrase" ORDER BY c DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q15.sql b/benchmarks/queries/clickbench/queries/q15.sql new file mode 100644 index 0000000000000..721d924cb9b95 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q15.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "UserID", COUNT(*) FROM hits GROUP BY "UserID" ORDER BY COUNT(*) DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q16.sql b/benchmarks/queries/clickbench/queries/q16.sql new file mode 100644 index 0000000000000..389725d58d7a3 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q16.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "UserID", "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", "SearchPhrase" ORDER BY COUNT(*) DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q17.sql b/benchmarks/queries/clickbench/queries/q17.sql new file mode 100644 index 0000000000000..be9976a01d7a4 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q17.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "UserID", "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", "SearchPhrase" LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q18.sql b/benchmarks/queries/clickbench/queries/q18.sql new file mode 100644 index 0000000000000..d649f1edfe2a4 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q18.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "UserID", extract(minute FROM to_timestamp_seconds("EventTime")) AS m, "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", m, "SearchPhrase" ORDER BY COUNT(*) DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q19.sql b/benchmarks/queries/clickbench/queries/q19.sql new file mode 100644 index 0000000000000..8212a765730a3 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q19.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "UserID" FROM hits WHERE "UserID" = 435090932899640449; diff --git a/benchmarks/queries/clickbench/queries/q2.sql b/benchmarks/queries/clickbench/queries/q2.sql new file mode 100644 index 0000000000000..bcdfad84ec10f --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q2.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT SUM("AdvEngineID"), COUNT(*), AVG("ResolutionWidth") FROM hits; diff --git a/benchmarks/queries/clickbench/queries/q20.sql b/benchmarks/queries/clickbench/queries/q20.sql new file mode 100644 index 0000000000000..a7e488c2abcd8 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q20.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT COUNT(*) FROM hits WHERE "URL" LIKE '%google%'; diff --git a/benchmarks/queries/clickbench/queries/q21.sql b/benchmarks/queries/clickbench/queries/q21.sql new file mode 100644 index 0000000000000..3551689728ede --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q21.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "SearchPhrase", MIN("URL"), COUNT(*) AS c FROM hits WHERE "URL" LIKE '%google%' AND "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY c DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q22.sql b/benchmarks/queries/clickbench/queries/q22.sql new file mode 100644 index 0000000000000..d5f696e75a8c8 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q22.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "SearchPhrase", MIN("URL"), MIN("Title"), COUNT(*) AS c, COUNT(DISTINCT "UserID") FROM hits WHERE "Title" LIKE '%Google%' AND "URL" NOT LIKE '%.google.%' AND "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY c DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q23.sql b/benchmarks/queries/clickbench/queries/q23.sql new file mode 100644 index 0000000000000..ff399ded6ed8c --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q23.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT * FROM hits WHERE "URL" LIKE '%google%' ORDER BY "EventTime" LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q24.sql b/benchmarks/queries/clickbench/queries/q24.sql new file mode 100644 index 0000000000000..bc7a364151e23 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q24.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "SearchPhrase" FROM hits WHERE "SearchPhrase" <> '' ORDER BY "EventTime" LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q25.sql b/benchmarks/queries/clickbench/queries/q25.sql new file mode 100644 index 0000000000000..5332e3451aeaf --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q25.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "SearchPhrase" FROM hits WHERE "SearchPhrase" <> '' ORDER BY "SearchPhrase" LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q26.sql b/benchmarks/queries/clickbench/queries/q26.sql new file mode 100644 index 0000000000000..bc1108aea1255 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q26.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "SearchPhrase" FROM hits WHERE "SearchPhrase" <> '' ORDER BY "EventTime", "SearchPhrase" LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q27.sql b/benchmarks/queries/clickbench/queries/q27.sql new file mode 100644 index 0000000000000..ba234d34f8877 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q27.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "CounterID", AVG(length("URL")) AS l, COUNT(*) AS c FROM hits WHERE "URL" <> '' GROUP BY "CounterID" HAVING COUNT(*) > 100000 ORDER BY l DESC LIMIT 25; diff --git a/benchmarks/queries/clickbench/queries/q28.sql b/benchmarks/queries/clickbench/queries/q28.sql new file mode 100644 index 0000000000000..6a3bd037bece7 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q28.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT REGEXP_REPLACE("Referer", '^https?://(?:www\.)?([^/]+)/.*$', '\1') AS k, AVG(length("Referer")) AS l, COUNT(*) AS c, MIN("Referer") FROM hits WHERE "Referer" <> '' GROUP BY k HAVING COUNT(*) > 100000 ORDER BY l DESC LIMIT 25; diff --git a/benchmarks/queries/clickbench/queries/q29.sql b/benchmarks/queries/clickbench/queries/q29.sql new file mode 100644 index 0000000000000..bca1eb7bbe54b --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q29.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT SUM("ResolutionWidth"), SUM("ResolutionWidth" + 1), SUM("ResolutionWidth" + 2), SUM("ResolutionWidth" + 3), SUM("ResolutionWidth" + 4), SUM("ResolutionWidth" + 5), SUM("ResolutionWidth" + 6), SUM("ResolutionWidth" + 7), SUM("ResolutionWidth" + 8), SUM("ResolutionWidth" + 9), SUM("ResolutionWidth" + 10), SUM("ResolutionWidth" + 11), SUM("ResolutionWidth" + 12), SUM("ResolutionWidth" + 13), SUM("ResolutionWidth" + 14), SUM("ResolutionWidth" + 15), SUM("ResolutionWidth" + 16), SUM("ResolutionWidth" + 17), SUM("ResolutionWidth" + 18), SUM("ResolutionWidth" + 19), SUM("ResolutionWidth" + 20), SUM("ResolutionWidth" + 21), SUM("ResolutionWidth" + 22), SUM("ResolutionWidth" + 23), SUM("ResolutionWidth" + 24), SUM("ResolutionWidth" + 25), SUM("ResolutionWidth" + 26), SUM("ResolutionWidth" + 27), SUM("ResolutionWidth" + 28), SUM("ResolutionWidth" + 29), SUM("ResolutionWidth" + 30), SUM("ResolutionWidth" + 31), SUM("ResolutionWidth" + 32), SUM("ResolutionWidth" + 33), SUM("ResolutionWidth" + 34), SUM("ResolutionWidth" + 35), SUM("ResolutionWidth" + 36), SUM("ResolutionWidth" + 37), SUM("ResolutionWidth" + 38), SUM("ResolutionWidth" + 39), SUM("ResolutionWidth" + 40), SUM("ResolutionWidth" + 41), SUM("ResolutionWidth" + 42), SUM("ResolutionWidth" + 43), SUM("ResolutionWidth" + 44), SUM("ResolutionWidth" + 45), SUM("ResolutionWidth" + 46), SUM("ResolutionWidth" + 47), SUM("ResolutionWidth" + 48), SUM("ResolutionWidth" + 49), SUM("ResolutionWidth" + 50), SUM("ResolutionWidth" + 51), SUM("ResolutionWidth" + 52), SUM("ResolutionWidth" + 53), SUM("ResolutionWidth" + 54), SUM("ResolutionWidth" + 55), SUM("ResolutionWidth" + 56), SUM("ResolutionWidth" + 57), SUM("ResolutionWidth" + 58), SUM("ResolutionWidth" + 59), SUM("ResolutionWidth" + 60), SUM("ResolutionWidth" + 61), SUM("ResolutionWidth" + 62), SUM("ResolutionWidth" + 63), SUM("ResolutionWidth" + 64), SUM("ResolutionWidth" + 65), SUM("ResolutionWidth" + 66), SUM("ResolutionWidth" + 67), SUM("ResolutionWidth" + 68), SUM("ResolutionWidth" + 69), SUM("ResolutionWidth" + 70), SUM("ResolutionWidth" + 71), SUM("ResolutionWidth" + 72), SUM("ResolutionWidth" + 73), SUM("ResolutionWidth" + 74), SUM("ResolutionWidth" + 75), SUM("ResolutionWidth" + 76), SUM("ResolutionWidth" + 77), SUM("ResolutionWidth" + 78), SUM("ResolutionWidth" + 79), SUM("ResolutionWidth" + 80), SUM("ResolutionWidth" + 81), SUM("ResolutionWidth" + 82), SUM("ResolutionWidth" + 83), SUM("ResolutionWidth" + 84), SUM("ResolutionWidth" + 85), SUM("ResolutionWidth" + 86), SUM("ResolutionWidth" + 87), SUM("ResolutionWidth" + 88), SUM("ResolutionWidth" + 89) FROM hits; diff --git a/benchmarks/queries/clickbench/queries/q3.sql b/benchmarks/queries/clickbench/queries/q3.sql new file mode 100644 index 0000000000000..09cdaca713047 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q3.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT AVG("UserID") FROM hits; diff --git a/benchmarks/queries/clickbench/queries/q30.sql b/benchmarks/queries/clickbench/queries/q30.sql new file mode 100644 index 0000000000000..c0d657927478e --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q30.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "SearchEngineID", "ClientIP", COUNT(*) AS c, SUM("IsRefresh"), AVG("ResolutionWidth") FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchEngineID", "ClientIP" ORDER BY c DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q31.sql b/benchmarks/queries/clickbench/queries/q31.sql new file mode 100644 index 0000000000000..76ab3622ffb57 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q31.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "WatchID", "ClientIP", COUNT(*) AS c, SUM("IsRefresh"), AVG("ResolutionWidth") FROM hits WHERE "SearchPhrase" <> '' GROUP BY "WatchID", "ClientIP" ORDER BY c DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q32.sql b/benchmarks/queries/clickbench/queries/q32.sql new file mode 100644 index 0000000000000..88f1e4ce42d23 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q32.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "WatchID", "ClientIP", COUNT(*) AS c, SUM("IsRefresh"), AVG("ResolutionWidth") FROM hits GROUP BY "WatchID", "ClientIP" ORDER BY c DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q33.sql b/benchmarks/queries/clickbench/queries/q33.sql new file mode 100644 index 0000000000000..3740503bbc0e9 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q33.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "URL", COUNT(*) AS c FROM hits GROUP BY "URL" ORDER BY c DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q34.sql b/benchmarks/queries/clickbench/queries/q34.sql new file mode 100644 index 0000000000000..fdb7edbb656ac --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q34.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT 1, "URL", COUNT(*) AS c FROM hits GROUP BY 1, "URL" ORDER BY c DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q35.sql b/benchmarks/queries/clickbench/queries/q35.sql new file mode 100644 index 0000000000000..de7e2256eb551 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q35.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3, COUNT(*) AS c FROM hits GROUP BY "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3 ORDER BY c DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q36.sql b/benchmarks/queries/clickbench/queries/q36.sql new file mode 100644 index 0000000000000..81b1199b0381e --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q36.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "URL", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "DontCountHits" = 0 AND "IsRefresh" = 0 AND "URL" <> '' GROUP BY "URL" ORDER BY PageViews DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q37.sql b/benchmarks/queries/clickbench/queries/q37.sql new file mode 100644 index 0000000000000..fa4b85ffbd9cb --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q37.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "Title", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "DontCountHits" = 0 AND "IsRefresh" = 0 AND "Title" <> '' GROUP BY "Title" ORDER BY PageViews DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q38.sql b/benchmarks/queries/clickbench/queries/q38.sql new file mode 100644 index 0000000000000..18fafab6c888f --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q38.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "URL", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "IsRefresh" = 0 AND "IsLink" <> 0 AND "IsDownload" = 0 GROUP BY "URL" ORDER BY PageViews DESC LIMIT 10 OFFSET 1000; diff --git a/benchmarks/queries/clickbench/queries/q39.sql b/benchmarks/queries/clickbench/queries/q39.sql new file mode 100644 index 0000000000000..306f0caacff64 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q39.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "TraficSourceID", "SearchEngineID", "AdvEngineID", CASE WHEN ("SearchEngineID" = 0 AND "AdvEngineID" = 0) THEN "Referer" ELSE '' END AS Src, "URL" AS Dst, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "IsRefresh" = 0 GROUP BY "TraficSourceID", "SearchEngineID", "AdvEngineID", Src, Dst ORDER BY PageViews DESC LIMIT 10 OFFSET 1000; diff --git a/benchmarks/queries/clickbench/queries/q4.sql b/benchmarks/queries/clickbench/queries/q4.sql new file mode 100644 index 0000000000000..d89ca78c2fb6f --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q4.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT COUNT(DISTINCT "UserID") FROM hits; diff --git a/benchmarks/queries/clickbench/queries/q40.sql b/benchmarks/queries/clickbench/queries/q40.sql new file mode 100644 index 0000000000000..e9d27f5985fa9 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q40.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "URLHash", "EventDate", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "IsRefresh" = 0 AND "TraficSourceID" IN (-1, 6) AND "RefererHash" = 3594120000172545465 GROUP BY "URLHash", "EventDate" ORDER BY PageViews DESC LIMIT 10 OFFSET 100; diff --git a/benchmarks/queries/clickbench/queries/q41.sql b/benchmarks/queries/clickbench/queries/q41.sql new file mode 100644 index 0000000000000..0e067e2dfc9da --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q41.sql @@ -0,0 +1,3 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true +SELECT "WindowClientWidth", "WindowClientHeight", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "IsRefresh" = 0 AND "DontCountHits" = 0 AND "URLHash" = 2868770270353813622 GROUP BY "WindowClientWidth", "WindowClientHeight" ORDER BY PageViews DESC LIMIT 10 OFFSET 10000; diff --git a/benchmarks/queries/clickbench/queries/q42.sql b/benchmarks/queries/clickbench/queries/q42.sql new file mode 100644 index 0000000000000..111cc1d3c4a9d --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q42.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) AS M, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-14' AND "EventDate" <= '2013-07-15' AND "IsRefresh" = 0 AND "DontCountHits" = 0 GROUP BY DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) ORDER BY DATE_TRUNC('minute', M) LIMIT 10 OFFSET 1000; diff --git a/benchmarks/queries/clickbench/queries/q5.sql b/benchmarks/queries/clickbench/queries/q5.sql new file mode 100644 index 0000000000000..d371cfb6b3557 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q5.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT COUNT(DISTINCT "SearchPhrase") FROM hits; diff --git a/benchmarks/queries/clickbench/queries/q6.sql b/benchmarks/queries/clickbench/queries/q6.sql new file mode 100644 index 0000000000000..5b4e896a1df26 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q6.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT MIN("EventDate"), MAX("EventDate") FROM hits; diff --git a/benchmarks/queries/clickbench/queries/q7.sql b/benchmarks/queries/clickbench/queries/q7.sql new file mode 100644 index 0000000000000..afffcb1306d54 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q7.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "AdvEngineID", COUNT(*) FROM hits WHERE "AdvEngineID" <> 0 GROUP BY "AdvEngineID" ORDER BY COUNT(*) DESC; diff --git a/benchmarks/queries/clickbench/queries/q8.sql b/benchmarks/queries/clickbench/queries/q8.sql new file mode 100644 index 0000000000000..097880a9da5ed --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q8.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "RegionID", COUNT(DISTINCT "UserID") AS u FROM hits GROUP BY "RegionID" ORDER BY u DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q9.sql b/benchmarks/queries/clickbench/queries/q9.sql new file mode 100644 index 0000000000000..cb1b79bf5bdc1 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q9.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "RegionID", SUM("AdvEngineID"), COUNT(*) AS c, AVG("ResolutionWidth"), COUNT(DISTINCT "UserID") FROM hits GROUP BY "RegionID" ORDER BY c DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/sorted_data/q0.sql b/benchmarks/queries/clickbench/queries/sorted_data/q0.sql new file mode 100644 index 0000000000000..1170a383bcb22 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/sorted_data/q0.sql @@ -0,0 +1,3 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true +SELECT * FROM hits ORDER BY "EventTime" DESC limit 10; diff --git a/benchmarks/queries/clickbench/update_queries.sh b/benchmarks/queries/clickbench/update_queries.sh new file mode 100755 index 0000000000000..d7db7359aa394 --- /dev/null +++ b/benchmarks/queries/clickbench/update_queries.sh @@ -0,0 +1,80 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This script is meant for developers of DataFusion -- it is runnable +# from the standard DataFusion development environment and uses cargo, +# etc and orchestrates gathering data and run the benchmark binary in +# different configurations. + +# Script to download ClickBench queries and split them into individual files + +set -e # Exit on any error + +# URL for the raw file (not the GitHub page) +URL="https://raw.githubusercontent.com/ClickHouse/ClickBench/main/datafusion/queries.sql" + +# Temporary file to store downloaded content +TEMP_FILE="queries.sql" + +TARGET_DIR="queries" + +# Download the file +echo "Downloading queries from $URL..." +if command -v curl &> /dev/null; then + curl -s -o "$TEMP_FILE" "$URL" +elif command -v wget &> /dev/null; then + wget -q -O "$TEMP_FILE" "$URL" +else + echo "Error: Neither curl nor wget is available. Please install one of them." + exit 1 +fi + +# Check if download was successful +if [ ! -f "$TEMP_FILE" ] || [ ! -s "$TEMP_FILE" ]; then + echo "Error: Failed to download or file is empty" + exit 1 +fi + +# Initialize counter +counter=0 + +# Ensure the target directory exists +if [ ! -d ${TARGET_DIR} ]; then + mkdir -p ${TARGET_DIR} +fi + +# Read the file line by line and create individual query files +mapfile -t lines < $TEMP_FILE +for line in "${lines[@]}"; do + # Skip empty lines + if [ -n "$line" ]; then + # Create filename with zero-padded counter + filename="q${counter}.sql" + + # Write the line to the individual file + echo "$line" > "${TARGET_DIR}/$filename" + + echo "Created ${TARGET_DIR}/$filename" + + # Increment counter + (( counter += 1 )) + fi +done + +# Clean up temporary file +rm "$TEMP_FILE" \ No newline at end of file diff --git a/benchmarks/queries/q10.sql b/benchmarks/queries/q10.sql index cf45e43485fb5..8613fd4962837 100644 --- a/benchmarks/queries/q10.sql +++ b/benchmarks/queries/q10.sql @@ -28,4 +28,5 @@ group by c_address, c_comment order by - revenue desc; \ No newline at end of file + revenue desc +limit 20; diff --git a/benchmarks/queries/q18.sql b/benchmarks/queries/q18.sql index 835de28a57be2..ba7ee7f716cf1 100644 --- a/benchmarks/queries/q18.sql +++ b/benchmarks/queries/q18.sql @@ -29,4 +29,5 @@ group by o_totalprice order by o_totalprice desc, - o_orderdate; \ No newline at end of file + o_orderdate +limit 100; diff --git a/benchmarks/queries/q2.sql b/benchmarks/queries/q2.sql index f66af210205e9..68e478f65d3f9 100644 --- a/benchmarks/queries/q2.sql +++ b/benchmarks/queries/q2.sql @@ -40,4 +40,5 @@ order by s_acctbal desc, n_name, s_name, - p_partkey; \ No newline at end of file + p_partkey +limit 100; diff --git a/benchmarks/queries/q21.sql b/benchmarks/queries/q21.sql index 9d2fe32cee228..b95e7b0dfca02 100644 --- a/benchmarks/queries/q21.sql +++ b/benchmarks/queries/q21.sql @@ -36,4 +36,5 @@ group by s_name order by numwait desc, - s_name; \ No newline at end of file + s_name +limit 100; diff --git a/benchmarks/queries/q3.sql b/benchmarks/queries/q3.sql index 7dbc6d9ef6783..e5fa9e38664c3 100644 --- a/benchmarks/queries/q3.sql +++ b/benchmarks/queries/q3.sql @@ -19,4 +19,5 @@ group by o_shippriority order by revenue desc, - o_orderdate; \ No newline at end of file + o_orderdate +limit 10; diff --git a/benchmarks/src/bin/dfbench.rs b/benchmarks/src/bin/dfbench.rs index 06337cb758885..7e21890519fd1 100644 --- a/benchmarks/src/bin/dfbench.rs +++ b/benchmarks/src/bin/dfbench.rs @@ -18,7 +18,7 @@ //! DataFusion benchmark runner use datafusion::error::Result; -use structopt::StructOpt; +use clap::{Parser, Subcommand}; #[cfg(all(feature = "snmalloc", feature = "mimalloc"))] compile_error!( @@ -34,21 +34,28 @@ static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc; use datafusion_benchmarks::{ - cancellation, clickbench, h2o, imdb, parquet_filter, sort, sort_tpch, tpch, + cancellation, clickbench, h2o, hj, imdb, nlj, smj, sort_tpch, tpcds, tpch, }; -#[derive(Debug, StructOpt)] -#[structopt(about = "benchmark command")] +#[derive(Debug, Parser)] +#[command(about = "benchmark command")] +struct Cli { + #[command(subcommand)] + command: Options, +} + +#[derive(Debug, Subcommand)] enum Options { Cancellation(cancellation::RunOpt), Clickbench(clickbench::RunOpt), H2o(h2o::RunOpt), + HJ(hj::RunOpt), Imdb(imdb::RunOpt), - ParquetFilter(parquet_filter::RunOpt), - Sort(sort::RunOpt), + Nlj(nlj::RunOpt), + Smj(smj::RunOpt), SortTpch(sort_tpch::RunOpt), Tpch(tpch::RunOpt), - TpchConvert(tpch::ConvertOpt), + Tpcds(tpcds::RunOpt), } // Main benchmark runner entrypoint @@ -56,15 +63,17 @@ enum Options { pub async fn main() -> Result<()> { env_logger::init(); - match Options::from_args() { + let cli = Cli::parse(); + match cli.command { Options::Cancellation(opt) => opt.run().await, Options::Clickbench(opt) => opt.run().await, Options::H2o(opt) => opt.run().await, - Options::Imdb(opt) => opt.run().await, - Options::ParquetFilter(opt) => opt.run().await, - Options::Sort(opt) => opt.run().await, + Options::HJ(opt) => opt.run().await, + Options::Imdb(opt) => Box::pin(opt.run()).await, + Options::Nlj(opt) => opt.run().await, + Options::Smj(opt) => opt.run().await, Options::SortTpch(opt) => opt.run().await, - Options::Tpch(opt) => opt.run().await, - Options::TpchConvert(opt) => opt.run().await, + Options::Tpch(opt) => Box::pin(opt.run()).await, + Options::Tpcds(opt) => Box::pin(opt.run()).await, } } diff --git a/benchmarks/src/bin/external_aggr.rs b/benchmarks/src/bin/external_aggr.rs index 36cd64222cc6b..ee604ec7365a1 100644 --- a/benchmarks/src/bin/external_aggr.rs +++ b/benchmarks/src/bin/external_aggr.rs @@ -17,13 +17,13 @@ //! external_aggr binary entrypoint +use clap::{Args, Parser, Subcommand}; use datafusion::execution::memory_pool::GreedyMemoryPool; use datafusion::execution::memory_pool::MemoryPool; use std::collections::HashMap; use std::path::PathBuf; use std::sync::Arc; use std::sync::LazyLock; -use structopt::StructOpt; use arrow::record_batch::RecordBatch; use arrow::util::pretty; @@ -33,55 +33,56 @@ use datafusion::datasource::listing::{ }; use datafusion::datasource::{MemTable, TableProvider}; use datafusion::error::Result; +use datafusion::execution::SessionStateBuilder; use datafusion::execution::memory_pool::FairSpillPool; -use datafusion::execution::memory_pool::{human_readable_size, units}; use datafusion::execution::runtime_env::RuntimeEnvBuilder; -use datafusion::execution::SessionStateBuilder; use datafusion::physical_plan::display::DisplayableExecutionPlan; use datafusion::physical_plan::{collect, displayable}; use datafusion::prelude::*; -use datafusion_benchmarks::util::{BenchmarkRun, CommonOpt}; +use datafusion_benchmarks::util::{BenchmarkRun, CommonOpt, QueryResult}; use datafusion_common::instant::Instant; use datafusion_common::utils::get_available_parallelism; -use datafusion_common::{exec_err, DEFAULT_PARQUET_EXTENSION}; +use datafusion_common::{DEFAULT_PARQUET_EXTENSION, exec_err}; +use datafusion_common::{human_readable_size, units}; -#[derive(Debug, StructOpt)] -#[structopt( +#[derive(Debug, Parser)] +#[command( name = "datafusion-external-aggregation", about = "DataFusion external aggregation benchmark" )] +struct Cli { + #[command(subcommand)] + command: ExternalAggrOpt, +} + +#[derive(Debug, Subcommand)] enum ExternalAggrOpt { Benchmark(ExternalAggrConfig), } -#[derive(Debug, StructOpt)] +#[derive(Debug, Args)] struct ExternalAggrConfig { /// Query number. If not specified, runs all queries - #[structopt(short, long)] + #[arg(short, long)] query: Option, /// Common options - #[structopt(flatten)] + #[command(flatten)] common: CommonOpt, /// Path to data files (lineitem). Only parquet format is supported - #[structopt(parse(from_os_str), required = true, short = "p", long = "path")] + #[arg(required = true, short = 'p', long = "path")] path: PathBuf, /// Load the data into a MemTable before executing the query - #[structopt(short = "m", long = "mem-table")] + #[arg(short = 'm', long = "mem-table")] mem_table: bool, /// Path to JSON benchmark result to be compare using `compare.py` - #[structopt(parse(from_os_str), short = "o", long = "output")] + #[arg(short = 'o', long = "output")] output_path: Option, } -struct QueryResult { - elapsed: std::time::Duration, - row_count: usize, -} - /// Query Memory Limits /// Map query id to predefined memory limits /// @@ -118,7 +119,7 @@ impl ExternalAggrConfig { "#, ]; - /// If `--query` and `--memory-limit` is not speicified, run all queries + /// If `--query` and `--memory-limit` is not specified, run all queries /// with pre-configured memory limits /// If only `--query` is specified, run the query with all memory limits /// for this query @@ -343,7 +344,8 @@ impl ExternalAggrConfig { pub async fn main() -> Result<()> { env_logger::init(); - match ExternalAggrOpt::from_args() { + let cli = Cli::parse(); + match cli.command { ExternalAggrOpt::Benchmark(opt) => opt.run().await?, } diff --git a/benchmarks/src/bin/imdb.rs b/benchmarks/src/bin/imdb.rs index 13421f8a89a9b..e86735f87b8f1 100644 --- a/benchmarks/src/bin/imdb.rs +++ b/benchmarks/src/bin/imdb.rs @@ -17,9 +17,9 @@ //! IMDB binary entrypoint +use clap::{Parser, Subcommand}; use datafusion::error::Result; use datafusion_benchmarks::imdb; -use structopt::StructOpt; #[cfg(all(feature = "snmalloc", feature = "mimalloc"))] compile_error!( @@ -34,26 +34,32 @@ static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; #[global_allocator] static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc; -#[derive(Debug, StructOpt)] -#[structopt(about = "benchmark command")] -enum BenchmarkSubCommandOpt { - #[structopt(name = "datafusion")] - DataFusionBenchmark(imdb::RunOpt), +#[derive(Debug, Parser)] +#[command(name = "IMDB", about = "IMDB Dataset Processing.")] +struct Cli { + #[command(subcommand)] + command: ImdbOpt, } -#[derive(Debug, StructOpt)] -#[structopt(name = "IMDB", about = "IMDB Dataset Processing.")] +#[derive(Debug, Subcommand)] enum ImdbOpt { + #[command(subcommand)] Benchmark(BenchmarkSubCommandOpt), Convert(imdb::ConvertOpt), } +#[derive(Debug, Subcommand)] +enum BenchmarkSubCommandOpt { + #[command(name = "datafusion")] + DataFusionBenchmark(imdb::RunOpt), +} + #[tokio::main] pub async fn main() -> Result<()> { env_logger::init(); - match ImdbOpt::from_args() { + match Cli::parse().command { ImdbOpt::Benchmark(BenchmarkSubCommandOpt::DataFusionBenchmark(opt)) => { - opt.run().await + Box::pin(opt.run()).await } ImdbOpt::Convert(opt) => opt.run().await, } diff --git a/benchmarks/src/bin/mem_profile.rs b/benchmarks/src/bin/mem_profile.rs new file mode 100644 index 0000000000000..41a0baecbba86 --- /dev/null +++ b/benchmarks/src/bin/mem_profile.rs @@ -0,0 +1,357 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! mem_profile binary entrypoint +use clap::{Parser, Subcommand}; +use datafusion::error::Result; +use std::{ + env, + io::{BufRead, BufReader}, + path::Path, + process::{Command, Stdio}, +}; + +use datafusion_benchmarks::{ + clickbench, + h2o::{self, AllQueries}, + imdb, sort_tpch, tpch, +}; + +#[derive(Debug, Parser)] +#[command(name = "Memory Profiling Utility")] +struct Cli { + /// Cargo profile to use in dfbench (e.g. release, release-nonlto) + #[arg(long, default_value = "release")] + bench_profile: String, + + #[command(subcommand)] + command: Options, +} + +#[derive(Debug, Subcommand)] +#[command(about = "Benchmark command")] +enum Options { + Clickbench(clickbench::RunOpt), + H2o(h2o::RunOpt), + Imdb(imdb::RunOpt), + SortTpch(sort_tpch::RunOpt), + Tpch(tpch::RunOpt), +} + +#[tokio::main] +pub async fn main() -> Result<()> { + // 1. Parse args and check which benchmarks should be run + let cli = Cli::parse(); + let profile = cli.bench_profile; + let query_range = match cli.command { + Options::Clickbench(opt) => { + let entries = std::fs::read_dir(&opt.queries_path)? + .filter_map(Result::ok) + .filter(|e| { + let path = e.path(); + path.extension().map(|ext| ext == "sql").unwrap_or(false) + }) + .collect::>(); + + let max_query_id = entries.len().saturating_sub(1); + match opt.query { + Some(query_id) => query_id..=query_id, + None => 0..=max_query_id, + } + } + Options::H2o(opt) => { + let queries = AllQueries::try_new(&opt.queries_path)?; + match opt.query { + Some(query_id) => query_id..=query_id, + None => queries.min_query_id()..=queries.max_query_id(), + } + } + Options::Imdb(opt) => match opt.query { + Some(query_id) => query_id..=query_id, + None => imdb::IMDB_QUERY_START_ID..=imdb::IMDB_QUERY_END_ID, + }, + Options::SortTpch(opt) => match opt.query { + Some(query_id) => query_id..=query_id, + None => { + sort_tpch::SORT_TPCH_QUERY_START_ID..=sort_tpch::SORT_TPCH_QUERY_END_ID + } + }, + Options::Tpch(opt) => match opt.query { + Some(query_id) => query_id..=query_id, + None => tpch::TPCH_QUERY_START_ID..=tpch::TPCH_QUERY_END_ID, + }, + }; + + // 2. Prebuild dfbench binary so that memory does not blow up due to build process + println!("Pre-building benchmark binary..."); + let status = Command::new("cargo") + .args([ + "build", + "--profile", + &profile, + "--features", + "mimalloc_extended", + "--bin", + "dfbench", + ]) + .status() + .expect("Failed to build dfbench"); + assert!(status.success()); + println!("Benchmark binary built successfully."); + + // 3. Create a new process per each benchmark query and print summary + // Find position of subcommand to collect args for dfbench + let args: Vec<_> = env::args().collect(); + let subcommands = ["tpch", "clickbench", "h2o", "imdb", "sort-tpch"]; + let sub_pos = args + .iter() + .position(|s| subcommands.iter().any(|&cmd| s == cmd)) + .expect("No benchmark subcommand found"); + + // Args starting from subcommand become dfbench args + let mut dfbench_args: Vec = + args[sub_pos..].iter().map(|s| s.to_string()).collect(); + + run_benchmark_as_child_process(&profile, query_range, &mut dfbench_args)?; + + Ok(()) +} + +fn run_benchmark_as_child_process( + profile: &str, + query_range: std::ops::RangeInclusive, + args: &mut Vec, +) -> Result<()> { + let mut query_strings: Vec = Vec::new(); + for i in query_range { + query_strings.push(i.to_string()); + } + + let target_dir = + env::var("CARGO_TARGET_DIR").unwrap_or_else(|_| "target".to_string()); + let command = format!("{target_dir}/{profile}/dfbench"); + // Check whether benchmark binary exists + if !Path::new(&command).exists() { + panic!( + "Benchmark binary not found: `{command}`\nRun this command from the top-level `datafusion/` directory so `target/{profile}/dfbench` can be found.", + ); + } + args.insert(0, command); + let mut results = vec![]; + + // Run Single Query (args already contain --query num) + if args.contains(&"--query".to_string()) { + let _ = run_query(args, &mut results); + print_summary_table(&results); + return Ok(()); + } + + // Run All Queries + args.push("--query".to_string()); + for query_str in query_strings { + args.push(query_str); + let _ = run_query(args, &mut results); + args.pop(); + } + + print_summary_table(&results); + Ok(()) +} + +fn run_query(args: &[String], results: &mut Vec) -> Result<()> { + let exec_path = &args[0]; + let exec_args = &args[1..]; + + let mut child = Command::new(exec_path) + .args(exec_args) + .stdout(Stdio::piped()) + .spawn() + .expect("Failed to start benchmark"); + + let stdout = child.stdout.take().unwrap(); + let reader = BufReader::new(stdout); + + // Buffer child's stdout + let lines: Result, std::io::Error> = + reader.lines().collect::>(); + + child + .wait() + .expect("Benchmark process exited with an error"); + + // Parse after child process terminates + let lines = lines?; + let mut iter = lines.iter().peekable(); + + // Look for lines that contain execution time / memory stats + while let Some(line) = iter.next() { + if let Some((query, duration_ms)) = parse_query_time(line) + && let Some(next_line) = iter.peek() + && let Some((peak_rss, peak_commit, page_faults)) = parse_vm_line(next_line) + { + results.push(QueryResult { + query, + duration_ms, + peak_rss, + peak_commit, + page_faults, + }); + break; + } + } + + Ok(()) +} + +#[derive(Debug)] +struct QueryResult { + query: usize, + duration_ms: f64, + peak_rss: String, + peak_commit: String, + page_faults: String, +} + +fn parse_query_time(line: &str) -> Option<(usize, f64)> { + let re = regex::Regex::new(r"Query (\d+) avg time: ([\d.]+) ms").unwrap(); + if let Some(caps) = re.captures(line) { + let query_id = caps[1].parse::().ok()?; + let avg_time = caps[2].parse::().ok()?; + Some((query_id, avg_time)) + } else { + None + } +} + +fn parse_vm_line(line: &str) -> Option<(String, String, String)> { + let re = regex::Regex::new( + r"Peak RSS:\s*([\d.]+\s*[A-Z]+),\s*Peak Commit:\s*([\d.]+\s*[A-Z]+),\s*Page Faults:\s*([\d.]+)" + ).ok()?; + let caps = re.captures(line)?; + let peak_rss = caps.get(1)?.as_str().to_string(); + let peak_commit = caps.get(2)?.as_str().to_string(); + let page_faults = caps.get(3)?.as_str().to_string(); + Some((peak_rss, peak_commit, page_faults)) +} + +// Print as simple aligned table +fn print_summary_table(results: &[QueryResult]) { + println!( + "\n{:<8} {:>10} {:>12} {:>12} {:>18}", + "Query", "Time (ms)", "Peak RSS", "Peak Commit", "Major Page Faults" + ); + println!("{}", "-".repeat(64)); + + for r in results { + println!( + "{:<8} {:>10.2} {:>12} {:>12} {:>18}", + r.query, r.duration_ms, r.peak_rss, r.peak_commit, r.page_faults + ); + } +} + +#[cfg(test)] +// Only run with "ci" mode when we have the data +#[cfg(feature = "ci")] +mod tests { + use datafusion::common::exec_err; + use datafusion::error::Result; + use std::path::{Path, PathBuf}; + use std::process::Command; + + fn get_tpch_data_path() -> Result { + let path = + std::env::var("TPCH_DATA").unwrap_or_else(|_| "benchmarks/data".to_string()); + if !Path::new(&path).exists() { + return exec_err!( + "Benchmark data not found (set TPCH_DATA env var to override): {}", + path + ); + } + Ok(path) + } + + // Try to find target/ dir upward + fn find_target_dir(start: &Path) -> Option { + let mut dir = start; + + while let Some(current) = Some(dir) { + if current.join("target").is_dir() { + return Some(current.join("target")); + } + + dir = match current.parent() { + Some(parent) => parent, + None => break, + }; + } + + None + } + + #[test] + // This test checks whether `mem_profile` runs successfully and produces expected output + // using TPC-H query 6 (which runs quickly). + fn mem_profile_e2e_tpch_q6() -> Result<()> { + let profile = "ci"; + let tpch_data = get_tpch_data_path()?; + + // The current working directory may not be the top-level datafusion/ directory, + // so we manually walkdir upward, locate the target directory + // and set it explicitly via CARGO_TARGET_DIR for the mem_profile command. + let target_dir = find_target_dir(&std::env::current_dir()?); + let output = Command::new("cargo") + .env("CARGO_TARGET_DIR", target_dir.unwrap()) + .args([ + "run", + "--profile", + profile, + "--bin", + "mem_profile", + "--", + "--bench-profile", + profile, + "tpch", + "--query", + "6", + "--path", + &tpch_data, + "--format", + "tbl", + ]) + .output() + .expect("Failed to run mem_profile"); + + let stdout = String::from_utf8_lossy(&output.stdout); + let stderr = String::from_utf8_lossy(&output.stderr); + + if !output.status.success() { + panic!( + "mem_profile failed\nstdout:\n{stdout}\nstderr:\n{stderr}---------------------", + ); + } + + assert!( + stdout.contains("Peak RSS") + && stdout.contains("Query") + && stdout.contains("Time"), + "Unexpected output:\n{stdout}", + ); + + Ok(()) + } +} diff --git a/benchmarks/src/bin/parquet.rs b/benchmarks/src/bin/parquet.rs deleted file mode 100644 index 6351a71a7bd3f..0000000000000 --- a/benchmarks/src/bin/parquet.rs +++ /dev/null @@ -1,49 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::common::Result; - -use datafusion_benchmarks::{parquet_filter, sort}; -use structopt::StructOpt; - -#[cfg(feature = "snmalloc")] -#[global_allocator] -static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; - -#[derive(Debug, Clone, StructOpt)] -#[structopt(name = "Benchmarks", about = "Apache DataFusion Rust Benchmarks.")] -enum ParquetBenchCmd { - /// Benchmark sorting parquet files - Sort(sort::RunOpt), - /// Benchmark parquet filter pushdown - Filter(parquet_filter::RunOpt), -} - -#[tokio::main] -async fn main() -> Result<()> { - let cmd = ParquetBenchCmd::from_args(); - match cmd { - ParquetBenchCmd::Filter(opt) => { - println!("running filter benchmarks"); - opt.run().await - } - ParquetBenchCmd::Sort(opt) => { - println!("running sort benchmarks"); - opt.run().await - } - } -} diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs deleted file mode 100644 index 3270b082cfb43..0000000000000 --- a/benchmarks/src/bin/tpch.rs +++ /dev/null @@ -1,65 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! tpch binary only entrypoint - -use datafusion::error::Result; -use datafusion_benchmarks::tpch; -use structopt::StructOpt; - -#[cfg(all(feature = "snmalloc", feature = "mimalloc"))] -compile_error!( - "feature \"snmalloc\" and feature \"mimalloc\" cannot be enabled at the same time" -); - -#[cfg(feature = "snmalloc")] -#[global_allocator] -static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; - -#[cfg(feature = "mimalloc")] -#[global_allocator] -static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc; - -#[derive(Debug, StructOpt)] -#[structopt(about = "benchmark command")] -enum BenchmarkSubCommandOpt { - #[structopt(name = "datafusion")] - DataFusionBenchmark(tpch::RunOpt), -} - -#[derive(Debug, StructOpt)] -#[structopt(name = "TPC-H", about = "TPC-H Benchmarks.")] -enum TpchOpt { - Benchmark(BenchmarkSubCommandOpt), - Convert(tpch::ConvertOpt), -} - -/// 'tpch' entry point, with tortured command line arguments. Please -/// use `dbbench` instead. -/// -/// Note: this is kept to be backwards compatible with the benchmark names prior to -/// -#[tokio::main] -async fn main() -> Result<()> { - env_logger::init(); - match TpchOpt::from_args() { - TpchOpt::Benchmark(BenchmarkSubCommandOpt::DataFusionBenchmark(opt)) => { - opt.run().await - } - TpchOpt::Convert(opt) => opt.run().await, - } -} diff --git a/benchmarks/src/cancellation.rs b/benchmarks/src/cancellation.rs index fcf03fbc54550..d3da1b0e83623 100644 --- a/benchmarks/src/cancellation.rs +++ b/benchmarks/src/cancellation.rs @@ -24,24 +24,24 @@ use crate::util::{BenchmarkRun, CommonOpt}; use arrow::array::Array; use arrow::datatypes::DataType; use arrow::record_batch::RecordBatch; +use clap::Args; use datafusion::common::{Result, ScalarValue}; -use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::file_format::FileFormat; +use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::{ListingOptions, ListingTableUrl}; -use datafusion::execution::object_store::ObjectStoreUrl; use datafusion::execution::TaskContext; -use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion::execution::object_store::ObjectStoreUrl; use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::prelude::*; use datafusion_common::instant::Instant; use futures::TryStreamExt; use object_store::ObjectStore; -use parquet::arrow::async_writer::ParquetObjectWriter; use parquet::arrow::AsyncArrowWriter; +use parquet::arrow::async_writer::ParquetObjectWriter; +use rand::Rng; use rand::distr::Alphanumeric; use rand::rngs::ThreadRng; -use rand::Rng; -use structopt::StructOpt; use tokio::runtime::Runtime; use tokio_util::sync::CancellationToken; @@ -57,31 +57,31 @@ use tokio_util::sync::CancellationToken; /// The query is an anonymized version of a real-world query, and the /// test starts the query then cancels it and reports how long it takes /// for the runtime to fully exit. -#[derive(Debug, StructOpt, Clone)] -#[structopt(verbatim_doc_comment)] +#[derive(Debug, Args, Clone)] +#[command(verbatim_doc_comment)] pub struct RunOpt { /// Common options - #[structopt(flatten)] + #[command(flatten)] common: CommonOpt, /// Path to folder where data will be generated - #[structopt(parse(from_os_str), required = true, short = "p", long = "path")] + #[arg(required = true, short = 'p', long = "path")] path: PathBuf, /// Path to machine readable output file - #[structopt(parse(from_os_str), short = "o", long = "output")] + #[arg(short = 'o', long = "output")] output_path: Option, /// Number of files to generate - #[structopt(long = "num-files", default_value = "7")] + #[arg(long = "num-files", default_value = "7")] num_files: usize, /// Number of rows per file to generate - #[structopt(long = "num-rows-per-file", default_value = "5000000")] + #[arg(long = "num-rows-per-file", default_value = "5000000")] num_rows_per_file: usize, /// How long to wait, in milliseconds, before attempting to cancel - #[structopt(long = "wait-time", default_value = "100")] + #[arg(long = "wait-time", default_value = "100")] wait_time: u64, } diff --git a/benchmarks/src/clickbench.rs b/benchmarks/src/clickbench.rs index 2e934346748e1..70aaeb7d2d192 100644 --- a/benchmarks/src/clickbench.rs +++ b/benchmarks/src/clickbench.rs @@ -15,19 +15,31 @@ // specific language governing permissions and limitations // under the License. -use std::path::Path; -use std::path::PathBuf; +use std::fs; +use std::io::ErrorKind; +use std::path::{Path, PathBuf}; -use crate::util::{BenchmarkRun, CommonOpt}; +use crate::util::{BenchmarkRun, CommonOpt, QueryResult, print_memory_stats}; +use clap::Args; +use datafusion::logical_expr::{ExplainFormat, ExplainOption}; use datafusion::{ error::{DataFusionError, Result}, prelude::SessionContext, }; use datafusion_common::exec_datafusion_err; use datafusion_common::instant::Instant; -use structopt::StructOpt; -/// Run the clickbench benchmark +/// SQL to create the hits view with proper EventDate casting. +/// +/// ClickBench stores EventDate as UInt16 (days since 1970-01-01) for +/// storage efficiency (2 bytes vs 4-8 bytes for date types). +/// This view transforms it to SQL DATE type for query compatibility. +const HITS_VIEW_DDL: &str = r#"CREATE VIEW hits AS +SELECT * EXCEPT ("EventDate"), + CAST(CAST("EventDate" AS INTEGER) AS DATE) AS "EventDate" +FROM hits_raw"#; + +/// Driver program to run the ClickBench benchmark /// /// The ClickBench[1] benchmarks are widely cited in the industry and /// focus on grouping / aggregation / filtering. This runner uses the @@ -35,140 +47,308 @@ use structopt::StructOpt; /// /// [1]: https://github.com/ClickHouse/ClickBench /// [2]: https://github.com/ClickHouse/ClickBench/tree/main/datafusion -#[derive(Debug, StructOpt, Clone)] -#[structopt(verbatim_doc_comment)] +#[derive(Debug, Args, Clone)] +#[command(verbatim_doc_comment)] pub struct RunOpt { /// Query number (between 0 and 42). If not specified, runs all queries - #[structopt(short, long)] - query: Option, + #[arg(short, long)] + pub query: Option, + + /// If specified, enables Parquet Filter Pushdown. + /// + /// Specifically, it enables: + /// * `pushdown_filters = true` + /// * `reorder_filters = true` + #[arg(long = "pushdown")] + pushdown: bool, /// Common options - #[structopt(flatten)] + #[command(flatten)] common: CommonOpt, /// Path to hits.parquet (single file) or `hits_partitioned` /// (partitioned, 100 files) - #[structopt( - parse(from_os_str), - short = "p", + #[arg( + short = 'p', long = "path", default_value = "benchmarks/data/hits.parquet" )] path: PathBuf, - /// Path to queries.sql (single file) - #[structopt( - parse(from_os_str), - short = "r", + /// Path to queries directory + #[arg( + short = 'r', long = "queries-path", - default_value = "benchmarks/queries/clickbench/queries.sql" + default_value = "benchmarks/queries/clickbench/queries" )] - queries_path: PathBuf, + pub queries_path: PathBuf, /// If present, write results json here - #[structopt(parse(from_os_str), short = "o", long = "output")] + #[arg(short = 'o', long = "output")] output_path: Option, -} -struct AllQueries { - queries: Vec, -} + /// Column name that the data is sorted by (e.g., "EventTime") + /// If specified, DataFusion will be informed that the data has this sort order + /// using CREATE EXTERNAL TABLE with WITH ORDER clause. + /// + /// Recommended to use with: -c datafusion.optimizer.prefer_existing_sort=true + /// This allows DataFusion to optimize away redundant sorts while maintaining + /// multi-core parallelism for other operations. + #[arg(long = "sorted-by")] + sorted_by: Option, -impl AllQueries { - fn try_new(path: &Path) -> Result { - // ClickBench has all queries in a single file identified by line number - let all_queries = std::fs::read_to_string(path) - .map_err(|e| exec_datafusion_err!("Could not open {path:?}: {e}"))?; - Ok(Self { - queries: all_queries.lines().map(|s| s.to_string()).collect(), - }) - } + /// Sort order: ASC or DESC (default: ASC) + #[arg(long = "sort-order", default_value = "ASC")] + sort_order: String, - /// Returns the text of query `query_id` - fn get_query(&self, query_id: usize) -> Result<&str> { - self.queries - .get(query_id) - .ok_or_else(|| { - let min_id = self.min_query_id(); - let max_id = self.max_query_id(); - exec_datafusion_err!( - "Invalid query id {query_id}. Must be between {min_id} and {max_id}" - ) - }) - .map(|s| s.as_str()) - } + /// Configuration options in the format key=value + /// Can be specified multiple times. + /// + /// Example: -c datafusion.optimizer.prefer_existing_sort=true + #[arg(short = 'c', long = "config")] + config_options: Vec, +} - fn min_query_id(&self) -> usize { - 0 - } +/// Get the SQL file path +pub fn get_query_path(query_dir: &Path, query: usize) -> PathBuf { + let mut query_path = query_dir.to_path_buf(); + query_path.push(format!("q{query}.sql")); + query_path +} - fn max_query_id(&self) -> usize { - self.queries.len() - 1 +/// Get the SQL statement from the specified query file +pub fn get_query_sql(query_path: &Path) -> Result> { + if fs::exists(query_path)? { + Ok(Some(fs::read_to_string(query_path)?)) + } else { + Ok(None) } } + impl RunOpt { pub async fn run(self) -> Result<()> { println!("Running benchmarks with the following options: {self:?}"); - let queries = AllQueries::try_new(self.queries_path.as_path())?; + + let query_dir_metadata = fs::metadata(&self.queries_path).map_err(|e| { + if e.kind() == ErrorKind::NotFound { + exec_datafusion_err!( + "Query path '{}' does not exist.", + &self.queries_path.to_str().unwrap() + ) + } else { + DataFusionError::External(Box::new(e)) + } + })?; + + if !query_dir_metadata.is_dir() { + return Err(exec_datafusion_err!( + "Query path '{}' is not a directory.", + &self.queries_path.to_str().unwrap() + )); + } + let query_range = match self.query { Some(query_id) => query_id..=query_id, - None => queries.min_query_id()..=queries.max_query_id(), + None => 0..=usize::MAX, }; // configure parquet options let mut config = self.common.config()?; + + if self.sorted_by.is_some() { + println!("ℹ️ Data is registered with sort order"); + + let has_prefer_sort = self + .config_options + .iter() + .any(|opt| opt.contains("prefer_existing_sort=true")); + + if !has_prefer_sort { + println!( + "ℹ️ Consider using -c datafusion.optimizer.prefer_existing_sort=true" + ); + println!("ℹ️ to optimize queries while maintaining parallelism"); + } + } + + // Apply user-provided configuration options + for config_opt in &self.config_options { + let parts: Vec<&str> = config_opt.splitn(2, '=').collect(); + if parts.len() != 2 { + return Err(exec_datafusion_err!( + "Invalid config option format: '{}'. Expected 'key=value'", + config_opt + )); + } + let key = parts[0]; + let value = parts[1]; + + println!("Setting config: {key} = {value}"); + config = config.set_str(key, value); + } + { let parquet_options = &mut config.options_mut().execution.parquet; // The hits_partitioned dataset specifies string columns // as binary due to how it was written. Force it to strings parquet_options.binary_as_string = true; + + // Turn on Parquet filter pushdown if requested + if self.pushdown { + parquet_options.pushdown_filters = true; + parquet_options.reorder_filters = true; + } + + if self.sorted_by.is_some() { + // We should compare the dynamic topk optimization when data is sorted, so we make the + // assumption that filter pushdown is also enabled in this case. + parquet_options.pushdown_filters = true; + parquet_options.reorder_filters = true; + } } - let rt_builder = self.common.runtime_env_builder()?; - let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); + let rt = self.common.build_runtime()?; + let ctx = SessionContext::new_with_config_rt(config, rt); + self.register_hits(&ctx).await?; - let iterations = self.common.iterations; let mut benchmark_run = BenchmarkRun::new(); for query_id in query_range { - let mut millis = Vec::with_capacity(iterations); + let query_path = get_query_path(&self.queries_path, query_id); + let Some(sql) = get_query_sql(&query_path)? else { + if self.query.is_some() { + return Err(exec_datafusion_err!( + "Could not load query file '{}'.", + &query_path.to_str().unwrap() + )); + } + break; + }; benchmark_run.start_new_case(&format!("Query {query_id}")); - let sql = queries.get_query(query_id)?; - println!("Q{query_id}: {sql}"); - - for i in 0..iterations { - let start = Instant::now(); - let results = ctx.sql(sql).await?.collect().await?; - let elapsed = start.elapsed(); - let ms = elapsed.as_secs_f64() * 1000.0; - millis.push(ms); - let row_count: usize = results.iter().map(|b| b.num_rows()).sum(); - println!( - "Query {query_id} iteration {i} took {ms:.1} ms and returned {row_count} rows" - ); - benchmark_run.write_iter(elapsed, row_count); - } - if self.common.debug { - ctx.sql(sql).await?.explain(false, false)?.show().await?; + let query_run = self.benchmark_query(&sql, query_id, &ctx).await; + match query_run { + Ok(query_results) => { + for iter in query_results { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + Err(e) => { + benchmark_run.mark_failed(); + eprintln!("Query {query_id} failed: {e}"); + } } - let avg = millis.iter().sum::() / millis.len() as f64; - println!("Query {query_id} avg time: {avg:.2} ms"); } benchmark_run.maybe_write_json(self.output_path.as_ref())?; + benchmark_run.maybe_print_failures(); Ok(()) } + async fn benchmark_query( + &self, + sql: &str, + query_id: usize, + ctx: &SessionContext, + ) -> Result> { + println!("Q{query_id}: {sql}"); + + let mut millis = Vec::with_capacity(self.iterations()); + let mut query_results = vec![]; + for i in 0..self.iterations() { + let start = Instant::now(); + let results = ctx.sql(sql).await?.collect().await?; + let elapsed = start.elapsed(); + let ms = elapsed.as_secs_f64() * 1000.0; + millis.push(ms); + let row_count: usize = results.iter().map(|b| b.num_rows()).sum(); + println!( + "Query {query_id} iteration {i} took {ms:.1} ms and returned {row_count} rows" + ); + query_results.push(QueryResult { elapsed, row_count }) + } + if self.common.debug { + ctx.sql(sql) + .await? + .explain_with_options( + ExplainOption::default().with_format(ExplainFormat::Tree), + )? + .show() + .await?; + } + let avg = millis.iter().sum::() / millis.len() as f64; + println!("Query {query_id} avg time: {avg:.2} ms"); + + // Print memory usage stats using mimalloc (only when compiled with --features mimalloc_extended) + print_memory_stats(); + + Ok(query_results) + } + /// Registers the `hits.parquet` as a table named `hits` + /// If sorted_by is specified, uses CREATE EXTERNAL TABLE with WITH ORDER async fn register_hits(&self, ctx: &SessionContext) -> Result<()> { - let options = Default::default(); let path = self.path.as_os_str().to_str().unwrap(); - ctx.register_parquet("hits", path, options) - .await - .map_err(|e| { - DataFusionError::Context( - format!("Registering 'hits' as {path}"), - Box::new(e), - ) - }) + + // If sorted_by is specified, use CREATE EXTERNAL TABLE with WITH ORDER + if let Some(ref sort_column) = self.sorted_by { + println!( + "Registering table with sort order: {} {}", + sort_column, self.sort_order + ); + + // Escape column name with double quotes + let escaped_column = if sort_column.contains('"') { + sort_column.clone() + } else { + format!("\"{sort_column}\"") + }; + + // Build CREATE EXTERNAL TABLE DDL with WITH ORDER clause + // Schema will be automatically inferred from the Parquet file + let create_table_sql = format!( + "CREATE EXTERNAL TABLE hits_raw \ + STORED AS PARQUET \ + LOCATION '{}' \ + WITH ORDER ({} {})", + path, + escaped_column, + self.sort_order.to_uppercase() + ); + + println!("Executing: {create_table_sql}"); + + // Execute the CREATE EXTERNAL TABLE statement + ctx.sql(&create_table_sql).await?.collect().await?; + } else { + // Original registration without sort order + let options = Default::default(); + ctx.register_parquet("hits_raw", path, options) + .await + .map_err(|e| { + DataFusionError::Context( + format!("Registering 'hits_raw' as {path}"), + Box::new(e), + ) + })?; + } + + // Create the hits view with EventDate transformation + Self::create_hits_view(ctx).await + } + + /// Creates the hits view with EventDate transformation from UInt16 to DATE. + /// + /// ClickBench encodes EventDate as UInt16 days since epoch (1970-01-01). + async fn create_hits_view(ctx: &SessionContext) -> Result<()> { + ctx.sql(HITS_VIEW_DDL).await?.collect().await.map_err(|e| { + DataFusionError::Context( + "Creating 'hits' view with EventDate transformation".to_string(), + Box::new(e), + ) + })?; + Ok(()) + } + + fn iterations(&self) -> usize { + self.common.iterations } } diff --git a/benchmarks/src/h2o.rs b/benchmarks/src/h2o.rs index 23dba07f426da..8b6e04932cb39 100644 --- a/benchmarks/src/h2o.rs +++ b/benchmarks/src/h2o.rs @@ -20,41 +20,40 @@ //! - [H2O AI Benchmark](https://duckdb.org/2023/04/14/h2oai.html) //! - [Extended window function benchmark](https://duckdb.org/2024/06/26/benchmarks-over-time.html#window-functions-benchmark) -use crate::util::{BenchmarkRun, CommonOpt}; +use crate::util::{BenchmarkRun, CommonOpt, print_memory_stats}; +use clap::Args; +use datafusion::logical_expr::{ExplainFormat, ExplainOption}; use datafusion::{error::Result, prelude::SessionContext}; use datafusion_common::{ - exec_datafusion_err, instant::Instant, internal_err, DataFusionError, + DataFusionError, TableReference, exec_datafusion_err, instant::Instant, internal_err, }; use std::path::{Path, PathBuf}; -use structopt::StructOpt; /// Run the H2O benchmark -#[derive(Debug, StructOpt, Clone)] -#[structopt(verbatim_doc_comment)] +#[derive(Debug, Args, Clone)] +#[command(verbatim_doc_comment)] pub struct RunOpt { - #[structopt(short, long)] - query: Option, + #[arg(short, long)] + pub query: Option, /// Common options - #[structopt(flatten)] + #[command(flatten)] common: CommonOpt, /// Path to queries.sql (single file) /// default value is the groupby.sql file in the h2o benchmark - #[structopt( - parse(from_os_str), - short = "r", + #[arg( + short = 'r', long = "queries-path", default_value = "benchmarks/queries/h2o/groupby.sql" )] - queries_path: PathBuf, + pub queries_path: PathBuf, /// Path to data file (parquet or csv) /// Default value is the G1_1e7_1e7_100_0.csv file in the h2o benchmark /// This is the small csv file with 10^7 rows - #[structopt( - parse(from_os_str), - short = "p", + #[arg( + short = 'p', long = "path", default_value = "benchmarks/data/h2o/G1_1e7_1e7_100_0.csv" )] @@ -63,15 +62,15 @@ pub struct RunOpt { /// Path to data files (parquet or csv), using , to separate the paths /// Default value is the small files for join x table, small table, medium table, big table files in the h2o benchmark /// This is the small csv file case - #[structopt( - short = "join-paths", + #[arg( + short = 'j', long = "join-paths", default_value = "benchmarks/data/h2o/J1_1e7_NA_0.csv,benchmarks/data/h2o/J1_1e7_1e1_0.csv,benchmarks/data/h2o/J1_1e7_1e4_0.csv,benchmarks/data/h2o/J1_1e7_1e7_NA.csv" )] join_paths: String, /// If present, write results json here - #[structopt(parse(from_os_str), short = "o", long = "output")] + #[arg(short = 'o', long = "output")] output_path: Option, } @@ -85,24 +84,24 @@ impl RunOpt { }; let config = self.common.config()?; - let rt_builder = self.common.runtime_env_builder()?; - let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); + let rt = self.common.build_runtime()?; + let ctx = SessionContext::new_with_config_rt(config, rt); // Register tables depending on which h2o benchmark is being run // (groupby/join/window) if self.queries_path.to_str().unwrap().ends_with("groupby.sql") { - self.register_data(&ctx).await?; + self.register_data("x", self.path.as_os_str().to_str().unwrap(), &ctx) + .await?; } else if self.queries_path.to_str().unwrap().ends_with("join.sql") { let join_paths: Vec<&str> = self.join_paths.split(',').collect(); let table_name: Vec<&str> = vec!["x", "small", "medium", "large"]; for (i, path) in join_paths.iter().enumerate() { - ctx.register_csv(table_name[i], path, Default::default()) - .await?; + self.register_data(table_name[i], path, &ctx).await?; } } else if self.queries_path.to_str().unwrap().ends_with("window.sql") { // Only register the 'large' table in h2o-join dataset let h2o_join_large_path = self.join_paths.split(',').nth(3).unwrap(); - ctx.register_csv("large", h2o_join_large_path, Default::default()) + self.register_data("large", h2o_join_large_path, &ctx) .await?; } else { return internal_err!("Invalid query file path"); @@ -131,8 +130,17 @@ impl RunOpt { let avg = millis.iter().sum::() / millis.len() as f64; println!("Query {query_id} avg time: {avg:.2} ms"); + // Print memory usage stats using mimalloc (only when compiled with --features mimalloc_extended) + print_memory_stats(); + if self.common.debug { - ctx.sql(sql).await?.explain(false, false)?.show().await?; + ctx.sql(sql) + .await? + .explain_with_options( + ExplainOption::default().with_format(ExplainFormat::Tree), + )? + .show() + .await?; } benchmark_run.maybe_write_json(self.output_path.as_ref())?; } @@ -140,49 +148,62 @@ impl RunOpt { Ok(()) } - async fn register_data(&self, ctx: &SessionContext) -> Result<()> { + async fn register_data( + &self, + table_ref: impl Into, + table_path: impl AsRef, + ctx: &SessionContext, + ) -> Result<()> { let csv_options = Default::default(); let parquet_options = Default::default(); - let path = self.path.as_os_str().to_str().unwrap(); - - if self.path.extension().map(|s| s == "csv").unwrap_or(false) { - ctx.register_csv("x", path, csv_options) - .await - .map_err(|e| { - DataFusionError::Context( - format!("Registering 'table' as {path}"), - Box::new(e), - ) - }) - .expect("error registering csv"); - } - if self - .path + let table_path_str = table_path.as_ref(); + + let extension = Path::new(table_path_str) .extension() - .map(|s| s == "parquet") - .unwrap_or(false) - { - ctx.register_parquet("x", path, parquet_options) - .await - .map_err(|e| { - DataFusionError::Context( - format!("Registering 'table' as {path}"), - Box::new(e), - ) - }) - .expect("error registering parquet"); + .and_then(|s| s.to_str()) + .unwrap_or(""); + + match extension { + "csv" => { + ctx.register_csv(table_ref, table_path_str, csv_options) + .await + .map_err(|e| { + DataFusionError::Context( + format!("Registering 'table' as {table_path_str}"), + Box::new(e), + ) + }) + .expect("error registering csv"); + } + "parquet" => { + ctx.register_parquet(table_ref, table_path_str, parquet_options) + .await + .map_err(|e| { + DataFusionError::Context( + format!("Registering 'table' as {table_path_str}"), + Box::new(e), + ) + }) + .expect("error registering parquet"); + } + _ => { + return Err(DataFusionError::Plan(format!( + "Unsupported file extension: {extension}", + ))); + } } + Ok(()) } } -struct AllQueries { +pub struct AllQueries { queries: Vec, } impl AllQueries { - fn try_new(path: &Path) -> Result { + pub fn try_new(path: &Path) -> Result { let all_queries = std::fs::read_to_string(path) .map_err(|e| exec_datafusion_err!("Could not open {path:?}: {e}"))?; @@ -192,7 +213,7 @@ impl AllQueries { } /// Returns the text of query `query_id` - fn get_query(&self, query_id: usize) -> Result<&str> { + pub fn get_query(&self, query_id: usize) -> Result<&str> { self.queries .get(query_id - 1) .ok_or_else(|| { @@ -205,11 +226,11 @@ impl AllQueries { .map(|s| s.as_str()) } - fn min_query_id(&self) -> usize { + pub fn min_query_id(&self) -> usize { 1 } - fn max_query_id(&self) -> usize { + pub fn max_query_id(&self) -> usize { self.queries.len() } } diff --git a/benchmarks/src/hj.rs b/benchmarks/src/hj.rs new file mode 100644 index 0000000000000..301fe0d599cd6 --- /dev/null +++ b/benchmarks/src/hj.rs @@ -0,0 +1,441 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::util::{BenchmarkRun, CommonOpt, QueryResult}; +use clap::Args; +use datafusion::physical_plan::execute_stream; +use datafusion::{error::Result, prelude::SessionContext}; +use datafusion_common::instant::Instant; +use datafusion_common::{DataFusionError, exec_datafusion_err, exec_err}; +use std::path::PathBuf; + +use futures::StreamExt; + +// TODO: Add existence joins + +/// Run the Hash Join benchmark +/// +/// This micro-benchmark focuses on the performance characteristics of Hash Joins. +/// It uses simple equality predicates to ensure a hash join is selected. +/// Where we vary selectivity, we do so with additional cheap predicates that +/// do not change the join key (so the physical operator remains HashJoin). +#[derive(Debug, Args, Clone)] +#[command(verbatim_doc_comment)] +pub struct RunOpt { + /// Query number. If not specified, runs all queries + #[arg(short, long)] + query: Option, + + /// Common options (iterations, batch size, target_partitions, etc.) + #[command(flatten)] + common: CommonOpt, + + /// Path to TPC-H SF10 data + #[arg(short = 'p', long = "path")] + path: Option, + + /// If present, write results json here + #[arg(short = 'o', long = "output")] + output_path: Option, +} + +struct HashJoinQuery { + sql: &'static str, + density: f64, + prob_hit: f64, + build_size: &'static str, + probe_size: &'static str, +} + +/// Inline SQL queries for Hash Join benchmarks +const HASH_QUERIES: &[HashJoinQuery] = &[ + // Q1: Very Small Build Side (Dense) + // Build Side: nation (25 rows) | Probe Side: customer (1.5M rows) + HashJoinQuery { + sql: r###"SELECT n_nationkey FROM nation JOIN customer ON c_nationkey = n_nationkey"###, + density: 1.0, + prob_hit: 1.0, + build_size: "25", + probe_size: "1.5M", + }, + // Q2: Very Small Build Side (Sparse, range < 1024) + // Build Side: nation (25 rows, range 961) | Probe Side: customer (1.5M rows) + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT c_nationkey * 40 as k + FROM customer + ) l + JOIN ( + SELECT n_nationkey * 40 as k FROM nation + ) s ON l.k = s.k"###, + density: 0.026, + prob_hit: 1.0, + build_size: "25", + probe_size: "1.5M", + }, + // Q3: 100% Density, 100% Hit rate + HashJoinQuery { + sql: r###"SELECT s_suppkey FROM supplier JOIN lineitem ON s_suppkey = l_suppkey"###, + density: 1.0, + prob_hit: 1.0, + build_size: "100K", + probe_size: "60M", + }, + // Q4: 100% Density, 10% Hit rate + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT CASE WHEN l_suppkey % 10 = 0 THEN l_suppkey ELSE l_suppkey + 1000000 END as k + FROM lineitem + ) l + JOIN ( + SELECT s_suppkey as k FROM supplier + ) s ON l.k = s.k"###, + density: 1.0, + prob_hit: 0.1, + build_size: "100K", + probe_size: "60M", + }, + // Q5: 75% Density, 100% Hit rate + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT l_suppkey * 4 / 3 as k + FROM lineitem + ) l + JOIN ( + SELECT s_suppkey * 4 / 3 as k FROM supplier + ) s ON l.k = s.k"###, + density: 0.75, + prob_hit: 1.0, + build_size: "100K", + probe_size: "60M", + }, + // Q6: 75% Density, 10% Hit rate + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT CASE + WHEN l_suppkey % 10 = 0 THEN l_suppkey * 4 / 3 + WHEN l_suppkey % 10 < 9 THEN (l_suppkey * 4 / 3 / 4) * 4 + 3 + ELSE l_suppkey * 4 / 3 + 1000000 + END as k + FROM lineitem + ) l + JOIN ( + SELECT s_suppkey * 4 / 3 as k FROM supplier + ) s ON l.k = s.k"###, + density: 0.75, + prob_hit: 0.1, + build_size: "100K", + probe_size: "60M", + }, + // Q7: 50% Density, 100% Hit rate + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT l_suppkey * 2 as k + FROM lineitem + ) l + JOIN ( + SELECT s_suppkey * 2 as k FROM supplier + ) s ON l.k = s.k"###, + density: 0.5, + prob_hit: 1.0, + build_size: "100K", + probe_size: "60M", + }, + // Q8: 50% Density, 10% Hit rate + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT CASE + WHEN l_suppkey % 10 = 0 THEN l_suppkey * 2 + WHEN l_suppkey % 10 < 9 THEN l_suppkey * 2 + 1 + ELSE l_suppkey * 2 + 1000000 + END as k + FROM lineitem + ) l + JOIN ( + SELECT s_suppkey * 2 as k FROM supplier + ) s ON l.k = s.k"###, + density: 0.5, + prob_hit: 0.1, + build_size: "100K", + probe_size: "60M", + }, + // Q9: 20% Density, 100% Hit rate + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT l_suppkey * 5 as k + FROM lineitem + ) l + JOIN ( + SELECT s_suppkey * 5 as k FROM supplier + ) s ON l.k = s.k"###, + density: 0.2, + prob_hit: 1.0, + build_size: "100K", + probe_size: "60M", + }, + // Q10: 20% Density, 10% Hit rate + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT CASE + WHEN l_suppkey % 10 = 0 THEN l_suppkey * 5 + WHEN l_suppkey % 10 < 9 THEN l_suppkey * 5 + 1 + ELSE l_suppkey * 5 + 1000000 + END as k + FROM lineitem + ) l + JOIN ( + SELECT s_suppkey * 5 as k FROM supplier + ) s ON l.k = s.k"###, + density: 0.2, + prob_hit: 0.1, + build_size: "100K", + probe_size: "60M", + }, + // Q11: 10% Density, 100% Hit rate + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT l_suppkey * 10 as k + FROM lineitem + ) l + JOIN ( + SELECT s_suppkey * 10 as k FROM supplier + ) s ON l.k = s.k"###, + density: 0.1, + prob_hit: 1.0, + build_size: "100K", + probe_size: "60M", + }, + // Q12: 10% Density, 10% Hit rate + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT CASE + WHEN l_suppkey % 10 = 0 THEN l_suppkey * 10 + WHEN l_suppkey % 10 < 9 THEN l_suppkey * 10 + 1 + ELSE l_suppkey * 10 + 1000000 + END as k + FROM lineitem + ) l + JOIN ( + SELECT s_suppkey * 10 as k FROM supplier + ) s ON l.k = s.k"###, + density: 0.1, + prob_hit: 0.1, + build_size: "100K", + probe_size: "60M", + }, + // Q13: 1% Density, 100% Hit rate + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT l_suppkey * 100 as k + FROM lineitem + ) l + JOIN ( + SELECT s_suppkey * 100 as k FROM supplier + ) s ON l.k = s.k"###, + density: 0.01, + prob_hit: 1.0, + build_size: "100K", + probe_size: "60M", + }, + // Q14: 1% Density, 10% Hit rate + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT CASE + WHEN l_suppkey % 10 = 0 THEN l_suppkey * 100 + WHEN l_suppkey % 10 < 9 THEN l_suppkey * 100 + 1 + ELSE l_suppkey * 100 + 11000000 -- oob + END as k + FROM lineitem + ) l + JOIN ( + SELECT s_suppkey * 100 as k FROM supplier + ) s ON l.k = s.k"###, + density: 0.01, + prob_hit: 0.1, + build_size: "100K", + probe_size: "60M", + }, + // Q15: 20% Density, 10% Hit rate, 20% Duplicates in Build Side + HashJoinQuery { + sql: r###"SELECT l.k + FROM ( + SELECT CASE + WHEN l_suppkey % 10 = 0 THEN ((l_suppkey % 80000) + 1) * 25 / 4 + ELSE ((l_suppkey % 80000) + 1) * 25 / 4 + 1 + END as k + FROM lineitem + ) l + JOIN ( + SELECT CASE + WHEN s_suppkey <= 80000 THEN (s_suppkey * 25) / 4 + ELSE ((s_suppkey - 80000) * 25) / 4 + END as k + FROM supplier + ) s ON l.k = s.k"###, + density: 0.2, + prob_hit: 0.1, + build_size: "100K_(20%_dups)", + probe_size: "60M", + }, +]; + +impl RunOpt { + pub async fn run(self) -> Result<()> { + println!("Running Hash Join benchmarks with the following options: {self:#?}\n"); + + let query_range = match self.query { + Some(query_id) => { + if query_id >= 1 && query_id <= HASH_QUERIES.len() { + query_id..=query_id + } else { + return exec_err!( + "Query {query_id} not found. Available queries: 1 to {}", + HASH_QUERIES.len() + ); + } + } + None => 1..=HASH_QUERIES.len(), + }; + + let config = self.common.config()?; + let rt = self.common.build_runtime()?; + let ctx = SessionContext::new_with_config_rt(config, rt); + + if let Some(path) = &self.path { + for table in &["lineitem", "supplier", "nation", "customer"] { + let table_path = path.join(table); + if !table_path.exists() { + return exec_err!( + "TPC-H table {} not found at {:?}", + table, + table_path + ); + } + ctx.register_parquet( + *table, + table_path.to_str().unwrap(), + Default::default(), + ) + .await?; + } + } + + let mut benchmark_run = BenchmarkRun::new(); + + for query_id in query_range { + let query_index = query_id - 1; + let query = &HASH_QUERIES[query_index]; + + let case_name = format!( + "Query {}_density={}_prob_hit={}_{}*{}", + query_id, + query.density, + query.prob_hit, + query.build_size, + query.probe_size + ); + benchmark_run.start_new_case(&case_name); + + let query_run = self + .benchmark_query(query.sql, &query_id.to_string(), &ctx) + .await; + match query_run { + Ok(query_results) => { + for iter in query_results { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + Err(e) => { + return Err(DataFusionError::Context( + format!("Hash Join benchmark Q{query_id} failed with error:"), + Box::new(e), + )); + } + } + } + + benchmark_run.maybe_write_json(self.output_path.as_ref())?; + Ok(()) + } + + /// Validates that the physical plan uses a HashJoin, then executes. + async fn benchmark_query( + &self, + sql: &str, + query_name: &str, + ctx: &SessionContext, + ) -> Result> { + let mut query_results = vec![]; + + // Build/validate plan + let df = ctx.sql(sql).await?; + let physical_plan = df.create_physical_plan().await?; + let plan_string = format!("{physical_plan:#?}"); + + if !plan_string.contains("HashJoinExec") { + return Err(exec_datafusion_err!( + "Query {query_name} does not use Hash Join. Physical plan: {plan_string}" + )); + } + + // Execute without buffering + for i in 0..self.common.iterations { + let start = Instant::now(); + let row_count = Self::execute_sql_without_result_buffering(sql, ctx).await?; + let elapsed = start.elapsed(); + + println!( + "Query {query_name} iteration {i} returned {row_count} rows in {elapsed:?}" + ); + + query_results.push(QueryResult { elapsed, row_count }); + } + + Ok(query_results) + } + + /// Executes the SQL query and drops each batch to avoid result buffering. + async fn execute_sql_without_result_buffering( + sql: &str, + ctx: &SessionContext, + ) -> Result { + let mut row_count = 0; + + let df = ctx.sql(sql).await?; + let physical_plan = df.create_physical_plan().await?; + let mut stream = execute_stream(physical_plan, ctx.task_ctx())?; + + while let Some(batch) = stream.next().await { + row_count += batch?.num_rows(); + // Drop batches immediately to minimize memory pressure + } + + Ok(row_count) + } +} diff --git a/benchmarks/src/imdb/convert.rs b/benchmarks/src/imdb/convert.rs index e7949aa715c23..aaed186da4905 100644 --- a/benchmarks/src/imdb/convert.rs +++ b/benchmarks/src/imdb/convert.rs @@ -20,31 +20,31 @@ use datafusion::logical_expr::select_expr::SelectExpr; use datafusion_common::instant::Instant; use std::path::PathBuf; +use clap::Args; use datafusion::error::Result; use datafusion::prelude::*; -use structopt::StructOpt; use datafusion::common::not_impl_err; -use super::get_imdb_table_schema; use super::IMDB_TABLES; +use super::get_imdb_table_schema; -#[derive(Debug, StructOpt)] +#[derive(Debug, Args)] pub struct ConvertOpt { /// Path to csv files - #[structopt(parse(from_os_str), required = true, short = "i", long = "input")] + #[arg(required = true, short = 'i', long = "input")] input_path: PathBuf, /// Output path - #[structopt(parse(from_os_str), required = true, short = "o", long = "output")] + #[arg(required = true, short = 'o', long = "output")] output_path: PathBuf, /// Output file format: `csv` or `parquet` - #[structopt(short = "f", long = "format")] + #[arg(short = 'f', long = "format")] file_format: String, /// Batch size when reading CSV or Parquet files - #[structopt(short = "s", long = "batch-size", default_value = "8192")] + #[arg(short = 's', long = "batch-size", default_value = "8192")] batch_size: usize, } diff --git a/benchmarks/src/imdb/mod.rs b/benchmarks/src/imdb/mod.rs index 6a45242e6ff4b..87462bc3e81ba 100644 --- a/benchmarks/src/imdb/mod.rs +++ b/benchmarks/src/imdb/mod.rs @@ -54,6 +54,9 @@ pub const IMDB_TABLES: &[&str] = &[ "person_info", ]; +pub const IMDB_QUERY_START_ID: usize = 1; +pub const IMDB_QUERY_END_ID: usize = 113; + /// Get the schema for the IMDB dataset tables /// see benchmarks/data/imdb/schematext.sql pub fn get_imdb_table_schema(table: &str) -> Schema { diff --git a/benchmarks/src/imdb/run.rs b/benchmarks/src/imdb/run.rs index 61dcc07ebd639..ca9710a920517 100644 --- a/benchmarks/src/imdb/run.rs +++ b/benchmarks/src/imdb/run.rs @@ -18,14 +18,17 @@ use std::path::PathBuf; use std::sync::Arc; -use super::{get_imdb_table_schema, get_query_sql, IMDB_TABLES}; -use crate::util::{BenchmarkRun, CommonOpt}; +use super::{ + IMDB_QUERY_END_ID, IMDB_QUERY_START_ID, IMDB_TABLES, get_imdb_table_schema, + get_query_sql, +}; +use crate::util::{BenchmarkRun, CommonOpt, QueryResult, print_memory_stats}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::{self, pretty_format_batches}; +use datafusion::datasource::file_format::FileFormat; use datafusion::datasource::file_format::csv::CsvFormat; use datafusion::datasource::file_format::parquet::ParquetFormat; -use datafusion::datasource::file_format::FileFormat; use datafusion::datasource::listing::{ ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, }; @@ -38,8 +41,8 @@ use datafusion_common::instant::Instant; use datafusion_common::utils::get_available_parallelism; use datafusion_common::{DEFAULT_CSV_EXTENSION, DEFAULT_PARQUET_EXTENSION}; +use clap::Args; use log::info; -use structopt::StructOpt; // hack to avoid `default_value is meaningless for bool` errors type BoolDefaultTrue = bool; @@ -51,48 +54,49 @@ type BoolDefaultTrue = bool; /// [2] and [3]. /// /// [1]: https://www.vldb.org/pvldb/vol9/p204-leis.pdf -/// [2]: http://homepages.cwi.nl/~boncz/job/imdb.tgz +/// [2]: https://event.cwi.nl/da/job/imdb.tgz /// [3]: https://db.in.tum.de/~leis/qo/job.tgz -#[derive(Debug, StructOpt, Clone)] -#[structopt(verbatim_doc_comment)] +#[derive(Debug, Args, Clone)] +#[command(verbatim_doc_comment)] pub struct RunOpt { /// Query number. If not specified, runs all queries - #[structopt(short, long)] - query: Option, + #[arg(short, long)] + pub query: Option, /// Common options - #[structopt(flatten)] + #[command(flatten)] common: CommonOpt, /// Path to data files - #[structopt(parse(from_os_str), required = true, short = "p", long = "path")] + #[arg(required = true, short = 'p', long = "path")] path: PathBuf, /// File format: `csv` or `parquet` - #[structopt(short = "f", long = "format", default_value = "csv")] + #[arg(short = 'f', long = "format", default_value = "csv")] file_format: String, /// Load the data into a MemTable before executing the query - #[structopt(short = "m", long = "mem-table")] + #[arg(short = 'm', long = "mem-table")] mem_table: bool, /// Path to machine readable output file - #[structopt(parse(from_os_str), short = "o", long = "output")] + #[arg(short = 'o', long = "output")] output_path: Option, /// Whether to disable collection of statistics (and cost based optimizations) or not. - #[structopt(short = "S", long = "disable-statistics")] + #[arg(short = 'S', long = "disable-statistics")] disable_statistics: bool, /// If true then hash join used, if false then sort merge join /// True by default. - #[structopt(short = "j", long = "prefer_hash_join", default_value = "true")] + #[arg(short = 'j', long = "prefer_hash_join", default_value = "true")] prefer_hash_join: BoolDefaultTrue, -} -const IMDB_QUERY_START_ID: usize = 1; -const IMDB_QUERY_END_ID: usize = 113; + /// How many bytes to buffer on the probe side of hash joins. + #[arg(long, default_value = "0")] + hash_join_buffering_capacity: usize, +} fn map_query_id_to_str(query_id: usize) -> &'static str { match query_id { @@ -306,8 +310,10 @@ impl RunOpt { .config()? .with_collect_statistics(!self.disable_statistics); config.options_mut().optimizer.prefer_hash_join = self.prefer_hash_join; - let rt_builder = self.common.runtime_env_builder()?; - let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); + config.options_mut().execution.hash_join_buffering_capacity = + self.hash_join_buffering_capacity; + let rt = self.common.build_runtime()?; + let ctx = SessionContext::new_with_config_rt(config, rt); // register tables self.register_tables(&ctx).await?; @@ -341,6 +347,9 @@ impl RunOpt { let avg = millis.iter().sum::() / millis.len() as f64; println!("Query {query_id} avg time: {avg:.2} ms"); + // Print memory usage stats using mimalloc (only when compiled with --features mimalloc_extended) + print_memory_stats(); + Ok(query_results) } @@ -475,11 +484,6 @@ impl RunOpt { } } -struct QueryResult { - elapsed: std::time::Duration, - row_count: usize, -} - #[cfg(test)] // Only run with "ci" mode when we have the data #[cfg(feature = "ci")] @@ -519,6 +523,7 @@ mod tests { memory_limit: None, sort_spill_reservation_bytes: None, debug: false, + simulate_latency: false, }; let opt = RunOpt { query: Some(query), @@ -529,6 +534,7 @@ mod tests { output_path: None, disable_statistics: false, prefer_hash_join: true, + hash_join_buffering_capacity: 0, }; opt.register_tables(&ctx).await?; let queries = get_query_sql(map_query_id_to_str(query))?; @@ -536,7 +542,7 @@ mod tests { let plan = ctx.sql(&query).await?; let plan = plan.into_optimized_plan()?; let bytes = logical_plan_to_bytes(&plan)?; - let plan2 = logical_plan_from_bytes(&bytes, &ctx)?; + let plan2 = logical_plan_from_bytes(&bytes, &ctx.task_ctx())?; let plan_formatted = format!("{}", plan.display_indent()); let plan2_formatted = format!("{}", plan2.display_indent()); assert_eq!(plan_formatted, plan2_formatted); @@ -555,6 +561,7 @@ mod tests { memory_limit: None, sort_spill_reservation_bytes: None, debug: false, + simulate_latency: false, }; let opt = RunOpt { query: Some(query), @@ -565,6 +572,7 @@ mod tests { output_path: None, disable_statistics: false, prefer_hash_join: true, + hash_join_buffering_capacity: 0, }; opt.register_tables(&ctx).await?; let queries = get_query_sql(map_query_id_to_str(query))?; @@ -572,7 +580,7 @@ mod tests { let plan = ctx.sql(&query).await?; let plan = plan.create_physical_plan().await?; let bytes = physical_plan_to_bytes(plan.clone())?; - let plan2 = physical_plan_from_bytes(&bytes, &ctx)?; + let plan2 = physical_plan_from_bytes(&bytes, &ctx.task_ctx())?; let plan_formatted = format!("{}", displayable(plan.as_ref()).indent(false)); let plan2_formatted = format!("{}", displayable(plan2.as_ref()).indent(false)); diff --git a/benchmarks/src/lib.rs b/benchmarks/src/lib.rs index a402fc1b8ce04..a3bc221840ada 100644 --- a/benchmarks/src/lib.rs +++ b/benchmarks/src/lib.rs @@ -19,9 +19,11 @@ pub mod cancellation; pub mod clickbench; pub mod h2o; +pub mod hj; pub mod imdb; -pub mod parquet_filter; -pub mod sort; +pub mod nlj; +pub mod smj; pub mod sort_tpch; +pub mod tpcds; pub mod tpch; pub mod util; diff --git a/benchmarks/src/nlj.rs b/benchmarks/src/nlj.rs new file mode 100644 index 0000000000000..361cc35ec200c --- /dev/null +++ b/benchmarks/src/nlj.rs @@ -0,0 +1,303 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::util::{BenchmarkRun, CommonOpt, QueryResult}; +use clap::Args; +use datafusion::physical_plan::execute_stream; +use datafusion::{error::Result, prelude::SessionContext}; +use datafusion_common::instant::Instant; +use datafusion_common::{DataFusionError, exec_datafusion_err, exec_err}; + +use futures::StreamExt; + +/// Run the Nested Loop Join (NLJ) benchmark +/// +/// This micro-benchmark focuses on the performance characteristics of NLJs. +/// +/// It always tries to use fast scanners (without decoding overhead) and +/// efficient predicate expressions to ensure it can reflect the performance +/// of the NLJ operator itself. +/// +/// In this micro-benchmark, the following workload characteristics will be +/// varied: +/// - Join type: Inner/Left/Right/Full (all for the NestedLoopJoin physical +/// operator) +/// TODO: Include special join types (Semi/Anti/Mark joins) +/// - Input size: Different combinations of left (build) side and right (probe) +/// side sizes +/// - Selectivity of join filters +#[derive(Debug, Args, Clone)] +#[command(verbatim_doc_comment)] +pub struct RunOpt { + /// Query number (between 1 and 10). If not specified, runs all queries + #[arg(short, long)] + query: Option, + + /// Common options + #[command(flatten)] + common: CommonOpt, + + /// If present, write results json here + #[arg(short = 'o', long = "output")] + output_path: Option, +} + +/// Inline SQL queries for NLJ benchmarks +/// +/// Each query's comment includes: +/// - Left (build) side row count × Right (probe) side row count +/// - Join predicate selectivity (1% means the output size is 1% * input size) +const NLJ_QUERIES: &[&str] = &[ + // Q1: INNER 10K x 10K | LOW 0.1% + r#" + SELECT * + FROM range(10000) AS t1 + JOIN range(10000) AS t2 + ON (t1.value + t2.value) % 1000 = 0; + "#, + // Q2: INNER 10K x 10K | Medium 20% + r#" + SELECT * + FROM range(10000) AS t1 + JOIN range(10000) AS t2 + ON (t1.value + t2.value) % 5 = 0; + "#, + // Q3: INNER 10K x 10K | High 90% + r#" + SELECT * + FROM range(10000) AS t1 + JOIN range(10000) AS t2 + ON (t1.value + t2.value) % 10 <> 0; + "#, + // Q4: INNER 30K x 30K | Medium 20% + r#" + SELECT * + FROM range(30000) AS t1 + JOIN range(30000) AS t2 + ON (t1.value + t2.value) % 5 = 0; + "#, + // Q5: INNER 10K x 200K | LOW 0.1% (small to large) + r#" + SELECT * + FROM range(10000) AS t1 + JOIN range(200000) AS t2 + ON (t1.value + t2.value) % 1000 = 0; + "#, + // Q6: INNER 200K x 10K | LOW 0.1% (large to small) + r#" + SELECT * + FROM range(200000) AS t1 + JOIN range(10000) AS t2 + ON (t1.value + t2.value) % 1000 = 0; + "#, + // Q7: RIGHT OUTER 10K x 200K | LOW 0.1% + r#" + SELECT * + FROM range(10000) AS t1 + RIGHT JOIN range(200000) AS t2 + ON (t1.value + t2.value) % 1000 = 0; + "#, + // Q8: LEFT OUTER 200K x 10K | LOW 0.1% + r#" + SELECT * + FROM range(200000) AS t1 + LEFT JOIN range(10000) AS t2 + ON (t1.value + t2.value) % 1000 = 0; + "#, + // Q9: FULL OUTER 30K x 30K | LOW 0.1% + r#" + SELECT * + FROM range(30000) AS t1 + FULL JOIN range(30000) AS t2 + ON (t1.value + t2.value) % 1000 = 0; + "#, + // Q10: FULL OUTER 30K x 30K | High 90% + r#" + SELECT * + FROM range(30000) AS t1 + FULL JOIN range(30000) AS t2 + ON (t1.value + t2.value) % 10 <> 0; + "#, + // Q11: INNER 30K x 30K | MEDIUM 50% | cheap predicate + r#" + SELECT * + FROM range(30000) AS t1 + INNER JOIN range(30000) AS t2 + ON (t1.value > t2.value); + "#, + // Q12: FULL OUTER 30K x 30K | MEDIUM 50% | cheap predicate + r#" + SELECT * + FROM range(30000) AS t1 + FULL JOIN range(30000) AS t2 + ON (t1.value > t2.value); + "#, + // Q13: LEFT SEMI 30K x 30K | HIGH 99.9% + r#" + SELECT t1.* + FROM range(30000) AS t1 + LEFT SEMI JOIN range(30000) AS t2 + ON t1.value < t2.value; + "#, + // Q14: LEFT ANTI 30K x 30K | LOW 0.003% + r#" + SELECT t1.* + FROM range(30000) AS t1 + LEFT ANTI JOIN range(30000) AS t2 + ON t1.value < t2.value; + "#, + // Q15: RIGHT SEMI 30K x 30K | HIGH 99.9% + r#" + SELECT t1.* + FROM range(30000) AS t2 + RIGHT SEMI JOIN range(30000) AS t1 + ON t2.value < t1.value; + "#, + // Q16: RIGHT ANTI 30K x 30K | LOW 0.003% + r#" + SELECT t1.* + FROM range(30000) AS t2 + RIGHT ANTI JOIN range(30000) AS t1 + ON t2.value < t1.value; + "#, + // Q17: LEFT MARK | HIGH 99.9% + r#" + SELECT * + FROM range(30000) AS t2(k2) + WHERE k2 > 0 + OR EXISTS ( + SELECT 1 + FROM range(30000) AS t1(k1) + WHERE t2.k2 > t1.k1 + ); + "#, +]; + +impl RunOpt { + pub async fn run(self) -> Result<()> { + println!("Running NLJ benchmarks with the following options: {self:#?}\n"); + + // Define query range + let query_range = match self.query { + Some(query_id) => { + if query_id >= 1 && query_id <= NLJ_QUERIES.len() { + query_id..=query_id + } else { + return exec_err!( + "Query {query_id} not found. Available queries: 1 to {}", + NLJ_QUERIES.len() + ); + } + } + None => 1..=NLJ_QUERIES.len(), + }; + + let config = self.common.config()?; + let rt = self.common.build_runtime()?; + let ctx = SessionContext::new_with_config_rt(config, rt); + + let mut benchmark_run = BenchmarkRun::new(); + for query_id in query_range { + let query_index = query_id - 1; // Convert 1-based to 0-based index + + let sql = NLJ_QUERIES[query_index]; + benchmark_run.start_new_case(&format!("Query {query_id}")); + let query_run = self.benchmark_query(sql, &query_id.to_string(), &ctx).await; + match query_run { + Ok(query_results) => { + for iter in query_results { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + Err(e) => { + return Err(DataFusionError::Context( + "NLJ benchmark Q{query_id} failed with error:".to_string(), + Box::new(e), + )); + } + } + } + + benchmark_run.maybe_write_json(self.output_path.as_ref())?; + Ok(()) + } + + /// Validates that the query's physical plan uses a NestedLoopJoin (NLJ), + /// then executes the query and collects execution times. + /// + /// TODO: ensure the optimizer won't change the join order (it's not at + /// v48.0.0). + async fn benchmark_query( + &self, + sql: &str, + query_name: &str, + ctx: &SessionContext, + ) -> Result> { + let mut query_results = vec![]; + + // Validate that the query plan includes a Nested Loop Join + let df = ctx.sql(sql).await?; + let physical_plan = df.create_physical_plan().await?; + let plan_string = format!("{physical_plan:#?}"); + + if !plan_string.contains("NestedLoopJoinExec") { + return Err(exec_datafusion_err!( + "Query {query_name} does not use Nested Loop Join. Physical plan: {plan_string}" + )); + } + + for i in 0..self.common.iterations { + let start = Instant::now(); + + let row_count = Self::execute_sql_without_result_buffering(sql, ctx).await?; + + let elapsed = start.elapsed(); + + println!( + "Query {query_name} iteration {i} returned {row_count} rows in {elapsed:?}" + ); + + query_results.push(QueryResult { elapsed, row_count }); + } + + Ok(query_results) + } + + /// Executes the SQL query and drops each result batch after evaluation, to + /// minimizes memory usage by not buffering results. + /// + /// Returns the total result row count + async fn execute_sql_without_result_buffering( + sql: &str, + ctx: &SessionContext, + ) -> Result { + let mut row_count = 0; + + let df = ctx.sql(sql).await?; + let physical_plan = df.create_physical_plan().await?; + let mut stream = execute_stream(physical_plan, ctx.task_ctx())?; + + while let Some(batch) = stream.next().await { + row_count += batch?.num_rows(); + + // Evaluate the result and do nothing, the result will be dropped + // to reduce memory pressure + } + + Ok(row_count) + } +} diff --git a/benchmarks/src/parquet_filter.rs b/benchmarks/src/parquet_filter.rs deleted file mode 100644 index 34103af0ffd21..0000000000000 --- a/benchmarks/src/parquet_filter.rs +++ /dev/null @@ -1,194 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::path::PathBuf; - -use crate::util::{AccessLogOpt, BenchmarkRun, CommonOpt}; - -use arrow::util::pretty; -use datafusion::common::Result; -use datafusion::logical_expr::utils::disjunction; -use datafusion::logical_expr::{lit, or, Expr}; -use datafusion::physical_plan::collect; -use datafusion::prelude::{col, SessionContext}; -use datafusion::test_util::parquet::{ParquetScanOptions, TestParquetFile}; -use datafusion_common::instant::Instant; - -use structopt::StructOpt; - -/// Test performance of parquet filter pushdown -/// -/// The queries are executed on a synthetic dataset generated during -/// the benchmark execution and designed to simulate web server access -/// logs. -/// -/// Example -/// -/// dfbench parquet-filter --path ./data --scale-factor 1.0 -/// -/// generates the synthetic dataset at `./data/logs.parquet`. The size -/// of the dataset can be controlled through the `size_factor` -/// (with the default value of `1.0` generating a ~1GB parquet file). -/// -/// For each filter we will run the query using different -/// `ParquetScanOption` settings. -/// -/// Example output: -/// -/// Running benchmarks with the following options: Opt { debug: false, iterations: 3, partitions: 2, path: "./data", batch_size: 8192, scale_factor: 1.0 } -/// Generated test dataset with 10699521 rows -/// Executing with filter 'request_method = Utf8("GET")' -/// Using scan options ParquetScanOptions { pushdown_filters: false, reorder_predicates: false, enable_page_index: false } -/// Iteration 0 returned 10699521 rows in 1303 ms -/// Iteration 1 returned 10699521 rows in 1288 ms -/// Iteration 2 returned 10699521 rows in 1266 ms -/// Using scan options ParquetScanOptions { pushdown_filters: true, reorder_predicates: true, enable_page_index: true } -/// Iteration 0 returned 1781686 rows in 1970 ms -/// Iteration 1 returned 1781686 rows in 2002 ms -/// Iteration 2 returned 1781686 rows in 1988 ms -/// Using scan options ParquetScanOptions { pushdown_filters: true, reorder_predicates: false, enable_page_index: true } -/// Iteration 0 returned 1781686 rows in 1940 ms -/// Iteration 1 returned 1781686 rows in 1986 ms -/// Iteration 2 returned 1781686 rows in 1947 ms -/// ... -#[derive(Debug, StructOpt, Clone)] -#[structopt(verbatim_doc_comment)] -pub struct RunOpt { - /// Common options - #[structopt(flatten)] - common: CommonOpt, - - /// Create data files - #[structopt(flatten)] - access_log: AccessLogOpt, - - /// Path to machine readable output file - #[structopt(parse(from_os_str), short = "o", long = "output")] - output_path: Option, -} - -impl RunOpt { - pub async fn run(self) -> Result<()> { - let test_file = self.access_log.build()?; - - let mut rundata = BenchmarkRun::new(); - let scan_options_matrix = vec![ - ParquetScanOptions { - pushdown_filters: false, - reorder_filters: false, - enable_page_index: false, - }, - ParquetScanOptions { - pushdown_filters: true, - reorder_filters: true, - enable_page_index: true, - }, - ParquetScanOptions { - pushdown_filters: true, - reorder_filters: true, - enable_page_index: false, - }, - ]; - - let filter_matrix = vec![ - ("Selective-ish filter", col("request_method").eq(lit("GET"))), - ( - "Non-selective filter", - col("request_method").not_eq(lit("GET")), - ), - ( - "Basic conjunction", - col("request_method") - .eq(lit("POST")) - .and(col("response_status").eq(lit(503_u16))), - ), - ( - "Nested filters", - col("request_method").eq(lit("POST")).and(or( - col("response_status").eq(lit(503_u16)), - col("response_status").eq(lit(403_u16)), - )), - ), - ( - "Many filters", - disjunction([ - col("request_method").not_eq(lit("GET")), - col("response_status").eq(lit(400_u16)), - col("service").eq(lit("backend")), - ]) - .unwrap(), - ), - ("Filter everything", col("response_status").eq(lit(429_u16))), - ("Filter nothing", col("response_status").gt(lit(0_u16))), - ]; - - for (name, filter_expr) in &filter_matrix { - println!("Executing '{name}' (filter: {filter_expr})"); - for scan_options in &scan_options_matrix { - println!("Using scan options {scan_options:?}"); - rundata.start_new_case(&format!( - "{name}: {}", - parquet_scan_disp(scan_options) - )); - for i in 0..self.common.iterations { - let config = self.common.update_config(scan_options.config()); - let ctx = SessionContext::new_with_config(config); - - let (rows, elapsed) = exec_scan( - &ctx, - &test_file, - filter_expr.clone(), - self.common.debug, - ) - .await?; - let ms = elapsed.as_secs_f64() * 1000.0; - println!("Iteration {i} returned {rows} rows in {ms} ms"); - rundata.write_iter(elapsed, rows); - } - } - println!("\n"); - } - rundata.maybe_write_json(self.output_path.as_ref())?; - Ok(()) - } -} - -fn parquet_scan_disp(opts: &ParquetScanOptions) -> String { - format!( - "pushdown_filters={}, reorder_filters={}, page_index={}", - opts.pushdown_filters, opts.reorder_filters, opts.enable_page_index - ) -} - -async fn exec_scan( - ctx: &SessionContext, - test_file: &TestParquetFile, - filter: Expr, - debug: bool, -) -> Result<(usize, std::time::Duration)> { - let start = Instant::now(); - let exec = test_file.create_scan(ctx, Some(filter)).await?; - - let task_ctx = ctx.task_ctx(); - let result = collect(exec, task_ctx).await?; - let elapsed = start.elapsed(); - if debug { - pretty::print_batches(&result)?; - } - let rows = result.iter().map(|b| b.num_rows()).sum(); - Ok((rows, elapsed)) -} diff --git a/benchmarks/src/smj.rs b/benchmarks/src/smj.rs new file mode 100644 index 0000000000000..5056fd5096156 --- /dev/null +++ b/benchmarks/src/smj.rs @@ -0,0 +1,524 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::util::{BenchmarkRun, CommonOpt, QueryResult}; +use clap::Args; +use datafusion::physical_plan::execute_stream; +use datafusion::{error::Result, prelude::SessionContext}; +use datafusion_common::instant::Instant; +use datafusion_common::{DataFusionError, exec_datafusion_err, exec_err}; + +use futures::StreamExt; + +/// Run the Sort Merge Join (SMJ) benchmark +/// +/// This micro-benchmark focuses on the performance characteristics of SMJs. +/// +/// It uses equality join predicates (to ensure SMJ is selected) and varies: +/// - Join type: Inner/Left/Right/Full/LeftSemi/LeftAnti/RightSemi/RightAnti +/// - Key cardinality: 1:1, 1:N, N:M relationships +/// - Filter selectivity: Low (1%), Medium (10%), High (50%) +/// - Input sizes: Small to large, balanced and skewed +/// +/// All inputs are pre-sorted in CTEs before the join to isolate join +/// performance from sort overhead. +#[derive(Debug, Args, Clone)] +#[command(verbatim_doc_comment)] +pub struct RunOpt { + /// Query number (between 1 and 20). If not specified, runs all queries + #[arg(short, long)] + query: Option, + + /// Common options + #[command(flatten)] + common: CommonOpt, + + /// If present, write results json here + #[arg(short = 'o', long = "output")] + output_path: Option, +} + +/// Inline SQL queries for SMJ benchmarks +/// +/// Each query's comment includes: +/// - Join type +/// - Left row count × Right row count +/// - Key cardinality (rows per key) +/// - Filter selectivity (if applicable) +const SMJ_QUERIES: &[&str] = &[ + // Q1: INNER 100K x 100K | 1:1 + r#" + WITH t1_sorted AS ( + SELECT value as key FROM range(100000) ORDER BY value + ), + t2_sorted AS ( + SELECT value as key FROM range(100000) ORDER BY value + ) + SELECT t1_sorted.key as k1, t2_sorted.key as k2 + FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + "#, + // Q2: INNER 100K x 1M | 1:10 + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + "#, + // Q3: INNER 1M x 1M | 1:100 + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + "#, + // Q4: INNER 100K x 1M | 1:10 | 1% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + WHERE t2_sorted.data % 100 = 0 + "#, + // Q5: INNER 1M x 1M | 1:100 | 10% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + WHERE t1_sorted.data <> t2_sorted.data AND t2_sorted.data % 10 = 0 + "#, + // Q6: LEFT 100K x 1M | 1:10 + r#" + WITH t1_sorted AS ( + SELECT value % 10500 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted LEFT JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + "#, + // Q7: LEFT 100K x 1M | 1:10 | 50% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted LEFT JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + WHERE t2_sorted.data IS NULL OR t2_sorted.data % 2 = 0 + "#, + // Q8: FULL 100K x 100K | 1:10 + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 12500 as key, value as data + FROM range(100000) + ORDER BY key, data + ) + SELECT t1_sorted.key as k1, t1_sorted.data as d1, + t2_sorted.key as k2, t2_sorted.data as d2 + FROM t1_sorted FULL JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + "#, + // Q9: FULL 100K x 1M | 1:10 | 10% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key as k1, t1_sorted.data as d1, + t2_sorted.key as k2, t2_sorted.data as d2 + FROM t1_sorted FULL JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + WHERE (t1_sorted.data IS NULL OR t2_sorted.data IS NULL + OR t1_sorted.data <> t2_sorted.data) + AND (t1_sorted.data IS NULL OR t1_sorted.data % 10 = 0) + "#, + // Q10: LEFT SEMI 100K x 1M | 1:10 + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key + FROM range(1000000) + ORDER BY key + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + ) + "#, + // Q11: LEFT SEMI 100K x 1M | 1:10 | 1% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + AND t2_sorted.data <> t1_sorted.data + AND t2_sorted.data % 100 = 0 + ) + "#, + // Q12: LEFT SEMI 100K x 1M | 1:10 | 50% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + AND t2_sorted.data <> t1_sorted.data + AND t2_sorted.data % 2 = 0 + ) + "#, + // Q13: LEFT SEMI 100K x 1M | 1:10 | 90% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + AND t2_sorted.data % 10 <> 0 + ) + "#, + // Q14: LEFT ANTI 100K x 1M | 1:10 + r#" + WITH t1_sorted AS ( + SELECT value % 10500 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key + FROM range(1000000) + ORDER BY key + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE NOT EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + ) + "#, + // Q15: LEFT ANTI 100K x 1M | 1:10 | partial match + r#" + WITH t1_sorted AS ( + SELECT value % 12000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key + FROM range(1000000) + ORDER BY key + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE NOT EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + ) + "#, + // Q16: LEFT ANTI 100K x 100K | 1:1 | stress + r#" + WITH t1_sorted AS ( + SELECT value % 11000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key + FROM range(100000) + ORDER BY key + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE NOT EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + ) + "#, + // Q17: INNER 100K x 5M | 1:50 | 5% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(5000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + WHERE t2_sorted.data <> t1_sorted.data AND t2_sorted.data % 20 = 0 + "#, + // Q18: LEFT SEMI 100K x 5M | 1:50 | 2% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(5000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + AND t2_sorted.data <> t1_sorted.data + AND t2_sorted.data % 50 = 0 + ) + "#, + // Q19: LEFT ANTI 100K x 5M | 1:50 | partial match + r#" + WITH t1_sorted AS ( + SELECT value % 15000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key + FROM range(5000000) + ORDER BY key + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE NOT EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + ) + "#, + // Q20: INNER 1M x 10M | 1:100 + GROUP BY + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(10000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, count(*) as cnt + FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + GROUP BY t1_sorted.key + "#, +]; + +impl RunOpt { + pub async fn run(self) -> Result<()> { + println!("Running SMJ benchmarks with the following options: {self:#?}\n"); + + // Define query range + let query_range = match self.query { + Some(query_id) => { + if query_id >= 1 && query_id <= SMJ_QUERIES.len() { + query_id..=query_id + } else { + return exec_err!( + "Query {query_id} not found. Available queries: 1 to {}", + SMJ_QUERIES.len() + ); + } + } + None => 1..=SMJ_QUERIES.len(), + }; + + let mut config = self.common.config()?; + // Disable hash joins to force SMJ + config = config.set_bool("datafusion.optimizer.prefer_hash_join", false); + let rt = self.common.build_runtime()?; + let ctx = SessionContext::new_with_config_rt(config, rt); + + let mut benchmark_run = BenchmarkRun::new(); + for query_id in query_range { + let query_index = query_id - 1; // Convert 1-based to 0-based index + + let sql = SMJ_QUERIES[query_index]; + benchmark_run.start_new_case(&format!("Query {query_id}")); + let query_run = self.benchmark_query(sql, &query_id.to_string(), &ctx).await; + match query_run { + Ok(query_results) => { + for iter in query_results { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + Err(e) => { + return Err(DataFusionError::Context( + format!("SMJ benchmark Q{query_id} failed with error:"), + Box::new(e), + )); + } + } + } + + benchmark_run.maybe_write_json(self.output_path.as_ref())?; + Ok(()) + } + + async fn benchmark_query( + &self, + sql: &str, + query_name: &str, + ctx: &SessionContext, + ) -> Result> { + let mut query_results = vec![]; + + // Validate that the query plan includes a Sort Merge Join + let df = ctx.sql(sql).await?; + let physical_plan = df.create_physical_plan().await?; + let plan_string = format!("{physical_plan:#?}"); + + if !plan_string.contains("SortMergeJoinExec") { + return Err(exec_datafusion_err!( + "Query {query_name} does not use Sort Merge Join. Physical plan: {plan_string}" + )); + } + + for i in 0..self.common.iterations { + let start = Instant::now(); + + let row_count = Self::execute_sql_without_result_buffering(sql, ctx).await?; + + let elapsed = start.elapsed(); + + println!( + "Query {query_name} iteration {i} returned {row_count} rows in {elapsed:?}" + ); + + query_results.push(QueryResult { elapsed, row_count }); + } + + Ok(query_results) + } + + /// Executes the SQL query and drops each result batch after evaluation, to + /// minimizes memory usage by not buffering results. + /// + /// Returns the total result row count + async fn execute_sql_without_result_buffering( + sql: &str, + ctx: &SessionContext, + ) -> Result { + let mut row_count = 0; + + let df = ctx.sql(sql).await?; + let physical_plan = df.create_physical_plan().await?; + let mut stream = execute_stream(physical_plan, ctx.task_ctx())?; + + while let Some(batch) = stream.next().await { + row_count += batch?.num_rows(); + + // Evaluate the result and do nothing, the result will be dropped + // to reduce memory pressure + } + + Ok(row_count) + } +} diff --git a/benchmarks/src/sort.rs b/benchmarks/src/sort.rs deleted file mode 100644 index 8b2b02670449e..0000000000000 --- a/benchmarks/src/sort.rs +++ /dev/null @@ -1,187 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::path::PathBuf; -use std::sync::Arc; - -use crate::util::{AccessLogOpt, BenchmarkRun, CommonOpt}; - -use arrow::util::pretty; -use datafusion::common::Result; -use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr}; -use datafusion::physical_plan::collect; -use datafusion::physical_plan::sorts::sort::SortExec; -use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion::test_util::parquet::TestParquetFile; -use datafusion_common::instant::Instant; -use datafusion_common::utils::get_available_parallelism; -use structopt::StructOpt; - -/// Test performance of sorting large datasets -/// -/// This test sorts a a synthetic dataset generated during the -/// benchmark execution, designed to simulate sorting web server -/// access logs. Such sorting is often done during data transformation -/// steps. -/// -/// The tests sort the entire dataset using several different sort -/// orders. -/// -/// Example: -/// -/// dfbench sort --path ./data --scale-factor 1.0 -#[derive(Debug, StructOpt, Clone)] -#[structopt(verbatim_doc_comment)] -pub struct RunOpt { - /// Common options - #[structopt(flatten)] - common: CommonOpt, - - /// Create data files - #[structopt(flatten)] - access_log: AccessLogOpt, - - /// Path to machine readable output file - #[structopt(parse(from_os_str), short = "o", long = "output")] - output_path: Option, -} - -impl RunOpt { - pub async fn run(self) -> Result<()> { - let test_file = self.access_log.build()?; - - use datafusion::physical_expr::expressions::col; - let mut rundata = BenchmarkRun::new(); - let schema = test_file.schema(); - let sort_cases = vec![ - ( - "sort utf8", - LexOrdering::new(vec![PhysicalSortExpr { - expr: col("request_method", &schema)?, - options: Default::default(), - }]), - ), - ( - "sort int", - LexOrdering::new(vec![PhysicalSortExpr { - expr: col("response_bytes", &schema)?, - options: Default::default(), - }]), - ), - ( - "sort decimal", - LexOrdering::new(vec![PhysicalSortExpr { - expr: col("decimal_price", &schema)?, - options: Default::default(), - }]), - ), - ( - "sort integer tuple", - LexOrdering::new(vec![ - PhysicalSortExpr { - expr: col("request_bytes", &schema)?, - options: Default::default(), - }, - PhysicalSortExpr { - expr: col("response_bytes", &schema)?, - options: Default::default(), - }, - ]), - ), - ( - "sort utf8 tuple", - LexOrdering::new(vec![ - // sort utf8 tuple - PhysicalSortExpr { - expr: col("service", &schema)?, - options: Default::default(), - }, - PhysicalSortExpr { - expr: col("host", &schema)?, - options: Default::default(), - }, - PhysicalSortExpr { - expr: col("pod", &schema)?, - options: Default::default(), - }, - PhysicalSortExpr { - expr: col("image", &schema)?, - options: Default::default(), - }, - ]), - ), - ( - "sort mixed tuple", - LexOrdering::new(vec![ - PhysicalSortExpr { - expr: col("service", &schema)?, - options: Default::default(), - }, - PhysicalSortExpr { - expr: col("request_bytes", &schema)?, - options: Default::default(), - }, - PhysicalSortExpr { - expr: col("decimal_price", &schema)?, - options: Default::default(), - }, - ]), - ), - ]; - for (title, expr) in sort_cases { - println!("Executing '{title}' (sorting by: {expr:?})"); - rundata.start_new_case(title); - for i in 0..self.common.iterations { - let config = SessionConfig::new().with_target_partitions( - self.common - .partitions - .unwrap_or_else(get_available_parallelism), - ); - let ctx = SessionContext::new_with_config(config); - let (rows, elapsed) = - exec_sort(&ctx, &expr, &test_file, self.common.debug).await?; - let ms = elapsed.as_secs_f64() * 1000.0; - println!("Iteration {i} finished in {ms} ms"); - rundata.write_iter(elapsed, rows); - } - println!("\n"); - } - if let Some(path) = &self.output_path { - std::fs::write(path, rundata.to_json())?; - } - Ok(()) - } -} - -async fn exec_sort( - ctx: &SessionContext, - expr: &LexOrdering, - test_file: &TestParquetFile, - debug: bool, -) -> Result<(usize, std::time::Duration)> { - let start = Instant::now(); - let scan = test_file.create_scan(ctx, None).await?; - let exec = Arc::new(SortExec::new(expr.clone(), scan)); - let task_ctx = ctx.task_ctx(); - let result = collect(exec, task_ctx).await?; - let elapsed = start.elapsed(); - if debug { - pretty::print_batches(&result)?; - } - let rows = result.iter().map(|b| b.num_rows()).sum(); - Ok((rows, elapsed)) -} diff --git a/benchmarks/src/sort_tpch.rs b/benchmarks/src/sort_tpch.rs index ba03529a930e7..95c90d826de20 100644 --- a/benchmarks/src/sort_tpch.rs +++ b/benchmarks/src/sort_tpch.rs @@ -21,10 +21,10 @@ //! Another `Sort` benchmark focus on single core execution. This benchmark //! runs end-to-end sort queries and test the performance on multiple CPU cores. +use clap::Args; use futures::StreamExt; use std::path::PathBuf; use std::sync::Arc; -use structopt::StructOpt; use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::{ @@ -36,48 +36,46 @@ use datafusion::execution::SessionStateBuilder; use datafusion::physical_plan::display::DisplayableExecutionPlan; use datafusion::physical_plan::{displayable, execute_stream}; use datafusion::prelude::*; +use datafusion_common::DEFAULT_PARQUET_EXTENSION; use datafusion_common::instant::Instant; use datafusion_common::utils::get_available_parallelism; -use datafusion_common::DEFAULT_PARQUET_EXTENSION; -use crate::util::{BenchmarkRun, CommonOpt}; +use crate::util::{BenchmarkRun, CommonOpt, QueryResult, print_memory_stats}; -#[derive(Debug, StructOpt)] +#[derive(Debug, Args)] pub struct RunOpt { /// Common options - #[structopt(flatten)] + #[command(flatten)] common: CommonOpt, /// Sort query number. If not specified, runs all queries - #[structopt(short, long)] - query: Option, + #[arg(short, long)] + pub query: Option, /// Path to data files (lineitem). Only parquet format is supported - #[structopt(parse(from_os_str), required = true, short = "p", long = "path")] + #[arg(required = true, short = 'p', long = "path")] path: PathBuf, /// Path to JSON benchmark result to be compare using `compare.py` - #[structopt(parse(from_os_str), short = "o", long = "output")] + #[arg(short = 'o', long = "output")] output_path: Option, /// Load the data into a MemTable before executing the query - #[structopt(short = "m", long = "mem-table")] + #[arg(short = 'm', long = "mem-table")] mem_table: bool, /// Mark the first column of each table as sorted in ascending order. /// The tables should have been created with the `--sort` option for this to have any effect. - #[structopt(short = "t", long = "sorted")] + #[arg(short = 't', long = "sorted")] sorted: bool, /// Append a `LIMIT n` clause to the query - #[structopt(short = "l", long = "limit")] + #[arg(short = 'l', long = "limit")] limit: Option, } -struct QueryResult { - elapsed: std::time::Duration, - row_count: usize, -} +pub const SORT_TPCH_QUERY_START_ID: usize = 1; +pub const SORT_TPCH_QUERY_END_ID: usize = 11; impl RunOpt { const SORT_TABLES: [&'static str; 1] = ["lineitem"]; @@ -179,34 +177,42 @@ impl RunOpt { /// If query is specified from command line, run only that query. /// Otherwise, run all queries. pub async fn run(&self) -> Result<()> { - let mut benchmark_run = BenchmarkRun::new(); + let mut benchmark_run: BenchmarkRun = BenchmarkRun::new(); let query_range = match self.query { Some(query_id) => query_id..=query_id, - None => 1..=Self::SORT_QUERIES.len(), + None => SORT_TPCH_QUERY_START_ID..=SORT_TPCH_QUERY_END_ID, }; for query_id in query_range { benchmark_run.start_new_case(&format!("{query_id}")); - let query_results = self.benchmark_query(query_id).await?; - for iter in query_results { - benchmark_run.write_iter(iter.elapsed, iter.row_count); + let query_results = self.benchmark_query(query_id).await; + match query_results { + Ok(query_results) => { + for iter in query_results { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + Err(e) => { + benchmark_run.mark_failed(); + eprintln!("Query {query_id} failed: {e}"); + } } } benchmark_run.maybe_write_json(self.output_path.as_ref())?; - + benchmark_run.maybe_print_failures(); Ok(()) } /// Benchmark query `query_id` in `SORT_QUERIES` async fn benchmark_query(&self, query_id: usize) -> Result> { let config = self.common.config()?; - let rt_builder = self.common.runtime_env_builder()?; + let rt = self.common.build_runtime()?; let state = SessionStateBuilder::new() .with_config(config) - .with_runtime_env(rt_builder.build_arc()?) + .with_runtime_env(rt) .with_default_features() .build(); let ctx = SessionContext::from(state); @@ -235,13 +241,16 @@ impl RunOpt { millis.push(ms); println!( - "Q{query_id} iteration {i} took {ms:.1} ms and returned {row_count} rows" + "Query {query_id} iteration {i} took {ms:.1} ms and returned {row_count} rows" ); query_results.push(QueryResult { elapsed, row_count }); } let avg = millis.iter().sum::() / millis.len() as f64; - println!("Q{query_id} avg time: {avg:.2} ms"); + println!("Query {query_id} avg time: {avg:.2} ms"); + + // Print memory usage stats using mimalloc (only when compiled with --features mimalloc_extended) + print_memory_stats(); Ok(query_results) } @@ -294,7 +303,7 @@ impl RunOpt { let mut stream = execute_stream(physical_plan.clone(), state.task_ctx())?; while let Some(batch) = stream.next().await { - row_count += batch.unwrap().num_rows(); + row_count += batch?.num_rows(); } if debug { diff --git a/benchmarks/src/tpcds/mod.rs b/benchmarks/src/tpcds/mod.rs new file mode 100644 index 0000000000000..4829eb9fd348a --- /dev/null +++ b/benchmarks/src/tpcds/mod.rs @@ -0,0 +1,19 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod run; +pub use run::RunOpt; diff --git a/benchmarks/src/tpcds/run.rs b/benchmarks/src/tpcds/run.rs new file mode 100644 index 0000000000000..f7ef6991515da --- /dev/null +++ b/benchmarks/src/tpcds/run.rs @@ -0,0 +1,362 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::fs; +use std::path::PathBuf; +use std::sync::Arc; + +use crate::util::{BenchmarkRun, CommonOpt, QueryResult, print_memory_stats}; + +use arrow::record_batch::RecordBatch; +use arrow::util::pretty::{self, pretty_format_batches}; +use datafusion::datasource::file_format::parquet::ParquetFormat; +use datafusion::datasource::listing::{ + ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, +}; +use datafusion::datasource::{MemTable, TableProvider}; +use datafusion::error::Result; +use datafusion::physical_plan::display::DisplayableExecutionPlan; +use datafusion::physical_plan::{collect, displayable}; +use datafusion::prelude::*; +use datafusion_common::instant::Instant; +use datafusion_common::utils::get_available_parallelism; +use datafusion_common::{DEFAULT_PARQUET_EXTENSION, plan_err}; + +use clap::Args; +use log::info; + +// hack to avoid `default_value is meaningless for bool` errors +type BoolDefaultTrue = bool; +pub const TPCDS_QUERY_START_ID: usize = 1; +pub const TPCDS_QUERY_END_ID: usize = 99; + +pub const TPCDS_TABLES: &[&str] = &[ + "call_center", + "customer_address", + "household_demographics", + "promotion", + "store_sales", + "web_page", + "catalog_page", + "customer_demographics", + "income_band", + "reason", + "store", + "web_returns", + "catalog_returns", + "customer", + "inventory", + "ship_mode", + "time_dim", + "web_sales", + "catalog_sales", + "date_dim", + "item", + "store_returns", + "warehouse", + "web_site", +]; + +/// Get the SQL statements from the specified query file +pub fn get_query_sql(base_query_path: &str, query: usize) -> Result> { + if query > 0 && query < 100 { + let filename = format!("{base_query_path}/{query}.sql"); + let mut errors = vec![]; + match fs::read_to_string(&filename) { + Ok(contents) => { + return Ok(contents + .split(';') + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .map(|s| s.to_string()) + .collect()); + } + Err(e) => errors.push(format!("{filename}: {e}")), + }; + + plan_err!("invalid query. Could not find query: {:?}", errors) + } else { + plan_err!("invalid query. Expected value between 1 and 99") + } +} + +/// Run the tpcds benchmark. +#[derive(Debug, Args, Clone)] +#[command(verbatim_doc_comment)] +pub struct RunOpt { + /// Query number. If not specified, runs all queries + #[arg(short, long)] + pub query: Option, + + /// Common options + #[command(flatten)] + common: CommonOpt, + + /// Path to data files + #[arg(required = true, short = 'p', long = "path")] + path: PathBuf, + + /// Path to query files + #[arg(required = true, short = 'Q', long = "query_path")] + query_path: PathBuf, + + /// Load the data into a MemTable before executing the query + #[arg(short = 'm', long = "mem-table")] + mem_table: bool, + + /// Path to machine readable output file + #[arg(short = 'o', long = "output")] + output_path: Option, + + /// Whether to disable collection of statistics (and cost based optimizations) or not. + #[arg(short = 'S', long = "disable-statistics")] + disable_statistics: bool, + + /// If true then hash join used, if false then sort merge join + /// True by default. + #[arg(short = 'j', long = "prefer_hash_join", default_value = "true")] + prefer_hash_join: BoolDefaultTrue, + + /// If true then Piecewise Merge Join can be used, if false then it will opt for Nested Loop Join + /// False by default. + #[arg( + short = 'w', + long = "enable_piecewise_merge_join", + default_value = "false" + )] + enable_piecewise_merge_join: BoolDefaultTrue, + + /// Mark the first column of each table as sorted in ascending order. + /// The tables should have been created with the `--sort` option for this to have any effect. + #[arg(short = 't', long = "sorted")] + sorted: bool, + + /// How many bytes to buffer on the probe side of hash joins. + #[arg(long, default_value = "0")] + hash_join_buffering_capacity: usize, +} + +impl RunOpt { + pub async fn run(self) -> Result<()> { + println!("Running benchmarks with the following options: {self:?}"); + let query_range = match self.query { + Some(query_id) => query_id..=query_id, + None => TPCDS_QUERY_START_ID..=TPCDS_QUERY_END_ID, + }; + + let mut benchmark_run = BenchmarkRun::new(); + let mut config = self + .common + .config()? + .with_collect_statistics(!self.disable_statistics); + config.options_mut().optimizer.prefer_hash_join = self.prefer_hash_join; + config.options_mut().optimizer.enable_piecewise_merge_join = + self.enable_piecewise_merge_join; + config.options_mut().execution.hash_join_buffering_capacity = + self.hash_join_buffering_capacity; + let rt = self.common.build_runtime()?; + let ctx = SessionContext::new_with_config_rt(config, rt); + // register tables + self.register_tables(&ctx).await?; + + for query_id in query_range { + benchmark_run.start_new_case(&format!("Query {query_id}")); + let query_run = self.benchmark_query(query_id, &ctx).await; + match query_run { + Ok(query_results) => { + for iter in query_results { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + Err(e) => { + benchmark_run.mark_failed(); + eprintln!("Query {query_id} failed: {e}"); + } + } + } + benchmark_run.maybe_write_json(self.output_path.as_ref())?; + benchmark_run.maybe_print_failures(); + Ok(()) + } + + async fn benchmark_query( + &self, + query_id: usize, + ctx: &SessionContext, + ) -> Result> { + let mut millis = vec![]; + // run benchmark + let mut query_results = vec![]; + + let sql = &get_query_sql(self.query_path.to_str().unwrap(), query_id)?; + + if self.common.debug { + println!("=== SQL for query {query_id} ===\n{}\n", sql.join(";\n")); + } + + for i in 0..self.iterations() { + let start = Instant::now(); + + // query 15 is special, with 3 statements. the second statement is the one from which we + // want to capture the results + let mut result = vec![]; + + for query in sql { + result = self.execute_query(ctx, query).await?; + } + + let elapsed = start.elapsed(); + let ms = elapsed.as_secs_f64() * 1000.0; + millis.push(ms); + info!("output:\n\n{}\n\n", pretty_format_batches(&result)?); + let row_count = result.iter().map(|b| b.num_rows()).sum(); + println!( + "Query {query_id} iteration {i} took {ms:.1} ms and returned {row_count} rows" + ); + query_results.push(QueryResult { elapsed, row_count }); + } + + let avg = millis.iter().sum::() / millis.len() as f64; + println!("Query {query_id} avg time: {avg:.2} ms"); + + // Print memory stats using mimalloc (only when compiled with --features mimalloc_extended) + print_memory_stats(); + + Ok(query_results) + } + + async fn register_tables(&self, ctx: &SessionContext) -> Result<()> { + for table in TPCDS_TABLES { + let table_provider = { self.get_table(ctx, table).await? }; + + if self.mem_table { + println!("Loading table '{table}' into memory"); + let start = Instant::now(); + let memtable = + MemTable::load(table_provider, Some(self.partitions()), &ctx.state()) + .await?; + println!( + "Loaded table '{}' into memory in {} ms", + table, + start.elapsed().as_millis() + ); + ctx.register_table(*table, Arc::new(memtable))?; + } else { + ctx.register_table(*table, table_provider)?; + } + } + Ok(()) + } + + async fn execute_query( + &self, + ctx: &SessionContext, + sql: &str, + ) -> Result> { + let debug = self.common.debug; + let plan = ctx.sql(sql).await?; + let (state, plan) = plan.into_parts(); + + if debug { + println!("=== Logical plan ===\n{plan}\n"); + } + + let plan = state.optimize(&plan)?; + if debug { + println!("=== Optimized logical plan ===\n{plan}\n"); + } + let physical_plan = state.create_physical_plan(&plan).await?; + if debug { + println!( + "=== Physical plan ===\n{}\n", + displayable(physical_plan.as_ref()).indent(true) + ); + } + let result = collect(physical_plan.clone(), state.task_ctx()).await?; + if debug { + println!( + "=== Physical plan with metrics ===\n{}\n", + DisplayableExecutionPlan::with_metrics(physical_plan.as_ref()) + .indent(true) + ); + if !result.is_empty() { + // do not call print_batches if there are no batches as the result is confusing + // and makes it look like there is a batch with no columns + pretty::print_batches(&result)?; + } + } + Ok(result) + } + + async fn get_table( + &self, + ctx: &SessionContext, + table: &str, + ) -> Result> { + let path = self.path.to_str().unwrap(); + let target_partitions = self.partitions(); + + // Obtain a snapshot of the SessionState + let state = ctx.state(); + let path = format!("{path}/{table}.parquet"); + + // Check if the file exists + if !std::path::Path::new(&path).exists() { + eprintln!("Warning registering {table}: Table file does not exist: {path}"); + } + + let format = ParquetFormat::default() + .with_options(ctx.state().table_options().parquet.clone()); + + let table_path = ListingTableUrl::parse(path)?; + let options = ListingOptions::new(Arc::new(format)) + .with_file_extension(DEFAULT_PARQUET_EXTENSION) + .with_target_partitions(target_partitions) + .with_collect_stat(state.config().collect_statistics()); + let schema = options.infer_schema(&state, &table_path).await?; + + if self.common.debug { + println!( + "Inferred schema from {table_path} for table '{table}':\n{schema:#?}\n" + ); + } + + let options = if self.sorted { + let key_column_name = schema.fields()[0].name(); + options + .with_file_sort_order(vec![vec![col(key_column_name).sort(true, false)]]) + } else { + options + }; + + let config = ListingTableConfig::new(table_path) + .with_listing_options(options) + .with_schema(schema); + + Ok(Arc::new(ListingTable::try_new(config)?)) + } + + fn iterations(&self) -> usize { + self.common.iterations + } + + fn partitions(&self) -> usize { + self.common + .partitions + .unwrap_or_else(get_available_parallelism) + } +} diff --git a/benchmarks/src/tpch/convert.rs b/benchmarks/src/tpch/convert.rs deleted file mode 100644 index 5219e09cd3052..0000000000000 --- a/benchmarks/src/tpch/convert.rs +++ /dev/null @@ -1,162 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::logical_expr::select_expr::SelectExpr; -use datafusion_common::instant::Instant; -use std::fs; -use std::path::{Path, PathBuf}; - -use datafusion::common::not_impl_err; - -use super::get_tbl_tpch_table_schema; -use super::TPCH_TABLES; -use datafusion::error::Result; -use datafusion::prelude::*; -use parquet::basic::Compression; -use parquet::file::properties::WriterProperties; -use structopt::StructOpt; - -/// Convert tpch .slt files to .parquet or .csv files -#[derive(Debug, StructOpt)] -pub struct ConvertOpt { - /// Path to csv files - #[structopt(parse(from_os_str), required = true, short = "i", long = "input")] - input_path: PathBuf, - - /// Output path - #[structopt(parse(from_os_str), required = true, short = "o", long = "output")] - output_path: PathBuf, - - /// Output file format: `csv` or `parquet` - #[structopt(short = "f", long = "format")] - file_format: String, - - /// Compression to use when writing Parquet files - #[structopt(short = "c", long = "compression", default_value = "zstd")] - compression: String, - - /// Number of partitions to produce - #[structopt(short = "n", long = "partitions", default_value = "1")] - partitions: usize, - - /// Batch size when reading CSV or Parquet files - #[structopt(short = "s", long = "batch-size", default_value = "8192")] - batch_size: usize, - - /// Sort each table by its first column in ascending order. - #[structopt(short = "t", long = "sort")] - sort: bool, -} - -impl ConvertOpt { - pub async fn run(self) -> Result<()> { - let compression = self.compression()?; - - let input_path = self.input_path.to_str().unwrap(); - let output_path = self.output_path.to_str().unwrap(); - - let output_root_path = Path::new(output_path); - for table in TPCH_TABLES { - let start = Instant::now(); - let schema = get_tbl_tpch_table_schema(table); - let key_column_name = schema.fields()[0].name(); - - let input_path = format!("{input_path}/{table}.tbl"); - let options = CsvReadOptions::new() - .schema(&schema) - .has_header(false) - .delimiter(b'|') - .file_extension(".tbl"); - let options = if self.sort { - // indicated that the file is already sorted by its first column to speed up the conversion - options - .file_sort_order(vec![vec![col(key_column_name).sort(true, false)]]) - } else { - options - }; - - let config = SessionConfig::new().with_batch_size(self.batch_size); - let ctx = SessionContext::new_with_config(config); - - // build plan to read the TBL file - let mut csv = ctx.read_csv(&input_path, options).await?; - - // Select all apart from the padding column - let selection = csv - .schema() - .iter() - .take(schema.fields.len() - 1) - .map(Expr::from) - .map(SelectExpr::from) - .collect::>(); - - csv = csv.select(selection)?; - // optionally, repartition the file - let partitions = self.partitions; - if partitions > 1 { - csv = csv.repartition(Partitioning::RoundRobinBatch(partitions))? - } - let csv = if self.sort { - csv.sort_by(vec![col(key_column_name)])? - } else { - csv - }; - - // create the physical plan - let csv = csv.create_physical_plan().await?; - - let output_path = output_root_path.join(table); - let output_path = output_path.to_str().unwrap().to_owned(); - fs::create_dir_all(&output_path)?; - println!( - "Converting '{}' to {} files in directory '{}'", - &input_path, self.file_format, &output_path - ); - match self.file_format.as_str() { - "csv" => ctx.write_csv(csv, output_path).await?, - "parquet" => { - let props = WriterProperties::builder() - .set_compression(compression) - .build(); - ctx.write_parquet(csv, output_path, Some(props)).await? - } - other => { - return not_impl_err!("Invalid output format: {other}"); - } - } - println!("Conversion completed in {} ms", start.elapsed().as_millis()); - } - - Ok(()) - } - - /// return the compression method to use when writing parquet - fn compression(&self) -> Result { - Ok(match self.compression.as_str() { - "none" => Compression::UNCOMPRESSED, - "snappy" => Compression::SNAPPY, - "brotli" => Compression::BROTLI(Default::default()), - "gzip" => Compression::GZIP(Default::default()), - "lz4" => Compression::LZ4, - "lz0" => Compression::LZO, - "zstd" => Compression::ZSTD(Default::default()), - other => { - return not_impl_err!("Invalid compression format: {other}"); - } - }) - } -} diff --git a/benchmarks/src/tpch/mod.rs b/benchmarks/src/tpch/mod.rs index 23d0681f560c8..681aa0a403ee1 100644 --- a/benchmarks/src/tpch/mod.rs +++ b/benchmarks/src/tpch/mod.rs @@ -27,13 +27,13 @@ use std::fs; mod run; pub use run::RunOpt; -mod convert; -pub use convert::ConvertOpt; - pub const TPCH_TABLES: &[&str] = &[ "part", "supplier", "partsupp", "customer", "orders", "lineitem", "nation", "region", ]; +pub const TPCH_QUERY_START_ID: usize = 1; +pub const TPCH_QUERY_END_ID: usize = 22; + /// The `.tbl` file contains a trailing column pub fn get_tbl_tpch_table_schema(table: &str) -> Schema { let mut schema = SchemaBuilder::from(get_tpch_table_schema(table).fields); diff --git a/benchmarks/src/tpch/run.rs b/benchmarks/src/tpch/run.rs index caef823aaf31d..0d1268013c168 100644 --- a/benchmarks/src/tpch/run.rs +++ b/benchmarks/src/tpch/run.rs @@ -19,15 +19,16 @@ use std::path::PathBuf; use std::sync::Arc; use super::{ - get_query_sql, get_tbl_tpch_table_schema, get_tpch_table_schema, TPCH_TABLES, + TPCH_QUERY_END_ID, TPCH_QUERY_START_ID, TPCH_TABLES, get_query_sql, + get_tbl_tpch_table_schema, get_tpch_table_schema, }; -use crate::util::{BenchmarkRun, CommonOpt}; +use crate::util::{BenchmarkRun, CommonOpt, QueryResult, print_memory_stats}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::{self, pretty_format_batches}; +use datafusion::datasource::file_format::FileFormat; use datafusion::datasource::file_format::csv::CsvFormat; use datafusion::datasource::file_format::parquet::ParquetFormat; -use datafusion::datasource::file_format::FileFormat; use datafusion::datasource::listing::{ ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, }; @@ -40,8 +41,8 @@ use datafusion_common::instant::Instant; use datafusion_common::utils::get_available_parallelism; use datafusion_common::{DEFAULT_CSV_EXTENSION, DEFAULT_PARQUET_EXTENSION}; +use clap::Args; use log::info; -use structopt::StructOpt; // hack to avoid `default_value is meaningless for bool` errors type BoolDefaultTrue = bool; @@ -53,52 +54,62 @@ type BoolDefaultTrue = bool; /// [2]. /// /// [1]: http://www.tpc.org/tpch/ -/// [2]: https://github.com/databricks/tpch-dbgen.git, +/// [2]: https://github.com/databricks/tpch-dbgen.git /// [2.17.1]: https://www.tpc.org/tpc_documents_current_versions/pdf/tpc-h_v2.17.1.pdf -#[derive(Debug, StructOpt, Clone)] -#[structopt(verbatim_doc_comment)] +#[derive(Debug, Args, Clone)] +#[command(verbatim_doc_comment)] pub struct RunOpt { /// Query number. If not specified, runs all queries - #[structopt(short, long)] - query: Option, + #[arg(short, long)] + pub query: Option, /// Common options - #[structopt(flatten)] + #[command(flatten)] common: CommonOpt, /// Path to data files - #[structopt(parse(from_os_str), required = true, short = "p", long = "path")] + #[arg(required = true, short = 'p', long = "path")] path: PathBuf, /// File format: `csv` or `parquet` - #[structopt(short = "f", long = "format", default_value = "csv")] + #[arg(short = 'f', long = "format", default_value = "csv")] file_format: String, /// Load the data into a MemTable before executing the query - #[structopt(short = "m", long = "mem-table")] + #[arg(short = 'm', long = "mem-table")] mem_table: bool, /// Path to machine readable output file - #[structopt(parse(from_os_str), short = "o", long = "output")] + #[arg(short = 'o', long = "output")] output_path: Option, /// Whether to disable collection of statistics (and cost based optimizations) or not. - #[structopt(short = "S", long = "disable-statistics")] + #[arg(short = 'S', long = "disable-statistics")] disable_statistics: bool, /// If true then hash join used, if false then sort merge join /// True by default. - #[structopt(short = "j", long = "prefer_hash_join", default_value = "true")] + #[arg(short = 'j', long = "prefer_hash_join", default_value = "true")] prefer_hash_join: BoolDefaultTrue, + /// If true then Piecewise Merge Join can be used, if false then it will opt for Nested Loop Join + /// False by default. + #[arg( + short = 'w', + long = "enable_piecewise_merge_join", + default_value = "false" + )] + enable_piecewise_merge_join: BoolDefaultTrue, + /// Mark the first column of each table as sorted in ascending order. /// The tables should have been created with the `--sort` option for this to have any effect. - #[structopt(short = "t", long = "sorted")] + #[arg(short = 't', long = "sorted")] sorted: bool, -} -const TPCH_QUERY_START_ID: usize = 1; -const TPCH_QUERY_END_ID: usize = 22; + /// How many bytes to buffer on the probe side of hash joins. + #[arg(long, default_value = "0")] + hash_join_buffering_capacity: usize, +} impl RunOpt { pub async fn run(self) -> Result<()> { @@ -114,19 +125,32 @@ impl RunOpt { .config()? .with_collect_statistics(!self.disable_statistics); config.options_mut().optimizer.prefer_hash_join = self.prefer_hash_join; - let rt_builder = self.common.runtime_env_builder()?; - let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); + config.options_mut().optimizer.enable_piecewise_merge_join = + self.enable_piecewise_merge_join; + config.options_mut().execution.hash_join_buffering_capacity = + self.hash_join_buffering_capacity; + let rt = self.common.build_runtime()?; + let ctx = SessionContext::new_with_config_rt(config, rt); // register tables self.register_tables(&ctx).await?; for query_id in query_range { benchmark_run.start_new_case(&format!("Query {query_id}")); - let query_run = self.benchmark_query(query_id, &ctx).await?; - for iter in query_run { - benchmark_run.write_iter(iter.elapsed, iter.row_count); + let query_run = self.benchmark_query(query_id, &ctx).await; + match query_run { + Ok(query_results) => { + for iter in query_results { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + Err(e) => { + benchmark_run.mark_failed(); + eprintln!("Query {query_id} failed: {e}"); + } } } benchmark_run.maybe_write_json(self.output_path.as_ref())?; + benchmark_run.maybe_print_failures(); Ok(()) } @@ -138,11 +162,12 @@ impl RunOpt { let mut millis = vec![]; // run benchmark let mut query_results = vec![]; + + let sql = &get_query_sql(query_id)?; + for i in 0..self.iterations() { let start = Instant::now(); - let sql = &get_query_sql(query_id)?; - // query 15 is special, with 3 statements. the second statement is the one from which we // want to capture the results let mut result = vec![]; @@ -160,7 +185,7 @@ impl RunOpt { } } - let elapsed = start.elapsed(); //.as_secs_f64() * 1000.0; + let elapsed = start.elapsed(); let ms = elapsed.as_secs_f64() * 1000.0; millis.push(ms); info!("output:\n\n{}\n\n", pretty_format_batches(&result)?); @@ -174,6 +199,9 @@ impl RunOpt { let avg = millis.iter().sum::() / millis.len() as f64; println!("Query {query_id} avg time: {avg:.2} ms"); + // Print memory stats using mimalloc (only when compiled with --features mimalloc_extended) + print_memory_stats(); + Ok(query_results) } @@ -264,7 +292,7 @@ impl RunOpt { (Arc::new(format), path, ".tbl") } "csv" => { - let path = format!("{path}/{table}"); + let path = format!("{path}/csv/{table}"); let format = CsvFormat::default() .with_delimiter(b',') .with_has_header(true); @@ -320,11 +348,6 @@ impl RunOpt { } } -struct QueryResult { - elapsed: std::time::Duration, - row_count: usize, -} - #[cfg(test)] // Only run with "ci" mode when we have the data #[cfg(feature = "ci")] @@ -363,6 +386,7 @@ mod tests { memory_limit: None, sort_spill_reservation_bytes: None, debug: false, + simulate_latency: false, }; let opt = RunOpt { query: Some(query), @@ -373,7 +397,9 @@ mod tests { output_path: None, disable_statistics: false, prefer_hash_join: true, + enable_piecewise_merge_join: false, sorted: false, + hash_join_buffering_capacity: 0, }; opt.register_tables(&ctx).await?; let queries = get_query_sql(query)?; @@ -381,7 +407,7 @@ mod tests { let plan = ctx.sql(&query).await?; let plan = plan.into_optimized_plan()?; let bytes = logical_plan_to_bytes(&plan)?; - let plan2 = logical_plan_from_bytes(&bytes, &ctx)?; + let plan2 = logical_plan_from_bytes(&bytes, &ctx.task_ctx())?; let plan_formatted = format!("{}", plan.display_indent()); let plan2_formatted = format!("{}", plan2.display_indent()); assert_eq!(plan_formatted, plan2_formatted); @@ -400,6 +426,7 @@ mod tests { memory_limit: None, sort_spill_reservation_bytes: None, debug: false, + simulate_latency: false, }; let opt = RunOpt { query: Some(query), @@ -410,7 +437,9 @@ mod tests { output_path: None, disable_statistics: false, prefer_hash_join: true, + enable_piecewise_merge_join: false, sorted: false, + hash_join_buffering_capacity: 0, }; opt.register_tables(&ctx).await?; let queries = get_query_sql(query)?; @@ -418,7 +447,7 @@ mod tests { let plan = ctx.sql(&query).await?; let plan = plan.create_physical_plan().await?; let bytes = physical_plan_to_bytes(plan.clone())?; - let plan2 = physical_plan_from_bytes(&bytes, &ctx)?; + let plan2 = physical_plan_from_bytes(&bytes, &ctx.task_ctx())?; let plan_formatted = format!("{}", displayable(plan.as_ref()).indent(false)); let plan2_formatted = format!("{}", displayable(plan2.as_ref()).indent(false)); diff --git a/benchmarks/src/util/access_log.rs b/benchmarks/src/util/access_log.rs deleted file mode 100644 index 2b29465ee20e3..0000000000000 --- a/benchmarks/src/util/access_log.rs +++ /dev/null @@ -1,74 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Benchmark data generation - -use datafusion::common::Result; -use datafusion::test_util::parquet::TestParquetFile; -use parquet::file::properties::WriterProperties; -use std::path::PathBuf; -use structopt::StructOpt; -use test_utils::AccessLogGenerator; - -// Options and builder for making an access log test file -// Note don't use docstring or else it ends up in help -#[derive(Debug, StructOpt, Clone)] -pub struct AccessLogOpt { - /// Path to folder where access log file will be generated - #[structopt(parse(from_os_str), required = true, short = "p", long = "path")] - path: PathBuf, - - /// Data page size of the generated parquet file - #[structopt(long = "page-size")] - page_size: Option, - - /// Data page size of the generated parquet file - #[structopt(long = "row-group-size")] - row_group_size: Option, - - /// Total size of generated dataset. The default scale factor of 1.0 will generate a roughly 1GB parquet file - #[structopt(long = "scale-factor", default_value = "1.0")] - scale_factor: f32, -} - -impl AccessLogOpt { - /// Create the access log and return the file. - /// - /// See [`TestParquetFile`] for more details - pub fn build(self) -> Result { - let path = self.path.join("logs.parquet"); - - let mut props_builder = WriterProperties::builder(); - - if let Some(s) = self.page_size { - props_builder = props_builder - .set_data_page_size_limit(s) - .set_write_batch_size(s); - } - - if let Some(s) = self.row_group_size { - props_builder = props_builder.set_max_row_group_size(s); - } - let props = props_builder.build(); - - let generator = AccessLogGenerator::new(); - - let num_batches = 100_f32 * self.scale_factor; - - TestParquetFile::try_new(path, props, generator.take(num_batches as usize)) - } -} diff --git a/benchmarks/src/util/latency_object_store.rs b/benchmarks/src/util/latency_object_store.rs new file mode 100644 index 0000000000000..9ef8d1b78b751 --- /dev/null +++ b/benchmarks/src/util/latency_object_store.rs @@ -0,0 +1,157 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! An ObjectStore wrapper that adds simulated S3-like latency to get and list operations. +//! +//! Cycles through a fixed latency distribution inspired by real S3 performance: +//! - P50: ~30ms +//! - P75-P90: ~100-120ms +//! - P99: ~150-200ms + +use std::fmt; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::time::Duration; + +use async_trait::async_trait; +use futures::StreamExt; +use futures::stream::BoxStream; +use object_store::path::Path; +use object_store::{ + CopyOptions, GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, + ObjectStore, PutMultipartOptions, PutOptions, PutPayload, PutResult, Result, +}; + +/// GET latency distribution, inspired by S3 latencies. +/// Deterministic but shuffled to avoid artificial patterns. +/// 20 values: 11x P50 (~25-35ms), 5x P75-P90 (~70-110ms), 2x P95 (~120-150ms), 2x P99 (~180-200ms) +/// Sorted: 25,25,28,28,30,30,30,30,32,32,35, 70,85,100,100,110, 130,150, 180,200 +/// P50≈32ms, P90≈110ms, P99≈200ms +const GET_LATENCIES_MS: &[u64] = &[ + 30, 100, 25, 85, 32, 200, 28, 130, 35, 70, 30, 150, 30, 110, 28, 180, 32, 25, 100, 30, +]; + +/// LIST latency distribution, generally higher than GET. +/// 20 values: 11x P50 (~40-70ms), 5x P75-P90 (~120-180ms), 2x P95 (~200-250ms), 2x P99 (~300-400ms) +/// Sorted: 40,40,50,50,55,55,60,60,65,65,70, 120,140,160,160,180, 210,250, 300,400 +/// P50≈65ms, P90≈180ms, P99≈400ms +const LIST_LATENCIES_MS: &[u64] = &[ + 55, 160, 40, 140, 65, 400, 50, 210, 70, 120, 60, 250, 55, 180, 50, 300, 65, 40, 160, + 60, +]; + +/// An ObjectStore wrapper that injects simulated latency on get and list calls. +#[derive(Debug)] +pub struct LatencyObjectStore { + inner: T, + get_counter: AtomicUsize, + list_counter: AtomicUsize, +} + +impl LatencyObjectStore { + pub fn new(inner: T) -> Self { + Self { + inner, + get_counter: AtomicUsize::new(0), + list_counter: AtomicUsize::new(0), + } + } + + fn next_get_latency(&self) -> Duration { + let idx = + self.get_counter.fetch_add(1, Ordering::Relaxed) % GET_LATENCIES_MS.len(); + Duration::from_millis(GET_LATENCIES_MS[idx]) + } + + fn next_list_latency(&self) -> Duration { + let idx = + self.list_counter.fetch_add(1, Ordering::Relaxed) % LIST_LATENCIES_MS.len(); + Duration::from_millis(LIST_LATENCIES_MS[idx]) + } +} + +impl fmt::Display for LatencyObjectStore { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "LatencyObjectStore({})", self.inner) + } +} + +#[async_trait] +impl ObjectStore for LatencyObjectStore { + async fn put_opts( + &self, + location: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> Result { + self.inner.put_opts(location, payload, opts).await + } + + async fn put_multipart_opts( + &self, + location: &Path, + opts: PutMultipartOptions, + ) -> Result> { + self.inner.put_multipart_opts(location, opts).await + } + + async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { + tokio::time::sleep(self.next_get_latency()).await; + self.inner.get_opts(location, options).await + } + + async fn get_ranges( + &self, + location: &Path, + ranges: &[std::ops::Range], + ) -> Result> { + tokio::time::sleep(self.next_get_latency()).await; + self.inner.get_ranges(location, ranges).await + } + + fn delete_stream( + &self, + locations: BoxStream<'static, Result>, + ) -> BoxStream<'static, Result> { + self.inner.delete_stream(locations) + } + + fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, Result> { + let latency = self.next_list_latency(); + let stream = self.inner.list(prefix); + futures::stream::once(async move { + tokio::time::sleep(latency).await; + futures::stream::empty() + }) + .flatten() + .chain(stream) + .boxed() + } + + async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { + tokio::time::sleep(self.next_list_latency()).await; + self.inner.list_with_delimiter(prefix).await + } + + async fn copy_opts( + &self, + from: &Path, + to: &Path, + options: CopyOptions, + ) -> Result<()> { + self.inner.copy_opts(from, to, options).await + } +} diff --git a/benchmarks/src/util/memory.rs b/benchmarks/src/util/memory.rs new file mode 100644 index 0000000000000..11b96ef227756 --- /dev/null +++ b/benchmarks/src/util/memory.rs @@ -0,0 +1,57 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Print Peak RSS, Peak Commit, Page Faults based on mimalloc api +pub fn print_memory_stats() { + #[cfg(all(feature = "mimalloc", feature = "mimalloc_extended"))] + { + use datafusion_common::human_readable_size; + let mut peak_rss = 0; + let mut peak_commit = 0; + let mut page_faults = 0; + unsafe { + libmimalloc_sys::mi_process_info( + std::ptr::null_mut(), + std::ptr::null_mut(), + std::ptr::null_mut(), + std::ptr::null_mut(), + &mut peak_rss, + std::ptr::null_mut(), + &mut peak_commit, + &mut page_faults, + ); + } + + // When modifying this output format, make sure to update the corresponding + // parsers in `mem_profile.rs`, specifically `parse_vm_line` and `parse_query_time`, + // to keep the log output and parser logic in sync. + println!( + "Peak RSS: {}, Peak Commit: {}, Page Faults: {}", + if peak_rss == 0 { + "N/A".to_string() + } else { + human_readable_size(peak_rss) + }, + if peak_commit == 0 { + "N/A".to_string() + } else { + human_readable_size(peak_commit) + }, + page_faults + ); + } +} diff --git a/benchmarks/src/util/mod.rs b/benchmarks/src/util/mod.rs index 95c6e5f53d0f0..6dc11c0f425bd 100644 --- a/benchmarks/src/util/mod.rs +++ b/benchmarks/src/util/mod.rs @@ -16,10 +16,11 @@ // under the License. //! Shared benchmark utilities -mod access_log; +pub mod latency_object_store; +mod memory; mod options; mod run; -pub use access_log::AccessLogOpt; +pub use memory::print_memory_stats; pub use options::CommonOpt; -pub use run::{BenchQuery, BenchmarkRun}; +pub use run::{BenchQuery, BenchmarkRun, QueryResult}; diff --git a/benchmarks/src/util/options.rs b/benchmarks/src/util/options.rs index 6627a287dfcd4..a50a5268c0bfe 100644 --- a/benchmarks/src/util/options.rs +++ b/benchmarks/src/util/options.rs @@ -17,50 +17,59 @@ use std::{num::NonZeroUsize, sync::Arc}; +use clap::Args; use datafusion::{ execution::{ disk_manager::DiskManagerBuilder, memory_pool::{FairSpillPool, GreedyMemoryPool, MemoryPool, TrackConsumersPool}, - runtime_env::RuntimeEnvBuilder, + object_store::ObjectStoreUrl, + runtime_env::{RuntimeEnv, RuntimeEnvBuilder}, }, prelude::SessionConfig, }; use datafusion_common::{DataFusionError, Result}; -use structopt::StructOpt; +use object_store::local::LocalFileSystem; + +use super::latency_object_store::LatencyObjectStore; // Common benchmark options (don't use doc comments otherwise this doc // shows up in help files) -#[derive(Debug, StructOpt, Clone)] +#[derive(Debug, Args, Clone)] pub struct CommonOpt { /// Number of iterations of each test run - #[structopt(short = "i", long = "iterations", default_value = "3")] + #[arg(short = 'i', long = "iterations", default_value = "3")] pub iterations: usize, /// Number of partitions to process in parallel. Defaults to number of available cores. - #[structopt(short = "n", long = "partitions")] + #[arg(short = 'n', long = "partitions")] pub partitions: Option, /// Batch size when reading CSV or Parquet files - #[structopt(short = "s", long = "batch-size")] + #[arg(short = 's', long = "batch-size")] pub batch_size: Option, /// The memory pool type to use, should be one of "fair" or "greedy" - #[structopt(long = "mem-pool-type", default_value = "fair")] + #[arg(long = "mem-pool-type", default_value = "fair")] pub mem_pool_type: String, /// Memory limit (e.g. '100M', '1.5G'). If not specified, run all pre-defined memory limits for given query /// if there's any, otherwise run with no memory limit. - #[structopt(long = "memory-limit", parse(try_from_str = parse_memory_limit))] + #[arg(long = "memory-limit", value_parser = parse_capacity_limit)] pub memory_limit: Option, /// The amount of memory to reserve for sort spill operations. DataFusion's default value will be used /// if not specified. - #[structopt(long = "sort-spill-reservation-bytes", parse(try_from_str = parse_memory_limit))] + #[arg(long = "sort-spill-reservation-bytes", value_parser = parse_capacity_limit)] pub sort_spill_reservation_bytes: Option, /// Activate debug mode to see more details - #[structopt(short, long)] + #[arg(short, long)] pub debug: bool, + + /// Simulate object store latency to mimic remote storage (e.g. S3). + /// Adds random latency in the range 20-200ms to each object store operation. + #[arg(long = "simulate-latency")] + pub simulate_latency: bool, } impl CommonOpt { @@ -91,7 +100,15 @@ impl CommonOpt { pub fn runtime_env_builder(&self) -> Result { let mut rt_builder = RuntimeEnvBuilder::new(); const NUM_TRACKED_CONSUMERS: usize = 5; - if let Some(memory_limit) = self.memory_limit { + // Use CLI --memory-limit if provided, otherwise fall back to + // DATAFUSION_RUNTIME_MEMORY_LIMIT env var + let memory_limit = self.memory_limit.or_else(|| { + std::env::var("DATAFUSION_RUNTIME_MEMORY_LIMIT") + .ok() + .and_then(|val| parse_capacity_limit(&val).ok()) + }); + + if let Some(memory_limit) = memory_limit { let pool: Arc = match self.mem_pool_type.as_str() { "fair" => Arc::new(TrackConsumersPool::new( FairSpillPool::new(memory_limit), @@ -105,7 +122,7 @@ impl CommonOpt { return Err(DataFusionError::Configuration(format!( "Invalid memory pool type: {}", self.mem_pool_type - ))) + ))); } }; rt_builder = rt_builder @@ -114,22 +131,44 @@ impl CommonOpt { } Ok(rt_builder) } + + /// Build the runtime environment, optionally wrapping the local filesystem + /// with a throttled object store to simulate remote storage latency. + pub fn build_runtime(&self) -> Result> { + let rt = self.runtime_env_builder()?.build_arc()?; + if self.simulate_latency { + let store: Arc = + Arc::new(LatencyObjectStore::new(LocalFileSystem::new())); + let url = ObjectStoreUrl::parse("file:///")?; + rt.register_object_store(url.as_ref(), store); + println!( + "Simulating S3-like object store latency (get: 25-200ms, list: 40-400ms)" + ); + } + Ok(rt) + } } -/// Parse memory limit from string to number of bytes -/// e.g. '1.5G', '100M' -> 1572864 -fn parse_memory_limit(limit: &str) -> Result { +/// Parse capacity limit from string to number of bytes by allowing units: K, M and G. +/// Supports formats like '1.5G' -> 1610612736, '100M' -> 104857600 +fn parse_capacity_limit(limit: &str) -> Result { + if limit.trim().is_empty() { + return Err("Capacity limit cannot be empty".to_string()); + } let (number, unit) = limit.split_at(limit.len() - 1); let number: f64 = number .parse() - .map_err(|_| format!("Failed to parse number from memory limit '{limit}'"))?; + .map_err(|_| format!("Failed to parse number from capacity limit '{limit}'"))?; + if number.is_sign_negative() || number.is_infinite() { + return Err("Limit value should be positive finite number".to_string()); + } match unit { "K" => Ok((number * 1024.0) as usize), "M" => Ok((number * 1024.0 * 1024.0) as usize), "G" => Ok((number * 1024.0 * 1024.0 * 1024.0) as usize), _ => Err(format!( - "Unsupported unit '{unit}' in memory limit '{limit}'" + "Unsupported unit '{unit}' in capacity limit '{limit}'. Unit must be one of: 'K', 'M', 'G'" )), } } @@ -139,16 +178,59 @@ mod tests { use super::*; #[test] - fn test_parse_memory_limit_all() { + fn test_runtime_env_builder_reads_env_var() { + // Set the env var and verify runtime_env_builder picks it up + // when no CLI --memory-limit is provided + let opt = CommonOpt { + iterations: 3, + partitions: None, + batch_size: None, + mem_pool_type: "fair".to_string(), + memory_limit: None, + sort_spill_reservation_bytes: None, + debug: false, + simulate_latency: false, + }; + + // With env var set, builder should succeed and have a memory pool + // SAFETY: This test is single-threaded and the env var is restored after use + unsafe { + std::env::set_var("DATAFUSION_RUNTIME_MEMORY_LIMIT", "2G"); + } + let builder = opt.runtime_env_builder().unwrap(); + let runtime = builder.build().unwrap(); + unsafe { + std::env::remove_var("DATAFUSION_RUNTIME_MEMORY_LIMIT"); + } + // A 2G memory pool should be present — verify it reports the correct limit + match runtime.memory_pool.memory_limit() { + datafusion::execution::memory_pool::MemoryLimit::Finite(limit) => { + assert_eq!(limit, 2 * 1024 * 1024 * 1024); + } + _ => panic!("Expected Finite memory limit"), + } + } + + #[test] + fn test_parse_capacity_limit_all() { // Test valid inputs - assert_eq!(parse_memory_limit("100K").unwrap(), 102400); - assert_eq!(parse_memory_limit("1.5M").unwrap(), 1572864); - assert_eq!(parse_memory_limit("2G").unwrap(), 2147483648); + assert_eq!(parse_capacity_limit("100K").unwrap(), 102400); + assert_eq!(parse_capacity_limit("1.5M").unwrap(), 1572864); + assert_eq!(parse_capacity_limit("2G").unwrap(), 2147483648); // Test invalid unit - assert!(parse_memory_limit("500X").is_err()); + assert!(parse_capacity_limit("500X").is_err()); // Test invalid number - assert!(parse_memory_limit("abcM").is_err()); + assert!(parse_capacity_limit("abcM").is_err()); + + // Test negative number + assert!(parse_capacity_limit("-1M").is_err()); + + // Test infinite number + assert!(parse_capacity_limit("infM").is_err()); + + // Test negative infinite number + assert!(parse_capacity_limit("-infM").is_err()); } } diff --git a/benchmarks/src/util/run.rs b/benchmarks/src/util/run.rs index 13969f4d39497..df17674e62961 100644 --- a/benchmarks/src/util/run.rs +++ b/benchmarks/src/util/run.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use datafusion::{error::Result, DATAFUSION_VERSION}; +use datafusion::{DATAFUSION_VERSION, error::Result}; use datafusion_common::utils::get_available_parallelism; use serde::{Serialize, Serializer}; use serde_json::Value; @@ -90,8 +90,13 @@ pub struct BenchQuery { iterations: Vec, #[serde(serialize_with = "serialize_start_time")] start_time: SystemTime, + success: bool, +} +/// Internal representation of a single benchmark query iteration result. +pub struct QueryResult { + pub elapsed: Duration, + pub row_count: usize, } - /// collects benchmark run data and then serializes it at the end pub struct BenchmarkRun { context: RunContext, @@ -120,6 +125,7 @@ impl BenchmarkRun { query: id.to_owned(), iterations: vec![], start_time: SystemTime::now(), + success: true, }); if let Some(c) = self.current_case.as_mut() { *c += 1; @@ -138,6 +144,28 @@ impl BenchmarkRun { } } + /// Print the names of failed queries, if any + pub fn maybe_print_failures(&self) { + let failed_queries: Vec<&str> = self + .queries + .iter() + .filter_map(|q| (!q.success).then_some(q.query.as_str())) + .collect(); + + if !failed_queries.is_empty() { + println!("Failed Queries: {}", failed_queries.join(", ")); + } + } + + /// Mark current query + pub fn mark_failed(&mut self) { + if let Some(idx) = self.current_case { + self.queries[idx].success = false; + } else { + unreachable!("Cannot mark failure: no current case"); + } + } + /// Stringify data into formatted json pub fn to_json(&self) -> String { let mut output = HashMap::<&str, Value>::new(); diff --git a/ci/scripts/check_examples_docs.sh b/ci/scripts/check_examples_docs.sh new file mode 100755 index 0000000000000..62308b323b535 --- /dev/null +++ b/ci/scripts/check_examples_docs.sh @@ -0,0 +1,77 @@ +#!/usr/bin/env bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Generates documentation for DataFusion examples using the Rust-based +# documentation generator and verifies that the committed README.md +# is up to date. +# +# The README is generated from documentation comments in: +# datafusion-examples/examples//main.rs +# +# This script is intended to be run in CI to ensure that example +# documentation stays in sync with the code. +# +# To update the README locally, run this script and replace README.md +# with the generated output. + +set -euo pipefail + +ROOT_DIR="$(git rev-parse --show-toplevel)" + +# Load centralized tool versions +source "${ROOT_DIR}/ci/scripts/utils/tool_versions.sh" + +EXAMPLES_DIR="$ROOT_DIR/datafusion-examples" +README="$EXAMPLES_DIR/README.md" +README_NEW="$EXAMPLES_DIR/README-NEW.md" + +echo "▶ Generating examples README (Rust generator)…" +cargo run --quiet \ + --manifest-path "$EXAMPLES_DIR/Cargo.toml" \ + --bin examples-docs \ + > "$README_NEW" + +echo "▶ Formatting generated README with prettier ${PRETTIER_VERSION}…" +npx "prettier@${PRETTIER_VERSION}" \ + --parser markdown \ + --write "$README_NEW" + +echo "▶ Comparing generated README with committed version…" + +if ! diff -u "$README" "$README_NEW" > /tmp/examples-readme.diff; then + echo "" + echo "❌ Examples README is out of date." + echo "" + echo "The examples documentation is generated automatically from:" + echo " - datafusion-examples/examples//main.rs" + echo "" + echo "To update the README locally, run:" + echo "" + echo " cargo run --bin examples-docs \\" + echo " | npx prettier@${PRETTIER_VERSION} --parser markdown --write \\" + echo " > datafusion-examples/README.md" + echo "" + echo "Diff:" + echo "------------------------------------------------------------" + cat /tmp/examples-readme.diff + echo "------------------------------------------------------------" + exit 1 +fi + +echo "✅ Examples README is up-to-date." diff --git a/ci/scripts/doc_prettier_check.sh b/ci/scripts/doc_prettier_check.sh new file mode 100755 index 0000000000000..95332eb65aaf2 --- /dev/null +++ b/ci/scripts/doc_prettier_check.sh @@ -0,0 +1,86 @@ +#!/usr/bin/env bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -euo pipefail + +ROOT_DIR="$(git rev-parse --show-toplevel)" +SCRIPT_NAME="$(basename "${BASH_SOURCE[0]}")" + +# Load shared utilities and tool versions +source "${ROOT_DIR}/ci/scripts/utils/tool_versions.sh" +source "${ROOT_DIR}/ci/scripts/utils/git.sh" + +PRETTIER_TARGETS=( + '{datafusion,datafusion-cli,datafusion-examples,dev,docs}/**/*.md' + '!datafusion/CHANGELOG.md' + README.md + CONTRIBUTING.md +) + +MODE="check" +ALLOW_DIRTY=0 + +usage() { + cat >&2 </dev/null 2>&1; then + echo "npx is required to run the prettier check. Install Node.js (e.g., brew install node) and re-run." >&2 + exit 1 +fi + +PRETTIER_MODE=(--check) +if [[ "$MODE" == "write" ]]; then + PRETTIER_MODE=(--write) +fi + +# Ignore subproject CHANGELOG.md because it is machine generated +npx "prettier@${PRETTIER_VERSION}" "${PRETTIER_MODE[@]}" "${PRETTIER_TARGETS[@]}" diff --git a/ci/scripts/license_header.sh b/ci/scripts/license_header.sh new file mode 100755 index 0000000000000..7ab8c9637598b --- /dev/null +++ b/ci/scripts/license_header.sh @@ -0,0 +1,78 @@ +#!/usr/bin/env bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +SCRIPT_NAME="$(basename "${BASH_SOURCE[0]}")" + +source "${SCRIPT_DIR}/utils/git.sh" + +MODE="check" +ALLOW_DIRTY=0 +HAWKEYE_CONFIG="licenserc.toml" + +usage() { + cat >&2 <&2 <&2 <&2 <&2 <&2 + return 1 + fi +} diff --git a/ci/scripts/utils/tool_versions.sh b/ci/scripts/utils/tool_versions.sh new file mode 100644 index 0000000000000..ac731ed0d5341 --- /dev/null +++ b/ci/scripts/utils/tool_versions.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file defines centralized tool versions used by CI and development scripts. +# It is intended to be sourced by other scripts and should not be executed directly. + +PRETTIER_VERSION="2.7.1" diff --git a/clippy.toml b/clippy.toml index 114e3bfceb272..ea3609b574c06 100644 --- a/clippy.toml +++ b/clippy.toml @@ -9,4 +9,14 @@ disallowed-types = [ # Lowering the threshold to help prevent stack overflows (default is 16384) # See: https://rust-lang.github.io/rust-clippy/master/index.html#/large_futures -future-size-threshold = 10000 \ No newline at end of file +future-size-threshold = 10000 + +# Be more aware of large error variants which can impact the "happy path" due +# to large stack footprint when considering async state machines (default is 128). +# +# Value of 70 picked arbitrarily as something less than 100. +# +# See: +# - https://github.com/apache/datafusion/issues/16652 +# - https://rust-lang.github.io/rust-clippy/master/index.html#result_large_err +large-error-threshold = 70 diff --git a/datafusion-cli/CONTRIBUTING.md b/datafusion-cli/CONTRIBUTING.md index 4b464dffc57ce..8be656ec4ee34 100644 --- a/datafusion-cli/CONTRIBUTING.md +++ b/datafusion-cli/CONTRIBUTING.md @@ -21,55 +21,40 @@ ## Running Tests -Tests can be run using `cargo` +First check out test files with ```shell -cargo test +git submodule update --init ``` -## Running Storage Integration Tests - -By default, storage integration tests are not run. To run them you will need to set `TEST_STORAGE_INTEGRATION=1` and -then provide the necessary configuration for that object store. +Then run all the tests with -For some of the tests, [snapshots](https://datafusion.apache.org/contributor-guide/testing.html#snapshot-testing) are used. +```shell +cargo test --all-targets +``` -### AWS +## Running Storage Integration Tests -To test the S3 integration against [Minio](https://github.com/minio/minio) +By default, storage integration tests are not run. These tests use the `testcontainers` crate to start up a local MinIO server using Docker on port 9000. -First start up a container with Minio and load test files. +To run them you will need to set `TEST_STORAGE_INTEGRATION`: ```shell -docker run -d \ - --name datafusion-test-minio \ - -p 9000:9000 \ - -e MINIO_ROOT_USER=TEST-DataFusionLogin \ - -e MINIO_ROOT_PASSWORD=TEST-DataFusionPassword \ - -v $(pwd)/../datafusion/core/tests/data:/source \ - quay.io/minio/minio server /data - -docker exec datafusion-test-minio /bin/sh -c "\ - mc ready local - mc alias set localminio http://localhost:9000 TEST-DataFusionLogin TEST-DataFusionPassword && \ - mc mb localminio/data && \ - mc cp -r /source/* localminio/data" +TEST_STORAGE_INTEGRATION=1 cargo test ``` -Setup environment +For some of the tests, [snapshots](https://datafusion.apache.org/contributor-guide/testing.html#snapshot-testing) are used. -```shell -export TEST_STORAGE_INTEGRATION=1 -export AWS_ACCESS_KEY_ID=TEST-DataFusionLogin -export AWS_SECRET_ACCESS_KEY=TEST-DataFusionPassword -export AWS_ENDPOINT=http://127.0.0.1:9000 -export AWS_ALLOW_HTTP=true -``` +### AWS -Note that `AWS_ENDPOINT` is set without slash at the end. +S3 integration is tested against [Minio](https://github.com/minio/minio) with [TestContainers](https://github.com/testcontainers/testcontainers-rs) +This requires Docker to be running on your machine and port 9000 to be free. -Run tests +If you see an error mentioning "failed to load IMDS session token" such as -```shell -cargo test -``` +> ---- object_storage::tests::s3_object_store_builder_resolves_region_when_none_provided stdout ---- +> Error: ObjectStore(Generic { store: "S3", source: "Error getting credentials from provider: an error occurred while loading credentials: failed to load IMDS session token" }) + +You may need to disable trying to fetch S3 credentials from the environment using the `AWS_EC2_METADATA_DISABLED`, for example: + +> $ AWS_EC2_METADATA_DISABLED=true TEST_STORAGE_INTEGRATION=1 cargo test diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index 2eec93628b520..3fe6be964c3f6 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -37,37 +37,44 @@ backtrace = ["datafusion/backtrace"] [dependencies] arrow = { workspace = true } async-trait = { workspace = true } -aws-config = "1.6.2" -aws-credential-types = "1.2.0" -clap = { version = "4.5.39", features = ["derive", "cargo"] } +aws-config = "1.8.14" +aws-credential-types = "1.2.13" +chrono = { workspace = true } +clap = { version = "4.5.60", features = ["cargo", "derive"] } datafusion = { workspace = true, features = [ "avro", + "compression", "crypto_expressions", "datetime_expressions", "encoding_expressions", "nested_expressions", "parquet", + "parquet_encryption", "recursive_protection", "regex_expressions", + "sql", "unicode_expressions", - "compression", ] } +datafusion-common = { workspace = true } dirs = "6.0.0" env_logger = { workspace = true } futures = { workspace = true } +log = { workspace = true } mimalloc = { version = "0.1", default-features = false } object_store = { workspace = true, features = ["aws", "gcp", "http"] } parking_lot = { workspace = true } parquet = { workspace = true, default-features = false } regex = { workspace = true } -rustyline = "16.0" -tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot", "signal"] } +rustyline = "17.0" +tokio = { workspace = true, features = ["macros", "parking_lot", "rt", "rt-multi-thread", "signal", "sync"] } url = { workspace = true } +[lints] +workspace = true + [dev-dependencies] -assert_cmd = "2.0" ctor = { workspace = true } insta = { workspace = true } insta-cmd = "0.6.0" -predicates = "3.0" rstest = { workspace = true } +testcontainers-modules = { workspace = true, features = ["minio"] } diff --git a/datafusion-cli/README.md b/datafusion-cli/README.md index ca796b525fa15..b34aa770374da 100644 --- a/datafusion-cli/README.md +++ b/datafusion-cli/README.md @@ -19,12 +19,15 @@ -# DataFusion Command-line Interface +# Apache DataFusion Command-line Interface -[DataFusion](https://datafusion.apache.org/) is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. DataFusion CLI (`datafusion-cli`) is a small command line utility that runs SQL queries using the DataFusion engine. +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ + # Frequently Asked Questions ## Where can I find more information? diff --git a/datafusion-cli/examples/cli-session-context.rs b/datafusion-cli/examples/cli-session-context.rs index 1a8f15c8731b2..6095072163870 100644 --- a/datafusion-cli/examples/cli-session-context.rs +++ b/datafusion-cli/examples/cli-session-context.rs @@ -23,12 +23,14 @@ use std::sync::Arc; use datafusion::{ dataframe::DataFrame, error::DataFusionError, - execution::{context::SessionState, TaskContext}, + execution::{TaskContext, context::SessionState}, logical_expr::{LogicalPlan, LogicalPlanBuilder}, prelude::SessionContext, }; use datafusion_cli::{ - cli_context::CliSessionContext, exec::exec_from_repl, print_options::PrintOptions, + cli_context::CliSessionContext, exec::exec_from_repl, + object_storage::instrumented::InstrumentedObjectStoreRegistry, + print_options::PrintOptions, }; use object_store::ObjectStore; @@ -89,6 +91,7 @@ pub async fn main() { quiet: false, maxrows: datafusion_cli::print_options::MaxRows::Unlimited, color: true, + instrumented_registry: Arc::new(InstrumentedObjectStoreRegistry::new()), }; exec_from_repl(&my_ctx, &mut print_options).await.unwrap(); diff --git a/datafusion-cli/src/catalog.rs b/datafusion-cli/src/catalog.rs index 3298b7deaeba2..63b055388fdbe 100644 --- a/datafusion-cli/src/catalog.rs +++ b/datafusion-cli/src/catalog.rs @@ -18,13 +18,13 @@ use std::any::Any; use std::sync::{Arc, Weak}; -use crate::object_storage::{get_object_store, AwsOptions, GcpOptions}; +use crate::object_storage::{AwsOptions, GcpOptions, get_object_store}; use datafusion::catalog::{CatalogProvider, CatalogProviderList, SchemaProvider}; use datafusion::common::plan_datafusion_err; -use datafusion::datasource::listing::ListingTableUrl; use datafusion::datasource::TableProvider; +use datafusion::datasource::listing::ListingTableUrl; use datafusion::error::Result; use datafusion::execution::context::SessionState; use datafusion::execution::session_state::SessionStateBuilder; @@ -152,10 +152,10 @@ impl SchemaProvider for DynamicObjectStoreSchemaProvider { async fn table(&self, name: &str) -> Result>> { let inner_table = self.inner.table(name).await; - if inner_table.is_ok() { - if let Some(inner_table) = inner_table? { - return Ok(Some(inner_table)); - } + if inner_table.is_ok() + && let Some(inner_table) = inner_table? + { + return Ok(Some(inner_table)); } // if the inner schema provider didn't have a table by @@ -200,6 +200,7 @@ impl SchemaProvider for DynamicObjectStoreSchemaProvider { table_url.scheme(), url, &state.default_table_options(), + false, ) .await?; state.runtime_env().register_object_store(url, store); @@ -218,17 +219,18 @@ impl SchemaProvider for DynamicObjectStoreSchemaProvider { } pub fn substitute_tilde(cur: String) -> String { - if let Some(usr_dir_path) = home_dir() { - if let Some(usr_dir) = usr_dir_path.to_str() { - if cur.starts_with('~') && !usr_dir.is_empty() { - return cur.replacen('~', usr_dir, 1); - } - } + if let Some(usr_dir_path) = home_dir() + && let Some(usr_dir) = usr_dir_path.to_str() + && cur.starts_with('~') + && !usr_dir.is_empty() + { + return cur.replacen('~', usr_dir, 1); } cur } #[cfg(test)] mod tests { + use std::{env, vec}; use super::*; @@ -284,6 +286,19 @@ mod tests { #[tokio::test] async fn query_s3_location_test() -> Result<()> { + let aws_envs = vec![ + "AWS_ENDPOINT", + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_ALLOW_HTTP", + ]; + for aws_env in aws_envs { + if env::var(aws_env).is_err() { + eprint!("aws envs not set, skipping s3 test"); + return Ok(()); + } + } + let bucket = "examples3bucket"; let location = format!("s3://{bucket}/file.parquet"); @@ -344,10 +359,12 @@ mod tests { } else { "/home/user" }; - env::set_var( - if cfg!(windows) { "USERPROFILE" } else { "HOME" }, - test_home_path, - ); + unsafe { + env::set_var( + if cfg!(windows) { "USERPROFILE" } else { "HOME" }, + test_home_path, + ); + } let input = "~/Code/datafusion/benchmarks/data/tpch_sf1/part/part-0.parquet"; let expected = PathBuf::from(test_home_path) .join("Code") @@ -361,12 +378,16 @@ mod tests { .to_string(); let actual = substitute_tilde(input.to_string()); assert_eq!(actual, expected); - match original_home { - Some(home_path) => env::set_var( - if cfg!(windows) { "USERPROFILE" } else { "HOME" }, - home_path.to_str().unwrap(), - ), - None => env::remove_var(if cfg!(windows) { "USERPROFILE" } else { "HOME" }), + unsafe { + match original_home { + Some(home_path) => env::set_var( + if cfg!(windows) { "USERPROFILE" } else { "HOME" }, + home_path.to_str().unwrap(), + ), + None => { + env::remove_var(if cfg!(windows) { "USERPROFILE" } else { "HOME" }) + } + } } } } diff --git a/datafusion-cli/src/cli_context.rs b/datafusion-cli/src/cli_context.rs index 516929ebacf19..a6320f03fe4de 100644 --- a/datafusion-cli/src/cli_context.rs +++ b/datafusion-cli/src/cli_context.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use datafusion::{ dataframe::DataFrame, error::DataFusionError, - execution::{context::SessionState, TaskContext}, + execution::{TaskContext, context::SessionState}, logical_expr::LogicalPlan, prelude::SessionContext, }; diff --git a/datafusion-cli/src/command.rs b/datafusion-cli/src/command.rs index 77bc8d3d20003..8aaa8025d1c3a 100644 --- a/datafusion-cli/src/command.rs +++ b/datafusion-cli/src/command.rs @@ -19,16 +19,16 @@ use crate::cli_context::CliSessionContext; use crate::exec::{exec_and_print, exec_from_lines}; -use crate::functions::{display_all_functions, Function}; +use crate::functions::{Function, display_all_functions}; use crate::print_format::PrintFormat; use crate::print_options::PrintOptions; use clap::ValueEnum; use datafusion::arrow::array::{ArrayRef, StringArray}; use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::arrow::record_batch::RecordBatch; -use datafusion::common::exec_err; use datafusion::common::instant::Instant; -use datafusion::error::{DataFusionError, Result}; +use datafusion::common::{exec_datafusion_err, exec_err}; +use datafusion::error::Result; use std::fs::File; use std::io::BufReader; use std::str::FromStr; @@ -46,6 +46,7 @@ pub enum Command { SearchFunctions(String), QuietMode(Option), OutputFormat(Option), + ObjectStoreProfileMode(Option), } pub enum OutputFormat { @@ -84,9 +85,7 @@ impl Command { Self::Include(filename) => { if let Some(filename) = filename { let file = File::open(filename).map_err(|e| { - DataFusionError::Execution(format!( - "Error opening {filename:?} {e}" - )) + exec_datafusion_err!("Error opening {filename:?} {e}") })?; exec_from_lines(ctx, &mut BufReader::new(file), print_options) .await?; @@ -124,6 +123,29 @@ impl Command { Self::OutputFormat(_) => exec_err!( "Unexpected change output format, this should be handled outside" ), + Self::ObjectStoreProfileMode(mode) => { + if let Some(mode) = mode { + let profile_mode = mode + .parse() + .map_err(|_| + exec_datafusion_err!("Failed to parse input: {mode}. Valid options are disabled, summary, trace") + )?; + print_options + .instrumented_registry + .set_instrument_mode(profile_mode); + println!( + "ObjectStore Profile mode set to {}", + print_options.instrumented_registry.instrument_mode() + ); + } else { + println!( + "ObjectStore Profile mode is {}", + print_options.instrumented_registry.instrument_mode() + ); + } + + Ok(()) + } } } @@ -142,11 +164,15 @@ impl Command { Self::OutputFormat(_) => { ("\\pset [NAME [VALUE]]", "set table output option\n(format)") } + Self::ObjectStoreProfileMode(_) => ( + "\\object_store_profiling (disabled|summary|trace)", + "print or set object store profile mode", + ), } } } -const ALL_COMMANDS: [Command; 9] = [ +const ALL_COMMANDS: [Command; 10] = [ Command::ListTables, Command::DescribeTableStmt(String::new()), Command::Quit, @@ -156,6 +182,7 @@ const ALL_COMMANDS: [Command; 9] = [ Command::SearchFunctions(String::new()), Command::QuietMode(None), Command::OutputFormat(None), + Command::ObjectStoreProfileMode(None), ]; fn all_commands_info() -> RecordBatch { @@ -206,6 +233,10 @@ impl FromStr for Command { Self::OutputFormat(Some(subcommand.to_string())) } ("pset", None) => Self::OutputFormat(None), + ("object_store_profiling", Some(mode)) => { + Self::ObjectStoreProfileMode(Some(mode.to_string())) + } + ("object_store_profiling", None) => Self::ObjectStoreProfileMode(None), _ => return Err(()), }) } @@ -246,3 +277,62 @@ impl OutputFormat { } } } + +#[cfg(test)] +mod tests { + use datafusion::prelude::SessionContext; + + use crate::{ + object_storage::instrumented::{ + InstrumentedObjectStoreMode, InstrumentedObjectStoreRegistry, + }, + print_options::MaxRows, + }; + + use super::*; + + #[tokio::test] + async fn command_execute_profile_mode() { + let ctx = SessionContext::new(); + + let mut print_options = PrintOptions { + format: PrintFormat::Automatic, + quiet: false, + maxrows: MaxRows::Unlimited, + color: true, + instrumented_registry: Arc::new(InstrumentedObjectStoreRegistry::new()), + }; + + let mut cmd: Command = "object_store_profiling" + .parse() + .expect("expected parse to succeed"); + assert!(cmd.execute(&ctx, &mut print_options).await.is_ok()); + assert_eq!( + print_options.instrumented_registry.instrument_mode(), + InstrumentedObjectStoreMode::default() + ); + + cmd = "object_store_profiling summary" + .parse() + .expect("expected parse to succeed"); + assert!(cmd.execute(&ctx, &mut print_options).await.is_ok()); + assert_eq!( + print_options.instrumented_registry.instrument_mode(), + InstrumentedObjectStoreMode::Summary + ); + + cmd = "object_store_profiling trace" + .parse() + .expect("expected parse to succeed"); + assert!(cmd.execute(&ctx, &mut print_options).await.is_ok()); + assert_eq!( + print_options.instrumented_registry.instrument_mode(), + InstrumentedObjectStoreMode::Trace + ); + + cmd = "object_store_profiling does_not_exist" + .parse() + .expect("expected parse to succeed"); + assert!(cmd.execute(&ctx, &mut print_options).await.is_err()); + } +} diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index 3c2a6e68bbe1b..09347d6d7dc2c 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -26,28 +26,28 @@ use crate::{ object_storage::get_object_store, print_options::{MaxRows, PrintOptions}, }; -use futures::StreamExt; -use std::collections::HashMap; -use std::fs::File; -use std::io::prelude::*; -use std::io::BufReader; - use datafusion::common::instant::Instant; use datafusion::common::{plan_datafusion_err, plan_err}; use datafusion::config::ConfigFileType; use datafusion::datasource::listing::ListingTableUrl; use datafusion::error::{DataFusionError, Result}; +use datafusion::execution::memory_pool::MemoryConsumer; use datafusion::logical_expr::{DdlStatement, LogicalPlan}; use datafusion::physical_plan::execution_plan::EmissionType; -use datafusion::physical_plan::{execute_stream, ExecutionPlanProperties}; -use datafusion::sql::parser::{DFParser, Statement}; -use datafusion::sql::sqlparser::dialect::dialect_from_str; - -use datafusion::execution::memory_pool::MemoryConsumer; use datafusion::physical_plan::spill::get_record_batch_memory_size; +use datafusion::physical_plan::{ExecutionPlanProperties, execute_stream}; +use datafusion::sql::parser::{DFParser, Statement}; use datafusion::sql::sqlparser; -use rustyline::error::ReadlineError; +use datafusion::sql::sqlparser::dialect::dialect_from_str; +use futures::StreamExt; +use log::warn; +use object_store::Error::Generic; use rustyline::Editor; +use rustyline::error::ReadlineError; +use std::collections::HashMap; +use std::fs::File; +use std::io::BufReader; +use std::io::prelude::*; use tokio::signal; /// run and execute SQL statements and commands, against a context with the given print options @@ -153,7 +153,7 @@ pub async fn exec_from_repl( } } else { eprintln!( - "'\\{}' is not a valid command", + "'\\{}' is not a valid command, you can use '\\?' to see all commands", &line[1..] ); } @@ -168,7 +168,10 @@ pub async fn exec_from_repl( } } } else { - eprintln!("'\\{}' is not a valid command", &line[1..]); + eprintln!( + "'\\{}' is not a valid command, you can use '\\?' to see all commands", + &line[1..] + ); } } Ok(line) => { @@ -193,6 +196,7 @@ pub async fn exec_from_repl( } Err(ReadlineError::Interrupted) => { println!("^C"); + rl.helper().unwrap().reset_hint(); continue; } Err(ReadlineError::Eof) => { @@ -214,7 +218,6 @@ pub(super) async fn exec_and_print( print_options: &PrintOptions, sql: String, ) -> Result<()> { - let now = Instant::now(); let task_ctx = ctx.task_ctx(); let options = task_ctx.session_config().options(); let dialect = &options.sql_parser.dialect; @@ -228,17 +231,46 @@ pub(super) async fn exec_and_print( let statements = DFParser::parse_sql_with_dialect(&sql, dialect.as_ref())?; for statement in statements { - let adjusted = - AdjustedPrintOptions::new(print_options.clone()).with_statement(&statement); + StatementExecutor::new(statement) + .execute(ctx, print_options) + .await?; + } - let plan = create_plan(ctx, statement).await?; - let adjusted = adjusted.with_plan(&plan); + Ok(()) +} - let df = ctx.execute_logical_plan(plan).await?; +/// Executor for SQL statements, including special handling for S3 region detection retry logic +struct StatementExecutor { + statement: Statement, + statement_for_retry: Option, +} + +impl StatementExecutor { + fn new(statement: Statement) -> Self { + let statement_for_retry = matches!(statement, Statement::CreateExternalTable(_)) + .then(|| statement.clone()); + + Self { + statement, + statement_for_retry, + } + } + + async fn execute( + self, + ctx: &dyn CliSessionContext, + print_options: &PrintOptions, + ) -> Result<()> { + let now = Instant::now(); + let (df, adjusted) = self + .create_and_execute_logical_plan(ctx, print_options) + .await?; let physical_plan = df.create_physical_plan().await?; + let task_ctx = ctx.task_ctx(); + let options = task_ctx.session_config().options(); // Track memory usage for the query result if it's bounded - let mut reservation = + let reservation = MemoryConsumer::new("DataFusion-Cli").register(task_ctx.memory_pool()); if physical_plan.boundedness().is_unbounded() { @@ -269,7 +301,7 @@ pub(super) async fn exec_and_print( let curr_num_rows = batch.num_rows(); // Stop collecting results if the number of rows exceeds the limit // results batch should include the last batch that exceeds the limit - if row_count < max_rows + curr_num_rows { + if row_count < max_rows.saturating_add(curr_num_rows) { // Try to grow the reservation to accommodate the batch in memory reservation.try_grow(get_record_batch_memory_size(&batch))?; results.push(batch); @@ -285,9 +317,40 @@ pub(super) async fn exec_and_print( )?; reservation.free(); } + + Ok(()) } - Ok(()) + async fn create_and_execute_logical_plan( + mut self, + ctx: &dyn CliSessionContext, + print_options: &PrintOptions, + ) -> Result<(datafusion::dataframe::DataFrame, AdjustedPrintOptions)> { + let adjusted = AdjustedPrintOptions::new(print_options.clone()) + .with_statement(&self.statement); + + let plan = create_plan(ctx, self.statement, false).await?; + let adjusted = adjusted.with_plan(&plan); + + let df = match ctx.execute_logical_plan(plan).await { + Ok(df) => Ok(df), + Err(DataFusionError::ObjectStore(err)) + if matches!(err.as_ref(), Generic { store, source: _ } if "S3".eq_ignore_ascii_case(store)) + && self.statement_for_retry.is_some() => + { + warn!( + "S3 region is incorrect, auto-detecting the correct region (this may be slow). Consider updating your region configuration." + ); + let plan = + create_plan(ctx, self.statement_for_retry.take().unwrap(), true) + .await?; + ctx.execute_logical_plan(plan).await + } + Err(e) => Err(e), + }?; + + Ok((df, adjusted)) + } } /// Track adjustments to the print options based on the plan / statement being executed @@ -348,6 +411,7 @@ fn config_file_type_from_str(ext: &str) -> Option { async fn create_plan( ctx: &dyn CliSessionContext, statement: Statement, + resolve_region: bool, ) -> Result { let mut plan = ctx.session_state().statement_to_plan(statement).await?; @@ -362,6 +426,7 @@ async fn create_plan( &cmd.location, &cmd.options, format, + resolve_region, ) .await?; } @@ -374,6 +439,7 @@ async fn create_plan( ©_to.output_url, ©_to.options, format, + false, ) .await?; } @@ -412,6 +478,7 @@ pub(crate) async fn register_object_store_and_config_extensions( location: &String, options: &HashMap, format: Option, + resolve_region: bool, ) -> Result<()> { // Parse the location URL to extract the scheme and other components let table_path = ListingTableUrl::parse(location)?; @@ -433,8 +500,14 @@ pub(crate) async fn register_object_store_and_config_extensions( table_options.alter_with_string_hash_map(options)?; // Retrieve the appropriate object store based on the scheme, URL, and modified table options - let store = - get_object_store(&ctx.session_state(), scheme, url, &table_options).await?; + let store = get_object_store( + &ctx.session_state(), + scheme, + url, + &table_options, + resolve_region, + ) + .await?; // Register the retrieved object store in the session context's runtime environment ctx.register_object_store(url, store); @@ -449,6 +522,7 @@ mod tests { use datafusion::common::plan_err; use datafusion::prelude::SessionContext; + use datafusion_common::assert_contains; use url::Url; async fn create_external_table_test(location: &str, sql: &str) -> Result<()> { @@ -462,6 +536,7 @@ mod tests { &cmd.location, &cmd.options, format, + false, ) .await?; } else { @@ -488,6 +563,7 @@ mod tests { &cmd.output_url, &cmd.options, format, + false, ) .await?; } else { @@ -513,6 +589,19 @@ mod tests { } #[tokio::test] async fn copy_to_external_object_store_test() -> Result<()> { + let aws_envs = vec![ + "AWS_ENDPOINT", + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_ALLOW_HTTP", + ]; + for aws_env in aws_envs { + if std::env::var(aws_env).is_err() { + eprint!("aws envs not set, skipping s3 test"); + return Ok(()); + } + } + let locations = vec![ "s3://bucket/path/file.parquet", "oss://bucket/path/file.parquet", @@ -534,7 +623,7 @@ mod tests { let statements = DFParser::parse_sql_with_dialect(&sql, dialect.as_ref())?; for statement in statements { //Should not fail - let mut plan = create_plan(&ctx, statement).await?; + let mut plan = create_plan(&ctx, statement, false).await?; if let LogicalPlan::Copy(copy_to) = &mut plan { assert_eq!(copy_to.output_url, location); assert_eq!(copy_to.file_type.get_ext(), "parquet".to_string()); @@ -617,8 +706,7 @@ mod tests { #[tokio::test] async fn create_object_store_table_gcs() -> Result<()> { let service_account_path = "fake_service_account_path"; - let service_account_key = - "{\"private_key\": \"fake_private_key.pem\",\"client_email\":\"fake_client_email\", \"private_key_id\":\"id\"}"; + let service_account_key = "{\"private_key\": \"fake_private_key.pem\",\"client_email\":\"fake_client_email\", \"private_key_id\":\"id\"}"; let application_credentials_path = "fake_application_credentials_path"; let location = "gcs://bucket/path/file.parquet"; @@ -628,15 +716,16 @@ mod tests { let err = create_external_table_test(location, &sql) .await .unwrap_err(); - assert!(err.to_string().contains("os error 2")); + assert_contains!(err.to_string(), "os error 2"); // for service_account_key - let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('gcp.service_account_key' '{service_account_key}') LOCATION '{location}'"); + let sql = format!( + "CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('gcp.service_account_key' '{service_account_key}') LOCATION '{location}'" + ); let err = create_external_table_test(location, &sql) .await - .unwrap_err() - .to_string(); - assert!(err.contains("No RSA key found in pem file"), "{err}"); + .unwrap_err(); + assert_contains!(err.to_string(), "Error reading pem file: no items found"); // for application_credentials_path let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET @@ -644,7 +733,7 @@ mod tests { let err = create_external_table_test(location, &sql) .await .unwrap_err(); - assert!(err.to_string().contains("os error 2")); + assert_contains!(err.to_string(), "os error 2"); Ok(()) } @@ -666,8 +755,9 @@ mod tests { let location = "path/to/file.cvs"; // Test with format options - let sql = - format!("CREATE EXTERNAL TABLE test STORED AS CSV LOCATION '{location}' OPTIONS('format.has_header' 'true')"); + let sql = format!( + "CREATE EXTERNAL TABLE test STORED AS CSV LOCATION '{location}' OPTIONS('format.has_header' 'true')" + ); create_external_table_test(location, &sql).await.unwrap(); Ok(()) diff --git a/datafusion-cli/src/functions.rs b/datafusion-cli/src/functions.rs index 911bbf34b06f4..67f3dc28269ef 100644 --- a/datafusion-cli/src/functions.rs +++ b/datafusion-cli/src/functions.rs @@ -17,20 +17,26 @@ //! Functions that are query-able and searchable via the `\h` command +use datafusion_common::instant::Instant; use std::fmt; use std::fs::File; use std::str::FromStr; use std::sync::Arc; -use arrow::array::{Int64Array, StringArray}; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::array::{ + DurationMillisecondArray, GenericListArray, Int64Array, StringArray, StructArray, + TimestampMillisecondArray, UInt64Array, +}; +use arrow::buffer::{Buffer, OffsetBuffer, ScalarBuffer}; +use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef, TimeUnit}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; use datafusion::catalog::{Session, TableFunctionImpl}; -use datafusion::common::{plan_err, Column}; -use datafusion::datasource::memory::MemorySourceConfig; +use datafusion::common::{Column, plan_err}; use datafusion::datasource::TableProvider; +use datafusion::datasource::memory::MemorySourceConfig; use datafusion::error::Result; +use datafusion::execution::cache::cache_manager::CacheManager; use datafusion::logical_expr::Expr; use datafusion::physical_plan::ExecutionPlan; use datafusion::scalar::ScalarValue; @@ -227,7 +233,7 @@ impl TableProvider for ParquetMetadataTable { self } - fn schema(&self) -> arrow::datatypes::SchemaRef { + fn schema(&self) -> SchemaRef { self.schema.clone() } @@ -322,7 +328,7 @@ pub struct ParquetMetadataFunc {} impl TableFunctionImpl for ParquetMetadataFunc { fn call(&self, exprs: &[Expr]) -> Result> { let filename = match exprs.first() { - Some(Expr::Literal(ScalarValue::Utf8(Some(s)))) => s, // single quote: parquet_metadata('x.parquet') + Some(Expr::Literal(ScalarValue::Utf8(Some(s)), _)) => s, // single quote: parquet_metadata('x.parquet') Some(Expr::Column(Column { name, .. })) => name, // double quote: parquet_metadata("x.parquet") _ => { return plan_err!( @@ -418,7 +424,9 @@ impl TableFunctionImpl for ParquetMetadataFunc { stats_max_value_arr.push(None); }; compression_arr.push(format!("{:?}", column.compression())); - encodings_arr.push(format!("{:?}", column.encodings())); + // need to collect into Vec to format + let encodings: Vec<_> = column.encodings().collect(); + encodings_arr.push(format!("{encodings:?}")); index_page_offset_arr.push(column.index_page_offset()); dictionary_page_offset_arr.push(column.dictionary_page_offset()); data_page_offset_arr.push(column.data_page_offset()); @@ -460,3 +468,416 @@ impl TableFunctionImpl for ParquetMetadataFunc { Ok(Arc::new(parquet_metadata)) } } + +/// METADATA_CACHE table function +#[derive(Debug)] +struct MetadataCacheTable { + schema: SchemaRef, + batch: RecordBatch, +} + +#[async_trait] +impl TableProvider for MetadataCacheTable { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> datafusion::logical_expr::TableType { + datafusion::logical_expr::TableType::Base + } + + async fn scan( + &self, + _state: &dyn Session, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + Ok(MemorySourceConfig::try_new_exec( + &[vec![self.batch.clone()]], + TableProvider::schema(self), + projection.cloned(), + )?) + } +} + +#[derive(Debug)] +pub struct MetadataCacheFunc { + cache_manager: Arc, +} + +impl MetadataCacheFunc { + pub fn new(cache_manager: Arc) -> Self { + Self { cache_manager } + } +} + +impl TableFunctionImpl for MetadataCacheFunc { + fn call(&self, exprs: &[Expr]) -> Result> { + if !exprs.is_empty() { + return plan_err!("metadata_cache should have no arguments"); + } + + let schema = Arc::new(Schema::new(vec![ + Field::new("path", DataType::Utf8, false), + Field::new( + "file_modified", + DataType::Timestamp(TimeUnit::Millisecond, None), + false, + ), + Field::new("file_size_bytes", DataType::UInt64, false), + Field::new("e_tag", DataType::Utf8, true), + Field::new("version", DataType::Utf8, true), + Field::new("metadata_size_bytes", DataType::UInt64, false), + Field::new("hits", DataType::UInt64, false), + Field::new("extra", DataType::Utf8, true), + ])); + + // construct record batch from metadata + let mut path_arr = vec![]; + let mut file_modified_arr = vec![]; + let mut file_size_bytes_arr = vec![]; + let mut e_tag_arr = vec![]; + let mut version_arr = vec![]; + let mut metadata_size_bytes = vec![]; + let mut hits_arr = vec![]; + let mut extra_arr = vec![]; + + let cached_entries = self.cache_manager.get_file_metadata_cache().list_entries(); + + for (path, entry) in cached_entries { + path_arr.push(path.to_string()); + file_modified_arr + .push(Some(entry.object_meta.last_modified.timestamp_millis())); + file_size_bytes_arr.push(entry.object_meta.size); + e_tag_arr.push(entry.object_meta.e_tag); + version_arr.push(entry.object_meta.version); + metadata_size_bytes.push(entry.size_bytes as u64); + hits_arr.push(entry.hits as u64); + + let mut extra = entry + .extra + .iter() + .map(|(k, v)| format!("{k}={v}")) + .collect::>(); + extra.sort(); + extra_arr.push(extra.join(" ")); + } + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(path_arr)), + Arc::new(TimestampMillisecondArray::from(file_modified_arr)), + Arc::new(UInt64Array::from(file_size_bytes_arr)), + Arc::new(StringArray::from(e_tag_arr)), + Arc::new(StringArray::from(version_arr)), + Arc::new(UInt64Array::from(metadata_size_bytes)), + Arc::new(UInt64Array::from(hits_arr)), + Arc::new(StringArray::from(extra_arr)), + ], + )?; + + let metadata_cache = MetadataCacheTable { schema, batch }; + Ok(Arc::new(metadata_cache)) + } +} + +/// STATISTICS_CACHE table function +#[derive(Debug)] +struct StatisticsCacheTable { + schema: SchemaRef, + batch: RecordBatch, +} + +#[async_trait] +impl TableProvider for StatisticsCacheTable { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> datafusion::logical_expr::TableType { + datafusion::logical_expr::TableType::Base + } + + async fn scan( + &self, + _state: &dyn Session, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + Ok(MemorySourceConfig::try_new_exec( + &[vec![self.batch.clone()]], + TableProvider::schema(self), + projection.cloned(), + )?) + } +} + +#[derive(Debug)] +pub struct StatisticsCacheFunc { + cache_manager: Arc, +} + +impl StatisticsCacheFunc { + pub fn new(cache_manager: Arc) -> Self { + Self { cache_manager } + } +} + +impl TableFunctionImpl for StatisticsCacheFunc { + fn call(&self, exprs: &[Expr]) -> Result> { + if !exprs.is_empty() { + return plan_err!("statistics_cache should have no arguments"); + } + + let schema = Arc::new(Schema::new(vec![ + Field::new("path", DataType::Utf8, false), + Field::new( + "file_modified", + DataType::Timestamp(TimeUnit::Millisecond, None), + false, + ), + Field::new("file_size_bytes", DataType::UInt64, false), + Field::new("e_tag", DataType::Utf8, true), + Field::new("version", DataType::Utf8, true), + Field::new("num_rows", DataType::Utf8, false), + Field::new("num_columns", DataType::UInt64, false), + Field::new("table_size_bytes", DataType::Utf8, false), + Field::new("statistics_size_bytes", DataType::UInt64, false), + ])); + + // construct record batch from metadata + let mut path_arr = vec![]; + let mut file_modified_arr = vec![]; + let mut file_size_bytes_arr = vec![]; + let mut e_tag_arr = vec![]; + let mut version_arr = vec![]; + let mut num_rows_arr = vec![]; + let mut num_columns_arr = vec![]; + let mut table_size_bytes_arr = vec![]; + let mut statistics_size_bytes_arr = vec![]; + + if let Some(file_statistics_cache) = self.cache_manager.get_file_statistic_cache() + { + for (path, entry) in file_statistics_cache.list_entries() { + path_arr.push(path.to_string()); + file_modified_arr + .push(Some(entry.object_meta.last_modified.timestamp_millis())); + file_size_bytes_arr.push(entry.object_meta.size); + e_tag_arr.push(entry.object_meta.e_tag); + version_arr.push(entry.object_meta.version); + num_rows_arr.push(entry.num_rows.to_string()); + num_columns_arr.push(entry.num_columns as u64); + table_size_bytes_arr.push(entry.table_size_bytes.to_string()); + statistics_size_bytes_arr.push(entry.statistics_size_bytes as u64); + } + } + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(path_arr)), + Arc::new(TimestampMillisecondArray::from(file_modified_arr)), + Arc::new(UInt64Array::from(file_size_bytes_arr)), + Arc::new(StringArray::from(e_tag_arr)), + Arc::new(StringArray::from(version_arr)), + Arc::new(StringArray::from(num_rows_arr)), + Arc::new(UInt64Array::from(num_columns_arr)), + Arc::new(StringArray::from(table_size_bytes_arr)), + Arc::new(UInt64Array::from(statistics_size_bytes_arr)), + ], + )?; + + let statistics_cache = StatisticsCacheTable { schema, batch }; + Ok(Arc::new(statistics_cache)) + } +} + +/// Implementation of the `list_files_cache` table function in datafusion-cli. +/// +/// This function returns the cached results of running a LIST command on a +/// particular object store path for a table. The object metadata is returned as +/// a List of Structs, with one Struct for each object. DataFusion uses these +/// cached results to plan queries against external tables. +/// +/// # Schema +/// ```sql +/// > describe select * from list_files_cache(); +/// +---------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-------------+ +/// | column_name | data_type | is_nullable | +/// +---------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-------------+ +/// | table | Utf8 | NO | +/// | path | Utf8 | NO | +/// | metadata_size_bytes | UInt64 | NO | +/// | expires_in | Duration(ms) | YES | +/// | metadata_list | List(Struct("file_path": non-null Utf8, "file_modified": non-null Timestamp(ms), "file_size_bytes": non-null UInt64, "e_tag": Utf8, "version": Utf8), field: 'metadata') | YES | +/// +---------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-------------+ +/// ``` +#[derive(Debug)] +struct ListFilesCacheTable { + schema: SchemaRef, + batch: RecordBatch, +} + +#[async_trait] +impl TableProvider for ListFilesCacheTable { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> datafusion::logical_expr::TableType { + datafusion::logical_expr::TableType::Base + } + + async fn scan( + &self, + _state: &dyn Session, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + Ok(MemorySourceConfig::try_new_exec( + &[vec![self.batch.clone()]], + TableProvider::schema(self), + projection.cloned(), + )?) + } +} + +#[derive(Debug)] +pub struct ListFilesCacheFunc { + cache_manager: Arc, +} + +impl ListFilesCacheFunc { + pub fn new(cache_manager: Arc) -> Self { + Self { cache_manager } + } +} + +impl TableFunctionImpl for ListFilesCacheFunc { + fn call(&self, exprs: &[Expr]) -> Result> { + if !exprs.is_empty() { + return plan_err!("list_files_cache should have no arguments"); + } + + let nested_fields = Fields::from(vec![ + Field::new("file_path", DataType::Utf8, false), + Field::new( + "file_modified", + DataType::Timestamp(TimeUnit::Millisecond, None), + false, + ), + Field::new("file_size_bytes", DataType::UInt64, false), + Field::new("e_tag", DataType::Utf8, true), + Field::new("version", DataType::Utf8, true), + ]); + + let metadata_field = + Field::new("metadata", DataType::Struct(nested_fields.clone()), true); + + let schema = Arc::new(Schema::new(vec![ + Field::new("table", DataType::Utf8, true), + Field::new("path", DataType::Utf8, false), + Field::new("metadata_size_bytes", DataType::UInt64, false), + // expires field in ListFilesEntry has type Instant when set, from which we cannot get "the number of seconds", hence using Duration instead of Timestamp as data type. + Field::new( + "expires_in", + DataType::Duration(TimeUnit::Millisecond), + true, + ), + Field::new( + "metadata_list", + DataType::List(Arc::new(metadata_field.clone())), + true, + ), + ])); + + let mut table_arr = vec![]; + let mut path_arr = vec![]; + let mut metadata_size_bytes_arr = vec![]; + let mut expires_arr = vec![]; + + let mut file_path_arr = vec![]; + let mut file_modified_arr = vec![]; + let mut file_size_bytes_arr = vec![]; + let mut etag_arr = vec![]; + let mut version_arr = vec![]; + let mut offsets: Vec = vec![0]; + + if let Some(list_files_cache) = self.cache_manager.get_list_files_cache() { + let now = Instant::now(); + let mut current_offset: i32 = 0; + + for (path, entry) in list_files_cache.list_entries() { + table_arr.push(path.table.map(|t| t.to_string())); + path_arr.push(path.path.to_string()); + metadata_size_bytes_arr.push(entry.size_bytes as u64); + // calculates time left before entry expires + expires_arr.push( + entry + .expires + .map(|t| t.duration_since(now).as_millis() as i64), + ); + + for meta in entry.metas.files.iter() { + file_path_arr.push(meta.location.to_string()); + file_modified_arr.push(meta.last_modified.timestamp_millis()); + file_size_bytes_arr.push(meta.size); + etag_arr.push(meta.e_tag.clone()); + version_arr.push(meta.version.clone()); + } + current_offset += entry.metas.files.len() as i32; + offsets.push(current_offset); + } + } + + let struct_arr = StructArray::new( + nested_fields, + vec![ + Arc::new(StringArray::from(file_path_arr)), + Arc::new(TimestampMillisecondArray::from(file_modified_arr)), + Arc::new(UInt64Array::from(file_size_bytes_arr)), + Arc::new(StringArray::from(etag_arr)), + Arc::new(StringArray::from(version_arr)), + ], + None, + ); + + let offsets_buffer: OffsetBuffer = + OffsetBuffer::new(ScalarBuffer::from(Buffer::from_vec(offsets))); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(table_arr)), + Arc::new(StringArray::from(path_arr)), + Arc::new(UInt64Array::from(metadata_size_bytes_arr)), + Arc::new(DurationMillisecondArray::from(expires_arr)), + Arc::new(GenericListArray::new( + Arc::new(metadata_field), + offsets_buffer, + Arc::new(struct_arr), + None, + )), + ], + )?; + + let list_files_cache = ListFilesCacheTable { schema, batch }; + Ok(Arc::new(list_files_cache)) + } +} diff --git a/datafusion-cli/src/helper.rs b/datafusion-cli/src/helper.rs index 64c34c4737369..f01d0891b964c 100644 --- a/datafusion-cli/src/helper.rs +++ b/datafusion-cli/src/helper.rs @@ -19,11 +19,13 @@ //! and auto-completion for file name during creating external table. use std::borrow::Cow; +use std::cell::Cell; -use crate::highlighter::{NoSyntaxHighlighter, SyntaxHighlighter}; +use crate::highlighter::{Color, NoSyntaxHighlighter, SyntaxHighlighter}; use datafusion::sql::parser::{DFParser, Statement}; use datafusion::sql::sqlparser::dialect::dialect_from_str; +use datafusion_common::config::Dialect; use rustyline::completion::{Completer, FilenameCompleter, Pair}; use rustyline::error::ReadlineError; @@ -32,14 +34,21 @@ use rustyline::hint::Hinter; use rustyline::validate::{ValidationContext, ValidationResult, Validator}; use rustyline::{Context, Helper, Result}; +/// Default suggestion shown when the input line is empty. +const DEFAULT_HINT_SUGGESTION: &str = " \\? for help, \\q to quit"; + pub struct CliHelper { completer: FilenameCompleter, - dialect: String, + dialect: Dialect, highlighter: Box, + /// Tracks whether to show the default hint. Set to `false` once the user + /// types anything, so the hint doesn't reappear after deleting back to + /// an empty line. Reset to `true` when the line is submitted. + show_hint: Cell, } impl CliHelper { - pub fn new(dialect: &str, color: bool) -> Self { + pub fn new(dialect: &Dialect, color: bool) -> Self { let highlighter: Box = if !color { Box::new(NoSyntaxHighlighter {}) } else { @@ -47,26 +56,32 @@ impl CliHelper { }; Self { completer: FilenameCompleter::new(), - dialect: dialect.into(), + dialect: *dialect, highlighter, + show_hint: Cell::new(true), } } - pub fn set_dialect(&mut self, dialect: &str) { - if dialect != self.dialect { - self.dialect = dialect.to_string(); + pub fn set_dialect(&mut self, dialect: &Dialect) { + if *dialect != self.dialect { + self.dialect = *dialect; } } + /// Re-enable the default hint for the next prompt. + pub fn reset_hint(&self) { + self.show_hint.set(true); + } + fn validate_input(&self, input: &str) -> Result { if let Some(sql) = input.strip_suffix(';') { - let dialect = match dialect_from_str(&self.dialect) { + let dialect = match dialect_from_str(self.dialect) { Some(dialect) => dialect, None => { return Ok(ValidationResult::Invalid(Some(format!( " 🤔 Invalid dialect: {}", self.dialect - )))) + )))); } }; let lines = split_from_semicolon(sql); @@ -97,7 +112,7 @@ impl CliHelper { impl Default for CliHelper { fn default() -> Self { - Self::new("generic", false) + Self::new(&Dialect::Generic, false) } } @@ -113,6 +128,14 @@ impl Highlighter for CliHelper { impl Hinter for CliHelper { type Hint = String; + + fn hint(&self, line: &str, _pos: usize, _ctx: &Context<'_>) -> Option { + if !line.is_empty() { + self.show_hint.set(false); + } + (self.show_hint.get() && line.trim().is_empty()) + .then(|| Color::gray(DEFAULT_HINT_SUGGESTION)) + } } /// returns true if the current position is after the open quote for @@ -120,12 +143,9 @@ impl Hinter for CliHelper { fn is_open_quote_for_location(line: &str, pos: usize) -> bool { let mut sql = line[..pos].to_string(); sql.push('\''); - if let Ok(stmts) = DFParser::parse_sql(&sql) { - if let Some(Statement::CreateExternalTable(_)) = stmts.back() { - return true; - } - } - false + DFParser::parse_sql(&sql).is_ok_and(|stmts| { + matches!(stmts.back(), Some(Statement::CreateExternalTable(_))) + }) } impl Completer for CliHelper { @@ -148,7 +168,9 @@ impl Completer for CliHelper { impl Validator for CliHelper { fn validate(&self, ctx: &mut ValidationContext<'_>) -> Result { let input = ctx.input().trim_end(); - self.validate_input(input) + let result = self.validate_input(input); + self.reset_hint(); + result } } @@ -289,7 +311,7 @@ mod tests { ); // valid in postgresql dialect - validator.set_dialect("postgresql"); + validator.set_dialect(&Dialect::PostgreSQL); let result = readline_direct(Cursor::new(r"select 1 # 2;".as_bytes()), &validator)?; assert!(matches!(result, ValidationResult::Valid(None))); diff --git a/datafusion-cli/src/highlighter.rs b/datafusion-cli/src/highlighter.rs index 7a886b94740bd..adcb135bb401f 100644 --- a/datafusion-cli/src/highlighter.rs +++ b/datafusion-cli/src/highlighter.rs @@ -23,10 +23,11 @@ use std::{ }; use datafusion::sql::sqlparser::{ - dialect::{dialect_from_str, Dialect, GenericDialect}, + dialect::{Dialect, GenericDialect, dialect_from_str}, keywords::Keyword, tokenizer::{Token, Tokenizer}, }; +use datafusion_common::config; use rustyline::highlight::{CmdKind, Highlighter}; /// The syntax highlighter. @@ -36,8 +37,9 @@ pub struct SyntaxHighlighter { } impl SyntaxHighlighter { - pub fn new(dialect: &str) -> Self { - let dialect = dialect_from_str(dialect).unwrap_or(Box::new(GenericDialect {})); + pub fn new(dialect: &config::Dialect) -> Self { + let dialect = + dialect_from_str(dialect).unwrap_or_else(|| Box::new(GenericDialect {})); Self { dialect } } } @@ -79,27 +81,32 @@ impl Highlighter for SyntaxHighlighter { } /// Convenient utility to return strings with [ANSI color](https://gist.github.com/JBlond/2fea43a3049b38287e5e9cefc87b2124). -struct Color {} +pub(crate) struct Color {} impl Color { - fn green(s: impl Display) -> String { + pub(crate) fn green(s: impl Display) -> String { format!("\x1b[92m{s}\x1b[0m") } - fn red(s: impl Display) -> String { + pub(crate) fn red(s: impl Display) -> String { format!("\x1b[91m{s}\x1b[0m") } + + pub(crate) fn gray(s: impl Display) -> String { + format!("\x1b[90m{s}\x1b[0m") + } } #[cfg(test)] mod tests { use super::SyntaxHighlighter; + use super::config::Dialect; use rustyline::highlight::Highlighter; #[test] fn highlighter_valid() { let s = "SElect col_a from tab_1;"; - let highlighter = SyntaxHighlighter::new("generic"); + let highlighter = SyntaxHighlighter::new(&Dialect::Generic); let out = highlighter.highlight(s, s.len()); assert_eq!( "\u{1b}[91mSElect\u{1b}[0m col_a \u{1b}[91mfrom\u{1b}[0m tab_1;", @@ -110,7 +117,7 @@ mod tests { #[test] fn highlighter_valid_with_new_line() { let s = "SElect col_a from tab_1\n WHERE col_b = 'なにか';"; - let highlighter = SyntaxHighlighter::new("generic"); + let highlighter = SyntaxHighlighter::new(&Dialect::Generic); let out = highlighter.highlight(s, s.len()); assert_eq!( "\u{1b}[91mSElect\u{1b}[0m col_a \u{1b}[91mfrom\u{1b}[0m tab_1\n \u{1b}[91mWHERE\u{1b}[0m col_b = \u{1b}[92m'なにか'\u{1b}[0m;", @@ -121,7 +128,7 @@ mod tests { #[test] fn highlighter_invalid() { let s = "SElect col_a from tab_1 WHERE col_b = ';"; - let highlighter = SyntaxHighlighter::new("generic"); + let highlighter = SyntaxHighlighter::new(&Dialect::Generic); let out = highlighter.highlight(s, s.len()); assert_eq!("SElect col_a from tab_1 WHERE col_b = ';", out); } diff --git a/datafusion-cli/src/lib.rs b/datafusion-cli/src/lib.rs index 34fba6f79304b..f0b0bc23fd73d 100644 --- a/datafusion-cli/src/lib.rs +++ b/datafusion-cli/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] #![doc = include_str!("../README.md")] pub const DATAFUSION_CLI_VERSION: &str = env!("CARGO_PKG_VERSION"); diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index fdecb185e33e4..6bfe1160ecdd6 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -28,15 +28,20 @@ use datafusion::execution::memory_pool::{ FairSpillPool, GreedyMemoryPool, MemoryPool, TrackConsumersPool, }; use datafusion::execution::runtime_env::RuntimeEnvBuilder; +use datafusion::logical_expr::ExplainFormat; use datafusion::prelude::SessionContext; use datafusion_cli::catalog::DynamicObjectStoreCatalog; -use datafusion_cli::functions::ParquetMetadataFunc; +use datafusion_cli::functions::{ + ListFilesCacheFunc, MetadataCacheFunc, ParquetMetadataFunc, StatisticsCacheFunc, +}; +use datafusion_cli::object_storage::instrumented::{ + InstrumentedObjectStoreMode, InstrumentedObjectStoreRegistry, +}; use datafusion_cli::{ - exec, + DATAFUSION_CLI_VERSION, exec, pool_type::PoolType, print_format::PrintFormat, print_options::{MaxRows, PrintOptions}, - DATAFUSION_CLI_VERSION, }; use clap::Parser; @@ -144,6 +149,13 @@ struct Args { value_parser(extract_disk_limit) )] disk_limit: Option, + + #[clap( + long, + help = "Specify the default object_store_profiling mode, defaults to 'disabled'.\n[possible values: disabled, summary, trace]", + default_value_t = InstrumentedObjectStoreMode::Disabled + )] + object_store_profiling: InstrumentedObjectStoreMode, } #[tokio::main] @@ -205,6 +217,12 @@ async fn main_inner() -> Result<()> { rt_builder = rt_builder.with_disk_manager_builder(builder); } + let instrumented_registry = Arc::new( + InstrumentedObjectStoreRegistry::new() + .with_profile_mode(args.object_store_profiling), + ); + rt_builder = rt_builder.with_object_store_registry(instrumented_registry.clone()); + let runtime_env = rt_builder.build_arc()?; // enable dynamic file query @@ -219,11 +237,35 @@ async fn main_inner() -> Result<()> { // register `parquet_metadata` table function to get metadata from parquet files ctx.register_udtf("parquet_metadata", Arc::new(ParquetMetadataFunc {})); + // register `metadata_cache` table function to get the contents of the file metadata cache + ctx.register_udtf( + "metadata_cache", + Arc::new(MetadataCacheFunc::new( + ctx.task_ctx().runtime_env().cache_manager.clone(), + )), + ); + + // register `statistics_cache` table function to get the contents of the file statistics cache + ctx.register_udtf( + "statistics_cache", + Arc::new(StatisticsCacheFunc::new( + ctx.task_ctx().runtime_env().cache_manager.clone(), + )), + ); + + ctx.register_udtf( + "list_files_cache", + Arc::new(ListFilesCacheFunc::new( + ctx.task_ctx().runtime_env().cache_manager.clone(), + )), + ); + let mut print_options = PrintOptions { format: args.format, quiet: args.quiet, maxrows: args.maxrows, color: args.color, + instrumented_registry: Arc::clone(&instrumented_registry), }; let commands = args.command; @@ -280,7 +322,7 @@ fn get_session_config(args: &Args) -> Result { // use easier to understand "tree" mode by default // if the user hasn't specified an explain format in the environment if env::var_os("DATAFUSION_EXPLAIN_FORMAT").is_none() { - config_options.explain.format = String::from("tree"); + config_options.explain.format = ExplainFormat::Tree; } // in the CLI, we want to show NULL values rather the empty strings @@ -396,9 +438,20 @@ pub fn extract_disk_limit(size: &str) -> Result { #[cfg(test)] mod tests { + use std::time::Duration; + use super::*; - use datafusion::common::test_util::batches_to_string; + use datafusion::{ + common::test_util::batches_to_string, + execution::cache::{ + DefaultListFilesCache, cache_manager::CacheManagerConfig, + cache_unit::DefaultFileStatisticsCache, + }, + prelude::{ParquetReadOptions, col, lit, split_part}, + }; use insta::assert_snapshot; + use object_store::memory::InMemory; + use url::Url; fn assert_conversion(input: &str, expected: Result) { let result = extract_memory_pool_size(input); @@ -462,8 +515,7 @@ mod tests { ctx.register_udtf("parquet_metadata", Arc::new(ParquetMetadataFunc {})); // input with single quote - let sql = - "SELECT * FROM parquet_metadata('../datafusion/core/tests/data/fixed_size_list_array.parquet')"; + let sql = "SELECT * FROM parquet_metadata('../datafusion/core/tests/data/fixed_size_list_array.parquet')"; let df = ctx.sql(sql).await?; let rbs = df.collect().await?; @@ -471,20 +523,19 @@ mod tests { +-------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+-------+-----------+-----------+------------------+----------------------+-----------------+-----------------+-------------+------------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+ | filename | row_group_id | row_group_num_rows | row_group_num_columns | row_group_bytes | column_id | file_offset | num_values | path_in_schema | type | stats_min | stats_max | stats_null_count | stats_distinct_count | stats_min_value | stats_max_value | compression | encodings | index_page_offset | dictionary_page_offset | data_page_offset | total_compressed_size | total_uncompressed_size | +-------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+-------+-----------+-----------+------------------+----------------------+-----------------+-----------------+-------------+------------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+ - | ../datafusion/core/tests/data/fixed_size_list_array.parquet | 0 | 2 | 1 | 123 | 0 | 125 | 4 | "f0.list.item" | INT64 | 1 | 4 | 0 | | 1 | 4 | SNAPPY | [RLE_DICTIONARY, PLAIN, RLE] | | 4 | 46 | 121 | 123 | + | ../datafusion/core/tests/data/fixed_size_list_array.parquet | 0 | 2 | 1 | 123 | 0 | 125 | 4 | "f0.list.item" | INT64 | 1 | 4 | 0 | | 1 | 4 | SNAPPY | [PLAIN, RLE, RLE_DICTIONARY] | | 4 | 46 | 121 | 123 | +-------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+-------+-----------+-----------+------------------+----------------------+-----------------+-----------------+-------------+------------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+ "#); // input with double quote - let sql = - "SELECT * FROM parquet_metadata(\"../datafusion/core/tests/data/fixed_size_list_array.parquet\")"; + let sql = "SELECT * FROM parquet_metadata(\"../datafusion/core/tests/data/fixed_size_list_array.parquet\")"; let df = ctx.sql(sql).await?; let rbs = df.collect().await?; assert_snapshot!(batches_to_string(&rbs), @r#" +-------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+-------+-----------+-----------+------------------+----------------------+-----------------+-----------------+-------------+------------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+ | filename | row_group_id | row_group_num_rows | row_group_num_columns | row_group_bytes | column_id | file_offset | num_values | path_in_schema | type | stats_min | stats_max | stats_null_count | stats_distinct_count | stats_min_value | stats_max_value | compression | encodings | index_page_offset | dictionary_page_offset | data_page_offset | total_compressed_size | total_uncompressed_size | +-------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+-------+-----------+-----------+------------------+----------------------+-----------------+-----------------+-------------+------------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+ - | ../datafusion/core/tests/data/fixed_size_list_array.parquet | 0 | 2 | 1 | 123 | 0 | 125 | 4 | "f0.list.item" | INT64 | 1 | 4 | 0 | | 1 | 4 | SNAPPY | [RLE_DICTIONARY, PLAIN, RLE] | | 4 | 46 | 121 | 123 | + | ../datafusion/core/tests/data/fixed_size_list_array.parquet | 0 | 2 | 1 | 123 | 0 | 125 | 4 | "f0.list.item" | INT64 | 1 | 4 | 0 | | 1 | 4 | SNAPPY | [PLAIN, RLE, RLE_DICTIONARY] | | 4 | 46 | 121 | 123 | +-------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+-------+-----------+-----------+------------------+----------------------+-----------------+-----------------+-------------+------------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+ "#); @@ -497,8 +548,7 @@ mod tests { ctx.register_udtf("parquet_metadata", Arc::new(ParquetMetadataFunc {})); // input with string columns - let sql = - "SELECT * FROM parquet_metadata('../parquet-testing/data/data_index_bloom_encoding_stats.parquet')"; + let sql = "SELECT * FROM parquet_metadata('../parquet-testing/data/data_index_bloom_encoding_stats.parquet')"; let df = ctx.sql(sql).await?; let rbs = df.collect().await?; @@ -506,10 +556,296 @@ mod tests { +-----------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+------------+-----------+-----------+------------------+----------------------+-----------------+-----------------+--------------------+--------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+ | filename | row_group_id | row_group_num_rows | row_group_num_columns | row_group_bytes | column_id | file_offset | num_values | path_in_schema | type | stats_min | stats_max | stats_null_count | stats_distinct_count | stats_min_value | stats_max_value | compression | encodings | index_page_offset | dictionary_page_offset | data_page_offset | total_compressed_size | total_uncompressed_size | +-----------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+------------+-----------+-----------+------------------+----------------------+-----------------+-----------------+--------------------+--------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+ - | ../parquet-testing/data/data_index_bloom_encoding_stats.parquet | 0 | 14 | 1 | 163 | 0 | 4 | 14 | "String" | BYTE_ARRAY | Hello | today | 0 | | Hello | today | GZIP(GzipLevel(6)) | [BIT_PACKED, RLE, PLAIN] | | | 4 | 152 | 163 | + | ../parquet-testing/data/data_index_bloom_encoding_stats.parquet | 0 | 14 | 1 | 163 | 0 | 4 | 14 | "String" | BYTE_ARRAY | Hello | today | 0 | | Hello | today | GZIP(GzipLevel(6)) | [PLAIN, RLE, BIT_PACKED] | | | 4 | 152 | 163 | +-----------------------------------------------------------------+--------------+--------------------+-----------------------+-----------------+-----------+-------------+------------+----------------+------------+-----------+-----------+------------------+----------------------+-----------------+-----------------+--------------------+--------------------------+-------------------+------------------------+------------------+-----------------------+-------------------------+ "#); Ok(()) } + + #[tokio::test] + async fn test_metadata_cache() -> Result<(), DataFusionError> { + let ctx = SessionContext::new(); + ctx.register_udtf( + "metadata_cache", + Arc::new(MetadataCacheFunc::new( + ctx.task_ctx().runtime_env().cache_manager.clone(), + )), + ); + + ctx.register_parquet( + "alltypes_plain", + "../parquet-testing/data/alltypes_plain.parquet", + ParquetReadOptions::new(), + ) + .await?; + + ctx.register_parquet( + "alltypes_tiny_pages", + "../parquet-testing/data/alltypes_tiny_pages.parquet", + ParquetReadOptions::new(), + ) + .await?; + + ctx.register_parquet( + "lz4_raw_compressed_larger", + "../parquet-testing/data/lz4_raw_compressed_larger.parquet", + ParquetReadOptions::new(), + ) + .await?; + + ctx.sql("select * from alltypes_plain") + .await? + .collect() + .await?; + ctx.sql("select * from alltypes_tiny_pages") + .await? + .collect() + .await?; + ctx.sql("select * from lz4_raw_compressed_larger") + .await? + .collect() + .await?; + + // initial state + let sql = "SELECT split_part(path, '/', -1) as filename, file_size_bytes, metadata_size_bytes, hits, extra from metadata_cache() order by filename"; + let df = ctx.sql(sql).await?; + let rbs = df.collect().await?; + + assert_snapshot!(batches_to_string(&rbs),@r" + +-----------------------------------+-----------------+---------------------+------+------------------+ + | filename | file_size_bytes | metadata_size_bytes | hits | extra | + +-----------------------------------+-----------------+---------------------+------+------------------+ + | alltypes_plain.parquet | 1851 | 8882 | 2 | page_index=false | + | alltypes_tiny_pages.parquet | 454233 | 269074 | 2 | page_index=true | + | lz4_raw_compressed_larger.parquet | 380836 | 1339 | 2 | page_index=false | + +-----------------------------------+-----------------+---------------------+------+------------------+ + "); + + // increase the number of hits + ctx.sql("select * from alltypes_plain") + .await? + .collect() + .await?; + ctx.sql("select * from alltypes_plain") + .await? + .collect() + .await?; + ctx.sql("select * from alltypes_plain") + .await? + .collect() + .await?; + ctx.sql("select * from lz4_raw_compressed_larger") + .await? + .collect() + .await?; + let sql = "select split_part(path, '/', -1) as filename, file_size_bytes, metadata_size_bytes, hits, extra from metadata_cache() order by filename"; + let df = ctx.sql(sql).await?; + let rbs = df.collect().await?; + + assert_snapshot!(batches_to_string(&rbs),@r" + +-----------------------------------+-----------------+---------------------+------+------------------+ + | filename | file_size_bytes | metadata_size_bytes | hits | extra | + +-----------------------------------+-----------------+---------------------+------+------------------+ + | alltypes_plain.parquet | 1851 | 8882 | 5 | page_index=false | + | alltypes_tiny_pages.parquet | 454233 | 269074 | 2 | page_index=true | + | lz4_raw_compressed_larger.parquet | 380836 | 1339 | 3 | page_index=false | + +-----------------------------------+-----------------+---------------------+------+------------------+ + "); + + Ok(()) + } + + /// Shows that the statistics cache is not enabled by default yet + /// See https://github.com/apache/datafusion/issues/19217 + #[tokio::test] + async fn test_statistics_cache_default() -> Result<(), DataFusionError> { + let ctx = SessionContext::new(); + + ctx.register_udtf( + "statistics_cache", + Arc::new(StatisticsCacheFunc::new( + ctx.task_ctx().runtime_env().cache_manager.clone(), + )), + ); + + for filename in [ + "alltypes_plain", + "alltypes_tiny_pages", + "lz4_raw_compressed_larger", + ] { + ctx.sql( + format!( + "create external table {filename} + stored as parquet + location '../parquet-testing/data/{filename}.parquet'", + ) + .as_str(), + ) + .await? + .collect() + .await?; + } + + // When the cache manager creates a StatisticsCache by default, + // the contents will show up here + let sql = "SELECT split_part(path, '/', -1) as filename, file_size_bytes, num_rows, num_columns, table_size_bytes from statistics_cache() order by filename"; + let df = ctx.sql(sql).await?; + let rbs = df.collect().await?; + assert_snapshot!(batches_to_string(&rbs),@r" + ++ + ++ + "); + + Ok(()) + } + + // Can be removed when https://github.com/apache/datafusion/issues/19217 is resolved + #[tokio::test] + async fn test_statistics_cache_override() -> Result<(), DataFusionError> { + // Install a specific StatisticsCache implementation + let file_statistics_cache = Arc::new(DefaultFileStatisticsCache::default()); + let cache_config = CacheManagerConfig::default() + .with_files_statistics_cache(Some(file_statistics_cache.clone())); + let runtime = RuntimeEnvBuilder::new() + .with_cache_manager(cache_config) + .build()?; + let config = SessionConfig::new().with_collect_statistics(true); + let ctx = SessionContext::new_with_config_rt(config, Arc::new(runtime)); + + ctx.register_udtf( + "statistics_cache", + Arc::new(StatisticsCacheFunc::new( + ctx.task_ctx().runtime_env().cache_manager.clone(), + )), + ); + + for filename in [ + "alltypes_plain", + "alltypes_tiny_pages", + "lz4_raw_compressed_larger", + ] { + ctx.sql( + format!( + "create external table {filename} + stored as parquet + location '../parquet-testing/data/{filename}.parquet'", + ) + .as_str(), + ) + .await? + .collect() + .await?; + } + + let sql = "SELECT split_part(path, '/', -1) as filename, file_size_bytes, num_rows, num_columns, table_size_bytes from statistics_cache() order by filename"; + let df = ctx.sql(sql).await?; + let rbs = df.collect().await?; + assert_snapshot!(batches_to_string(&rbs),@r" + +-----------------------------------+-----------------+--------------+-------------+------------------+ + | filename | file_size_bytes | num_rows | num_columns | table_size_bytes | + +-----------------------------------+-----------------+--------------+-------------+------------------+ + | alltypes_plain.parquet | 1851 | Exact(8) | 11 | Absent | + | alltypes_tiny_pages.parquet | 454233 | Exact(7300) | 13 | Absent | + | lz4_raw_compressed_larger.parquet | 380836 | Exact(10000) | 1 | Absent | + +-----------------------------------+-----------------+--------------+-------------+------------------+ + "); + + Ok(()) + } + + #[tokio::test] + async fn test_list_files_cache() -> Result<(), DataFusionError> { + let list_files_cache = Arc::new(DefaultListFilesCache::new( + 1024, + Some(Duration::from_secs(1)), + )); + + let rt = RuntimeEnvBuilder::new() + .with_cache_manager( + CacheManagerConfig::default() + .with_list_files_cache(Some(list_files_cache)), + ) + .build_arc() + .unwrap(); + + let ctx = SessionContext::new_with_config_rt(SessionConfig::default(), rt); + + ctx.register_object_store( + &Url::parse("mem://test_table").unwrap(), + Arc::new(InMemory::new()), + ); + + ctx.register_udtf( + "list_files_cache", + Arc::new(ListFilesCacheFunc::new( + ctx.task_ctx().runtime_env().cache_manager.clone(), + )), + ); + + ctx.sql( + "CREATE EXTERNAL TABLE src_table + STORED AS PARQUET + LOCATION '../parquet-testing/data/alltypes_plain.parquet'", + ) + .await? + .collect() + .await?; + + ctx.sql("COPY (SELECT * FROM src_table) TO 'mem://test_table/0.parquet' STORED AS PARQUET").await?.collect().await?; + + ctx.sql("COPY (SELECT * FROM src_table) TO 'mem://test_table/1.parquet' STORED AS PARQUET").await?.collect().await?; + + ctx.sql( + "CREATE EXTERNAL TABLE test_table + STORED AS PARQUET + LOCATION 'mem://test_table/' + ", + ) + .await? + .collect() + .await?; + + let sql = "SELECT metadata_size_bytes, expires_in, metadata_list FROM list_files_cache()"; + let df = ctx + .sql(sql) + .await? + .unnest_columns(&["metadata_list"])? + .with_column_renamed("metadata_list", "metadata")? + .unnest_columns(&["metadata"])?; + + assert_eq!( + 2, + df.clone() + .filter(col("expires_in").is_not_null())? + .count() + .await? + ); + + let df = df + .with_column_renamed(r#""metadata.file_size_bytes""#, "file_size_bytes")? + .with_column_renamed(r#""metadata.e_tag""#, "etag")? + .with_column( + "filename", + split_part(col(r#""metadata.file_path""#), lit("/"), lit(-1)), + )? + .select_columns(&[ + "metadata_size_bytes", + "filename", + "file_size_bytes", + "etag", + ])? + .sort(vec![col("filename").sort(true, false)])?; + let rbs = df.collect().await?; + assert_snapshot!(batches_to_string(&rbs),@r" + +---------------------+-----------+-----------------+------+ + | metadata_size_bytes | filename | file_size_bytes | etag | + +---------------------+-----------+-----------------+------+ + | 212 | 0.parquet | 3642 | 0 | + | 212 | 1.parquet | 3642 | 1 | + +---------------------+-----------+-----------------+------+ + "); + + Ok(()) + } } diff --git a/datafusion-cli/src/object_storage.rs b/datafusion-cli/src/object_storage.rs index c31310093ac6b..34787838929f1 100644 --- a/datafusion-cli/src/object_storage.rs +++ b/datafusion-cli/src/object_storage.rs @@ -15,29 +15,70 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; -use std::fmt::{Debug, Display}; -use std::sync::Arc; - -use datafusion::common::config::{ - ConfigEntry, ConfigExtension, ConfigField, ExtensionOptions, TableOptions, Visit, -}; -use datafusion::common::{config_err, exec_datafusion_err, exec_err}; -use datafusion::error::{DataFusionError, Result}; -use datafusion::execution::context::SessionState; +pub mod instrumented; use async_trait::async_trait; use aws_config::BehaviorVersion; -use aws_credential_types::provider::ProvideCredentials; -use object_store::aws::{AmazonS3Builder, AwsCredential}; -use object_store::gcp::GoogleCloudStorageBuilder; -use object_store::http::HttpBuilder; -use object_store::{ClientOptions, CredentialProvider, ObjectStore}; +use aws_credential_types::provider::{ + ProvideCredentials, SharedCredentialsProvider, error::CredentialsError, +}; +use datafusion::{ + common::{ + config::ConfigEntry, config::ConfigExtension, config::ConfigField, + config::ExtensionOptions, config::TableOptions, config::Visit, config_err, + exec_datafusion_err, exec_err, + }, + error::{DataFusionError, Result}, + execution::context::SessionState, +}; +use log::debug; +use object_store::{ + ClientOptions, CredentialProvider, + Error::Generic, + ObjectStore, + aws::{AmazonS3Builder, AmazonS3ConfigKey, AwsCredential}, + gcp::GoogleCloudStorageBuilder, + http::HttpBuilder, +}; +use std::{ + any::Any, + error::Error, + fmt::{Debug, Display}, + sync::Arc, +}; use url::Url; +#[cfg(not(test))] +use object_store::aws::resolve_bucket_region; + +// Provide a local mock when running tests so we don't make network calls +#[cfg(test)] +async fn resolve_bucket_region( + _bucket: &str, + _client_options: &ClientOptions, +) -> object_store::Result { + Ok("eu-central-1".to_string()) +} + pub async fn get_s3_object_store_builder( url: &Url, aws_options: &AwsOptions, + resolve_region: bool, +) -> Result { + // Box the inner future to reduce the future size of this async function, + // which is deeply nested in the CLI's async call chain. + Box::pin(get_s3_object_store_builder_inner( + url, + aws_options, + resolve_region, + )) + .await +} + +async fn get_s3_object_store_builder_inner( + url: &Url, + aws_options: &AwsOptions, + resolve_region: bool, ) -> Result { let AwsOptions { access_key_id, @@ -46,6 +87,7 @@ pub async fn get_s3_object_store_builder( region, endpoint, allow_http, + skip_signature, } = aws_options; let bucket_name = get_bucket_name(url)?; @@ -54,6 +96,7 @@ pub async fn get_s3_object_store_builder( if let (Some(access_key_id), Some(secret_access_key)) = (access_key_id, secret_access_key) { + debug!("Using explicitly provided S3 access_key_id and secret_access_key"); builder = builder .with_access_key_id(access_key_id) .with_secret_access_key(secret_access_key); @@ -62,40 +105,49 @@ pub async fn get_s3_object_store_builder( builder = builder.with_token(session_token); } } else { - let config = aws_config::defaults(BehaviorVersion::latest()).load().await; - if let Some(region) = config.region() { - builder = builder.with_region(region.to_string()); + debug!("Using AWS S3 SDK to determine credentials"); + let CredentialsFromConfig { + region, + credentials, + } = CredentialsFromConfig::try_new().await?; + if let Some(region) = region { + builder = builder.with_region(region); + } + if let Some(credentials) = credentials { + let credentials = Arc::new(S3CredentialProvider { credentials }); + builder = builder.with_credentials(credentials); + } else { + debug!("No credentials found, defaulting to skip signature "); + builder = builder.with_skip_signature(true); } - - let credentials = config - .credentials_provider() - .ok_or_else(|| { - DataFusionError::ObjectStore(object_store::Error::Generic { - store: "S3", - source: "Failed to get S3 credentials from the environment".into(), - }) - })? - .clone(); - - let credentials = Arc::new(S3CredentialProvider { credentials }); - builder = builder.with_credentials(credentials); } if let Some(region) = region { builder = builder.with_region(region); } + // If the region is not set or auto_detect_region is true, resolve the region. + if builder + .get_config_value(&AmazonS3ConfigKey::Region) + .is_none() + || resolve_region + { + let region = resolve_bucket_region(bucket_name, &ClientOptions::new()).await?; + builder = builder.with_region(region); + } + if let Some(endpoint) = endpoint { // Make a nicer error if the user hasn't allowed http and the endpoint // is http as the default message is "URL scheme is not allowed" - if let Ok(endpoint_url) = Url::try_from(endpoint.as_str()) { - if !matches!(allow_http, Some(true)) && endpoint_url.scheme() == "http" { - return config_err!( - "Invalid endpoint: {endpoint}. \ + if let Ok(endpoint_url) = Url::try_from(endpoint.as_str()) + && !matches!(allow_http, Some(true)) + && endpoint_url.scheme() == "http" + { + return config_err!( + "Invalid endpoint: {endpoint}. \ HTTP is not allowed for S3 endpoints. \ To allow HTTP, set 'aws.allow_http' to true" - ); - } + ); } builder = builder.with_endpoint(endpoint); @@ -105,12 +157,74 @@ pub async fn get_s3_object_store_builder( builder = builder.with_allow_http(*allow_http); } + if let Some(skip_signature) = skip_signature { + builder = builder.with_skip_signature(*skip_signature); + } + Ok(builder) } +/// Credentials from the AWS SDK +struct CredentialsFromConfig { + region: Option, + credentials: Option, +} + +impl CredentialsFromConfig { + /// Attempt find AWS S3 credentials via the AWS SDK + pub async fn try_new() -> Result { + let config = aws_config::defaults(BehaviorVersion::latest()).load().await; + let region = config.region().map(|r| r.to_string()); + + let credentials = config + .credentials_provider() + .ok_or_else(|| { + DataFusionError::ObjectStore(Box::new(Generic { + store: "S3", + source: "Failed to get S3 credentials aws_config".into(), + })) + })? + .clone(); + + // The credential provider is lazy, so it does not fetch credentials + // until they are needed. To ensure that the credentials are valid, + // we can call `provide_credentials` here. + let credentials = match credentials.provide_credentials().await { + Ok(_) => Some(credentials), + Err(CredentialsError::CredentialsNotLoaded(_)) => { + debug!("Could not use AWS SDK to get credentials"); + None + } + // other errors like `CredentialsError::InvalidConfiguration` + // should be returned to the user so they can be fixed + Err(e) => { + // Pass back underlying error to the user, including underlying source + let source_message = if let Some(source) = e.source() { + format!(": {source}") + } else { + String::new() + }; + + let message = format!( + "Error getting credentials from provider: {e}{source_message}", + ); + + return Err(DataFusionError::ObjectStore(Box::new(Generic { + store: "S3", + source: message.into(), + }))); + } + }; + Ok(Self { + region, + credentials, + }) + } +} + #[derive(Debug)] struct S3CredentialProvider { - credentials: aws_credential_types::provider::SharedCredentialsProvider, + credentials: SharedCredentialsProvider, } #[async_trait] @@ -118,12 +232,14 @@ impl CredentialProvider for S3CredentialProvider { type Credential = AwsCredential; async fn get_credential(&self) -> object_store::Result> { - let creds = self.credentials.provide_credentials().await.map_err(|e| { - object_store::Error::Generic { - store: "S3", - source: Box::new(e), - } - })?; + let creds = + self.credentials + .provide_credentials() + .await + .map_err(|e| Generic { + store: "S3", + source: Box::new(e), + })?; Ok(Arc::new(AwsCredential { key_id: creds.access_key_id().to_string(), secret_key: creds.secret_access_key().to_string(), @@ -197,10 +313,7 @@ pub fn get_gcs_object_store_builder( fn get_bucket_name(url: &Url) -> Result<&str> { url.host_str().ok_or_else(|| { - DataFusionError::Execution(format!( - "Not able to parse bucket name from url: {}", - url.as_str() - )) + exec_datafusion_err!("Not able to parse bucket name from url: {}", url.as_str()) }) } @@ -219,6 +332,11 @@ pub struct AwsOptions { pub endpoint: Option, /// Allow HTTP (otherwise will always use https) pub allow_http: Option, + /// Do not fetch credentials and do not sign requests + /// + /// This can be useful when interacting with public S3 buckets that deny + /// authorized requests + pub skip_signature: Option, } impl ExtensionOptions for AwsOptions { @@ -256,6 +374,9 @@ impl ExtensionOptions for AwsOptions { "allow_http" => { self.allow_http.set(rem, value)?; } + "skip_signature" | "nosign" => { + self.skip_signature.set(rem, value)?; + } _ => { return config_err!("Config value \"{}\" not found on AwsOptions", rem); } @@ -397,6 +518,7 @@ pub(crate) async fn get_object_store( scheme: &str, url: &Url, table_options: &TableOptions, + resolve_region: bool, ) -> Result, DataFusionError> { let store: Arc = match scheme { "s3" => { @@ -405,7 +527,8 @@ pub(crate) async fn get_object_store( "Given table options incompatible with the 's3' scheme" ); }; - let builder = get_s3_object_store_builder(url, options).await?; + let builder = + get_s3_object_store_builder(url, options, resolve_region).await?; Arc::new(builder.build()?) } "oss" => { @@ -461,7 +584,6 @@ mod tests { use super::*; - use datafusion::common::plan_err; use datafusion::{ datasource::listing::ListingTableUrl, logical_expr::{DdlStatement, LogicalPlan}, @@ -470,6 +592,74 @@ mod tests { use object_store::{aws::AmazonS3ConfigKey, gcp::GoogleConfigKey}; + #[tokio::test] + async fn s3_object_store_builder_default() -> Result<()> { + if let Err(DataFusionError::Execution(e)) = check_aws_envs().await { + // Skip test if AWS envs are not set + eprintln!("{e}"); + return Ok(()); + } + + let location = "s3://bucket/path/FAKE/file.parquet"; + // Set it to a non-existent file to avoid reading the default configuration file + unsafe { + std::env::set_var("AWS_CONFIG_FILE", "data/aws.config"); + std::env::set_var("AWS_SHARED_CREDENTIALS_FILE", "data/aws.credentials"); + } + + // No options + let table_url = ListingTableUrl::parse(location)?; + let scheme = table_url.scheme(); + let sql = + format!("CREATE EXTERNAL TABLE test STORED AS PARQUET LOCATION '{location}'"); + + let ctx = SessionContext::new(); + ctx.register_table_options_extension_from_scheme(scheme); + let table_options = get_table_options(&ctx, &sql).await; + let aws_options = table_options.extensions.get::().unwrap(); + let builder = + get_s3_object_store_builder(table_url.as_ref(), aws_options, false).await?; + + // If the environment variables are set (as they are in CI) use them + let expected_access_key_id = std::env::var("AWS_ACCESS_KEY_ID").ok(); + let expected_secret_access_key = std::env::var("AWS_SECRET_ACCESS_KEY").ok(); + let expected_region = Some( + std::env::var("AWS_REGION").unwrap_or_else(|_| "eu-central-1".to_string()), + ); + let expected_endpoint = std::env::var("AWS_ENDPOINT").ok(); + + // get the actual configuration information, then assert_eq! + assert_eq!( + builder.get_config_value(&AmazonS3ConfigKey::AccessKeyId), + expected_access_key_id + ); + assert_eq!( + builder.get_config_value(&AmazonS3ConfigKey::SecretAccessKey), + expected_secret_access_key + ); + // Default is to skip signature when no credentials are provided + let expected_skip_signature = + if expected_access_key_id.is_none() && expected_secret_access_key.is_none() { + Some(String::from("true")) + } else { + Some(String::from("false")) + }; + assert_eq!( + builder.get_config_value(&AmazonS3ConfigKey::Region), + expected_region + ); + assert_eq!( + builder.get_config_value(&AmazonS3ConfigKey::Endpoint), + expected_endpoint + ); + assert_eq!(builder.get_config_value(&AmazonS3ConfigKey::Token), None); + assert_eq!( + builder.get_config_value(&AmazonS3ConfigKey::SkipSignature), + expected_skip_signature + ); + Ok(()) + } + #[tokio::test] async fn s3_object_store_builder() -> Result<()> { // "fake" is uppercase to ensure the values are not lowercased when parsed @@ -493,29 +683,27 @@ mod tests { ); let ctx = SessionContext::new(); - let mut plan = ctx.state().create_logical_plan(&sql).await?; - - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { - ctx.register_table_options_extension_from_scheme(scheme); - let mut table_options = ctx.state().default_table_options(); - table_options.alter_with_string_hash_map(&cmd.options)?; - let aws_options = table_options.extensions.get::().unwrap(); - let builder = - get_s3_object_store_builder(table_url.as_ref(), aws_options).await?; - // get the actual configuration information, then assert_eq! - let config = [ - (AmazonS3ConfigKey::AccessKeyId, access_key_id), - (AmazonS3ConfigKey::SecretAccessKey, secret_access_key), - (AmazonS3ConfigKey::Region, region), - (AmazonS3ConfigKey::Endpoint, endpoint), - (AmazonS3ConfigKey::Token, session_token), - ]; - for (key, value) in config { - assert_eq!(value, builder.get_config_value(&key).unwrap()); - } - } else { - return plan_err!("LogicalPlan is not a CreateExternalTable"); + ctx.register_table_options_extension_from_scheme(scheme); + let table_options = get_table_options(&ctx, &sql).await; + let aws_options = table_options.extensions.get::().unwrap(); + let builder = + get_s3_object_store_builder(table_url.as_ref(), aws_options, false).await?; + // get the actual configuration information, then assert_eq! + let config = [ + (AmazonS3ConfigKey::AccessKeyId, access_key_id), + (AmazonS3ConfigKey::SecretAccessKey, secret_access_key), + (AmazonS3ConfigKey::Region, region), + (AmazonS3ConfigKey::Endpoint, endpoint), + (AmazonS3ConfigKey::Token, session_token), + ]; + for (key, value) in config { + assert_eq!(value, builder.get_config_value(&key).unwrap()); } + // Should not skip signature when credentials are provided + assert_eq!( + builder.get_config_value(&AmazonS3ConfigKey::SkipSignature), + Some("false".into()) + ); Ok(()) } @@ -538,21 +726,18 @@ mod tests { ); let ctx = SessionContext::new(); - let mut plan = ctx.state().create_logical_plan(&sql).await?; - - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { - ctx.register_table_options_extension_from_scheme(scheme); - let mut table_options = ctx.state().default_table_options(); - table_options.alter_with_string_hash_map(&cmd.options)?; - let aws_options = table_options.extensions.get::().unwrap(); - let err = get_s3_object_store_builder(table_url.as_ref(), aws_options) - .await - .unwrap_err(); + ctx.register_table_options_extension_from_scheme(scheme); - assert_eq!(err.to_string().lines().next().unwrap_or_default(), "Invalid or Unsupported Configuration: Invalid endpoint: http://endpoint33. HTTP is not allowed for S3 endpoints. To allow HTTP, set 'aws.allow_http' to true"); - } else { - return plan_err!("LogicalPlan is not a CreateExternalTable"); - } + let table_options = get_table_options(&ctx, &sql).await; + let aws_options = table_options.extensions.get::().unwrap(); + let err = get_s3_object_store_builder(table_url.as_ref(), aws_options, false) + .await + .unwrap_err(); + + assert_eq!( + err.to_string().lines().next().unwrap_or_default(), + "Invalid or Unsupported Configuration: Invalid endpoint: http://endpoint33. HTTP is not allowed for S3 endpoints. To allow HTTP, set 'aws.allow_http' to true" + ); // Now add `allow_http` to the options and check if it works let sql = format!( @@ -563,20 +748,75 @@ mod tests { 'aws.allow_http' 'true'\ ) LOCATION '{location}'" ); + let table_options = get_table_options(&ctx, &sql).await; - let mut plan = ctx.state().create_logical_plan(&sql).await?; + let aws_options = table_options.extensions.get::().unwrap(); + // ensure this isn't an error + get_s3_object_store_builder(table_url.as_ref(), aws_options, false).await?; - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { - ctx.register_table_options_extension_from_scheme(scheme); - let mut table_options = ctx.state().default_table_options(); - table_options.alter_with_string_hash_map(&cmd.options)?; - let aws_options = table_options.extensions.get::().unwrap(); - // ensure this isn't an error - get_s3_object_store_builder(table_url.as_ref(), aws_options).await?; - } else { - return plan_err!("LogicalPlan is not a CreateExternalTable"); + Ok(()) + } + + #[tokio::test] + async fn s3_object_store_builder_resolves_region_when_none_provided() -> Result<()> { + if let Err(DataFusionError::Execution(e)) = check_aws_envs().await { + // Skip test if AWS envs are not set + eprintln!("{e}"); + return Ok(()); + } + let location = "s3://test-bucket/path/file.parquet"; + // Set it to a non-existent file to avoid reading the default configuration file + unsafe { + std::env::set_var("AWS_CONFIG_FILE", "data/aws.config"); + } + + let table_url = ListingTableUrl::parse(location)?; + let aws_options = AwsOptions { + region: None, // No region specified - should auto-detect + ..Default::default() + }; + + let builder = + get_s3_object_store_builder(table_url.as_ref(), &aws_options, false).await?; + + // Verify that the region was auto-detected in test environment + assert!( + builder + .get_config_value(&AmazonS3ConfigKey::Region) + .is_some() + ); + + Ok(()) + } + + #[tokio::test] + async fn s3_object_store_builder_overrides_region_when_resolve_region_enabled() + -> Result<()> { + if let Err(DataFusionError::Execution(e)) = check_aws_envs().await { + // Skip test if AWS envs are not set + eprintln!("{e}"); + return Ok(()); } + let original_region = "us-east-1"; + let expected_region = "eu-central-1"; // This should be the auto-detected region + let location = "s3://test-bucket/path/file.parquet"; + + let table_url = ListingTableUrl::parse(location)?; + let aws_options = AwsOptions { + region: Some(original_region.to_string()), // Explicit region provided + ..Default::default() + }; + + let builder = + get_s3_object_store_builder(table_url.as_ref(), &aws_options, true).await?; + + // Verify that the region was overridden by auto-detection + assert_eq!( + builder.get_config_value(&AmazonS3ConfigKey::Region), + Some(expected_region.to_string()) + ); + Ok(()) } @@ -589,28 +829,24 @@ mod tests { let table_url = ListingTableUrl::parse(location)?; let scheme = table_url.scheme(); - let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}', 'aws.oss.endpoint' '{endpoint}') LOCATION '{location}'"); + let sql = format!( + "CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}', 'aws.oss.endpoint' '{endpoint}') LOCATION '{location}'" + ); let ctx = SessionContext::new(); - let mut plan = ctx.state().create_logical_plan(&sql).await?; - - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { - ctx.register_table_options_extension_from_scheme(scheme); - let mut table_options = ctx.state().default_table_options(); - table_options.alter_with_string_hash_map(&cmd.options)?; - let aws_options = table_options.extensions.get::().unwrap(); - let builder = get_oss_object_store_builder(table_url.as_ref(), aws_options)?; - // get the actual configuration information, then assert_eq! - let config = [ - (AmazonS3ConfigKey::AccessKeyId, access_key_id), - (AmazonS3ConfigKey::SecretAccessKey, secret_access_key), - (AmazonS3ConfigKey::Endpoint, endpoint), - ]; - for (key, value) in config { - assert_eq!(value, builder.get_config_value(&key).unwrap()); - } - } else { - return plan_err!("LogicalPlan is not a CreateExternalTable"); + ctx.register_table_options_extension_from_scheme(scheme); + let table_options = get_table_options(&ctx, &sql).await; + + let aws_options = table_options.extensions.get::().unwrap(); + let builder = get_oss_object_store_builder(table_url.as_ref(), aws_options)?; + // get the actual configuration information, then assert_eq! + let config = [ + (AmazonS3ConfigKey::AccessKeyId, access_key_id), + (AmazonS3ConfigKey::SecretAccessKey, secret_access_key), + (AmazonS3ConfigKey::Endpoint, endpoint), + ]; + for (key, value) in config { + assert_eq!(value, builder.get_config_value(&key).unwrap()); } Ok(()) @@ -619,40 +855,66 @@ mod tests { #[tokio::test] async fn gcs_object_store_builder() -> Result<()> { let service_account_path = "fake_service_account_path"; - let service_account_key = - "{\"private_key\": \"fake_private_key.pem\",\"client_email\":\"fake_client_email\"}"; + let service_account_key = "{\"private_key\": \"fake_private_key.pem\",\"client_email\":\"fake_client_email\"}"; let application_credentials_path = "fake_application_credentials_path"; let location = "gcs://bucket/path/file.parquet"; let table_url = ListingTableUrl::parse(location)?; let scheme = table_url.scheme(); - let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('gcp.service_account_path' '{service_account_path}', 'gcp.service_account_key' '{service_account_key}', 'gcp.application_credentials_path' '{application_credentials_path}') LOCATION '{location}'"); + let sql = format!( + "CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('gcp.service_account_path' '{service_account_path}', 'gcp.service_account_key' '{service_account_key}', 'gcp.application_credentials_path' '{application_credentials_path}') LOCATION '{location}'" + ); let ctx = SessionContext::new(); - let mut plan = ctx.state().create_logical_plan(&sql).await?; - - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { - ctx.register_table_options_extension_from_scheme(scheme); - let mut table_options = ctx.state().default_table_options(); - table_options.alter_with_string_hash_map(&cmd.options)?; - let gcp_options = table_options.extensions.get::().unwrap(); - let builder = get_gcs_object_store_builder(table_url.as_ref(), gcp_options)?; - // get the actual configuration information, then assert_eq! - let config = [ - (GoogleConfigKey::ServiceAccount, service_account_path), - (GoogleConfigKey::ServiceAccountKey, service_account_key), - ( - GoogleConfigKey::ApplicationCredentials, - application_credentials_path, - ), - ]; - for (key, value) in config { - assert_eq!(value, builder.get_config_value(&key).unwrap()); - } - } else { - return plan_err!("LogicalPlan is not a CreateExternalTable"); + ctx.register_table_options_extension_from_scheme(scheme); + let table_options = get_table_options(&ctx, &sql).await; + + let gcp_options = table_options.extensions.get::().unwrap(); + let builder = get_gcs_object_store_builder(table_url.as_ref(), gcp_options)?; + // get the actual configuration information, then assert_eq! + let config = [ + (GoogleConfigKey::ServiceAccount, service_account_path), + (GoogleConfigKey::ServiceAccountKey, service_account_key), + ( + GoogleConfigKey::ApplicationCredentials, + application_credentials_path, + ), + ]; + for (key, value) in config { + assert_eq!(value, builder.get_config_value(&key).unwrap()); } Ok(()) } + + /// Plans the `CREATE EXTERNAL TABLE` SQL statement and returns the + /// resulting resolved `CreateExternalTable` command. + async fn get_table_options(ctx: &SessionContext, sql: &str) -> TableOptions { + let mut plan = ctx.state().create_logical_plan(sql).await.unwrap(); + + let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan else { + panic!("plan is not a CreateExternalTable"); + }; + + let mut table_options = ctx.state().default_table_options(); + table_options + .alter_with_string_hash_map(&cmd.options) + .unwrap(); + table_options + } + + async fn check_aws_envs() -> Result<()> { + let aws_envs = [ + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_REGION", + "AWS_ALLOW_HTTP", + ]; + for aws_env in aws_envs { + std::env::var(aws_env).map_err(|_| { + exec_datafusion_err!("aws envs not set, skipping s3 tests") + })?; + } + Ok(()) + } } diff --git a/datafusion-cli/src/object_storage/instrumented.rs b/datafusion-cli/src/object_storage/instrumented.rs new file mode 100644 index 0000000000000..a0321cacb374b --- /dev/null +++ b/datafusion-cli/src/object_storage/instrumented.rs @@ -0,0 +1,1388 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ + cmp, fmt, + ops::AddAssign, + str::FromStr, + sync::{ + Arc, + atomic::{AtomicU8, AtomicU64, Ordering}, + }, + time::Duration, +}; + +use arrow::array::{ArrayRef, RecordBatch, StringArray}; +use arrow::util::pretty::pretty_format_batches; +use async_trait::async_trait; +use chrono::Utc; +use datafusion::{ + common::{HashMap, instant::Instant}, + error::DataFusionError, + execution::object_store::{DefaultObjectStoreRegistry, ObjectStoreRegistry}, +}; +use futures::stream::{BoxStream, Stream}; +use futures::{StreamExt, TryStreamExt}; +use object_store::{ + CopyOptions, GetOptions, GetRange, GetResult, ListResult, MultipartUpload, + ObjectMeta, ObjectStore, ObjectStoreExt, PutMultipartOptions, PutOptions, PutPayload, + PutResult, Result, path::Path, +}; +use parking_lot::{Mutex, RwLock}; +use url::Url; + +/// A stream wrapper that measures the time until the first response(item or end of stream) is yielded. +/// +/// The timer starts on the first `poll_next` call (not at stream creation) to avoid +/// measuring unrelated work between stream creation and first poll. +/// Duration is stored as nanoseconds in an `AtomicU64` (0 = not yet set). +struct TimeToFirstItemStream { + inner: S, + start: Option, + request_duration: Arc, + duration_recorded: bool, +} + +impl TimeToFirstItemStream { + fn new(inner: S, request_duration: Arc) -> Self { + Self { + inner, + start: None, + request_duration, + duration_recorded: false, + } + } +} + +impl Stream for TimeToFirstItemStream +where + S: Stream> + Unpin, +{ + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let start = *self.start.get_or_insert_with(Instant::now); + + let poll_result = std::pin::Pin::new(&mut self.inner).poll_next(cx); + + if !self.duration_recorded && poll_result.is_ready() { + self.duration_recorded = true; + let nanos = start.elapsed().as_nanos() as u64; + self.request_duration.store(nanos, Ordering::Release); + } + + poll_result + } +} + +/// The profiling mode to use for an [`InstrumentedObjectStore`] instance. Collecting profiling +/// data will have a small negative impact on both CPU and memory usage. Default is `Disabled` +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)] +pub enum InstrumentedObjectStoreMode { + /// Disable collection of profiling data + #[default] + Disabled, + /// Enable collection of profiling data and output a summary + Summary, + /// Enable collection of profiling data and output a summary and all details + Trace, +} + +impl fmt::Display for InstrumentedObjectStoreMode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{self:?}") + } +} + +impl FromStr for InstrumentedObjectStoreMode { + type Err = DataFusionError; + + fn from_str(s: &str) -> std::result::Result { + match s.to_lowercase().as_str() { + "disabled" => Ok(Self::Disabled), + "summary" => Ok(Self::Summary), + "trace" => Ok(Self::Trace), + _ => Err(DataFusionError::Execution(format!("Unrecognized mode {s}"))), + } + } +} + +impl From for InstrumentedObjectStoreMode { + fn from(value: u8) -> Self { + match value { + 1 => InstrumentedObjectStoreMode::Summary, + 2 => InstrumentedObjectStoreMode::Trace, + _ => InstrumentedObjectStoreMode::Disabled, + } + } +} + +/// Wrapped [`ObjectStore`] instances that record information for reporting on the usage of the +/// inner [`ObjectStore`] +#[derive(Debug)] +pub struct InstrumentedObjectStore { + inner: Arc, + instrument_mode: AtomicU8, + requests: Arc>>, +} + +impl InstrumentedObjectStore { + /// Returns a new [`InstrumentedObjectStore`] that wraps the provided [`ObjectStore`] + fn new(object_store: Arc, instrument_mode: AtomicU8) -> Self { + Self { + inner: object_store, + instrument_mode, + requests: Arc::new(Mutex::new(Vec::new())), + } + } + + fn set_instrument_mode(&self, mode: InstrumentedObjectStoreMode) { + self.instrument_mode.store(mode as u8, Ordering::Relaxed) + } + + /// Returns all [`RequestDetails`] accumulated in this [`InstrumentedObjectStore`] and clears + /// the stored requests + pub fn take_requests(&self) -> Vec { + let mut req = self.requests.lock(); + + req.drain(..).collect() + } + + fn enabled(&self) -> bool { + self.instrument_mode.load(Ordering::Relaxed) + != InstrumentedObjectStoreMode::Disabled as u8 + } + + async fn instrumented_put_opts( + &self, + location: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> Result { + let timestamp = Utc::now(); + let start = Instant::now(); + let size = payload.content_length(); + let ret = self.inner.put_opts(location, payload, opts).await?; + let elapsed = start.elapsed(); + + self.requests.lock().push(RequestDetails { + op: Operation::Put, + path: location.clone(), + timestamp, + duration_nanos: Arc::new(AtomicU64::new(elapsed.as_nanos() as u64)), + size: Some(size), + range: None, + extra_display: None, + }); + + Ok(ret) + } + + async fn instrumented_put_multipart( + &self, + location: &Path, + opts: PutMultipartOptions, + ) -> Result> { + let timestamp = Utc::now(); + let start = Instant::now(); + let ret = self.inner.put_multipart_opts(location, opts).await?; + let elapsed = start.elapsed(); + + self.requests.lock().push(RequestDetails { + op: Operation::Put, + path: location.clone(), + timestamp, + duration_nanos: Arc::new(AtomicU64::new(elapsed.as_nanos() as u64)), + size: None, + range: None, + extra_display: None, + }); + + Ok(ret) + } + + async fn instrumented_get_opts( + &self, + location: &Path, + options: GetOptions, + ) -> Result { + let timestamp = Utc::now(); + let range = options.range.clone(); + + let head = options.head; + let start = Instant::now(); + let ret = self.inner.get_opts(location, options).await?; + let elapsed = start.elapsed(); + + let (op, size) = if head { + (Operation::Head, None) + } else { + ( + Operation::Get, + Some((ret.range.end - ret.range.start) as usize), + ) + }; + + self.requests.lock().push(RequestDetails { + op, + path: location.clone(), + timestamp, + duration_nanos: Arc::new(AtomicU64::new(elapsed.as_nanos() as u64)), + size, + range, + extra_display: None, + }); + + Ok(ret) + } + + fn instrumented_delete_stream( + &self, + locations: BoxStream<'static, Result>, + ) -> BoxStream<'static, Result> { + let requests_captured = Arc::clone(&self.requests); + + let timestamp = Utc::now(); + let start = Instant::now(); + self.inner + .delete_stream(locations) + .and_then(move |location| { + let elapsed = start.elapsed(); + requests_captured.lock().push(RequestDetails { + op: Operation::Delete, + path: location.clone(), + timestamp, + duration_nanos: Arc::new(AtomicU64::new(elapsed.as_nanos() as u64)), + size: None, + range: None, + extra_display: None, + }); + futures::future::ok(location) + }) + .boxed() + } + + fn instrumented_list( + &self, + prefix: Option<&Path>, + ) -> BoxStream<'static, Result> { + let timestamp = Utc::now(); + let inner_stream = self.inner.list(prefix); + + let duration_nanos = Arc::new(AtomicU64::new(0)); + self.requests.lock().push(RequestDetails { + op: Operation::List, + path: prefix.cloned().unwrap_or_else(|| Path::from("")), + timestamp, + duration_nanos: Arc::clone(&duration_nanos), + size: None, + range: None, + extra_display: None, + }); + + Box::pin(TimeToFirstItemStream::new(inner_stream, duration_nanos)) + } + + async fn instrumented_list_with_delimiter( + &self, + prefix: Option<&Path>, + ) -> Result { + let timestamp = Utc::now(); + let start = Instant::now(); + let ret = self.inner.list_with_delimiter(prefix).await?; + let elapsed = start.elapsed(); + + self.requests.lock().push(RequestDetails { + op: Operation::List, + path: prefix.cloned().unwrap_or_else(|| Path::from("")), + timestamp, + duration_nanos: Arc::new(AtomicU64::new(elapsed.as_nanos() as u64)), + size: None, + range: None, + extra_display: None, + }); + + Ok(ret) + } + + async fn instrumented_copy(&self, from: &Path, to: &Path) -> Result<()> { + let timestamp = Utc::now(); + let start = Instant::now(); + self.inner.copy(from, to).await?; + let elapsed = start.elapsed(); + + self.requests.lock().push(RequestDetails { + op: Operation::Copy, + path: from.clone(), + timestamp, + duration_nanos: Arc::new(AtomicU64::new(elapsed.as_nanos() as u64)), + size: None, + range: None, + extra_display: Some(format!("copy_to: {to}")), + }); + + Ok(()) + } + + async fn instrumented_copy_if_not_exists( + &self, + from: &Path, + to: &Path, + ) -> Result<()> { + let timestamp = Utc::now(); + let start = Instant::now(); + self.inner.copy_if_not_exists(from, to).await?; + let elapsed = start.elapsed(); + + self.requests.lock().push(RequestDetails { + op: Operation::Copy, + path: from.clone(), + timestamp, + duration_nanos: Arc::new(AtomicU64::new(elapsed.as_nanos() as u64)), + size: None, + range: None, + extra_display: Some(format!("copy_to: {to}")), + }); + + Ok(()) + } +} + +impl fmt::Display for InstrumentedObjectStore { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mode: InstrumentedObjectStoreMode = + self.instrument_mode.load(Ordering::Relaxed).into(); + write!( + f, + "Instrumented Object Store: instrument_mode: {mode}, inner: {}", + self.inner + ) + } +} + +#[async_trait] +impl ObjectStore for InstrumentedObjectStore { + async fn put_opts( + &self, + location: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> Result { + if self.enabled() { + return self.instrumented_put_opts(location, payload, opts).await; + } + + self.inner.put_opts(location, payload, opts).await + } + + async fn put_multipart_opts( + &self, + location: &Path, + opts: PutMultipartOptions, + ) -> Result> { + if self.enabled() { + return self.instrumented_put_multipart(location, opts).await; + } + + self.inner.put_multipart_opts(location, opts).await + } + + async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { + if self.enabled() { + return self.instrumented_get_opts(location, options).await; + } + + self.inner.get_opts(location, options).await + } + + fn delete_stream( + &self, + locations: BoxStream<'static, Result>, + ) -> BoxStream<'static, Result> { + if self.enabled() { + return self.instrumented_delete_stream(locations); + } + + self.inner.delete_stream(locations) + } + + fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, Result> { + if self.enabled() { + return self.instrumented_list(prefix); + } + + self.inner.list(prefix) + } + + async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { + if self.enabled() { + return self.instrumented_list_with_delimiter(prefix).await; + } + + self.inner.list_with_delimiter(prefix).await + } + + async fn copy_opts( + &self, + from: &Path, + to: &Path, + options: CopyOptions, + ) -> Result<()> { + if self.enabled() { + return match options.mode { + object_store::CopyMode::Create => { + self.instrumented_copy_if_not_exists(from, to).await + } + object_store::CopyMode::Overwrite => { + self.instrumented_copy(from, to).await + } + }; + } + + self.inner.copy_opts(from, to, options).await + } +} + +/// Object store operation types tracked by [`InstrumentedObjectStore`] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub enum Operation { + Copy, + Delete, + Get, + Head, + List, + Put, +} + +impl fmt::Display for Operation { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{self:?}") + } +} + +/// Holds profiling details about individual requests made through an [`InstrumentedObjectStore`] +pub struct RequestDetails { + op: Operation, + path: Path, + timestamp: chrono::DateTime, + /// Duration stored as nanoseconds in an AtomicU64. 0 means not yet set. + duration_nanos: Arc, + size: Option, + range: Option, + extra_display: Option, +} + +impl fmt::Debug for RequestDetails { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RequestDetails") + .field("op", &self.op) + .field("path", &self.path) + .field("timestamp", &self.timestamp) + .field("duration", &self.duration()) + .field("size", &self.size) + .field("range", &self.range) + .field("extra_display", &self.extra_display) + .finish() + } +} + +impl RequestDetails { + fn duration(&self) -> Option { + let nanos = self.duration_nanos.load(Ordering::Acquire); + if nanos == 0 { + None + } else { + Some(Duration::from_nanos(nanos)) + } + } +} + +impl fmt::Display for RequestDetails { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut output_parts = vec![format!( + "{} operation={:?}", + self.timestamp.to_rfc3339(), + self.op + )]; + + if let Some(d) = self.duration() { + output_parts.push(format!("duration={:.6}s", d.as_secs_f32())); + } + if let Some(s) = self.size { + output_parts.push(format!("size={s}")); + } + if let Some(r) = &self.range { + output_parts.push(format!("range: {r}")); + } + output_parts.push(format!("path={}", self.path)); + + if let Some(ed) = &self.extra_display { + output_parts.push(ed.clone()); + } + + write!(f, "{}", output_parts.join(" ")) + } +} + +/// Summary statistics for all requests recorded in an [`InstrumentedObjectStore`] +#[derive(Default)] +pub struct RequestSummaries { + summaries: Vec, +} + +/// Display the summary as a table +impl fmt::Display for RequestSummaries { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // Don't expect an error, but avoid panicking if it happens + match pretty_format_batches(&[self.to_batch()]) { + Err(e) => { + write!(f, "Error formatting summary: {e}") + } + Ok(displayable) => { + write!(f, "{displayable}") + } + } + } +} + +impl RequestSummaries { + /// Summarizes input [`RequestDetails`] + pub fn new(requests: &[RequestDetails]) -> Self { + let mut summaries: HashMap = HashMap::new(); + for rd in requests { + match summaries.get_mut(&rd.op) { + Some(rs) => rs.push(rd), + None => { + let mut rs = RequestSummary::new(rd.op); + rs.push(rd); + summaries.insert(rd.op, rs); + } + } + } + // Convert to a Vec with consistent ordering + let mut summaries: Vec = summaries.into_values().collect(); + summaries.sort_by_key(|s| s.operation); + Self { summaries } + } + + /// Convert the summaries into a `RecordBatch` for display + /// + /// Results in a table like: + /// ```text + /// +-----------+----------+-----------+-----------+-----------+-----------+-----------+ + /// | Operation | Metric | min | max | avg | sum | count | + /// +-----------+----------+-----------+-----------+-----------+-----------+-----------+ + /// | Get | duration | 5.000000s | 5.000000s | 5.000000s | | 1 | + /// | Get | size | 100 B | 100 B | 100 B | 100 B | 1 | + /// +-----------+----------+-----------+-----------+-----------+-----------+-----------+ + /// ``` + pub fn to_batch(&self) -> RecordBatch { + let operations: StringArray = self + .iter() + .flat_map(|s| std::iter::repeat_n(Some(s.operation.to_string()), 2)) + .collect(); + let metrics: StringArray = self + .iter() + .flat_map(|_s| [Some("duration"), Some("size")]) + .collect(); + let mins: StringArray = self + .stats_iter() + .flat_map(|(duration_stats, size_stats)| { + let dur_min = + duration_stats.map(|d| format!("{:.6}s", d.min.as_secs_f32())); + let size_min = size_stats.map(|s| format!("{} B", s.min)); + [dur_min, size_min] + }) + .collect(); + let maxs: StringArray = self + .stats_iter() + .flat_map(|(duration_stats, size_stats)| { + let dur_max = + duration_stats.map(|d| format!("{:.6}s", d.max.as_secs_f32())); + let size_max = size_stats.map(|s| format!("{} B", s.max)); + [dur_max, size_max] + }) + .collect(); + let avgs: StringArray = self + .iter() + .flat_map(|s| { + let count = s.count as f32; + let duration_stats = s.duration_stats.as_ref(); + let size_stats = s.size_stats.as_ref(); + let dur_avg = duration_stats.map(|d| { + let avg = d.sum.as_secs_f32() / count; + format!("{avg:.6}s") + }); + let size_avg = size_stats.map(|s| { + let avg = s.sum as f32 / count; + format!("{avg} B") + }); + [dur_avg, size_avg] + }) + .collect(); + let sums: StringArray = self + .stats_iter() + .flat_map(|(duration_stats, size_stats)| { + // Omit a sum stat for duration in the initial + // implementation because it can be a bit misleading (at least + // at first glance). For example, particularly large queries the + // sum of the durations was often larger than the total time of + // the query itself, can be confusing without additional + // explanation (e.g. that the sum is of individual requests, + // which may be concurrent). + let dur_sum = + duration_stats.map(|d| format!("{:.6}s", d.sum.as_secs_f32())); + let size_sum = size_stats.map(|s| format!("{} B", s.sum)); + [dur_sum, size_sum] + }) + .collect(); + let counts: StringArray = self + .iter() + .flat_map(|s| { + let count = s.count.to_string(); + [Some(count.clone()), Some(count)] + }) + .collect(); + + RecordBatch::try_from_iter(vec![ + ("Operation", Arc::new(operations) as ArrayRef), + ("Metric", Arc::new(metrics) as ArrayRef), + ("min", Arc::new(mins) as ArrayRef), + ("max", Arc::new(maxs) as ArrayRef), + ("avg", Arc::new(avgs) as ArrayRef), + ("sum", Arc::new(sums) as ArrayRef), + ("count", Arc::new(counts) as ArrayRef), + ]) + .expect("Created the batch correctly") + } + + /// Return an iterator over the summaries + fn iter(&self) -> impl Iterator { + self.summaries.iter() + } + + /// Return an iterator over (duration_stats, size_stats) tuples + /// for each summary + fn stats_iter( + &self, + ) -> impl Iterator>, Option<&Stats>)> { + self.summaries + .iter() + .map(|s| (s.duration_stats.as_ref(), s.size_stats.as_ref())) + } +} + +/// Summary statistics for a particular type of [`Operation`] (e.g. `GET` or `PUT`) +/// in an [`InstrumentedObjectStore`]'s [`RequestDetails`] +pub struct RequestSummary { + operation: Operation, + count: usize, + duration_stats: Option>, + size_stats: Option>, +} + +impl RequestSummary { + fn new(operation: Operation) -> Self { + Self { + operation, + count: 0, + duration_stats: None, + size_stats: None, + } + } + fn push(&mut self, request: &RequestDetails) { + self.count += 1; + if let Some(dur) = request.duration() { + self.duration_stats.get_or_insert_default().push(dur) + } + if let Some(size) = request.size { + self.size_stats.get_or_insert_default().push(size) + } + } +} + +struct Stats> { + min: T, + max: T, + sum: T, +} + +impl> Stats { + fn push(&mut self, val: T) { + self.min = cmp::min(val, self.min); + self.max = cmp::max(val, self.max); + self.sum += val; + } +} + +impl Default for Stats { + fn default() -> Self { + Self { + min: Duration::MAX, + max: Duration::ZERO, + sum: Duration::ZERO, + } + } +} + +impl Default for Stats { + fn default() -> Self { + Self { + min: usize::MAX, + max: usize::MIN, + sum: 0, + } + } +} + +/// Provides access to [`InstrumentedObjectStore`] instances that record requests for reporting +#[derive(Debug)] +pub struct InstrumentedObjectStoreRegistry { + inner: Arc, + instrument_mode: AtomicU8, + stores: RwLock>>, +} + +impl Default for InstrumentedObjectStoreRegistry { + fn default() -> Self { + Self::new() + } +} + +impl InstrumentedObjectStoreRegistry { + /// Returns a new [`InstrumentedObjectStoreRegistry`] that wraps the provided + /// [`ObjectStoreRegistry`] + pub fn new() -> Self { + Self { + inner: Arc::new(DefaultObjectStoreRegistry::new()), + instrument_mode: AtomicU8::new(InstrumentedObjectStoreMode::default() as u8), + stores: RwLock::new(Vec::new()), + } + } + + pub fn with_profile_mode(self, mode: InstrumentedObjectStoreMode) -> Self { + self.instrument_mode.store(mode as u8, Ordering::Relaxed); + self + } + + /// Provides access to all of the [`InstrumentedObjectStore`]s managed by this + /// [`InstrumentedObjectStoreRegistry`] + pub fn stores(&self) -> Vec> { + self.stores.read().clone() + } + + /// Returns the current [`InstrumentedObjectStoreMode`] for this + /// [`InstrumentedObjectStoreRegistry`] + pub fn instrument_mode(&self) -> InstrumentedObjectStoreMode { + self.instrument_mode.load(Ordering::Relaxed).into() + } + + /// Sets the [`InstrumentedObjectStoreMode`] for this [`InstrumentedObjectStoreRegistry`] + pub fn set_instrument_mode(&self, mode: InstrumentedObjectStoreMode) { + self.instrument_mode.store(mode as u8, Ordering::Relaxed); + for s in self.stores.read().iter() { + s.set_instrument_mode(mode) + } + } +} + +impl ObjectStoreRegistry for InstrumentedObjectStoreRegistry { + fn register_store( + &self, + url: &Url, + store: Arc, + ) -> Option> { + let mode = self.instrument_mode.load(Ordering::Relaxed); + let instrumented = + Arc::new(InstrumentedObjectStore::new(store, AtomicU8::new(mode))); + self.stores.write().push(Arc::clone(&instrumented)); + self.inner.register_store(url, instrumented) + } + + fn deregister_store( + &self, + url: &Url, + ) -> datafusion::common::Result> { + self.inner.deregister_store(url) + } + + fn get_store(&self, url: &Url) -> datafusion::common::Result> { + self.inner.get_store(url) + } +} + +#[cfg(test)] +mod tests { + use futures::StreamExt; + use object_store::WriteMultipart; + + use super::*; + use insta::assert_snapshot; + + #[test] + fn instrumented_mode() { + assert!(matches!( + InstrumentedObjectStoreMode::default(), + InstrumentedObjectStoreMode::Disabled + )); + + assert!(matches!( + "dIsABleD".parse().unwrap(), + InstrumentedObjectStoreMode::Disabled + )); + assert!(matches!( + "SUmMaRy".parse().unwrap(), + InstrumentedObjectStoreMode::Summary + )); + assert!(matches!( + "TRaCe".parse().unwrap(), + InstrumentedObjectStoreMode::Trace + )); + assert!( + "does_not_exist" + .parse::() + .is_err() + ); + + assert!(matches!(0.into(), InstrumentedObjectStoreMode::Disabled)); + assert!(matches!(1.into(), InstrumentedObjectStoreMode::Summary)); + assert!(matches!(2.into(), InstrumentedObjectStoreMode::Trace)); + assert!(matches!(3.into(), InstrumentedObjectStoreMode::Disabled)); + } + + #[test] + fn instrumented_registry() { + let mut reg = InstrumentedObjectStoreRegistry::new(); + assert!(reg.stores().is_empty()); + assert_eq!( + reg.instrument_mode(), + InstrumentedObjectStoreMode::default() + ); + + reg = reg.with_profile_mode(InstrumentedObjectStoreMode::Trace); + assert_eq!(reg.instrument_mode(), InstrumentedObjectStoreMode::Trace); + + let store = object_store::memory::InMemory::new(); + let url = "mem://test".parse().unwrap(); + let registered = reg.register_store(&url, Arc::new(store)); + assert!(registered.is_none()); + + let fetched = reg.get_store(&url); + assert!(fetched.is_ok()); + assert_eq!(reg.stores().len(), 1); + } + + // Returns an `InstrumentedObjectStore` with some data loaded for testing and the path to + // access the data + async fn setup_test_store() -> (InstrumentedObjectStore, Path) { + let store = Arc::new(object_store::memory::InMemory::new()); + let mode = AtomicU8::new(InstrumentedObjectStoreMode::default() as u8); + let instrumented = InstrumentedObjectStore::new(store, mode); + + // Load the test store with some data we can read + let path = Path::from("test/data"); + let payload = PutPayload::from_static(b"test_data"); + instrumented.put(&path, payload).await.unwrap(); + + (instrumented, path) + } + + #[tokio::test] + async fn instrumented_store_get() { + let (instrumented, path) = setup_test_store().await; + + // By default no requests should be instrumented/stored + assert!(instrumented.requests.lock().is_empty()); + let _ = instrumented.get(&path).await.unwrap(); + assert!(instrumented.requests.lock().is_empty()); + + instrumented.set_instrument_mode(InstrumentedObjectStoreMode::Trace); + assert!(instrumented.requests.lock().is_empty()); + let _ = instrumented.get(&path).await.unwrap(); + assert_eq!(instrumented.requests.lock().len(), 1); + + let mut requests = instrumented.take_requests(); + assert_eq!(requests.len(), 1); + assert!(instrumented.requests.lock().is_empty()); + + let request = requests.pop().unwrap(); + assert_eq!(request.op, Operation::Get); + assert_eq!(request.path, path); + assert!(request.duration().is_some()); + assert_eq!(request.size, Some(9)); + assert_eq!(request.range, None); + assert!(request.extra_display.is_none()); + } + + #[tokio::test] + async fn instrumented_store_delete() { + let (instrumented, path) = setup_test_store().await; + + // By default no requests should be instrumented/stored + assert!(instrumented.requests.lock().is_empty()); + instrumented.delete(&path).await.unwrap(); + assert!(instrumented.requests.lock().is_empty()); + + // We need a new store so we have data to delete again + let (instrumented, path) = setup_test_store().await; + instrumented.set_instrument_mode(InstrumentedObjectStoreMode::Trace); + assert!(instrumented.requests.lock().is_empty()); + instrumented.delete(&path).await.unwrap(); + assert_eq!(instrumented.requests.lock().len(), 1); + + let mut requests = instrumented.take_requests(); + assert_eq!(requests.len(), 1); + assert!(instrumented.requests.lock().is_empty()); + + let request = requests.pop().unwrap(); + assert_eq!(request.op, Operation::Delete); + assert_eq!(request.path, path); + assert!(request.duration().is_some()); + assert!(request.size.is_none()); + assert!(request.range.is_none()); + assert!(request.extra_display.is_none()); + } + + #[tokio::test] + async fn instrumented_store_list() { + let (instrumented, path) = setup_test_store().await; + + // By default no requests should be instrumented/stored + assert!(instrumented.requests.lock().is_empty()); + let _ = instrumented.list(Some(&path)); + assert!(instrumented.requests.lock().is_empty()); + + instrumented.set_instrument_mode(InstrumentedObjectStoreMode::Trace); + assert!(instrumented.requests.lock().is_empty()); + let mut stream = instrumented.list(Some(&path)); + // Sleep between stream creation and first poll to verify the timer + // starts on first poll, not at stream creation. + let delay = Duration::from_millis(50); + tokio::time::sleep(delay).await; + let _ = stream.next().await; + assert_eq!(instrumented.requests.lock().len(), 1); + + let request = instrumented.take_requests().pop().unwrap(); + assert_eq!(request.op, Operation::List); + assert_eq!(request.path, path); + let duration = request + .duration() + .expect("duration should be set after consuming stream"); + assert!( + duration < delay, + "duration {duration:?} should exclude the {delay:?} sleep before first poll" + ); + assert!(request.size.is_none()); + assert!(request.range.is_none()); + assert!(request.extra_display.is_none()); + } + + #[tokio::test] + async fn time_to_first_item_stream_captures_inner_latency() { + let inner_delay = Duration::from_millis(50); + let inner_stream = futures::stream::once(async move { + tokio::time::sleep(inner_delay).await; + Ok(ObjectMeta { + location: Path::from("test"), + last_modified: Utc::now(), + size: 0, + e_tag: None, + version: None, + }) + }) + .boxed(); + + let duration_nanos = Arc::new(AtomicU64::new(0)); + let mut stream = Box::pin(TimeToFirstItemStream::new( + inner_stream, + Arc::clone(&duration_nanos), + )); + let _ = stream.next().await; + + let recorded = Duration::from_nanos(duration_nanos.load(Ordering::Acquire)); + assert!( + recorded >= inner_delay, + "recorded duration {recorded:?} should be >= inner stream delay {inner_delay:?}" + ); + } + + #[tokio::test] + async fn instrumented_store_list_with_delimiter() { + let (instrumented, path) = setup_test_store().await; + + // By default no requests should be instrumented/stored + assert!(instrumented.requests.lock().is_empty()); + let _ = instrumented.list_with_delimiter(Some(&path)).await.unwrap(); + assert!(instrumented.requests.lock().is_empty()); + + instrumented.set_instrument_mode(InstrumentedObjectStoreMode::Trace); + assert!(instrumented.requests.lock().is_empty()); + let _ = instrumented.list_with_delimiter(Some(&path)).await.unwrap(); + assert_eq!(instrumented.requests.lock().len(), 1); + + let request = instrumented.take_requests().pop().unwrap(); + assert_eq!(request.op, Operation::List); + assert_eq!(request.path, path); + assert!(request.duration().is_some()); + assert!(request.size.is_none()); + assert!(request.range.is_none()); + assert!(request.extra_display.is_none()); + } + + #[tokio::test] + async fn instrumented_store_put_opts() { + // The `setup_test_store()` method comes with data already `put` into it, so we'll setup + // manually for this test + let store = Arc::new(object_store::memory::InMemory::new()); + let mode = AtomicU8::new(InstrumentedObjectStoreMode::default() as u8); + let instrumented = InstrumentedObjectStore::new(store, mode); + + let path = Path::from("test/data"); + let payload = PutPayload::from_static(b"test_data"); + let size = payload.content_length(); + + // By default no requests should be instrumented/stored + assert!(instrumented.requests.lock().is_empty()); + instrumented.put(&path, payload.clone()).await.unwrap(); + assert!(instrumented.requests.lock().is_empty()); + + instrumented.set_instrument_mode(InstrumentedObjectStoreMode::Trace); + assert!(instrumented.requests.lock().is_empty()); + instrumented.put(&path, payload).await.unwrap(); + assert_eq!(instrumented.requests.lock().len(), 1); + + let request = instrumented.take_requests().pop().unwrap(); + assert_eq!(request.op, Operation::Put); + assert_eq!(request.path, path); + assert!(request.duration().is_some()); + assert_eq!(request.size.unwrap(), size); + assert!(request.range.is_none()); + assert!(request.extra_display.is_none()); + } + + #[tokio::test] + async fn instrumented_store_put_multipart() { + // The `setup_test_store()` method comes with data already `put` into it, so we'll setup + // manually for this test + let store = Arc::new(object_store::memory::InMemory::new()); + let mode = AtomicU8::new(InstrumentedObjectStoreMode::default() as u8); + let instrumented = InstrumentedObjectStore::new(store, mode); + + let path = Path::from("test/data"); + + // By default no requests should be instrumented/stored + assert!(instrumented.requests.lock().is_empty()); + let mp = instrumented.put_multipart(&path).await.unwrap(); + let mut write = WriteMultipart::new(mp); + write.write(b"test_data"); + write.finish().await.unwrap(); + assert!(instrumented.requests.lock().is_empty()); + + instrumented.set_instrument_mode(InstrumentedObjectStoreMode::Trace); + assert!(instrumented.requests.lock().is_empty()); + let mp = instrumented.put_multipart(&path).await.unwrap(); + let mut write = WriteMultipart::new(mp); + write.write(b"test_data"); + write.finish().await.unwrap(); + assert_eq!(instrumented.requests.lock().len(), 1); + + let request = instrumented.take_requests().pop().unwrap(); + assert_eq!(request.op, Operation::Put); + assert_eq!(request.path, path); + assert!(request.duration().is_some()); + assert!(request.size.is_none()); + assert!(request.range.is_none()); + assert!(request.extra_display.is_none()); + } + + #[tokio::test] + async fn instrumented_store_copy() { + let (instrumented, path) = setup_test_store().await; + let copy_to = Path::from("test/copied"); + + // By default no requests should be instrumented/stored + assert!(instrumented.requests.lock().is_empty()); + instrumented.copy(&path, ©_to).await.unwrap(); + assert!(instrumented.requests.lock().is_empty()); + + instrumented.set_instrument_mode(InstrumentedObjectStoreMode::Trace); + assert!(instrumented.requests.lock().is_empty()); + instrumented.copy(&path, ©_to).await.unwrap(); + assert_eq!(instrumented.requests.lock().len(), 1); + + let mut requests = instrumented.take_requests(); + assert_eq!(requests.len(), 1); + assert!(instrumented.requests.lock().is_empty()); + + let request = requests.pop().unwrap(); + assert_eq!(request.op, Operation::Copy); + assert_eq!(request.path, path); + assert!(request.duration().is_some()); + assert!(request.size.is_none()); + assert!(request.range.is_none()); + assert_eq!( + request.extra_display.unwrap(), + format!("copy_to: {copy_to}") + ); + } + + #[tokio::test] + async fn instrumented_store_copy_if_not_exists() { + let (instrumented, path) = setup_test_store().await; + let mut copy_to = Path::from("test/copied"); + + // By default no requests should be instrumented/stored + assert!(instrumented.requests.lock().is_empty()); + instrumented + .copy_if_not_exists(&path, ©_to) + .await + .unwrap(); + assert!(instrumented.requests.lock().is_empty()); + + // Use a new destination since the previous one already exists + copy_to = Path::from("test/copied_again"); + instrumented.set_instrument_mode(InstrumentedObjectStoreMode::Trace); + assert!(instrumented.requests.lock().is_empty()); + instrumented + .copy_if_not_exists(&path, ©_to) + .await + .unwrap(); + assert_eq!(instrumented.requests.lock().len(), 1); + + let mut requests = instrumented.take_requests(); + assert_eq!(requests.len(), 1); + assert!(instrumented.requests.lock().is_empty()); + + let request = requests.pop().unwrap(); + assert_eq!(request.op, Operation::Copy); + assert_eq!(request.path, path); + assert!(request.duration().is_some()); + assert!(request.size.is_none()); + assert!(request.range.is_none()); + assert_eq!( + request.extra_display.unwrap(), + format!("copy_to: {copy_to}") + ); + } + + #[tokio::test] + async fn instrumented_store_head() { + let (instrumented, path) = setup_test_store().await; + + // By default no requests should be instrumented/stored + assert!(instrumented.requests.lock().is_empty()); + let _ = instrumented.head(&path).await.unwrap(); + assert!(instrumented.requests.lock().is_empty()); + + instrumented.set_instrument_mode(InstrumentedObjectStoreMode::Trace); + assert!(instrumented.requests.lock().is_empty()); + let _ = instrumented.head(&path).await.unwrap(); + assert_eq!(instrumented.requests.lock().len(), 1); + + let mut requests = instrumented.take_requests(); + assert_eq!(requests.len(), 1); + assert!(instrumented.requests.lock().is_empty()); + + let request = requests.pop().unwrap(); + assert_eq!(request.op, Operation::Head); + assert_eq!(request.path, path); + assert!(request.duration().is_some()); + assert!(request.size.is_none()); + assert!(request.range.is_none()); + assert!(request.extra_display.is_none()); + } + + #[test] + fn request_details() { + let rd = RequestDetails { + op: Operation::Get, + path: Path::from("test"), + timestamp: chrono::DateTime::from_timestamp(0, 0).unwrap(), + duration_nanos: Arc::new(AtomicU64::new( + Duration::new(5, 0).as_nanos() as u64 + )), + size: Some(10), + range: Some((..10).into()), + extra_display: Some(String::from("extra info")), + }; + + assert_eq!( + format!("{rd}"), + "1970-01-01T00:00:00+00:00 operation=Get duration=5.000000s size=10 range: bytes=0-9 path=test extra info" + ); + } + + #[test] + fn request_summary() { + // Test empty request list + let mut requests = Vec::new(); + assert_snapshot!(RequestSummaries::new(&requests), @r" + +-----------+--------+-----+-----+-----+-----+-------+ + | Operation | Metric | min | max | avg | sum | count | + +-----------+--------+-----+-----+-----+-----+-------+ + +-----------+--------+-----+-----+-----+-----+-------+ + "); + + requests.push(RequestDetails { + op: Operation::Get, + path: Path::from("test1"), + timestamp: chrono::DateTime::from_timestamp(0, 0).unwrap(), + duration_nanos: Arc::new(AtomicU64::new( + Duration::from_secs(5).as_nanos() as u64 + )), + size: Some(100), + range: None, + extra_display: None, + }); + + assert_snapshot!(RequestSummaries::new(&requests), @r" + +-----------+----------+-----------+-----------+-----------+-----------+-------+ + | Operation | Metric | min | max | avg | sum | count | + +-----------+----------+-----------+-----------+-----------+-----------+-------+ + | Get | duration | 5.000000s | 5.000000s | 5.000000s | 5.000000s | 1 | + | Get | size | 100 B | 100 B | 100 B | 100 B | 1 | + +-----------+----------+-----------+-----------+-----------+-----------+-------+ + "); + + // Add more Get requests to test aggregation + requests.push(RequestDetails { + op: Operation::Get, + path: Path::from("test2"), + timestamp: chrono::DateTime::from_timestamp(1, 0).unwrap(), + duration_nanos: Arc::new(AtomicU64::new( + Duration::from_secs(8).as_nanos() as u64 + )), + size: Some(150), + range: None, + extra_display: None, + }); + requests.push(RequestDetails { + op: Operation::Get, + path: Path::from("test3"), + timestamp: chrono::DateTime::from_timestamp(2, 0).unwrap(), + duration_nanos: Arc::new(AtomicU64::new( + Duration::from_secs(2).as_nanos() as u64 + )), + size: Some(50), + range: None, + extra_display: None, + }); + assert_snapshot!(RequestSummaries::new(&requests), @r" + +-----------+----------+-----------+-----------+-----------+------------+-------+ + | Operation | Metric | min | max | avg | sum | count | + +-----------+----------+-----------+-----------+-----------+------------+-------+ + | Get | duration | 2.000000s | 8.000000s | 5.000000s | 15.000000s | 3 | + | Get | size | 50 B | 150 B | 100 B | 300 B | 3 | + +-----------+----------+-----------+-----------+-----------+------------+-------+ + "); + + // Add Put requests to test grouping + requests.push(RequestDetails { + op: Operation::Put, + path: Path::from("test4"), + timestamp: chrono::DateTime::from_timestamp(3, 0).unwrap(), + duration_nanos: Arc::new(AtomicU64::new( + Duration::from_millis(200).as_nanos() as u64, + )), + size: Some(75), + range: None, + extra_display: None, + }); + + assert_snapshot!(RequestSummaries::new(&requests), @r" + +-----------+----------+-----------+-----------+-----------+------------+-------+ + | Operation | Metric | min | max | avg | sum | count | + +-----------+----------+-----------+-----------+-----------+------------+-------+ + | Get | duration | 2.000000s | 8.000000s | 5.000000s | 15.000000s | 3 | + | Get | size | 50 B | 150 B | 100 B | 300 B | 3 | + | Put | duration | 0.200000s | 0.200000s | 0.200000s | 0.200000s | 1 | + | Put | size | 75 B | 75 B | 75 B | 75 B | 1 | + +-----------+----------+-----------+-----------+-----------+------------+-------+ + "); + } + + #[test] + fn request_summary_only_duration() { + // Test request with only duration (no size) + let only_duration = vec![RequestDetails { + op: Operation::Get, + path: Path::from("test1"), + timestamp: chrono::DateTime::from_timestamp(0, 0).unwrap(), + duration_nanos: Arc::new(AtomicU64::new( + Duration::from_secs(3).as_nanos() as u64 + )), + size: None, + range: None, + extra_display: None, + }]; + assert_snapshot!(RequestSummaries::new(&only_duration), @r" + +-----------+----------+-----------+-----------+-----------+-----------+-------+ + | Operation | Metric | min | max | avg | sum | count | + +-----------+----------+-----------+-----------+-----------+-----------+-------+ + | Get | duration | 3.000000s | 3.000000s | 3.000000s | 3.000000s | 1 | + | Get | size | | | | | 1 | + +-----------+----------+-----------+-----------+-----------+-----------+-------+ + "); + } + + #[test] + fn request_summary_only_size() { + // Test request with only size (no duration) + let only_size = vec![RequestDetails { + op: Operation::Get, + path: Path::from("test1"), + timestamp: chrono::DateTime::from_timestamp(0, 0).unwrap(), + duration_nanos: Arc::new(AtomicU64::new(0)), + size: Some(200), + range: None, + extra_display: None, + }]; + assert_snapshot!(RequestSummaries::new(&only_size), @r" + +-----------+----------+-------+-------+-------+-------+-------+ + | Operation | Metric | min | max | avg | sum | count | + +-----------+----------+-------+-------+-------+-------+-------+ + | Get | duration | | | | | 1 | + | Get | size | 200 B | 200 B | 200 B | 200 B | 1 | + +-----------+----------+-------+-------+-------+-------+-------+ + "); + } + + #[test] + fn request_summary_neither_duration_or_size() { + // Test request with neither duration nor size + let no_stats = vec![RequestDetails { + op: Operation::Get, + path: Path::from("test1"), + timestamp: chrono::DateTime::from_timestamp(0, 0).unwrap(), + duration_nanos: Arc::new(AtomicU64::new(0)), + size: None, + range: None, + extra_display: None, + }]; + assert_snapshot!(RequestSummaries::new(&no_stats), @r" + +-----------+----------+-----+-----+-----+-----+-------+ + | Operation | Metric | min | max | avg | sum | count | + +-----------+----------+-----+-----+-----+-----+-------+ + | Get | duration | | | | | 1 | + | Get | size | | | | | 1 | + +-----------+----------+-----+-----+-----+-----+-------+ + "); + } +} diff --git a/datafusion-cli/src/print_format.rs b/datafusion-cli/src/print_format.rs index 1d6a8396aee74..6a6a0370b08ac 100644 --- a/datafusion-cli/src/print_format.rs +++ b/datafusion-cli/src/print_format.rs @@ -221,6 +221,7 @@ mod tests { use arrow::array::Int32Array; use arrow::datatypes::{DataType, Field, Schema}; + use insta::{allow_duplicates, assert_snapshot}; #[test] fn print_empty() { @@ -232,249 +233,202 @@ mod tests { PrintFormat::Automatic, ] { // no output for empty batches, even with header set - PrintBatchesTest::new() + let output = PrintBatchesTest::new() .with_format(format) .with_schema(three_column_schema()) .with_batches(vec![]) - .with_expected(&[""]) .run(); + assert_eq!(output, "") } // output column headers for empty batches when format is Table - #[rustfmt::skip] - let expected = &[ - "+---+---+---+", - "| a | b | c |", - "+---+---+---+", - "+---+---+---+", - ]; - PrintBatchesTest::new() + let output = PrintBatchesTest::new() .with_format(PrintFormat::Table) .with_schema(three_column_schema()) .with_batches(vec![]) - .with_expected(expected) .run(); + assert_snapshot!(output, @r" + +---+---+---+ + | a | b | c | + +---+---+---+ + +---+---+---+ + "); } #[test] fn print_csv_no_header() { - #[rustfmt::skip] - let expected = &[ - "1,4,7", - "2,5,8", - "3,6,9", - ]; - - PrintBatchesTest::new() + let output = PrintBatchesTest::new() .with_format(PrintFormat::Csv) - .with_batches(split_batch(three_column_batch())) + .with_batches(split_batch(&three_column_batch())) .with_header(WithHeader::No) - .with_expected(expected) .run(); + assert_snapshot!(output, @r" + 1,4,7 + 2,5,8 + 3,6,9 + "); } #[test] fn print_csv_with_header() { - #[rustfmt::skip] - let expected = &[ - "a,b,c", - "1,4,7", - "2,5,8", - "3,6,9", - ]; - - PrintBatchesTest::new() + let output = PrintBatchesTest::new() .with_format(PrintFormat::Csv) - .with_batches(split_batch(three_column_batch())) + .with_batches(split_batch(&three_column_batch())) .with_header(WithHeader::Yes) - .with_expected(expected) .run(); + assert_snapshot!(output, @r" + a,b,c + 1,4,7 + 2,5,8 + 3,6,9 + "); } #[test] fn print_tsv_no_header() { - #[rustfmt::skip] - let expected = &[ - "1\t4\t7", - "2\t5\t8", - "3\t6\t9", - ]; - - PrintBatchesTest::new() + let output = PrintBatchesTest::new() .with_format(PrintFormat::Tsv) - .with_batches(split_batch(three_column_batch())) + .with_batches(split_batch(&three_column_batch())) .with_header(WithHeader::No) - .with_expected(expected) .run(); + assert_snapshot!(output, @r" + 1 4 7 + 2 5 8 + 3 6 9 + ") } #[test] fn print_tsv_with_header() { - #[rustfmt::skip] - let expected = &[ - "a\tb\tc", - "1\t4\t7", - "2\t5\t8", - "3\t6\t9", - ]; - - PrintBatchesTest::new() + let output = PrintBatchesTest::new() .with_format(PrintFormat::Tsv) - .with_batches(split_batch(three_column_batch())) + .with_batches(split_batch(&three_column_batch())) .with_header(WithHeader::Yes) - .with_expected(expected) .run(); + assert_snapshot!(output, @r" + a b c + 1 4 7 + 2 5 8 + 3 6 9 + "); } #[test] fn print_table() { - let expected = &[ - "+---+---+---+", - "| a | b | c |", - "+---+---+---+", - "| 1 | 4 | 7 |", - "| 2 | 5 | 8 |", - "| 3 | 6 | 9 |", - "+---+---+---+", - ]; - - PrintBatchesTest::new() + let output = PrintBatchesTest::new() .with_format(PrintFormat::Table) - .with_batches(split_batch(three_column_batch())) + .with_batches(split_batch(&three_column_batch())) .with_header(WithHeader::Ignored) - .with_expected(expected) .run(); + assert_snapshot!(output, @r" + +---+---+---+ + | a | b | c | + +---+---+---+ + | 1 | 4 | 7 | + | 2 | 5 | 8 | + | 3 | 6 | 9 | + +---+---+---+ + "); } #[test] fn print_json() { - let expected = - &[r#"[{"a":1,"b":4,"c":7},{"a":2,"b":5,"c":8},{"a":3,"b":6,"c":9}]"#]; - - PrintBatchesTest::new() + let output = PrintBatchesTest::new() .with_format(PrintFormat::Json) - .with_batches(split_batch(three_column_batch())) + .with_batches(split_batch(&three_column_batch())) .with_header(WithHeader::Ignored) - .with_expected(expected) .run(); + assert_snapshot!(output, @r#"[{"a":1,"b":4,"c":7},{"a":2,"b":5,"c":8},{"a":3,"b":6,"c":9}]"#); } #[test] fn print_ndjson() { - let expected = &[ - r#"{"a":1,"b":4,"c":7}"#, - r#"{"a":2,"b":5,"c":8}"#, - r#"{"a":3,"b":6,"c":9}"#, - ]; - - PrintBatchesTest::new() + let output = PrintBatchesTest::new() .with_format(PrintFormat::NdJson) - .with_batches(split_batch(three_column_batch())) + .with_batches(split_batch(&three_column_batch())) .with_header(WithHeader::Ignored) - .with_expected(expected) .run(); + assert_snapshot!(output, @r#" + {"a":1,"b":4,"c":7} + {"a":2,"b":5,"c":8} + {"a":3,"b":6,"c":9} + "#); } #[test] fn print_automatic_no_header() { - #[rustfmt::skip] - let expected = &[ - "1,4,7", - "2,5,8", - "3,6,9", - ]; - - PrintBatchesTest::new() + let output = PrintBatchesTest::new() .with_format(PrintFormat::Automatic) - .with_batches(split_batch(three_column_batch())) + .with_batches(split_batch(&three_column_batch())) .with_header(WithHeader::No) - .with_expected(expected) .run(); + assert_snapshot!(output, @r" + 1,4,7 + 2,5,8 + 3,6,9 + "); } #[test] fn print_automatic_with_header() { - #[rustfmt::skip] - let expected = &[ - "a,b,c", - "1,4,7", - "2,5,8", - "3,6,9", - ]; - - PrintBatchesTest::new() + let output = PrintBatchesTest::new() .with_format(PrintFormat::Automatic) - .with_batches(split_batch(three_column_batch())) + .with_batches(split_batch(&three_column_batch())) .with_header(WithHeader::Yes) - .with_expected(expected) .run(); + assert_snapshot!(output, @r" + a,b,c + 1,4,7 + 2,5,8 + 3,6,9 + "); } #[test] fn print_maxrows_unlimited() { - #[rustfmt::skip] - let expected = &[ - "+---+", - "| a |", - "+---+", - "| 1 |", - "| 2 |", - "| 3 |", - "+---+", - ]; - // should print out entire output with no truncation if unlimited or // limit greater than number of batches or equal to the number of batches for max_rows in [MaxRows::Unlimited, MaxRows::Limited(5), MaxRows::Limited(3)] { - PrintBatchesTest::new() + let output = PrintBatchesTest::new() .with_format(PrintFormat::Table) .with_schema(one_column_schema()) .with_batches(vec![one_column_batch()]) .with_maxrows(max_rows) - .with_expected(expected) .run(); + allow_duplicates! { + assert_snapshot!(output, @r" + +---+ + | a | + +---+ + | 1 | + | 2 | + | 3 | + +---+ + "); + } } } #[test] fn print_maxrows_limited_one_batch() { - #[rustfmt::skip] - let expected = &[ - "+---+", - "| a |", - "+---+", - "| 1 |", - "| . |", - "| . |", - "| . |", - "+---+", - ]; - - PrintBatchesTest::new() + let output = PrintBatchesTest::new() .with_format(PrintFormat::Table) .with_batches(vec![one_column_batch()]) .with_maxrows(MaxRows::Limited(1)) - .with_expected(expected) .run(); + assert_snapshot!(output, @r" + +---+ + | a | + +---+ + | 1 | + | . | + | . | + | . | + +---+ + "); } #[test] fn print_maxrows_limited_multi_batched() { - #[rustfmt::skip] - let expected = &[ - "+---+", - "| a |", - "+---+", - "| 1 |", - "| 2 |", - "| 3 |", - "| 1 |", - "| 2 |", - "| . |", - "| . |", - "| . |", - "+---+", - ]; - - PrintBatchesTest::new() + let output = PrintBatchesTest::new() .with_format(PrintFormat::Table) .with_batches(vec![ one_column_batch(), @@ -482,8 +436,21 @@ mod tests { one_column_batch(), ]) .with_maxrows(MaxRows::Limited(5)) - .with_expected(expected) .run(); + assert_snapshot!(output, @r" + +---+ + | a | + +---+ + | 1 | + | 2 | + | 3 | + | 1 | + | 2 | + | . | + | . | + | . | + +---+ + "); } #[test] @@ -491,22 +458,19 @@ mod tests { let batch = one_column_batch(); let empty_batch = RecordBatch::new_empty(batch.schema()); - #[rustfmt::skip] - let expected =&[ - "+---+", - "| a |", - "+---+", - "| 1 |", - "| 2 |", - "| 3 |", - "+---+", - ]; - - PrintBatchesTest::new() + let output = PrintBatchesTest::new() .with_format(PrintFormat::Table) .with_batches(vec![empty_batch.clone(), batch, empty_batch]) - .with_expected(expected) .run(); + assert_snapshot!(output, @r" + +---+ + | a | + +---+ + | 1 | + | 2 | + | 3 | + +---+ + "); } #[test] @@ -514,32 +478,28 @@ mod tests { let empty_batch = RecordBatch::new_empty(one_column_batch().schema()); // Print column headers for empty batch when format is Table - #[rustfmt::skip] - let expected =&[ - "+---+", - "| a |", - "+---+", - "+---+", - ]; - - PrintBatchesTest::new() + let output = PrintBatchesTest::new() .with_format(PrintFormat::Table) .with_schema(one_column_schema()) .with_batches(vec![empty_batch]) .with_header(WithHeader::Yes) - .with_expected(expected) .run(); + assert_snapshot!(output, @r" + +---+ + | a | + +---+ + +---+ + "); // No output for empty batch when schema contains no columns let empty_batch = RecordBatch::new_empty(Arc::new(Schema::empty())); - let expected = &[""]; - PrintBatchesTest::new() + let output = PrintBatchesTest::new() .with_format(PrintFormat::Table) .with_schema(Arc::new(Schema::empty())) .with_batches(vec![empty_batch]) .with_header(WithHeader::Yes) - .with_expected(expected) .run(); + assert_eq!(output, "") } #[derive(Debug)] @@ -549,7 +509,6 @@ mod tests { batches: Vec, maxrows: MaxRows, with_header: WithHeader, - expected: Vec<&'static str>, } /// How to test with_header @@ -569,7 +528,6 @@ mod tests { batches: vec![], maxrows: MaxRows::Unlimited, with_header: WithHeader::Ignored, - expected: vec![], } } @@ -603,25 +561,9 @@ mod tests { self } - /// set expected output - fn with_expected(mut self, expected: &[&'static str]) -> Self { - self.expected = expected.to_vec(); - self - } - /// run the test - fn run(self) { - let actual = self.output(); - let actual: Vec<_> = actual.trim_end().split('\n').collect(); - let expected = self.expected; - assert_eq!( - actual, expected, - "\n\nactual:\n{actual:#?}\n\nexpected:\n{expected:#?}" - ); - } - /// formats batches using parameters and returns the resulting output - fn output(&self) -> String { + fn run(self) -> String { match self.with_header { WithHeader::Yes => self.output_with_header(true), WithHeader::No => self.output_with_header(false), @@ -691,7 +633,7 @@ mod tests { } /// Slice the record batch into 2 batches - fn split_batch(batch: RecordBatch) -> Vec { + fn split_batch(batch: &RecordBatch) -> Vec { assert!(batch.num_rows() > 1); let split = batch.num_rows() / 2; vec![ diff --git a/datafusion-cli/src/print_options.rs b/datafusion-cli/src/print_options.rs index 56d787b0fe087..d0810cb034df1 100644 --- a/datafusion-cli/src/print_options.rs +++ b/datafusion-cli/src/print_options.rs @@ -16,16 +16,20 @@ // under the License. use std::fmt::{Display, Formatter}; -use std::io::Write; +use std::io; use std::pin::Pin; use std::str::FromStr; +use std::sync::Arc; +use crate::object_storage::instrumented::{ + InstrumentedObjectStoreMode, InstrumentedObjectStoreRegistry, RequestSummaries, +}; use crate::print_format::PrintFormat; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use datafusion::common::instant::Instant; use datafusion::common::DataFusionError; +use datafusion::common::instant::Instant; use datafusion::error::Result; use datafusion::physical_plan::RecordBatchStream; @@ -51,8 +55,10 @@ impl FromStr for MaxRows { Ok(Self::Unlimited) } else { match maxrows.parse::() { - Ok(nrows) => Ok(Self::Limited(nrows)), - _ => Err(format!("Invalid maxrows {maxrows}. Valid inputs are natural numbers or \'none\', \'inf\', or \'infinite\' for no limit.")), + Ok(nrows) => Ok(Self::Limited(nrows)), + _ => Err(format!( + "Invalid maxrows {maxrows}. Valid inputs are natural numbers or \'none\', \'inf\', or \'infinite\' for no limit." + )), } } } @@ -67,12 +73,15 @@ impl Display for MaxRows { } } +const OBJECT_STORE_PROFILING_HEADER: &str = "Object Store Profiling"; + #[derive(Debug, Clone)] pub struct PrintOptions { pub format: PrintFormat, pub quiet: bool, pub maxrows: MaxRows, pub color: bool, + pub instrumented_registry: Arc, } // Returns the query execution details formatted @@ -106,7 +115,7 @@ impl PrintOptions { row_count: usize, format_options: &FormatOptions, ) -> Result<()> { - let stdout = std::io::stdout(); + let stdout = io::stdout(); let mut writer = stdout.lock(); self.format.print_batches( @@ -128,11 +137,7 @@ impl PrintOptions { query_start_time, ); - if !self.quiet { - writeln!(writer, "{formatted_exec_details}")?; - } - - Ok(()) + self.write_output(&mut writer, &formatted_exec_details) } /// Print the stream to stdout using the specified format @@ -148,7 +153,7 @@ impl PrintOptions { )); }; - let stdout = std::io::stdout(); + let stdout = io::stdout(); let mut writer = stdout.lock(); let mut row_count = 0_usize; @@ -174,10 +179,88 @@ impl PrintOptions { query_start_time, ); + self.write_output(&mut writer, &formatted_exec_details) + } + + fn write_output( + &self, + writer: &mut W, + formatted_exec_details: &str, + ) -> Result<()> { if !self.quiet { writeln!(writer, "{formatted_exec_details}")?; + + let instrument_mode = self.instrumented_registry.instrument_mode(); + if instrument_mode != InstrumentedObjectStoreMode::Disabled { + writeln!(writer, "{OBJECT_STORE_PROFILING_HEADER}")?; + for store in self.instrumented_registry.stores() { + let requests = store.take_requests(); + + if !requests.is_empty() { + writeln!(writer, "{store}")?; + if instrument_mode == InstrumentedObjectStoreMode::Trace { + for req in requests.iter() { + writeln!(writer, "{req}")?; + } + // Add an extra blank line to help visually organize the output + writeln!(writer)?; + } + + writeln!(writer, "Summaries:")?; + let summaries = RequestSummaries::new(&requests); + writeln!(writer, "{summaries}")?; + } + } + } } Ok(()) } } + +#[cfg(test)] +mod tests { + use datafusion::error::Result; + + use super::*; + + #[test] + fn write_output() -> Result<()> { + let instrumented_registry = Arc::new(InstrumentedObjectStoreRegistry::new()); + let mut print_options = PrintOptions { + format: PrintFormat::Automatic, + quiet: true, + maxrows: MaxRows::Unlimited, + color: true, + instrumented_registry: Arc::clone(&instrumented_registry), + }; + + let mut print_output: Vec = Vec::new(); + let exec_out = String::from("Formatted Exec Output"); + print_options.write_output(&mut print_output, &exec_out)?; + assert!(print_output.is_empty()); + + print_options.quiet = false; + print_options.write_output(&mut print_output, &exec_out)?; + let out_str: String = print_output + .clone() + .try_into() + .expect("Expected successful String conversion"); + assert!(out_str.contains(&exec_out)); + + // clear the previous data from the output so it doesn't pollute the next test + print_output.clear(); + print_options + .instrumented_registry + .set_instrument_mode(InstrumentedObjectStoreMode::Trace); + print_options.write_output(&mut print_output, &exec_out)?; + let out_str: String = print_output + .clone() + .try_into() + .expect("Expected successful String conversion"); + assert!(out_str.contains(&exec_out)); + assert!(out_str.contains(OBJECT_STORE_PROFILING_HEADER)); + + Ok(()) + } +} diff --git a/datafusion-cli/tests/cli_integration.rs b/datafusion-cli/tests/cli_integration.rs index fb2f08157f674..7bc45693a8b0d 100644 --- a/datafusion-cli/tests/cli_integration.rs +++ b/datafusion-cli/tests/cli_integration.rs @@ -19,9 +19,18 @@ use std::process::Command; use rstest::rstest; -use insta::{glob, Settings}; +use async_trait::async_trait; +use insta::internals::SettingsBindDropGuard; +use insta::{Settings, glob}; use insta_cmd::{assert_cmd_snapshot, get_cargo_bin}; +use std::path::PathBuf; use std::{env, fs}; +use testcontainers_modules::minio; +use testcontainers_modules::testcontainers::core::{CmdWaitFor, ExecCommand, Mount}; +use testcontainers_modules::testcontainers::runners::AsyncRunner; +use testcontainers_modules::testcontainers::{ + ContainerAsync, ImageExt, TestcontainersError, +}; fn cli() -> Command { Command::new(get_cargo_bin("datafusion-cli")) @@ -32,9 +41,85 @@ fn make_settings() -> Settings { settings.set_prepend_module_to_snapshot(false); settings.add_filter(r"Elapsed .* seconds\.", "[ELAPSED]"); settings.add_filter(r"DataFusion CLI v.*", "[CLI_VERSION]"); + settings.add_filter(r"(?s)backtrace:.*?\n\n\n", ""); settings } +async fn setup_minio_container() -> Result, String> { + const MINIO_ROOT_USER: &str = "TEST-DataFusionLogin"; + const MINIO_ROOT_PASSWORD: &str = "TEST-DataFusionPassword"; + + let data_path = + PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../datafusion/core/tests/data"); + + let absolute_data_path = data_path + .canonicalize() + .expect("Failed to get absolute path for test data"); + + let container = minio::MinIO::default() + .with_env_var("MINIO_ROOT_USER", MINIO_ROOT_USER) + .with_env_var("MINIO_ROOT_PASSWORD", MINIO_ROOT_PASSWORD) + .with_mount(Mount::bind_mount( + absolute_data_path.to_str().unwrap(), + "/source", + )) + .start() + .await; + + match container { + Ok(container) => { + // We wait for MinIO to be healthy and prepare test files. We do it via CLI to avoid s3 dependency + let commands = [ + ExecCommand::new(["/usr/bin/mc", "ready", "local"]), + ExecCommand::new([ + "/usr/bin/mc", + "alias", + "set", + "localminio", + "http://localhost:9000", + MINIO_ROOT_USER, + MINIO_ROOT_PASSWORD, + ]), + ExecCommand::new(["/usr/bin/mc", "mb", "localminio/data"]), + ExecCommand::new([ + "/usr/bin/mc", + "cp", + "-r", + "/source/", + "localminio/data/", + ]), + ]; + + for command in commands { + let command = + command.with_cmd_ready_condition(CmdWaitFor::Exit { code: Some(0) }); + + let cmd_ref = format!("{command:?}"); + + if let Err(e) = container.exec(command).await { + let stdout = container.stdout_to_vec().await.unwrap_or_default(); + let stderr = container.stderr_to_vec().await.unwrap_or_default(); + + return Err(format!( + "Failed to execute command: {}\nError: {}\nStdout: {:?}\nStderr: {:?}", + cmd_ref, + e, + String::from_utf8_lossy(&stdout), + String::from_utf8_lossy(&stderr) + )); + } + } + + Ok(container) + } + + Err(TestcontainersError::Client(e)) => Err(format!( + "Failed to start MinIO container. Ensure Docker is running and accessible: {e}" + )), + Err(e) => Err(format!("Failed to start MinIO container: {e}")), + } +} + #[cfg(test)] #[ctor::ctor] fn init() { @@ -131,13 +216,49 @@ fn test_cli_top_memory_consumers<'a>( #[case] snapshot_name: &str, #[case] top_memory_consumers: impl IntoIterator, ) { + let _bound = bind_to_settings(snapshot_name); + + let mut cmd = cli(); + let sql = "select * from generate_series(1,500000) as t1(v1) order by v1;"; + cmd.args(["--memory-limit", "10M", "--command", sql]); + cmd.args(top_memory_consumers); + + assert_cmd_snapshot!(cmd); +} + +#[rstest] +#[case("no_track", ["--top-memory-consumers", "0"])] +#[case("top2", ["--top-memory-consumers", "2"])] +#[test] +fn test_cli_top_memory_consumers_with_mem_pool_type<'a>( + #[case] snapshot_name: &str, + #[case] top_memory_consumers: impl IntoIterator, +) { + let _bound = bind_to_settings(snapshot_name); + + let mut cmd = cli(); + let sql = "select * from generate_series(1,500000) as t1(v1) order by v1;"; + cmd.args([ + "--memory-limit", + "10M", + "--mem-pool-type", + "fair", + "--command", + sql, + ]); + cmd.args(top_memory_consumers); + + assert_cmd_snapshot!(cmd); +} + +fn bind_to_settings(snapshot_name: &str) -> SettingsBindDropGuard { let mut settings = make_settings(); settings.set_snapshot_suffix(snapshot_name); settings.add_filter( - r"[^\s]+\#\d+\(can spill: (true|false)\) consumed .*?B", - "Consumer(can spill: bool) consumed XB", + r"[^\s]+\#\d+\(can spill: (true|false)\) consumed .*?B, peak .*?B", + "Consumer(can spill: bool) consumed XB, peak XB", ); settings.add_filter( r"Error: Failed to allocate additional .*? for .*? with .*? already allocated for this reservation - .*? remain available for the total pool", @@ -148,12 +269,20 @@ fn test_cli_top_memory_consumers<'a>( "Resources exhausted: Failed to allocate", ); + settings.bind_to_scope() +} + +#[test] +fn test_cli_with_unbounded_memory_pool() { + let mut settings = make_settings(); + + settings.set_snapshot_suffix("default"); + let _bound = settings.bind_to_scope(); let mut cmd = cli(); let sql = "select * from generate_series(1,500000) as t1(v1) order by v1;"; - cmd.args(["--memory-limit", "10M", "--command", sql]); - cmd.args(top_memory_consumers); + cmd.args(["--maxrows", "10", "--command", sql]); assert_cmd_snapshot!(cmd); } @@ -165,12 +294,31 @@ async fn test_cli() { return; } + let container = match setup_minio_container().await { + Ok(c) => c, + Err(e) if e.contains("toomanyrequests") => { + eprintln!("Skipping test: Docker pull rate limit reached: {e}"); + return; + } + e @ Err(_) => e.unwrap(), + }; + let settings = make_settings(); let _bound = settings.bind_to_scope(); + let port = container.get_host_port_ipv4(9000).await.unwrap(); + glob!("sql/integration/*.sql", |path| { let input = fs::read_to_string(path).unwrap(); - assert_cmd_snapshot!(cli().pass_stdin(input)) + assert_cmd_snapshot!( + cli() + .env_clear() + .env("AWS_ACCESS_KEY_ID", "TEST-DataFusionLogin") + .env("AWS_SECRET_ACCESS_KEY", "TEST-DataFusionPassword") + .env("AWS_ENDPOINT", format!("http://localhost:{port}")) + .env("AWS_ALLOW_HTTP", "true") + .pass_stdin(input) + ) }); } @@ -186,20 +334,24 @@ async fn test_aws_options() { let settings = make_settings(); let _bound = settings.bind_to_scope(); - let access_key_id = - env::var("AWS_ACCESS_KEY_ID").expect("AWS_ACCESS_KEY_ID is not set"); - let secret_access_key = - env::var("AWS_SECRET_ACCESS_KEY").expect("AWS_SECRET_ACCESS_KEY is not set"); - let endpoint_url = env::var("AWS_ENDPOINT").expect("AWS_ENDPOINT is not set"); + let container = match setup_minio_container().await { + Ok(c) => c, + Err(e) if e.contains("toomanyrequests") => { + eprintln!("Skipping test: Docker pull rate limit reached: {e}"); + return; + } + e @ Err(_) => e.unwrap(), + }; + let port = container.get_host_port_ipv4(9000).await.unwrap(); let input = format!( r#"CREATE EXTERNAL TABLE CARS STORED AS CSV LOCATION 's3://data/cars.csv' OPTIONS( - 'aws.access_key_id' '{access_key_id}', - 'aws.secret_access_key' '{secret_access_key}', - 'aws.endpoint' '{endpoint_url}', + 'aws.access_key_id' 'TEST-DataFusionLogin', + 'aws.secret_access_key' 'TEST-DataFusionPassword', + 'aws.endpoint' 'http://localhost:{port}', 'aws.allow_http' 'true' ); @@ -209,3 +361,186 @@ SELECT * FROM CARS limit 1; assert_cmd_snapshot!(cli().env_clear().pass_stdin(input)); } + +#[tokio::test] +async fn test_aws_region_auto_resolution() { + if env::var("TEST_STORAGE_INTEGRATION").is_err() { + eprintln!("Skipping external storages integration tests"); + return; + } + + let mut settings = make_settings(); + settings.add_filter(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z", "[TIME]"); + let _bound = settings.bind_to_scope(); + + let bucket = "s3://clickhouse-public-datasets/hits_compatible/athena_partitioned/hits_1.parquet"; + let region = "us-east-1"; + + let input = format!( + r#"CREATE EXTERNAL TABLE hits +STORED AS PARQUET +LOCATION '{bucket}' +OPTIONS( + 'aws.region' '{region}', + 'aws.skip_signature' true +); + +SELECT COUNT(*) FROM hits; +"# + ); + + assert_cmd_snapshot!( + cli() + .env("RUST_LOG", "warn") + .env_remove("AWS_ENDPOINT") + .pass_stdin(input) + ); +} + +/// Ensure backtrace will be printed, if executing `datafusion-cli` with a query +/// that triggers error. +/// Example: +/// RUST_BACKTRACE=1 cargo run --features backtrace -- -c 'select pow(1,'foo');' +#[rstest] +#[case("SELECT pow(1,'foo')")] +#[case("SELECT CAST('not_a_number' AS INTEGER);")] +#[cfg(feature = "backtrace")] +fn test_backtrace_output(#[case] query: &str) { + let mut cmd = cli(); + // Use a command that will cause an error and trigger backtrace + cmd.args(["--command", query, "-q"]) + .env("RUST_BACKTRACE", "1"); // Enable backtrace + + let output = cmd.output().expect("Failed to execute command"); + let stdout = String::from_utf8_lossy(&output.stdout); + let stderr = String::from_utf8_lossy(&output.stderr); + let combined_output = format!("{}{}", stdout, stderr); + + // Assert that the output includes literal 'backtrace' + assert!( + combined_output.to_lowercase().contains("backtrace"), + "Expected output to contain 'backtrace', but got stdout: '{}' stderr: '{}'", + stdout, + stderr + ); +} + +#[tokio::test] +async fn test_s3_url_fallback() { + if env::var("TEST_STORAGE_INTEGRATION").is_err() { + eprintln!("Skipping external storages integration tests"); + return; + } + + let container = match setup_minio_container().await { + Ok(c) => c, + Err(e) if e.contains("toomanyrequests") => { + eprintln!("Skipping test: Docker pull rate limit reached: {e}"); + return; + } + e @ Err(_) => e.unwrap(), + }; + + let mut settings = make_settings(); + settings.set_snapshot_suffix("s3_url_fallback"); + let _bound = settings.bind_to_scope(); + + // Create a table using a prefix path (without trailing slash) + // This should trigger the fallback logic where head() fails on the prefix + // and list() is used to discover the actual files + let input = r#"CREATE EXTERNAL TABLE partitioned_data +STORED AS CSV +LOCATION 's3://data/partitioned_csv' +OPTIONS ( + 'format.has_header' 'false' +); + +SELECT * FROM partitioned_data ORDER BY column_1, column_2 LIMIT 5; +"#; + + assert_cmd_snapshot!(cli().with_minio(&container).await.pass_stdin(input)); +} + +/// Validate object store profiling output +#[tokio::test] +async fn test_object_store_profiling() { + if env::var("TEST_STORAGE_INTEGRATION").is_err() { + eprintln!("Skipping external storages integration tests"); + return; + } + + let container = match setup_minio_container().await { + Ok(c) => c, + Err(e) if e.contains("toomanyrequests") => { + eprintln!("Skipping test: Docker pull rate limit reached: {e}"); + return; + } + e @ Err(_) => e.unwrap(), + }; + let mut settings = make_settings(); + + // as the object store profiling contains timestamps and durations, we must + // filter them out to have stable snapshots + // + // Example line to filter: + // 2025-10-11T12:02:59.722646+00:00 operation=Get duration=0.001495s size=1006 path=cars.csv + // Output: + // operation=Get duration=[DURATION] size=1006 path=cars.csv + settings.add_filter( + r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(?:\.\d+)?[+-]\d{2}:\d{2} operation=(Get|Put|Delete|List|Head) duration=\d+\.\d{6}s (size=\d+\s+)?path=(.*)", + " operation=$1 duration=[DURATION] ${2}path=$3", + ); + + // We also need to filter out the summary statistics (anything with an 's' at the end) + // Example line(s) to filter: + // | Get | duration | 5.000000s | 5.000000s | 5.000000s | | 1 | + settings.add_filter( + r"\| (Get|Put|Delete|List|Head)( +)\| duration \| .*? \| .*? \| .*? \| .*? \| (.*?) \|", + "| $1$2 | duration | ...NORMALIZED...| $3 |", + ); + + let _bound = settings.bind_to_scope(); + + let input = r#" + CREATE EXTERNAL TABLE CARS +STORED AS CSV +LOCATION 's3://data/cars.csv'; + +-- Initial query should not show any profiling as the object store is not instrumented yet +SELECT * from CARS LIMIT 1; +\object_store_profiling trace +-- Query again to see the full profiling output +SELECT * from CARS LIMIT 1; +\object_store_profiling summary +-- Query again to see the summarized profiling output +SELECT * from CARS LIMIT 1; +\object_store_profiling disabled +-- Final query should not show any profiling as we disabled it again +SELECT * from CARS LIMIT 1; +"#; + + assert_cmd_snapshot!(cli().with_minio(&container).await.pass_stdin(input)); +} + +/// Extension trait to Add the minio connection information to a Command +#[async_trait] +trait MinioCommandExt { + async fn with_minio(&mut self, container: &ContainerAsync) + -> &mut Self; +} + +#[async_trait] +impl MinioCommandExt for Command { + async fn with_minio( + &mut self, + container: &ContainerAsync, + ) -> &mut Self { + let port = container.get_host_port_ipv4(9000).await.unwrap(); + + self.env_clear() + .env("AWS_ACCESS_KEY_ID", "TEST-DataFusionLogin") + .env("AWS_SECRET_ACCESS_KEY", "TEST-DataFusionPassword") + .env("AWS_ENDPOINT", format!("http://localhost:{port}")) + .env("AWS_ALLOW_HTTP", "true") + } +} diff --git a/datafusion-cli/tests/snapshots/aws_region_auto_resolution.snap b/datafusion-cli/tests/snapshots/aws_region_auto_resolution.snap new file mode 100644 index 0000000000000..cd6d918b78d99 --- /dev/null +++ b/datafusion-cli/tests/snapshots/aws_region_auto_resolution.snap @@ -0,0 +1,29 @@ +--- +source: datafusion-cli/tests/cli_integration.rs +info: + program: datafusion-cli + args: [] + env: + AWS_ENDPOINT: "" + RUST_LOG: warn + stdin: "CREATE EXTERNAL TABLE hits\nSTORED AS PARQUET\nLOCATION 's3://clickhouse-public-datasets/hits_compatible/athena_partitioned/hits_1.parquet'\nOPTIONS(\n 'aws.region' 'us-east-1',\n 'aws.skip_signature' true\n);\n\nSELECT COUNT(*) FROM hits;\n" +--- +success: true +exit_code: 0 +----- stdout ----- +[CLI_VERSION] +0 row(s) fetched. +[ELAPSED] + ++----------+ +| count(*) | ++----------+ +| 1000000 | ++----------+ +1 row(s) fetched. +[ELAPSED] + +\q + +----- stderr ----- +[[TIME] WARN datafusion_cli::exec] S3 region is incorrect, auto-detecting the correct region (this may be slow). Consider updating your region configuration. diff --git a/datafusion-cli/tests/snapshots/cli_explain_environment_overrides@explain_plan_environment_overrides.snap b/datafusion-cli/tests/snapshots/cli_explain_environment_overrides@explain_plan_environment_overrides.snap index 6b3a247dd7b82..1359cefbe71c7 100644 --- a/datafusion-cli/tests/snapshots/cli_explain_environment_overrides@explain_plan_environment_overrides.snap +++ b/datafusion-cli/tests/snapshots/cli_explain_environment_overrides@explain_plan_environment_overrides.snap @@ -7,7 +7,6 @@ info: - EXPLAIN SELECT 123 env: DATAFUSION_EXPLAIN_FORMAT: pgjson -snapshot_kind: text --- success: true exit_code: 0 diff --git a/datafusion-cli/tests/snapshots/cli_format@automatic.snap b/datafusion-cli/tests/snapshots/cli_format@automatic.snap index 2591f493e90a8..76b14d9a3a924 100644 --- a/datafusion-cli/tests/snapshots/cli_format@automatic.snap +++ b/datafusion-cli/tests/snapshots/cli_format@automatic.snap @@ -1,5 +1,5 @@ --- -source: tests/cli_integration.rs +source: datafusion-cli/tests/cli_integration.rs info: program: datafusion-cli args: diff --git a/datafusion-cli/tests/snapshots/cli_format@csv.snap b/datafusion-cli/tests/snapshots/cli_format@csv.snap index c41b042298eb0..2c969bd91d121 100644 --- a/datafusion-cli/tests/snapshots/cli_format@csv.snap +++ b/datafusion-cli/tests/snapshots/cli_format@csv.snap @@ -1,5 +1,5 @@ --- -source: tests/cli_integration.rs +source: datafusion-cli/tests/cli_integration.rs info: program: datafusion-cli args: diff --git a/datafusion-cli/tests/snapshots/cli_format@json.snap b/datafusion-cli/tests/snapshots/cli_format@json.snap index 8f804a337cce5..22a9cc4657a91 100644 --- a/datafusion-cli/tests/snapshots/cli_format@json.snap +++ b/datafusion-cli/tests/snapshots/cli_format@json.snap @@ -1,5 +1,5 @@ --- -source: tests/cli_integration.rs +source: datafusion-cli/tests/cli_integration.rs info: program: datafusion-cli args: diff --git a/datafusion-cli/tests/snapshots/cli_format@nd-json.snap b/datafusion-cli/tests/snapshots/cli_format@nd-json.snap index 7b4ce1e2530cf..513bcb7372ca6 100644 --- a/datafusion-cli/tests/snapshots/cli_format@nd-json.snap +++ b/datafusion-cli/tests/snapshots/cli_format@nd-json.snap @@ -1,5 +1,5 @@ --- -source: tests/cli_integration.rs +source: datafusion-cli/tests/cli_integration.rs info: program: datafusion-cli args: diff --git a/datafusion-cli/tests/snapshots/cli_format@table.snap b/datafusion-cli/tests/snapshots/cli_format@table.snap index 99914182462aa..8677847588385 100644 --- a/datafusion-cli/tests/snapshots/cli_format@table.snap +++ b/datafusion-cli/tests/snapshots/cli_format@table.snap @@ -1,5 +1,5 @@ --- -source: tests/cli_integration.rs +source: datafusion-cli/tests/cli_integration.rs info: program: datafusion-cli args: diff --git a/datafusion-cli/tests/snapshots/cli_format@tsv.snap b/datafusion-cli/tests/snapshots/cli_format@tsv.snap index 968268c31dd55..c56e60fcab155 100644 --- a/datafusion-cli/tests/snapshots/cli_format@tsv.snap +++ b/datafusion-cli/tests/snapshots/cli_format@tsv.snap @@ -1,5 +1,5 @@ --- -source: tests/cli_integration.rs +source: datafusion-cli/tests/cli_integration.rs info: program: datafusion-cli args: diff --git a/datafusion-cli/tests/snapshots/cli_quick_test@batch_size.snap b/datafusion-cli/tests/snapshots/cli_quick_test@batch_size.snap index c27d527df0b6a..9fd07fa6f4e1b 100644 --- a/datafusion-cli/tests/snapshots/cli_quick_test@batch_size.snap +++ b/datafusion-cli/tests/snapshots/cli_quick_test@batch_size.snap @@ -1,5 +1,5 @@ --- -source: tests/cli_integration.rs +source: datafusion-cli/tests/cli_integration.rs info: program: datafusion-cli args: diff --git a/datafusion-cli/tests/snapshots/cli_quick_test@can_see_indent_format.snap b/datafusion-cli/tests/snapshots/cli_quick_test@can_see_indent_format.snap index b2fb64709974e..8275041acaecc 100644 --- a/datafusion-cli/tests/snapshots/cli_quick_test@can_see_indent_format.snap +++ b/datafusion-cli/tests/snapshots/cli_quick_test@can_see_indent_format.snap @@ -5,7 +5,6 @@ info: args: - "--command" - EXPLAIN FORMAT indent SELECT 123 -snapshot_kind: text --- success: true exit_code: 0 @@ -15,7 +14,7 @@ exit_code: 0 | plan_type | plan | +---------------+------------------------------------------+ | logical_plan | Projection: Int64(123) | -| | EmptyRelation | +| | EmptyRelation: rows=1 | | physical_plan | ProjectionExec: expr=[123 as Int64(123)] | | | PlaceholderRowExec | | | | diff --git a/datafusion-cli/tests/snapshots/cli_quick_test@default_explain_plan.snap b/datafusion-cli/tests/snapshots/cli_quick_test@default_explain_plan.snap index 46ee6be64f624..8620f6da84488 100644 --- a/datafusion-cli/tests/snapshots/cli_quick_test@default_explain_plan.snap +++ b/datafusion-cli/tests/snapshots/cli_quick_test@default_explain_plan.snap @@ -5,7 +5,6 @@ info: args: - "--command" - EXPLAIN SELECT 123 -snapshot_kind: text --- success: true exit_code: 0 diff --git a/datafusion-cli/tests/snapshots/cli_quick_test@files.snap b/datafusion-cli/tests/snapshots/cli_quick_test@files.snap index 7c44e41729a17..df3a10b6bb54b 100644 --- a/datafusion-cli/tests/snapshots/cli_quick_test@files.snap +++ b/datafusion-cli/tests/snapshots/cli_quick_test@files.snap @@ -1,5 +1,5 @@ --- -source: tests/cli_integration.rs +source: datafusion-cli/tests/cli_integration.rs info: program: datafusion-cli args: diff --git a/datafusion-cli/tests/snapshots/cli_quick_test@statements.snap b/datafusion-cli/tests/snapshots/cli_quick_test@statements.snap index 3b975bb6a927d..a394458768d1b 100644 --- a/datafusion-cli/tests/snapshots/cli_quick_test@statements.snap +++ b/datafusion-cli/tests/snapshots/cli_quick_test@statements.snap @@ -1,5 +1,5 @@ --- -source: tests/cli_integration.rs +source: datafusion-cli/tests/cli_integration.rs info: program: datafusion-cli args: diff --git a/datafusion-cli/tests/snapshots/cli_top_memory_consumers@no_track.snap b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@no_track.snap index 89b646a531f8b..fe454595eb4bc 100644 --- a/datafusion-cli/tests/snapshots/cli_top_memory_consumers@no_track.snap +++ b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@no_track.snap @@ -14,7 +14,7 @@ success: false exit_code: 1 ----- stdout ----- [CLI_VERSION] -Error: Not enough memory to continue external sort. Consider increasing the memory limit, or decreasing sort_spill_reservation_bytes +Error: Not enough memory to continue external sort. Consider increasing the memory limit config: 'datafusion.runtime.memory_limit', or decreasing the config: 'datafusion.execution.sort_spill_reservation_bytes'. caused by Resources exhausted: Failed to allocate diff --git a/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top2.snap b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top2.snap index ed925a6f64613..bb30e387166bc 100644 --- a/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top2.snap +++ b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top2.snap @@ -14,11 +14,11 @@ success: false exit_code: 1 ----- stdout ----- [CLI_VERSION] -Error: Not enough memory to continue external sort. Consider increasing the memory limit, or decreasing sort_spill_reservation_bytes +Error: Not enough memory to continue external sort. Consider increasing the memory limit config: 'datafusion.runtime.memory_limit', or decreasing the config: 'datafusion.execution.sort_spill_reservation_bytes'. caused by -Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: - Consumer(can spill: bool) consumed XB, - Consumer(can spill: bool) consumed XB. +Resources exhausted: Additional allocation failed for ExternalSorter[0] with top memory consumers (across reservations) as: + Consumer(can spill: bool) consumed XB, peak XB, + Consumer(can spill: bool) consumed XB, peak XB. Error: Failed to allocate ----- stderr ----- diff --git a/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top3_default.snap b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top3_default.snap index f35e3b117178f..891d72e3cc639 100644 --- a/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top3_default.snap +++ b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top3_default.snap @@ -12,12 +12,12 @@ success: false exit_code: 1 ----- stdout ----- [CLI_VERSION] -Error: Not enough memory to continue external sort. Consider increasing the memory limit, or decreasing sort_spill_reservation_bytes +Error: Not enough memory to continue external sort. Consider increasing the memory limit config: 'datafusion.runtime.memory_limit', or decreasing the config: 'datafusion.execution.sort_spill_reservation_bytes'. caused by -Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: - Consumer(can spill: bool) consumed XB, - Consumer(can spill: bool) consumed XB, - Consumer(can spill: bool) consumed XB. +Resources exhausted: Additional allocation failed for ExternalSorter[0] with top memory consumers (across reservations) as: + Consumer(can spill: bool) consumed XB, peak XB, + Consumer(can spill: bool) consumed XB, peak XB, + Consumer(can spill: bool) consumed XB, peak XB. Error: Failed to allocate ----- stderr ----- diff --git a/datafusion-cli/tests/snapshots/cli_top_memory_consumers_with_mem_pool_type@no_track.snap b/datafusion-cli/tests/snapshots/cli_top_memory_consumers_with_mem_pool_type@no_track.snap new file mode 100644 index 0000000000000..25267ea1617e5 --- /dev/null +++ b/datafusion-cli/tests/snapshots/cli_top_memory_consumers_with_mem_pool_type@no_track.snap @@ -0,0 +1,23 @@ +--- +source: datafusion-cli/tests/cli_integration.rs +info: + program: datafusion-cli + args: + - "--memory-limit" + - 10M + - "--mem-pool-type" + - fair + - "--command" + - "select * from generate_series(1,500000) as t1(v1) order by v1;" + - "--top-memory-consumers" + - "0" +--- +success: false +exit_code: 1 +----- stdout ----- +[CLI_VERSION] +Error: Not enough memory to continue external sort. Consider increasing the memory limit config: 'datafusion.runtime.memory_limit', or decreasing the config: 'datafusion.execution.sort_spill_reservation_bytes'. +caused by +Resources exhausted: Failed to allocate + +----- stderr ----- diff --git a/datafusion-cli/tests/snapshots/cli_top_memory_consumers_with_mem_pool_type@top2.snap b/datafusion-cli/tests/snapshots/cli_top_memory_consumers_with_mem_pool_type@top2.snap new file mode 100644 index 0000000000000..6515050047107 --- /dev/null +++ b/datafusion-cli/tests/snapshots/cli_top_memory_consumers_with_mem_pool_type@top2.snap @@ -0,0 +1,26 @@ +--- +source: datafusion-cli/tests/cli_integration.rs +info: + program: datafusion-cli + args: + - "--memory-limit" + - 10M + - "--mem-pool-type" + - fair + - "--command" + - "select * from generate_series(1,500000) as t1(v1) order by v1;" + - "--top-memory-consumers" + - "2" +--- +success: false +exit_code: 1 +----- stdout ----- +[CLI_VERSION] +Error: Not enough memory to continue external sort. Consider increasing the memory limit config: 'datafusion.runtime.memory_limit', or decreasing the config: 'datafusion.execution.sort_spill_reservation_bytes'. +caused by +Resources exhausted: Additional allocation failed for ExternalSorter[0] with top memory consumers (across reservations) as: + Consumer(can spill: bool) consumed XB, peak XB, + Consumer(can spill: bool) consumed XB, peak XB. +Error: Failed to allocate + +----- stderr ----- diff --git a/datafusion-cli/tests/snapshots/cli_with_unbounded_memory_pool@default.snap b/datafusion-cli/tests/snapshots/cli_with_unbounded_memory_pool@default.snap new file mode 100644 index 0000000000000..7bdcd63dc7be6 --- /dev/null +++ b/datafusion-cli/tests/snapshots/cli_with_unbounded_memory_pool@default.snap @@ -0,0 +1,36 @@ +--- +source: datafusion-cli/tests/cli_integration.rs +info: + program: datafusion-cli + args: + - "--maxrows" + - "10" + - "--command" + - "select * from generate_series(1,500000) as t1(v1) order by v1;" +--- +success: true +exit_code: 0 +----- stdout ----- +[CLI_VERSION] ++----+ +| v1 | ++----+ +| 1 | +| 2 | +| 3 | +| 4 | +| 5 | +| 6 | +| 7 | +| 8 | +| 9 | +| 10 | +| . | +| . | +| . | ++----+ +500000 row(s) fetched. (First 10 displayed. Use --maxrows to adjust) +[ELAPSED] + + +----- stderr ----- diff --git a/datafusion-cli/tests/snapshots/object_store_profiling.snap b/datafusion-cli/tests/snapshots/object_store_profiling.snap new file mode 100644 index 0000000000000..029b07c324f5d --- /dev/null +++ b/datafusion-cli/tests/snapshots/object_store_profiling.snap @@ -0,0 +1,83 @@ +--- +source: datafusion-cli/tests/cli_integration.rs +info: + program: datafusion-cli + args: [] + env: + AWS_ACCESS_KEY_ID: TEST-DataFusionLogin + AWS_ALLOW_HTTP: "true" + AWS_ENDPOINT: "http://localhost:55057" + AWS_SECRET_ACCESS_KEY: TEST-DataFusionPassword + stdin: "\n CREATE EXTERNAL TABLE CARS\nSTORED AS CSV\nLOCATION 's3://data/cars.csv';\n\n-- Initial query should not show any profiling as the object store is not instrumented yet\nSELECT * from CARS LIMIT 1;\n\\object_store_profiling trace\n-- Query again to see the full profiling output\nSELECT * from CARS LIMIT 1;\n\\object_store_profiling summary\n-- Query again to see the summarized profiling output\nSELECT * from CARS LIMIT 1;\n\\object_store_profiling disabled\n-- Final query should not show any profiling as we disabled it again\nSELECT * from CARS LIMIT 1;\n" +snapshot_kind: text +--- +success: true +exit_code: 0 +----- stdout ----- +[CLI_VERSION] +0 row(s) fetched. +[ELAPSED] + ++-----+-------+---------------------+ +| car | speed | time | ++-----+-------+---------------------+ +| red | 20.0 | 1996-04-12T12:05:03 | ++-----+-------+---------------------+ +1 row(s) fetched. +[ELAPSED] + +ObjectStore Profile mode set to Trace ++-----+-------+---------------------+ +| car | speed | time | ++-----+-------+---------------------+ +| red | 20.0 | 1996-04-12T12:05:03 | ++-----+-------+---------------------+ +1 row(s) fetched. +[ELAPSED] + +Object Store Profiling +Instrumented Object Store: instrument_mode: Trace, inner: AmazonS3(data) + operation=Head duration=[DURATION] path=cars.csv + operation=Get duration=[DURATION] size=1006 path=cars.csv + +Summaries: ++-----------+----------+-----------+-----------+-----------+-----------+-------+ +| Operation | Metric | min | max | avg | sum | count | ++-----------+----------+-----------+-----------+-----------+-----------+-------+ +| Get | duration | ...NORMALIZED...| 1 | +| Get | size | 1006 B | 1006 B | 1006 B | 1006 B | 1 | +| Head | duration | ...NORMALIZED...| 1 | +| Head | size | | | | | 1 | ++-----------+----------+-----------+-----------+-----------+-----------+-------+ +ObjectStore Profile mode set to Summary ++-----+-------+---------------------+ +| car | speed | time | ++-----+-------+---------------------+ +| red | 20.0 | 1996-04-12T12:05:03 | ++-----+-------+---------------------+ +1 row(s) fetched. +[ELAPSED] + +Object Store Profiling +Instrumented Object Store: instrument_mode: Summary, inner: AmazonS3(data) +Summaries: ++-----------+----------+-----------+-----------+-----------+-----------+-------+ +| Operation | Metric | min | max | avg | sum | count | ++-----------+----------+-----------+-----------+-----------+-----------+-------+ +| Get | duration | ...NORMALIZED...| 1 | +| Get | size | 1006 B | 1006 B | 1006 B | 1006 B | 1 | +| Head | duration | ...NORMALIZED...| 1 | +| Head | size | | | | | 1 | ++-----------+----------+-----------+-----------+-----------+-----------+-------+ +ObjectStore Profile mode set to Disabled ++-----+-------+---------------------+ +| car | speed | time | ++-----+-------+---------------------+ +| red | 20.0 | 1996-04-12T12:05:03 | ++-----+-------+---------------------+ +1 row(s) fetched. +[ELAPSED] + +\q + +----- stderr ----- diff --git a/datafusion-cli/tests/snapshots/s3_url_fallback@s3_url_fallback.snap b/datafusion-cli/tests/snapshots/s3_url_fallback@s3_url_fallback.snap new file mode 100644 index 0000000000000..07036d041b42c --- /dev/null +++ b/datafusion-cli/tests/snapshots/s3_url_fallback@s3_url_fallback.snap @@ -0,0 +1,34 @@ +--- +source: datafusion-cli/tests/cli_integration.rs +info: + program: datafusion-cli + args: [] + env: + AWS_ACCESS_KEY_ID: TEST-DataFusionLogin + AWS_ALLOW_HTTP: "true" + AWS_ENDPOINT: "http://localhost:32771" + AWS_SECRET_ACCESS_KEY: TEST-DataFusionPassword + stdin: "CREATE EXTERNAL TABLE partitioned_data\nSTORED AS CSV\nLOCATION 's3://data/partitioned_csv'\nOPTIONS (\n 'format.has_header' 'false'\n);\n\nSELECT * FROM partitioned_data ORDER BY column_1, column_2 LIMIT 5;\n" +--- +success: true +exit_code: 0 +----- stdout ----- +[CLI_VERSION] +0 row(s) fetched. +[ELAPSED] + ++----------+----------+----------+ +| column_1 | column_2 | column_3 | ++----------+----------+----------+ +| 0 | 0 | true | +| 0 | 1 | false | +| 0 | 2 | true | +| 0 | 3 | false | +| 0 | 4 | true | ++----------+----------+----------+ +5 row(s) fetched. +[ELAPSED] + +\q + +----- stderr ----- diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index b31708a5c1cc7..e56f5ad6b8ca7 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -29,55 +29,50 @@ license = { workspace = true } authors = { workspace = true } rust-version = { workspace = true } +# Note: add additional linter rules in lib.rs. +# Rust does not support workspace + new linter rules in subcrates yet +# https://github.com/rust-lang/cargo/issues/13157 [lints] workspace = true -[[example]] -name = "flight_sql_server" -path = "examples/flight/flight_sql_server.rs" - -[[example]] -name = "flight_server" -path = "examples/flight/flight_server.rs" - -[[example]] -name = "flight_client" -path = "examples/flight/flight_client.rs" - -[[example]] -name = "dataframe_to_s3" -path = "examples/external_dependency/dataframe-to-s3.rs" - -[[example]] -name = "query_aws_s3" -path = "examples/external_dependency/query-aws-s3.rs" +[dependencies] +arrow = { workspace = true } +arrow-schema = { workspace = true } +datafusion = { workspace = true, default-features = true, features = ["parquet_encryption"] } +datafusion-common = { workspace = true } +nom = "8.0.0" +tempfile = { workspace = true } +tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot", "fs"] } [dev-dependencies] -arrow = { workspace = true } -# arrow_schema is required for record_batch! macro :sad: arrow-flight = { workspace = true } -arrow-schema = { workspace = true } async-trait = { workspace = true } bytes = { workspace = true } dashmap = { workspace = true } # note only use main datafusion crate for examples -datafusion = { workspace = true, default-features = true } -datafusion-ffi = { workspace = true } +base64 = "0.22.1" +datafusion-expr = { workspace = true } +datafusion-physical-expr-adapter = { workspace = true } datafusion-proto = { workspace = true } +datafusion-sql = { workspace = true } env_logger = { workspace = true } futures = { workspace = true } +insta = { workspace = true } log = { workspace = true } mimalloc = { version = "0.1", default-features = false } object_store = { workspace = true, features = ["aws", "http"] } prost = { workspace = true } -tempfile = { workspace = true } +rand = { workspace = true } +serde = { version = "1", features = ["derive"] } +serde_json = { workspace = true } +strum = { workspace = true } +strum_macros = { workspace = true } test-utils = { path = "../test-utils" } -tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot"] } -tonic = "0.12.1" +tonic = "0.14" tracing = { version = "0.1" } tracing-subscriber = { version = "0.3" } url = { workspace = true } -uuid = "1.17" +uuid = { workspace = true } [target.'cfg(not(target_os = "windows"))'.dev-dependencies] -nix = { version = "0.30.1", features = ["fs"] } +nix = { version = "0.31.1", features = ["fs"] } diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index 3ba4c77cd84c3..2cf0ec52409f8 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -39,51 +39,181 @@ git submodule update --init # Change to the examples directory cd datafusion-examples/examples -# Run the `dataframe` example: -# ... use the equivalent for other examples -cargo run --example dataframe +# Run all examples in a group +cargo run --example -- all + +# Run a specific example within a group +cargo run --example -- + +# Run all examples in the `dataframe` group +cargo run --example dataframe -- all + +# Run a single example from the `dataframe` group +# (apply the same pattern for any other group) +cargo run --example dataframe -- dataframe ``` -## Single Process - -- [`advanced_udaf.rs`](examples/advanced_udaf.rs): Define and invoke a more complicated User Defined Aggregate Function (UDAF) -- [`advanced_udf.rs`](examples/advanced_udf.rs): Define and invoke a more complicated User Defined Scalar Function (UDF) -- [`advanced_udwf.rs`](examples/advanced_udwf.rs): Define and invoke a more complicated User Defined Window Function (UDWF) -- [`advanced_parquet_index.rs`](examples/advanced_parquet_index.rs): Creates a detailed secondary index that covers the contents of several parquet files -- [`analyzer_rule.rs`](examples/analyzer_rule.rs): Use a custom AnalyzerRule to change a query's semantics (row level access control) -- [`catalog.rs`](examples/catalog.rs): Register the table into a custom catalog -- [`composed_extension_codec`](examples/composed_extension_codec.rs): Example of using multiple extension codecs for serialization / deserialization -- [`csv_sql_streaming.rs`](examples/csv_sql_streaming.rs): Build and run a streaming query plan from a SQL statement against a local CSV file -- [`csv_json_opener.rs`](examples/csv_json_opener.rs): Use low level `FileOpener` APIs to read CSV/JSON into Arrow `RecordBatch`es -- [`custom_datasource.rs`](examples/custom_datasource.rs): Run queries against a custom datasource (TableProvider) -- [`custom_file_format.rs`](examples/custom_file_format.rs): Write data to a custom file format -- [`dataframe-to-s3.rs`](examples/external_dependency/dataframe-to-s3.rs): Run a query using a DataFrame against a parquet file from s3 and writing back to s3 -- [`dataframe.rs`](examples/dataframe.rs): Run a query using a DataFrame API against parquet files, csv files, and in-memory data, including multiple subqueries. Also demonstrates the various methods to write out a DataFrame to a table, parquet file, csv file, and json file. -- [`deserialize_to_struct.rs`](examples/deserialize_to_struct.rs): Convert query results (Arrow ArrayRefs) into Rust structs -- [`expr_api.rs`](examples/expr_api.rs): Create, execute, simplify, analyze and coerce `Expr`s -- [`file_stream_provider.rs`](examples/file_stream_provider.rs): Run a query on `FileStreamProvider` which implements `StreamProvider` for reading and writing to arbitrary stream sources / sinks. -- [`flight_sql_server.rs`](examples/flight/flight_sql_server.rs): Run DataFusion as a standalone process and execute SQL queries from JDBC clients -- [`function_factory.rs`](examples/function_factory.rs): Register `CREATE FUNCTION` handler to implement SQL macros -- [`optimizer_rule.rs`](examples/optimizer_rule.rs): Use a custom OptimizerRule to replace certain predicates -- [`parquet_index.rs`](examples/parquet_index.rs): Create an secondary index over several parquet files and use it to speed up queries -- [`parquet_exec_visitor.rs`](examples/parquet_exec_visitor.rs): Extract statistics by visiting an ExecutionPlan after execution -- [`parse_sql_expr.rs`](examples/parse_sql_expr.rs): Parse SQL text into DataFusion `Expr`. -- [`plan_to_sql.rs`](examples/plan_to_sql.rs): Generate SQL from DataFusion `Expr` and `LogicalPlan` -- [`planner_api.rs`](examples/planner_api.rs) APIs to manipulate logical and physical plans -- [`pruning.rs`](examples/pruning.rs): Use pruning to rule out files based on statistics -- [`query-aws-s3.rs`](examples/external_dependency/query-aws-s3.rs): Configure `object_store` and run a query against files stored in AWS S3 -- [`query-http-csv.rs`](examples/query-http-csv.rs): Configure `object_store` and run a query against files vi HTTP -- [`regexp.rs`](examples/regexp.rs): Examples of using regular expression functions -- [`remote_catalog.rs`](examples/regexp.rs): Examples of interfacing with a remote catalog (e.g. over a network) -- [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User Defined Aggregate Function (UDAF) -- [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined Scalar Function (UDF) -- [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User Defined Window Function (UDWF) -- [`sql_analysis.rs`](examples/sql_analysis.rs): Analyse SQL queries with DataFusion structures -- [`sql_frontend.rs`](examples/sql_frontend.rs): Create LogicalPlans (only) from sql strings -- [`sql_dialect.rs`](examples/sql_dialect.rs): Example of implementing a custom SQL dialect on top of `DFParser` -- [`sql_query.rs`](examples/memtable.rs): Query data using SQL (in memory `RecordBatches`, local Parquet files) -- [`date_time_function.rs`](examples/date_time_function.rs): Examples of date-time related functions and queries. - -## Distributed - -- [`flight_client.rs`](examples/flight/flight_client.rs) and [`flight_server.rs`](examples/flight/flight_server.rs): Run DataFusion as a standalone process and execute SQL queries from a client using the Flight protocol. +## Builtin Functions Examples + +### Group: `builtin_functions` + +#### Category: Single Process + +| Subcommand | File Path | Description | +| ---------------- | ----------------------------------------------------------------------------------------- | ---------------------------------------------------------- | +| date_time | [`builtin_functions/date_time.rs`](examples/builtin_functions/date_time.rs) | Examples of date-time related functions and queries | +| function_factory | [`builtin_functions/function_factory.rs`](examples/builtin_functions/function_factory.rs) | Register `CREATE FUNCTION` handler to implement SQL macros | +| regexp | [`builtin_functions/regexp.rs`](examples/builtin_functions/regexp.rs) | Examples of using regular expression functions | + +## Custom Data Source Examples + +### Group: `custom_data_source` + +#### Category: Single Process + +| Subcommand | File Path | Description | +| --------------------- | ----------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------- | +| adapter_serialization | [`custom_data_source/adapter_serialization.rs`](examples/custom_data_source/adapter_serialization.rs) | Preserve custom PhysicalExprAdapter information during plan serialization using PhysicalExtensionCodec interception | +| csv_json_opener | [`custom_data_source/csv_json_opener.rs`](examples/custom_data_source/csv_json_opener.rs) | Use low-level FileOpener APIs for CSV/JSON | +| csv_sql_streaming | [`custom_data_source/csv_sql_streaming.rs`](examples/custom_data_source/csv_sql_streaming.rs) | Run a streaming SQL query against CSV data | +| custom_datasource | [`custom_data_source/custom_datasource.rs`](examples/custom_data_source/custom_datasource.rs) | Query a custom TableProvider | +| custom_file_casts | [`custom_data_source/custom_file_casts.rs`](examples/custom_data_source/custom_file_casts.rs) | Implement custom casting rules | +| custom_file_format | [`custom_data_source/custom_file_format.rs`](examples/custom_data_source/custom_file_format.rs) | Write to a custom file format | +| default_column_values | [`custom_data_source/default_column_values.rs`](examples/custom_data_source/default_column_values.rs) | Custom default values using metadata | +| file_stream_provider | [`custom_data_source/file_stream_provider.rs`](examples/custom_data_source/file_stream_provider.rs) | Read/write via FileStreamProvider for streams | + +## Data IO Examples + +### Group: `data_io` + +#### Category: Single Process + +| Subcommand | File Path | Description | +| -------------------- | ----------------------------------------------------------------------------------------- | ------------------------------------------------------ | +| catalog | [`data_io/catalog.rs`](examples/data_io/catalog.rs) | Register tables into a custom catalog | +| json_shredding | [`data_io/json_shredding.rs`](examples/data_io/json_shredding.rs) | Implement filter rewriting for JSON shredding | +| parquet_adv_idx | [`data_io/parquet_advanced_index.rs`](examples/data_io/parquet_advanced_index.rs) | Create a secondary index across multiple parquet files | +| parquet_emb_idx | [`data_io/parquet_embedded_index.rs`](examples/data_io/parquet_embedded_index.rs) | Store a custom index inside Parquet files | +| parquet_enc | [`data_io/parquet_encrypted.rs`](examples/data_io/parquet_encrypted.rs) | Read & write encrypted Parquet files | +| parquet_enc_with_kms | [`data_io/parquet_encrypted_with_kms.rs`](examples/data_io/parquet_encrypted_with_kms.rs) | Encrypted Parquet I/O using a KMS-backed factory | +| parquet_exec_visitor | [`data_io/parquet_exec_visitor.rs`](examples/data_io/parquet_exec_visitor.rs) | Extract statistics by visiting an ExecutionPlan | +| parquet_idx | [`data_io/parquet_index.rs`](examples/data_io/parquet_index.rs) | Create a secondary index | +| query_http_csv | [`data_io/query_http_csv.rs`](examples/data_io/query_http_csv.rs) | Query CSV files via HTTP | +| remote_catalog | [`data_io/remote_catalog.rs`](examples/data_io/remote_catalog.rs) | Interact with a remote catalog | + +## DataFrame Examples + +### Group: `dataframe` + +#### Category: Single Process + +| Subcommand | File Path | Description | +| --------------------- | ----------------------------------------------------------------------------------- | ------------------------------------------------------- | +| cache_factory | [`dataframe/cache_factory.rs`](examples/dataframe/cache_factory.rs) | Custom lazy caching for DataFrames using `CacheFactory` | +| dataframe | [`dataframe/dataframe.rs`](examples/dataframe/dataframe.rs) | Query DataFrames from various sources and write output | +| deserialize_to_struct | [`dataframe/deserialize_to_struct.rs`](examples/dataframe/deserialize_to_struct.rs) | Convert Arrow arrays into Rust structs | + +## Execution Monitoring Examples + +### Group: `execution_monitoring` + +#### Category: Single Process + +| Subcommand | File Path | Description | +| ------------------ | ------------------------------------------------------------------------------------------------------------------- | ---------------------------------------- | +| mem_pool_exec_plan | [`execution_monitoring/memory_pool_execution_plan.rs`](examples/execution_monitoring/memory_pool_execution_plan.rs) | Memory-aware ExecutionPlan with spilling | +| mem_pool_tracking | [`execution_monitoring/memory_pool_tracking.rs`](examples/execution_monitoring/memory_pool_tracking.rs) | Demonstrates memory tracking | +| tracing | [`execution_monitoring/tracing.rs`](examples/execution_monitoring/tracing.rs) | Demonstrates tracing integration | + +## External Dependency Examples + +### Group: `external_dependency` + +#### Category: Single Process + +| Subcommand | File Path | Description | +| --------------- | ------------------------------------------------------------------------------------------- | ---------------------------------------- | +| dataframe_to_s3 | [`external_dependency/dataframe_to_s3.rs`](examples/external_dependency/dataframe_to_s3.rs) | Query DataFrames and write results to S3 | +| query_aws_s3 | [`external_dependency/query_aws_s3.rs`](examples/external_dependency/query_aws_s3.rs) | Query S3-backed data using object_store | + +## Flight Examples + +### Group: `flight` + +#### Category: Distributed + +| Subcommand | File Path | Description | +| ---------- | ------------------------------------------------------- | ------------------------------------------------------ | +| client | [`flight/client.rs`](examples/flight/client.rs) | Execute SQL queries via Arrow Flight protocol | +| server | [`flight/server.rs`](examples/flight/server.rs) | Run DataFusion server accepting FlightSQL/JDBC queries | +| sql_server | [`flight/sql_server.rs`](examples/flight/sql_server.rs) | Standalone SQL server for JDBC clients | + +## Proto Examples + +### Group: `proto` + +#### Category: Single Process + +| Subcommand | File Path | Description | +| ------------------------ | --------------------------------------------------------------------------------- | ----------------------------------------------------------------------------- | +| composed_extension_codec | [`proto/composed_extension_codec.rs`](examples/proto/composed_extension_codec.rs) | Use multiple extension codecs for serialization/deserialization | +| expression_deduplication | [`proto/expression_deduplication.rs`](examples/proto/expression_deduplication.rs) | Example of expression caching/deduplication using the codec decorator pattern | + +## Query Planning Examples + +### Group: `query_planning` + +#### Category: Single Process + +| Subcommand | File Path | Description | +| -------------- | ------------------------------------------------------------------------------- | ------------------------------------------------------ | +| analyzer_rule | [`query_planning/analyzer_rule.rs`](examples/query_planning/analyzer_rule.rs) | Custom AnalyzerRule to change query semantics | +| expr_api | [`query_planning/expr_api.rs`](examples/query_planning/expr_api.rs) | Create, execute, analyze, and coerce Exprs | +| optimizer_rule | [`query_planning/optimizer_rule.rs`](examples/query_planning/optimizer_rule.rs) | Replace predicates via a custom OptimizerRule | +| parse_sql_expr | [`query_planning/parse_sql_expr.rs`](examples/query_planning/parse_sql_expr.rs) | Parse SQL into DataFusion Expr | +| plan_to_sql | [`query_planning/plan_to_sql.rs`](examples/query_planning/plan_to_sql.rs) | Generate SQL from expressions or plans | +| planner_api | [`query_planning/planner_api.rs`](examples/query_planning/planner_api.rs) | APIs for logical and physical plan manipulation | +| pruning | [`query_planning/pruning.rs`](examples/query_planning/pruning.rs) | Use pruning to skip irrelevant files | +| thread_pools | [`query_planning/thread_pools.rs`](examples/query_planning/thread_pools.rs) | Configure custom thread pools for DataFusion execution | + +## Relation Planner Examples + +### Group: `relation_planner` + +#### Category: Single Process + +| Subcommand | File Path | Description | +| --------------- | ------------------------------------------------------------------------------------- | ------------------------------------------ | +| match_recognize | [`relation_planner/match_recognize.rs`](examples/relation_planner/match_recognize.rs) | Implement MATCH_RECOGNIZE pattern matching | +| pivot_unpivot | [`relation_planner/pivot_unpivot.rs`](examples/relation_planner/pivot_unpivot.rs) | Implement PIVOT / UNPIVOT | +| table_sample | [`relation_planner/table_sample.rs`](examples/relation_planner/table_sample.rs) | Implement TABLESAMPLE | + +## SQL Ops Examples + +### Group: `sql_ops` + +#### Category: Single Process + +| Subcommand | File Path | Description | +| ----------------- | ----------------------------------------------------------------------- | -------------------------------------------------- | +| analysis | [`sql_ops/analysis.rs`](examples/sql_ops/analysis.rs) | Analyze SQL queries | +| custom_sql_parser | [`sql_ops/custom_sql_parser.rs`](examples/sql_ops/custom_sql_parser.rs) | Implement a custom SQL parser to extend DataFusion | +| frontend | [`sql_ops/frontend.rs`](examples/sql_ops/frontend.rs) | Build LogicalPlans from SQL | +| query | [`sql_ops/query.rs`](examples/sql_ops/query.rs) | Query data using SQL | + +## UDF Examples + +### Group: `udf` + +#### Category: Single Process + +| Subcommand | File Path | Description | +| ---------- | ------------------------------------------------------- | ----------------------------------------------- | +| adv_udaf | [`udf/advanced_udaf.rs`](examples/udf/advanced_udaf.rs) | Advanced User Defined Aggregate Function (UDAF) | +| adv_udf | [`udf/advanced_udf.rs`](examples/udf/advanced_udf.rs) | Advanced User Defined Scalar Function (UDF) | +| adv_udwf | [`udf/advanced_udwf.rs`](examples/udf/advanced_udwf.rs) | Advanced User Defined Window Function (UDWF) | +| async_udf | [`udf/async_udf.rs`](examples/udf/async_udf.rs) | Asynchronous User Defined Scalar Function | +| udaf | [`udf/simple_udaf.rs`](examples/udf/simple_udaf.rs) | Simple UDAF example | +| udf | [`udf/simple_udf.rs`](examples/udf/simple_udf.rs) | Simple UDF example | +| udtf | [`udf/simple_udtf.rs`](examples/udf/simple_udtf.rs) | Simple UDTF example | +| udwf | [`udf/simple_udwf.rs`](examples/udf/simple_udwf.rs) | Simple UDWF example | diff --git a/datafusion-examples/data/README.md b/datafusion-examples/data/README.md new file mode 100644 index 0000000000000..e8296a8856e60 --- /dev/null +++ b/datafusion-examples/data/README.md @@ -0,0 +1,25 @@ + + +## Example datasets + +| Filename | Path | Description | +| ----------- | --------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| `cars.csv` | [`data/csv/cars.csv`](./csv/cars.csv) | Time-series–like dataset containing car identifiers, speed values, and timestamps. Used in window function and time-based query examples (e.g. ordering, window frames). | +| `regex.csv` | [`data/csv/regex.csv`](./csv/regex.csv) | Dataset for regular expression examples. Contains input values, regex patterns, replacement strings, and optional flags. Covers ASCII, Unicode, and locale-specific text processing. | diff --git a/datafusion-examples/data/csv/cars.csv b/datafusion-examples/data/csv/cars.csv new file mode 100644 index 0000000000000..bc40f3b01e7a5 --- /dev/null +++ b/datafusion-examples/data/csv/cars.csv @@ -0,0 +1,26 @@ +car,speed,time +red,20.0,1996-04-12T12:05:03.000000000 +red,20.3,1996-04-12T12:05:04.000000000 +red,21.4,1996-04-12T12:05:05.000000000 +red,21.5,1996-04-12T12:05:06.000000000 +red,19.0,1996-04-12T12:05:07.000000000 +red,18.0,1996-04-12T12:05:08.000000000 +red,17.0,1996-04-12T12:05:09.000000000 +red,7.0,1996-04-12T12:05:10.000000000 +red,7.1,1996-04-12T12:05:11.000000000 +red,7.2,1996-04-12T12:05:12.000000000 +red,3.0,1996-04-12T12:05:13.000000000 +red,1.0,1996-04-12T12:05:14.000000000 +red,0.0,1996-04-12T12:05:15.000000000 +green,10.0,1996-04-12T12:05:03.000000000 +green,10.3,1996-04-12T12:05:04.000000000 +green,10.4,1996-04-12T12:05:05.000000000 +green,10.5,1996-04-12T12:05:06.000000000 +green,11.0,1996-04-12T12:05:07.000000000 +green,12.0,1996-04-12T12:05:08.000000000 +green,14.0,1996-04-12T12:05:09.000000000 +green,15.0,1996-04-12T12:05:10.000000000 +green,15.1,1996-04-12T12:05:11.000000000 +green,15.2,1996-04-12T12:05:12.000000000 +green,8.0,1996-04-12T12:05:13.000000000 +green,2.0,1996-04-12T12:05:14.000000000 diff --git a/datafusion-examples/data/csv/regex.csv b/datafusion-examples/data/csv/regex.csv new file mode 100644 index 0000000000000..b249c39522b60 --- /dev/null +++ b/datafusion-examples/data/csv/regex.csv @@ -0,0 +1,12 @@ +values,patterns,replacement,flags +abc,^(a),bb\1bb,i +ABC,^(A).*,B,i +aBc,(b|d),e,i +AbC,(B|D),e, +aBC,^(b|c),d, +4000,\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b,xyz, +4010,\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b,xyz, +Düsseldorf,[\p{Letter}-]+,München, +Москва,[\p{L}-]+,Moscow, +Köln,[a-zA-Z]ö[a-zA-Z]{2},Koln, +اليوم,^\p{Arabic}+$,Today, \ No newline at end of file diff --git a/datafusion-examples/examples/date_time_functions.rs b/datafusion-examples/examples/builtin_functions/date_time.rs similarity index 94% rename from datafusion-examples/examples/date_time_functions.rs rename to datafusion-examples/examples/builtin_functions/date_time.rs index dbe9970439df7..08d4bc6e29978 100644 --- a/datafusion-examples/examples/date_time_functions.rs +++ b/datafusion-examples/examples/builtin_functions/date_time.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use std::sync::Arc; use arrow::array::{Date32Array, Int32Array}; @@ -26,8 +28,20 @@ use datafusion::common::assert_contains; use datafusion::error::Result; use datafusion::prelude::*; -#[tokio::main] -async fn main() -> Result<()> { +/// Example: Working with Date and Time Functions +/// +/// This example demonstrates how to work with various date and time +/// functions in DataFusion using both the DataFrame API and SQL queries. +/// +/// It includes: +/// - `make_date`: building `DATE` values from year, month, and day columns +/// - `to_date`: converting string expressions into `DATE` values +/// - `to_timestamp`: parsing strings or numeric values into `TIMESTAMP`s +/// - `to_char`: formatting dates, timestamps, and durations as strings +/// +/// Together, these examples show how to create, convert, and format temporal +/// data using DataFusion’s built-in functions. +pub async fn date_time() -> Result<()> { query_make_date().await?; query_to_date().await?; query_to_timestamp().await?; @@ -167,12 +181,13 @@ async fn query_make_date() -> Result<()> { // invalid column values will result in an error let result = ctx - .sql("select make_date(2024, null, 23)") + .sql("select make_date(2024, '', 23)") .await? .collect() .await; - let expected = "Execution error: Unable to parse date from null/empty value"; + let expected = + "Arrow error: Cast error: Cannot cast string '' to value of Int32 type"; assert_contains!(result.unwrap_err().to_string(), expected); // invalid date values will also result in an error @@ -182,7 +197,7 @@ async fn query_make_date() -> Result<()> { .collect() .await; - let expected = "Execution error: Unable to parse date from 2024, 1, 32"; + let expected = "Execution error: Day value '32' is out of range"; assert_contains!(result.unwrap_err().to_string(), expected); Ok(()) @@ -492,14 +507,14 @@ async fn query_to_char() -> Result<()> { assert_batches_eq!( &[ - "+------------------------------+", - "| to_char(t.values,t.patterns) |", - "+------------------------------+", - "| 2020-09-01 |", - "| 2020:09:02 |", - "| 20200903 |", - "| 04-09-2020 |", - "+------------------------------+", + "+----------------------------------+", + "| date_format(t.values,t.patterns) |", + "+----------------------------------+", + "| 2020-09-01 |", + "| 2020:09:02 |", + "| 20200903 |", + "| 04-09-2020 |", + "+----------------------------------+", ], &result ); diff --git a/datafusion-examples/examples/function_factory.rs b/datafusion-examples/examples/builtin_functions/function_factory.rs similarity index 95% rename from datafusion-examples/examples/function_factory.rs rename to datafusion-examples/examples/builtin_functions/function_factory.rs index e712f4ea8eaa4..106c53cdf7f12 100644 --- a/datafusion-examples/examples/function_factory.rs +++ b/datafusion-examples/examples/builtin_functions/function_factory.rs @@ -15,19 +15,22 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use arrow::datatypes::DataType; use datafusion::common::tree_node::{Transformed, TreeNode}; -use datafusion::common::{exec_err, internal_err, DataFusionError}; +use datafusion::common::{DataFusionError, exec_datafusion_err, exec_err, internal_err}; use datafusion::error::Result; use datafusion::execution::context::{ FunctionFactory, RegisterFunction, SessionContext, SessionState, }; -use datafusion::logical_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion::logical_expr::simplify::{ExprSimplifyResult, SimplifyContext}; use datafusion::logical_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion::logical_expr::{ ColumnarValue, CreateFunction, Expr, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; +use std::hash::Hash; use std::result::Result as RResult; use std::sync::Arc; @@ -41,8 +44,7 @@ use std::sync::Arc; /// /// This example is rather simple and does not cover all cases required for a /// real implementation. -#[tokio::main] -async fn main() -> Result<()> { +pub async fn function_factory() -> Result<()> { // First we must configure the SessionContext with our function factory let ctx = SessionContext::new() // register custom function factory @@ -106,7 +108,7 @@ impl FunctionFactory for CustomFunctionFactory { } /// this function represents the newly created execution engine. -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] struct ScalarFunctionWrapper { /// The text of the function body, `$1 + f1($2)` in our example name: String, @@ -143,17 +145,13 @@ impl ScalarUDFImpl for ScalarFunctionWrapper { fn simplify( &self, args: Vec, - _info: &dyn SimplifyInfo, + _info: &SimplifyContext, ) -> Result { let replacement = Self::replacement(&self.expr, &args)?; Ok(ExprSimplifyResult::Simplified(replacement)) } - fn aliases(&self) -> &[String] { - &[] - } - fn output_ordering(&self, _input: &[ExprProperties]) -> Result { Ok(SortProperties::Unordered) } @@ -188,9 +186,7 @@ impl ScalarFunctionWrapper { fn parse_placeholder_identifier(placeholder: &str) -> Result { if let Some(value) = placeholder.strip_prefix('$') { Ok(value.parse().map(|v: usize| v - 1).map_err(|e| { - DataFusionError::Execution(format!( - "Placeholder `{placeholder}` parsing error: {e}!" - )) + exec_datafusion_err!("Placeholder `{placeholder}` parsing error: {e}!") })?) } else { exec_err!("Placeholder should start with `$`!") diff --git a/datafusion-examples/examples/builtin_functions/main.rs b/datafusion-examples/examples/builtin_functions/main.rs new file mode 100644 index 0000000000000..42ca15f91935d --- /dev/null +++ b/datafusion-examples/examples/builtin_functions/main.rs @@ -0,0 +1,94 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # These are miscellaneous function-related examples +//! +//! These examples demonstrate miscellaneous function-related features. +//! +//! ## Usage +//! ```bash +//! cargo run --example builtin_functions -- [all|date_time|function_factory|regexp] +//! ``` +//! +//! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module +//! +//! - `date_time` +//! (file: date_time.rs, desc: Examples of date-time related functions and queries) +//! +//! - `function_factory` +//! (file: function_factory.rs, desc: Register `CREATE FUNCTION` handler to implement SQL macros) +//! +//! - `regexp` +//! (file: regexp.rs, desc: Examples of using regular expression functions) + +mod date_time; +mod function_factory; +mod regexp; + +use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; + +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] +enum ExampleKind { + All, + DateTime, + FunctionFactory, + Regexp, +} + +impl ExampleKind { + const EXAMPLE_NAME: &str = "builtin_functions"; + + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<()> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::DateTime => date_time::date_time().await?, + ExampleKind::FunctionFactory => function_factory::function_factory().await?, + ExampleKind::Regexp => regexp::regexp().await?, + } + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let usage = format!( + "Usage: cargo run --example {} -- [{}]", + ExampleKind::EXAMPLE_NAME, + ExampleKind::VARIANTS.join("|") + ); + + let example: ExampleKind = std::env::args() + .nth(1) + .unwrap_or_else(|| ExampleKind::All.to_string()) + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; + + example.run().await +} diff --git a/datafusion-examples/examples/regexp.rs b/datafusion-examples/examples/builtin_functions/regexp.rs similarity index 74% rename from datafusion-examples/examples/regexp.rs rename to datafusion-examples/examples/builtin_functions/regexp.rs index 12d115b9b502c..97dc71b94e934 100644 --- a/datafusion-examples/examples/regexp.rs +++ b/datafusion-examples/examples/builtin_functions/regexp.rs @@ -1,5 +1,4 @@ // Licensed to the Apache Software Foundation (ASF) under one -// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file @@ -16,9 +15,12 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use datafusion::common::{assert_batches_eq, assert_contains}; use datafusion::error::Result; use datafusion::prelude::*; +use datafusion_examples::utils::datasets::ExampleDataset; /// This example demonstrates how to use the regexp_* functions /// @@ -28,15 +30,12 @@ use datafusion::prelude::*; /// /// Supported flags can be found at /// https://docs.rs/regex/latest/regex/#grouping-and-flags -#[tokio::main] -async fn main() -> Result<()> { +pub async fn regexp() -> Result<()> { let ctx = SessionContext::new(); - ctx.register_csv( - "examples", - "../../datafusion/physical-expr/tests/data/regex.csv", - CsvReadOptions::new(), - ) - .await?; + let dataset = ExampleDataset::Regex; + + ctx.register_csv("examples", dataset.path_str()?, CsvReadOptions::new()) + .await?; // // @@ -112,11 +111,11 @@ async fn main() -> Result<()> { assert_batches_eq!( &[ - "+---------------------------------------------------+----------------------------------------------------+", - "| regexp_like(Utf8(\"John Smith\"),Utf8(\"^.*Smith$\")) | regexp_like(Utf8(\"Smith Jones\"),Utf8(\"^Smith.*$\")) |", - "+---------------------------------------------------+----------------------------------------------------+", - "| true | true |", - "+---------------------------------------------------+----------------------------------------------------+", + "+---------------------------------------------------+----------------------------------------------------+", + "| regexp_like(Utf8(\"John Smith\"),Utf8(\"^.*Smith$\")) | regexp_like(Utf8(\"Smith Jones\"),Utf8(\"^Smith.*$\")) |", + "+---------------------------------------------------+----------------------------------------------------+", + "| true | true |", + "+---------------------------------------------------+----------------------------------------------------+", ], &result ); @@ -242,11 +241,11 @@ async fn main() -> Result<()> { assert_batches_eq!( &[ - "+----------------------------------------------------+-----------------------------------------------------+", - "| regexp_match(Utf8(\"John Smith\"),Utf8(\"^.*Smith$\")) | regexp_match(Utf8(\"Smith Jones\"),Utf8(\"^Smith.*$\")) |", - "+----------------------------------------------------+-----------------------------------------------------+", - "| [John Smith] | [Smith Jones] |", - "+----------------------------------------------------+-----------------------------------------------------+", + "+----------------------------------------------------+-----------------------------------------------------+", + "| regexp_match(Utf8(\"John Smith\"),Utf8(\"^.*Smith$\")) | regexp_match(Utf8(\"Smith Jones\"),Utf8(\"^Smith.*$\")) |", + "+----------------------------------------------------+-----------------------------------------------------+", + "| [John Smith] | [Smith Jones] |", + "+----------------------------------------------------+-----------------------------------------------------+", ], &result ); @@ -268,21 +267,21 @@ async fn main() -> Result<()> { assert_batches_eq!( &[ - "+---------------------------------------------------------------------------------------------------------+", - "| regexp_replace(examples.values,examples.patterns,examples.replacement,concat(Utf8(\"g\"),examples.flags)) |", - "+---------------------------------------------------------------------------------------------------------+", - "| bbabbbc |", - "| B |", - "| aec |", - "| AbC |", - "| aBC |", - "| 4000 |", - "| xyz |", - "| München |", - "| Moscow |", - "| Koln |", - "| Today |", - "+---------------------------------------------------------------------------------------------------------+", + "+---------------------------------------------------------------------------------------------------------+", + "| regexp_replace(examples.values,examples.patterns,examples.replacement,concat(Utf8(\"g\"),examples.flags)) |", + "+---------------------------------------------------------------------------------------------------------+", + "| bbabbbc |", + "| B |", + "| aec |", + "| AbC |", + "| aBC |", + "| 4000 |", + "| xyz |", + "| München |", + "| Moscow |", + "| Koln |", + "| Today |", + "+---------------------------------------------------------------------------------------------------------+", ], &result ); @@ -296,11 +295,11 @@ async fn main() -> Result<()> { assert_batches_eq!( &[ - "+------------------------------------------------------------------------+", - "| regexp_replace(Utf8(\"foobarbaz\"),Utf8(\"b(..)\"),Utf8(\"X\\1Y\"),Utf8(\"g\")) |", - "+------------------------------------------------------------------------+", - "| fooXarYXazY |", - "+------------------------------------------------------------------------+", + "+------------------------------------------------------------------------+", + "| regexp_replace(Utf8(\"foobarbaz\"),Utf8(\"b(..)\"),Utf8(\"X\\1Y\"),Utf8(\"g\")) |", + "+------------------------------------------------------------------------+", + "| fooXarYXazY |", + "+------------------------------------------------------------------------+", ], &result ); diff --git a/datafusion-examples/examples/custom_data_source/adapter_serialization.rs b/datafusion-examples/examples/custom_data_source/adapter_serialization.rs new file mode 100644 index 0000000000000..a2cd187fee067 --- /dev/null +++ b/datafusion-examples/examples/custom_data_source/adapter_serialization.rs @@ -0,0 +1,519 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. +//! +//! This example demonstrates how to use the `PhysicalExtensionCodec` trait's +//! interception methods (`serialize_physical_plan` and `deserialize_physical_plan`) +//! to implement custom serialization logic. +//! +//! The key insight is that `FileScanConfig::expr_adapter_factory` is NOT serialized by +//! default. This example shows how to: +//! 1. Detect plans with custom adapters during serialization +//! 2. Wrap them as Extension nodes with JSON-serialized adapter metadata +//! 3. Store the inner DataSourceExec (without adapter) as a child in the extension's inputs field +//! 4. Unwrap and restore the adapter during deserialization +//! +//! This demonstrates nested serialization (protobuf outer, JSON inner) and the power +//! of the `PhysicalExtensionCodec` interception pattern. Both plan and expression +//! serialization route through the codec, enabling interception at every node in the tree. + +use std::fmt::Debug; +use std::sync::Arc; + +use arrow::array::record_batch; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::assert_batches_eq; +use datafusion::common::{Result, not_impl_err}; +use datafusion::datasource::listing::{ + ListingTable, ListingTableConfig, ListingTableConfigExt, ListingTableUrl, +}; +use datafusion::datasource::physical_plan::{FileScanConfig, FileScanConfigBuilder}; +use datafusion::datasource::source::DataSourceExec; +use datafusion::execution::TaskContext; +use datafusion::execution::context::SessionContext; +use datafusion::execution::object_store::ObjectStoreUrl; +use datafusion::parquet::arrow::ArrowWriter; +use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::SessionConfig; +use datafusion_physical_expr_adapter::{ + DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory, +}; +use datafusion_proto::bytes::{ + physical_plan_from_bytes_with_proto_converter, + physical_plan_to_bytes_with_proto_converter, +}; +use datafusion_proto::physical_plan::from_proto::parse_physical_expr_with_converter; +use datafusion_proto::physical_plan::to_proto::serialize_physical_expr_with_converter; +use datafusion_proto::physical_plan::{ + PhysicalExtensionCodec, PhysicalProtoConverterExtension, +}; +use datafusion_proto::protobuf::physical_plan_node::PhysicalPlanType; +use datafusion_proto::protobuf::{ + PhysicalExprNode, PhysicalExtensionNode, PhysicalPlanNode, +}; +use object_store::memory::InMemory; +use object_store::path::Path; +use object_store::{ObjectStore, ObjectStoreExt, PutPayload}; +use serde::{Deserialize, Serialize}; + +/// Example showing how to preserve custom adapter information during plan serialization. +/// +/// This demonstrates: +/// 1. Creating a custom PhysicalExprAdapter with metadata +/// 2. Using PhysicalExtensionCodec to intercept serialization +/// 3. Wrapping adapter info as Extension nodes +/// 4. Restoring adapters during deserialization +pub async fn adapter_serialization() -> Result<()> { + println!("=== PhysicalExprAdapter Serialization Example ===\n"); + + // Step 1: Create sample Parquet data in memory + println!("Step 1: Creating sample Parquet data..."); + let store = Arc::new(InMemory::new()) as Arc; + let batch = record_batch!(("id", Int32, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))?; + let path = Path::from("data.parquet"); + write_parquet(&store, &path, &batch).await?; + + // Step 2: Set up session with custom adapter + println!("Step 2: Setting up session with custom adapter..."); + let logical_schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + + let mut cfg = SessionConfig::new(); + cfg.options_mut().execution.parquet.pushdown_filters = true; + let ctx = SessionContext::new_with_config(cfg); + ctx.runtime_env().register_object_store( + ObjectStoreUrl::parse("memory://")?.as_ref(), + Arc::clone(&store), + ); + + // Create a table with our custom MetadataAdapterFactory + let adapter_factory = Arc::new(MetadataAdapterFactory::new("v1")); + let listing_config = + ListingTableConfig::new(ListingTableUrl::parse("memory:///data.parquet")?) + .infer_options(&ctx.state()) + .await? + .with_schema(logical_schema) + .with_expr_adapter_factory( + Arc::clone(&adapter_factory) as Arc + ); + let table = ListingTable::try_new(listing_config)?; + ctx.register_table("my_table", Arc::new(table))?; + + // Step 3: Create physical plan with filter + println!("Step 3: Creating physical plan with filter..."); + let df = ctx.sql("SELECT * FROM my_table WHERE id > 5").await?; + let original_plan = df.create_physical_plan().await?; + + // Verify adapter is present in original plan + let has_adapter_before = verify_adapter_in_plan(&original_plan, "original"); + println!(" Original plan has adapter: {has_adapter_before}"); + + // Step 4: Serialize with our custom codec + println!("\nStep 4: Serializing plan with AdapterPreservingCodec..."); + let codec = AdapterPreservingCodec; + let bytes = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&original_plan), + &codec, + &codec, + )?; + println!(" Serialized {} bytes", bytes.len()); + println!(" (DataSourceExec with adapter was wrapped as PhysicalExtensionNode)"); + + // Step 5: Deserialize with our custom codec + println!("\nStep 5: Deserializing plan with AdapterPreservingCodec..."); + let task_ctx = ctx.task_ctx(); + let restored_plan = + physical_plan_from_bytes_with_proto_converter(&bytes, &task_ctx, &codec, &codec)?; + + // Verify adapter is restored + let has_adapter_after = verify_adapter_in_plan(&restored_plan, "restored"); + println!(" Restored plan has adapter: {has_adapter_after}"); + + // Step 6: Execute and compare results + println!("\nStep 6: Executing plans and comparing results..."); + let original_results = + datafusion::physical_plan::collect(Arc::clone(&original_plan), task_ctx.clone()) + .await?; + let restored_results = + datafusion::physical_plan::collect(restored_plan, task_ctx).await?; + + #[rustfmt::skip] + let expected = [ + "+----+", + "| id |", + "+----+", + "| 6 |", + "| 7 |", + "| 8 |", + "| 9 |", + "| 10 |", + "+----+", + ]; + + println!("\n Original plan results:"); + arrow::util::pretty::print_batches(&original_results)?; + assert_batches_eq!(expected, &original_results); + + println!("\n Restored plan results:"); + arrow::util::pretty::print_batches(&restored_results)?; + assert_batches_eq!(expected, &restored_results); + + println!("\n=== Example Complete! ==="); + println!("Key takeaways:"); + println!( + " 1. PhysicalExtensionCodec provides serialize_physical_plan/deserialize_physical_plan hooks" + ); + println!(" 2. Custom metadata can be wrapped as PhysicalExtensionNode"); + println!(" 3. Nested serialization (protobuf + JSON) works seamlessly"); + println!( + " 4. Both plans produce identical results despite serialization round-trip" + ); + println!(" 5. Adapters are fully preserved through the serialization round-trip"); + + Ok(()) +} + +// ============================================================================ +// MetadataAdapter - A simple custom adapter with a tag +// ============================================================================ + +/// A custom PhysicalExprAdapter that wraps another adapter. +/// The tag metadata is stored in the factory, not the adapter itself. +#[derive(Debug)] +struct MetadataAdapter { + inner: Arc, +} + +impl PhysicalExprAdapter for MetadataAdapter { + fn rewrite(&self, expr: Arc) -> Result> { + // Simply delegate to inner adapter + self.inner.rewrite(expr) + } +} + +// ============================================================================ +// MetadataAdapterFactory - Factory for creating MetadataAdapter instances +// ============================================================================ + +/// Factory for creating MetadataAdapter instances. +/// The tag is stored in the factory and extracted via Debug formatting in `extract_adapter_tag`. +#[derive(Debug)] +struct MetadataAdapterFactory { + // Note: This field is read via Debug formatting in `extract_adapter_tag`. + // Rust's dead code analysis doesn't recognize Debug-based field access. + // In PR #19234, this field is used by `with_partition_values`, but that method + // doesn't exist in upstream DataFusion's PhysicalExprAdapter trait. + #[expect(dead_code)] + tag: String, +} + +impl MetadataAdapterFactory { + fn new(tag: impl Into) -> Self { + Self { tag: tag.into() } + } +} + +impl PhysicalExprAdapterFactory for MetadataAdapterFactory { + fn create( + &self, + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + ) -> Result> { + let inner = DefaultPhysicalExprAdapterFactory + .create(logical_file_schema, physical_file_schema)?; + Ok(Arc::new(MetadataAdapter { inner })) + } +} + +// ============================================================================ +// AdapterPreservingCodec - Custom codec that preserves adapters +// ============================================================================ + +/// Extension payload structure for serializing adapter info +#[derive(Serialize, Deserialize)] +struct ExtensionPayload { + /// Marker to identify this is our custom extension + marker: String, + /// JSON-serialized adapter metadata + adapter_metadata: AdapterMetadata, +} + +/// Metadata about the adapter to recreate it during deserialization +#[derive(Serialize, Deserialize)] +struct AdapterMetadata { + /// The adapter tag (e.g., "v1") + tag: String, +} + +const EXTENSION_MARKER: &str = "adapter_preserving_extension_v1"; + +/// A codec that intercepts serialization to preserve adapter information. +#[derive(Debug)] +struct AdapterPreservingCodec; + +impl PhysicalExtensionCodec for AdapterPreservingCodec { + // Required method: decode custom extension nodes + fn try_decode( + &self, + buf: &[u8], + inputs: &[Arc], + _ctx: &TaskContext, + ) -> Result> { + // Try to parse as our extension payload + if let Ok(payload) = serde_json::from_slice::(buf) + && payload.marker == EXTENSION_MARKER + { + if inputs.len() != 1 { + return Err(datafusion::error::DataFusionError::Plan(format!( + "Extension node expected exactly 1 child, got {}", + inputs.len() + ))); + } + let inner_plan = inputs[0].clone(); + + // Recreate the adapter factory + let adapter_factory = create_adapter_factory(&payload.adapter_metadata.tag); + + // Inject adapter into the plan + return inject_adapter_into_plan(inner_plan, adapter_factory); + } + + not_impl_err!("Unknown extension type") + } + + // Required method: encode custom execution plans + fn try_encode( + &self, + _node: Arc, + _buf: &mut Vec, + ) -> Result<()> { + // We don't need this for the example - we use serialize_physical_plan instead + not_impl_err!( + "try_encode not used - adapter wrapping happens in serialize_physical_plan" + ) + } +} + +impl PhysicalProtoConverterExtension for AdapterPreservingCodec { + fn execution_plan_to_proto( + &self, + plan: &Arc, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + // Check if this is a DataSourceExec with adapter + if let Some(exec) = plan.as_any().downcast_ref::() + && let Some(config) = + exec.data_source().as_any().downcast_ref::() + && let Some(adapter_factory) = &config.expr_adapter_factory + && let Some(tag) = extract_adapter_tag(adapter_factory.as_ref()) + { + // Try to extract our MetadataAdapterFactory's tag + println!(" [Serialize] Found DataSourceExec with adapter tag: {tag}"); + + // 1. Create adapter metadata + let adapter_metadata = AdapterMetadata { tag }; + + // 2. Serialize the inner plan to protobuf + // Note that this will drop the custom adapter since the default serialization cannot handle it + let inner_proto = PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(plan), + extension_codec, + self, + )?; + + // 3. Create extension payload to wrap the plan + // so that the custom adapter gets re-attached during deserialization + // The choice of JSON is arbitrary; other formats could be used. + let payload = ExtensionPayload { + marker: EXTENSION_MARKER.to_string(), + adapter_metadata, + }; + let payload_bytes = serde_json::to_vec(&payload).map_err(|e| { + datafusion::error::DataFusionError::Plan(format!( + "Failed to serialize payload: {e}" + )) + })?; + + // 4. Return as PhysicalExtensionNode with child plan in inputs + return Ok(PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Extension( + PhysicalExtensionNode { + node: payload_bytes, + inputs: vec![inner_proto], + }, + )), + }); + } + + // No adapter found, not a DataSourceExec, etc. - use default serialization + PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(plan), + extension_codec, + self, + ) + } + + // Interception point: override deserialization to unwrap adapters + fn proto_to_execution_plan( + &self, + ctx: &TaskContext, + extension_codec: &dyn PhysicalExtensionCodec, + proto: &PhysicalPlanNode, + ) -> Result> { + // Check if this is our custom extension wrapper + if let Some(PhysicalPlanType::Extension(extension)) = &proto.physical_plan_type + && let Ok(payload) = + serde_json::from_slice::(&extension.node) + && payload.marker == EXTENSION_MARKER + { + println!( + " [Deserialize] Found adapter extension with tag: {}", + payload.adapter_metadata.tag + ); + + // Get the inner plan proto from inputs field + if extension.inputs.is_empty() { + return Err(datafusion::error::DataFusionError::Plan( + "Extension node missing child plan in inputs".to_string(), + )); + } + let inner_proto = &extension.inputs[0]; + + // Deserialize the inner plan + let inner_plan = inner_proto.try_into_physical_plan_with_converter( + ctx, + extension_codec, + self, + )?; + + // Recreate the adapter factory + let adapter_factory = create_adapter_factory(&payload.adapter_metadata.tag); + + // Inject adapter into the plan + return inject_adapter_into_plan(inner_plan, adapter_factory); + } + + // Not our extension - use default deserialization + proto.try_into_physical_plan_with_converter(ctx, extension_codec, self) + } + + fn proto_to_physical_expr( + &self, + proto: &PhysicalExprNode, + ctx: &TaskContext, + input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + parse_physical_expr_with_converter(proto, ctx, input_schema, codec, self) + } + + fn physical_expr_to_proto( + &self, + expr: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result { + serialize_physical_expr_with_converter(expr, codec, self) + } +} + +// ============================================================================ +// Helper functions +// ============================================================================ + +/// Write a RecordBatch to Parquet in the object store +async fn write_parquet( + store: &dyn ObjectStore, + path: &Path, + batch: &arrow::record_batch::RecordBatch, +) -> Result<()> { + let mut buf = vec![]; + let mut writer = ArrowWriter::try_new(&mut buf, batch.schema(), None)?; + writer.write(batch)?; + writer.close()?; + + let payload = PutPayload::from_bytes(buf.into()); + store.put(path, payload).await?; + Ok(()) +} + +/// Extract the tag from a MetadataAdapterFactory. +/// +/// Note: Since `PhysicalExprAdapterFactory` doesn't provide `as_any()` for downcasting, +/// we parse the Debug output. In a production system, you might add a dedicated trait +/// method for metadata extraction. +fn extract_adapter_tag(factory: &dyn PhysicalExprAdapterFactory) -> Option { + let debug_str = format!("{factory:?}"); + if debug_str.contains("MetadataAdapterFactory") { + // Extract tag from debug output: MetadataAdapterFactory { tag: "v1" } + if let Some(start) = debug_str.find("tag: \"") { + let after_tag = &debug_str[start + 6..]; + if let Some(end) = after_tag.find('"') { + return Some(after_tag[..end].to_string()); + } + } + } + None +} + +/// Create an adapter factory from a tag +fn create_adapter_factory(tag: &str) -> Arc { + Arc::new(MetadataAdapterFactory::new(tag)) +} + +/// Inject an adapter into a plan (assumes plan is a DataSourceExec with FileScanConfig) +fn inject_adapter_into_plan( + plan: Arc, + adapter_factory: Arc, +) -> Result> { + if let Some(exec) = plan.as_any().downcast_ref::() + && let Some(config) = exec.data_source().as_any().downcast_ref::() + { + let new_config = FileScanConfigBuilder::from(config.clone()) + .with_expr_adapter(Some(adapter_factory)) + .build(); + return Ok(DataSourceExec::from_data_source(new_config)); + } + // If not a DataSourceExec with FileScanConfig, return as-is + Ok(plan) +} + +/// Helper to verify if a plan has an adapter (for testing/validation) +fn verify_adapter_in_plan(plan: &Arc, label: &str) -> bool { + // Walk the plan tree to find DataSourceExec with adapter + fn check_plan(plan: &dyn ExecutionPlan) -> bool { + if let Some(exec) = plan.as_any().downcast_ref::() + && let Some(config) = + exec.data_source().as_any().downcast_ref::() + && config.expr_adapter_factory.is_some() + { + return true; + } + // Check children + for child in plan.children() { + if check_plan(child.as_ref()) { + return true; + } + } + false + } + + let has_adapter = check_plan(plan.as_ref()); + println!(" [Verify] {label} plan adapter check: {has_adapter}"); + has_adapter +} diff --git a/datafusion-examples/examples/csv_json_opener.rs b/datafusion-examples/examples/custom_data_source/csv_json_opener.rs similarity index 66% rename from datafusion-examples/examples/csv_json_opener.rs rename to datafusion-examples/examples/custom_data_source/csv_json_opener.rs index 1a2c2cbff4183..4804586382dc2 100644 --- a/datafusion-examples/examples/csv_json_opener.rs +++ b/datafusion-examples/examples/custom_data_source/csv_json_opener.rs @@ -15,32 +15,36 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use std::sync::Arc; use arrow::datatypes::{DataType, Field, Schema}; +use datafusion::common::config::CsvOptions; use datafusion::{ assert_batches_eq, datasource::{ file_format::file_compression_type::FileCompressionType, listing::PartitionedFile, object_store::ObjectStoreUrl, - physical_plan::{CsvSource, FileSource, FileStream, JsonOpener, JsonSource}, + physical_plan::{ + CsvSource, FileSource, FileStreamBuilder, JsonOpener, JsonSource, + }, }, error::Result, physical_plan::metrics::ExecutionPlanMetricsSet, - test_util::aggr_test_schema, }; use datafusion::datasource::physical_plan::FileScanConfigBuilder; +use datafusion_examples::utils::datasets::ExampleDataset; use futures::StreamExt; -use object_store::{local::LocalFileSystem, memory::InMemory, ObjectStore}; +use object_store::{ObjectStoreExt, local::LocalFileSystem, memory::InMemory}; /// This example demonstrates using the low level [`FileStream`] / [`FileOpener`] APIs to directly /// read data from (CSV/JSON) into Arrow RecordBatches. /// /// If you want to query data in CSV or JSON files, see the [`dataframe.rs`] and [`sql_query.rs`] examples -#[tokio::main] -async fn main() -> Result<()> { +pub async fn csv_json_opener() -> Result<()> { csv_opener().await?; json_opener().await?; Ok(()) @@ -48,48 +52,53 @@ async fn main() -> Result<()> { async fn csv_opener() -> Result<()> { let object_store = Arc::new(LocalFileSystem::new()); - let schema = aggr_test_schema(); - let testdata = datafusion::test_util::arrow_test_data(); - let path = format!("{testdata}/csv/aggregate_test_100.csv"); + let dataset = ExampleDataset::Cars; + let csv_path = dataset.path(); + let schema = dataset.schema(); - let path = std::path::Path::new(&path).canonicalize()?; + let options = CsvOptions { + has_header: Some(true), + delimiter: b',', + quote: b'"', + ..Default::default() + }; - let scan_config = FileScanConfigBuilder::new( - ObjectStoreUrl::local_filesystem(), - Arc::clone(&schema), - Arc::new(CsvSource::default()), - ) - .with_projection(Some(vec![12, 0])) - .with_limit(Some(5)) - .with_file(PartitionedFile::new(path.display().to_string(), 10)) - .build(); - - let config = CsvSource::new(true, b',', b'"') + let source = CsvSource::new(Arc::clone(&schema)) + .with_csv_options(options) .with_comment(Some(b'#')) - .with_schema(schema) - .with_batch_size(8192) - .with_projection(&scan_config); + .with_batch_size(8192); - let opener = config.create_file_opener(object_store, &scan_config, 0); + let scan_config = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), source) + .with_projection_indices(Some(vec![0, 1]))? + .with_limit(Some(5)) + .with_file(PartitionedFile::new(csv_path.display().to_string(), 10)) + .build(); + + let opener = + scan_config + .file_source() + .create_file_opener(object_store, &scan_config, 0)?; let mut result = vec![]; let mut stream = - FileStream::new(&scan_config, 0, opener, &ExecutionPlanMetricsSet::new())?; + FileStreamBuilder::new(&scan_config, 0, opener, &ExecutionPlanMetricsSet::new()) + .build()?; while let Some(batch) = stream.next().await.transpose()? { result.push(batch); } assert_batches_eq!( &[ - "+--------------------------------+----+", - "| c13 | c1 |", - "+--------------------------------+----+", - "| 6WfVFBVGJSQb7FhA7E0lBwdvjfZnSW | c |", - "| C2GT5KVyOPZpgKVl110TyZO0NcJ434 | d |", - "| AyYVExXK6AR2qUTxNZ7qRHQOVGMLcz | b |", - "| 0keZ5G8BffGwgF2RwQD59TFzMStxCB | a |", - "| Ig1QcuKsjHXkproePdERo2w0mYzIqd | b |", - "+--------------------------------+----+", + "+-----+-------+", + "| car | speed |", + "+-----+-------+", + "| red | 20.0 |", + "| red | 20.3 |", + "| red | 21.4 |", + "| red | 21.5 |", + "| red | 19.0 |", + "+-----+-------+", ], &result ); @@ -119,24 +128,25 @@ async fn json_opener() -> Result<()> { projected, FileCompressionType::UNCOMPRESSED, Arc::new(object_store), + true, ); let scan_config = FileScanConfigBuilder::new( ObjectStoreUrl::local_filesystem(), - schema, - Arc::new(JsonSource::default()), + Arc::new(JsonSource::new(schema)), ) - .with_projection(Some(vec![1, 0])) + .with_projection_indices(Some(vec![1, 0]))? .with_limit(Some(5)) .with_file(PartitionedFile::new(path.to_string(), 10)) .build(); - let mut stream = FileStream::new( + let mut stream = FileStreamBuilder::new( &scan_config, 0, Arc::new(opener), &ExecutionPlanMetricsSet::new(), - )?; + ) + .build()?; let mut result = vec![]; while let Some(batch) = stream.next().await.transpose()? { result.push(batch); diff --git a/datafusion-examples/examples/csv_sql_streaming.rs b/datafusion-examples/examples/custom_data_source/csv_sql_streaming.rs similarity index 82% rename from datafusion-examples/examples/csv_sql_streaming.rs rename to datafusion-examples/examples/custom_data_source/csv_sql_streaming.rs index 99264bbcb486d..4692086a10b26 100644 --- a/datafusion-examples/examples/csv_sql_streaming.rs +++ b/datafusion-examples/examples/custom_data_source/csv_sql_streaming.rs @@ -15,44 +15,46 @@ // specific language governing permissions and limitations // under the License. -use datafusion::common::test_util::datafusion_test_data; +//! See `main.rs` for how to run it. + use datafusion::error::Result; use datafusion::prelude::*; +use datafusion_examples::utils::datasets::ExampleDataset; /// This example demonstrates executing a simple query against an Arrow data source (CSV) and /// fetching results with streaming aggregation and streaming window -#[tokio::main] -async fn main() -> Result<()> { +pub async fn csv_sql_streaming() -> Result<()> { // create local execution context let ctx = SessionContext::new(); - let testdata = datafusion_test_data(); + let dataset = ExampleDataset::Cars; + let csv_path = dataset.path(); - // Register a table source and tell DataFusion the file is ordered by `ts ASC`. + // Register a table source and tell DataFusion the file is ordered by `car ASC`. // Note it is the responsibility of the user to make sure // that file indeed satisfies this condition or else incorrect answers may be produced. let asc = true; let nulls_first = true; - let sort_expr = vec![col("ts").sort(asc, nulls_first)]; + let sort_expr = vec![col("car").sort(asc, nulls_first)]; // register csv file with the execution context ctx.register_csv( "ordered_table", - &format!("{testdata}/window_1.csv"), + csv_path.to_str().unwrap(), CsvReadOptions::new().file_sort_order(vec![sort_expr]), ) .await?; // execute the query - // Following query can be executed with unbounded sources because group by expressions (e.g ts) is + // Following query can be executed with unbounded sources because group by expressions (e.g car) is // already ordered at the source. // // Unbounded sources means that if the input came from a "never ending" source (such as a FIFO // file on unix) the query could produce results incrementally as data was read. let df = ctx .sql( - "SELECT ts, MIN(inc_col), MAX(inc_col) \ + "SELECT car, MIN(speed), MAX(speed) \ FROM ordered_table \ - GROUP BY ts", + GROUP BY car", ) .await?; @@ -63,7 +65,7 @@ async fn main() -> Result<()> { // its result in streaming fashion, because its required ordering is already satisfied at the source. let df = ctx .sql( - "SELECT ts, SUM(inc_col) OVER(ORDER BY ts ASC) \ + "SELECT car, SUM(speed) OVER(ORDER BY car ASC) \ FROM ordered_table", ) .await?; diff --git a/datafusion-examples/examples/custom_datasource.rs b/datafusion-examples/examples/custom_data_source/custom_datasource.rs similarity index 87% rename from datafusion-examples/examples/custom_datasource.rs rename to datafusion-examples/examples/custom_data_source/custom_datasource.rs index bc865fac5a338..71e589dcf6e88 100644 --- a/datafusion-examples/examples/custom_datasource.rs +++ b/datafusion-examples/examples/custom_data_source/custom_datasource.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use std::any::Any; use std::collections::{BTreeMap, HashMap}; use std::fmt::{self, Debug, Formatter}; @@ -22,10 +24,11 @@ use std::sync::{Arc, Mutex}; use std::time::Duration; use async_trait::async_trait; -use datafusion::arrow::array::{UInt64Builder, UInt8Builder}; +use datafusion::arrow::array::{UInt8Builder, UInt64Builder}; use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::arrow::record_batch::RecordBatch; -use datafusion::datasource::{provider_as_source, TableProvider, TableType}; +use datafusion::common::tree_node::TreeNodeRecursion; +use datafusion::datasource::{TableProvider, TableType, provider_as_source}; use datafusion::error::Result; use datafusion::execution::context::TaskContext; use datafusion::logical_expr::LogicalPlanBuilder; @@ -33,8 +36,8 @@ use datafusion::physical_expr::EquivalenceProperties; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion::physical_plan::memory::MemoryStream; use datafusion::physical_plan::{ - project_schema, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, - PlanProperties, SendableRecordBatchStream, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, + SendableRecordBatchStream, project_schema, }; use datafusion::prelude::*; @@ -42,8 +45,7 @@ use datafusion::catalog::Session; use tokio::time::timeout; /// This example demonstrates executing a simple query against a custom datasource -#[tokio::main] -async fn main() -> Result<()> { +pub async fn custom_datasource() -> Result<()> { // create our custom datasource and adding some users let db = CustomDataSource::default(); db.populate_users(); @@ -191,10 +193,11 @@ impl TableProvider for CustomDataSource { struct CustomExec { db: CustomDataSource, projected_schema: SchemaRef, - cache: PlanProperties, + cache: Arc, } impl CustomExec { + #[expect(clippy::needless_pass_by_value)] fn new( projections: Option<&Vec>, schema: SchemaRef, @@ -205,7 +208,7 @@ impl CustomExec { Self { db, projected_schema, - cache, + cache: Arc::new(cache), } } @@ -236,7 +239,7 @@ impl ExecutionPlan for CustomExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -281,4 +284,20 @@ impl ExecutionPlan for CustomExec { None, )?)) } + + fn apply_expressions( + &self, + f: &mut dyn FnMut( + &dyn datafusion::physical_plan::PhysicalExpr, + ) -> Result, + ) -> Result { + // Visit expressions in the output ordering from equivalence properties + let mut tnr = TreeNodeRecursion::Continue; + if let Some(ordering) = self.cache.output_ordering() { + for sort_expr in ordering { + tnr = tnr.visit_sibling(|| f(sort_expr.expr.as_ref()))?; + } + } + Ok(tnr) + } } diff --git a/datafusion-examples/examples/custom_data_source/custom_file_casts.rs b/datafusion-examples/examples/custom_data_source/custom_file_casts.rs new file mode 100644 index 0000000000000..6b37db653e35d --- /dev/null +++ b/datafusion-examples/examples/custom_data_source/custom_file_casts.rs @@ -0,0 +1,212 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. + +use std::sync::Arc; + +use arrow::array::{RecordBatch, record_batch}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + +use datafusion::assert_batches_eq; +use datafusion::common::Result; +use datafusion::common::not_impl_err; +use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion::datasource::listing::{ + ListingTable, ListingTableConfig, ListingTableConfigExt, ListingTableUrl, +}; +use datafusion::execution::context::SessionContext; +use datafusion::execution::object_store::ObjectStoreUrl; +use datafusion::parquet::arrow::ArrowWriter; +use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_expr::expressions::{CastColumnExpr, CastExpr}; +use datafusion::prelude::SessionConfig; +use datafusion_physical_expr_adapter::{ + DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory, +}; +use object_store::memory::InMemory; +use object_store::path::Path; +use object_store::{ObjectStore, ObjectStoreExt, PutPayload}; + +// Example showing how to implement custom casting rules to adapt file schemas. +// This example enforces that casts must be strictly widening: if the file type is Int64 and the table type is Int32, it will error +// before even reading the data. +// Without this custom cast rule DataFusion would happily do the narrowing cast, potentially erroring only if it found a row with data it could not cast. +pub async fn custom_file_casts() -> Result<()> { + println!("=== Creating example data ==="); + + // Create a logical / table schema with an Int32 column (nullable) + let logical_schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, true)])); + + // Create some data that can be cast (Int16 -> Int32 is widening) and some that cannot (Int64 -> Int32 is narrowing) + let store = Arc::new(InMemory::new()) as Arc; + let path = Path::from("good.parquet"); + let batch = record_batch!(("id", Int16, [1, 2, 3]))?; + write_data(&store, &path, &batch).await?; + let path = Path::from("bad.parquet"); + let batch = record_batch!(("id", Int64, [1, 2, 3]))?; + write_data(&store, &path, &batch).await?; + + // Set up query execution + let mut cfg = SessionConfig::new(); + // Turn on filter pushdown so that the PhysicalExprAdapter is used + cfg.options_mut().execution.parquet.pushdown_filters = true; + let ctx = SessionContext::new_with_config(cfg); + ctx.runtime_env() + .register_object_store(ObjectStoreUrl::parse("memory://")?.as_ref(), store); + + // Register our good and bad files via ListingTable + let listing_table_config = + ListingTableConfig::new(ListingTableUrl::parse("memory:///good.parquet")?) + .infer_options(&ctx.state()) + .await? + .with_schema(Arc::clone(&logical_schema)) + .with_expr_adapter_factory(Arc::new( + CustomCastPhysicalExprAdapterFactory::new(Arc::new( + DefaultPhysicalExprAdapterFactory, + )), + )); + let table = ListingTable::try_new(listing_table_config).unwrap(); + ctx.register_table("good_table", Arc::new(table))?; + let listing_table_config = + ListingTableConfig::new(ListingTableUrl::parse("memory:///bad.parquet")?) + .infer_options(&ctx.state()) + .await? + .with_schema(Arc::clone(&logical_schema)) + .with_expr_adapter_factory(Arc::new( + CustomCastPhysicalExprAdapterFactory::new(Arc::new( + DefaultPhysicalExprAdapterFactory, + )), + )); + let table = ListingTable::try_new(listing_table_config).unwrap(); + ctx.register_table("bad_table", Arc::new(table))?; + + println!("\n=== File with narrower schema is cast ==="); + let query = "SELECT id FROM good_table WHERE id > 1"; + println!("Query: {query}"); + let batches = ctx.sql(query).await?.collect().await?; + #[rustfmt::skip] + let expected = [ + "+----+", + "| id |", + "+----+", + "| 2 |", + "| 3 |", + "+----+", + ]; + arrow::util::pretty::print_batches(&batches)?; + assert_batches_eq!(expected, &batches); + + println!("\n=== File with wider schema errors ==="); + let query = "SELECT id FROM bad_table WHERE id > 1"; + println!("Query: {query}"); + match ctx.sql(query).await?.collect().await { + Ok(_) => panic!("Expected error for narrowing cast, but query succeeded"), + Err(e) => { + println!("Caught expected error: {e}"); + } + } + Ok(()) +} + +async fn write_data( + store: &dyn ObjectStore, + path: &Path, + batch: &RecordBatch, +) -> Result<()> { + let mut buf = vec![]; + let mut writer = ArrowWriter::try_new(&mut buf, batch.schema(), None)?; + writer.write(batch)?; + writer.close()?; + + let payload = PutPayload::from_bytes(buf.into()); + store.put(path, payload).await?; + Ok(()) +} + +/// Factory for creating DefaultValuePhysicalExprAdapter instances +#[derive(Debug)] +struct CustomCastPhysicalExprAdapterFactory { + inner: Arc, +} + +impl CustomCastPhysicalExprAdapterFactory { + fn new(inner: Arc) -> Self { + Self { inner } + } +} + +impl PhysicalExprAdapterFactory for CustomCastPhysicalExprAdapterFactory { + fn create( + &self, + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + ) -> Result> { + let inner = self + .inner + .create(logical_file_schema, Arc::clone(&physical_file_schema))?; + Ok(Arc::new(CustomCastsPhysicalExprAdapter { + physical_file_schema, + inner, + })) + } +} + +/// Custom PhysicalExprAdapter that handles missing columns with default values from metadata +/// and wraps DefaultPhysicalExprAdapter for standard schema adaptation +#[derive(Debug, Clone)] +struct CustomCastsPhysicalExprAdapter { + physical_file_schema: SchemaRef, + inner: Arc, +} + +impl PhysicalExprAdapter for CustomCastsPhysicalExprAdapter { + fn rewrite(&self, mut expr: Arc) -> Result> { + // First delegate to the inner adapter to handle missing columns and discover any necessary casts + expr = self.inner.rewrite(expr)?; + // Now we can apply custom casting rules or even swap out all CastExprs for a custom cast kernel / expression + // For example, [DataFusion Comet](https://github.com/apache/datafusion-comet) has a [custom cast kernel](https://github.com/apache/datafusion-comet/blob/b4ac876ab420ed403ac7fc8e1b29f42f1f442566/native/spark-expr/src/conversion_funcs/cast.rs#L133-L138). + expr.transform(|expr| { + if let Some(cast) = expr.as_any().downcast_ref::() { + let input_data_type = + cast.expr().data_type(&self.physical_file_schema)?; + let output_data_type = cast.data_type(&self.physical_file_schema)?; + if !cast.is_bigger_cast(&input_data_type) { + return not_impl_err!( + "Unsupported CAST from {input_data_type} to {output_data_type}" + ); + } + } + if let Some(cast) = expr.as_any().downcast_ref::() { + let input_data_type = + cast.expr().data_type(&self.physical_file_schema)?; + let output_data_type = cast.data_type(&self.physical_file_schema)?; + if !CastExpr::check_bigger_cast( + cast.target_field().data_type(), + &input_data_type, + ) { + return not_impl_err!( + "Unsupported CAST from {input_data_type} to {output_data_type}" + ); + } + } + Ok(Transformed::no(expr)) + }) + .data() + } +} diff --git a/datafusion-examples/examples/custom_file_format.rs b/datafusion-examples/examples/custom_data_source/custom_file_format.rs similarity index 89% rename from datafusion-examples/examples/custom_file_format.rs rename to datafusion-examples/examples/custom_data_source/custom_file_format.rs index ac1e643517685..6817beec41188 100644 --- a/datafusion-examples/examples/custom_file_format.rs +++ b/datafusion-examples/examples/custom_data_source/custom_file_format.rs @@ -15,33 +15,33 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use std::{any::Any, sync::Arc}; use arrow::{ array::{AsArray, RecordBatch, StringArray, UInt8Array}, datatypes::{DataType, Field, Schema, SchemaRef, UInt64Type}, }; -use datafusion::physical_expr::LexRequirement; use datafusion::{ catalog::Session, common::{GetExt, Statistics}, -}; -use datafusion::{ - datasource::physical_plan::FileSource, execution::session_state::SessionStateBuilder, -}; -use datafusion::{ datasource::{ + MemTable, file_format::{ - csv::CsvFormatFactory, file_compression_type::FileCompressionType, - FileFormat, FileFormatFactory, + FileFormat, FileFormatFactory, csv::CsvFormatFactory, + file_compression_type::FileCompressionType, }, - physical_plan::{FileScanConfig, FileSinkConfig}, - MemTable, + physical_plan::{FileScanConfig, FileSinkConfig, FileSource}, + table_schema::TableSchema, }, error::Result, + execution::session_state::SessionStateBuilder, + physical_expr_common::sort_expr::LexRequirement, physical_plan::ExecutionPlan, prelude::SessionContext, }; + use object_store::{ObjectMeta, ObjectStore}; use tempfile::tempdir; @@ -50,6 +50,42 @@ use tempfile::tempdir; /// TSVFileFormatFactory is responsible for creating instances of TSVFileFormat. /// The former, once registered with the SessionState, will then be used /// to facilitate SQL operations on TSV files, such as `COPY TO` shown here. +pub async fn custom_file_format() -> Result<()> { + // Create a new context with the default configuration + let mut state = SessionStateBuilder::new().with_default_features().build(); + + // Register the custom file format + let file_format = Arc::new(TSVFileFactory::new()); + state.register_file_format(file_format, true)?; + + // Create a new context with the custom file format + let ctx = SessionContext::new_with_state(state); + + let mem_table = create_mem_table(); + ctx.register_table("mem_table", mem_table)?; + + let temp_dir = tempdir().unwrap(); + let table_save_path = temp_dir.path().join("mem_table.tsv"); + + let d = ctx + .sql(&format!( + "COPY mem_table TO '{}' STORED AS TSV;", + table_save_path.display(), + )) + .await?; + + let results = d.collect().await?; + println!( + "Number of inserted rows: {:?}", + (results[0] + .column_by_name("count") + .unwrap() + .as_primitive::() + .value(0)) + ); + + Ok(()) +} #[derive(Debug)] /// Custom file format that reads and writes TSV files @@ -84,6 +120,10 @@ impl FileFormat for TSVFileFormat { } } + fn compression_type(&self) -> Option { + None + } + async fn infer_schema( &self, state: &dyn Session, @@ -127,8 +167,8 @@ impl FileFormat for TSVFileFormat { .await } - fn file_source(&self) -> Arc { - self.csv_file_format.file_source() + fn file_source(&self, table_schema: TableSchema) -> Arc { + self.csv_file_format.file_source(table_schema) } } @@ -179,44 +219,6 @@ impl GetExt for TSVFileFactory { } } -#[tokio::main] -async fn main() -> Result<()> { - // Create a new context with the default configuration - let mut state = SessionStateBuilder::new().with_default_features().build(); - - // Register the custom file format - let file_format = Arc::new(TSVFileFactory::new()); - state.register_file_format(file_format, true).unwrap(); - - // Create a new context with the custom file format - let ctx = SessionContext::new_with_state(state); - - let mem_table = create_mem_table(); - ctx.register_table("mem_table", mem_table).unwrap(); - - let temp_dir = tempdir().unwrap(); - let table_save_path = temp_dir.path().join("mem_table.tsv"); - - let d = ctx - .sql(&format!( - "COPY mem_table TO '{}' STORED AS TSV;", - table_save_path.display(), - )) - .await?; - - let results = d.collect().await?; - println!( - "Number of inserted rows: {:?}", - (results[0] - .column_by_name("count") - .unwrap() - .as_primitive::() - .value(0)) - ); - - Ok(()) -} - // create a simple mem table fn create_mem_table() -> Arc { let fields = vec![ diff --git a/datafusion-examples/examples/custom_data_source/default_column_values.rs b/datafusion-examples/examples/custom_data_source/default_column_values.rs new file mode 100644 index 0000000000000..40c8836c1f822 --- /dev/null +++ b/datafusion-examples/examples/custom_data_source/default_column_values.rs @@ -0,0 +1,335 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. + +use std::any::Any; +use std::collections::HashMap; +use std::sync::Arc; + +use arrow::array::RecordBatch; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use async_trait::async_trait; + +use datafusion::assert_batches_eq; +use datafusion::catalog::memory::DataSourceExec; +use datafusion::catalog::{Session, TableProvider}; +use datafusion::common::DFSchema; +use datafusion::common::{Result, ScalarValue}; +use datafusion::datasource::listing::PartitionedFile; +use datafusion::datasource::physical_plan::{FileScanConfigBuilder, ParquetSource}; +use datafusion::execution::context::SessionContext; +use datafusion::execution::object_store::ObjectStoreUrl; +use datafusion::logical_expr::utils::conjunction; +use datafusion::logical_expr::{Expr, TableProviderFilterPushDown, TableType}; +use datafusion::parquet::arrow::ArrowWriter; +use datafusion::parquet::file::properties::WriterProperties; +use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::{SessionConfig, lit}; +use datafusion_physical_expr_adapter::{ + DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory, + replace_columns_with_literals, +}; +use futures::StreamExt; +use object_store::memory::InMemory; +use object_store::path::Path; +use object_store::{ObjectStore, ObjectStoreExt, PutPayload}; + +// Metadata key for storing default values in field metadata +const DEFAULT_VALUE_METADATA_KEY: &str = "example.default_value"; + +/// Example showing how to implement custom default value handling for missing columns +/// using field metadata and PhysicalExprAdapter. +/// +/// This example demonstrates how to: +/// 1. Store default values in field metadata using a constant key +/// 2. Create a custom PhysicalExprAdapter that reads these defaults +/// 3. Inject default values for missing columns in filter predicates using `replace_columns_with_literals` +/// 4. Use the DefaultPhysicalExprAdapter as a fallback for standard schema adaptation +/// 5. Convert string default values to proper types using `ScalarValue::cast_to()` at planning time +/// +/// Important: PhysicalExprAdapter handles rewriting both filter predicates and projection +/// expressions for file scans, including handling missing columns. +/// +/// The metadata-based approach provides a flexible way to store default values as strings +/// and cast them to the appropriate types at planning time, avoiding runtime overhead. +pub async fn default_column_values() -> Result<()> { + println!("=== Creating example data with missing columns and default values ==="); + + // Create sample data where the logical schema has more columns than the physical schema + let (logical_schema, physical_schema, batch) = create_sample_data_with_defaults(); + + let store = InMemory::new(); + let buf = { + let mut buf = vec![]; + + let props = WriterProperties::builder() + .set_max_row_group_row_count(Some(2)) + .build(); + + let mut writer = + ArrowWriter::try_new(&mut buf, physical_schema.clone(), Some(props))?; + + writer.write(&batch)?; + writer.close()?; + buf + }; + let path = Path::from("example.parquet"); + let payload = PutPayload::from_bytes(buf.into()); + store.put(&path, payload).await?; + + // Create a custom table provider that handles missing columns with defaults + let table_provider = Arc::new(DefaultValueTableProvider::new(logical_schema)); + + // Set up query execution + let mut cfg = SessionConfig::new(); + cfg.options_mut().execution.parquet.pushdown_filters = true; + let ctx = SessionContext::new_with_config(cfg); + + // Register our table + ctx.register_table("example_table", table_provider)?; + + ctx.runtime_env().register_object_store( + ObjectStoreUrl::parse("memory://")?.as_ref(), + Arc::new(store), + ); + + println!("\n=== Demonstrating default value injection in filter predicates ==="); + let query = "SELECT id, name FROM example_table WHERE status = 'active' ORDER BY id"; + println!("Query: {query}"); + println!("Note: The 'status' column doesn't exist in the physical schema,"); + println!( + "but our adapter injects the default value 'active' for the filter predicate." + ); + + let batches = ctx.sql(query).await?.collect().await?; + + #[rustfmt::skip] + let expected = [ + "+----+-------+", + "| id | name |", + "+----+-------+", + "| 1 | Alice |", + "| 2 | Bob |", + "| 3 | Carol |", + "+----+-------+", + ]; + arrow::util::pretty::print_batches(&batches)?; + assert_batches_eq!(expected, &batches); + + println!("\n=== Key Insight ==="); + println!("This example demonstrates how PhysicalExprAdapter works:"); + println!("1. Physical schema only has 'id' and 'name' columns"); + println!( + "2. Logical schema has 'id', 'name', 'status', and 'priority' columns with defaults" + ); + println!( + "3. Our custom adapter uses replace_columns_with_literals to inject default values" + ); + println!("4. Default values from metadata are cast to proper types at planning time"); + println!("5. The DefaultPhysicalExprAdapter handles other schema adaptations"); + + Ok(()) +} + +/// Create sample data with a logical schema that has default values in metadata +/// and a physical schema that's missing some columns +fn create_sample_data_with_defaults() -> (SchemaRef, SchemaRef, RecordBatch) { + // Create metadata for default values + let mut status_metadata = HashMap::new(); + status_metadata.insert(DEFAULT_VALUE_METADATA_KEY.to_string(), "active".to_string()); + + let mut priority_metadata = HashMap::new(); + priority_metadata.insert(DEFAULT_VALUE_METADATA_KEY.to_string(), "1".to_string()); + + // The logical schema includes all columns with their default values in metadata + // Note: We make the columns with defaults nullable to allow the default adapter to handle them + let logical_schema = Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + Field::new("status", DataType::Utf8, true).with_metadata(status_metadata), + Field::new("priority", DataType::Int32, true).with_metadata(priority_metadata), + ]); + + // The physical schema only has some columns (simulating missing columns in storage) + let physical_schema = Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + ]); + + // Create sample data for the physical schema + let batch = RecordBatch::try_new( + Arc::new(physical_schema.clone()), + vec![ + Arc::new(arrow::array::Int32Array::from(vec![1, 2, 3])), + Arc::new(arrow::array::StringArray::from(vec![ + "Alice", "Bob", "Carol", + ])), + ], + ) + .unwrap(); + + (Arc::new(logical_schema), Arc::new(physical_schema), batch) +} + +/// Custom TableProvider that uses DefaultValuePhysicalExprAdapter +#[derive(Debug)] +struct DefaultValueTableProvider { + schema: SchemaRef, +} + +impl DefaultValueTableProvider { + fn new(schema: SchemaRef) -> Self { + Self { schema } + } +} + +#[async_trait] +impl TableProvider for DefaultValueTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> Result> { + Ok(vec![TableProviderFilterPushDown::Inexact; filters.len()]) + } + + async fn scan( + &self, + state: &dyn Session, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, + ) -> Result> { + let schema = Arc::clone(&self.schema); + let df_schema = DFSchema::try_from(schema.clone())?; + let filter = state.create_physical_expr( + conjunction(filters.iter().cloned()).unwrap_or_else(|| lit(true)), + &df_schema, + )?; + + let parquet_source = ParquetSource::new(schema.clone()) + .with_predicate(filter) + .with_pushdown_filters(true); + + let object_store_url = ObjectStoreUrl::parse("memory://")?; + let store = state.runtime_env().object_store(object_store_url)?; + + let mut files = vec![]; + let mut listing = store.list(None); + while let Some(file) = listing.next().await { + if let Ok(file) = file { + files.push(file); + } + } + + let file_group = files + .iter() + .map(|file| PartitionedFile::new(file.location.clone(), file.size)) + .collect(); + + let file_scan_config = FileScanConfigBuilder::new( + ObjectStoreUrl::parse("memory://")?, + Arc::new(parquet_source), + ) + .with_projection_indices(projection.cloned())? + .with_limit(limit) + .with_file_group(file_group) + .with_expr_adapter(Some(Arc::new(DefaultValuePhysicalExprAdapterFactory) as _)); + + Ok(Arc::new(DataSourceExec::new(Arc::new( + file_scan_config.build(), + )))) + } +} + +/// Factory for creating DefaultValuePhysicalExprAdapter instances +#[derive(Debug)] +struct DefaultValuePhysicalExprAdapterFactory; + +impl PhysicalExprAdapterFactory for DefaultValuePhysicalExprAdapterFactory { + fn create( + &self, + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + ) -> Result> { + let default_factory = DefaultPhysicalExprAdapterFactory; + let default_adapter = default_factory.create( + Arc::clone(&logical_file_schema), + Arc::clone(&physical_file_schema), + )?; + + Ok(Arc::new(DefaultValuePhysicalExprAdapter { + logical_file_schema, + physical_file_schema, + default_adapter, + })) + } +} + +/// Custom PhysicalExprAdapter that handles missing columns with default values from metadata +/// and wraps DefaultPhysicalExprAdapter for standard schema adaptation +#[derive(Debug)] +struct DefaultValuePhysicalExprAdapter { + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + default_adapter: Arc, +} + +impl PhysicalExprAdapter for DefaultValuePhysicalExprAdapter { + fn rewrite(&self, expr: Arc) -> Result> { + // Pre-compute replacements for missing columns with default values + let mut replacements = HashMap::new(); + for field in self.logical_file_schema.fields() { + // Skip columns that exist in physical schema + if self.physical_file_schema.index_of(field.name()).is_ok() { + continue; + } + + // Check if this missing column has a default value in metadata + if let Some(default_str) = field.metadata().get(DEFAULT_VALUE_METADATA_KEY) { + // Create a Utf8 ScalarValue from the string and cast it to the target type + let string_value = ScalarValue::Utf8(Some(default_str.to_string())); + let typed_value = string_value.cast_to(field.data_type())?; + replacements.insert(field.name().as_str(), typed_value); + } + } + + // Replace columns with their default literals if any + let rewritten = if !replacements.is_empty() { + let refs: HashMap<_, _> = replacements.iter().map(|(k, v)| (*k, v)).collect(); + replace_columns_with_literals(expr, &refs)? + } else { + expr + }; + + // Apply the default adapter as a fallback for other schema adaptations + self.default_adapter.rewrite(rewritten) + } +} diff --git a/datafusion-examples/examples/file_stream_provider.rs b/datafusion-examples/examples/custom_data_source/file_stream_provider.rs similarity index 90% rename from datafusion-examples/examples/file_stream_provider.rs rename to datafusion-examples/examples/custom_data_source/file_stream_provider.rs index e6c59d57e98de..5b43072d43f80 100644 --- a/datafusion-examples/examples/file_stream_provider.rs +++ b/datafusion-examples/examples/custom_data_source/file_stream_provider.rs @@ -15,6 +15,31 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + +/// Demonstrates how to use [`FileStreamProvider`] and [`StreamTable`] to stream data +/// from a file-like source (FIFO) into DataFusion for continuous querying. +/// +/// On non-Windows systems, this example creates a named pipe (FIFO) and +/// writes rows into it asynchronously while DataFusion reads the data +/// through a `FileStreamProvider`. +/// +/// This illustrates how to integrate dynamically updated data sources +/// with DataFusion without needing to reload the entire dataset each time. +/// +/// This example does not work on Windows. +pub async fn file_stream_provider() -> datafusion::error::Result<()> { + #[cfg(target_os = "windows")] + { + println!("file_stream_provider example does not work on windows"); + Ok(()) + } + #[cfg(not(target_os = "windows"))] + { + non_windows::main().await + } +} + #[cfg(not(target_os = "windows"))] mod non_windows { use datafusion::assert_batches_eq; @@ -22,8 +47,8 @@ mod non_windows { use std::fs::{File, OpenOptions}; use std::io::Write; use std::path::PathBuf; - use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; + use std::sync::atomic::{AtomicBool, Ordering}; use std::thread; use std::time::Duration; @@ -34,9 +59,9 @@ mod non_windows { use tempfile::TempDir; use tokio::task::JoinSet; - use datafusion::common::{exec_err, Result}; - use datafusion::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; + use datafusion::common::{Result, exec_err}; use datafusion::datasource::TableProvider; + use datafusion::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; use datafusion::logical_expr::SortExpr; use datafusion::prelude::{SessionConfig, SessionContext}; @@ -101,7 +126,6 @@ mod non_windows { let broken_pipe_timeout = Duration::from_secs(10); let sa = file_path; // Spawn a new thread to write to the FIFO file - #[allow(clippy::disallowed_methods)] // spawn allowed only in tests tasks.spawn_blocking(move || { let file = OpenOptions::new().write(true).open(sa).unwrap(); // Reference time to use when deciding to fail the test @@ -186,16 +210,3 @@ mod non_windows { Ok(()) } } - -#[tokio::main] -async fn main() -> datafusion::error::Result<()> { - #[cfg(target_os = "windows")] - { - println!("file_stream_provider example does not work on windows"); - Ok(()) - } - #[cfg(not(target_os = "windows"))] - { - non_windows::main().await - } -} diff --git a/datafusion-examples/examples/custom_data_source/main.rs b/datafusion-examples/examples/custom_data_source/main.rs new file mode 100644 index 0000000000000..0d21a62591129 --- /dev/null +++ b/datafusion-examples/examples/custom_data_source/main.rs @@ -0,0 +1,138 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # These examples are all related to extending or defining how DataFusion reads data +//! +//! These examples demonstrate how DataFusion reads data. +//! +//! ## Usage +//! ```bash +//! cargo run --example custom_data_source -- [all|csv_json_opener|csv_sql_streaming|custom_datasource|custom_file_casts|custom_file_format|default_column_values|file_stream_provider] +//! ``` +//! +//! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module +//! +//! - `adapter_serialization` +//! (file: adapter_serialization.rs, desc: Preserve custom PhysicalExprAdapter information during plan serialization using PhysicalExtensionCodec interception) +//! +//! - `csv_json_opener` +//! (file: csv_json_opener.rs, desc: Use low-level FileOpener APIs for CSV/JSON) +//! +//! - `csv_sql_streaming` +//! (file: csv_sql_streaming.rs, desc: Run a streaming SQL query against CSV data) +//! +//! - `custom_datasource` +//! (file: custom_datasource.rs, desc: Query a custom TableProvider) +//! +//! - `custom_file_casts` +//! (file: custom_file_casts.rs, desc: Implement custom casting rules) +//! +//! - `custom_file_format` +//! (file: custom_file_format.rs, desc: Write to a custom file format) +//! +//! - `default_column_values` +//! (file: default_column_values.rs, desc: Custom default values using metadata) +//! +//! - `file_stream_provider` +//! (file: file_stream_provider.rs, desc: Read/write via FileStreamProvider for streams) + +mod adapter_serialization; +mod csv_json_opener; +mod csv_sql_streaming; +mod custom_datasource; +mod custom_file_casts; +mod custom_file_format; +mod default_column_values; +mod file_stream_provider; + +use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; + +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] +enum ExampleKind { + All, + AdapterSerialization, + CsvJsonOpener, + CsvSqlStreaming, + CustomDatasource, + CustomFileCasts, + CustomFileFormat, + DefaultColumnValues, + FileStreamProvider, +} + +impl ExampleKind { + const EXAMPLE_NAME: &str = "custom_data_source"; + + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<()> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::AdapterSerialization => { + adapter_serialization::adapter_serialization().await? + } + ExampleKind::CsvJsonOpener => csv_json_opener::csv_json_opener().await?, + ExampleKind::CsvSqlStreaming => { + csv_sql_streaming::csv_sql_streaming().await? + } + ExampleKind::CustomDatasource => { + custom_datasource::custom_datasource().await? + } + ExampleKind::CustomFileCasts => { + custom_file_casts::custom_file_casts().await? + } + ExampleKind::CustomFileFormat => { + custom_file_format::custom_file_format().await? + } + ExampleKind::DefaultColumnValues => { + default_column_values::default_column_values().await? + } + ExampleKind::FileStreamProvider => { + file_stream_provider::file_stream_provider().await? + } + } + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let usage = format!( + "Usage: cargo run --example {} -- [{}]", + ExampleKind::EXAMPLE_NAME, + ExampleKind::VARIANTS.join("|") + ); + + let example: ExampleKind = std::env::args() + .nth(1) + .unwrap_or_else(|| ExampleKind::All.to_string()) + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; + + example.run().await +} diff --git a/datafusion-examples/examples/catalog.rs b/datafusion-examples/examples/data_io/catalog.rs similarity index 97% rename from datafusion-examples/examples/catalog.rs rename to datafusion-examples/examples/data_io/catalog.rs index 229867cdfc5bb..9781a93374ea6 100644 --- a/datafusion-examples/examples/catalog.rs +++ b/datafusion-examples/examples/data_io/catalog.rs @@ -15,15 +15,17 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. +//! //! Simple example of a catalog/schema implementation. use async_trait::async_trait; use datafusion::{ arrow::util::pretty, catalog::{CatalogProvider, CatalogProviderList, SchemaProvider}, datasource::{ - file_format::{csv::CsvFormat, FileFormat}, - listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl}, TableProvider, + file_format::{FileFormat, csv::CsvFormat}, + listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl}, }, error::Result, execution::context::SessionState, @@ -34,8 +36,8 @@ use std::{any::Any, collections::HashMap, path::Path, sync::Arc}; use std::{fs::File, io::Write}; use tempfile::TempDir; -#[tokio::main] -async fn main() -> Result<()> { +/// Register the table into a custom catalog +pub async fn catalog() -> Result<()> { env_logger::builder() .filter_level(log::LevelFilter::Info) .init(); @@ -134,12 +136,13 @@ struct DirSchemaOpts<'a> { dir: &'a Path, format: Arc, } + /// Schema where every file with extension `ext` in a given `dir` is a table. #[derive(Debug)] struct DirSchema { - ext: String, tables: RwLock>>, } + impl DirSchema { async fn create(state: &SessionState, opts: DirSchemaOpts<'_>) -> Result> { let DirSchemaOpts { ext, dir, format } = opts; @@ -169,13 +172,8 @@ impl DirSchema { } Ok(Arc::new(Self { tables: RwLock::new(tables), - ext: ext.to_string(), })) } - #[allow(unused)] - fn name(&self) -> &str { - &self.ext - } } #[async_trait] @@ -198,6 +196,7 @@ impl SchemaProvider for DirSchema { let tables = self.tables.read().unwrap(); tables.contains_key(name) } + fn register_table( &self, name: String, @@ -211,7 +210,6 @@ impl SchemaProvider for DirSchema { /// If supported by the implementation, removes an existing table from this schema and returns it. /// If no table of that name exists, returns Ok(None). - #[allow(unused_variables)] fn deregister_table(&self, name: &str) -> Result>> { let mut tables = self.tables.write().unwrap(); log::info!("dropping table {name}"); @@ -223,6 +221,7 @@ impl SchemaProvider for DirSchema { struct DirCatalog { schemas: RwLock>>, } + impl DirCatalog { fn new() -> Self { Self { @@ -230,10 +229,12 @@ impl DirCatalog { } } } + impl CatalogProvider for DirCatalog { fn as_any(&self) -> &dyn Any { self } + fn register_schema( &self, name: &str, @@ -260,11 +261,13 @@ impl CatalogProvider for DirCatalog { } } } + /// Catalog lists holds multiple catalog providers. Each context has a single catalog list. #[derive(Debug)] struct CustomCatalogProviderList { catalogs: RwLock>>, } + impl CustomCatalogProviderList { fn new() -> Self { Self { @@ -272,10 +275,12 @@ impl CustomCatalogProviderList { } } } + impl CatalogProviderList for CustomCatalogProviderList { fn as_any(&self) -> &dyn Any { self } + fn register_catalog( &self, name: String, diff --git a/datafusion-examples/examples/data_io/json_shredding.rs b/datafusion-examples/examples/data_io/json_shredding.rs new file mode 100644 index 0000000000000..ca1513f626245 --- /dev/null +++ b/datafusion-examples/examples/data_io/json_shredding.rs @@ -0,0 +1,363 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{RecordBatch, StringArray}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + +use datafusion::assert_batches_eq; +use datafusion::common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; +use datafusion::common::{Result, assert_contains, exec_datafusion_err}; +use datafusion::datasource::listing::{ + ListingTable, ListingTableConfig, ListingTableConfigExt, ListingTableUrl, +}; +use datafusion::execution::context::SessionContext; +use datafusion::execution::object_store::ObjectStoreUrl; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion::parquet::arrow::ArrowWriter; +use datafusion::parquet::file::properties::WriterProperties; +use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_expr::{ScalarFunctionExpr, expressions}; +use datafusion::prelude::SessionConfig; +use datafusion::scalar::ScalarValue; +use datafusion_physical_expr_adapter::{ + DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory, +}; +use object_store::memory::InMemory; +use object_store::path::Path; +use object_store::{ObjectStoreExt, PutPayload}; + +// Example showing how to implement custom filter rewriting for JSON shredding. +// +// JSON shredding is a technique for optimizing queries on semi-structured data +// by materializing commonly accessed fields into separate columns for better +// columnar storage performance. +// +// In this example, we have a table with both: +// - Original JSON data: data: '{"age": 30}' +// - Shredded flat columns: _data.name: "Alice" (extracted from JSON) +// +// Our custom TableProvider uses a PhysicalExprAdapter to rewrite +// expressions like `json_get_str('name', data)` to use the pre-computed +// flat column `_data.name` when available. This allows the query engine to: +// 1. Push down predicates for better filtering +// 2. Avoid expensive JSON parsing at query time +// 3. Leverage columnar storage benefits for the materialized fields +pub async fn json_shredding() -> Result<()> { + println!("=== Creating example data with flat columns and underscore prefixes ==="); + + // Create sample data with flat columns using underscore prefixes + let (table_schema, batch) = create_sample_data(); + + let store = InMemory::new(); + let buf = { + let mut buf = vec![]; + + let props = WriterProperties::builder() + .set_max_row_group_row_count(Some(2)) + .build(); + + let mut writer = ArrowWriter::try_new(&mut buf, batch.schema(), Some(props)) + .expect("creating writer"); + + writer.write(&batch).expect("Writing batch"); + writer.close().unwrap(); + buf + }; + let path = Path::from("example.parquet"); + let payload = PutPayload::from_bytes(buf.into()); + store.put(&path, payload).await?; + + // Set up query execution + let mut cfg = SessionConfig::new(); + cfg.options_mut().execution.parquet.pushdown_filters = true; + let ctx = SessionContext::new_with_config(cfg); + ctx.runtime_env().register_object_store( + ObjectStoreUrl::parse("memory://")?.as_ref(), + Arc::new(store), + ); + + // Create a custom table provider that rewrites struct field access + let listing_table_config = + ListingTableConfig::new(ListingTableUrl::parse("memory:///example.parquet")?) + .infer_options(&ctx.state()) + .await? + .with_schema(table_schema) + .with_expr_adapter_factory(Arc::new(ShreddedJsonRewriterFactory)); + let table = ListingTable::try_new(listing_table_config).unwrap(); + let table_provider = Arc::new(table); + + // Register our table + ctx.register_table("structs", table_provider)?; + ctx.register_udf(ScalarUDF::new_from_impl(JsonGetStr::default())); + + println!("\n=== Showing all data ==="); + let batches = ctx.sql("SELECT * FROM structs").await?.collect().await?; + arrow::util::pretty::print_batches(&batches)?; + + println!("\n=== Running query with flat column access and filter ==="); + let query = "SELECT json_get_str('age', data) as age FROM structs WHERE json_get_str('name', data) = 'Bob'"; + println!("Query: {query}"); + + let batches = ctx.sql(query).await?.collect().await?; + + #[rustfmt::skip] + let expected = [ + "+-----+", + "| age |", + "+-----+", + "| 25 |", + "+-----+", + ]; + arrow::util::pretty::print_batches(&batches)?; + assert_batches_eq!(expected, &batches); + + println!("\n=== Running explain analyze to confirm row group pruning ==="); + + let batches = ctx + .sql(&format!("EXPLAIN ANALYZE {query}")) + .await? + .collect() + .await?; + let plan = format!("{}", arrow::util::pretty::pretty_format_batches(&batches)?); + println!("{plan}"); + assert_contains!(&plan, "row_groups_pruned_statistics=2 total → 1 matched"); + assert_contains!(&plan, "pushdown_rows_pruned=1"); + + Ok(()) +} + +/// Create the example data with flat columns using underscore prefixes. +/// +/// This demonstrates the logical data structure: +/// - Table schema: What users see (just the 'data' JSON column) +/// - File schema: What's physically stored (both 'data' and materialized '_data.name') +/// +/// The naming convention uses underscore prefixes to indicate shredded columns: +/// - `data` -> original JSON column +/// - `_data.name` -> materialized field from JSON data.name +fn create_sample_data() -> (SchemaRef, RecordBatch) { + // The table schema only has the main data column - this is what users query against + let table_schema = Schema::new(vec![Field::new("data", DataType::Utf8, false)]); + + // The file schema has both the main column and the shredded flat column with underscore prefix + // This represents the actual physical storage with pre-computed columns + let file_schema = Schema::new(vec![ + Field::new("data", DataType::Utf8, false), // Original JSON data + Field::new("_data.name", DataType::Utf8, false), // Materialized name field + ]); + + let batch = create_sample_record_batch(&file_schema); + + (Arc::new(table_schema), batch) +} + +/// Create the actual RecordBatch with sample data +fn create_sample_record_batch(file_schema: &Schema) -> RecordBatch { + // Build a RecordBatch with flat columns + let data_array = StringArray::from(vec![ + r#"{"age": 30}"#, + r#"{"age": 25}"#, + r#"{"age": 35}"#, + r#"{"age": 22}"#, + ]); + let names_array = StringArray::from(vec!["Alice", "Bob", "Charlie", "Dave"]); + + RecordBatch::try_new( + Arc::new(file_schema.clone()), + vec![Arc::new(data_array), Arc::new(names_array)], + ) + .unwrap() +} + +/// Scalar UDF that uses serde_json to access json fields +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct JsonGetStr { + signature: Signature, +} + +impl Default for JsonGetStr { + fn default() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for JsonGetStr { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "json_get_str" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + assert!( + args.args.len() == 2, + "json_get_str requires exactly 2 arguments" + ); + let key = match &args.args[0] { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(key))) => key, + _ => { + return Err(exec_datafusion_err!( + "json_get_str first argument must be a string" + )); + } + }; + // We expect a string array that contains JSON strings + let json_array = match &args.args[1] { + ColumnarValue::Array(array) => array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + exec_datafusion_err!( + "json_get_str second argument must be a string array" + ) + })?, + _ => { + return Err(exec_datafusion_err!( + "json_get_str second argument must be a string array" + )); + } + }; + let values = json_array + .iter() + .map(|value| { + value.and_then(|v| { + let json_value: serde_json::Value = + serde_json::from_str(v).unwrap_or_default(); + json_value.get(key).map(|v| v.to_string()) + }) + }) + .collect::(); + Ok(ColumnarValue::Array(Arc::new(values))) + } +} + +/// Factory for creating ShreddedJsonRewriter instances +#[derive(Debug)] +struct ShreddedJsonRewriterFactory; + +impl PhysicalExprAdapterFactory for ShreddedJsonRewriterFactory { + fn create( + &self, + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + ) -> Result> { + let default_factory = DefaultPhysicalExprAdapterFactory; + let default_adapter = default_factory.create( + Arc::clone(&logical_file_schema), + Arc::clone(&physical_file_schema), + )?; + + Ok(Arc::new(ShreddedJsonRewriter { + physical_file_schema, + default_adapter, + })) + } +} + +/// Rewriter that converts json_get_str calls to direct flat column references +/// and wraps DefaultPhysicalExprAdapter for standard schema adaptation +#[derive(Debug)] +struct ShreddedJsonRewriter { + physical_file_schema: SchemaRef, + default_adapter: Arc, +} + +impl PhysicalExprAdapter for ShreddedJsonRewriter { + fn rewrite(&self, expr: Arc) -> Result> { + // First try our custom JSON shredding rewrite + let rewritten = expr + .transform(|expr| self.rewrite_impl(expr, &self.physical_file_schema)) + .data()?; + + // Then apply the default adapter as a fallback to handle standard schema differences + // like type casting and missing columns + self.default_adapter.rewrite(rewritten) + } +} + +impl ShreddedJsonRewriter { + fn rewrite_impl( + &self, + expr: Arc, + physical_file_schema: &Schema, + ) -> Result>> { + if let Some(func) = expr.as_any().downcast_ref::() + && func.name() == "json_get_str" + && func.args().len() == 2 + { + // Get the key from the first argument + if let Some(literal) = func.args()[0] + .as_any() + .downcast_ref::() + && let ScalarValue::Utf8(Some(field_name)) = literal.value() + { + // Get the column from the second argument + if let Some(column) = func.args()[1] + .as_any() + .downcast_ref::() + { + let column_name = column.name(); + // Check if there's a flat column with underscore prefix + let flat_column_name = format!("_{column_name}.{field_name}"); + + if let Ok(flat_field_index) = + physical_file_schema.index_of(&flat_column_name) + { + let flat_field = physical_file_schema.field(flat_field_index); + + if flat_field.data_type() == &DataType::Utf8 { + // Replace the whole expression with a direct column reference + let new_expr = Arc::new(expressions::Column::new( + &flat_column_name, + flat_field_index, + )) + as Arc; + + return Ok(Transformed { + data: new_expr, + tnr: TreeNodeRecursion::Stop, + transformed: true, + }); + } + } + } + } + } + Ok(Transformed::no(expr)) + } +} diff --git a/datafusion-examples/examples/data_io/main.rs b/datafusion-examples/examples/data_io/main.rs new file mode 100644 index 0000000000000..0039585d15b60 --- /dev/null +++ b/datafusion-examples/examples/data_io/main.rs @@ -0,0 +1,144 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # These examples of data formats and I/O +//! +//! These examples demonstrate data formats and I/O. +//! +//! ## Usage +//! ```bash +//! cargo run --example data_io -- [all|catalog|json_shredding|parquet_adv_idx|parquet_emb_idx|parquet_enc_with_kms|parquet_enc|parquet_exec_visitor|parquet_idx|query_http_csv|remote_catalog] +//! ``` +//! +//! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module +//! +//! - `catalog` +//! (file: catalog.rs, desc: Register tables into a custom catalog) +//! +//! - `json_shredding` +//! (file: json_shredding.rs, desc: Implement filter rewriting for JSON shredding) +//! +//! - `parquet_adv_idx` +//! (file: parquet_advanced_index.rs, desc: Create a secondary index across multiple parquet files) +//! +//! - `parquet_emb_idx` +//! (file: parquet_embedded_index.rs, desc: Store a custom index inside Parquet files) +//! +//! - `parquet_enc` +//! (file: parquet_encrypted.rs, desc: Read & write encrypted Parquet files) +//! +//! - `parquet_enc_with_kms` +//! (file: parquet_encrypted_with_kms.rs, desc: Encrypted Parquet I/O using a KMS-backed factory) +//! +//! - `parquet_exec_visitor` +//! (file: parquet_exec_visitor.rs, desc: Extract statistics by visiting an ExecutionPlan) +//! +//! - `parquet_idx` +//! (file: parquet_index.rs, desc: Create a secondary index) +//! +//! - `query_http_csv` +//! (file: query_http_csv.rs, desc: Query CSV files via HTTP) +//! +//! - `remote_catalog` +//! (file: remote_catalog.rs, desc: Interact with a remote catalog) + +mod catalog; +mod json_shredding; +mod parquet_advanced_index; +mod parquet_embedded_index; +mod parquet_encrypted; +mod parquet_encrypted_with_kms; +mod parquet_exec_visitor; +mod parquet_index; +mod query_http_csv; +mod remote_catalog; + +use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; + +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] +enum ExampleKind { + All, + Catalog, + JsonShredding, + ParquetAdvIdx, + ParquetEmbIdx, + ParquetEnc, + ParquetEncWithKms, + ParquetExecVisitor, + ParquetIdx, + QueryHttpCsv, + RemoteCatalog, +} + +impl ExampleKind { + const EXAMPLE_NAME: &str = "data_io"; + + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<()> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::Catalog => catalog::catalog().await?, + ExampleKind::JsonShredding => json_shredding::json_shredding().await?, + ExampleKind::ParquetAdvIdx => { + parquet_advanced_index::parquet_advanced_index().await? + } + ExampleKind::ParquetEmbIdx => { + parquet_embedded_index::parquet_embedded_index().await? + } + ExampleKind::ParquetEncWithKms => { + parquet_encrypted_with_kms::parquet_encrypted_with_kms().await? + } + ExampleKind::ParquetEnc => parquet_encrypted::parquet_encrypted().await?, + ExampleKind::ParquetExecVisitor => { + parquet_exec_visitor::parquet_exec_visitor().await? + } + ExampleKind::ParquetIdx => parquet_index::parquet_index().await?, + ExampleKind::QueryHttpCsv => query_http_csv::query_http_csv().await?, + ExampleKind::RemoteCatalog => remote_catalog::remote_catalog().await?, + } + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let usage = format!( + "Usage: cargo run --example {} -- [{}]", + ExampleKind::EXAMPLE_NAME, + ExampleKind::VARIANTS.join("|") + ); + + let example: ExampleKind = std::env::args() + .nth(1) + .unwrap_or_else(|| ExampleKind::All.to_string()) + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; + + example.run().await +} diff --git a/datafusion-examples/examples/advanced_parquet_index.rs b/datafusion-examples/examples/data_io/parquet_advanced_index.rs similarity index 96% rename from datafusion-examples/examples/advanced_parquet_index.rs rename to datafusion-examples/examples/data_io/parquet_advanced_index.rs index efaee23366a1c..f02b01354b784 100644 --- a/datafusion-examples/examples/advanced_parquet_index.rs +++ b/datafusion-examples/examples/data_io/parquet_advanced_index.rs @@ -15,40 +15,42 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use std::any::Any; use std::collections::{HashMap, HashSet}; use std::fs::File; use std::ops::Range; use std::path::{Path, PathBuf}; -use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; use datafusion::catalog::Session; use datafusion::common::{ - internal_datafusion_err, DFSchema, DataFusionError, Result, ScalarValue, + DFSchema, DataFusionError, Result, ScalarValue, internal_datafusion_err, }; +use datafusion::datasource::TableProvider; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::physical_plan::parquet::ParquetAccessPlan; use datafusion::datasource::physical_plan::{ - FileMeta, FileScanConfigBuilder, ParquetFileReaderFactory, ParquetSource, + FileScanConfigBuilder, ParquetFileReaderFactory, ParquetSource, }; -use datafusion::datasource::TableProvider; use datafusion::execution::object_store::ObjectStoreUrl; use datafusion::logical_expr::utils::conjunction; use datafusion::logical_expr::{TableProviderFilterPushDown, TableType}; +use datafusion::parquet::arrow::ArrowWriter; use datafusion::parquet::arrow::arrow_reader::{ ArrowReaderOptions, ParquetRecordBatchReaderBuilder, RowSelection, RowSelector, }; use datafusion::parquet::arrow::async_reader::{AsyncFileReader, ParquetObjectReader}; -use datafusion::parquet::arrow::ArrowWriter; -use datafusion::parquet::file::metadata::ParquetMetaData; +use datafusion::parquet::file::metadata::{PageIndexPolicy, ParquetMetaData}; use datafusion::parquet::file::properties::{EnabledStatistics, WriterProperties}; use datafusion::parquet::schema::types::ColumnPath; -use datafusion::physical_expr::utils::{Guarantee, LiteralGuarantee}; use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_expr::utils::{Guarantee, LiteralGuarantee}; use datafusion::physical_optimizer::pruning::PruningPredicate; -use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion::prelude::*; use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; @@ -56,8 +58,8 @@ use arrow::datatypes::SchemaRef; use async_trait::async_trait; use bytes::Bytes; use datafusion::datasource::memory::DataSourceExec; -use futures::future::BoxFuture; use futures::FutureExt; +use futures::future::BoxFuture; use object_store::ObjectStore; use tempfile::TempDir; use url::Url; @@ -121,7 +123,6 @@ use url::Url; /// │ ╚═══════════════════╝ │ 1. With cached ParquetMetadata, so /// └───────────────────────┘ the ParquetSource does not re-read / /// Parquet File decode the thrift footer -/// /// ``` /// /// Within a Row Group, Column Chunks store data in DataPages. This example also @@ -156,8 +157,7 @@ use url::Url; /// /// [`ListingTable`]: datafusion::datasource::listing::ListingTable /// [Page Index](https://github.com/apache/parquet-format/blob/master/PageIndex.md) -#[tokio::main] -async fn main() -> Result<()> { +pub async fn parquet_advanced_index() -> Result<()> { // the object store is used to read the parquet files (in this case, it is // a local file system, but in a real system it could be S3, GCS, etc) let object_store: Arc = @@ -240,6 +240,7 @@ pub struct IndexTableProvider { /// if true, use row selections in addition to row group selections use_row_selections: AtomicBool, } + impl IndexTableProvider { /// Create a new IndexTableProvider /// * `object_store` - the object store implementation to use for reading files @@ -409,7 +410,7 @@ impl IndexedFile { let options = ArrowReaderOptions::new() // Load the page index when reading metadata to cache // so it is available to interpret row selections - .with_page_index(true); + .with_page_index_policy(PageIndexPolicy::Required); let reader = ParquetRecordBatchReaderBuilder::try_new_with_options(file, options)?; let metadata = reader.metadata().clone(); @@ -492,19 +493,18 @@ impl TableProvider for IndexTableProvider { .with_file(indexed_file); let file_source = Arc::new( - ParquetSource::default() + ParquetSource::new(schema.clone()) // provide the predicate so the DataSourceExec can try and prune // row groups internally .with_predicate(predicate) // provide the factory to create parquet reader without re-reading metadata .with_parquet_file_reader_factory(Arc::new(reader_factory)), ); - let file_scan_config = - FileScanConfigBuilder::new(object_store_url, schema, file_source) - .with_limit(limit) - .with_projection(projection.cloned()) - .with_file(partitioned_file) - .build(); + let file_scan_config = FileScanConfigBuilder::new(object_store_url, file_source) + .with_limit(limit) + .with_projection_indices(projection.cloned())? + .with_file(partitioned_file) + .build(); // Finally, put it all together into a DataSourceExec Ok(DataSourceExec::from_data_source(file_scan_config)) @@ -541,6 +541,7 @@ impl CachedParquetFileReaderFactory { metadata: HashMap::new(), } } + /// Add the pre-parsed information about the file to the factor fn with_file(mut self, indexed_file: &IndexedFile) -> Self { self.metadata.insert( @@ -555,25 +556,26 @@ impl ParquetFileReaderFactory for CachedParquetFileReaderFactory { fn create_reader( &self, _partition_index: usize, - file_meta: FileMeta, + partitioned_file: PartitionedFile, metadata_size_hint: Option, _metrics: &ExecutionPlanMetricsSet, ) -> Result> { // for this example we ignore the partition index and metrics // but in a real system you would likely use them to report details on // the performance of the reader. - let filename = file_meta - .location() + let filename = partitioned_file + .object_meta + .location .parts() - .last() + .next_back() .expect("No path in location") .as_ref() .to_string(); let object_store = Arc::clone(&self.object_store); let mut inner = - ParquetObjectReader::new(object_store, file_meta.object_meta.location) - .with_file_size(file_meta.object_meta.size); + ParquetObjectReader::new(object_store, partitioned_file.object_meta.location) + .with_file_size(partitioned_file.object_meta.size); if let Some(hint) = metadata_size_hint { inner = inner.with_footer_size_hint(hint) @@ -657,7 +659,7 @@ fn make_demo_file(path: impl AsRef, value_range: Range) -> Result<()> // enable page statistics for the tag column, // for everything else. let props = WriterProperties::builder() - .set_max_row_group_size(100) + .set_max_row_group_row_count(Some(100)) // compute column chunk (per row group) statistics by default .set_statistics_enabled(EnabledStatistics::Chunk) // compute column page statistics for the tag column diff --git a/datafusion-examples/examples/data_io/parquet_embedded_index.rs b/datafusion-examples/examples/data_io/parquet_embedded_index.rs new file mode 100644 index 0000000000000..bcaca2ed5c85b --- /dev/null +++ b/datafusion-examples/examples/data_io/parquet_embedded_index.rs @@ -0,0 +1,475 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. +//! +//! Embedding and using a custom index in Parquet files +//! +//! # Background +//! +//! This example shows how to add an application‑specific index to an Apache +//! Parquet file without modifying the Parquet format itself. The resulting +//! files can be read by any standard Parquet reader, which will simply +//! ignore the extra index data. +//! +//! A “distinct value” index, similar to a ["set" Skip Index in ClickHouse], +//! is stored in a custom binary format within the parquet file. Only the +//! location of index is stored in Parquet footer key/value metadata. +//! This approach is more efficient than storing the index itself in the footer +//! metadata because the footer must be read and parsed by all readers, +//! even those that do not use the index. +//! +//! This example uses a file level index for skipping entire files, but any +//! index can be stored using the same techniques and used skip row groups, +//! data pages, or rows using the APIs on [`TableProvider`] and [`ParquetSource`]. +//! +//! The resulting Parquet file layout is as follows: +//! +//! ```text +//! ┌──────────────────────┐ +//! │┌───────────────────┐ │ +//! ││ DataPage │ │ +//! │└───────────────────┘ │ +//! Standard Parquet │┌───────────────────┐ │ +//! Data Pages ││ DataPage │ │ +//! │└───────────────────┘ │ +//! │ ... │ +//! │┌───────────────────┐ │ +//! ││ DataPage │ │ +//! │└───────────────────┘ │ +//! │┏━━━━━━━━━━━━━━━━━━━┓ │ +//! Non standard │┃ ┃ │ +//! index (ignored by │┃Custom Binary Index┃ │ +//! other Parquet │┃ (Distinct Values) ┃◀│─ ─ ─ +//! readers) │┃ ┃ │ │ +//! │┗━━━━━━━━━━━━━━━━━━━┛ │ +//! Standard Parquet │┏━━━━━━━━━━━━━━━━━━━┓ │ │ key/value metadata +//! Page Index │┃ Page Index ┃ │ contains location +//! │┗━━━━━━━━━━━━━━━━━━━┛ │ │ of special index +//! │╔═══════════════════╗ │ +//! │║ Parquet Footer w/ ║ │ │ +//! │║ Metadata ║ ┼ ─ ─ +//! │║ (Thrift Encoded) ║ │ +//! │╚═══════════════════╝ │ +//! └──────────────────────┘ +//! +//! Parquet File +//! +//! # High Level Flow +//! +//! To create a custom Parquet index: +//! +//! 1. Compute the index and serialize it to a binary format. +//! +//! 2. Write the Parquet file with: +//! - regular data pages +//! - the serialized index inline +//! - footer key/value metadata entry to locate the index +//! +//! To read and use the index are: +//! +//! 1. Read and deserialize the file’s footer to locate the index. +//! +//! 2. Read and deserialize the index. +//! +//! 3. Create a `TableProvider` that knows how to use the index to quickly find +//! the relevant files, row groups, data pages or rows based on on pushed down +//! filters. +//! +//! # FAQ: Why do other Parquet readers skip over the custom index? +//! +//! The flow for reading a parquet file is: +//! +//! 1. Seek to the end of the file and read the last 8 bytes (a 4‑byte +//! little‑endian footer length followed by the `PAR1` magic bytes). +//! +//! 2. Seek backwards by that length to parse the Thrift‑encoded footer +//! metadata (including key/value pairs). +//! +//! 3. Read data required for decoding such as data pages based on the offsets +//! encoded in the metadata. +//! +//! Since parquet readers do not scan from the start of the file they will read +//! data in the file unless it is explicitly referenced in the footer metadata. +//! +//! Thus other readers will encounter and ignore an unknown key +//! (`distinct_index_offset`) in the footer key/value metadata. Unless they +//! know how to use that information, they will not attempt to read or +//! the bytes that make up the index. +//! +//! ["set" Skip Index in ClickHouse]: https://clickhouse.com/docs/optimize/skipping-indexes#set + +use arrow::array::{ArrayRef, StringArray}; +use arrow::record_batch::RecordBatch; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use async_trait::async_trait; +use datafusion::catalog::{Session, TableProvider}; +use datafusion::common::{HashMap, HashSet, Result, exec_err}; +use datafusion::datasource::TableType; +use datafusion::datasource::listing::PartitionedFile; +use datafusion::datasource::memory::DataSourceExec; +use datafusion::datasource::physical_plan::{FileScanConfigBuilder, ParquetSource}; +use datafusion::execution::object_store::ObjectStoreUrl; +use datafusion::logical_expr::{Operator, TableProviderFilterPushDown}; +use datafusion::parquet::arrow::ArrowWriter; +use datafusion::parquet::errors::ParquetError; +use datafusion::parquet::file::metadata::{FileMetaData, KeyValue}; +use datafusion::parquet::file::reader::{FileReader, SerializedFileReader}; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::*; +use datafusion::scalar::ScalarValue; +use std::fs::{File, read_dir}; +use std::io::{Read, Seek, SeekFrom, Write}; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use tempfile::TempDir; + +/// Store a custom index inside a Parquet file and use it to speed up queries +pub async fn parquet_embedded_index() -> Result<()> { + // 1. Create temp dir and write 3 Parquet files with different category sets + let tmp = TempDir::new()?; + let dir = tmp.path(); + write_file_with_index(&dir.join("a.parquet"), &["foo", "bar", "foo"])?; + write_file_with_index(&dir.join("b.parquet"), &["baz", "qux"])?; + write_file_with_index(&dir.join("c.parquet"), &["foo", "quux", "quux"])?; + + // 2. Register our custom TableProvider + let field = Field::new("category", DataType::Utf8, false); + let schema_ref = Arc::new(Schema::new(vec![field])); + let provider = Arc::new(DistinctIndexTable::try_new(dir, schema_ref.clone())?); + + let ctx = SessionContext::new(); + ctx.register_table("t", provider)?; + + // 3. Run a query: only files containing 'foo' get scanned. The rest are pruned. + // based on the distinct index. + let df = ctx.sql("SELECT * FROM t WHERE category = 'foo'").await?; + df.show().await?; + + Ok(()) +} + +/// An index of distinct values for a single column +/// +/// In this example the index is a simple set of strings, but in a real +/// application it could be any arbitrary data structure. +/// +/// Also, this example indexes the distinct values for an entire file +/// but a real application could create multiple indexes for multiple +/// row groups and/or columns, depending on the use case. +#[derive(Debug, Clone)] +struct DistinctIndex { + inner: HashSet, +} + +impl DistinctIndex { + /// Create a DistinctIndex from an iterator of strings + pub fn new>(iter: I) -> Self { + Self { + inner: iter.into_iter().collect(), + } + } + + /// Returns true if the index contains the given value + pub fn contains(&self, value: &str) -> bool { + self.inner.contains(value) + } + + /// Serialize the distinct index to a writer as bytes + /// + /// In this example, we use a simple newline-separated format, + /// but a real application can use any arbitrary binary format. + /// + /// Note that we must use the ArrowWriter to write the index so that its + /// internal accounting of offsets can correctly track the actual size of + /// the file. If we wrote directly to the underlying writer, the PageIndex + /// written right before the would be incorrect as they would not account + /// for the extra bytes written. + fn serialize( + &self, + arrow_writer: &mut ArrowWriter, + ) -> Result<()> { + let serialized = self + .inner + .iter() + .map(|s| s.as_str()) + .collect::>() + .join("\n"); + let index_bytes = serialized.into_bytes(); + + // Set the offset for the index + let offset = arrow_writer.bytes_written(); + let index_len = index_bytes.len() as u64; + + println!("Writing custom index at offset: {offset}, length: {index_len}"); + // Write the index magic and length to the file + arrow_writer.write_all(INDEX_MAGIC)?; + arrow_writer.write_all(&index_len.to_le_bytes())?; + + // Write the index bytes + arrow_writer.write_all(&index_bytes)?; + + // Append metadata about the index to the Parquet file footer + arrow_writer.append_key_value_metadata(KeyValue::new( + "distinct_index_offset".to_string(), + offset.to_string(), + )); + Ok(()) + } + + /// Read the distinct values index from a reader at the given offset and length + pub fn new_from_reader(mut reader: R, offset: u64) -> Result { + reader.seek(SeekFrom::Start(offset))?; + + let mut magic_buf = [0u8; 4]; + reader.read_exact(&mut magic_buf)?; + if magic_buf != INDEX_MAGIC { + return exec_err!("Invalid index magic number at offset {offset}"); + } + + let mut len_buf = [0u8; 8]; + reader.read_exact(&mut len_buf)?; + let stored_len = u64::from_le_bytes(len_buf) as usize; + + let mut index_buf = vec![0u8; stored_len]; + reader.read_exact(&mut index_buf)?; + + let Ok(s) = String::from_utf8(index_buf) else { + return exec_err!("Invalid UTF-8 in index data"); + }; + + Ok(Self { + inner: s.lines().map(|s| s.to_string()).collect(), + }) + } +} + +/// DataFusion [`TableProvider]` that reads Parquet files and uses a +/// `DistinctIndex` to prune files based on pushed down filters. +#[derive(Debug)] +struct DistinctIndexTable { + /// The schema of the table + schema: SchemaRef, + /// Key is file name, value is DistinctIndex for that file + files_and_index: HashMap, + /// Directory containing the Parquet files + dir: PathBuf, +} + +impl DistinctIndexTable { + /// Create a new DistinctIndexTable for files in the given directory + /// + /// Scans the directory, reading the `DistinctIndex` from each file + fn try_new(dir: impl Into, schema: SchemaRef) -> Result { + let dir = dir.into(); + let mut index = HashMap::new(); + + for entry in read_dir(&dir)? { + let path = entry?.path(); + if path.extension().and_then(|s| s.to_str()) != Some("parquet") { + continue; + } + let file_name = path.file_name().unwrap().to_string_lossy().to_string(); + + let distinct_set = read_distinct_index(&path)?; + + println!("Read distinct index for {file_name}: {file_name:?}"); + index.insert(file_name, distinct_set); + } + + Ok(Self { + schema, + files_and_index: index, + dir, + }) + } +} + +/// Wrapper around ArrowWriter to write Parquet files with an embedded index +struct IndexedParquetWriter { + writer: ArrowWriter, +} + +/// Magic bytes to identify our custom index format +const INDEX_MAGIC: &[u8] = b"IDX1"; + +impl IndexedParquetWriter { + pub fn try_new(sink: W, schema: Arc) -> Result { + let writer = ArrowWriter::try_new(sink, schema, None)?; + Ok(Self { writer }) + } + + /// Write a RecordBatch to the Parquet file + pub fn write(&mut self, batch: &RecordBatch) -> Result<()> { + self.writer.write(batch)?; + Ok(()) + } + + /// Flush the current row group + pub fn flush(&mut self) -> Result<()> { + self.writer.flush()?; + Ok(()) + } + + /// Close the Parquet file, flushing any remaining data + pub fn close(self) -> Result<()> { + self.writer.close()?; + Ok(()) + } + + /// write the DistinctIndex to the Parquet file + pub fn write_index(&mut self, index: &DistinctIndex) -> Result<()> { + index.serialize(&mut self.writer) + } +} + +/// Write a Parquet file with a single column "category" containing the +/// strings in `values` and a DistinctIndex for that column. +fn write_file_with_index(path: &Path, values: &[&str]) -> Result<()> { + // form an input RecordBatch with the string values + let field = Field::new("category", DataType::Utf8, false); + let schema = Arc::new(Schema::new(vec![field.clone()])); + let arr: ArrayRef = Arc::new(StringArray::from(values.to_vec())); + let batch = RecordBatch::try_new(schema.clone(), vec![arr])?; + + // compute the distinct index + let distinct_index: DistinctIndex = + DistinctIndex::new(values.iter().map(|s| (*s).to_string())); + + let file = File::create(path)?; + + let mut writer = IndexedParquetWriter::try_new(file, schema.clone())?; + writer.write(&batch)?; + writer.flush()?; + writer.write_index(&distinct_index)?; + writer.close()?; + + println!("Finished writing file to {}", path.display()); + Ok(()) +} + +/// Read a `DistinctIndex` from a Parquet file +fn read_distinct_index(path: &Path) -> Result { + let file = File::open(path)?; + + let file_size = file.metadata()?.len(); + println!("Reading index from {} (size: {file_size})", path.display(),); + + let reader = SerializedFileReader::new(file.try_clone()?)?; + let meta = reader.metadata().file_metadata(); + + let offset = get_key_value(meta, "distinct_index_offset") + .ok_or_else(|| ParquetError::General("Missing index offset".into()))? + .parse::() + .map_err(|e| ParquetError::General(e.to_string()))?; + + println!("Reading index at offset: {offset}, length"); + DistinctIndex::new_from_reader(file, offset) +} + +/// Returns the value of a named key from the Parquet file metadata +/// +/// Returns None if the key is not found +fn get_key_value<'a>(file_meta_data: &'a FileMetaData, key: &'_ str) -> Option<&'a str> { + let kvs = file_meta_data.key_value_metadata()?; + let kv = kvs.iter().find(|kv| kv.key == key)?; + kv.value.as_deref() +} + +/// Implement TableProvider for DistinctIndexTable, using the distinct index to prune files +#[async_trait] +impl TableProvider for DistinctIndexTable { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + fn table_type(&self) -> TableType { + TableType::Base + } + + /// Prune files before reading: only keep files whose distinct set + /// contains the filter value + async fn scan( + &self, + _ctx: &dyn Session, + _proj: Option<&Vec>, + filters: &[Expr], + _limit: Option, + ) -> Result> { + // This example only handles filters of the form + // `category = 'X'` where X is a string literal + // + // You can use `PruningPredicate` for much more general range and + // equality analysis or write your own custom logic. + let mut target: Option<&str> = None; + + if filters.len() == 1 + && let Expr::BinaryExpr(expr) = &filters[0] + && expr.op == Operator::Eq + && let (Expr::Column(c), Expr::Literal(ScalarValue::Utf8(Some(v)), _)) = + (&*expr.left, &*expr.right) + && c.name == "category" + { + println!("Filtering for category: {v}"); + target = Some(v); + } + // Determine which files to scan + let files_to_scan: Vec<_> = self + .files_and_index + .iter() + .filter_map(|(f, distinct_index)| { + // keep file if no target or target is in the distinct set + if target.is_none() || distinct_index.contains(target?) { + Some(f) + } else { + None + } + }) + .collect(); + + println!("Scanning only files: {files_to_scan:?}"); + + // Build ParquetSource to actually read the files + let url = ObjectStoreUrl::parse("file://")?; + let source = Arc::new( + ParquetSource::new(self.schema.clone()).with_enable_page_index(true), + ); + let mut builder = FileScanConfigBuilder::new(url, source); + for file in files_to_scan { + let path = self.dir.join(file); + let len = std::fs::metadata(&path)?.len(); + // If the index contained information about row groups or pages, + // you could also pass that information here to further prune + // the data read from the file. + let partitioned_file = + PartitionedFile::new(path.to_str().unwrap().to_string(), len); + builder = builder.with_file(partitioned_file); + } + Ok(DataSourceExec::from_data_source(builder.build())) + } + + /// Tell DataFusion that we can handle filters on the "category" column + fn supports_filters_pushdown( + &self, + fs: &[&Expr], + ) -> Result> { + // Mark as inexact since pruning is file‑granular + Ok(vec![TableProviderFilterPushDown::Inexact; fs.len()]) + } +} diff --git a/datafusion-examples/examples/data_io/parquet_encrypted.rs b/datafusion-examples/examples/data_io/parquet_encrypted.rs new file mode 100644 index 0000000000000..26361e9b52be0 --- /dev/null +++ b/datafusion-examples/examples/data_io/parquet_encrypted.rs @@ -0,0 +1,129 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. + +use std::sync::Arc; + +use datafusion::common::DataFusionError; +use datafusion::config::{ConfigFileEncryptionProperties, TableParquetOptions}; +use datafusion::dataframe::{DataFrame, DataFrameWriteOptions}; +use datafusion::logical_expr::{col, lit}; +use datafusion::parquet::encryption::decrypt::FileDecryptionProperties; +use datafusion::parquet::encryption::encrypt::FileEncryptionProperties; +use datafusion::prelude::{ParquetReadOptions, SessionContext}; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; +use tempfile::TempDir; + +/// Read and write encrypted Parquet files using DataFusion +pub async fn parquet_encrypted() -> datafusion::common::Result<()> { + // The SessionContext is the main high level API for interacting with DataFusion + let ctx = SessionContext::new(); + + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; + + // Read the sample parquet file + let parquet_df = ctx + .read_parquet(parquet_temp.path_str()?, ParquetReadOptions::default()) + .await?; + + // Show information from the dataframe + println!( + "===============================================================================" + ); + println!("Original Parquet DataFrame:"); + query_dataframe(&parquet_df).await?; + + // Setup encryption and decryption properties + let (encrypt, decrypt) = setup_encryption(&parquet_df)?; + + // Create a temporary file location for the encrypted parquet file + let tmp_source = TempDir::new()?; + let tempfile = tmp_source.path().join("cars_encrypted.parquet"); + + // Write encrypted parquet + let mut options = TableParquetOptions::default(); + options.crypto.file_encryption = Some(ConfigFileEncryptionProperties::from(&encrypt)); + parquet_df + .write_parquet( + tempfile.to_str().unwrap(), + DataFrameWriteOptions::new().with_single_file_output(true), + Some(options), + ) + .await?; + + // Read encrypted parquet back as a DataFrame using matching decryption config + let ctx: SessionContext = SessionContext::new(); + let read_options = + ParquetReadOptions::default().file_decryption_properties((&decrypt).into()); + + let encrypted_parquet_df = ctx + .read_parquet(tempfile.to_str().unwrap(), read_options) + .await?; + + // Show information from the dataframe + println!( + "\n\n===============================================================================" + ); + println!("Encrypted Parquet DataFrame:"); + query_dataframe(&encrypted_parquet_df).await?; + + Ok(()) +} + +// Show information from the dataframe +async fn query_dataframe(df: &DataFrame) -> Result<(), DataFusionError> { + // show its schema using 'describe' + println!("Schema:"); + df.clone().describe().await?.show().await?; + + // Select three columns and filter the results + // so that only rows where speed > 5 are returned + // select car, speed, time from t where speed > 5 + println!("\nSelected rows and columns:"); + df.clone() + .select_columns(&["car", "speed", "time"])? + .filter(col("speed").gt(lit(5)))? + .show() + .await?; + + Ok(()) +} + +// Setup encryption and decryption properties +fn setup_encryption( + parquet_df: &DataFrame, +) -> Result<(Arc, Arc), DataFusionError> +{ + let schema = parquet_df.schema(); + let footer_key = b"0123456789012345".to_vec(); // 128bit/16 + let column_key = b"1234567890123450".to_vec(); // 128bit/16 + + let mut encrypt = FileEncryptionProperties::builder(footer_key.clone()); + let mut decrypt = FileDecryptionProperties::builder(footer_key.clone()); + + for field in schema.fields().iter() { + encrypt = encrypt.with_column_key(field.name().as_str(), column_key.clone()); + decrypt = decrypt.with_column_key(field.name().as_str(), column_key.clone()); + } + + let encrypt = encrypt.build()?; + let decrypt = decrypt.build()?; + Ok((encrypt, decrypt)) +} diff --git a/datafusion-examples/examples/data_io/parquet_encrypted_with_kms.rs b/datafusion-examples/examples/data_io/parquet_encrypted_with_kms.rs new file mode 100644 index 0000000000000..1a9bf56c09b35 --- /dev/null +++ b/datafusion-examples/examples/data_io/parquet_encrypted_with_kms.rs @@ -0,0 +1,304 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. + +use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; +use arrow_schema::SchemaRef; +use async_trait::async_trait; +use base64::Engine; +use datafusion::common::extensions_options; +use datafusion::config::{EncryptionFactoryOptions, TableParquetOptions}; +use datafusion::dataframe::DataFrameWriteOptions; +use datafusion::datasource::file_format::parquet::ParquetFormat; +use datafusion::datasource::listing::ListingOptions; +use datafusion::error::Result; +use datafusion::execution::parquet_encryption::EncryptionFactory; +use datafusion::parquet::encryption::decrypt::KeyRetriever; +use datafusion::parquet::encryption::{ + decrypt::FileDecryptionProperties, encrypt::FileEncryptionProperties, +}; +use datafusion::prelude::SessionContext; +use futures::StreamExt; +use object_store::path::Path; +use rand::rand_core::{OsRng, TryRngCore}; +use std::collections::HashSet; +use std::sync::Arc; +use tempfile::TempDir; + +const ENCRYPTION_FACTORY_ID: &str = "example.mock_kms_encryption"; + +/// This example demonstrates reading and writing Parquet files that +/// are encrypted using Parquet Modular Encryption. +/// +/// Compared to the `parquet_encrypted` example, where AES keys +/// are specified directly, this example implements an `EncryptionFactory` that +/// generates encryption keys dynamically per file. +/// Encryption key metadata is stored inline in the Parquet files and is used to determine +/// the decryption keys when reading the files. +/// +/// In this example, encryption keys are simply stored base64 encoded in the Parquet metadata, +/// which is not a secure way to store encryption keys. +/// For production use, it is recommended to use a key-management service (KMS) to encrypt +/// data encryption keys. +pub async fn parquet_encrypted_with_kms() -> Result<()> { + let ctx = SessionContext::new(); + + // Register an `EncryptionFactory` implementation to be used for Parquet encryption + // in the runtime environment. + // `EncryptionFactory` instances are registered with a name to identify them so + // they can be later referenced in configuration options, and it's possible to register + // multiple different factories to handle different ways of encrypting Parquet. + let encryption_factory = TestEncryptionFactory::default(); + ctx.runtime_env().register_parquet_encryption_factory( + ENCRYPTION_FACTORY_ID, + Arc::new(encryption_factory), + ); + + // Register some simple test data + let a: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d"])); + let b: ArrayRef = Arc::new(Int32Array::from(vec![1, 10, 10, 100])); + let c: ArrayRef = Arc::new(Int32Array::from(vec![2, 20, 20, 200])); + let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)])?; + ctx.register_batch("test_data", batch)?; + + { + // Write and read encrypted Parquet with the programmatic API + let tmpdir = TempDir::new()?; + let table_path = format!("{}/", tmpdir.path().to_str().unwrap()); + write_encrypted(&ctx, &table_path).await?; + read_encrypted(&ctx, &table_path).await?; + } + + { + // Write and read encrypted Parquet with the SQL API + let tmpdir = TempDir::new()?; + let table_path = format!("{}/", tmpdir.path().to_str().unwrap()); + write_encrypted_with_sql(&ctx, &table_path).await?; + read_encrypted_with_sql(&ctx, &table_path).await?; + } + + Ok(()) +} + +/// Write an encrypted Parquet file +async fn write_encrypted(ctx: &SessionContext, table_path: &str) -> Result<()> { + let df = ctx.table("test_data").await?; + + let mut parquet_options = TableParquetOptions::new(); + // We specify that we want to use Parquet encryption by setting the identifier of the + // encryption factory to use and providing the factory-specific configuration. + // Our encryption factory only requires specifying the columns to encrypt. + let encryption_config = EncryptionConfig { + encrypted_columns: "b,c".to_owned(), + }; + parquet_options + .crypto + .configure_factory(ENCRYPTION_FACTORY_ID, &encryption_config); + + df.write_parquet( + table_path, + DataFrameWriteOptions::new(), + Some(parquet_options), + ) + .await?; + + println!("Encrypted Parquet written to {table_path}"); + Ok(()) +} + +/// Read from an encrypted Parquet file +async fn read_encrypted(ctx: &SessionContext, table_path: &str) -> Result<()> { + let mut parquet_options = TableParquetOptions::new(); + // Specify the encryption factory to use for decrypting Parquet. + // In this example, we don't require any additional configuration options when reading + // as we only need the key metadata from the Parquet files to determine the decryption keys. + parquet_options + .crypto + .configure_factory(ENCRYPTION_FACTORY_ID, &EncryptionConfig::default()); + + let file_format = ParquetFormat::default().with_options(parquet_options); + let listing_options = ListingOptions::new(Arc::new(file_format)); + + ctx.register_listing_table( + "encrypted_parquet_table", + &table_path, + listing_options.clone(), + None, + None, + ) + .await?; + + let mut batch_stream = ctx + .table("encrypted_parquet_table") + .await? + .execute_stream() + .await?; + println!("Reading encrypted Parquet as a RecordBatch stream"); + while let Some(batch) = batch_stream.next().await { + let batch = batch?; + println!("Read batch with {} rows", batch.num_rows()); + } + + println!("Finished reading"); + Ok(()) +} + +/// Write an encrypted Parquet file using only SQL syntax with string configuration +async fn write_encrypted_with_sql(ctx: &SessionContext, table_path: &str) -> Result<()> { + let query = format!( + "COPY test_data \ + TO '{table_path}' \ + STORED AS parquet + OPTIONS (\ + 'format.crypto.factory_id' '{ENCRYPTION_FACTORY_ID}', \ + 'format.crypto.factory_options.encrypted_columns' 'b,c' \ + )" + ); + let _ = ctx.sql(&query).await?.collect().await?; + + println!("Encrypted Parquet written to {table_path}"); + Ok(()) +} + +/// Read from an encrypted Parquet file using only the SQL API and string-based configuration +async fn read_encrypted_with_sql(ctx: &SessionContext, table_path: &str) -> Result<()> { + let ddl = format!( + "CREATE EXTERNAL TABLE encrypted_parquet_table_2 \ + STORED AS PARQUET LOCATION '{table_path}' OPTIONS (\ + 'format.crypto.factory_id' '{ENCRYPTION_FACTORY_ID}' \ + )" + ); + ctx.sql(&ddl).await?; + let df = ctx.sql("SELECT * FROM encrypted_parquet_table_2").await?; + let mut batch_stream = df.execute_stream().await?; + + println!("Reading encrypted Parquet as a RecordBatch stream"); + while let Some(batch) = batch_stream.next().await { + let batch = batch?; + println!("Read batch with {} rows", batch.num_rows()); + } + println!("Finished reading"); + Ok(()) +} + +// Options used to configure our example encryption factory +extensions_options! { + struct EncryptionConfig { + /// Comma-separated list of columns to encrypt + pub encrypted_columns: String, default = "".to_owned() + } +} + +/// Mock implementation of an `EncryptionFactory` that stores encryption keys +/// base64 encoded in the Parquet encryption metadata. +/// For production use, integrating with a key-management service to encrypt +/// data encryption keys is recommended. +#[derive(Default, Debug)] +struct TestEncryptionFactory {} + +/// `EncryptionFactory` is a DataFusion trait for types that generate +/// file encryption and decryption properties. +#[async_trait] +impl EncryptionFactory for TestEncryptionFactory { + /// Generate file encryption properties to use when writing a Parquet file. + /// The `schema` is provided so that it may be used to dynamically configure + /// per-column encryption keys. + /// The file path is also available. We don't use the path in this example, + /// but other implementations may want to use this to compute an + /// AAD prefix for the file, or to allow use of external key material + /// (where key metadata is stored in a JSON file alongside Parquet files). + async fn get_file_encryption_properties( + &self, + options: &EncryptionFactoryOptions, + schema: &SchemaRef, + _file_path: &Path, + ) -> Result>> { + let config: EncryptionConfig = options.to_extension_options()?; + + // Generate a random encryption key for this file. + let mut key = vec![0u8; 16]; + OsRng.try_fill_bytes(&mut key).unwrap(); + + // Generate the key metadata that allows retrieving the key when reading the file. + let key_metadata = wrap_key(&key); + + let mut builder = FileEncryptionProperties::builder(key.to_vec()) + .with_footer_key_metadata(key_metadata.clone()); + + let encrypted_columns: HashSet<&str> = + config.encrypted_columns.split(",").collect(); + if !encrypted_columns.is_empty() { + // Set up per-column encryption. + for field in schema.fields().iter() { + if encrypted_columns.contains(field.name().as_str()) { + // Here we re-use the same key for all encrypted columns, + // but new keys could also be generated per column. + builder = builder.with_column_key_and_metadata( + field.name().as_str(), + key.clone(), + key_metadata.clone(), + ); + } + } + } + + let encryption_properties = builder.build()?; + + Ok(Some(encryption_properties)) + } + + /// Generate file decryption properties to use when reading a Parquet file. + /// Rather than provide the AES keys directly for decryption, we set a `KeyRetriever` + /// that can determine the keys using the encryption metadata. + async fn get_file_decryption_properties( + &self, + _options: &EncryptionFactoryOptions, + _file_path: &Path, + ) -> Result>> { + let decryption_properties = + FileDecryptionProperties::with_key_retriever(Arc::new(TestKeyRetriever {})) + .build()?; + Ok(Some(decryption_properties)) + } +} + +/// Mock implementation of encrypting a key that simply base64 encodes the key. +/// Note that this is not a secure way to store encryption keys, +/// and for production use keys should be encrypted with a KMS. +fn wrap_key(key: &[u8]) -> Vec { + base64::prelude::BASE64_STANDARD + .encode(key) + .as_bytes() + .to_vec() +} + +struct TestKeyRetriever {} + +impl KeyRetriever for TestKeyRetriever { + /// Get a data encryption key using the metadata stored in the Parquet file. + fn retrieve_key( + &self, + key_metadata: &[u8], + ) -> datafusion::parquet::errors::Result> { + let key_metadata = std::str::from_utf8(key_metadata)?; + let key = base64::prelude::BASE64_STANDARD + .decode(key_metadata) + .unwrap(); + Ok(key) + } +} diff --git a/datafusion-examples/examples/parquet_exec_visitor.rs b/datafusion-examples/examples/data_io/parquet_exec_visitor.rs similarity index 73% rename from datafusion-examples/examples/parquet_exec_visitor.rs rename to datafusion-examples/examples/data_io/parquet_exec_visitor.rs index 84f92d4f450e1..47caf9480df93 100644 --- a/datafusion-examples/examples/parquet_exec_visitor.rs +++ b/datafusion-examples/examples/data_io/parquet_exec_visitor.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use std::sync::Arc; use datafusion::datasource::file_format::parquet::ParquetFormat; @@ -25,34 +27,37 @@ use datafusion::error::DataFusionError; use datafusion::execution::context::SessionContext; use datafusion::physical_plan::metrics::MetricValue; use datafusion::physical_plan::{ - execute_stream, visit_execution_plan, ExecutionPlan, ExecutionPlanVisitor, + ExecutionPlan, ExecutionPlanVisitor, execute_stream, visit_execution_plan, }; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; use futures::StreamExt; /// Example of collecting metrics after execution by visiting the `ExecutionPlan` -#[tokio::main] -async fn main() { +pub async fn parquet_exec_visitor() -> datafusion::common::Result<()> { let ctx = SessionContext::new(); - let test_data = datafusion::test_util::parquet_test_data(); + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; // Configure listing options let file_format = ParquetFormat::default().with_enable_pruning(true); let listing_options = ListingOptions::new(Arc::new(file_format)); + let table_path = parquet_temp.file_uri()?; + // First example were we use an absolute path, which requires no additional setup. - let _ = ctx - .register_listing_table( - "my_table", - &format!("file://{test_data}/alltypes_plain.parquet"), - listing_options.clone(), - None, - None, - ) - .await; - - let df = ctx.sql("SELECT * FROM my_table").await.unwrap(); - let plan = df.create_physical_plan().await.unwrap(); + ctx.register_listing_table( + "my_table", + &table_path, + listing_options.clone(), + None, + None, + ) + .await?; + + let df = ctx.sql("SELECT * FROM my_table").await?; + let plan = df.create_physical_plan().await?; // Create empty visitor let mut visitor = ParquetExecVisitor { @@ -63,12 +68,12 @@ async fn main() { // Make sure you execute the plan to collect actual execution statistics. // For example, in this example the `file_scan_config` is known without executing // but the `bytes_scanned` would be None if we did not execute. - let mut batch_stream = execute_stream(plan.clone(), ctx.task_ctx()).unwrap(); + let mut batch_stream = execute_stream(plan.clone(), ctx.task_ctx())?; while let Some(batch) = batch_stream.next().await { println!("Batch rows: {}", batch.unwrap().num_rows()); } - visit_execution_plan(plan.as_ref(), &mut visitor).unwrap(); + visit_execution_plan(plan.as_ref(), &mut visitor)?; println!( "ParquetExecVisitor bytes_scanned: {:?}", @@ -78,6 +83,8 @@ async fn main() { "ParquetExecVisitor file_groups: {:?}", visitor.file_groups.unwrap() ); + + Ok(()) } /// Define a struct with fields to hold the execution information you want to @@ -97,18 +104,17 @@ impl ExecutionPlanVisitor for ParquetExecVisitor { /// or `post_visit` (visit each node after its children/inputs) fn pre_visit(&mut self, plan: &dyn ExecutionPlan) -> Result { // If needed match on a specific `ExecutionPlan` node type - if let Some(data_source_exec) = plan.as_any().downcast_ref::() { - if let Some((file_config, _)) = + if let Some(data_source_exec) = plan.as_any().downcast_ref::() + && let Some((file_config, _)) = data_source_exec.downcast_to_file_source::() - { - self.file_groups = Some(file_config.file_groups.clone()); - - let metrics = match data_source_exec.metrics() { - None => return Ok(true), - Some(metrics) => metrics, - }; - self.bytes_scanned = metrics.sum_by_name("bytes_scanned"); - } + { + self.file_groups = Some(file_config.file_groups.clone()); + + let metrics = match data_source_exec.metrics() { + None => return Ok(true), + Some(metrics) => metrics, + }; + self.bytes_scanned = metrics.sum_by_name("bytes_scanned"); } Ok(true) } diff --git a/datafusion-examples/examples/parquet_index.rs b/datafusion-examples/examples/data_io/parquet_index.rs similarity index 97% rename from datafusion-examples/examples/parquet_index.rs rename to datafusion-examples/examples/data_io/parquet_index.rs index e5ae3cc86bfe5..e11a303f442a4 100644 --- a/datafusion-examples/examples/parquet_index.rs +++ b/datafusion-examples/examples/data_io/parquet_index.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use arrow::array::{ Array, ArrayRef, AsArray, BooleanArray, Int32Array, RecordBatch, StringArray, UInt64Array, @@ -25,19 +27,19 @@ use async_trait::async_trait; use datafusion::catalog::Session; use datafusion::common::pruning::PruningStatistics; use datafusion::common::{ - internal_datafusion_err, DFSchema, DataFusionError, Result, ScalarValue, + DFSchema, DataFusionError, Result, ScalarValue, internal_datafusion_err, }; +use datafusion::datasource::TableProvider; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::memory::DataSourceExec; use datafusion::datasource::physical_plan::{FileScanConfigBuilder, ParquetSource}; -use datafusion::datasource::TableProvider; use datafusion::execution::object_store::ObjectStoreUrl; use datafusion::logical_expr::{ - utils::conjunction, TableProviderFilterPushDown, TableType, + TableProviderFilterPushDown, TableType, utils::conjunction, }; use datafusion::parquet::arrow::arrow_reader::statistics::StatisticsConverter; use datafusion::parquet::arrow::{ - arrow_reader::ParquetRecordBatchReaderBuilder, ArrowWriter, + ArrowWriter, arrow_reader::ParquetRecordBatchReaderBuilder, }; use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_optimizer::pruning::PruningPredicate; @@ -50,8 +52,8 @@ use std::fs; use std::fs::{DirEntry, File}; use std::ops::Range; use std::path::{Path, PathBuf}; -use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; use tempfile::TempDir; use url::Url; @@ -71,7 +73,7 @@ use url::Url; /// (using the same underlying APIs) /// /// For a more advanced example of using an index to prune row groups within a -/// file, see the (forthcoming) `advanced_parquet_index` example. +/// file, see the `advanced_parquet_index` example. /// /// # Diagram /// @@ -99,12 +101,10 @@ use url::Url; /// Thus some parquet files are │ │ /// "pruned" and thus are not └─────────────┘ /// scanned at all Parquet Files -/// /// ``` /// /// [`ListingTable`]: datafusion::datasource::listing::ListingTable -#[tokio::main] -async fn main() -> Result<()> { +pub async fn parquet_index() -> Result<()> { // Demo data has three files, each with schema // * file_name (string) // * value (int32) @@ -243,10 +243,11 @@ impl TableProvider for IndexTableProvider { let files = self.index.get_files(predicate.clone())?; let object_store_url = ObjectStoreUrl::parse("file://")?; - let source = Arc::new(ParquetSource::default().with_predicate(predicate)); + let source = + Arc::new(ParquetSource::new(self.schema()).with_predicate(predicate)); let mut file_scan_config_builder = - FileScanConfigBuilder::new(object_store_url, self.schema(), source) - .with_projection(projection.cloned()) + FileScanConfigBuilder::new(object_store_url, source) + .with_projection_indices(projection.cloned())? .with_limit(limit); // Transform to the format needed to pass to DataSourceExec @@ -313,7 +314,7 @@ impl Display for ParquetMetadataIndex { "ParquetMetadataIndex(last_num_pruned: {})", self.last_num_pruned() )?; - let batches = pretty_format_batches(&[self.index.clone()]).unwrap(); + let batches = pretty_format_batches(std::slice::from_ref(&self.index)).unwrap(); write!(f, "{batches}",) } } @@ -510,7 +511,7 @@ impl ParquetMetadataIndexBuilder { // Get the schema of the file. A real system might have to handle the // case where the schema of the file is not the same as the schema of - // the other files e.g. using SchemaAdapter. + // the other files e.g. using PhysicalExprAdapterFactory. if self.file_schema.is_none() { self.file_schema = Some(reader.schema().clone()); } diff --git a/datafusion-examples/examples/query-http-csv.rs b/datafusion-examples/examples/data_io/query_http_csv.rs similarity index 91% rename from datafusion-examples/examples/query-http-csv.rs rename to datafusion-examples/examples/data_io/query_http_csv.rs index fa3fd2ac068df..71421e6270ccb 100644 --- a/datafusion-examples/examples/query-http-csv.rs +++ b/datafusion-examples/examples/data_io/query_http_csv.rs @@ -15,16 +15,16 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use datafusion::error::Result; use datafusion::prelude::*; use object_store::http::HttpBuilder; use std::sync::Arc; use url::Url; -/// This example demonstrates executing a simple query against an Arrow data source (CSV) and -/// fetching results -#[tokio::main] -async fn main() -> Result<()> { +/// Configure `object_store` and run a query against files via HTTP +pub async fn query_http_csv() -> Result<()> { // create local execution context let ctx = SessionContext::new(); diff --git a/datafusion-examples/examples/remote_catalog.rs b/datafusion-examples/examples/data_io/remote_catalog.rs similarity index 97% rename from datafusion-examples/examples/remote_catalog.rs rename to datafusion-examples/examples/data_io/remote_catalog.rs index 70c0963545e08..10ec26b1d5c05 100644 --- a/datafusion-examples/examples/remote_catalog.rs +++ b/datafusion-examples/examples/data_io/remote_catalog.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. +//! /// This example shows how to implement the DataFusion [`CatalogProvider`] API /// for catalogs that are remote (require network access) and/or offer only /// asynchronous APIs such as [Polaris], [Unity], and [Hive]. @@ -39,15 +41,15 @@ use datafusion::common::{assert_batches_eq, internal_datafusion_err, plan_err}; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::execution::SendableRecordBatchStream; use datafusion::logical_expr::{Expr, TableType}; -use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::prelude::{DataFrame, SessionContext}; use futures::TryStreamExt; use std::any::Any; use std::sync::Arc; -#[tokio::main] -async fn main() -> Result<()> { +/// Interfacing with a remote catalog (e.g. over a network) +pub async fn remote_catalog() -> Result<()> { // As always, we create a session context to interact with DataFusion let ctx = SessionContext::new(); @@ -75,8 +77,8 @@ async fn main() -> Result<()> { let state = ctx.state(); // First, parse the SQL (but don't plan it / resolve any table references) - let dialect = state.config().options().sql_parser.dialect.as_str(); - let statement = state.sql_to_statement(sql, dialect)?; + let dialect = state.config().options().sql_parser.dialect; + let statement = state.sql_to_statement(sql, &dialect)?; // Find all `TableReferences` in the parsed queries. These correspond to the // tables referred to by the query (in this case diff --git a/datafusion-examples/examples/dataframe/cache_factory.rs b/datafusion-examples/examples/dataframe/cache_factory.rs new file mode 100644 index 0000000000000..a92c3dc4ce26a --- /dev/null +++ b/datafusion-examples/examples/dataframe/cache_factory.rs @@ -0,0 +1,229 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. + +use std::fmt::Debug; +use std::hash::Hash; +use std::sync::{Arc, RwLock}; + +use arrow::array::RecordBatch; +use async_trait::async_trait; +use datafusion::catalog::memory::MemorySourceConfig; +use datafusion::common::DFSchemaRef; +use datafusion::error::Result; +use datafusion::execution::context::QueryPlanner; +use datafusion::execution::session_state::CacheFactory; +use datafusion::execution::{SessionState, SessionStateBuilder}; +use datafusion::logical_expr::{ + Extension, LogicalPlan, UserDefinedLogicalNode, UserDefinedLogicalNodeCore, +}; +use datafusion::physical_plan::{ExecutionPlan, collect_partitioned}; +use datafusion::physical_planner::{ + DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner, +}; +use datafusion::prelude::*; +use datafusion_common::HashMap; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; + +/// This example demonstrates how to leverage [CacheFactory] to implement custom caching strategies for dataframes in DataFusion. +/// By default, [DataFrame::cache] in Datafusion is eager and creates an in-memory table. This example shows a basic alternative implementation for lazy caching. +/// Specifically, it implements: +/// - A [CustomCacheFactory] that creates a logical node [CacheNode] representing the cache operation. +/// - A [CacheNodePlanner] (an [ExtensionPlanner]) that understands [CacheNode] and performs caching. +/// - A [CacheNodeQueryPlanner] that installs [CacheNodePlanner]. +/// - A simple in-memory [CacheManager] that stores cached [RecordBatch]es. Note that the implementation for this example is very naive and only implements put, but for real production use cases cache eviction and drop should also be implemented. +pub async fn cache_dataframe_with_custom_logic() -> Result<()> { + let session_state = SessionStateBuilder::new() + .with_cache_factory(Some(Arc::new(CustomCacheFactory {}))) + .with_query_planner(Arc::new(CacheNodeQueryPlanner::default())) + .build(); + let ctx = SessionContext::new_with_state(session_state); + + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; + + // Read the parquet files and show its schema using 'describe' + let parquet_df = ctx + .read_parquet(parquet_temp.path_str()?, ParquetReadOptions::default()) + .await?; + + let df_cached = parquet_df + .select_columns(&["car", "speed", "time"])? + .filter(col("speed").gt(lit(1.0)))? + .cache() + .await?; + + let df1 = df_cached.clone().filter(col("car").eq(lit("red")))?; + let df2 = df1.clone().sort(vec![col("car").sort(true, false)])?; + + // should see log for caching only once + df_cached.show().await?; + df1.show().await?; + df2.show().await?; + + Ok(()) +} + +#[derive(Debug)] +struct CustomCacheFactory {} + +impl CacheFactory for CustomCacheFactory { + fn create( + &self, + plan: LogicalPlan, + _session_state: &SessionState, + ) -> Result { + Ok(LogicalPlan::Extension(Extension { + node: Arc::new(CacheNode { input: plan }), + })) + } +} + +#[derive(PartialEq, Eq, PartialOrd, Hash, Debug)] +struct CacheNode { + input: LogicalPlan, +} + +impl UserDefinedLogicalNodeCore for CacheNode { + fn name(&self) -> &str { + "CacheNode" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + fn schema(&self) -> &DFSchemaRef { + self.input.schema() + } + + fn expressions(&self) -> Vec { + vec![] + } + + fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "CacheNode") + } + + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + mut inputs: Vec, + ) -> Result { + assert_eq!(inputs.len(), 1, "input size must be one"); + Ok(Self { + input: inputs.swap_remove(0), + }) + } +} + +struct CacheNodePlanner { + cache_manager: Arc>, +} + +#[async_trait] +impl ExtensionPlanner for CacheNodePlanner { + async fn plan_extension( + &self, + _planner: &dyn PhysicalPlanner, + node: &dyn UserDefinedLogicalNode, + logical_inputs: &[&LogicalPlan], + physical_inputs: &[Arc], + session_state: &SessionState, + ) -> Result>> { + if let Some(cache_node) = node.as_any().downcast_ref::() { + assert_eq!(logical_inputs.len(), 1, "Inconsistent number of inputs"); + assert_eq!(physical_inputs.len(), 1, "Inconsistent number of inputs"); + if self + .cache_manager + .read() + .unwrap() + .get(&cache_node.input) + .is_none() + { + let ctx = session_state.task_ctx(); + println!("caching in memory"); + let batches = + collect_partitioned(physical_inputs[0].clone(), ctx).await?; + self.cache_manager + .write() + .unwrap() + .put(cache_node.input.clone(), batches); + } else { + println!("fetching directly from cache manager"); + } + Ok(self + .cache_manager + .read() + .unwrap() + .get(&cache_node.input) + .map(|batches| { + let exec: Arc = MemorySourceConfig::try_new_exec( + batches, + physical_inputs[0].schema(), + None, + ) + .unwrap(); + exec + })) + } else { + Ok(None) + } + } +} + +#[derive(Debug, Default)] +struct CacheNodeQueryPlanner { + cache_manager: Arc>, +} + +#[async_trait] +impl QueryPlanner for CacheNodeQueryPlanner { + async fn create_physical_plan( + &self, + logical_plan: &LogicalPlan, + session_state: &SessionState, + ) -> Result> { + let physical_planner = + DefaultPhysicalPlanner::with_extension_planners(vec![Arc::new( + CacheNodePlanner { + cache_manager: Arc::clone(&self.cache_manager), + }, + )]); + physical_planner + .create_physical_plan(logical_plan, session_state) + .await + } +} + +// This naive implementation only includes put, but for real production use cases cache eviction and drop should also be implemented. +#[derive(Debug, Default)] +struct CacheManager { + cache: HashMap>>, +} + +impl CacheManager { + pub fn put(&mut self, k: LogicalPlan, v: Vec>) { + self.cache.insert(k, v); + } + + pub fn get(&self, k: &LogicalPlan) -> Option<&Vec>> { + self.cache.get(k) + } +} diff --git a/datafusion-examples/examples/dataframe.rs b/datafusion-examples/examples/dataframe/dataframe.rs similarity index 73% rename from datafusion-examples/examples/dataframe.rs rename to datafusion-examples/examples/dataframe/dataframe.rs index 57a28aeca0de2..dde19cb476f14 100644 --- a/datafusion-examples/examples/dataframe.rs +++ b/datafusion-examples/examples/dataframe/dataframe.rs @@ -15,22 +15,26 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + +use std::fs::File; +use std::io::Write; +use std::sync::Arc; + use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray, StringViewArray}; use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::catalog::MemTable; +use datafusion::common::ScalarValue; use datafusion::common::config::CsvOptions; use datafusion::common::parsers::CompressionTypeVariant; -use datafusion::common::DataFusionError; -use datafusion::common::ScalarValue; use datafusion::dataframe::DataFrameWriteOptions; use datafusion::error::Result; use datafusion::functions_aggregate::average::avg; use datafusion::functions_aggregate::min_max::max; use datafusion::prelude::*; -use std::fs::File; -use std::io::Write; -use std::sync::Arc; -use tempfile::tempdir; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; +use tempfile::{TempDir, tempdir}; +use tokio::fs::create_dir_all; /// This example demonstrates using DataFusion's DataFrame API /// @@ -39,6 +43,7 @@ use tempfile::tempdir; /// * [read_parquet]: execute queries against parquet files /// * [read_csv]: execute queries against csv files /// * [read_memory]: execute queries against in-memory arrow data +/// * [read_memory_macro]: execute queries against in-memory arrow data using macro /// /// # Writing out to local storage /// @@ -53,12 +58,8 @@ use tempfile::tempdir; /// * [where_scalar_subquery]: execute a scalar subquery /// * [where_in_subquery]: execute a subquery with an IN clause /// * [where_exist_subquery]: execute a subquery with an EXISTS clause -/// -/// # Querying data -/// -/// * [query_to_date]: execute queries against parquet files -#[tokio::main] -async fn main() -> Result<()> { +pub async fn dataframe_example() -> Result<()> { + env_logger::init(); // The SessionContext is the main high level API for interacting with DataFusion let ctx = SessionContext::new(); read_parquet(&ctx).await?; @@ -66,8 +67,8 @@ async fn main() -> Result<()> { read_memory(&ctx).await?; read_memory_macro().await?; write_out(&ctx).await?; - register_aggregate_test_data("t1", &ctx).await?; - register_aggregate_test_data("t2", &ctx).await?; + register_cars_test_data("t1", &ctx).await?; + register_cars_test_data("t2", &ctx).await?; where_scalar_subquery(&ctx).await?; where_in_subquery(&ctx).await?; where_exist_subquery(&ctx).await?; @@ -79,23 +80,24 @@ async fn main() -> Result<()> { /// 2. Show the schema /// 3. Select columns and rows async fn read_parquet(ctx: &SessionContext) -> Result<()> { - // Find the local path of "alltypes_plain.parquet" - let testdata = datafusion::test_util::parquet_test_data(); - let filename = &format!("{testdata}/alltypes_plain.parquet"); + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(ctx, &dataset.path()).await?; // Read the parquet files and show its schema using 'describe' let parquet_df = ctx - .read_parquet(filename, ParquetReadOptions::default()) + .read_parquet(parquet_temp.path_str()?, ParquetReadOptions::default()) .await?; // show its schema using 'describe' parquet_df.clone().describe().await?.show().await?; // Select three columns and filter the results - // so that only rows where id > 1 are returned + // so that only rows where speed > 1 are returned + // select car, speed, time from t where speed > 1 parquet_df - .select_columns(&["id", "bool_col", "timestamp_col"])? - .filter(col("id").gt(lit(1)))? + .select_columns(&["car", "speed", "time"])? + .filter(col("speed").gt(lit(1)))? .show() .await?; @@ -198,7 +200,7 @@ async fn read_memory_macro() -> Result<()> { /// 2. Write out a DataFrame to a parquet file /// 3. Write out a DataFrame to a csv file /// 4. Write out a DataFrame to a json file -async fn write_out(ctx: &SessionContext) -> std::result::Result<(), DataFusionError> { +async fn write_out(ctx: &SessionContext) -> Result<()> { let array = StringViewArray::from(vec!["a", "b", "c"]); let schema = Arc::new(Schema::new(vec![Field::new( "tablecol1", @@ -210,15 +212,26 @@ async fn write_out(ctx: &SessionContext) -> std::result::Result<(), DataFusionEr ctx.register_table("initial_data", Arc::new(mem_table))?; let df = ctx.table("initial_data").await?; - ctx.sql( - "create external table - test(tablecol1 varchar) - stored as parquet - location './datafusion-examples/test_table/'", - ) - .await? - .collect() - .await?; + // Create a single temp root with subdirectories + let tmp_root = TempDir::new()?; + let examples_root = tmp_root.path().join("datafusion-examples"); + create_dir_all(&examples_root).await?; + let table_dir = examples_root.join("test_table"); + let parquet_dir = examples_root.join("test_parquet"); + let csv_dir = examples_root.join("test_csv"); + let json_dir = examples_root.join("test_json"); + create_dir_all(&table_dir).await?; + create_dir_all(&parquet_dir).await?; + create_dir_all(&csv_dir).await?; + create_dir_all(&json_dir).await?; + + let create_sql = format!( + "CREATE EXTERNAL TABLE test(tablecol1 varchar) + STORED AS parquet + LOCATION '{}'", + table_dir.display() + ); + ctx.sql(&create_sql).await?.collect().await?; // This is equivalent to INSERT INTO test VALUES ('a'), ('b'), ('c'). // The behavior of write_table depends on the TableProvider's implementation @@ -229,7 +242,7 @@ async fn write_out(ctx: &SessionContext) -> std::result::Result<(), DataFusionEr df.clone() .write_parquet( - "./datafusion-examples/test_parquet/", + parquet_dir.to_str().unwrap(), DataFrameWriteOptions::new(), None, ) @@ -237,7 +250,7 @@ async fn write_out(ctx: &SessionContext) -> std::result::Result<(), DataFusionEr df.clone() .write_csv( - "./datafusion-examples/test_csv/", + csv_dir.to_str().unwrap(), // DataFrameWriteOptions contains options which control how data is written // such as compression codec DataFrameWriteOptions::new(), @@ -247,7 +260,7 @@ async fn write_out(ctx: &SessionContext) -> std::result::Result<(), DataFusionEr df.clone() .write_json( - "./datafusion-examples/test_json/", + json_dir.to_str().unwrap(), DataFrameWriteOptions::new(), None, ) @@ -257,7 +270,7 @@ async fn write_out(ctx: &SessionContext) -> std::result::Result<(), DataFusionEr } /// Use the DataFrame API to execute the following subquery: -/// select c1,c2 from t1 where (select avg(t2.c2) from t2 where t1.c1 = t2.c1)>0 limit 3; +/// select car, speed from t1 where (select avg(t2.speed) from t2 where t1.car = t2.car) > 0 limit 3; async fn where_scalar_subquery(ctx: &SessionContext) -> Result<()> { ctx.table("t1") .await? @@ -265,14 +278,14 @@ async fn where_scalar_subquery(ctx: &SessionContext) -> Result<()> { scalar_subquery(Arc::new( ctx.table("t2") .await? - .filter(out_ref_col(DataType::Utf8, "t1.c1").eq(col("t2.c1")))? - .aggregate(vec![], vec![avg(col("t2.c2"))])? - .select(vec![avg(col("t2.c2"))])? + .filter(out_ref_col(DataType::Utf8, "t1.car").eq(col("t2.car")))? + .aggregate(vec![], vec![avg(col("t2.speed"))])? + .select(vec![avg(col("t2.speed"))])? .into_unoptimized_plan(), )) - .gt(lit(0u8)), + .gt(lit(0.0)), )? - .select(vec![col("t1.c1"), col("t1.c2")])? + .select(vec![col("t1.car"), col("t1.speed")])? .limit(0, Some(3))? .show() .await?; @@ -280,22 +293,24 @@ async fn where_scalar_subquery(ctx: &SessionContext) -> Result<()> { } /// Use the DataFrame API to execute the following subquery: -/// select t1.c1, t1.c2 from t1 where t1.c2 in (select max(t2.c2) from t2 where t2.c1 > 0 ) limit 3; +/// select t1.car, t1.speed from t1 where t1.speed in (select max(t2.speed) from t2 where t2.car = 'red') limit 3; async fn where_in_subquery(ctx: &SessionContext) -> Result<()> { ctx.table("t1") .await? .filter(in_subquery( - col("t1.c2"), + col("t1.speed"), Arc::new( ctx.table("t2") .await? - .filter(col("t2.c1").gt(lit(ScalarValue::UInt8(Some(0)))))? - .aggregate(vec![], vec![max(col("t2.c2"))])? - .select(vec![max(col("t2.c2"))])? + .filter( + col("t2.car").eq(lit(ScalarValue::Utf8(Some("red".to_string())))), + )? + .aggregate(vec![], vec![max(col("t2.speed"))])? + .select(vec![max(col("t2.speed"))])? .into_unoptimized_plan(), ), ))? - .select(vec![col("t1.c1"), col("t1.c2")])? + .select(vec![col("t1.car"), col("t1.speed")])? .limit(0, Some(3))? .show() .await?; @@ -303,31 +318,27 @@ async fn where_in_subquery(ctx: &SessionContext) -> Result<()> { } /// Use the DataFrame API to execute the following subquery: -/// select t1.c1, t1.c2 from t1 where exists (select t2.c2 from t2 where t1.c1 = t2.c1) limit 3; +/// select t1.car, t1.speed from t1 where exists (select t2.speed from t2 where t1.car = t2.car) limit 3; async fn where_exist_subquery(ctx: &SessionContext) -> Result<()> { ctx.table("t1") .await? .filter(exists(Arc::new( ctx.table("t2") .await? - .filter(out_ref_col(DataType::Utf8, "t1.c1").eq(col("t2.c1")))? - .select(vec![col("t2.c2")])? + .filter(out_ref_col(DataType::Utf8, "t1.car").eq(col("t2.car")))? + .select(vec![col("t2.speed")])? .into_unoptimized_plan(), )))? - .select(vec![col("t1.c1"), col("t1.c2")])? + .select(vec![col("t1.car"), col("t1.speed")])? .limit(0, Some(3))? .show() .await?; Ok(()) } -async fn register_aggregate_test_data(name: &str, ctx: &SessionContext) -> Result<()> { - let testdata = datafusion::test_util::arrow_test_data(); - ctx.register_csv( - name, - &format!("{testdata}/csv/aggregate_test_100.csv"), - CsvReadOptions::default(), - ) - .await?; +async fn register_cars_test_data(name: &str, ctx: &SessionContext) -> Result<()> { + let dataset = ExampleDataset::Cars; + ctx.register_csv(name, dataset.path_str()?, CsvReadOptions::default()) + .await?; Ok(()) } diff --git a/datafusion-examples/examples/dataframe/deserialize_to_struct.rs b/datafusion-examples/examples/dataframe/deserialize_to_struct.rs new file mode 100644 index 0000000000000..b031225dc9b69 --- /dev/null +++ b/datafusion-examples/examples/dataframe/deserialize_to_struct.rs @@ -0,0 +1,366 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. + +use arrow::array::{Array, Float64Array, StringViewArray}; +use datafusion::common::assert_batches_eq; +use datafusion::error::Result; +use datafusion::prelude::*; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; +use futures::StreamExt; + +/// This example shows how to convert query results into Rust structs by using +/// the Arrow APIs to convert the results into Rust native types. +/// +/// This is a bit tricky initially as the results are returned as columns stored +/// as [ArrayRef] +/// +/// [ArrayRef]: arrow::array::ArrayRef +pub async fn deserialize_to_struct() -> Result<()> { + // Run a query that returns two columns of data + let ctx = SessionContext::new(); + + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; + + ctx.register_parquet( + "cars", + parquet_temp.path_str()?, + ParquetReadOptions::default(), + ) + .await?; + + let df = ctx + .sql("SELECT car, speed FROM cars ORDER BY speed LIMIT 50") + .await?; + + // print out the results showing we have car and speed columns and a deterministic ordering + let results = df.clone().collect().await?; + assert_batches_eq!( + [ + "+-------+-------+", + "| car | speed |", + "+-------+-------+", + "| red | 0.0 |", + "| red | 1.0 |", + "| green | 2.0 |", + "| red | 3.0 |", + "| red | 7.0 |", + "| red | 7.1 |", + "| red | 7.2 |", + "| green | 8.0 |", + "| green | 10.0 |", + "| green | 10.3 |", + "| green | 10.4 |", + "| green | 10.5 |", + "| green | 11.0 |", + "| green | 12.0 |", + "| green | 14.0 |", + "| green | 15.0 |", + "| green | 15.1 |", + "| green | 15.2 |", + "| red | 17.0 |", + "| red | 18.0 |", + "| red | 19.0 |", + "| red | 20.0 |", + "| red | 20.3 |", + "| red | 21.4 |", + "| red | 21.5 |", + "+-------+-------+", + ], + &results + ); + + // We will now convert the query results into a Rust struct + let mut stream = df.execute_stream().await?; + let mut list: Vec = vec![]; + + // DataFusion produces data in chunks called `RecordBatch`es which are + // typically 8000 rows each. This loop processes each `RecordBatch` as it is + // produced by the query plan and adds it to the list + while let Some(batch) = stream.next().await.transpose()? { + // Each `RecordBatch` has one or more columns. Each column is stored as + // an `ArrayRef`. To interact with data using Rust native types we need to + // convert these `ArrayRef`s into concrete array types using APIs from + // the arrow crate. + + // In this case, we know that each batch has two columns of the Arrow + // types StringView and Float64, so first we cast the two columns to the + // appropriate Arrow PrimitiveArray (this is a fast / zero-copy cast).: + let car_col = batch + .column(0) + .as_any() + .downcast_ref::() + .expect("car column must be Utf8View"); + + let speed_col = batch + .column(1) + .as_any() + .downcast_ref::() + .expect("speed column must be Float64"); + + // With PrimitiveArrays, we can access to the values as native Rust + // types String and f64, and forming the desired `Data` structs + for i in 0..batch.num_rows() { + let car = if car_col.is_null(i) { + None + } else { + Some(car_col.value(i).to_string()) + }; + + let speed = if speed_col.is_null(i) { + None + } else { + Some(speed_col.value(i)) + }; + + list.push(Data { car, speed }); + } + } + + // Finally, we have the results in the list of Rust structs + let res = format!("{list:#?}"); + assert_eq!( + res, + r#"[ + Data { + car: Some( + "red", + ), + speed: Some( + 0.0, + ), + }, + Data { + car: Some( + "red", + ), + speed: Some( + 1.0, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 2.0, + ), + }, + Data { + car: Some( + "red", + ), + speed: Some( + 3.0, + ), + }, + Data { + car: Some( + "red", + ), + speed: Some( + 7.0, + ), + }, + Data { + car: Some( + "red", + ), + speed: Some( + 7.1, + ), + }, + Data { + car: Some( + "red", + ), + speed: Some( + 7.2, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 8.0, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 10.0, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 10.3, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 10.4, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 10.5, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 11.0, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 12.0, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 14.0, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 15.0, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 15.1, + ), + }, + Data { + car: Some( + "green", + ), + speed: Some( + 15.2, + ), + }, + Data { + car: Some( + "red", + ), + speed: Some( + 17.0, + ), + }, + Data { + car: Some( + "red", + ), + speed: Some( + 18.0, + ), + }, + Data { + car: Some( + "red", + ), + speed: Some( + 19.0, + ), + }, + Data { + car: Some( + "red", + ), + speed: Some( + 20.0, + ), + }, + Data { + car: Some( + "red", + ), + speed: Some( + 20.3, + ), + }, + Data { + car: Some( + "red", + ), + speed: Some( + 21.4, + ), + }, + Data { + car: Some( + "red", + ), + speed: Some( + 21.5, + ), + }, +]"# + ); + + let speed_green_sum: f64 = list + .iter() + .filter(|data| data.car.as_deref() == Some("green")) + .filter_map(|data| data.speed) + .sum(); + let speed_red_sum: f64 = list + .iter() + .filter(|data| data.car.as_deref() == Some("red")) + .filter_map(|data| data.speed) + .sum(); + assert_eq!(speed_green_sum, 133.5); + assert_eq!(speed_red_sum, 162.5); + + Ok(()) +} + +/// This is target struct where we want the query results. +#[derive(Debug)] +struct Data { + car: Option, + speed: Option, +} diff --git a/datafusion-examples/examples/dataframe/main.rs b/datafusion-examples/examples/dataframe/main.rs new file mode 100644 index 0000000000000..25b5377d38239 --- /dev/null +++ b/datafusion-examples/examples/dataframe/main.rs @@ -0,0 +1,100 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # These are core DataFrame API usage +//! +//! These examples demonstrate core DataFrame API usage. +//! +//! ## Usage +//! ```bash +//! cargo run --example dataframe -- [all|dataframe|deserialize_to_struct|cache_factory] +//! ``` +//! +//! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module +//! +//! - `cache_factory` +//! (file: cache_factory.rs, desc: Custom lazy caching for DataFrames using `CacheFactory`) +// +//! - `dataframe` +//! (file: dataframe.rs, desc: Query DataFrames from various sources and write output) +//! +//! - `deserialize_to_struct` +//! (file: deserialize_to_struct.rs, desc: Convert Arrow arrays into Rust structs) + +mod cache_factory; +mod dataframe; +mod deserialize_to_struct; + +use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; + +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] +enum ExampleKind { + All, + Dataframe, + DeserializeToStruct, + CacheFactory, +} + +impl ExampleKind { + const EXAMPLE_NAME: &str = "dataframe"; + + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<()> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::Dataframe => { + dataframe::dataframe_example().await?; + } + ExampleKind::DeserializeToStruct => { + deserialize_to_struct::deserialize_to_struct().await?; + } + ExampleKind::CacheFactory => { + cache_factory::cache_dataframe_with_custom_logic().await?; + } + } + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let usage = format!( + "Usage: cargo run --example {} -- [{}]", + ExampleKind::EXAMPLE_NAME, + ExampleKind::VARIANTS.join("|") + ); + + let example: ExampleKind = std::env::args() + .nth(1) + .unwrap_or_else(|| ExampleKind::All.to_string()) + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; + + example.run().await +} diff --git a/datafusion-examples/examples/deserialize_to_struct.rs b/datafusion-examples/examples/deserialize_to_struct.rs deleted file mode 100644 index d6655b3b654f9..0000000000000 --- a/datafusion-examples/examples/deserialize_to_struct.rs +++ /dev/null @@ -1,150 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::array::{AsArray, PrimitiveArray}; -use arrow::datatypes::{Float64Type, Int32Type}; -use datafusion::common::assert_batches_eq; -use datafusion::error::Result; -use datafusion::prelude::*; -use futures::StreamExt; - -/// This example shows how to convert query results into Rust structs by using -/// the Arrow APIs to convert the results into Rust native types. -/// -/// This is a bit tricky initially as the results are returned as columns stored -/// as [ArrayRef] -/// -/// [ArrayRef]: arrow::array::ArrayRef -#[tokio::main] -async fn main() -> Result<()> { - // Run a query that returns two columns of data - let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); - ctx.register_parquet( - "alltypes_plain", - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) - .await?; - let df = ctx - .sql("SELECT int_col, double_col FROM alltypes_plain") - .await?; - - // print out the results showing we have an int32 and a float64 column - let results = df.clone().collect().await?; - assert_batches_eq!( - [ - "+---------+------------+", - "| int_col | double_col |", - "+---------+------------+", - "| 0 | 0.0 |", - "| 1 | 10.1 |", - "| 0 | 0.0 |", - "| 1 | 10.1 |", - "| 0 | 0.0 |", - "| 1 | 10.1 |", - "| 0 | 0.0 |", - "| 1 | 10.1 |", - "+---------+------------+", - ], - &results - ); - - // We will now convert the query results into a Rust struct - let mut stream = df.execute_stream().await?; - let mut list = vec![]; - - // DataFusion produces data in chunks called `RecordBatch`es which are - // typically 8000 rows each. This loop processes each `RecordBatch` as it is - // produced by the query plan and adds it to the list - while let Some(b) = stream.next().await.transpose()? { - // Each `RecordBatch` has one or more columns. Each column is stored as - // an `ArrayRef`. To interact with data using Rust native types we need to - // convert these `ArrayRef`s into concrete array types using APIs from - // the arrow crate. - - // In this case, we know that each batch has two columns of the Arrow - // types Int32 and Float64, so first we cast the two columns to the - // appropriate Arrow PrimitiveArray (this is a fast / zero-copy cast).: - let int_col: &PrimitiveArray = b.column(0).as_primitive(); - let float_col: &PrimitiveArray = b.column(1).as_primitive(); - - // With PrimitiveArrays, we can access to the values as native Rust - // types i32 and f64, and forming the desired `Data` structs - for (i, f) in int_col.values().iter().zip(float_col.values()) { - list.push(Data { - int_col: *i, - double_col: *f, - }) - } - } - - // Finally, we have the results in the list of Rust structs - let res = format!("{list:#?}"); - assert_eq!( - res, - r#"[ - Data { - int_col: 0, - double_col: 0.0, - }, - Data { - int_col: 1, - double_col: 10.1, - }, - Data { - int_col: 0, - double_col: 0.0, - }, - Data { - int_col: 1, - double_col: 10.1, - }, - Data { - int_col: 0, - double_col: 0.0, - }, - Data { - int_col: 1, - double_col: 10.1, - }, - Data { - int_col: 0, - double_col: 0.0, - }, - Data { - int_col: 1, - double_col: 10.1, - }, -]"# - ); - - // Use the fields in the struct to avoid clippy complaints - let int_sum = list.iter().fold(0, |acc, x| acc + x.int_col); - let double_sum = list.iter().fold(0.0, |acc, x| acc + x.double_col); - assert_eq!(int_sum, 4); - assert_eq!(double_sum, 40.4); - - Ok(()) -} - -/// This is target struct where we want the query results. -#[derive(Debug)] -struct Data { - int_col: i32, - double_col: f64, -} diff --git a/datafusion-examples/examples/execution_monitoring/main.rs b/datafusion-examples/examples/execution_monitoring/main.rs new file mode 100644 index 0000000000000..8f80c36929ca2 --- /dev/null +++ b/datafusion-examples/examples/execution_monitoring/main.rs @@ -0,0 +1,98 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # These examples of memory and performance management +//! +//! These examples demonstrate memory and performance management. +//! +//! ## Usage +//! ```bash +//! cargo run --example execution_monitoring -- [all|mem_pool_exec_plan|mem_pool_tracking|tracing] +//! ``` +//! +//! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module +//! +//! - `mem_pool_exec_plan` +//! (file: memory_pool_execution_plan.rs, desc: Memory-aware ExecutionPlan with spilling) +//! +//! - `mem_pool_tracking` +//! (file: memory_pool_tracking.rs, desc: Demonstrates memory tracking) +//! +//! - `tracing` +//! (file: tracing.rs, desc: Demonstrates tracing integration) + +mod memory_pool_execution_plan; +mod memory_pool_tracking; +mod tracing; + +use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; + +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] +enum ExampleKind { + All, + MemPoolExecPlan, + MemPoolTracking, + Tracing, +} + +impl ExampleKind { + const EXAMPLE_NAME: &str = "execution_monitoring"; + + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<()> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::MemPoolExecPlan => { + memory_pool_execution_plan::memory_pool_execution_plan().await? + } + ExampleKind::MemPoolTracking => { + memory_pool_tracking::mem_pool_tracking().await? + } + ExampleKind::Tracing => tracing::tracing().await?, + } + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let usage = format!( + "Usage: cargo run --example {} -- [{}]", + ExampleKind::EXAMPLE_NAME, + ExampleKind::VARIANTS.join("|") + ); + + let example: ExampleKind = std::env::args() + .nth(1) + .unwrap_or_else(|| ExampleKind::All.to_string()) + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; + + example.run().await +} diff --git a/datafusion-examples/examples/execution_monitoring/memory_pool_execution_plan.rs b/datafusion-examples/examples/execution_monitoring/memory_pool_execution_plan.rs new file mode 100644 index 0000000000000..1440347d4413d --- /dev/null +++ b/datafusion-examples/examples/execution_monitoring/memory_pool_execution_plan.rs @@ -0,0 +1,316 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. +//! +//! This example demonstrates how to implement custom ExecutionPlans that properly +//! use memory tracking through TrackConsumersPool. +//! +//! This shows the pattern for implementing memory-aware operators that: +//! - Register memory consumers with the pool +//! - Reserve memory before allocating +//! - Handle memory pressure by spilling to disk +//! - Release memory when done + +use arrow::record_batch::RecordBatch; +use arrow_schema::SchemaRef; +use datafusion::common::record_batch; +use datafusion::common::tree_node::TreeNodeRecursion; +use datafusion::common::{exec_datafusion_err, internal_err}; +use datafusion::datasource::{DefaultTableSource, memory::MemTable}; +use datafusion::error::Result; +use datafusion::execution::memory_pool::{MemoryConsumer, MemoryReservation}; +use datafusion::execution::runtime_env::RuntimeEnvBuilder; +use datafusion::execution::{SendableRecordBatchStream, TaskContext}; +use datafusion::logical_expr::LogicalPlanBuilder; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, +}; +use datafusion::prelude::*; +use futures::stream::{StreamExt, TryStreamExt}; +use std::any::Any; +use std::fmt; +use std::sync::Arc; + +/// Shows how to implement memory-aware ExecutionPlan with memory reservation and spilling +pub async fn memory_pool_execution_plan() -> Result<()> { + println!("=== DataFusion ExecutionPlan Memory Tracking Example ===\n"); + + // Set up a runtime with memory tracking + // Set a low memory limit to trigger spilling on the second batch + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(15000, 1.0) // Allow only enough for 1 batch at once + .build_arc()?; + + let config = SessionConfig::new().with_coalesce_batches(false); + let ctx = SessionContext::new_with_config_rt(config, runtime.clone()); + + // Create smaller batches to ensure we get multiple RecordBatches from the scan + // Make each batch smaller than the memory limit to force multiple batches + let batch1 = record_batch!( + ("id", Int32, vec![1; 800]), + ("name", Utf8, vec!["Alice"; 800]) + )?; + + let batch2 = record_batch!( + ("id", Int32, vec![2; 800]), + ("name", Utf8, vec!["Bob"; 800]) + )?; + + let batch3 = record_batch!( + ("id", Int32, vec![3; 800]), + ("name", Utf8, vec!["Charlie"; 800]) + )?; + + let batch4 = record_batch!( + ("id", Int32, vec![4; 800]), + ("name", Utf8, vec!["David"; 800]) + )?; + + let schema = batch1.schema(); + + // Create a single MemTable with all batches in one partition to preserve order but ensure streaming + let mem_table = Arc::new(MemTable::try_new( + Arc::clone(&schema), + vec![vec![batch1, batch2, batch3, batch4]], // Single partition with multiple batches + )?); + + // Build logical plan with a single scan that will yield multiple batches + let table_source = Arc::new(DefaultTableSource::new(mem_table)); + let logical_plan = + LogicalPlanBuilder::scan("multi_batch_table", table_source, None)?.build()?; + + // Convert to physical plan + let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?; + + println!("Example: Custom Memory-Aware BufferingExecutionPlan"); + println!("---------------------------------------------------"); + + // Wrap our input plan with our custom BufferingExecutionPlan + let buffering_plan = Arc::new(BufferingExecutionPlan::new(schema, physical_plan)); + + // Create a task context from our runtime + let task_ctx = Arc::new(TaskContext::default().with_runtime(runtime)); + + // Execute the plan directly to demonstrate memory tracking + println!("Executing BufferingExecutionPlan with memory tracking..."); + println!("Memory limit: 15000 bytes - should trigger spill on later batches\n"); + + let stream = buffering_plan.execute(0, task_ctx.clone())?; + let _results: Vec = stream.try_collect().await?; + + println!("\nSuccessfully executed BufferingExecutionPlan!"); + + println!("\nThe BufferingExecutionPlan processed 4 input batches and"); + println!("demonstrated memory tracking with spilling behavior when the"); + println!("memory limit was exceeded by later batches."); + println!("Check the console output above to see the spill messages."); + + Ok(()) +} + +/// Example of an external batch bufferer that uses memory reservation. +/// +/// It's a simple example which spills all existing data to disk +/// whenever the memory limit is reached. +struct ExternalBatchBufferer { + buffer: Vec, + reservation: MemoryReservation, + spill_count: usize, +} + +impl ExternalBatchBufferer { + fn new(reservation: MemoryReservation) -> Self { + Self { + buffer: Vec::new(), + reservation, + spill_count: 0, + } + } + + #[expect(clippy::needless_pass_by_value)] + fn add_batch(&mut self, batch_data: Vec) -> Result<()> { + let additional_memory = batch_data.len(); + + // Try to reserve memory before allocating + if self.reservation.try_grow(additional_memory).is_err() { + // Memory limit reached - handle by spilling + println!( + "Memory limit reached, spilling {} bytes to disk", + self.buffer.len() + ); + self.spill_to_disk()?; + + // Try again after spilling + self.reservation.try_grow(additional_memory)?; + } + + self.buffer.extend_from_slice(&batch_data); + println!( + "Added batch of {} bytes, total buffered: {} bytes", + additional_memory, + self.buffer.len() + ); + Ok(()) + } + + fn spill_to_disk(&mut self) -> Result<()> { + // Simulate writing buffer to disk + self.spill_count += 1; + println!( + "Spill #{}: Writing {} bytes to disk", + self.spill_count, + self.buffer.len() + ); + + // Free the memory after spilling + let freed_bytes = self.buffer.len(); + self.buffer.clear(); + self.reservation.shrink(freed_bytes); + + Ok(()) + } + + fn finish(&mut self) -> Vec { + let result = std::mem::take(&mut self.buffer); + // Free the memory when done + self.reservation.free(); + println!("Finished processing, released {} bytes", result.len()); + result + } +} + +/// Example of an ExecutionPlan that uses the ExternalBatchBufferer. +#[derive(Debug)] +struct BufferingExecutionPlan { + schema: SchemaRef, + input: Arc, + properties: Arc, +} + +impl BufferingExecutionPlan { + fn new(schema: SchemaRef, input: Arc) -> Self { + let properties = input.properties().clone(); + + Self { + schema, + input, + properties, + } + } +} + +impl DisplayAs for BufferingExecutionPlan { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "BufferingExecutionPlan") + } +} + +impl ExecutionPlan for BufferingExecutionPlan { + fn name(&self) -> &'static str { + "BufferingExecutionPlan" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn properties(&self) -> &Arc { + &self.properties + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + if children.len() == 1 { + Ok(Arc::new(BufferingExecutionPlan::new( + self.schema.clone(), + children[0].clone(), + ))) + } else { + internal_err!("BufferingExecutionPlan must have exactly one child") + } + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + // Register memory consumer with the context's memory pool + let reservation = MemoryConsumer::new("MyExternalBatchBufferer") + .with_can_spill(true) + .register(context.memory_pool()); + + // Incoming stream of batches + let mut input_stream = self.input.execute(partition, context)?; + + // Process the stream and collect all batches + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + futures::stream::once(async move { + let mut operator = ExternalBatchBufferer::new(reservation); + + while let Some(batch) = input_stream.next().await { + let batch = batch?; + + // Convert RecordBatch to bytes for this example + let batch_data = vec![1u8; batch.get_array_memory_size()]; + + operator.add_batch(batch_data)?; + } + + // Finish processing and get results + let _final_result = operator.finish(); + // In a real implementation, you would convert final_result back to RecordBatches + + // Since this is a simplified example, return an empty batch + // In a real implementation, you would create a batch stream from the processed results + record_batch!(("id", Int32, vec![5]), ("name", Utf8, vec!["Eve"])) + .map_err(|e| { + exec_datafusion_err!("Failed to create final RecordBatch: {e}") + }) + }), + ))) + } + + fn apply_expressions( + &self, + f: &mut dyn FnMut( + &dyn datafusion::physical_plan::PhysicalExpr, + ) -> Result, + ) -> Result { + // Visit expressions in the output ordering from equivalence properties + let mut tnr = TreeNodeRecursion::Continue; + if let Some(ordering) = self.properties.output_ordering() { + for sort_expr in ordering { + tnr = tnr.visit_sibling(|| f(sort_expr.expr.as_ref()))?; + } + } + Ok(tnr) + } +} diff --git a/datafusion-examples/examples/execution_monitoring/memory_pool_tracking.rs b/datafusion-examples/examples/execution_monitoring/memory_pool_tracking.rs new file mode 100644 index 0000000000000..af3031c690fa3 --- /dev/null +++ b/datafusion-examples/examples/execution_monitoring/memory_pool_tracking.rs @@ -0,0 +1,131 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. +//! +//! This example demonstrates how to use TrackConsumersPool for memory tracking and debugging. +//! +//! The TrackConsumersPool provides enhanced error messages that show the top memory consumers +//! when memory allocation fails, making it easier to debug memory issues in DataFusion queries. +//! +//! # Examples +//! +//! * [`automatic_usage_example`]: Shows how to use RuntimeEnvBuilder to automatically enable memory tracking + +use datafusion::error::Result; +use datafusion::execution::runtime_env::RuntimeEnvBuilder; +use datafusion::prelude::*; + +/// Demonstrates TrackConsumersPool for memory tracking and debugging with enhanced error messages +pub async fn mem_pool_tracking() -> Result<()> { + println!("=== DataFusion Memory Pool Tracking Example ===\n"); + + // Example 1: Automatic Usage with RuntimeEnvBuilder + automatic_usage_example().await?; + + Ok(()) +} + +/// Example 1: Automatic Usage with RuntimeEnvBuilder +/// +/// This shows the recommended way to use TrackConsumersPool through RuntimeEnvBuilder, +/// which automatically creates a TrackConsumersPool with sensible defaults. +async fn automatic_usage_example() -> Result<()> { + println!("Example 1: Automatic Usage with RuntimeEnvBuilder"); + println!("------------------------------------------------"); + + // Success case: Create a runtime with reasonable memory limit + println!("Success case: Normal operation with sufficient memory"); + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(5_000_000, 1.0) // 5MB, 100% utilization + .build_arc()?; + + let config = SessionConfig::new(); + let ctx = SessionContext::new_with_config_rt(config, runtime); + + // Create a simple table for demonstration + ctx.sql("CREATE TABLE test AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") + .await? + .collect() + .await?; + + println!("✓ Created table with memory tracking enabled"); + + // Run a simple query to show it works + let results = ctx.sql("SELECT * FROM test").await?.collect().await?; + println!( + "✓ Query executed successfully. Found {} rows", + results.len() + ); + + println!("\n{}", "-".repeat(50)); + + // Error case: Create a runtime with low memory limit to trigger errors + println!("Error case: Triggering memory limit error with detailed error messages"); + + // Use a WITH query that generates data and then processes it to trigger memory usage + match ctx.sql(" + WITH large_dataset AS ( + SELECT + column1 as id, + column1 * 2 as doubled, + repeat('data_', 20) || column1 as text_field, + column1 * column1 as squared + FROM generate_series(1, 2000) as t(column1) + ), + aggregated AS ( + SELECT + id, + doubled, + text_field, + squared, + sum(doubled) OVER (ORDER BY id ROWS BETWEEN 100 PRECEDING AND CURRENT ROW) as running_sum + FROM large_dataset + ) + SELECT + a1.id, + a1.text_field, + a2.text_field as text_field2, + a1.running_sum + a2.running_sum as combined_sum + FROM aggregated a1 + JOIN aggregated a2 ON a1.id = a2.id - 1 + ORDER BY a1.id + ").await?.collect().await { + Ok(results) => panic!("Should not succeed! Yet got {} batches", results.len()), + Err(e) => { + println!("✓ Expected memory limit error during data processing:"); + println!("Error: {e}"); + /* Example error message: + Error: Not enough memory to continue external sort. Consider increasing the memory limit config: 'datafusion.runtime.memory_limit', + or decreasing the config: 'datafusion.execution.sort_spill_reservation_bytes'. + caused by + Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: + ExternalSorterMerge[3]#112(can spill: false) consumed 10.0 MB, peak 10.0 MB, + ExternalSorterMerge[10]#147(can spill: false) consumed 10.0 MB, peak 10.0 MB, + ExternalSorter[1]#93(can spill: true) consumed 69.0 KB, peak 69.0 KB, + ExternalSorter[13]#155(can spill: true) consumed 67.6 KB, peak 67.6 KB, + ExternalSorter[8]#140(can spill: true) consumed 67.2 KB, peak 67.2 KB. + Error: Failed to allocate additional 10.0 MB for ExternalSorterMerge[0] with 0.0 B already allocated for this reservation - 7.1 MB remain available for the total pool + */ + } + } + + println!("\nNote: The error message above shows which memory consumers"); + println!("were using the most memory when the limit was exceeded."); + + Ok(()) +} diff --git a/datafusion-examples/examples/tracing.rs b/datafusion-examples/examples/execution_monitoring/tracing.rs similarity index 82% rename from datafusion-examples/examples/tracing.rs rename to datafusion-examples/examples/execution_monitoring/tracing.rs index 334ee0f4e5686..172c1ca83b3bd 100644 --- a/datafusion-examples/examples/tracing.rs +++ b/datafusion-examples/examples/execution_monitoring/tracing.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. +//! //! This example demonstrates the tracing injection feature for the DataFusion runtime. //! Tasks spawned on new threads behave differently depending on whether a tracer is injected. //! The log output clearly distinguishes the two cases. @@ -49,20 +51,21 @@ //! 10:29:40.809 INFO main ThreadId(01) tracing: ***** WITH tracer: Non-main tasks DID inherit the `run_instrumented_query` span ***** //! ``` -use datafusion::common::runtime::{set_join_set_tracer, JoinSetTracer}; +use std::any::Any; +use std::sync::Arc; + +use datafusion::common::runtime::{JoinSetTracer, set_join_set_tracer}; use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::ListingOptions; use datafusion::error::Result; use datafusion::prelude::*; -use datafusion::test_util::parquet_test_data; -use futures::future::BoxFuture; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; use futures::FutureExt; -use std::any::Any; -use std::sync::Arc; -use tracing::{info, instrument, Instrument, Level, Span}; +use futures::future::BoxFuture; +use tracing::{Instrument, Level, Span, info, instrument}; -#[tokio::main] -async fn main() -> Result<()> { +/// Demonstrates the tracing injection feature for the DataFusion runtime +pub async fn tracing() -> Result<()> { // Initialize tracing subscriber with thread info. tracing_subscriber::fmt() .with_thread_ids(true) @@ -73,7 +76,9 @@ async fn main() -> Result<()> { // Run query WITHOUT tracer injection. info!("***** RUNNING WITHOUT INJECTED TRACER *****"); run_instrumented_query().await?; - info!("***** WITHOUT tracer: `tokio-runtime-worker` tasks did NOT inherit the `run_instrumented_query` span *****"); + info!( + "***** WITHOUT tracer: `tokio-runtime-worker` tasks did NOT inherit the `run_instrumented_query` span *****" + ); // Inject custom tracer so tasks run in the current span. info!("Injecting custom tracer..."); @@ -82,7 +87,9 @@ async fn main() -> Result<()> { // Run query WITH tracer injection. info!("***** RUNNING WITH INJECTED TRACER *****"); run_instrumented_query().await?; - info!("***** WITH tracer: `tokio-runtime-worker` tasks DID inherit the `run_instrumented_query` span *****"); + info!( + "***** WITH tracer: `tokio-runtime-worker` tasks DID inherit the `run_instrumented_query` span *****" + ); Ok(()) } @@ -120,18 +127,27 @@ async fn run_instrumented_query() -> Result<()> { info!("Starting query execution"); let ctx = SessionContext::new(); - let test_data = parquet_test_data(); + + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; + let file_format = ParquetFormat::default().with_enable_pruning(true); - let listing_options = ListingOptions::new(Arc::new(file_format)) - .with_file_extension("alltypes_tiny_pages_plain.parquet"); + let listing_options = + ListingOptions::new(Arc::new(file_format)).with_file_extension(".parquet"); - let table_path = format!("file://{test_data}/"); - info!("Registering table 'alltypes' from {}", table_path); - ctx.register_listing_table("alltypes", &table_path, listing_options, None, None) - .await - .expect("Failed to register table"); + info!("Registering table 'cars' from {}", parquet_temp.path_str()?); + ctx.register_listing_table( + "cars", + parquet_temp.path_str()?, + listing_options, + None, + None, + ) + .await + .expect("Failed to register table"); - let sql = "SELECT COUNT(*), string_col FROM alltypes GROUP BY string_col"; + let sql = "SELECT COUNT(*), car, sum(speed) FROM cars GROUP BY car"; info!(sql, "Executing SQL query"); let result = ctx.sql(sql).await?.collect().await?; info!("Query complete: {} batches returned", result.len()); diff --git a/datafusion-examples/examples/external_dependency/dataframe-to-s3.rs b/datafusion-examples/examples/external_dependency/dataframe_to_s3.rs similarity index 87% rename from datafusion-examples/examples/external_dependency/dataframe-to-s3.rs rename to datafusion-examples/examples/external_dependency/dataframe_to_s3.rs index e75ba5dd5328a..fdb8a3c9c051a 100644 --- a/datafusion-examples/examples/external_dependency/dataframe-to-s3.rs +++ b/datafusion-examples/examples/external_dependency/dataframe_to_s3.rs @@ -15,12 +15,14 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use std::env; use std::sync::Arc; use datafusion::dataframe::DataFrameWriteOptions; -use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::file_format::FileFormat; +use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::ListingOptions; use datafusion::error::Result; use datafusion::prelude::*; @@ -28,14 +30,18 @@ use datafusion::prelude::*; use object_store::aws::AmazonS3Builder; use url::Url; -/// This example demonstrates querying data from AmazonS3 and writing -/// the result of a query back to AmazonS3 -#[tokio::main] -async fn main() -> Result<()> { +/// This example demonstrates querying data from Amazon S3 and writing +/// the result of a query back to Amazon S3. +/// +/// The following environment variables must be defined: +/// +/// - AWS_ACCESS_KEY_ID +/// - AWS_SECRET_ACCESS_KEY +pub async fn dataframe_to_s3() -> Result<()> { // create local execution context let ctx = SessionContext::new(); - //enter region and bucket to which your credentials have GET and PUT access + // enter region and bucket to which your credentials have GET and PUT access let region = ""; let bucket_name = ""; @@ -66,13 +72,13 @@ async fn main() -> Result<()> { .write_parquet(&out_path, DataFrameWriteOptions::new(), None) .await?; - //write as JSON to s3 + // write as JSON to s3 let json_out = format!("s3://{bucket_name}/json_out"); df.clone() .write_json(&json_out, DataFrameWriteOptions::new(), None) .await?; - //write as csv to s3 + // write as csv to s3 let csv_out = format!("s3://{bucket_name}/csv_out"); df.write_csv(&csv_out, DataFrameWriteOptions::new(), None) .await?; diff --git a/datafusion-examples/examples/external_dependency/main.rs b/datafusion-examples/examples/external_dependency/main.rs new file mode 100644 index 0000000000000..447e7d38bdd5b --- /dev/null +++ b/datafusion-examples/examples/external_dependency/main.rs @@ -0,0 +1,88 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # These are using data from Amazon S3 examples +//! +//! These examples demonstrate how to work with data from Amazon S3. +//! +//! ## Usage +//! ```bash +//! cargo run --example external_dependency -- [all|dataframe_to_s3|query_aws_s3] +//! ``` +//! +//! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module +//! +//! - `dataframe_to_s3` +//! (file: dataframe_to_s3.rs, desc: Query DataFrames and write results to S3) +//! +//! - `query_aws_s3` +//! (file: query_aws_s3.rs, desc: Query S3-backed data using object_store) + +mod dataframe_to_s3; +mod query_aws_s3; + +use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; + +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] +enum ExampleKind { + All, + DataframeToS3, + QueryAwsS3, +} + +impl ExampleKind { + const EXAMPLE_NAME: &str = "external_dependency"; + + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<()> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::DataframeToS3 => dataframe_to_s3::dataframe_to_s3().await?, + ExampleKind::QueryAwsS3 => query_aws_s3::query_aws_s3().await?, + } + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let usage = format!( + "Usage: cargo run --example {} -- [{}]", + ExampleKind::EXAMPLE_NAME, + ExampleKind::VARIANTS.join("|") + ); + + let example: ExampleKind = std::env::args() + .nth(1) + .unwrap_or_else(|| ExampleKind::All.to_string()) + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; + + example.run().await +} diff --git a/datafusion-examples/examples/external_dependency/query-aws-s3.rs b/datafusion-examples/examples/external_dependency/query_aws_s3.rs similarity index 90% rename from datafusion-examples/examples/external_dependency/query-aws-s3.rs rename to datafusion-examples/examples/external_dependency/query_aws_s3.rs index da2d7e4879f99..63507bb3eed11 100644 --- a/datafusion-examples/examples/external_dependency/query-aws-s3.rs +++ b/datafusion-examples/examples/external_dependency/query_aws_s3.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use datafusion::error::Result; use datafusion::prelude::*; use object_store::aws::AmazonS3Builder; @@ -22,15 +24,13 @@ use std::env; use std::sync::Arc; use url::Url; -/// This example demonstrates querying data in an S3 bucket. +/// This example demonstrates querying data in a public S3 bucket +/// (the NYC TLC open dataset: `s3://nyc-tlc`). /// /// The following environment variables must be defined: -/// -/// - AWS_ACCESS_KEY_ID -/// - AWS_SECRET_ACCESS_KEY -/// -#[tokio::main] -async fn main() -> Result<()> { +/// - `AWS_ACCESS_KEY_ID` +/// - `AWS_SECRET_ACCESS_KEY` +pub async fn query_aws_s3() -> Result<()> { let ctx = SessionContext::new(); // the region must be set to the region where the bucket exists until the following diff --git a/datafusion-examples/examples/ffi/ffi_example_table_provider/Cargo.toml b/datafusion-examples/examples/ffi/ffi_example_table_provider/Cargo.toml index e9c0c5b43d682..e2d0e3fa6744d 100644 --- a/datafusion-examples/examples/ffi/ffi_example_table_provider/Cargo.toml +++ b/datafusion-examples/examples/ffi/ffi_example_table_provider/Cargo.toml @@ -28,6 +28,9 @@ datafusion = { workspace = true } datafusion-ffi = { workspace = true } ffi_module_interface = { path = "../ffi_module_interface" } +[lints] +workspace = true + [lib] name = "ffi_example_table_provider" crate-type = ["cdylib", 'rlib'] diff --git a/datafusion-examples/examples/ffi/ffi_example_table_provider/src/lib.rs b/datafusion-examples/examples/ffi/ffi_example_table_provider/src/lib.rs index a83f15926f054..eb217ef9e4832 100644 --- a/datafusion-examples/examples/ffi/ffi_example_table_provider/src/lib.rs +++ b/datafusion-examples/examples/ffi/ffi_example_table_provider/src/lib.rs @@ -21,6 +21,7 @@ use abi_stable::{export_root_module, prefix_type::PrefixTypeTrait}; use arrow::array::RecordBatch; use arrow::datatypes::{DataType, Field, Schema}; use datafusion::{common::record_batch, datasource::MemTable}; +use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use datafusion_ffi::table_provider::FFI_TableProvider; use ffi_module_interface::{TableProviderModule, TableProviderModuleRef}; @@ -34,7 +35,9 @@ fn create_record_batch(start_value: i32, num_values: usize) -> RecordBatch { /// Here we only wish to create a simple table provider as an example. /// We create an in-memory table and convert it to it's FFI counterpart. -extern "C" fn construct_simple_table_provider() -> FFI_TableProvider { +extern "C" fn construct_simple_table_provider( + codec: FFI_LogicalExtensionCodec, +) -> FFI_TableProvider { let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, true), Field::new("b", DataType::Float64, true), @@ -50,7 +53,7 @@ extern "C" fn construct_simple_table_provider() -> FFI_TableProvider { let table_provider = MemTable::try_new(schema, vec![batches]).unwrap(); - FFI_TableProvider::new(Arc::new(table_provider), true, None) + FFI_TableProvider::new_with_ffi_codec(Arc::new(table_provider), true, None, codec) } #[export_root_module] diff --git a/datafusion-examples/examples/ffi/ffi_module_interface/Cargo.toml b/datafusion-examples/examples/ffi/ffi_module_interface/Cargo.toml index 612a219324763..fe4902711241e 100644 --- a/datafusion-examples/examples/ffi/ffi_module_interface/Cargo.toml +++ b/datafusion-examples/examples/ffi/ffi_module_interface/Cargo.toml @@ -18,9 +18,12 @@ [package] name = "ffi_module_interface" version = "0.1.0" -edition = "2021" +edition = "2024" publish = false +[lints] +workspace = true + [dependencies] abi_stable = "0.11.3" datafusion-ffi = { workspace = true } diff --git a/datafusion-examples/examples/ffi/ffi_module_interface/src/lib.rs b/datafusion-examples/examples/ffi/ffi_module_interface/src/lib.rs index 88690e9297135..3b2b9e1871dae 100644 --- a/datafusion-examples/examples/ffi/ffi_module_interface/src/lib.rs +++ b/datafusion-examples/examples/ffi/ffi_module_interface/src/lib.rs @@ -16,12 +16,12 @@ // under the License. use abi_stable::{ - declare_root_module_statics, + StableAbi, declare_root_module_statics, library::{LibraryError, RootModule}, package_version_strings, sabi_types::VersionStrings, - StableAbi, }; +use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use datafusion_ffi::table_provider::FFI_TableProvider; #[repr(C)] @@ -34,7 +34,8 @@ use datafusion_ffi::table_provider::FFI_TableProvider; /// how a user may wish to separate these concerns. pub struct TableProviderModule { /// Constructs the table provider - pub create_table: extern "C" fn() -> FFI_TableProvider, + pub create_table: + extern "C" fn(codec: FFI_LogicalExtensionCodec) -> FFI_TableProvider, } impl RootModule for TableProviderModuleRef { diff --git a/datafusion-examples/examples/ffi/ffi_module_loader/Cargo.toml b/datafusion-examples/examples/ffi/ffi_module_loader/Cargo.toml index 028a366aab1c0..8d7434dca211b 100644 --- a/datafusion-examples/examples/ffi/ffi_module_loader/Cargo.toml +++ b/datafusion-examples/examples/ffi/ffi_module_loader/Cargo.toml @@ -18,9 +18,12 @@ [package] name = "ffi_module_loader" version = "0.1.0" -edition = "2021" +edition = "2024" publish = false +[lints] +workspace = true + [dependencies] abi_stable = "0.11.3" datafusion = { workspace = true } diff --git a/datafusion-examples/examples/ffi/ffi_module_loader/src/main.rs b/datafusion-examples/examples/ffi/ffi_module_loader/src/main.rs index 6e376ca866e8f..8ce5b156df3b1 100644 --- a/datafusion-examples/examples/ffi/ffi_module_loader/src/main.rs +++ b/datafusion-examples/examples/ffi/ffi_module_loader/src/main.rs @@ -22,8 +22,10 @@ use datafusion::{ prelude::SessionContext, }; -use abi_stable::library::{development_utils::compute_library_path, RootModule}; -use datafusion_ffi::table_provider::ForeignTableProvider; +use abi_stable::library::{RootModule, development_utils::compute_library_path}; +use datafusion::datasource::TableProvider; +use datafusion::execution::TaskContextProvider; +use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use ffi_module_interface::TableProviderModuleRef; #[tokio::main] @@ -39,6 +41,11 @@ async fn main() -> Result<()> { TableProviderModuleRef::load_from_directory(&library_path) .map_err(|e| DataFusionError::External(Box::new(e)))?; + let ctx = Arc::new(SessionContext::new()); + let codec = FFI_LogicalExtensionCodec::new_default( + &(Arc::clone(&ctx) as Arc), + ); + // By calling the code below, the table provided will be created within // the module's code. let ffi_table_provider = @@ -46,16 +53,14 @@ async fn main() -> Result<()> { .create_table() .ok_or(DataFusionError::NotImplemented( "External table provider failed to implement create_table".to_string(), - ))?(); + ))?(codec); // In order to access the table provider within this executable, we need to - // turn it into a `ForeignTableProvider`. - let foreign_table_provider: ForeignTableProvider = (&ffi_table_provider).into(); - - let ctx = SessionContext::new(); + // turn it into a `TableProvider`. + let foreign_table_provider: Arc = (&ffi_table_provider).into(); // Display the data to show the full cycle works. - ctx.register_table("external_table", Arc::new(foreign_table_provider))?; + ctx.register_table("external_table", foreign_table_provider)?; let df = ctx.table("external_table").await?; df.show().await?; diff --git a/datafusion-examples/examples/flight/flight_client.rs b/datafusion-examples/examples/flight/client.rs similarity index 78% rename from datafusion-examples/examples/flight/flight_client.rs rename to datafusion-examples/examples/flight/client.rs index e3237284b4307..8f6856a4e4849 100644 --- a/datafusion-examples/examples/flight/flight_client.rs +++ b/datafusion-examples/examples/flight/client.rs @@ -15,32 +15,41 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use std::collections::HashMap; use std::sync::Arc; -use datafusion::arrow::datatypes::Schema; - use arrow_flight::flight_descriptor; use arrow_flight::flight_service_client::FlightServiceClient; use arrow_flight::utils::flight_data_to_arrow_batch; use arrow_flight::{FlightDescriptor, Ticket}; +use datafusion::arrow::datatypes::Schema; use datafusion::arrow::util::pretty; +use datafusion::prelude::SessionContext; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; +use tonic::transport::Endpoint; /// This example shows how to wrap DataFusion with `FlightService` to support looking up schema information for /// Parquet files and executing SQL queries against them on a remote server. /// This example is run along-side the example `flight_server`. -#[tokio::main] -async fn main() -> Result<(), Box> { - let testdata = datafusion::test_util::parquet_test_data(); +pub async fn client() -> Result<(), Box> { + let ctx = SessionContext::new(); + + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; // Create Flight client - let mut client = FlightServiceClient::connect("http://localhost:50051").await?; + let endpoint = Endpoint::new("http://localhost:50051")?; + let channel = endpoint.connect().await?; + let mut client = FlightServiceClient::new(channel); // Call get_schema to get the schema of a Parquet file let request = tonic::Request::new(FlightDescriptor { r#type: flight_descriptor::DescriptorType::Path as i32, cmd: Default::default(), - path: vec![format!("{testdata}/alltypes_plain.parquet")], + path: vec![format!("{}", parquet_temp.path_str()?)], }); let schema_result = client.get_schema(request).await?.into_inner(); @@ -49,7 +58,7 @@ async fn main() -> Result<(), Box> { // Call do_get to execute a SQL query and receive results let request = tonic::Request::new(Ticket { - ticket: "SELECT id FROM alltypes_plain".into(), + ticket: "SELECT car FROM cars".into(), }); let mut stream = client.do_get(request).await?.into_inner(); diff --git a/datafusion-examples/examples/flight/main.rs b/datafusion-examples/examples/flight/main.rs new file mode 100644 index 0000000000000..426e806486f70 --- /dev/null +++ b/datafusion-examples/examples/flight/main.rs @@ -0,0 +1,101 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # Arrow Flight Examples +//! +//! These examples demonstrate Arrow Flight usage. +//! +//! ## Usage +//! ```bash +//! cargo run --example flight -- [all|client|server|sql_server] +//! ``` +//! +//! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module +//! Note: The Flight server must be started in a separate process +//! before running the `client` example. Therefore, running `all` will +//! not produce a full server+client workflow automatically. +//! +//! - `client` +//! (file: client.rs, desc: Execute SQL queries via Arrow Flight protocol) +//! +//! - `server` +//! (file: server.rs, desc: Run DataFusion server accepting FlightSQL/JDBC queries) +//! +//! - `sql_server` +//! (file: sql_server.rs, desc: Standalone SQL server for JDBC clients) + +mod client; +mod server; +mod sql_server; + +use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; + +/// The `all` option cannot run all examples end-to-end because the +/// `server` example must run in a separate process before the `client` +/// example can connect. +/// Therefore, `all` only iterates over individually runnable examples. +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] +enum ExampleKind { + All, + Client, + Server, + SqlServer, +} + +impl ExampleKind { + const EXAMPLE_NAME: &str = "flight"; + + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<(), Box> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::Client => client::client().await?, + ExampleKind::Server => server::server().await?, + ExampleKind::SqlServer => sql_server::sql_server().await?, + } + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let usage = format!( + "Usage: cargo run --example {} -- [{}]", + ExampleKind::EXAMPLE_NAME, + ExampleKind::VARIANTS.join("|") + ); + + let example: ExampleKind = std::env::args() + .nth(1) + .unwrap_or_else(|| ExampleKind::All.to_string()) + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; + + example.run().await +} diff --git a/datafusion-examples/examples/flight/flight_server.rs b/datafusion-examples/examples/flight/server.rs similarity index 84% rename from datafusion-examples/examples/flight/flight_server.rs rename to datafusion-examples/examples/flight/server.rs index cc5f43746ddfb..b73c81dd7d2c3 100644 --- a/datafusion-examples/examples/flight/flight_server.rs +++ b/datafusion-examples/examples/flight/server.rs @@ -15,25 +15,26 @@ // specific language governing permissions and limitations // under the License. -use arrow::ipc::writer::{DictionaryTracker, IpcDataGenerator}; +//! See `main.rs` for how to run it. + use std::sync::Arc; +use arrow::ipc::writer::{CompressionContext, DictionaryTracker, IpcDataGenerator}; +use arrow_flight::{ + Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, + HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket, + flight_service_server::FlightService, flight_service_server::FlightServiceServer, +}; use arrow_flight::{PollInfo, SchemaAsIpc}; use datafusion::arrow::error::ArrowError; use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::{ListingOptions, ListingTableUrl}; +use datafusion::prelude::*; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; use futures::stream::BoxStream; use tonic::transport::Server; use tonic::{Request, Response, Status, Streaming}; -use datafusion::prelude::*; - -use arrow_flight::{ - flight_service_server::FlightService, flight_service_server::FlightServiceServer, - Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, - HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket, -}; - #[derive(Clone)] pub struct FlightServiceImpl {} @@ -83,22 +84,27 @@ impl FlightService for FlightServiceImpl { // create local execution context let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()) + .await + .map_err(|e| { + Status::internal(format!("Error writing csv to parquet: {e}")) + })?; + let parquet_path = parquet_temp.path_str().map_err(|e| { + Status::internal(format!("Error getting parquet path: {e}")) + })?; // register parquet file with the execution context - ctx.register_parquet( - "alltypes_plain", - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) - .await - .map_err(to_tonic_err)?; + ctx.register_parquet("cars", parquet_path, ParquetReadOptions::default()) + .await + .map_err(to_tonic_err)?; // create the DataFrame let df = ctx.sql(sql).await.map_err(to_tonic_err)?; // execute the query - let schema = df.schema().clone().into(); + let schema = Arc::clone(df.schema().inner()); let results = df.collect().await.map_err(to_tonic_err)?; if results.is_empty() { return Err(Status::internal("There were no results from ticket")); @@ -106,6 +112,7 @@ impl FlightService for FlightServiceImpl { // add an initial FlightData message that sends schema let options = arrow::ipc::writer::IpcWriteOptions::default(); + let mut compression_context = CompressionContext::default(); let schema_flight_data = SchemaAsIpc::new(&schema, &options); let mut flights = vec![FlightData::from(schema_flight_data)]; @@ -115,7 +122,7 @@ impl FlightService for FlightServiceImpl { for batch in &results { let (flight_dictionaries, flight_batch) = encoder - .encoded_batch(batch, &mut tracker, &options) + .encode(batch, &mut tracker, &options, &mut compression_context) .map_err(|e: ArrowError| Status::internal(e.to_string()))?; flights.extend(flight_dictionaries.into_iter().map(Into::into)); @@ -186,6 +193,7 @@ impl FlightService for FlightServiceImpl { } } +#[expect(clippy::needless_pass_by_value)] fn to_tonic_err(e: datafusion::error::DataFusionError) -> Status { Status::internal(format!("{e:?}")) } @@ -193,8 +201,7 @@ fn to_tonic_err(e: datafusion::error::DataFusionError) -> Status { /// This example shows how to wrap DataFusion with `FlightService` to support looking up schema information for /// Parquet files and executing SQL queries against them on a remote server. /// This example is run along-side the example `flight_client`. -#[tokio::main] -async fn main() -> Result<(), Box> { +pub async fn server() -> Result<(), Box> { let addr = "0.0.0.0:50051".parse()?; let service = FlightServiceImpl {}; diff --git a/datafusion-examples/examples/flight/flight_sql_server.rs b/datafusion-examples/examples/flight/sql_server.rs similarity index 93% rename from datafusion-examples/examples/flight/flight_sql_server.rs rename to datafusion-examples/examples/flight/sql_server.rs index 5a573ed52320d..e55aaa7250ea7 100644 --- a/datafusion-examples/examples/flight/flight_sql_server.rs +++ b/datafusion-examples/examples/flight/sql_server.rs @@ -15,6 +15,11 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + +use std::pin::Pin; +use std::sync::Arc; + use arrow::array::{ArrayRef, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::ipc::writer::IpcWriteOptions; @@ -36,12 +41,11 @@ use arrow_flight::{ use dashmap::DashMap; use datafusion::logical_expr::LogicalPlan; use datafusion::prelude::{DataFrame, ParquetReadOptions, SessionConfig, SessionContext}; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; use futures::{Stream, StreamExt, TryStreamExt}; use log::info; use mimalloc::MiMalloc; use prost::Message; -use std::pin::Pin; -use std::sync::Arc; use tonic::metadata::MetadataValue; use tonic::transport::Server; use tonic::{Request, Response, Status, Streaming}; @@ -68,9 +72,7 @@ macro_rules! status { /// /// Based heavily on Ballista's implementation: https://github.com/apache/datafusion-ballista/blob/main/ballista/scheduler/src/flight_sql.rs /// and the example in arrow-rs: https://github.com/apache/arrow-rs/blob/master/arrow-flight/examples/flight_sql_server.rs -/// -#[tokio::main] -async fn main() -> Result<(), Box> { +pub async fn sql_server() -> Result<(), Box> { env_logger::init(); let addr = "0.0.0.0:50051".parse()?; let service = FlightSqlServiceImpl { @@ -100,22 +102,24 @@ impl FlightSqlServiceImpl { .with_information_schema(true); let ctx = Arc::new(SessionContext::new_with_config(session_config)); - let testdata = datafusion::test_util::parquet_test_data(); + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()) + .await + .map_err(|e| status!("Error writing csv to parquet", e))?; + let parquet_path = parquet_temp + .path_str() + .map_err(|e| status!("Error getting parquet path", e))?; // register parquet file with the execution context - ctx.register_parquet( - "alltypes_plain", - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) - .await - .map_err(|e| status!("Error registering table", e))?; + ctx.register_parquet("cars", parquet_path, ParquetReadOptions::default()) + .await + .map_err(|e| status!("Error registering table", e))?; self.contexts.insert(uuid.clone(), ctx); Ok(uuid) } - #[allow(clippy::result_large_err)] fn get_ctx(&self, req: &Request) -> Result, Status> { // get the token from the authorization header on Request let auth = req @@ -141,7 +145,6 @@ impl FlightSqlServiceImpl { } } - #[allow(clippy::result_large_err)] fn get_plan(&self, handle: &str) -> Result { if let Some(plan) = self.statements.get(handle) { Ok(plan.clone()) @@ -150,7 +153,6 @@ impl FlightSqlServiceImpl { } } - #[allow(clippy::result_large_err)] fn get_result(&self, handle: &str) -> Result, Status> { if let Some(result) = self.results.get(handle) { Ok(result.clone()) @@ -198,13 +200,11 @@ impl FlightSqlServiceImpl { .unwrap() } - #[allow(clippy::result_large_err)] fn remove_plan(&self, handle: &str) -> Result<(), Status> { self.statements.remove(&handle.to_string()); Ok(()) } - #[allow(clippy::result_large_err)] fn remove_result(&self, handle: &str) -> Result<(), Status> { self.results.remove(&handle.to_string()); Ok(()) @@ -395,10 +395,8 @@ impl FlightSqlService for FlightSqlServiceImpl { let plan_uuid = Uuid::new_v4().hyphenated().to_string(); self.statements.insert(plan_uuid.clone(), plan.clone()); - let plan_schema = plan.schema(); - - let arrow_schema = (&**plan_schema).into(); - let message = SchemaAsIpc::new(&arrow_schema, &IpcWriteOptions::default()) + let arrow_schema = plan.schema().as_arrow(); + let message = SchemaAsIpc::new(arrow_schema, &IpcWriteOptions::default()) .try_into() .map_err(|e| status!("Unable to serialize schema", e))?; let IpcMessage(schema_bytes) = message; @@ -418,7 +416,9 @@ impl FlightSqlService for FlightSqlServiceImpl { ) -> Result<(), Status> { let handle = std::str::from_utf8(&handle.prepared_statement_handle); if let Ok(handle) = handle { - info!("do_action_close_prepared_statement: removing plan and results for {handle}"); + info!( + "do_action_close_prepared_statement: removing plan and results for {handle}" + ); let _ = self.remove_plan(handle); let _ = self.remove_result(handle); } diff --git a/datafusion-examples/examples/composed_extension_codec.rs b/datafusion-examples/examples/proto/composed_extension_codec.rs similarity index 66% rename from datafusion-examples/examples/composed_extension_codec.rs rename to datafusion-examples/examples/proto/composed_extension_codec.rs index 4baefcae507f6..df3d58b7bfb81 100644 --- a/datafusion-examples/examples/composed_extension_codec.rs +++ b/datafusion-examples/examples/proto/composed_extension_codec.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. +//! //! This example demonstrates how to compose multiple PhysicalExtensionCodecs //! //! This can be helpful when an Execution plan tree has different nodes from different crates @@ -32,20 +34,21 @@ use std::any::Any; use std::fmt::Debug; -use std::ops::Deref; use std::sync::Arc; use datafusion::common::Result; -use datafusion::common::{internal_err, DataFusionError}; -use datafusion::logical_expr::registry::FunctionRegistry; -use datafusion::logical_expr::{AggregateUDF, ScalarUDF}; +use datafusion::common::internal_err; +use datafusion::common::tree_node::TreeNodeRecursion; +use datafusion::execution::TaskContext; use datafusion::physical_plan::{DisplayAs, ExecutionPlan}; use datafusion::prelude::SessionContext; -use datafusion_proto::physical_plan::{AsExecutionPlan, PhysicalExtensionCodec}; +use datafusion_proto::physical_plan::{ + AsExecutionPlan, ComposedPhysicalExtensionCodec, PhysicalExtensionCodec, +}; use datafusion_proto::protobuf; -#[tokio::main] -async fn main() { +/// Example of using multiple extension codecs for serialization / deserialization +pub async fn composed_extension_codec() -> Result<()> { // build execution plan that has both types of nodes // // Note each node requires a different `PhysicalExtensionCodec` to decode @@ -54,29 +57,28 @@ async fn main() { }); let ctx = SessionContext::new(); - let composed_codec = ComposedPhysicalExtensionCodec { - codecs: vec![ - Arc::new(ParentPhysicalExtensionCodec {}), - Arc::new(ChildPhysicalExtensionCodec {}), - ], - }; + // Position in this list is important as it will be used for decoding. + // If new codec is added it should go to last position. + let composed_codec = ComposedPhysicalExtensionCodec::new(vec![ + Arc::new(ParentPhysicalExtensionCodec {}), + Arc::new(ChildPhysicalExtensionCodec {}), + ]); // serialize execution plan to proto let proto: protobuf::PhysicalPlanNode = protobuf::PhysicalPlanNode::try_from_physical_plan( exec_plan.clone(), &composed_codec, - ) - .expect("to proto"); + )?; // deserialize proto back to execution plan - let runtime = ctx.runtime_env(); - let result_exec_plan: Arc = proto - .try_into_physical_plan(&ctx, runtime.deref(), &composed_codec) - .expect("from proto"); + let result_exec_plan: Arc = + proto.try_into_physical_plan(&ctx.task_ctx(), &composed_codec)?; // assert that the original and deserialized execution plans are equal assert_eq!(format!("{exec_plan:?}"), format!("{result_exec_plan:?}")); + + Ok(()) } /// This example has two types of nodes: `ParentExec` and `ChildExec` which can only @@ -105,7 +107,7 @@ impl ExecutionPlan for ParentExec { self } - fn properties(&self) -> &datafusion::physical_plan::PlanProperties { + fn properties(&self) -> &Arc { unreachable!() } @@ -123,10 +125,19 @@ impl ExecutionPlan for ParentExec { fn execute( &self, _partition: usize, - _context: Arc, + _context: Arc, ) -> Result { unreachable!() } + + fn apply_expressions( + &self, + _f: &mut dyn FnMut( + &dyn datafusion::physical_plan::PhysicalExpr, + ) -> Result, + ) -> Result { + Ok(TreeNodeRecursion::Continue) + } } /// A PhysicalExtensionCodec that can serialize and deserialize ParentExec @@ -138,7 +149,7 @@ impl PhysicalExtensionCodec for ParentPhysicalExtensionCodec { &self, buf: &[u8], inputs: &[Arc], - _registry: &dyn FunctionRegistry, + _ctx: &TaskContext, ) -> Result> { if buf == "ParentExec".as_bytes() { Ok(Arc::new(ParentExec { @@ -181,7 +192,7 @@ impl ExecutionPlan for ChildExec { self } - fn properties(&self) -> &datafusion::physical_plan::PlanProperties { + fn properties(&self) -> &Arc { unreachable!() } @@ -199,10 +210,19 @@ impl ExecutionPlan for ChildExec { fn execute( &self, _partition: usize, - _context: Arc, + _context: Arc, ) -> Result { unreachable!() } + + fn apply_expressions( + &self, + _f: &mut dyn FnMut( + &dyn datafusion::physical_plan::PhysicalExpr, + ) -> Result, + ) -> Result { + Ok(TreeNodeRecursion::Continue) + } } /// A PhysicalExtensionCodec that can serialize and deserialize ChildExec @@ -214,7 +234,7 @@ impl PhysicalExtensionCodec for ChildPhysicalExtensionCodec { &self, buf: &[u8], _inputs: &[Arc], - _registry: &dyn FunctionRegistry, + _ctx: &TaskContext, ) -> Result> { if buf == "ChildExec".as_bytes() { Ok(Arc::new(ChildExec {})) @@ -232,60 +252,3 @@ impl PhysicalExtensionCodec for ChildPhysicalExtensionCodec { } } } - -/// A PhysicalExtensionCodec that tries one of multiple inner codecs -/// until one works -#[derive(Debug)] -struct ComposedPhysicalExtensionCodec { - codecs: Vec>, -} - -impl ComposedPhysicalExtensionCodec { - fn try_any( - &self, - mut f: impl FnMut(&dyn PhysicalExtensionCodec) -> Result, - ) -> Result { - let mut last_err = None; - for codec in &self.codecs { - match f(codec.as_ref()) { - Ok(node) => return Ok(node), - Err(err) => last_err = Some(err), - } - } - - Err(last_err.unwrap_or_else(|| { - DataFusionError::NotImplemented("Empty list of composed codecs".to_owned()) - })) - } -} - -impl PhysicalExtensionCodec for ComposedPhysicalExtensionCodec { - fn try_decode( - &self, - buf: &[u8], - inputs: &[Arc], - registry: &dyn FunctionRegistry, - ) -> Result> { - self.try_any(|codec| codec.try_decode(buf, inputs, registry)) - } - - fn try_encode(&self, node: Arc, buf: &mut Vec) -> Result<()> { - self.try_any(|codec| codec.try_encode(node.clone(), buf)) - } - - fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { - self.try_any(|codec| codec.try_decode_udf(name, buf)) - } - - fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { - self.try_any(|codec| codec.try_encode_udf(node, buf)) - } - - fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result> { - self.try_any(|codec| codec.try_decode_udaf(name, buf)) - } - - fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec) -> Result<()> { - self.try_any(|codec| codec.try_encode_udaf(node, buf)) - } -} diff --git a/datafusion-examples/examples/proto/expression_deduplication.rs b/datafusion-examples/examples/proto/expression_deduplication.rs new file mode 100644 index 0000000000000..0dec807f8043a --- /dev/null +++ b/datafusion-examples/examples/proto/expression_deduplication.rs @@ -0,0 +1,275 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. +//! +//! This example demonstrates how to use the `PhysicalExtensionCodec` trait's +//! interception methods to implement expression deduplication during deserialization. +//! +//! This pattern is inspired by PR #18192, which introduces expression caching +//! to reduce memory usage when deserializing plans with duplicate expressions. +//! +//! The key insight is that identical expressions serialize to identical protobuf bytes. +//! By caching deserialized expressions keyed by their protobuf bytes, we can: +//! 1. Return the same Arc for duplicate expressions +//! 2. Reduce memory allocation during deserialization +//! 3. Enable downstream optimizations that rely on Arc pointer equality +//! +//! This demonstrates the decorator pattern enabled by the `PhysicalExtensionCodec` trait, +//! where all expression serialization/deserialization routes through the codec methods. + +use std::collections::HashMap; +use std::fmt::Debug; +use std::sync::{Arc, RwLock}; + +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion::common::Result; +use datafusion::execution::TaskContext; +use datafusion::logical_expr::Operator; +use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::expressions::{BinaryExpr, col}; +use datafusion::physical_plan::filter::FilterExec; +use datafusion::physical_plan::placeholder_row::PlaceholderRowExec; +use datafusion::prelude::SessionContext; +use datafusion_proto::physical_plan::from_proto::parse_physical_expr_with_converter; +use datafusion_proto::physical_plan::to_proto::serialize_physical_expr_with_converter; +use datafusion_proto::physical_plan::{ + DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, + PhysicalProtoConverterExtension, +}; +use datafusion_proto::protobuf::{PhysicalExprNode, PhysicalPlanNode}; +use prost::Message; + +/// Example showing how to implement expression deduplication using the codec decorator pattern. +/// +/// This demonstrates: +/// 1. Creating a CachingCodec that caches expressions by their protobuf bytes +/// 2. Intercepting deserialization to return cached Arcs for duplicate expressions +/// 3. Verifying that duplicate expressions share the same Arc after deserialization +/// +/// Deduplication is keyed by the protobuf bytes representing the expression, +/// in reality deduplication could be done based on e.g. the pointer address of the +/// serialized expression in memory, but this is simpler to demonstrate. +/// +/// In this case our expression is trivial and just for demonstration purposes. +/// In real scenarios, expressions can be much more complex, e.g. a large InList +/// expression could be megabytes in size, so deduplication can save significant memory +/// in addition to more correctly representing the original plan structure. +pub async fn expression_deduplication() -> Result<()> { + println!("=== Expression Deduplication Example ===\n"); + + // Create a schema for our test expressions + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Boolean, false)])); + + // Step 1: Create expressions with duplicates + println!("Step 1: Creating expressions with duplicates..."); + + // Create expression: col("a") + let a = col("a", &schema)?; + + // Create a clone to show duplicates + let a_clone = Arc::clone(&a); + + // Combine: a OR a_clone + let combined_expr = + Arc::new(BinaryExpr::new(a, Operator::Or, a_clone)) as Arc; + println!(" Created expression: a OR a with duplicates"); + println!(" Note: a appears twice in the expression tree\n"); + // Step 2: Create a filter plan with this expression + println!("Step 2: Creating physical plan with the expression..."); + + let input = Arc::new(PlaceholderRowExec::new(Arc::clone(&schema))); + let filter_plan: Arc = + Arc::new(FilterExec::try_new(combined_expr, input)?); + + println!(" Created FilterExec with duplicate sub-expressions\n"); + + // Step 3: Serialize with the caching codec + println!("Step 3: Serializing plan..."); + + let extension_codec = DefaultPhysicalExtensionCodec {}; + let caching_converter = CachingCodec::new(); + let proto = + caching_converter.execution_plan_to_proto(&filter_plan, &extension_codec)?; + + // Serialize to bytes + let mut bytes = Vec::new(); + proto.encode(&mut bytes).unwrap(); + println!(" Serialized plan to {} bytes\n", bytes.len()); + + // Step 4: Deserialize with the caching codec + println!("Step 4: Deserializing plan with CachingCodec..."); + + let ctx = SessionContext::new(); + let deserialized_plan = proto.try_into_physical_plan_with_converter( + &ctx.task_ctx(), + &extension_codec, + &caching_converter, + )?; + + // Step 5: check that we deduplicated expressions + println!("Step 5: Checking for deduplicated expressions..."); + let Some(filter_exec) = deserialized_plan.as_any().downcast_ref::() + else { + panic!("Deserialized plan is not a FilterExec"); + }; + let predicate = Arc::clone(filter_exec.predicate()); + let binary_expr = predicate + .as_any() + .downcast_ref::() + .expect("Predicate is not a BinaryExpr"); + let left = &binary_expr.left(); + let right = &binary_expr.right(); + // Check if left and right point to the same Arc + let deduplicated = Arc::ptr_eq(left, right); + if deduplicated { + println!(" Success: Duplicate expressions were deduplicated!"); + println!( + " Cache Stats: hits={}, misses={}", + caching_converter.stats.read().unwrap().cache_hits, + caching_converter.stats.read().unwrap().cache_misses, + ); + } else { + println!(" Failure: Duplicate expressions were NOT deduplicated."); + } + + Ok(()) +} + +// ============================================================================ +// CachingCodec - Implements expression deduplication +// ============================================================================ + +/// Statistics for cache performance monitoring +#[derive(Debug, Default)] +struct CacheStats { + cache_hits: usize, + cache_misses: usize, +} + +/// A codec that caches deserialized expressions to enable deduplication. +/// +/// When deserializing, if we've already seen the same protobuf bytes, +/// we return the cached Arc instead of creating a new allocation. +#[derive(Debug, Default)] +struct CachingCodec { + /// Cache mapping protobuf bytes -> deserialized expression + expr_cache: RwLock, Arc>>, + /// Statistics for demonstration + stats: RwLock, +} + +impl CachingCodec { + fn new() -> Self { + Self::default() + } +} + +impl PhysicalExtensionCodec for CachingCodec { + // Required: decode custom extension nodes + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[Arc], + _ctx: &TaskContext, + ) -> Result> { + datafusion::common::not_impl_err!("No custom extension nodes") + } + + // Required: encode custom execution plans + fn try_encode( + &self, + _node: Arc, + _buf: &mut Vec, + ) -> Result<()> { + datafusion::common::not_impl_err!("No custom extension nodes") + } +} + +impl PhysicalProtoConverterExtension for CachingCodec { + fn proto_to_execution_plan( + &self, + ctx: &TaskContext, + extension_codec: &dyn PhysicalExtensionCodec, + proto: &PhysicalPlanNode, + ) -> Result> { + proto.try_into_physical_plan_with_converter(ctx, extension_codec, self) + } + + fn execution_plan_to_proto( + &self, + plan: &Arc, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(plan), + extension_codec, + self, + ) + } + + // CACHING IMPLEMENTATION: Intercept expression deserialization + fn proto_to_physical_expr( + &self, + proto: &PhysicalExprNode, + ctx: &TaskContext, + input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + // Create cache key from protobuf bytes + let mut key = Vec::new(); + proto.encode(&mut key).map_err(|e| { + datafusion::error::DataFusionError::Internal(format!( + "Failed to encode proto for cache key: {e}" + )) + })?; + + // Check cache first + { + let cache = self.expr_cache.read().unwrap(); + if let Some(cached) = cache.get(&key) { + // Cache hit! Update stats and return cached Arc + let mut stats = self.stats.write().unwrap(); + stats.cache_hits += 1; + return Ok(Arc::clone(cached)); + } + } + + // Cache miss - deserialize and store + let expr = + parse_physical_expr_with_converter(proto, ctx, input_schema, codec, self)?; + + // Store in cache + { + let mut cache = self.expr_cache.write().unwrap(); + cache.insert(key, Arc::clone(&expr)); + let mut stats = self.stats.write().unwrap(); + stats.cache_misses += 1; + } + + Ok(expr) + } + + fn physical_expr_to_proto( + &self, + expr: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result { + serialize_physical_expr_with_converter(expr, codec, self) + } +} diff --git a/datafusion-examples/examples/proto/main.rs b/datafusion-examples/examples/proto/main.rs new file mode 100644 index 0000000000000..3f525b5d46afa --- /dev/null +++ b/datafusion-examples/examples/proto/main.rs @@ -0,0 +1,92 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # Examples demonstrating DataFusion's plan serialization via the `datafusion-proto` crate +//! +//! These examples show how to use multiple extension codecs for serialization / deserialization. +//! +//! ## Usage +//! ```bash +//! cargo run --example proto -- [all|composed_extension_codec|expression_deduplication] +//! ``` +//! +//! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module +//! +//! - `composed_extension_codec` +//! (file: composed_extension_codec.rs, desc: Use multiple extension codecs for serialization/deserialization) +//! +//! - `expression_deduplication` +//! (file: expression_deduplication.rs, desc: Example of expression caching/deduplication using the codec decorator pattern) + +mod composed_extension_codec; +mod expression_deduplication; + +use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; + +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] +enum ExampleKind { + All, + ComposedExtensionCodec, + ExpressionDeduplication, +} + +impl ExampleKind { + const EXAMPLE_NAME: &str = "proto"; + + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<()> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::ComposedExtensionCodec => { + composed_extension_codec::composed_extension_codec().await? + } + ExampleKind::ExpressionDeduplication => { + expression_deduplication::expression_deduplication().await? + } + } + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let usage = format!( + "Usage: cargo run --example {} -- [{}]", + ExampleKind::EXAMPLE_NAME, + ExampleKind::VARIANTS.join("|") + ); + + let example: ExampleKind = std::env::args() + .nth(1) + .unwrap_or_else(|| ExampleKind::All.to_string()) + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; + + example.run().await +} diff --git a/datafusion-examples/examples/analyzer_rule.rs b/datafusion-examples/examples/query_planning/analyzer_rule.rs similarity index 97% rename from datafusion-examples/examples/analyzer_rule.rs rename to datafusion-examples/examples/query_planning/analyzer_rule.rs index cb81cd167a88b..a86f5cdd2a5e3 100644 --- a/datafusion-examples/examples/analyzer_rule.rs +++ b/datafusion-examples/examples/query_planning/analyzer_rule.rs @@ -15,11 +15,13 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; +use datafusion::common::Result; use datafusion::common::config::ConfigOptions; use datafusion::common::tree_node::{Transformed, TreeNode}; -use datafusion::common::Result; -use datafusion::logical_expr::{col, lit, Expr, LogicalPlan, LogicalPlanBuilder}; +use datafusion::logical_expr::{Expr, LogicalPlan, LogicalPlanBuilder, col, lit}; use datafusion::optimizer::analyzer::AnalyzerRule; use datafusion::prelude::SessionContext; use std::sync::{Arc, Mutex}; @@ -35,8 +37,7 @@ use std::sync::{Arc, Mutex}; /// level access control scheme by introducing a filter to the query. /// /// See [optimizer_rule.rs] for an example of a optimizer rule -#[tokio::main] -pub async fn main() -> Result<()> { +pub async fn analyzer_rule() -> Result<()> { // AnalyzerRules run before OptimizerRules. // // DataFusion includes several built in AnalyzerRules for tasks such as type diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/query_planning/expr_api.rs similarity index 94% rename from datafusion-examples/examples/expr_api.rs rename to datafusion-examples/examples/query_planning/expr_api.rs index 089b8db6a5a06..386273c72817b 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/query_planning/expr_api.rs @@ -15,10 +15,12 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use std::collections::HashMap; use std::sync::Arc; -use arrow::array::{BooleanArray, Int32Array, Int8Array}; +use arrow::array::{BooleanArray, Int8Array, Int32Array}; use arrow::record_batch::RecordBatch; use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; @@ -35,7 +37,7 @@ use datafusion::logical_expr::simplify::SimplifyContext; use datafusion::logical_expr::{ColumnarValue, ExprFunctionExt, ExprSchemable, Operator}; use datafusion::optimizer::analyzer::type_coercion::TypeCoercionRewriter; use datafusion::optimizer::simplify_expressions::ExprSimplifier; -use datafusion::physical_expr::{analyze, AnalysisContext, ExprBoundaries}; +use datafusion::physical_expr::{AnalysisContext, ExprBoundaries, analyze}; use datafusion::prelude::*; /// This example demonstrates the DataFusion [`Expr`] API. @@ -55,8 +57,7 @@ use datafusion::prelude::*; /// 5. Analyze predicates for boundary ranges: [`range_analysis_demo`] /// 6. Get the types of the expressions: [`expression_type_demo`] /// 7. Apply type coercion to expressions: [`type_coercion_demo`] -#[tokio::main] -async fn main() -> Result<()> { +pub async fn expr_api() -> Result<()> { // The easiest way to do create expressions is to use the // "fluent"-style API: let expr = col("a") + lit(5); @@ -65,7 +66,7 @@ async fn main() -> Result<()> { let expr2 = Expr::BinaryExpr(BinaryExpr::new( Box::new(col("a")), Operator::Plus, - Box::new(Expr::Literal(ScalarValue::Int32(Some(5)))), + Box::new(Expr::Literal(ScalarValue::Int32(Some(5)), None)), )); assert_eq!(expr, expr2); @@ -85,7 +86,7 @@ async fn main() -> Result<()> { boundary_analysis_and_selectivity_demo()?; // See how boundary analysis works for `AND` & `OR` conjunctions. - boundary_analysis_in_conjuctions_demo()?; + boundary_analysis_in_conjunctions_demo()?; // See how to determine the data types of expressions expression_type_demo()?; @@ -174,8 +175,9 @@ fn simplify_demo() -> Result<()> { // the ExecutionProps carries information needed to simplify // expressions, such as the current time (to evaluate `now()` // correctly) - let props = ExecutionProps::new(); - let context = SimplifyContext::new(&props).with_schema(schema); + let context = SimplifyContext::default() + .with_schema(schema) + .with_current_time(); let simplifier = ExprSimplifier::new(context); // And then call the simplify_expr function: @@ -190,7 +192,9 @@ fn simplify_demo() -> Result<()> { // here are some other examples of what DataFusion is capable of let schema = Schema::new(vec![make_field("i", DataType::Int64)]).to_dfschema_ref()?; - let context = SimplifyContext::new(&props).with_schema(schema.clone()); + let context = SimplifyContext::default() + .with_schema(Arc::clone(&schema)) + .with_current_time(); let simplifier = ExprSimplifier::new(context); // basic arithmetic simplification @@ -302,6 +306,7 @@ fn boundary_analysis_and_selectivity_demo() -> Result<()> { min_value: Precision::Exact(ScalarValue::Int64(Some(1))), sum_value: Precision::Absent, distinct_count: Precision::Absent, + byte_size: Precision::Absent, }; // We can then build our expression boundaries from the column statistics @@ -342,16 +347,18 @@ fn boundary_analysis_and_selectivity_demo() -> Result<()> { // // (a' - b' + 1) / (a - b) // (10000 - 5000 + 1) / (10000 - 1) - assert!(analysis - .selectivity - .is_some_and(|selectivity| (0.5..=0.6).contains(&selectivity))); + assert!( + analysis + .selectivity + .is_some_and(|selectivity| (0.5..=0.6).contains(&selectivity)) + ); Ok(()) } /// This function shows how to think about and leverage the analysis API /// to infer boundaries in `AND` & `OR` conjunctions. -fn boundary_analysis_in_conjuctions_demo() -> Result<()> { +fn boundary_analysis_in_conjunctions_demo() -> Result<()> { // Let us consider the more common case of AND & OR conjunctions. // // age > 18 AND age <= 25 @@ -369,6 +376,7 @@ fn boundary_analysis_in_conjuctions_demo() -> Result<()> { min_value: Precision::Exact(ScalarValue::Int64(Some(14))), sum_value: Precision::Absent, distinct_count: Precision::Absent, + byte_size: Precision::Absent, }; let initial_boundaries = @@ -414,9 +422,11 @@ fn boundary_analysis_in_conjuctions_demo() -> Result<()> { // // Granted a column such as age will more likely follow a Normal distribution // as such our selectivity estimation will not be as good as it can. - assert!(analysis - .selectivity - .is_some_and(|selectivity| (0.1..=0.2).contains(&selectivity))); + assert!( + analysis + .selectivity + .is_some_and(|selectivity| (0.1..=0.2).contains(&selectivity)) + ); // The above example was a good way to look at how we can derive better // interval and get a lower selectivity during boundary analysis. @@ -519,7 +529,7 @@ fn type_coercion_demo() -> Result<()> { )?; let i8_array = Int8Array::from_iter_values(vec![0, 1, 2]); let batch = RecordBatch::try_new( - Arc::new(df_schema.as_arrow().to_owned()), + Arc::clone(df_schema.inner()), vec![Arc::new(i8_array) as _], )?; @@ -532,10 +542,11 @@ fn type_coercion_demo() -> Result<()> { let physical_expr = datafusion::physical_expr::create_physical_expr(&expr, &df_schema, &props)?; let e = physical_expr.evaluate(&batch).unwrap_err(); - assert!(e - .find_root() - .to_string() - .contains("Invalid comparison operation: Int8 > Int32")); + assert!( + e.find_root() + .to_string() + .contains("Invalid comparison operation: Int8 > Int32") + ); // 1. Type coercion with `SessionContext::create_physical_expr` which implicitly applies type coercion before constructing the physical expr. let physical_expr = @@ -543,7 +554,9 @@ fn type_coercion_demo() -> Result<()> { assert!(physical_expr.evaluate(&batch).is_ok()); // 2. Type coercion with `ExprSimplifier::coerce`. - let context = SimplifyContext::new(&props).with_schema(Arc::new(df_schema.clone())); + let context = SimplifyContext::default() + .with_schema(Arc::new(df_schema.clone())) + .with_current_time(); let simplifier = ExprSimplifier::new(context); let coerced_expr = simplifier.coerce(expr.clone(), &df_schema)?; let physical_expr = datafusion::physical_expr::create_physical_expr( diff --git a/datafusion-examples/examples/query_planning/main.rs b/datafusion-examples/examples/query_planning/main.rs new file mode 100644 index 0000000000000..d3f99aedceb3d --- /dev/null +++ b/datafusion-examples/examples/query_planning/main.rs @@ -0,0 +1,124 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # These are all internal mechanics of the query planning and optimization layers +//! +//! These examples demonstrate internal mechanics of the query planning and optimization layers. +//! +//! ## Usage +//! ```bash +//! cargo run --example query_planning -- [all|analyzer_rule|expr_api|optimizer_rule|parse_sql_expr|plan_to_sql|planner_api|pruning|thread_pools] +//! ``` +//! +//! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module +//! +//! - `analyzer_rule` +//! (file: analyzer_rule.rs, desc: Custom AnalyzerRule to change query semantics) +//! +//! - `expr_api` +//! (file: expr_api.rs, desc: Create, execute, analyze, and coerce Exprs) +//! +//! - `optimizer_rule` +//! (file: optimizer_rule.rs, desc: Replace predicates via a custom OptimizerRule) +//! +//! - `parse_sql_expr` +//! (file: parse_sql_expr.rs, desc: Parse SQL into DataFusion Expr) +//! +//! - `plan_to_sql` +//! (file: plan_to_sql.rs, desc: Generate SQL from expressions or plans) +//! +//! - `planner_api` +//! (file: planner_api.rs, desc: APIs for logical and physical plan manipulation) +//! +//! - `pruning` +//! (file: pruning.rs, desc: Use pruning to skip irrelevant files) +//! +//! - `thread_pools` +//! (file: thread_pools.rs, desc: Configure custom thread pools for DataFusion execution) + +mod analyzer_rule; +mod expr_api; +mod optimizer_rule; +mod parse_sql_expr; +mod plan_to_sql; +mod planner_api; +mod pruning; +mod thread_pools; + +use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; + +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] +enum ExampleKind { + All, + AnalyzerRule, + ExprApi, + OptimizerRule, + ParseSqlExpr, + PlanToSql, + PlannerApi, + Pruning, + ThreadPools, +} + +impl ExampleKind { + const EXAMPLE_NAME: &str = "query_planning"; + + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<()> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::AnalyzerRule => analyzer_rule::analyzer_rule().await?, + ExampleKind::ExprApi => expr_api::expr_api().await?, + ExampleKind::OptimizerRule => optimizer_rule::optimizer_rule().await?, + ExampleKind::ParseSqlExpr => parse_sql_expr::parse_sql_expr().await?, + ExampleKind::PlanToSql => plan_to_sql::plan_to_sql_examples().await?, + ExampleKind::PlannerApi => planner_api::planner_api().await?, + ExampleKind::Pruning => pruning::pruning().await?, + ExampleKind::ThreadPools => thread_pools::thread_pools().await?, + } + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let usage = format!( + "Usage: cargo run --example {} -- [{}]", + ExampleKind::EXAMPLE_NAME, + ExampleKind::VARIANTS.join("|") + ); + + let example: ExampleKind = std::env::args() + .nth(1) + .unwrap_or_else(|| ExampleKind::All.to_string()) + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; + + example.run().await +} diff --git a/datafusion-examples/examples/optimizer_rule.rs b/datafusion-examples/examples/query_planning/optimizer_rule.rs similarity index 96% rename from datafusion-examples/examples/optimizer_rule.rs rename to datafusion-examples/examples/query_planning/optimizer_rule.rs index 63f17484809e2..de9de7737a6a0 100644 --- a/datafusion-examples/examples/optimizer_rule.rs +++ b/datafusion-examples/examples/query_planning/optimizer_rule.rs @@ -15,10 +15,12 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; use arrow::datatypes::DataType; use datafusion::common::tree_node::{Transformed, TreeNode}; -use datafusion::common::{assert_batches_eq, Result, ScalarValue}; +use datafusion::common::{Result, ScalarValue, assert_batches_eq}; use datafusion::logical_expr::{ BinaryExpr, ColumnarValue, Expr, LogicalPlan, Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, @@ -37,8 +39,7 @@ use std::sync::Arc; /// /// See [analyzer_rule.rs] for an example of AnalyzerRules, which are for /// changing plan semantics. -#[tokio::main] -pub async fn main() -> Result<()> { +pub async fn optimizer_rule() -> Result<()> { // DataFusion includes many built in OptimizerRules for tasks such as outer // to inner join conversion and constant folding. // @@ -171,11 +172,11 @@ fn is_binary_eq(binary_expr: &BinaryExpr) -> bool { /// Return true if the expression is a literal or column reference fn is_lit_or_col(expr: &Expr) -> bool { - matches!(expr, Expr::Column(_) | Expr::Literal(_)) + matches!(expr, Expr::Column(_) | Expr::Literal(_, _)) } /// A simple user defined filter function -#[derive(Debug, Clone)] +#[derive(Debug, PartialEq, Eq, Hash, Clone)] struct MyEq { signature: Signature, } diff --git a/datafusion-examples/examples/parse_sql_expr.rs b/datafusion-examples/examples/query_planning/parse_sql_expr.rs similarity index 68% rename from datafusion-examples/examples/parse_sql_expr.rs rename to datafusion-examples/examples/query_planning/parse_sql_expr.rs index 5387e7c4a05dc..74072b8480f99 100644 --- a/datafusion-examples/examples/parse_sql_expr.rs +++ b/datafusion-examples/examples/query_planning/parse_sql_expr.rs @@ -15,8 +15,11 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use arrow::datatypes::{DataType, Field, Schema}; use datafusion::common::DFSchema; +use datafusion::common::ScalarValue; use datafusion::logical_expr::{col, lit}; use datafusion::sql::unparser::Unparser; use datafusion::{ @@ -24,6 +27,7 @@ use datafusion::{ error::Result, prelude::{ParquetReadOptions, SessionContext}, }; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; /// This example demonstrates the programmatic parsing of SQL expressions using /// the DataFusion [`SessionContext::parse_sql_expr`] API or the [`DataFrame::parse_sql_expr`] API. @@ -32,17 +36,15 @@ use datafusion::{ /// The code in this example shows how to: /// /// 1. [`simple_session_context_parse_sql_expr_demo`]: Parse a simple SQL text into a logical -/// expression using a schema at [`SessionContext`]. +/// expression using a schema at [`SessionContext`]. /// /// 2. [`simple_dataframe_parse_sql_expr_demo`]: Parse a simple SQL text into a logical expression -/// using a schema at [`DataFrame`]. +/// using a schema at [`DataFrame`]. /// /// 3. [`query_parquet_demo`]: Query a parquet file using the parsed_sql_expr from a DataFrame. /// /// 4. [`round_trip_parse_sql_expr_demo`]: Parse a SQL text and convert it back to SQL using [`Unparser`]. - -#[tokio::main] -async fn main() -> Result<()> { +pub async fn parse_sql_expr() -> Result<()> { // See how to evaluate expressions simple_session_context_parse_sql_expr_demo()?; simple_dataframe_parse_sql_expr_demo().await?; @@ -70,18 +72,19 @@ fn simple_session_context_parse_sql_expr_demo() -> Result<()> { /// DataFusion can parse a SQL text to an logical expression using schema at [`DataFrame`]. async fn simple_dataframe_parse_sql_expr_demo() -> Result<()> { - let sql = "int_col < 5 OR double_col = 8.0"; - let expr = col("int_col") - .lt(lit(5_i64)) - .or(col("double_col").eq(lit(8.0_f64))); + let sql = "car = 'red' OR speed > 1.0"; + let expr = col("car") + .eq(lit(ScalarValue::Utf8(Some("red".to_string())))) + .or(col("speed").gt(lit(1.0_f64))); let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); + + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; + let df = ctx - .read_parquet( - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) + .read_parquet(parquet_temp.path_str()?, ParquetReadOptions::default()) .await?; let parsed_expr = df.parse_sql_expr(sql)?; @@ -93,39 +96,37 @@ async fn simple_dataframe_parse_sql_expr_demo() -> Result<()> { async fn query_parquet_demo() -> Result<()> { let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); + + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; + let df = ctx - .read_parquet( - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) + .read_parquet(parquet_temp.path_str()?, ParquetReadOptions::default()) .await?; let df = df .clone() - .select(vec![ - df.parse_sql_expr("int_col")?, - df.parse_sql_expr("double_col")?, - ])? - .filter(df.parse_sql_expr("int_col < 5 OR double_col = 8.0")?)? + .select(vec![df.parse_sql_expr("car")?, df.parse_sql_expr("speed")?])? + .filter(df.parse_sql_expr("car = 'red' OR speed > 1.0")?)? .aggregate( - vec![df.parse_sql_expr("double_col")?], - vec![df.parse_sql_expr("SUM(int_col) as sum_int_col")?], + vec![df.parse_sql_expr("car")?], + vec![df.parse_sql_expr("SUM(speed) as sum_speed")?], )? // Directly parsing the SQL text into a sort expression is not supported yet, so // construct it programmatically - .sort(vec![col("double_col").sort(false, false)])? + .sort(vec![col("car").sort(false, false)])? .limit(0, Some(1))?; let result = df.collect().await?; assert_batches_eq!( &[ - "+------------+-------------+", - "| double_col | sum_int_col |", - "+------------+-------------+", - "| 10.1 | 4 |", - "+------------+-------------+", + "+-----+--------------------+", + "| car | sum_speed |", + "+-----+--------------------+", + "| red | 162.49999999999997 |", + "+-----+--------------------+" ], &result ); @@ -135,15 +136,16 @@ async fn query_parquet_demo() -> Result<()> { /// DataFusion can parse a SQL text and convert it back to SQL using [`Unparser`]. async fn round_trip_parse_sql_expr_demo() -> Result<()> { - let sql = "((int_col < 5) OR (double_col = 8))"; + let sql = "((car = 'red') OR (speed > 1.0))"; let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); + + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; + let df = ctx - .read_parquet( - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) + .read_parquet(parquet_temp.path_str()?, ParquetReadOptions::default()) .await?; let parsed_expr = df.parse_sql_expr(sql)?; @@ -158,7 +160,7 @@ async fn round_trip_parse_sql_expr_demo() -> Result<()> { // difference in precedence rules between DataFusion and target engines. let unparser = Unparser::default().with_pretty(true); - let pretty = "int_col < 5 OR double_col = 8"; + let pretty = "car = 'red' OR speed > 1.0"; let pretty_round_trip_sql = unparser.expr_to_sql(&parsed_expr)?.to_string(); assert_eq!(pretty, pretty_round_trip_sql); diff --git a/datafusion-examples/examples/plan_to_sql.rs b/datafusion-examples/examples/query_planning/plan_to_sql.rs similarity index 77% rename from datafusion-examples/examples/plan_to_sql.rs rename to datafusion-examples/examples/query_planning/plan_to_sql.rs index 54483b143a169..86aebbc0b2c33 100644 --- a/datafusion-examples/examples/plan_to_sql.rs +++ b/datafusion-examples/examples/query_planning/plan_to_sql.rs @@ -15,7 +15,13 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + +use std::fmt; +use std::sync::Arc; + use datafusion::common::DFSchemaRef; +use datafusion::common::ScalarValue; use datafusion::error::Result; use datafusion::logical_expr::sqlparser::ast::Statement; use datafusion::logical_expr::{ @@ -32,9 +38,8 @@ use datafusion::sql::unparser::extension_unparser::UserDefinedLogicalNodeUnparse use datafusion::sql::unparser::extension_unparser::{ UnparseToStatementResult, UnparseWithinStatementResult, }; -use datafusion::sql::unparser::{plan_to_sql, Unparser}; -use std::fmt; -use std::sync::Arc; +use datafusion::sql::unparser::{Unparser, plan_to_sql}; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; /// This example demonstrates the programmatic construction of SQL strings using /// the DataFusion Expr [`Expr`] and LogicalPlan [`LogicalPlan`] API. @@ -43,28 +48,26 @@ use std::sync::Arc; /// The code in this example shows how to: /// /// 1. [`simple_expr_to_sql_demo`]: Create a simple expression [`Exprs`] with -/// fluent API and convert to sql suitable for passing to another database +/// fluent API and convert to sql suitable for passing to another database /// /// 2. [`simple_expr_to_pretty_sql_demo`] Create a simple expression -/// [`Exprs`] with fluent API and convert to sql without extra parentheses, -/// suitable for displaying to humans +/// [`Exprs`] with fluent API and convert to sql without extra parentheses, +/// suitable for displaying to humans /// /// 3. [`simple_expr_to_sql_demo_escape_mysql_style`]" Create a simple -/// expression [`Exprs`] with fluent API and convert to sql escaping column -/// names in MySQL style. +/// expression [`Exprs`] with fluent API and convert to sql escaping column +/// names in MySQL style. /// /// 4. [`simple_plan_to_sql_demo`]: Create a simple logical plan using the -/// DataFrames API and convert to sql string. +/// DataFrames API and convert to sql string. /// /// 5. [`round_trip_plan_to_sql_demo`]: Create a logical plan from a SQL string, modify it using the -/// DataFrames API and convert it back to a sql string. +/// DataFrames API and convert it back to a sql string. /// /// 6. [`unparse_my_logical_plan_as_statement`]: Create a custom logical plan and unparse it as a statement. /// /// 7. [`unparse_my_logical_plan_as_subquery`]: Create a custom logical plan and unparse it as a subquery. - -#[tokio::main] -async fn main() -> Result<()> { +pub async fn plan_to_sql_examples() -> Result<()> { // See how to evaluate expressions simple_expr_to_sql_demo()?; simple_expr_to_pretty_sql_demo()?; @@ -114,21 +117,21 @@ fn simple_expr_to_sql_demo_escape_mysql_style() -> Result<()> { async fn simple_plan_to_sql_demo() -> Result<()> { let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; + let df = ctx - .read_parquet( - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) + .read_parquet(parquet_temp.path_str()?, ParquetReadOptions::default()) .await? - .select_columns(&["id", "int_col", "double_col", "date_string_col"])?; + .select_columns(&["car", "speed", "time"])?; // Convert the data frame to a SQL string let sql = plan_to_sql(df.logical_plan())?.to_string(); assert_eq!( sql, - r#"SELECT "?table?".id, "?table?".int_col, "?table?".double_col, "?table?".date_string_col FROM "?table?""# + r#"SELECT "?table?".car, "?table?".speed, "?table?"."time" FROM "?table?""# ); Ok(()) @@ -139,35 +142,35 @@ async fn simple_plan_to_sql_demo() -> Result<()> { async fn round_trip_plan_to_sql_demo() -> Result<()> { let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; // register parquet file with the execution context ctx.register_parquet( - "alltypes_plain", - &format!("{testdata}/alltypes_plain.parquet"), + "cars", + parquet_temp.path_str()?, ParquetReadOptions::default(), ) .await?; // create a logical plan from a SQL string and then programmatically add new filters + // select car, speed, time from cars where speed > 1 and car = 'red' let df = ctx // Use SQL to read some data from the parquet file - .sql( - "SELECT int_col, double_col, CAST(date_string_col as VARCHAR) \ - FROM alltypes_plain", - ) + .sql("SELECT car, speed, time FROM cars") .await? - // Add id > 1 and tinyint_col < double_col filter + // Add speed > 1 and car = 'red' filter .filter( - col("id") + col("speed") .gt(lit(1)) - .and(col("tinyint_col").lt(col("double_col"))), + .and(col("car").eq(lit(ScalarValue::Utf8(Some("red".to_string()))))), )?; let sql = plan_to_sql(df.logical_plan())?.to_string(); assert_eq!( sql, - r#"SELECT alltypes_plain.int_col, alltypes_plain.double_col, CAST(alltypes_plain.date_string_col AS VARCHAR) FROM alltypes_plain WHERE ((alltypes_plain.id > 1) AND (alltypes_plain.tinyint_col < alltypes_plain.double_col))"# + r#"SELECT cars.car, cars.speed, cars."time" FROM cars WHERE ((cars.speed > 1) AND (cars.car = 'red'))"# ); Ok(()) @@ -211,6 +214,7 @@ impl UserDefinedLogicalNodeCore for MyLogicalPlan { } struct PlanToStatement {} + impl UserDefinedLogicalNodeUnparser for PlanToStatement { fn unparse_to_statement( &self, @@ -231,14 +235,15 @@ impl UserDefinedLogicalNodeUnparser for PlanToStatement { /// It can be unparse as a statement that reads from the same parquet file. async fn unparse_my_logical_plan_as_statement() -> Result<()> { let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); + + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; + let inner_plan = ctx - .read_parquet( - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) + .read_parquet(parquet_temp.path_str()?, ParquetReadOptions::default()) .await? - .select_columns(&["id", "int_col", "double_col", "date_string_col"])? + .select_columns(&["car", "speed", "time"])? .into_unoptimized_plan(); let node = Arc::new(MyLogicalPlan { input: inner_plan }); @@ -249,7 +254,7 @@ async fn unparse_my_logical_plan_as_statement() -> Result<()> { let sql = unparser.plan_to_sql(&my_plan)?.to_string(); assert_eq!( sql, - r#"SELECT "?table?".id, "?table?".int_col, "?table?".double_col, "?table?".date_string_col FROM "?table?""# + r#"SELECT "?table?".car, "?table?".speed, "?table?"."time" FROM "?table?""# ); Ok(()) } @@ -284,14 +289,15 @@ impl UserDefinedLogicalNodeUnparser for PlanToSubquery { /// It can be unparse as a subquery that reads from the same parquet file, with some columns projected. async fn unparse_my_logical_plan_as_subquery() -> Result<()> { let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); + + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; + let inner_plan = ctx - .read_parquet( - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) + .read_parquet(parquet_temp.path_str()?, ParquetReadOptions::default()) .await? - .select_columns(&["id", "int_col", "double_col", "date_string_col"])? + .select_columns(&["car", "speed", "time"])? .into_unoptimized_plan(); let node = Arc::new(MyLogicalPlan { input: inner_plan }); @@ -299,8 +305,8 @@ async fn unparse_my_logical_plan_as_subquery() -> Result<()> { let my_plan = LogicalPlan::Extension(Extension { node }); let plan = LogicalPlanBuilder::from(my_plan) .project(vec![ - col("id").alias("my_id"), - col("int_col").alias("my_int"), + col("car").alias("my_car"), + col("speed").alias("my_speed"), ])? .build()?; let unparser = @@ -308,8 +314,8 @@ async fn unparse_my_logical_plan_as_subquery() -> Result<()> { let sql = unparser.plan_to_sql(&plan)?.to_string(); assert_eq!( sql, - "SELECT \"?table?\".id AS my_id, \"?table?\".int_col AS my_int FROM \ - (SELECT \"?table?\".id, \"?table?\".int_col, \"?table?\".double_col, \"?table?\".date_string_col FROM \"?table?\")", + "SELECT \"?table?\".car AS my_car, \"?table?\".speed AS my_speed FROM \ + (SELECT \"?table?\".car, \"?table?\".speed, \"?table?\".\"time\" FROM \"?table?\")", ); Ok(()) } diff --git a/datafusion-examples/examples/planner_api.rs b/datafusion-examples/examples/query_planning/planner_api.rs similarity index 86% rename from datafusion-examples/examples/planner_api.rs rename to datafusion-examples/examples/query_planning/planner_api.rs index 55aec7b0108a4..8b2c09f4aecba 100644 --- a/datafusion-examples/examples/planner_api.rs +++ b/datafusion-examples/examples/query_planning/planner_api.rs @@ -15,11 +15,14 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use datafusion::error::Result; use datafusion::logical_expr::LogicalPlan; use datafusion::physical_plan::displayable; use datafusion::physical_planner::DefaultPhysicalPlanner; use datafusion::prelude::*; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; /// This example demonstrates the process of converting logical plan /// into physical execution plans using DataFusion. @@ -32,29 +35,26 @@ use datafusion::prelude::*; /// physical plan: /// - Via the combined `create_physical_plan` API. /// - Utilizing the analyzer, optimizer, and query planner APIs separately. -#[tokio::main] -async fn main() -> Result<()> { +pub async fn planner_api() -> Result<()> { // Set up a DataFusion context and load a Parquet file let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); + + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; + let df = ctx - .read_parquet( - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) + .read_parquet(parquet_temp.path_str()?, ParquetReadOptions::default()) .await?; // Construct the input logical plan using DataFrame API let df = df .clone() - .select(vec![ - df.parse_sql_expr("int_col")?, - df.parse_sql_expr("double_col")?, - ])? - .filter(df.parse_sql_expr("int_col < 5 OR double_col = 8.0")?)? + .select(vec![df.parse_sql_expr("car")?, df.parse_sql_expr("speed")?])? + .filter(df.parse_sql_expr("car = 'red' OR speed > 1.0")?)? .aggregate( - vec![df.parse_sql_expr("double_col")?], - vec![df.parse_sql_expr("SUM(int_col) as sum_int_col")?], + vec![df.parse_sql_expr("car")?], + vec![df.parse_sql_expr("SUM(speed) as sum_speed")?], )? .limit(0, Some(1))?; let logical_plan = df.logical_plan().clone(); diff --git a/datafusion-examples/examples/pruning.rs b/datafusion-examples/examples/query_planning/pruning.rs similarity index 95% rename from datafusion-examples/examples/pruning.rs rename to datafusion-examples/examples/query_planning/pruning.rs index b2d2fa13b7ed2..33f3f8428a77f 100644 --- a/datafusion-examples/examples/pruning.rs +++ b/datafusion-examples/examples/query_planning/pruning.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use std::collections::HashSet; use std::sync::Arc; @@ -22,6 +24,7 @@ use arrow::array::{ArrayRef, BooleanArray, Int32Array}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::common::pruning::PruningStatistics; use datafusion::common::{DFSchema, ScalarValue}; +use datafusion::error::Result; use datafusion::execution::context::ExecutionProps; use datafusion::physical_expr::create_physical_expr; use datafusion::physical_optimizer::pruning::PruningPredicate; @@ -40,8 +43,7 @@ use datafusion::prelude::*; /// one might do as part of a higher level storage engine. See /// `parquet_index.rs` for an example that uses pruning in the context of an /// individual query. -#[tokio::main] -async fn main() { +pub async fn pruning() -> Result<()> { // In this example, we'll use the PruningPredicate to determine if // the expression `x = 5 AND y = 10` can never be true based on statistics @@ -69,7 +71,7 @@ async fn main() { let predicate = create_pruning_predicate(expr, &my_catalog.schema); // Evaluate the predicate for the three files in the catalog - let prune_results = predicate.prune(&my_catalog).unwrap(); + let prune_results = predicate.prune(&my_catalog)?; println!("Pruning results: {prune_results:?}"); // The result is a `Vec` of bool values, one for each file in the catalog @@ -93,6 +95,8 @@ async fn main() { false ] ); + + Ok(()) } /// A simple model catalog that has information about the three files that store @@ -186,11 +190,12 @@ impl PruningStatistics for MyCatalog { } } +#[expect(clippy::needless_pass_by_value)] fn create_pruning_predicate(expr: Expr, schema: &SchemaRef) -> PruningPredicate { - let df_schema = DFSchema::try_from(schema.as_ref().clone()).unwrap(); + let df_schema = DFSchema::try_from(Arc::clone(schema)).unwrap(); let props = ExecutionProps::new(); let physical_expr = create_physical_expr(&expr, &df_schema, &props).unwrap(); - PruningPredicate::try_new(physical_expr, schema.clone()).unwrap() + PruningPredicate::try_new(physical_expr, Arc::clone(schema)).unwrap() } fn i32_array<'a>(values: impl Iterator>) -> ArrayRef { diff --git a/datafusion-examples/examples/query_planning/thread_pools.rs b/datafusion-examples/examples/query_planning/thread_pools.rs new file mode 100644 index 0000000000000..2ff73a77c4024 --- /dev/null +++ b/datafusion-examples/examples/query_planning/thread_pools.rs @@ -0,0 +1,355 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. +//! +//! This example shows how to use separate thread pools (tokio [`Runtime`]))s to +//! run the IO and CPU intensive parts of DataFusion plans. +//! +//! # Background +//! +//! DataFusion, by default, plans and executes all operations (both CPU and IO) +//! on the same thread pool. This makes it fast and easy to get started, but +//! can cause issues when running at scale, especially when fetching and operating +//! on data directly from remote sources. +//! +//! Specifically, without configuration such as in this example, DataFusion +//! plans and executes everything the same thread pool (Tokio Runtime), including +//! any I/O, such as reading Parquet files from remote object storage +//! (e.g. AWS S3), catalog access, and CPU intensive work. Running this diverse +//! workload can lead to issues described in the [Architecture section] such as +//! throttled network bandwidth (due to congestion control) and increased +//! latencies or timeouts while processing network messages. +//! +//! [Architecture section]: https://docs.rs/datafusion/latest/datafusion/index.html#thread-scheduling-cpu--io-thread-pools-and-tokio-runtimes + +use std::sync::Arc; + +use arrow::util::pretty::pretty_format_batches; +use datafusion::common::runtime::JoinSet; +use datafusion::error::Result; +use datafusion::execution::SendableRecordBatchStream; +use datafusion::prelude::*; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; +use futures::stream::StreamExt; +use object_store::client::SpawnedReqwestConnector; +use object_store::http::HttpBuilder; +use tokio::runtime::Handle; +use tokio::sync::Notify; +use url::Url; + +/// Normally, you don't need to worry about the details of the tokio +/// [`Runtime`], but for this example it is important to understand how the +/// [`Runtime`]s work. +/// +/// Each thread has "current" runtime that is installed in a thread local +/// variable which is used by the `tokio::spawn` function. +/// +/// The `#[tokio::main]` macro creates a [`Runtime`] and installs it as +/// as the "current" runtime in a thread local variable, on which any `async` +/// [`Future`], [`Stream]`s and [`Task]`s are run. +/// +/// This example uses the runtime created by [`tokio::main`] to do I/O and spawn +/// CPU intensive tasks on a separate [`Runtime`], mirroring the common pattern +/// when using Rust libraries such as `tonic`. Using a separate `Runtime` for +/// CPU bound tasks will often be simpler in larger applications, even though it +/// makes this example slightly more complex. +pub async fn thread_pools() -> Result<()> { + // The first two examples read local files. Enabling the URL table feature + // lets us treat filenames as tables in SQL. + let ctx = SessionContext::new().enable_url_table(); + + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; + + let sql = format!("SELECT * FROM '{}'", parquet_temp.path_str()?); + + // Run a query on the current runtime. Calling `await` means the future + // (in this case the `async` function and all spawned work in DataFusion + // plans) on the current runtime. + same_runtime(&ctx, &sql).await?; + + // Run the same query but this time on a different runtime. + // + // Since we call `await` here, the `async` function itself runs on the + // current runtime, but internally `different_runtime_basic` executes the + // DataFusion plan on a different Runtime. + different_runtime_basic(ctx, sql).await?; + + // Run the same query on a different runtime, including remote IO. + // + // NOTE: This is best practice for production systems + different_runtime_advanced().await?; + + Ok(()) +} + +/// Run queries directly on the current tokio `Runtime` +/// +/// This is how most examples in DataFusion are written and works well for +/// development, local query processing, and non latency sensitive workloads. +async fn same_runtime(ctx: &SessionContext, sql: &str) -> Result<()> { + // Calling .sql is an async function as it may also do network + // I/O, for example to contact a remote catalog or do an object store LIST + let df = ctx.sql(sql).await?; + + // While many examples call `collect` or `show()`, those methods buffers the + // results. Internally DataFusion generates output a RecordBatch at a time + + // Calling `execute_stream` return a `SendableRecordBatchStream`. Depending + // on the plan, this may also do network I/O, for example to begin reading a + // parquet file from a remote object store. + let mut stream: SendableRecordBatchStream = df.execute_stream().await?; + + // `next()` drives the plan, incrementally producing new `RecordBatch`es + // using the current runtime. + // + // Perhaps somewhat non obviously, calling `next()` can also result in other + // tasks being spawned on the current runtime (e.g. for `RepartitionExec` to + // read data from each of its input partitions in parallel). + // + // Executing the plan using this pattern intermixes any IO and CPU intensive + // work on same Runtime + while let Some(batch) = stream.next().await { + println!("{}", pretty_format_batches(&[batch?])?); + } + Ok(()) +} + +/// Run queries on a **different** Runtime dedicated for CPU bound work +/// +/// This example is suitable for running DataFusion plans against local data +/// sources (e.g. files) and returning results to an async destination, as might +/// be done to return query results to a remote client. +/// +/// Production systems which also read data locally or require very low latency +/// should follow the recommendations on [`different_runtime_advanced`] when +/// processing data from a remote source such as object storage. +async fn different_runtime_basic(ctx: SessionContext, sql: String) -> Result<()> { + // Since we are already in the context of runtime (installed by + // #[tokio::main]), we need a new Runtime (threadpool) for CPU bound tasks + let cpu_runtime = CpuRuntime::try_new()?; + + // Prepare a task that runs the plan on cpu_runtime and sends + // the results back to the original runtime via a channel. + let (tx, mut rx) = tokio::sync::mpsc::channel(2); + let driver_task = async move { + // Plan the query (which might require CPU work to evaluate statistics) + let df = ctx.sql(&sql).await?; + let mut stream: SendableRecordBatchStream = df.execute_stream().await?; + + // Calling `next()` to drive the plan in this task drives the + // execution from the cpu runtime the other thread pool + // + // NOTE any IO run by this plan (for example, reading from an + // `ObjectStore`) will be done on this new thread pool as well. + while let Some(batch) = stream.next().await { + if tx.send(batch).await.is_err() { + // error means dropped receiver, so nothing will get results anymore + return Ok(()); + } + } + Ok(()) as Result<()> + }; + + // Run the driver task on the cpu runtime. Use a JoinSet to + // ensure the spawned task is canceled on error/drop + let mut join_set = JoinSet::new(); + join_set.spawn_on(driver_task, cpu_runtime.handle()); + + // Retrieve the results in the original (IO) runtime. This requires only + // minimal work (pass pointers around). + while let Some(batch) = rx.recv().await { + println!("{}", pretty_format_batches(&[batch?])?); + } + + // wait for completion of the driver task + drain_join_set(join_set).await; + + Ok(()) +} + +/// Run CPU intensive work on a different runtime but do IO operations (object +/// store access) on the current runtime. +async fn different_runtime_advanced() -> Result<()> { + // In this example, we will query a file via https, reading + // the data directly from the plan + + // The current runtime (created by tokio::main) is used for IO + // + // Note this handle should be used for *ALL* remote IO operations in your + // systems, including remote catalog access, which is not included in this + // example. + let cpu_runtime = CpuRuntime::try_new()?; + let io_handle = Handle::current(); + + let ctx = SessionContext::new(); + + // By default, the HttpStore use the same runtime that calls `await` for IO + // operations. This means that if the DataFusion plan is called from the + // cpu_runtime, the HttpStore IO operations will *also* run on the CPU + // runtime, which will error. + // + // To avoid this, we use a `SpawnedReqwestConnector` to configure the + // `ObjectStore` to run the HTTP requests on the IO runtime. + let base_url = Url::parse("https://github.com").unwrap(); + let http_store = HttpBuilder::new() + .with_url(base_url.clone()) + // Use the io_runtime to run the HTTP requests. Without this line, + // you will see an error such as: + // A Tokio 1.x context was found, but IO is disabled. + .with_http_connector(SpawnedReqwestConnector::new(io_handle)) + .build()?; + + // Tell DataFusion to process `http://` urls with this wrapped object store + ctx.register_object_store(&base_url, Arc::new(http_store)); + + // As above, plan and execute the query on the cpu runtime. + let (tx, mut rx) = tokio::sync::mpsc::channel(2); + let driver_task = async move { + // Plan / execute the query + let url = "https://github.com/apache/arrow-testing/raw/master/data/csv/aggregate_test_100.csv"; + let df = ctx + .sql(&format!("SELECT c1,c2,c3 FROM '{url}' LIMIT 5")) + .await?; + + let mut stream: SendableRecordBatchStream = df.execute_stream().await?; + + // Note you can do other non trivial CPU work on the results of the + // stream before sending it back to the original runtime. For example, + // calling a FlightDataEncoder to convert the results to flight messages + // to send over the network + + // send results, as above + while let Some(batch) = stream.next().await { + if tx.send(batch).await.is_err() { + return Ok(()); + } + } + Ok(()) as Result<()> + }; + + let mut join_set = JoinSet::new(); + join_set.spawn_on(driver_task, cpu_runtime.handle()); + while let Some(batch) = rx.recv().await { + println!("{}", pretty_format_batches(&[batch?])?); + } + + Ok(()) +} + +/// Waits for all tasks in the JoinSet to complete and reports any errors that +/// occurred. +/// +/// If we don't do this, any errors that occur in the task (such as IO errors) +/// are not reported. +async fn drain_join_set(mut join_set: JoinSet>) { + // retrieve any errors from the tasks + while let Some(result) = join_set.join_next().await { + match result { + Ok(Ok(())) => {} // task completed successfully + Ok(Err(e)) => eprintln!("Task failed: {e}"), // task failed + Err(e) => eprintln!("JoinSet error: {e}"), // JoinSet error + } + } +} + +/// Creates a Tokio [`Runtime`] for use with CPU bound tasks +/// +/// Tokio forbids dropping `Runtime`s in async contexts, so creating a separate +/// `Runtime` correctly is somewhat tricky. This structure manages the creation +/// and shutdown of a separate thread. +/// +/// # Notes +/// On drop, the thread will wait for all remaining tasks to complete. +/// +/// Depending on your application, more sophisticated shutdown logic may be +/// required, such as ensuring that no new tasks are added to the runtime. +/// +/// # Credits +/// This code is derived from code originally written for [InfluxDB 3.0] +/// +/// [InfluxDB 3.0]: https://github.com/influxdata/influxdb3_core/tree/6fcbb004232738d55655f32f4ad2385523d10696/executor +struct CpuRuntime { + /// Handle is the tokio structure for interacting with a Runtime. + handle: Handle, + /// Signal to start shutting down + notify_shutdown: Arc, + /// When thread is active, is Some + thread_join_handle: Option>, +} + +impl Drop for CpuRuntime { + fn drop(&mut self) { + // Notify the thread to shutdown. + self.notify_shutdown.notify_one(); + // In a production system you also need to ensure your code stops adding + // new tasks to the underlying runtime after this point to allow the + // thread to complete its work and exit cleanly. + if let Some(thread_join_handle) = self.thread_join_handle.take() { + // If the thread is still running, we wait for it to finish + print!("Shutting down CPU runtime thread..."); + if let Err(e) = thread_join_handle.join() { + eprintln!("Error joining CPU runtime thread: {e:?}",); + } else { + println!("CPU runtime thread shutdown successfully."); + } + } + } +} + +impl CpuRuntime { + /// Create a new Tokio Runtime for CPU bound tasks + pub fn try_new() -> Result { + let cpu_runtime = tokio::runtime::Builder::new_multi_thread() + .enable_time() + .build()?; + let handle = cpu_runtime.handle().clone(); + let notify_shutdown = Arc::new(Notify::new()); + let notify_shutdown_captured = Arc::clone(¬ify_shutdown); + + // The cpu_runtime runs and is dropped on a separate thread + let thread_join_handle = std::thread::spawn(move || { + cpu_runtime.block_on(async move { + notify_shutdown_captured.notified().await; + }); + // Note: cpu_runtime is dropped here, which will wait for all tasks + // to complete + }); + + Ok(Self { + handle, + notify_shutdown, + thread_join_handle: Some(thread_join_handle), + }) + } + + /// Return a handle suitable for spawning CPU bound tasks + /// + /// # Notes + /// + /// If a task spawned on this handle attempts to do IO, it will error with a + /// message such as: + /// + /// ```text + /// A Tokio 1.x context was found, but IO is disabled. + /// ``` + pub fn handle(&self) -> &Handle { + &self.handle + } +} diff --git a/datafusion-examples/examples/relation_planner/main.rs b/datafusion-examples/examples/relation_planner/main.rs new file mode 100644 index 0000000000000..babc0d3714f72 --- /dev/null +++ b/datafusion-examples/examples/relation_planner/main.rs @@ -0,0 +1,127 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # Relation Planner Examples +//! +//! These examples demonstrate how to use custom relation planners to extend +//! DataFusion's SQL syntax with custom table operators. +//! +//! ## Usage +//! ```bash +//! cargo run --example relation_planner -- [all|match_recognize|pivot_unpivot|table_sample] +//! ``` +//! +//! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module +//! +//! - `match_recognize` +//! (file: match_recognize.rs, desc: Implement MATCH_RECOGNIZE pattern matching) +//! +//! - `pivot_unpivot` +//! (file: pivot_unpivot.rs, desc: Implement PIVOT / UNPIVOT) +//! +//! - `table_sample` +//! (file: table_sample.rs, desc: Implement TABLESAMPLE) +//! +//! ## Snapshot Testing +//! +//! These examples use [insta](https://insta.rs) for inline snapshot assertions. +//! If query output changes, regenerate the snapshots with: +//! ```bash +//! cargo insta test --example relation_planner --accept +//! ``` + +mod match_recognize; +mod pivot_unpivot; +mod table_sample; + +use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; + +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] +enum ExampleKind { + All, + MatchRecognize, + PivotUnpivot, + TableSample, +} + +impl ExampleKind { + const EXAMPLE_NAME: &str = "relation_planner"; + + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<()> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::MatchRecognize => match_recognize::match_recognize().await?, + ExampleKind::PivotUnpivot => pivot_unpivot::pivot_unpivot().await?, + ExampleKind::TableSample => table_sample::table_sample().await?, + } + + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let usage = format!( + "Usage: cargo run --example {} -- [{}]", + ExampleKind::EXAMPLE_NAME, + ExampleKind::VARIANTS.join("|") + ); + + let example: ExampleKind = std::env::args() + .nth(1) + .unwrap_or_else(|| ExampleKind::All.to_string()) + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; + + example.run().await +} + +/// Test wrappers that enable `cargo insta test --example relation_planner --accept` +/// to regenerate inline snapshots. Without these, insta cannot run the examples +/// in test mode since they only have `main()` functions. +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_match_recognize() { + match_recognize::match_recognize().await.unwrap(); + } + + #[tokio::test] + async fn test_pivot_unpivot() { + pivot_unpivot::pivot_unpivot().await.unwrap(); + } + + #[tokio::test] + async fn test_table_sample() { + table_sample::table_sample().await.unwrap(); + } +} diff --git a/datafusion-examples/examples/relation_planner/match_recognize.rs b/datafusion-examples/examples/relation_planner/match_recognize.rs new file mode 100644 index 0000000000000..c4b3d522efc17 --- /dev/null +++ b/datafusion-examples/examples/relation_planner/match_recognize.rs @@ -0,0 +1,408 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # MATCH_RECOGNIZE Example +//! +//! This example demonstrates implementing SQL `MATCH_RECOGNIZE` pattern matching +//! using a custom [`RelationPlanner`]. Unlike the [`pivot_unpivot`] example that +//! rewrites SQL to standard operations, this example creates a **custom logical +//! plan node** (`MiniMatchRecognizeNode`) to represent the operation. +//! +//! ## Supported Syntax +//! +//! ```sql +//! SELECT * FROM events +//! MATCH_RECOGNIZE ( +//! PARTITION BY region +//! MEASURES SUM(price) AS total, AVG(price) AS average +//! PATTERN (A B+ C) +//! DEFINE +//! A AS price < 100, +//! B AS price BETWEEN 100 AND 200, +//! C AS price > 200 +//! ) AS matches +//! ``` +//! +//! ## Architecture +//! +//! This example demonstrates **logical planning only**. Physical execution would +//! require implementing an [`ExecutionPlan`] (see the [`table_sample`] example +//! for a complete implementation with physical planning). +//! +//! ```text +//! SQL Query +//! │ +//! ▼ +//! ┌─────────────────────────────────────┐ +//! │ MatchRecognizePlanner │ +//! │ (RelationPlanner trait) │ +//! │ │ +//! │ • Parses MATCH_RECOGNIZE syntax │ +//! │ • Creates MiniMatchRecognizeNode │ +//! │ • Converts SQL exprs to DataFusion │ +//! └─────────────────────────────────────┘ +//! │ +//! ▼ +//! ┌─────────────────────────────────────┐ +//! │ MiniMatchRecognizeNode │ +//! │ (UserDefinedLogicalNode) │ +//! │ │ +//! │ • measures: [(alias, expr), ...] │ +//! │ • definitions: [(symbol, expr), ...]│ +//! └─────────────────────────────────────┘ +//! ``` +//! +//! [`pivot_unpivot`]: super::pivot_unpivot +//! [`table_sample`]: super::table_sample +//! [`ExecutionPlan`]: datafusion::physical_plan::ExecutionPlan + +use std::{any::Any, cmp::Ordering, hash::Hasher, sync::Arc}; + +use arrow::array::{ArrayRef, Float64Array, Int32Array, StringArray}; +use arrow::record_batch::RecordBatch; +use datafusion::prelude::*; +use datafusion_common::{DFSchemaRef, Result}; +use datafusion_expr::{ + Expr, UserDefinedLogicalNode, + logical_plan::{Extension, InvariantLevel, LogicalPlan}, + planner::{ + PlannedRelation, RelationPlanner, RelationPlannerContext, RelationPlanning, + }, +}; +use datafusion_sql::sqlparser::ast::TableFactor; +use insta::assert_snapshot; + +// ============================================================================ +// Example Entry Point +// ============================================================================ + +/// Runs the MATCH_RECOGNIZE examples demonstrating pattern matching on event streams. +/// +/// Note: This example demonstrates **logical planning only**. Physical execution +/// would require additional implementation of an [`ExecutionPlan`]. +pub async fn match_recognize() -> Result<()> { + let ctx = SessionContext::new(); + ctx.register_relation_planner(Arc::new(MatchRecognizePlanner))?; + register_sample_data(&ctx)?; + + println!("MATCH_RECOGNIZE Example (Logical Planning Only)"); + println!("================================================\n"); + + run_examples(&ctx).await +} + +async fn run_examples(ctx: &SessionContext) -> Result<()> { + // Example 1: Basic MATCH_RECOGNIZE with MEASURES and DEFINE + // Demonstrates: Aggregate measures over matched rows + let plan = run_example( + ctx, + "Example 1: MATCH_RECOGNIZE with aggregations", + r#"SELECT * FROM events + MATCH_RECOGNIZE ( + PARTITION BY 1 + MEASURES SUM(price) AS total_price, AVG(price) AS avg_price + PATTERN (A) + DEFINE A AS price > 10 + ) AS matches"#, + ) + .await?; + assert_snapshot!(plan, @r" + Projection: matches.price + SubqueryAlias: matches + MiniMatchRecognize measures=[total_price := sum(events.price), avg_price := avg(events.price)] define=[a := events.price > Int64(10)] + TableScan: events + "); + + // Example 2: Stock price pattern detection + // Demonstrates: Real-world use case finding prices above threshold + let plan = run_example( + ctx, + "Example 2: Detect high stock prices", + r#"SELECT * FROM stock_prices + MATCH_RECOGNIZE ( + MEASURES + MIN(price) AS min_price, + MAX(price) AS max_price, + AVG(price) AS avg_price + PATTERN (HIGH) + DEFINE HIGH AS price > 151.0 + ) AS trends"#, + ) + .await?; + assert_snapshot!(plan, @r" + Projection: trends.symbol, trends.price + SubqueryAlias: trends + MiniMatchRecognize measures=[min_price := min(stock_prices.price), max_price := max(stock_prices.price), avg_price := avg(stock_prices.price)] define=[high := stock_prices.price > Float64(151)] + TableScan: stock_prices + "); + + Ok(()) +} + +/// Helper to run a single example query and display the logical plan. +async fn run_example(ctx: &SessionContext, title: &str, sql: &str) -> Result { + println!("{title}:\n{sql}\n"); + let plan = ctx.sql(sql).await?.into_unoptimized_plan(); + let plan_str = plan.display_indent().to_string(); + println!("{plan_str}\n"); + Ok(plan_str) +} + +/// Register test data tables. +fn register_sample_data(ctx: &SessionContext) -> Result<()> { + // events: simple price series + ctx.register_batch( + "events", + RecordBatch::try_from_iter(vec![( + "price", + Arc::new(Int32Array::from(vec![5, 12, 8, 15, 20])) as ArrayRef, + )])?, + )?; + + // stock_prices: realistic stock data + ctx.register_batch( + "stock_prices", + RecordBatch::try_from_iter(vec![ + ( + "symbol", + Arc::new(StringArray::from(vec!["DDOG", "DDOG", "DDOG", "DDOG"])) + as ArrayRef, + ), + ( + "price", + Arc::new(Float64Array::from(vec![150.0, 155.0, 152.0, 158.0])), + ), + ])?, + )?; + + Ok(()) +} + +// ============================================================================ +// Logical Plan Node: MiniMatchRecognizeNode +// ============================================================================ + +/// A custom logical plan node representing MATCH_RECOGNIZE operations. +/// +/// This is a simplified implementation that captures the essential structure: +/// - `measures`: Aggregate expressions computed over matched rows +/// - `definitions`: Symbol definitions (predicate expressions) +/// +/// A production implementation would also include: +/// - Pattern specification (regex-like pattern) +/// - Partition and order by clauses +/// - Output mode (ONE ROW PER MATCH, ALL ROWS PER MATCH) +/// - After match skip strategy +#[derive(Debug)] +struct MiniMatchRecognizeNode { + input: Arc, + schema: DFSchemaRef, + /// Measures: (alias, aggregate_expr) + measures: Vec<(String, Expr)>, + /// Symbol definitions: (symbol_name, predicate_expr) + definitions: Vec<(String, Expr)>, +} + +impl UserDefinedLogicalNode for MiniMatchRecognizeNode { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "MiniMatchRecognize" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + fn schema(&self) -> &DFSchemaRef { + &self.schema + } + + fn check_invariants(&self, _check: InvariantLevel) -> Result<()> { + Ok(()) + } + + fn expressions(&self) -> Vec { + self.measures + .iter() + .chain(&self.definitions) + .map(|(_, expr)| expr.clone()) + .collect() + } + + fn fmt_for_explain(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "MiniMatchRecognize")?; + + if !self.measures.is_empty() { + write!(f, " measures=[")?; + for (i, (alias, expr)) in self.measures.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{alias} := {expr}")?; + } + write!(f, "]")?; + } + + if !self.definitions.is_empty() { + write!(f, " define=[")?; + for (i, (symbol, expr)) in self.definitions.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{symbol} := {expr}")?; + } + write!(f, "]")?; + } + + Ok(()) + } + + fn with_exprs_and_inputs( + &self, + exprs: Vec, + inputs: Vec, + ) -> Result> { + let expected_len = self.measures.len() + self.definitions.len(); + if exprs.len() != expected_len { + return Err(datafusion_common::plan_datafusion_err!( + "MiniMatchRecognize: expected {expected_len} expressions, got {}", + exprs.len() + )); + } + + let input = inputs.into_iter().next().ok_or_else(|| { + datafusion_common::plan_datafusion_err!( + "MiniMatchRecognize requires exactly one input" + ) + })?; + + let (measure_exprs, definition_exprs) = exprs.split_at(self.measures.len()); + + let measures = self + .measures + .iter() + .zip(measure_exprs) + .map(|((alias, _), expr)| (alias.clone(), expr.clone())) + .collect(); + + let definitions = self + .definitions + .iter() + .zip(definition_exprs) + .map(|((symbol, _), expr)| (symbol.clone(), expr.clone())) + .collect(); + + Ok(Arc::new(Self { + input: Arc::new(input), + schema: Arc::clone(&self.schema), + measures, + definitions, + })) + } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + state.write_usize(Arc::as_ptr(&self.input) as usize); + state.write_usize(self.measures.len()); + state.write_usize(self.definitions.len()); + } + + fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool { + other.as_any().downcast_ref::().is_some_and(|o| { + Arc::ptr_eq(&self.input, &o.input) + && self.measures == o.measures + && self.definitions == o.definitions + }) + } + + fn dyn_ord(&self, other: &dyn UserDefinedLogicalNode) -> Option { + if self.dyn_eq(other) { + Some(Ordering::Equal) + } else { + None + } + } +} + +// ============================================================================ +// Relation Planner: MatchRecognizePlanner +// ============================================================================ + +/// Relation planner that creates `MiniMatchRecognizeNode` for MATCH_RECOGNIZE queries. +#[derive(Debug)] +struct MatchRecognizePlanner; + +impl RelationPlanner for MatchRecognizePlanner { + fn plan_relation( + &self, + relation: TableFactor, + ctx: &mut dyn RelationPlannerContext, + ) -> Result { + let TableFactor::MatchRecognize { + table, + measures, + symbols, + alias, + .. + } = relation + else { + return Ok(RelationPlanning::Original(Box::new(relation))); + }; + + // Plan the input table + let input = ctx.plan(*table)?; + let schema = input.schema().clone(); + + // Convert MEASURES: SQL expressions → DataFusion expressions + let planned_measures: Vec<(String, Expr)> = measures + .iter() + .map(|m| { + let alias = ctx.normalize_ident(m.alias.clone()); + let expr = ctx.sql_to_expr(m.expr.clone(), schema.as_ref())?; + Ok((alias, expr)) + }) + .collect::>()?; + + // Convert DEFINE: symbol definitions → DataFusion expressions + let planned_definitions: Vec<(String, Expr)> = symbols + .iter() + .map(|s| { + let name = ctx.normalize_ident(s.symbol.clone()); + let expr = ctx.sql_to_expr(s.definition.clone(), schema.as_ref())?; + Ok((name, expr)) + }) + .collect::>()?; + + // Create the custom node + let node = MiniMatchRecognizeNode { + input: Arc::new(input), + schema, + measures: planned_measures, + definitions: planned_definitions, + }; + + let plan = LogicalPlan::Extension(Extension { + node: Arc::new(node), + }); + + Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))) + } +} diff --git a/datafusion-examples/examples/relation_planner/pivot_unpivot.rs b/datafusion-examples/examples/relation_planner/pivot_unpivot.rs new file mode 100644 index 0000000000000..2e1696956bf62 --- /dev/null +++ b/datafusion-examples/examples/relation_planner/pivot_unpivot.rs @@ -0,0 +1,571 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # PIVOT and UNPIVOT Example +//! +//! This example demonstrates implementing SQL `PIVOT` and `UNPIVOT` operations +//! using a custom [`RelationPlanner`]. Unlike the other examples that create +//! custom logical/physical nodes, this example shows how to **rewrite** SQL +//! constructs into equivalent standard SQL operations: +//! +//! ## Supported Syntax +//! +//! ```sql +//! -- PIVOT: Transform rows into columns +//! SELECT * FROM sales +//! PIVOT (SUM(amount) FOR quarter IN ('Q1', 'Q2', 'Q3', 'Q4')) +//! +//! -- UNPIVOT: Transform columns into rows +//! SELECT * FROM wide_table +//! UNPIVOT (value FOR name IN (col1, col2, col3)) +//! ``` +//! +//! ## Rewrite Strategy +//! +//! **PIVOT** is rewritten to `GROUP BY` with `CASE` expressions: +//! ```sql +//! -- Original: +//! SELECT * FROM sales PIVOT (SUM(amount) FOR quarter IN ('Q1', 'Q2')) +//! +//! -- Rewritten to: +//! SELECT region, +//! SUM(CASE quarter WHEN 'Q1' THEN amount END) AS Q1, +//! SUM(CASE quarter WHEN 'Q2' THEN amount END) AS Q2 +//! FROM sales +//! GROUP BY region +//! ``` +//! +//! **UNPIVOT** is rewritten to `UNION ALL` of projections: +//! ```sql +//! -- Original: +//! SELECT * FROM wide UNPIVOT (sales FOR quarter IN (q1, q2)) +//! +//! -- Rewritten to: +//! SELECT region, 'q1' AS quarter, q1 AS sales FROM wide +//! UNION ALL +//! SELECT region, 'q2' AS quarter, q2 AS sales FROM wide +//! ``` + +use std::sync::Arc; + +use arrow::array::{ArrayRef, Int64Array, StringArray}; +use arrow::record_batch::RecordBatch; +use datafusion::prelude::*; +use datafusion_common::{Result, ScalarValue, plan_datafusion_err}; +use datafusion_expr::{ + Expr, case, col, lit, + logical_plan::builder::LogicalPlanBuilder, + planner::{ + PlannedRelation, RelationPlanner, RelationPlannerContext, RelationPlanning, + }, +}; +use datafusion_sql::sqlparser::ast::{NullInclusion, PivotValueSource, TableFactor}; +use insta::assert_snapshot; + +// ============================================================================ +// Example Entry Point +// ============================================================================ + +/// Runs the PIVOT/UNPIVOT examples demonstrating data reshaping operations. +pub async fn pivot_unpivot() -> Result<()> { + let ctx = SessionContext::new(); + ctx.register_relation_planner(Arc::new(PivotUnpivotPlanner))?; + register_sample_data(&ctx)?; + + println!("PIVOT and UNPIVOT Example"); + println!("=========================\n"); + + run_examples(&ctx).await +} + +async fn run_examples(ctx: &SessionContext) -> Result<()> { + // ----- PIVOT Examples ----- + + // Example 1: Basic PIVOT + // Transforms: (region, quarter, amount) → (region, Q1, Q2) + let results = run_example( + ctx, + "Example 1: Basic PIVOT", + r#"SELECT * FROM quarterly_sales + PIVOT (SUM(amount) FOR quarter IN ('Q1', 'Q2')) AS p + ORDER BY region"#, + ) + .await?; + assert_snapshot!(results, @r" + +--------+------+------+ + | region | Q1 | Q2 | + +--------+------+------+ + | North | 1000 | 1500 | + | South | 1200 | 1300 | + +--------+------+------+ + "); + + // Example 2: PIVOT with multiple aggregates + // Creates columns for each (aggregate, value) combination + let results = run_example( + ctx, + "Example 2: PIVOT with multiple aggregates", + r#"SELECT * FROM quarterly_sales + PIVOT (SUM(amount), AVG(amount) FOR quarter IN ('Q1', 'Q2')) AS p + ORDER BY region"#, + ) + .await?; + assert_snapshot!(results, @r" + +--------+--------+--------+--------+--------+ + | region | sum_Q1 | sum_Q2 | avg_Q1 | avg_Q2 | + +--------+--------+--------+--------+--------+ + | North | 1000 | 1500 | 1000.0 | 1500.0 | + | South | 1200 | 1300 | 1200.0 | 1300.0 | + +--------+--------+--------+--------+--------+ + "); + + // Example 3: PIVOT with multiple grouping columns + // Non-pivot, non-aggregate columns become GROUP BY columns + let results = run_example( + ctx, + "Example 3: PIVOT with multiple grouping columns", + r#"SELECT * FROM product_sales + PIVOT (SUM(amount) FOR quarter IN ('Q1', 'Q2')) AS p + ORDER BY region, product"#, + ) + .await?; + assert_snapshot!(results, @r" + +--------+----------+-----+-----+ + | region | product | Q1 | Q2 | + +--------+----------+-----+-----+ + | North | ProductA | 500 | | + | North | ProductB | 500 | | + | South | ProductA | | 650 | + +--------+----------+-----+-----+ + "); + + // ----- UNPIVOT Examples ----- + + // Example 4: Basic UNPIVOT + // Transforms: (region, q1, q2) → (region, quarter, sales) + let results = run_example( + ctx, + "Example 4: Basic UNPIVOT", + r#"SELECT * FROM wide_sales + UNPIVOT (sales FOR quarter IN (q1 AS 'Q1', q2 AS 'Q2')) AS u + ORDER BY quarter, region"#, + ) + .await?; + assert_snapshot!(results, @r" + +--------+---------+-------+ + | region | quarter | sales | + +--------+---------+-------+ + | North | Q1 | 1000 | + | South | Q1 | 1200 | + | North | Q2 | 1500 | + | South | Q2 | 1300 | + +--------+---------+-------+ + "); + + // Example 5: UNPIVOT with INCLUDE NULLS + // By default, UNPIVOT excludes rows where the value column is NULL. + // INCLUDE NULLS keeps them (same result here since no NULLs in data). + let results = run_example( + ctx, + "Example 5: UNPIVOT INCLUDE NULLS", + r#"SELECT * FROM wide_sales + UNPIVOT INCLUDE NULLS (sales FOR quarter IN (q1 AS 'Q1', q2 AS 'Q2')) AS u + ORDER BY quarter, region"#, + ) + .await?; + assert_snapshot!(results, @r" + +--------+---------+-------+ + | region | quarter | sales | + +--------+---------+-------+ + | North | Q1 | 1000 | + | South | Q1 | 1200 | + | North | Q2 | 1500 | + | South | Q2 | 1300 | + +--------+---------+-------+ + "); + + // Example 6: PIVOT with column projection + // Standard SQL operations work seamlessly after PIVOT + let results = run_example( + ctx, + "Example 6: PIVOT with projection", + r#"SELECT region FROM quarterly_sales + PIVOT (SUM(amount) FOR quarter IN ('Q1', 'Q2')) AS p + ORDER BY region"#, + ) + .await?; + assert_snapshot!(results, @r" + +--------+ + | region | + +--------+ + | North | + | South | + +--------+ + "); + + Ok(()) +} + +/// Helper to run a single example query and capture results. +async fn run_example(ctx: &SessionContext, title: &str, sql: &str) -> Result { + println!("{title}:\n{sql}\n"); + let df = ctx.sql(sql).await?; + println!("{}\n", df.logical_plan().display_indent()); + + let batches = df.collect().await?; + let results = arrow::util::pretty::pretty_format_batches(&batches)?.to_string(); + println!("{results}\n"); + + Ok(results) +} + +/// Register test data tables. +fn register_sample_data(ctx: &SessionContext) -> Result<()> { + // quarterly_sales: normalized sales data (region, quarter, amount) + ctx.register_batch( + "quarterly_sales", + RecordBatch::try_from_iter(vec![ + ( + "region", + Arc::new(StringArray::from(vec!["North", "North", "South", "South"])) + as ArrayRef, + ), + ( + "quarter", + Arc::new(StringArray::from(vec!["Q1", "Q2", "Q1", "Q2"])), + ), + ( + "amount", + Arc::new(Int64Array::from(vec![1000, 1500, 1200, 1300])), + ), + ])?, + )?; + + // product_sales: sales with additional grouping dimension + ctx.register_batch( + "product_sales", + RecordBatch::try_from_iter(vec![ + ( + "region", + Arc::new(StringArray::from(vec!["North", "North", "South"])) as ArrayRef, + ), + ( + "quarter", + Arc::new(StringArray::from(vec!["Q1", "Q1", "Q2"])), + ), + ( + "product", + Arc::new(StringArray::from(vec!["ProductA", "ProductB", "ProductA"])), + ), + ("amount", Arc::new(Int64Array::from(vec![500, 500, 650]))), + ])?, + )?; + + // wide_sales: denormalized/wide format (for UNPIVOT) + ctx.register_batch( + "wide_sales", + RecordBatch::try_from_iter(vec![ + ( + "region", + Arc::new(StringArray::from(vec!["North", "South"])) as ArrayRef, + ), + ("q1", Arc::new(Int64Array::from(vec![1000, 1200]))), + ("q2", Arc::new(Int64Array::from(vec![1500, 1300]))), + ])?, + )?; + + Ok(()) +} + +// ============================================================================ +// Relation Planner: PivotUnpivotPlanner +// ============================================================================ + +/// Relation planner that rewrites PIVOT and UNPIVOT into standard SQL. +#[derive(Debug)] +struct PivotUnpivotPlanner; + +impl RelationPlanner for PivotUnpivotPlanner { + fn plan_relation( + &self, + relation: TableFactor, + ctx: &mut dyn RelationPlannerContext, + ) -> Result { + match relation { + TableFactor::Pivot { + table, + aggregate_functions, + value_column, + value_source, + alias, + .. + } => plan_pivot( + ctx, + *table, + &aggregate_functions, + &value_column, + value_source, + alias, + ), + + TableFactor::Unpivot { + table, + value, + name, + columns, + null_inclusion, + alias, + } => plan_unpivot( + ctx, + *table, + &value, + name, + &columns, + null_inclusion.as_ref(), + alias, + ), + + other => Ok(RelationPlanning::Original(Box::new(other))), + } + } +} + +// ============================================================================ +// PIVOT Implementation +// ============================================================================ + +/// Rewrite PIVOT to GROUP BY with CASE expressions. +fn plan_pivot( + ctx: &mut dyn RelationPlannerContext, + table: TableFactor, + aggregate_functions: &[datafusion_sql::sqlparser::ast::ExprWithAlias], + value_column: &[datafusion_sql::sqlparser::ast::Expr], + value_source: PivotValueSource, + alias: Option, +) -> Result { + // Plan the input table + let input = ctx.plan(table)?; + let schema = input.schema(); + + // Parse aggregate functions + let aggregates: Vec = aggregate_functions + .iter() + .map(|agg| ctx.sql_to_expr(agg.expr.clone(), schema.as_ref())) + .collect::>()?; + + // Get the pivot column (only single-column pivot supported) + if value_column.len() != 1 { + return Err(plan_datafusion_err!( + "Only single-column PIVOT is supported" + )); + } + let pivot_col = ctx.sql_to_expr(value_column[0].clone(), schema.as_ref())?; + let pivot_col_name = extract_column_name(&pivot_col)?; + + // Parse pivot values + let pivot_values = match value_source { + PivotValueSource::List(list) => list + .iter() + .map(|item| { + let alias = item + .alias + .as_ref() + .map(|id| ctx.normalize_ident(id.clone())); + let expr = ctx.sql_to_expr(item.expr.clone(), schema.as_ref())?; + Ok((alias, expr)) + }) + .collect::>>()?, + _ => { + return Err(plan_datafusion_err!( + "Dynamic PIVOT (ANY/Subquery) is not supported" + )); + } + }; + + // Determine GROUP BY columns (non-pivot, non-aggregate columns) + let agg_input_cols: Vec<&str> = aggregates + .iter() + .filter_map(|agg| { + if let Expr::AggregateFunction(f) = agg { + f.params.args.first().and_then(|e| { + if let Expr::Column(c) = e { + Some(c.name.as_str()) + } else { + None + } + }) + } else { + None + } + }) + .collect(); + + let group_by_cols: Vec = schema + .fields() + .iter() + .map(|f| f.name().as_str()) + .filter(|name| *name != pivot_col_name.as_str() && !agg_input_cols.contains(name)) + .map(col) + .collect(); + + // Build CASE expressions for each (aggregate, pivot_value) pair + let mut pivot_exprs = Vec::new(); + for agg in &aggregates { + let Expr::AggregateFunction(agg_fn) = agg else { + continue; + }; + let Some(agg_input) = agg_fn.params.args.first().cloned() else { + continue; + }; + + for (value_alias, pivot_value) in &pivot_values { + // CASE pivot_col WHEN pivot_value THEN agg_input END + let case_expr = case(col(&pivot_col_name)) + .when(pivot_value.clone(), agg_input.clone()) + .end()?; + + // Wrap in aggregate function + let pivoted = agg_fn.func.call(vec![case_expr]); + + // Determine column alias + let value_str = value_alias + .clone() + .unwrap_or_else(|| expr_to_string(pivot_value)); + let col_alias = if aggregates.len() > 1 { + format!("{}_{}", agg_fn.func.name(), value_str) + } else { + value_str + }; + + pivot_exprs.push(pivoted.alias(col_alias)); + } + } + + let plan = LogicalPlanBuilder::from(input) + .aggregate(group_by_cols, pivot_exprs)? + .build()?; + + Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))) +} + +// ============================================================================ +// UNPIVOT Implementation +// ============================================================================ + +/// Rewrite UNPIVOT to UNION ALL of projections. +fn plan_unpivot( + ctx: &mut dyn RelationPlannerContext, + table: TableFactor, + value: &datafusion_sql::sqlparser::ast::Expr, + name: datafusion_sql::sqlparser::ast::Ident, + columns: &[datafusion_sql::sqlparser::ast::ExprWithAlias], + null_inclusion: Option<&NullInclusion>, + alias: Option, +) -> Result { + // Plan the input table + let input = ctx.plan(table)?; + let schema = input.schema(); + + // Output column names + let value_col_name = value.to_string(); + let name_col_name = ctx.normalize_ident(name); + + // Parse columns to unpivot: (source_column, label) + let unpivot_cols: Vec<(String, String)> = columns + .iter() + .map(|c| { + let label = c + .alias + .as_ref() + .map(|id| ctx.normalize_ident(id.clone())) + .unwrap_or_else(|| c.expr.to_string()); + let expr = ctx.sql_to_expr(c.expr.clone(), schema.as_ref())?; + let col_name = extract_column_name(&expr)?; + Ok((col_name.to_string(), label)) + }) + .collect::>()?; + + // Columns to preserve (not being unpivoted) + let keep_cols: Vec<&str> = schema + .fields() + .iter() + .map(|f| f.name().as_str()) + .filter(|name| !unpivot_cols.iter().any(|(c, _)| c == *name)) + .collect(); + + // Build UNION ALL: one SELECT per unpivot column + if unpivot_cols.is_empty() { + return Err(plan_datafusion_err!("UNPIVOT requires at least one column")); + } + + let mut union_inputs: Vec<_> = unpivot_cols + .iter() + .map(|(col_name, label)| { + let mut projection: Vec = keep_cols.iter().map(|c| col(*c)).collect(); + projection.push(lit(label.clone()).alias(&name_col_name)); + projection.push(col(col_name).alias(&value_col_name)); + + LogicalPlanBuilder::from(input.clone()) + .project(projection)? + .build() + }) + .collect::>()?; + + // Combine with UNION ALL + let mut plan = union_inputs.remove(0); + for branch in union_inputs { + plan = LogicalPlanBuilder::from(plan).union(branch)?.build()?; + } + + // Apply EXCLUDE NULLS filter (default behavior) + let exclude_nulls = null_inclusion.is_none() + || matches!(null_inclusion, Some(&NullInclusion::ExcludeNulls)); + if exclude_nulls { + plan = LogicalPlanBuilder::from(plan) + .filter(col(&value_col_name).is_not_null())? + .build()?; + } + + Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))) +} + +// ============================================================================ +// Helpers +// ============================================================================ + +/// Extract column name from an expression. +fn extract_column_name(expr: &Expr) -> Result { + match expr { + Expr::Column(c) => Ok(c.name.clone()), + _ => Err(plan_datafusion_err!( + "Expected column reference, got {expr}" + )), + } +} + +/// Convert an expression to a string for use as column alias. +fn expr_to_string(expr: &Expr) -> String { + match expr { + Expr::Literal(ScalarValue::Utf8(Some(s)), _) => s.clone(), + Expr::Literal(v, _) => v.to_string(), + other => other.to_string(), + } +} diff --git a/datafusion-examples/examples/relation_planner/table_sample.rs b/datafusion-examples/examples/relation_planner/table_sample.rs new file mode 100644 index 0000000000000..04e5efd9706a6 --- /dev/null +++ b/datafusion-examples/examples/relation_planner/table_sample.rs @@ -0,0 +1,836 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # TABLESAMPLE Example +//! +//! This example demonstrates implementing SQL `TABLESAMPLE` support using +//! DataFusion's extensibility APIs. +//! +//! This is a working `TABLESAMPLE` implementation that can serve as a starting +//! point for your own projects. It also works as a template for adding other +//! custom SQL operators, covering the full pipeline from parsing to execution. +//! +//! It shows how to: +//! +//! 1. **Parse** TABLESAMPLE syntax via a custom [`RelationPlanner`] +//! 2. **Plan** sampling as a custom logical node ([`TableSamplePlanNode`]) +//! 3. **Execute** sampling via a custom physical operator ([`SampleExec`]) +//! +//! ## Supported Syntax +//! +//! ```sql +//! -- Bernoulli sampling (each row has N% chance of selection) +//! SELECT * FROM table TABLESAMPLE BERNOULLI(10 PERCENT) +//! +//! -- Fractional sampling (0.0 to 1.0) +//! SELECT * FROM table TABLESAMPLE (0.1) +//! +//! -- Row count limit +//! SELECT * FROM table TABLESAMPLE (100 ROWS) +//! +//! -- Reproducible sampling with a seed +//! SELECT * FROM table TABLESAMPLE (10 PERCENT) REPEATABLE(42) +//! ``` +//! +//! ## Architecture +//! +//! ```text +//! ┌─────────────────────────────────────────────────────────────────┐ +//! │ SQL Query │ +//! │ SELECT * FROM t TABLESAMPLE BERNOULLI(10 PERCENT) REPEATABLE(1)│ +//! └─────────────────────────────────────────────────────────────────┘ +//! │ +//! ▼ +//! ┌─────────────────────────────────────────────────────────────────┐ +//! │ TableSamplePlanner │ +//! │ (RelationPlanner: parses TABLESAMPLE, creates logical node) │ +//! └─────────────────────────────────────────────────────────────────┘ +//! │ +//! ▼ +//! ┌─────────────────────────────────────────────────────────────────┐ +//! │ TableSamplePlanNode │ +//! │ (UserDefinedLogicalNode: stores sampling params) │ +//! └─────────────────────────────────────────────────────────────────┘ +//! │ +//! ▼ +//! ┌─────────────────────────────────────────────────────────────────┐ +//! │ TableSampleExtensionPlanner │ +//! │ (ExtensionPlanner: creates physical execution plan) │ +//! └─────────────────────────────────────────────────────────────────┘ +//! │ +//! ▼ +//! ┌─────────────────────────────────────────────────────────────────┐ +//! │ SampleExec │ +//! │ (ExecutionPlan: performs actual row sampling at runtime) │ +//! └─────────────────────────────────────────────────────────────────┘ +//! ``` + +use std::{ + any::Any, + fmt::{self, Debug, Formatter}, + hash::{Hash, Hasher}, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use arrow::datatypes::{Float64Type, Int64Type}; +use arrow::{ + array::{ArrayRef, Int32Array, RecordBatch, StringArray, UInt32Array}, + compute, +}; +use arrow_schema::SchemaRef; +use futures::{ + ready, + stream::{Stream, StreamExt}, +}; +use rand::{Rng, SeedableRng, rngs::StdRng}; +use tonic::async_trait; + +use datafusion::optimizer::simplify_expressions::simplify_literal::parse_literal; +use datafusion::{ + execution::{ + RecordBatchStream, SendableRecordBatchStream, SessionState, SessionStateBuilder, + TaskContext, context::QueryPlanner, + }, + physical_expr::EquivalenceProperties, + physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput}, + }, + physical_planner::{DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner}, + prelude::*, +}; +use datafusion_common::{ + DFSchemaRef, DataFusionError, Result, Statistics, internal_err, not_impl_err, + plan_datafusion_err, plan_err, tree_node::TreeNodeRecursion, +}; +use datafusion_expr::{ + UserDefinedLogicalNode, UserDefinedLogicalNodeCore, + logical_plan::{Extension, LogicalPlan, LogicalPlanBuilder}, + planner::{ + PlannedRelation, RelationPlanner, RelationPlannerContext, RelationPlanning, + }, +}; +use datafusion_sql::sqlparser::ast::{ + self, TableFactor, TableSampleMethod, TableSampleUnit, +}; +use insta::assert_snapshot; + +// ============================================================================ +// Example Entry Point +// ============================================================================ + +/// Runs the TABLESAMPLE examples demonstrating various sampling techniques. +pub async fn table_sample() -> Result<()> { + // Build session with custom query planner for physical planning + let state = SessionStateBuilder::new() + .with_default_features() + .with_query_planner(Arc::new(TableSampleQueryPlanner)) + .build(); + + let ctx = SessionContext::new_with_state(state); + + // Register custom relation planner for logical planning + ctx.register_relation_planner(Arc::new(TableSamplePlanner))?; + register_sample_data(&ctx)?; + + println!("TABLESAMPLE Example"); + println!("===================\n"); + + run_examples(&ctx).await +} + +async fn run_examples(ctx: &SessionContext) -> Result<()> { + // Example 1: Baseline - full table scan + let results = run_example( + ctx, + "Example 1: Full table (baseline)", + "SELECT * FROM sample_data", + ) + .await?; + assert_snapshot!(results, @r" + +---------+---------+ + | column1 | column2 | + +---------+---------+ + | 1 | row_1 | + | 2 | row_2 | + | 3 | row_3 | + | 4 | row_4 | + | 5 | row_5 | + | 6 | row_6 | + | 7 | row_7 | + | 8 | row_8 | + | 9 | row_9 | + | 10 | row_10 | + +---------+---------+ + "); + + // Example 2: Percentage-based Bernoulli sampling + // REPEATABLE(seed) ensures deterministic results for snapshot testing + let results = run_example( + ctx, + "Example 2: BERNOULLI percentage sampling", + "SELECT * FROM sample_data TABLESAMPLE BERNOULLI(30 PERCENT) REPEATABLE(123)", + ) + .await?; + assert_snapshot!(results, @r" + +---------+---------+ + | column1 | column2 | + +---------+---------+ + | 1 | row_1 | + | 2 | row_2 | + | 7 | row_7 | + | 8 | row_8 | + +---------+---------+ + "); + + // Example 3: Fractional sampling (0.0 to 1.0) + // REPEATABLE(seed) ensures deterministic results for snapshot testing + let results = run_example( + ctx, + "Example 3: Fractional sampling", + "SELECT * FROM sample_data TABLESAMPLE (0.5) REPEATABLE(456)", + ) + .await?; + assert_snapshot!(results, @r" + +---------+---------+ + | column1 | column2 | + +---------+---------+ + | 2 | row_2 | + | 4 | row_4 | + | 8 | row_8 | + +---------+---------+ + "); + + // Example 4: Row count limit (deterministic, no seed needed) + let results = run_example( + ctx, + "Example 4: Row count limit", + "SELECT * FROM sample_data TABLESAMPLE (3 ROWS)", + ) + .await?; + assert_snapshot!(results, @r" + +---------+---------+ + | column1 | column2 | + +---------+---------+ + | 1 | row_1 | + | 2 | row_2 | + | 3 | row_3 | + +---------+---------+ + "); + + // Example 5: Sampling combined with filtering + let results = run_example( + ctx, + "Example 5: Sampling with WHERE clause", + "SELECT * FROM sample_data TABLESAMPLE (5 ROWS) WHERE column1 > 2", + ) + .await?; + assert_snapshot!(results, @r" + +---------+---------+ + | column1 | column2 | + +---------+---------+ + | 3 | row_3 | + | 4 | row_4 | + | 5 | row_5 | + +---------+---------+ + "); + + // Example 6: Sampling in JOIN queries + // REPEATABLE(seed) ensures deterministic results for snapshot testing + let results = run_example( + ctx, + "Example 6: Sampling in JOINs", + r#"SELECT t1.column1, t2.column1, t1.column2, t2.column2 + FROM sample_data t1 TABLESAMPLE (0.7) REPEATABLE(789) + JOIN sample_data t2 TABLESAMPLE (0.7) REPEATABLE(123) + ON t1.column1 = t2.column1"#, + ) + .await?; + assert_snapshot!(results, @r" + +---------+---------+---------+---------+ + | column1 | column1 | column2 | column2 | + +---------+---------+---------+---------+ + | 2 | 2 | row_2 | row_2 | + | 5 | 5 | row_5 | row_5 | + | 7 | 7 | row_7 | row_7 | + | 8 | 8 | row_8 | row_8 | + | 10 | 10 | row_10 | row_10 | + +---------+---------+---------+---------+ + "); + + Ok(()) +} + +/// Helper to run a single example query and capture results. +async fn run_example(ctx: &SessionContext, title: &str, sql: &str) -> Result { + println!("{title}:\n{sql}\n"); + let df = ctx.sql(sql).await?; + println!("{}\n", df.logical_plan().display_indent()); + + let batches = df.collect().await?; + let results = arrow::util::pretty::pretty_format_batches(&batches)?.to_string(); + println!("{results}\n"); + + Ok(results) +} + +/// Register test data: 10 rows with column1=1..10 and column2="row_1".."row_10" +fn register_sample_data(ctx: &SessionContext) -> Result<()> { + let column1: ArrayRef = Arc::new(Int32Array::from((1..=10).collect::>())); + let column2: ArrayRef = Arc::new(StringArray::from( + (1..=10).map(|i| format!("row_{i}")).collect::>(), + )); + let batch = + RecordBatch::try_from_iter(vec![("column1", column1), ("column2", column2)])?; + ctx.register_batch("sample_data", batch)?; + Ok(()) +} + +// ============================================================================ +// Logical Planning: TableSamplePlanner + TableSamplePlanNode +// ============================================================================ + +/// Relation planner that intercepts `TABLESAMPLE` clauses in SQL and creates +/// [`TableSamplePlanNode`] logical nodes. +#[derive(Debug)] +struct TableSamplePlanner; + +impl RelationPlanner for TableSamplePlanner { + fn plan_relation( + &self, + relation: TableFactor, + context: &mut dyn RelationPlannerContext, + ) -> Result { + // Only handle Table relations with TABLESAMPLE clause + let TableFactor::Table { + sample: Some(sample), + alias, + name, + args, + with_hints, + version, + with_ordinality, + partitions, + json_path, + index_hints, + } = relation + else { + return Ok(RelationPlanning::Original(Box::new(relation))); + }; + + // Extract sample spec (handles both before/after alias positions) + let sample = match sample { + ast::TableSampleKind::BeforeTableAlias(s) + | ast::TableSampleKind::AfterTableAlias(s) => s, + }; + + // Validate sampling method + if let Some(method) = &sample.name + && *method != TableSampleMethod::Bernoulli + && *method != TableSampleMethod::Row + { + return not_impl_err!( + "Sampling method {} is not supported (only BERNOULLI and ROW)", + method + ); + } + + // Offset sampling (ClickHouse-style) not supported + if sample.offset.is_some() { + return not_impl_err!( + "TABLESAMPLE with OFFSET is not supported (requires total row count)" + ); + } + + // Parse optional REPEATABLE seed + let seed = sample + .seed + .map(|s| { + s.value.to_string().parse::().map_err(|_| { + plan_datafusion_err!("REPEATABLE seed must be an integer") + }) + }) + .transpose()?; + + // Plan the underlying table without the sample clause + let base_relation = TableFactor::Table { + sample: None, + alias: alias.clone(), + name, + args, + with_hints, + version, + with_ordinality, + partitions, + json_path, + index_hints, + }; + let input = context.plan(base_relation)?; + + // Handle bucket sampling (Hive-style: TABLESAMPLE(BUCKET x OUT OF y)) + if let Some(bucket) = sample.bucket { + if bucket.on.is_some() { + return not_impl_err!( + "TABLESAMPLE BUCKET with ON clause requires CLUSTERED BY table" + ); + } + let bucket_num: u64 = + bucket.bucket.to_string().parse().map_err(|_| { + plan_datafusion_err!("bucket number must be an integer") + })?; + let total: u64 = + bucket.total.to_string().parse().map_err(|_| { + plan_datafusion_err!("bucket total must be an integer") + })?; + + let fraction = bucket_num as f64 / total as f64; + let plan = TableSamplePlanNode::new(input, fraction, seed).into_plan(); + return Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))); + } + + // Handle quantity-based sampling + let Some(quantity) = sample.quantity else { + return plan_err!( + "TABLESAMPLE requires a quantity (percentage, fraction, or row count)" + ); + }; + let quantity_value_expr = context.sql_to_expr(quantity.value, input.schema())?; + + match quantity.unit { + // TABLESAMPLE (N ROWS) - exact row limit + Some(TableSampleUnit::Rows) => { + let rows: i64 = parse_literal::(&quantity_value_expr)?; + if rows < 0 { + return plan_err!("row count must be non-negative, got {}", rows); + } + let plan = LogicalPlanBuilder::from(input) + .limit(0, Some(rows as usize))? + .build()?; + Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))) + } + + // TABLESAMPLE (N PERCENT) - percentage sampling + Some(TableSampleUnit::Percent) => { + let percent: f64 = parse_literal::(&quantity_value_expr)?; + let fraction = percent / 100.0; + let plan = TableSamplePlanNode::new(input, fraction, seed).into_plan(); + Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))) + } + + // TABLESAMPLE (N) - fraction if <1.0, row limit if >=1.0 + None => { + let value = parse_literal::(&quantity_value_expr)?; + if value < 0.0 { + return plan_err!("sample value must be non-negative, got {}", value); + } + let plan = if value >= 1.0 { + // Interpret as row limit + LogicalPlanBuilder::from(input) + .limit(0, Some(value as usize))? + .build()? + } else { + // Interpret as fraction + TableSamplePlanNode::new(input, value, seed).into_plan() + }; + Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))) + } + } + } +} + +/// Custom logical plan node representing a TABLESAMPLE operation. +/// +/// Stores sampling parameters (bounds, seed) and wraps the input plan. +/// Gets converted to [`SampleExec`] during physical planning. +#[derive(Debug, Clone, Hash, Eq, PartialEq, PartialOrd)] +struct TableSamplePlanNode { + input: LogicalPlan, + lower_bound: HashableF64, + upper_bound: HashableF64, + seed: u64, +} + +impl TableSamplePlanNode { + /// Create a new sampling node with the given fraction (0.0 to 1.0). + fn new(input: LogicalPlan, fraction: f64, seed: Option) -> Self { + Self { + input, + lower_bound: HashableF64(0.0), + upper_bound: HashableF64(fraction), + seed: seed.unwrap_or_else(rand::random), + } + } + + /// Wrap this node in a LogicalPlan::Extension. + fn into_plan(self) -> LogicalPlan { + LogicalPlan::Extension(Extension { + node: Arc::new(self), + }) + } +} + +impl UserDefinedLogicalNodeCore for TableSamplePlanNode { + fn name(&self) -> &str { + "TableSample" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + fn schema(&self) -> &DFSchemaRef { + self.input.schema() + } + + fn expressions(&self) -> Vec { + vec![] + } + + fn fmt_for_explain(&self, f: &mut Formatter) -> fmt::Result { + write!( + f, + "Sample: bounds=[{}, {}], seed={}", + self.lower_bound.0, self.upper_bound.0, self.seed + ) + } + + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + mut inputs: Vec, + ) -> Result { + Ok(Self { + input: inputs.swap_remove(0), + lower_bound: self.lower_bound, + upper_bound: self.upper_bound, + seed: self.seed, + }) + } +} + +/// Wrapper for f64 that implements Hash and Eq (required for LogicalPlan). +#[derive(Debug, Clone, Copy, PartialOrd)] +struct HashableF64(f64); + +impl PartialEq for HashableF64 { + fn eq(&self, other: &Self) -> bool { + self.0.to_bits() == other.0.to_bits() + } +} + +impl Eq for HashableF64 {} + +impl Hash for HashableF64 { + fn hash(&self, state: &mut H) { + self.0.to_bits().hash(state); + } +} + +// ============================================================================ +// Physical Planning: TableSampleQueryPlanner + TableSampleExtensionPlanner +// ============================================================================ + +/// Custom query planner that registers [`TableSampleExtensionPlanner`] to +/// convert [`TableSamplePlanNode`] into [`SampleExec`]. +#[derive(Debug)] +struct TableSampleQueryPlanner; + +#[async_trait] +impl QueryPlanner for TableSampleQueryPlanner { + async fn create_physical_plan( + &self, + logical_plan: &LogicalPlan, + session_state: &SessionState, + ) -> Result> { + let planner = DefaultPhysicalPlanner::with_extension_planners(vec![Arc::new( + TableSampleExtensionPlanner, + )]); + planner + .create_physical_plan(logical_plan, session_state) + .await + } +} + +/// Extension planner that converts [`TableSamplePlanNode`] to [`SampleExec`]. +struct TableSampleExtensionPlanner; + +#[async_trait] +impl ExtensionPlanner for TableSampleExtensionPlanner { + async fn plan_extension( + &self, + _planner: &dyn PhysicalPlanner, + node: &dyn UserDefinedLogicalNode, + _logical_inputs: &[&LogicalPlan], + physical_inputs: &[Arc], + _session_state: &SessionState, + ) -> Result>> { + let Some(sample_node) = node.as_any().downcast_ref::() + else { + return Ok(None); + }; + + let exec = SampleExec::try_new( + Arc::clone(&physical_inputs[0]), + sample_node.lower_bound.0, + sample_node.upper_bound.0, + sample_node.seed, + )?; + Ok(Some(Arc::new(exec))) + } +} + +// ============================================================================ +// Physical Execution: SampleExec + BernoulliSampler +// ============================================================================ + +/// Physical execution plan that samples rows from its input using Bernoulli sampling. +/// +/// Each row is independently selected with probability `(upper_bound - lower_bound)` +/// and appears at most once. +#[derive(Debug, Clone)] +pub struct SampleExec { + input: Arc, + lower_bound: f64, + upper_bound: f64, + seed: u64, + metrics: ExecutionPlanMetricsSet, + cache: Arc, +} + +impl SampleExec { + /// Create a new SampleExec with Bernoulli sampling (without replacement). + /// + /// # Arguments + /// * `input` - The input execution plan + /// * `lower_bound` - Lower bound of sampling range (typically 0.0) + /// * `upper_bound` - Upper bound of sampling range (0.0 to 1.0) + /// * `seed` - Random seed for reproducible sampling + pub fn try_new( + input: Arc, + lower_bound: f64, + upper_bound: f64, + seed: u64, + ) -> Result { + if lower_bound < 0.0 || upper_bound > 1.0 || lower_bound > upper_bound { + return internal_err!( + "Sampling bounds must satisfy 0.0 <= lower <= upper <= 1.0, got [{}, {}]", + lower_bound, + upper_bound + ); + } + + let cache = PlanProperties::new( + EquivalenceProperties::new(input.schema()), + input.properties().partitioning.clone(), + input.properties().emission_type, + input.properties().boundedness, + ); + + Ok(Self { + input, + lower_bound, + upper_bound, + seed, + metrics: ExecutionPlanMetricsSet::new(), + cache: Arc::new(cache), + }) + } + + /// Create a sampler for the given partition. + fn create_sampler(&self, partition: usize) -> BernoulliSampler { + let seed = self.seed.wrapping_add(partition as u64); + BernoulliSampler::new(self.lower_bound, self.upper_bound, seed) + } +} + +impl DisplayAs for SampleExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> fmt::Result { + write!( + f, + "SampleExec: bounds=[{}, {}], seed={}", + self.lower_bound, self.upper_bound, self.seed + ) + } +} + +impl ExecutionPlan for SampleExec { + fn name(&self) -> &'static str { + "SampleExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &Arc { + &self.cache + } + + fn maintains_input_order(&self) -> Vec { + // Sampling preserves row order (rows are filtered, not reordered) + vec![true] + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + mut children: Vec>, + ) -> Result> { + Ok(Arc::new(Self::try_new( + children.swap_remove(0), + self.lower_bound, + self.upper_bound, + self.seed, + )?)) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + Ok(Box::pin(SampleStream { + input: self.input.execute(partition, context)?, + sampler: self.create_sampler(partition), + metrics: BaselineMetrics::new(&self.metrics, partition), + })) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn partition_statistics(&self, partition: Option) -> Result> { + let mut stats = Arc::unwrap_or_clone(self.input.partition_statistics(partition)?); + let ratio = self.upper_bound - self.lower_bound; + + // Scale statistics by sampling ratio (inexact due to randomness) + stats.num_rows = stats + .num_rows + .map(|n| (n as f64 * ratio) as usize) + .to_inexact(); + stats.total_byte_size = stats + .total_byte_size + .map(|n| (n as f64 * ratio) as usize) + .to_inexact(); + + Ok(Arc::new(stats)) + } + + fn apply_expressions( + &self, + f: &mut dyn FnMut( + &dyn datafusion::physical_plan::PhysicalExpr, + ) -> Result, + ) -> Result { + // Visit expressions in the output ordering from equivalence properties + let mut tnr = TreeNodeRecursion::Continue; + if let Some(ordering) = self.cache.output_ordering() { + for sort_expr in ordering { + tnr = tnr.visit_sibling(|| f(sort_expr.expr.as_ref()))?; + } + } + Ok(tnr) + } +} + +/// Bernoulli sampler: includes each row with probability `(upper - lower)`. +/// This is sampling **without replacement** - each row appears at most once. +struct BernoulliSampler { + lower_bound: f64, + upper_bound: f64, + rng: StdRng, +} + +impl BernoulliSampler { + fn new(lower_bound: f64, upper_bound: f64, seed: u64) -> Self { + Self { + lower_bound, + upper_bound, + rng: StdRng::seed_from_u64(seed), + } + } + + fn sample(&mut self, batch: &RecordBatch) -> Result { + let range = self.upper_bound - self.lower_bound; + if range <= 0.0 { + return Ok(RecordBatch::new_empty(batch.schema())); + } + + // Select rows where random value falls in [lower, upper) + let indices: Vec = (0..batch.num_rows()) + .filter(|_| { + let r: f64 = self.rng.random(); + r >= self.lower_bound && r < self.upper_bound + }) + .map(|i| i as u32) + .collect(); + + if indices.is_empty() { + return Ok(RecordBatch::new_empty(batch.schema())); + } + + compute::take_record_batch(batch, &UInt32Array::from(indices)) + .map_err(DataFusionError::from) + } +} + +/// Stream adapter that applies sampling to each batch. +struct SampleStream { + input: SendableRecordBatchStream, + sampler: BernoulliSampler, + metrics: BaselineMetrics, +} + +impl Stream for SampleStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match ready!(self.input.poll_next_unpin(cx)) { + Some(Ok(batch)) => { + let elapsed = self.metrics.elapsed_compute().clone(); + let _timer = elapsed.timer(); + let result = self.sampler.sample(&batch); + Poll::Ready(Some(result.record_output(&self.metrics))) + } + Some(Err(e)) => Poll::Ready(Some(Err(e))), + None => Poll::Ready(None), + } + } +} + +impl RecordBatchStream for SampleStream { + fn schema(&self) -> SchemaRef { + self.input.schema() + } +} diff --git a/datafusion-examples/examples/sql_dialect.rs b/datafusion-examples/examples/sql_dialect.rs deleted file mode 100644 index 20b515506f3b4..0000000000000 --- a/datafusion-examples/examples/sql_dialect.rs +++ /dev/null @@ -1,134 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::fmt::Display; - -use datafusion::error::{DataFusionError, Result}; -use datafusion::sql::{ - parser::{CopyToSource, CopyToStatement, DFParser, DFParserBuilder, Statement}, - sqlparser::{keywords::Keyword, tokenizer::Token}, -}; - -/// This example demonstrates how to use the DFParser to parse a statement in a custom way -/// -/// This technique can be used to implement a custom SQL dialect, for example. -#[tokio::main] -async fn main() -> Result<()> { - let mut my_parser = - MyParser::new("COPY source_table TO 'file.fasta' STORED AS FASTA")?; - - let my_statement = my_parser.parse_statement()?; - - match my_statement { - MyStatement::DFStatement(s) => println!("df: {s}"), - MyStatement::MyCopyTo(s) => println!("my_copy: {s}"), - } - - Ok(()) -} - -/// Here we define a Parser for our new SQL dialect that wraps the existing `DFParser` -struct MyParser<'a> { - df_parser: DFParser<'a>, -} - -impl<'a> MyParser<'a> { - fn new(sql: &'a str) -> Result { - let df_parser = DFParserBuilder::new(sql).build()?; - Ok(Self { df_parser }) - } - - /// Returns true if the next token is `COPY` keyword, false otherwise - fn is_copy(&self) -> bool { - matches!( - self.df_parser.parser.peek_token().token, - Token::Word(w) if w.keyword == Keyword::COPY - ) - } - - /// This is the entry point to our parser -- it handles `COPY` statements specially - /// but otherwise delegates to the existing DataFusion parser. - pub fn parse_statement(&mut self) -> Result { - if self.is_copy() { - self.df_parser.parser.next_token(); // COPY - let df_statement = self.df_parser.parse_copy()?; - - if let Statement::CopyTo(s) = df_statement { - Ok(MyStatement::from(s)) - } else { - Ok(MyStatement::DFStatement(Box::from(df_statement))) - } - } else { - let df_statement = self.df_parser.parse_statement()?; - Ok(MyStatement::from(df_statement)) - } - } -} - -enum MyStatement { - DFStatement(Box), - MyCopyTo(MyCopyToStatement), -} - -impl Display for MyStatement { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - MyStatement::DFStatement(s) => write!(f, "{s}"), - MyStatement::MyCopyTo(s) => write!(f, "{s}"), - } - } -} - -impl From for MyStatement { - fn from(s: Statement) -> Self { - Self::DFStatement(Box::from(s)) - } -} - -impl From for MyStatement { - fn from(s: CopyToStatement) -> Self { - if s.stored_as == Some("FASTA".to_string()) { - Self::MyCopyTo(MyCopyToStatement::from(s)) - } else { - Self::DFStatement(Box::from(Statement::CopyTo(s))) - } - } -} - -struct MyCopyToStatement { - pub source: CopyToSource, - pub target: String, -} - -impl From for MyCopyToStatement { - fn from(s: CopyToStatement) -> Self { - Self { - source: s.source, - target: s.target, - } - } -} - -impl Display for MyCopyToStatement { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "COPY {} TO '{}' STORED AS FASTA", - self.source, self.target - ) - } -} diff --git a/datafusion-examples/examples/sql_analysis.rs b/datafusion-examples/examples/sql_ops/analysis.rs similarity index 96% rename from datafusion-examples/examples/sql_analysis.rs rename to datafusion-examples/examples/sql_ops/analysis.rs index d3826026a9725..4243a2927865b 100644 --- a/datafusion-examples/examples/sql_analysis.rs +++ b/datafusion-examples/examples/sql_ops/analysis.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. +//! //! This example shows how to use the structures that DataFusion provides to perform //! Analysis on SQL queries and their plans. //! @@ -23,8 +25,8 @@ use std::sync::Arc; -use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion::common::Result; +use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion::logical_expr::LogicalPlan; use datafusion::{ datasource::MemTable, @@ -32,141 +34,9 @@ use datafusion::{ }; use test_utils::tpcds::tpcds_schemas; -/// Counts the total number of joins in a plan -fn total_join_count(plan: &LogicalPlan) -> usize { - let mut total = 0; - - // We can use the TreeNode API to walk over a LogicalPlan. - plan.apply(|node| { - // if we encounter a join we update the running count - if matches!(node, LogicalPlan::Join(_)) { - total += 1; - } - Ok(TreeNodeRecursion::Continue) - }) - .unwrap(); - - total -} - -/// Counts the total number of joins in a plan and collects every join tree in -/// the plan with their respective join count. -/// -/// Join Tree Definition: the largest subtree consisting entirely of joins -/// -/// For example, this plan: -/// -/// ```text -/// JOIN -/// / \ -/// A JOIN -/// / \ -/// B C -/// ``` -/// -/// has a single join tree `(A-B-C)` which will result in `(2, [2])` -/// -/// This plan: -/// -/// ```text -/// JOIN -/// / \ -/// A GROUP -/// | -/// JOIN -/// / \ -/// B C -/// ``` -/// -/// Has two join trees `(A-, B-C)` which will result in `(2, [1, 1])` -fn count_trees(plan: &LogicalPlan) -> (usize, Vec) { - // this works the same way as `total_count`, but now when we encounter a Join - // we try to collect it's entire tree - let mut to_visit = vec![plan]; - let mut total = 0; - let mut groups = vec![]; - - while let Some(node) = to_visit.pop() { - // if we encounter a join, we know were at the root of the tree - // count this tree and recurse on it's inputs - if matches!(node, LogicalPlan::Join(_)) { - let (group_count, inputs) = count_tree(node); - total += group_count; - groups.push(group_count); - to_visit.extend(inputs); - } else { - to_visit.extend(node.inputs()); - } - } - - (total, groups) -} - -/// Count the entire join tree and return its inputs using TreeNode API -/// -/// For example, if this function receives following plan: -/// -/// ```text -/// JOIN -/// / \ -/// A GROUP -/// | -/// JOIN -/// / \ -/// B C -/// ``` -/// -/// It will return `(1, [A, GROUP])` -fn count_tree(join: &LogicalPlan) -> (usize, Vec<&LogicalPlan>) { - let mut inputs = Vec::new(); - let mut total = 0; - - join.apply(|node| { - // Some extra knowledge: - // - // optimized plans have their projections pushed down as far as - // possible, which sometimes results in a projection going in between 2 - // subsequent joins giving the illusion these joins are not "related", - // when in fact they are. - // - // This plan: - // JOIN - // / \ - // A PROJECTION - // | - // JOIN - // / \ - // B C - // - // is the same as: - // - // JOIN - // / \ - // A JOIN - // / \ - // B C - // we can continue the recursion in this case - if let LogicalPlan::Projection(_) = node { - return Ok(TreeNodeRecursion::Continue); - } - - // any join we count - if matches!(node, LogicalPlan::Join(_)) { - total += 1; - Ok(TreeNodeRecursion::Continue) - } else { - inputs.push(node); - // skip children of input node - Ok(TreeNodeRecursion::Jump) - } - }) - .unwrap(); - - (total, inputs) -} - -#[tokio::main] -async fn main() -> Result<()> { +/// Demonstrates how to analyze a SQL query by counting JOINs and identifying +/// join-trees using DataFusion’s `LogicalPlan` and `TreeNode` API. +pub async fn analysis() -> Result<()> { // To show how we can count the joins in a sql query we'll be using query 88 // from the TPC-DS benchmark. // @@ -274,7 +144,10 @@ from for table in tables { ctx.register_table( table.name, - Arc::new(MemTable::try_new(Arc::new(table.schema.clone()), vec![])?), + Arc::new(MemTable::try_new( + Arc::new(table.schema.clone()), + vec![vec![]], + )?), )?; } // We can create a LogicalPlan from a SQL query like this @@ -307,3 +180,136 @@ from Ok(()) } + +/// Counts the total number of joins in a plan +fn total_join_count(plan: &LogicalPlan) -> usize { + let mut total = 0; + + // We can use the TreeNode API to walk over a LogicalPlan. + plan.apply(|node| { + // if we encounter a join we update the running count + if matches!(node, LogicalPlan::Join(_)) { + total += 1; + } + Ok(TreeNodeRecursion::Continue) + }) + .unwrap(); + + total +} + +/// Counts the total number of joins in a plan and collects every join tree in +/// the plan with their respective join count. +/// +/// Join Tree Definition: the largest subtree consisting entirely of joins +/// +/// For example, this plan: +/// +/// ```text +/// JOIN +/// / \ +/// A JOIN +/// / \ +/// B C +/// ``` +/// +/// has a single join tree `(A-B-C)` which will result in `(2, [2])` +/// +/// This plan: +/// +/// ```text +/// JOIN +/// / \ +/// A GROUP +/// | +/// JOIN +/// / \ +/// B C +/// ``` +/// +/// Has two join trees `(A-, B-C)` which will result in `(2, [1, 1])` +fn count_trees(plan: &LogicalPlan) -> (usize, Vec) { + // this works the same way as `total_count`, but now when we encounter a Join + // we try to collect it's entire tree + let mut to_visit = vec![plan]; + let mut total = 0; + let mut groups = vec![]; + + while let Some(node) = to_visit.pop() { + // if we encounter a join, we know were at the root of the tree + // count this tree and recurse on it's inputs + if matches!(node, LogicalPlan::Join(_)) { + let (group_count, inputs) = count_tree(node); + total += group_count; + groups.push(group_count); + to_visit.extend(inputs); + } else { + to_visit.extend(node.inputs()); + } + } + + (total, groups) +} + +/// Count the entire join tree and return its inputs using TreeNode API +/// +/// For example, if this function receives following plan: +/// +/// ```text +/// JOIN +/// / \ +/// A GROUP +/// | +/// JOIN +/// / \ +/// B C +/// ``` +/// +/// It will return `(1, [A, GROUP])` +fn count_tree(join: &LogicalPlan) -> (usize, Vec<&LogicalPlan>) { + let mut inputs = Vec::new(); + let mut total = 0; + + join.apply(|node| { + // Some extra knowledge: + // + // optimized plans have their projections pushed down as far as + // possible, which sometimes results in a projection going in between 2 + // subsequent joins giving the illusion these joins are not "related", + // when in fact they are. + // + // This plan: + // JOIN + // / \ + // A PROJECTION + // | + // JOIN + // / \ + // B C + // + // is the same as: + // + // JOIN + // / \ + // A JOIN + // / \ + // B C + // we can continue the recursion in this case + if let LogicalPlan::Projection(_) = node { + return Ok(TreeNodeRecursion::Continue); + } + + // any join we count + if matches!(node, LogicalPlan::Join(_)) { + total += 1; + Ok(TreeNodeRecursion::Continue) + } else { + inputs.push(node); + // skip children of input node + Ok(TreeNodeRecursion::Jump) + } + }) + .unwrap(); + + (total, inputs) +} diff --git a/datafusion-examples/examples/sql_ops/custom_sql_parser.rs b/datafusion-examples/examples/sql_ops/custom_sql_parser.rs new file mode 100644 index 0000000000000..308a0de62a242 --- /dev/null +++ b/datafusion-examples/examples/sql_ops/custom_sql_parser.rs @@ -0,0 +1,420 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This example demonstrates extending the DataFusion SQL parser to support +//! custom DDL statements, specifically `CREATE EXTERNAL CATALOG`. +//! +//! ### Custom Syntax +//! ```sql +//! CREATE EXTERNAL CATALOG my_catalog +//! STORED AS ICEBERG +//! LOCATION 's3://my-bucket/warehouse/' +//! OPTIONS ( +//! 'region' = 'us-west-2' +//! ); +//! ``` +//! +//! Note: For the purpose of this example, we use `local://workspace/` to +//! automatically discover and register files from the project's test data. + +use std::collections::HashMap; +use std::fmt::Display; +use std::sync::Arc; + +use datafusion::catalog::{ + CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider, + TableProviderFactory, +}; +use datafusion::datasource::listing_table_factory::ListingTableFactory; +use datafusion::error::{DataFusionError, Result}; +use datafusion::prelude::SessionContext; +use datafusion::sql::{ + parser::{DFParser, DFParserBuilder, Statement}, + sqlparser::{ + ast::{ObjectName, Value}, + keywords::Keyword, + tokenizer::Token, + }, +}; +use datafusion_common::{DFSchema, TableReference, plan_datafusion_err, plan_err}; +use datafusion_expr::CreateExternalTable; +use futures::StreamExt; +use insta::assert_snapshot; +use object_store::ObjectStore; +use object_store::local::LocalFileSystem; + +/// Entry point for the example. +pub async fn custom_sql_parser() -> Result<()> { + // Use standard Parquet testing data as our "external" source. + let base_path = datafusion::common::test_util::parquet_test_data(); + let base_path = std::path::Path::new(&base_path).canonicalize()?; + + // Make the path relative to the workspace root + let workspace_root = workspace_root(); + let location = base_path + .strip_prefix(&workspace_root) + .map(|p| p.to_string_lossy().to_string()) + .unwrap_or_else(|_| base_path.to_string_lossy().to_string()); + + let create_catalog_sql = format!( + "CREATE EXTERNAL CATALOG parquet_testing + STORED AS parquet + LOCATION 'local://workspace/{location}' + OPTIONS ( + 'schema_name' = 'staged_data', + 'format.pruning' = 'true' + )" + ); + + // ========================================================================= + // Part 1: Standard DataFusion parser rejects the custom DDL + // ========================================================================= + println!("=== Part 1: Standard DataFusion Parser ===\n"); + println!("Parsing: {}\n", create_catalog_sql.trim()); + + let ctx_standard = SessionContext::new(); + let err = ctx_standard + .sql(&create_catalog_sql) + .await + .expect_err("Expected the standard parser to reject CREATE EXTERNAL CATALOG (custom DDL syntax)"); + + println!("Error: {err}\n"); + assert_snapshot!(err.to_string(), @r#"SQL error: ParserError("Expected: TABLE, found: CATALOG at Line: 1, Column: 17")"#); + + // ========================================================================= + // Part 2: Custom parser handles the statement + // ========================================================================= + println!("=== Part 2: Custom Parser ===\n"); + println!("Parsing: {}\n", create_catalog_sql.trim()); + + let ctx = SessionContext::new(); + + let mut parser = CustomParser::new(&create_catalog_sql)?; + let statement = parser.parse_statement()?; + match statement { + CustomStatement::CreateExternalCatalog(stmt) => { + handle_create_external_catalog(&ctx, stmt).await?; + } + CustomStatement::DFStatement(_) => { + panic!("Expected CreateExternalCatalog statement"); + } + } + + // Query a table from the registered catalog + let query_sql = "SELECT id, bool_col, tinyint_col FROM parquet_testing.staged_data.alltypes_plain LIMIT 5"; + println!("Executing: {query_sql}\n"); + + let results = execute_sql(&ctx, query_sql).await?; + println!("{results}"); + assert_snapshot!(results, @r" + +----+----------+-------------+ + | id | bool_col | tinyint_col | + +----+----------+-------------+ + | 4 | true | 0 | + | 5 | false | 1 | + | 6 | true | 0 | + | 7 | false | 1 | + | 2 | true | 0 | + +----+----------+-------------+ + "); + + Ok(()) +} + +/// Execute SQL and return formatted results. +async fn execute_sql(ctx: &SessionContext, sql: &str) -> Result { + let batches = ctx.sql(sql).await?.collect().await?; + Ok(arrow::util::pretty::pretty_format_batches(&batches)?.to_string()) +} + +/// Custom handler for the `CREATE EXTERNAL CATALOG` statement. +async fn handle_create_external_catalog( + ctx: &SessionContext, + stmt: CreateExternalCatalog, +) -> Result<()> { + let factory = ListingTableFactory::new(); + let catalog = Arc::new(MemoryCatalogProvider::new()); + let schema = Arc::new(MemorySchemaProvider::new()); + + // Extract options + let mut schema_name = "public".to_string(); + let mut table_options = HashMap::new(); + + for (k, v) in stmt.options { + let val_str = match v { + Value::SingleQuotedString(ref s) | Value::DoubleQuotedString(ref s) => { + s.to_string() + } + Value::Number(ref n, _) => n.to_string(), + Value::Boolean(b) => b.to_string(), + _ => v.to_string(), + }; + + if k == "schema_name" { + schema_name = val_str; + } else { + table_options.insert(k, val_str); + } + } + + println!(" Target Catalog: {}", stmt.name); + println!(" Data Location: {}", stmt.location); + println!(" Resolved Schema: {schema_name}"); + + // Register a local object store rooted at the workspace root. + // We use a specific authority 'workspace' to ensure consistent resolution. + let store = Arc::new(LocalFileSystem::new_with_prefix(workspace_root())?); + let store_url = url::Url::parse("local://workspace").unwrap(); + ctx.register_object_store(&store_url, Arc::clone(&store) as _); + + let target_ext = format!(".{}", stmt.catalog_type.to_lowercase()); + + // For 'local://workspace/parquet-testing/data', the path is 'parquet-testing/data'. + let path_str = stmt + .location + .strip_prefix("local://workspace/") + .unwrap_or(&stmt.location); + let prefix = object_store::path::Path::from(path_str); + + // Discover data files using the ObjectStore API + let mut table_count = 0; + let mut list_stream = store.list(Some(&prefix)); + + while let Some(meta) = list_stream.next().await { + let meta = meta?; + let path = &meta.location; + + if path.as_ref().ends_with(&target_ext) { + let name = std::path::Path::new(path.as_ref()) + .file_stem() + .unwrap() + .to_string_lossy() + .to_string(); + + let table_url = format!("local://workspace/{path}"); + + let cmd = CreateExternalTable::builder( + TableReference::bare(name.clone()), + table_url, + stmt.catalog_type.clone(), + Arc::new(DFSchema::empty()), + ) + .with_options(table_options.clone()) + .build(); + + match factory.create(&ctx.state(), &cmd).await { + Ok(table) => { + schema.register_table(name, table)?; + table_count += 1; + } + Err(e) => { + eprintln!("Failed to create table {name}: {e}"); + } + } + } + } + println!(" Registered {table_count} tables into schema: {schema_name}"); + + catalog.register_schema(&schema_name, schema)?; + ctx.register_catalog(stmt.name.to_string(), catalog); + + Ok(()) +} + +/// Possible statements returned by our custom parser. +#[derive(Debug, Clone)] +pub enum CustomStatement { + /// Standard DataFusion statement + DFStatement(Box), + /// Custom `CREATE EXTERNAL CATALOG` statement + CreateExternalCatalog(CreateExternalCatalog), +} + +/// Data structure for `CREATE EXTERNAL CATALOG`. +#[derive(Debug, Clone)] +pub struct CreateExternalCatalog { + pub name: ObjectName, + pub catalog_type: String, + pub location: String, + pub options: Vec<(String, Value)>, +} + +impl Display for CustomStatement { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::DFStatement(s) => write!(f, "{s}"), + Self::CreateExternalCatalog(s) => write!(f, "{s}"), + } + } +} + +impl Display for CreateExternalCatalog { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "CREATE EXTERNAL CATALOG {} STORED AS {} LOCATION '{}'", + self.name, self.catalog_type, self.location + )?; + if !self.options.is_empty() { + write!(f, " OPTIONS (")?; + for (i, (k, v)) in self.options.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "'{k}' = '{v}'")?; + } + write!(f, ")")?; + } + Ok(()) + } +} + +/// A parser that extends `DFParser` with custom syntax. +struct CustomParser<'a> { + df_parser: DFParser<'a>, +} + +impl<'a> CustomParser<'a> { + fn new(sql: &'a str) -> Result { + Ok(Self { + df_parser: DFParserBuilder::new(sql).build()?, + }) + } + + pub fn parse_statement(&mut self) -> Result { + if self.is_create_external_catalog() { + return self.parse_create_external_catalog(); + } + Ok(CustomStatement::DFStatement(Box::new( + self.df_parser.parse_statement()?, + ))) + } + + fn is_create_external_catalog(&self) -> bool { + let t1 = &self.df_parser.parser.peek_nth_token(0).token; + let t2 = &self.df_parser.parser.peek_nth_token(1).token; + let t3 = &self.df_parser.parser.peek_nth_token(2).token; + + matches!(t1, Token::Word(w) if w.keyword == Keyword::CREATE) + && matches!(t2, Token::Word(w) if w.keyword == Keyword::EXTERNAL) + && matches!(t3, Token::Word(w) if w.value.to_uppercase() == "CATALOG") + } + + fn parse_create_external_catalog(&mut self) -> Result { + // Consume prefix tokens: CREATE EXTERNAL CATALOG + for _ in 0..3 { + self.df_parser.parser.next_token(); + } + + let name = self + .df_parser + .parser + .parse_object_name(false) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + + let mut catalog_type = None; + let mut location = None; + let mut options = vec![]; + + while let Some(keyword) = self.df_parser.parser.parse_one_of_keywords(&[ + Keyword::STORED, + Keyword::LOCATION, + Keyword::OPTIONS, + ]) { + match keyword { + Keyword::STORED => { + if catalog_type.is_some() { + return plan_err!("Duplicate STORED AS"); + } + self.df_parser + .parser + .expect_keyword(Keyword::AS) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + catalog_type = Some( + self.df_parser + .parser + .parse_identifier() + .map_err(|e| DataFusionError::External(Box::new(e)))? + .value, + ); + } + Keyword::LOCATION => { + if location.is_some() { + return plan_err!("Duplicate LOCATION"); + } + location = Some( + self.df_parser + .parser + .parse_literal_string() + .map_err(|e| DataFusionError::External(Box::new(e)))?, + ); + } + Keyword::OPTIONS => { + if !options.is_empty() { + return plan_err!("Duplicate OPTIONS"); + } + options = self.parse_value_options()?; + } + _ => unreachable!(), + } + } + + Ok(CustomStatement::CreateExternalCatalog( + CreateExternalCatalog { + name, + catalog_type: catalog_type + .ok_or_else(|| plan_datafusion_err!("Missing STORED AS"))?, + location: location + .ok_or_else(|| plan_datafusion_err!("Missing LOCATION"))?, + options, + }, + )) + } + + /// Parse options in the form: (key [=] value, key [=] value, ...) + fn parse_value_options(&mut self) -> Result> { + let mut options = vec![]; + self.df_parser + .parser + .expect_token(&Token::LParen) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + + loop { + let key = self.df_parser.parse_option_key()?; + // Support optional '=' between key and value + let _ = self.df_parser.parser.consume_token(&Token::Eq); + let value = self.df_parser.parse_option_value()?; + options.push((key, value)); + + let comma = self.df_parser.parser.consume_token(&Token::Comma); + if self.df_parser.parser.consume_token(&Token::RParen) { + break; + } else if !comma { + return plan_err!("Expected ',' or ')' in OPTIONS"); + } + } + Ok(options) + } +} + +/// Returns the workspace root directory (parent of datafusion-examples). +fn workspace_root() -> std::path::PathBuf { + std::path::Path::new(env!("CARGO_MANIFEST_DIR")) + .parent() + .expect("CARGO_MANIFEST_DIR should have a parent") + .to_path_buf() +} diff --git a/datafusion-examples/examples/sql_frontend.rs b/datafusion-examples/examples/sql_ops/frontend.rs similarity index 98% rename from datafusion-examples/examples/sql_frontend.rs rename to datafusion-examples/examples/sql_ops/frontend.rs index 3955d5038cfb0..025fe47e75b07 100644 --- a/datafusion-examples/examples/sql_frontend.rs +++ b/datafusion-examples/examples/sql_ops/frontend.rs @@ -15,8 +15,10 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use datafusion::common::{plan_err, TableReference}; +use datafusion::common::{TableReference, plan_err}; use datafusion::config::ConfigOptions; use datafusion::error::Result; use datafusion::logical_expr::{ @@ -44,7 +46,7 @@ use std::sync::Arc; /// /// In this example, we demonstrate how to use the lower level APIs directly, /// which only requires the `datafusion-sql` dependency. -pub fn main() -> Result<()> { +pub fn frontend() -> Result<()> { // First, we parse the SQL string. Note that we use the DataFusion // Parser, which wraps the `sqlparser-rs` SQL parser and adds DataFusion // specific syntax such as `CREATE EXTERNAL TABLE` @@ -83,7 +85,7 @@ pub fn main() -> Result<()> { let config = OptimizerContext::default().with_skip_failing_rules(false); let analyzed_plan = Analyzer::new().execute_and_check( logical_plan, - config.options(), + &config.options(), observe_analyzer, )?; // Note that the Analyzer has added a CAST to the plan to align the types diff --git a/datafusion-examples/examples/sql_ops/main.rs b/datafusion-examples/examples/sql_ops/main.rs new file mode 100644 index 0000000000000..ce7be8fa2bada --- /dev/null +++ b/datafusion-examples/examples/sql_ops/main.rs @@ -0,0 +1,102 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # SQL Examples +//! +//! These examples demonstrate SQL operations in DataFusion. +//! +//! ## Usage +//! ```bash +//! cargo run --example sql_ops -- [all|analysis|custom_sql_parser|frontend|query] +//! ``` +//! +//! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module +//! +//! - `analysis` +//! (file: analysis.rs, desc: Analyze SQL queries) +//! +//! - `custom_sql_parser` +//! (file: custom_sql_parser.rs, desc: Implement a custom SQL parser to extend DataFusion) +//! +//! - `frontend` +//! (file: frontend.rs, desc: Build LogicalPlans from SQL) +//! +//! - `query` +//! (file: query.rs, desc: Query data using SQL) + +mod analysis; +mod custom_sql_parser; +mod frontend; +mod query; + +use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; + +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] +enum ExampleKind { + All, + Analysis, + CustomSqlParser, + Frontend, + Query, +} + +impl ExampleKind { + const EXAMPLE_NAME: &str = "sql_ops"; + + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<()> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::Analysis => analysis::analysis().await?, + ExampleKind::CustomSqlParser => { + custom_sql_parser::custom_sql_parser().await? + } + ExampleKind::Frontend => frontend::frontend()?, + ExampleKind::Query => query::query().await?, + } + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let usage = format!( + "Usage: cargo run --example {} -- [{}]", + ExampleKind::EXAMPLE_NAME, + ExampleKind::VARIANTS.join("|") + ); + + let example: ExampleKind = std::env::args() + .nth(1) + .unwrap_or_else(|| ExampleKind::All.to_string()) + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; + + example.run().await +} diff --git a/datafusion-examples/examples/sql_query.rs b/datafusion-examples/examples/sql_ops/query.rs similarity index 66% rename from datafusion-examples/examples/sql_query.rs rename to datafusion-examples/examples/sql_ops/query.rs index 0ac203cfb7e74..60b47c36b9ae2 100644 --- a/datafusion-examples/examples/sql_query.rs +++ b/datafusion-examples/examples/sql_ops/query.rs @@ -15,26 +15,27 @@ // specific language governing permissions and limitations // under the License. -use datafusion::arrow::array::{UInt64Array, UInt8Array}; +//! See `main.rs` for how to run it. + +use std::sync::Arc; + +use datafusion::arrow::array::{UInt8Array, UInt64Array}; use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::arrow::record_batch::RecordBatch; +use datafusion::catalog::MemTable; use datafusion::common::{assert_batches_eq, exec_datafusion_err}; use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::ListingOptions; -use datafusion::datasource::MemTable; use datafusion::error::{DataFusionError, Result}; use datafusion::prelude::*; +use datafusion_examples::utils::{datasets::ExampleDataset, write_csv_to_parquet}; use object_store::local::LocalFileSystem; -use std::path::Path; -use std::sync::Arc; /// Examples of various ways to execute queries using SQL /// /// [`query_memtable`]: a simple query against a [`MemTable`] /// [`query_parquet`]: a simple query against a directory with multiple Parquet files -/// -#[tokio::main] -async fn main() -> Result<()> { +pub async fn query() -> Result<()> { query_memtable().await?; query_parquet().await?; Ok(()) @@ -113,32 +114,33 @@ async fn query_parquet() -> Result<()> { // create local execution context let ctx = SessionContext::new(); - let test_data = datafusion::test_util::parquet_test_data(); + // Convert the CSV input into a temporary Parquet directory for querying + let dataset = ExampleDataset::Cars; + let parquet_temp = write_csv_to_parquet(&ctx, &dataset.path()).await?; // Configure listing options let file_format = ParquetFormat::default().with_enable_pruning(true); - let listing_options = ListingOptions::new(Arc::new(file_format)) - // This is a workaround for this example since `test_data` contains - // many different parquet different files, - // in practice use FileType::PARQUET.get_ext(). - .with_file_extension("alltypes_plain.parquet"); + let listing_options = + ListingOptions::new(Arc::new(file_format)).with_file_extension(".parquet"); + + let table_path = parquet_temp.file_uri()?; // First example were we use an absolute path, which requires no additional setup. ctx.register_listing_table( "my_table", - &format!("file://{test_data}/"), + &table_path, listing_options.clone(), None, None, ) - .await - .unwrap(); + .await?; // execute the query let df = ctx .sql( "SELECT * \ FROM my_table \ + ORDER BY speed \ LIMIT 1", ) .await?; @@ -147,20 +149,22 @@ async fn query_parquet() -> Result<()> { let results = df.collect().await?; assert_batches_eq!( [ - "+----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+", - "| id | bool_col | tinyint_col | smallint_col | int_col | bigint_col | float_col | double_col | date_string_col | string_col | timestamp_col |", - "+----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+", - "| 4 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30332f30312f3039 | 30 | 2009-03-01T00:00:00 |", - "+----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+", + "+-----+-------+---------------------+", + "| car | speed | time |", + "+-----+-------+---------------------+", + "| red | 0.0 | 1996-04-12T12:05:15 |", + "+-----+-------+---------------------+", ], - &results); + &results + ); - // Second example were we temporarily move into the test data's parent directory and - // simulate a relative path, this requires registering an ObjectStore. + // Second example where we change the current working directory and explicitly + // register a local filesystem object store. This demonstrates how listing tables + // resolve paths via an ObjectStore, even when using filesystem-backed data. let cur_dir = std::env::current_dir()?; - - let test_data_path = Path::new(&test_data); - let test_data_path_parent = test_data_path + let test_data_path_parent = parquet_temp + .tmp_dir + .path() .parent() .ok_or(exec_datafusion_err!("test_data path needs a parent"))?; @@ -168,15 +172,15 @@ async fn query_parquet() -> Result<()> { let local_fs = Arc::new(LocalFileSystem::default()); - let u = url::Url::parse("file://./") + let url = url::Url::parse("file://./") .map_err(|e| DataFusionError::External(Box::new(e)))?; - ctx.register_object_store(&u, local_fs); + ctx.register_object_store(&url, local_fs); // Register a listing table - this will use all files in the directory as data sources // for the query ctx.register_listing_table( "relative_table", - "./data", + parquet_temp.path_str()?, listing_options.clone(), None, None, @@ -188,6 +192,7 @@ async fn query_parquet() -> Result<()> { .sql( "SELECT * \ FROM relative_table \ + ORDER BY speed \ LIMIT 1", ) .await?; @@ -196,13 +201,14 @@ async fn query_parquet() -> Result<()> { let results = df.collect().await?; assert_batches_eq!( [ - "+----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+", - "| id | bool_col | tinyint_col | smallint_col | int_col | bigint_col | float_col | double_col | date_string_col | string_col | timestamp_col |", - "+----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+", - "| 4 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30332f30312f3039 | 30 | 2009-03-01T00:00:00 |", - "+----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+", + "+-----+-------+---------------------+", + "| car | speed | time |", + "+-----+-------+---------------------+", + "| red | 0.0 | 1996-04-12T12:05:15 |", + "+-----+-------+---------------------+", ], - &results); + &results + ); // Reset the current directory std::env::set_current_dir(cur_dir)?; diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/udf/advanced_udaf.rs similarity index 95% rename from datafusion-examples/examples/advanced_udaf.rs rename to datafusion-examples/examples/udf/advanced_udaf.rs index 7b1d3e94b2efe..89f621d30e18d 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/udf/advanced_udaf.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use arrow::datatypes::{Field, Schema}; use datafusion::physical_expr::NullState; use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; @@ -26,13 +28,13 @@ use arrow::array::{ use arrow::datatypes::{ArrowNativeTypeOp, ArrowPrimitiveType, Float64Type, UInt32Type}; use arrow::record_batch::RecordBatch; use arrow_schema::FieldRef; -use datafusion::common::{cast::as_float64_array, ScalarValue}; +use datafusion::common::{ScalarValue, cast::as_float64_array}; use datafusion::error::Result; use datafusion::logical_expr::{ + Accumulator, AggregateUDF, AggregateUDFImpl, EmitTo, GroupsAccumulator, Signature, expr::AggregateFunction, function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs}, - simplify::SimplifyInfo, - Accumulator, AggregateUDF, AggregateUDFImpl, EmitTo, GroupsAccumulator, Signature, + simplify::SimplifyContext, }; use datafusion::prelude::*; @@ -41,7 +43,7 @@ use datafusion::prelude::*; /// a function `accumulator` that returns the `Accumulator` instance. /// /// To do so, we must implement the `AggregateUDFImpl` trait. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] struct GeoMeanUdaf { signature: Signature, } @@ -312,12 +314,16 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator { let prods = emit_to.take_needed(&mut self.prods); let nulls = self.null_state.build(emit_to); - assert_eq!(nulls.len(), prods.len()); + if let Some(nulls) = &nulls { + assert_eq!(nulls.len(), counts.len()); + } assert_eq!(counts.len(), prods.len()); // don't evaluate geometric mean with null inputs to avoid errors on null values - let array: PrimitiveArray = if nulls.null_count() > 0 { + let array: PrimitiveArray = if let Some(nulls) = &nulls + && nulls.null_count() > 0 + { let mut builder = PrimitiveBuilder::::with_capacity(nulls.len()); let iter = prods.into_iter().zip(counts).zip(nulls.iter()); @@ -335,7 +341,7 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator { .zip(counts) .map(|(prod, count)| prod.powf(1.0 / count as f64)) .collect::>(); - PrimitiveArray::new(geo_mean.into(), Some(nulls)) // no copy + PrimitiveArray::new(geo_mean.into(), nulls) // no copy .with_data_type(self.return_data_type.clone()) }; @@ -345,7 +351,6 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator { // return arrays for counts and prods fn state(&mut self, emit_to: EmitTo) -> Result> { let nulls = self.null_state.build(emit_to); - let nulls = Some(nulls); let counts = emit_to.take_needed(&mut self.counts); let counts = UInt32Array::new(counts.into(), nulls.clone()); // zero copy @@ -368,7 +373,7 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator { /// This example shows how to use the AggregateUDFImpl::simplify API to simplify/replace user /// defined aggregate function with a different expression which is defined in the `simplify` method. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] struct SimplifiedGeoMeanUdaf { signature: Signature, } @@ -419,7 +424,7 @@ impl AggregateUDFImpl for SimplifiedGeoMeanUdaf { /// Optionally replaces a UDAF with another expression during query optimization. fn simplify(&self) -> Option { - let simplify = |aggregate_function: AggregateFunction, _: &dyn SimplifyInfo| { + let simplify = |aggregate_function: AggregateFunction, _: &SimplifyContext| { // Replaces the UDAF with `GeoMeanUdaf` as a placeholder example to demonstrate the `simplify` method. // In real-world scenarios, you might create UDFs from built-in expressions. Ok(Expr::AggregateFunction(AggregateFunction::new_udf( @@ -469,8 +474,9 @@ fn create_context() -> Result { Ok(ctx) } -#[tokio::main] -async fn main() -> Result<()> { +/// In this example we register `GeoMeanUdaf` and `SimplifiedGeoMeanUdaf` +/// as user defined aggregate functions and invoke them via the DataFrame API and SQL +pub async fn advanced_udaf() -> Result<()> { let ctx = create_context()?; let geo_mean_udf = AggregateUDF::from(GeoMeanUdaf::new()); diff --git a/datafusion-examples/examples/advanced_udf.rs b/datafusion-examples/examples/udf/advanced_udf.rs similarity index 98% rename from datafusion-examples/examples/advanced_udf.rs rename to datafusion-examples/examples/udf/advanced_udf.rs index 290d1c53334b7..a00a7e7df434f 100644 --- a/datafusion-examples/examples/advanced_udf.rs +++ b/datafusion-examples/examples/udf/advanced_udf.rs @@ -15,19 +15,21 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use std::any::Any; use std::sync::Arc; use arrow::array::{ - new_null_array, Array, ArrayRef, AsArray, Float32Array, Float64Array, + Array, ArrayRef, AsArray, Float32Array, Float64Array, new_null_array, }; use arrow::compute; use arrow::datatypes::{DataType, Float64Type}; use arrow::record_batch::RecordBatch; -use datafusion::common::{exec_err, internal_err, ScalarValue}; +use datafusion::common::{ScalarValue, exec_err, internal_err}; use datafusion::error::Result; -use datafusion::logical_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion::logical_expr::Volatility; +use datafusion::logical_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion::logical_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, }; @@ -39,7 +41,7 @@ use datafusion::prelude::*; /// the power of the second argument `a^b`. /// /// To do so, we must implement the `ScalarUDFImpl` trait. -#[derive(Debug, Clone)] +#[derive(Debug, PartialEq, Eq, Hash)] struct PowUdf { signature: Signature, aliases: Vec, @@ -245,10 +247,35 @@ fn maybe_pow_in_place(base: f64, exp_array: ArrayRef) -> Result { } } +/// create local execution context with an in-memory table: +/// +/// ```text +/// +-----+-----+ +/// | a | b | +/// +-----+-----+ +/// | 2.1 | 1.0 | +/// | 3.1 | 2.0 | +/// | 4.1 | 3.0 | +/// | 5.1 | 4.0 | +/// +-----+-----+ +/// ``` +fn create_context() -> Result { + // define data. + let a: ArrayRef = Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1])); + let b: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])); + let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)])?; + + // declare a new context. In Spark API, this corresponds to a new SparkSession + let ctx = SessionContext::new(); + + // declare a table in memory. In Spark API, this corresponds to createDataFrame(...). + ctx.register_batch("t", batch)?; + Ok(ctx) +} + /// In this example we register `PowUdf` as a user defined function /// and invoke it via the DataFrame API and SQL -#[tokio::main] -async fn main() -> Result<()> { +pub async fn advanced_udf() -> Result<()> { let ctx = create_context()?; // create the UDF @@ -295,29 +322,3 @@ async fn main() -> Result<()> { Ok(()) } - -/// create local execution context with an in-memory table: -/// -/// ```text -/// +-----+-----+ -/// | a | b | -/// +-----+-----+ -/// | 2.1 | 1.0 | -/// | 3.1 | 2.0 | -/// | 4.1 | 3.0 | -/// | 5.1 | 4.0 | -/// +-----+-----+ -/// ``` -fn create_context() -> Result { - // define data. - let a: ArrayRef = Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1])); - let b: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])); - let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)])?; - - // declare a new context. In Spark API, this corresponds to a new SparkSession - let ctx = SessionContext::new(); - - // declare a table in memory. In Spark API, this corresponds to createDataFrame(...). - ctx.register_batch("t", batch)?; - Ok(ctx) -} diff --git a/datafusion-examples/examples/advanced_udwf.rs b/datafusion-examples/examples/udf/advanced_udwf.rs similarity index 89% rename from datafusion-examples/examples/advanced_udwf.rs rename to datafusion-examples/examples/udf/advanced_udwf.rs index 4f00e04e7e993..615d099c2854d 100644 --- a/datafusion-examples/examples/advanced_udwf.rs +++ b/datafusion-examples/examples/udf/advanced_udwf.rs @@ -15,8 +15,9 @@ // specific language governing permissions and limitations // under the License. -use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; -use std::any::Any; +//! See `main.rs` for how to run it. + +use std::{any::Any, sync::Arc}; use arrow::datatypes::Field; use arrow::{ @@ -31,19 +32,22 @@ use datafusion::logical_expr::expr::{WindowFunction, WindowFunctionParams}; use datafusion::logical_expr::function::{ PartitionEvaluatorArgs, WindowFunctionSimplification, WindowUDFFieldArgs, }; -use datafusion::logical_expr::simplify::SimplifyInfo; +use datafusion::logical_expr::simplify::SimplifyContext; use datafusion::logical_expr::{ - Expr, PartitionEvaluator, Signature, WindowFrame, WindowFunctionDefinition, - WindowUDF, WindowUDFImpl, + Expr, LimitEffect, PartitionEvaluator, Signature, WindowFrame, + WindowFunctionDefinition, WindowUDF, WindowUDFImpl, }; +use datafusion::physical_expr::PhysicalExpr; use datafusion::prelude::*; +use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; +use datafusion_examples::utils::datasets::ExampleDataset; /// This example shows how to use the full WindowUDFImpl API to implement a user /// defined window function. As in the `simple_udwf.rs` example, this struct implements /// a function `partition_evaluator` that returns the `MyPartitionEvaluator` instance. /// /// To do so, we must implement the `WindowUDFImpl` trait. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] struct SmoothItUdf { signature: Signature, } @@ -91,6 +95,10 @@ impl WindowUDFImpl for SmoothItUdf { fn field(&self, field_args: WindowUDFFieldArgs) -> Result { Ok(Field::new(field_args.name(), DataType::Float64, true).into()) } + + fn limit_effect(&self, _args: &[Arc]) -> LimitEffect { + LimitEffect::Unknown + } } /// This implements the lowest level evaluation for a window function @@ -149,7 +157,7 @@ impl PartitionEvaluator for MyPartitionEvaluator { } /// This UDWF will show how to use the WindowUDFImpl::simplify() API -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] struct SimplifySmoothItUdf { signature: Signature, } @@ -190,8 +198,8 @@ impl WindowUDFImpl for SimplifySmoothItUdf { /// this function will simplify `SimplifySmoothItUdf` to `AggregateUDF` for `Avg` /// default implementation will not be called (left as `todo!()`) fn simplify(&self) -> Option { - let simplify = |window_function: WindowFunction, _: &dyn SimplifyInfo| { - Ok(Expr::WindowFunction(WindowFunction { + let simplify = |window_function: WindowFunction, _: &SimplifyContext| { + Ok(Expr::from(WindowFunction { fun: WindowFunctionDefinition::AggregateUDF(avg_udaf()), params: WindowFunctionParams { args: window_function.params.args, @@ -199,6 +207,8 @@ impl WindowUDFImpl for SimplifySmoothItUdf { order_by: window_function.params.order_by, window_frame: window_function.params.window_frame, null_treatment: window_function.params.null_treatment, + distinct: window_function.params.distinct, + filter: window_function.params.filter, }, })) }; @@ -209,6 +219,10 @@ impl WindowUDFImpl for SimplifySmoothItUdf { fn field(&self, field_args: WindowUDFFieldArgs) -> Result { Ok(Field::new(field_args.name(), DataType::Float64, true).into()) } + + fn limit_effect(&self, _args: &[Arc]) -> LimitEffect { + LimitEffect::Unknown + } } // create local execution context with `cars.csv` registered as a table named `cars` @@ -216,17 +230,17 @@ async fn create_context() -> Result { // declare a new context. In spark API, this corresponds to a new spark SQL session let ctx = SessionContext::new(); - // declare a table in memory. In spark API, this corresponds to createDataFrame(...). - println!("pwd: {}", std::env::current_dir().unwrap().display()); - let csv_path = "../../datafusion/core/tests/data/cars.csv".to_string(); - let read_options = CsvReadOptions::default().has_header(true); + let dataset = ExampleDataset::Cars; + + ctx.register_csv("cars", dataset.path_str()?, CsvReadOptions::new()) + .await?; - ctx.register_csv("cars", &csv_path, read_options).await?; Ok(ctx) } -#[tokio::main] -async fn main() -> Result<()> { +/// In this example we register `SmoothItUdf` as user defined window function +/// and invoke it via the DataFrame API and SQL +pub async fn advanced_udwf() -> Result<()> { let ctx = create_context().await?; let smooth_it = WindowUDF::from(SmoothItUdf::new()); ctx.register_udwf(smooth_it.clone()); diff --git a/datafusion-examples/examples/udf/async_udf.rs b/datafusion-examples/examples/udf/async_udf.rs new file mode 100644 index 0000000000000..3d8faf623d439 --- /dev/null +++ b/datafusion-examples/examples/udf/async_udf.rs @@ -0,0 +1,238 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. +//! +//! This example shows how to create and use "Async UDFs" in DataFusion. +//! +//! Async UDFs allow you to perform asynchronous operations, such as +//! making network requests. This can be used for tasks like fetching +//! data from an external API such as a LLM service or an external database. + +use std::{any::Any, sync::Arc}; + +use arrow::array::{ArrayRef, BooleanArray, Int64Array, RecordBatch, StringArray}; +use arrow_schema::{DataType, Field, Schema}; +use async_trait::async_trait; +use datafusion::assert_batches_eq; +use datafusion::common::cast::as_string_view_array; +use datafusion::common::error::Result; +use datafusion::common::not_impl_err; +use datafusion::common::utils::take_function_args; +use datafusion::execution::SessionStateBuilder; +use datafusion::logical_expr::async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl}; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion::prelude::{SessionConfig, SessionContext}; + +/// In this example we register `AskLLM` as an asynchronous user defined function +/// and invoke it via the DataFrame API and SQL +pub async fn async_udf() -> Result<()> { + // Use a hard coded parallelism level of 4 so the explain plan + // is consistent across machines. + let config = SessionConfig::new().with_target_partitions(4); + let ctx = + SessionContext::from(SessionStateBuilder::new().with_config(config).build()); + + // Similarly to regular UDFs, you create an AsyncScalarUDF by implementing + // `AsyncScalarUDFImpl` and creating an instance of `AsyncScalarUDF`. + let async_equal = AskLLM::new(); + let udf = AsyncScalarUDF::new(Arc::new(async_equal)); + + // Async UDFs are registered with the SessionContext, using the same + // `register_udf` method as regular UDFs. + ctx.register_udf(udf.into_scalar_udf()); + + // Create a table named 'animal' with some sample data + ctx.register_batch("animal", animal()?)?; + + // You can use the async UDF as normal in SQL queries + // + // Note: Async UDFs can currently be used in the select list and filter conditions. + let results = ctx + .sql("select * from animal a where ask_llm(a.name, 'Is this animal furry?')") + .await? + .collect() + .await?; + + assert_batches_eq!( + [ + "+----+------+", + "| id | name |", + "+----+------+", + "| 1 | cat |", + "| 2 | dog |", + "+----+------+", + ], + &results + ); + + // While the interface is the same for both normal and async UDFs, you can + // use `EXPLAIN` output to see that the async UDF uses a special + // `AsyncFuncExec` node in the physical plan: + let results = ctx + .sql("explain select * from animal a where ask_llm(a.name, 'Is this animal furry?')") + .await? + .collect() + .await?; + + assert_batches_eq!( + [ + "+---------------+------------------------------------------------------------------------------------------------------------------------------+", + "| plan_type | plan |", + "+---------------+------------------------------------------------------------------------------------------------------------------------------+", + "| logical_plan | SubqueryAlias: a |", + "| | Filter: ask_llm(CAST(animal.name AS Utf8View), Utf8View(\"Is this animal furry?\")) |", + "| | TableScan: animal projection=[id, name] |", + "| physical_plan | FilterExec: __async_fn_0@2, projection=[id@0, name@1] |", + "| | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 |", + "| | AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=ask_llm(CAST(name@1 AS Utf8View), Is this animal furry?))] |", + "| | DataSourceExec: partitions=1, partition_sizes=[1] |", + "| | |", + "+---------------+------------------------------------------------------------------------------------------------------------------------------+", + ], + &results + ); + + Ok(()) +} + +/// Returns a sample `RecordBatch` representing an "animal" table with two columns: +fn animal() -> Result { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, false), + ])); + + let id_array = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])); + let name_array = Arc::new(StringArray::from(vec![ + "cat", "dog", "fish", "bird", "snake", + ])); + + Ok(RecordBatch::try_new(schema, vec![id_array, name_array])?) +} + +/// An async UDF that simulates asking a large language model (LLM) service a +/// question based on the content of two columns. The UDF will return a boolean +/// indicating whether the LLM thinks the first argument matches the question in +/// the second argument. +/// +/// Since this is a simplified example, it does not call an LLM service, but +/// could be extended to do so in a real-world scenario. +#[derive(Debug, PartialEq, Eq, Hash)] +struct AskLLM { + signature: Signature, +} + +impl Default for AskLLM { + fn default() -> Self { + Self::new() + } +} + +impl AskLLM { + pub fn new() -> Self { + Self { + signature: Signature::exact( + vec![DataType::Utf8View, DataType::Utf8View], + Volatility::Volatile, + ), + } + } +} + +/// All async UDFs implement the `ScalarUDFImpl` trait, which provides the basic +/// information for the function, such as its name, signature, and return type. +/// [async_trait] +impl ScalarUDFImpl for AskLLM { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "ask_llm" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Boolean) + } + + /// Since this is an async UDF, the `invoke_with_args` method will not be + /// called directly. + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + not_impl_err!("AskLLM can only be called from async contexts") + } +} + +/// In addition to [`ScalarUDFImpl`], we also need to implement the +/// [`AsyncScalarUDFImpl`] trait. +#[async_trait] +impl AsyncScalarUDFImpl for AskLLM { + /// The `invoke_async_with_args` method is similar to `invoke_with_args`, + /// but it returns a `Future` that resolves to the result. + /// + /// Since this signature is `async`, it can do any `async` operations, such + /// as network requests. This method is run on the same tokio `Runtime` that + /// is processing the query, so you may wish to make actual network requests + /// on a different `Runtime`, as explained in the `thread_pools.rs` example + /// in this directory. + async fn invoke_async_with_args( + &self, + args: ScalarFunctionArgs, + ) -> Result { + // in a real UDF you would likely want to special case constant + // arguments to improve performance, but this example converts the + // arguments to arrays for simplicity. + let args = ColumnarValue::values_to_arrays(&args.args)?; + let [content_column, question_column] = take_function_args(self.name(), args)?; + + // In a real function, you would use a library such as `reqwest` here to + // make an async HTTP request. Credentials and other configurations can + // be supplied via the `ConfigOptions` parameter. + + // In this example, we will simulate the LLM response by comparing the two + // input arguments using some static strings + let content_column = as_string_view_array(&content_column)?; + let question_column = as_string_view_array(&question_column)?; + + let result_array: BooleanArray = content_column + .iter() + .zip(question_column.iter()) + .map(|(a, b)| { + // If either value is null, return None + let a = a?; + let b = b?; + // Simulate an LLM response by checking the arguments to some + // hardcoded conditions. + if a.contains("cat") && b.contains("furry") + || a.contains("dog") && b.contains("furry") + { + Some(true) + } else { + Some(false) + } + }) + .collect(); + + Ok(ColumnarValue::from(Arc::new(result_array) as ArrayRef)) + } +} diff --git a/datafusion-examples/examples/udf/main.rs b/datafusion-examples/examples/udf/main.rs new file mode 100644 index 0000000000000..e024e466ab07e --- /dev/null +++ b/datafusion-examples/examples/udf/main.rs @@ -0,0 +1,125 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # User-Defined Functions Examples +//! +//! These examples demonstrate user-defined functions in DataFusion. +//! +//! ## Usage +//! ```bash +//! cargo run --example udf -- [all|adv_udaf|adv_udf|adv_udwf|async_udf|udaf|udf|udtf|udwf] +//! ``` +//! +//! Each subcommand runs a corresponding example: +//! - `all` — run all examples included in this module +//! +//! - `adv_udaf` +//! (file: advanced_udaf.rs, desc: Advanced User Defined Aggregate Function (UDAF)) +//! +//! - `adv_udf` +//! (file: advanced_udf.rs, desc: Advanced User Defined Scalar Function (UDF)) +//! +//! - `adv_udwf` +//! (file: advanced_udwf.rs, desc: Advanced User Defined Window Function (UDWF)) +//! +//! - `async_udf` +//! (file: async_udf.rs, desc: Asynchronous User Defined Scalar Function) +//! +//! - `udaf` +//! (file: simple_udaf.rs, desc: Simple UDAF example) +//! +//! - `udf` +//! (file: simple_udf.rs, desc: Simple UDF example) +//! +//! - `udtf` +//! (file: simple_udtf.rs, desc: Simple UDTF example) +//! +//! - `udwf` +//! (file: simple_udwf.rs, desc: Simple UDWF example) + +mod advanced_udaf; +mod advanced_udf; +mod advanced_udwf; +mod async_udf; +mod simple_udaf; +mod simple_udf; +mod simple_udtf; +mod simple_udwf; + +use datafusion::error::{DataFusionError, Result}; +use strum::{IntoEnumIterator, VariantNames}; +use strum_macros::{Display, EnumIter, EnumString, VariantNames}; + +#[derive(EnumIter, EnumString, Display, VariantNames)] +#[strum(serialize_all = "snake_case")] +enum ExampleKind { + All, + AdvUdaf, + AdvUdf, + AdvUdwf, + AsyncUdf, + Udf, + Udaf, + Udwf, + Udtf, +} + +impl ExampleKind { + const EXAMPLE_NAME: &str = "udf"; + + fn runnable() -> impl Iterator { + ExampleKind::iter().filter(|v| !matches!(v, ExampleKind::All)) + } + + async fn run(&self) -> Result<()> { + match self { + ExampleKind::All => { + for example in ExampleKind::runnable() { + println!("Running example: {example}"); + Box::pin(example.run()).await?; + } + } + ExampleKind::AdvUdaf => advanced_udaf::advanced_udaf().await?, + ExampleKind::AdvUdf => advanced_udf::advanced_udf().await?, + ExampleKind::AdvUdwf => advanced_udwf::advanced_udwf().await?, + ExampleKind::AsyncUdf => async_udf::async_udf().await?, + ExampleKind::Udaf => simple_udaf::simple_udaf().await?, + ExampleKind::Udf => simple_udf::simple_udf().await?, + ExampleKind::Udtf => simple_udtf::simple_udtf().await?, + ExampleKind::Udwf => simple_udwf::simple_udwf().await?, + } + + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let usage = format!( + "Usage: cargo run --example {} -- [{}]", + ExampleKind::EXAMPLE_NAME, + ExampleKind::VARIANTS.join("|") + ); + + let example: ExampleKind = std::env::args() + .nth(1) + .unwrap_or_else(|| ExampleKind::All.to_string()) + .parse() + .map_err(|_| DataFusionError::Execution(format!("Unknown example. {usage}")))?; + + example.run().await +} diff --git a/datafusion-examples/examples/simple_udaf.rs b/datafusion-examples/examples/udf/simple_udaf.rs similarity index 96% rename from datafusion-examples/examples/simple_udaf.rs rename to datafusion-examples/examples/udf/simple_udaf.rs index 82bde7c034a57..42ea0054b759f 100644 --- a/datafusion-examples/examples/simple_udaf.rs +++ b/datafusion-examples/examples/udf/simple_udaf.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. +//! /// In this example we will declare a single-type, single return type UDAF that computes the geometric mean. /// The geometric mean is described here: https://en.wikipedia.org/wiki/Geometric_mean use datafusion::arrow::{ @@ -135,8 +137,9 @@ impl Accumulator for GeometricMean { } } -#[tokio::main] -async fn main() -> Result<()> { +/// In this example we register `GeometricMean` +/// as user defined aggregate function and invoke it via the DataFrame API and SQL +pub async fn simple_udaf() -> Result<()> { let ctx = create_context()?; // here is where we define the UDAF. We also declare its signature: diff --git a/datafusion-examples/examples/simple_udf.rs b/datafusion-examples/examples/udf/simple_udf.rs similarity index 98% rename from datafusion-examples/examples/simple_udf.rs rename to datafusion-examples/examples/udf/simple_udf.rs index 5612e0939f709..e8d6c9c8173ac 100644 --- a/datafusion-examples/examples/simple_udf.rs +++ b/datafusion-examples/examples/udf/simple_udf.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! See `main.rs` for how to run it. + use datafusion::{ arrow::{ array::{ArrayRef, Float32Array, Float64Array}, @@ -57,8 +59,7 @@ fn create_context() -> Result { } /// In this example we will declare a single-type, single return type UDF that exponentiates f64, a^b -#[tokio::main] -async fn main() -> Result<()> { +pub async fn simple_udf() -> Result<()> { let ctx = create_context()?; // First, declare the actual implementation of the calculation diff --git a/datafusion-examples/examples/simple_udtf.rs b/datafusion-examples/examples/udf/simple_udtf.rs similarity index 87% rename from datafusion-examples/examples/simple_udtf.rs rename to datafusion-examples/examples/udf/simple_udtf.rs index d2b2d1bf96551..ee2615c4a5ac1 100644 --- a/datafusion-examples/examples/simple_udtf.rs +++ b/datafusion-examples/examples/udf/simple_udtf.rs @@ -15,53 +15,56 @@ // specific language governing permissions and limitations // under the License. -use arrow::csv::reader::Format; +//! See `main.rs` for how to run it. + +use std::fs::File; +use std::io::Seek; +use std::path::Path; +use std::sync::Arc; + use arrow::csv::ReaderBuilder; +use arrow::csv::reader::Format; use async_trait::async_trait; use datafusion::arrow::datatypes::SchemaRef; use datafusion::arrow::record_batch::RecordBatch; -use datafusion::catalog::Session; -use datafusion::catalog::TableFunctionImpl; -use datafusion::common::{plan_err, ScalarValue}; -use datafusion::datasource::memory::MemorySourceConfig; +use datafusion::catalog::{Session, TableFunctionImpl}; +use datafusion::common::{ScalarValue, plan_err}; use datafusion::datasource::TableProvider; +use datafusion::datasource::memory::MemorySourceConfig; use datafusion::error::Result; -use datafusion::execution::context::ExecutionProps; use datafusion::logical_expr::simplify::SimplifyContext; use datafusion::logical_expr::{Expr, TableType}; use datafusion::optimizer::simplify_expressions::ExprSimplifier; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::*; -use std::fs::File; -use std::io::Seek; -use std::path::Path; -use std::sync::Arc; +use datafusion_examples::utils::datasets::ExampleDataset; + // To define your own table function, you only need to do the following 3 things: // 1. Implement your own [`TableProvider`] // 2. Implement your own [`TableFunctionImpl`] and return your [`TableProvider`] // 3. Register the function using [`SessionContext::register_udtf`] /// This example demonstrates how to register a TableFunction -#[tokio::main] -async fn main() -> Result<()> { +pub async fn simple_udtf() -> Result<()> { // create local execution context let ctx = SessionContext::new(); // register the table function that will be called in SQL statements by `read_csv` ctx.register_udtf("read_csv", Arc::new(LocalCsvTableFunc {})); - let testdata = datafusion::test_util::arrow_test_data(); - let csv_file = format!("{testdata}/csv/aggregate_test_100.csv"); + let dataset = ExampleDataset::Cars; // Pass 2 arguments, read csv with at most 2 rows (simplify logic makes 1+1 --> 2) let df = ctx - .sql(format!("SELECT * FROM read_csv('{csv_file}', 1 + 1);").as_str()) + .sql( + format!("SELECT * FROM read_csv('{}', 1 + 1);", dataset.path_str()?).as_str(), + ) .await?; df.show().await?; // just run, return all rows let df = ctx - .sql(format!("SELECT * FROM read_csv('{csv_file}');").as_str()) + .sql(format!("SELECT * FROM read_csv('{}');", dataset.path_str()?).as_str()) .await?; df.show().await?; @@ -133,7 +136,7 @@ struct LocalCsvTableFunc {} impl TableFunctionImpl for LocalCsvTableFunc { fn call(&self, exprs: &[Expr]) -> Result> { - let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)))) = exprs.first() else { + let Some(Expr::Literal(ScalarValue::Utf8(Some(path)), _)) = exprs.first() else { return plan_err!("read_csv requires at least one string argument"); }; @@ -141,11 +144,10 @@ impl TableFunctionImpl for LocalCsvTableFunc { .get(1) .map(|expr| { // try to simplify the expression, so 1+2 becomes 3, for example - let execution_props = ExecutionProps::new(); - let info = SimplifyContext::new(&execution_props); + let info = SimplifyContext::default(); let expr = ExprSimplifier::new(info).simplify(expr.clone())?; - if let Expr::Literal(ScalarValue::Int64(Some(limit))) = expr { + if let Expr::Literal(ScalarValue::Int64(Some(limit)), _) = expr { Ok(limit as usize) } else { plan_err!("Limit must be an integer") diff --git a/datafusion-examples/examples/simple_udwf.rs b/datafusion-examples/examples/udf/simple_udwf.rs similarity index 79% rename from datafusion-examples/examples/simple_udwf.rs rename to datafusion-examples/examples/udf/simple_udwf.rs index 1736ff00bd700..1842d88b9ba29 100644 --- a/datafusion-examples/examples/simple_udwf.rs +++ b/datafusion-examples/examples/udf/simple_udwf.rs @@ -15,35 +15,70 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; +//! See `main.rs` for how to run it. + +use std::{fs::File, io::Write, sync::Arc}; use arrow::{ array::{ArrayRef, AsArray, Float64Array}, datatypes::{DataType, Float64Type}, }; - use datafusion::common::ScalarValue; use datafusion::error::Result; use datafusion::logical_expr::{PartitionEvaluator, Volatility, WindowFrame}; use datafusion::prelude::*; +use tempfile::tempdir; // create local execution context with `cars.csv` registered as a table named `cars` async fn create_context() -> Result { // declare a new context. In spark API, this corresponds to a new spark SQL session let ctx = SessionContext::new(); - // declare a table in memory. In spark API, this corresponds to createDataFrame(...). - println!("pwd: {}", std::env::current_dir().unwrap().display()); - let csv_path = "../../datafusion/core/tests/data/cars.csv".to_string(); - let read_options = CsvReadOptions::default().has_header(true); + // content from file 'datafusion/core/tests/data/cars.csv' + let csv_data = r#"car,speed,time +red,20.0,1996-04-12T12:05:03.000000000 +red,20.3,1996-04-12T12:05:04.000000000 +red,21.4,1996-04-12T12:05:05.000000000 +red,21.5,1996-04-12T12:05:06.000000000 +red,19.0,1996-04-12T12:05:07.000000000 +red,18.0,1996-04-12T12:05:08.000000000 +red,17.0,1996-04-12T12:05:09.000000000 +red,7.0,1996-04-12T12:05:10.000000000 +red,7.1,1996-04-12T12:05:11.000000000 +red,7.2,1996-04-12T12:05:12.000000000 +red,3.0,1996-04-12T12:05:13.000000000 +red,1.0,1996-04-12T12:05:14.000000000 +red,0.0,1996-04-12T12:05:15.000000000 +green,10.0,1996-04-12T12:05:03.000000000 +green,10.3,1996-04-12T12:05:04.000000000 +green,10.4,1996-04-12T12:05:05.000000000 +green,10.5,1996-04-12T12:05:06.000000000 +green,11.0,1996-04-12T12:05:07.000000000 +green,12.0,1996-04-12T12:05:08.000000000 +green,14.0,1996-04-12T12:05:09.000000000 +green,15.0,1996-04-12T12:05:10.000000000 +green,15.1,1996-04-12T12:05:11.000000000 +green,15.2,1996-04-12T12:05:12.000000000 +green,8.0,1996-04-12T12:05:13.000000000 +green,2.0,1996-04-12T12:05:14.000000000 +"#; + let dir = tempdir()?; + let file_path = dir.path().join("cars.csv"); + { + let mut file = File::create(&file_path)?; + // write CSV data + file.write_all(csv_data.as_bytes())?; + } // scope closes the file + let file_path = file_path.to_str().unwrap(); + + ctx.register_csv("cars", file_path, CsvReadOptions::new()) + .await?; - ctx.register_csv("cars", &csv_path, read_options).await?; Ok(ctx) } /// In this example we will declare a user defined window function that computes a moving average and then run it using SQL -#[tokio::main] -async fn main() -> Result<()> { +pub async fn simple_udwf() -> Result<()> { let ctx = create_context().await?; // here is where we define the UDWF. We also declare its signature: diff --git a/datafusion-examples/src/bin/examples-docs.rs b/datafusion-examples/src/bin/examples-docs.rs new file mode 100644 index 0000000000000..7efcf4da15d20 --- /dev/null +++ b/datafusion-examples/src/bin/examples-docs.rs @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Generates Markdown documentation for DataFusion example groups. +//! +//! This binary scans `datafusion-examples/examples`, extracts structured +//! documentation from each group's `main.rs` file, and renders a README-style +//! Markdown document. +//! +//! By default, documentation is generated for all example groups. If a group +//! name is provided as the first CLI argument, only that group is rendered. +//! +//! ## Usage +//! +//! ```bash +//! # Generate docs for all example groups +//! cargo run --bin examples-docs +//! +//! # Generate docs for a single group +//! cargo run --bin examples-docs -- dataframe +//! ``` + +use datafusion_examples::utils::example_metadata::{ + RepoLayout, generate_examples_readme, +}; + +fn main() -> Result<(), Box> { + let layout = RepoLayout::detect()?; + let group = std::env::args().nth(1); + let markdown = generate_examples_readme(&layout, group.as_deref())?; + print!("{markdown}"); + Ok(()) +} diff --git a/datafusion-examples/src/lib.rs b/datafusion-examples/src/lib.rs new file mode 100644 index 0000000000000..7f334aedaafe2 --- /dev/null +++ b/datafusion-examples/src/lib.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Internal utilities shared by the DataFusion examples. + +pub mod utils; diff --git a/datafusion-examples/src/utils/csv_to_parquet.rs b/datafusion-examples/src/utils/csv_to_parquet.rs new file mode 100644 index 0000000000000..1fbf2930e9043 --- /dev/null +++ b/datafusion-examples/src/utils/csv_to_parquet.rs @@ -0,0 +1,244 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::path::{Path, PathBuf}; + +use datafusion::dataframe::DataFrameWriteOptions; +use datafusion::error::{DataFusionError, Result}; +use datafusion::prelude::{CsvReadOptions, SessionContext}; +use tempfile::TempDir; +use tokio::fs::create_dir_all; + +/// Temporary Parquet directory that is deleted when dropped. +#[derive(Debug)] +pub struct ParquetTemp { + pub tmp_dir: TempDir, + pub parquet_dir: PathBuf, +} + +impl ParquetTemp { + pub fn path(&self) -> &Path { + &self.parquet_dir + } + + pub fn path_str(&self) -> Result<&str> { + self.parquet_dir.to_str().ok_or_else(|| { + DataFusionError::Execution(format!( + "Parquet directory path is not valid UTF-8: {}", + self.parquet_dir.display() + )) + }) + } + + pub fn file_uri(&self) -> Result { + Ok(format!("file://{}", self.path_str()?)) + } +} + +/// Helper for examples: load a CSV file and materialize it as Parquet +/// in a temporary directory. +/// +/// # Example +/// ``` +/// use std::path::PathBuf; +/// use datafusion::prelude::*; +/// use datafusion_examples::utils::write_csv_to_parquet; +/// # use datafusion::assert_batches_eq; +/// # use datafusion::error::Result; +/// # #[tokio::main] +/// # async fn main() -> Result<()> { +/// let ctx = SessionContext::new(); +/// let csv_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) +/// .join("data") +/// .join("csv") +/// .join("cars.csv"); +/// let parquet_dir = write_csv_to_parquet(&ctx, &csv_path).await?; +/// let df = ctx.read_parquet(parquet_dir.path_str()?, ParquetReadOptions::default()).await?; +/// let rows = df +/// .sort(vec![col("speed").sort(true, true)])? +/// .limit(0, Some(5))?; +/// assert_batches_eq!( +/// &[ +/// "+-------+-------+---------------------+", +/// "| car | speed | time |", +/// "+-------+-------+---------------------+", +/// "| red | 0.0 | 1996-04-12T12:05:15 |", +/// "| red | 1.0 | 1996-04-12T12:05:14 |", +/// "| green | 2.0 | 1996-04-12T12:05:14 |", +/// "| red | 3.0 | 1996-04-12T12:05:13 |", +/// "| red | 7.0 | 1996-04-12T12:05:10 |", +/// "+-------+-------+---------------------+", +/// ], +/// &rows.collect().await? +/// ); +/// # Ok(()) +/// # } +/// ``` +pub async fn write_csv_to_parquet( + ctx: &SessionContext, + csv_path: &Path, +) -> Result { + if !csv_path.is_file() { + return Err(DataFusionError::Execution(format!( + "CSV file does not exist: {}", + csv_path.display() + ))); + } + + let csv_path = csv_path.to_str().ok_or_else(|| { + DataFusionError::Execution("CSV path is not valid UTF-8".to_string()) + })?; + + let csv_df = ctx.read_csv(csv_path, CsvReadOptions::default()).await?; + + let tmp_dir = TempDir::new()?; + let parquet_dir = tmp_dir.path().join("parquet_source"); + create_dir_all(&parquet_dir).await?; + + let path = parquet_dir.to_str().ok_or_else(|| { + DataFusionError::Execution("Failed processing tmp directory path".to_string()) + })?; + + csv_df + .write_parquet(path, DataFrameWriteOptions::default(), None) + .await?; + + Ok(ParquetTemp { + tmp_dir, + parquet_dir, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::path::PathBuf; + + use datafusion::assert_batches_eq; + use datafusion::prelude::*; + + #[tokio::test] + async fn test_write_csv_to_parquet_with_cars_data() -> Result<()> { + let ctx = SessionContext::new(); + let csv_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("data") + .join("csv") + .join("cars.csv"); + + let parquet_dir = write_csv_to_parquet(&ctx, &csv_path).await?; + let df = ctx + .read_parquet(parquet_dir.path_str()?, ParquetReadOptions::default()) + .await?; + + let rows = df.sort(vec![col("speed").sort(true, true)])?; + assert_batches_eq!( + &[ + "+-------+-------+---------------------+", + "| car | speed | time |", + "+-------+-------+---------------------+", + "| red | 0.0 | 1996-04-12T12:05:15 |", + "| red | 1.0 | 1996-04-12T12:05:14 |", + "| green | 2.0 | 1996-04-12T12:05:14 |", + "| red | 3.0 | 1996-04-12T12:05:13 |", + "| red | 7.0 | 1996-04-12T12:05:10 |", + "| red | 7.1 | 1996-04-12T12:05:11 |", + "| red | 7.2 | 1996-04-12T12:05:12 |", + "| green | 8.0 | 1996-04-12T12:05:13 |", + "| green | 10.0 | 1996-04-12T12:05:03 |", + "| green | 10.3 | 1996-04-12T12:05:04 |", + "| green | 10.4 | 1996-04-12T12:05:05 |", + "| green | 10.5 | 1996-04-12T12:05:06 |", + "| green | 11.0 | 1996-04-12T12:05:07 |", + "| green | 12.0 | 1996-04-12T12:05:08 |", + "| green | 14.0 | 1996-04-12T12:05:09 |", + "| green | 15.0 | 1996-04-12T12:05:10 |", + "| green | 15.1 | 1996-04-12T12:05:11 |", + "| green | 15.2 | 1996-04-12T12:05:12 |", + "| red | 17.0 | 1996-04-12T12:05:09 |", + "| red | 18.0 | 1996-04-12T12:05:08 |", + "| red | 19.0 | 1996-04-12T12:05:07 |", + "| red | 20.0 | 1996-04-12T12:05:03 |", + "| red | 20.3 | 1996-04-12T12:05:04 |", + "| red | 21.4 | 1996-04-12T12:05:05 |", + "| red | 21.5 | 1996-04-12T12:05:06 |", + "+-------+-------+---------------------+", + ], + &rows.collect().await? + ); + + Ok(()) + } + + #[tokio::test] + async fn test_write_csv_to_parquet_with_regex_data() -> Result<()> { + let ctx = SessionContext::new(); + let csv_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("data") + .join("csv") + .join("regex.csv"); + + let parquet_dir = write_csv_to_parquet(&ctx, &csv_path).await?; + let df = ctx + .read_parquet(parquet_dir.path_str()?, ParquetReadOptions::default()) + .await?; + + let rows = df.sort(vec![col("values").sort(true, true)])?; + assert_batches_eq!( + &[ + "+------------+--------------------------------------+-------------+-------+", + "| values | patterns | replacement | flags |", + "+------------+--------------------------------------+-------------+-------+", + "| 4000 | \\b4([1-9]\\d\\d|\\d[1-9]\\d|\\d\\d[1-9])\\b | xyz | |", + "| 4010 | \\b4([1-9]\\d\\d|\\d[1-9]\\d|\\d\\d[1-9])\\b | xyz | |", + "| ABC | ^(A).* | B | i |", + "| AbC | (B|D) | e | |", + "| Düsseldorf | [\\p{Letter}-]+ | München | |", + "| Köln | [a-zA-Z]ö[a-zA-Z]{2} | Koln | |", + "| aBC | ^(b|c) | d | |", + "| aBc | (b|d) | e | i |", + "| abc | ^(a) | bb\\1bb | i |", + "| Москва | [\\p{L}-]+ | Moscow | |", + "| اليوم | ^\\p{Arabic}+$ | Today | |", + "+------------+--------------------------------------+-------------+-------+", + ], + &rows.collect().await? + ); + + Ok(()) + } + + #[tokio::test] + async fn test_write_csv_to_parquet_error() { + let ctx = SessionContext::new(); + let csv_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("data") + .join("csv") + .join("file-does-not-exist.csv"); + + let err = write_csv_to_parquet(&ctx, &csv_path).await.unwrap_err(); + match err { + DataFusionError::Execution(msg) => { + assert!( + msg.contains("CSV file does not exist"), + "unexpected error message: {msg}" + ); + } + other => panic!("unexpected error variant: {other:?}"), + } + } +} diff --git a/datafusion-examples/src/utils/datasets/cars.rs b/datafusion-examples/src/utils/datasets/cars.rs new file mode 100644 index 0000000000000..2d8547c16d686 --- /dev/null +++ b/datafusion-examples/src/utils/datasets/cars.rs @@ -0,0 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; + +/// Schema for the `data/csv/cars.csv` example dataset. +pub fn schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("car", DataType::Utf8, false), + Field::new("speed", DataType::Float64, false), + Field::new( + "time", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + ), + ])) +} diff --git a/datafusion-examples/src/utils/datasets/mod.rs b/datafusion-examples/src/utils/datasets/mod.rs new file mode 100644 index 0000000000000..1857e6af9b559 --- /dev/null +++ b/datafusion-examples/src/utils/datasets/mod.rs @@ -0,0 +1,139 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::path::PathBuf; + +use arrow_schema::SchemaRef; +use datafusion::error::{DataFusionError, Result}; + +pub mod cars; +pub mod regex; + +/// Describes example datasets used across DataFusion examples. +/// +/// This enum provides a single, discoverable place to define +/// dataset-specific metadata such as file paths and schemas. +#[derive(Debug)] +pub enum ExampleDataset { + Cars, + Regex, +} + +impl ExampleDataset { + pub fn file_stem(&self) -> &'static str { + match self { + Self::Cars => "cars", + Self::Regex => "regex", + } + } + + pub fn path(&self) -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("data") + .join("csv") + .join(format!("{}.csv", self.file_stem())) + } + + pub fn path_str(&self) -> Result { + let path = self.path(); + path.to_str().map(String::from).ok_or_else(|| { + DataFusionError::Execution(format!( + "CSV directory path is not valid UTF-8: {}", + path.display() + )) + }) + } + + pub fn schema(&self) -> SchemaRef { + match self { + Self::Cars => cars::schema(), + Self::Regex => regex::schema(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use arrow::datatypes::{DataType, TimeUnit}; + + #[test] + fn example_dataset_file_stem() { + assert_eq!(ExampleDataset::Cars.file_stem(), "cars"); + assert_eq!(ExampleDataset::Regex.file_stem(), "regex"); + } + + #[test] + fn example_dataset_path_points_to_csv() { + let path = ExampleDataset::Cars.path(); + assert!(path.ends_with("data/csv/cars.csv")); + + let path = ExampleDataset::Regex.path(); + assert!(path.ends_with("data/csv/regex.csv")); + } + + #[test] + fn example_dataset_path_str_is_valid_utf8() { + let path = ExampleDataset::Cars.path_str().unwrap(); + assert!(path.ends_with("cars.csv")); + + let path = ExampleDataset::Regex.path_str().unwrap(); + assert!(path.ends_with("regex.csv")); + } + + #[test] + fn cars_schema_is_stable() { + let schema = ExampleDataset::Cars.schema(); + + let fields: Vec<_> = schema + .fields() + .iter() + .map(|f| (f.name().as_str(), f.data_type().clone())) + .collect(); + + assert_eq!( + fields, + vec![ + ("car", DataType::Utf8), + ("speed", DataType::Float64), + ("time", DataType::Timestamp(TimeUnit::Nanosecond, None)), + ] + ); + } + + #[test] + fn regex_schema_is_stable() { + let schema = ExampleDataset::Regex.schema(); + + let fields: Vec<_> = schema + .fields() + .iter() + .map(|f| (f.name().as_str(), f.data_type().clone())) + .collect(); + + assert_eq!( + fields, + vec![ + ("values", DataType::Utf8), + ("patterns", DataType::Utf8), + ("replacement", DataType::Utf8), + ("flags", DataType::Utf8), + ] + ); + } +} diff --git a/datafusion-examples/src/utils/datasets/regex.rs b/datafusion-examples/src/utils/datasets/regex.rs new file mode 100644 index 0000000000000..d44582126a053 --- /dev/null +++ b/datafusion-examples/src/utils/datasets/regex.rs @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, Schema}; + +/// Schema for the `data/csv/regex.csv` example dataset. +pub fn schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("values", DataType::Utf8, false), + Field::new("patterns", DataType::Utf8, false), + Field::new("replacement", DataType::Utf8, false), + Field::new("flags", DataType::Utf8, true), + ])) +} diff --git a/datafusion-examples/src/utils/example_metadata/discover.rs b/datafusion-examples/src/utils/example_metadata/discover.rs new file mode 100644 index 0000000000000..1ba5f6d29a14e --- /dev/null +++ b/datafusion-examples/src/utils/example_metadata/discover.rs @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utilities for discovering example groups in the repository filesystem. +//! +//! An example group is defined as a directory containing a `main.rs` file +//! under the examples root. This module is intentionally filesystem-focused +//! and does not perform any parsing or rendering. +//! Discovery fails if no valid example groups are found. + +use std::fs; +use std::path::{Path, PathBuf}; + +use datafusion::common::exec_err; +use datafusion::error::Result; + +/// Discovers all example group directories under the given root. +/// +/// A directory is considered an example group if it contains a `main.rs` file. +pub fn discover_example_groups(root: &Path) -> Result> { + let mut groups = Vec::new(); + for entry in fs::read_dir(root)? { + let entry = entry?; + let path = entry.path(); + + if path.is_dir() && path.join("main.rs").is_file() { + groups.push(path); + } + } + + if groups.is_empty() { + return exec_err!("No example groups found under: {}", root.display()); + } + + groups.sort(); + Ok(groups) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::utils::example_metadata::test_utils::assert_exec_err_contains; + + use std::fs::{self, File}; + + use tempfile::TempDir; + + #[test] + fn discover_example_groups_finds_dirs_with_main_rs() -> Result<()> { + let tmp = TempDir::new()?; + let root = tmp.path(); + + // valid example group + let group1 = root.join("group1"); + fs::create_dir(&group1)?; + File::create(group1.join("main.rs"))?; + + // not an example group + let group2 = root.join("group2"); + fs::create_dir(&group2)?; + + let groups = discover_example_groups(root)?; + assert_eq!(groups.len(), 1); + assert_eq!(groups[0], group1); + Ok(()) + } + + #[test] + fn discover_example_groups_errors_if_main_rs_is_a_directory() -> Result<()> { + let tmp = TempDir::new()?; + let root = tmp.path(); + let group = root.join("group"); + fs::create_dir(&group)?; + fs::create_dir(group.join("main.rs"))?; + + let err = discover_example_groups(root).unwrap_err(); + assert_exec_err_contains(err, "No example groups found"); + Ok(()) + } + + #[test] + fn discover_example_groups_errors_if_none_found() -> Result<()> { + let tmp = TempDir::new()?; + let err = discover_example_groups(tmp.path()).unwrap_err(); + assert_exec_err_contains(err, "No example groups found"); + Ok(()) + } +} diff --git a/datafusion-examples/src/utils/example_metadata/layout.rs b/datafusion-examples/src/utils/example_metadata/layout.rs new file mode 100644 index 0000000000000..ee6fad89855f9 --- /dev/null +++ b/datafusion-examples/src/utils/example_metadata/layout.rs @@ -0,0 +1,113 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Repository layout utilities. +//! +//! This module provides a small helper (`RepoLayout`) that encapsulates +//! knowledge about the DataFusion repository structure, in particular +//! where example groups are located relative to the repository root. + +use std::path::{Path, PathBuf}; + +use datafusion::error::{DataFusionError, Result}; + +/// Describes the layout of a DataFusion repository. +/// +/// This type centralizes knowledge about where example-related +/// directories live relative to the repository root. +#[derive(Debug, Clone)] +pub struct RepoLayout { + root: PathBuf, +} + +impl From<&Path> for RepoLayout { + fn from(path: &Path) -> Self { + Self { + root: path.to_path_buf(), + } + } +} + +impl RepoLayout { + /// Creates a layout from an explicit repository root. + pub fn from_root(root: PathBuf) -> Self { + Self { root } + } + + /// Detects the repository root based on `CARGO_MANIFEST_DIR`. + /// + /// This is intended for use from binaries inside the workspace. + pub fn detect() -> Result { + let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + + let root = manifest_dir.parent().ok_or_else(|| { + DataFusionError::Execution( + "CARGO_MANIFEST_DIR does not have a parent".to_string(), + ) + })?; + + Ok(Self { + root: root.to_path_buf(), + }) + } + + /// Returns the repository root directory. + pub fn root(&self) -> &Path { + &self.root + } + + /// Returns the `datafusion-examples/examples` directory. + pub fn examples_root(&self) -> PathBuf { + self.root.join("datafusion-examples").join("examples") + } + + /// Returns the directory for a single example group. + /// + /// Example: `examples/udf` + pub fn example_group_dir(&self, group: &str) -> PathBuf { + self.examples_root().join(group) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn detect_sets_non_empty_root() -> Result<()> { + let layout = RepoLayout::detect()?; + assert!(!layout.root().as_os_str().is_empty()); + Ok(()) + } + + #[test] + fn examples_root_is_under_repo_root() -> Result<()> { + let layout = RepoLayout::detect()?; + let examples_root = layout.examples_root(); + assert!(examples_root.starts_with(layout.root())); + assert!(examples_root.ends_with("datafusion-examples/examples")); + Ok(()) + } + + #[test] + fn example_group_dir_appends_group_name() -> Result<()> { + let layout = RepoLayout::detect()?; + let group_dir = layout.example_group_dir("foo"); + assert!(group_dir.ends_with("datafusion-examples/examples/foo")); + Ok(()) + } +} diff --git a/datafusion-examples/src/utils/example_metadata/mod.rs b/datafusion-examples/src/utils/example_metadata/mod.rs new file mode 100644 index 0000000000000..ab4c8e4a8e4c2 --- /dev/null +++ b/datafusion-examples/src/utils/example_metadata/mod.rs @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Documentation generator for DataFusion examples. +//! +//! # Design goals +//! +//! - Keep README.md in sync with runnable examples +//! - Fail fast on malformed documentation +//! +//! # Overview +//! +//! Each example group corresponds to a directory under +//! `datafusion-examples/examples/` containing a `main.rs` file. +//! Documentation is extracted from structured `//!` comments in that file. +//! +//! For each example group, the generator produces: +//! +//! ```text +//! ## Examples +//! ### Group: `` +//! #### Category: Single Process | Distributed +//! +//! | Subcommand | File Path | Description | +//! ``` +//! +//! # Usage +//! +//! Generate documentation for a single group only: +//! +//! ```bash +//! cargo run --bin examples-docs -- dataframe +//! ``` +//! +//! Generate documentation for all examples: +//! +//! ```bash +//! cargo run --bin examples-docs +//! ``` + +pub mod discover; +pub mod layout; +pub mod model; +pub mod parser; +pub mod render; + +#[cfg(test)] +pub mod test_utils; + +pub use layout::RepoLayout; +pub use model::{Category, ExampleEntry, ExampleGroup, GroupName}; +pub use parser::parse_main_rs_docs; +pub use render::generate_examples_readme; diff --git a/datafusion-examples/src/utils/example_metadata/model.rs b/datafusion-examples/src/utils/example_metadata/model.rs new file mode 100644 index 0000000000000..11416d141eb74 --- /dev/null +++ b/datafusion-examples/src/utils/example_metadata/model.rs @@ -0,0 +1,418 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Domain model for DataFusion example documentation. +//! +//! This module defines the core data structures used to represent +//! example groups, individual examples, and their categorization +//! as parsed from `main.rs` documentation comments. + +use std::path::Path; + +use datafusion::error::{DataFusionError, Result}; + +use crate::utils::example_metadata::parse_main_rs_docs; + +/// Well-known abbreviations used to preserve correct capitalization +/// when generating human-readable documentation titles. +const ABBREVIATIONS: &[(&str, &str)] = &[ + ("dataframe", "DataFrame"), + ("io", "IO"), + ("sql", "SQL"), + ("udf", "UDF"), +]; + +/// A group of related examples (e.g. `builtin_functions`, `udf`). +/// +/// Each group corresponds to a directory containing a `main.rs` file +/// with structured documentation comments. +#[derive(Debug)] +pub struct ExampleGroup { + pub name: GroupName, + pub examples: Vec, + pub category: Category, +} + +impl ExampleGroup { + /// Parses an example group from its directory. + /// + /// The group name is derived from the directory name, and example + /// entries are extracted from `main.rs`. + pub fn from_dir(dir: &Path, category: Category) -> Result { + let raw_name = dir + .file_name() + .and_then(|s| s.to_str()) + .ok_or_else(|| { + DataFusionError::Execution("Invalid example group dir".to_string()) + })? + .to_string(); + + let name = GroupName::from_dir_name(raw_name); + let main_rs = dir.join("main.rs"); + let examples = parse_main_rs_docs(&main_rs)?; + + Ok(Self { + name, + examples, + category, + }) + } +} + +/// Represents an example group name in both raw and human-readable forms. +/// +/// For example: +/// - raw: `builtin_functions` +/// - title: `Builtin Functions` +#[derive(Debug)] +pub struct GroupName { + raw: String, + title: String, +} + +impl GroupName { + /// Creates a group name from a directory name. + pub fn from_dir_name(raw: String) -> Self { + let title = raw + .split('_') + .map(format_part) + .collect::>() + .join(" "); + + Self { raw, title } + } + + /// Returns the raw group name (directory name). + pub fn raw(&self) -> &str { + &self.raw + } + + /// Returns a title-cased name for documentation. + pub fn title(&self) -> &str { + &self.title + } +} + +/// A single runnable example within a group. +/// +/// Each entry corresponds to a subcommand documented in `main.rs`. +#[derive(Debug)] +pub struct ExampleEntry { + /// CLI subcommand name. + pub subcommand: String, + /// Rust source file name. + pub file: String, + /// Human-readable description. + pub desc: String, +} + +/// Execution category of an example group. +#[derive(Debug, Default)] +pub enum Category { + /// Runs in a single process. + #[default] + SingleProcess, + /// Requires a distributed setup. + Distributed, +} + +impl Category { + /// Returns the display name used in documentation. + pub fn name(&self) -> &str { + match self { + Self::SingleProcess => "Single Process", + Self::Distributed => "Distributed", + } + } + + /// Determines the category for a group by name. + pub fn for_group(name: &str) -> Self { + match name { + "flight" => Category::Distributed, + _ => Category::SingleProcess, + } + } +} + +/// Formats a single group-name segment for display. +/// +/// This function applies DataFusion-specific capitalization rules: +/// - Known abbreviations (e.g. `sql`, `io`, `udf`) are rendered in all caps +/// - All other segments fall back to standard Title Case +fn format_part(part: &str) -> String { + let lower = part.to_ascii_lowercase(); + + if let Some((_, replacement)) = ABBREVIATIONS.iter().find(|(k, _)| *k == lower) { + return replacement.to_string(); + } + + let mut chars = part.chars(); + match chars.next() { + Some(first) => first.to_uppercase().collect::() + chars.as_str(), + None => String::new(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::utils::example_metadata::test_utils::{ + assert_exec_err_contains, example_group_from_docs, + }; + + use std::fs; + + use tempfile::TempDir; + + #[test] + fn category_for_group_works() { + assert!(matches!( + Category::for_group("flight"), + Category::Distributed + )); + assert!(matches!( + Category::for_group("anything_else"), + Category::SingleProcess + )); + } + + #[test] + fn all_subcommand_is_ignored() -> Result<()> { + let group = example_group_from_docs( + r#" + //! - `all` — run all examples included in this module + //! + //! - `foo` + //! (file: foo.rs, desc: foo example) + "#, + )?; + assert_eq!(group.examples.len(), 1); + assert_eq!(group.examples[0].subcommand, "foo"); + Ok(()) + } + + #[test] + fn metadata_without_subcommand_fails() { + let err = example_group_from_docs("//! (file: foo.rs, desc: missing subcommand)") + .unwrap_err(); + assert_exec_err_contains(err, "Metadata without preceding subcommand"); + } + + #[test] + fn group_name_handles_abbreviations() { + assert_eq!( + GroupName::from_dir_name("dataframe".to_string()).title(), + "DataFrame" + ); + assert_eq!( + GroupName::from_dir_name("data_io".to_string()).title(), + "Data IO" + ); + assert_eq!( + GroupName::from_dir_name("sql_ops".to_string()).title(), + "SQL Ops" + ); + assert_eq!(GroupName::from_dir_name("udf".to_string()).title(), "UDF"); + } + + #[test] + fn group_name_title_cases() { + let cases = [ + ("very_long_group_name", "Very Long Group Name"), + ("foo", "Foo"), + ("dataframe", "DataFrame"), + ("data_io", "Data IO"), + ("sql_ops", "SQL Ops"), + ("udf", "UDF"), + ]; + for (input, expected) in cases { + let name = GroupName::from_dir_name(input.to_string()); + assert_eq!(name.title(), expected); + } + } + + #[test] + fn parse_group_example_works() -> Result<()> { + let tmp = TempDir::new().unwrap(); + + // Simulate: examples/builtin_functions/ + let group_dir = tmp.path().join("builtin_functions"); + fs::create_dir(&group_dir)?; + + // Write a fake main.rs with docs + let main_rs = group_dir.join("main.rs"); + fs::write( + &main_rs, + r#" + // Licensed to the Apache Software Foundation (ASF) under one + // or more contributor license agreements. See the NOTICE file + // distributed with this work for additional information + // regarding copyright ownership. The ASF licenses this file + // to you under the Apache License, Version 2.0 (the + // "License"); you may not use this file except in compliance + // with the License. You may obtain a copy of the License at + // + // http://www.apache.org/licenses/LICENSE-2.0 + // + // Unless required by applicable law or agreed to in writing, + // software distributed under the License is distributed on an + // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + // KIND, either express or implied. See the License for the + // specific language governing permissions and limitations + // under the License. + // + //! # These are miscellaneous function-related examples + //! + //! These examples demonstrate miscellaneous function-related features. + //! + //! ## Usage + //! ```bash + //! cargo run --example builtin_functions -- [all|date_time|function_factory|regexp] + //! ``` + //! + //! Each subcommand runs a corresponding example: + //! - `all` — run all examples included in this module + //! + //! - `date_time` + //! (file: date_time.rs, desc: Examples of date-time related functions and queries) + //! + //! - `function_factory` + //! (file: function_factory.rs, desc: Register `CREATE FUNCTION` handler to implement SQL macros) + //! + //! - `regexp` + //! (file: regexp.rs, desc: Examples of using regular expression functions) + "#, + )?; + + let group = ExampleGroup::from_dir(&group_dir, Category::SingleProcess)?; + + // Assert group-level data + assert_eq!(group.name.title(), "Builtin Functions"); + assert_eq!(group.examples.len(), 3); + + // Assert 1 example + assert_eq!(group.examples[0].subcommand, "date_time"); + assert_eq!(group.examples[0].file, "date_time.rs"); + assert_eq!( + group.examples[0].desc, + "Examples of date-time related functions and queries" + ); + + // Assert 2 example + assert_eq!(group.examples[1].subcommand, "function_factory"); + assert_eq!(group.examples[1].file, "function_factory.rs"); + assert_eq!( + group.examples[1].desc, + "Register `CREATE FUNCTION` handler to implement SQL macros" + ); + + // Assert 3 example + assert_eq!(group.examples[2].subcommand, "regexp"); + assert_eq!(group.examples[2].file, "regexp.rs"); + assert_eq!( + group.examples[2].desc, + "Examples of using regular expression functions" + ); + + Ok(()) + } + + #[test] + fn duplicate_metadata_without_repeating_subcommand_fails() { + let err = example_group_from_docs( + r#" + //! - `foo` + //! (file: a.rs, desc: first) + //! (file: b.rs, desc: second) + "#, + ) + .unwrap_err(); + assert_exec_err_contains(err, "Metadata without preceding subcommand"); + } + + #[test] + fn duplicate_metadata_for_same_subcommand_fails() { + let err = example_group_from_docs( + r#" + //! - `foo` + //! (file: a.rs, desc: first) + //! + //! - `foo` + //! (file: b.rs, desc: second) + "#, + ) + .unwrap_err(); + assert_exec_err_contains(err, "Duplicate metadata for subcommand `foo`"); + } + + #[test] + fn metadata_must_follow_subcommand() { + let err = example_group_from_docs( + r#" + //! - `foo` + //! some unrelated comment + //! (file: foo.rs, desc: test) + "#, + ) + .unwrap_err(); + assert_exec_err_contains(err, "Metadata without preceding subcommand"); + } + + #[test] + fn preserves_example_order_from_main_rs() -> Result<()> { + let group = example_group_from_docs( + r#" + //! - `second` + //! (file: second.rs, desc: second example) + //! + //! - `first` + //! (file: first.rs, desc: first example) + //! + //! - `third` + //! (file: third.rs, desc: third example) + "#, + )?; + + let subcommands: Vec<&str> = group + .examples + .iter() + .map(|e| e.subcommand.as_str()) + .collect(); + + assert_eq!( + subcommands, + vec!["second", "first", "third"], + "examples must preserve the order defined in main.rs" + ); + + Ok(()) + } + + #[test] + fn metadata_can_follow_blank_doc_line() -> Result<()> { + let group = example_group_from_docs( + r#" + //! - `foo` + //! + //! (file: foo.rs, desc: test) + "#, + )?; + assert_eq!(group.examples.len(), 1); + Ok(()) + } +} diff --git a/datafusion-examples/src/utils/example_metadata/parser.rs b/datafusion-examples/src/utils/example_metadata/parser.rs new file mode 100644 index 0000000000000..4ead3e5a2ae9f --- /dev/null +++ b/datafusion-examples/src/utils/example_metadata/parser.rs @@ -0,0 +1,267 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Parser for example metadata embedded in `main.rs` documentation comments. +//! +//! This module scans `//!` doc comments to extract example subcommands +//! and their associated metadata (file name and description), enforcing +//! a strict ordering and structure to avoid ambiguous documentation. + +use std::{collections::HashSet, fs, path::Path}; + +use datafusion::common::exec_err; +use datafusion::error::Result; +use nom::{ + Err, IResult, Parser, + bytes::complete::{tag, take_until, take_while}, + character::complete::multispace0, + combinator::all_consuming, + error::{Error, ErrorKind}, + sequence::{delimited, preceded}, +}; + +use crate::utils::example_metadata::ExampleEntry; + +/// Parsing state machine used while scanning `main.rs` docs. +/// +/// This makes the "subcommand - metadata" relationship explicit: +/// metadata is only valid immediately after a subcommand has been seen. +enum ParserState<'a> { + /// Not currently expecting metadata. + Idle, + /// A subcommand was just parsed; the next valid metadata (if any) + /// must belong to this subcommand. + SeenSubcommand(&'a str), +} + +/// Parses a subcommand declaration line from `main.rs` docs. +/// +/// Expected format: +/// ```text +/// //! - `` +/// ``` +fn parse_subcommand_line(input: &str) -> IResult<&str, &str> { + let parser = preceded( + multispace0, + delimited(tag("//! - `"), take_until("`"), tag("`")), + ); + all_consuming(parser).parse(input) +} + +/// Parses example metadata (file name and description) from `main.rs` docs. +/// +/// Expected format: +/// ```text +/// //! (file: .rs, desc: ) +/// ``` +fn parse_metadata_line(input: &str) -> IResult<&str, (&str, &str)> { + let parser = preceded( + multispace0, + preceded(tag("//!"), preceded(multispace0, take_while(|_| true))), + ); + let (rest, payload) = all_consuming(parser).parse(input)?; + + let content = payload + .strip_prefix("(") + .and_then(|s| s.strip_suffix(")")) + .ok_or_else(|| Err::Error(Error::new(payload, ErrorKind::Tag)))?; + + let (file, desc) = content + .strip_prefix("file:") + .ok_or_else(|| Err::Error(Error::new(payload, ErrorKind::Tag)))? + .split_once(", desc:") + .ok_or_else(|| Err::Error(Error::new(payload, ErrorKind::Tag)))?; + + Ok((rest, (file.trim(), desc.trim()))) +} + +/// Parses example entries from a group's `main.rs` file. +pub fn parse_main_rs_docs(path: &Path) -> Result> { + let content = fs::read_to_string(path)?; + let mut entries = vec![]; + let mut state = ParserState::Idle; + let mut seen_subcommands = HashSet::new(); + + for (line_no, raw_line) in content.lines().enumerate() { + let line = raw_line.trim(); + + // Try parsing subcommand, excluding `all` because it's not used in README + if let Ok((_, sub)) = parse_subcommand_line(line) { + state = if sub == "all" { + ParserState::Idle + } else { + ParserState::SeenSubcommand(sub) + }; + continue; + } + + // Try parsing metadata + if let Ok((_, (file, desc))) = parse_metadata_line(line) { + let subcommand = match state { + ParserState::SeenSubcommand(s) => s, + ParserState::Idle => { + return exec_err!( + "Metadata without preceding subcommand at {}:{}", + path.display(), + line_no + 1 + ); + } + }; + + if !seen_subcommands.insert(subcommand) { + return exec_err!("Duplicate metadata for subcommand `{subcommand}`"); + } + + entries.push(ExampleEntry { + subcommand: subcommand.to_string(), + file: file.to_string(), + desc: desc.to_string(), + }); + + state = ParserState::Idle; + continue; + } + + // If a non-blank doc line interrupts a pending subcommand, reset the state + if let ParserState::SeenSubcommand(_) = state + && is_non_blank_doc_line(line) + { + state = ParserState::Idle; + } + } + + Ok(entries) +} + +/// Returns `true` for non-blank Rust doc comment lines (`//!`). +/// +/// Used to detect when a subcommand is interrupted by unrelated documentation, +/// so metadata is only accepted immediately after a subcommand (blank doc lines +/// are allowed in between). +fn is_non_blank_doc_line(line: &str) -> bool { + line.starts_with("//!") && !line.trim_start_matches("//!").trim().is_empty() +} + +#[cfg(test)] +mod tests { + use super::*; + + use tempfile::TempDir; + + #[test] + fn parse_subcommand_line_accepts_valid_input() { + let line = "//! - `date_time`"; + let sub = parse_subcommand_line(line); + assert_eq!(sub, Ok(("", "date_time"))); + } + + #[test] + fn parse_subcommand_line_invalid_inputs() { + let err_lines = [ + "//! - ", + "//! - foo", + "//! - `foo` bar", + "//! --", + "//!-", + "//!--", + "//!", + "//", + "/", + "", + ]; + for line in err_lines { + assert!( + parse_subcommand_line(line).is_err(), + "expected error for input: {line}" + ); + } + } + + #[test] + fn parse_metadata_line_accepts_valid_input() { + let line = + "//! (file: date_time.rs, desc: Examples of date-time related functions)"; + let res = parse_metadata_line(line); + assert_eq!( + res, + Ok(( + "", + ("date_time.rs", "Examples of date-time related functions") + )) + ); + + let line = "//! (file: foo.rs, desc: Foo, bar, baz)"; + let res = parse_metadata_line(line); + assert_eq!(res, Ok(("", ("foo.rs", "Foo, bar, baz")))); + + let line = "//! (file: foo.rs, desc: Foo(FOO))"; + let res = parse_metadata_line(line); + assert_eq!(res, Ok(("", ("foo.rs", "Foo(FOO)")))); + } + + #[test] + fn parse_metadata_line_invalid_inputs() { + let bad_lines = [ + "//! (file: foo.rs)", + "//! (desc: missing file)", + "//! file: foo.rs, desc: test", + "//! file: foo.rs,desc: test", + "//! (file: foo.rs desc: test)", + "//! (file: foo.rs,desc: test)", + "//! (desc: test, file: foo.rs)", + "//! ()", + "//! (file: foo.rs, desc: test) extra", + "", + ]; + for line in bad_lines { + assert!( + parse_metadata_line(line).is_err(), + "expected error for input: {line}" + ); + } + } + + #[test] + fn parse_main_rs_docs_extracts_entries() -> Result<()> { + let tmp = TempDir::new().unwrap(); + let main_rs = tmp.path().join("main.rs"); + + fs::write( + &main_rs, + r#" + //! - `foo` + //! (file: foo.rs, desc: first example) + //! + //! - `bar` + //! (file: bar.rs, desc: second example) + "#, + )?; + + let entries = parse_main_rs_docs(&main_rs)?; + + assert_eq!(entries.len(), 2); + + assert_eq!(entries[0].subcommand, "foo"); + assert_eq!(entries[0].file, "foo.rs"); + assert_eq!(entries[0].desc, "first example"); + + assert_eq!(entries[1].subcommand, "bar"); + assert_eq!(entries[1].file, "bar.rs"); + assert_eq!(entries[1].desc, "second example"); + Ok(()) + } +} diff --git a/datafusion-examples/src/utils/example_metadata/render.rs b/datafusion-examples/src/utils/example_metadata/render.rs new file mode 100644 index 0000000000000..a4ea620e78352 --- /dev/null +++ b/datafusion-examples/src/utils/example_metadata/render.rs @@ -0,0 +1,203 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Markdown renderer for DataFusion example documentation. +//! +//! This module takes parsed example metadata and generates the +//! `README.md` content for `datafusion-examples`, including group +//! sections and example tables. + +use std::path::PathBuf; + +use datafusion::error::{DataFusionError, Result}; + +use crate::utils::example_metadata::discover::discover_example_groups; +use crate::utils::example_metadata::model::ExampleGroup; +use crate::utils::example_metadata::{Category, RepoLayout}; + +const STATIC_HEADER: &str = r#" + +# DataFusion Examples + +This crate includes end to end, highly commented examples of how to use +various DataFusion APIs to help you get started. + +## Prerequisites + +Run `git submodule update --init` to init test files. + +## Running Examples + +To run an example, use the `cargo run` command, such as: + +```bash +git clone https://github.com/apache/datafusion +cd datafusion +# Download test data +git submodule update --init + +# Change to the examples directory +cd datafusion-examples/examples + +# Run all examples in a group +cargo run --example -- all + +# Run a specific example within a group +cargo run --example -- + +# Run all examples in the `dataframe` group +cargo run --example dataframe -- all + +# Run a single example from the `dataframe` group +# (apply the same pattern for any other group) +cargo run --example dataframe -- dataframe +``` +"#; + +/// Generates Markdown documentation for DataFusion examples. +/// +/// If `group` is `None`, documentation is generated for all example groups. +/// If `group` is `Some`, only that group is rendered. +/// +/// # Errors +/// +/// Returns an error if: +/// - the requested group does not exist +/// - a `main.rs` file is missing +/// - documentation comments are malformed +pub fn generate_examples_readme( + layout: &RepoLayout, + group: Option<&str>, +) -> Result { + let examples_root = layout.examples_root(); + + let mut out = String::new(); + out.push_str(STATIC_HEADER); + + let group_dirs: Vec = match group { + Some(name) => { + let dir = examples_root.join(name); + if !dir.is_dir() { + return Err(DataFusionError::Execution(format!( + "Example group `{name}` does not exist" + ))); + } + vec![dir] + } + None => discover_example_groups(&examples_root)?, + }; + + for group_dir in group_dirs { + let raw_name = + group_dir + .file_name() + .and_then(|s| s.to_str()) + .ok_or_else(|| { + DataFusionError::Execution("Invalid example group dir".to_string()) + })?; + + let category = Category::for_group(raw_name); + let group = ExampleGroup::from_dir(&group_dir, category)?; + + out.push_str(&group.render_markdown()); + } + + Ok(out) +} + +impl ExampleGroup { + /// Renders this example group as a Markdown section for the README. + pub fn render_markdown(&self) -> String { + let mut out = String::new(); + out.push_str(&format!("\n## {} Examples\n\n", self.name.title())); + out.push_str(&format!("### Group: `{}`\n\n", self.name.raw())); + out.push_str(&format!("#### Category: {}\n\n", self.category.name())); + out.push_str("| Subcommand | File Path | Description |\n"); + out.push_str("| --- | --- | --- |\n"); + + for example in &self.examples { + out.push_str(&format!( + "| {} | [`{}/{}`](examples/{}/{}) | {} |\n", + example.subcommand, + self.name.raw(), + example.file, + self.name.raw(), + example.file, + example.desc + )); + } + + out + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::utils::example_metadata::test_utils::assert_exec_err_contains; + + use std::fs; + + use tempfile::TempDir; + + #[test] + fn single_group_generation_works() { + let tmp = TempDir::new().unwrap(); + // Fake repo root + let layout = RepoLayout::from_root(tmp.path().to_path_buf()); + + // Create: datafusion-examples/examples/builtin_functions + let examples_dir = layout.example_group_dir("builtin_functions"); + fs::create_dir_all(&examples_dir).unwrap(); + + fs::write( + examples_dir.join("main.rs"), + "//! - `x`\n//! (file: foo.rs, desc: test)", + ) + .unwrap(); + + let out = generate_examples_readme(&layout, Some("builtin_functions")).unwrap(); + assert!(out.contains("Builtin Functions")); + assert!(out.contains("| x | [`builtin_functions/foo.rs`]")); + } + + #[test] + fn single_group_generation_fails_if_group_missing() { + let tmp = TempDir::new().unwrap(); + let layout = RepoLayout::from_root(tmp.path().to_path_buf()); + let err = generate_examples_readme(&layout, Some("missing_group")).unwrap_err(); + assert_exec_err_contains(err, "Example group `missing_group` does not exist"); + } +} diff --git a/datafusion-examples/src/utils/example_metadata/test_utils.rs b/datafusion-examples/src/utils/example_metadata/test_utils.rs new file mode 100644 index 0000000000000..d6ab3b06ba06d --- /dev/null +++ b/datafusion-examples/src/utils/example_metadata/test_utils.rs @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Test helpers for example metadata parsing and validation. +//! +//! This module provides small, focused utilities to reduce duplication +//! and keep tests readable across the example metadata submodules. + +use std::fs; + +use datafusion::error::{DataFusionError, Result}; +use tempfile::TempDir; + +use crate::utils::example_metadata::{Category, ExampleGroup}; + +/// Asserts that an `Execution` error contains the expected message fragment. +/// +/// Keeps tests focused on semantic error causes without coupling them +/// to full error string formatting. +pub fn assert_exec_err_contains(err: DataFusionError, needle: &str) { + match err { + DataFusionError::Execution(msg) => { + assert!( + msg.contains(needle), + "expected '{needle}' in error message, got: {msg}" + ); + } + other => panic!("expected Execution error, got: {other:?}"), + } +} + +/// Helper for grammar-focused tests. +/// +/// Creates a minimal temporary example group with a single `main.rs` +/// containing the provided docs. Intended for testing parsing and +/// validation rules, not full integration behavior. +pub fn example_group_from_docs(docs: &str) -> Result { + let tmp = TempDir::new().map_err(|e| { + DataFusionError::Execution(format!("Failed initializing temp dir: {e}")) + })?; + let dir = tmp.path().join("group"); + fs::create_dir(&dir).map_err(|e| { + DataFusionError::Execution(format!("Failed creating temp dir: {e}")) + })?; + fs::write(dir.join("main.rs"), docs).map_err(|e| { + DataFusionError::Execution(format!("Failed writing to temp file: {e}")) + })?; + ExampleGroup::from_dir(&dir, Category::SingleProcess) +} diff --git a/datafusion-examples/src/utils/mod.rs b/datafusion-examples/src/utils/mod.rs new file mode 100644 index 0000000000000..da96724a49cb3 --- /dev/null +++ b/datafusion-examples/src/utils/mod.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod csv_to_parquet; +pub mod datasets; +pub mod example_metadata; + +pub use csv_to_parquet::write_csv_to_parquet; diff --git a/datafusion-testing b/datafusion-testing index e9f9e22ccf091..eccb0e4a42634 160000 --- a/datafusion-testing +++ b/datafusion-testing @@ -1 +1 @@ -Subproject commit e9f9e22ccf09145a7368f80fd6a871f11e2b4481 +Subproject commit eccb0e4a426344ef3faf534cd60e02e9c3afd3ac diff --git a/datafusion/catalog-listing/Cargo.toml b/datafusion/catalog-listing/Cargo.toml index b88461e7ebcbc..be1374b371485 100644 --- a/datafusion/catalog-listing/Cargo.toml +++ b/datafusion/catalog-listing/Cargo.toml @@ -18,11 +18,11 @@ [package] name = "datafusion-catalog-listing" description = "datafusion-catalog-listing" +readme = "README.md" authors.workspace = true edition.workspace = true homepage.workspace = true license.workspace = true -readme.workspace = true repository.workspace = true rust-version.workspace = true version.workspace = true @@ -39,19 +39,26 @@ datafusion-datasource = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-physical-expr = { workspace = true } +datafusion-physical-expr-adapter = { workspace = true } datafusion-physical-expr-common = { workspace = true } datafusion-physical-plan = { workspace = true } -datafusion-session = { workspace = true } futures = { workspace = true } +itertools = { workspace = true } log = { workspace = true } object_store = { workspace = true } -tokio = { workspace = true } [dev-dependencies] +datafusion-datasource-parquet = { workspace = true } +# Note: add additional linter rules in lib.rs. +# Rust does not support workspace + new linter rules in subcrates yet +# https://github.com/rust-lang/cargo/issues/13157 [lints] workspace = true [lib] name = "datafusion_catalog_listing" path = "src/mod.rs" + +[package.metadata.cargo-machete] +ignored = ["datafusion-datasource-parquet"] diff --git a/datafusion/catalog-listing/README.md b/datafusion/catalog-listing/README.md index b4760c413d60b..81a7c7b1da3ae 100644 --- a/datafusion/catalog-listing/README.md +++ b/datafusion/catalog-listing/README.md @@ -17,14 +17,20 @@ under the License. --> -# DataFusion catalog-listing +# Apache DataFusion Catalog Listing -[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. This crate is a submodule of DataFusion with [ListingTable], an implementation of [TableProvider] based on files in a directory (either locally or on remote object storage such as S3). -[df]: https://crates.io/crates/datafusion +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ [listingtable]: https://docs.rs/datafusion/latest/datafusion/datasource/listing/struct.ListingTable.html [tableprovider]: https://docs.rs/datafusion/latest/datafusion/datasource/trait.TableProvider.html +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/catalog-listing/src/config.rs b/datafusion/catalog-listing/src/config.rs new file mode 100644 index 0000000000000..ca4d2abfcd737 --- /dev/null +++ b/datafusion/catalog-listing/src/config.rs @@ -0,0 +1,319 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::options::ListingOptions; +use arrow::datatypes::{DataType, Schema, SchemaRef}; +use datafusion_catalog::Session; +use datafusion_common::{config_err, internal_err}; +use datafusion_datasource::ListingTableUrl; +use datafusion_datasource::file_compression_type::FileCompressionType; +#[expect(deprecated)] +use datafusion_datasource::schema_adapter::SchemaAdapterFactory; +use datafusion_physical_expr_adapter::PhysicalExprAdapterFactory; +use std::str::FromStr; +use std::sync::Arc; + +/// Indicates the source of the schema for a [`crate::ListingTable`] +// PartialEq required for assert_eq! in tests +#[derive(Debug, Clone, Copy, PartialEq, Default)] +pub enum SchemaSource { + /// Schema is not yet set (initial state) + #[default] + Unset, + /// Schema was inferred from first table_path + Inferred, + /// Schema was specified explicitly via with_schema + Specified, +} + +/// Configuration for creating a [`crate::ListingTable`] +/// +/// # Schema Evolution Support +/// +/// This configuration supports schema evolution through the optional +/// [`PhysicalExprAdapterFactory`]. You might want to override the default factory when you need: +/// +/// - **Type coercion requirements**: When you need custom logic for converting between +/// different Arrow data types (e.g., Int32 ↔ Int64, Utf8 ↔ LargeUtf8) +/// - **Column mapping**: You need to map columns with a legacy name to a new name +/// - **Custom handling of missing columns**: By default they are filled in with nulls, but you may e.g. want to fill them in with `0` or `""`. +#[derive(Debug, Clone, Default)] +pub struct ListingTableConfig { + /// Paths on the `ObjectStore` for creating [`crate::ListingTable`]. + /// They should share the same schema and object store. + pub table_paths: Vec, + /// Optional `SchemaRef` for the to be created [`crate::ListingTable`]. + /// + /// See details on [`ListingTableConfig::with_schema`] + pub file_schema: Option, + /// Optional [`ListingOptions`] for the to be created [`crate::ListingTable`]. + /// + /// See details on [`ListingTableConfig::with_listing_options`] + pub options: Option, + /// Tracks the source of the schema information + pub(crate) schema_source: SchemaSource, + /// Optional [`PhysicalExprAdapterFactory`] for creating physical expression adapters + pub(crate) expr_adapter_factory: Option>, +} + +impl ListingTableConfig { + /// Creates new [`ListingTableConfig`] for reading the specified URL + pub fn new(table_path: ListingTableUrl) -> Self { + Self { + table_paths: vec![table_path], + ..Default::default() + } + } + + /// Creates new [`ListingTableConfig`] with multiple table paths. + /// + /// See `ListingTableConfigExt::infer_options` for details on what happens with multiple paths + pub fn new_with_multi_paths(table_paths: Vec) -> Self { + Self { + table_paths, + ..Default::default() + } + } + + /// Returns the source of the schema for this configuration + pub fn schema_source(&self) -> SchemaSource { + self.schema_source + } + /// Set the `schema` for the overall [`crate::ListingTable`] + /// + /// [`crate::ListingTable`] will automatically coerce, when possible, the schema + /// for individual files to match this schema. + /// + /// If a schema is not provided, it is inferred using + /// [`Self::infer_schema`]. + /// + /// If the schema is provided, it must contain only the fields in the file + /// without the table partitioning columns. + /// + /// # Example: Specifying Table Schema + /// ```rust + /// # use std::sync::Arc; + /// # use datafusion_catalog_listing::{ListingTableConfig, ListingOptions}; + /// # use datafusion_datasource::ListingTableUrl; + /// # use datafusion_datasource_parquet::file_format::ParquetFormat; + /// # use arrow::datatypes::{Schema, Field, DataType}; + /// # let table_paths = ListingTableUrl::parse("file:///path/to/data").unwrap(); + /// # let listing_options = ListingOptions::new(Arc::new(ParquetFormat::default())); + /// let schema = Arc::new(Schema::new(vec![ + /// Field::new("id", DataType::Int64, false), + /// Field::new("name", DataType::Utf8, true), + /// ])); + /// + /// let config = ListingTableConfig::new(table_paths) + /// .with_listing_options(listing_options) // Set options first + /// .with_schema(schema); // Then set schema + /// ``` + pub fn with_schema(self, schema: SchemaRef) -> Self { + // Note: We preserve existing options state, but downstream code may expect + // options to be set. Consider calling with_listing_options() or infer_options() + // before operations that require options to be present. + debug_assert!( + self.options.is_some() || cfg!(test), + "ListingTableConfig::with_schema called without options set. \ + Consider calling with_listing_options() or infer_options() first to avoid panics in downstream code." + ); + + Self { + file_schema: Some(schema), + schema_source: SchemaSource::Specified, + ..self + } + } + + /// Add `listing_options` to [`ListingTableConfig`] + /// + /// If not provided, format and other options are inferred via + /// `ListingTableConfigExt::infer_options`. + /// + /// # Example: Configuring Parquet Files with Custom Options + /// ```rust + /// # use std::sync::Arc; + /// # use datafusion_catalog_listing::{ListingTableConfig, ListingOptions}; + /// # use datafusion_datasource::ListingTableUrl; + /// # use datafusion_datasource_parquet::file_format::ParquetFormat; + /// # let table_paths = ListingTableUrl::parse("file:///path/to/data").unwrap(); + /// let options = ListingOptions::new(Arc::new(ParquetFormat::default())) + /// .with_file_extension(".parquet") + /// .with_collect_stat(true); + /// + /// let config = ListingTableConfig::new(table_paths).with_listing_options(options); + /// // Configure file format and options + /// ``` + pub fn with_listing_options(self, listing_options: ListingOptions) -> Self { + // Note: This method properly sets options, but be aware that downstream + // methods like infer_schema() and try_new() require both schema and options + // to be set to function correctly. + debug_assert!( + !self.table_paths.is_empty() || cfg!(test), + "ListingTableConfig::with_listing_options called without table_paths set. \ + Consider calling new() or new_with_multi_paths() first to establish table paths." + ); + + Self { + options: Some(listing_options), + ..self + } + } + + /// Returns a tuple of `(file_extension, optional compression_extension)` + /// + /// For example a path ending with blah.test.csv.gz returns `("csv", Some("gz"))` + /// For example a path ending with blah.test.csv returns `("csv", None)` + pub fn infer_file_extension_and_compression_type( + path: &str, + ) -> datafusion_common::Result<(String, Option)> { + let mut exts = path.rsplit('.'); + + let split = exts.next().unwrap_or(""); + + let file_compression_type = FileCompressionType::from_str(split) + .unwrap_or(FileCompressionType::UNCOMPRESSED); + + if file_compression_type.is_compressed() { + let split2 = exts.next().unwrap_or(""); + Ok((split2.to_string(), Some(split.to_string()))) + } else { + Ok((split.to_string(), None)) + } + } + + /// Infer the [`SchemaRef`] based on `table_path`s. + /// + /// This method infers the table schema using the first `table_path`. + /// See [`ListingOptions::infer_schema`] for more details + /// + /// # Errors + /// * if `self.options` is not set. See [`Self::with_listing_options`] + pub async fn infer_schema( + self, + state: &dyn Session, + ) -> datafusion_common::Result { + match self.options { + Some(options) => { + let ListingTableConfig { + table_paths, + file_schema, + options: _, + schema_source, + expr_adapter_factory, + } = self; + + let (schema, new_schema_source) = match file_schema { + Some(schema) => (schema, schema_source), // Keep existing source if schema exists + None => { + if let Some(url) = table_paths.first() { + ( + options.infer_schema(state, url).await?, + SchemaSource::Inferred, + ) + } else { + (Arc::new(Schema::empty()), SchemaSource::Inferred) + } + } + }; + + Ok(Self { + table_paths, + file_schema: Some(schema), + options: Some(options), + schema_source: new_schema_source, + expr_adapter_factory, + }) + } + None => internal_err!("No `ListingOptions` set for inferring schema"), + } + } + + /// Infer the partition columns from `table_paths`. + /// + /// # Errors + /// * if `self.options` is not set. See [`Self::with_listing_options`] + pub async fn infer_partitions_from_path( + self, + state: &dyn Session, + ) -> datafusion_common::Result { + match self.options { + Some(options) => { + let Some(url) = self.table_paths.first() else { + return config_err!("No table path found"); + }; + let partitions = options + .infer_partitions(state, url) + .await? + .into_iter() + .map(|col_name| { + ( + col_name, + DataType::Dictionary( + Box::new(DataType::UInt16), + Box::new(DataType::Utf8), + ), + ) + }) + .collect::>(); + let options = options.with_table_partition_cols(partitions); + Ok(Self { + table_paths: self.table_paths, + file_schema: self.file_schema, + options: Some(options), + schema_source: self.schema_source, + expr_adapter_factory: self.expr_adapter_factory, + }) + } + None => config_err!("No `ListingOptions` set for inferring schema"), + } + } + + /// Set the [`PhysicalExprAdapterFactory`] for the [`crate::ListingTable`] + /// + /// The expression adapter factory is used to create physical expression adapters that can + /// handle schema evolution and type conversions when evaluating expressions + /// with different schemas than the table schema. + pub fn with_expr_adapter_factory( + self, + expr_adapter_factory: Arc, + ) -> Self { + Self { + expr_adapter_factory: Some(expr_adapter_factory), + ..self + } + } + + /// Deprecated: Set the [`SchemaAdapterFactory`] for the [`crate::ListingTable`] + /// + /// `SchemaAdapterFactory` has been removed. Use [`Self::with_expr_adapter_factory`] + /// and `PhysicalExprAdapterFactory` instead. See `upgrading.md` for more details. + /// + /// This method is a no-op and returns `self` unchanged. + #[deprecated( + since = "52.0.0", + note = "SchemaAdapterFactory has been removed. Use with_expr_adapter_factory and PhysicalExprAdapterFactory instead. See upgrading.md for more details." + )] + #[expect(deprecated)] + pub fn with_schema_adapter_factory( + self, + _schema_adapter_factory: Arc, + ) -> Self { + // No-op - just return self unchanged + self + } +} diff --git a/datafusion/catalog-listing/src/helpers.rs b/datafusion/catalog-listing/src/helpers.rs index 037c69cebd572..031b2ebfb8109 100644 --- a/datafusion/catalog-listing/src/helpers.rs +++ b/datafusion/catalog-listing/src/helpers.rs @@ -21,25 +21,23 @@ use std::mem; use std::sync::Arc; use datafusion_catalog::Session; -use datafusion_common::internal_err; -use datafusion_common::{HashMap, Result, ScalarValue}; +use datafusion_common::{HashMap, Result, ScalarValue, assert_or_internal_err}; use datafusion_datasource::ListingTableUrl; use datafusion_datasource::PartitionedFile; -use datafusion_expr::{BinaryExpr, Operator}; +use datafusion_expr::{BinaryExpr, Operator, lit, utils}; use arrow::{ - array::{Array, ArrayRef, AsArray, StringBuilder}, - compute::{and, cast, prep_null_mask_filter}, - datatypes::{DataType, Field, Fields, Schema}, + array::AsArray, + datatypes::{DataType, Field}, record_batch::RecordBatch, }; use datafusion_expr::execution_props::ExecutionProps; use futures::stream::FuturesUnordered; -use futures::{stream::BoxStream, StreamExt, TryStreamExt}; +use futures::{StreamExt, TryStreamExt, stream::BoxStream}; use log::{debug, trace}; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; -use datafusion_common::{Column, DFSchema, DataFusionError}; +use datafusion_common::{Column, DFSchema}; use datafusion_expr::{Expr, Volatility}; use datafusion_physical_expr::create_physical_expr; use object_store::path::Path; @@ -53,7 +51,7 @@ use object_store::{ObjectMeta, ObjectStore}; pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { let mut is_applicable = true; expr.apply(|expr| match expr { - Expr::Column(Column { ref name, .. }) => { + Expr::Column(Column { name, .. }) => { is_applicable &= col_names.contains(&name.as_str()); if is_applicable { Ok(TreeNodeRecursion::Jump) @@ -61,7 +59,7 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { Ok(TreeNodeRecursion::Stop) } } - Expr::Literal(_) + Expr::Literal(_, _) | Expr::Alias(_) | Expr::OuterReferenceColumn(_, _) | Expr::ScalarVariable(_, _) @@ -85,6 +83,7 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { | Expr::Exists(_) | Expr::InSubquery(_) | Expr::ScalarSubquery(_) + | Expr::SetComparison(_) | Expr::GroupingSet(_) | Expr::Case(_) => Ok(TreeNodeRecursion::Continue), @@ -156,6 +155,7 @@ pub fn split_files( chunks } +#[derive(Debug)] pub struct Partition { /// The path to the partition, including the table prefix path: Path, @@ -238,96 +238,6 @@ pub async fn list_partitions( Ok(out) } -async fn prune_partitions( - table_path: &ListingTableUrl, - partitions: Vec, - filters: &[Expr], - partition_cols: &[(String, DataType)], -) -> Result> { - if filters.is_empty() { - return Ok(partitions); - } - - let mut builders: Vec<_> = (0..partition_cols.len()) - .map(|_| StringBuilder::with_capacity(partitions.len(), partitions.len() * 10)) - .collect(); - - for partition in &partitions { - let cols = partition_cols.iter().map(|x| x.0.as_str()); - let parsed = parse_partitions_for_path(table_path, &partition.path, cols) - .unwrap_or_default(); - - let mut builders = builders.iter_mut(); - for (p, b) in parsed.iter().zip(&mut builders) { - b.append_value(p); - } - builders.for_each(|b| b.append_null()); - } - - let arrays = partition_cols - .iter() - .zip(builders) - .map(|((_, d), mut builder)| { - let array = builder.finish(); - cast(&array, d) - }) - .collect::>()?; - - let fields: Fields = partition_cols - .iter() - .map(|(n, d)| Field::new(n, d.clone(), true)) - .collect(); - let schema = Arc::new(Schema::new(fields)); - - let df_schema = DFSchema::from_unqualified_fields( - partition_cols - .iter() - .map(|(n, d)| Field::new(n, d.clone(), true)) - .collect(), - Default::default(), - )?; - - let batch = RecordBatch::try_new(schema, arrays)?; - - // TODO: Plumb this down - let props = ExecutionProps::new(); - - // Applies `filter` to `batch` returning `None` on error - let do_filter = |filter| -> Result { - let expr = create_physical_expr(filter, &df_schema, &props)?; - expr.evaluate(&batch)?.into_array(partitions.len()) - }; - - //.Compute the conjunction of the filters - let mask = filters - .iter() - .map(|f| do_filter(f).map(|a| a.as_boolean().clone())) - .reduce(|a, b| Ok(and(&a?, &b?)?)); - - let mask = match mask { - Some(Ok(mask)) => mask, - Some(Err(err)) => return Err(err), - None => return Ok(partitions), - }; - - // Don't retain partitions that evaluated to null - let prepared = match mask.null_count() { - 0 => mask, - _ => prep_null_mask_filter(&mask), - }; - - // Sanity check - assert_eq!(prepared.len(), partitions.len()); - - let filtered = partitions - .into_iter() - .zip(prepared.values()) - .filter_map(|(p, f)| f.then_some(p)) - .collect(); - - Ok(filtered) -} - #[derive(Debug)] enum PartitionValue { Single(String), @@ -338,16 +248,11 @@ fn populate_partition_values<'a>( partition_values: &mut HashMap<&'a str, PartitionValue>, filter: &'a Expr, ) { - if let Expr::BinaryExpr(BinaryExpr { - ref left, - op, - ref right, - }) = filter - { + if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = filter { match op { Operator::Eq => match (left.as_ref(), right.as_ref()) { - (Expr::Column(Column { ref name, .. }), Expr::Literal(val)) - | (Expr::Literal(val), Expr::Column(Column { ref name, .. })) => { + (Expr::Column(Column { name, .. }), Expr::Literal(val, _)) + | (Expr::Literal(val, _), Expr::Column(Column { name, .. })) => { if partition_values .insert(name, PartitionValue::Single(val.to_string())) .is_some() @@ -402,6 +307,62 @@ pub fn evaluate_partition_prefix<'a>( } } +fn filter_partitions( + pf: PartitionedFile, + filters: &[Expr], + df_schema: &DFSchema, +) -> Result> { + if pf.partition_values.is_empty() && !filters.is_empty() { + return Ok(None); + } else if filters.is_empty() { + return Ok(Some(pf)); + } + + let arrays = pf + .partition_values + .iter() + .map(|v| v.to_array()) + .collect::>()?; + + let batch = RecordBatch::try_new(Arc::clone(df_schema.inner()), arrays)?; + + let filter = utils::conjunction(filters.iter().cloned()).unwrap_or_else(|| lit(true)); + let props = ExecutionProps::new(); + let expr = create_physical_expr(&filter, df_schema, &props)?; + + // Since we're only operating on a single file, our batch and resulting "array" holds only one + // value indicating if the input file matches the provided filters + let matches = expr.evaluate(&batch)?.into_array(1)?; + if matches.as_boolean().value(0) { + return Ok(Some(pf)); + } + + Ok(None) +} + +fn try_into_partitioned_file( + object_meta: ObjectMeta, + partition_cols: &[(String, DataType)], + table_path: &ListingTableUrl, +) -> Result { + let cols = partition_cols.iter().map(|(name, _)| name.as_str()); + let parsed = parse_partitions_for_path(table_path, &object_meta.location, cols); + + let partition_values = parsed + .into_iter() + .flatten() + .zip(partition_cols) + .map(|(parsed, (_, datatype))| { + ScalarValue::try_from_string(parsed.to_string(), datatype) + }) + .collect::>>()?; + + let mut pf: PartitionedFile = object_meta.into(); + pf.partition_values = partition_values; + + Ok(pf) +} + /// Discover the partitions on the given path and prune out files /// that belong to irrelevant partitions using `filters` expressions. /// `filters` should only contain expressions that can be evaluated @@ -414,79 +375,46 @@ pub async fn pruned_partition_list<'a>( file_extension: &'a str, partition_cols: &'a [(String, DataType)], ) -> Result>> { - // if no partition col => simply list all the files - if partition_cols.is_empty() { - if !filters.is_empty() { - return internal_err!( - "Got partition filters for unpartitioned table {}", - table_path - ); - } - return Ok(Box::pin( - table_path - .list_all_files(ctx, store, file_extension) - .await? - .try_filter(|object_meta| futures::future::ready(object_meta.size > 0)) - .map_ok(|object_meta| object_meta.into()), - )); - } - - let partition_prefix = evaluate_partition_prefix(partition_cols, filters); - let partitions = - list_partitions(store, table_path, partition_cols.len(), partition_prefix) - .await?; - debug!("Listed {} partitions", partitions.len()); - - let pruned = - prune_partitions(table_path, partitions, filters, partition_cols).await?; + let prefix = if !partition_cols.is_empty() { + evaluate_partition_prefix(partition_cols, filters) + } else { + None + }; - debug!("Pruning yielded {} partitions", pruned.len()); + let objects = table_path + .list_prefixed_files(ctx, store, prefix, file_extension) + .await? + .try_filter(|object_meta| futures::future::ready(object_meta.size > 0)); - let stream = futures::stream::iter(pruned) - .map(move |partition: Partition| async move { - let cols = partition_cols.iter().map(|x| x.0.as_str()); - let parsed = parse_partitions_for_path(table_path, &partition.path, cols); + if partition_cols.is_empty() { + assert_or_internal_err!( + filters.is_empty(), + "Got partition filters for unpartitioned table {}", + table_path + ); - let partition_values = parsed - .into_iter() - .flatten() - .zip(partition_cols) - .map(|(parsed, (_, datatype))| { - ScalarValue::try_from_string(parsed.to_string(), datatype) - }) - .collect::>>()?; - - let files = match partition.files { - Some(files) => files, - None => { - trace!("Recursively listing partition {}", partition.path); - store.list(Some(&partition.path)).try_collect().await? - } - }; - let files = files.into_iter().filter(move |o| { - let extension_match = o.location.as_ref().ends_with(file_extension); - // here need to scan subdirectories(`listing_table_ignore_subdirectory` = false) - let glob_match = table_path.contains(&o.location, false); - extension_match && glob_match - }); - - let stream = futures::stream::iter(files.map(move |object_meta| { - Ok(PartitionedFile { - object_meta, - partition_values: partition_values.clone(), - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - }) - })); - - Ok::<_, DataFusionError>(stream) - }) - .buffer_unordered(CONCURRENCY_LIMIT) - .try_flatten() - .boxed(); - Ok(stream) + // if no partition col => simply list all the files + Ok(objects.map_ok(|object_meta| object_meta.into()).boxed()) + } else { + let df_schema = DFSchema::from_unqualified_fields( + partition_cols + .iter() + .map(|(n, d)| Field::new(n, d.clone(), true)) + .collect(), + Default::default(), + )?; + + Ok(objects + .map_ok(|object_meta| { + try_into_partitioned_file(object_meta, partition_cols, table_path) + }) + .try_filter_map(move |pf| { + futures::future::ready( + pf.and_then(|pf| filter_partitions(pf, filters, &df_schema)), + ) + }) + .boxed()) + } } /// Extract the partition values for the given `file_path` (in the given `table_path`) @@ -502,12 +430,12 @@ where let subpath = table_path.strip_prefix(file_path)?; let mut part_values = vec![]; - for (part, pn) in subpath.zip(table_partition_cols) { + for (part, expected_partition) in subpath.zip(table_partition_cols) { match part.split_once('=') { - Some((name, val)) if name == pn => part_values.push(val), + Some((name, val)) if name == expected_partition => part_values.push(val), _ => { debug!( - "Ignoring file: file_path='{file_path}', table_path='{table_path}', part='{part}', partition_col='{pn}'", + "Ignoring file: file_path='{file_path}', table_path='{table_path}', part='{part}', partition_col='{expected_partition}'", ); return None; } @@ -530,22 +458,11 @@ pub fn describe_partition(partition: &Partition) -> (&str, usize, Vec<&str>) { #[cfg(test)] mod tests { - use async_trait::async_trait; - use datafusion_common::config::TableOptions; use datafusion_datasource::file_groups::FileGroup; - use datafusion_execution::config::SessionConfig; - use datafusion_execution::runtime_env::RuntimeEnv; - use futures::FutureExt; - use object_store::memory::InMemory; - use std::any::Any; use std::ops::Not; use super::*; - use datafusion_expr::{ - case, col, lit, AggregateUDF, Expr, LogicalPlan, ScalarUDF, WindowUDF, - }; - use datafusion_physical_expr_common::physical_expr::PhysicalExpr; - use datafusion_physical_plan::ExecutionPlan; + use datafusion_expr::{Expr, case, col, lit}; #[test] fn test_split_files() { @@ -588,205 +505,6 @@ mod tests { assert_eq!(0, chunks.len()); } - #[tokio::test] - async fn test_pruned_partition_list_empty() { - let (store, state) = make_test_store_and_state(&[ - ("tablepath/mypartition=val1/notparquetfile", 100), - ("tablepath/mypartition=val1/ignoresemptyfile.parquet", 0), - ("tablepath/file.parquet", 100), - ]); - let filter = Expr::eq(col("mypartition"), lit("val1")); - let pruned = pruned_partition_list( - state.as_ref(), - store.as_ref(), - &ListingTableUrl::parse("file:///tablepath/").unwrap(), - &[filter], - ".parquet", - &[(String::from("mypartition"), DataType::Utf8)], - ) - .await - .expect("partition pruning failed") - .collect::>() - .await; - - assert_eq!(pruned.len(), 0); - } - - #[tokio::test] - async fn test_pruned_partition_list() { - let (store, state) = make_test_store_and_state(&[ - ("tablepath/mypartition=val1/file.parquet", 100), - ("tablepath/mypartition=val2/file.parquet", 100), - ("tablepath/mypartition=val1/ignoresemptyfile.parquet", 0), - ("tablepath/mypartition=val1/other=val3/file.parquet", 100), - ]); - let filter = Expr::eq(col("mypartition"), lit("val1")); - let pruned = pruned_partition_list( - state.as_ref(), - store.as_ref(), - &ListingTableUrl::parse("file:///tablepath/").unwrap(), - &[filter], - ".parquet", - &[(String::from("mypartition"), DataType::Utf8)], - ) - .await - .expect("partition pruning failed") - .try_collect::>() - .await - .unwrap(); - - assert_eq!(pruned.len(), 2); - let f1 = &pruned[0]; - assert_eq!( - f1.object_meta.location.as_ref(), - "tablepath/mypartition=val1/file.parquet" - ); - assert_eq!(&f1.partition_values, &[ScalarValue::from("val1")]); - let f2 = &pruned[1]; - assert_eq!( - f2.object_meta.location.as_ref(), - "tablepath/mypartition=val1/other=val3/file.parquet" - ); - assert_eq!(f2.partition_values, &[ScalarValue::from("val1"),]); - } - - #[tokio::test] - async fn test_pruned_partition_list_multi() { - let (store, state) = make_test_store_and_state(&[ - ("tablepath/part1=p1v1/file.parquet", 100), - ("tablepath/part1=p1v2/part2=p2v1/file1.parquet", 100), - ("tablepath/part1=p1v2/part2=p2v1/file2.parquet", 100), - ("tablepath/part1=p1v3/part2=p2v1/file2.parquet", 100), - ("tablepath/part1=p1v2/part2=p2v2/file2.parquet", 100), - ]); - let filter1 = Expr::eq(col("part1"), lit("p1v2")); - let filter2 = Expr::eq(col("part2"), lit("p2v1")); - let pruned = pruned_partition_list( - state.as_ref(), - store.as_ref(), - &ListingTableUrl::parse("file:///tablepath/").unwrap(), - &[filter1, filter2], - ".parquet", - &[ - (String::from("part1"), DataType::Utf8), - (String::from("part2"), DataType::Utf8), - ], - ) - .await - .expect("partition pruning failed") - .try_collect::>() - .await - .unwrap(); - - assert_eq!(pruned.len(), 2); - let f1 = &pruned[0]; - assert_eq!( - f1.object_meta.location.as_ref(), - "tablepath/part1=p1v2/part2=p2v1/file1.parquet" - ); - assert_eq!( - &f1.partition_values, - &[ScalarValue::from("p1v2"), ScalarValue::from("p2v1"),] - ); - let f2 = &pruned[1]; - assert_eq!( - f2.object_meta.location.as_ref(), - "tablepath/part1=p1v2/part2=p2v1/file2.parquet" - ); - assert_eq!( - &f2.partition_values, - &[ScalarValue::from("p1v2"), ScalarValue::from("p2v1")] - ); - } - - #[tokio::test] - async fn test_list_partition() { - let (store, _) = make_test_store_and_state(&[ - ("tablepath/part1=p1v1/file.parquet", 100), - ("tablepath/part1=p1v2/part2=p2v1/file1.parquet", 100), - ("tablepath/part1=p1v2/part2=p2v1/file2.parquet", 100), - ("tablepath/part1=p1v3/part2=p2v1/file3.parquet", 100), - ("tablepath/part1=p1v2/part2=p2v2/file4.parquet", 100), - ("tablepath/part1=p1v2/part2=p2v2/empty.parquet", 0), - ]); - - let partitions = list_partitions( - store.as_ref(), - &ListingTableUrl::parse("file:///tablepath/").unwrap(), - 0, - None, - ) - .await - .expect("listing partitions failed"); - - assert_eq!( - &partitions - .iter() - .map(describe_partition) - .collect::>(), - &vec![ - ("tablepath", 0, vec![]), - ("tablepath/part1=p1v1", 1, vec![]), - ("tablepath/part1=p1v2", 1, vec![]), - ("tablepath/part1=p1v3", 1, vec![]), - ] - ); - - let partitions = list_partitions( - store.as_ref(), - &ListingTableUrl::parse("file:///tablepath/").unwrap(), - 1, - None, - ) - .await - .expect("listing partitions failed"); - - assert_eq!( - &partitions - .iter() - .map(describe_partition) - .collect::>(), - &vec![ - ("tablepath", 0, vec![]), - ("tablepath/part1=p1v1", 1, vec!["file.parquet"]), - ("tablepath/part1=p1v2", 1, vec![]), - ("tablepath/part1=p1v2/part2=p2v1", 2, vec![]), - ("tablepath/part1=p1v2/part2=p2v2", 2, vec![]), - ("tablepath/part1=p1v3", 1, vec![]), - ("tablepath/part1=p1v3/part2=p2v1", 2, vec![]), - ] - ); - - let partitions = list_partitions( - store.as_ref(), - &ListingTableUrl::parse("file:///tablepath/").unwrap(), - 2, - None, - ) - .await - .expect("listing partitions failed"); - - assert_eq!( - &partitions - .iter() - .map(describe_partition) - .collect::>(), - &vec![ - ("tablepath", 0, vec![]), - ("tablepath/part1=p1v1", 1, vec!["file.parquet"]), - ("tablepath/part1=p1v2", 1, vec![]), - ("tablepath/part1=p1v3", 1, vec![]), - ( - "tablepath/part1=p1v2/part2=p2v1", - 2, - vec!["file1.parquet", "file2.parquet"] - ), - ("tablepath/part1=p1v2/part2=p2v2", 2, vec!["file4.parquet"]), - ("tablepath/part1=p1v3/part2=p2v1", 2, vec!["file3.parquet"]), - ] - ); - } - #[test] fn test_parse_partitions_for_path() { assert_eq!( @@ -984,7 +702,7 @@ mod tests { assert_eq!( evaluate_partition_prefix( partitions, - &[col("a").eq(Expr::Literal(ScalarValue::Date32(Some(3))))], + &[col("a").eq(Expr::Literal(ScalarValue::Date32(Some(3)), None))], ), Some(Path::from("a=1970-01-04")), ); @@ -993,93 +711,12 @@ mod tests { assert_eq!( evaluate_partition_prefix( partitions, - &[col("a").eq(Expr::Literal(ScalarValue::Date64(Some( - 4 * 24 * 60 * 60 * 1000 - )))),], + &[col("a").eq(Expr::Literal( + ScalarValue::Date64(Some(4 * 24 * 60 * 60 * 1000)), + None + )),], ), Some(Path::from("a=1970-01-05")), ); } - - pub fn make_test_store_and_state( - files: &[(&str, u64)], - ) -> (Arc, Arc) { - let memory = InMemory::new(); - - for (name, size) in files { - memory - .put(&Path::from(*name), vec![0; *size as usize].into()) - .now_or_never() - .unwrap() - .unwrap(); - } - - (Arc::new(memory), Arc::new(MockSession {})) - } - - struct MockSession {} - - #[async_trait] - impl Session for MockSession { - fn session_id(&self) -> &str { - unimplemented!() - } - - fn config(&self) -> &SessionConfig { - unimplemented!() - } - - async fn create_physical_plan( - &self, - _logical_plan: &LogicalPlan, - ) -> Result> { - unimplemented!() - } - - fn create_physical_expr( - &self, - _expr: Expr, - _df_schema: &DFSchema, - ) -> Result> { - unimplemented!() - } - - fn scalar_functions(&self) -> &std::collections::HashMap> { - unimplemented!() - } - - fn aggregate_functions( - &self, - ) -> &std::collections::HashMap> { - unimplemented!() - } - - fn window_functions(&self) -> &std::collections::HashMap> { - unimplemented!() - } - - fn runtime_env(&self) -> &Arc { - unimplemented!() - } - - fn execution_props(&self) -> &ExecutionProps { - unimplemented!() - } - - fn as_any(&self) -> &dyn Any { - unimplemented!() - } - - fn table_options(&self) -> &TableOptions { - unimplemented!() - } - - fn table_options_mut(&mut self) -> &mut TableOptions { - unimplemented!() - } - - fn task_ctx(&self) -> Arc { - unimplemented!() - } - } } diff --git a/datafusion/catalog-listing/src/mod.rs b/datafusion/catalog-listing/src/mod.rs index fb0a960f37b6a..9efb5aa96267e 100644 --- a/datafusion/catalog-listing/src/mod.rs +++ b/datafusion/catalog-listing/src/mod.rs @@ -15,13 +15,21 @@ // specific language governing permissions and limitations // under the License. +#![cfg_attr(test, allow(clippy::needless_pass_by_value))] #![doc( html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] +mod config; pub mod helpers; +mod options; +mod table; + +pub use config::{ListingTableConfig, SchemaSource}; +pub use options::ListingOptions; +pub use table::{ListFilesResult, ListingTable}; diff --git a/datafusion/catalog-listing/src/options.rs b/datafusion/catalog-listing/src/options.rs new file mode 100644 index 0000000000000..146f98d62335e --- /dev/null +++ b/datafusion/catalog-listing/src/options.rs @@ -0,0 +1,399 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, SchemaRef}; +use datafusion_catalog::Session; +use datafusion_common::plan_err; +use datafusion_datasource::ListingTableUrl; +use datafusion_datasource::file_format::FileFormat; +use datafusion_execution::config::SessionConfig; +use datafusion_expr::SortExpr; +use futures::StreamExt; +use futures::{TryStreamExt, future}; +use itertools::Itertools; +use std::sync::Arc; + +/// Options for creating a [`crate::ListingTable`] +#[derive(Clone, Debug)] +pub struct ListingOptions { + /// A suffix on which files should be filtered (leave empty to + /// keep all files on the path) + pub file_extension: String, + /// The file format + pub format: Arc, + /// The expected partition column names in the folder structure. + /// See [Self::with_table_partition_cols] for details + pub table_partition_cols: Vec<(String, DataType)>, + /// Set true to try to guess statistics from the files. + /// This can add a lot of overhead as it will usually require files + /// to be opened and at least partially parsed. + pub collect_stat: bool, + /// Group files to avoid that the number of partitions exceeds + /// this limit + pub target_partitions: usize, + /// Optional pre-known sort order(s). Must be `SortExpr`s. + /// + /// DataFusion may take advantage of this ordering to omit sorts + /// or use more efficient algorithms. Currently sortedness must be + /// provided if it is known by some external mechanism, but may in + /// the future be automatically determined, for example using + /// parquet metadata. + /// + /// See + /// + /// NOTE: This attribute stores all equivalent orderings (the outer `Vec`) + /// where each ordering consists of an individual lexicographic + /// ordering (encapsulated by a `Vec`). If there aren't + /// multiple equivalent orderings, the outer `Vec` will have a + /// single element. + pub file_sort_order: Vec>, +} + +impl ListingOptions { + /// Creates an options instance with the given format + /// Default values: + /// - use default file extension filter + /// - no input partition to discover + /// - one target partition + /// - do not collect statistics + pub fn new(format: Arc) -> Self { + Self { + file_extension: format.get_ext(), + format, + table_partition_cols: vec![], + collect_stat: false, + target_partitions: 1, + file_sort_order: vec![], + } + } + + /// Set options from [`SessionConfig`] and returns self. + /// + /// Currently this sets `target_partitions` and `collect_stat` + /// but if more options are added in the future that need to be coordinated + /// they will be synchronized through this method. + pub fn with_session_config_options(mut self, config: &SessionConfig) -> Self { + self = self.with_target_partitions(config.target_partitions()); + self = self.with_collect_stat(config.collect_statistics()); + self + } + + /// Set file extension on [`ListingOptions`] and returns self. + /// + /// # Example + /// ``` + /// # use std::sync::Arc; + /// # use datafusion_catalog_listing::ListingOptions; + /// # use datafusion_datasource_parquet::file_format::ParquetFormat; + /// + /// let listing_options = ListingOptions::new(Arc::new(ParquetFormat::default())) + /// .with_file_extension(".parquet"); + /// + /// assert_eq!(listing_options.file_extension, ".parquet"); + /// ``` + pub fn with_file_extension(mut self, file_extension: impl Into) -> Self { + self.file_extension = file_extension.into(); + self + } + + /// Optionally set file extension on [`ListingOptions`] and returns self. + /// + /// If `file_extension` is `None`, the file extension will not be changed + /// + /// # Example + /// ``` + /// # use std::sync::Arc; + /// # use datafusion_catalog_listing::ListingOptions; + /// # use datafusion_datasource_parquet::file_format::ParquetFormat; + /// + /// let extension = Some(".parquet"); + /// let listing_options = ListingOptions::new(Arc::new(ParquetFormat::default())) + /// .with_file_extension_opt(extension); + /// + /// assert_eq!(listing_options.file_extension, ".parquet"); + /// ``` + pub fn with_file_extension_opt(mut self, file_extension: Option) -> Self + where + S: Into, + { + if let Some(file_extension) = file_extension { + self.file_extension = file_extension.into(); + } + self + } + + /// Set `table partition columns` on [`ListingOptions`] and returns self. + /// + /// "partition columns," used to support [Hive Partitioning], are + /// columns added to the data that is read, based on the folder + /// structure where the data resides. + /// + /// For example, give the following files in your filesystem: + /// + /// ```text + /// /mnt/nyctaxi/year=2022/month=01/tripdata.parquet + /// /mnt/nyctaxi/year=2021/month=12/tripdata.parquet + /// /mnt/nyctaxi/year=2021/month=11/tripdata.parquet + /// ``` + /// + /// A [`crate::ListingTable`] created at `/mnt/nyctaxi/` with partition + /// columns "year" and "month" will include new `year` and `month` + /// columns while reading the files. The `year` column would have + /// value `2022` and the `month` column would have value `01` for + /// the rows read from + /// `/mnt/nyctaxi/year=2022/month=01/tripdata.parquet` + /// + ///# Notes + /// + /// - If only one level (e.g. `year` in the example above) is + /// specified, the other levels are ignored but the files are + /// still read. + /// + /// - Files that don't follow this partitioning scheme will be + /// ignored. + /// + /// - Since the columns have the same value for all rows read from + /// each individual file (such as dates), they are typically + /// dictionary encoded for efficiency. You may use + /// [`wrap_partition_type_in_dict`] to request a + /// dictionary-encoded type. + /// + /// - The partition columns are solely extracted from the file path. Especially they are NOT part of the parquet files itself. + /// + /// # Example + /// + /// ``` + /// # use std::sync::Arc; + /// # use arrow::datatypes::DataType; + /// # use datafusion_expr::col; + /// # use datafusion_catalog_listing::ListingOptions; + /// # use datafusion_datasource_parquet::file_format::ParquetFormat; + /// + /// // listing options for files with paths such as `/mnt/data/col_a=x/col_b=y/data.parquet` + /// // `col_a` and `col_b` will be included in the data read from those files + /// let listing_options = ListingOptions::new(Arc::new( + /// ParquetFormat::default() + /// )) + /// .with_table_partition_cols(vec![("col_a".to_string(), DataType::Utf8), + /// ("col_b".to_string(), DataType::Utf8)]); + /// + /// assert_eq!(listing_options.table_partition_cols, vec![("col_a".to_string(), DataType::Utf8), + /// ("col_b".to_string(), DataType::Utf8)]); + /// ``` + /// + /// [Hive Partitioning]: https://docs.cloudera.com/HDPDocuments/HDP2/HDP-2.1.3/bk_system-admin-guide/content/hive_partitioned_tables.html + /// [`wrap_partition_type_in_dict`]: datafusion_datasource::file_scan_config::wrap_partition_type_in_dict + pub fn with_table_partition_cols( + mut self, + table_partition_cols: Vec<(String, DataType)>, + ) -> Self { + self.table_partition_cols = table_partition_cols; + self + } + + /// Set stat collection on [`ListingOptions`] and returns self. + /// + /// ``` + /// # use std::sync::Arc; + /// # use datafusion_catalog_listing::ListingOptions; + /// # use datafusion_datasource_parquet::file_format::ParquetFormat; + /// + /// let listing_options = + /// ListingOptions::new(Arc::new(ParquetFormat::default())).with_collect_stat(true); + /// + /// assert_eq!(listing_options.collect_stat, true); + /// ``` + pub fn with_collect_stat(mut self, collect_stat: bool) -> Self { + self.collect_stat = collect_stat; + self + } + + /// Set number of target partitions on [`ListingOptions`] and returns self. + /// + /// ``` + /// # use std::sync::Arc; + /// # use datafusion_catalog_listing::ListingOptions; + /// # use datafusion_datasource_parquet::file_format::ParquetFormat; + /// + /// let listing_options = + /// ListingOptions::new(Arc::new(ParquetFormat::default())).with_target_partitions(8); + /// + /// assert_eq!(listing_options.target_partitions, 8); + /// ``` + pub fn with_target_partitions(mut self, target_partitions: usize) -> Self { + self.target_partitions = target_partitions; + self + } + + /// Set file sort order on [`ListingOptions`] and returns self. + /// + /// ``` + /// # use std::sync::Arc; + /// # use datafusion_expr::col; + /// # use datafusion_catalog_listing::ListingOptions; + /// # use datafusion_datasource_parquet::file_format::ParquetFormat; + /// + /// // Tell datafusion that the files are sorted by column "a" + /// let file_sort_order = vec![vec![col("a").sort(true, true)]]; + /// + /// let listing_options = ListingOptions::new(Arc::new(ParquetFormat::default())) + /// .with_file_sort_order(file_sort_order.clone()); + /// + /// assert_eq!(listing_options.file_sort_order, file_sort_order); + /// ``` + pub fn with_file_sort_order(mut self, file_sort_order: Vec>) -> Self { + self.file_sort_order = file_sort_order; + self + } + + /// Infer the schema of the files at the given path on the provided object store. + /// + /// If the table_path contains one or more files (i.e. it is a directory / + /// prefix of files) their schema is merged by calling [`FileFormat::infer_schema`] + /// + /// Note: The inferred schema does not include any partitioning columns. + /// + /// This method is called as part of creating a [`crate::ListingTable`]. + pub async fn infer_schema<'a>( + &'a self, + state: &dyn Session, + table_path: &'a ListingTableUrl, + ) -> datafusion_common::Result { + let store = state.runtime_env().object_store(table_path)?; + + let files: Vec<_> = table_path + .list_all_files(state, store.as_ref(), &self.file_extension) + .await? + // Empty files cannot affect schema but may throw when trying to read for it + .try_filter(|object_meta| future::ready(object_meta.size > 0)) + .try_collect() + .await?; + + let schema = self.format.infer_schema(state, &store, &files).await?; + + Ok(schema) + } + + /// Infers the partition columns stored in `LOCATION` and compares + /// them with the columns provided in `PARTITIONED BY` to help prevent + /// accidental corrupts of partitioned tables. + /// + /// Allows specifying partial partitions. + pub async fn validate_partitions( + &self, + state: &dyn Session, + table_path: &ListingTableUrl, + ) -> datafusion_common::Result<()> { + if self.table_partition_cols.is_empty() { + return Ok(()); + } + + if !table_path.is_collection() { + return plan_err!( + "Can't create a partitioned table backed by a single file, \ + perhaps the URL is missing a trailing slash?" + ); + } + + let inferred = self.infer_partitions(state, table_path).await?; + + // no partitioned files found on disk + if inferred.is_empty() { + return Ok(()); + } + + let table_partition_names = self + .table_partition_cols + .iter() + .map(|(col_name, _)| col_name.clone()) + .collect_vec(); + + if inferred.len() < table_partition_names.len() { + return plan_err!( + "Inferred partitions to be {:?}, but got {:?}", + inferred, + table_partition_names + ); + } + + // match prefix to allow creating tables with partial partitions + for (idx, col) in table_partition_names.iter().enumerate() { + if &inferred[idx] != col { + return plan_err!( + "Inferred partitions to be {:?}, but got {:?}", + inferred, + table_partition_names + ); + } + } + + Ok(()) + } + + /// Infer the partitioning at the given path on the provided object store. + /// For performance reasons, it doesn't read all the files on disk + /// and therefore may fail to detect invalid partitioning. + pub async fn infer_partitions( + &self, + state: &dyn Session, + table_path: &ListingTableUrl, + ) -> datafusion_common::Result> { + let store = state.runtime_env().object_store(table_path)?; + + // only use 10 files for inference + // This can fail to detect inconsistent partition keys + // A DFS traversal approach of the store can help here + let files: Vec<_> = table_path + .list_all_files(state, store.as_ref(), &self.file_extension) + .await? + .take(10) + .try_collect() + .await?; + + let stripped_path_parts = files.iter().map(|file| { + table_path + .strip_prefix(&file.location) + .unwrap() + .collect_vec() + }); + + let partition_keys = stripped_path_parts + .map(|path_parts| { + path_parts + .into_iter() + .rev() + .skip(1) // get parents only; skip the file itself + .rev() + // Partitions are expected to follow the format "column_name=value", so we + // should ignore any path part that cannot be parsed into the expected format + .filter(|s| s.contains('=')) + .map(|s| s.split('=').take(1).collect()) + .collect_vec() + }) + .collect_vec(); + + match partition_keys.into_iter().all_equal_value() { + Ok(v) => Ok(v), + Err(None) => Ok(vec![]), + Err(Some(diff)) => { + let mut sorted_diff = [diff.0, diff.1]; + sorted_diff.sort(); + plan_err!("Found mixed partition values on disk {:?}", sorted_diff) + } + } + } +} diff --git a/datafusion/catalog-listing/src/table.rs b/datafusion/catalog-listing/src/table.rs new file mode 100644 index 0000000000000..a5de79b052a4e --- /dev/null +++ b/datafusion/catalog-listing/src/table.rs @@ -0,0 +1,1058 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::config::SchemaSource; +use crate::helpers::{expr_applicable_for_cols, pruned_partition_list}; +use crate::{ListingOptions, ListingTableConfig}; +use arrow::datatypes::{Field, Schema, SchemaBuilder, SchemaRef}; +use async_trait::async_trait; +use datafusion_catalog::{ScanArgs, ScanResult, Session, TableProvider}; +use datafusion_common::stats::Precision; +use datafusion_common::{ + Constraints, SchemaExt, Statistics, internal_datafusion_err, plan_err, project_schema, +}; +use datafusion_datasource::file::FileSource; +use datafusion_datasource::file_groups::FileGroup; +use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; +use datafusion_datasource::file_sink_config::{FileOutputMode, FileSinkConfig}; +#[expect(deprecated)] +use datafusion_datasource::schema_adapter::SchemaAdapterFactory; +use datafusion_datasource::{ + ListingTableUrl, PartitionedFile, TableSchema, compute_all_files_statistics, +}; +use datafusion_execution::cache::TableScopedPath; +use datafusion_execution::cache::cache_manager::FileStatisticsCache; +use datafusion_execution::cache::cache_unit::DefaultFileStatisticsCache; +use datafusion_expr::dml::InsertOp; +use datafusion_expr::execution_props::ExecutionProps; +use datafusion_expr::{Expr, TableProviderFilterPushDown, TableType}; +use datafusion_physical_expr::create_lex_ordering; +use datafusion_physical_expr_adapter::PhysicalExprAdapterFactory; +use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_plan::ExecutionPlan; +use datafusion_physical_plan::empty::EmptyExec; +use futures::{Stream, StreamExt, TryStreamExt, future, stream}; +use object_store::ObjectStore; +use std::any::Any; +use std::collections::HashMap; +use std::sync::Arc; + +/// Result of a file listing operation from [`ListingTable::list_files_for_scan`]. +#[derive(Debug)] +pub struct ListFilesResult { + /// File groups organized by the partitioning strategy. + pub file_groups: Vec, + /// Aggregated statistics for all files. + pub statistics: Statistics, + /// Whether files are grouped by partition values (enables Hash partitioning). + pub grouped_by_partition: bool, +} + +/// Built in [`TableProvider`] that reads data from one or more files as a single table. +/// +/// The files are read using an [`ObjectStore`] instance, for example from +/// local files or objects from AWS S3. +/// +/// # Features: +/// * Reading multiple files as a single table +/// * Hive style partitioning (e.g., directories named `date=2024-06-01`) +/// * Merges schemas from files with compatible but not identical schemas (see [`ListingTableConfig::file_schema`]) +/// * `limit`, `filter` and `projection` pushdown for formats that support it (e.g., +/// Parquet) +/// * Statistics collection and pruning based on file metadata +/// * Pre-existing sort order (see [`ListingOptions::file_sort_order`]) +/// * Metadata caching to speed up repeated queries (see [`FileMetadataCache`]) +/// * Statistics caching (see [`FileStatisticsCache`]) +/// +/// [`FileMetadataCache`]: datafusion_execution::cache::cache_manager::FileMetadataCache +/// +/// # Reading Directories and Hive Style Partitioning +/// +/// For example, given the `table1` directory (or object store prefix) +/// +/// ```text +/// table1 +/// ├── file1.parquet +/// └── file2.parquet +/// ``` +/// +/// A `ListingTable` would read the files `file1.parquet` and `file2.parquet` as +/// a single table, merging the schemas if the files have compatible but not +/// identical schemas. +/// +/// Given the `table2` directory (or object store prefix) +/// +/// ```text +/// table2 +/// ├── date=2024-06-01 +/// │ ├── file3.parquet +/// │ └── file4.parquet +/// └── date=2024-06-02 +/// └── file5.parquet +/// ``` +/// +/// A `ListingTable` would read the files `file3.parquet`, `file4.parquet`, and +/// `file5.parquet` as a single table, again merging schemas if necessary. +/// +/// Given the hive style partitioning structure (e.g,. directories named +/// `date=2024-06-01` and `date=2026-06-02`), `ListingTable` also adds a `date` +/// column when reading the table: +/// * The files in `table2/date=2024-06-01` will have the value `2024-06-01` +/// * The files in `table2/date=2024-06-02` will have the value `2024-06-02`. +/// +/// If the query has a predicate like `WHERE date = '2024-06-01'` +/// only the corresponding directory will be read. +/// +/// # See Also +/// +/// 1. [`ListingTableConfig`]: Configuration options +/// 1. [`DataSourceExec`]: `ExecutionPlan` used by `ListingTable` +/// +/// [`DataSourceExec`]: datafusion_datasource::source::DataSourceExec +/// +/// # Caching Metadata +/// +/// Some formats, such as Parquet, use the `FileMetadataCache` to cache file +/// metadata that is needed to execute but expensive to read, such as row +/// groups and statistics. The cache is scoped to the `SessionContext` and can +/// be configured via the [runtime config options]. +/// +/// [runtime config options]: https://datafusion.apache.org/user-guide/configs.html#runtime-configuration-settings +/// +/// # Example: Read a directory of parquet files using a [`ListingTable`] +/// +/// ```no_run +/// # use datafusion_common::Result; +/// # use std::sync::Arc; +/// # use datafusion_catalog::TableProvider; +/// # use datafusion_catalog_listing::{ListingOptions, ListingTable, ListingTableConfig}; +/// # use datafusion_datasource::ListingTableUrl; +/// # use datafusion_datasource_parquet::file_format::ParquetFormat;/// # +/// # use datafusion_catalog::Session; +/// async fn get_listing_table(session: &dyn Session) -> Result> { +/// let table_path = "/path/to/parquet"; +/// +/// // Parse the path +/// let table_path = ListingTableUrl::parse(table_path)?; +/// +/// // Create default parquet options +/// let file_format = ParquetFormat::new(); +/// let listing_options = ListingOptions::new(Arc::new(file_format)) +/// .with_file_extension(".parquet"); +/// +/// // Resolve the schema +/// let resolved_schema = listing_options +/// .infer_schema(session, &table_path) +/// .await?; +/// +/// let config = ListingTableConfig::new(table_path) +/// .with_listing_options(listing_options) +/// .with_schema(resolved_schema); +/// +/// // Create a new TableProvider +/// let provider = Arc::new(ListingTable::try_new(config)?); +/// +/// # Ok(provider) +/// # } +/// ``` +#[derive(Debug, Clone)] +pub struct ListingTable { + table_paths: Vec, + /// `file_schema` contains only the columns physically stored in the data files themselves. + /// - Represents the actual fields found in files like Parquet, CSV, etc. + /// - Used when reading the raw data from files + file_schema: SchemaRef, + /// `table_schema` combines `file_schema` + partition columns + /// - Partition columns are derived from directory paths (not stored in files) + /// - These are columns like "year=2022/month=01" in paths like `/data/year=2022/month=01/file.parquet` + table_schema: SchemaRef, + /// Indicates how the schema was derived (inferred or explicitly specified) + schema_source: SchemaSource, + /// Options used to configure the listing table such as the file format + /// and partitioning information + options: ListingOptions, + /// The SQL definition for this table, if any + definition: Option, + /// Cache for collected file statistics + collected_statistics: Arc, + /// Constraints applied to this table + constraints: Constraints, + /// Column default expressions for columns that are not physically present in the data files + column_defaults: HashMap, + /// Optional [`PhysicalExprAdapterFactory`] for creating physical expression adapters + expr_adapter_factory: Option>, +} + +impl ListingTable { + /// Create new [`ListingTable`] + /// + /// See documentation and example on [`ListingTable`] and [`ListingTableConfig`] + pub fn try_new(config: ListingTableConfig) -> datafusion_common::Result { + // Extract schema_source before moving other parts of the config + let schema_source = config.schema_source(); + + let file_schema = config + .file_schema + .ok_or_else(|| internal_datafusion_err!("No schema provided."))?; + + let options = config + .options + .ok_or_else(|| internal_datafusion_err!("No ListingOptions provided"))?; + + // Add the partition columns to the file schema + let mut builder = SchemaBuilder::from(file_schema.as_ref().to_owned()); + for (part_col_name, part_col_type) in &options.table_partition_cols { + builder.push(Field::new(part_col_name, part_col_type.clone(), false)); + } + + let table_schema = Arc::new( + builder + .finish() + .with_metadata(file_schema.metadata().clone()), + ); + + let table = Self { + table_paths: config.table_paths, + file_schema, + table_schema, + schema_source, + options, + definition: None, + collected_statistics: Arc::new(DefaultFileStatisticsCache::default()), + constraints: Constraints::default(), + column_defaults: HashMap::new(), + expr_adapter_factory: config.expr_adapter_factory, + }; + + Ok(table) + } + + /// Assign constraints + pub fn with_constraints(mut self, constraints: Constraints) -> Self { + self.constraints = constraints; + self + } + + /// Assign column defaults + pub fn with_column_defaults( + mut self, + column_defaults: HashMap, + ) -> Self { + self.column_defaults = column_defaults; + self + } + + /// Set the [`FileStatisticsCache`] used to cache parquet file statistics. + /// + /// Setting a statistics cache on the `SessionContext` can avoid refetching statistics + /// multiple times in the same session. + /// + /// If `None`, creates a new [`DefaultFileStatisticsCache`] scoped to this query. + pub fn with_cache(mut self, cache: Option>) -> Self { + self.collected_statistics = + cache.unwrap_or_else(|| Arc::new(DefaultFileStatisticsCache::default())); + self + } + + /// Specify the SQL definition for this table, if any + pub fn with_definition(mut self, definition: Option) -> Self { + self.definition = definition; + self + } + + /// Get paths ref + pub fn table_paths(&self) -> &Vec { + &self.table_paths + } + + /// Get options ref + pub fn options(&self) -> &ListingOptions { + &self.options + } + + /// Get the schema source + pub fn schema_source(&self) -> SchemaSource { + self.schema_source + } + + /// Deprecated: Set the [`SchemaAdapterFactory`] for this [`ListingTable`] + /// + /// `SchemaAdapterFactory` has been removed. Use [`ListingTableConfig::with_expr_adapter_factory`] + /// and `PhysicalExprAdapterFactory` instead. See `upgrading.md` for more details. + /// + /// This method is a no-op and returns `self` unchanged. + #[deprecated( + since = "52.0.0", + note = "SchemaAdapterFactory has been removed. Use ListingTableConfig::with_expr_adapter_factory and PhysicalExprAdapterFactory instead. See upgrading.md for more details." + )] + #[expect(deprecated)] + pub fn with_schema_adapter_factory( + self, + _schema_adapter_factory: Arc, + ) -> Self { + // No-op - just return self unchanged + self + } + + /// Deprecated: Returns the [`SchemaAdapterFactory`] used by this [`ListingTable`]. + /// + /// `SchemaAdapterFactory` has been removed. Use `PhysicalExprAdapterFactory` instead. + /// See `upgrading.md` for more details. + /// + /// Always returns `None`. + #[deprecated( + since = "52.0.0", + note = "SchemaAdapterFactory has been removed. Use PhysicalExprAdapterFactory instead. See upgrading.md for more details." + )] + #[expect(deprecated)] + pub fn schema_adapter_factory(&self) -> Option> { + None + } + + /// Creates a file source for this table + fn create_file_source(&self) -> Arc { + let table_schema = TableSchema::new( + Arc::clone(&self.file_schema), + self.options + .table_partition_cols + .iter() + .map(|(col, field)| Arc::new(Field::new(col, field.clone(), false))) + .collect(), + ); + + self.options.format.file_source(table_schema) + } + + /// Creates output ordering from user-specified file_sort_order or derives + /// from file orderings when user doesn't specify. + /// + /// If user specified `file_sort_order`, that takes precedence. + /// Otherwise, attempts to derive common ordering from file orderings in + /// the provided file groups. + pub fn try_create_output_ordering( + &self, + execution_props: &ExecutionProps, + file_groups: &[FileGroup], + ) -> datafusion_common::Result> { + // If user specified sort order, use that + if !self.options.file_sort_order.is_empty() { + return create_lex_ordering( + &self.table_schema, + &self.options.file_sort_order, + execution_props, + ); + } + if let Some(ordering) = derive_common_ordering_from_files(file_groups) { + return Ok(vec![ordering]); + } + Ok(vec![]) + } +} + +/// Derives a common ordering from file orderings across all file groups. +/// +/// Returns the common ordering if all files have compatible orderings, +/// otherwise returns None. +/// +/// The function finds the longest common prefix among all file orderings. +/// For example, if files have orderings `[a, b, c]` and `[a, b]`, the common +/// ordering is `[a, b]`. +fn derive_common_ordering_from_files(file_groups: &[FileGroup]) -> Option { + enum CurrentOrderingState { + /// Initial state before processing any files + FirstFile, + /// Some common ordering found so far + SomeOrdering(LexOrdering), + /// No files have ordering + NoOrdering, + } + let mut state = CurrentOrderingState::FirstFile; + + // Collect file orderings and track counts + for group in file_groups { + for file in group.iter() { + state = match (&state, &file.ordering) { + // If this is the first file with ordering, set it as current + (CurrentOrderingState::FirstFile, Some(ordering)) => { + CurrentOrderingState::SomeOrdering(ordering.clone()) + } + (CurrentOrderingState::FirstFile, None) => { + CurrentOrderingState::NoOrdering + } + // If we have an existing ordering, find common prefix with new ordering + (CurrentOrderingState::SomeOrdering(current), Some(ordering)) => { + // Find common prefix between current and new ordering + let prefix_len = current + .as_ref() + .iter() + .zip(ordering.as_ref().iter()) + .take_while(|(a, b)| a == b) + .count(); + if prefix_len == 0 { + log::trace!( + "Cannot derive common ordering: no common prefix between orderings {current:?} and {ordering:?}" + ); + return None; + } else { + let ordering = + LexOrdering::new(current.as_ref()[..prefix_len].to_vec()) + .expect("prefix_len > 0, so ordering must be valid"); + CurrentOrderingState::SomeOrdering(ordering) + } + } + // If one file has ordering and another doesn't, no common ordering + // Return None and log a trace message explaining why + (CurrentOrderingState::SomeOrdering(ordering), None) + | (CurrentOrderingState::NoOrdering, Some(ordering)) => { + log::trace!( + "Cannot derive common ordering: some files have ordering {ordering:?}, others don't" + ); + return None; + } + // Both have no ordering, remain in NoOrdering state + (CurrentOrderingState::NoOrdering, None) => { + CurrentOrderingState::NoOrdering + } + }; + } + } + + match state { + CurrentOrderingState::SomeOrdering(ordering) => Some(ordering), + _ => None, + } +} + +// Expressions can be used for partition pruning if they can be evaluated using +// only the partition columns and there are partition columns. +fn can_be_evaluated_for_partition_pruning( + partition_column_names: &[&str], + expr: &Expr, +) -> bool { + !partition_column_names.is_empty() + && expr_applicable_for_cols(partition_column_names, expr) +} + +#[async_trait] +impl TableProvider for ListingTable { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.table_schema) + } + + fn constraints(&self) -> Option<&Constraints> { + Some(&self.constraints) + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + state: &dyn Session, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, + ) -> datafusion_common::Result> { + let options = ScanArgs::default() + .with_projection(projection.map(|p| p.as_slice())) + .with_filters(Some(filters)) + .with_limit(limit); + Ok(self.scan_with_args(state, options).await?.into_inner()) + } + + async fn scan_with_args<'a>( + &self, + state: &dyn Session, + args: ScanArgs<'a>, + ) -> datafusion_common::Result { + let projection = args.projection().map(|p| p.to_vec()); + let filters = args.filters().map(|f| f.to_vec()).unwrap_or_default(); + let limit = args.limit(); + + // extract types of partition columns + let table_partition_cols = self + .options + .table_partition_cols + .iter() + .map(|col| Ok(Arc::new(self.table_schema.field_with_name(&col.0)?.clone()))) + .collect::>>()?; + + let table_partition_col_names = table_partition_cols + .iter() + .map(|field| field.name().as_str()) + .collect::>(); + + // If the filters can be resolved using only partition cols, there is no need to + // pushdown it to TableScan, otherwise, `unhandled` pruning predicates will be generated + let (partition_filters, filters): (Vec<_>, Vec<_>) = + filters.iter().cloned().partition(|filter| { + can_be_evaluated_for_partition_pruning(&table_partition_col_names, filter) + }); + + // We should not limit the number of partitioned files to scan if there are filters and limit + // at the same time. This is because the limit should be applied after the filters are applied. + let statistic_file_limit = if filters.is_empty() { limit } else { None }; + + let ListFilesResult { + file_groups: mut partitioned_file_lists, + statistics, + grouped_by_partition: partitioned_by_file_group, + } = self + .list_files_for_scan(state, &partition_filters, statistic_file_limit) + .await?; + + // if no files need to be read, return an `EmptyExec` + if partitioned_file_lists.is_empty() { + let projected_schema = project_schema(&self.schema(), projection.as_ref())?; + return Ok(ScanResult::new(Arc::new(EmptyExec::new(projected_schema)))); + } + + let output_ordering = self.try_create_output_ordering( + state.execution_props(), + &partitioned_file_lists, + )?; + match state + .config_options() + .execution + .split_file_groups_by_statistics + .then(|| { + output_ordering.first().map(|output_ordering| { + FileScanConfig::split_groups_by_statistics_with_target_partitions( + &self.table_schema, + &partitioned_file_lists, + output_ordering, + self.options.target_partitions, + ) + }) + }) + .flatten() + { + Some(Err(e)) => log::debug!("failed to split file groups by statistics: {e}"), + Some(Ok(new_groups)) => { + if new_groups.len() <= self.options.target_partitions { + partitioned_file_lists = new_groups; + } else { + log::debug!( + "attempted to split file groups by statistics, but there were more file groups than target_partitions; falling back to unordered" + ) + } + } + None => {} // no ordering required + }; + + let Some(object_store_url) = + self.table_paths.first().map(ListingTableUrl::object_store) + else { + return Ok(ScanResult::new(Arc::new(EmptyExec::new(Arc::new( + Schema::empty(), + ))))); + }; + + let file_source = self.create_file_source(); + + // create the execution plan + let plan = self + .options + .format + .create_physical_plan( + state, + FileScanConfigBuilder::new(object_store_url, file_source) + .with_file_groups(partitioned_file_lists) + .with_constraints(self.constraints.clone()) + .with_statistics(statistics) + .with_projection_indices(projection)? + .with_limit(limit) + .with_output_ordering(output_ordering) + .with_expr_adapter(self.expr_adapter_factory.clone()) + .with_partitioned_by_file_group(partitioned_by_file_group) + .build(), + ) + .await?; + + Ok(ScanResult::new(plan)) + } + + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> datafusion_common::Result> { + let partition_column_names = self + .options + .table_partition_cols + .iter() + .map(|col| col.0.as_str()) + .collect::>(); + filters + .iter() + .map(|filter| { + if can_be_evaluated_for_partition_pruning(&partition_column_names, filter) + { + // if filter can be handled by partition pruning, it is exact + return Ok(TableProviderFilterPushDown::Exact); + } + + Ok(TableProviderFilterPushDown::Inexact) + }) + .collect() + } + + fn get_table_definition(&self) -> Option<&str> { + self.definition.as_deref() + } + + async fn insert_into( + &self, + state: &dyn Session, + input: Arc, + insert_op: InsertOp, + ) -> datafusion_common::Result> { + // Check that the schema of the plan matches the schema of this table. + self.schema() + .logically_equivalent_names_and_types(&input.schema())?; + + let table_path = &self.table_paths()[0]; + if !table_path.is_collection() { + return plan_err!( + "Inserting into a ListingTable backed by a single file is not supported, URL is possibly missing a trailing `/`. \ + To append to an existing file use StreamTable, e.g. by using CREATE UNBOUNDED EXTERNAL TABLE" + ); + } + + // Get the object store for the table path. + let store = state.runtime_env().object_store(table_path)?; + + let file_list_stream = pruned_partition_list( + state, + store.as_ref(), + table_path, + &[], + &self.options.file_extension, + &self.options.table_partition_cols, + ) + .await?; + + let file_group = file_list_stream.try_collect::>().await?.into(); + let keep_partition_by_columns = + state.config_options().execution.keep_partition_by_columns; + + // Invalidate cache entries for this table if they exist + if let Some(lfc) = state.runtime_env().cache_manager.get_list_files_cache() { + let key = TableScopedPath { + table: table_path.get_table_ref().clone(), + path: table_path.prefix().clone(), + }; + let _ = lfc.remove(&key); + } + + // Sink related option, apart from format + let config = FileSinkConfig { + original_url: String::default(), + object_store_url: self.table_paths()[0].object_store(), + table_paths: self.table_paths().clone(), + file_group, + output_schema: self.schema(), + table_partition_cols: self.options.table_partition_cols.clone(), + insert_op, + keep_partition_by_columns, + file_extension: self.options().format.get_ext(), + file_output_mode: FileOutputMode::Automatic, + }; + + // For writes, we only use user-specified ordering (no file groups to derive from) + let orderings = self.try_create_output_ordering(state.execution_props(), &[])?; + // It is sufficient to pass only one of the equivalent orderings: + let order_requirements = orderings.into_iter().next().map(Into::into); + + self.options() + .format + .create_writer_physical_plan(input, state, config, order_requirements) + .await + } + + fn get_column_default(&self, column: &str) -> Option<&Expr> { + self.column_defaults.get(column) + } +} + +impl ListingTable { + /// Get the list of files for a scan as well as the file level statistics. + /// The list is grouped to let the execution plan know how the files should + /// be distributed to different threads / executors. + pub async fn list_files_for_scan<'a>( + &'a self, + ctx: &'a dyn Session, + filters: &'a [Expr], + limit: Option, + ) -> datafusion_common::Result { + let store = if let Some(url) = self.table_paths.first() { + ctx.runtime_env().object_store(url)? + } else { + return Ok(ListFilesResult { + file_groups: vec![], + statistics: Statistics::new_unknown(&self.file_schema), + grouped_by_partition: false, + }); + }; + // list files (with partitions) + let file_list = future::try_join_all(self.table_paths.iter().map(|table_path| { + pruned_partition_list( + ctx, + store.as_ref(), + table_path, + filters, + &self.options.file_extension, + &self.options.table_partition_cols, + ) + })) + .await?; + let meta_fetch_concurrency = + ctx.config_options().execution.meta_fetch_concurrency; + let file_list = stream::iter(file_list).flatten_unordered(meta_fetch_concurrency); + // collect the statistics and ordering if required by the config + let files = file_list + .map(|part_file| async { + let part_file = part_file?; + let (statistics, ordering) = if self.options.collect_stat { + self.do_collect_statistics_and_ordering(ctx, &store, &part_file) + .await? + } else { + (Arc::new(Statistics::new_unknown(&self.file_schema)), None) + }; + Ok(part_file + .with_statistics(statistics) + .with_ordering(ordering)) + }) + .boxed() + .buffer_unordered(ctx.config_options().execution.meta_fetch_concurrency); + + let (file_group, inexact_stats) = + get_files_with_limit(files, limit, self.options.collect_stat).await?; + + // Threshold: 0 = disabled, N > 0 = enabled when distinct_keys >= N + // + // When enabled, files are grouped by their Hive partition column values, allowing + // FileScanConfig to declare Hash partitioning. This enables the optimizer to skip + // hash repartitioning for aggregates and joins on partition columns. + let threshold = ctx.config_options().optimizer.preserve_file_partitions; + + let (file_groups, grouped_by_partition) = if threshold > 0 + && !self.options.table_partition_cols.is_empty() + { + let grouped = + file_group.group_by_partition_values(self.options.target_partitions); + if grouped.len() >= threshold { + (grouped, true) + } else { + let all_files: Vec<_> = + grouped.into_iter().flat_map(|g| g.into_inner()).collect(); + ( + FileGroup::new(all_files).split_files(self.options.target_partitions), + false, + ) + } + } else { + ( + file_group.split_files(self.options.target_partitions), + false, + ) + }; + + let (file_groups, stats) = compute_all_files_statistics( + file_groups, + self.schema(), + self.options.collect_stat, + inexact_stats, + )?; + + // Note: Statistics already include both file columns and partition columns. + // PartitionedFile::with_statistics automatically appends exact partition column + // statistics (min=max=partition_value, null_count=0, distinct_count=1) computed + // from partition_values. + Ok(ListFilesResult { + file_groups, + statistics: stats, + grouped_by_partition, + }) + } + + /// Collects statistics and ordering for a given partitioned file. + /// + /// This method checks if statistics are cached. If cached, it returns the + /// cached statistics and infers ordering separately. If not cached, it infers + /// both statistics and ordering in a single metadata read for efficiency. + async fn do_collect_statistics_and_ordering( + &self, + ctx: &dyn Session, + store: &Arc, + part_file: &PartitionedFile, + ) -> datafusion_common::Result<(Arc, Option)> { + use datafusion_execution::cache::cache_manager::CachedFileMetadata; + + let path = &part_file.object_meta.location; + let meta = &part_file.object_meta; + + // Check cache first - if we have valid cached statistics and ordering + if let Some(cached) = self.collected_statistics.get(path) + && cached.is_valid_for(meta) + { + // Return cached statistics and ordering + return Ok((Arc::clone(&cached.statistics), cached.ordering.clone())); + } + + // Cache miss or invalid: fetch both statistics and ordering in a single metadata read + let file_meta = self + .options + .format + .infer_stats_and_ordering(ctx, store, Arc::clone(&self.file_schema), meta) + .await?; + + let statistics = Arc::new(file_meta.statistics); + + // Store in cache + self.collected_statistics.put( + path, + CachedFileMetadata::new( + meta.clone(), + Arc::clone(&statistics), + file_meta.ordering.clone(), + ), + ); + + Ok((statistics, file_meta.ordering)) + } +} + +/// Processes a stream of partitioned files and returns a `FileGroup` containing the files. +/// +/// This function collects files from the provided stream until either: +/// 1. The stream is exhausted +/// 2. The accumulated number of rows exceeds the provided `limit` (if specified) +/// +/// # Arguments +/// * `files` - A stream of `Result` items to process +/// * `limit` - An optional row count limit. If provided, the function will stop collecting files +/// once the accumulated number of rows exceeds this limit +/// * `collect_stats` - Whether to collect and accumulate statistics from the files +/// +/// # Returns +/// A `Result` containing a `FileGroup` with the collected files +/// and a boolean indicating whether the statistics are inexact. +/// +/// # Note +/// The function will continue processing files if statistics are not available or if the +/// limit is not provided. If `collect_stats` is false, statistics won't be accumulated +/// but files will still be collected. +async fn get_files_with_limit( + files: impl Stream>, + limit: Option, + collect_stats: bool, +) -> datafusion_common::Result<(FileGroup, bool)> { + let mut file_group = FileGroup::default(); + // Fusing the stream allows us to call next safely even once it is finished. + let mut all_files = Box::pin(files.fuse()); + enum ProcessingState { + ReadingFiles, + ReachedLimit, + } + + let mut state = ProcessingState::ReadingFiles; + let mut num_rows = Precision::Absent; + + while let Some(file_result) = all_files.next().await { + // Early exit if we've already reached our limit + if matches!(state, ProcessingState::ReachedLimit) { + break; + } + + let file = file_result?; + + // Update file statistics regardless of state + if collect_stats && let Some(file_stats) = &file.statistics { + num_rows = if file_group.is_empty() { + // For the first file, just take its row count + file_stats.num_rows + } else { + // For subsequent files, accumulate the counts + num_rows.add(&file_stats.num_rows) + }; + } + + // Always add the file to our group + file_group.push(file); + + // Check if we've hit the limit (if one was specified) + if let Some(limit) = limit + && let Precision::Exact(row_count) = num_rows + && row_count > limit + { + state = ProcessingState::ReachedLimit; + } + } + // If we still have files in the stream, it means that the limit kicked + // in, and the statistic could have been different had we processed the + // files in a different order. + let inexact_stats = all_files.next().await.is_some(); + Ok((file_group, inexact_stats)) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::compute::SortOptions; + use datafusion_physical_expr::expressions::Column; + use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; + use std::sync::Arc; + + /// Helper to create a PhysicalSortExpr + fn sort_expr( + name: &str, + idx: usize, + descending: bool, + nulls_first: bool, + ) -> PhysicalSortExpr { + PhysicalSortExpr::new( + Arc::new(Column::new(name, idx)), + SortOptions { + descending, + nulls_first, + }, + ) + } + + /// Helper to create a LexOrdering (unwraps the Option) + fn lex_ordering(exprs: Vec) -> LexOrdering { + LexOrdering::new(exprs).expect("expected non-empty ordering") + } + + /// Helper to create a PartitionedFile with optional ordering + fn create_file(name: &str, ordering: Option) -> PartitionedFile { + PartitionedFile::new(name.to_string(), 1024).with_ordering(ordering) + } + + #[test] + fn test_derive_common_ordering_all_files_same_ordering() { + // All files have the same ordering -> returns that ordering + let ordering = lex_ordering(vec![ + sort_expr("a", 0, false, true), + sort_expr("b", 1, true, false), + ]); + + let file_groups = vec![ + FileGroup::new(vec![ + create_file("f1.parquet", Some(ordering.clone())), + create_file("f2.parquet", Some(ordering.clone())), + ]), + FileGroup::new(vec![create_file("f3.parquet", Some(ordering.clone()))]), + ]; + + let result = derive_common_ordering_from_files(&file_groups); + assert_eq!(result, Some(ordering)); + } + + #[test] + fn test_derive_common_ordering_common_prefix() { + // Files have different orderings but share a common prefix + let ordering_abc = lex_ordering(vec![ + sort_expr("a", 0, false, true), + sort_expr("b", 1, false, true), + sort_expr("c", 2, false, true), + ]); + let ordering_ab = lex_ordering(vec![ + sort_expr("a", 0, false, true), + sort_expr("b", 1, false, true), + ]); + + let file_groups = vec![FileGroup::new(vec![ + create_file("f1.parquet", Some(ordering_abc)), + create_file("f2.parquet", Some(ordering_ab.clone())), + ])]; + + let result = derive_common_ordering_from_files(&file_groups); + assert_eq!(result, Some(ordering_ab)); + } + + #[test] + fn test_derive_common_ordering_no_common_prefix() { + // Files have completely different orderings -> returns None + let ordering_a = lex_ordering(vec![sort_expr("a", 0, false, true)]); + let ordering_b = lex_ordering(vec![sort_expr("b", 1, false, true)]); + + let file_groups = vec![FileGroup::new(vec![ + create_file("f1.parquet", Some(ordering_a)), + create_file("f2.parquet", Some(ordering_b)), + ])]; + + let result = derive_common_ordering_from_files(&file_groups); + assert_eq!(result, None); + } + + #[test] + fn test_derive_common_ordering_mixed_with_none() { + // Some files have ordering, some don't -> returns None + let ordering = lex_ordering(vec![sort_expr("a", 0, false, true)]); + + let file_groups = vec![FileGroup::new(vec![ + create_file("f1.parquet", Some(ordering)), + create_file("f2.parquet", None), + ])]; + + let result = derive_common_ordering_from_files(&file_groups); + assert_eq!(result, None); + } + + #[test] + fn test_derive_common_ordering_all_none() { + // No files have ordering -> returns None + let file_groups = vec![FileGroup::new(vec![ + create_file("f1.parquet", None), + create_file("f2.parquet", None), + ])]; + + let result = derive_common_ordering_from_files(&file_groups); + assert_eq!(result, None); + } + + #[test] + fn test_derive_common_ordering_empty_groups() { + // Empty file groups -> returns None + let file_groups: Vec = vec![]; + let result = derive_common_ordering_from_files(&file_groups); + assert_eq!(result, None); + } + + #[test] + fn test_derive_common_ordering_single_file() { + // Single file with ordering -> returns that ordering + let ordering = lex_ordering(vec![ + sort_expr("a", 0, false, true), + sort_expr("b", 1, true, false), + ]); + + let file_groups = vec![FileGroup::new(vec![create_file( + "f1.parquet", + Some(ordering.clone()), + )])]; + + let result = derive_common_ordering_from_files(&file_groups); + assert_eq!(result, Some(ordering)); + } +} diff --git a/datafusion/catalog/Cargo.toml b/datafusion/catalog/Cargo.toml index 7307c4de87a8a..1009e9aee477b 100644 --- a/datafusion/catalog/Cargo.toml +++ b/datafusion/catalog/Cargo.toml @@ -18,11 +18,11 @@ [package] name = "datafusion-catalog" description = "datafusion-catalog" +readme = "README.md" authors.workspace = true edition.workspace = true homepage.workspace = true license.workspace = true -readme.workspace = true repository.workspace = true rust-version.workspace = true version.workspace = true @@ -42,7 +42,6 @@ datafusion-expr = { workspace = true } datafusion-physical-expr = { workspace = true } datafusion-physical-plan = { workspace = true } datafusion-session = { workspace = true } -datafusion-sql = { workspace = true } futures = { workspace = true } itertools = { workspace = true } log = { workspace = true } @@ -50,5 +49,8 @@ object_store = { workspace = true } parking_lot = { workspace = true } tokio = { workspace = true } +# Note: add additional linter rules in lib.rs. +# Rust does not support workspace + new linter rules in subcrates yet +# https://github.com/rust-lang/cargo/issues/13157 [lints] workspace = true diff --git a/datafusion/catalog/README.md b/datafusion/catalog/README.md index 5b201e736fdc4..48c61b43c025b 100644 --- a/datafusion/catalog/README.md +++ b/datafusion/catalog/README.md @@ -17,10 +17,16 @@ under the License. --> -# DataFusion Catalog +# Apache DataFusion Catalog -[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. This crate is a submodule of DataFusion that provides catalog management functionality, including catalogs, schemas, and tables. -[df]: https://crates.io/crates/datafusion +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/catalog/src/async.rs b/datafusion/catalog/src/async.rs index 5d7a51ad71232..1b8039d828fdb 100644 --- a/datafusion/catalog/src/async.rs +++ b/datafusion/catalog/src/async.rs @@ -18,7 +18,7 @@ use std::sync::Arc; use async_trait::async_trait; -use datafusion_common::{error::Result, not_impl_err, HashMap, TableReference}; +use datafusion_common::{HashMap, TableReference, error::Result, not_impl_err}; use datafusion_execution::config::SessionConfig; use crate::{CatalogProvider, CatalogProviderList, SchemaProvider, TableProvider}; @@ -60,7 +60,9 @@ impl SchemaProvider for ResolvedSchemaProvider { } fn deregister_table(&self, name: &str) -> Result>> { - not_impl_err!("Attempt to deregister table '{name}' with ResolvedSchemaProvider which is not supported") + not_impl_err!( + "Attempt to deregister table '{name}' with ResolvedSchemaProvider which is not supported" + ) } fn table_exist(&self, name: &str) -> bool { @@ -193,7 +195,7 @@ impl CatalogProviderList for ResolvedCatalogProviderList { /// /// See the [remote_catalog.rs] for an end to end example /// -/// [remote_catalog.rs]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/remote_catalog.rs +/// [remote_catalog.rs]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/data_io/remote_catalog.rs #[async_trait] pub trait AsyncSchemaProvider: Send + Sync { /// Lookup a table in the schema provider @@ -425,14 +427,14 @@ mod tests { use std::{ any::Any, sync::{ - atomic::{AtomicU32, Ordering}, Arc, + atomic::{AtomicU32, Ordering}, }, }; use arrow::datatypes::SchemaRef; use async_trait::async_trait; - use datafusion_common::{error::Result, Statistics, TableReference}; + use datafusion_common::{Statistics, TableReference, error::Result}; use datafusion_execution::config::SessionConfig; use datafusion_expr::{Expr, TableType}; use datafusion_physical_plan::ExecutionPlan; @@ -737,7 +739,7 @@ mod tests { ] { let async_provider = MockAsyncCatalogProviderList::default(); let cached_provider = async_provider - .resolve(&[table_ref.clone()], &test_config()) + .resolve(std::slice::from_ref(table_ref), &test_config()) .await .unwrap(); diff --git a/datafusion/catalog/src/catalog.rs b/datafusion/catalog/src/catalog.rs index 71b9eccf9d657..bb9e89eba2fef 100644 --- a/datafusion/catalog/src/catalog.rs +++ b/datafusion/catalog/src/catalog.rs @@ -20,8 +20,8 @@ use std::fmt::Debug; use std::sync::Arc; pub use crate::schema::SchemaProvider; -use datafusion_common::not_impl_err; use datafusion_common::Result; +use datafusion_common::not_impl_err; /// Represents a catalog, comprising a number of named schemas. /// @@ -61,7 +61,7 @@ use datafusion_common::Result; /// schemas and tables exist. /// /// [Delta Lake]: https://delta.io/ -/// [`remote_catalog`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/remote_catalog.rs +/// [`remote_catalog`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/data_io/remote_catalog.rs /// /// The [`CatalogProvider`] can support this use case, but it takes some care. /// The planning APIs in DataFusion are not `async` and thus network IO can not @@ -100,7 +100,7 @@ use datafusion_common::Result; /// /// [`datafusion-cli`]: https://datafusion.apache.org/user-guide/cli/index.html /// [`DynamicFileCatalogProvider`]: https://github.com/apache/datafusion/blob/31b9b48b08592b7d293f46e75707aad7dadd7cbc/datafusion-cli/src/catalog.rs#L75 -/// [`catalog.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/catalog.rs +/// [`catalog.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/data_io/catalog.rs /// [delta-rs]: https://github.com/delta-io/delta-rs /// [`UnityCatalogProvider`]: https://github.com/delta-io/delta-rs/blob/951436ecec476ce65b5ed3b58b50fb0846ca7b91/crates/deltalake-core/src/data_catalog/unity/datafusion.rs#L111-L123 /// diff --git a/datafusion/catalog/src/cte_worktable.rs b/datafusion/catalog/src/cte_worktable.rs index d72a30909c02c..9565dcc60141e 100644 --- a/datafusion/catalog/src/cte_worktable.rs +++ b/datafusion/catalog/src/cte_worktable.rs @@ -17,20 +17,18 @@ //! CteWorkTable implementation used for recursive queries +use std::any::Any; +use std::borrow::Cow; use std::sync::Arc; -use std::{any::Any, borrow::Cow}; -use crate::Session; use arrow::datatypes::SchemaRef; use async_trait::async_trait; -use datafusion_physical_plan::work_table::WorkTableExec; - -use datafusion_physical_plan::ExecutionPlan; - use datafusion_common::error::Result; use datafusion_expr::{Expr, LogicalPlan, TableProviderFilterPushDown, TableType}; +use datafusion_physical_plan::ExecutionPlan; +use datafusion_physical_plan::work_table::WorkTableExec; -use crate::TableProvider; +use crate::{ScanArgs, ScanResult, Session, TableProvider}; /// The temporary working table where the previous iteration of a recursive query is stored /// Naming is based on PostgreSQL's implementation. @@ -71,7 +69,7 @@ impl TableProvider for CteWorkTable { self } - fn get_logical_plan(&self) -> Option> { + fn get_logical_plan(&'_ self) -> Option> { None } @@ -85,16 +83,28 @@ impl TableProvider for CteWorkTable { async fn scan( &self, - _state: &dyn Session, - _projection: Option<&Vec>, - _filters: &[Expr], - _limit: Option, + state: &dyn Session, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, ) -> Result> { - // TODO: pushdown filters and limits - Ok(Arc::new(WorkTableExec::new( + let options = ScanArgs::default() + .with_projection(projection.map(|p| p.as_slice())) + .with_filters(Some(filters)) + .with_limit(limit); + Ok(self.scan_with_args(state, options).await?.into_inner()) + } + + async fn scan_with_args<'a>( + &self, + _state: &dyn Session, + args: ScanArgs<'a>, + ) -> Result { + Ok(ScanResult::new(Arc::new(WorkTableExec::new( self.name.clone(), Arc::clone(&self.table_schema), - ))) + args.projection().map(|p| p.to_vec()), + )?))) } fn supports_filters_pushdown( diff --git a/datafusion/catalog/src/default_table_source.rs b/datafusion/catalog/src/default_table_source.rs index 9db8242caa999..fb6531ba0b2ee 100644 --- a/datafusion/catalog/src/default_table_source.rs +++ b/datafusion/catalog/src/default_table_source.rs @@ -23,7 +23,7 @@ use std::{any::Any, borrow::Cow}; use crate::TableProvider; use arrow::datatypes::SchemaRef; -use datafusion_common::{internal_err, Constraints}; +use datafusion_common::{Constraints, internal_err}; use datafusion_expr::{Expr, TableProviderFilterPushDown, TableSource, TableType}; /// Implements [`TableSource`] for a [`TableProvider`] @@ -33,8 +33,6 @@ use datafusion_expr::{Expr, TableProviderFilterPushDown, TableSource, TableType} /// /// It is used so logical plans in the `datafusion_expr` crate do not have a /// direct dependency on physical plans, such as [`TableProvider`]s. -/// -/// [`TableProvider`]: https://docs.rs/datafusion/latest/datafusion/datasource/provider/trait.TableProvider.html pub struct DefaultTableSource { /// table provider pub table_provider: Arc, @@ -78,7 +76,7 @@ impl TableSource for DefaultTableSource { self.table_provider.supports_filters_pushdown(filter) } - fn get_logical_plan(&self) -> Option> { + fn get_logical_plan(&'_ self) -> Option> { self.table_provider.get_logical_plan() } diff --git a/datafusion/catalog/src/information_schema.rs b/datafusion/catalog/src/information_schema.rs index 057d1a8198820..ea93dc21a3f5b 100644 --- a/datafusion/catalog/src/information_schema.rs +++ b/datafusion/catalog/src/information_schema.rs @@ -24,20 +24,24 @@ use crate::{CatalogProviderList, SchemaProvider, TableProvider}; use arrow::array::builder::{BooleanBuilder, UInt8Builder}; use arrow::{ array::{StringBuilder, UInt64Builder}, - datatypes::{DataType, Field, Schema, SchemaRef}, + datatypes::{DataType, Field, FieldRef, Schema, SchemaRef}, record_batch::RecordBatch, }; use async_trait::async_trait; +use datafusion_common::DataFusionError; use datafusion_common::config::{ConfigEntry, ConfigOptions}; use datafusion_common::error::Result; use datafusion_common::types::NativeType; -use datafusion_common::DataFusionError; use datafusion_execution::TaskContext; -use datafusion_expr::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF}; +use datafusion_execution::runtime_env::RuntimeEnv; +use datafusion_expr::function::WindowUDFFieldArgs; +use datafusion_expr::{ + AggregateUDF, ReturnFieldArgs, ScalarUDF, Signature, TypeSignature, WindowUDF, +}; use datafusion_expr::{TableType, Volatility}; +use datafusion_physical_plan::SendableRecordBatchStream; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use datafusion_physical_plan::streaming::PartitionStream; -use datafusion_physical_plan::SendableRecordBatchStream; use std::collections::{BTreeSet, HashMap, HashSet}; use std::fmt::Debug; use std::{any::Any, sync::Arc}; @@ -103,12 +107,14 @@ impl InformationSchemaConfig { // schema name may not exist in the catalog, so we need to check if let Some(schema) = catalog.schema(&schema_name) { for table_name in schema.table_names() { - if let Some(table) = schema.table(&table_name).await? { + if let Some(table_type) = + schema.table_type(&table_name).await? + { builder.add_table( &catalog_name, &schema_name, &table_name, - table.table_type(), + table_type, ); } } @@ -135,11 +141,11 @@ impl InformationSchemaConfig { let catalog = self.catalog_list.catalog(&catalog_name).unwrap(); for schema_name in catalog.schema_names() { - if schema_name != INFORMATION_SCHEMA { - if let Some(schema) = catalog.schema(&schema_name) { - let schema_owner = schema.owner_name(); - builder.add_schemata(&catalog_name, &schema_name, schema_owner); - } + if schema_name != INFORMATION_SCHEMA + && let Some(schema) = catalog.schema(&schema_name) + { + let schema_owner = schema.owner_name(); + builder.add_schemata(&catalog_name, &schema_name, schema_owner); } } } @@ -213,11 +219,16 @@ impl InformationSchemaConfig { fn make_df_settings( &self, config_options: &ConfigOptions, + runtime_env: &Arc, builder: &mut InformationSchemaDfSettingsBuilder, ) { for entry in config_options.entries() { builder.add_setting(entry); } + // Add runtime configuration entries + for entry in runtime_env.config_entries() { + builder.add_setting(entry); + } } fn make_routines( @@ -243,7 +254,7 @@ impl InformationSchemaConfig { name, "FUNCTION", Self::is_deterministic(udf.signature()), - return_type, + return_type.as_ref(), "SCALAR", udf.documentation().map(|d| d.description.to_string()), udf.documentation().map(|d| d.syntax_example.to_string()), @@ -263,7 +274,7 @@ impl InformationSchemaConfig { name, "FUNCTION", Self::is_deterministic(udaf.signature()), - return_type, + return_type.as_ref(), "AGGREGATE", udaf.documentation().map(|d| d.description.to_string()), udaf.documentation().map(|d| d.syntax_example.to_string()), @@ -283,7 +294,7 @@ impl InformationSchemaConfig { name, "FUNCTION", Self::is_deterministic(udwf.signature()), - return_type, + return_type.as_ref(), "WINDOW", udwf.documentation().map(|d| d.description.to_string()), udwf.documentation().map(|d| d.syntax_example.to_string()), @@ -413,14 +424,28 @@ fn get_udf_args_and_return_types( Ok(arg_types .into_iter() .map(|arg_types| { - // only handle the function which implemented [`ScalarUDFImpl::return_type`] method + let arg_fields: Vec = arg_types + .iter() + .enumerate() + .map(|(i, t)| { + Arc::new(Field::new(format!("arg_{i}"), t.clone(), true)) + }) + .collect(); + let scalar_arguments = vec![None; arg_fields.len()]; let return_type = udf - .return_type(&arg_types) - .map(|t| remove_native_type_prefix(NativeType::from(t))) + .return_field_from_args(ReturnFieldArgs { + arg_fields: &arg_fields, + scalar_arguments: &scalar_arguments, + }) + .map(|f| { + remove_native_type_prefix(&NativeType::from( + f.data_type().clone(), + )) + }) .ok(); let arg_types = arg_types .into_iter() - .map(|t| remove_native_type_prefix(NativeType::from(t))) + .map(|t| remove_native_type_prefix(&NativeType::from(t))) .collect::>(); (arg_types, return_type) }) @@ -439,14 +464,24 @@ fn get_udaf_args_and_return_types( Ok(arg_types .into_iter() .map(|arg_types| { - // only handle the function which implemented [`ScalarUDFImpl::return_type`] method + let arg_fields: Vec = arg_types + .iter() + .enumerate() + .map(|(i, t)| { + Arc::new(Field::new(format!("arg_{i}"), t.clone(), true)) + }) + .collect(); let return_type = udaf - .return_type(&arg_types) - .ok() - .map(|t| remove_native_type_prefix(NativeType::from(t))); + .return_field(&arg_fields) + .map(|f| { + remove_native_type_prefix(&NativeType::from( + f.data_type().clone(), + )) + }) + .ok(); let arg_types = arg_types .into_iter() - .map(|t| remove_native_type_prefix(NativeType::from(t))) + .map(|t| remove_native_type_prefix(&NativeType::from(t))) .collect::>(); (arg_types, return_type) }) @@ -465,20 +500,34 @@ fn get_udwf_args_and_return_types( Ok(arg_types .into_iter() .map(|arg_types| { - // only handle the function which implemented [`ScalarUDFImpl::return_type`] method + let arg_fields: Vec = arg_types + .iter() + .enumerate() + .map(|(i, t)| { + Arc::new(Field::new(format!("arg_{i}"), t.clone(), true)) + }) + .collect(); + let return_type = udwf + .field(WindowUDFFieldArgs::new(&arg_fields, udwf.name())) + .map(|f| { + remove_native_type_prefix(&NativeType::from( + f.data_type().clone(), + )) + }) + .ok(); let arg_types = arg_types .into_iter() - .map(|t| remove_native_type_prefix(NativeType::from(t))) + .map(|t| remove_native_type_prefix(&NativeType::from(t))) .collect::>(); - (arg_types, None) + (arg_types, return_type) }) .collect::>()) } } #[inline] -fn remove_native_type_prefix(native_type: NativeType) -> String { - format!("{native_type:?}") +fn remove_native_type_prefix(native_type: &NativeType) -> String { + format!("{native_type}") } #[async_trait] @@ -490,7 +539,7 @@ impl SchemaProvider for InformationSchemaProvider { fn table_names(&self) -> Vec { INFORMATION_SCHEMA_TABLES .iter() - .map(|t| t.to_string()) + .map(|t| (*t).to_string()) .collect() } @@ -677,7 +726,7 @@ impl InformationSchemaViewBuilder { catalog_name: impl AsRef, schema_name: impl AsRef, table_name: impl AsRef, - definition: Option>, + definition: Option<&(impl AsRef + ?Sized)>, ) { // Note: append_value is actually infallible. self.catalog_names.append_value(catalog_name.as_ref()); @@ -808,7 +857,7 @@ impl InformationSchemaColumnsBuilder { ) { use DataType::*; - // Note: append_value is actually infallable. + // Note: append_value is actually infallible. self.catalog_names.append_value(catalog_name); self.schema_names.append_value(schema_name); self.table_names.append_value(table_name); @@ -825,8 +874,7 @@ impl InformationSchemaColumnsBuilder { self.is_nullables.append_value(nullable_str); // "System supplied type" --> Use debug format of the datatype - self.data_types - .append_value(format!("{:?}", field.data_type())); + self.data_types.append_value(field.data_type().to_string()); // "If data_type identifies a character or bit string type, the // declared maximum length; null for all other data types or @@ -1059,7 +1107,12 @@ impl PartitionStream for InformationSchemaDfSettings { // TODO: Stream this futures::stream::once(async move { // create a mem table with the names of tables - config.make_df_settings(ctx.session_config().options(), &mut builder); + let runtime_env = ctx.runtime_env(); + config.make_df_settings( + ctx.session_config().options(), + &runtime_env, + &mut builder, + ); Ok(builder.finish()) }), )) @@ -1155,7 +1208,7 @@ struct InformationSchemaRoutinesBuilder { } impl InformationSchemaRoutinesBuilder { - #[allow(clippy::too_many_arguments)] + #[expect(clippy::too_many_arguments)] fn add_routine( &mut self, catalog_name: impl AsRef, @@ -1163,7 +1216,7 @@ impl InformationSchemaRoutinesBuilder { routine_name: impl AsRef, routine_type: impl AsRef, is_deterministic: bool, - data_type: Option>, + data_type: Option<&impl AsRef>, function_type: impl AsRef, description: Option>, syntax_example: Option>, @@ -1289,7 +1342,7 @@ struct InformationSchemaParametersBuilder { } impl InformationSchemaParametersBuilder { - #[allow(clippy::too_many_arguments)] + #[expect(clippy::too_many_arguments)] fn add_parameter( &mut self, specific_catalog: impl AsRef, @@ -1297,7 +1350,7 @@ impl InformationSchemaParametersBuilder { specific_name: impl AsRef, ordinal_position: u64, parameter_mode: impl AsRef, - parameter_name: Option>, + parameter_name: Option<&(impl AsRef + ?Sized)>, data_type: impl AsRef, parameter_default: Option>, is_variadic: bool, @@ -1359,3 +1412,94 @@ impl PartitionStream for InformationSchemaParameters { )) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::CatalogProvider; + + #[tokio::test] + async fn make_tables_uses_table_type() { + let config = InformationSchemaConfig { + catalog_list: Arc::new(Fixture), + }; + let mut builder = InformationSchemaTablesBuilder { + catalog_names: StringBuilder::new(), + schema_names: StringBuilder::new(), + table_names: StringBuilder::new(), + table_types: StringBuilder::new(), + schema: Arc::new(Schema::empty()), + }; + + assert!(config.make_tables(&mut builder).await.is_ok()); + + assert_eq!("BASE TABLE", builder.table_types.finish().value(0)); + } + + #[derive(Debug)] + struct Fixture; + + #[async_trait] + impl SchemaProvider for Fixture { + // InformationSchemaConfig::make_tables should use this. + async fn table_type(&self, _: &str) -> Result> { + Ok(Some(TableType::Base)) + } + + // InformationSchemaConfig::make_tables used this before `table_type` + // existed but should not, as it may be expensive. + async fn table(&self, _: &str) -> Result>> { + panic!( + "InformationSchemaConfig::make_tables called SchemaProvider::table instead of table_type" + ) + } + + fn as_any(&self) -> &dyn Any { + unimplemented!("not required for these tests") + } + + fn table_names(&self) -> Vec { + vec!["atable".to_string()] + } + + fn table_exist(&self, _: &str) -> bool { + unimplemented!("not required for these tests") + } + } + + impl CatalogProviderList for Fixture { + fn as_any(&self) -> &dyn Any { + unimplemented!("not required for these tests") + } + + fn register_catalog( + &self, + _: String, + _: Arc, + ) -> Option> { + unimplemented!("not required for these tests") + } + + fn catalog_names(&self) -> Vec { + vec!["acatalog".to_string()] + } + + fn catalog(&self, _: &str) -> Option> { + Some(Arc::new(Self)) + } + } + + impl CatalogProvider for Fixture { + fn as_any(&self) -> &dyn Any { + unimplemented!("not required for these tests") + } + + fn schema_names(&self) -> Vec { + vec!["aschema".to_string()] + } + + fn schema(&self, _: &str) -> Option> { + Some(Arc::new(Self)) + } + } +} diff --git a/datafusion/catalog/src/lib.rs b/datafusion/catalog/src/lib.rs index 0394b05277dac..931941e8fdfad 100644 --- a/datafusion/catalog/src/lib.rs +++ b/datafusion/catalog/src/lib.rs @@ -19,10 +19,11 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] +#![cfg_attr(test, allow(clippy::needless_pass_by_value))] //! Interfaces and default implementations of catalogs and schemas. //! @@ -46,13 +47,13 @@ mod dynamic_file; mod schema; mod table; +pub use r#async::*; pub use catalog::*; pub use datafusion_session::Session; pub use dynamic_file::catalog::*; pub use memory::{ MemTable, MemoryCatalogProvider, MemoryCatalogProviderList, MemorySchemaProvider, }; -pub use r#async::*; pub use schema::*; pub use table::*; diff --git a/datafusion/catalog/src/listing_schema.rs b/datafusion/catalog/src/listing_schema.rs index cc2c2ee606b3d..77fbea8577089 100644 --- a/datafusion/catalog/src/listing_schema.rs +++ b/datafusion/catalog/src/listing_schema.rs @@ -26,7 +26,7 @@ use crate::{SchemaProvider, TableProvider, TableProviderFactory}; use crate::Session; use datafusion_common::{ - Constraints, DFSchema, DataFusionError, HashMap, TableReference, + DFSchema, DataFusionError, HashMap, TableReference, internal_datafusion_err, }; use datafusion_expr::CreateExternalTable; @@ -111,17 +111,13 @@ impl ListingSchemaProvider { let file_name = table .path .file_name() - .ok_or_else(|| { - DataFusionError::Internal("Cannot parse file name!".to_string()) - })? + .ok_or_else(|| internal_datafusion_err!("Cannot parse file name!"))? .to_str() - .ok_or_else(|| { - DataFusionError::Internal("Cannot parse file name!".to_string()) - })?; + .ok_or_else(|| internal_datafusion_err!("Cannot parse file name!"))?; let table_name = file_name.split('.').collect_vec()[0]; - let table_path = table.to_string().ok_or_else(|| { - DataFusionError::Internal("Cannot parse file name!".to_string()) - })?; + let table_path = table + .to_string() + .ok_or_else(|| internal_datafusion_err!("Cannot parse file name!"))?; if !self.table_exist(table_name) { let table_url = format!("{}/{}", self.authority, table_path); @@ -131,21 +127,13 @@ impl ListingSchemaProvider { .factory .create( state, - &CreateExternalTable { - schema: Arc::new(DFSchema::empty()), + &CreateExternalTable::builder( name, - location: table_url, - file_type: self.format.clone(), - table_partition_cols: vec![], - if_not_exists: false, - temporary: false, - definition: None, - order_exprs: vec![], - unbounded: false, - options: Default::default(), - constraints: Constraints::empty(), - column_defaults: Default::default(), - }, + table_url, + self.format.clone(), + Arc::new(DFSchema::empty()), + ) + .build(), ) .await?; let _ = diff --git a/datafusion/catalog/src/memory/schema.rs b/datafusion/catalog/src/memory/schema.rs index f1b3628f7affc..97a579b021617 100644 --- a/datafusion/catalog/src/memory/schema.rs +++ b/datafusion/catalog/src/memory/schema.rs @@ -20,7 +20,7 @@ use crate::{SchemaProvider, TableProvider}; use async_trait::async_trait; use dashmap::DashMap; -use datafusion_common::{exec_err, DataFusionError}; +use datafusion_common::{DataFusionError, exec_err}; use std::any::Any; use std::sync::Arc; diff --git a/datafusion/catalog/src/memory/table.rs b/datafusion/catalog/src/memory/table.rs index 81243e2c4889e..9b91062657a07 100644 --- a/datafusion/catalog/src/memory/table.rs +++ b/datafusion/catalog/src/memory/table.rs @@ -23,25 +23,32 @@ use std::fmt::Debug; use std::sync::Arc; use crate::TableProvider; -use datafusion_common::error::Result; -use datafusion_expr::Expr; -use datafusion_expr::TableType; -use datafusion_physical_expr::create_physical_sort_exprs; -use datafusion_physical_plan::repartition::RepartitionExec; -use datafusion_physical_plan::{ - common, ExecutionPlan, ExecutionPlanProperties, Partitioning, -}; -use arrow::datatypes::SchemaRef; +use arrow::array::{ + Array, ArrayRef, BooleanArray, RecordBatch as ArrowRecordBatch, UInt64Array, +}; +use arrow::compute::kernels::zip::zip; +use arrow::compute::{and, filter_record_batch}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use datafusion_common::{not_impl_err, plan_err, Constraints, DFSchema, SchemaExt}; +use datafusion_common::error::Result; +use datafusion_common::tree_node::TreeNodeRecursion; +use datafusion_common::{Constraints, DFSchema, SchemaExt, not_impl_err, plan_err}; use datafusion_common_runtime::JoinSet; -use datafusion_datasource::memory::MemSink; -use datafusion_datasource::memory::MemorySourceConfig; +use datafusion_datasource::memory::{MemSink, MemorySourceConfig}; use datafusion_datasource::sink::DataSinkExec; use datafusion_datasource::source::DataSourceExec; use datafusion_expr::dml::InsertOp; -use datafusion_expr::SortExpr; +use datafusion_expr::{Expr, SortExpr, TableType}; +use datafusion_physical_expr::{ + LexOrdering, create_physical_expr, create_physical_sort_exprs, +}; +use datafusion_physical_plan::repartition::RepartitionExec; +use datafusion_physical_plan::stream::RecordBatchStreamAdapter; +use datafusion_physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, Partitioning, + PhysicalExpr, PlanProperties, common, +}; use datafusion_session::Session; use async_trait::async_trait; @@ -70,8 +77,16 @@ pub struct MemTable { } impl MemTable { - /// Create a new in-memory table from the provided schema and record batches + /// Create a new in-memory table from the provided schema and record batches. + /// + /// Requires at least one partition. To construct an empty `MemTable`, pass + /// `vec![vec![]]` as the `partitions` argument, this represents one partition with + /// no batches. pub fn try_new(schema: SchemaRef, partitions: Vec>) -> Result { + if partitions.is_empty() { + return plan_err!("No partitions provided, expected at least one partition"); + } + for batches in partitions.iter().flatten() { let batches_schema = batches.schema(); if !schema.contains(&batches_schema) { @@ -89,7 +104,7 @@ impl MemTable { .into_iter() .map(|e| Arc::new(RwLock::new(e))) .collect::>(), - constraints: Constraints::empty(), + constraints: Constraints::default(), column_defaults: HashMap::new(), sort_order: Arc::new(Mutex::new(vec![])), }) @@ -237,18 +252,15 @@ impl TableProvider for MemTable { // add sort information if present let sort_order = self.sort_order.lock(); if !sort_order.is_empty() { - let df_schema = DFSchema::try_from(self.schema.as_ref().clone())?; - - let file_sort_order = sort_order - .iter() - .map(|sort_exprs| { - create_physical_sort_exprs( - sort_exprs, - &df_schema, - state.execution_props(), - ) - }) - .collect::>>()?; + let df_schema = DFSchema::try_from(Arc::clone(&self.schema))?; + + let eqp = state.execution_props(); + let mut file_sort_order = vec![]; + for sort_exprs in sort_order.iter() { + let physical_exprs = + create_physical_sort_exprs(sort_exprs, &df_schema, eqp)?; + file_sort_order.extend(LexOrdering::new(physical_exprs)); + } source = source.try_with_sort_information(file_sort_order)?; } @@ -293,4 +305,342 @@ impl TableProvider for MemTable { fn get_column_default(&self, column: &str) -> Option<&Expr> { self.column_defaults.get(column) } + + async fn delete_from( + &self, + state: &dyn Session, + filters: Vec, + ) -> Result> { + // Early exit if table has no partitions + if self.batches.is_empty() { + return Ok(Arc::new(DmlResultExec::new(0))); + } + + *self.sort_order.lock() = vec![]; + + let mut total_deleted: u64 = 0; + let df_schema = DFSchema::try_from(Arc::clone(&self.schema))?; + + for partition_data in &self.batches { + let mut partition = partition_data.write().await; + let mut new_batches = Vec::with_capacity(partition.len()); + + for batch in partition.iter() { + if batch.num_rows() == 0 { + continue; + } + + // Evaluate filters - None means "match all rows" + let filter_mask = evaluate_filters_to_mask( + &filters, + batch, + &df_schema, + state.execution_props(), + )?; + + let (delete_count, keep_mask) = match filter_mask { + Some(mask) => { + // Count rows where mask is true (will be deleted) + let count = mask.iter().filter(|v| v == &Some(true)).count(); + // Keep rows where predicate is false or NULL (SQL three-valued logic) + let keep: BooleanArray = + mask.iter().map(|v| Some(v != Some(true))).collect(); + (count, keep) + } + None => { + // No filters = delete all rows + ( + batch.num_rows(), + BooleanArray::from(vec![false; batch.num_rows()]), + ) + } + }; + + total_deleted += delete_count as u64; + + let filtered_batch = filter_record_batch(batch, &keep_mask)?; + if filtered_batch.num_rows() > 0 { + new_batches.push(filtered_batch); + } + } + + *partition = new_batches; + } + + Ok(Arc::new(DmlResultExec::new(total_deleted))) + } + + async fn update( + &self, + state: &dyn Session, + assignments: Vec<(String, Expr)>, + filters: Vec, + ) -> Result> { + // Early exit if table has no partitions + if self.batches.is_empty() { + return Ok(Arc::new(DmlResultExec::new(0))); + } + + // Validate column names upfront with clear error messages + let available_columns: Vec<&str> = self + .schema + .fields() + .iter() + .map(|f| f.name().as_str()) + .collect(); + for (column_name, _) in &assignments { + if self.schema.field_with_name(column_name).is_err() { + return plan_err!( + "UPDATE failed: column '{}' does not exist. Available columns: {}", + column_name, + available_columns.join(", ") + ); + } + } + + let df_schema = DFSchema::try_from(Arc::clone(&self.schema))?; + + // Create physical expressions for assignments upfront (outside batch loop) + let physical_assignments: HashMap> = assignments + .iter() + .map(|(name, expr)| { + let physical_expr = + create_physical_expr(expr, &df_schema, state.execution_props())?; + Ok((name.clone(), physical_expr)) + }) + .collect::>()?; + + *self.sort_order.lock() = vec![]; + + let mut total_updated: u64 = 0; + + for partition_data in &self.batches { + let mut partition = partition_data.write().await; + let mut new_batches = Vec::with_capacity(partition.len()); + + for batch in partition.iter() { + if batch.num_rows() == 0 { + continue; + } + + // Evaluate filters - None means "match all rows" + let filter_mask = evaluate_filters_to_mask( + &filters, + batch, + &df_schema, + state.execution_props(), + )?; + + let (update_count, update_mask) = match filter_mask { + Some(mask) => { + // Count rows where mask is true (will be updated) + let count = mask.iter().filter(|v| v == &Some(true)).count(); + // Normalize mask: only true (not NULL) triggers update + let normalized: BooleanArray = + mask.iter().map(|v| Some(v == Some(true))).collect(); + (count, normalized) + } + None => { + // No filters = update all rows + ( + batch.num_rows(), + BooleanArray::from(vec![true; batch.num_rows()]), + ) + } + }; + + total_updated += update_count as u64; + + if update_count == 0 { + new_batches.push(batch.clone()); + continue; + } + + let mut new_columns: Vec = + Vec::with_capacity(batch.num_columns()); + + for field in self.schema.fields() { + let column_name = field.name(); + let original_column = + batch.column_by_name(column_name).ok_or_else(|| { + datafusion_common::DataFusionError::Internal(format!( + "Column '{column_name}' not found in batch" + )) + })?; + + let new_column = if let Some(physical_expr) = + physical_assignments.get(column_name.as_str()) + { + // Use evaluate_selection to only evaluate on matching rows. + // This avoids errors (e.g., divide-by-zero) on rows that won't + // be updated. The result is scattered back with nulls for + // non-matching rows, which zip() will replace with originals. + let new_values = + physical_expr.evaluate_selection(batch, &update_mask)?; + let new_array = new_values.into_array(batch.num_rows())?; + + // Convert to &dyn Array which implements Datum + let new_arr: &dyn Array = new_array.as_ref(); + let orig_arr: &dyn Array = original_column.as_ref(); + zip(&update_mask, &new_arr, &orig_arr)? + } else { + Arc::clone(original_column) + }; + + new_columns.push(new_column); + } + + let updated_batch = + ArrowRecordBatch::try_new(Arc::clone(&self.schema), new_columns)?; + new_batches.push(updated_batch); + } + + *partition = new_batches; + } + + Ok(Arc::new(DmlResultExec::new(total_updated))) + } +} + +/// Evaluate filter expressions against a batch and return a combined boolean mask. +/// Returns None if filters is empty (meaning "match all rows"). +/// The returned mask has true for rows that match the filter predicates. +fn evaluate_filters_to_mask( + filters: &[Expr], + batch: &RecordBatch, + df_schema: &DFSchema, + execution_props: &datafusion_expr::execution_props::ExecutionProps, +) -> Result> { + if filters.is_empty() { + return Ok(None); + } + + let mut combined_mask: Option = None; + + for filter_expr in filters { + let physical_expr = + create_physical_expr(filter_expr, df_schema, execution_props)?; + + let result = physical_expr.evaluate(batch)?; + let array = result.into_array(batch.num_rows())?; + let bool_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + datafusion_common::DataFusionError::Internal( + "Filter did not evaluate to boolean".to_string(), + ) + })? + .clone(); + + combined_mask = Some(match combined_mask { + Some(existing) => and(&existing, &bool_array)?, + None => bool_array, + }); + } + + Ok(combined_mask) +} + +/// Returns a single row with the count of affected rows. +#[derive(Debug)] +struct DmlResultExec { + rows_affected: u64, + schema: SchemaRef, + properties: Arc, +} + +impl DmlResultExec { + fn new(rows_affected: u64) -> Self { + let schema = Arc::new(Schema::new(vec![Field::new( + "count", + DataType::UInt64, + false, + )])); + + let properties = PlanProperties::new( + datafusion_physical_expr::EquivalenceProperties::new(Arc::clone(&schema)), + Partitioning::UnknownPartitioning(1), + datafusion_physical_plan::execution_plan::EmissionType::Final, + datafusion_physical_plan::execution_plan::Boundedness::Bounded, + ); + + Self { + rows_affected, + schema, + properties: Arc::new(properties), + } + } +} + +impl DisplayAs for DmlResultExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default + | DisplayFormatType::Verbose + | DisplayFormatType::TreeRender => { + write!(f, "DmlResultExec: rows_affected={}", self.rows_affected) + } + } + } +} + +impl ExecutionPlan for DmlResultExec { + fn name(&self) -> &str { + "DmlResultExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn properties(&self) -> &Arc { + &self.properties + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + Ok(self) + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + // Create a single batch with the count + let count_array = UInt64Array::from(vec![self.rows_affected]); + let batch = ArrowRecordBatch::try_new( + Arc::clone(&self.schema), + vec![Arc::new(count_array) as ArrayRef], + )?; + + // Create a stream that yields just this one batch + let stream = futures::stream::iter(vec![Ok(batch)]); + Ok(Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&self.schema), + stream, + ))) + } + + fn apply_expressions( + &self, + _f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + Ok(TreeNodeRecursion::Continue) + } } diff --git a/datafusion/catalog/src/schema.rs b/datafusion/catalog/src/schema.rs index 5b37348fd7427..c6299582813b4 100644 --- a/datafusion/catalog/src/schema.rs +++ b/datafusion/catalog/src/schema.rs @@ -19,13 +19,14 @@ //! representing collections of named tables. use async_trait::async_trait; -use datafusion_common::{exec_err, DataFusionError}; +use datafusion_common::{DataFusionError, exec_err}; use std::any::Any; use std::fmt::Debug; use std::sync::Arc; use crate::table::TableProvider; use datafusion_common::Result; +use datafusion_expr::TableType; /// Represents a schema, comprising a number of named tables. /// @@ -54,12 +55,20 @@ pub trait SchemaProvider: Debug + Sync + Send { name: &str, ) -> Result>, DataFusionError>; + /// Retrieves the type of a specific table from the schema by name, if it exists, otherwise + /// returns `None`. Implementations for which this operation is cheap but [Self::table] is + /// expensive can override this to improve operations that only need the type, e.g. + /// `SELECT * FROM information_schema.tables`. + async fn table_type(&self, name: &str) -> Result> { + self.table(name).await.map(|o| o.map(|t| t.table_type())) + } + /// If supported by the implementation, adds a new table named `name` to /// this schema. /// /// If a table of the same name was already registered, returns "Table /// already exists" error. - #[allow(unused_variables)] + #[expect(unused_variables)] fn register_table( &self, name: String, @@ -72,7 +81,7 @@ pub trait SchemaProvider: Debug + Sync + Send { /// schema and returns the previously registered [`TableProvider`], if any. /// /// If no `name` table exists, returns Ok(None). - #[allow(unused_variables)] + #[expect(unused_variables)] fn deregister_table(&self, name: &str) -> Result>> { exec_err!("schema provider does not support deregistering tables") } diff --git a/datafusion/catalog/src/stream.rs b/datafusion/catalog/src/stream.rs index fbfab513229e0..bdd72a1b1d70b 100644 --- a/datafusion/catalog/src/stream.rs +++ b/datafusion/catalog/src/stream.rs @@ -28,13 +28,13 @@ use std::sync::Arc; use crate::{Session, TableProvider, TableProviderFactory}; use arrow::array::{RecordBatch, RecordBatchReader, RecordBatchWriter}; use arrow::datatypes::SchemaRef; -use datafusion_common::{config_err, plan_err, Constraints, DataFusionError, Result}; +use datafusion_common::{Constraints, DataFusionError, Result, config_err, plan_err}; use datafusion_common_runtime::SpawnedTask; use datafusion_datasource::sink::{DataSink, DataSinkExec}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::dml::InsertOp; use datafusion_expr::{CreateExternalTable, Expr, SortExpr, TableType}; -use datafusion_physical_expr::create_ordering; +use datafusion_physical_expr::create_lex_ordering; use datafusion_physical_plan::stream::RecordBatchReceiverStreamBuilder; use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; @@ -53,7 +53,7 @@ impl TableProviderFactory for StreamTableFactory { state: &dyn Session, cmd: &CreateExternalTable, ) -> Result> { - let schema: SchemaRef = Arc::new(cmd.schema.as_ref().into()); + let schema: SchemaRef = Arc::clone(cmd.schema.inner()); let location = cmd.location.clone(); let encoding = cmd.file_type.parse()?; let header = if let Ok(opt) = cmd @@ -256,7 +256,7 @@ impl StreamConfig { Self { source, order: vec![], - constraints: Constraints::empty(), + constraints: Constraints::default(), } } @@ -321,17 +321,21 @@ impl TableProvider for StreamTable { async fn scan( &self, - _state: &dyn Session, + state: &dyn Session, projection: Option<&Vec>, _filters: &[Expr], limit: Option, ) -> Result> { let projected_schema = match projection { Some(p) => { - let projected = self.0.source.schema().project(p)?; - create_ordering(&projected, &self.0.order)? + let projected = Arc::new(self.0.source.schema().project(p)?); + create_lex_ordering(&projected, &self.0.order, state.execution_props())? } - None => create_ordering(self.0.source.schema(), &self.0.order)?, + None => create_lex_ordering( + self.0.source.schema(), + &self.0.order, + state.execution_props(), + )?, }; Ok(Arc::new(StreamingTableExec::try_new( @@ -350,15 +354,11 @@ impl TableProvider for StreamTable { input: Arc, _insert_op: InsertOp, ) -> Result> { - let ordering = match self.0.order.first() { - Some(x) => { - let schema = self.0.source.schema(); - let orders = create_ordering(schema, std::slice::from_ref(x))?; - let ordering = orders.into_iter().next().unwrap(); - Some(ordering.into_iter().map(Into::into).collect()) - } - None => None, - }; + let schema = self.0.source.schema(); + let orders = + create_lex_ordering(schema, &self.0.order, _state.execution_props())?; + // It is sufficient to pass only one of the equivalent orderings: + let ordering = orders.into_iter().next().map(Into::into); Ok(Arc::new(DataSinkExec::new( input, @@ -440,6 +440,6 @@ impl DataSink for StreamWrite { write_task .join_unwind() .await - .map_err(DataFusionError::ExecutionJoin)? + .map_err(|e| DataFusionError::ExecutionJoin(Box::new(e)))? } } diff --git a/datafusion/catalog/src/streaming.rs b/datafusion/catalog/src/streaming.rs index 654e6755d7d4c..db9596b420b7b 100644 --- a/datafusion/catalog/src/streaming.rs +++ b/datafusion/catalog/src/streaming.rs @@ -22,21 +22,23 @@ use std::sync::Arc; use arrow::datatypes::SchemaRef; use async_trait::async_trait; - -use crate::Session; -use crate::TableProvider; -use datafusion_common::{plan_err, Result}; -use datafusion_expr::{Expr, TableType}; -use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; +use datafusion_common::{DFSchema, Result, plan_err}; +use datafusion_expr::{Expr, SortExpr, TableType}; +use datafusion_physical_expr::equivalence::project_ordering; +use datafusion_physical_expr::{LexOrdering, create_physical_sort_exprs}; use datafusion_physical_plan::ExecutionPlan; +use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; use log::debug; +use crate::{Session, TableProvider}; + /// A [`TableProvider`] that streams a set of [`PartitionStream`] #[derive(Debug)] pub struct StreamingTable { schema: SchemaRef, partitions: Vec>, infinite: bool, + sort_order: Vec, } impl StreamingTable { @@ -60,13 +62,21 @@ impl StreamingTable { schema, partitions, infinite: false, + sort_order: vec![], }) } + /// Sets streaming table can be infinite. pub fn with_infinite_table(mut self, infinite: bool) -> Self { self.infinite = infinite; self } + + /// Sets the existing ordering of streaming table. + pub fn with_sort_order(mut self, sort_order: Vec) -> Self { + self.sort_order = sort_order; + self + } } #[async_trait] @@ -85,16 +95,40 @@ impl TableProvider for StreamingTable { async fn scan( &self, - _state: &dyn Session, + state: &dyn Session, projection: Option<&Vec>, _filters: &[Expr], limit: Option, ) -> Result> { + let physical_sort = if !self.sort_order.is_empty() { + let df_schema = DFSchema::try_from(Arc::clone(&self.schema))?; + let eqp = state.execution_props(); + + let original_sort_exprs = + create_physical_sort_exprs(&self.sort_order, &df_schema, eqp)?; + + if let Some(p) = projection { + // When performing a projection, the output columns will not match + // the original physical sort expression indices. Also the sort columns + // may not be in the output projection. To correct for these issues + // we need to project the ordering based on the output schema. + let schema = Arc::new(self.schema.project(p)?); + LexOrdering::new(original_sort_exprs) + .and_then(|lex_ordering| project_ordering(&lex_ordering, &schema)) + .map(|lex_ordering| lex_ordering.to_vec()) + .unwrap_or_default() + } else { + original_sort_exprs + } + } else { + vec![] + }; + Ok(Arc::new(StreamingTableExec::try_new( Arc::clone(&self.schema), self.partitions.clone(), projection, - None, + LexOrdering::new(physical_sort), self.infinite, limit, )?)) diff --git a/datafusion/catalog/src/table.rs b/datafusion/catalog/src/table.rs index 207abb9c66703..c9b4e974c8994 100644 --- a/datafusion/catalog/src/table.rs +++ b/datafusion/catalog/src/table.rs @@ -24,7 +24,7 @@ use crate::session::Session; use arrow::datatypes::SchemaRef; use async_trait::async_trait; use datafusion_common::Result; -use datafusion_common::{not_impl_err, Constraints, Statistics}; +use datafusion_common::{Constraints, Statistics, not_impl_err}; use datafusion_expr::Expr; use datafusion_expr::dml::InsertOp; @@ -49,7 +49,7 @@ use datafusion_physical_plan::ExecutionPlan; /// [`CatalogProvider`]: super::CatalogProvider #[async_trait] pub trait TableProvider: Debug + Sync + Send { - /// Returns the table provider as [`Any`](std::any::Any) so that it can be + /// Returns the table provider as [`Any`] so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; @@ -75,7 +75,7 @@ pub trait TableProvider: Debug + Sync + Send { } /// Get the [`LogicalPlan`] of this table, if available. - fn get_logical_plan(&self) -> Option> { + fn get_logical_plan(&'_ self) -> Option> { None } @@ -171,6 +171,37 @@ pub trait TableProvider: Debug + Sync + Send { limit: Option, ) -> Result>; + /// Create an [`ExecutionPlan`] for scanning the table using structured arguments. + /// + /// This method uses [`ScanArgs`] to pass scan parameters in a structured way + /// and returns a [`ScanResult`] containing the execution plan. + /// + /// Table providers can override this method to take advantage of additional + /// parameters like the upcoming `preferred_ordering` that may not be available through + /// other scan methods. + /// + /// # Arguments + /// * `state` - The session state containing configuration and context + /// * `args` - Structured scan arguments including projection, filters, limit, and ordering preferences + /// + /// # Returns + /// A [`ScanResult`] containing the [`ExecutionPlan`] for scanning the table + /// + /// See [`Self::scan`] for detailed documentation about projection, filters, and limits. + async fn scan_with_args<'a>( + &self, + state: &dyn Session, + args: ScanArgs<'a>, + ) -> Result { + let filters = args.filters().unwrap_or(&[]); + let projection = args.projection().map(|p| p.to_vec()); + let limit = args.limit(); + let plan = self + .scan(state, projection.as_ref(), filters, limit) + .await?; + Ok(plan.into()) + } + /// Specify if DataFusion should provide filter expressions to the /// TableProvider to apply *during* the scan. /// @@ -297,6 +328,147 @@ pub trait TableProvider: Debug + Sync + Send { ) -> Result> { not_impl_err!("Insert into not implemented for this table") } + + /// Delete rows matching the filter predicates. + /// + /// Returns an [`ExecutionPlan`] producing a single row with `count` (UInt64). + /// Empty `filters` deletes all rows. + async fn delete_from( + &self, + _state: &dyn Session, + _filters: Vec, + ) -> Result> { + not_impl_err!("DELETE not supported for {} table", self.table_type()) + } + + /// Update rows matching the filter predicates. + /// + /// Returns an [`ExecutionPlan`] producing a single row with `count` (UInt64). + /// Empty `filters` updates all rows. + async fn update( + &self, + _state: &dyn Session, + _assignments: Vec<(String, Expr)>, + _filters: Vec, + ) -> Result> { + not_impl_err!("UPDATE not supported for {} table", self.table_type()) + } + + /// Remove all rows from the table. + /// + /// Should return an [ExecutionPlan] producing a single row with count (UInt64), + /// representing the number of rows removed. + async fn truncate(&self, _state: &dyn Session) -> Result> { + not_impl_err!("TRUNCATE not supported for {} table", self.table_type()) + } +} + +/// Arguments for scanning a table with [`TableProvider::scan_with_args`]. +#[derive(Debug, Clone, Default)] +pub struct ScanArgs<'a> { + filters: Option<&'a [Expr]>, + projection: Option<&'a [usize]>, + limit: Option, +} + +impl<'a> ScanArgs<'a> { + /// Set the column projection for the scan. + /// + /// The projection is a list of column indices from [`TableProvider::schema`] + /// that should be included in the scan results. If `None`, all columns are included. + /// + /// # Arguments + /// * `projection` - Optional slice of column indices to project + pub fn with_projection(mut self, projection: Option<&'a [usize]>) -> Self { + self.projection = projection; + self + } + + /// Get the column projection for the scan. + /// + /// Returns a reference to the projection column indices, or `None` if + /// no projection was specified (meaning all columns should be included). + pub fn projection(&self) -> Option<&'a [usize]> { + self.projection + } + + /// Set the filter expressions for the scan. + /// + /// Filters are boolean expressions that should be evaluated during the scan + /// to reduce the number of rows returned. All expressions are combined with AND logic. + /// Whether filters are actually pushed down depends on [`TableProvider::supports_filters_pushdown`]. + /// + /// # Arguments + /// * `filters` - Optional slice of filter expressions + pub fn with_filters(mut self, filters: Option<&'a [Expr]>) -> Self { + self.filters = filters; + self + } + + /// Get the filter expressions for the scan. + /// + /// Returns a reference to the filter expressions, or `None` if no filters were specified. + pub fn filters(&self) -> Option<&'a [Expr]> { + self.filters + } + + /// Set the maximum number of rows to return from the scan. + /// + /// If specified, the scan should return at most this many rows. This is typically + /// used to optimize queries with `LIMIT` clauses. + /// + /// # Arguments + /// * `limit` - Optional maximum number of rows to return + pub fn with_limit(mut self, limit: Option) -> Self { + self.limit = limit; + self + } + + /// Get the maximum number of rows to return from the scan. + /// + /// Returns the row limit, or `None` if no limit was specified. + pub fn limit(&self) -> Option { + self.limit + } +} + +/// Result of a table scan operation from [`TableProvider::scan_with_args`]. +#[derive(Debug, Clone)] +pub struct ScanResult { + /// The ExecutionPlan to run. + plan: Arc, +} + +impl ScanResult { + /// Create a new `ScanResult` with the given execution plan. + /// + /// # Arguments + /// * `plan` - The execution plan that will perform the table scan + pub fn new(plan: Arc) -> Self { + Self { plan } + } + + /// Get a reference to the execution plan for this scan result. + /// + /// Returns a reference to the [`ExecutionPlan`] that will perform + /// the actual table scanning and data retrieval. + pub fn plan(&self) -> &Arc { + &self.plan + } + + /// Consume this ScanResult and return the execution plan. + /// + /// Returns the owned [`ExecutionPlan`] that will perform + /// the actual table scanning and data retrieval. + pub fn into_inner(self) -> Arc { + self.plan + } +} + +impl From> for ScanResult { + fn from(plan: Arc) -> Self { + Self::new(plan) + } } /// A factory which creates [`TableProvider`]s at runtime given a URL. @@ -314,13 +486,13 @@ pub trait TableProviderFactory: Debug + Sync + Send { } /// A trait for table function implementations -pub trait TableFunctionImpl: Debug + Sync + Send { +pub trait TableFunctionImpl: Debug + Sync + Send + Any { /// Create a table provider fn call(&self, args: &[Expr]) -> Result>; } /// A table that uses a function to generate data -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct TableFunction { /// Name of the table function name: String, diff --git a/datafusion/catalog/src/view.rs b/datafusion/catalog/src/view.rs index 8dfb79718c9bb..54c54431a5913 100644 --- a/datafusion/catalog/src/view.rs +++ b/datafusion/catalog/src/view.rs @@ -24,8 +24,8 @@ use crate::TableProvider; use arrow::datatypes::SchemaRef; use async_trait::async_trait; -use datafusion_common::error::Result; use datafusion_common::Column; +use datafusion_common::error::Result; use datafusion_expr::TableType; use datafusion_expr::{Expr, LogicalPlan}; use datafusion_expr::{LogicalPlanBuilder, TableProviderFilterPushDown}; @@ -51,7 +51,7 @@ impl ViewTable { /// Notes: the `LogicalPlan` is not validated or type coerced. If this is /// needed it should be done after calling this function. pub fn new(logical_plan: LogicalPlan, definition: Option) -> Self { - let table_schema = logical_plan.schema().as_ref().to_owned().into(); + let table_schema = Arc::clone(logical_plan.schema().inner()); Self { logical_plan, table_schema, @@ -87,7 +87,7 @@ impl TableProvider for ViewTable { self } - fn get_logical_plan(&self) -> Option> { + fn get_logical_plan(&'_ self) -> Option> { Some(Cow::Borrowed(&self.logical_plan)) } diff --git a/datafusion/common-runtime/Cargo.toml b/datafusion/common-runtime/Cargo.toml index 7ddc021e640c9..fd9a818bcb1d0 100644 --- a/datafusion/common-runtime/Cargo.toml +++ b/datafusion/common-runtime/Cargo.toml @@ -31,6 +31,9 @@ rust-version = { workspace = true } [package.metadata.docs.rs] all-features = true +# Note: add additional linter rules in lib.rs. +# Rust does not support workspace + new linter rules in subcrates yet +# https://github.com/rust-lang/cargo/issues/13157 [lints] workspace = true @@ -43,4 +46,4 @@ log = { workspace = true } tokio = { workspace = true } [dev-dependencies] -tokio = { version = "1.45", features = ["rt", "rt-multi-thread", "time"] } +tokio = { workspace = true, features = ["rt", "rt-multi-thread", "time"] } diff --git a/datafusion/common-runtime/README.md b/datafusion/common-runtime/README.md index 77100e52603c9..ff44e6c3e209e 100644 --- a/datafusion/common-runtime/README.md +++ b/datafusion/common-runtime/README.md @@ -17,10 +17,16 @@ under the License. --> -# DataFusion Common Runtime +# Apache DataFusion Common Runtime -[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. This crate is a submodule of DataFusion that provides common utilities. -[df]: https://crates.io/crates/datafusion +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/common-runtime/src/common.rs b/datafusion/common-runtime/src/common.rs index e7aba1d455ee6..ca618b19ed2f1 100644 --- a/datafusion/common-runtime/src/common.rs +++ b/datafusion/common-runtime/src/common.rs @@ -44,7 +44,7 @@ impl SpawnedTask { R: Send, { // Ok to use spawn here as SpawnedTask handles aborting/cancelling the task on Drop - #[allow(clippy::disallowed_methods)] + #[expect(clippy::disallowed_methods)] let inner = tokio::task::spawn(trace_future(task)); Self { inner } } @@ -56,7 +56,7 @@ impl SpawnedTask { R: Send, { // Ok to use spawn_blocking here as SpawnedTask handles aborting/cancelling the task on Drop - #[allow(clippy::disallowed_methods)] + #[expect(clippy::disallowed_methods)] let inner = tokio::task::spawn_blocking(trace_block(task)); Self { inner } } @@ -68,15 +68,28 @@ impl SpawnedTask { } /// Joins the task and unwinds the panic if it happens. - pub async fn join_unwind(self) -> Result { + pub async fn join_unwind(mut self) -> Result { + self.join_unwind_mut().await + } + + /// Joins the task using a mutable reference and unwinds the panic if it happens. + /// + /// This method is similar to [`join_unwind`](Self::join_unwind), but takes a mutable + /// reference instead of consuming `self`. This allows the `SpawnedTask` to remain + /// usable after the call. + /// + /// If called multiple times on the same task: + /// - If the task is still running, it will continue waiting for completion + /// - If the task has already completed successfully, subsequent calls will + /// continue to return the same `JoinError` indicating the task is finished + /// - If the task panicked, the first call will resume the panic, and the + /// program will not reach subsequent calls + pub async fn join_unwind_mut(&mut self) -> Result { self.await.map_err(|e| { // `JoinError` can be caused either by panic or cancellation. We have to handle panics: if e.is_panic() { std::panic::resume_unwind(e.into_panic()); } else { - // Cancellation may be caused by two reasons: - // 1. Abort is called, but since we consumed `self`, it's not our case (`JoinHandle` not accessible outside). - // 2. The runtime is shutting down. log::warn!("SpawnedTask was polled during shutdown"); e } @@ -102,14 +115,14 @@ impl Drop for SpawnedTask { mod tests { use super::*; - use std::future::{pending, Pending}; + use std::future::{Pending, pending}; use tokio::{runtime::Runtime, sync::oneshot}; #[tokio::test] async fn runtime_shutdown() { let rt = Runtime::new().unwrap(); - #[allow(clippy::async_yields_async)] + #[expect(clippy::async_yields_async)] let task = rt .spawn(async { SpawnedTask::spawn(async { diff --git a/datafusion/common-runtime/src/lib.rs b/datafusion/common-runtime/src/lib.rs index ec8db0bdcd911..cf45ccf3ef63a 100644 --- a/datafusion/common-runtime/src/lib.rs +++ b/datafusion/common-runtime/src/lib.rs @@ -15,11 +15,12 @@ // specific language governing permissions and limitations // under the License. +#![cfg_attr(test, allow(clippy::needless_pass_by_value))] #![doc( html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] @@ -30,4 +31,6 @@ mod trace_utils; pub use common::SpawnedTask; pub use join_set::JoinSet; -pub use trace_utils::{set_join_set_tracer, JoinSetTracer}; +pub use trace_utils::{ + JoinSetTracer, JoinSetTracerError, set_join_set_tracer, trace_block, trace_future, +}; diff --git a/datafusion/common-runtime/src/trace_utils.rs b/datafusion/common-runtime/src/trace_utils.rs index c3a39c355fc88..f8adbe8825bc1 100644 --- a/datafusion/common-runtime/src/trace_utils.rs +++ b/datafusion/common-runtime/src/trace_utils.rs @@ -15,8 +15,8 @@ // specific language governing permissions and limitations // under the License. -use futures::future::BoxFuture; use futures::FutureExt; +use futures::future::BoxFuture; use std::any::Any; use std::error::Error; use std::fmt::{Display, Formatter, Result as FmtResult}; diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index d471e48be4e75..92dd76aa97d47 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -31,6 +31,9 @@ rust-version = { workspace = true } [package.metadata.docs.rs] all-features = true +# Note: add additional linter rules in lib.rs. +# Rust does not support workspace + new linter rules in subcrates yet +# https://github.com/rust-lang/cargo/issues/13157 [lints] workspace = true @@ -40,13 +43,30 @@ name = "datafusion_common" [features] avro = ["apache-avro"] backtrace = [] -pyarrow = ["pyo3", "arrow/pyarrow", "parquet"] +parquet_encryption = [ + "parquet", + "parquet/encryption", + "dep:hex", +] force_hash_collisions = [] recursive_protection = ["dep:recursive"] +parquet = ["dep:parquet"] +sql = ["sqlparser"] + +[[bench]] +harness = false +name = "with_hashes" + +[[bench]] +harness = false +name = "scalar_to_array" + +[[bench]] +harness = false +name = "stats_merge" [dependencies] -ahash = { workspace = true } -apache-avro = { version = "0.17", default-features = false, features = [ +apache-avro = { workspace = true, features = [ "bzip", "snappy", "xz", @@ -54,18 +74,19 @@ apache-avro = { version = "0.17", default-features = false, features = [ ], optional = true } arrow = { workspace = true } arrow-ipc = { workspace = true } -base64 = "0.22.1" +chrono = { workspace = true } +foldhash = "0.2" half = { workspace = true } hashbrown = { workspace = true } +hex = { workspace = true, optional = true } indexmap = { workspace = true } -libc = "0.2.172" +itertools = { workspace = true } +libc = "0.2.180" log = { workspace = true } object_store = { workspace = true, optional = true } parquet = { workspace = true, optional = true, default-features = true } -paste = "1.0.15" -pyo3 = { version = "0.24.2", optional = true } recursive = { workspace = true, optional = true } -sqlparser = { workspace = true } +sqlparser = { workspace = true, optional = true } tokio = { workspace = true } [target.'cfg(target_family = "wasm")'.dependencies] @@ -73,5 +94,7 @@ web-time = "1.1.0" [dev-dependencies] chrono = { workspace = true } +criterion = { workspace = true } insta = { workspace = true } rand = { workspace = true } +sqlparser = { workspace = true } diff --git a/datafusion/common/README.md b/datafusion/common/README.md index 524ab4420d2a8..4948c8c581be9 100644 --- a/datafusion/common/README.md +++ b/datafusion/common/README.md @@ -17,10 +17,16 @@ under the License. --> -# DataFusion Common +# Apache DataFusion Common -[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. This crate is a submodule of DataFusion that provides common data types and utilities. -[df]: https://crates.io/crates/datafusion +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/common/benches/scalar_to_array.rs b/datafusion/common/benches/scalar_to_array.rs new file mode 100644 index 0000000000000..90a152e515fe5 --- /dev/null +++ b/datafusion/common/benches/scalar_to_array.rs @@ -0,0 +1,107 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmarks for `ScalarValue::to_array_of_size`, focusing on List +//! scalars. + +use arrow::array::{Array, ArrayRef, AsArray, StringViewBuilder}; +use arrow::datatypes::{DataType, Field}; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::utils::SingleRowListArrayBuilder; +use std::sync::Arc; + +/// Build a `ScalarValue::List` of `num_elements` Utf8View strings whose +/// inner StringViewArray has `num_buffers` data buffers. +fn make_list_scalar(num_elements: usize, num_buffers: usize) -> ScalarValue { + let elements_per_buffer = num_elements.div_ceil(num_buffers); + + let mut small_arrays: Vec = Vec::new(); + let mut remaining = num_elements; + for buf_idx in 0..num_buffers { + let count = remaining.min(elements_per_buffer); + if count == 0 { + break; + } + let start = buf_idx * elements_per_buffer; + let mut builder = StringViewBuilder::with_capacity(count); + for i in start..start + count { + builder.append_value(format!("{i:024x}")); + } + small_arrays.push(Arc::new(builder.finish()) as ArrayRef); + remaining -= count; + } + + let refs: Vec<&dyn Array> = small_arrays.iter().map(|a| a.as_ref()).collect(); + let concated = arrow::compute::concat(&refs).unwrap(); + + let list_array = SingleRowListArrayBuilder::new(concated) + .with_field(&Field::new_list_field(DataType::Utf8View, true)) + .build_list_array(); + ScalarValue::List(Arc::new(list_array)) +} + +/// We want to measure the cost of doing the conversion and then also accessing +/// the results, to model what would happen during query evaluation. +fn consume_list_array(arr: &ArrayRef) { + let list_arr = arr.as_list::(); + let mut total_len: usize = 0; + for i in 0..list_arr.len() { + let inner = list_arr.value(i); + let sv = inner.as_string_view(); + for j in 0..sv.len() { + total_len += sv.value(j).len(); + } + } + std::hint::black_box(total_len); +} + +fn bench_list_to_array_of_size(c: &mut Criterion) { + let mut group = c.benchmark_group("list_to_array_of_size"); + + let num_elements = 1245; + let scalar_1buf = make_list_scalar(num_elements, 1); + let scalar_50buf = make_list_scalar(num_elements, 50); + + for batch_size in [256, 1024] { + group.bench_with_input( + BenchmarkId::new("1_buffer", batch_size), + &batch_size, + |b, &sz| { + b.iter(|| { + let arr = scalar_1buf.to_array_of_size(sz).unwrap(); + consume_list_array(&arr); + }); + }, + ); + group.bench_with_input( + BenchmarkId::new("50_buffers", batch_size), + &batch_size, + |b, &sz| { + b.iter(|| { + let arr = scalar_50buf.to_array_of_size(sz).unwrap(); + consume_list_array(&arr); + }); + }, + ); + } + + group.finish(); +} + +criterion_group!(benches, bench_list_to_array_of_size); +criterion_main!(benches); diff --git a/datafusion/common/benches/stats_merge.rs b/datafusion/common/benches/stats_merge.rs new file mode 100644 index 0000000000000..73229b6379360 --- /dev/null +++ b/datafusion/common/benches/stats_merge.rs @@ -0,0 +1,85 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmark for `Statistics::try_merge_iter`. + +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, Schema}; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use datafusion_common::stats::Precision; +use datafusion_common::{ColumnStatistics, ScalarValue, Statistics}; + +/// Build a vector of `n` with `num_cols` columns +fn make_stats(n: usize, num_cols: usize) -> Vec { + (0..n) + .map(|i| { + let mut stats = Statistics::default() + .with_num_rows(Precision::Exact(100 + i)) + .with_total_byte_size(Precision::Exact(8000 + i * 80)); + for c in 0..num_cols { + let base = (i * num_cols + c) as i64; + stats = stats.add_column_statistics( + ColumnStatistics::new_unknown() + .with_null_count(Precision::Exact(i)) + .with_min_value(Precision::Exact(ScalarValue::Int64(Some(base)))) + .with_max_value(Precision::Exact(ScalarValue::Int64(Some( + base + 1000, + )))) + .with_sum_value(Precision::Exact(ScalarValue::Int64(Some( + base * 100, + )))), + ); + } + stats + }) + .collect() +} + +fn bench_stats_merge(c: &mut Criterion) { + let mut group = c.benchmark_group("stats_merge"); + + for &num_partitions in &[10, 100, 500] { + for &num_cols in &[1, 5, 20] { + let items = make_stats(num_partitions, num_cols); + let schema = Arc::new(Schema::new( + (0..num_cols) + .map(|i| Field::new(format!("col{i}"), DataType::Int64, true)) + .collect::>(), + )); + + let param = format!("{num_partitions}parts_{num_cols}cols"); + + group.bench_with_input( + BenchmarkId::new("try_merge_iter", ¶m), + &(&items, &schema), + |b, (items, schema)| { + b.iter(|| { + std::hint::black_box( + Statistics::try_merge_iter(*items, schema).unwrap(), + ); + }); + }, + ); + } + } + + group.finish(); +} + +criterion_group!(benches, bench_stats_merge); +criterion_main!(benches); diff --git a/datafusion/common/benches/with_hashes.rs b/datafusion/common/benches/with_hashes.rs new file mode 100644 index 0000000000000..0e9c53c896a5e --- /dev/null +++ b/datafusion/common/benches/with_hashes.rs @@ -0,0 +1,569 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmarks for `with_hashes` function + +use arrow::array::{ + Array, ArrayRef, ArrowPrimitiveType, DictionaryArray, GenericStringArray, Int32Array, + Int64Array, ListArray, MapArray, NullBufferBuilder, OffsetSizeTrait, PrimitiveArray, + RunArray, StringViewArray, StructArray, UnionArray, make_array, +}; +use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}; +use arrow::datatypes::{ + ArrowDictionaryKeyType, DataType, Field, Fields, Int32Type, Int64Type, UnionFields, +}; +use criterion::{Bencher, Criterion, criterion_group, criterion_main}; +use datafusion_common::hash_utils::RandomState; +use datafusion_common::hash_utils::with_hashes; +use rand::Rng; +use rand::SeedableRng; +use rand::distr::{Alphanumeric, Distribution, StandardUniform}; +use rand::prelude::StdRng; +use std::sync::Arc; + +const BATCH_SIZE: usize = 8192; + +struct BenchData { + name: &'static str, + array: ArrayRef, + /// Union arrays can't have null bitmasks added + supports_nulls: bool, +} + +fn criterion_benchmark(c: &mut Criterion) { + let pool = StringPool::new(100, 64); + // poll with small strings for string view tests (<=12 bytes are inlined) + let small_pool = StringPool::new(100, 5); + let cases = [ + BenchData { + name: "int64", + array: primitive_array::(BATCH_SIZE), + supports_nulls: true, + }, + BenchData { + name: "utf8", + array: pool.string_array::(BATCH_SIZE), + supports_nulls: true, + }, + BenchData { + name: "large_utf8", + array: pool.string_array::(BATCH_SIZE), + supports_nulls: true, + }, + BenchData { + name: "utf8_view", + array: pool.string_view_array(BATCH_SIZE), + supports_nulls: true, + }, + BenchData { + name: "utf8_view (small)", + array: small_pool.string_view_array(BATCH_SIZE), + supports_nulls: true, + }, + BenchData { + name: "dictionary_utf8_int32", + array: pool.dictionary_array::(BATCH_SIZE), + supports_nulls: true, + }, + BenchData { + name: "list_array", + array: list_array(BATCH_SIZE), + supports_nulls: true, + }, + BenchData { + name: "map_array", + array: map_array(BATCH_SIZE), + supports_nulls: true, + }, + BenchData { + name: "sparse_union", + array: sparse_union_array(BATCH_SIZE), + supports_nulls: false, + }, + BenchData { + name: "dense_union", + array: dense_union_array(BATCH_SIZE), + supports_nulls: false, + }, + BenchData { + name: "struct_array", + array: create_struct_array(&pool, BATCH_SIZE), + supports_nulls: true, + }, + BenchData { + name: "run_array_int32", + array: create_run_array::(BATCH_SIZE), + supports_nulls: true, + }, + ]; + + for BenchData { + name, + array, + supports_nulls, + } in cases + { + c.bench_function(&format!("{name}: single, no nulls"), |b| { + do_hash_test(b, std::slice::from_ref(&array)); + }); + c.bench_function(&format!("{name}: multiple, no nulls"), |b| { + let arrays = vec![array.clone(), array.clone(), array.clone()]; + do_hash_test(b, &arrays); + }); + // Union arrays can't have null bitmasks + if supports_nulls { + let nullable_array = add_nulls(&array); + c.bench_function(&format!("{name}: single, nulls"), |b| { + do_hash_test(b, std::slice::from_ref(&nullable_array)); + }); + c.bench_function(&format!("{name}: multiple, nulls"), |b| { + let arrays = vec![ + nullable_array.clone(), + nullable_array.clone(), + nullable_array.clone(), + ]; + do_hash_test(b, &arrays); + }); + } + } +} + +fn do_hash_test(b: &mut Bencher, arrays: &[ArrayRef]) { + let state = RandomState::default(); + b.iter(|| { + with_hashes(arrays, &state, |hashes| { + assert_eq!(hashes.len(), BATCH_SIZE); // make sure the result is used + Ok(()) + }) + .unwrap(); + }); +} + +fn create_null_mask(len: usize) -> NullBuffer +where + StandardUniform: Distribution, +{ + let mut rng = make_rng(); + let null_density = 0.03; + let mut builder = NullBufferBuilder::new(len); + for _ in 0..len { + if rng.random::() < null_density { + builder.append_null(); + } else { + builder.append_non_null(); + } + } + builder.finish().expect("should be nulls in buffer") +} + +// Returns a new array that is the same as array, but with nulls +// Handles the special case of RunArray where nulls must be in the values array +fn add_nulls(array: &ArrayRef) -> ArrayRef { + use arrow::datatypes::DataType; + + match array.data_type() { + DataType::RunEndEncoded(_, _) => { + // RunArray can't have top-level nulls, so apply nulls to the values array + let run_array = array + .as_any() + .downcast_ref::>() + .expect("Expected RunArray"); + + let run_ends_buffer = run_array.run_ends().inner().clone(); + let run_ends_array = PrimitiveArray::::new(run_ends_buffer, None); + let values = run_array.values().clone(); + + // Add nulls to the values array + let values_with_nulls = { + let array_data = values + .clone() + .into_data() + .into_builder() + .nulls(Some(create_null_mask(values.len()))) + .build() + .unwrap(); + make_array(array_data) + }; + + Arc::new( + RunArray::try_new(&run_ends_array, values_with_nulls.as_ref()) + .expect("Failed to create RunArray with null values"), + ) + } + _ => { + let array_data = array + .clone() + .into_data() + .into_builder() + .nulls(Some(create_null_mask(array.len()))) + .build() + .unwrap(); + make_array(array_data) + } + } +} + +pub fn make_rng() -> StdRng { + StdRng::seed_from_u64(42) +} + +/// String pool for generating low cardinality data (for dictionaries and string views) +struct StringPool { + strings: Vec, +} + +impl StringPool { + /// Create a new string pool with the given number of random strings + /// each having between 1 and max_length characters. + fn new(pool_size: usize, max_length: usize) -> Self { + let mut rng = make_rng(); + let mut strings = Vec::with_capacity(pool_size); + for _ in 0..pool_size { + let len = rng.random_range(1..=max_length); + let value: Vec = + rng.clone().sample_iter(&Alphanumeric).take(len).collect(); + strings.push(String::from_utf8(value).unwrap()); + } + Self { strings } + } + + /// Return an iterator over &str of the given length with values randomly chosen from the pool + fn iter_strings(&self, len: usize) -> impl Iterator { + let mut rng = make_rng(); + (0..len).map(move |_| { + let idx = rng.random_range(0..self.strings.len()); + self.strings[idx].as_str() + }) + } + + /// Return a StringArray of the given length with values randomly chosen from the pool + fn string_array(&self, array_length: usize) -> ArrayRef { + Arc::new(GenericStringArray::::from_iter_values( + self.iter_strings(array_length), + )) + } + + /// Return a StringViewArray of the given length with values randomly chosen from the pool + fn string_view_array(&self, array_length: usize) -> ArrayRef { + Arc::new(StringViewArray::from_iter_values( + self.iter_strings(array_length), + )) + } + + /// Return a DictionaryArray of the given length with values randomly chosen from the pool + fn dictionary_array( + &self, + array_length: usize, + ) -> ArrayRef { + Arc::new(DictionaryArray::::from_iter( + self.iter_strings(array_length), + )) + } +} + +pub fn primitive_array(array_len: usize) -> ArrayRef +where + T: ArrowPrimitiveType, + StandardUniform: Distribution, +{ + let mut rng = make_rng(); + + let array: PrimitiveArray = (0..array_len) + .map(|_| Some(rng.random::())) + .collect(); + Arc::new(array) +} + +/// Benchmark sliced arrays to demonstrate the optimization for when an array is +/// sliced, the underlying buffer may be much larger than what's referenced by +/// the slice. The optimization avoids hashing unreferenced elements. +fn sliced_array_benchmark(c: &mut Criterion) { + // Test with different slice ratios: slice_size / total_size + // Smaller ratio = more potential savings from the optimization + let slice_ratios = [10, 5, 2]; // 1/10, 1/5, 1/2 of total + + for ratio in slice_ratios { + let total_rows = BATCH_SIZE * ratio; + let slice_offset = BATCH_SIZE * (ratio / 2); // Take from middle + let slice_len = BATCH_SIZE; + + // Sliced ListArray + { + let full_array = list_array(total_rows); + let sliced: ArrayRef = Arc::new( + full_array + .as_any() + .downcast_ref::() + .unwrap() + .slice(slice_offset, slice_len), + ); + c.bench_function( + &format!("list_array_sliced: 1/{ratio} of {total_rows} rows"), + |b| { + do_hash_test_with_len(b, std::slice::from_ref(&sliced), slice_len); + }, + ); + } + + // Sliced MapArray + { + let full_array = map_array(total_rows); + let sliced: ArrayRef = Arc::new( + full_array + .as_any() + .downcast_ref::() + .unwrap() + .slice(slice_offset, slice_len), + ); + c.bench_function( + &format!("map_array_sliced: 1/{ratio} of {total_rows} rows"), + |b| { + do_hash_test_with_len(b, std::slice::from_ref(&sliced), slice_len); + }, + ); + } + + // Sliced Sparse UnionArray + { + let full_array = sparse_union_array(total_rows); + let sliced: ArrayRef = Arc::new( + full_array + .as_any() + .downcast_ref::() + .unwrap() + .slice(slice_offset, slice_len), + ); + c.bench_function( + &format!("sparse_union_sliced: 1/{ratio} of {total_rows} rows"), + |b| { + do_hash_test_with_len(b, std::slice::from_ref(&sliced), slice_len); + }, + ); + } + } +} + +fn do_hash_test_with_len(b: &mut Bencher, arrays: &[ArrayRef], expected_len: usize) { + let state = RandomState::default(); + b.iter(|| { + with_hashes(arrays, &state, |hashes| { + assert_eq!(hashes.len(), expected_len); + Ok(()) + }) + .unwrap(); + }); +} + +fn list_array(num_rows: usize) -> ArrayRef { + let mut rng = make_rng(); + let elements_per_row = 5; + let total_elements = num_rows * elements_per_row; + + let values: Int64Array = (0..total_elements) + .map(|_| Some(rng.random::())) + .collect(); + let offsets: Vec = (0..=num_rows) + .map(|i| (i * elements_per_row) as i32) + .collect(); + + Arc::new(ListArray::new( + Arc::new(Field::new("item", DataType::Int64, true)), + OffsetBuffer::new(ScalarBuffer::from(offsets)), + Arc::new(values), + None, + )) +} + +fn map_array(num_rows: usize) -> ArrayRef { + let mut rng = make_rng(); + let entries_per_row = 5; + let total_entries = num_rows * entries_per_row; + + let keys: Int32Array = (0..total_entries) + .map(|_| Some(rng.random::())) + .collect(); + let values: Int64Array = (0..total_entries) + .map(|_| Some(rng.random::())) + .collect(); + let offsets: Vec = (0..=num_rows) + .map(|i| (i * entries_per_row) as i32) + .collect(); + + let entries = StructArray::try_new( + Fields::from(vec![ + Field::new("keys", DataType::Int32, false), + Field::new("values", DataType::Int64, true), + ]), + vec![Arc::new(keys), Arc::new(values)], + None, + ) + .unwrap(); + + Arc::new(MapArray::new( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("keys", DataType::Int32, false), + Field::new("values", DataType::Int64, true), + ])), + false, + )), + OffsetBuffer::new(ScalarBuffer::from(offsets)), + entries, + None, + false, + )) +} + +fn sparse_union_array(num_rows: usize) -> ArrayRef { + let mut rng = make_rng(); + let num_types = 5; + + let type_ids: Vec = (0..num_rows) + .map(|_| rng.random_range(0..num_types) as i8) + .collect(); + let (fields, children): (Vec<_>, Vec<_>) = (0..num_types) + .map(|i| { + ( + ( + i as i8, + Arc::new(Field::new(format!("f{i}"), DataType::Int64, true)), + ), + primitive_array::(num_rows), + ) + }) + .unzip(); + + Arc::new( + UnionArray::try_new( + UnionFields::from_iter(fields), + ScalarBuffer::from(type_ids), + None, + children, + ) + .unwrap(), + ) +} + +fn dense_union_array(num_rows: usize) -> ArrayRef { + let mut rng = make_rng(); + let num_types = 5; + let type_ids: Vec = (0..num_rows) + .map(|_| rng.random_range(0..num_types) as i8) + .collect(); + + let mut type_counts = vec![0i32; num_types]; + for &tid in &type_ids { + type_counts[tid as usize] += 1; + } + + let mut current_offsets = vec![0i32; num_types]; + let offsets: Vec = type_ids + .iter() + .map(|&tid| { + let offset = current_offsets[tid as usize]; + current_offsets[tid as usize] += 1; + offset + }) + .collect(); + + let (fields, children): (Vec<_>, Vec<_>) = (0..num_types) + .map(|i| { + ( + ( + i as i8, + Arc::new(Field::new(format!("f{i}"), DataType::Int64, true)), + ), + primitive_array::(type_counts[i] as usize), + ) + }) + .unzip(); + + Arc::new( + UnionArray::try_new( + UnionFields::from_iter(fields), + ScalarBuffer::from(type_ids), + Some(ScalarBuffer::from(offsets)), + children, + ) + .unwrap(), + ) +} + +fn boolean_array(array_len: usize) -> ArrayRef { + let mut rng = make_rng(); + Arc::new( + (0..array_len) + .map(|_| Some(rng.random::())) + .collect::(), + ) +} + +/// Create a StructArray with multiple columns +fn create_struct_array(pool: &StringPool, array_len: usize) -> ArrayRef { + let bool_array = boolean_array(array_len); + let int32_array = primitive_array::(array_len); + let int64_array = primitive_array::(array_len); + let str_array = pool.string_array::(array_len); + + let fields = Fields::from(vec![ + Field::new("bool_col", DataType::Boolean, false), + Field::new("int32_col", DataType::Int32, false), + Field::new("int64_col", DataType::Int64, false), + Field::new("string_col", DataType::Utf8, false), + ]); + + Arc::new(StructArray::new( + fields, + vec![bool_array, int32_array, int64_array, str_array], + None, + )) +} + +/// Create a RunArray to test run array hashing. +fn create_run_array(array_len: usize) -> ArrayRef +where + T: ArrowPrimitiveType, + StandardUniform: Distribution, +{ + let mut rng = make_rng(); + + // Create runs of varying lengths + let mut run_ends = Vec::new(); + let mut values = Vec::new(); + let mut current_end = 0; + + while current_end < array_len { + // Random run length between 1 and 50 + let run_length = rng.random_range(1..=50).min(array_len - current_end); + current_end += run_length; + run_ends.push(current_end as i32); + values.push(Some(rng.random::())); + } + + let run_ends_array = Arc::new(PrimitiveArray::::from(run_ends)); + let values_array: Arc = + Arc::new(values.into_iter().collect::>()); + + Arc::new( + RunArray::try_new(&run_ends_array, values_array.as_ref()) + .expect("Failed to create RunArray"), + ) +} + +criterion_group!(benches, criterion_benchmark, sliced_array_benchmark); +criterion_main!(benches); diff --git a/datafusion/common/src/alias.rs b/datafusion/common/src/alias.rs index 2ee2cb4dc7add..99f6447a6acd8 100644 --- a/datafusion/common/src/alias.rs +++ b/datafusion/common/src/alias.rs @@ -37,6 +37,16 @@ impl AliasGenerator { Self::default() } + /// Advance the counter to at least `min_id`, ensuring future aliases + /// won't collide with already-existing ones. + /// + /// For example, if the query already contains an alias `alias_42`, then calling + /// `update_min_id(42)` will ensure that future aliases generated by this + /// [`AliasGenerator`] will start from `alias_43`. + pub fn update_min_id(&self, min_id: usize) { + self.next_id.fetch_max(min_id + 1, Ordering::Relaxed); + } + /// Return a unique alias with the provided prefix pub fn next(&self, prefix: &str) -> String { let id = self.next_id.fetch_add(1, Ordering::Relaxed); diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs index 28202c6684b50..bc4313ed95665 100644 --- a/datafusion/common/src/cast.rs +++ b/datafusion/common/src/cast.rs @@ -20,11 +20,14 @@ //! but provide an error message rather than a panic, as the corresponding //! kernels in arrow-rs such as `as_boolean_array` do. -use crate::{downcast_value, Result}; +use crate::{Result, downcast_value}; use arrow::array::{ - BinaryViewArray, Float16Array, Int16Array, Int8Array, LargeBinaryArray, - LargeStringArray, StringViewArray, UInt16Array, + BinaryViewArray, Decimal32Array, Decimal64Array, DurationMicrosecondArray, + DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray, Float16Array, + Int8Array, Int16Array, LargeBinaryArray, LargeListViewArray, LargeStringArray, + ListViewArray, RunArray, StringViewArray, UInt16Array, }; +use arrow::datatypes::RunEndIndexType; use arrow::{ array::{ Array, BinaryArray, BooleanArray, Date32Array, Date64Array, Decimal128Array, @@ -35,254 +38,305 @@ use arrow::{ MapArray, NullArray, OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, - TimestampNanosecondArray, TimestampSecondArray, UInt32Array, UInt64Array, - UInt8Array, UnionArray, + TimestampNanosecondArray, TimestampSecondArray, UInt8Array, UInt32Array, + UInt64Array, UnionArray, }, datatypes::{ArrowDictionaryKeyType, ArrowPrimitiveType}, }; -// Downcast ArrayRef to Date32Array +// Downcast Array to Date32Array pub fn as_date32_array(array: &dyn Array) -> Result<&Date32Array> { Ok(downcast_value!(array, Date32Array)) } -// Downcast ArrayRef to Date64Array +// Downcast Array to Date64Array pub fn as_date64_array(array: &dyn Array) -> Result<&Date64Array> { Ok(downcast_value!(array, Date64Array)) } -// Downcast ArrayRef to StructArray +// Downcast Array to StructArray pub fn as_struct_array(array: &dyn Array) -> Result<&StructArray> { Ok(downcast_value!(array, StructArray)) } -// Downcast ArrayRef to Int8Array +// Downcast Array to Int8Array pub fn as_int8_array(array: &dyn Array) -> Result<&Int8Array> { Ok(downcast_value!(array, Int8Array)) } -// Downcast ArrayRef to UInt8Array +// Downcast Array to UInt8Array pub fn as_uint8_array(array: &dyn Array) -> Result<&UInt8Array> { Ok(downcast_value!(array, UInt8Array)) } -// Downcast ArrayRef to Int16Array +// Downcast Array to Int16Array pub fn as_int16_array(array: &dyn Array) -> Result<&Int16Array> { Ok(downcast_value!(array, Int16Array)) } -// Downcast ArrayRef to UInt16Array +// Downcast Array to UInt16Array pub fn as_uint16_array(array: &dyn Array) -> Result<&UInt16Array> { Ok(downcast_value!(array, UInt16Array)) } -// Downcast ArrayRef to Int32Array +// Downcast Array to Int32Array pub fn as_int32_array(array: &dyn Array) -> Result<&Int32Array> { Ok(downcast_value!(array, Int32Array)) } -// Downcast ArrayRef to UInt32Array +// Downcast Array to UInt32Array pub fn as_uint32_array(array: &dyn Array) -> Result<&UInt32Array> { Ok(downcast_value!(array, UInt32Array)) } -// Downcast ArrayRef to Int64Array +// Downcast Array to Int64Array pub fn as_int64_array(array: &dyn Array) -> Result<&Int64Array> { Ok(downcast_value!(array, Int64Array)) } -// Downcast ArrayRef to UInt64Array +// Downcast Array to UInt64Array pub fn as_uint64_array(array: &dyn Array) -> Result<&UInt64Array> { Ok(downcast_value!(array, UInt64Array)) } -// Downcast ArrayRef to Decimal128Array +// Downcast Array to Decimal32Array +pub fn as_decimal32_array(array: &dyn Array) -> Result<&Decimal32Array> { + Ok(downcast_value!(array, Decimal32Array)) +} + +// Downcast Array to Decimal64Array +pub fn as_decimal64_array(array: &dyn Array) -> Result<&Decimal64Array> { + Ok(downcast_value!(array, Decimal64Array)) +} + +// Downcast Array to Decimal128Array pub fn as_decimal128_array(array: &dyn Array) -> Result<&Decimal128Array> { Ok(downcast_value!(array, Decimal128Array)) } -// Downcast ArrayRef to Decimal256Array +// Downcast Array to Decimal256Array pub fn as_decimal256_array(array: &dyn Array) -> Result<&Decimal256Array> { Ok(downcast_value!(array, Decimal256Array)) } -// Downcast ArrayRef to Float16Array +// Downcast Array to Float16Array pub fn as_float16_array(array: &dyn Array) -> Result<&Float16Array> { Ok(downcast_value!(array, Float16Array)) } -// Downcast ArrayRef to Float32Array +// Downcast Array to Float32Array pub fn as_float32_array(array: &dyn Array) -> Result<&Float32Array> { Ok(downcast_value!(array, Float32Array)) } -// Downcast ArrayRef to Float64Array +// Downcast Array to Float64Array pub fn as_float64_array(array: &dyn Array) -> Result<&Float64Array> { Ok(downcast_value!(array, Float64Array)) } -// Downcast ArrayRef to StringArray +// Downcast Array to StringArray pub fn as_string_array(array: &dyn Array) -> Result<&StringArray> { Ok(downcast_value!(array, StringArray)) } -// Downcast ArrayRef to StringViewArray +// Downcast Array to StringViewArray pub fn as_string_view_array(array: &dyn Array) -> Result<&StringViewArray> { Ok(downcast_value!(array, StringViewArray)) } -// Downcast ArrayRef to LargeStringArray +// Downcast Array to LargeStringArray pub fn as_large_string_array(array: &dyn Array) -> Result<&LargeStringArray> { Ok(downcast_value!(array, LargeStringArray)) } -// Downcast ArrayRef to BooleanArray +// Downcast Array to BooleanArray pub fn as_boolean_array(array: &dyn Array) -> Result<&BooleanArray> { Ok(downcast_value!(array, BooleanArray)) } -// Downcast ArrayRef to ListArray +// Downcast Array to ListArray pub fn as_list_array(array: &dyn Array) -> Result<&ListArray> { Ok(downcast_value!(array, ListArray)) } -// Downcast ArrayRef to DictionaryArray +// Downcast Array to DictionaryArray pub fn as_dictionary_array( array: &dyn Array, ) -> Result<&DictionaryArray> { Ok(downcast_value!(array, DictionaryArray, T)) } -// Downcast ArrayRef to GenericBinaryArray +// Downcast Array to GenericBinaryArray pub fn as_generic_binary_array( array: &dyn Array, ) -> Result<&GenericBinaryArray> { Ok(downcast_value!(array, GenericBinaryArray, T)) } -// Downcast ArrayRef to GenericListArray +// Downcast Array to GenericListArray pub fn as_generic_list_array( array: &dyn Array, ) -> Result<&GenericListArray> { Ok(downcast_value!(array, GenericListArray, T)) } -// Downcast ArrayRef to LargeListArray +// Downcast Array to LargeListArray pub fn as_large_list_array(array: &dyn Array) -> Result<&LargeListArray> { Ok(downcast_value!(array, LargeListArray)) } -// Downcast ArrayRef to PrimitiveArray +// Downcast Array to PrimitiveArray pub fn as_primitive_array( array: &dyn Array, ) -> Result<&PrimitiveArray> { Ok(downcast_value!(array, PrimitiveArray, T)) } -// Downcast ArrayRef to MapArray +// Downcast Array to MapArray pub fn as_map_array(array: &dyn Array) -> Result<&MapArray> { Ok(downcast_value!(array, MapArray)) } -// Downcast ArrayRef to NullArray +// Downcast Array to NullArray pub fn as_null_array(array: &dyn Array) -> Result<&NullArray> { Ok(downcast_value!(array, NullArray)) } -// Downcast ArrayRef to NullArray +// Downcast Array to NullArray pub fn as_union_array(array: &dyn Array) -> Result<&UnionArray> { Ok(downcast_value!(array, UnionArray)) } -// Downcast ArrayRef to Time32SecondArray +// Downcast Array to Time32SecondArray pub fn as_time32_second_array(array: &dyn Array) -> Result<&Time32SecondArray> { Ok(downcast_value!(array, Time32SecondArray)) } -// Downcast ArrayRef to Time32MillisecondArray +// Downcast Array to Time32MillisecondArray pub fn as_time32_millisecond_array(array: &dyn Array) -> Result<&Time32MillisecondArray> { Ok(downcast_value!(array, Time32MillisecondArray)) } -// Downcast ArrayRef to Time64MicrosecondArray +// Downcast Array to Time64MicrosecondArray pub fn as_time64_microsecond_array(array: &dyn Array) -> Result<&Time64MicrosecondArray> { Ok(downcast_value!(array, Time64MicrosecondArray)) } -// Downcast ArrayRef to Time64NanosecondArray +// Downcast Array to Time64NanosecondArray pub fn as_time64_nanosecond_array(array: &dyn Array) -> Result<&Time64NanosecondArray> { Ok(downcast_value!(array, Time64NanosecondArray)) } -// Downcast ArrayRef to TimestampNanosecondArray +// Downcast Array to TimestampNanosecondArray pub fn as_timestamp_nanosecond_array( array: &dyn Array, ) -> Result<&TimestampNanosecondArray> { Ok(downcast_value!(array, TimestampNanosecondArray)) } -// Downcast ArrayRef to TimestampMillisecondArray +// Downcast Array to TimestampMillisecondArray pub fn as_timestamp_millisecond_array( array: &dyn Array, ) -> Result<&TimestampMillisecondArray> { Ok(downcast_value!(array, TimestampMillisecondArray)) } -// Downcast ArrayRef to TimestampMicrosecondArray +// Downcast Array to TimestampMicrosecondArray pub fn as_timestamp_microsecond_array( array: &dyn Array, ) -> Result<&TimestampMicrosecondArray> { Ok(downcast_value!(array, TimestampMicrosecondArray)) } -// Downcast ArrayRef to TimestampSecondArray +// Downcast Array to TimestampSecondArray pub fn as_timestamp_second_array(array: &dyn Array) -> Result<&TimestampSecondArray> { Ok(downcast_value!(array, TimestampSecondArray)) } -// Downcast ArrayRef to IntervalYearMonthArray +// Downcast Array to IntervalYearMonthArray pub fn as_interval_ym_array(array: &dyn Array) -> Result<&IntervalYearMonthArray> { Ok(downcast_value!(array, IntervalYearMonthArray)) } -// Downcast ArrayRef to IntervalDayTimeArray +// Downcast Array to IntervalDayTimeArray pub fn as_interval_dt_array(array: &dyn Array) -> Result<&IntervalDayTimeArray> { Ok(downcast_value!(array, IntervalDayTimeArray)) } -// Downcast ArrayRef to IntervalMonthDayNanoArray +// Downcast Array to IntervalMonthDayNanoArray pub fn as_interval_mdn_array(array: &dyn Array) -> Result<&IntervalMonthDayNanoArray> { Ok(downcast_value!(array, IntervalMonthDayNanoArray)) } -// Downcast ArrayRef to BinaryArray +// Downcast Array to DurationSecondArray +pub fn as_duration_second_array(array: &dyn Array) -> Result<&DurationSecondArray> { + Ok(downcast_value!(array, DurationSecondArray)) +} + +// Downcast Array to DurationMillisecondArray +pub fn as_duration_millisecond_array( + array: &dyn Array, +) -> Result<&DurationMillisecondArray> { + Ok(downcast_value!(array, DurationMillisecondArray)) +} + +// Downcast Array to DurationMicrosecondArray +pub fn as_duration_microsecond_array( + array: &dyn Array, +) -> Result<&DurationMicrosecondArray> { + Ok(downcast_value!(array, DurationMicrosecondArray)) +} + +// Downcast Array to DurationNanosecondArray +pub fn as_duration_nanosecond_array( + array: &dyn Array, +) -> Result<&DurationNanosecondArray> { + Ok(downcast_value!(array, DurationNanosecondArray)) +} + +// Downcast Array to BinaryArray pub fn as_binary_array(array: &dyn Array) -> Result<&BinaryArray> { Ok(downcast_value!(array, BinaryArray)) } -// Downcast ArrayRef to BinaryViewArray +// Downcast Array to BinaryViewArray pub fn as_binary_view_array(array: &dyn Array) -> Result<&BinaryViewArray> { Ok(downcast_value!(array, BinaryViewArray)) } -// Downcast ArrayRef to LargeBinaryArray +// Downcast Array to LargeBinaryArray pub fn as_large_binary_array(array: &dyn Array) -> Result<&LargeBinaryArray> { Ok(downcast_value!(array, LargeBinaryArray)) } -// Downcast ArrayRef to FixedSizeListArray +// Downcast Array to FixedSizeListArray pub fn as_fixed_size_list_array(array: &dyn Array) -> Result<&FixedSizeListArray> { Ok(downcast_value!(array, FixedSizeListArray)) } -// Downcast ArrayRef to FixedSizeListArray +// Downcast Array to FixedSizeBinaryArray pub fn as_fixed_size_binary_array(array: &dyn Array) -> Result<&FixedSizeBinaryArray> { Ok(downcast_value!(array, FixedSizeBinaryArray)) } -// Downcast ArrayRef to GenericBinaryArray +// Downcast Array to GenericBinaryArray pub fn as_generic_string_array( array: &dyn Array, ) -> Result<&GenericStringArray> { Ok(downcast_value!(array, GenericStringArray, T)) } + +// Downcast Array to ListViewArray +pub fn as_list_view_array(array: &dyn Array) -> Result<&ListViewArray> { + Ok(downcast_value!(array, ListViewArray)) +} + +// Downcast Array to LargeListViewArray +pub fn as_large_list_view_array(array: &dyn Array) -> Result<&LargeListViewArray> { + Ok(downcast_value!(array, LargeListViewArray)) +} + +// Downcast Array to RunArray +pub fn as_run_array(array: &dyn Array) -> Result<&RunArray> { + Ok(downcast_value!(array, RunArray, T)) +} diff --git a/datafusion/common/src/column.rs b/datafusion/common/src/column.rs index b3acaeee5a54c..c7f0b5a4f4881 100644 --- a/datafusion/common/src/column.rs +++ b/datafusion/common/src/column.rs @@ -18,13 +18,12 @@ //! Column use crate::error::{_schema_err, add_possible_columns_to_diag}; -use crate::utils::{parse_identifiers_normalized, quote_identifier}; +use crate::utils::parse_identifiers_normalized; +use crate::utils::quote_identifier; use crate::{DFSchema, Diagnostic, Result, SchemaError, Spans, TableReference}; use arrow::datatypes::{Field, FieldRef}; use std::collections::HashSet; -use std::convert::Infallible; use std::fmt; -use std::str::FromStr; /// A named reference to a qualified field in a schema. #[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] @@ -140,6 +139,7 @@ impl Column { } /// Deserialize a fully qualified name string into a column preserving column text case + #[cfg(feature = "sql")] pub fn from_qualified_name_ignore_case(flat_name: impl Into) -> Self { let flat_name = flat_name.into(); Self::from_idents(parse_identifiers_normalized(&flat_name, true)).unwrap_or_else( @@ -151,6 +151,11 @@ impl Column { ) } + #[cfg(not(feature = "sql"))] + pub fn from_qualified_name_ignore_case(flat_name: impl Into) -> Self { + Self::from_qualified_name(flat_name) + } + /// return the column's name. /// /// Note: This ignores the relation and returns the column name only. @@ -262,7 +267,7 @@ impl Column { // If not due to USING columns then due to ambiguous column name return _schema_err!(SchemaError::AmbiguousReference { - field: Column::new_unqualified(&self.name), + field: Box::new(Column::new_unqualified(&self.name)), }) .map_err(|err| { let mut diagnostic = Diagnostic::new_error( @@ -356,8 +361,9 @@ impl From<(Option<&TableReference>, &FieldRef)> for Column { } } -impl FromStr for Column { - type Err = Infallible; +#[cfg(feature = "sql")] +impl std::str::FromStr for Column { + type Err = std::convert::Infallible; fn from_str(s: &str) -> Result { Ok(s.into()) diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 726015d171496..9b6e6aa5dac37 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -17,15 +17,25 @@ //! Runtime configuration, via [`ConfigOptions`] +use arrow_ipc::CompressionType; + +#[cfg(feature = "parquet_encryption")] +use crate::encryption::{FileDecryptionProperties, FileEncryptionProperties}; use crate::error::_config_err; +use crate::format::{ExplainAnalyzeLevel, ExplainFormat}; +use crate::parquet_config::DFParquetWriterVersion; use crate::parsers::CompressionTypeVariant; use crate::utils::get_available_parallelism; use crate::{DataFusionError, Result}; +#[cfg(feature = "parquet_encryption")] +use hex; use std::any::Any; use std::collections::{BTreeMap, HashMap}; use std::error::Error; use std::fmt::{self, Display}; use std::str::FromStr; +#[cfg(feature = "parquet_encryption")] +use std::sync::Arc; /// A macro that wraps a configuration struct and automatically derives /// [`Default`] and [`ConfigField`] for it, allowing it to be used @@ -48,7 +58,7 @@ use std::str::FromStr; /// /// Field 3 doc /// field3: Option, default = None /// } -///} +/// } /// ``` /// /// Will generate @@ -148,12 +158,10 @@ macro_rules! config_namespace { // $(#[allow(deprecated)])? { $(let value = $transform(value);)? // Apply transformation if specified - #[allow(deprecated)] let ret = self.$field_name.set(rem, value.as_ref()); $(if !$warn.is_empty() { let default: $field_type = $default; - #[allow(deprecated)] if default != self.$field_name { log::warn!($warn); } @@ -172,14 +180,36 @@ macro_rules! config_namespace { $( let key = format!(concat!("{}.", stringify!($field_name)), key_prefix); let desc = concat!($($d),*).trim(); - #[allow(deprecated)] self.$field_name.visit(v, key.as_str(), desc); )* } + + fn reset(&mut self, key: &str) -> $crate::error::Result<()> { + let (key, rem) = key.split_once('.').unwrap_or((key, "")); + match key { + $( + stringify!($field_name) => { + { + if rem.is_empty() { + let default_value: $field_type = $default; + self.$field_name = default_value; + Ok(()) + } else { + self.$field_name.reset(rem) + } + } + }, + )* + _ => $crate::error::_config_err!( + "Config value \"{}\" not found on {}", + key, + stringify!($struct_name) + ), + } + } } impl Default for $struct_name { fn default() -> Self { - #[allow(deprecated)] Self { $($field_name: $default),* } @@ -250,7 +280,7 @@ config_namespace! { /// Configure the SQL dialect used by DataFusion's parser; supported values include: Generic, /// MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, Ansi, DuckDB and Databricks. - pub dialect: String, default = "generic".to_string() + pub dialect: Dialect, default = Dialect::Generic // no need to lowercase because `sqlparser::dialect_from_str`] is case-insensitive /// If true, permit lengths for `VARCHAR` such as `VARCHAR(20)`, but @@ -259,10 +289,10 @@ config_namespace! { /// string length and thus DataFusion can not enforce such limits. pub support_varchar_with_length: bool, default = true - /// If true, `VARCHAR` is mapped to `Utf8View` during SQL planning. - /// If false, `VARCHAR` is mapped to `Utf8` during SQL planning. - /// Default is false. - pub map_varchar_to_utf8view: bool, default = true + /// If true, string types (VARCHAR, CHAR, Text, and String) are mapped to `Utf8View` during SQL planning. + /// If false, they are mapped to `Utf8`. + /// Default is true. + pub map_string_types_to_utf8view: bool, default = true /// When set to true, the source locations relative to the original SQL /// query (i.e. [`Span`](https://docs.rs/sqlparser/latest/sqlparser/tokenizer/struct.Span.html)) will be collected @@ -271,6 +301,159 @@ config_namespace! { /// Specifies the recursion depth limit when parsing complex SQL Queries pub recursion_limit: usize, default = 50 + + /// Specifies the default null ordering for query results. There are 4 options: + /// - `nulls_max`: Nulls appear last in ascending order. + /// - `nulls_min`: Nulls appear first in ascending order. + /// - `nulls_first`: Nulls always be first in any order. + /// - `nulls_last`: Nulls always be last in any order. + /// + /// By default, `nulls_max` is used to follow Postgres's behavior. + /// postgres rule: + pub default_null_ordering: String, default = "nulls_max".to_string() + } +} + +/// This is the SQL dialect used by DataFusion's parser. +/// This mirrors [sqlparser::dialect::Dialect](https://docs.rs/sqlparser/latest/sqlparser/dialect/trait.Dialect.html) +/// trait in order to offer an easier API and avoid adding the `sqlparser` dependency +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] +pub enum Dialect { + #[default] + Generic, + MySQL, + PostgreSQL, + Hive, + SQLite, + Snowflake, + Redshift, + MsSQL, + ClickHouse, + BigQuery, + Ansi, + DuckDB, + Databricks, +} + +impl AsRef for Dialect { + fn as_ref(&self) -> &str { + match self { + Self::Generic => "generic", + Self::MySQL => "mysql", + Self::PostgreSQL => "postgresql", + Self::Hive => "hive", + Self::SQLite => "sqlite", + Self::Snowflake => "snowflake", + Self::Redshift => "redshift", + Self::MsSQL => "mssql", + Self::ClickHouse => "clickhouse", + Self::BigQuery => "bigquery", + Self::Ansi => "ansi", + Self::DuckDB => "duckdb", + Self::Databricks => "databricks", + } + } +} + +impl FromStr for Dialect { + type Err = DataFusionError; + + fn from_str(s: &str) -> Result { + let value = match s.to_ascii_lowercase().as_str() { + "generic" => Self::Generic, + "mysql" => Self::MySQL, + "postgresql" | "postgres" => Self::PostgreSQL, + "hive" => Self::Hive, + "sqlite" => Self::SQLite, + "snowflake" => Self::Snowflake, + "redshift" => Self::Redshift, + "mssql" => Self::MsSQL, + "clickhouse" => Self::ClickHouse, + "bigquery" => Self::BigQuery, + "ansi" => Self::Ansi, + "duckdb" => Self::DuckDB, + "databricks" => Self::Databricks, + other => { + let error_message = format!( + "Invalid Dialect: {other}. Expected one of: Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, Ansi, DuckDB, Databricks" + ); + return Err(DataFusionError::Configuration(error_message)); + } + }; + Ok(value) + } +} + +impl ConfigField for Dialect { + fn visit(&self, v: &mut V, key: &str, description: &'static str) { + v.some(key, self, description) + } + + fn set(&mut self, _: &str, value: &str) -> Result<()> { + *self = Self::from_str(value)?; + Ok(()) + } +} + +impl Display for Dialect { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let str = self.as_ref(); + write!(f, "{str}") + } +} + +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] +pub enum SpillCompression { + Zstd, + Lz4Frame, + #[default] + Uncompressed, +} + +impl FromStr for SpillCompression { + type Err = DataFusionError; + + fn from_str(s: &str) -> Result { + match s.to_ascii_lowercase().as_str() { + "zstd" => Ok(Self::Zstd), + "lz4_frame" => Ok(Self::Lz4Frame), + "uncompressed" | "" => Ok(Self::Uncompressed), + other => Err(DataFusionError::Configuration(format!( + "Invalid Spill file compression type: {other}. Expected one of: zstd, lz4_frame, uncompressed" + ))), + } + } +} + +impl ConfigField for SpillCompression { + fn visit(&self, v: &mut V, key: &str, description: &'static str) { + v.some(key, self, description) + } + + fn set(&mut self, _: &str, value: &str) -> Result<()> { + *self = SpillCompression::from_str(value)?; + Ok(()) + } +} + +impl Display for SpillCompression { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let str = match self { + Self::Zstd => "zstd", + Self::Lz4Frame => "lz4_frame", + Self::Uncompressed => "uncompressed", + }; + write!(f, "{str}") + } +} + +impl From for Option { + fn from(c: SpillCompression) -> Self { + match c { + SpillCompression::Zstd => Some(CompressionType::ZSTD), + SpillCompression::Lz4Frame => Some(CompressionType::LZ4_FRAME), + SpillCompression::Uncompressed => None, + } } } @@ -286,6 +469,25 @@ config_namespace! { /// metadata memory consumption pub batch_size: usize, default = 8192 + /// A perfect hash join (see `HashJoinExec` for more details) will be considered + /// if the range of keys (max - min) on the build side is < this threshold. + /// This provides a fast path for joins with very small key ranges, + /// bypassing the density check. + /// + /// Currently only supports cases where build_side.num_rows() < u32::MAX. + /// Support for build_side.num_rows() >= u32::MAX will be added in the future. + pub perfect_hash_join_small_build_threshold: usize, default = 1024 + + /// The minimum required density of join keys on the build side to consider a + /// perfect hash join (see `HashJoinExec` for more details). Density is calculated as: + /// `(number of rows) / (max_key - min_key + 1)`. + /// A perfect hash join may be used if the actual key density > this + /// value. + /// + /// Currently only supports cases where build_side.num_rows() < u32::MAX. + /// Support for build_side.num_rows() >= u32::MAX will be added in the future. + pub perfect_hash_join_min_key_density: f64, default = 0.15 + /// When set to true, record batches will be examined between each operator and /// small batches will be coalesced into larger batches. This is helpful when there /// are highly selective filters or joins that could produce tiny output batches. The @@ -294,8 +496,8 @@ config_namespace! { /// Should DataFusion collect statistics when first creating a table. /// Has no effect after the table is created. Applies to the default - /// `ListingTableProvider` in DataFusion. Defaults to false. - pub collect_statistics: bool, default = false + /// `ListingTableProvider` in DataFusion. Defaults to true. + pub collect_statistics: bool, default = true /// Number of partitions for query execution. Increasing partitions can increase /// concurrency. @@ -305,9 +507,8 @@ config_namespace! { /// The default time zone /// - /// Some functions, e.g. `EXTRACT(HOUR from SOME_TIME)`, shift the underlying datetime - /// according to this time zone, and then extract the hour - pub time_zone: Option, default = Some("+00:00".into()) + /// Some functions, e.g. `now` return timestamps in this time zone + pub time_zone: Option, default = None /// Parquet options pub parquet: ParquetOptions, default = Default::default() @@ -330,6 +531,16 @@ config_namespace! { /// the new schema verification step. pub skip_physical_aggregate_schema_check: bool, default = false + /// Sets the compression codec used when spilling data to disk. + /// + /// Since datafusion writes spill files using the Arrow IPC Stream format, + /// only codecs supported by the Arrow IPC Stream Writer are allowed. + /// Valid values are: uncompressed, lz4_frame, zstd. + /// Note: lz4_frame offers faster (de)compression, but typically results in + /// larger spill files. In contrast, zstd achieves + /// higher compression ratios at the cost of slower (de)compression speed. + pub spill_compression: SpillCompression, default = SpillCompression::Uncompressed + /// Specifies the reserved memory for each spillable sort operation to /// facilitate an in-memory merge. /// @@ -346,6 +557,23 @@ config_namespace! { /// batches and merged. pub sort_in_place_threshold_bytes: usize, default = 1024 * 1024 + /// Maximum size in bytes for individual spill files before rotating to a new file. + /// + /// When operators spill data to disk (e.g., RepartitionExec), they write + /// multiple batches to the same file until this size limit is reached, then rotate + /// to a new file. This reduces syscall overhead compared to one-file-per-batch + /// while preventing files from growing too large. + /// + /// A larger value reduces file creation overhead but may hold more disk space. + /// A smaller value creates more files but allows finer-grained space reclamation + /// as files can be deleted once fully consumed. + /// + /// Now only `RepartitionExec` supports this spill file rotation feature, other spilling operators + /// may create spill files larger than the limit. + /// + /// Default: 128 MB + pub max_spill_file_size_bytes: usize, default = 128 * 1024 * 1024 + /// Number of files to read in parallel when inferring schema and statistics pub meta_fetch_concurrency: usize, default = 32 @@ -373,6 +601,11 @@ config_namespace! { /// tables (e.g. `/table/year=2021/month=01/data.parquet`). pub listing_table_ignore_subdirectory: bool, default = true + /// Should a `ListingTable` created through the `ListingTableFactory` infer table + /// partitions from Hive compliant directories. Defaults to true (partition columns are + /// inferred and will be represented in the table schema). + pub listing_table_factory_infer_partitions: bool, default = true + /// Should DataFusion support recursive CTEs pub enable_recursive_ctes: bool, default = true @@ -413,6 +646,44 @@ config_namespace! { /// written, it may be necessary to increase this size to avoid errors from /// the remote end point. pub objectstore_writer_buffer_size: usize, default = 10 * 1024 * 1024 + + /// Whether to enable ANSI SQL mode. + /// + /// The flag is experimental and relevant only for DataFusion Spark built-in functions + /// + /// When `enable_ansi_mode` is set to `true`, the query engine follows ANSI SQL + /// semantics for expressions, casting, and error handling. This means: + /// - **Strict type coercion rules:** implicit casts between incompatible types are disallowed. + /// - **Standard SQL arithmetic behavior:** operations such as division by zero, + /// numeric overflow, or invalid casts raise runtime errors rather than returning + /// `NULL` or adjusted values. + /// - **Consistent ANSI behavior** for string concatenation, comparisons, and `NULL` handling. + /// + /// When `enable_ansi_mode` is `false` (the default), the engine uses a more permissive, + /// non-ANSI mode designed for user convenience and backward compatibility. In this mode: + /// - Implicit casts between types are allowed (e.g., string to integer when possible). + /// - Arithmetic operations are more lenient — for example, `abs()` on the minimum + /// representable integer value returns the input value instead of raising overflow. + /// - Division by zero or invalid casts may return `NULL` instead of failing. + /// + /// # Default + /// `false` — ANSI SQL mode is disabled by default. + pub enable_ansi_mode: bool, default = false + + /// How many bytes to buffer in the probe side of hash joins while the build side is + /// concurrently being built. + /// + /// Without this, hash joins will wait until the full materialization of the build side + /// before polling the probe side. This is useful in scenarios where the query is not + /// completely CPU bounded, allowing to do some early work concurrently and reducing the + /// latency of the query. + /// + /// Note that when hash join buffering is enabled, the probe side will start eagerly + /// polling data, not giving time for the producer side of dynamic filters to produce any + /// meaningful predicate. Queries with dynamic filters might see performance degradation. + /// + /// Disabled by default, set to a number greater than 0 for enabling it. + pub hash_join_buffering_capacity: usize, default = 0 } } @@ -444,7 +715,10 @@ config_namespace! { /// bytes of the parquet file optimistically. If not specified, two reads are required: /// One read to fetch the 8-byte parquet footer and /// another to fetch the metadata length encoded in the footer - pub metadata_size_hint: Option, default = None + /// Default setting to 512 KiB, which should be sufficient for most parquet files, + /// it can reduce one I/O operation per parquet file. If the metadata is larger than + /// the hint, two reads will still be performed. + pub metadata_size_hint: Option, default = Some(512 * 1024) /// (reading) If true, filter expressions are be applied during the parquet decoding operation to /// reduce the number of rows decoded. This optimization is sometimes called "late materialization". @@ -455,6 +729,12 @@ config_namespace! { /// the filters are applied in the same order as written in the query pub reorder_filters: bool, default = false + /// (reading) Force the use of RowSelections for filter results, when + /// pushdown_filters is enabled. If false, the reader will automatically + /// choose between a RowSelection and a Bitmap based on the number and + /// pattern of selected rows. + pub force_filter_selections: bool, default = false + /// (reading) If true, parquet reader will read columns of `Utf8/Utf8Large` with `Utf8View`, /// and `Binary/BinaryLarge` with `BinaryView`. pub schema_force_view_types: bool, default = true @@ -478,18 +758,26 @@ config_namespace! { /// (reading) Use any available bloom filters when reading parquet files pub bloom_filter_on_read: bool, default = true + /// (reading) The maximum predicate cache size, in bytes. When + /// `pushdown_filters` is enabled, sets the maximum memory used to cache + /// the results of predicate evaluation between filter evaluation and + /// output generation. Decreasing this value will reduce memory usage, + /// but may increase IO and CPU usage. None means use the default + /// parquet reader setting. 0 means no caching. + pub max_predicate_cache_size: Option, default = None + // The following options affect writing to parquet files // and map to parquet::file::properties::WriterProperties /// (writing) Sets best effort maximum size of data page in bytes pub data_pagesize_limit: usize, default = 1024 * 1024 - /// (writing) Sets write_batch_size in bytes + /// (writing) Sets write_batch_size in rows pub write_batch_size: usize, default = 1024 /// (writing) Sets parquet writer version /// valid values are "1.0" and "2.0" - pub writer_version: String, default = "1.0".to_string() + pub writer_version: DFParquetWriterVersion, default = DFParquetWriterVersion::default() /// (writing) Skip encoding the embedded arrow metadata in the KV_meta /// @@ -499,7 +787,7 @@ config_namespace! { /// (writing) Sets default parquet compression codec. /// Valid values are: uncompressed, snappy, gzip(level), - /// lzo, brotli(level), lz4, zstd(level), and lz4_raw. + /// brotli(level), lz4, zstd(level), and lz4_raw. /// These values are not case sensitive. If NULL, uses /// default parquet writer setting /// @@ -520,13 +808,6 @@ config_namespace! { /// default parquet writer setting pub statistics_enabled: Option, transform = str::to_lowercase, default = Some("page".into()) - /// (writing) Sets max statistics size for any column. If NULL, uses - /// default parquet writer setting - /// max_statistics_size is deprecated, currently it is not being used - // TODO: remove once deprecated - #[deprecated(since = "45.0.0", note = "Setting does not do anything")] - pub max_statistics_size: Option, default = Some(4096) - /// (writing) Target maximum number of rows in each row group (defaults to 1M /// rows). Writing larger row groups requires more memory to write, but /// can get better compression and be faster to read. @@ -538,9 +819,9 @@ config_namespace! { /// (writing) Sets column index truncate length pub column_index_truncate_length: Option, default = Some(64) - /// (writing) Sets statictics truncate length. If NULL, uses + /// (writing) Sets statistics truncate length. If NULL, uses /// default parquet writer setting - pub statistics_truncate_length: Option, default = None + pub statistics_truncate_length: Option, default = Some(64) /// (writing) Sets best effort maximum number of rows in data page pub data_page_row_count_limit: usize, default = 20_000 @@ -594,6 +875,44 @@ config_namespace! { } } +config_namespace! { + /// Options for configuring Parquet Modular Encryption + /// + /// To use Parquet encryption, you must enable the `parquet_encryption` feature flag, as it is not activated by default. + pub struct ParquetEncryptionOptions { + /// Optional file decryption properties + pub file_decryption: Option, default = None + + /// Optional file encryption properties + pub file_encryption: Option, default = None + + /// Identifier for the encryption factory to use to create file encryption and decryption properties. + /// Encryption factories can be registered in the runtime environment with + /// `RuntimeEnv::register_parquet_encryption_factory`. + pub factory_id: Option, default = None + + /// Any encryption factory specific options + pub factory_options: EncryptionFactoryOptions, default = EncryptionFactoryOptions::default() + } +} + +impl ParquetEncryptionOptions { + /// Specify the encryption factory to use for Parquet modular encryption, along with its configuration + pub fn configure_factory( + &mut self, + factory_id: &str, + config: &impl ExtensionOptions, + ) { + self.factory_id = Some(factory_id.to_owned()); + self.factory_options.options.clear(); + for entry in config.entries() { + if let Some(value) = entry.value { + self.factory_options.options.insert(entry.key, value); + } + } + } +} + config_namespace! { /// Options related to query optimization /// @@ -614,6 +933,36 @@ config_namespace! { /// during aggregations, if possible pub enable_topk_aggregation: bool, default = true + /// When set to true, the optimizer will attempt to push limit operations + /// past window functions, if possible + pub enable_window_limits: bool, default = true + + /// When set to true, the optimizer will push TopK (Sort with fetch) + /// below hash repartition when the partition key is a prefix of the + /// sort key, reducing data volume before the shuffle. + pub enable_topk_repartition: bool, default = true + + /// When set to true, the optimizer will attempt to push down TopK dynamic filters + /// into the file scan phase. + pub enable_topk_dynamic_filter_pushdown: bool, default = true + + /// When set to true, the optimizer will attempt to push down Join dynamic filters + /// into the file scan phase. + pub enable_join_dynamic_filter_pushdown: bool, default = true + + /// When set to true, the optimizer will attempt to push down Aggregate dynamic filters + /// into the file scan phase. + pub enable_aggregate_dynamic_filter_pushdown: bool, default = true + + /// When set to true attempts to push down dynamic filters generated by operators (TopK, Join & Aggregate) into the file scan phase. + /// For example, for a query such as `SELECT * FROM t ORDER BY timestamp DESC LIMIT 10`, the optimizer + /// will attempt to push down the current top 10 timestamps that the TopK operator references into the file scans. + /// This means that if we already have 10 timestamps in the year 2025 + /// any files that only have timestamps in the year 2024 can be skipped / pruned at various stages in the scan. + /// The config will suppress `enable_join_dynamic_filter_pushdown`, `enable_topk_dynamic_filter_pushdown` & `enable_aggregate_dynamic_filter_pushdown` + /// So if you disable `enable_topk_dynamic_filter_pushdown`, then enable `enable_dynamic_filter_pushdown`, the `enable_topk_dynamic_filter_pushdown` will be overridden. + pub enable_dynamic_filter_pushdown: bool, default = true + /// When set to true, the optimizer will insert filters before a join between /// a nullable and non-nullable column to filter out nulls on the nullable side. This /// filter can add additional overhead when the file format does not fully support @@ -656,6 +1005,19 @@ config_namespace! { /// record tables provided to the MemTable on creation. pub repartition_file_scans: bool, default = true + /// Minimum number of distinct partition values required to group files by their + /// Hive partition column values (enabling Hash partitioning declaration). + /// + /// How the option is used: + /// - preserve_file_partitions=0: Disable it. + /// - preserve_file_partitions=1: Always enable it. + /// - preserve_file_partitions=N, actual file partitions=M: Only enable when M >= N. + /// This threshold preserves I/O parallelism when file partitioning is below it. + /// + /// Note: This may reduce parallelism, rooting from the I/O level, if the number of distinct + /// partitions is less than the target_partitions. + pub preserve_file_partitions: usize, default = 0 + /// Should DataFusion repartition data using the partitions keys to execute window /// functions in parallel using the provided `target_partitions` level pub repartition_windows: bool, default = true @@ -678,6 +1040,34 @@ config_namespace! { /// ``` pub repartition_sorts: bool, default = true + /// Partition count threshold for subset satisfaction optimization. + /// + /// When the current partition count is >= this threshold, DataFusion will + /// skip repartitioning if the required partitioning expression is a subset + /// of the current partition expression such as Hash(a) satisfies Hash(a, b). + /// + /// When the current partition count is < this threshold, DataFusion will + /// repartition to increase parallelism even when subset satisfaction applies. + /// + /// Set to 0 to always repartition (disable subset satisfaction optimization). + /// Set to a high value to always use subset satisfaction. + /// + /// Example (subset_repartition_threshold = 4): + /// ```text + /// Hash([a]) satisfies Hash([a, b]) because (Hash([a, b]) is subset of Hash([a]) + /// + /// If current partitions (3) < threshold (4), repartition: + /// AggregateExec: mode=FinalPartitioned, gby=[a, b], aggr=[SUM(x)] + /// RepartitionExec: partitioning=Hash([a, b], 8), input_partitions=3 + /// AggregateExec: mode=Partial, gby=[a, b], aggr=[SUM(x)] + /// DataSourceExec: file_groups={...}, output_partitioning=Hash([a], 3) + /// + /// If current partitions (8) >= threshold (4), use subset satisfaction: + /// AggregateExec: mode=SinglePartitioned, gby=[a, b], aggr=[SUM(x)] + /// DataSourceExec: file_groups={...}, output_partitioning=Hash([a], 8) + /// ``` + pub subset_repartition_threshold: usize, default = 4 + /// When true, DataFusion will opportunistically remove sorts when the data is already sorted, /// (i.e. setting `preserve_order` to true on `RepartitionExec` and /// using `SortPreservingMergeExec`) @@ -702,6 +1092,11 @@ config_namespace! { /// HashJoin can work more efficiently than SortMergeJoin but consumes more memory pub prefer_hash_join: bool, default = true + /// When set to true, piecewise merge join is enabled. PiecewiseMergeJoin is currently + /// experimental. Physical planner will opt for PiecewiseMergeJoin when there is only + /// one range filter. + pub enable_piecewise_merge_join: bool, default = false + /// The maximum estimated size in bytes for one input side of a HashJoin /// will be collected into a single partition pub hash_join_single_partition_threshold: usize, default = 1024 * 1024 @@ -710,6 +1105,36 @@ config_namespace! { /// will be collected into a single partition pub hash_join_single_partition_threshold_rows: usize, default = 1024 * 128 + /// Maximum size in bytes for the build side of a hash join to be pushed down as an InList expression for dynamic filtering. + /// Build sides larger than this will use hash table lookups instead. + /// Set to 0 to always use hash table lookups. + /// + /// InList pushdown can be more efficient for small build sides because it can result in better + /// statistics pruning as well as use any bloom filters present on the scan side. + /// InList expressions are also more transparent and easier to serialize over the network in distributed uses of DataFusion. + /// On the other hand InList pushdown requires making a copy of the data and thus adds some overhead to the build side and uses more memory. + /// + /// This setting is per-partition, so we may end up using `hash_join_inlist_pushdown_max_size` * `target_partitions` memory. + /// + /// The default is 128kB per partition. + /// This should allow point lookup joins (e.g. joining on a unique primary key) to use InList pushdown in most cases + /// but avoids excessive memory usage or overhead for larger joins. + pub hash_join_inlist_pushdown_max_size: usize, default = 128 * 1024 + + /// Maximum number of distinct values (rows) in the build side of a hash join to be pushed down as an InList expression for dynamic filtering. + /// Build sides with more rows than this will use hash table lookups instead. + /// Set to 0 to always use hash table lookups. + /// + /// This provides an additional limit beyond `hash_join_inlist_pushdown_max_size` to prevent + /// very large IN lists that might not provide much benefit over hash table lookups. + /// + /// This uses the deduplicated row count once the build side has been evaluated. + /// + /// The default is 150 values per partition. + /// This is inspired by Trino's `max-filter-keys-per-column` setting. + /// See: + pub hash_join_inlist_pushdown_max_distinct_values: usize, default = 150 + /// The default filter selectivity used by Filter Statistics /// when an exact selectivity cannot be determined. Valid values are /// between 0 (no selectivity) and 100 (all rows are selected). @@ -722,6 +1147,27 @@ config_namespace! { /// then the output will be coerced to a non-view. /// Coerces `Utf8View` to `LargeUtf8`, and `BinaryView` to `LargeBinary`. pub expand_views_at_output: bool, default = false + + /// Enable sort pushdown optimization. + /// When enabled, attempts to push sort requirements down to data sources + /// that can natively handle them (e.g., by reversing file/row group read order). + /// + /// Returns **inexact ordering**: Sort operator is kept for correctness, + /// but optimized input enables early termination for TopK queries (ORDER BY ... LIMIT N), + /// providing significant speedup. + /// + /// Memory: No additional overhead (only changes read order). + /// + /// Future: Will add option to detect perfectly sorted data and eliminate Sort completely. + /// + /// Default: true + pub enable_sort_pushdown: bool, default = true + + /// When set to true, the optimizer will extract leaf expressions + /// (such as `get_field`) from filter/sort/join nodes into projections + /// closer to the leaf table scans, and push those projections down + /// towards the leaf nodes. + pub enable_leaf_expression_pushdown: bool, default = true } } @@ -750,7 +1196,16 @@ config_namespace! { /// Display format of explain. Default is "indent". /// When set to "tree", it will print the plan in a tree-rendered format. - pub format: String, default = "indent".to_string() + pub format: ExplainFormat, default = ExplainFormat::Indent + + /// (format=tree only) Maximum total width of the rendered tree. + /// When set to 0, the tree will have no width limit. + pub tree_maximum_render_width: usize, default = 240 + + /// Verbosity level for "EXPLAIN ANALYZE". Default is "dev" + /// "summary" shows common metrics for high-level insights. + /// "dev" provides deep operator-level introspection for developers. + pub analyze_level: ExplainAnalyzeLevel, default = ExplainAnalyzeLevel::Dev } } @@ -803,7 +1258,7 @@ impl<'a> TryInto> for &'a FormatOptions return _config_err!( "Invalid duration format: {}. Valid values are pretty or iso8601", self.duration_format - ) + ); } }; @@ -821,7 +1276,7 @@ impl<'a> TryInto> for &'a FormatOptions } /// A key value pair, with a corresponding description -#[derive(Debug)] +#[derive(Debug, Clone, Hash, PartialEq, Eq)] pub struct ConfigEntry { /// A unique string to identify this config value pub key: String, @@ -854,6 +1309,15 @@ pub struct ConfigOptions { } impl ConfigField for ConfigOptions { + fn visit(&self, v: &mut V, _key_prefix: &str, _description: &'static str) { + self.catalog.visit(v, "datafusion.catalog", ""); + self.execution.visit(v, "datafusion.execution", ""); + self.optimizer.visit(v, "datafusion.optimizer", ""); + self.explain.visit(v, "datafusion.explain", ""); + self.sql_parser.visit(v, "datafusion.sql_parser", ""); + self.format.visit(v, "datafusion.format", ""); + } + fn set(&mut self, key: &str, value: &str) -> Result<()> { // Extensions are handled in the public `ConfigOptions::set` let (key, rem) = key.split_once('.').unwrap_or((key, "")); @@ -868,16 +1332,50 @@ impl ConfigField for ConfigOptions { } } - fn visit(&self, v: &mut V, _key_prefix: &str, _description: &'static str) { - self.catalog.visit(v, "datafusion.catalog", ""); - self.execution.visit(v, "datafusion.execution", ""); - self.optimizer.visit(v, "datafusion.optimizer", ""); - self.explain.visit(v, "datafusion.explain", ""); - self.sql_parser.visit(v, "datafusion.sql_parser", ""); - self.format.visit(v, "datafusion.format", ""); + /// Reset a configuration option back to its default value + fn reset(&mut self, key: &str) -> Result<()> { + let Some((prefix, rest)) = key.split_once('.') else { + return _config_err!("could not find config namespace for key \"{key}\""); + }; + + if prefix != "datafusion" { + return _config_err!("Could not find config namespace \"{prefix}\""); + } + + let (section, rem) = rest.split_once('.').unwrap_or((rest, "")); + if rem.is_empty() { + return _config_err!("could not find config field for key \"{key}\""); + } + + match section { + "catalog" => self.catalog.reset(rem), + "execution" => self.execution.reset(rem), + "optimizer" => { + if rem == "enable_dynamic_filter_pushdown" { + let defaults = OptimizerOptions::default(); + self.optimizer.enable_dynamic_filter_pushdown = + defaults.enable_dynamic_filter_pushdown; + self.optimizer.enable_topk_dynamic_filter_pushdown = + defaults.enable_topk_dynamic_filter_pushdown; + self.optimizer.enable_join_dynamic_filter_pushdown = + defaults.enable_join_dynamic_filter_pushdown; + Ok(()) + } else { + self.optimizer.reset(rem) + } + } + "explain" => self.explain.reset(rem), + "sql_parser" => self.sql_parser.reset(rem), + "format" => self.format.reset(rem), + other => _config_err!("Config value \"{other}\" not found on ConfigOptions"), + } } } +/// This namespace is reserved for interacting with Foreign Function Interface +/// (FFI) based configuration extensions. +pub const DATAFUSION_FFI_CONFIG_NAMESPACE: &str = "datafusion_ffi"; + impl ConfigOptions { /// Creates a new [`ConfigOptions`] with default values pub fn new() -> Self { @@ -892,25 +1390,62 @@ impl ConfigOptions { /// Set a configuration option pub fn set(&mut self, key: &str, value: &str) -> Result<()> { - let Some((prefix, key)) = key.split_once('.') else { + let Some((mut prefix, mut inner_key)) = key.split_once('.') else { return _config_err!("could not find config namespace for key \"{key}\""); }; if prefix == "datafusion" { - return ConfigField::set(self, key, value); + if inner_key == "optimizer.enable_dynamic_filter_pushdown" { + let bool_value = value.parse::().map_err(|e| { + DataFusionError::Configuration(format!( + "Failed to parse '{value}' as bool: {e}", + )) + })?; + + { + self.optimizer.enable_dynamic_filter_pushdown = bool_value; + self.optimizer.enable_topk_dynamic_filter_pushdown = bool_value; + self.optimizer.enable_join_dynamic_filter_pushdown = bool_value; + self.optimizer.enable_aggregate_dynamic_filter_pushdown = bool_value; + } + return Ok(()); + } + return ConfigField::set(self, inner_key, value); + } + + if !self.extensions.0.contains_key(prefix) + && self + .extensions + .0 + .contains_key(DATAFUSION_FFI_CONFIG_NAMESPACE) + { + inner_key = key; + prefix = DATAFUSION_FFI_CONFIG_NAMESPACE; } let Some(e) = self.extensions.0.get_mut(prefix) else { return _config_err!("Could not find config namespace \"{prefix}\""); }; - e.0.set(key, value) + e.0.set(inner_key, value) } - /// Create new ConfigOptions struct, taking values from - /// environment variables where possible. + /// Create new [`ConfigOptions`], taking values from environment variables + /// where possible. + /// + /// For example, to configure `datafusion.execution.batch_size` + /// ([`ExecutionOptions::batch_size`]) you would set the + /// `DATAFUSION_EXECUTION_BATCH_SIZE` environment variable. /// - /// For example, setting `DATAFUSION_EXECUTION_BATCH_SIZE` will - /// control `datafusion.execution.batch_size`. + /// The name of the environment variable is the option's key, transformed to + /// uppercase and with periods replaced with underscores. + /// + /// Values are parsed according to the [same rules used in casts from + /// Utf8](https://docs.rs/arrow/latest/arrow/compute/kernels/cast/fn.cast.html). + /// + /// If the value in the environment variable cannot be cast to the type of + /// the configuration option, the default value will be used instead and a + /// warning emitted. Environment variables are read when this method is + /// called, and are not re-read later. pub fn from_env() -> Result { struct Visitor(Vec); @@ -1046,36 +1581,35 @@ impl ConfigOptions { /// # Example /// ``` /// use datafusion_common::{ -/// config::ConfigExtension, extensions_options, -/// config::ConfigOptions, +/// config::ConfigExtension, config::ConfigOptions, extensions_options, /// }; -/// // Define a new configuration struct using the `extensions_options` macro -/// extensions_options! { -/// /// My own config options. -/// pub struct MyConfig { -/// /// Should "foo" be replaced by "bar"? -/// pub foo_to_bar: bool, default = true +/// // Define a new configuration struct using the `extensions_options` macro +/// extensions_options! { +/// /// My own config options. +/// pub struct MyConfig { +/// /// Should "foo" be replaced by "bar"? +/// pub foo_to_bar: bool, default = true /// -/// /// How many "baz" should be created? -/// pub baz_count: usize, default = 1337 -/// } -/// } +/// /// How many "baz" should be created? +/// pub baz_count: usize, default = 1337 +/// } +/// } /// -/// impl ConfigExtension for MyConfig { +/// impl ConfigExtension for MyConfig { /// const PREFIX: &'static str = "my_config"; -/// } +/// } /// -/// // set up config struct and register extension -/// let mut config = ConfigOptions::default(); -/// config.extensions.insert(MyConfig::default()); +/// // set up config struct and register extension +/// let mut config = ConfigOptions::default(); +/// config.extensions.insert(MyConfig::default()); /// -/// // overwrite config default -/// config.set("my_config.baz_count", "42").unwrap(); +/// // overwrite config default +/// config.set("my_config.baz_count", "42").unwrap(); /// -/// // check config state -/// let my_config = config.extensions.get::().unwrap(); -/// assert!(my_config.foo_to_bar,); -/// assert_eq!(my_config.baz_count, 42,); +/// // check config state +/// let my_config = config.extensions.get::().unwrap(); +/// assert!(my_config.foo_to_bar,); +/// assert_eq!(my_config.baz_count, 42,); /// ``` /// /// # Note: @@ -1142,6 +1676,14 @@ impl Extensions { let e = self.0.get_mut(T::PREFIX)?; e.0.as_any_mut().downcast_mut() } + + /// Iterates all the config extension entries yielding their prefix and their + /// [ExtensionOptions] implementation. + pub fn iter( + &self, + ) -> impl Iterator)> { + self.0.iter().map(|(k, v)| (*k, &v.0)) + } } #[derive(Debug)] @@ -1159,6 +1701,10 @@ pub trait ConfigField { fn visit(&self, v: &mut V, key: &str, description: &'static str); fn set(&mut self, key: &str, value: &str) -> Result<()>; + + fn reset(&mut self, key: &str) -> Result<()> { + _config_err!("Reset is not supported for this config field, key: {}", key) + } } impl ConfigField for Option { @@ -1172,9 +1718,21 @@ impl ConfigField for Option { fn set(&mut self, key: &str, value: &str) -> Result<()> { self.get_or_insert_with(Default::default).set(key, value) } + + fn reset(&mut self, key: &str) -> Result<()> { + if key.is_empty() { + *self = Default::default(); + Ok(()) + } else { + self.get_or_insert_with(Default::default).reset(key) + } + } } -fn default_transform(input: &str) -> Result +/// Default transformation to parse a [`ConfigField`] for a string. +/// +/// This uses [`FromStr`] to parse the data. +pub fn default_config_transform(input: &str) -> Result where T: FromStr, ::Err: Sync + Send + Error + 'static, @@ -1191,31 +1749,71 @@ where }) } +/// Macro that generates [`ConfigField`] for a given type. +/// +/// # Usage +/// This always requires [`Display`] to be implemented for the given type. +/// +/// There are two ways to invoke this macro. The first one uses +/// [`default_config_transform`]/[`FromStr`] to parse the data: +/// +/// ```ignore +/// config_field(MyType); +/// ``` +/// +/// Note that the parsing error MUST implement [`std::error::Error`]! +/// +/// Or you can specify how you want to parse an [`str`] into the type: +/// +/// ```ignore +/// fn parse_it(s: &str) -> Result { +/// ... +/// } +/// +/// config_field( +/// MyType, +/// value => parse_it(value) +/// ); +/// ``` #[macro_export] macro_rules! config_field { ($t:ty) => { - config_field!($t, value => default_transform(value)?); + config_field!($t, value => $crate::config::default_config_transform(value)?); }; ($t:ty, $arg:ident => $transform:expr) => { - impl ConfigField for $t { - fn visit(&self, v: &mut V, key: &str, description: &'static str) { + impl $crate::config::ConfigField for $t { + fn visit(&self, v: &mut V, key: &str, description: &'static str) { v.some(key, self, description) } - fn set(&mut self, _: &str, $arg: &str) -> Result<()> { + fn set(&mut self, _: &str, $arg: &str) -> $crate::error::Result<()> { *self = $transform; Ok(()) } + + fn reset(&mut self, key: &str) -> $crate::error::Result<()> { + if key.is_empty() { + *self = <$t as Default>::default(); + Ok(()) + } else { + $crate::error::_config_err!( + "Config field is a scalar {} and does not have nested field \"{}\"", + stringify!($t), + key + ) + } + } } }; } config_field!(String); -config_field!(bool, value => default_transform(value.to_lowercase().as_str())?); +config_field!(bool, value => default_config_transform(value.to_lowercase().as_str())?); config_field!(usize); config_field!(f64); config_field!(u64); +config_field!(u32); impl ConfigField for u8 { fn visit(&self, v: &mut V, key: &str, description: &'static str) { @@ -1406,8 +2004,7 @@ macro_rules! extensions_options { // Safely apply deprecated attribute if present // $(#[allow(deprecated)])? { - #[allow(deprecated)] - self.$field_name.set(rem, value.as_ref()) + self.$field_name.set(rem, value.as_ref()) } }, )* @@ -1421,7 +2018,6 @@ macro_rules! extensions_options { $( let key = stringify!($field_name).to_string(); let desc = concat!($($d),*).trim(); - #[allow(deprecated)] self.$field_name.visit(v, key.as_str(), desc); )* } @@ -1595,7 +2191,7 @@ impl TableOptions { /// /// A result indicating success or failure in setting the configuration option. pub fn set(&mut self, key: &str, value: &str) -> Result<()> { - let Some((prefix, _)) = key.split_once('.') else { + let Some((mut prefix, _)) = key.split_once('.') else { return _config_err!("could not find config namespace for key \"{key}\""); }; @@ -1607,6 +2203,15 @@ impl TableOptions { return Ok(()); } + if !self.extensions.0.contains_key(prefix) + && self + .extensions + .0 + .contains_key(DATAFUSION_FFI_CONFIG_NAMESPACE) + { + prefix = DATAFUSION_FFI_CONFIG_NAMESPACE; + } + let Some(e) = self.extensions.0.get_mut(prefix) else { return _config_err!("Could not find config namespace \"{prefix}\""); }; @@ -1692,7 +2297,7 @@ impl TableOptions { /// Options that control how Parquet files are read, including global options /// that apply to all columns and optional column-specific overrides /// -/// Closely tied to [`ParquetWriterOptions`](crate::file_options::parquet_writer::ParquetWriterOptions). +/// Closely tied to `ParquetWriterOptions` (see `crate::file_options::parquet_writer::ParquetWriterOptions` when the "parquet" feature is enabled). /// Properties not included in [`TableParquetOptions`] may not be configurable at the external API /// (e.g. sorting_columns). #[derive(Clone, Default, Debug, PartialEq)] @@ -1716,6 +2321,26 @@ pub struct TableParquetOptions { /// ) /// ``` pub key_value_metadata: HashMap>, + /// Options for configuring Parquet modular encryption + /// + /// To use Parquet encryption, you must enable the `parquet_encryption` feature flag, as it is not activated by default. + /// See ConfigFileEncryptionProperties and ConfigFileDecryptionProperties in datafusion/common/src/config.rs + /// These can be set via 'format.crypto', for example: + /// ```sql + /// OPTIONS ( + /// 'format.crypto.file_encryption.encrypt_footer' 'true', + /// 'format.crypto.file_encryption.footer_key_as_hex' '30313233343536373839303132333435', -- b"0123456789012345" */ + /// 'format.crypto.file_encryption.column_key_as_hex::double_field' '31323334353637383930313233343530', -- b"1234567890123450" + /// 'format.crypto.file_encryption.column_key_as_hex::float_field' '31323334353637383930313233343531', -- b"1234567890123451" + /// -- Same for decryption + /// 'format.crypto.file_decryption.footer_key_as_hex' '30313233343536373839303132333435', -- b"0123456789012345" + /// 'format.crypto.file_decryption.column_key_as_hex::double_field' '31323334353637383930313233343530', -- b"1234567890123450" + /// 'format.crypto.file_decryption.column_key_as_hex::float_field' '31323334353637383930313233343531', -- b"1234567890123451" + /// ) + /// ``` + /// See datafusion-cli/tests/sql/encrypted_parquet.sql for a more complete example. + /// Note that keys must be provided as in hex format since these are binary strings. + pub crypto: ParquetEncryptionOptions, } impl TableParquetOptions { @@ -1737,13 +2362,52 @@ impl TableParquetOptions { ..self } } + + /// Retrieves all configuration entries from this `TableParquetOptions`. + /// + /// # Returns + /// + /// A vector of `ConfigEntry` instances, representing all the configuration options within this + pub fn entries(self: &TableParquetOptions) -> Vec { + struct Visitor(Vec); + + impl Visit for Visitor { + fn some( + &mut self, + key: &str, + value: V, + description: &'static str, + ) { + self.0.push(ConfigEntry { + key: key[1..].to_string(), + value: Some(value.to_string()), + description, + }) + } + + fn none(&mut self, key: &str, description: &'static str) { + self.0.push(ConfigEntry { + key: key[1..].to_string(), + value: None, + description, + }) + } + } + + let mut v = Visitor(vec![]); + self.visit(&mut v, "", ""); + + v.0 + } } impl ConfigField for TableParquetOptions { fn visit(&self, v: &mut V, key_prefix: &str, description: &'static str) { self.global.visit(v, key_prefix, description); self.column_specific_options - .visit(v, key_prefix, description) + .visit(v, key_prefix, description); + self.crypto + .visit(v, &format!("{key_prefix}.crypto"), description); } fn set(&mut self, key: &str, value: &str) -> Result<()> { @@ -1753,17 +2417,19 @@ impl ConfigField for TableParquetOptions { [_meta] | [_meta, ""] => { return _config_err!( "Invalid metadata key provided, missing key in metadata::" - ) + ); } [_meta, k] => k.into(), _ => { return _config_err!( "Invalid metadata key provided, found too many '::' in \"{key}\"" - ) + ); } }; self.key_value_metadata.insert(k, Some(value.into())); Ok(()) + } else if let Some(crypto_feature) = key.strip_prefix("crypto.") { + self.crypto.set(crypto_feature, value) } else if key.contains("::") { self.column_specific_options.set(key, value) } else { @@ -1803,7 +2469,6 @@ macro_rules! config_namespace_with_hashmap { $( stringify!($field_name) => { // Handle deprecated fields - #[allow(deprecated)] // Allow deprecated fields $(let value = $transform(value);)? self.$field_name.set(rem, value.as_ref()) }, @@ -1819,7 +2484,6 @@ macro_rules! config_namespace_with_hashmap { let key = format!(concat!("{}.", stringify!($field_name)), key_prefix); let desc = concat!($($d),*).trim(); // Handle deprecated fields - #[allow(deprecated)] self.$field_name.visit(v, key.as_str(), desc); )* } @@ -1827,7 +2491,6 @@ macro_rules! config_namespace_with_hashmap { impl Default for $struct_name { fn default() -> Self { - #[allow(deprecated)] Self { $($field_name: $default),* } @@ -1855,7 +2518,6 @@ macro_rules! config_namespace_with_hashmap { $( let key = format!("{}.{field}::{}", key_prefix, column_name, field = stringify!($field_name)); let desc = concat!($($d),*).trim(); - #[allow(deprecated)] col_options.$field_name.visit(v, key.as_str(), desc); )* } @@ -1886,7 +2548,7 @@ config_namespace_with_hashmap! { /// Sets default parquet compression codec for the column path. /// Valid values are: uncompressed, snappy, gzip(level), - /// lzo, brotli(level), lz4, zstd(level), and lz4_raw. + /// brotli(level), lz4, zstd(level), and lz4_raw. /// These values are not case-sensitive. If NULL, uses /// default parquet options pub compression: Option, transform = str::to_lowercase, default = None @@ -1904,13 +2566,352 @@ config_namespace_with_hashmap! { /// Sets bloom filter number of distinct values. If NULL, uses /// default parquet options pub bloom_filter_ndv: Option, default = None + } +} - /// Sets max statistics size for the column path. If NULL, uses - /// default parquet options - /// max_statistics_size is deprecated, currently it is not being used - // TODO: remove once deprecated - #[deprecated(since = "45.0.0", note = "Setting does not do anything")] - pub max_statistics_size: Option, default = None +#[derive(Clone, Debug, PartialEq)] +pub struct ConfigFileEncryptionProperties { + /// Should the parquet footer be encrypted + /// default is true + pub encrypt_footer: bool, + /// Key to use for the parquet footer encoded in hex format + pub footer_key_as_hex: String, + /// Metadata information for footer key + pub footer_key_metadata_as_hex: String, + /// HashMap of column names --> (key in hex format, metadata) + pub column_encryption_properties: HashMap, + /// AAD prefix string uniquely identifies the file and prevents file swapping + pub aad_prefix_as_hex: String, + /// If true, store the AAD prefix in the file + /// default is false + pub store_aad_prefix: bool, +} + +// Setup to match EncryptionPropertiesBuilder::new() +impl Default for ConfigFileEncryptionProperties { + fn default() -> Self { + ConfigFileEncryptionProperties { + encrypt_footer: true, + footer_key_as_hex: String::new(), + footer_key_metadata_as_hex: String::new(), + column_encryption_properties: Default::default(), + aad_prefix_as_hex: String::new(), + store_aad_prefix: false, + } + } +} + +config_namespace_with_hashmap! { + pub struct ColumnEncryptionProperties { + /// Per column encryption key + pub column_key_as_hex: String, default = "".to_string() + /// Per column encryption key metadata + pub column_metadata_as_hex: Option, default = None + } +} + +impl ConfigField for ConfigFileEncryptionProperties { + fn visit(&self, v: &mut V, key_prefix: &str, _description: &'static str) { + let key = format!("{key_prefix}.encrypt_footer"); + let desc = "Encrypt the footer"; + self.encrypt_footer.visit(v, key.as_str(), desc); + + let key = format!("{key_prefix}.footer_key_as_hex"); + let desc = "Key to use for the parquet footer"; + self.footer_key_as_hex.visit(v, key.as_str(), desc); + + let key = format!("{key_prefix}.footer_key_metadata_as_hex"); + let desc = "Metadata to use for the parquet footer"; + self.footer_key_metadata_as_hex.visit(v, key.as_str(), desc); + + self.column_encryption_properties.visit(v, key_prefix, desc); + + let key = format!("{key_prefix}.aad_prefix_as_hex"); + let desc = "AAD prefix to use"; + self.aad_prefix_as_hex.visit(v, key.as_str(), desc); + + let key = format!("{key_prefix}.store_aad_prefix"); + let desc = "If true, store the AAD prefix"; + self.store_aad_prefix.visit(v, key.as_str(), desc); + + self.aad_prefix_as_hex.visit(v, key.as_str(), desc); + } + + fn set(&mut self, key: &str, value: &str) -> Result<()> { + // Any hex encoded values must be pre-encoded using + // hex::encode() before calling set. + + if key.contains("::") { + // Handle any column specific properties + return self.column_encryption_properties.set(key, value); + }; + + let (key, rem) = key.split_once('.').unwrap_or((key, "")); + match key { + "encrypt_footer" => self.encrypt_footer.set(rem, value.as_ref()), + "footer_key_as_hex" => self.footer_key_as_hex.set(rem, value.as_ref()), + "footer_key_metadata_as_hex" => { + self.footer_key_metadata_as_hex.set(rem, value.as_ref()) + } + "aad_prefix_as_hex" => self.aad_prefix_as_hex.set(rem, value.as_ref()), + "store_aad_prefix" => self.store_aad_prefix.set(rem, value.as_ref()), + _ => _config_err!( + "Config value \"{}\" not found on ConfigFileEncryptionProperties", + key + ), + } + } +} + +#[cfg(feature = "parquet_encryption")] +impl From for FileEncryptionProperties { + fn from(val: ConfigFileEncryptionProperties) -> Self { + let mut fep = FileEncryptionProperties::builder( + hex::decode(val.footer_key_as_hex).unwrap(), + ) + .with_plaintext_footer(!val.encrypt_footer) + .with_aad_prefix_storage(val.store_aad_prefix); + + if !val.footer_key_metadata_as_hex.is_empty() { + fep = fep.with_footer_key_metadata( + hex::decode(&val.footer_key_metadata_as_hex) + .expect("Invalid footer key metadata"), + ); + } + + for (column_name, encryption_props) in val.column_encryption_properties.iter() { + let encryption_key = hex::decode(&encryption_props.column_key_as_hex) + .expect("Invalid column encryption key"); + let key_metadata = encryption_props + .column_metadata_as_hex + .as_ref() + .map(|x| hex::decode(x).expect("Invalid column metadata")); + match key_metadata { + Some(key_metadata) => { + fep = fep.with_column_key_and_metadata( + column_name, + encryption_key, + key_metadata, + ); + } + None => { + fep = fep.with_column_key(column_name, encryption_key); + } + } + } + + if !val.aad_prefix_as_hex.is_empty() { + let aad_prefix: Vec = + hex::decode(&val.aad_prefix_as_hex).expect("Invalid AAD prefix"); + fep = fep.with_aad_prefix(aad_prefix); + } + Arc::unwrap_or_clone(fep.build().unwrap()) + } +} + +#[cfg(feature = "parquet_encryption")] +impl From<&Arc> for ConfigFileEncryptionProperties { + fn from(f: &Arc) -> Self { + let (column_names_vec, column_keys_vec, column_metas_vec) = f.column_keys(); + + let mut column_encryption_properties: HashMap< + String, + ColumnEncryptionProperties, + > = HashMap::new(); + + for (i, column_name) in column_names_vec.iter().enumerate() { + let column_key_as_hex = hex::encode(&column_keys_vec[i]); + let column_metadata_as_hex: Option = + column_metas_vec.get(i).map(hex::encode); + column_encryption_properties.insert( + column_name.clone(), + ColumnEncryptionProperties { + column_key_as_hex, + column_metadata_as_hex, + }, + ); + } + let aad_prefix = f.aad_prefix().cloned().unwrap_or_default(); + ConfigFileEncryptionProperties { + encrypt_footer: f.encrypt_footer(), + footer_key_as_hex: hex::encode(f.footer_key()), + footer_key_metadata_as_hex: f + .footer_key_metadata() + .map(hex::encode) + .unwrap_or_default(), + column_encryption_properties, + aad_prefix_as_hex: hex::encode(aad_prefix), + store_aad_prefix: f.store_aad_prefix(), + } + } +} + +#[derive(Clone, Debug, PartialEq)] +pub struct ConfigFileDecryptionProperties { + /// Binary string to use for the parquet footer encoded in hex format + pub footer_key_as_hex: String, + /// HashMap of column names --> key in hex format + pub column_decryption_properties: HashMap, + /// AAD prefix string uniquely identifies the file and prevents file swapping + pub aad_prefix_as_hex: String, + /// If true, then verify signature for files with plaintext footers. + /// default = true + pub footer_signature_verification: bool, +} + +config_namespace_with_hashmap! { + pub struct ColumnDecryptionProperties { + /// Per column encryption key + pub column_key_as_hex: String, default = "".to_string() + } +} + +// Setup to match DecryptionPropertiesBuilder::new() +impl Default for ConfigFileDecryptionProperties { + fn default() -> Self { + ConfigFileDecryptionProperties { + footer_key_as_hex: String::new(), + column_decryption_properties: Default::default(), + aad_prefix_as_hex: String::new(), + footer_signature_verification: true, + } + } +} + +impl ConfigField for ConfigFileDecryptionProperties { + fn visit(&self, v: &mut V, key_prefix: &str, _description: &'static str) { + let key = format!("{key_prefix}.footer_key_as_hex"); + let desc = "Key to use for the parquet footer"; + self.footer_key_as_hex.visit(v, key.as_str(), desc); + + let key = format!("{key_prefix}.aad_prefix_as_hex"); + let desc = "AAD prefix to use"; + self.aad_prefix_as_hex.visit(v, key.as_str(), desc); + + let key = format!("{key_prefix}.footer_signature_verification"); + let desc = "If true, verify the footer signature"; + self.footer_signature_verification + .visit(v, key.as_str(), desc); + + self.column_decryption_properties.visit(v, key_prefix, desc); + } + + fn set(&mut self, key: &str, value: &str) -> Result<()> { + // Any hex encoded values must be pre-encoded using + // hex::encode() before calling set. + + if key.contains("::") { + // Handle any column specific properties + return self.column_decryption_properties.set(key, value); + }; + + let (key, rem) = key.split_once('.').unwrap_or((key, "")); + match key { + "footer_key_as_hex" => self.footer_key_as_hex.set(rem, value.as_ref()), + "aad_prefix_as_hex" => self.aad_prefix_as_hex.set(rem, value.as_ref()), + "footer_signature_verification" => { + self.footer_signature_verification.set(rem, value.as_ref()) + } + _ => _config_err!( + "Config value \"{}\" not found on ConfigFileDecryptionProperties", + key + ), + } + } +} + +#[cfg(feature = "parquet_encryption")] +impl From for FileDecryptionProperties { + fn from(val: ConfigFileDecryptionProperties) -> Self { + let mut column_names: Vec<&str> = Vec::new(); + let mut column_keys: Vec> = Vec::new(); + + for (col_name, decryption_properties) in val.column_decryption_properties.iter() { + column_names.push(col_name.as_str()); + column_keys.push( + hex::decode(&decryption_properties.column_key_as_hex) + .expect("Invalid column decryption key"), + ); + } + + let mut fep = FileDecryptionProperties::builder( + hex::decode(val.footer_key_as_hex).expect("Invalid footer key"), + ) + .with_column_keys(column_names, column_keys) + .unwrap(); + + if !val.footer_signature_verification { + fep = fep.disable_footer_signature_verification(); + } + + if !val.aad_prefix_as_hex.is_empty() { + let aad_prefix = + hex::decode(&val.aad_prefix_as_hex).expect("Invalid AAD prefix"); + fep = fep.with_aad_prefix(aad_prefix); + } + + Arc::unwrap_or_clone(fep.build().unwrap()) + } +} + +#[cfg(feature = "parquet_encryption")] +impl From<&Arc> for ConfigFileDecryptionProperties { + fn from(f: &Arc) -> Self { + let (column_names_vec, column_keys_vec) = f.column_keys(); + let mut column_decryption_properties: HashMap< + String, + ColumnDecryptionProperties, + > = HashMap::new(); + for (i, column_name) in column_names_vec.iter().enumerate() { + let props = ColumnDecryptionProperties { + column_key_as_hex: hex::encode(column_keys_vec[i].clone()), + }; + column_decryption_properties.insert(column_name.clone(), props); + } + + let aad_prefix = f.aad_prefix().cloned().unwrap_or_default(); + ConfigFileDecryptionProperties { + footer_key_as_hex: hex::encode( + f.footer_key(None).unwrap_or_default().as_ref(), + ), + column_decryption_properties, + aad_prefix_as_hex: hex::encode(aad_prefix), + footer_signature_verification: f.check_plaintext_footer_integrity(), + } + } +} + +/// Holds implementation-specific options for an encryption factory +#[derive(Clone, Debug, Default, PartialEq)] +pub struct EncryptionFactoryOptions { + pub options: HashMap, +} + +impl ConfigField for EncryptionFactoryOptions { + fn visit(&self, v: &mut V, key: &str, _description: &'static str) { + for (option_key, option_value) in &self.options { + v.some( + &format!("{key}.{option_key}"), + option_value, + "Encryption factory specific option", + ); + } + } + + fn set(&mut self, key: &str, value: &str) -> Result<()> { + self.options.insert(key.to_owned(), value.to_owned()); + Ok(()) + } +} + +impl EncryptionFactoryOptions { + /// Convert these encryption factory options to an [`ExtensionOptions`] instance. + pub fn to_extension_options(&self) -> Result { + let mut options = T::default(); + for (key, value) in &self.options { + options.set(key, value)?; + } + Ok(options) } } @@ -1935,6 +2936,14 @@ config_namespace! { /// The default behaviour depends on the `datafusion.catalog.newlines_in_values` setting. pub newlines_in_values: Option, default = None pub compression: CompressionTypeVariant, default = CompressionTypeVariant::UNCOMPRESSED + /// Compression level for the output file. The valid range depends on the + /// compression algorithm: + /// - ZSTD: 1 to 22 (default: 3) + /// - GZIP: 0 to 9 (default: 6) + /// - BZIP2: 0 to 9 (default: 6) + /// - XZ: 0 to 9 (default: 6) + /// If not specified, the default level for the compression algorithm is used. + pub compression_level: Option, default = None pub schema_infer_max_rec: Option, default = None pub date_format: Option, default = None pub datetime_format: Option, default = None @@ -1946,6 +2955,16 @@ config_namespace! { // The input regex for Nulls when loading CSVs. pub null_regex: Option, default = None pub comment: Option, default = None + /// Whether to allow truncated rows when parsing, both within a single file and across files. + /// + /// When set to false (default), reading a single CSV file which has rows of different lengths will + /// error; if reading multiple CSV files with different number of columns, it will also fail. + /// + /// When set to true, reading a single CSV file with rows of different lengths will pad the truncated + /// rows with null values for the missing columns; if reading multiple CSV files with different number + /// of columns, it creates a union schema containing all columns found across the files, and will + /// pad any files missing columns with null values for their rows. + pub truncated_rows: Option, default = None } } @@ -2038,6 +3057,23 @@ impl CsvOptions { self } + /// Whether to allow truncated rows when parsing. + /// By default this is set to false and will error if the CSV rows have different lengths. + /// When set to true then it will allow records with less than the expected number of columns and fill the missing columns with nulls. + /// If the record’s schema is not nullable, then it will still return an error. + pub fn with_truncated_rows(mut self, allow: bool) -> Self { + self.truncated_rows = Some(allow); + self + } + + /// Set the compression level for the output file. + /// The valid range depends on the compression algorithm. + /// If not specified, the default level for the algorithm is used. + pub fn with_compression_level(mut self, level: u32) -> Self { + self.compression_level = Some(level); + self + } + /// The delimiter character. pub fn delimiter(&self) -> u8 { self.delimiter @@ -2063,14 +3099,38 @@ config_namespace! { /// Options controlling JSON format pub struct JsonOptions { pub compression: CompressionTypeVariant, default = CompressionTypeVariant::UNCOMPRESSED + /// Compression level for the output file. The valid range depends on the + /// compression algorithm: + /// - ZSTD: 1 to 22 (default: 3) + /// - GZIP: 0 to 9 (default: 6) + /// - BZIP2: 0 to 9 (default: 6) + /// - XZ: 0 to 9 (default: 6) + /// If not specified, the default level for the compression algorithm is used. + pub compression_level: Option, default = None pub schema_infer_max_rec: Option, default = None + /// The JSON format to use when reading files. + /// + /// When `true` (default), expects newline-delimited JSON (NDJSON): + /// ```text + /// {"key1": 1, "key2": "val"} + /// {"key1": 2, "key2": "vals"} + /// ``` + /// + /// When `false`, expects JSON array format: + /// ```text + /// [ + /// {"key1": 1, "key2": "val"}, + /// {"key1": 2, "key2": "vals"} + /// ] + /// ``` + pub newline_delimited: bool, default = true } } pub trait OutputFormatExt: Display {} #[derive(Debug, Clone, PartialEq)] -#[allow(clippy::large_enum_variant)] +#[cfg_attr(feature = "parquet", expect(clippy::large_enum_variant))] pub enum OutputFormat { CSV(CsvOptions), JSON(JsonOptions), @@ -2096,13 +3156,14 @@ impl Display for OutputFormat { #[cfg(test)] mod tests { - use std::any::Any; - use std::collections::HashMap; - + #[cfg(feature = "parquet")] + use crate::config::TableParquetOptions; use crate::config::{ ConfigEntry, ConfigExtension, ConfigField, ConfigFileType, ExtensionOptions, Extensions, TableOptions, }; + use std::any::Any; + use std::collections::HashMap; #[derive(Default, Debug, Clone)] pub struct TestExtensionConfig { @@ -2174,6 +3235,16 @@ mod tests { ); } + #[test] + fn iter_test_extension_config() { + let mut extension = Extensions::new(); + extension.insert(TestExtensionConfig::default()); + let table_config = TableOptions::new().with_extensions(extension); + let extensions = table_config.extensions.iter().collect::>(); + assert_eq!(extensions.len(), 1); + assert_eq!(extensions[0].0, TestExtensionConfig::PREFIX); + } + #[test] fn csv_u8_table_options() { let mut table_config = TableOptions::new(); @@ -2217,6 +3288,19 @@ mod tests { assert_eq!(COUNT.load(std::sync::atomic::Ordering::Relaxed), 1); } + #[test] + fn reset_nested_scalar_reports_helpful_error() { + let mut value = true; + let err = ::reset(&mut value, "nested").unwrap_err(); + let message = err.to_string(); + assert!( + message.starts_with( + "Invalid or Unsupported Configuration: Config field is a scalar bool and does not have nested field \"nested\"" + ), + "unexpected error message: {message}" + ); + } + #[cfg(feature = "parquet")] #[test] fn parquet_table_options() { @@ -2231,6 +3315,159 @@ mod tests { ); } + #[cfg(feature = "parquet_encryption")] + #[test] + fn parquet_table_encryption() { + use crate::config::{ + ConfigFileDecryptionProperties, ConfigFileEncryptionProperties, + }; + use parquet::encryption::decrypt::FileDecryptionProperties; + use parquet::encryption::encrypt::FileEncryptionProperties; + use std::sync::Arc; + + let footer_key = b"0123456789012345".to_vec(); // 128bit/16 + let column_names = vec!["double_field", "float_field"]; + let column_keys = + vec![b"1234567890123450".to_vec(), b"1234567890123451".to_vec()]; + + let file_encryption_properties = + FileEncryptionProperties::builder(footer_key.clone()) + .with_column_keys(column_names.clone(), column_keys.clone()) + .unwrap() + .build() + .unwrap(); + + let decryption_properties = FileDecryptionProperties::builder(footer_key.clone()) + .with_column_keys(column_names.clone(), column_keys.clone()) + .unwrap() + .build() + .unwrap(); + + // Test round-trip + let config_encrypt = + ConfigFileEncryptionProperties::from(&file_encryption_properties); + let encryption_properties_built = + Arc::new(FileEncryptionProperties::from(config_encrypt.clone())); + assert_eq!(file_encryption_properties, encryption_properties_built); + + let config_decrypt = ConfigFileDecryptionProperties::from(&decryption_properties); + let decryption_properties_built = + Arc::new(FileDecryptionProperties::from(config_decrypt.clone())); + assert_eq!(decryption_properties, decryption_properties_built); + + /////////////////////////////////////////////////////////////////////////////////// + // Test encryption config + + // Display original encryption config + // println!("{:#?}", config_encrypt); + + let mut table_config = TableOptions::new(); + table_config.set_config_format(ConfigFileType::PARQUET); + table_config + .parquet + .set( + "crypto.file_encryption.encrypt_footer", + config_encrypt.encrypt_footer.to_string().as_str(), + ) + .unwrap(); + table_config + .parquet + .set( + "crypto.file_encryption.footer_key_as_hex", + config_encrypt.footer_key_as_hex.as_str(), + ) + .unwrap(); + + for (i, col_name) in column_names.iter().enumerate() { + let key = format!("crypto.file_encryption.column_key_as_hex::{col_name}"); + let value = hex::encode(column_keys[i].clone()); + table_config + .parquet + .set(key.as_str(), value.as_str()) + .unwrap(); + } + + // Print matching final encryption config + // println!("{:#?}", table_config.parquet.crypto.file_encryption); + + assert_eq!( + table_config.parquet.crypto.file_encryption, + Some(config_encrypt) + ); + + /////////////////////////////////////////////////////////////////////////////////// + // Test decryption config + + // Display original decryption config + // println!("{:#?}", config_decrypt); + + let mut table_config = TableOptions::new(); + table_config.set_config_format(ConfigFileType::PARQUET); + table_config + .parquet + .set( + "crypto.file_decryption.footer_key_as_hex", + config_decrypt.footer_key_as_hex.as_str(), + ) + .unwrap(); + + for (i, col_name) in column_names.iter().enumerate() { + let key = format!("crypto.file_decryption.column_key_as_hex::{col_name}"); + let value = hex::encode(column_keys[i].clone()); + table_config + .parquet + .set(key.as_str(), value.as_str()) + .unwrap(); + } + + // Print matching final decryption config + // println!("{:#?}", table_config.parquet.crypto.file_decryption); + + assert_eq!( + table_config.parquet.crypto.file_decryption, + Some(config_decrypt.clone()) + ); + + // Set config directly + let mut table_config = TableOptions::new(); + table_config.set_config_format(ConfigFileType::PARQUET); + table_config.parquet.crypto.file_decryption = Some(config_decrypt.clone()); + assert_eq!( + table_config.parquet.crypto.file_decryption, + Some(config_decrypt.clone()) + ); + } + + #[cfg(feature = "parquet_encryption")] + #[test] + fn parquet_encryption_factory_config() { + let mut parquet_options = TableParquetOptions::default(); + + assert_eq!(parquet_options.crypto.factory_id, None); + assert_eq!(parquet_options.crypto.factory_options.options.len(), 0); + + let mut input_config = TestExtensionConfig::default(); + input_config + .properties + .insert("key1".to_string(), "value 1".to_string()); + input_config + .properties + .insert("key2".to_string(), "value 2".to_string()); + + parquet_options + .crypto + .configure_factory("example_factory", &input_config); + + assert_eq!( + parquet_options.crypto.factory_id, + Some("example_factory".to_string()) + ); + let factory_options = &parquet_options.crypto.factory_options.options; + assert_eq!(factory_options.len(), 2); + assert_eq!(factory_options.get("key1"), Some(&"value 1".to_string())); + assert_eq!(factory_options.get("key2"), Some(&"value 2".to_string())); + } + #[cfg(feature = "parquet")] #[test] fn parquet_table_options_config_entry() { @@ -2240,9 +3477,28 @@ mod tests { .set("format.bloom_filter_enabled::col1", "true") .unwrap(); let entries = table_config.entries(); - assert!(entries - .iter() - .any(|item| item.key == "format.bloom_filter_enabled::col1")) + assert!( + entries + .iter() + .any(|item| item.key == "format.bloom_filter_enabled::col1") + ) + } + + #[cfg(feature = "parquet")] + #[test] + fn parquet_table_parquet_options_config_entry() { + let mut table_parquet_options = TableParquetOptions::new(); + table_parquet_options + .set( + "crypto.file_encryption.column_key_as_hex::double_field", + "31323334353637383930313233343530", + ) + .unwrap(); + let entries = table_parquet_options.entries(); + assert!( + entries.iter().any(|item| item.key + == "crypto.file_encryption.column_key_as_hex::double_field") + ) } #[cfg(feature = "parquet")] @@ -2278,4 +3534,37 @@ mod tests { let parsed_metadata = table_config.parquet.key_value_metadata; assert_eq!(parsed_metadata.get("key_dupe"), Some(&Some("B".into()))); } + #[cfg(feature = "parquet")] + #[test] + fn test_parquet_writer_version_validation() { + use crate::{config::ConfigOptions, parquet_config::DFParquetWriterVersion}; + + let mut config = ConfigOptions::default(); + + // Valid values should work + config + .set("datafusion.execution.parquet.writer_version", "1.0") + .unwrap(); + assert_eq!( + config.execution.parquet.writer_version, + DFParquetWriterVersion::V1_0 + ); + + config + .set("datafusion.execution.parquet.writer_version", "2.0") + .unwrap(); + assert_eq!( + config.execution.parquet.writer_version, + DFParquetWriterVersion::V2_0 + ); + + // Invalid value should error immediately at SET time + let err = config + .set("datafusion.execution.parquet.writer_version", "3.0") + .unwrap_err(); + assert_eq!( + err.to_string(), + "Invalid or Unsupported Configuration: Invalid parquet writer version: 3.0. Expected one of: 1.0, 2.0" + ); + } } diff --git a/datafusion/common/src/cse.rs b/datafusion/common/src/cse.rs index 674d3386171f8..93169d6a02ff1 100644 --- a/datafusion/common/src/cse.rs +++ b/datafusion/common/src/cse.rs @@ -19,12 +19,12 @@ //! a [`CSEController`], that defines how to eliminate common subtrees from a particular //! [`TreeNode`] tree. +use crate::Result; use crate::hash_utils::combine_hashes; use crate::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, }; -use crate::Result; use indexmap::IndexMap; use std::collections::HashMap; use std::hash::{BuildHasher, Hash, Hasher, RandomState}; @@ -676,13 +676,13 @@ where #[cfg(test)] mod test { + use crate::Result; use crate::alias::AliasGenerator; use crate::cse::{ - CSEController, HashNode, IdArray, Identifier, NodeStats, NormalizeEq, - Normalizeable, CSE, + CSE, CSEController, HashNode, IdArray, Identifier, NodeStats, NormalizeEq, + Normalizeable, }; use crate::tree_node::tests::TestTreeNode; - use crate::Result; use std::collections::HashSet; use std::hash::{Hash, Hasher}; diff --git a/datafusion/common/src/datatype.rs b/datafusion/common/src/datatype.rs new file mode 100644 index 0000000000000..19847f8583505 --- /dev/null +++ b/datafusion/common/src/datatype.rs @@ -0,0 +1,273 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`DataTypeExt`] and [`FieldExt`] extension trait for working with Arrow [`DataType`] and [`Field`]s + +use crate::arrow::datatypes::{DataType, Field, FieldRef}; +use crate::metadata::FieldMetadata; +use std::sync::Arc; + +/// DataFusion extension methods for Arrow [`DataType`] +pub trait DataTypeExt { + /// Convert the type to field with nullable type and "" name + /// + /// This is used to track the places where we convert a [`DataType`] + /// into a nameless field to interact with an API that is + /// capable of representing an extension type and/or nullability. + /// + /// For example, it will convert a `DataType::Int32` into + /// `Field::new("", DataType::Int32, true)`. + /// + /// ``` + /// # use datafusion_common::datatype::DataTypeExt; + /// # use arrow::datatypes::DataType; + /// let dt = DataType::Utf8; + /// let field = dt.into_nullable_field(); + /// // result is a nullable Utf8 field with "" name + /// assert_eq!(field.name(), ""); + /// assert_eq!(field.data_type(), &DataType::Utf8); + /// assert!(field.is_nullable()); + /// ``` + fn into_nullable_field(self) -> Field; + + /// Convert the type to [`FieldRef`] with nullable type and "" name + /// + /// Concise wrapper around [`DataTypeExt::into_nullable_field`] that + /// constructs a [`FieldRef`]. + fn into_nullable_field_ref(self) -> FieldRef; +} + +impl DataTypeExt for DataType { + fn into_nullable_field(self) -> Field { + Field::new("", self, true) + } + + fn into_nullable_field_ref(self) -> FieldRef { + Arc::new(Field::new("", self, true)) + } +} + +/// DataFusion extension methods for Arrow [`Field`] and [`FieldRef`] +/// +/// This trait is implemented for both [`Field`] and [`FieldRef`] and +/// provides convenience methods for efficiently working with both types. +/// +/// For [`FieldRef`], the methods will attempt to unwrap the `Arc` +/// to avoid unnecessary cloning when possible. +pub trait FieldExt { + /// Ensure the field is named `new_name`, returning the given field if the + /// name matches, and a new field if not. + /// + /// This method avoids `clone`ing fields and names if the name is the same + /// as the field's existing name. + /// + /// Example: + /// ``` + /// # use std::sync::Arc; + /// # use arrow::datatypes::{DataType, Field}; + /// # use datafusion_common::datatype::FieldExt; + /// let int_field = Field::new("my_int", DataType::Int32, true); + /// // rename to "your_int" + /// let renamed_field = int_field.renamed("your_int"); + /// assert_eq!(renamed_field.name(), "your_int"); + /// ``` + fn renamed(self, new_name: &str) -> Self; + + /// Ensure the field has the given data type + /// + /// Note this is different than simply calling [`Field::with_data_type`] as + /// it avoids copying if the data type is already the same. + /// + /// Example: + /// ``` + /// # use std::sync::Arc; + /// # use arrow::datatypes::{DataType, Field}; + /// # use datafusion_common::datatype::FieldExt; + /// let int_field = Field::new("my_int", DataType::Int32, true); + /// // change to Float64 + /// let retyped_field = int_field.retyped(DataType::Float64); + /// assert_eq!(retyped_field.data_type(), &DataType::Float64); + /// ``` + fn retyped(self, new_data_type: DataType) -> Self; + + /// Add field metadata to the Field + fn with_field_metadata(self, metadata: &FieldMetadata) -> Self; + + /// Add optional field metadata, + fn with_field_metadata_opt(self, metadata: Option<&FieldMetadata>) -> Self; + + /// Returns a new Field representing a List of this Field's DataType. + /// + /// For example if input represents an `Int32`, the return value will + /// represent a `List`. + /// + /// Example: + /// ``` + /// # use std::sync::Arc; + /// # use arrow::datatypes::{DataType, Field}; + /// # use datafusion_common::datatype::FieldExt; + /// // Int32 field + /// let int_field = Field::new("my_int", DataType::Int32, true); + /// // convert to a List field + /// let list_field = int_field.into_list(); + /// // List + /// // Note that the item field name has been renamed to "item" + /// assert_eq!(list_field.data_type(), &DataType::List(Arc::new( + /// Field::new("item", DataType::Int32, true) + /// ))); + fn into_list(self) -> Self; + + /// Return a new Field representing this Field as the item type of a + /// [`DataType::FixedSizeList`] + /// + /// For example if input represents an `Int32`, the return value will + /// represent a `FixedSizeList`. + /// + /// Example: + /// ``` + /// # use std::sync::Arc; + /// # use arrow::datatypes::{DataType, Field}; + /// # use datafusion_common::datatype::FieldExt; + /// // Int32 field + /// let int_field = Field::new("my_int", DataType::Int32, true); + /// // convert to a FixedSizeList field of size 3 + /// let fixed_size_list_field = int_field.into_fixed_size_list(3); + /// // FixedSizeList + /// // Note that the item field name has been renamed to "item" + /// assert_eq!( + /// fixed_size_list_field.data_type(), + /// &DataType::FixedSizeList(Arc::new( + /// Field::new("item", DataType::Int32, true)), + /// 3 + /// )); + fn into_fixed_size_list(self, list_size: i32) -> Self; + + /// Update the field to have the default list field name ("item") + /// + /// Lists are allowed to have an arbitrarily named field; however, a name + /// other than 'item' will cause it to fail an == check against a more + /// idiomatically created list in arrow-rs which causes issues. + /// + /// For example, if input represents an `Int32` field named "my_int", + /// the return value will represent an `Int32` field named "item". + /// + /// Example: + /// ``` + /// # use arrow::datatypes::Field; + /// # use datafusion_common::datatype::FieldExt; + /// let my_field = Field::new("my_int", arrow::datatypes::DataType::Int32, true); + /// let item_field = my_field.into_list_item(); + /// assert_eq!(item_field.name(), Field::LIST_FIELD_DEFAULT_NAME); + /// assert_eq!(item_field.name(), "item"); + /// ``` + fn into_list_item(self) -> Self; +} + +impl FieldExt for Field { + fn renamed(self, new_name: &str) -> Self { + // check if this is a new name before allocating a new Field / copying + // the existing one + if self.name() != new_name { + self.with_name(new_name) + } else { + self + } + } + + fn retyped(self, new_data_type: DataType) -> Self { + self.with_data_type(new_data_type) + } + + fn with_field_metadata(self, metadata: &FieldMetadata) -> Self { + metadata.add_to_field(self) + } + + fn with_field_metadata_opt(self, metadata: Option<&FieldMetadata>) -> Self { + if let Some(metadata) = metadata { + self.with_field_metadata(metadata) + } else { + self + } + } + + fn into_list(self) -> Self { + DataType::List(Arc::new(self.into_list_item())).into_nullable_field() + } + + fn into_fixed_size_list(self, list_size: i32) -> Self { + DataType::FixedSizeList(self.into_list_item().into(), list_size) + .into_nullable_field() + } + + fn into_list_item(self) -> Self { + if self.name() != Field::LIST_FIELD_DEFAULT_NAME { + self.with_name(Field::LIST_FIELD_DEFAULT_NAME) + } else { + self + } + } +} + +impl FieldExt for Arc { + fn renamed(mut self, new_name: &str) -> Self { + if self.name() != new_name { + // avoid cloning if possible + Arc::make_mut(&mut self).set_name(new_name); + } + self + } + + fn retyped(mut self, new_data_type: DataType) -> Self { + if self.data_type() != &new_data_type { + // avoid cloning if possible + Arc::make_mut(&mut self).set_data_type(new_data_type); + } + self + } + + fn with_field_metadata(self, metadata: &FieldMetadata) -> Self { + metadata.add_to_field_ref(self) + } + + fn with_field_metadata_opt(self, metadata: Option<&FieldMetadata>) -> Self { + if let Some(metadata) = metadata { + self.with_field_metadata(metadata) + } else { + self + } + } + + fn into_list(self) -> Self { + DataType::List(self.into_list_item()) + .into_nullable_field() + .into() + } + + fn into_fixed_size_list(self, list_size: i32) -> Self { + DataType::FixedSizeList(self.into_list_item(), list_size) + .into_nullable_field() + .into() + } + + fn into_list_item(mut self) -> Self { + if self.name() != Field::LIST_FIELD_DEFAULT_NAME { + // avoid cloning if possible + Arc::make_mut(&mut self).set_name(Field::LIST_FIELD_DEFAULT_NAME); + } + self + } +} diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 804e14bf72fb0..de0aacf9e8bcd 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -23,10 +23,10 @@ use std::fmt::{Display, Formatter}; use std::hash::Hash; use std::sync::Arc; -use crate::error::{DataFusionError, Result, _plan_err, _schema_err}; +use crate::error::{_plan_err, _schema_err, DataFusionError, Result}; use crate::{ - field_not_found, unqualified_field_not_found, Column, FunctionalDependencies, - SchemaError, TableReference, + Column, FunctionalDependencies, SchemaError, TableReference, field_not_found, + unqualified_field_not_found, }; use arrow::compute::can_cast_types; @@ -37,7 +37,7 @@ use arrow::datatypes::{ /// A reference-counted reference to a [DFSchema]. pub type DFSchemaRef = Arc; -/// DFSchema wraps an Arrow schema and adds relation names. +/// DFSchema wraps an Arrow schema and add a relation (table) name. /// /// The schema may hold the fields across multiple tables. Some fields may be /// qualified and some unqualified. A qualified field is a field that has a @@ -47,8 +47,14 @@ pub type DFSchemaRef = Arc; /// have a distinct name from any qualified field names. This allows finding a /// qualified field by name to be possible, so long as there aren't multiple /// qualified fields with the same name. +///] +/// # See Also +/// * [DFSchemaRef], an alias to `Arc` +/// * [DataTypeExt], common methods for working with Arrow [DataType]s +/// * [FieldExt], extension methods for working with Arrow [Field]s /// -/// There is an alias to `Arc` named [DFSchemaRef]. +/// [DataTypeExt]: crate::datatype::DataTypeExt +/// [FieldExt]: crate::datatype::FieldExt /// /// # Creating qualified schemas /// @@ -56,12 +62,10 @@ pub type DFSchemaRef = Arc; /// an Arrow schema. /// /// ```rust -/// use datafusion_common::{DFSchema, Column}; /// use arrow::datatypes::{DataType, Field, Schema}; +/// use datafusion_common::{Column, DFSchema}; /// -/// let arrow_schema = Schema::new(vec![ -/// Field::new("c1", DataType::Int32, false), -/// ]); +/// let arrow_schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); /// /// let df_schema = DFSchema::try_from_qualified_schema("t1", &arrow_schema).unwrap(); /// let column = Column::from_qualified_name("t1.c1"); @@ -77,12 +81,10 @@ pub type DFSchemaRef = Arc; /// Create an unqualified schema using TryFrom: /// /// ```rust -/// use datafusion_common::{DFSchema, Column}; /// use arrow::datatypes::{DataType, Field, Schema}; +/// use datafusion_common::{Column, DFSchema}; /// -/// let arrow_schema = Schema::new(vec![ -/// Field::new("c1", DataType::Int32, false), -/// ]); +/// let arrow_schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); /// /// let df_schema = DFSchema::try_from(arrow_schema).unwrap(); /// let column = Column::new_unqualified("c1"); @@ -94,14 +96,16 @@ pub type DFSchemaRef = Arc; /// Use the `Into` trait to convert `DFSchema` into an Arrow schema: /// /// ```rust +/// use arrow::datatypes::{Field, Schema}; /// use datafusion_common::DFSchema; -/// use arrow::datatypes::{Schema, Field}; /// use std::collections::HashMap; /// -/// let df_schema = DFSchema::from_unqualified_fields(vec![ -/// Field::new("c1", arrow::datatypes::DataType::Int32, false), -/// ].into(),HashMap::new()).unwrap(); -/// let schema = Schema::from(df_schema); +/// let df_schema = DFSchema::from_unqualified_fields( +/// vec![Field::new("c1", arrow::datatypes::DataType::Int32, false)].into(), +/// HashMap::new(), +/// ) +/// .unwrap(); +/// let schema: &Schema = df_schema.as_arrow(); /// assert_eq!(schema.fields().len(), 1); /// ``` #[derive(Debug, Clone, PartialEq, Eq)] @@ -206,6 +210,25 @@ impl DFSchema { Ok(dfschema) } + /// Return the same schema, where all fields have a given qualifier. + pub fn with_field_specific_qualified_schema( + &self, + qualifiers: Vec>, + ) -> Result { + if qualifiers.len() != self.fields().len() { + return _plan_err!( + "Number of qualifiers must match number of fields. Expected {}, got {}", + self.fields().len(), + qualifiers.len() + ); + } + Ok(DFSchema { + inner: Arc::clone(&self.inner), + field_qualifiers: qualifiers, + functional_dependencies: self.functional_dependencies.clone(), + }) + } + /// Check if the schema have some fields with the same name pub fn check_names(&self) -> Result<()> { let mut qualified_names = BTreeSet::new(); @@ -229,7 +252,7 @@ impl DFSchema { for (qualifier, name) in qualified_names { if unqualified_names.contains(name) { return _schema_err!(SchemaError::AmbiguousReference { - field: Column::new(Some(qualifier.clone()), name) + field: Box::new(Column::new(Some(qualifier.clone()), name)) }); } } @@ -278,6 +301,20 @@ impl DFSchema { /// Modify this schema by appending the fields from the supplied schema, ignoring any /// duplicate fields. + /// + /// ## Merge Precedence + /// + /// **Schema-level metadata**: Metadata from both schemas is merged. + /// If both schemas have the same metadata key, the value from the `other_schema` parameter takes precedence. + /// + /// **Field-level merging**: Only non-duplicate fields are added. This means that the + /// `self` fields will always take precedence over the `other_schema` fields. + /// Duplicate field detection is based on: + /// - For qualified fields: both qualifier and field name must match + /// - For unqualified fields: only field name needs to match + /// + /// Take note how the precedence for fields & metadata merging differs; + /// merging prefers fields from `self` but prefers metadata from `other_schema`. pub fn merge(&mut self, other_schema: &DFSchema) { if other_schema.inner.fields.is_empty() { return; @@ -315,20 +352,22 @@ impl DFSchema { self.field_qualifiers.extend(qualifiers); } - /// Get a list of fields + /// Get a list of fields for this schema pub fn fields(&self) -> &Fields { &self.inner.fields } - /// Returns an immutable reference of a specific `Field` instance selected using an - /// offset within the internal `fields` vector - pub fn field(&self, i: usize) -> &Field { + /// Returns a reference to [`FieldRef`] for a column at specific index + /// within the schema. + /// + /// See also [Self::qualified_field] to get both qualifier and field + pub fn field(&self, i: usize) -> &FieldRef { &self.inner.fields[i] } - /// Returns an immutable reference of a specific `Field` instance selected using an - /// offset within the internal `fields` vector and its qualifier - pub fn qualified_field(&self, i: usize) -> (Option<&TableReference>, &Field) { + /// Returns the qualifier (if any) and [`FieldRef`] for a column at specific + /// index within the schema. + pub fn qualified_field(&self, i: usize) -> (Option<&TableReference>, &FieldRef) { (self.field_qualifiers[i].as_ref(), self.field(i)) } @@ -379,12 +418,12 @@ impl DFSchema { .is_some() } - /// Find the field with the given name + /// Find the [`FieldRef`] with the given name and optional qualifier pub fn field_with_name( &self, qualifier: Option<&TableReference>, name: &str, - ) -> Result<&Field> { + ) -> Result<&FieldRef> { if let Some(qualifier) = qualifier { self.field_with_qualified_name(qualifier, name) } else { @@ -397,7 +436,7 @@ impl DFSchema { &self, qualifier: Option<&TableReference>, name: &str, - ) -> Result<(Option<&TableReference>, &Field)> { + ) -> Result<(Option<&TableReference>, &FieldRef)> { if let Some(qualifier) = qualifier { let idx = self .index_of_column_by_name(Some(qualifier), name) @@ -409,10 +448,10 @@ impl DFSchema { } /// Find all fields having the given qualifier - pub fn fields_with_qualified(&self, qualifier: &TableReference) -> Vec<&Field> { + pub fn fields_with_qualified(&self, qualifier: &TableReference) -> Vec<&FieldRef> { self.iter() .filter(|(q, _)| q.map(|q| q.eq(qualifier)).unwrap_or(false)) - .map(|(_, f)| f.as_ref()) + .map(|(_, f)| f) .collect() } @@ -428,11 +467,10 @@ impl DFSchema { } /// Find all fields that match the given name - pub fn fields_with_unqualified_name(&self, name: &str) -> Vec<&Field> { + pub fn fields_with_unqualified_name(&self, name: &str) -> Vec<&FieldRef> { self.fields() .iter() .filter(|field| field.name() == name) - .map(|f| f.as_ref()) .collect() } @@ -440,10 +478,9 @@ impl DFSchema { pub fn qualified_fields_with_unqualified_name( &self, name: &str, - ) -> Vec<(Option<&TableReference>, &Field)> { + ) -> Vec<(Option<&TableReference>, &FieldRef)> { self.iter() .filter(|(_, field)| field.name() == name) - .map(|(qualifier, field)| (qualifier, field.as_ref())) .collect() } @@ -468,7 +505,7 @@ impl DFSchema { pub fn qualified_field_with_unqualified_name( &self, name: &str, - ) -> Result<(Option<&TableReference>, &Field)> { + ) -> Result<(Option<&TableReference>, &FieldRef)> { let matches = self.qualified_fields_with_unqualified_name(name); match matches.len() { 0 => Err(unqualified_field_not_found(name, self)), @@ -489,7 +526,7 @@ impl DFSchema { Ok((fields_without_qualifier[0].0, fields_without_qualifier[0].1)) } else { _schema_err!(SchemaError::AmbiguousReference { - field: Column::new_unqualified(name.to_string(),), + field: Box::new(Column::new_unqualified(name.to_string())) }) } } @@ -497,7 +534,7 @@ impl DFSchema { } /// Find the field with the given name - pub fn field_with_unqualified_name(&self, name: &str) -> Result<&Field> { + pub fn field_with_unqualified_name(&self, name: &str) -> Result<&FieldRef> { self.qualified_field_with_unqualified_name(name) .map(|(_, field)| field) } @@ -507,7 +544,7 @@ impl DFSchema { &self, qualifier: &TableReference, name: &str, - ) -> Result<&Field> { + ) -> Result<&FieldRef> { let idx = self .index_of_column_by_name(Some(qualifier), name) .ok_or_else(|| field_not_found(Some(qualifier.clone()), name, self))?; @@ -519,7 +556,7 @@ impl DFSchema { pub fn qualified_field_from_column( &self, column: &Column, - ) -> Result<(Option<&TableReference>, &Field)> { + ) -> Result<(Option<&TableReference>, &FieldRef)> { self.qualified_field_with_name(column.relation.as_ref(), &column.name) } @@ -561,7 +598,7 @@ impl DFSchema { &self, arrow_schema: &Schema, ) -> Result<()> { - let self_arrow_schema: Schema = self.into(); + let self_arrow_schema = self.as_arrow(); self_arrow_schema .fields() .iter() @@ -636,8 +673,8 @@ impl DFSchema { )) { _plan_err!( - "Schema mismatch: Expected field '{}' with type {:?}, \ - but got '{}' with type {:?}.", + "Schema mismatch: Expected field '{}' with type {}, \ + but got '{}' with type {}.", f1.name(), f1.data_type(), f2.name(), @@ -661,10 +698,12 @@ impl DFSchema { // check nested fields match (dt1, dt2) { (DataType::Dictionary(_, v1), DataType::Dictionary(_, v2)) => { - v1.as_ref() == v2.as_ref() + Self::datatype_is_logically_equal(v1.as_ref(), v2.as_ref()) + } + (DataType::Dictionary(_, v1), othertype) + | (othertype, DataType::Dictionary(_, v1)) => { + Self::datatype_is_logically_equal(v1.as_ref(), othertype) } - (DataType::Dictionary(_, v1), othertype) => v1.as_ref() == othertype, - (othertype, DataType::Dictionary(_, v1)) => v1.as_ref() == othertype, (DataType::List(f1), DataType::List(f2)) | (DataType::LargeList(f1), DataType::LargeList(f2)) | (DataType::FixedSizeList(f1, _), DataType::FixedSizeList(f2, _)) => { @@ -714,7 +753,8 @@ impl DFSchema { } /// Returns true of two [`DataType`]s are semantically equal (same - /// name and type), ignoring both metadata and nullability, and decimal precision/scale. + /// name and type), ignoring both metadata and nullability, decimal precision/scale, + /// and timezone time units/timezones. /// /// request to upstream: pub fn datatype_is_semantically_equal(dt1: &DataType, dt2: &DataType) -> bool { @@ -765,6 +805,14 @@ impl DFSchema { .zip(iter2) .all(|((t1, f1), (t2, f2))| t1 == t2 && Self::field_is_semantically_equal(f1, f2)) } + ( + DataType::Decimal32(_l_precision, _l_scale), + DataType::Decimal32(_r_precision, _r_scale), + ) => true, + ( + DataType::Decimal64(_l_precision, _l_scale), + DataType::Decimal64(_r_precision, _r_scale), + ) => true, ( DataType::Decimal128(_l_precision, _l_scale), DataType::Decimal128(_r_precision, _r_scale), @@ -773,6 +821,10 @@ impl DFSchema { DataType::Decimal256(_l_precision, _l_scale), DataType::Decimal256(_r_precision, _r_scale), ) => true, + ( + DataType::Timestamp(_l_time_unit, _l_timezone), + DataType::Timestamp(_r_time_unit, _r_timezone), + ) => true, _ => dt1 == dt2, } } @@ -830,21 +882,216 @@ impl DFSchema { .zip(self.inner.fields().iter()) .map(|(qualifier, field)| (qualifier.as_ref(), field)) } + /// Returns a tree-like string representation of the schema. + /// + /// This method formats the schema + /// with a tree-like structure showing field names, types, and nullability. + /// + /// # Example + /// + /// ``` + /// use arrow::datatypes::{DataType, Field, Schema}; + /// use datafusion_common::DFSchema; + /// use std::collections::HashMap; + /// + /// let schema = DFSchema::from_unqualified_fields( + /// vec![ + /// Field::new("id", DataType::Int32, false), + /// Field::new("name", DataType::Utf8, true), + /// ] + /// .into(), + /// HashMap::new(), + /// ) + /// .unwrap(); + /// + /// assert_eq!( + /// schema.tree_string().to_string(), + /// r#"root + /// |-- id: int32 (nullable = false) + /// |-- name: utf8 (nullable = true)"# + /// ); + /// ``` + pub fn tree_string(&self) -> impl Display + '_ { + let mut result = String::from("root\n"); + + for (qualifier, field) in self.iter() { + let field_name = match qualifier { + Some(q) => format!("{}.{}", q, field.name()), + None => field.name().to_string(), + }; + + format_field_with_indent( + &mut result, + &field_name, + field.data_type(), + field.is_nullable(), + " ", + ); + } + + // Remove the trailing newline + if result.ends_with('\n') { + result.pop(); + } + + result + } } -impl From for Schema { - /// Convert DFSchema into a Schema - fn from(df_schema: DFSchema) -> Self { - let fields: Fields = df_schema.inner.fields.clone(); - Schema::new_with_metadata(fields, df_schema.inner.metadata.clone()) +/// Format field with proper nested indentation for complex types +fn format_field_with_indent( + result: &mut String, + field_name: &str, + data_type: &DataType, + nullable: bool, + indent: &str, +) { + let nullable_str = nullable.to_string().to_lowercase(); + let child_indent = format!("{indent}| "); + + match data_type { + DataType::List(field) => { + result.push_str(&format!( + "{indent}|-- {field_name}: list (nullable = {nullable_str})\n" + )); + format_field_with_indent( + result, + field.name(), + field.data_type(), + field.is_nullable(), + &child_indent, + ); + } + DataType::LargeList(field) => { + result.push_str(&format!( + "{indent}|-- {field_name}: large list (nullable = {nullable_str})\n" + )); + format_field_with_indent( + result, + field.name(), + field.data_type(), + field.is_nullable(), + &child_indent, + ); + } + DataType::FixedSizeList(field, _size) => { + result.push_str(&format!( + "{indent}|-- {field_name}: fixed size list (nullable = {nullable_str})\n" + )); + format_field_with_indent( + result, + field.name(), + field.data_type(), + field.is_nullable(), + &child_indent, + ); + } + DataType::Map(field, _) => { + result.push_str(&format!( + "{indent}|-- {field_name}: map (nullable = {nullable_str})\n" + )); + if let DataType::Struct(inner_fields) = field.data_type() + && inner_fields.len() == 2 + { + format_field_with_indent( + result, + "key", + inner_fields[0].data_type(), + inner_fields[0].is_nullable(), + &child_indent, + ); + let value_contains_null = field.is_nullable().to_string().to_lowercase(); + // Handle complex value types properly + match inner_fields[1].data_type() { + DataType::Struct(_) + | DataType::List(_) + | DataType::LargeList(_) + | DataType::FixedSizeList(_, _) + | DataType::Map(_, _) => { + format_field_with_indent( + result, + "value", + inner_fields[1].data_type(), + inner_fields[1].is_nullable(), + &child_indent, + ); + } + _ => { + result.push_str(&format!("{child_indent}|-- value: {} (nullable = {value_contains_null})\n", + format_simple_data_type(inner_fields[1].data_type()))); + } + } + } + } + DataType::Struct(fields) => { + result.push_str(&format!( + "{indent}|-- {field_name}: struct (nullable = {nullable_str})\n" + )); + for struct_field in fields { + format_field_with_indent( + result, + struct_field.name(), + struct_field.data_type(), + struct_field.is_nullable(), + &child_indent, + ); + } + } + _ => { + let type_str = format_simple_data_type(data_type); + result.push_str(&format!( + "{indent}|-- {field_name}: {type_str} (nullable = {nullable_str})\n" + )); + } } } -impl From<&DFSchema> for Schema { - /// Convert DFSchema reference into a Schema - fn from(df_schema: &DFSchema) -> Self { - let fields: Fields = df_schema.inner.fields.clone(); - Schema::new_with_metadata(fields, df_schema.inner.metadata.clone()) +/// Format simple DataType in lowercase format (for leaf nodes) +fn format_simple_data_type(data_type: &DataType) -> String { + match data_type { + DataType::Boolean => "boolean".to_string(), + DataType::Int8 => "int8".to_string(), + DataType::Int16 => "int16".to_string(), + DataType::Int32 => "int32".to_string(), + DataType::Int64 => "int64".to_string(), + DataType::UInt8 => "uint8".to_string(), + DataType::UInt16 => "uint16".to_string(), + DataType::UInt32 => "uint32".to_string(), + DataType::UInt64 => "uint64".to_string(), + DataType::Float16 => "float16".to_string(), + DataType::Float32 => "float32".to_string(), + DataType::Float64 => "float64".to_string(), + DataType::Utf8 => "utf8".to_string(), + DataType::LargeUtf8 => "large_utf8".to_string(), + DataType::Binary => "binary".to_string(), + DataType::LargeBinary => "large_binary".to_string(), + DataType::FixedSizeBinary(_) => "fixed_size_binary".to_string(), + DataType::Date32 => "date32".to_string(), + DataType::Date64 => "date64".to_string(), + DataType::Time32(_) => "time32".to_string(), + DataType::Time64(_) => "time64".to_string(), + DataType::Timestamp(_, tz) => match tz { + Some(tz_str) => format!("timestamp ({tz_str})"), + None => "timestamp".to_string(), + }, + DataType::Interval(_) => "interval".to_string(), + DataType::Dictionary(_, value_type) => { + format_simple_data_type(value_type.as_ref()) + } + DataType::Decimal32(precision, scale) => { + format!("decimal32({precision}, {scale})") + } + DataType::Decimal64(precision, scale) => { + format!("decimal64({precision}, {scale})") + } + DataType::Decimal128(precision, scale) => { + format!("decimal128({precision}, {scale})") + } + DataType::Decimal256(precision, scale) => { + format!("decimal256({precision}, {scale})") + } + DataType::Null => "null".to_string(), + _ => format!("{data_type}").to_lowercase(), } } @@ -880,13 +1127,18 @@ impl TryFrom for DFSchema { field_qualifiers: vec![None; field_count], functional_dependencies: FunctionalDependencies::empty(), }; + // Without checking names, because schema here may have duplicate field names. + // For example, Partial AggregateMode will generate duplicate field names from + // state_fields. + // See + // dfschema.check_names()?; Ok(dfschema) } } impl From for SchemaRef { - fn from(df_schema: DFSchema) -> Self { - SchemaRef::new(df_schema.into()) + fn from(dfschema: DFSchema) -> Self { + Arc::clone(&dfschema.inner) } } @@ -982,7 +1234,7 @@ pub trait ExprSchema: std::fmt::Debug { } // Return the column's field - fn field_from_column(&self, col: &Column) -> Result<&Field>; + fn field_from_column(&self, col: &Column) -> Result<&FieldRef>; } // Implement `ExprSchema` for `Arc` @@ -1003,13 +1255,13 @@ impl + std::fmt::Debug> ExprSchema for P { self.as_ref().data_type_and_nullable(col) } - fn field_from_column(&self, col: &Column) -> Result<&Field> { + fn field_from_column(&self, col: &Column) -> Result<&FieldRef> { self.as_ref().field_from_column(col) } } impl ExprSchema for DFSchema { - fn field_from_column(&self, col: &Column) -> Result<&Field> { + fn field_from_column(&self, col: &Column) -> Result<&FieldRef> { match &col.relation { Some(r) => self.field_with_qualified_name(r, &col.name), None => self.field_with_unqualified_name(&col.name), @@ -1072,8 +1324,8 @@ impl SchemaExt for Schema { .try_for_each(|(f1, f2)| { if f1.name() != f2.name() || (!DFSchema::datatype_is_logically_equal(f1.data_type(), f2.data_type()) && !can_cast_types(f2.data_type(), f1.data_type())) { _plan_err!( - "Inserting query schema mismatch: Expected table field '{}' with type {:?}, \ - but got '{}' with type {:?}.", + "Inserting query schema mismatch: Expected table field '{}' with type {}, \ + but got '{}' with type {}.", f1.name(), f1.data_type(), f2.name(), @@ -1179,10 +1431,8 @@ mod tests { #[test] fn from_qualified_schema_into_arrow_schema() -> Result<()> { let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; - let arrow_schema: Schema = schema.into(); - let expected = "Field { name: \"c0\", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, \ - Field { name: \"c1\", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }"; - assert_eq!(expected, arrow_schema.to_string()); + let arrow_schema = schema.as_arrow(); + insta::assert_snapshot!(arrow_schema.to_string(), @r#"Field { "c0": nullable Boolean }, Field { "c1": nullable Boolean }"#); Ok(()) } @@ -1196,12 +1446,14 @@ mod tests { join.to_string() ); // test valid access - assert!(join - .field_with_qualified_name(&TableReference::bare("t1"), "c0") - .is_ok()); - assert!(join - .field_with_qualified_name(&TableReference::bare("t2"), "c0") - .is_ok()); + assert!( + join.field_with_qualified_name(&TableReference::bare("t1"), "c0") + .is_ok() + ); + assert!( + join.field_with_qualified_name(&TableReference::bare("t2"), "c0") + .is_ok() + ); // test invalid access assert!(join.field_with_unqualified_name("c0").is_err()); assert!(join.field_with_unqualified_name("t1.c0").is_err()); @@ -1243,18 +1495,20 @@ mod tests { join.to_string() ); // test valid access - assert!(join - .field_with_qualified_name(&TableReference::bare("t1"), "c0") - .is_ok()); + assert!( + join.field_with_qualified_name(&TableReference::bare("t1"), "c0") + .is_ok() + ); assert!(join.field_with_unqualified_name("c0").is_ok()); assert!(join.field_with_unqualified_name("c100").is_ok()); assert!(join.field_with_name(None, "c100").is_ok()); // test invalid access assert!(join.field_with_unqualified_name("t1.c0").is_err()); assert!(join.field_with_unqualified_name("t1.c100").is_err()); - assert!(join - .field_with_qualified_name(&TableReference::bare(""), "c100") - .is_err()); + assert!( + join.field_with_qualified_name(&TableReference::bare(""), "c100") + .is_err() + ); Ok(()) } @@ -1263,9 +1517,11 @@ mod tests { let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; let right = DFSchema::try_from(test_schema_1())?; let join = left.join(&right); - assert_contains!(join.unwrap_err().to_string(), - "Schema error: Schema contains qualified \ - field name t1.c0 and unqualified field name c0 which would be ambiguous"); + assert_contains!( + join.unwrap_err().to_string(), + "Schema error: Schema contains qualified \ + field name t1.c0 and unqualified field name c0 which would be ambiguous" + ); Ok(()) } @@ -1544,6 +1800,27 @@ mod tests { &DataType::Utf8, &DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)) )); + + // Dictionary is logically equal to the logically equivalent value type + assert!(DFSchema::datatype_is_logically_equal( + &DataType::Utf8View, + &DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)) + )); + + assert!(DFSchema::datatype_is_logically_equal( + &DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::List( + Field::new("element", DataType::Utf8, false).into() + )) + ), + &DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::List( + Field::new("element", DataType::Utf8View, false).into() + )) + ) + )); } #[test] @@ -1558,6 +1835,36 @@ mod tests { &DataType::Int16 )); + // Succeeds if decimal precision and scale are different + assert!(DFSchema::datatype_is_semantically_equal( + &DataType::Decimal32(1, 2), + &DataType::Decimal32(2, 1), + )); + + assert!(DFSchema::datatype_is_semantically_equal( + &DataType::Decimal64(1, 2), + &DataType::Decimal64(2, 1), + )); + + assert!(DFSchema::datatype_is_semantically_equal( + &DataType::Decimal128(1, 2), + &DataType::Decimal128(2, 1), + )); + + assert!(DFSchema::datatype_is_semantically_equal( + &DataType::Decimal256(1, 2), + &DataType::Decimal256(2, 1), + )); + + // Any two timestamp types should match + assert!(DFSchema::datatype_is_semantically_equal( + &DataType::Timestamp( + arrow::datatypes::TimeUnit::Microsecond, + Some("UTC".into()) + ), + &DataType::Timestamp(arrow::datatypes::TimeUnit::Millisecond, None), + )); + // Test lists // Succeeds if both have the same element type, disregards names and nullability @@ -1700,4 +2007,488 @@ mod tests { fn test_metadata_n(n: usize) -> HashMap { (0..n).map(|i| (format!("k{i}"), format!("v{i}"))).collect() } + + #[test] + fn test_print_schema_unqualified() { + let schema = DFSchema::from_unqualified_fields( + vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + Field::new("age", DataType::Int64, true), + Field::new("active", DataType::Boolean, false), + ] + .into(), + HashMap::new(), + ) + .unwrap(); + + let output = schema.tree_string(); + + insta::assert_snapshot!(output, @r" + root + |-- id: int32 (nullable = false) + |-- name: utf8 (nullable = true) + |-- age: int64 (nullable = true) + |-- active: boolean (nullable = false) + "); + } + + #[test] + fn test_print_schema_qualified() { + let schema = DFSchema::try_from_qualified_schema( + "table1", + &Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ]), + ) + .unwrap(); + + let output = schema.tree_string(); + + insta::assert_snapshot!(output, @r" + root + |-- table1.id: int32 (nullable = false) + |-- table1.name: utf8 (nullable = true) + "); + } + + #[test] + fn test_print_schema_complex_types() { + let struct_field = Field::new( + "address", + DataType::Struct(Fields::from(vec![ + Field::new("street", DataType::Utf8, true), + Field::new("city", DataType::Utf8, true), + ])), + true, + ); + + let list_field = Field::new( + "tags", + DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), + true, + ); + + let schema = DFSchema::from_unqualified_fields( + vec![ + Field::new("id", DataType::Int32, false), + struct_field, + list_field, + Field::new("score", DataType::Decimal128(10, 2), true), + ] + .into(), + HashMap::new(), + ) + .unwrap(); + + let output = schema.tree_string(); + insta::assert_snapshot!(output, @r" + root + |-- id: int32 (nullable = false) + |-- address: struct (nullable = true) + | |-- street: utf8 (nullable = true) + | |-- city: utf8 (nullable = true) + |-- tags: list (nullable = true) + | |-- item: utf8 (nullable = true) + |-- score: decimal128(10, 2) (nullable = true) + "); + } + + #[test] + fn test_print_schema_empty() { + let schema = DFSchema::empty(); + let output = schema.tree_string(); + insta::assert_snapshot!(output, @"root"); + } + + #[test] + fn test_print_schema_deeply_nested_types() { + // Create a deeply nested structure to test indentation and complex type formatting + let inner_struct = Field::new( + "inner", + DataType::Struct(Fields::from(vec![ + Field::new("level1", DataType::Utf8, true), + Field::new("level2", DataType::Int32, false), + ])), + true, + ); + + let nested_list = Field::new( + "nested_list", + DataType::List(Arc::new(Field::new( + "item", + DataType::Struct(Fields::from(vec![ + Field::new("id", DataType::Int64, false), + Field::new("value", DataType::Float64, true), + ])), + true, + ))), + true, + ); + + let map_field = Field::new( + "map_data", + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new( + "value", + DataType::List(Arc::new(Field::new( + "item", + DataType::Int32, + true, + ))), + true, + ), + ])), + false, + )), + false, + ), + true, + ); + + let schema = DFSchema::from_unqualified_fields( + vec![ + Field::new("simple_field", DataType::Utf8, true), + inner_struct, + nested_list, + map_field, + Field::new( + "timestamp_field", + DataType::Timestamp( + arrow::datatypes::TimeUnit::Microsecond, + Some("UTC".into()), + ), + false, + ), + ] + .into(), + HashMap::new(), + ) + .unwrap(); + + let output = schema.tree_string(); + + insta::assert_snapshot!(output, @r" + root + |-- simple_field: utf8 (nullable = true) + |-- inner: struct (nullable = true) + | |-- level1: utf8 (nullable = true) + | |-- level2: int32 (nullable = false) + |-- nested_list: list (nullable = true) + | |-- item: struct (nullable = true) + | | |-- id: int64 (nullable = false) + | | |-- value: float64 (nullable = true) + |-- map_data: map (nullable = true) + | |-- key: utf8 (nullable = false) + | |-- value: list (nullable = true) + | | |-- item: int32 (nullable = true) + |-- timestamp_field: timestamp (UTC) (nullable = false) + "); + } + + #[test] + fn test_print_schema_mixed_qualified_unqualified() { + // Test a schema with mixed qualified and unqualified fields + let schema = DFSchema::new_with_metadata( + vec![ + ( + Some("table1".into()), + Arc::new(Field::new("id", DataType::Int32, false)), + ), + (None, Arc::new(Field::new("name", DataType::Utf8, true))), + ( + Some("table2".into()), + Arc::new(Field::new("score", DataType::Float64, true)), + ), + ( + None, + Arc::new(Field::new("active", DataType::Boolean, false)), + ), + ], + HashMap::new(), + ) + .unwrap(); + + let output = schema.tree_string(); + + insta::assert_snapshot!(output, @r" + root + |-- table1.id: int32 (nullable = false) + |-- name: utf8 (nullable = true) + |-- table2.score: float64 (nullable = true) + |-- active: boolean (nullable = false) + "); + } + + #[test] + fn test_print_schema_array_of_map() { + // Test the specific example from user feedback: array of map + let map_field = Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Utf8, false), + ])), + false, + ); + + let array_of_map_field = Field::new( + "array_map_field", + DataType::List(Arc::new(Field::new( + "item", + DataType::Map(Arc::new(map_field), false), + false, + ))), + false, + ); + + let schema = DFSchema::from_unqualified_fields( + vec![array_of_map_field].into(), + HashMap::new(), + ) + .unwrap(); + + let output = schema.tree_string(); + + insta::assert_snapshot!(output, @r" + root + |-- array_map_field: list (nullable = false) + | |-- item: map (nullable = false) + | | |-- key: utf8 (nullable = false) + | | |-- value: utf8 (nullable = false) + "); + } + + #[test] + fn test_print_schema_complex_type_combinations() { + // Test various combinations of list, struct, and map types + + // List of structs + let list_of_structs = Field::new( + "list_of_structs", + DataType::List(Arc::new(Field::new( + "item", + DataType::Struct(Fields::from(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + Field::new("score", DataType::Float64, true), + ])), + true, + ))), + true, + ); + + // Struct containing lists + let struct_with_lists = Field::new( + "struct_with_lists", + DataType::Struct(Fields::from(vec![ + Field::new( + "tags", + DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), + true, + ), + Field::new( + "scores", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + false, + ), + Field::new("metadata", DataType::Utf8, true), + ])), + false, + ); + + // Map with struct values + let map_with_struct_values = Field::new( + "map_with_struct_values", + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new( + "value", + DataType::Struct(Fields::from(vec![ + Field::new("count", DataType::Int64, false), + Field::new("active", DataType::Boolean, true), + ])), + true, + ), + ])), + false, + )), + false, + ), + true, + ); + + // List of maps + let list_of_maps = Field::new( + "list_of_maps", + DataType::List(Arc::new(Field::new( + "item", + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Int32, true), + ])), + false, + )), + false, + ), + true, + ))), + true, + ); + + // Deeply nested: struct containing list of structs containing maps + let deeply_nested = Field::new( + "deeply_nested", + DataType::Struct(Fields::from(vec![ + Field::new("level1", DataType::Utf8, true), + Field::new( + "level2", + DataType::List(Arc::new(Field::new( + "item", + DataType::Struct(Fields::from(vec![ + Field::new("id", DataType::Int32, false), + Field::new( + "properties", + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Float64, true), + ])), + false, + )), + false, + ), + true, + ), + ])), + true, + ))), + false, + ), + ])), + true, + ); + + let schema = DFSchema::from_unqualified_fields( + vec![ + list_of_structs, + struct_with_lists, + map_with_struct_values, + list_of_maps, + deeply_nested, + ] + .into(), + HashMap::new(), + ) + .unwrap(); + + let output = schema.tree_string(); + + insta::assert_snapshot!(output, @r" + root + |-- list_of_structs: list (nullable = true) + | |-- item: struct (nullable = true) + | | |-- id: int32 (nullable = false) + | | |-- name: utf8 (nullable = true) + | | |-- score: float64 (nullable = true) + |-- struct_with_lists: struct (nullable = false) + | |-- tags: list (nullable = true) + | | |-- item: utf8 (nullable = true) + | |-- scores: list (nullable = false) + | | |-- item: int32 (nullable = true) + | |-- metadata: utf8 (nullable = true) + |-- map_with_struct_values: map (nullable = true) + | |-- key: utf8 (nullable = false) + | |-- value: struct (nullable = true) + | | |-- count: int64 (nullable = false) + | | |-- active: boolean (nullable = true) + |-- list_of_maps: list (nullable = true) + | |-- item: map (nullable = true) + | | |-- key: utf8 (nullable = false) + | | |-- value: int32 (nullable = false) + |-- deeply_nested: struct (nullable = true) + | |-- level1: utf8 (nullable = true) + | |-- level2: list (nullable = false) + | | |-- item: struct (nullable = true) + | | | |-- id: int32 (nullable = false) + | | | |-- properties: map (nullable = true) + | | | | |-- key: utf8 (nullable = false) + | | | | |-- value: float64 (nullable = false) + "); + } + + #[test] + fn test_print_schema_edge_case_types() { + // Test edge cases and special types + let schema = DFSchema::from_unqualified_fields( + vec![ + Field::new("null_field", DataType::Null, true), + Field::new("binary_field", DataType::Binary, false), + Field::new("large_binary", DataType::LargeBinary, true), + Field::new("large_utf8", DataType::LargeUtf8, false), + Field::new("fixed_size_binary", DataType::FixedSizeBinary(16), true), + Field::new( + "fixed_size_list", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Int32, true)), + 5, + ), + false, + ), + Field::new("decimal32", DataType::Decimal32(9, 4), true), + Field::new("decimal64", DataType::Decimal64(9, 4), true), + Field::new("decimal128", DataType::Decimal128(18, 4), true), + Field::new("decimal256", DataType::Decimal256(38, 10), false), + Field::new("date32", DataType::Date32, true), + Field::new("date64", DataType::Date64, false), + Field::new( + "time32_seconds", + DataType::Time32(arrow::datatypes::TimeUnit::Second), + true, + ), + Field::new( + "time64_nanoseconds", + DataType::Time64(arrow::datatypes::TimeUnit::Nanosecond), + false, + ), + ] + .into(), + HashMap::new(), + ) + .unwrap(); + + let output = schema.tree_string(); + + insta::assert_snapshot!(output, @r" + root + |-- null_field: null (nullable = true) + |-- binary_field: binary (nullable = false) + |-- large_binary: large_binary (nullable = true) + |-- large_utf8: large_utf8 (nullable = false) + |-- fixed_size_binary: fixed_size_binary (nullable = true) + |-- fixed_size_list: fixed size list (nullable = false) + | |-- item: int32 (nullable = true) + |-- decimal32: decimal32(9, 4) (nullable = true) + |-- decimal64: decimal64(9, 4) (nullable = true) + |-- decimal128: decimal128(18, 4) (nullable = true) + |-- decimal256: decimal256(38, 10) (nullable = false) + |-- date32: date32 (nullable = true) + |-- date64: date64 (nullable = false) + |-- time32_seconds: time32 (nullable = true) + |-- time64_nanoseconds: time64 (nullable = false) + "); + } } diff --git a/datafusion/common/src/diagnostic.rs b/datafusion/common/src/diagnostic.rs index 0dce8e6a56eca..b25bf1c12e44a 100644 --- a/datafusion/common/src/diagnostic.rs +++ b/datafusion/common/src/diagnostic.rs @@ -30,8 +30,11 @@ use crate::Span; /// ```rust /// # use datafusion_common::{Location, Span, Diagnostic}; /// let span = Some(Span { -/// start: Location{ line: 2, column: 1 }, -/// end: Location{ line: 4, column: 15 } +/// start: Location { line: 2, column: 1 }, +/// end: Location { +/// line: 4, +/// column: 15, +/// }, /// }); /// let diagnostic = Diagnostic::new_error("Something went wrong", span) /// .with_help("Have you tried turning it on and off again?", None); diff --git a/datafusion/common/src/display/human_readable.rs b/datafusion/common/src/display/human_readable.rs new file mode 100644 index 0000000000000..0e0d677bd8904 --- /dev/null +++ b/datafusion/common/src/display/human_readable.rs @@ -0,0 +1,139 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Helpers for rendering sizes, counts, and durations in human readable form. + +/// Common data size units +pub mod units { + pub const TB: u64 = 1 << 40; + pub const GB: u64 = 1 << 30; + pub const MB: u64 = 1 << 20; + pub const KB: u64 = 1 << 10; +} + +/// Present size in human-readable form +pub fn human_readable_size(size: usize) -> String { + use units::*; + + let size = size as u64; + let (value, unit) = { + if size >= 2 * TB { + (size as f64 / TB as f64, "TB") + } else if size >= 2 * GB { + (size as f64 / GB as f64, "GB") + } else if size >= 2 * MB { + (size as f64 / MB as f64, "MB") + } else if size >= 2 * KB { + (size as f64 / KB as f64, "KB") + } else { + (size as f64, "B") + } + }; + format!("{value:.1} {unit}") +} + +/// Present count in human-readable form with K, M, B, T suffixes +pub fn human_readable_count(count: usize) -> String { + let count = count as u64; + let (value, unit) = { + if count >= 1_000_000_000_000 { + (count as f64 / 1_000_000_000_000.0, " T") + } else if count >= 1_000_000_000 { + (count as f64 / 1_000_000_000.0, " B") + } else if count >= 1_000_000 { + (count as f64 / 1_000_000.0, " M") + } else if count >= 1_000 { + (count as f64 / 1_000.0, " K") + } else { + return count.to_string(); + } + }; + + // Format with appropriate precision + // For values >= 100, show 1 decimal place (e.g., 123.4 K) + // For values < 100, show 2 decimal places (e.g., 10.12 K) + if value >= 100.0 { + format!("{value:.1}{unit}") + } else { + format!("{value:.2}{unit}") + } +} + +/// Present duration in human-readable form with 2 decimal places +pub fn human_readable_duration(nanos: u64) -> String { + const NANOS_PER_SEC: f64 = 1_000_000_000.0; + const NANOS_PER_MILLI: f64 = 1_000_000.0; + const NANOS_PER_MICRO: f64 = 1_000.0; + + let nanos_f64 = nanos as f64; + + if nanos >= 1_000_000_000 { + // >= 1 second: show in seconds + format!("{:.2}s", nanos_f64 / NANOS_PER_SEC) + } else if nanos >= 1_000_000 { + // >= 1 millisecond: show in milliseconds + format!("{:.2}ms", nanos_f64 / NANOS_PER_MILLI) + } else if nanos >= 1_000 { + // >= 1 microsecond: show in microseconds + format!("{:.2}µs", nanos_f64 / NANOS_PER_MICRO) + } else { + // < 1 microsecond: show in nanoseconds + format!("{nanos}ns") + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_human_readable_count() { + assert_eq!(human_readable_count(0), "0"); + assert_eq!(human_readable_count(1), "1"); + assert_eq!(human_readable_count(999), "999"); + assert_eq!(human_readable_count(1_000), "1.00 K"); + assert_eq!(human_readable_count(10_100), "10.10 K"); + assert_eq!(human_readable_count(1_532), "1.53 K"); + assert_eq!(human_readable_count(99_999), "100.00 K"); + assert_eq!(human_readable_count(1_000_000), "1.00 M"); + assert_eq!(human_readable_count(1_532_000), "1.53 M"); + assert_eq!(human_readable_count(99_000_000), "99.00 M"); + assert_eq!(human_readable_count(123_456_789), "123.5 M"); + assert_eq!(human_readable_count(1_000_000_000), "1.00 B"); + assert_eq!(human_readable_count(1_532_000_000), "1.53 B"); + assert_eq!(human_readable_count(999_999_999_999), "1000.0 B"); + assert_eq!(human_readable_count(1_000_000_000_000), "1.00 T"); + assert_eq!(human_readable_count(42_000_000_000_000), "42.00 T"); + } + + #[test] + fn test_human_readable_duration() { + assert_eq!(human_readable_duration(0), "0ns"); + assert_eq!(human_readable_duration(1), "1ns"); + assert_eq!(human_readable_duration(999), "999ns"); + assert_eq!(human_readable_duration(1_000), "1.00µs"); + assert_eq!(human_readable_duration(1_234), "1.23µs"); + assert_eq!(human_readable_duration(999_999), "1000.00µs"); + assert_eq!(human_readable_duration(1_000_000), "1.00ms"); + assert_eq!(human_readable_duration(11_295_377), "11.30ms"); + assert_eq!(human_readable_duration(1_234_567), "1.23ms"); + assert_eq!(human_readable_duration(999_999_999), "1000.00ms"); + assert_eq!(human_readable_duration(1_000_000_000), "1.00s"); + assert_eq!(human_readable_duration(1_234_567_890), "1.23s"); + assert_eq!(human_readable_duration(42_000_000_000), "42.00s"); + } +} diff --git a/datafusion/common/src/display/mod.rs b/datafusion/common/src/display/mod.rs index bad51c45f8ee8..a6a97b243f06a 100644 --- a/datafusion/common/src/display/mod.rs +++ b/datafusion/common/src/display/mod.rs @@ -18,6 +18,7 @@ //! Types for plan display mod graphviz; +pub mod human_readable; pub use graphviz::*; use std::{ diff --git a/datafusion/common/src/encryption.rs b/datafusion/common/src/encryption.rs new file mode 100644 index 0000000000000..2a8cfdbc89966 --- /dev/null +++ b/datafusion/common/src/encryption.rs @@ -0,0 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Support optional features for encryption in Parquet files. +//! This module provides types and functions related to encryption in Parquet files. + +#[cfg(feature = "parquet_encryption")] +pub use parquet::encryption::decrypt::FileDecryptionProperties; +#[cfg(feature = "parquet_encryption")] +pub use parquet::encryption::encrypt::FileEncryptionProperties; + +#[cfg(not(feature = "parquet_encryption"))] +#[derive(Default, Clone, Debug)] +pub struct FileDecryptionProperties; +#[cfg(not(feature = "parquet_encryption"))] +#[derive(Default, Clone, Debug)] +pub struct FileEncryptionProperties; + +pub use crate::config::{ConfigFileDecryptionProperties, ConfigFileEncryptionProperties}; diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index b4a537fdce7ee..b7a30f868a02b 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -15,7 +15,25 @@ // specific language governing permissions and limitations // under the License. -//! DataFusion error types +//! # Error Handling in DataFusion +//! +//! In DataFusion, there are two types of errors that can be raised: +//! +//! 1. Expected errors – These indicate invalid operations performed by the caller, +//! such as attempting to open a non-existent file. Different categories exist to +//! distinguish their sources (e.g., [`DataFusionError::ArrowError`], +//! [`DataFusionError::IoError`], etc.). +//! +//! 2. Unexpected errors – Represented by [`DataFusionError::Internal`], these +//! indicate that an internal invariant has been broken, suggesting a potential +//! bug in the system. +//! +//! There are several convenient macros for throwing errors. For example, use +//! `exec_err!` for expected errors. +//! For invariant checks, you can use `assert_or_internal_err!`, +//! `assert_eq_or_internal_err!`, `assert_ne_or_internal_err!` for easier assertions. +//! On the performance-critical path, use `debug_assert!` instead to reduce overhead. + #[cfg(feature = "backtrace")] use std::backtrace::{Backtrace, BacktraceStatus}; @@ -35,6 +53,7 @@ use apache_avro::Error as AvroError; use arrow::error::ArrowError; #[cfg(feature = "parquet")] use parquet::errors::ParquetError; +#[cfg(feature = "sql")] use sqlparser::parser::ParserError; use tokio::task::JoinError; @@ -53,22 +72,23 @@ pub enum DataFusionError { /// Error returned by arrow. /// /// 2nd argument is for optional backtrace - ArrowError(ArrowError, Option), + ArrowError(Box, Option), /// Error when reading / writing Parquet data. #[cfg(feature = "parquet")] - ParquetError(ParquetError), + ParquetError(Box), /// Error when reading Avro data. #[cfg(feature = "avro")] AvroError(Box), /// Error when reading / writing to / from an object_store (e.g. S3 or LocalFile) #[cfg(feature = "object_store")] - ObjectStore(object_store::Error), + ObjectStore(Box), /// Error when an I/O operation fails IoError(io::Error), /// Error when SQL is syntactically incorrect. /// /// 2nd argument is for optional backtrace - SQL(ParserError, Option), + #[cfg(feature = "sql")] + SQL(Box, Option), /// Error when a feature is not yet implemented. /// /// These errors are sometimes returned for features that are still in @@ -107,7 +127,7 @@ pub enum DataFusionError { /// /// 2nd argument is for optional backtrace /// Boxing the optional backtrace to prevent - SchemaError(SchemaError, Box>), + SchemaError(Box, Box>), /// Error during execution of the query. /// /// This error is returned when an error happens during execution due to a @@ -118,7 +138,7 @@ pub enum DataFusionError { /// [`JoinError`] during execution of the query. /// /// This error can't occur for unjoined tasks, such as execution shutdown. - ExecutionJoin(JoinError), + ExecutionJoin(Box), /// Error when resources (such as memory of scratch disk space) are exhausted. /// /// This error is thrown when a consumer cannot acquire additional memory @@ -151,6 +171,10 @@ pub enum DataFusionError { /// to multiple receivers. For example, when the source of a repartition /// errors and the error is propagated to multiple consumers. Shared(Arc), + /// An error that originated during a foreign function interface call. + /// Transferring errors across the FFI boundary is difficult, so the original + /// error will be converted to a string. + Ffi(String), } #[macro_export] @@ -164,7 +188,7 @@ macro_rules! context { #[derive(Debug)] pub enum SchemaError { /// Schema contains a (possibly) qualified and unqualified field with same unqualified name - AmbiguousReference { field: Column }, + AmbiguousReference { field: Box }, /// Schema contains duplicate qualified field name DuplicateQualifiedField { qualifier: Box, @@ -276,14 +300,14 @@ impl From for DataFusionError { impl From for DataFusionError { fn from(e: ArrowError) -> Self { - DataFusionError::ArrowError(e, None) + DataFusionError::ArrowError(Box::new(e), Some(DataFusionError::get_back_trace())) } } impl From for ArrowError { fn from(e: DataFusionError) -> Self { match e { - DataFusionError::ArrowError(e, _) => e, + DataFusionError::ArrowError(e, _) => *e, DataFusionError::External(e) => ArrowError::ExternalError(e), other => ArrowError::ExternalError(Box::new(other)), } @@ -304,7 +328,7 @@ impl From<&Arc> for DataFusionError { #[cfg(feature = "parquet")] impl From for DataFusionError { fn from(e: ParquetError) -> Self { - DataFusionError::ParquetError(e) + DataFusionError::ParquetError(Box::new(e)) } } @@ -318,20 +342,21 @@ impl From for DataFusionError { #[cfg(feature = "object_store")] impl From for DataFusionError { fn from(e: object_store::Error) -> Self { - DataFusionError::ObjectStore(e) + DataFusionError::ObjectStore(Box::new(e)) } } #[cfg(feature = "object_store")] impl From for DataFusionError { fn from(e: object_store::path::Error) -> Self { - DataFusionError::ObjectStore(e.into()) + DataFusionError::ObjectStore(Box::new(e.into())) } } +#[cfg(feature = "sql")] impl From for DataFusionError { fn from(e: ParserError) -> Self { - DataFusionError::SQL(e, None) + DataFusionError::SQL(Box::new(e), None) } } @@ -361,22 +386,23 @@ impl Display for DataFusionError { impl Error for DataFusionError { fn source(&self) -> Option<&(dyn Error + 'static)> { match self { - DataFusionError::ArrowError(e, _) => Some(e), + DataFusionError::ArrowError(e, _) => Some(e.as_ref()), #[cfg(feature = "parquet")] - DataFusionError::ParquetError(e) => Some(e), + DataFusionError::ParquetError(e) => Some(e.as_ref()), #[cfg(feature = "avro")] - DataFusionError::AvroError(e) => Some(e), + DataFusionError::AvroError(e) => Some(e.as_ref()), #[cfg(feature = "object_store")] - DataFusionError::ObjectStore(e) => Some(e), + DataFusionError::ObjectStore(e) => Some(e.as_ref()), DataFusionError::IoError(e) => Some(e), - DataFusionError::SQL(e, _) => Some(e), + #[cfg(feature = "sql")] + DataFusionError::SQL(e, _) => Some(e.as_ref()), DataFusionError::NotImplemented(_) => None, DataFusionError::Internal(_) => None, DataFusionError::Configuration(_) => None, DataFusionError::Plan(_) => None, - DataFusionError::SchemaError(e, _) => Some(e), + DataFusionError::SchemaError(e, _) => Some(e.as_ref()), DataFusionError::Execution(_) => None, - DataFusionError::ExecutionJoin(e) => Some(e), + DataFusionError::ExecutionJoin(e) => Some(e.as_ref()), DataFusionError::ResourcesExhausted(_) => None, DataFusionError::External(e) => Some(e.as_ref()), DataFusionError::Context(_, e) => Some(e.as_ref()), @@ -391,6 +417,7 @@ impl Error for DataFusionError { // can't be executed. DataFusionError::Collection(errs) => errs.first().map(|e| e as &dyn Error), DataFusionError::Shared(e) => Some(e.as_ref()), + DataFusionError::Ffi(_) => None, } } } @@ -451,12 +478,13 @@ impl DataFusionError { /// If backtrace enabled then error has a format "message" [`Self::BACK_TRACE_SEP`] "backtrace" /// The method strips the backtrace and outputs "message" pub fn strip_backtrace(&self) -> String { - self.to_string() + (*self + .to_string() .split(Self::BACK_TRACE_SEP) .collect::>() .first() - .unwrap_or(&"") - .to_string() + .unwrap_or(&"")) + .to_string() } /// To enable optional rust backtrace in DataFusion: @@ -497,6 +525,7 @@ impl DataFusionError { #[cfg(feature = "object_store")] DataFusionError::ObjectStore(_) => "Object Store error: ", DataFusionError::IoError(_) => "IO error: ", + #[cfg(feature = "sql")] DataFusionError::SQL(_, _) => "SQL error: ", DataFusionError::NotImplemented(_) => { "This feature is not implemented: " @@ -520,10 +549,11 @@ impl DataFusionError { errs.first().expect("cannot construct DataFusionError::Collection with 0 errors, but got one such case").error_prefix() } DataFusionError::Shared(_) => "", + DataFusionError::Ffi(_) => "FFI error: ", } } - pub fn message(&self) -> Cow { + pub fn message(&self) -> Cow<'_, str> { match *self { DataFusionError::ArrowError(ref desc, ref backtrace) => { let backtrace = backtrace.clone().unwrap_or_else(|| "".to_owned()); @@ -534,6 +564,7 @@ impl DataFusionError { #[cfg(feature = "avro")] DataFusionError::AvroError(ref desc) => Cow::Owned(desc.to_string()), DataFusionError::IoError(ref desc) => Cow::Owned(desc.to_string()), + #[cfg(feature = "sql")] DataFusionError::SQL(ref desc, ref backtrace) => { let backtrace: String = backtrace.clone().unwrap_or_else(|| "".to_owned()); @@ -542,8 +573,9 @@ impl DataFusionError { DataFusionError::Configuration(ref desc) => Cow::Owned(desc.to_string()), DataFusionError::NotImplemented(ref desc) => Cow::Owned(desc.to_string()), DataFusionError::Internal(ref desc) => Cow::Owned(format!( - "{desc}.\nThis was likely caused by a bug in DataFusion's \ - code and we would welcome that you file an bug report in our issue tracker" + "{desc}.\nThis issue was likely caused by a bug in DataFusion's code. \ + Please help us to resolve this by filing a bug report in our issue tracker: \ + https://github.com/apache/datafusion/issues" )), DataFusionError::Plan(ref desc) => Cow::Owned(desc.to_string()), DataFusionError::SchemaError(ref desc, ref backtrace) => { @@ -570,6 +602,7 @@ impl DataFusionError { .expect("cannot construct DataFusionError::Collection with 0 errors") .message(), DataFusionError::Shared(ref desc) => Cow::Owned(desc.to_string()), + DataFusionError::Ffi(ref desc) => Cow::Owned(desc.to_string()), } } @@ -676,7 +709,10 @@ impl DataFusionError { /// let mut builder = DataFusionError::builder(); /// builder.add_error(DataFusionError::Internal("foo".to_owned())); /// // ok_or returns the value if no errors have been added -/// assert_contains!(builder.error_or(42).unwrap_err().to_string(), "Internal error: foo"); +/// assert_contains!( +/// builder.error_or(42).unwrap_err().to_string(), +/// "Internal error: foo" +/// ); /// ``` #[derive(Debug, Default)] pub struct DataFusionErrorBuilder(Vec); @@ -694,7 +730,10 @@ impl DataFusionErrorBuilder { /// # use datafusion_common::{assert_contains, DataFusionError}; /// let mut builder = DataFusionError::builder(); /// builder.add_error(DataFusionError::Internal("foo".to_owned())); - /// assert_contains!(builder.error_or(42).unwrap_err().to_string(), "Internal error: foo"); + /// assert_contains!( + /// builder.error_or(42).unwrap_err().to_string(), + /// "Internal error: foo" + /// ); /// ``` pub fn add_error(&mut self, error: DataFusionError) { self.0.push(error); @@ -706,8 +745,11 @@ impl DataFusionErrorBuilder { /// ``` /// # use datafusion_common::{assert_contains, DataFusionError}; /// let builder = DataFusionError::builder() - /// .with_error(DataFusionError::Internal("foo".to_owned())); - /// assert_contains!(builder.error_or(42).unwrap_err().to_string(), "Internal error: foo"); + /// .with_error(DataFusionError::Internal("foo".to_owned())); + /// assert_contains!( + /// builder.error_or(42).unwrap_err().to_string(), + /// "Internal error: foo" + /// ); /// ``` pub fn with_error(mut self, error: DataFusionError) -> Self { self.0.push(error); @@ -733,7 +775,7 @@ impl DataFusionErrorBuilder { macro_rules! unwrap_or_internal_err { ($Value: ident) => { $Value.ok_or_else(|| { - DataFusionError::Internal(format!( + $crate::DataFusionError::Internal(format!( "{} should not be None", stringify!($Value) )) @@ -741,6 +783,116 @@ macro_rules! unwrap_or_internal_err { }; } +/// Assert a condition, returning `DataFusionError::Internal` on failure. +/// +/// # Examples +/// +/// ```text +/// assert_or_internal_err!(predicate); +/// assert_or_internal_err!(predicate, "human readable message"); +/// assert_or_internal_err!(predicate, format!("details: {}", value)); +/// ``` +#[macro_export] +macro_rules! assert_or_internal_err { + ($cond:expr) => { + if !$cond { + return Err($crate::DataFusionError::Internal(format!( + "Assertion failed: {}", + stringify!($cond) + ))); + } + }; + ($cond:expr, $($arg:tt)+) => { + if !$cond { + return Err($crate::DataFusionError::Internal(format!( + "Assertion failed: {}: {}", + stringify!($cond), + format!($($arg)+) + ))); + } + }; +} + +/// Assert equality, returning `DataFusionError::Internal` on failure. +/// +/// # Examples +/// +/// ```text +/// assert_eq_or_internal_err!(actual, expected); +/// assert_eq_or_internal_err!(left_expr, right_expr, "values must match"); +/// assert_eq_or_internal_err!(lhs, rhs, "metadata: {}", extra); +/// ``` +#[macro_export] +macro_rules! assert_eq_or_internal_err { + ($left:expr, $right:expr $(,)?) => {{ + let left_val = &$left; + let right_val = &$right; + if left_val != right_val { + return Err($crate::DataFusionError::Internal(format!( + "Assertion failed: {} == {} (left: {:?}, right: {:?})", + stringify!($left), + stringify!($right), + left_val, + right_val + ))); + } + }}; + ($left:expr, $right:expr, $($arg:tt)+) => {{ + let left_val = &$left; + let right_val = &$right; + if left_val != right_val { + return Err($crate::DataFusionError::Internal(format!( + "Assertion failed: {} == {} (left: {:?}, right: {:?}): {}", + stringify!($left), + stringify!($right), + left_val, + right_val, + format!($($arg)+) + ))); + } + }}; +} + +/// Assert inequality, returning `DataFusionError::Internal` on failure. +/// +/// # Examples +/// +/// ```text +/// assert_ne_or_internal_err!(left, right); +/// assert_ne_or_internal_err!(lhs_expr, rhs_expr, "values must differ"); +/// assert_ne_or_internal_err!(a, b, "context {}", info); +/// ``` +#[macro_export] +macro_rules! assert_ne_or_internal_err { + ($left:expr, $right:expr $(,)?) => {{ + let left_val = &$left; + let right_val = &$right; + if left_val == right_val { + return Err($crate::DataFusionError::Internal(format!( + "Assertion failed: {} != {} (left: {:?}, right: {:?})", + stringify!($left), + stringify!($right), + left_val, + right_val + ))); + } + }}; + ($left:expr, $right:expr, $($arg:tt)+) => {{ + let left_val = &$left; + let right_val = &$right; + if left_val == right_val { + return Err($crate::DataFusionError::Internal(format!( + "Assertion failed: {} != {} (left: {:?}, right: {:?}): {}", + stringify!($left), + stringify!($right), + left_val, + right_val, + format!($($arg)+) + ))); + } + }}; +} + /// Add a macros for concise DataFusionError::* errors declaration /// supports placeholders the same way as `format!` /// Examples: @@ -751,84 +903,131 @@ macro_rules! unwrap_or_internal_err { /// plan_err!("Error {val:?}") /// /// `NAME_ERR` - macro name for wrapping Err(DataFusionError::*) +/// `PREFIXED_NAME_ERR` - underscore-prefixed alias for NAME_ERR (e.g., _plan_err) +/// (Needed to avoid compiler error when using macro in the same crate: `macros from the current crate cannot be referred to by absolute paths`) /// `NAME_DF_ERR` - macro name for wrapping DataFusionError::*. Needed to keep backtrace opportunity /// in construction where DataFusionError::* used directly, like `map_err`, `ok_or_else`, etc +/// `PREFIXED_NAME_DF_ERR` - underscore-prefixed alias for NAME_DF_ERR (e.g., _plan_datafusion_err). +/// (Needed to avoid compiler error when using macro in the same crate: `macros from the current crate cannot be referred to by absolute paths`) macro_rules! make_error { - ($NAME_ERR:ident, $NAME_DF_ERR: ident, $ERR:ident) => { make_error!(@inner ($), $NAME_ERR, $NAME_DF_ERR, $ERR); }; - (@inner ($d:tt), $NAME_ERR:ident, $NAME_DF_ERR:ident, $ERR:ident) => { - ::paste::paste!{ - /// Macro wraps `$ERR` to add backtrace feature - #[macro_export] - macro_rules! $NAME_DF_ERR { - ($d($d args:expr),* $d(; diagnostic=$d DIAG:expr)?) => {{ - let err =$crate::DataFusionError::$ERR( - ::std::format!( - "{}{}", - ::std::format!($d($d args),*), - $crate::DataFusionError::get_back_trace(), - ).into() - ); - $d ( - let err = err.with_diagnostic($d DIAG); - )? - err - } - } + ($NAME_ERR:ident, $PREFIXED_NAME_ERR:ident, $NAME_DF_ERR:ident, $PREFIXED_NAME_DF_ERR:ident, $ERR:ident) => { + make_error!(@inner ($), $NAME_ERR, $PREFIXED_NAME_ERR, $NAME_DF_ERR, $PREFIXED_NAME_DF_ERR, $ERR); + }; + (@inner ($d:tt), $NAME_ERR:ident, $PREFIXED_NAME_ERR:ident, $NAME_DF_ERR:ident, $PREFIXED_NAME_DF_ERR:ident, $ERR:ident) => { + /// Macro wraps `$ERR` to add backtrace feature + #[macro_export] + macro_rules! $NAME_DF_ERR { + ($d($d args:expr),* $d(; diagnostic = $d DIAG:expr)?) => {{ + let err = $crate::DataFusionError::$ERR( + ::std::format!( + "{}{}", + ::std::format!($d($d args),*), + $crate::DataFusionError::get_back_trace(), + ).into() + ); + $d ( + let err = err.with_diagnostic($d DIAG); + )? + err + }} } - /// Macro wraps Err(`$ERR`) to add backtrace feature - #[macro_export] - macro_rules! $NAME_ERR { - ($d($d args:expr),* $d(; diagnostic = $d DIAG:expr)?) => {{ - let err = $crate::[<_ $NAME_DF_ERR>]!($d($d args),*); - $d ( - let err = err.with_diagnostic($d DIAG); - )? - Err(err) - - }} - } - - - // Note: Certain macros are used in this crate, but not all. - // This macro generates a use or all of them in case they are needed - // so we allow unused code to avoid warnings when they are not used - #[doc(hidden)] - #[allow(unused)] - pub use $NAME_ERR as [<_ $NAME_ERR>]; - #[doc(hidden)] - #[allow(unused)] - pub use $NAME_DF_ERR as [<_ $NAME_DF_ERR>]; + /// Macro wraps Err(`$ERR`) to add backtrace feature + #[macro_export] + macro_rules! $NAME_ERR { + ($d($d args:expr),* $d(; diagnostic = $d DIAG:expr)?) => {{ + let err = $crate::$PREFIXED_NAME_DF_ERR!($d($d args),*); + $d ( + let err = err.with_diagnostic($d DIAG); + )? + Err(err) + }} } + + #[doc(hidden)] + pub use $NAME_ERR as $PREFIXED_NAME_ERR; + #[doc(hidden)] + pub use $NAME_DF_ERR as $PREFIXED_NAME_DF_ERR; }; } // Exposes a macro to create `DataFusionError::Plan` with optional backtrace -make_error!(plan_err, plan_datafusion_err, Plan); +make_error!( + plan_err, + _plan_err, + plan_datafusion_err, + _plan_datafusion_err, + Plan +); // Exposes a macro to create `DataFusionError::Internal` with optional backtrace -make_error!(internal_err, internal_datafusion_err, Internal); +make_error!( + internal_err, + _internal_err, + internal_datafusion_err, + _internal_datafusion_err, + Internal +); // Exposes a macro to create `DataFusionError::NotImplemented` with optional backtrace -make_error!(not_impl_err, not_impl_datafusion_err, NotImplemented); +make_error!( + not_impl_err, + _not_impl_err, + not_impl_datafusion_err, + _not_impl_datafusion_err, + NotImplemented +); // Exposes a macro to create `DataFusionError::Execution` with optional backtrace -make_error!(exec_err, exec_datafusion_err, Execution); +make_error!( + exec_err, + _exec_err, + exec_datafusion_err, + _exec_datafusion_err, + Execution +); // Exposes a macro to create `DataFusionError::Configuration` with optional backtrace -make_error!(config_err, config_datafusion_err, Configuration); +make_error!( + config_err, + _config_err, + config_datafusion_err, + _config_datafusion_err, + Configuration +); // Exposes a macro to create `DataFusionError::Substrait` with optional backtrace -make_error!(substrait_err, substrait_datafusion_err, Substrait); +make_error!( + substrait_err, + _substrait_err, + substrait_datafusion_err, + _substrait_datafusion_err, + Substrait +); // Exposes a macro to create `DataFusionError::ResourcesExhausted` with optional backtrace -make_error!(resources_err, resources_datafusion_err, ResourcesExhausted); +make_error!( + resources_err, + _resources_err, + resources_datafusion_err, + _resources_datafusion_err, + ResourcesExhausted +); + +// Exposes a macro to create `DataFusionError::Ffi` with optional backtrace +make_error!( + ffi_err, + _ffi_err, + ffi_datafusion_err, + _ffi_datafusion_err, + Ffi +); // Exposes a macro to create `DataFusionError::SQL` with optional backtrace #[macro_export] macro_rules! sql_datafusion_err { ($ERR:expr $(; diagnostic = $DIAG:expr)?) => {{ - let err = DataFusionError::SQL($ERR, Some(DataFusionError::get_back_trace())); + let err = $crate::DataFusionError::SQL(Box::new($ERR), Some($crate::DataFusionError::get_back_trace())); $( let err = err.with_diagnostic($DIAG); )? @@ -840,7 +1039,7 @@ macro_rules! sql_datafusion_err { #[macro_export] macro_rules! sql_err { ($ERR:expr $(; diagnostic = $DIAG:expr)?) => {{ - let err = datafusion_common::sql_datafusion_err!($ERR); + let err = $crate::sql_datafusion_err!($ERR); $( let err = err.with_diagnostic($DIAG); )? @@ -852,7 +1051,7 @@ macro_rules! sql_err { #[macro_export] macro_rules! arrow_datafusion_err { ($ERR:expr $(; diagnostic = $DIAG:expr)?) => {{ - let err = DataFusionError::ArrowError($ERR, Some(DataFusionError::get_back_trace())); + let err = $crate::DataFusionError::ArrowError(Box::new($ERR), Some($crate::DataFusionError::get_back_trace())); $( let err = err.with_diagnostic($DIAG); )? @@ -865,7 +1064,7 @@ macro_rules! arrow_datafusion_err { macro_rules! arrow_err { ($ERR:expr $(; diagnostic = $DIAG:expr)?) => { { - let err = datafusion_common::arrow_datafusion_err!($ERR); + let err = $crate::arrow_datafusion_err!($ERR); $( let err = err.with_diagnostic($DIAG); )? @@ -877,9 +1076,9 @@ macro_rules! arrow_err { #[macro_export] macro_rules! schema_datafusion_err { ($ERR:expr $(; diagnostic = $DIAG:expr)?) => {{ - let err = $crate::error::DataFusionError::SchemaError( - $ERR, - Box::new(Some($crate::error::DataFusionError::get_back_trace())), + let err = $crate::DataFusionError::SchemaError( + Box::new($ERR), + Box::new(Some($crate::DataFusionError::get_back_trace())), ); $( let err = err.with_diagnostic($DIAG); @@ -892,9 +1091,9 @@ macro_rules! schema_datafusion_err { #[macro_export] macro_rules! schema_err { ($ERR:expr $(; diagnostic = $DIAG:expr)?) => {{ - let err = $crate::error::DataFusionError::SchemaError( - $ERR, - Box::new(Some($crate::error::DataFusionError::get_back_trace())), + let err = $crate::DataFusionError::SchemaError( + Box::new($ERR), + Box::new(Some($crate::DataFusionError::get_back_trace())), ); $( let err = err.with_diagnostic($DIAG); @@ -951,17 +1150,137 @@ pub fn add_possible_columns_to_diag( #[cfg(test)] mod test { + use super::*; + + use std::mem::size_of; use std::sync::Arc; - use crate::error::{DataFusionError, GenericError}; use arrow::error::ArrowError; + use insta::assert_snapshot; + + fn ok_result() -> Result<()> { + Ok(()) + } + + #[test] + fn test_assert_eq_or_internal_err_passes() -> Result<()> { + assert_eq_or_internal_err!(1, 1); + ok_result() + } + + #[test] + fn test_assert_eq_or_internal_err_fails() { + fn check() -> Result<()> { + assert_eq_or_internal_err!(1, 2, "expected equality"); + ok_result() + } + + let err = check().unwrap_err(); + assert_snapshot!( + err.to_string(), + @r" + Internal error: Assertion failed: 1 == 2 (left: 1, right: 2): expected equality. + This issue was likely caused by a bug in DataFusion's code. Please help us to resolve this by filing a bug report in our issue tracker: https://github.com/apache/datafusion/issues + " + ); + } + + #[test] + fn test_assert_ne_or_internal_err_passes() -> Result<()> { + assert_ne_or_internal_err!(1, 2); + ok_result() + } + + #[test] + fn test_assert_ne_or_internal_err_fails() { + fn check() -> Result<()> { + assert_ne_or_internal_err!(3, 3, "values must differ"); + ok_result() + } + + let err = check().unwrap_err(); + assert_snapshot!( + err.to_string(), + @r" + Internal error: Assertion failed: 3 != 3 (left: 3, right: 3): values must differ. + This issue was likely caused by a bug in DataFusion's code. Please help us to resolve this by filing a bug report in our issue tracker: https://github.com/apache/datafusion/issues + " + ); + } + + #[test] + fn test_assert_or_internal_err_passes() -> Result<()> { + assert_or_internal_err!(true); + assert_or_internal_err!(true, "message"); + ok_result() + } + + #[test] + fn test_assert_or_internal_err_fails_default() { + fn check() -> Result<()> { + assert_or_internal_err!(false); + ok_result() + } + + let err = check().unwrap_err(); + assert_snapshot!( + err.to_string(), + @r" + Internal error: Assertion failed: false. + This issue was likely caused by a bug in DataFusion's code. Please help us to resolve this by filing a bug report in our issue tracker: https://github.com/apache/datafusion/issues + " + ); + } + + #[test] + fn test_assert_or_internal_err_fails_with_message() { + fn check() -> Result<()> { + assert_or_internal_err!(false, "custom message"); + ok_result() + } + + let err = check().unwrap_err(); + assert_snapshot!( + err.to_string(), + @r" + Internal error: Assertion failed: false: custom message. + This issue was likely caused by a bug in DataFusion's code. Please help us to resolve this by filing a bug report in our issue tracker: https://github.com/apache/datafusion/issues + " + ); + } + + #[test] + fn test_assert_or_internal_err_with_format_arguments() { + fn check() -> Result<()> { + assert_or_internal_err!(false, "custom {}", 42); + ok_result() + } + + let err = check().unwrap_err(); + assert_snapshot!( + err.to_string(), + @r" + Internal error: Assertion failed: false: custom 42. + This issue was likely caused by a bug in DataFusion's code. Please help us to resolve this by filing a bug report in our issue tracker: https://github.com/apache/datafusion/issues + " + ); + } + + #[test] + fn test_error_size() { + // Since Errors influence the size of Result which influence the size of the stack + // please don't allow this to grow larger + assert_eq!(size_of::(), 40); + assert_eq!(size_of::(), 40); + } #[test] fn datafusion_error_to_arrow() { let res = return_arrow_error().unwrap_err(); - assert!(res - .to_string() - .starts_with("External error: Error during planning: foo")); + assert!( + res.to_string() + .starts_with("External error: Error during planning: foo") + ); } #[test] @@ -973,7 +1292,7 @@ mod test { // To pass the test the environment variable RUST_BACKTRACE should be set to 1 to enforce backtrace #[cfg(feature = "backtrace")] #[test] - #[allow(clippy::unnecessary_literal_unwrap)] + #[expect(clippy::unnecessary_literal_unwrap)] fn test_enabled_backtrace() { match std::env::var("RUST_BACKTRACE") { Ok(val) if val == "1" => {} @@ -990,17 +1309,17 @@ mod test { .unwrap(), &"Error during planning: Err" ); - assert!(!err - .split(DataFusionError::BACK_TRACE_SEP) - .collect::>() - .get(1) - .unwrap() - .is_empty()); + assert!( + !err.split(DataFusionError::BACK_TRACE_SEP) + .collect::>() + .get(1) + .unwrap() + .is_empty() + ); } #[cfg(not(feature = "backtrace"))] #[test] - #[allow(clippy::unnecessary_literal_unwrap)] fn test_disabled_backtrace() { let res: Result<(), DataFusionError> = plan_err!("Err"); let res = res.unwrap_err().to_string(); @@ -1020,8 +1339,8 @@ mod test { do_root_test( DataFusionError::ArrowError( - ArrowError::ExternalError(Box::new(DataFusionError::ResourcesExhausted( - "foo".to_string(), + Box::new(ArrowError::ExternalError(Box::new( + DataFusionError::ResourcesExhausted("foo".to_string()), ))), None, ), @@ -1044,9 +1363,11 @@ mod test { do_root_test( DataFusionError::ArrowError( - ArrowError::ExternalError(Box::new(ArrowError::ExternalError(Box::new( - DataFusionError::ResourcesExhausted("foo".to_string()), - )))), + Box::new(ArrowError::ExternalError(Box::new( + ArrowError::ExternalError(Box::new( + DataFusionError::ResourcesExhausted("foo".to_string()), + )), + ))), None, ), DataFusionError::ResourcesExhausted("foo".to_string()), @@ -1068,7 +1389,6 @@ mod test { } #[test] - #[allow(clippy::unnecessary_literal_unwrap)] fn test_make_error_parse_input() { let res: Result<(), DataFusionError> = plan_err!("Err"); let res = res.unwrap_err(); @@ -1120,7 +1440,7 @@ mod test { ); // assert wrapping other Error - let generic_error: GenericError = Box::new(std::io::Error::other("io error")); + let generic_error: GenericError = Box::new(io::Error::other("io error")); let datafusion_error: DataFusionError = generic_error.into(); println!("{}", datafusion_error.strip_backtrace()); assert_eq!( @@ -1131,15 +1451,17 @@ mod test { #[test] fn external_error_no_recursive() { - let generic_error_1: GenericError = Box::new(std::io::Error::other("io error")); + let generic_error_1: GenericError = Box::new(io::Error::other("io error")); let external_error_1: DataFusionError = generic_error_1.into(); let generic_error_2: GenericError = Box::new(external_error_1); let external_error_2: DataFusionError = generic_error_2.into(); println!("{external_error_2}"); - assert!(external_error_2 - .to_string() - .starts_with("External error: io error")); + assert!( + external_error_2 + .to_string() + .starts_with("External error: io error") + ); } /// Model what happens when implementing SendableRecordBatchStream: @@ -1151,7 +1473,7 @@ mod test { /// Model what happens when using arrow kernels in DataFusion /// code: need to turn an ArrowError into a DataFusionError - fn return_datafusion_error() -> crate::error::Result<()> { + fn return_datafusion_error() -> Result<()> { // Expect the '?' to work Err(ArrowError::SchemaError("bar".to_string()).into()) } diff --git a/datafusion/common/src/file_options/csv_writer.rs b/datafusion/common/src/file_options/csv_writer.rs index 943288af91642..4e6f74a4448af 100644 --- a/datafusion/common/src/file_options/csv_writer.rs +++ b/datafusion/common/src/file_options/csv_writer.rs @@ -31,6 +31,8 @@ pub struct CsvWriterOptions { /// Compression to apply after ArrowWriter serializes RecordBatches. /// This compression is applied by DataFusion not the ArrowWriter itself. pub compression: CompressionTypeVariant, + /// Compression level for the output file. + pub compression_level: Option, } impl CsvWriterOptions { @@ -41,6 +43,20 @@ impl CsvWriterOptions { Self { writer_options, compression, + compression_level: None, + } + } + + /// Create a new `CsvWriterOptions` with the specified compression level. + pub fn new_with_level( + writer_options: WriterBuilder, + compression: CompressionTypeVariant, + compression_level: u32, + ) -> Self { + Self { + writer_options, + compression, + compression_level: Some(compression_level), } } } @@ -81,6 +97,7 @@ impl TryFrom<&CsvOptions> for CsvWriterOptions { Ok(CsvWriterOptions { writer_options: builder, compression: value.compression, + compression_level: value.compression_level, }) } } diff --git a/datafusion/common/src/file_options/json_writer.rs b/datafusion/common/src/file_options/json_writer.rs index 750d2972329bb..a537192c8128a 100644 --- a/datafusion/common/src/file_options/json_writer.rs +++ b/datafusion/common/src/file_options/json_writer.rs @@ -27,11 +27,26 @@ use crate::{ #[derive(Clone, Debug)] pub struct JsonWriterOptions { pub compression: CompressionTypeVariant, + pub compression_level: Option, } impl JsonWriterOptions { pub fn new(compression: CompressionTypeVariant) -> Self { - Self { compression } + Self { + compression, + compression_level: None, + } + } + + /// Create a new `JsonWriterOptions` with the specified compression and level. + pub fn new_with_level( + compression: CompressionTypeVariant, + compression_level: u32, + ) -> Self { + Self { + compression, + compression_level: Some(compression_level), + } } } @@ -41,6 +56,7 @@ impl TryFrom<&JsonOptions> for JsonWriterOptions { fn try_from(value: &JsonOptions) -> Result { Ok(JsonWriterOptions { compression: value.compression, + compression_level: value.compression_level, }) } } diff --git a/datafusion/common/src/file_options/mod.rs b/datafusion/common/src/file_options/mod.rs index 02667e0165717..5d2abd23172ed 100644 --- a/datafusion/common/src/file_options/mod.rs +++ b/datafusion/common/src/file_options/mod.rs @@ -31,10 +31,10 @@ mod tests { use std::collections::HashMap; use crate::{ + Result, config::{ConfigFileType, TableOptions}, file_options::{csv_writer::CsvWriterOptions, json_writer::JsonWriterOptions}, parsers::CompressionTypeVariant, - Result, }; use parquet::{ @@ -84,7 +84,7 @@ mod tests { .build(); // Verify the expected options propagated down to parquet crate WriterProperties struct - assert_eq!(properties.max_row_group_size(), 123); + assert_eq!(properties.max_row_group_row_count(), Some(123)); assert_eq!(properties.data_page_size_limit(), 123); assert_eq!(properties.write_batch_size(), 123); assert_eq!(properties.writer_version(), WriterVersion::PARQUET_2_0); diff --git a/datafusion/common/src/file_options/parquet_writer.rs b/datafusion/common/src/file_options/parquet_writer.rs index 07e763f0ee6f3..a7a1fc6d0bb66 100644 --- a/datafusion/common/src/file_options/parquet_writer.rs +++ b/datafusion/common/src/file_options/parquet_writer.rs @@ -17,25 +17,23 @@ //! Options related to how parquet files should be written -use base64::Engine; use std::sync::Arc; use crate::{ + _internal_datafusion_err, DataFusionError, Result, config::{ParquetOptions, TableParquetOptions}, - DataFusionError, Result, _internal_datafusion_err, }; use arrow::datatypes::Schema; -// TODO: handle once deprecated -#[allow(deprecated)] +use parquet::arrow::encode_arrow_schema; use parquet::{ arrow::ARROW_SCHEMA_META_KEY, basic::{BrotliLevel, GzipLevel, ZstdLevel}, file::{ metadata::KeyValue, properties::{ - EnabledStatistics, WriterProperties, WriterPropertiesBuilder, WriterVersion, - DEFAULT_MAX_STATISTICS_SIZE, DEFAULT_STATISTICS_ENABLED, + DEFAULT_STATISTICS_ENABLED, EnabledStatistics, WriterProperties, + WriterPropertiesBuilder, }, }, schema::types::ColumnPath, @@ -89,12 +87,15 @@ impl TryFrom<&TableParquetOptions> for WriterPropertiesBuilder { /// Convert the session's [`TableParquetOptions`] into a single write action's [`WriterPropertiesBuilder`]. /// /// The returned [`WriterPropertiesBuilder`] includes customizations applicable per column. + /// Note that any encryption options are ignored as building the `FileEncryptionProperties` + /// might require other inputs besides the [`TableParquetOptions`]. fn try_from(table_parquet_options: &TableParquetOptions) -> Result { // Table options include kv_metadata and col-specific options let TableParquetOptions { global, column_specific_options, key_value_metadata, + crypto: _, } = table_parquet_options; let mut builder = global.into_writer_properties_builder()?; @@ -103,7 +104,9 @@ impl TryFrom<&TableParquetOptions> for WriterPropertiesBuilder { if !global.skip_arrow_metadata && !key_value_metadata.contains_key(ARROW_SCHEMA_META_KEY) { - return Err(_internal_datafusion_err!("arrow schema was not added to the kv_metadata, even though it is required by configuration settings")); + return Err(_internal_datafusion_err!( + "arrow schema was not added to the kv_metadata, even though it is required by configuration settings" + )); } // add kv_meta, if any @@ -157,47 +160,12 @@ impl TryFrom<&TableParquetOptions> for WriterPropertiesBuilder { builder = builder.set_column_bloom_filter_ndv(path.clone(), bloom_filter_ndv); } - - // max_statistics_size is deprecated, currently it is not being used - // TODO: remove once deprecated - #[allow(deprecated)] - if let Some(max_statistics_size) = options.max_statistics_size { - builder = { - #[allow(deprecated)] - builder.set_column_max_statistics_size(path, max_statistics_size) - } - } } Ok(builder) } } -/// Encodes the Arrow schema into the IPC format, and base64 encodes it -/// -/// TODO: use extern parquet's private method, once publicly available. -/// Refer to -fn encode_arrow_schema(schema: &Arc) -> String { - let options = arrow_ipc::writer::IpcWriteOptions::default(); - let mut dictionary_tracker = arrow_ipc::writer::DictionaryTracker::new(true); - let data_gen = arrow_ipc::writer::IpcDataGenerator::default(); - let mut serialized_schema = data_gen.schema_to_bytes_with_dictionary_tracker( - schema, - &mut dictionary_tracker, - &options, - ); - - // manually prepending the length to the schema as arrow uses the legacy IPC format - // TODO: change after addressing ARROW-9777 - let schema_len = serialized_schema.ipc_message.len(); - let mut len_prefix_schema = Vec::with_capacity(schema_len + 8); - len_prefix_schema.append(&mut vec![255u8, 255, 255, 255]); - len_prefix_schema.append((schema_len as u32).to_le_bytes().to_vec().as_mut()); - len_prefix_schema.append(&mut serialized_schema.ipc_message); - - base64::prelude::BASE64_STANDARD.encode(&len_prefix_schema) -} - impl ParquetOptions { /// Convert the global session options, [`ParquetOptions`], into a single write action's [`WriterPropertiesBuilder`]. /// @@ -206,7 +174,6 @@ impl ParquetOptions { /// /// Note that this method does not include the key_value_metadata from [`TableParquetOptions`]. pub fn into_writer_properties_builder(&self) -> Result { - #[allow(deprecated)] let ParquetOptions { data_pagesize_limit, write_batch_size, @@ -215,7 +182,6 @@ impl ParquetOptions { dictionary_enabled, dictionary_page_size_limit, statistics_enabled, - max_statistics_size, max_row_group_size, created_by, column_index_truncate_length, @@ -233,6 +199,7 @@ impl ParquetOptions { metadata_size_hint: _, pushdown_filters: _, reorder_filters: _, + force_filter_selections: _, // not used for writer props allow_single_file_parallelism: _, maximum_parallel_row_group_writers: _, maximum_buffered_record_batches_per_stream: _, @@ -241,12 +208,13 @@ impl ParquetOptions { binary_as_string: _, // not used for writer props coerce_int96: _, // not used for writer props skip_arrow_metadata: _, + max_predicate_cache_size: _, } = self; let mut builder = WriterProperties::builder() .set_data_page_size_limit(*data_pagesize_limit) .set_write_batch_size(*write_batch_size) - .set_writer_version(parse_version_string(writer_version.as_str())?) + .set_writer_version((*writer_version).into()) .set_dictionary_page_size_limit(*dictionary_page_size_limit) .set_statistics_enabled( statistics_enabled @@ -254,20 +222,13 @@ impl ParquetOptions { .and_then(|s| parse_statistics_string(s).ok()) .unwrap_or(DEFAULT_STATISTICS_ENABLED), ) - .set_max_row_group_size(*max_row_group_size) + .set_max_row_group_row_count(Some(*max_row_group_size)) .set_created_by(created_by.clone()) .set_column_index_truncate_length(*column_index_truncate_length) .set_statistics_truncate_length(*statistics_truncate_length) .set_data_page_row_count_limit(*data_page_row_count_limit) .set_bloom_filter_enabled(*bloom_filter_on_write); - builder = { - #[allow(deprecated)] - builder.set_max_statistics_size( - max_statistics_size.unwrap_or(DEFAULT_MAX_STATISTICS_SIZE), - ) - }; - if let Some(bloom_filter_fpp) = bloom_filter_fpp { builder = builder.set_bloom_filter_fpp(*bloom_filter_fpp); }; @@ -300,7 +261,7 @@ pub(crate) fn parse_encoding_string( "plain" => Ok(parquet::basic::Encoding::PLAIN), "plain_dictionary" => Ok(parquet::basic::Encoding::PLAIN_DICTIONARY), "rle" => Ok(parquet::basic::Encoding::RLE), - #[allow(deprecated)] + #[expect(deprecated)] "bit_packed" => Ok(parquet::basic::Encoding::BIT_PACKED), "delta_binary_packed" => Ok(parquet::basic::Encoding::DELTA_BINARY_PACKED), "delta_length_byte_array" => { @@ -380,10 +341,6 @@ pub fn parse_compression_string( level, )?)) } - "lzo" => { - check_level_is_none(codec, &level)?; - Ok(parquet::basic::Compression::LZO) - } "brotli" => { let level = require_level(codec, level)?; Ok(parquet::basic::Compression::BROTLI(BrotliLevel::try_new( @@ -407,19 +364,7 @@ pub fn parse_compression_string( _ => Err(DataFusionError::Configuration(format!( "Unknown or unsupported parquet compression: \ {str_setting}. Valid values are: uncompressed, snappy, gzip(level), \ - lzo, brotli(level), lz4, zstd(level), and lz4_raw." - ))), - } -} - -pub(crate) fn parse_version_string(str_setting: &str) -> Result { - let str_setting_lower: &str = &str_setting.to_lowercase(); - match str_setting_lower { - "1.0" => Ok(WriterVersion::PARQUET_1_0), - "2.0" => Ok(WriterVersion::PARQUET_2_0), - _ => Err(DataFusionError::Configuration(format!( - "Unknown or unsupported parquet writer version {str_setting} \ - valid options are 1.0 and 2.0" + brotli(level), lz4, zstd(level), and lz4_raw." ))), } } @@ -440,31 +385,28 @@ pub(crate) fn parse_statistics_string(str_setting: &str) -> Result ParquetColumnOptions { - #[allow(deprecated)] // max_statistics_size ParquetColumnOptions { compression: Some("zstd(22)".into()), dictionary_enabled: src_col_defaults.dictionary_enabled.map(|v| !v), statistics_enabled: Some("none".into()), - max_statistics_size: Some(72), encoding: Some("RLE".into()), bloom_filter_enabled: Some(true), bloom_filter_fpp: Some(0.72), @@ -474,22 +416,21 @@ mod tests { fn parquet_options_with_non_defaults() -> ParquetOptions { let defaults = ParquetOptions::default(); - let writer_version = if defaults.writer_version.eq("1.0") { - "2.0" + let writer_version = if defaults.writer_version.eq(&DFParquetWriterVersion::V1_0) + { + DFParquetWriterVersion::V2_0 } else { - "1.0" + DFParquetWriterVersion::V1_0 }; - #[allow(deprecated)] // max_statistics_size ParquetOptions { data_pagesize_limit: 42, write_batch_size: 42, - writer_version: writer_version.into(), + writer_version, compression: Some("zstd(22)".into()), dictionary_enabled: Some(!defaults.dictionary_enabled.unwrap_or(false)), dictionary_page_size_limit: 42, statistics_enabled: Some("chunk".into()), - max_statistics_size: Some(42), max_row_group_size: 42, created_by: "wordy".into(), column_index_truncate_length: Some(42), @@ -507,6 +448,7 @@ mod tests { metadata_size_hint: defaults.metadata_size_hint, pushdown_filters: defaults.pushdown_filters, reorder_filters: defaults.reorder_filters, + force_filter_selections: defaults.force_filter_selections, allow_single_file_parallelism: defaults.allow_single_file_parallelism, maximum_parallel_row_group_writers: defaults .maximum_parallel_row_group_writers, @@ -517,6 +459,7 @@ mod tests { binary_as_string: defaults.binary_as_string, skip_arrow_metadata: defaults.skip_arrow_metadata, coerce_int96: None, + max_predicate_cache_size: defaults.max_predicate_cache_size, } } @@ -526,7 +469,6 @@ mod tests { ) -> ParquetColumnOptions { let bloom_filter_default_props = props.bloom_filter_properties(&col); - #[allow(deprecated)] // max_statistics_size ParquetColumnOptions { bloom_filter_enabled: Some(bloom_filter_default_props.is_some()), encoding: props.encoding(&col).map(|s| s.to_string()), @@ -547,7 +489,6 @@ mod tests { ), bloom_filter_fpp: bloom_filter_default_props.map(|p| p.fpp), bloom_filter_ndv: bloom_filter_default_props.map(|p| p.ndv), - max_statistics_size: Some(props.max_statistics_size(&col)), } } @@ -580,15 +521,24 @@ mod tests { HashMap::from([(COL_NAME.into(), configured_col_props)]) }; - #[allow(deprecated)] // max_statistics_size + #[cfg(feature = "parquet_encryption")] + let fep = props + .file_encryption_properties() + .map(ConfigFileEncryptionProperties::from); + + #[cfg(not(feature = "parquet_encryption"))] + let fep = None; + TableParquetOptions { global: ParquetOptions { // global options data_pagesize_limit: props.dictionary_page_size_limit(), write_batch_size: props.write_batch_size(), - writer_version: format!("{}.0", props.writer_version().as_num()), + writer_version: props.writer_version().into(), dictionary_page_size_limit: props.dictionary_page_size_limit(), - max_row_group_size: props.max_row_group_size(), + max_row_group_size: props + .max_row_group_row_count() + .unwrap_or(DEFAULT_MAX_ROW_GROUP_ROW_COUNT), created_by: props.created_by().to_string(), column_index_truncate_length: props.column_index_truncate_length(), statistics_truncate_length: props.statistics_truncate_length(), @@ -599,7 +549,6 @@ mod tests { compression: default_col_props.compression, dictionary_enabled: default_col_props.dictionary_enabled, statistics_enabled: default_col_props.statistics_enabled, - max_statistics_size: default_col_props.max_statistics_size, bloom_filter_on_write: default_col_props .bloom_filter_enabled .unwrap_or_default(), @@ -613,6 +562,7 @@ mod tests { metadata_size_hint: global_options_defaults.metadata_size_hint, pushdown_filters: global_options_defaults.pushdown_filters, reorder_filters: global_options_defaults.reorder_filters, + force_filter_selections: global_options_defaults.force_filter_selections, allow_single_file_parallelism: global_options_defaults .allow_single_file_parallelism, maximum_parallel_row_group_writers: global_options_defaults @@ -620,6 +570,8 @@ mod tests { maximum_buffered_record_batches_per_stream: global_options_defaults .maximum_buffered_record_batches_per_stream, bloom_filter_on_read: global_options_defaults.bloom_filter_on_read, + max_predicate_cache_size: global_options_defaults + .max_predicate_cache_size, schema_force_view_types: global_options_defaults.schema_force_view_types, binary_as_string: global_options_defaults.binary_as_string, skip_arrow_metadata: global_options_defaults.skip_arrow_metadata, @@ -627,6 +579,12 @@ mod tests { }, column_specific_options, key_value_metadata, + crypto: ParquetEncryptionOptions { + file_encryption: fep, + file_decryption: None, + factory_id: None, + factory_options: Default::default(), + }, } } @@ -681,6 +639,7 @@ mod tests { )] .into(), key_value_metadata: [(key, value)].into(), + crypto: Default::default(), }; let writer_props = WriterPropertiesBuilder::try_from(&table_parquet_opts) @@ -701,8 +660,7 @@ mod tests { let mut default_table_writer_opts = TableParquetOptions::default(); let default_parquet_opts = ParquetOptions::default(); assert_eq!( - default_table_writer_opts.global, - default_parquet_opts, + default_table_writer_opts.global, default_parquet_opts, "should have matching defaults for TableParquetOptions.global and ParquetOptions", ); @@ -726,7 +684,9 @@ mod tests { "should have different created_by sources", ); assert!( - default_writer_props.created_by().starts_with("parquet-rs version"), + default_writer_props + .created_by() + .starts_with("parquet-rs version"), "should indicate that writer_props defaults came from the extern parquet crate", ); assert!( @@ -760,8 +720,7 @@ mod tests { from_extern_parquet.global.skip_arrow_metadata = true; assert_eq!( - default_table_writer_opts, - from_extern_parquet, + default_table_writer_opts, from_extern_parquet, "the default writer_props should have the same configuration as the session's default TableParquetOptions", ); } diff --git a/datafusion/common/src/format.rs b/datafusion/common/src/format.rs index a4ebd17539996..a505bd0e1c74e 100644 --- a/datafusion/common/src/format.rs +++ b/datafusion/common/src/format.rs @@ -15,9 +15,15 @@ // specific language governing permissions and limitations // under the License. +use std::fmt::{self, Display}; +use std::str::FromStr; + use arrow::compute::CastOptions; use arrow::util::display::{DurationFormat, FormatOptions}; +use crate::config::{ConfigField, Visit}; +use crate::error::{DataFusionError, Result}; + /// The default [`FormatOptions`] to use within DataFusion /// Also see [`crate::config::FormatOptions`] pub const DEFAULT_FORMAT_OPTIONS: FormatOptions<'static> = @@ -28,3 +34,219 @@ pub const DEFAULT_CAST_OPTIONS: CastOptions<'static> = CastOptions { safe: false, format_options: DEFAULT_FORMAT_OPTIONS, }; + +/// Output formats for controlling for Explain plans +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum ExplainFormat { + /// Indent mode + /// + /// Example: + /// ```text + /// > explain format indent select x from values (1) t(x); + /// +---------------+-----------------------------------------------------+ + /// | plan_type | plan | + /// +---------------+-----------------------------------------------------+ + /// | logical_plan | SubqueryAlias: t | + /// | | Projection: column1 AS x | + /// | | Values: (Int64(1)) | + /// | physical_plan | ProjectionExec: expr=[column1@0 as x] | + /// | | DataSourceExec: partitions=1, partition_sizes=[1] | + /// | | | + /// +---------------+-----------------------------------------------------+ + /// ``` + Indent, + /// Tree mode + /// + /// Example: + /// ```text + /// > explain format tree select x from values (1) t(x); + /// +---------------+-------------------------------+ + /// | plan_type | plan | + /// +---------------+-------------------------------+ + /// | physical_plan | ┌───────────────────────────┐ | + /// | | │ ProjectionExec │ | + /// | | │ -------------------- │ | + /// | | │ x: column1@0 │ | + /// | | └─────────────┬─────────────┘ | + /// | | ┌─────────────┴─────────────┐ | + /// | | │ DataSourceExec │ | + /// | | │ -------------------- │ | + /// | | │ bytes: 128 │ | + /// | | │ format: memory │ | + /// | | │ rows: 1 │ | + /// | | └───────────────────────────┘ | + /// | | | + /// +---------------+-------------------------------+ + /// ``` + Tree, + /// Postgres Json mode + /// + /// A displayable structure that produces plan in postgresql JSON format. + /// + /// Users can use this format to visualize the plan in existing plan + /// visualization tools, for example [dalibo](https://explain.dalibo.com/) + /// + /// Example: + /// ```text + /// > explain format pgjson select x from values (1) t(x); + /// +--------------+--------------------------------------+ + /// | plan_type | plan | + /// +--------------+--------------------------------------+ + /// | logical_plan | [ | + /// | | { | + /// | | "Plan": { | + /// | | "Alias": "t", | + /// | | "Node Type": "Subquery", | + /// | | "Output": [ | + /// | | "x" | + /// | | ], | + /// | | "Plans": [ | + /// | | { | + /// | | "Expressions": [ | + /// | | "column1 AS x" | + /// | | ], | + /// | | "Node Type": "Projection", | + /// | | "Output": [ | + /// | | "x" | + /// | | ], | + /// | | "Plans": [ | + /// | | { | + /// | | "Node Type": "Values", | + /// | | "Output": [ | + /// | | "column1" | + /// | | ], | + /// | | "Plans": [], | + /// | | "Values": "(Int64(1))" | + /// | | } | + /// | | ] | + /// | | } | + /// | | ] | + /// | | } | + /// | | } | + /// | | ] | + /// +--------------+--------------------------------------+ + /// ``` + PostgresJSON, + /// Graphviz mode + /// + /// Example: + /// ```text + /// > explain format graphviz select x from values (1) t(x); + /// +--------------+------------------------------------------------------------------------+ + /// | plan_type | plan | + /// +--------------+------------------------------------------------------------------------+ + /// | logical_plan | | + /// | | // Begin DataFusion GraphViz Plan, | + /// | | // display it online here: https://dreampuf.github.io/GraphvizOnline | + /// | | | + /// | | digraph { | + /// | | subgraph cluster_1 | + /// | | { | + /// | | graph[label="LogicalPlan"] | + /// | | 2[shape=box label="SubqueryAlias: t"] | + /// | | 3[shape=box label="Projection: column1 AS x"] | + /// | | 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back] | + /// | | 4[shape=box label="Values: (Int64(1))"] | + /// | | 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back] | + /// | | } | + /// | | subgraph cluster_5 | + /// | | { | + /// | | graph[label="Detailed LogicalPlan"] | + /// | | 6[shape=box label="SubqueryAlias: t\nSchema: [x:Int64;N]"] | + /// | | 7[shape=box label="Projection: column1 AS x\nSchema: [x:Int64;N]"] | + /// | | 6 -> 7 [arrowhead=none, arrowtail=normal, dir=back] | + /// | | 8[shape=box label="Values: (Int64(1))\nSchema: [column1:Int64;N]"] | + /// | | 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back] | + /// | | } | + /// | | } | + /// | | // End DataFusion GraphViz Plan | + /// | | | + /// +--------------+------------------------------------------------------------------------+ + /// ``` + Graphviz, +} + +/// Implement parsing strings to `ExplainFormat` +impl FromStr for ExplainFormat { + type Err = DataFusionError; + + fn from_str(format: &str) -> Result { + match format.to_lowercase().as_str() { + "indent" => Ok(ExplainFormat::Indent), + "tree" => Ok(ExplainFormat::Tree), + "pgjson" => Ok(ExplainFormat::PostgresJSON), + "graphviz" => Ok(ExplainFormat::Graphviz), + _ => Err(DataFusionError::Configuration(format!( + "Invalid explain format. Expected 'indent', 'tree', 'pgjson' or 'graphviz'. Got '{format}'" + ))), + } + } +} + +impl Display for ExplainFormat { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match self { + ExplainFormat::Indent => "indent", + ExplainFormat::Tree => "tree", + ExplainFormat::PostgresJSON => "pgjson", + ExplainFormat::Graphviz => "graphviz", + }; + write!(f, "{s}") + } +} + +impl ConfigField for ExplainFormat { + fn visit(&self, v: &mut V, key: &str, description: &'static str) { + v.some(key, self, description) + } + + fn set(&mut self, _: &str, value: &str) -> Result<()> { + *self = ExplainFormat::from_str(value)?; + Ok(()) + } +} + +/// Verbosity levels controlling how `EXPLAIN ANALYZE` renders metrics +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum ExplainAnalyzeLevel { + /// Show a compact view containing high-level metrics + Summary, + /// Show a developer-focused view with per-operator details + Dev, + // When adding new enum, update the error message in `from_str()` accordingly. +} + +impl FromStr for ExplainAnalyzeLevel { + type Err = DataFusionError; + + fn from_str(level: &str) -> Result { + match level.to_lowercase().as_str() { + "summary" => Ok(ExplainAnalyzeLevel::Summary), + "dev" => Ok(ExplainAnalyzeLevel::Dev), + other => Err(DataFusionError::Configuration(format!( + "Invalid explain analyze level. Expected 'summary' or 'dev'. Got '{other}'" + ))), + } + } +} + +impl Display for ExplainAnalyzeLevel { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match self { + ExplainAnalyzeLevel::Summary => "summary", + ExplainAnalyzeLevel::Dev => "dev", + }; + write!(f, "{s}") + } +} + +impl ConfigField for ExplainAnalyzeLevel { + fn visit(&self, v: &mut V, key: &str, description: &'static str) { + v.some(key, self, description) + } + + fn set(&mut self, _: &str, value: &str) -> Result<()> { + *self = ExplainAnalyzeLevel::from_str(value)?; + Ok(()) + } +} diff --git a/datafusion/common/src/functional_dependencies.rs b/datafusion/common/src/functional_dependencies.rs index 77e00d6dcda23..63962998ad18b 100644 --- a/datafusion/common/src/functional_dependencies.rs +++ b/datafusion/common/src/functional_dependencies.rs @@ -36,35 +36,31 @@ pub enum Constraint { } /// This object encapsulates a list of functional constraints: -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, PartialOrd)] pub struct Constraints { inner: Vec, } impl Constraints { - /// Create empty constraints - pub fn empty() -> Self { - Constraints::new_unverified(vec![]) - } - /// Create a new [`Constraints`] object from the given `constraints`. - /// Users should use the [`Constraints::empty`] or [`SqlToRel::new_constraint_from_table_constraints`] functions - /// for constructing [`Constraints`]. This constructor is for internal - /// purposes only and does not check whether the argument is valid. The user - /// is responsible for supplying a valid vector of [`Constraint`] objects. + /// Users should use the [`Constraints::default`] or [`SqlToRel::new_constraint_from_table_constraints`] + /// functions for constructing [`Constraints`] instances. This constructor + /// is for internal purposes only and does not check whether the argument + /// is valid. The user is responsible for supplying a valid vector of + /// [`Constraint`] objects. /// /// [`SqlToRel::new_constraint_from_table_constraints`]: https://docs.rs/datafusion/latest/datafusion/sql/planner/struct.SqlToRel.html#method.new_constraint_from_table_constraints pub fn new_unverified(constraints: Vec) -> Self { Self { inner: constraints } } - /// Check whether constraints is empty - pub fn is_empty(&self) -> bool { - self.inner.is_empty() + /// Extends the current constraints with the given `other` constraints. + pub fn extend(&mut self, other: Constraints) { + self.inner.extend(other.inner); } - /// Projects constraints using the given projection indices. - /// Returns None if any of the constraint columns are not included in the projection. + /// Projects constraints using the given projection indices. Returns `None` + /// if any of the constraint columns are not included in the projection. pub fn project(&self, proj_indices: &[usize]) -> Option { let projected = self .inner @@ -74,14 +70,14 @@ impl Constraints { Constraint::PrimaryKey(indices) => { let new_indices = update_elements_with_matching_indices(indices, proj_indices); - // Only keep constraint if all columns are preserved + // Only keep the constraint if all columns are preserved: (new_indices.len() == indices.len()) .then_some(Constraint::PrimaryKey(new_indices)) } Constraint::Unique(indices) => { let new_indices = update_elements_with_matching_indices(indices, proj_indices); - // Only keep constraint if all columns are preserved + // Only keep the constraint if all columns are preserved: (new_indices.len() == indices.len()) .then_some(Constraint::Unique(new_indices)) } @@ -93,15 +89,9 @@ impl Constraints { } } -impl Default for Constraints { - fn default() -> Self { - Constraints::empty() - } -} - impl IntoIterator for Constraints { type Item = Constraint; - type IntoIter = IntoIter; + type IntoIter = IntoIter; fn into_iter(self) -> Self::IntoIter { self.inner.into_iter() @@ -374,7 +364,7 @@ impl FunctionalDependencies { // These joins preserve functional dependencies of the left side: left_func_dependencies } - JoinType::RightSemi | JoinType::RightAnti => { + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { // These joins preserve functional dependencies of the right side: right_func_dependencies } diff --git a/datafusion/common/src/hash_utils.rs b/datafusion/common/src/hash_utils.rs index e78d42257b9cb..255525b92e0c0 100644 --- a/datafusion/common/src/hash_utils.rs +++ b/datafusion/common/src/hash_utils.rs @@ -17,25 +17,30 @@ //! Functionality used both on logical and physical plans -#[cfg(not(feature = "force_hash_collisions"))] -use std::sync::Arc; - -use ahash::RandomState; use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano}; use arrow::array::*; +use arrow::compute::take; use arrow::datatypes::*; #[cfg(not(feature = "force_hash_collisions"))] use arrow::{downcast_dictionary_array, downcast_primitive_array}; +use foldhash::fast::FixedState; +use itertools::Itertools; +use std::collections::HashMap; +use std::hash::{BuildHasher, Hash, Hasher}; + +/// The hash random state used throughout DataFusion for hashing. +pub type RandomState = FixedState; #[cfg(not(feature = "force_hash_collisions"))] use crate::cast::{ as_binary_view_array, as_boolean_array, as_fixed_size_list_array, - as_generic_binary_array, as_large_list_array, as_list_array, as_map_array, - as_string_array, as_string_view_array, as_struct_array, + as_generic_binary_array, as_large_list_array, as_large_list_view_array, + as_list_array, as_list_view_array, as_map_array, as_string_array, + as_string_view_array, as_struct_array, as_union_array, }; use crate::error::Result; -#[cfg(not(feature = "force_hash_collisions"))] -use crate::error::_internal_err; +use crate::error::{_internal_datafusion_err, _internal_err}; +use std::cell::RefCell; // Combines two hashes into one hash #[inline] @@ -44,6 +49,94 @@ pub fn combine_hashes(l: u64, r: u64) -> u64 { hash.wrapping_mul(37).wrapping_add(r) } +/// Maximum size for the thread-local hash buffer before truncation (4MB = 524,288 u64 elements). +/// The goal of this is to avoid unbounded memory growth that would appear as a memory leak. +/// We allow temporary allocations beyond this size, but after use the buffer is truncated +/// to this size. +const MAX_BUFFER_SIZE: usize = 524_288; + +thread_local! { + /// Thread-local buffer for hash computations to avoid repeated allocations. + /// The buffer is reused across calls and truncated if it exceeds MAX_BUFFER_SIZE. + /// Defaults to a capacity of 8192 u64 elements which is the default batch size. + /// This corresponds to 64KB of memory. + static HASH_BUFFER: RefCell> = const { RefCell::new(Vec::new()) }; +} + +/// Creates hashes for the given arrays using a thread-local buffer, then calls the provided callback +/// with an immutable reference to the computed hashes. +/// +/// This function manages a thread-local buffer to avoid repeated allocations. The buffer is automatically +/// truncated if it exceeds `MAX_BUFFER_SIZE` after use. +/// +/// # Arguments +/// * `arrays` - The arrays to hash (must contain at least one array) +/// * `random_state` - The random state for hashing +/// * `callback` - A function that receives an immutable reference to the hash slice and returns a result +/// +/// # Errors +/// Returns an error if: +/// - No arrays are provided +/// - The function is called reentrantly (i.e., the callback invokes `with_hashes` again on the same thread) +/// - The function is called during or after thread destruction +/// +/// # Example +/// ```ignore +/// use datafusion_common::hash_utils::{with_hashes, RandomState}; +/// use arrow::array::{Int32Array, ArrayRef}; +/// use std::sync::Arc; +/// +/// let array: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); +/// let random_state = RandomState::default(); +/// +/// let result = with_hashes([&array], &random_state, |hashes| { +/// // Use the hashes here +/// Ok(hashes.len()) +/// })?; +/// ``` +pub fn with_hashes( + arrays: I, + random_state: &RandomState, + callback: F, +) -> Result +where + I: IntoIterator, + T: AsDynArray, + F: FnOnce(&[u64]) -> Result, +{ + // Peek at the first array to determine buffer size without fully collecting + let mut iter = arrays.into_iter().peekable(); + + // Get the required size from the first array + let required_size = match iter.peek() { + Some(arr) => arr.as_dyn_array().len(), + None => return _internal_err!("with_hashes requires at least one array"), + }; + + HASH_BUFFER.try_with(|cell| { + let mut buffer = cell.try_borrow_mut() + .map_err(|_| _internal_datafusion_err!("with_hashes cannot be called reentrantly on the same thread"))?; + + // Ensure buffer has sufficient length, clearing old values + buffer.clear(); + buffer.resize(required_size, 0); + + // Create hashes in the buffer - this consumes the iterator + create_hashes(iter, random_state, &mut buffer[..required_size])?; + + // Execute the callback with an immutable slice + let result = callback(&buffer[..required_size])?; + + // Cleanup: truncate if buffer grew too large + if buffer.capacity() > MAX_BUFFER_SIZE { + buffer.truncate(MAX_BUFFER_SIZE); + buffer.shrink_to_fit(); + } + + Ok(result) + }).map_err(|_| _internal_datafusion_err!("with_hashes cannot access thread-local storage during or after thread destruction"))? +} + #[cfg(not(feature = "force_hash_collisions"))] fn hash_null(random_state: &RandomState, hashes_buffer: &'_ mut [u64], mul_col: bool) { if mul_col { @@ -60,12 +153,17 @@ fn hash_null(random_state: &RandomState, hashes_buffer: &'_ mut [u64], mul_col: pub trait HashValue { fn hash_one(&self, state: &RandomState) -> u64; + /// Write this value into an existing hasher (same data as `hash_one`). + fn hash_write(&self, hasher: &mut impl Hasher); } impl HashValue for &T { fn hash_one(&self, state: &RandomState) -> u64 { T::hash_one(self, state) } + fn hash_write(&self, hasher: &mut impl Hasher) { + T::hash_write(self, hasher) + } } macro_rules! hash_value { @@ -74,10 +172,13 @@ macro_rules! hash_value { fn hash_one(&self, state: &RandomState) -> u64 { state.hash_one(self) } + fn hash_write(&self, hasher: &mut impl Hasher) { + Hash::hash(self, hasher) + } })+ }; } -hash_value!(i8, i16, i32, i64, i128, i256, u8, u16, u32, u64); +hash_value!(i8, i16, i32, i64, i128, i256, u8, u16, u32, u64, u128); hash_value!(bool, str, [u8], IntervalDayTime, IntervalMonthDayNano); macro_rules! hash_float_value { @@ -86,14 +187,28 @@ macro_rules! hash_float_value { fn hash_one(&self, state: &RandomState) -> u64 { state.hash_one(<$i>::from_ne_bytes(self.to_ne_bytes())) } + fn hash_write(&self, hasher: &mut impl Hasher) { + hasher.write(&self.to_ne_bytes()) + } })+ }; } hash_float_value!((half::f16, u16), (f32, u32), (f64, u64)); +/// Create a `SeedableRandomState` whose per-hasher seed incorporates `seed`. +/// This folds the previous hash into the hasher's initial state so only the +/// new value needs to pass through the hash function — same cost as `hash_one`. +#[inline] +fn seeded_state(seed: u64) -> foldhash::fast::SeedableRandomState { + foldhash::fast::SeedableRandomState::with_seed( + seed, + foldhash::SharedSeed::global_fixed(), + ) +} + /// Builds hash values of PrimitiveArray and writes them into `hashes_buffer` -/// If `rehash==true` this combines the previous hash value in the buffer -/// with the new hash using `combine_hashes` +/// If `rehash==true` this folds the existing hash into the hasher state +/// and hashes only the new value (avoiding a separate combine step). #[cfg(not(feature = "force_hash_collisions"))] fn hash_array_primitive( array: &PrimitiveArray, @@ -112,7 +227,9 @@ fn hash_array_primitive( if array.null_count() == 0 { if rehash { for (hash, &value) in hashes_buffer.iter_mut().zip(array.values().iter()) { - *hash = combine_hashes(value.hash_one(random_state), *hash); + let mut hasher = seeded_state(*hash).build_hasher(); + value.hash_write(&mut hasher); + *hash = hasher.finish(); } } else { for (hash, &value) in hashes_buffer.iter_mut().zip(array.values().iter()) { @@ -120,18 +237,16 @@ fn hash_array_primitive( } } } else if rehash { - for (i, hash) in hashes_buffer.iter_mut().enumerate() { - if !array.is_null(i) { - let value = unsafe { array.value_unchecked(i) }; - *hash = combine_hashes(value.hash_one(random_state), *hash); - } + for i in array.nulls().unwrap().valid_indices() { + let value = unsafe { array.value_unchecked(i) }; + let mut hasher = seeded_state(hashes_buffer[i]).build_hasher(); + value.hash_write(&mut hasher); + hashes_buffer[i] = hasher.finish(); } } else { - for (i, hash) in hashes_buffer.iter_mut().enumerate() { - if !array.is_null(i) { - let value = unsafe { array.value_unchecked(i) }; - *hash = value.hash_one(random_state); - } + for i in array.nulls().unwrap().valid_indices() { + let value = unsafe { array.value_unchecked(i) }; + hashes_buffer[i] = value.hash_one(random_state); } } } @@ -141,7 +256,7 @@ fn hash_array_primitive( /// with the new hash using `combine_hashes` #[cfg(not(feature = "force_hash_collisions"))] fn hash_array( - array: T, + array: &T, random_state: &RandomState, hashes_buffer: &mut [u64], rehash: bool, @@ -168,54 +283,255 @@ fn hash_array( } } } else if rehash { - for (i, hash) in hashes_buffer.iter_mut().enumerate() { - if !array.is_null(i) { - let value = unsafe { array.value_unchecked(i) }; - *hash = combine_hashes(value.hash_one(random_state), *hash); - } + for i in array.nulls().unwrap().valid_indices() { + let value = unsafe { array.value_unchecked(i) }; + hashes_buffer[i] = + combine_hashes(value.hash_one(random_state), hashes_buffer[i]); } } else { - for (i, hash) in hashes_buffer.iter_mut().enumerate() { - if !array.is_null(i) { - let value = unsafe { array.value_unchecked(i) }; - *hash = value.hash_one(random_state); + for i in array.nulls().unwrap().valid_indices() { + let value = unsafe { array.value_unchecked(i) }; + hashes_buffer[i] = value.hash_one(random_state); + } + } +} + +/// Hash a StringView or BytesView array +/// +/// Templated to optimize inner loop based on presence of nulls and external buffers. +/// +/// HAS_NULLS: do we have to check null in the inner loop +/// HAS_BUFFERS: if true, array has external buffers; if false, all strings are inlined/ less then 12 bytes +/// REHASH: if true, combining with existing hash, otherwise initializing +#[inline(never)] +fn hash_string_view_array_inner< + T: ByteViewType, + const HAS_NULLS: bool, + const HAS_BUFFERS: bool, + const REHASH: bool, +>( + array: &GenericByteViewArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], +) { + assert_eq!( + hashes_buffer.len(), + array.len(), + "hashes_buffer and array should be of equal length" + ); + + let buffers = array.data_buffers(); + let view_bytes = |view_len: u32, view: u128| { + let view = ByteView::from(view); + let offset = view.offset as usize; + // SAFETY: view is a valid view as it came from the array + unsafe { + let data = buffers.get_unchecked(view.buffer_index as usize); + data.get_unchecked(offset..offset + view_len as usize) + } + }; + + let hashes_and_views = hashes_buffer.iter_mut().zip(array.views().iter()); + for (i, (hash, &v)) in hashes_and_views.enumerate() { + if HAS_NULLS && array.is_null(i) { + continue; + } + let view_len = v as u32; + // all views are inlined, no need to access external buffers + if !HAS_BUFFERS || view_len <= 12 { + if REHASH { + let mut hasher = seeded_state(*hash).build_hasher(); + v.hash_write(&mut hasher); + *hash = hasher.finish(); + } else { + *hash = v.hash_one(random_state); } + continue; + } + // view is not inlined, so we need to hash the bytes as well + let value = view_bytes(view_len, v); + if REHASH { + let mut hasher = seeded_state(*hash).build_hasher(); + value.hash_write(&mut hasher); + *hash = hasher.finish(); + } else { + *hash = value.hash_one(random_state); } } } -/// Hash the values in a dictionary array +/// Builds hash values for array views and writes them into `hashes_buffer` +/// If `rehash==true` this combines the previous hash value in the buffer +/// with the new hash using `combine_hashes` #[cfg(not(feature = "force_hash_collisions"))] -fn hash_dictionary( +fn hash_generic_byte_view_array( + array: &GenericByteViewArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], + rehash: bool, +) { + // instantiate the correct version based on presence of nulls and external buffers + match ( + array.null_count() != 0, + !array.data_buffers().is_empty(), + rehash, + ) { + // no nulls or buffers ==> hash the inlined views directly + // don't call the inner function as Rust seems better able to inline this simpler code (2-3% faster) + (false, false, false) => { + for (hash, &view) in hashes_buffer.iter_mut().zip(array.views().iter()) { + *hash = view.hash_one(random_state); + } + } + (false, false, true) => { + for (hash, &view) in hashes_buffer.iter_mut().zip(array.views().iter()) { + let mut hasher = seeded_state(*hash).build_hasher(); + view.hash_write(&mut hasher); + *hash = hasher.finish(); + } + } + (false, true, false) => hash_string_view_array_inner::( + array, + random_state, + hashes_buffer, + ), + (false, true, true) => hash_string_view_array_inner::( + array, + random_state, + hashes_buffer, + ), + (true, false, false) => hash_string_view_array_inner::( + array, + random_state, + hashes_buffer, + ), + (true, false, true) => hash_string_view_array_inner::( + array, + random_state, + hashes_buffer, + ), + (true, true, false) => hash_string_view_array_inner::( + array, + random_state, + hashes_buffer, + ), + (true, true, true) => hash_string_view_array_inner::( + array, + random_state, + hashes_buffer, + ), + } +} + +/// Hash dictionary array with compile-time specialization for null handling. +/// +/// Uses const generics to eliminate runtim branching in the hot loop: +/// - `HAS_NULL_KEYS`: Whether to check for null dictionary keys +/// - `HAS_NULL_VALUES`: Whether to check for null dictionary values +/// - `MULTI_COL`: Whether to combine with existing hash (true) or initialize (false) +#[inline(never)] +fn hash_dictionary_inner< + K: ArrowDictionaryKeyType, + const HAS_NULL_KEYS: bool, + const HAS_NULL_VALUES: bool, + const MULTI_COL: bool, +>( array: &DictionaryArray, random_state: &RandomState, hashes_buffer: &mut [u64], - multi_col: bool, ) -> Result<()> { // Hash each dictionary value once, and then use that computed // hash for each key value to avoid a potentially expensive // redundant hashing for large dictionary elements (e.g. strings) - let values = Arc::clone(array.values()); - let mut dict_hashes = vec![0; values.len()]; - create_hashes(&[values], random_state, &mut dict_hashes)?; + let dict_values = array.values(); + let mut dict_hashes = vec![0; dict_values.len()]; + create_hashes([dict_values], random_state, &mut dict_hashes)?; - // combine hash for each index in values - if multi_col { + if HAS_NULL_KEYS { for (hash, key) in hashes_buffer.iter_mut().zip(array.keys().iter()) { if let Some(key) = key { - *hash = combine_hashes(dict_hashes[key.as_usize()], *hash) - } // no update for Null, consistent with other hashes + let idx = key.as_usize(); + if !HAS_NULL_VALUES || dict_values.is_valid(idx) { + if MULTI_COL { + *hash = combine_hashes(dict_hashes[idx], *hash); + } else { + *hash = dict_hashes[idx]; + } + } + } } } else { - for (hash, key) in hashes_buffer.iter_mut().zip(array.keys().iter()) { - if let Some(key) = key { - *hash = dict_hashes[key.as_usize()] - } // no update for Null, consistent with other hashes + for (hash, key) in hashes_buffer.iter_mut().zip(array.keys().values()) { + let idx = key.as_usize(); + if !HAS_NULL_VALUES || dict_values.is_valid(idx) { + if MULTI_COL { + *hash = combine_hashes(dict_hashes[idx], *hash); + } else { + *hash = dict_hashes[idx]; + } + } } } Ok(()) } +/// Hash the values in a dictionary array +#[cfg(not(feature = "force_hash_collisions"))] +fn hash_dictionary( + array: &DictionaryArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], + multi_col: bool, +) -> Result<()> { + let has_null_keys = array.keys().null_count() != 0; + let has_null_values = array.values().null_count() != 0; + + // Dispatcher based on null presence and multi-column mode + // Should reduce branching within hot loops + match (has_null_keys, has_null_values, multi_col) { + (false, false, false) => hash_dictionary_inner::( + array, + random_state, + hashes_buffer, + ), + (false, false, true) => hash_dictionary_inner::( + array, + random_state, + hashes_buffer, + ), + (false, true, false) => hash_dictionary_inner::( + array, + random_state, + hashes_buffer, + ), + (false, true, true) => hash_dictionary_inner::( + array, + random_state, + hashes_buffer, + ), + (true, false, false) => hash_dictionary_inner::( + array, + random_state, + hashes_buffer, + ), + (true, false, true) => hash_dictionary_inner::( + array, + random_state, + hashes_buffer, + ), + (true, true, false) => hash_dictionary_inner::( + array, + random_state, + hashes_buffer, + ), + (true, true, true) => hash_dictionary_inner::( + array, + random_state, + hashes_buffer, + ), + } +} + #[cfg(not(feature = "force_hash_collisions"))] fn hash_struct_array( array: &StructArray, @@ -225,19 +541,21 @@ fn hash_struct_array( let nulls = array.nulls(); let row_len = array.len(); - let valid_row_indices: Vec = if let Some(nulls) = nulls { - nulls.valid_indices().collect() - } else { - (0..row_len).collect() - }; - // Create hashes for each row that combines the hashes over all the column at that row. let mut values_hashes = vec![0u64; row_len]; create_hashes(array.columns(), random_state, &mut values_hashes)?; - for i in valid_row_indices { - let hash = &mut hashes_buffer[i]; - *hash = combine_hashes(*hash, values_hashes[i]); + // Separate paths to avoid allocating Vec when there are no nulls + if let Some(nulls) = nulls { + for i in nulls.valid_indices() { + let hash = &mut hashes_buffer[i]; + *hash = combine_hashes(*hash, values_hashes[i]); + } + } else { + for i in 0..row_len { + let hash = &mut hashes_buffer[i]; + *hash = combine_hashes(*hash, values_hashes[i]); + } } Ok(()) @@ -254,15 +572,29 @@ fn hash_map_array( let offsets = array.offsets(); // Create hashes for each entry in each row - let mut values_hashes = vec![0u64; array.entries().len()]; - create_hashes(array.entries().columns(), random_state, &mut values_hashes)?; + let first_offset = offsets.first().copied().unwrap_or_default() as usize; + let last_offset = offsets.last().copied().unwrap_or_default() as usize; + let entries_len = last_offset - first_offset; + + // Only hash the entries that are actually referenced + let mut values_hashes = vec![0u64; entries_len]; + let entries = array.entries(); + let sliced_columns: Vec = entries + .columns() + .iter() + .map(|col| col.slice(first_offset, entries_len)) + .collect(); + create_hashes(&sliced_columns, random_state, &mut values_hashes)?; // Combine the hashes for entries on each row with each other and previous hash for that row + // Adjust indices by first_offset since values_hashes is sliced starting from first_offset if let Some(nulls) = nulls { for (i, (start, stop)) in offsets.iter().zip(offsets.iter().skip(1)).enumerate() { if nulls.is_valid(i) { let hash = &mut hashes_buffer[i]; - for values_hash in &values_hashes[start.as_usize()..stop.as_usize()] { + for values_hash in &values_hashes + [start.as_usize() - first_offset..stop.as_usize() - first_offset] + { *hash = combine_hashes(*hash, *values_hash); } } @@ -270,7 +602,9 @@ fn hash_map_array( } else { for (i, (start, stop)) in offsets.iter().zip(offsets.iter().skip(1)).enumerate() { let hash = &mut hashes_buffer[i]; - for values_hash in &values_hashes[start.as_usize()..stop.as_usize()] { + for values_hash in &values_hashes + [start.as_usize() - first_offset..stop.as_usize() - first_offset] + { *hash = combine_hashes(*hash, *values_hash); } } @@ -288,24 +622,80 @@ fn hash_list_array( where OffsetSize: OffsetSizeTrait, { - let values = Arc::clone(array.values()); + // In case values is sliced, hash only the bytes used by the offsets of this ListArray + let first_offset = array.value_offsets().first().cloned().unwrap_or_default(); + let last_offset = array.value_offsets().last().cloned().unwrap_or_default(); + let value_bytes_len = (last_offset - first_offset).as_usize(); + let mut values_hashes = vec![0u64; value_bytes_len]; + create_hashes( + [array + .values() + .slice(first_offset.as_usize(), value_bytes_len)], + random_state, + &mut values_hashes, + )?; + + if array.null_count() > 0 { + for (i, (start, stop)) in array.value_offsets().iter().tuple_windows().enumerate() + { + if array.is_valid(i) { + let hash = &mut hashes_buffer[i]; + for values_hash in &values_hashes[(*start - first_offset).as_usize() + ..(*stop - first_offset).as_usize()] + { + *hash = combine_hashes(*hash, *values_hash); + } + } + } + } else { + for ((start, stop), hash) in array + .value_offsets() + .iter() + .tuple_windows() + .zip(hashes_buffer.iter_mut()) + { + for values_hash in &values_hashes + [(*start - first_offset).as_usize()..(*stop - first_offset).as_usize()] + { + *hash = combine_hashes(*hash, *values_hash); + } + } + } + Ok(()) +} + +#[cfg(not(feature = "force_hash_collisions"))] +fn hash_list_view_array( + array: &GenericListViewArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], +) -> Result<()> +where + OffsetSize: OffsetSizeTrait, +{ + let values = array.values(); let offsets = array.value_offsets(); + let sizes = array.value_sizes(); let nulls = array.nulls(); let mut values_hashes = vec![0u64; values.len()]; - create_hashes(&[values], random_state, &mut values_hashes)?; + create_hashes([values], random_state, &mut values_hashes)?; if let Some(nulls) = nulls { - for (i, (start, stop)) in offsets.iter().zip(offsets.iter().skip(1)).enumerate() { + for (i, (offset, size)) in offsets.iter().zip(sizes.iter()).enumerate() { if nulls.is_valid(i) { let hash = &mut hashes_buffer[i]; - for values_hash in &values_hashes[start.as_usize()..stop.as_usize()] { + let start = offset.as_usize(); + let end = start + size.as_usize(); + for values_hash in &values_hashes[start..end] { *hash = combine_hashes(*hash, *values_hash); } } } } else { - for (i, (start, stop)) in offsets.iter().zip(offsets.iter().skip(1)).enumerate() { + for (i, (offset, size)) in offsets.iter().zip(sizes.iter()).enumerate() { let hash = &mut hashes_buffer[i]; - for values_hash in &values_hashes[start.as_usize()..stop.as_usize()] { + let start = offset.as_usize(); + let end = start + size.as_usize(); + for values_hash in &values_hashes[start..end] { *hash = combine_hashes(*hash, *values_hash); } } @@ -313,17 +703,145 @@ where Ok(()) } +#[cfg(not(feature = "force_hash_collisions"))] +fn hash_union_array( + array: &UnionArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], +) -> Result<()> { + let DataType::Union(union_fields, _mode) = array.data_type() else { + unreachable!() + }; + + if array.is_dense() { + // Dense union: children only contain values of their type, so they're already compact. + // Use the default hashing approach which is efficient for dense unions. + hash_union_array_default(array, union_fields, random_state, hashes_buffer) + } else { + // Sparse union: each child has the same length as the union array. + // Optimization: only hash the elements that are actually referenced by type_ids, + // instead of hashing all K*N elements (where K = num types, N = array length). + hash_sparse_union_array(array, union_fields, random_state, hashes_buffer) + } +} + +/// Default hashing for union arrays - hashes all elements of each child array fully. +/// +/// This approach works for both dense and sparse union arrays: +/// - Dense unions: children are compact (each child only contains values of that type) +/// - Sparse unions: children have the same length as the union array +/// +/// For sparse unions with 3+ types, the optimized take/scatter approach in +/// `hash_sparse_union_array` is more efficient, but for 1-2 types or dense unions, +/// this simpler approach is preferred. +#[cfg(not(feature = "force_hash_collisions"))] +fn hash_union_array_default( + array: &UnionArray, + union_fields: &UnionFields, + random_state: &RandomState, + hashes_buffer: &mut [u64], +) -> Result<()> { + let mut child_hashes: HashMap> = + HashMap::with_capacity(union_fields.len()); + + // Hash each child array fully + for (type_id, _field) in union_fields.iter() { + let child = array.child(type_id); + let mut child_hash_buffer = vec![0; child.len()]; + create_hashes([child], random_state, &mut child_hash_buffer)?; + + child_hashes.insert(type_id, child_hash_buffer); + } + + // Combine hashes for each row using the appropriate child offset + // For dense unions: value_offset points to the actual position in the child + // For sparse unions: value_offset equals the row index + #[expect(clippy::needless_range_loop)] + for i in 0..array.len() { + let type_id = array.type_id(i); + let child_offset = array.value_offset(i); + + let child_hash = child_hashes.get(&type_id).expect("invalid type_id"); + hashes_buffer[i] = combine_hashes(hashes_buffer[i], child_hash[child_offset]); + } + + Ok(()) +} + +/// Hash a sparse union array. +/// Sparse unions have child arrays with the same length as the union array. +/// For 3+ types, we optimize by only hashing the N elements that are actually used +/// (via take/scatter), instead of hashing all K*N elements. +/// +/// For 1-2 types, the overhead of take/scatter outweighs the benefit, so we use +/// the default approach of hashing all children (same as dense unions). +#[cfg(not(feature = "force_hash_collisions"))] +fn hash_sparse_union_array( + array: &UnionArray, + union_fields: &UnionFields, + random_state: &RandomState, + hashes_buffer: &mut [u64], +) -> Result<()> { + use std::collections::HashMap; + + // For 1-2 types, the take/scatter overhead isn't worth it. + // Fall back to the default approach (same as dense union). + if union_fields.len() <= 2 { + return hash_union_array_default( + array, + union_fields, + random_state, + hashes_buffer, + ); + } + + let type_ids = array.type_ids(); + + // Group indices by type_id + let mut indices_by_type: HashMap> = HashMap::new(); + for (i, &type_id) in type_ids.iter().enumerate() { + indices_by_type.entry(type_id).or_default().push(i as u32); + } + + // For each type, extract only the needed elements, hash them, and scatter back + for (type_id, _field) in union_fields.iter() { + if let Some(indices) = indices_by_type.get(&type_id) { + if indices.is_empty() { + continue; + } + + let child = array.child(type_id); + let indices_array = UInt32Array::from(indices.clone()); + + // Extract only the elements we need using take() + let filtered = take(child.as_ref(), &indices_array, None)?; + + // Hash the filtered array + let mut filtered_hashes = vec![0u64; filtered.len()]; + create_hashes([&filtered], random_state, &mut filtered_hashes)?; + + // Scatter hashes back to correct positions + for (hash, &idx) in filtered_hashes.iter().zip(indices.iter()) { + hashes_buffer[idx as usize] = + combine_hashes(hashes_buffer[idx as usize], *hash); + } + } + } + + Ok(()) +} + #[cfg(not(feature = "force_hash_collisions"))] fn hash_fixed_list_array( array: &FixedSizeListArray, random_state: &RandomState, hashes_buffer: &mut [u64], ) -> Result<()> { - let values = Arc::clone(array.values()); + let values = array.values(); let value_length = array.value_length() as usize; let nulls = array.nulls(); let mut values_hashes = vec![0u64; values.len()]; - create_hashes(&[values], random_state, &mut values_hashes)?; + create_hashes([values], random_state, &mut values_hashes)?; if let Some(nulls) = nulls { for i in 0..array.len() { if nulls.is_valid(i) { @@ -346,83 +864,246 @@ fn hash_fixed_list_array( Ok(()) } -/// Test version of `create_hashes` that produces the same value for -/// all hashes (to test collisions) -/// -/// See comments on `hashes_buffer` for more details +/// Inner hash function for RunArray +#[inline(never)] +#[cfg(not(feature = "force_hash_collisions"))] +fn hash_run_array_inner< + R: RunEndIndexType, + const HAS_NULL_VALUES: bool, + const REHASH: bool, +>( + array: &RunArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], +) -> Result<()> { + // We find the relevant runs that cover potentially sliced arrays, so we can only hash those + // values. Then we find the runs that refer to the original runs and ensure that we apply + // hashes correctly to the sliced, whether sliced at the start, end, or both. + let array_offset = array.offset(); + let array_len = array.len(); + + if array_len == 0 { + return Ok(()); + } + + let run_ends = array.run_ends(); + let run_ends_values = run_ends.values(); + let values = array.values(); + + let start_physical_index = array.get_start_physical_index(); + // get_end_physical_index returns the inclusive last index, but we need the exclusive range end + // for the operations we use below. + let end_physical_index = array.get_end_physical_index() + 1; + + let sliced_values = values.slice( + start_physical_index, + end_physical_index - start_physical_index, + ); + let mut values_hashes = vec![0u64; sliced_values.len()]; + create_hashes( + std::slice::from_ref(&sliced_values), + random_state, + &mut values_hashes, + )?; + + let mut start_in_slice = 0; + for (adjusted_physical_index, &absolute_run_end) in run_ends_values + [start_physical_index..end_physical_index] + .iter() + .enumerate() + { + let absolute_run_end = absolute_run_end.as_usize(); + let end_in_slice = (absolute_run_end - array_offset).min(array_len); + + if HAS_NULL_VALUES && sliced_values.is_null(adjusted_physical_index) { + start_in_slice = end_in_slice; + continue; + } + + let value_hash = values_hashes[adjusted_physical_index]; + let run_slice = &mut hashes_buffer[start_in_slice..end_in_slice]; + + if REHASH { + for hash in run_slice.iter_mut() { + *hash = combine_hashes(value_hash, *hash); + } + } else { + run_slice.fill(value_hash); + } + + start_in_slice = end_in_slice; + } + + Ok(()) +} + +#[cfg(not(feature = "force_hash_collisions"))] +fn hash_run_array( + array: &RunArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], + rehash: bool, +) -> Result<()> { + let has_null_values = array.values().null_count() != 0; + + match (has_null_values, rehash) { + (false, false) => { + hash_run_array_inner::(array, random_state, hashes_buffer) + } + (false, true) => { + hash_run_array_inner::(array, random_state, hashes_buffer) + } + (true, false) => { + hash_run_array_inner::(array, random_state, hashes_buffer) + } + (true, true) => { + hash_run_array_inner::(array, random_state, hashes_buffer) + } + } +} + +/// Internal helper function that hashes a single array and either initializes or combines +/// the hash values in the buffer. +#[cfg(not(feature = "force_hash_collisions"))] +fn hash_single_array( + array: &dyn Array, + random_state: &RandomState, + hashes_buffer: &mut [u64], + rehash: bool, +) -> Result<()> { + downcast_primitive_array! { + array => hash_array_primitive(array, random_state, hashes_buffer, rehash), + DataType::Null => hash_null(random_state, hashes_buffer, rehash), + DataType::Boolean => hash_array(&as_boolean_array(array)?, random_state, hashes_buffer, rehash), + DataType::Utf8 => hash_array(&as_string_array(array)?, random_state, hashes_buffer, rehash), + DataType::Utf8View => hash_generic_byte_view_array(as_string_view_array(array)?, random_state, hashes_buffer, rehash), + DataType::LargeUtf8 => hash_array(&as_largestring_array(array), random_state, hashes_buffer, rehash), + DataType::Binary => hash_array(&as_generic_binary_array::(array)?, random_state, hashes_buffer, rehash), + DataType::BinaryView => hash_generic_byte_view_array(as_binary_view_array(array)?, random_state, hashes_buffer, rehash), + DataType::LargeBinary => hash_array(&as_generic_binary_array::(array)?, random_state, hashes_buffer, rehash), + DataType::FixedSizeBinary(_) => { + let array: &FixedSizeBinaryArray = array.as_any().downcast_ref().unwrap(); + hash_array(&array, random_state, hashes_buffer, rehash) + } + DataType::Dictionary(_, _) => downcast_dictionary_array! { + array => hash_dictionary(array, random_state, hashes_buffer, rehash)?, + _ => unreachable!() + } + DataType::Struct(_) => { + let array = as_struct_array(array)?; + hash_struct_array(array, random_state, hashes_buffer)?; + } + DataType::List(_) => { + let array = as_list_array(array)?; + hash_list_array(array, random_state, hashes_buffer)?; + } + DataType::LargeList(_) => { + let array = as_large_list_array(array)?; + hash_list_array(array, random_state, hashes_buffer)?; + } + DataType::ListView(_) => { + let array = as_list_view_array(array)?; + hash_list_view_array(array, random_state, hashes_buffer)?; + } + DataType::LargeListView(_) => { + let array = as_large_list_view_array(array)?; + hash_list_view_array(array, random_state, hashes_buffer)?; + } + DataType::Map(_, _) => { + let array = as_map_array(array)?; + hash_map_array(array, random_state, hashes_buffer)?; + } + DataType::FixedSizeList(_,_) => { + let array = as_fixed_size_list_array(array)?; + hash_fixed_list_array(array, random_state, hashes_buffer)?; + } + DataType::Union(_, _) => { + let array = as_union_array(array)?; + hash_union_array(array, random_state, hashes_buffer)?; + } + DataType::RunEndEncoded(_, _) => downcast_run_array! { + array => hash_run_array(array, random_state, hashes_buffer, rehash)?, + _ => unreachable!() + } + _ => { + // This is internal because we should have caught this before. + return _internal_err!( + "Unsupported data type in hasher: {}", + array.data_type() + ); + } + } + Ok(()) +} + +/// Test version of `hash_single_array` that forces all hashes to collide to zero. #[cfg(feature = "force_hash_collisions")] -pub fn create_hashes<'a>( - _arrays: &[ArrayRef], +fn hash_single_array( + _array: &dyn Array, _random_state: &RandomState, - hashes_buffer: &'a mut Vec, -) -> Result<&'a mut Vec> { + hashes_buffer: &mut [u64], + _rehash: bool, +) -> Result<()> { for hash in hashes_buffer.iter_mut() { *hash = 0 } - Ok(hashes_buffer) + Ok(()) } -/// Creates hash values for every row, based on the values in the -/// columns. +/// Something that can be returned as a `&dyn Array`. +/// +/// We want `create_hashes` to accept either `&dyn Array` or `ArrayRef`, +/// and this seems the best way to do so. +/// +/// We tried having it accept `AsRef` +/// but that is not implemented for and cannot be implemented for +/// `&dyn Array` so callers that have the latter would not be able +/// to call `create_hashes` directly. This shim trait makes it possible. +pub trait AsDynArray { + fn as_dyn_array(&self) -> &dyn Array; +} + +impl AsDynArray for dyn Array { + fn as_dyn_array(&self) -> &dyn Array { + self + } +} + +impl AsDynArray for &dyn Array { + fn as_dyn_array(&self) -> &dyn Array { + *self + } +} + +impl AsDynArray for ArrayRef { + fn as_dyn_array(&self) -> &dyn Array { + self.as_ref() + } +} + +impl AsDynArray for &ArrayRef { + fn as_dyn_array(&self) -> &dyn Array { + self.as_ref() + } +} + +/// Creates hash values for every row, based on the values in the columns. /// /// The number of rows to hash is determined by `hashes_buffer.len()`. -/// `hashes_buffer` should be pre-sized appropriately -#[cfg(not(feature = "force_hash_collisions"))] -pub fn create_hashes<'a>( - arrays: &[ArrayRef], +/// `hashes_buffer` should be pre-sized appropriately. +pub fn create_hashes<'a, I, T>( + arrays: I, random_state: &RandomState, - hashes_buffer: &'a mut Vec, -) -> Result<&'a mut Vec> { - for (i, col) in arrays.iter().enumerate() { - let array = col.as_ref(); + hashes_buffer: &'a mut [u64], +) -> Result<&'a mut [u64]> +where + I: IntoIterator, + T: AsDynArray, +{ + for (i, array) in arrays.into_iter().enumerate() { // combine hashes with `combine_hashes` for all columns besides the first let rehash = i >= 1; - downcast_primitive_array! { - array => hash_array_primitive(array, random_state, hashes_buffer, rehash), - DataType::Null => hash_null(random_state, hashes_buffer, rehash), - DataType::Boolean => hash_array(as_boolean_array(array)?, random_state, hashes_buffer, rehash), - DataType::Utf8 => hash_array(as_string_array(array)?, random_state, hashes_buffer, rehash), - DataType::Utf8View => hash_array(as_string_view_array(array)?, random_state, hashes_buffer, rehash), - DataType::LargeUtf8 => hash_array(as_largestring_array(array), random_state, hashes_buffer, rehash), - DataType::Binary => hash_array(as_generic_binary_array::(array)?, random_state, hashes_buffer, rehash), - DataType::BinaryView => hash_array(as_binary_view_array(array)?, random_state, hashes_buffer, rehash), - DataType::LargeBinary => hash_array(as_generic_binary_array::(array)?, random_state, hashes_buffer, rehash), - DataType::FixedSizeBinary(_) => { - let array: &FixedSizeBinaryArray = array.as_any().downcast_ref().unwrap(); - hash_array(array, random_state, hashes_buffer, rehash) - } - DataType::Dictionary(_, _) => downcast_dictionary_array! { - array => hash_dictionary(array, random_state, hashes_buffer, rehash)?, - _ => unreachable!() - } - DataType::Struct(_) => { - let array = as_struct_array(array)?; - hash_struct_array(array, random_state, hashes_buffer)?; - } - DataType::List(_) => { - let array = as_list_array(array)?; - hash_list_array(array, random_state, hashes_buffer)?; - } - DataType::LargeList(_) => { - let array = as_large_list_array(array)?; - hash_list_array(array, random_state, hashes_buffer)?; - } - DataType::Map(_, _) => { - let array = as_map_array(array)?; - hash_map_array(array, random_state, hashes_buffer)?; - } - DataType::FixedSizeList(_,_) => { - let array = as_fixed_size_list_array(array)?; - hash_fixed_list_array(array, random_state, hashes_buffer)?; - } - _ => { - // This is internal because we should have caught this before. - return _internal_err!( - "Unsupported data type in hasher: {}", - col.data_type() - ); - } - } + hash_single_array(array.as_dyn_array(), random_state, hashes_buffer, rehash)?; } Ok(hashes_buffer) } @@ -445,8 +1126,8 @@ mod tests { .collect::() .with_precision_and_scale(20, 3) .unwrap(); - let array_ref = Arc::new(array); - let random_state = RandomState::with_seeds(0, 0, 0, 0); + let array_ref: ArrayRef = Arc::new(array); + let random_state = RandomState::with_seed(0); let hashes_buff = &mut vec![0; array_ref.len()]; let hashes = create_hashes(&[array_ref], &random_state, hashes_buff)?; assert_eq!(hashes.len(), 4); @@ -456,19 +1137,25 @@ mod tests { #[test] fn create_hashes_for_empty_fixed_size_lit() -> Result<()> { let empty_array = FixedSizeListBuilder::new(StringBuilder::new(), 1).finish(); - let random_state = RandomState::with_seeds(0, 0, 0, 0); - let hashes_buff = &mut vec![0; 0]; - let hashes = create_hashes(&[Arc::new(empty_array)], &random_state, hashes_buff)?; + let random_state = RandomState::with_seed(0); + let hashes_buff = &mut [0; 0]; + let hashes = create_hashes( + &[Arc::new(empty_array) as ArrayRef], + &random_state, + hashes_buff, + )?; assert_eq!(hashes, &Vec::::new()); Ok(()) } #[test] fn create_hashes_for_float_arrays() -> Result<()> { - let f32_arr = Arc::new(Float32Array::from(vec![0.12, 0.5, 1f32, 444.7])); - let f64_arr = Arc::new(Float64Array::from(vec![0.12, 0.5, 1f64, 444.7])); + let f32_arr: ArrayRef = + Arc::new(Float32Array::from(vec![0.12, 0.5, 1f32, 444.7])); + let f64_arr: ArrayRef = + Arc::new(Float64Array::from(vec![0.12, 0.5, 1f64, 444.7])); - let random_state = RandomState::with_seeds(0, 0, 0, 0); + let random_state = RandomState::with_seed(0); let hashes_buff = &mut vec![0; f32_arr.len()]; let hashes = create_hashes(&[f32_arr], &random_state, hashes_buff)?; assert_eq!(hashes.len(), 4,); @@ -494,18 +1181,15 @@ mod tests { Some(b"Longer than 12 bytes string"), ]; - let binary_array = Arc::new(binary.iter().cloned().collect::<$ARRAY>()); - let ref_array = Arc::new(binary.iter().cloned().collect::()); + let binary_array: ArrayRef = + Arc::new(binary.iter().cloned().collect::<$ARRAY>()); - let random_state = RandomState::with_seeds(0, 0, 0, 0); + let random_state = RandomState::with_seed(0); let mut binary_hashes = vec![0; binary.len()]; create_hashes(&[binary_array], &random_state, &mut binary_hashes) .unwrap(); - let mut ref_hashes = vec![0; binary.len()]; - create_hashes(&[ref_array], &random_state, &mut ref_hashes).unwrap(); - // Null values result in a zero hash, for (val, hash) in binary.iter().zip(binary_hashes.iter()) { match val { @@ -514,9 +1198,6 @@ mod tests { } } - // same logical values should hash to the same hash value - assert_eq!(binary_hashes, ref_hashes); - // Same values should map to same hash values assert_eq!(binary[0], binary[5]); assert_eq!(binary[4], binary[6]); @@ -528,15 +1209,16 @@ mod tests { } create_hash_binary!(binary_array, BinaryArray); + create_hash_binary!(large_binary_array, LargeBinaryArray); create_hash_binary!(binary_view_array, BinaryViewArray); #[test] fn create_hashes_fixed_size_binary() -> Result<()> { let input_arg = vec![vec![1, 2], vec![5, 6], vec![5, 6]]; - let fixed_size_binary_array = + let fixed_size_binary_array: ArrayRef = Arc::new(FixedSizeBinaryArray::try_from_iter(input_arg.into_iter()).unwrap()); - let random_state = RandomState::with_seeds(0, 0, 0, 0); + let random_state = RandomState::with_seed(0); let hashes_buff = &mut vec![0; fixed_size_binary_array.len()]; let hashes = create_hashes(&[fixed_size_binary_array], &random_state, hashes_buff)?; @@ -560,15 +1242,16 @@ mod tests { Some("Longer than 12 bytes string"), ]; - let string_array = Arc::new(strings.iter().cloned().collect::<$ARRAY>()); - let dict_array = Arc::new( + let string_array: ArrayRef = + Arc::new(strings.iter().cloned().collect::<$ARRAY>()); + let dict_array: ArrayRef = Arc::new( strings .iter() .cloned() .collect::>(), ); - let random_state = RandomState::with_seeds(0, 0, 0, 0); + let random_state = RandomState::with_seed(0); let mut string_hashes = vec![0; strings.len()]; create_hashes(&[string_array], &random_state, &mut string_hashes) @@ -603,21 +1286,90 @@ mod tests { create_hash_string!(string_view_array, StringArray); create_hash_string!(dict_string_array, DictionaryArray); + #[test] + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_run_array() -> Result<()> { + let values = Arc::new(Int32Array::from(vec![10, 20, 30])); + let run_ends = Arc::new(Int32Array::from(vec![2, 5, 7])); + let array = Arc::new(RunArray::try_new(&run_ends, values.as_ref()).unwrap()); + + let random_state = RandomState::with_seed(0); + let hashes_buff = &mut vec![0; array.len()]; + let hashes = create_hashes( + &[Arc::clone(&array) as ArrayRef], + &random_state, + hashes_buff, + )?; + + assert_eq!(hashes.len(), 7); + assert_eq!(hashes[0], hashes[1]); + assert_eq!(hashes[2], hashes[3]); + assert_eq!(hashes[3], hashes[4]); + assert_eq!(hashes[5], hashes[6]); + assert_ne!(hashes[0], hashes[2]); + assert_ne!(hashes[2], hashes[5]); + assert_ne!(hashes[0], hashes[5]); + + Ok(()) + } + + #[test] + #[cfg(not(feature = "force_hash_collisions"))] + fn create_multi_column_hash_with_run_array() -> Result<()> { + let int_array = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7])); + let values = Arc::new(StringArray::from(vec!["foo", "bar", "baz"])); + let run_ends = Arc::new(Int32Array::from(vec![2, 5, 7])); + let run_array = Arc::new(RunArray::try_new(&run_ends, values.as_ref()).unwrap()); + + let random_state = RandomState::with_seed(0); + let mut one_col_hashes = vec![0; int_array.len()]; + create_hashes( + &[Arc::clone(&int_array) as ArrayRef], + &random_state, + &mut one_col_hashes, + )?; + + let mut two_col_hashes = vec![0; int_array.len()]; + create_hashes( + &[ + Arc::clone(&int_array) as ArrayRef, + Arc::clone(&run_array) as ArrayRef, + ], + &random_state, + &mut two_col_hashes, + )?; + + assert_eq!(one_col_hashes.len(), 7); + assert_eq!(two_col_hashes.len(), 7); + assert_ne!(one_col_hashes, two_col_hashes); + + let diff_0_vs_1_one_col = one_col_hashes[0] != one_col_hashes[1]; + let diff_0_vs_1_two_col = two_col_hashes[0] != two_col_hashes[1]; + assert_eq!(diff_0_vs_1_one_col, diff_0_vs_1_two_col); + + let diff_2_vs_3_one_col = one_col_hashes[2] != one_col_hashes[3]; + let diff_2_vs_3_two_col = two_col_hashes[2] != two_col_hashes[3]; + assert_eq!(diff_2_vs_3_one_col, diff_2_vs_3_two_col); + + Ok(()) + } + #[test] // Tests actual values of hashes, which are different if forcing collisions #[cfg(not(feature = "force_hash_collisions"))] fn create_hashes_for_dict_arrays() { let strings = [Some("foo"), None, Some("bar"), Some("foo"), None]; - let string_array = Arc::new(strings.iter().cloned().collect::()); - let dict_array = Arc::new( + let string_array: ArrayRef = + Arc::new(strings.iter().cloned().collect::()); + let dict_array: ArrayRef = Arc::new( strings .iter() .cloned() .collect::>(), ); - let random_state = RandomState::with_seeds(0, 0, 0, 0); + let random_state = RandomState::with_seed(0); let mut string_hashes = vec![0; strings.len()]; create_hashes(&[string_array], &random_state, &mut string_hashes).unwrap(); @@ -662,7 +1414,7 @@ mod tests { ]; let list_array = Arc::new(ListArray::from_iter_primitive::(data)) as ArrayRef; - let random_state = RandomState::with_seeds(0, 0, 0, 0); + let random_state = RandomState::with_seed(0); let mut hashes = vec![0; list_array.len()]; create_hashes(&[list_array], &random_state, &mut hashes).unwrap(); assert_eq!(hashes[0], hashes[5]); @@ -671,6 +1423,130 @@ mod tests { assert_eq!(hashes[1], hashes[6]); // null vs empty list } + #[test] + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_sliced_list_arrays() { + let data = vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + // Slice from here + Some(vec![Some(3), None, Some(5)]), + Some(vec![Some(3), None, Some(5)]), + None, + // To here + Some(vec![Some(0), Some(1), Some(2)]), + Some(vec![]), + ]; + let list_array = + Arc::new(ListArray::from_iter_primitive::(data)) as ArrayRef; + let list_array = list_array.slice(2, 3); + let random_state = RandomState::with_seed(0); + let mut hashes = vec![0; list_array.len()]; + create_hashes(&[list_array], &random_state, &mut hashes).unwrap(); + assert_eq!(hashes[0], hashes[1]); + assert_ne!(hashes[1], hashes[2]); + } + + #[test] + // Tests actual values of hashes, which are different if forcing collisions + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_list_view_arrays() { + use arrow::buffer::{NullBuffer, ScalarBuffer}; + + // Create values array: [0, 1, 2, 3, null, 5] + let values = Arc::new(Int32Array::from(vec![ + Some(0), + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef; + let field = Arc::new(Field::new("item", DataType::Int32, true)); + + // Create ListView with the following logical structure: + // Row 0: [0, 1, 2] (offset=0, size=3) + // Row 1: null (null bit set) + // Row 2: [3, null, 5] (offset=3, size=3) + // Row 3: [3, null, 5] (offset=3, size=3) - same as row 2 + // Row 4: null (null bit set) + // Row 5: [0, 1, 2] (offset=0, size=3) - same as row 0 + // Row 6: [] (offset=0, size=0) - empty list + let offsets = ScalarBuffer::from(vec![0i32, 0, 3, 3, 0, 0, 0]); + let sizes = ScalarBuffer::from(vec![3i32, 0, 3, 3, 0, 3, 0]); + let nulls = Some(NullBuffer::from(vec![ + true, false, true, true, false, true, true, + ])); + + let list_view_array = + Arc::new(ListViewArray::new(field, offsets, sizes, values, nulls)) + as ArrayRef; + + let random_state = RandomState::with_seed(0); + let mut hashes = vec![0; list_view_array.len()]; + create_hashes(&[list_view_array], &random_state, &mut hashes).unwrap(); + + assert_eq!(hashes[0], hashes[5]); // same content [0, 1, 2] + assert_eq!(hashes[1], hashes[4]); // both null + assert_eq!(hashes[2], hashes[3]); // same content [3, null, 5] + assert_eq!(hashes[1], hashes[6]); // null vs empty list + + // Negative tests: different content should produce different hashes + assert_ne!(hashes[0], hashes[2]); // [0, 1, 2] vs [3, null, 5] + assert_ne!(hashes[0], hashes[6]); // [0, 1, 2] vs [] + assert_ne!(hashes[2], hashes[6]); // [3, null, 5] vs [] + } + + #[test] + // Tests actual values of hashes, which are different if forcing collisions + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_large_list_view_arrays() { + use arrow::buffer::{NullBuffer, ScalarBuffer}; + + // Create values array: [0, 1, 2, 3, null, 5] + let values = Arc::new(Int32Array::from(vec![ + Some(0), + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef; + let field = Arc::new(Field::new("item", DataType::Int32, true)); + + // Create LargeListView with the following logical structure: + // Row 0: [0, 1, 2] (offset=0, size=3) + // Row 1: null (null bit set) + // Row 2: [3, null, 5] (offset=3, size=3) + // Row 3: [3, null, 5] (offset=3, size=3) - same as row 2 + // Row 4: null (null bit set) + // Row 5: [0, 1, 2] (offset=0, size=3) - same as row 0 + // Row 6: [] (offset=0, size=0) - empty list + let offsets = ScalarBuffer::from(vec![0i64, 0, 3, 3, 0, 0, 0]); + let sizes = ScalarBuffer::from(vec![3i64, 0, 3, 3, 0, 3, 0]); + let nulls = Some(NullBuffer::from(vec![ + true, false, true, true, false, true, true, + ])); + + let large_list_view_array = Arc::new(LargeListViewArray::new( + field, offsets, sizes, values, nulls, + )) as ArrayRef; + + let random_state = RandomState::with_seed(0); + let mut hashes = vec![0; large_list_view_array.len()]; + create_hashes(&[large_list_view_array], &random_state, &mut hashes).unwrap(); + + assert_eq!(hashes[0], hashes[5]); // same content [0, 1, 2] + assert_eq!(hashes[1], hashes[4]); // both null + assert_eq!(hashes[2], hashes[3]); // same content [3, null, 5] + assert_eq!(hashes[1], hashes[6]); // null vs empty list + + // Negative tests: different content should produce different hashes + assert_ne!(hashes[0], hashes[2]); // [0, 1, 2] vs [3, null, 5] + assert_ne!(hashes[0], hashes[6]); // [0, 1, 2] vs [] + assert_ne!(hashes[2], hashes[6]); // [3, null, 5] vs [] + } + #[test] // Tests actual values of hashes, which are different if forcing collisions #[cfg(not(feature = "force_hash_collisions"))] @@ -687,7 +1563,7 @@ mod tests { Arc::new(FixedSizeListArray::from_iter_primitive::( data, 3, )) as ArrayRef; - let random_state = RandomState::with_seeds(0, 0, 0, 0); + let random_state = RandomState::with_seed(0); let mut hashes = vec![0; list_array.len()]; create_hashes(&[list_array], &random_state, &mut hashes).unwrap(); assert_eq!(hashes[0], hashes[5]); @@ -737,7 +1613,7 @@ mod tests { let array = Arc::new(struct_array) as ArrayRef; - let random_state = RandomState::with_seeds(0, 0, 0, 0); + let random_state = RandomState::with_seed(0); let mut hashes = vec![0; array.len()]; create_hashes(&[array], &random_state, &mut hashes).unwrap(); assert_eq!(hashes[0], hashes[1]); @@ -774,7 +1650,7 @@ mod tests { assert!(struct_array.is_valid(1)); let array = Arc::new(struct_array) as ArrayRef; - let random_state = RandomState::with_seeds(0, 0, 0, 0); + let random_state = RandomState::with_seed(0); let mut hashes = vec![0; array.len()]; create_hashes(&[array], &random_state, &mut hashes).unwrap(); assert_eq!(hashes[0], hashes[1]); @@ -827,7 +1703,7 @@ mod tests { let array = Arc::new(builder.finish()) as ArrayRef; - let random_state = RandomState::with_seeds(0, 0, 0, 0); + let random_state = RandomState::with_seed(0); let mut hashes = vec![0; array.len()]; create_hashes(&[array], &random_state, &mut hashes).unwrap(); assert_eq!(hashes[0], hashes[1]); // same value @@ -845,15 +1721,16 @@ mod tests { let strings1 = [Some("foo"), None, Some("bar")]; let strings2 = [Some("blarg"), Some("blah"), None]; - let string_array = Arc::new(strings1.iter().cloned().collect::()); - let dict_array = Arc::new( + let string_array: ArrayRef = + Arc::new(strings1.iter().cloned().collect::()); + let dict_array: ArrayRef = Arc::new( strings2 .iter() .cloned() .collect::>(), ); - let random_state = RandomState::with_seeds(0, 0, 0, 0); + let random_state = RandomState::with_seed(0); let mut one_col_hashes = vec![0; strings1.len()]; create_hashes( @@ -876,4 +1753,345 @@ mod tests { assert_ne!(one_col_hashes, two_col_hashes); } + + #[test] + fn test_create_hashes_from_arrays() { + let int_array: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4])); + let float_array: ArrayRef = + Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])); + + let random_state = RandomState::with_seed(0); + let hashes_buff = &mut vec![0; int_array.len()]; + let hashes = + create_hashes(&[int_array, float_array], &random_state, hashes_buff).unwrap(); + assert_eq!(hashes.len(), 4,); + } + + #[test] + fn test_create_hashes_from_dyn_arrays() { + let int_array: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4])); + let float_array: ArrayRef = + Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])); + + // Verify that we can call create_hashes with only &dyn Array + fn test(arr1: &dyn Array, arr2: &dyn Array) { + let random_state = RandomState::with_seed(0); + let hashes_buff = &mut vec![0; arr1.len()]; + let hashes = create_hashes([arr1, arr2], &random_state, hashes_buff).unwrap(); + assert_eq!(hashes.len(), 4,); + } + test(&*int_array, &*float_array); + } + + #[test] + fn test_create_hashes_equivalence() { + let array: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4])); + let random_state = RandomState::with_seed(0); + + let mut hashes1 = vec![0; array.len()]; + create_hashes( + &[Arc::clone(&array) as ArrayRef], + &random_state, + &mut hashes1, + ) + .unwrap(); + + let mut hashes2 = vec![0; array.len()]; + create_hashes([array], &random_state, &mut hashes2).unwrap(); + + assert_eq!(hashes1, hashes2); + } + + #[test] + fn test_with_hashes() { + let array: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4])); + let random_state = RandomState::with_seed(0); + + // Test that with_hashes produces the same results as create_hashes + let mut expected_hashes = vec![0; array.len()]; + create_hashes([&array], &random_state, &mut expected_hashes).unwrap(); + + let result = with_hashes([&array], &random_state, |hashes| { + assert_eq!(hashes.len(), 4); + // Verify hashes match expected values + assert_eq!(hashes, &expected_hashes[..]); + // Return a copy of the hashes + Ok(hashes.to_vec()) + }) + .unwrap(); + + // Verify callback result is returned correctly + assert_eq!(result, expected_hashes); + } + + #[test] + fn test_with_hashes_multi_column() { + let int_array: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); + let str_array: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"])); + let random_state = RandomState::with_seed(0); + + // Test multi-column hashing + let mut expected_hashes = vec![0; int_array.len()]; + create_hashes( + [&int_array, &str_array], + &random_state, + &mut expected_hashes, + ) + .unwrap(); + + with_hashes([&int_array, &str_array], &random_state, |hashes| { + assert_eq!(hashes.len(), 3); + assert_eq!(hashes, &expected_hashes[..]); + Ok(()) + }) + .unwrap(); + } + + #[test] + fn test_with_hashes_empty_arrays() { + let random_state = RandomState::with_seed(0); + + // Test that passing no arrays returns an error + let empty: [&ArrayRef; 0] = []; + let result = with_hashes(empty, &random_state, |_hashes| Ok(())); + + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("requires at least one array") + ); + } + + #[test] + fn test_with_hashes_reentrancy() { + let array: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); + let array2: ArrayRef = Arc::new(Int32Array::from(vec![4, 5, 6])); + let random_state = RandomState::with_seed(0); + + // Test that reentrant calls return an error instead of panicking + let result = with_hashes([&array], &random_state, |_hashes| { + // Try to call with_hashes again inside the callback + with_hashes([&array2], &random_state, |_inner_hashes| Ok(())) + }); + + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("reentrantly") || err_msg.contains("cannot be called"), + "Error message should mention reentrancy: {err_msg}", + ); + } + + #[test] + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_sparse_union_arrays() { + // logical array: [int(5), str("foo"), int(10), int(5)] + let int_array = Int32Array::from(vec![Some(5), None, Some(10), Some(5)]); + let str_array = StringArray::from(vec![None, Some("foo"), None, None]); + + let type_ids = vec![0_i8, 1, 0, 0].into(); + let children = vec![ + Arc::new(int_array) as ArrayRef, + Arc::new(str_array) as ArrayRef, + ]; + + let union_fields = [ + (0, Arc::new(Field::new("a", DataType::Int32, true))), + (1, Arc::new(Field::new("b", DataType::Utf8, true))), + ] + .into_iter() + .collect(); + + let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap(); + let array_ref = Arc::new(array) as ArrayRef; + + let random_state = RandomState::with_seed(0); + let mut hashes = vec![0; array_ref.len()]; + create_hashes(&[array_ref], &random_state, &mut hashes).unwrap(); + + // Rows 0 and 3 both have type_id=0 (int) with value 5 + assert_eq!(hashes[0], hashes[3]); + // Row 0 (int 5) vs Row 2 (int 10) - different values + assert_ne!(hashes[0], hashes[2]); + // Row 0 (int) vs Row 1 (string) - different types + assert_ne!(hashes[0], hashes[1]); + } + + #[test] + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_sparse_union_arrays_with_nulls() { + // logical array: [int(5), str("foo"), int(null), str(null)] + let int_array = Int32Array::from(vec![Some(5), None, None, None]); + let str_array = StringArray::from(vec![None, Some("foo"), None, None]); + + let type_ids = vec![0, 1, 0, 1].into(); + let children = vec![ + Arc::new(int_array) as ArrayRef, + Arc::new(str_array) as ArrayRef, + ]; + + let union_fields = [ + (0, Arc::new(Field::new("a", DataType::Int32, true))), + (1, Arc::new(Field::new("b", DataType::Utf8, true))), + ] + .into_iter() + .collect(); + + let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap(); + let array_ref = Arc::new(array) as ArrayRef; + + let random_state = RandomState::with_seed(0); + let mut hashes = vec![0; array_ref.len()]; + create_hashes(&[array_ref], &random_state, &mut hashes).unwrap(); + + // row 2 (int null) and row 3 (str null) should have the same hash + // because they are both null values + assert_eq!(hashes[2], hashes[3]); + + // row 0 (int 5) vs row 2 (int null) - different (value vs null) + assert_ne!(hashes[0], hashes[2]); + + // row 1 (str "foo") vs row 3 (str null) - different (value vs null) + assert_ne!(hashes[1], hashes[3]); + } + + #[test] + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_dense_union_arrays() { + // creates a dense union array with int and string types + // [67, "norm", 100, "macdonald", 67] + let int_array = Int32Array::from(vec![67, 100, 67]); + let str_array = StringArray::from(vec!["norm", "macdonald"]); + + let type_ids = vec![0, 1, 0, 1, 0].into(); + let offsets = vec![0, 0, 1, 1, 2].into(); + let children = vec![ + Arc::new(int_array) as ArrayRef, + Arc::new(str_array) as ArrayRef, + ]; + + let union_fields = [ + (0, Arc::new(Field::new("a", DataType::Int32, false))), + (1, Arc::new(Field::new("b", DataType::Utf8, false))), + ] + .into_iter() + .collect(); + + let array = + UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap(); + let array_ref = Arc::new(array) as ArrayRef; + + let random_state = RandomState::with_seed(0); + let mut hashes = vec![0; array_ref.len()]; + create_hashes(&[array_ref], &random_state, &mut hashes).unwrap(); + + // 67 vs "norm" + assert_ne!(hashes[0], hashes[1]); + // 67 vs 100 + assert_ne!(hashes[0], hashes[2]); + // "norm" vs "macdonald" + assert_ne!(hashes[1], hashes[3]); + // 100 vs "macdonald" + assert_ne!(hashes[2], hashes[3]); + // 67 vs 67 + assert_eq!(hashes[0], hashes[4]); + } + + #[test] + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_sliced_run_array() -> Result<()> { + let values = Arc::new(Int32Array::from(vec![10, 20, 30])); + let run_ends = Arc::new(Int32Array::from(vec![2, 5, 7])); + let array = Arc::new(RunArray::try_new(&run_ends, values.as_ref()).unwrap()); + + let random_state = RandomState::with_seed(0); + let mut full_hashes = vec![0; array.len()]; + create_hashes( + &[Arc::clone(&array) as ArrayRef], + &random_state, + &mut full_hashes, + )?; + + let array_ref: ArrayRef = Arc::clone(&array) as ArrayRef; + let sliced_array = array_ref.slice(2, 3); + + let mut sliced_hashes = vec![0; sliced_array.len()]; + create_hashes( + std::slice::from_ref(&sliced_array), + &random_state, + &mut sliced_hashes, + )?; + + assert_eq!(sliced_hashes.len(), 3); + assert_eq!(sliced_hashes[0], sliced_hashes[1]); + assert_eq!(sliced_hashes[1], sliced_hashes[2]); + assert_eq!(&sliced_hashes, &full_hashes[2..5]); + + Ok(()) + } + + #[test] + #[cfg(not(feature = "force_hash_collisions"))] + fn test_run_array_with_nulls() -> Result<()> { + let values = Arc::new(Int32Array::from(vec![Some(10), None, Some(20)])); + let run_ends = Arc::new(Int32Array::from(vec![2, 4, 6])); + let array = Arc::new(RunArray::try_new(&run_ends, values.as_ref()).unwrap()); + + let random_state = RandomState::with_seed(0); + let mut hashes = vec![0; array.len()]; + create_hashes( + &[Arc::clone(&array) as ArrayRef], + &random_state, + &mut hashes, + )?; + + assert_eq!(hashes[0], hashes[1]); + assert_ne!(hashes[0], 0); + assert_eq!(hashes[2], hashes[3]); + assert_eq!(hashes[2], 0); + assert_eq!(hashes[4], hashes[5]); + assert_ne!(hashes[4], 0); + assert_ne!(hashes[0], hashes[4]); + + Ok(()) + } + + #[test] + #[cfg(not(feature = "force_hash_collisions"))] + fn test_run_array_with_nulls_multicolumn() -> Result<()> { + let primitive_array = Arc::new(Int32Array::from(vec![Some(10), None, Some(20)])); + let run_values = Arc::new(Int32Array::from(vec![Some(10), None, Some(20)])); + let run_ends = Arc::new(Int32Array::from(vec![1, 2, 3])); + let run_array = + Arc::new(RunArray::try_new(&run_ends, run_values.as_ref()).unwrap()); + let second_col = Arc::new(Int32Array::from(vec![100, 200, 300])); + + let random_state = RandomState::with_seed(0); + + let mut primitive_hashes = vec![0; 3]; + create_hashes( + &[ + Arc::clone(&primitive_array) as ArrayRef, + Arc::clone(&second_col) as ArrayRef, + ], + &random_state, + &mut primitive_hashes, + )?; + + let mut run_hashes = vec![0; 3]; + create_hashes( + &[ + Arc::clone(&run_array) as ArrayRef, + Arc::clone(&second_col) as ArrayRef, + ], + &random_state, + &mut run_hashes, + )?; + + assert_eq!(primitive_hashes, run_hashes); + + Ok(()) + } } diff --git a/datafusion/common/src/instant.rs b/datafusion/common/src/instant.rs index 42f21c061c0c2..a5dfb28292581 100644 --- a/datafusion/common/src/instant.rs +++ b/datafusion/common/src/instant.rs @@ -22,7 +22,7 @@ /// under `wasm` feature gate. It provides the same API as [`std::time::Instant`]. pub type Instant = web_time::Instant; -#[allow(clippy::disallowed_types)] +#[expect(clippy::disallowed_types)] #[cfg(not(target_family = "wasm"))] /// DataFusion wrapper around [`std::time::Instant`]. This is only a type alias. pub type Instant = std::time::Instant; diff --git a/datafusion/common/src/join_type.rs b/datafusion/common/src/join_type.rs index ac81d977b7296..8855e993f2bc7 100644 --- a/datafusion/common/src/join_type.rs +++ b/datafusion/common/src/join_type.rs @@ -67,6 +67,11 @@ pub enum JoinType { /// /// [1]: http://btw2017.informatik.uni-stuttgart.de/slidesandpapers/F1-10-37/paper_web.pdf LeftMark, + /// Right Mark Join + /// + /// Same logic as the LeftMark Join above, however it returns a record for each record from the + /// right input. + RightMark, } impl JoinType { @@ -87,13 +92,41 @@ impl JoinType { JoinType::RightSemi => JoinType::LeftSemi, JoinType::LeftAnti => JoinType::RightAnti, JoinType::RightAnti => JoinType::LeftAnti, - JoinType::LeftMark => { - unreachable!("LeftMark join type does not support swapping") - } + JoinType::LeftMark => JoinType::RightMark, + JoinType::RightMark => JoinType::LeftMark, + } + } + + /// Whether each side of the join is preserved for ON-clause filter pushdown. + /// + /// It is only correct to push ON-clause filters below a join for preserved + /// inputs. + /// + /// # "Preserved" input definition + /// + /// A join side is preserved if the join returns all or a subset of the rows + /// from that side, such that each output row directly maps to an input row. + /// If a side is not preserved, the join can produce extra null rows that + /// don't map to any input row. + /// + /// # Return Value + /// + /// A tuple of booleans - (left_preserved, right_preserved). + pub fn on_lr_is_preserved(&self) -> (bool, bool) { + match self { + JoinType::Inner => (true, true), + JoinType::Left => (false, true), + JoinType::Right => (true, false), + JoinType::Full => (false, false), + JoinType::LeftSemi | JoinType::RightSemi => (true, true), + JoinType::LeftAnti => (false, true), + JoinType::RightAnti => (true, false), + JoinType::LeftMark => (false, true), + JoinType::RightMark => (true, false), } } - /// Does the join type support swapping inputs? + /// Does the join type support swapping inputs? pub fn supports_swap(&self) -> bool { matches!( self, @@ -105,6 +138,8 @@ impl JoinType { | JoinType::RightSemi | JoinType::LeftAnti | JoinType::RightAnti + | JoinType::LeftMark + | JoinType::RightMark ) } } @@ -121,6 +156,7 @@ impl Display for JoinType { JoinType::LeftAnti => "LeftAnti", JoinType::RightAnti => "RightAnti", JoinType::LeftMark => "LeftMark", + JoinType::RightMark => "RightMark", }; write!(f, "{join_type}") } @@ -141,6 +177,7 @@ impl FromStr for JoinType { "LEFTANTI" => Ok(JoinType::LeftAnti), "RIGHTANTI" => Ok(JoinType::RightAnti), "LEFTMARK" => Ok(JoinType::LeftMark), + "RIGHTMARK" => Ok(JoinType::RightMark), _ => _not_impl_err!("The join type {s} does not exist or is not implemented"), } } diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 7b2c86d3975ff..fdd04f752455e 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -19,18 +19,17 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] +#![cfg_attr(test, allow(clippy::needless_pass_by_value))] mod column; mod dfschema; mod functional_dependencies; mod join_type; mod param_value; -#[cfg(feature = "pyarrow")] -mod pyarrow; mod schema_reference; mod table_reference; mod unnest; @@ -39,13 +38,19 @@ pub mod alias; pub mod cast; pub mod config; pub mod cse; +pub mod datatype; pub mod diagnostic; pub mod display; +pub mod encryption; pub mod error; pub mod file_options; pub mod format; pub mod hash_utils; pub mod instant; +pub mod metadata; +pub mod nested_struct; +mod null_equality; +pub mod parquet_config; pub mod parsers; pub mod pruning; pub mod rounding; @@ -56,29 +61,33 @@ pub mod test_util; pub mod tree_node; pub mod types; pub mod utils; - /// Reexport arrow crate pub use arrow; pub use column::Column; pub use dfschema::{ - qualified_name, DFSchema, DFSchemaRef, ExprSchema, SchemaExt, ToDFSchema, + DFSchema, DFSchemaRef, ExprSchema, SchemaExt, ToDFSchema, qualified_name, }; pub use diagnostic::Diagnostic; +pub use display::human_readable::{ + human_readable_count, human_readable_duration, human_readable_size, units, +}; pub use error::{ - field_not_found, unqualified_field_not_found, DataFusionError, Result, SchemaError, - SharedResult, + DataFusionError, Result, SchemaError, SharedResult, field_not_found, + unqualified_field_not_found, }; pub use file_options::file_type::{ - GetExt, DEFAULT_ARROW_EXTENSION, DEFAULT_AVRO_EXTENSION, DEFAULT_CSV_EXTENSION, - DEFAULT_JSON_EXTENSION, DEFAULT_PARQUET_EXTENSION, + DEFAULT_ARROW_EXTENSION, DEFAULT_AVRO_EXTENSION, DEFAULT_CSV_EXTENSION, + DEFAULT_JSON_EXTENSION, DEFAULT_PARQUET_EXTENSION, GetExt, }; pub use functional_dependencies::{ + Constraint, Constraints, Dependency, FunctionalDependence, FunctionalDependencies, aggregate_functional_dependencies, get_required_group_by_exprs_indices, - get_target_functional_dependencies, Constraint, Constraints, Dependency, - FunctionalDependence, FunctionalDependencies, + get_target_functional_dependencies, }; -use hashbrown::hash_map::DefaultHashBuilder; +use hashbrown::DefaultHashBuilder; pub use join_type::{JoinConstraint, JoinSide, JoinType}; +pub use nested_struct::cast_column; +pub use null_equality::NullEquality; pub use param_value::ParamValues; pub use scalar::{ScalarType, ScalarValue}; pub use schema_reference::SchemaReference; @@ -95,14 +104,20 @@ pub use utils::project_schema; // https://github.com/rust-lang/rust/pull/52234#issuecomment-976702997 #[doc(hidden)] pub use error::{ - _config_datafusion_err, _exec_datafusion_err, _internal_datafusion_err, - _not_impl_datafusion_err, _plan_datafusion_err, _resources_datafusion_err, - _substrait_datafusion_err, + _config_datafusion_err, _exec_datafusion_err, _ffi_datafusion_err, + _internal_datafusion_err, _not_impl_datafusion_err, _plan_datafusion_err, + _resources_datafusion_err, _substrait_datafusion_err, }; // The HashMap and HashSet implementations that should be used as the uniform defaults pub type HashMap = hashbrown::HashMap; pub type HashSet = hashbrown::HashSet; +pub mod hash_map { + pub use hashbrown::hash_map::Entry; +} +pub mod hash_set { + pub use hashbrown::hash_set::Entry; +} /// Downcast an Arrow Array to a concrete type, return an `DataFusionError::Internal` if the cast is /// not possible. In normal usage of DataFusion the downcast should always succeed. @@ -123,10 +138,10 @@ macro_rules! downcast_value { // Not public API. #[doc(hidden)] pub mod __private { - use crate::error::_internal_datafusion_err; use crate::Result; + use crate::error::_internal_datafusion_err; use arrow::array::Array; - use std::any::{type_name, Any}; + use std::any::{Any, type_name}; #[doc(hidden)] pub trait DowncastArrayHelper { @@ -136,10 +151,12 @@ pub mod __private { impl DowncastArrayHelper for T { fn downcast_array_helper(&self) -> Result<&U> { self.as_any().downcast_ref().ok_or_else(|| { + let actual_type = self.data_type(); + let desired_type_name = type_name::(); _internal_datafusion_err!( "could not cast array of type {} to {}", - self.data_type(), - type_name::() + actual_type, + desired_type_name ) }) } @@ -175,7 +192,7 @@ mod tests { assert_starts_with( error.to_string(), - "Internal error: could not cast array of type Int32 to arrow_array::array::primitive_array::PrimitiveArray" + "Internal error: could not cast array of type Int32 to arrow_array::array::primitive_array::PrimitiveArray", ); } diff --git a/datafusion/common/src/metadata.rs b/datafusion/common/src/metadata.rs new file mode 100644 index 0000000000000..d6d8fb7b0ed0c --- /dev/null +++ b/datafusion/common/src/metadata.rs @@ -0,0 +1,384 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{collections::BTreeMap, sync::Arc}; + +use arrow::datatypes::{DataType, Field, FieldRef}; +use hashbrown::HashMap; + +use crate::{DataFusionError, ScalarValue, error::_plan_err}; + +/// A [`ScalarValue`] with optional [`FieldMetadata`] +#[derive(Debug, Clone)] +pub struct ScalarAndMetadata { + pub value: ScalarValue, + pub metadata: Option, +} + +impl ScalarAndMetadata { + /// Create a new Literal from a scalar value with optional [`FieldMetadata`] + pub fn new(value: ScalarValue, metadata: Option) -> Self { + Self { value, metadata } + } + + /// Access the underlying [ScalarValue] storage + pub fn value(&self) -> &ScalarValue { + &self.value + } + + /// Access the [FieldMetadata] attached to this value, if any + pub fn metadata(&self) -> Option<&FieldMetadata> { + self.metadata.as_ref() + } + + /// Consume self and return components + pub fn into_inner(self) -> (ScalarValue, Option) { + (self.value, self.metadata) + } + + /// Cast this values's storage type + /// + /// This operation assumes that if the underlying [ScalarValue] can be casted + /// to a given type that any extension type represented by the metadata is also + /// valid. + pub fn cast_storage_to( + &self, + target_type: &DataType, + ) -> Result { + let new_value = self.value().cast_to(target_type)?; + Ok(Self::new(new_value, self.metadata.clone())) + } +} + +/// create a new ScalarAndMetadata from a ScalarValue without +/// any metadata +impl From for ScalarAndMetadata { + fn from(value: ScalarValue) -> Self { + Self::new(value, None) + } +} + +/// Assert equality of data types where one or both sides may have field metadata +/// +/// This currently compares absent metadata (e.g., one side was a DataType) and +/// empty metadata (e.g., one side was a field where the field had no metadata) +/// as equal and uses byte-for-byte comparison for the keys and values of the +/// fields, even though this is potentially too strict for some cases (e.g., +/// extension types where extension metadata is represented by JSON, or cases +/// where field metadata is orthogonal to the interpretation of the data type). +/// +/// Returns a planning error with suitably formatted type representations if +/// actual and expected do not compare to equal. +pub fn check_metadata_with_storage_equal( + actual: ( + &DataType, + Option<&std::collections::HashMap>, + ), + expected: ( + &DataType, + Option<&std::collections::HashMap>, + ), + what: &str, + context: &str, +) -> Result<(), DataFusionError> { + if actual.0 != expected.0 { + return _plan_err!( + "Expected {what} of type {}, got {}{context}", + format_type_and_metadata(expected.0, expected.1), + format_type_and_metadata(actual.0, actual.1) + ); + } + + let metadata_equal = match (actual.1, expected.1) { + (None, None) => true, + (None, Some(expected_metadata)) => expected_metadata.is_empty(), + (Some(actual_metadata), None) => actual_metadata.is_empty(), + (Some(actual_metadata), Some(expected_metadata)) => { + actual_metadata == expected_metadata + } + }; + + if !metadata_equal { + return _plan_err!( + "Expected {what} of type {}, got {}{context}", + format_type_and_metadata(expected.0, expected.1), + format_type_and_metadata(actual.0, actual.1) + ); + } + + Ok(()) +} + +/// Given a data type represented by storage and optional metadata, generate +/// a user-facing string +/// +/// This function exists to reduce the number of Field debug strings that are +/// used to communicate type information in error messages and plan explain +/// renderings. +pub fn format_type_and_metadata( + data_type: &DataType, + metadata: Option<&std::collections::HashMap>, +) -> String { + match metadata { + Some(metadata) if !metadata.is_empty() => { + format!("{data_type}<{metadata:?}>") + } + _ => data_type.to_string(), + } +} + +/// Literal metadata +/// +/// Stores metadata associated with a literal expressions +/// and is designed to be fast to `clone`. +/// +/// This structure is used to store metadata associated with a literal expression, and it +/// corresponds to the `metadata` field on [`Field`]. +/// +/// # Example: Create [`FieldMetadata`] from a [`Field`] +/// ``` +/// # use std::collections::HashMap; +/// # use datafusion_common::metadata::FieldMetadata; +/// # use arrow::datatypes::{Field, DataType}; +/// # let field = Field::new("c1", DataType::Int32, true) +/// # .with_metadata(HashMap::from([("foo".to_string(), "bar".to_string())])); +/// // Create a new `FieldMetadata` instance from a `Field` +/// let metadata = FieldMetadata::new_from_field(&field); +/// // There is also a `From` impl: +/// let metadata = FieldMetadata::from(&field); +/// ``` +/// +/// # Example: Update a [`Field`] with [`FieldMetadata`] +/// ``` +/// # use datafusion_common::metadata::FieldMetadata; +/// # use arrow::datatypes::{Field, DataType}; +/// # let field = Field::new("c1", DataType::Int32, true); +/// # let metadata = FieldMetadata::new_from_field(&field); +/// // Add any metadata from `FieldMetadata` to `Field` +/// let updated_field = metadata.add_to_field(field); +/// ``` +/// +/// For more background, please also see the [Implementing User Defined Types and Custom Metadata in DataFusion blog] +/// +/// [Implementing User Defined Types and Custom Metadata in DataFusion blog]: https://datafusion.apache.org/blog/2025/09/21/custom-types-using-metadata +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +pub struct FieldMetadata { + /// The inner metadata of a literal expression, which is a map of string + /// keys to string values. + /// + /// Note this is not a `HashMap` because `HashMap` does not provide + /// implementations for traits like `Debug` and `Hash`. + inner: Arc>, +} + +impl Default for FieldMetadata { + fn default() -> Self { + Self::new_empty() + } +} + +impl FieldMetadata { + /// Create a new empty metadata instance. + pub fn new_empty() -> Self { + Self { + inner: Arc::new(BTreeMap::new()), + } + } + + /// Merges two optional `FieldMetadata` instances, overwriting any existing + /// keys in `m` with keys from `n` if present. + /// + /// This function is commonly used in alias operations, particularly for literals + /// with metadata. When creating an alias expression, the metadata from the original + /// expression (such as a literal) is combined with any metadata specified on the alias. + /// + /// # Arguments + /// + /// * `m` - The first metadata (typically from the original expression like a literal) + /// * `n` - The second metadata (typically from the alias definition) + /// + /// # Merge Strategy + /// + /// - If both metadata instances exist, they are merged with `n` taking precedence + /// - Keys from `n` will overwrite keys from `m` if they have the same name + /// - If only one metadata instance exists, it is returned unchanged + /// - If neither exists, `None` is returned + /// + /// # Example usage + /// ```rust + /// use datafusion_common::metadata::FieldMetadata; + /// use std::collections::BTreeMap; + /// + /// // Create metadata for a literal expression + /// let literal_metadata = Some(FieldMetadata::from(BTreeMap::from([ + /// ("source".to_string(), "constant".to_string()), + /// ("type".to_string(), "int".to_string()), + /// ]))); + /// + /// // Create metadata for an alias + /// let alias_metadata = Some(FieldMetadata::from(BTreeMap::from([ + /// ("description".to_string(), "answer".to_string()), + /// ("source".to_string(), "user".to_string()), // This will override literal's "source" + /// ]))); + /// + /// // Merge the metadata + /// let merged = FieldMetadata::merge_options( + /// literal_metadata.as_ref(), + /// alias_metadata.as_ref(), + /// ); + /// + /// // Result contains: {"source": "user", "type": "int", "description": "answer"} + /// assert!(merged.is_some()); + /// ``` + pub fn merge_options( + m: Option<&FieldMetadata>, + n: Option<&FieldMetadata>, + ) -> Option { + match (m, n) { + (Some(m), Some(n)) => { + let mut merged = m.clone(); + merged.extend(n.clone()); + Some(merged) + } + (Some(m), None) => Some(m.clone()), + (None, Some(n)) => Some(n.clone()), + (None, None) => None, + } + } + + /// Create a new metadata instance from a `Field`'s metadata. + pub fn new_from_field(field: &Field) -> Self { + let inner = field + .metadata() + .iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(); + Self { + inner: Arc::new(inner), + } + } + + /// Create a new metadata instance from a map of string keys to string values. + pub fn new(inner: BTreeMap) -> Self { + Self { + inner: Arc::new(inner), + } + } + + /// Get the inner metadata as a reference to a `BTreeMap`. + pub fn inner(&self) -> &BTreeMap { + &self.inner + } + + /// Return the inner metadata + pub fn into_inner(self) -> Arc> { + self.inner + } + + /// Adds metadata from `other` into `self`, overwriting any existing keys. + pub fn extend(&mut self, other: Self) { + if other.is_empty() { + return; + } + let other = Arc::unwrap_or_clone(other.into_inner()); + Arc::make_mut(&mut self.inner).extend(other); + } + + /// Returns true if the metadata is empty. + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + /// Returns the number of key-value pairs in the metadata. + pub fn len(&self) -> usize { + self.inner.len() + } + + /// Convert this `FieldMetadata` into a `HashMap` + pub fn to_hashmap(&self) -> std::collections::HashMap { + self.inner + .iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect() + } + + /// Updates the metadata on the Field with this metadata, if it is not empty. + pub fn add_to_field(&self, field: Field) -> Field { + if self.inner.is_empty() { + return field; + } + + field.with_metadata(self.to_hashmap()) + } + + /// Updates the metadata on the FieldRef with this metadata, if it is not empty. + pub fn add_to_field_ref(&self, mut field_ref: FieldRef) -> FieldRef { + if self.inner.is_empty() { + return field_ref; + } + + Arc::make_mut(&mut field_ref).set_metadata(self.to_hashmap()); + field_ref + } +} + +impl From<&Field> for FieldMetadata { + fn from(field: &Field) -> Self { + Self::new_from_field(field) + } +} + +impl From> for FieldMetadata { + fn from(inner: BTreeMap) -> Self { + Self::new(inner) + } +} + +impl From> for FieldMetadata { + fn from(map: std::collections::HashMap) -> Self { + Self::new(map.into_iter().collect()) + } +} + +/// From reference +impl From<&std::collections::HashMap> for FieldMetadata { + fn from(map: &std::collections::HashMap) -> Self { + let inner = map + .iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(); + Self::new(inner) + } +} + +/// From hashbrown map +impl From> for FieldMetadata { + fn from(map: HashMap) -> Self { + let inner = map.into_iter().collect(); + Self::new(inner) + } +} + +impl From<&HashMap> for FieldMetadata { + fn from(map: &HashMap) -> Self { + let inner = map + .into_iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(); + Self::new(inner) + } +} diff --git a/datafusion/common/src/nested_struct.rs b/datafusion/common/src/nested_struct.rs new file mode 100644 index 0000000000000..bf2558f313069 --- /dev/null +++ b/datafusion/common/src/nested_struct.rs @@ -0,0 +1,1013 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::error::{_plan_err, Result}; +use arrow::{ + array::{Array, ArrayRef, StructArray, new_null_array}, + compute::{CastOptions, cast_with_options}, + datatypes::{DataType, DataType::Struct, Field, FieldRef}, +}; +use std::{collections::HashSet, sync::Arc}; + +/// Cast a struct column to match target struct fields, handling nested structs recursively. +/// +/// This function implements struct-to-struct casting with the assumption that **structs should +/// always be allowed to cast to other structs**. However, the source column must already be +/// a struct type - non-struct sources will result in an error. +/// +/// ## Field Matching Strategy +/// - **By Name**: Source struct fields are matched to target fields by name (case-sensitive) +/// - **No Positional Mapping**: Structs with no overlapping field names are rejected +/// - **Type Adaptation**: When a matching field is found, it is recursively cast to the target field's type +/// - **Missing Fields**: Target fields not present in the source are filled with null values +/// - **Extra Fields**: Source fields not present in the target are ignored +/// +/// ## Nested Struct Handling +/// - Nested structs are handled recursively using the same casting rules +/// - Each level of nesting follows the same field matching and null-filling strategy +/// - This allows for complex struct transformations while maintaining data integrity +/// +/// # Arguments +/// * `source_col` - The source array to cast (must be a struct array) +/// * `target_fields` - The target struct field definitions to cast to +/// +/// # Returns +/// A `Result` containing the cast struct array +/// +/// # Errors +/// Returns a `DataFusionError::Plan` if the source column is not a struct type +fn cast_struct_column( + source_col: &ArrayRef, + target_fields: &[Arc], + cast_options: &CastOptions, +) -> Result { + if source_col.data_type() == &DataType::Null + || (!source_col.is_empty() && source_col.null_count() == source_col.len()) + { + return Ok(new_null_array( + &Struct(target_fields.to_vec().into()), + source_col.len(), + )); + } + + if let Some(source_struct) = source_col.as_any().downcast_ref::() { + let source_fields = source_struct.fields(); + validate_struct_compatibility(source_fields, target_fields)?; + let mut fields: Vec> = Vec::with_capacity(target_fields.len()); + let mut arrays: Vec = Vec::with_capacity(target_fields.len()); + let num_rows = source_col.len(); + + // Iterate target fields and pick source child by name when present. + for target_child_field in target_fields.iter() { + fields.push(Arc::clone(target_child_field)); + + let source_child_opt = + source_struct.column_by_name(target_child_field.name()); + + match source_child_opt { + Some(source_child_col) => { + let adapted_child = + cast_column(source_child_col, target_child_field, cast_options) + .map_err(|e| { + e.context(format!( + "While casting struct field '{}'", + target_child_field.name() + )) + })?; + arrays.push(adapted_child); + } + None => { + arrays.push(new_null_array(target_child_field.data_type(), num_rows)); + } + } + } + + let struct_array = + StructArray::new(fields.into(), arrays, source_struct.nulls().cloned()); + Ok(Arc::new(struct_array)) + } else { + // Return error if source is not a struct type + _plan_err!( + "Cannot cast column of type {} to struct type. Source must be a struct to cast to struct.", + source_col.data_type() + ) + } +} + +/// Cast a column to match the target field type, with special handling for nested structs. +/// +/// This function serves as the main entry point for column casting operations. For struct +/// types, it enforces that **only struct columns can be cast to struct types**. +/// +/// ## Casting Behavior +/// - **Struct Types**: Delegates to `cast_struct_column` for struct-to-struct casting only +/// - **Non-Struct Types**: Uses Arrow's standard `cast` function for primitive type conversions +/// +/// ## Cast Options +/// The `cast_options` argument controls how Arrow handles values that cannot be represented +/// in the target type. When `safe` is `false` (DataFusion's default) the cast will return an +/// error if such a value is encountered. Setting `safe` to `true` instead produces `NULL` +/// for out-of-range or otherwise invalid values. The options also allow customizing how +/// temporal values are formatted when cast to strings. +/// +/// ``` +/// use arrow::array::{ArrayRef, Int64Array}; +/// use arrow::compute::CastOptions; +/// use arrow::datatypes::{DataType, Field}; +/// use datafusion_common::nested_struct::cast_column; +/// use std::sync::Arc; +/// +/// let source: ArrayRef = Arc::new(Int64Array::from(vec![1, i64::MAX])); +/// let target = Field::new("ints", DataType::Int32, true); +/// // Permit lossy conversions by producing NULL on overflow instead of erroring +/// let options = CastOptions { +/// safe: true, +/// ..Default::default() +/// }; +/// let result = cast_column(&source, &target, &options).unwrap(); +/// assert!(result.is_null(1)); +/// ``` +/// +/// ## Struct Casting Requirements +/// The struct casting logic requires that the source column must already be a struct type. +/// This makes the function useful for: +/// - Schema evolution scenarios where struct layouts change over time +/// - Data migration between different struct schemas +/// - Type-safe data processing pipelines that maintain struct type integrity +/// +/// # Arguments +/// * `source_col` - The source array to cast +/// * `target_field` - The target field definition (including type and metadata) +/// * `cast_options` - Options that govern strictness and formatting of the cast +/// +/// # Returns +/// A `Result` containing the cast array +/// +/// # Errors +/// Returns an error if: +/// - Attempting to cast a non-struct column to a struct type +/// - Arrow's cast function fails for non-struct types +/// - Memory allocation fails during struct construction +/// - Invalid data type combinations are encountered +pub fn cast_column( + source_col: &ArrayRef, + target_field: &Field, + cast_options: &CastOptions, +) -> Result { + match target_field.data_type() { + Struct(target_fields) => { + cast_struct_column(source_col, target_fields, cast_options) + } + _ => Ok(cast_with_options( + source_col, + target_field.data_type(), + cast_options, + )?), + } +} + +/// Validates compatibility between source and target struct fields for casting operations. +/// +/// This function implements comprehensive struct compatibility checking by examining: +/// - Field name matching between source and target structs +/// - Type castability for each matching field (including recursive struct validation) +/// - Proper handling of missing fields (target fields not in source are allowed - filled with nulls) +/// - Proper handling of extra fields (source fields not in target are allowed - ignored) +/// +/// # Compatibility Rules +/// - **Field Matching**: Fields are matched by name (case-sensitive) +/// - **Missing Target Fields**: Allowed - will be filled with null values during casting +/// - **Extra Source Fields**: Allowed - will be ignored during casting +/// - **Type Compatibility**: Each matching field must be castable using Arrow's type system +/// - **Nested Structs**: Recursively validates nested struct compatibility +/// +/// # Arguments +/// * `source_fields` - Fields from the source struct type +/// * `target_fields` - Fields from the target struct type +/// +/// # Returns +/// * `Ok(())` if the structs are compatible for casting +/// * `Err(DataFusionError)` with detailed error message if incompatible +/// +/// # Examples +/// ```text +/// // Compatible: source has extra field, target has missing field +/// // Source: {a: i32, b: string, c: f64} +/// // Target: {a: i64, d: bool} +/// // Result: Ok(()) - 'a' can cast i32->i64, 'b','c' ignored, 'd' filled with nulls +/// +/// // Incompatible: matching field has incompatible types +/// // Source: {a: string} +/// // Target: {a: binary} +/// // Result: Err(...) - string cannot cast to binary +/// ``` +/// +pub fn validate_struct_compatibility( + source_fields: &[FieldRef], + target_fields: &[FieldRef], +) -> Result<()> { + let has_overlap = has_one_of_more_common_fields(source_fields, target_fields); + if !has_overlap { + return _plan_err!( + "Cannot cast struct with {} fields to {} fields because there is no field name overlap", + source_fields.len(), + target_fields.len() + ); + } + + // Check compatibility for each target field + for target_field in target_fields { + // Look for matching field in source by name + if let Some(source_field) = source_fields + .iter() + .find(|f| f.name() == target_field.name()) + { + validate_field_compatibility(source_field, target_field)?; + } else { + // Target field is missing from source + // If it's non-nullable, we cannot fill it with NULL + if !target_field.is_nullable() { + return _plan_err!( + "Cannot cast struct: target field '{}' is non-nullable but missing from source. \ + Cannot fill with NULL.", + target_field.name() + ); + } + } + } + + // Extra fields in source are OK - they'll be ignored + Ok(()) +} + +fn validate_field_compatibility( + source_field: &Field, + target_field: &Field, +) -> Result<()> { + if source_field.data_type() == &DataType::Null { + // Validate that target allows nulls before returning early. + // It is invalid to cast a NULL source field to a non-nullable target field. + if !target_field.is_nullable() { + return _plan_err!( + "Cannot cast NULL struct field '{}' to non-nullable field '{}'", + source_field.name(), + target_field.name() + ); + } + return Ok(()); + } + + // Ensure nullability is compatible. It is invalid to cast a nullable + // source field to a non-nullable target field as this may discard + // null values. + if source_field.is_nullable() && !target_field.is_nullable() { + return _plan_err!( + "Cannot cast nullable struct field '{}' to non-nullable field", + target_field.name() + ); + } + + // Check if the matching field types are compatible + match (source_field.data_type(), target_field.data_type()) { + // Recursively validate nested structs + (Struct(source_nested), Struct(target_nested)) => { + validate_struct_compatibility(source_nested, target_nested)?; + } + // For non-struct types, use the existing castability check + _ => { + if !arrow::compute::can_cast_types( + source_field.data_type(), + target_field.data_type(), + ) { + return _plan_err!( + "Cannot cast struct field '{}' from type {} to type {}", + target_field.name(), + source_field.data_type(), + target_field.data_type() + ); + } + } + } + + Ok(()) +} + +/// Check if two field lists have at least one common field by name. +/// +/// This is useful for validating struct compatibility when casting between structs, +/// ensuring that source and target fields have overlapping names. +pub fn has_one_of_more_common_fields( + source_fields: &[FieldRef], + target_fields: &[FieldRef], +) -> bool { + let source_names: HashSet<&str> = source_fields + .iter() + .map(|field| field.name().as_str()) + .collect(); + target_fields + .iter() + .any(|field| source_names.contains(field.name().as_str())) +} + +#[cfg(test)] +mod tests { + + use super::*; + use crate::{assert_contains, format::DEFAULT_CAST_OPTIONS}; + use arrow::{ + array::{ + BinaryArray, Int32Array, Int32Builder, Int64Array, ListArray, MapArray, + MapBuilder, NullArray, StringArray, StringBuilder, + }, + buffer::NullBuffer, + datatypes::{DataType, Field, FieldRef, Int32Type}, + }; + /// Macro to extract and downcast a column from a StructArray + macro_rules! get_column_as { + ($struct_array:expr, $column_name:expr, $array_type:ty) => { + $struct_array + .column_by_name($column_name) + .unwrap() + .as_any() + .downcast_ref::<$array_type>() + .unwrap() + }; + } + + fn field(name: &str, data_type: DataType) -> Field { + Field::new(name, data_type, true) + } + + fn non_null_field(name: &str, data_type: DataType) -> Field { + Field::new(name, data_type, false) + } + + fn arc_field(name: &str, data_type: DataType) -> FieldRef { + Arc::new(field(name, data_type)) + } + + fn struct_type(fields: Vec) -> DataType { + Struct(fields.into()) + } + + fn struct_field(name: &str, fields: Vec) -> Field { + field(name, struct_type(fields)) + } + + fn arc_struct_field(name: &str, fields: Vec) -> FieldRef { + Arc::new(struct_field(name, fields)) + } + + #[test] + fn test_cast_simple_column() { + let source = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; + let target_field = field("ints", DataType::Int64); + let result = cast_column(&source, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + assert_eq!(result.len(), 3); + assert_eq!(result.value(0), 1); + assert_eq!(result.value(1), 2); + assert_eq!(result.value(2), 3); + } + + #[test] + fn test_cast_column_with_options() { + let source = Arc::new(Int64Array::from(vec![1, i64::MAX])) as ArrayRef; + let target_field = field("ints", DataType::Int32); + + let safe_opts = CastOptions { + // safe: false - return Err for failure + safe: false, + ..DEFAULT_CAST_OPTIONS + }; + assert!(cast_column(&source, &target_field, &safe_opts).is_err()); + + let unsafe_opts = CastOptions { + // safe: true - return Null for failure + safe: true, + ..DEFAULT_CAST_OPTIONS + }; + let result = cast_column(&source, &target_field, &unsafe_opts).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + assert_eq!(result.value(0), 1); + assert!(result.is_null(1)); + } + + #[test] + fn test_cast_struct_with_missing_field() { + let a_array = Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef; + let source_struct = StructArray::from(vec![( + arc_field("a", DataType::Int32), + Arc::clone(&a_array), + )]); + let source_col = Arc::new(source_struct) as ArrayRef; + + let target_field = struct_field( + "s", + vec![field("a", DataType::Int32), field("b", DataType::Utf8)], + ); + + let result = + cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + let struct_array = result.as_any().downcast_ref::().unwrap(); + assert_eq!(struct_array.fields().len(), 2); + let a_result = get_column_as!(&struct_array, "a", Int32Array); + assert_eq!(a_result.value(0), 1); + assert_eq!(a_result.value(1), 2); + + let b_result = get_column_as!(&struct_array, "b", StringArray); + assert_eq!(b_result.len(), 2); + assert!(b_result.is_null(0)); + assert!(b_result.is_null(1)); + } + + #[test] + fn test_cast_struct_source_not_struct() { + let source = Arc::new(Int32Array::from(vec![10, 20])) as ArrayRef; + let target_field = struct_field("s", vec![field("a", DataType::Int32)]); + + let result = cast_column(&source, &target_field, &DEFAULT_CAST_OPTIONS); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert!(error_msg.contains("Cannot cast column of type")); + assert!(error_msg.contains("to struct type")); + assert!(error_msg.contains("Source must be a struct")); + } + + #[test] + fn test_cast_struct_incompatible_child_type() { + let a_array = Arc::new(BinaryArray::from(vec![ + Some(b"a".as_ref()), + Some(b"b".as_ref()), + ])) as ArrayRef; + let source_struct = + StructArray::from(vec![(arc_field("a", DataType::Binary), a_array)]); + let source_col = Arc::new(source_struct) as ArrayRef; + + let target_field = struct_field("s", vec![field("a", DataType::Int32)]); + + let result = cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert!(error_msg.contains("Cannot cast struct field 'a'")); + } + + #[test] + fn test_validate_struct_compatibility_incompatible_types() { + // Source struct: {field1: Binary, field2: String} + let source_fields = vec![ + arc_field("field1", DataType::Binary), + arc_field("field2", DataType::Utf8), + ]; + + // Target struct: {field1: Int32} + let target_fields = vec![arc_field("field1", DataType::Int32)]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert!(error_msg.contains("Cannot cast struct field 'field1'")); + assert!(error_msg.contains("Binary")); + assert!(error_msg.contains("Int32")); + } + + #[test] + fn test_validate_struct_compatibility_compatible_types() { + // Source struct: {field1: Int32, field2: String} + let source_fields = vec![ + arc_field("field1", DataType::Int32), + arc_field("field2", DataType::Utf8), + ]; + + // Target struct: {field1: Int64} (Int32 can cast to Int64) + let target_fields = vec![arc_field("field1", DataType::Int64)]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_ok()); + } + + #[test] + fn test_validate_struct_compatibility_missing_field_in_source() { + // Source struct: {field1: Int32} (missing field2) + let source_fields = vec![arc_field("field1", DataType::Int32)]; + + // Target struct: {field1: Int32, field2: Utf8} + let target_fields = vec![ + arc_field("field1", DataType::Int32), + arc_field("field2", DataType::Utf8), + ]; + + // Should be OK - missing fields will be filled with nulls + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_ok()); + } + + #[test] + fn test_validate_struct_compatibility_additional_field_in_source() { + // Source struct: {field1: Int32, field2: String} (extra field2) + let source_fields = vec![ + arc_field("field1", DataType::Int32), + arc_field("field2", DataType::Utf8), + ]; + + // Target struct: {field1: Int32} + let target_fields = vec![arc_field("field1", DataType::Int32)]; + + // Should be OK - extra fields in source are ignored + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_ok()); + } + + #[test] + fn test_validate_struct_compatibility_no_overlap_mismatch_len() { + let source_fields = vec![ + arc_field("left", DataType::Int32), + arc_field("right", DataType::Int32), + ]; + let target_fields = vec![arc_field("alpha", DataType::Int32)]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert_contains!(error_msg, "no field name overlap"); + } + + #[test] + fn test_cast_struct_parent_nulls_retained() { + let a_array = Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef; + let fields = vec![arc_field("a", DataType::Int32)]; + let nulls = Some(NullBuffer::from(vec![true, false])); + let source_struct = StructArray::new(fields.clone().into(), vec![a_array], nulls); + let source_col = Arc::new(source_struct) as ArrayRef; + + let target_field = struct_field("s", vec![field("a", DataType::Int64)]); + + let result = + cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + let struct_array = result.as_any().downcast_ref::().unwrap(); + assert_eq!(struct_array.null_count(), 1); + assert!(struct_array.is_valid(0)); + assert!(struct_array.is_null(1)); + + let a_result = get_column_as!(&struct_array, "a", Int64Array); + assert_eq!(a_result.value(0), 1); + assert_eq!(a_result.value(1), 2); + } + + #[test] + fn test_validate_struct_compatibility_nullable_to_non_nullable() { + // Source struct: {field1: Int32 nullable} + let source_fields = vec![arc_field("field1", DataType::Int32)]; + + // Target struct: {field1: Int32 non-nullable} + let target_fields = vec![Arc::new(non_null_field("field1", DataType::Int32))]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert!(error_msg.contains("field1")); + assert!(error_msg.contains("non-nullable")); + } + + #[test] + fn test_validate_struct_compatibility_non_nullable_to_nullable() { + // Source struct: {field1: Int32 non-nullable} + let source_fields = vec![Arc::new(non_null_field("field1", DataType::Int32))]; + + // Target struct: {field1: Int32 nullable} + let target_fields = vec![arc_field("field1", DataType::Int32)]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_ok()); + } + + #[test] + fn test_validate_struct_compatibility_nested_nullable_to_non_nullable() { + // Source struct: {field1: {nested: Int32 nullable}} + let source_fields = vec![Arc::new(non_null_field( + "field1", + struct_type(vec![field("nested", DataType::Int32)]), + ))]; + + // Target struct: {field1: {nested: Int32 non-nullable}} + let target_fields = vec![Arc::new(non_null_field( + "field1", + struct_type(vec![non_null_field("nested", DataType::Int32)]), + ))]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert!(error_msg.contains("nested")); + assert!(error_msg.contains("non-nullable")); + } + + #[test] + fn test_validate_struct_compatibility_by_name() { + // Source struct: {field1: Int32, field2: String} + let source_fields = vec![ + arc_field("field1", DataType::Int32), + arc_field("field2", DataType::Utf8), + ]; + + // Target struct: {field2: String, field1: Int64} + let target_fields = vec![ + arc_field("field2", DataType::Utf8), + arc_field("field1", DataType::Int64), + ]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_ok()); + } + + #[test] + fn test_validate_struct_compatibility_by_name_with_type_mismatch() { + // Source struct: {field1: Binary} + let source_fields = vec![arc_field("field1", DataType::Binary)]; + + // Target struct: {field1: Int32} (incompatible type) + let target_fields = vec![arc_field("field1", DataType::Int32)]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert_contains!( + error_msg, + "Cannot cast struct field 'field1' from type Binary to type Int32" + ); + } + + #[test] + fn test_validate_struct_compatibility_no_overlap_equal_len() { + let source_fields = vec![ + arc_field("left", DataType::Int32), + arc_field("right", DataType::Utf8), + ]; + + let target_fields = vec![ + arc_field("alpha", DataType::Int32), + arc_field("beta", DataType::Utf8), + ]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert_contains!(error_msg, "no field name overlap"); + } + + #[test] + fn test_validate_struct_compatibility_mixed_name_overlap() { + // Source struct: {a: Int32, b: String, extra: Boolean} + let source_fields = vec![ + arc_field("a", DataType::Int32), + arc_field("b", DataType::Utf8), + arc_field("extra", DataType::Boolean), + ]; + + // Target struct: {b: String, a: Int64, c: Float32} + // Name overlap with a and b, missing c (nullable) + let target_fields = vec![ + arc_field("b", DataType::Utf8), + arc_field("a", DataType::Int64), + arc_field("c", DataType::Float32), + ]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_ok()); + } + + #[test] + fn test_validate_struct_compatibility_by_name_missing_required_field() { + // Source struct: {field1: Int32} (missing field2) + let source_fields = vec![arc_field("field1", DataType::Int32)]; + + // Target struct: {field1: Int32, field2: Int32 non-nullable} + let target_fields = vec![ + arc_field("field1", DataType::Int32), + Arc::new(non_null_field("field2", DataType::Int32)), + ]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert_contains!( + error_msg, + "Cannot cast struct: target field 'field2' is non-nullable but missing from source. Cannot fill with NULL." + ); + } + + #[test] + fn test_validate_struct_compatibility_partial_name_overlap_with_count_mismatch() { + // Source struct: {a: Int32} (only one field) + let source_fields = vec![arc_field("a", DataType::Int32)]; + + // Target struct: {a: Int32, b: String} (two fields, but 'a' overlaps) + let target_fields = vec![ + arc_field("a", DataType::Int32), + arc_field("b", DataType::Utf8), + ]; + + // This should succeed - partial overlap means by-name mapping + // and missing field 'b' is nullable + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_ok()); + } + + #[test] + fn test_cast_nested_struct_with_extra_and_missing_fields() { + // Source inner struct has fields a, b, extra + let a = Arc::new(Int32Array::from(vec![Some(1), None])) as ArrayRef; + let b = Arc::new(Int32Array::from(vec![Some(2), Some(3)])) as ArrayRef; + let extra = Arc::new(Int32Array::from(vec![Some(9), Some(10)])) as ArrayRef; + + let inner = StructArray::from(vec![ + (arc_field("a", DataType::Int32), a), + (arc_field("b", DataType::Int32), b), + (arc_field("extra", DataType::Int32), extra), + ]); + + let source_struct = StructArray::from(vec![( + arc_struct_field( + "inner", + vec![ + field("a", DataType::Int32), + field("b", DataType::Int32), + field("extra", DataType::Int32), + ], + ), + Arc::new(inner) as ArrayRef, + )]); + let source_col = Arc::new(source_struct) as ArrayRef; + + // Target inner struct reorders fields, adds "missing", and drops "extra" + let target_field = struct_field( + "outer", + vec![struct_field( + "inner", + vec![ + field("b", DataType::Int64), + field("a", DataType::Int32), + field("missing", DataType::Int32), + ], + )], + ); + + let result = + cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + let outer = result.as_any().downcast_ref::().unwrap(); + let inner = get_column_as!(&outer, "inner", StructArray); + assert_eq!(inner.fields().len(), 3); + + let b = get_column_as!(inner, "b", Int64Array); + assert_eq!(b.value(0), 2); + assert_eq!(b.value(1), 3); + assert!(!b.is_null(0)); + assert!(!b.is_null(1)); + + let a = get_column_as!(inner, "a", Int32Array); + assert_eq!(a.value(0), 1); + assert!(a.is_null(1)); + + let missing = get_column_as!(inner, "missing", Int32Array); + assert!(missing.is_null(0)); + assert!(missing.is_null(1)); + } + + #[test] + fn test_cast_null_struct_field_to_nested_struct() { + let null_inner = Arc::new(NullArray::new(2)) as ArrayRef; + let source_struct = StructArray::from(vec![( + arc_field("inner", DataType::Null), + Arc::clone(&null_inner), + )]); + let source_col = Arc::new(source_struct) as ArrayRef; + + let target_field = struct_field( + "outer", + vec![struct_field("inner", vec![field("a", DataType::Int32)])], + ); + + let result = + cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + let outer = result.as_any().downcast_ref::().unwrap(); + let inner = get_column_as!(&outer, "inner", StructArray); + assert_eq!(inner.len(), 2); + assert!(inner.is_null(0)); + assert!(inner.is_null(1)); + + let inner_a = get_column_as!(inner, "a", Int32Array); + assert!(inner_a.is_null(0)); + assert!(inner_a.is_null(1)); + } + + #[test] + fn test_cast_struct_with_array_and_map_fields() { + // Array field with second row null + let arr_array = Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2)]), + None, + ])) as ArrayRef; + + // Map field with second row null + let string_builder = StringBuilder::new(); + let int_builder = Int32Builder::new(); + let mut map_builder = MapBuilder::new(None, string_builder, int_builder); + map_builder.keys().append_value("a"); + map_builder.values().append_value(1); + map_builder.append(true).unwrap(); + map_builder.append(false).unwrap(); + let map_array = Arc::new(map_builder.finish()) as ArrayRef; + + let source_struct = StructArray::from(vec![ + ( + arc_field( + "arr", + DataType::List(Arc::new(field("item", DataType::Int32))), + ), + arr_array, + ), + ( + arc_field( + "map", + DataType::Map( + Arc::new(non_null_field( + "entries", + struct_type(vec![ + non_null_field("keys", DataType::Utf8), + field("values", DataType::Int32), + ]), + )), + false, + ), + ), + map_array, + ), + ]); + let source_col = Arc::new(source_struct) as ArrayRef; + + let target_field = struct_field( + "s", + vec![ + field( + "arr", + DataType::List(Arc::new(field("item", DataType::Int32))), + ), + field( + "map", + DataType::Map( + Arc::new(non_null_field( + "entries", + struct_type(vec![ + non_null_field("keys", DataType::Utf8), + field("values", DataType::Int32), + ]), + )), + false, + ), + ), + ], + ); + + let result = + cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + let struct_array = result.as_any().downcast_ref::().unwrap(); + + let arr = get_column_as!(&struct_array, "arr", ListArray); + assert!(!arr.is_null(0)); + assert!(arr.is_null(1)); + let arr0 = arr.value(0); + let values = arr0.as_any().downcast_ref::().unwrap(); + assert_eq!(values.value(0), 1); + assert_eq!(values.value(1), 2); + + let map = get_column_as!(&struct_array, "map", MapArray); + assert!(!map.is_null(0)); + assert!(map.is_null(1)); + let map0 = map.value(0); + let entries = map0.as_any().downcast_ref::().unwrap(); + let keys = get_column_as!(entries, "keys", StringArray); + let vals = get_column_as!(entries, "values", Int32Array); + assert_eq!(keys.value(0), "a"); + assert_eq!(vals.value(0), 1); + } + + #[test] + fn test_cast_struct_field_order_differs() { + let a = Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef; + let b = Arc::new(Int32Array::from(vec![Some(3), None])) as ArrayRef; + + let source_struct = StructArray::from(vec![ + (arc_field("a", DataType::Int32), a), + (arc_field("b", DataType::Int32), b), + ]); + let source_col = Arc::new(source_struct) as ArrayRef; + + let target_field = struct_field( + "s", + vec![field("b", DataType::Int64), field("a", DataType::Int32)], + ); + + let result = + cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + let struct_array = result.as_any().downcast_ref::().unwrap(); + + let b_col = get_column_as!(&struct_array, "b", Int64Array); + assert_eq!(b_col.value(0), 3); + assert!(b_col.is_null(1)); + + let a_col = get_column_as!(&struct_array, "a", Int32Array); + assert_eq!(a_col.value(0), 1); + assert_eq!(a_col.value(1), 2); + } + + #[test] + fn test_cast_struct_no_overlap_rejected() { + let first = Arc::new(Int32Array::from(vec![Some(10), Some(20)])) as ArrayRef; + let second = + Arc::new(StringArray::from(vec![Some("alpha"), Some("beta")])) as ArrayRef; + + let source_struct = StructArray::from(vec![ + (arc_field("left", DataType::Int32), first), + (arc_field("right", DataType::Utf8), second), + ]); + let source_col = Arc::new(source_struct) as ArrayRef; + + let target_field = struct_field( + "s", + vec![field("a", DataType::Int64), field("b", DataType::Utf8)], + ); + + let result = cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert_contains!(error_msg, "no field name overlap"); + } + + #[test] + fn test_cast_struct_missing_non_nullable_field_fails() { + // Source has only field 'a' + let a = Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef; + let source_struct = StructArray::from(vec![(arc_field("a", DataType::Int32), a)]); + let source_col = Arc::new(source_struct) as ArrayRef; + + // Target has fields 'a' (nullable) and 'b' (non-nullable) + let target_field = struct_field( + "s", + vec![ + field("a", DataType::Int32), + non_null_field("b", DataType::Int32), + ], + ); + + // Should fail because 'b' is non-nullable but missing from source + let result = cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS); + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + err.to_string() + .contains("target field 'b' is non-nullable but missing from source"), + "Unexpected error: {err}" + ); + } + + #[test] + fn test_cast_struct_missing_nullable_field_succeeds() { + // Source has only field 'a' + let a = Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef; + let source_struct = StructArray::from(vec![(arc_field("a", DataType::Int32), a)]); + let source_col = Arc::new(source_struct) as ArrayRef; + + // Target has fields 'a' and 'b' (both nullable) + let target_field = struct_field( + "s", + vec![field("a", DataType::Int32), field("b", DataType::Int32)], + ); + + // Should succeed - 'b' is nullable so can be filled with NULL + let result = + cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + let struct_array = result.as_any().downcast_ref::().unwrap(); + + let a_col = get_column_as!(&struct_array, "a", Int32Array); + assert_eq!(a_col.value(0), 1); + assert_eq!(a_col.value(1), 2); + + let b_col = get_column_as!(&struct_array, "b", Int32Array); + assert!(b_col.is_null(0)); + assert!(b_col.is_null(1)); + } +} diff --git a/datafusion/common/src/null_equality.rs b/datafusion/common/src/null_equality.rs new file mode 100644 index 0000000000000..847fb0975703e --- /dev/null +++ b/datafusion/common/src/null_equality.rs @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Represents the behavior for null values when evaluating equality. Currently, its primary use +/// case is to define the behavior of joins for null values. +/// +/// # Examples +/// +/// The following table shows the expected equality behavior for `NullEquality`. +/// +/// | A | B | NullEqualsNothing | NullEqualsNull | +/// |------|------|-------------------|----------------| +/// | NULL | NULL | false | true | +/// | NULL | 'b' | false | false | +/// | 'a' | NULL | false | false | +/// | 'a' | 'b' | false | false | +/// +/// # Order +/// +/// The order on this type represents the "restrictiveness" of the behavior. The more restrictive +/// a behavior is, the fewer elements are considered to be equal to null. +/// [NullEquality::NullEqualsNothing] represents the most restrictive behavior. +/// +/// This mirrors the old order with `null_equals_null` booleans, as `false` indicated that +/// `null != null`. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Hash)] +pub enum NullEquality { + /// Null is *not* equal to anything (`null != null`) + NullEqualsNothing, + /// Null is equal to null (`null == null`) + NullEqualsNull, +} diff --git a/datafusion/common/src/param_value.rs b/datafusion/common/src/param_value.rs index d2802c096da1b..0fac6b529eb0f 100644 --- a/datafusion/common/src/param_value.rs +++ b/datafusion/common/src/param_value.rs @@ -16,22 +16,37 @@ // under the License. use crate::error::{_plan_datafusion_err, _plan_err}; +use crate::metadata::{ScalarAndMetadata, check_metadata_with_storage_equal}; use crate::{Result, ScalarValue}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field, FieldRef}; use std::collections::HashMap; /// The parameter value corresponding to the placeholder #[derive(Debug, Clone)] pub enum ParamValues { /// For positional query parameters, like `SELECT * FROM test WHERE a > $1 AND b = $2` - List(Vec), + List(Vec), /// For named query parameters, like `SELECT * FROM test WHERE a > $foo AND b = $goo` - Map(HashMap), + Map(HashMap), } impl ParamValues { - /// Verify parameter list length and type + /// Verify parameter list length and DataType + /// + /// Use [`ParamValues::verify_fields`] to ensure field metadata is considered when + /// computing type equality. + #[deprecated(since = "51.0.0", note = "Use verify_fields instead")] pub fn verify(&self, expect: &[DataType]) -> Result<()> { + // make dummy Fields + let expect = expect + .iter() + .map(|dt| Field::new("", dt.clone(), true).into()) + .collect::>(); + self.verify_fields(&expect) + } + + /// Verify parameter list length and type + pub fn verify_fields(&self, expect: &[FieldRef]) -> Result<()> { match self { ParamValues::List(list) => { // Verify if the number of params matches the number of values @@ -45,15 +60,16 @@ impl ParamValues { // Verify if the types of the params matches the types of the values let iter = expect.iter().zip(list.iter()); - for (i, (param_type, value)) in iter.enumerate() { - if *param_type != value.data_type() { - return _plan_err!( - "Expected parameter of type {:?}, got {:?} at index {}", - param_type, - value.data_type(), - i - ); - } + for (i, (param_type, lit)) in iter.enumerate() { + check_metadata_with_storage_equal( + ( + &lit.value.data_type(), + lit.metadata.as_ref().map(|m| m.to_hashmap()).as_ref(), + ), + (param_type.data_type(), Some(param_type.metadata())), + "parameter", + &format!(" at index {i}"), + )?; } Ok(()) } @@ -65,7 +81,7 @@ impl ParamValues { } } - pub fn get_placeholders_with_values(&self, id: &str) -> Result { + pub fn get_placeholders_with_values(&self, id: &str) -> Result { match self { ParamValues::List(list) => { if id.is_empty() { @@ -99,7 +115,7 @@ impl ParamValues { impl From> for ParamValues { fn from(value: Vec) -> Self { - Self::List(value) + Self::List(value.into_iter().map(ScalarAndMetadata::from).collect()) } } @@ -108,8 +124,10 @@ where K: Into, { fn from(value: Vec<(K, ScalarValue)>) -> Self { - let value: HashMap = - value.into_iter().map(|(k, v)| (k.into(), v)).collect(); + let value: HashMap = value + .into_iter() + .map(|(k, v)| (k.into(), ScalarAndMetadata::from(v))) + .collect(); Self::Map(value) } } @@ -119,8 +137,10 @@ where K: Into, { fn from(value: HashMap) -> Self { - let value: HashMap = - value.into_iter().map(|(k, v)| (k.into(), v)).collect(); + let value: HashMap = value + .into_iter() + .map(|(k, v)| (k.into(), ScalarAndMetadata::from(v))) + .collect(); Self::Map(value) } } diff --git a/datafusion/common/src/parquet_config.rs b/datafusion/common/src/parquet_config.rs new file mode 100644 index 0000000000000..9d6d7a88566a7 --- /dev/null +++ b/datafusion/common/src/parquet_config.rs @@ -0,0 +1,108 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::fmt::{self, Display}; +use std::str::FromStr; + +use crate::config::{ConfigField, Visit}; +use crate::error::{DataFusionError, Result}; + +/// Parquet writer version options for controlling the Parquet file format version +/// +/// This enum validates parquet writer version values at configuration time, +/// ensuring only valid versions ("1.0" or "2.0") can be set via `SET` commands +/// or proto deserialization. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum DFParquetWriterVersion { + /// Parquet format version 1.0 + #[default] + V1_0, + /// Parquet format version 2.0 + V2_0, +} + +/// Implement parsing strings to `DFParquetWriterVersion` +impl FromStr for DFParquetWriterVersion { + type Err = DataFusionError; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "1.0" => Ok(DFParquetWriterVersion::V1_0), + "2.0" => Ok(DFParquetWriterVersion::V2_0), + other => Err(DataFusionError::Configuration(format!( + "Invalid parquet writer version: {other}. Expected one of: 1.0, 2.0" + ))), + } + } +} + +impl Display for DFParquetWriterVersion { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match self { + DFParquetWriterVersion::V1_0 => "1.0", + DFParquetWriterVersion::V2_0 => "2.0", + }; + write!(f, "{s}") + } +} + +impl ConfigField for DFParquetWriterVersion { + fn visit(&self, v: &mut V, key: &str, description: &'static str) { + v.some(key, self, description) + } + + fn set(&mut self, _: &str, value: &str) -> Result<()> { + *self = DFParquetWriterVersion::from_str(value)?; + Ok(()) + } +} + +/// Convert `DFParquetWriterVersion` to parquet crate's `WriterVersion` +/// +/// This conversion is infallible since `DFParquetWriterVersion` only contains +/// valid values that have been validated at configuration time. +#[cfg(feature = "parquet")] +impl From for parquet::file::properties::WriterVersion { + fn from(value: DFParquetWriterVersion) -> Self { + match value { + DFParquetWriterVersion::V1_0 => { + parquet::file::properties::WriterVersion::PARQUET_1_0 + } + DFParquetWriterVersion::V2_0 => { + parquet::file::properties::WriterVersion::PARQUET_2_0 + } + } + } +} + +/// Convert parquet crate's `WriterVersion` to `DFParquetWriterVersion` +/// +/// This is used when converting from existing parquet writer properties, +/// such as when reading from proto or test code. +#[cfg(feature = "parquet")] +impl From for DFParquetWriterVersion { + fn from(version: parquet::file::properties::WriterVersion) -> Self { + match version { + parquet::file::properties::WriterVersion::PARQUET_1_0 => { + DFParquetWriterVersion::V1_0 + } + parquet::file::properties::WriterVersion::PARQUET_2_0 => { + DFParquetWriterVersion::V2_0 + } + } + } +} diff --git a/datafusion/common/src/parsers.rs b/datafusion/common/src/parsers.rs index 41571ebb8576c..cd3d607dacd88 100644 --- a/datafusion/common/src/parsers.rs +++ b/datafusion/common/src/parsers.rs @@ -20,7 +20,7 @@ use std::fmt::Display; use std::str::FromStr; -use sqlparser::parser::ParserError; +use crate::DataFusionError; /// Readable file compression type #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -38,9 +38,9 @@ pub enum CompressionTypeVariant { } impl FromStr for CompressionTypeVariant { - type Err = ParserError; + type Err = DataFusionError; - fn from_str(s: &str) -> Result { + fn from_str(s: &str) -> Result { let s = s.to_uppercase(); match s.as_str() { "GZIP" | "GZ" => Ok(Self::GZIP), @@ -48,7 +48,7 @@ impl FromStr for CompressionTypeVariant { "XZ" => Ok(Self::XZ), "ZST" | "ZSTD" => Ok(Self::ZSTD), "" | "UNCOMPRESSED" => Ok(Self::UNCOMPRESSED), - _ => Err(ParserError::ParserError(format!( + _ => Err(DataFusionError::NotImplemented(format!( "Unsupported file compression type {s}" ))), } diff --git a/datafusion/common/src/pruning.rs b/datafusion/common/src/pruning.rs index 48750e3c995c4..5a7598ea1f299 100644 --- a/datafusion/common/src/pruning.rs +++ b/datafusion/common/src/pruning.rs @@ -135,6 +135,10 @@ pub trait PruningStatistics { /// This feeds into [`CompositePruningStatistics`] to allow pruning /// with filters that depend both on partition columns and data columns /// (e.g. `WHERE partition_col = data_col`). +#[deprecated( + since = "52.0.0", + note = "This struct is no longer used internally. Use `replace_columns_with_literals` from `datafusion-physical-expr-adapter` to substitute partition column values before pruning. It will be removed in 58.0.0 or 6 months after 52.0.0 is released, whichever comes first." +)] #[derive(Clone)] pub struct PartitionPruningStatistics { /// Values for each column for each container. @@ -156,6 +160,7 @@ pub struct PartitionPruningStatistics { partition_schema: SchemaRef, } +#[expect(deprecated)] impl PartitionPruningStatistics { /// Create a new instance of [`PartitionPruningStatistics`]. /// @@ -169,6 +174,36 @@ impl PartitionPruningStatistics { /// This must **not** be the schema of the entire file or table: /// instead it must only be the schema of the partition columns, /// in the same order as the values in `partition_values`. + /// + /// # Example + /// + /// To create [`PartitionPruningStatistics`] for two partition columns `a` and `b`, + /// for three containers like this: + /// + /// | a | b | + /// | - | - | + /// | 1 | 2 | + /// | 3 | 4 | + /// | 5 | 6 | + /// + /// ``` + /// # use std::sync::Arc; + /// # use datafusion_common::ScalarValue; + /// # use arrow::datatypes::{DataType, Field}; + /// # use datafusion_common::pruning::PartitionPruningStatistics; + /// + /// let partition_values = vec![ + /// vec![ScalarValue::from(1i32), ScalarValue::from(2i32)], + /// vec![ScalarValue::from(3i32), ScalarValue::from(4i32)], + /// vec![ScalarValue::from(5i32), ScalarValue::from(6i32)], + /// ]; + /// let partition_fields = vec![ + /// Arc::new(Field::new("a", DataType::Int32, false)), + /// Arc::new(Field::new("b", DataType::Int32, false)), + /// ]; + /// let partition_stats = + /// PartitionPruningStatistics::try_new(partition_values, partition_fields).unwrap(); + /// ``` pub fn try_new( partition_values: Vec>, partition_fields: Vec, @@ -202,6 +237,7 @@ impl PartitionPruningStatistics { } } +#[expect(deprecated)] impl PruningStatistics for PartitionPruningStatistics { fn min_values(&self, column: &Column) -> Option { let index = self.partition_schema.index_of(column.name()).ok()?; @@ -245,7 +281,7 @@ impl PruningStatistics for PartitionPruningStatistics { match acc { None => Some(Some(eq_result)), Some(acc_array) => { - arrow::compute::kernels::boolean::and(&acc_array, &eq_result) + arrow::compute::kernels::boolean::or_kleene(&acc_array, &eq_result) .map(Some) .ok() } @@ -409,10 +445,15 @@ impl PruningStatistics for PrunableStatistics { /// the first one is returned without any regard for completeness or accuracy. /// That is: if the first statistics has information for a column, even if it is incomplete, /// that is returned even if a later statistics has more complete information. +#[deprecated( + since = "52.0.0", + note = "This struct is no longer used internally. It may be removed in 58.0.0 or 6 months after 52.0.0 is released, whichever comes first. Please open an issue if you have a use case for it." +)] pub struct CompositePruningStatistics { pub statistics: Vec>, } +#[expect(deprecated)] impl CompositePruningStatistics { /// Create a new instance of [`CompositePruningStatistics`] from /// a vector of [`PruningStatistics`]. @@ -427,6 +468,7 @@ impl CompositePruningStatistics { } } +#[expect(deprecated)] impl PruningStatistics for CompositePruningStatistics { fn min_values(&self, column: &Column) -> Option { for stats in &self.statistics { @@ -483,18 +525,25 @@ impl PruningStatistics for CompositePruningStatistics { } #[cfg(test)] +#[expect(deprecated)] mod tests { use crate::{ - cast::{as_int32_array, as_uint64_array}, ColumnStatistics, + cast::{as_int32_array, as_uint64_array}, }; use super::*; use arrow::datatypes::{DataType, Field}; use std::sync::Arc; - #[test] - fn test_partition_pruning_statistics() { + /// return a PartitionPruningStatistics for two columns 'a' and 'b' + /// and the following stats + /// + /// | a | b | + /// | - | - | + /// | 1 | 2 | + /// | 3 | 4 | + fn partition_pruning_statistics_setup() -> PartitionPruningStatistics { let partition_values = vec![ vec![ScalarValue::from(1i32), ScalarValue::from(2i32)], vec![ScalarValue::from(3i32), ScalarValue::from(4i32)], @@ -503,9 +552,12 @@ mod tests { Arc::new(Field::new("a", DataType::Int32, false)), Arc::new(Field::new("b", DataType::Int32, false)), ]; - let partition_stats = - PartitionPruningStatistics::try_new(partition_values, partition_fields) - .unwrap(); + PartitionPruningStatistics::try_new(partition_values, partition_fields).unwrap() + } + + #[test] + fn test_partition_pruning_statistics() { + let partition_stats = partition_pruning_statistics_setup(); let column_a = Column::new_unqualified("a"); let column_b = Column::new_unqualified("b"); @@ -560,6 +612,85 @@ mod tests { assert_eq!(partition_stats.num_containers(), 2); } + #[test] + fn test_partition_pruning_statistics_multiple_positive_values() { + let partition_stats = partition_pruning_statistics_setup(); + + let column_a = Column::new_unqualified("a"); + + // The two containers have `a` values 1 and 3, so they both only contain values from 1 and 3 + let values = HashSet::from([ScalarValue::from(1i32), ScalarValue::from(3i32)]); + let contained_a = partition_stats.contained(&column_a, &values).unwrap(); + let expected_contained_a = BooleanArray::from(vec![true, true]); + assert_eq!(contained_a, expected_contained_a); + } + + #[test] + fn test_partition_pruning_statistics_multiple_negative_values() { + let partition_stats = partition_pruning_statistics_setup(); + + let column_a = Column::new_unqualified("a"); + + // The two containers have `a` values 1 and 3, + // so the first contains ONLY values from 1,2 + // but the second does not + let values = HashSet::from([ScalarValue::from(1i32), ScalarValue::from(2i32)]); + let contained_a = partition_stats.contained(&column_a, &values).unwrap(); + let expected_contained_a = BooleanArray::from(vec![true, false]); + assert_eq!(contained_a, expected_contained_a); + } + + #[test] + fn test_partition_pruning_statistics_null_in_values() { + let partition_values = vec![ + vec![ + ScalarValue::from(1i32), + ScalarValue::from(2i32), + ScalarValue::from(3i32), + ], + vec![ + ScalarValue::from(4i32), + ScalarValue::from(5i32), + ScalarValue::from(6i32), + ], + ]; + let partition_fields = vec![ + Arc::new(Field::new("a", DataType::Int32, false)), + Arc::new(Field::new("b", DataType::Int32, false)), + Arc::new(Field::new("c", DataType::Int32, false)), + ]; + let partition_stats = + PartitionPruningStatistics::try_new(partition_values, partition_fields) + .unwrap(); + + let column_a = Column::new_unqualified("a"); + let column_b = Column::new_unqualified("b"); + let column_c = Column::new_unqualified("c"); + + let values_a = HashSet::from([ScalarValue::from(1i32), ScalarValue::Int32(None)]); + let contained_a = partition_stats.contained(&column_a, &values_a).unwrap(); + let mut builder = BooleanArray::builder(2); + builder.append_value(true); + builder.append_null(); + let expected_contained_a = builder.finish(); + assert_eq!(contained_a, expected_contained_a); + + // First match creates a NULL boolean array + // The accumulator should update the value to true for the second value + let values_b = HashSet::from([ScalarValue::Int32(None), ScalarValue::from(5i32)]); + let contained_b = partition_stats.contained(&column_b, &values_b).unwrap(); + let mut builder = BooleanArray::builder(2); + builder.append_null(); + builder.append_value(true); + let expected_contained_b = builder.finish(); + assert_eq!(contained_b, expected_contained_b); + + // All matches are null, contained should return None + let values_c = HashSet::from([ScalarValue::Int32(None)]); + let contained_c = partition_stats.contained(&column_c, &values_c); + assert!(contained_c.is_none()); + } + #[test] fn test_partition_pruning_statistics_empty() { let partition_values = vec![]; diff --git a/datafusion/common/src/pyarrow.rs b/datafusion/common/src/pyarrow.rs deleted file mode 100644 index ff413e08ab076..0000000000000 --- a/datafusion/common/src/pyarrow.rs +++ /dev/null @@ -1,171 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Conversions between PyArrow and DataFusion types - -use arrow::array::{Array, ArrayData}; -use arrow::pyarrow::{FromPyArrow, ToPyArrow}; -use pyo3::exceptions::PyException; -use pyo3::prelude::PyErr; -use pyo3::types::{PyAnyMethods, PyList}; -use pyo3::{Bound, FromPyObject, IntoPyObject, PyAny, PyObject, PyResult, Python}; - -use crate::{DataFusionError, ScalarValue}; - -impl From for PyErr { - fn from(err: DataFusionError) -> PyErr { - PyException::new_err(err.to_string()) - } -} - -impl FromPyArrow for ScalarValue { - fn from_pyarrow_bound(value: &Bound<'_, PyAny>) -> PyResult { - let py = value.py(); - let typ = value.getattr("type")?; - let val = value.call_method0("as_py")?; - - // construct pyarrow array from the python value and pyarrow type - let factory = py.import("pyarrow")?.getattr("array")?; - let args = PyList::new(py, [val])?; - let array = factory.call1((args, typ))?; - - // convert the pyarrow array to rust array using C data interface - let array = arrow::array::make_array(ArrayData::from_pyarrow_bound(&array)?); - let scalar = ScalarValue::try_from_array(&array, 0)?; - - Ok(scalar) - } -} - -impl ToPyArrow for ScalarValue { - fn to_pyarrow(&self, py: Python) -> PyResult { - let array = self.to_array()?; - // convert to pyarrow array using C data interface - let pyarray = array.to_data().to_pyarrow(py)?; - let pyscalar = pyarray.call_method1(py, "__getitem__", (0,))?; - - Ok(pyscalar) - } -} - -impl<'source> FromPyObject<'source> for ScalarValue { - fn extract_bound(value: &Bound<'source, PyAny>) -> PyResult { - Self::from_pyarrow_bound(value) - } -} - -impl<'source> IntoPyObject<'source> for ScalarValue { - type Target = PyAny; - - type Output = Bound<'source, Self::Target>; - - type Error = PyErr; - - fn into_pyobject(self, py: Python<'source>) -> Result { - let array = self.to_array()?; - // convert to pyarrow array using C data interface - let pyarray = array.to_data().to_pyarrow(py)?; - let pyarray_bound = pyarray.bind(py); - pyarray_bound.call_method1("__getitem__", (0,)) - } -} - -#[cfg(test)] -mod tests { - use pyo3::ffi::c_str; - use pyo3::prepare_freethreaded_python; - use pyo3::py_run; - use pyo3::types::PyDict; - - use super::*; - - fn init_python() { - prepare_freethreaded_python(); - Python::with_gil(|py| { - if py.run(c_str!("import pyarrow"), None, None).is_err() { - let locals = PyDict::new(py); - py.run( - c_str!( - "import sys; executable = sys.executable; python_path = sys.path" - ), - None, - Some(&locals), - ) - .expect("Couldn't get python info"); - let executable = locals.get_item("executable").unwrap(); - let executable: String = executable.extract().unwrap(); - - let python_path = locals.get_item("python_path").unwrap(); - let python_path: Vec = python_path.extract().unwrap(); - - panic!("pyarrow not found\nExecutable: {executable}\nPython path: {python_path:?}\n\ - HINT: try `pip install pyarrow`\n\ - NOTE: On Mac OS, you must compile against a Framework Python \ - (default in python.org installers and brew, but not pyenv)\n\ - NOTE: On Mac OS, PYO3 might point to incorrect Python library \ - path when using virtual environments. Try \ - `export PYTHONPATH=$(python -c \"import sys; print(sys.path[-1])\")`\n") - } - }) - } - - #[test] - fn test_roundtrip() { - init_python(); - - let example_scalars = vec![ - ScalarValue::Boolean(Some(true)), - ScalarValue::Int32(Some(23)), - ScalarValue::Float64(Some(12.34)), - ScalarValue::from("Hello!"), - ScalarValue::Date32(Some(1234)), - ]; - - Python::with_gil(|py| { - for scalar in example_scalars.iter() { - let result = ScalarValue::from_pyarrow_bound( - scalar.to_pyarrow(py).unwrap().bind(py), - ) - .unwrap(); - assert_eq!(scalar, &result); - } - }); - } - - #[test] - fn test_py_scalar() -> PyResult<()> { - init_python(); - - Python::with_gil(|py| -> PyResult<()> { - let scalar_float = ScalarValue::Float64(Some(12.34)); - let py_float = scalar_float - .into_pyobject(py)? - .call_method0("as_py") - .unwrap(); - py_run!(py, py_float, "assert py_float == 12.34"); - - let scalar_string = ScalarValue::Utf8(Some("Hello!".to_string())); - let py_string = scalar_string - .into_pyobject(py)? - .call_method0("as_py") - .unwrap(); - py_run!(py, py_string, "assert py_string == 'Hello!'"); - - Ok(()) - }) - } -} diff --git a/datafusion/common/src/rounding.rs b/datafusion/common/src/rounding.rs index 413067ecd61ed..1796143d7cf1a 100644 --- a/datafusion/common/src/rounding.rs +++ b/datafusion/common/src/rounding.rs @@ -47,7 +47,7 @@ extern crate libc; any(target_arch = "x86_64", target_arch = "aarch64"), not(target_os = "windows") ))] -extern "C" { +unsafe extern "C" { fn fesetround(round: i32); fn fegetround() -> i32; } @@ -77,6 +77,7 @@ pub trait FloatBits { /// The integer value 0, used in bitwise operations. const ZERO: Self::Item; + const NEG_ZERO: Self::Item; /// Converts the floating-point value to its bitwise representation. fn to_bits(self) -> Self::Item; @@ -101,6 +102,7 @@ impl FloatBits for f32 { const CLEAR_SIGN_MASK: u32 = 0x7fff_ffff; const ONE: Self::Item = 1; const ZERO: Self::Item = 0; + const NEG_ZERO: Self::Item = 0x8000_0000; fn to_bits(self) -> Self::Item { self.to_bits() @@ -130,6 +132,7 @@ impl FloatBits for f64 { const CLEAR_SIGN_MASK: u64 = 0x7fff_ffff_ffff_ffff; const ONE: Self::Item = 1; const ZERO: Self::Item = 0; + const NEG_ZERO: Self::Item = 0x8000_0000_0000_0000; fn to_bits(self) -> Self::Item { self.to_bits() @@ -175,8 +178,10 @@ pub fn next_up(float: F) -> F { } let abs = bits & F::CLEAR_SIGN_MASK; - let next_bits = if abs == F::ZERO { + let next_bits = if bits == F::ZERO { F::TINY_BITS + } else if abs == F::ZERO { + F::ZERO } else if bits == abs { bits + F::ONE } else { @@ -206,8 +211,11 @@ pub fn next_down(float: F) -> F { if float.float_is_nan() || bits == F::neg_infinity().to_bits() { return float; } + let abs = bits & F::CLEAR_SIGN_MASK; - let next_bits = if abs == F::ZERO { + let next_bits = if bits == F::ZERO { + F::NEG_ZERO + } else if abs == F::ZERO { F::NEG_TINY_BITS } else if bits == abs { bits - F::ONE @@ -396,4 +404,32 @@ mod tests { let result = next_down(value); assert!(result.is_nan()); } + + #[test] + fn test_next_up_neg_zero_f32() { + let value: f32 = -0.0; + let result = next_up(value); + assert_eq!(result, 0.0); + } + + #[test] + fn test_next_down_zero_f32() { + let value: f32 = 0.0; + let result = next_down(value); + assert_eq!(result, -0.0); + } + + #[test] + fn test_next_up_neg_zero_f64() { + let value: f64 = -0.0; + let result = next_up(value); + assert_eq!(result, 0.0); + } + + #[test] + fn test_next_down_zero_f64() { + let value: f64 = 0.0; + let result = next_down(value); + assert_eq!(result, -0.0); + } } diff --git a/datafusion/common/src/scalar/cache.rs b/datafusion/common/src/scalar/cache.rs new file mode 100644 index 0000000000000..5b1ad4e4ede01 --- /dev/null +++ b/datafusion/common/src/scalar/cache.rs @@ -0,0 +1,215 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Array caching utilities for scalar values + +use std::iter::repeat_n; +use std::sync::{Arc, LazyLock, Mutex}; + +use arrow::array::{Array, ArrayRef, PrimitiveArray, new_null_array}; +use arrow::datatypes::{ + ArrowDictionaryKeyType, DataType, Int8Type, Int16Type, Int32Type, Int64Type, + UInt8Type, UInt16Type, UInt32Type, UInt64Type, +}; + +/// Maximum number of rows to cache to be conservative on memory usage +const MAX_CACHE_SIZE: usize = 1024 * 1024; + +/// Cache for dictionary key arrays to avoid repeated allocations +/// when the same size is used frequently. +/// +/// Similar to PartitionColumnProjector's ZeroBufferGenerators, this cache +/// stores key arrays for different dictionary key types. The cache is +/// limited to 1 entry per type (the last size used) to prevent memory leaks +/// for extremely large array requests. +#[derive(Debug)] +struct KeyArrayCache { + cache: Option<(usize, bool, PrimitiveArray)>, // (num_rows, is_null, key_array) +} + +impl Default for KeyArrayCache { + fn default() -> Self { + Self { cache: None } + } +} + +impl KeyArrayCache { + /// Get or create a cached key array for the given number of rows and null status + fn get_or_create(&mut self, num_rows: usize, is_null: bool) -> PrimitiveArray { + // Check cache size limit to prevent memory leaks + if num_rows > MAX_CACHE_SIZE { + // For very large arrays, don't cache them - just create and return + return self.create_key_array(num_rows, is_null); + } + + match &self.cache { + Some((cached_num_rows, cached_is_null, cached_array)) + if *cached_num_rows == num_rows && *cached_is_null == is_null => + { + // Cache hit: reuse existing array if same size and null status + cached_array.clone() + } + _ => { + // Cache miss: create new array and cache it + let key_array = self.create_key_array(num_rows, is_null); + self.cache = Some((num_rows, is_null, key_array.clone())); + key_array + } + } + } + + /// Create a new key array with the specified number of rows and null status + fn create_key_array(&self, num_rows: usize, is_null: bool) -> PrimitiveArray { + let key_array: PrimitiveArray = repeat_n( + if is_null { + None + } else { + Some(K::default_value()) + }, + num_rows, + ) + .collect(); + key_array + } +} + +/// Cache for null arrays to avoid repeated allocations +/// when the same size is used frequently. +#[derive(Debug, Default)] +struct NullArrayCache { + cache: Option<(usize, ArrayRef)>, // (num_rows, null_array) +} + +impl NullArrayCache { + /// Get or create a cached null array for the given number of rows + fn get_or_create(&mut self, num_rows: usize) -> ArrayRef { + // Check cache size limit to prevent memory leaks + if num_rows > MAX_CACHE_SIZE { + // For very large arrays, don't cache them - just create and return + return new_null_array(&DataType::Null, num_rows); + } + + match &self.cache { + Some((cached_num_rows, cached_array)) if *cached_num_rows == num_rows => { + // Cache hit: reuse existing array if same size + Arc::clone(cached_array) + } + _ => { + // Cache miss: create new array and cache it + let null_array = new_null_array(&DataType::Null, num_rows); + self.cache = Some((num_rows, Arc::clone(&null_array))); + null_array + } + } + } +} + +/// Global cache for dictionary key arrays and null arrays +#[derive(Debug, Default)] +struct ArrayCaches { + cache_i8: KeyArrayCache, + cache_i16: KeyArrayCache, + cache_i32: KeyArrayCache, + cache_i64: KeyArrayCache, + cache_u8: KeyArrayCache, + cache_u16: KeyArrayCache, + cache_u32: KeyArrayCache, + cache_u64: KeyArrayCache, + null_cache: NullArrayCache, +} + +static ARRAY_CACHES: LazyLock> = + LazyLock::new(|| Mutex::new(ArrayCaches::default())); + +/// Get the global cache for arrays +fn get_array_caches() -> &'static Mutex { + &ARRAY_CACHES +} + +/// Get or create a cached null array for the given number of rows +pub(crate) fn get_or_create_cached_null_array(num_rows: usize) -> ArrayRef { + let cache = get_array_caches(); + let mut caches = cache.lock().unwrap(); + caches.null_cache.get_or_create(num_rows) +} + +/// Get or create a cached key array for a specific key type +pub(crate) fn get_or_create_cached_key_array( + num_rows: usize, + is_null: bool, +) -> PrimitiveArray { + let cache = get_array_caches(); + let mut caches = cache.lock().unwrap(); + + // Use the DATA_TYPE to dispatch to the correct cache, similar to original implementation + match K::DATA_TYPE { + DataType::Int8 => { + let array = caches.cache_i8.get_or_create(num_rows, is_null); + // Convert using ArrayData to avoid unsafe transmute + let array_data = array.to_data(); + PrimitiveArray::::from(array_data) + } + DataType::Int16 => { + let array = caches.cache_i16.get_or_create(num_rows, is_null); + let array_data = array.to_data(); + PrimitiveArray::::from(array_data) + } + DataType::Int32 => { + let array = caches.cache_i32.get_or_create(num_rows, is_null); + let array_data = array.to_data(); + PrimitiveArray::::from(array_data) + } + DataType::Int64 => { + let array = caches.cache_i64.get_or_create(num_rows, is_null); + let array_data = array.to_data(); + PrimitiveArray::::from(array_data) + } + DataType::UInt8 => { + let array = caches.cache_u8.get_or_create(num_rows, is_null); + let array_data = array.to_data(); + PrimitiveArray::::from(array_data) + } + DataType::UInt16 => { + let array = caches.cache_u16.get_or_create(num_rows, is_null); + let array_data = array.to_data(); + PrimitiveArray::::from(array_data) + } + DataType::UInt32 => { + let array = caches.cache_u32.get_or_create(num_rows, is_null); + let array_data = array.to_data(); + PrimitiveArray::::from(array_data) + } + DataType::UInt64 => { + let array = caches.cache_u64.get_or_create(num_rows, is_null); + let array_data = array.to_data(); + PrimitiveArray::::from(array_data) + } + _ => { + // Fallback for unsupported types - create array directly without caching + let key_array: PrimitiveArray = repeat_n( + if is_null { + None + } else { + Some(K::default_value()) + }, + num_rows, + ) + .collect(); + key_array + } + } +} diff --git a/datafusion/common/src/scalar/consts.rs b/datafusion/common/src/scalar/consts.rs index efcde651841b0..599c2523cd2c7 100644 --- a/datafusion/common/src/scalar/consts.rs +++ b/datafusion/common/src/scalar/consts.rs @@ -17,28 +17,40 @@ // Constants defined for scalar construction. -// PI ~ 3.1415927 in f32 -#[allow(clippy::approx_constant)] -pub(super) const PI_UPPER_F32: f32 = 3.141593_f32; +// Next F16 value above π (upper bound) +pub(super) const PI_UPPER_F16: half::f16 = half::f16::from_bits(0x4249); -// PI ~ 3.141592653589793 in f64 -pub(super) const PI_UPPER_F64: f64 = 3.141592653589794_f64; +// Next f32 value above π (upper bound) +pub(super) const PI_UPPER_F32: f32 = std::f32::consts::PI.next_up(); -// -PI ~ -3.1415927 in f32 -#[allow(clippy::approx_constant)] -pub(super) const NEGATIVE_PI_LOWER_F32: f32 = -3.141593_f32; +// Next f64 value above π (upper bound) +pub(super) const PI_UPPER_F64: f64 = std::f64::consts::PI.next_up(); -// -PI ~ -3.141592653589793 in f64 -pub(super) const NEGATIVE_PI_LOWER_F64: f64 = -3.141592653589794_f64; +// Next f16 value below -π (lower bound) +pub(super) const NEGATIVE_PI_LOWER_F16: half::f16 = half::f16::from_bits(0xC249); -// PI / 2 ~ 1.5707964 in f32 -pub(super) const FRAC_PI_2_UPPER_F32: f32 = 1.5707965_f32; +// Next f32 value below -π (lower bound) +pub(super) const NEGATIVE_PI_LOWER_F32: f32 = (-std::f32::consts::PI).next_down(); -// PI / 2 ~ 1.5707963267948966 in f64 -pub(super) const FRAC_PI_2_UPPER_F64: f64 = 1.5707963267948967_f64; +// Next f64 value below -π (lower bound) +pub(super) const NEGATIVE_PI_LOWER_F64: f64 = (-std::f64::consts::PI).next_down(); -// -PI / 2 ~ -1.5707964 in f32 -pub(super) const NEGATIVE_FRAC_PI_2_LOWER_F32: f32 = -1.5707965_f32; +// Next f16 value above π/2 (upper bound) +pub(super) const FRAC_PI_2_UPPER_F16: half::f16 = half::f16::from_bits(0x3E49); -// -PI / 2 ~ -1.5707963267948966 in f64 -pub(super) const NEGATIVE_FRAC_PI_2_LOWER_F64: f64 = -1.5707963267948967_f64; +// Next f32 value above π/2 (upper bound) +pub(super) const FRAC_PI_2_UPPER_F32: f32 = std::f32::consts::FRAC_PI_2.next_up(); + +// Next f64 value above π/2 (upper bound) +pub(super) const FRAC_PI_2_UPPER_F64: f64 = std::f64::consts::FRAC_PI_2.next_up(); + +// Next f32 value below -π/2 (lower bound) +pub(super) const NEGATIVE_FRAC_PI_2_LOWER_F16: half::f16 = half::f16::from_bits(0xBE49); + +// Next f32 value below -π/2 (lower bound) +pub(super) const NEGATIVE_FRAC_PI_2_LOWER_F32: f32 = + (-std::f32::consts::FRAC_PI_2).next_down(); + +// Next f64 value below -π/2 (lower bound) +pub(super) const NEGATIVE_FRAC_PI_2_LOWER_F64: f64 = + (-std::f64::consts::FRAC_PI_2).next_down(); diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 3d4aa78b6da65..ebed41e9d8587 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -17,6 +17,7 @@ //! [`ScalarValue`]: stores single values +mod cache; mod consts; mod struct_builder; @@ -25,6 +26,7 @@ use std::cmp::Ordering; use std::collections::{HashSet, VecDeque}; use std::convert::Infallible; use std::fmt; +use std::fmt::Write; use std::hash::Hash; use std::hash::Hasher; use std::iter::repeat_n; @@ -32,36 +34,162 @@ use std::mem::{size_of, size_of_val}; use std::str::FromStr; use std::sync::Arc; -use crate::arrow_datafusion_err; +use crate::assert_or_internal_err; use crate::cast::{ - as_decimal128_array, as_decimal256_array, as_dictionary_array, - as_fixed_size_binary_array, as_fixed_size_list_array, + as_binary_array, as_binary_view_array, as_boolean_array, as_date32_array, + as_date64_array, as_decimal32_array, as_decimal64_array, as_decimal128_array, + as_decimal256_array, as_dictionary_array, as_duration_microsecond_array, + as_duration_millisecond_array, as_duration_nanosecond_array, + as_duration_second_array, as_fixed_size_binary_array, as_fixed_size_list_array, + as_float16_array, as_float32_array, as_float64_array, as_int8_array, as_int16_array, + as_int32_array, as_int64_array, as_interval_dt_array, as_interval_mdn_array, + as_interval_ym_array, as_large_binary_array, as_large_list_array, + as_large_string_array, as_run_array, as_string_array, as_string_view_array, + as_time32_millisecond_array, as_time32_second_array, as_time64_microsecond_array, + as_time64_nanosecond_array, as_timestamp_microsecond_array, + as_timestamp_millisecond_array, as_timestamp_nanosecond_array, + as_timestamp_second_array, as_uint8_array, as_uint16_array, as_uint32_array, + as_uint64_array, as_union_array, }; -use crate::error::{DataFusionError, Result, _exec_err, _internal_err, _not_impl_err}; +use crate::error::{_exec_err, _internal_err, _not_impl_err, DataFusionError, Result}; use crate::format::DEFAULT_CAST_OPTIONS; use crate::hash_utils::create_hashes; use crate::utils::SingleRowListArrayBuilder; +use crate::{_internal_datafusion_err, arrow_datafusion_err}; use arrow::array::{ - types::{IntervalDayTime, IntervalMonthDayNano}, - *, + Array, ArrayData, ArrayDataBuilder, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, + AsArray, BinaryArray, BinaryViewArray, BinaryViewBuilder, BooleanArray, Date32Array, + Date64Array, Decimal32Array, Decimal64Array, Decimal128Array, Decimal256Array, + DictionaryArray, DurationMicrosecondArray, DurationMillisecondArray, + DurationNanosecondArray, DurationSecondArray, FixedSizeBinaryArray, + FixedSizeListArray, Float16Array, Float32Array, Float64Array, GenericListArray, + Int8Array, Int16Array, Int32Array, Int64Array, IntervalDayTimeArray, + IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray, LargeListArray, + LargeStringArray, ListArray, MapArray, MutableArrayData, OffsetSizeTrait, + PrimitiveArray, RunArray, Scalar, StringArray, StringViewArray, StringViewBuilder, + StructArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, + Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, UInt8Array, UInt16Array, UInt32Array, + UInt64Array, UnionArray, downcast_run_array, new_empty_array, new_null_array, }; -use arrow::buffer::ScalarBuffer; -use arrow::compute::kernels::{ - cast::{cast_with_options, CastOptions}, - numeric::*, +use arrow::buffer::{BooleanBuffer, ScalarBuffer}; +use arrow::compute::kernels::cast::{CastOptions, cast_with_options}; +use arrow::compute::kernels::numeric::{ + add, add_wrapping, div, mul, mul_wrapping, rem, sub, sub_wrapping, }; use arrow::datatypes::{ - i256, ArrowDictionaryKeyType, ArrowNativeType, ArrowTimestampType, DataType, - Date32Type, Date64Type, Field, Float32Type, Int16Type, Int32Type, Int64Type, - Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, - IntervalYearMonthType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, - TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, - UInt8Type, UnionFields, UnionMode, DECIMAL128_MAX_PRECISION, + ArrowDictionaryKeyType, ArrowNativeType, ArrowTimestampType, DataType, Date32Type, + Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, DecimalType, Field, + FieldRef, Float32Type, Int8Type, Int16Type, Int32Type, Int64Type, IntervalDayTime, + IntervalDayTimeType, IntervalMonthDayNano, IntervalMonthDayNanoType, IntervalUnit, + IntervalYearMonthType, RunEndIndexType, TimeUnit, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt8Type, + UInt16Type, UInt32Type, UInt64Type, UnionFields, UnionMode, i256, + validate_decimal_precision_and_scale, }; -use arrow::util::display::{array_value_to_string, ArrayFormatter, FormatOptions}; +use arrow::util::display::{ArrayFormatter, FormatOptions, array_value_to_string}; +use cache::{get_or_create_cached_key_array, get_or_create_cached_null_array}; +use chrono::{Duration, NaiveDate}; use half::f16; pub use struct_builder::ScalarStructBuilder; +const SECONDS_PER_DAY: i64 = 86_400; +const MILLIS_PER_DAY: i64 = SECONDS_PER_DAY * 1_000; +const MICROS_PER_DAY: i64 = MILLIS_PER_DAY * 1_000; +const NANOS_PER_DAY: i64 = MICROS_PER_DAY * 1_000; +const MICROS_PER_MILLISECOND: i64 = 1_000; +const NANOS_PER_MILLISECOND: i64 = 1_000_000; + +/// Returns the multiplier that converts the input date representation into the +/// desired timestamp unit, if the conversion requires a multiplication that can +/// overflow an `i64`. +pub fn date_to_timestamp_multiplier( + source_type: &DataType, + target_type: &DataType, +) -> Option { + let DataType::Timestamp(target_unit, _) = target_type else { + return None; + }; + + // Only `Timestamp` target types have a time unit; otherwise no + // multiplier applies (handled above). The function returns `Some(m)` + // when converting the `source_type` to `target_type` requires a + // multiplication that could overflow `i64`. It returns `None` when + // the conversion is a division or otherwise doesn't require a + // multiplication (e.g. Date64 -> Second). + match source_type { + // Date32 stores days since epoch. Converting to any timestamp + // unit requires multiplying by the per-day factor (seconds, + // milliseconds, microseconds, nanoseconds). + DataType::Date32 => Some(match target_unit { + TimeUnit::Second => SECONDS_PER_DAY, + TimeUnit::Millisecond => MILLIS_PER_DAY, + TimeUnit::Microsecond => MICROS_PER_DAY, + TimeUnit::Nanosecond => NANOS_PER_DAY, + }), + + // Date64 stores milliseconds since epoch. Converting to + // seconds is a division (no multiplication), so return `None`. + // Converting to milliseconds is 1:1 (multiplier 1). Converting + // to micro/nano requires multiplying by 1_000 / 1_000_000. + DataType::Date64 => match target_unit { + TimeUnit::Second => None, + // Converting Date64 (ms since epoch) to millisecond timestamps + // is an identity conversion and does not require multiplication. + // Returning `None` indicates no multiplication-based overflow + // check is necessary. + TimeUnit::Millisecond => None, + TimeUnit::Microsecond => Some(MICROS_PER_MILLISECOND), + TimeUnit::Nanosecond => Some(NANOS_PER_MILLISECOND), + }, + + _ => None, + } +} + +/// Ensures the provided value can be represented as a timestamp with the given +/// multiplier. Returns an [`DataFusionError::Execution`] when the converted +/// value would overflow the timestamp range. +pub fn ensure_timestamp_in_bounds( + value: i64, + multiplier: i64, + source_type: &DataType, + target_type: &DataType, +) -> Result<()> { + if multiplier <= 1 { + return Ok(()); + } + + if value.checked_mul(multiplier).is_none() { + let target = format_timestamp_type_for_error(target_type); + _exec_err!( + "Cannot cast {} value {} to {}: converted value exceeds the representable i64 range", + source_type, + value, + target + ) + } else { + Ok(()) + } +} + +/// Format a `DataType::Timestamp` into a short, stable string used in +/// user-facing error messages. +pub(crate) fn format_timestamp_type_for_error(target_type: &DataType) -> String { + match target_type { + DataType::Timestamp(unit, _) => { + let s = match unit { + TimeUnit::Second => "s", + TimeUnit::Millisecond => "ms", + TimeUnit::Microsecond => "us", + TimeUnit::Nanosecond => "ns", + }; + format!("Timestamp({s})") + } + other => format!("{other}"), + } +} + /// A dynamically typed, nullable single value. /// /// While an arrow [`Array`]) stores one or more values of the same type, in a @@ -142,9 +270,9 @@ pub use struct_builder::ScalarStructBuilder; /// let field_b = Field::new("b", DataType::Utf8, false); /// /// let s1 = ScalarStructBuilder::new() -/// .with_scalar(field_a, ScalarValue::from(1i32)) -/// .with_scalar(field_b, ScalarValue::from("foo")) -/// .build(); +/// .with_scalar(field_a, ScalarValue::from(1i32)) +/// .with_scalar(field_b, ScalarValue::from("foo")) +/// .build(); /// ``` /// /// ## Example: Creating a null [`ScalarValue::Struct`] using [`ScalarStructBuilder`] @@ -170,13 +298,13 @@ pub use struct_builder::ScalarStructBuilder; /// // Build a struct like: {a: 1, b: "foo"} /// // Field description /// let fields = Fields::from(vec![ -/// Field::new("a", DataType::Int32, false), -/// Field::new("b", DataType::Utf8, false), +/// Field::new("a", DataType::Int32, false), +/// Field::new("b", DataType::Utf8, false), /// ]); /// // one row arrays for each field /// let arrays: Vec = vec![ -/// Arc::new(Int32Array::from(vec![1])), -/// Arc::new(StringArray::from(vec!["foo"])), +/// Arc::new(Int32Array::from(vec![1])), +/// Arc::new(StringArray::from(vec!["foo"])), /// ]; /// // no nulls for this array /// let nulls = None; @@ -191,6 +319,8 @@ pub use struct_builder::ScalarStructBuilder; /// See [datatypes](https://arrow.apache.org/docs/python/api/datatypes.html) for /// details on datatypes and the [format](https://github.com/apache/arrow/blob/master/format/Schema.fbs#L354-L375) /// for the definitive reference. +/// +/// [`NullArray`]: arrow::array::NullArray #[derive(Clone)] pub enum ScalarValue { /// represents `DataType::Null` (castable to/from any other type) @@ -203,6 +333,10 @@ pub enum ScalarValue { Float32(Option), /// 64bit float Float64(Option), + /// 32bit decimal, using the i32 to represent the decimal, precision scale + Decimal32(Option, u8, i8), + /// 64bit decimal, using the i64 to represent the decimal, precision scale + Decimal64(Option, u8, i8), /// 128bit decimal, using the i128 to represent the decimal, precision scale Decimal128(Option, u8, i8), /// 256bit decimal, using the i256 to represent the decimal, precision scale @@ -296,6 +430,8 @@ pub enum ScalarValue { Union(Option<(i8, Box)>, UnionFields, UnionMode), /// Dictionary type: index type and value Dictionary(Box, Box), + /// (run-ends field, value field, value) + RunEndEncoded(FieldRef, FieldRef, Box), } impl Hash for Fl { @@ -312,6 +448,14 @@ impl PartialEq for ScalarValue { // any newly added enum variant will require editing this list // or else face a compile error match (self, other) { + (Decimal32(v1, p1, s1), Decimal32(v2, p2, s2)) => { + v1.eq(v2) && p1.eq(p2) && s1.eq(s2) + } + (Decimal32(_, _, _), _) => false, + (Decimal64(v1, p1, s1), Decimal64(v2, p2, s2)) => { + v1.eq(v2) && p1.eq(p2) && s1.eq(s2) + } + (Decimal64(_, _, _), _) => false, (Decimal128(v1, p1, s1), Decimal128(v2, p2, s2)) => { v1.eq(v2) && p1.eq(p2) && s1.eq(s2) } @@ -417,6 +561,10 @@ impl PartialEq for ScalarValue { (Union(_, _, _), _) => false, (Dictionary(k1, v1), Dictionary(k2, v2)) => k1.eq(k2) && v1.eq(v2), (Dictionary(_, _), _) => false, + (RunEndEncoded(rf1, vf1, v1), RunEndEncoded(rf2, vf2, v2)) => { + rf1.eq(rf2) && vf1.eq(vf2) && v1.eq(v2) + } + (RunEndEncoded(_, _, _), _) => false, (Null, Null) => true, (Null, _) => false, } @@ -431,6 +579,24 @@ impl PartialOrd for ScalarValue { // any newly added enum variant will require editing this list // or else face a compile error match (self, other) { + (Decimal32(v1, p1, s1), Decimal32(v2, p2, s2)) => { + if p1.eq(p2) && s1.eq(s2) { + v1.partial_cmp(v2) + } else { + // Two decimal values can be compared if they have the same precision and scale. + None + } + } + (Decimal32(_, _, _), _) => None, + (Decimal64(v1, p1, s1), Decimal64(v2, p2, s2)) => { + if p1.eq(p2) && s1.eq(s2) { + v1.partial_cmp(v2) + } else { + // Two decimal values can be compared if they have the same precision and scale. + None + } + } + (Decimal64(_, _, _), _) => None, (Decimal128(v1, p1, s1), Decimal128(v2, p2, s2)) => { if p1.eq(p2) && s1.eq(s2) { v1.partial_cmp(v2) @@ -561,13 +727,18 @@ impl PartialOrd for ScalarValue { (Union(_, _, _), _) => None, (Dictionary(k1, v1), Dictionary(k2, v2)) => { // Don't compare if the key types don't match (it is effectively a different datatype) - if k1 == k2 { + if k1 == k2 { v1.partial_cmp(v2) } else { None } + } + (Dictionary(_, _), _) => None, + (RunEndEncoded(rf1, vf1, v1), RunEndEncoded(rf2, vf2, v2)) => { + // Don't compare if the run ends fields don't match (it is effectively a different datatype) + if rf1 == rf2 && vf1 == vf2 { v1.partial_cmp(v2) } else { None } } - (Dictionary(_, _), _) => None, + (RunEndEncoded(_, _, _), _) => None, (Null, Null) => Some(Ordering::Equal), (Null, _) => None, } @@ -585,7 +756,9 @@ fn first_array_for_list(arr: &dyn Array) -> ArrayRef { } else if let Some(arr) = arr.as_fixed_size_list_opt() { arr.value(0) } else { - unreachable!("Since only List / LargeList / FixedSizeList are supported, this should never happen") + unreachable!( + "Since only List / LargeList / FixedSizeList are supported, this should never happen" + ) } } @@ -732,6 +905,16 @@ impl Hash for ScalarValue { fn hash(&self, state: &mut H) { use ScalarValue::*; match self { + Decimal32(v, p, s) => { + v.hash(state); + p.hash(state); + s.hash(state) + } + Decimal64(v, p, s) => { + v.hash(state); + p.hash(state); + s.hash(state) + } Decimal128(v, p, s) => { v.hash(state); p.hash(state); @@ -799,6 +982,11 @@ impl Hash for ScalarValue { k.hash(state); v.hash(state); } + RunEndEncoded(rf, vf, v) => { + rf.hash(state); + vf.hash(state); + v.hash(state); + } // stable hash for Null value Null => 1.hash(state), } @@ -806,10 +994,11 @@ impl Hash for ScalarValue { } fn hash_nested_array(arr: ArrayRef, state: &mut H) { - let arrays = vec![arr.to_owned()]; - let hashes_buffer = &mut vec![0; arr.len()]; - let random_state = ahash::RandomState::with_seeds(0, 0, 0, 0); - let hashes = create_hashes(&arrays, &random_state, hashes_buffer).unwrap(); + let len = arr.len(); + let hashes_buffer = &mut vec![0; len]; + let random_state = crate::hash_utils::RandomState::with_seed(0); + let hashes = create_hashes(&[arr], &random_state, hashes_buffer) + .expect("hash_nested_array: failed to create row hashes"); // Hash back to std::hash::Hasher hashes.hash(state); } @@ -839,15 +1028,9 @@ fn dict_from_scalar( let values_array = value.to_array_of_size(1)?; // Create a key array with `size` elements, each of 0 - let key_array: PrimitiveArray = repeat_n( - if value.is_null() { - None - } else { - Some(K::default_value()) - }, - size, - ) - .collect(); + // Use cache to avoid repeated allocations for the same size + let key_array: PrimitiveArray = + get_or_create_cached_key_array::(size, value.is_null()); // create a new DictionaryArray // @@ -859,8 +1042,21 @@ fn dict_from_scalar( )) } -/// Create a dictionary array representing all the values in values -fn dict_from_values( +/// Create a `DictionaryArray` from the provided values array. +/// +/// Each element gets a unique key (`0..N-1`), without deduplication. +/// Useful for wrapping arrays in dictionary form. +/// +/// # Input +/// ["alice", "bob", "alice", null, "carol"] +/// +/// # Output +/// `DictionaryArray` +/// { +/// keys: [0, 1, 2, 3, 4], +/// values: ["alice", "bob", "alice", null, "carol"] +/// } +pub fn dict_from_values( values_array: ArrayRef, ) -> Result { // Create a key array with `size` elements of 0..array_len for all @@ -869,11 +1065,10 @@ fn dict_from_values( .map(|index| { if values_array.is_valid(index) { let native_index = K::Native::from_usize(index).ok_or_else(|| { - DataFusionError::Internal(format!( - "Can not create index of type {} from value {}", - K::DATA_TYPE, - index - )) + _internal_datafusion_err!( + "Can not create index of type {} from value {index}", + K::DATA_TYPE + ) })?; Ok(Some(native_index)) } else { @@ -894,17 +1089,8 @@ fn dict_from_values( } macro_rules! typed_cast_tz { - ($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident, $TZ:expr) => {{ - use std::any::type_name; - let array = $array - .as_any() - .downcast_ref::<$ARRAYTYPE>() - .ok_or_else(|| { - DataFusionError::Internal(format!( - "could not cast value to {}", - type_name::<$ARRAYTYPE>() - )) - })?; + ($array:expr, $index:expr, $array_cast:ident, $SCALAR:ident, $TZ:expr) => {{ + let array = $array_cast($array)?; Ok::(ScalarValue::$SCALAR( match array.is_null($index) { true => None, @@ -916,17 +1102,8 @@ macro_rules! typed_cast_tz { } macro_rules! typed_cast { - ($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident) => {{ - use std::any::type_name; - let array = $array - .as_any() - .downcast_ref::<$ARRAYTYPE>() - .ok_or_else(|| { - DataFusionError::Internal(format!( - "could not cast value to {}", - type_name::<$ARRAYTYPE>() - )) - })?; + ($array:expr, $index:expr, $array_cast:ident, $SCALAR:ident) => {{ + let array = $array_cast($array)?; Ok::(ScalarValue::$SCALAR( match array.is_null($index) { true => None, @@ -963,17 +1140,8 @@ macro_rules! build_timestamp_array_from_option { } macro_rules! eq_array_primitive { - ($array:expr, $index:expr, $ARRAYTYPE:ident, $VALUE:expr) => {{ - use std::any::type_name; - let array = $array - .as_any() - .downcast_ref::<$ARRAYTYPE>() - .ok_or_else(|| { - DataFusionError::Internal(format!( - "could not cast value to {}", - type_name::<$ARRAYTYPE>() - )) - })?; + ($array:expr, $index:expr, $array_cast:ident, $VALUE:expr) => {{ + let array = $array_cast($array)?; let is_valid = array.is_valid($index); Ok::(match $VALUE { Some(val) => is_valid && &array.value($index) == val, @@ -1004,21 +1172,16 @@ impl ScalarValue { /// Create a decimal Scalar from value/precision and scale. pub fn try_new_decimal128(value: i128, precision: u8, scale: i8) -> Result { - // make sure the precision and scale is valid - if precision <= DECIMAL128_MAX_PRECISION && scale.unsigned_abs() <= precision { - return Ok(ScalarValue::Decimal128(Some(value), precision, scale)); - } - _internal_err!( - "Can not new a decimal type ScalarValue for precision {precision} and scale {scale}" - ) + Self::validate_decimal_or_internal_err::(precision, scale)?; + Ok(ScalarValue::Decimal128(Some(value), precision, scale)) } /// Create a Null instance of ScalarValue for this datatype /// /// Example /// ``` - /// use datafusion_common::ScalarValue; /// use arrow::datatypes::DataType; + /// use datafusion_common::ScalarValue; /// /// let scalar = ScalarValue::try_new_null(&DataType::Int32).unwrap(); /// assert_eq!(scalar.is_null(), true); @@ -1038,6 +1201,12 @@ impl ScalarValue { DataType::UInt16 => ScalarValue::UInt16(None), DataType::UInt32 => ScalarValue::UInt32(None), DataType::UInt64 => ScalarValue::UInt64(None), + DataType::Decimal32(precision, scale) => { + ScalarValue::Decimal32(None, *precision, *scale) + } + DataType::Decimal64(precision, scale) => { + ScalarValue::Decimal64(None, *precision, *scale) + } DataType::Decimal128(precision, scale) => { ScalarValue::Decimal128(None, *precision, *scale) } @@ -1096,7 +1265,14 @@ impl ScalarValue { index_type.clone(), Box::new(value_type.as_ref().try_into()?), ), - // `ScalaValue::List` contains single element `ListArray`. + DataType::RunEndEncoded(run_ends_field, value_field) => { + ScalarValue::RunEndEncoded( + Arc::clone(run_ends_field), + Arc::clone(value_field), + Box::new(value_field.data_type().try_into()?), + ) + } + // `ScalarValue::List` contains single element `ListArray`. DataType::List(field_ref) => ScalarValue::List(Arc::new( GenericListArray::new_null(Arc::clone(field_ref), 1), )), @@ -1104,7 +1280,7 @@ impl ScalarValue { DataType::LargeList(field_ref) => ScalarValue::LargeList(Arc::new( GenericListArray::new_null(Arc::clone(field_ref), 1), )), - // `ScalaValue::FixedSizeList` contains single element `FixedSizeList`. + // `ScalarValue::FixedSizeList` contains single element `FixedSizeList`. DataType::FixedSizeList(field_ref, fixed_length) => { ScalarValue::FixedSizeList(Arc::new(FixedSizeListArray::new_null( Arc::clone(field_ref), @@ -1130,7 +1306,7 @@ impl ScalarValue { DataType::Null => ScalarValue::Null, _ => { return _not_impl_err!( - "Can't create a null scalar from data_type \"{data_type:?}\"" + "Can't create a null scalar from data_type \"{data_type}\"" ); } }) @@ -1184,21 +1360,21 @@ impl ScalarValue { /// Returns a [`ScalarValue`] representing PI pub fn new_pi(datatype: &DataType) -> Result { match datatype { + DataType::Float16 => Ok(ScalarValue::from(f16::PI)), DataType::Float32 => Ok(ScalarValue::from(std::f32::consts::PI)), DataType::Float64 => Ok(ScalarValue::from(std::f64::consts::PI)), - _ => _internal_err!("PI is not supported for data type: {:?}", datatype), + _ => _internal_err!("PI is not supported for data type: {}", datatype), } } /// Returns a [`ScalarValue`] representing PI's upper bound pub fn new_pi_upper(datatype: &DataType) -> Result { - // TODO: replace the constants with next_up/next_down when - // they are stabilized: https://doc.rust-lang.org/std/primitive.f64.html#method.next_up match datatype { + DataType::Float16 => Ok(ScalarValue::Float16(Some(consts::PI_UPPER_F16))), DataType::Float32 => Ok(ScalarValue::from(consts::PI_UPPER_F32)), DataType::Float64 => Ok(ScalarValue::from(consts::PI_UPPER_F64)), _ => { - _internal_err!("PI_UPPER is not supported for data type: {:?}", datatype) + _internal_err!("PI_UPPER is not supported for data type: {}", datatype) } } } @@ -1206,10 +1382,13 @@ impl ScalarValue { /// Returns a [`ScalarValue`] representing -PI's lower bound pub fn new_negative_pi_lower(datatype: &DataType) -> Result { match datatype { + DataType::Float16 => { + Ok(ScalarValue::Float16(Some(consts::NEGATIVE_PI_LOWER_F16))) + } DataType::Float32 => Ok(ScalarValue::from(consts::NEGATIVE_PI_LOWER_F32)), DataType::Float64 => Ok(ScalarValue::from(consts::NEGATIVE_PI_LOWER_F64)), _ => { - _internal_err!("-PI_LOWER is not supported for data type: {:?}", datatype) + _internal_err!("-PI_LOWER is not supported for data type: {}", datatype) } } } @@ -1217,13 +1396,13 @@ impl ScalarValue { /// Returns a [`ScalarValue`] representing FRAC_PI_2's upper bound pub fn new_frac_pi_2_upper(datatype: &DataType) -> Result { match datatype { + DataType::Float16 => { + Ok(ScalarValue::Float16(Some(consts::FRAC_PI_2_UPPER_F16))) + } DataType::Float32 => Ok(ScalarValue::from(consts::FRAC_PI_2_UPPER_F32)), DataType::Float64 => Ok(ScalarValue::from(consts::FRAC_PI_2_UPPER_F64)), _ => { - _internal_err!( - "PI_UPPER/2 is not supported for data type: {:?}", - datatype - ) + _internal_err!("PI_UPPER/2 is not supported for data type: {}", datatype) } } } @@ -1231,6 +1410,9 @@ impl ScalarValue { // Returns a [`ScalarValue`] representing FRAC_PI_2's lower bound pub fn new_neg_frac_pi_2_lower(datatype: &DataType) -> Result { match datatype { + DataType::Float16 => Ok(ScalarValue::Float16(Some( + consts::NEGATIVE_FRAC_PI_2_LOWER_F16, + ))), DataType::Float32 => { Ok(ScalarValue::from(consts::NEGATIVE_FRAC_PI_2_LOWER_F32)) } @@ -1238,10 +1420,7 @@ impl ScalarValue { Ok(ScalarValue::from(consts::NEGATIVE_FRAC_PI_2_LOWER_F64)) } _ => { - _internal_err!( - "-PI/2_LOWER is not supported for data type: {:?}", - datatype - ) + _internal_err!("-PI/2_LOWER is not supported for data type: {}", datatype) } } } @@ -1249,37 +1428,41 @@ impl ScalarValue { /// Returns a [`ScalarValue`] representing -PI pub fn new_negative_pi(datatype: &DataType) -> Result { match datatype { + DataType::Float16 => Ok(ScalarValue::from(-f16::PI)), DataType::Float32 => Ok(ScalarValue::from(-std::f32::consts::PI)), DataType::Float64 => Ok(ScalarValue::from(-std::f64::consts::PI)), - _ => _internal_err!("-PI is not supported for data type: {:?}", datatype), + _ => _internal_err!("-PI is not supported for data type: {}", datatype), } } /// Returns a [`ScalarValue`] representing PI/2 pub fn new_frac_pi_2(datatype: &DataType) -> Result { match datatype { + DataType::Float16 => Ok(ScalarValue::from(f16::FRAC_PI_2)), DataType::Float32 => Ok(ScalarValue::from(std::f32::consts::FRAC_PI_2)), DataType::Float64 => Ok(ScalarValue::from(std::f64::consts::FRAC_PI_2)), - _ => _internal_err!("PI/2 is not supported for data type: {:?}", datatype), + _ => _internal_err!("PI/2 is not supported for data type: {}", datatype), } } /// Returns a [`ScalarValue`] representing -PI/2 pub fn new_neg_frac_pi_2(datatype: &DataType) -> Result { match datatype { + DataType::Float16 => Ok(ScalarValue::from(-f16::FRAC_PI_2)), DataType::Float32 => Ok(ScalarValue::from(-std::f32::consts::FRAC_PI_2)), DataType::Float64 => Ok(ScalarValue::from(-std::f64::consts::FRAC_PI_2)), - _ => _internal_err!("-PI/2 is not supported for data type: {:?}", datatype), + _ => _internal_err!("-PI/2 is not supported for data type: {}", datatype), } } /// Returns a [`ScalarValue`] representing infinity pub fn new_infinity(datatype: &DataType) -> Result { match datatype { + DataType::Float16 => Ok(ScalarValue::from(f16::INFINITY)), DataType::Float32 => Ok(ScalarValue::from(f32::INFINITY)), DataType::Float64 => Ok(ScalarValue::from(f64::INFINITY)), _ => { - _internal_err!("Infinity is not supported for data type: {:?}", datatype) + _internal_err!("Infinity is not supported for data type: {}", datatype) } } } @@ -1287,11 +1470,12 @@ impl ScalarValue { /// Returns a [`ScalarValue`] representing negative infinity pub fn new_neg_infinity(datatype: &DataType) -> Result { match datatype { + DataType::Float16 => Ok(ScalarValue::from(f16::NEG_INFINITY)), DataType::Float32 => Ok(ScalarValue::from(f32::NEG_INFINITY)), DataType::Float64 => Ok(ScalarValue::from(f64::NEG_INFINITY)), _ => { _internal_err!( - "Negative Infinity is not supported for data type: {:?}", + "Negative Infinity is not supported for data type: {}", datatype ) } @@ -1310,9 +1494,15 @@ impl ScalarValue { DataType::UInt16 => ScalarValue::UInt16(Some(0)), DataType::UInt32 => ScalarValue::UInt32(Some(0)), DataType::UInt64 => ScalarValue::UInt64(Some(0)), - DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(0.0))), + DataType::Float16 => ScalarValue::Float16(Some(f16::ZERO)), DataType::Float32 => ScalarValue::Float32(Some(0.0)), DataType::Float64 => ScalarValue::Float64(Some(0.0)), + DataType::Decimal32(precision, scale) => { + ScalarValue::Decimal32(Some(0), *precision, *scale) + } + DataType::Decimal64(precision, scale) => { + ScalarValue::Decimal64(Some(0), *precision, *scale) + } DataType::Decimal128(precision, scale) => { ScalarValue::Decimal128(Some(0), *precision, *scale) } @@ -1364,12 +1554,159 @@ impl ScalarValue { DataType::Date64 => ScalarValue::Date64(Some(0)), _ => { return _not_impl_err!( - "Can't create a zero scalar from data_type \"{datatype:?}\"" + "Can't create a zero scalar from data_type \"{datatype}\"" ); } }) } + /// Returns a default value for the given `DataType`. + /// + /// This function is useful when you need to initialize a column with + /// non-null values in a DataFrame or when you need a "zero" value + /// for a specific data type. + /// + /// # Default Values + /// + /// - **Numeric types**: Returns zero (via [`new_zero`]) + /// - **String types**: Returns empty string (`""`) + /// - **Binary types**: Returns empty byte array + /// - **Temporal types**: Returns zero/epoch value + /// - **List types**: Returns empty list + /// - **Struct types**: Returns struct with all fields set to their defaults + /// - **Dictionary types**: Returns dictionary with default value + /// - **Map types**: Returns empty map + /// - **Union types**: Returns first variant with default value + /// + /// # Errors + /// + /// Returns an error for data types that don't have a clear default value + /// or are not yet supported (e.g., `RunEndEncoded`). + /// + /// [`new_zero`]: Self::new_zero + pub fn new_default(datatype: &DataType) -> Result { + match datatype { + // Null type + DataType::Null => Ok(ScalarValue::Null), + + // Numeric types + DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float16 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal32(_, _) + | DataType::Decimal64(_, _) + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) + | DataType::Timestamp(_, _) + | DataType::Time32(_) + | DataType::Time64(_) + | DataType::Interval(_) + | DataType::Duration(_) + | DataType::Date32 + | DataType::Date64 => ScalarValue::new_zero(datatype), + + // String types + DataType::Utf8 => Ok(ScalarValue::Utf8(Some("".to_string()))), + DataType::LargeUtf8 => Ok(ScalarValue::LargeUtf8(Some("".to_string()))), + DataType::Utf8View => Ok(ScalarValue::Utf8View(Some("".to_string()))), + + // Binary types + DataType::Binary => Ok(ScalarValue::Binary(Some(vec![]))), + DataType::LargeBinary => Ok(ScalarValue::LargeBinary(Some(vec![]))), + DataType::BinaryView => Ok(ScalarValue::BinaryView(Some(vec![]))), + + // Fixed-size binary + DataType::FixedSizeBinary(size) => Ok(ScalarValue::FixedSizeBinary( + *size, + Some(vec![0; *size as usize]), + )), + + // List types + DataType::List(field) => { + let list = + ScalarValue::new_list(&[], field.data_type(), field.is_nullable()); + Ok(ScalarValue::List(list)) + } + DataType::FixedSizeList(field, _size) => { + let empty_arr = new_empty_array(field.data_type()); + let values = Arc::new( + SingleRowListArrayBuilder::new(empty_arr) + .with_nullable(field.is_nullable()) + .build_fixed_size_list_array(0), + ); + Ok(ScalarValue::FixedSizeList(values)) + } + DataType::LargeList(field) => { + let list = ScalarValue::new_large_list(&[], field.data_type()); + Ok(ScalarValue::LargeList(list)) + } + + // Struct types + DataType::Struct(fields) => { + let values = fields + .iter() + .map(|f| ScalarValue::new_default(f.data_type())) + .collect::>>()?; + Ok(ScalarValue::Struct(Arc::new(StructArray::new( + fields.clone(), + values + .into_iter() + .map(|v| v.to_array()) + .collect::>()?, + None, + )))) + } + + // Dictionary types + DataType::Dictionary(key_type, value_type) => Ok(ScalarValue::Dictionary( + key_type.clone(), + Box::new(ScalarValue::new_default(value_type)?), + )), + + DataType::RunEndEncoded(run_ends_field, value_field) => { + Ok(ScalarValue::RunEndEncoded( + Arc::clone(run_ends_field), + Arc::clone(value_field), + Box::new(ScalarValue::new_default(value_field.data_type())?), + )) + } + + // Map types + DataType::Map(field, _) => Ok(ScalarValue::Map(Arc::new(MapArray::from( + ArrayData::new_empty(field.data_type()), + )))), + + // Union types - return first variant with default value + DataType::Union(fields, mode) => { + if let Some((type_id, field)) = fields.iter().next() { + let default_value = ScalarValue::new_default(field.data_type())?; + Ok(ScalarValue::Union( + Some((type_id, Box::new(default_value))), + fields.clone(), + *mode, + )) + } else { + _internal_err!("Union type must have at least one field") + } + } + + DataType::ListView(_) | DataType::LargeListView(_) => { + _not_impl_err!( + "Default value for data_type \"{datatype}\" is not implemented yet" + ) + } + } + } + /// Create an one value in the given type. pub fn new_one(datatype: &DataType) -> Result { Ok(match datatype { @@ -1381,12 +1718,60 @@ impl ScalarValue { DataType::UInt16 => ScalarValue::UInt16(Some(1)), DataType::UInt32 => ScalarValue::UInt32(Some(1)), DataType::UInt64 => ScalarValue::UInt64(Some(1)), - DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(1.0))), + DataType::Float16 => ScalarValue::Float16(Some(f16::ONE)), DataType::Float32 => ScalarValue::Float32(Some(1.0)), DataType::Float64 => ScalarValue::Float64(Some(1.0)), + DataType::Decimal32(precision, scale) => { + Self::validate_decimal_or_internal_err::( + *precision, *scale, + )?; + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); + match 10_i32.checked_pow(*scale as u32) { + Some(value) => { + ScalarValue::Decimal32(Some(value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } + DataType::Decimal64(precision, scale) => { + Self::validate_decimal_or_internal_err::( + *precision, *scale, + )?; + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); + match i64::from(10).checked_pow(*scale as u32) { + Some(value) => { + ScalarValue::Decimal64(Some(value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } + DataType::Decimal128(precision, scale) => { + Self::validate_decimal_or_internal_err::( + *precision, *scale, + )?; + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); + match i128::from(10).checked_pow(*scale as u32) { + Some(value) => { + ScalarValue::Decimal128(Some(value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } + DataType::Decimal256(precision, scale) => { + Self::validate_decimal_or_internal_err::( + *precision, *scale, + )?; + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); + match i256::from(10).checked_pow(*scale as u32) { + Some(value) => { + ScalarValue::Decimal256(Some(value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } _ => { return _not_impl_err!( - "Can't create an one scalar from data_type \"{datatype:?}\"" + "Can't create an one scalar from data_type \"{datatype}\"" ); } }) @@ -1399,12 +1784,60 @@ impl ScalarValue { DataType::Int16 | DataType::UInt16 => ScalarValue::Int16(Some(-1)), DataType::Int32 | DataType::UInt32 => ScalarValue::Int32(Some(-1)), DataType::Int64 | DataType::UInt64 => ScalarValue::Int64(Some(-1)), - DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(-1.0))), + DataType::Float16 => ScalarValue::Float16(Some(f16::NEG_ONE)), DataType::Float32 => ScalarValue::Float32(Some(-1.0)), DataType::Float64 => ScalarValue::Float64(Some(-1.0)), + DataType::Decimal32(precision, scale) => { + Self::validate_decimal_or_internal_err::( + *precision, *scale, + )?; + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); + match 10_i32.checked_pow(*scale as u32) { + Some(value) => { + ScalarValue::Decimal32(Some(-value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } + DataType::Decimal64(precision, scale) => { + Self::validate_decimal_or_internal_err::( + *precision, *scale, + )?; + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); + match i64::from(10).checked_pow(*scale as u32) { + Some(value) => { + ScalarValue::Decimal64(Some(-value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } + DataType::Decimal128(precision, scale) => { + Self::validate_decimal_or_internal_err::( + *precision, *scale, + )?; + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); + match i128::from(10).checked_pow(*scale as u32) { + Some(value) => { + ScalarValue::Decimal128(Some(-value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } + DataType::Decimal256(precision, scale) => { + Self::validate_decimal_or_internal_err::( + *precision, *scale, + )?; + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); + match i256::from(10).checked_pow(*scale as u32) { + Some(value) => { + ScalarValue::Decimal256(Some(-value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } _ => { return _not_impl_err!( - "Can't create a negative one scalar from data_type \"{datatype:?}\"" + "Can't create a negative one scalar from data_type \"{datatype}\"" ); } }) @@ -1423,9 +1856,57 @@ impl ScalarValue { DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(10.0))), DataType::Float32 => ScalarValue::Float32(Some(10.0)), DataType::Float64 => ScalarValue::Float64(Some(10.0)), + DataType::Decimal32(precision, scale) => { + Self::validate_decimal_or_internal_err::( + *precision, *scale, + )?; + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); + match 10_i32.checked_pow((*scale + 1) as u32) { + Some(value) => { + ScalarValue::Decimal32(Some(value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } + DataType::Decimal64(precision, scale) => { + Self::validate_decimal_or_internal_err::( + *precision, *scale, + )?; + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); + match i64::from(10).checked_pow((*scale + 1) as u32) { + Some(value) => { + ScalarValue::Decimal64(Some(value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } + DataType::Decimal128(precision, scale) => { + Self::validate_decimal_or_internal_err::( + *precision, *scale, + )?; + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); + match i128::from(10).checked_pow((*scale + 1) as u32) { + Some(value) => { + ScalarValue::Decimal128(Some(value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } + DataType::Decimal256(precision, scale) => { + Self::validate_decimal_or_internal_err::( + *precision, *scale, + )?; + assert_or_internal_err!(*scale >= 0, "Negative scale is not supported"); + match i256::from(10).checked_pow((*scale + 1) as u32) { + Some(value) => { + ScalarValue::Decimal256(Some(value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } _ => { return _not_impl_err!( - "Can't create a ten scalar from data_type \"{datatype:?}\"" + "Can't create a ten scalar from data_type \"{datatype}\"" ); } }) @@ -1443,6 +1924,12 @@ impl ScalarValue { ScalarValue::Int16(_) => DataType::Int16, ScalarValue::Int32(_) => DataType::Int32, ScalarValue::Int64(_) => DataType::Int64, + ScalarValue::Decimal32(_, precision, scale) => { + DataType::Decimal32(*precision, *scale) + } + ScalarValue::Decimal64(_, precision, scale) => { + DataType::Decimal64(*precision, *scale) + } ScalarValue::Decimal128(_, precision, scale) => { DataType::Decimal128(*precision, *scale) } @@ -1503,6 +1990,12 @@ impl ScalarValue { ScalarValue::Dictionary(k, v) => { DataType::Dictionary(k.clone(), Box::new(v.data_type())) } + ScalarValue::RunEndEncoded(run_ends_field, value_field, _) => { + DataType::RunEndEncoded( + Arc::clone(run_ends_field), + Arc::clone(value_field), + ) + } ScalarValue::Null => DataType::Null, } } @@ -1524,9 +2017,7 @@ impl ScalarValue { | ScalarValue::Float16(None) | ScalarValue::Float32(None) | ScalarValue::Float64(None) => Ok(self.clone()), - ScalarValue::Float16(Some(v)) => { - Ok(ScalarValue::Float16(Some(f16::from_f32(-v.to_f32())))) - } + ScalarValue::Float16(Some(v)) => Ok(ScalarValue::Float16(Some(-v))), ScalarValue::Float64(Some(v)) => Ok(ScalarValue::Float64(Some(-v))), ScalarValue::Float32(Some(v)) => Ok(ScalarValue::Float32(Some(-v))), ScalarValue::Int8(Some(v)) => Ok(ScalarValue::Int8(Some(v.neg_checked()?))), @@ -1565,6 +2056,24 @@ impl ScalarValue { ); Ok(ScalarValue::IntervalMonthDayNano(Some(val))) } + ScalarValue::Decimal32(Some(v), precision, scale) => { + Ok(ScalarValue::Decimal32( + Some(neg_checked_with_ctx(*v, || { + format!("In negation of Decimal32({v}, {precision}, {scale})") + })?), + *precision, + *scale, + )) + } + ScalarValue::Decimal64(Some(v), precision, scale) => { + Ok(ScalarValue::Decimal64( + Some(neg_checked_with_ctx(*v, || { + format!("In negation of Decimal64({v}, {precision}, {scale})") + })?), + *precision, + *scale, + )) + } ScalarValue::Decimal128(Some(v), precision, scale) => { Ok(ScalarValue::Decimal128( Some(neg_checked_with_ctx(*v, || { @@ -1629,6 +2138,7 @@ impl ScalarValue { let r = add_wrapping(&self.to_scalar()?, &other.borrow().to_scalar()?)?; Self::try_from_array(r.as_ref(), 0) } + /// Checked addition of `ScalarValue` /// /// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code @@ -1716,6 +2226,8 @@ impl ScalarValue { ScalarValue::Float16(v) => v.is_none(), ScalarValue::Float32(v) => v.is_none(), ScalarValue::Float64(v) => v.is_none(), + ScalarValue::Decimal32(v, _, _) => v.is_none(), + ScalarValue::Decimal64(v, _, _) => v.is_none(), ScalarValue::Decimal128(v, _, _) => v.is_none(), ScalarValue::Decimal256(v, _, _) => v.is_none(), ScalarValue::Int8(v) => v.is_none(), @@ -1762,6 +2274,7 @@ impl ScalarValue { None => true, }, ScalarValue::Dictionary(_, v) => v.is_null(), + ScalarValue::RunEndEncoded(_, _, v) => v.is_null(), } } @@ -1792,6 +2305,26 @@ impl ScalarValue { (Self::Float64(Some(l)), Self::Float64(Some(r))) => { Some((l - r).abs().round() as _) } + ( + Self::Decimal128(Some(l), lprecision, lscale), + Self::Decimal128(Some(r), rprecision, rscale), + ) => { + if lprecision == rprecision && lscale == rscale { + l.checked_sub(*r)?.checked_abs()?.to_usize() + } else { + None + } + } + ( + Self::Decimal256(Some(l), lprecision, lscale), + Self::Decimal256(Some(r), rprecision, rscale), + ) => { + if lprecision == rprecision && lscale == rscale { + l.checked_sub(*r)?.checked_abs()?.to_usize() + } else { + None + } + } _ => None, } } @@ -1816,23 +2349,16 @@ impl ScalarValue { /// /// # Example /// ``` - /// use datafusion_common::ScalarValue; /// use arrow::array::{BooleanArray, Int32Array}; + /// use datafusion_common::ScalarValue; /// /// let arr = Int32Array::from(vec![Some(1), None, Some(10)]); /// let five = ScalarValue::Int32(Some(5)); /// - /// let result = arrow::compute::kernels::cmp::lt( - /// &arr, - /// &five.to_scalar().unwrap(), - /// ).unwrap(); + /// let result = + /// arrow::compute::kernels::cmp::lt(&arr, &five.to_scalar().unwrap()).unwrap(); /// - /// let expected = BooleanArray::from(vec![ - /// Some(true), - /// None, - /// Some(false) - /// ] - /// ); + /// let expected = BooleanArray::from(vec![Some(true), None, Some(false)]); /// /// assert_eq!(&result, &expected); /// ``` @@ -1848,32 +2374,22 @@ impl ScalarValue { /// Returns an error if the iterator is empty or if the /// [`ScalarValue`]s are not all the same type /// - /// # Panics - /// - /// Panics if `self` is a dictionary with invalid key type - /// /// # Example /// ``` - /// use datafusion_common::ScalarValue; /// use arrow::array::{ArrayRef, BooleanArray}; + /// use datafusion_common::ScalarValue; /// /// let scalars = vec![ - /// ScalarValue::Boolean(Some(true)), - /// ScalarValue::Boolean(None), - /// ScalarValue::Boolean(Some(false)), + /// ScalarValue::Boolean(Some(true)), + /// ScalarValue::Boolean(None), + /// ScalarValue::Boolean(Some(false)), /// ]; /// /// // Build an Array from the list of ScalarValues - /// let array = ScalarValue::iter_to_array(scalars.into_iter()) - /// .unwrap(); + /// let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap(); /// - /// let expected: ArrayRef = std::sync::Arc::new( - /// BooleanArray::from(vec![ - /// Some(true), - /// None, - /// Some(false) - /// ] - /// )); + /// let expected: ArrayRef = + /// std::sync::Arc::new(BooleanArray::from(vec![Some(true), None, Some(false)])); /// /// assert_eq!(&array, &expected); /// ``` @@ -1895,18 +2411,20 @@ impl ScalarValue { macro_rules! build_array_primitive { ($ARRAY_TY:ident, $SCALAR_TY:ident) => {{ { - let array = scalars.map(|sv| { - if let ScalarValue::$SCALAR_TY(v) = sv { - Ok(v) - } else { - _exec_err!( - "Inconsistent types in ScalarValue::iter_to_array. \ + let array = scalars + .map(|sv| { + if let ScalarValue::$SCALAR_TY(v) = sv { + Ok(v) + } else { + _exec_err!( + "Inconsistent types in ScalarValue::iter_to_array. \ Expected {:?}, got {:?}", - data_type, sv - ) - } - }) - .collect::>()?; + data_type, + sv + ) + } + }) + .collect::>()?; Arc::new(array) } }}; @@ -1915,18 +2433,20 @@ impl ScalarValue { macro_rules! build_array_primitive_tz { ($ARRAY_TY:ident, $SCALAR_TY:ident, $TZ:expr) => {{ { - let array = scalars.map(|sv| { - if let ScalarValue::$SCALAR_TY(v, _) = sv { - Ok(v) - } else { - _exec_err!( - "Inconsistent types in ScalarValue::iter_to_array. \ + let array = scalars + .map(|sv| { + if let ScalarValue::$SCALAR_TY(v, _) = sv { + Ok(v) + } else { + _exec_err!( + "Inconsistent types in ScalarValue::iter_to_array. \ Expected {:?}, got {:?}", - data_type, sv - ) - } - }) - .collect::>()?; + data_type, + sv + ) + } + }) + .collect::>()?; Arc::new(array.with_timezone_opt($TZ.clone())) } }}; @@ -1937,36 +2457,48 @@ impl ScalarValue { macro_rules! build_array_string { ($ARRAY_TY:ident, $SCALAR_TY:ident) => {{ { - let array = scalars.map(|sv| { - if let ScalarValue::$SCALAR_TY(v) = sv { - Ok(v) - } else { - _exec_err!( - "Inconsistent types in ScalarValue::iter_to_array. \ + let array = scalars + .map(|sv| { + if let ScalarValue::$SCALAR_TY(v) = sv { + Ok(v) + } else { + _exec_err!( + "Inconsistent types in ScalarValue::iter_to_array. \ Expected {:?}, got {:?}", - data_type, sv - ) - } - }) - .collect::>()?; + data_type, + sv + ) + } + }) + .collect::>()?; Arc::new(array) } }}; } let array: ArrayRef = match &data_type { - DataType::Decimal128(precision, scale) => { + DataType::Decimal32(precision, scale) => { let decimal_array = - ScalarValue::iter_to_decimal_array(scalars, *precision, *scale)?; + ScalarValue::iter_to_decimal32_array(scalars, *precision, *scale)?; Arc::new(decimal_array) } - DataType::Decimal256(precision, scale) => { + DataType::Decimal64(precision, scale) => { let decimal_array = - ScalarValue::iter_to_decimal256_array(scalars, *precision, *scale)?; + ScalarValue::iter_to_decimal64_array(scalars, *precision, *scale)?; Arc::new(decimal_array) } - DataType::Null => ScalarValue::iter_to_null_array(scalars)?, - DataType::Boolean => build_array_primitive!(BooleanArray, Boolean), + DataType::Decimal128(precision, scale) => { + let decimal_array = + ScalarValue::iter_to_decimal128_array(scalars, *precision, *scale)?; + Arc::new(decimal_array) + } + DataType::Decimal256(precision, scale) => { + let decimal_array = + ScalarValue::iter_to_decimal256_array(scalars, *precision, *scale)?; + Arc::new(decimal_array) + } + DataType::Null => ScalarValue::iter_to_null_array(scalars)?, + DataType::Boolean => build_array_primitive!(BooleanArray, Boolean), DataType::Float16 => build_array_primitive!(Float16Array, Float16), DataType::Float32 => build_array_primitive!(Float32Array, Float32), DataType::Float64 => build_array_primitive!(Float64Array, Float64), @@ -2107,7 +2639,95 @@ impl ScalarValue { DataType::UInt16 => dict_from_values::(values)?, DataType::UInt32 => dict_from_values::(values)?, DataType::UInt64 => dict_from_values::(values)?, - _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), + _ => unreachable!("Invalid dictionary keys type: {}", key_type), + } + } + DataType::RunEndEncoded(run_ends_field, value_field) => { + fn make_run_array( + scalars: impl IntoIterator, + run_ends_field: &FieldRef, + values_field: &FieldRef, + ) -> Result { + let mut scalars = scalars.into_iter(); + + let mut run_ends = vec![]; + let mut value_scalars = vec![]; + + let mut len = R::Native::ONE; + let mut current = + if let Some(ScalarValue::RunEndEncoded(_, _, scalar)) = + scalars.next() + { + *scalar + } else { + // We are guaranteed to have one element of correct + // type because we peeked above + unreachable!() + }; + for scalar in scalars { + let scalar = match scalar { + ScalarValue::RunEndEncoded( + inner_run_ends_field, + inner_value_field, + scalar, + ) if &inner_run_ends_field == run_ends_field + && &inner_value_field == values_field => + { + *scalar + } + _ => { + return _exec_err!( + "Expected RunEndEncoded scalar with run-ends field {run_ends_field} but got: {scalar:?}" + ); + } + }; + + // new run + if scalar != current { + run_ends.push(len); + value_scalars.push(current); + current = scalar; + } + + len = len.add_checked(R::Native::ONE).map_err(|_| { + DataFusionError::Execution(format!( + "Cannot construct RunArray: Overflows run-ends type {}", + run_ends_field.data_type() + )) + })?; + } + + run_ends.push(len); + value_scalars.push(current); + + let run_ends = PrimitiveArray::::from_iter_values(run_ends); + let values = ScalarValue::iter_to_array(value_scalars)?; + + // Using ArrayDataBuilder so we can maintain the fields + let dt = DataType::RunEndEncoded( + Arc::clone(run_ends_field), + Arc::clone(values_field), + ); + let builder = ArrayDataBuilder::new(dt) + .len(RunArray::logical_len(&run_ends)) + .add_child_data(run_ends.to_data()) + .add_child_data(values.to_data()); + let run_array = RunArray::::from(builder.build()?); + + Ok(Arc::new(run_array)) + } + + match run_ends_field.data_type() { + DataType::Int16 => { + make_run_array::(scalars, run_ends_field, value_field)? + } + DataType::Int32 => { + make_run_array::(scalars, run_ends_field, value_field)? + } + DataType::Int64 => { + make_run_array::(scalars, run_ends_field, value_field)? + } + dt => unreachable!("Invalid run-ends type: {dt}"), } } DataType::FixedSizeBinary(size) => { @@ -2118,7 +2738,7 @@ impl ScalarValue { } else { _exec_err!( "Inconsistent types in ScalarValue::iter_to_array. \ - Expected {data_type:?}, got {sv:?}" + Expected {data_type}, got {sv:?}" ) } }) @@ -2130,7 +2750,7 @@ impl ScalarValue { Arc::new(array) } // explicitly enumerate unsupported types so newly added - // types must be aknowledged, Time32 and Time64 types are + // types must be acknowledged, Time32 and Time64 types are // not supported if the TimeUnit is not valid (Time32 can // only be used with Second and Millisecond, Time64 only // with Microsecond and Nanosecond) @@ -2138,7 +2758,6 @@ impl ScalarValue { | DataType::Time32(TimeUnit::Nanosecond) | DataType::Time64(TimeUnit::Second) | DataType::Time64(TimeUnit::Millisecond) - | DataType::RunEndEncoded(_, _) | DataType::ListView(_) | DataType::LargeListView(_) => { return _not_impl_err!( @@ -2166,71 +2785,78 @@ impl ScalarValue { Ok(new_null_array(&DataType::Null, length)) } - fn iter_to_decimal_array( + fn iter_to_decimal32_array( scalars: impl IntoIterator, precision: u8, scale: i8, - ) -> Result { + ) -> Result { let array = scalars .into_iter() .map(|element: ScalarValue| match element { - ScalarValue::Decimal128(v1, _, _) => Ok(v1), + ScalarValue::Decimal32(v1, _, _) => Ok(v1), s => { _internal_err!("Expected ScalarValue::Null element. Received {s:?}") } }) - .collect::>()? + .collect::>()? .with_precision_and_scale(precision, scale)?; Ok(array) } - fn iter_to_decimal256_array( + fn iter_to_decimal64_array( scalars: impl IntoIterator, precision: u8, scale: i8, - ) -> Result { + ) -> Result { let array = scalars .into_iter() .map(|element: ScalarValue| match element { - ScalarValue::Decimal256(v1, _, _) => Ok(v1), + ScalarValue::Decimal64(v1, _, _) => Ok(v1), s => { - _internal_err!( - "Expected ScalarValue::Decimal256 element. Received {s:?}" - ) + _internal_err!("Expected ScalarValue::Null element. Received {s:?}") } }) - .collect::>()? + .collect::>()? .with_precision_and_scale(precision, scale)?; Ok(array) } - fn build_decimal_array( - value: Option, + fn iter_to_decimal128_array( + scalars: impl IntoIterator, precision: u8, scale: i8, - size: usize, ) -> Result { - Ok(match value { - Some(val) => Decimal128Array::from(vec![val; size]) - .with_precision_and_scale(precision, scale)?, - None => { - let mut builder = Decimal128Array::builder(size) - .with_precision_and_scale(precision, scale)?; - builder.append_nulls(size); - builder.finish() - } - }) + let array = scalars + .into_iter() + .map(|element: ScalarValue| match element { + ScalarValue::Decimal128(v1, _, _) => Ok(v1), + s => { + _internal_err!("Expected ScalarValue::Null element. Received {s:?}") + } + }) + .collect::>()? + .with_precision_and_scale(precision, scale)?; + Ok(array) } - fn build_decimal256_array( - value: Option, + fn iter_to_decimal256_array( + scalars: impl IntoIterator, precision: u8, scale: i8, - size: usize, ) -> Result { - Ok(repeat_n(value, size) - .collect::() - .with_precision_and_scale(precision, scale)?) + let array = scalars + .into_iter() + .map(|element: ScalarValue| match element { + ScalarValue::Decimal256(v1, _, _) => Ok(v1), + s => { + _internal_err!( + "Expected ScalarValue::Decimal256 element. Received {s:?}" + ) + } + }) + .collect::>()? + .with_precision_and_scale(precision, scale)?; + Ok(array) } /// Converts `Vec` where each element has type corresponding to @@ -2238,23 +2864,24 @@ impl ScalarValue { /// /// Example /// ``` - /// use datafusion_common::ScalarValue; - /// use arrow::array::{ListArray, Int32Array}; + /// use arrow::array::{Int32Array, ListArray}; /// use arrow::datatypes::{DataType, Int32Type}; /// use datafusion_common::cast::as_list_array; + /// use datafusion_common::ScalarValue; /// /// let scalars = vec![ - /// ScalarValue::Int32(Some(1)), - /// ScalarValue::Int32(None), - /// ScalarValue::Int32(Some(2)) + /// ScalarValue::Int32(Some(1)), + /// ScalarValue::Int32(None), + /// ScalarValue::Int32(Some(2)), /// ]; /// /// let result = ScalarValue::new_list(&scalars, &DataType::Int32, true); /// - /// let expected = ListArray::from_iter_primitive::( - /// vec![ - /// Some(vec![Some(1), None, Some(2)]) - /// ]); + /// let expected = ListArray::from_iter_primitive::(vec![Some(vec![ + /// Some(1), + /// None, + /// Some(2), + /// ])]); /// /// assert_eq!(*result, expected); /// ``` @@ -2298,23 +2925,25 @@ impl ScalarValue { /// /// Example /// ``` - /// use datafusion_common::ScalarValue; - /// use arrow::array::{ListArray, Int32Array}; + /// use arrow::array::{Int32Array, ListArray}; /// use arrow::datatypes::{DataType, Int32Type}; /// use datafusion_common::cast::as_list_array; + /// use datafusion_common::ScalarValue; /// /// let scalars = vec![ - /// ScalarValue::Int32(Some(1)), - /// ScalarValue::Int32(None), - /// ScalarValue::Int32(Some(2)) + /// ScalarValue::Int32(Some(1)), + /// ScalarValue::Int32(None), + /// ScalarValue::Int32(Some(2)), /// ]; /// - /// let result = ScalarValue::new_list_from_iter(scalars.into_iter(), &DataType::Int32, true); + /// let result = + /// ScalarValue::new_list_from_iter(scalars.into_iter(), &DataType::Int32, true); /// - /// let expected = ListArray::from_iter_primitive::( - /// vec![ - /// Some(vec![Some(1), None, Some(2)]) - /// ]); + /// let expected = ListArray::from_iter_primitive::(vec![Some(vec![ + /// Some(1), + /// None, + /// Some(2), + /// ])]); /// /// assert_eq!(*result, expected); /// ``` @@ -2340,23 +2969,25 @@ impl ScalarValue { /// /// Example /// ``` - /// use datafusion_common::ScalarValue; - /// use arrow::array::{LargeListArray, Int32Array}; + /// use arrow::array::{Int32Array, LargeListArray}; /// use arrow::datatypes::{DataType, Int32Type}; /// use datafusion_common::cast::as_large_list_array; + /// use datafusion_common::ScalarValue; /// /// let scalars = vec![ - /// ScalarValue::Int32(Some(1)), - /// ScalarValue::Int32(None), - /// ScalarValue::Int32(Some(2)) + /// ScalarValue::Int32(Some(1)), + /// ScalarValue::Int32(None), + /// ScalarValue::Int32(Some(2)), /// ]; /// /// let result = ScalarValue::new_large_list(&scalars, &DataType::Int32); /// - /// let expected = LargeListArray::from_iter_primitive::( - /// vec![ - /// Some(vec![Some(1), None, Some(2)]) - /// ]); + /// let expected = + /// LargeListArray::from_iter_primitive::(vec![Some(vec![ + /// Some(1), + /// None, + /// Some(2), + /// ])]); /// /// assert_eq!(*result, expected); /// ``` @@ -2378,20 +3009,51 @@ impl ScalarValue { /// /// Errors if `self` is /// - a decimal that fails be converted to a decimal array of size - /// - a `FixedsizeList` that fails to be concatenated into an array of size + /// - a `FixedSizeList` that fails to be concatenated into an array of size /// - a `List` that fails to be concatenated into an array of size /// - a `Dictionary` that fails be converted to a dictionary array of size pub fn to_array_of_size(&self, size: usize) -> Result { Ok(match self { - ScalarValue::Decimal128(e, precision, scale) => Arc::new( - ScalarValue::build_decimal_array(*e, *precision, *scale, size)?, + ScalarValue::Decimal32(Some(e), precision, scale) => Arc::new( + Decimal32Array::from_value(*e, size) + .with_precision_and_scale(*precision, *scale)?, + ), + ScalarValue::Decimal32(None, precision, scale) => { + new_null_array(&DataType::Decimal32(*precision, *scale), size) + } + ScalarValue::Decimal64(Some(e), precision, scale) => Arc::new( + Decimal64Array::from_value(*e, size) + .with_precision_and_scale(*precision, *scale)?, ), - ScalarValue::Decimal256(e, precision, scale) => Arc::new( - ScalarValue::build_decimal256_array(*e, *precision, *scale, size)?, + ScalarValue::Decimal64(None, precision, scale) => { + new_null_array(&DataType::Decimal64(*precision, *scale), size) + } + ScalarValue::Decimal128(Some(e), precision, scale) => Arc::new( + Decimal128Array::from_value(*e, size) + .with_precision_and_scale(*precision, *scale)?, + ), + ScalarValue::Decimal128(None, precision, scale) => { + new_null_array(&DataType::Decimal128(*precision, *scale), size) + } + ScalarValue::Decimal256(Some(e), precision, scale) => Arc::new( + Decimal256Array::from_value(*e, size) + .with_precision_and_scale(*precision, *scale)?, ), - ScalarValue::Boolean(e) => { - Arc::new(BooleanArray::from(vec![*e; size])) as ArrayRef + ScalarValue::Decimal256(None, precision, scale) => { + new_null_array(&DataType::Decimal256(*precision, *scale), size) } + + ScalarValue::Boolean(e) => match e { + None => new_null_array(&DataType::Boolean, size), + Some(true) => { + Arc::new(BooleanArray::new(BooleanBuffer::new_set(size), None)) + as ArrayRef + } + Some(false) => { + Arc::new(BooleanArray::new(BooleanBuffer::new_unset(size), None)) + as ArrayRef + } + }, ScalarValue::Float64(e) => { build_array_from_option!(Float64, Float64Array, e, size) } @@ -2453,36 +3115,36 @@ impl ScalarValue { ) } ScalarValue::Utf8(e) => match e { - Some(value) => { - Arc::new(StringArray::from_iter_values(repeat_n(value, size))) - } + Some(value) => Arc::new(StringArray::new_repeated(value, size)), None => new_null_array(&DataType::Utf8, size), }, ScalarValue::Utf8View(e) => match e { Some(value) => { - Arc::new(StringViewArray::from_iter_values(repeat_n(value, size))) + let mut builder = StringViewBuilder::with_capacity(size); + builder.try_append_value_n(value, size)?; + let array = builder.finish(); + Arc::new(array) } None => new_null_array(&DataType::Utf8View, size), }, ScalarValue::LargeUtf8(e) => match e { - Some(value) => { - Arc::new(LargeStringArray::from_iter_values(repeat_n(value, size))) - } + Some(value) => Arc::new(LargeStringArray::new_repeated(value, size)), None => new_null_array(&DataType::LargeUtf8, size), }, ScalarValue::Binary(e) => match e { - Some(value) => Arc::new( - repeat_n(Some(value.as_slice()), size).collect::(), - ), - None => Arc::new(repeat_n(None::<&str>, size).collect::()), + Some(value) => { + Arc::new(BinaryArray::new_repeated(value.as_slice(), size)) + } + None => new_null_array(&DataType::Binary, size), }, ScalarValue::BinaryView(e) => match e { - Some(value) => Arc::new( - repeat_n(Some(value.as_slice()), size).collect::(), - ), - None => { - Arc::new(repeat_n(None::<&str>, size).collect::()) + Some(value) => { + let mut builder = BinaryViewBuilder::with_capacity(size); + builder.try_append_value_n(value, size)?; + let array = builder.finish(); + Arc::new(array) } + None => new_null_array(&DataType::BinaryView, size), }, ScalarValue::FixedSizeBinary(s, e) => match e { Some(value) => Arc::new( @@ -2492,35 +3154,42 @@ impl ScalarValue { ) .unwrap(), ), - None => Arc::new( - FixedSizeBinaryArray::try_from_sparse_iter_with_size( - repeat_n(None::<&[u8]>, size), - *s, - ) - .unwrap(), - ), + None => Arc::new(FixedSizeBinaryArray::new_null(*s, size)), }, ScalarValue::LargeBinary(e) => match e { - Some(value) => Arc::new( - repeat_n(Some(value.as_slice()), size).collect::(), - ), - None => { - Arc::new(repeat_n(None::<&str>, size).collect::()) + Some(value) => { + Arc::new(LargeBinaryArray::new_repeated(value.as_slice(), size)) } + None => new_null_array(&DataType::LargeBinary, size), }, ScalarValue::List(arr) => { + if size == 1 { + return Ok(Arc::clone(arr) as Arc); + } Self::list_to_array_of_size(arr.as_ref() as &dyn Array, size)? } ScalarValue::LargeList(arr) => { + if size == 1 { + return Ok(Arc::clone(arr) as Arc); + } Self::list_to_array_of_size(arr.as_ref() as &dyn Array, size)? } ScalarValue::FixedSizeList(arr) => { + if size == 1 { + return Ok(Arc::clone(arr) as Arc); + } Self::list_to_array_of_size(arr.as_ref() as &dyn Array, size)? } ScalarValue::Struct(arr) => { + if size == 1 { + return Ok(Arc::clone(arr) as Arc); + } Self::list_to_array_of_size(arr.as_ref() as &dyn Array, size)? } ScalarValue::Map(arr) => { + if size == 1 { + return Ok(Arc::clone(arr) as Arc); + } Self::list_to_array_of_size(arr.as_ref() as &dyn Array, size)? } ScalarValue::Date32(e) => { @@ -2646,13 +3315,10 @@ impl ScalarValue { value_offsets, child_arrays, ) - .map_err(|e| DataFusionError::ArrowError(e, None))?; + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?; Arc::new(ar) } - None => { - let dt = self.data_type(); - new_null_array(&dt, size) - } + None => new_null_array(&DataType::Union(fields.clone(), *mode), size), }, ScalarValue::Dictionary(key_type, v) => { // values array is one element long (the value) @@ -2665,10 +3331,58 @@ impl ScalarValue { DataType::UInt16 => dict_from_scalar::(v, size)?, DataType::UInt32 => dict_from_scalar::(v, size)?, DataType::UInt64 => dict_from_scalar::(v, size)?, - _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), + _ => unreachable!("Invalid dictionary keys type: {}", key_type), + } + } + ScalarValue::RunEndEncoded(run_ends_field, values_field, value) => { + fn make_run_array( + run_ends_field: &Arc, + values_field: &Arc, + value: &ScalarValue, + size: usize, + ) -> Result { + let size_native = R::Native::from_usize(size) + .ok_or_else(|| DataFusionError::Execution(format!("Cannot construct RunArray of size {size}: Overflows run-ends type {}", R::DATA_TYPE)))?; + let values = value.to_array_of_size(1)?; + let run_ends = + PrimitiveArray::::new(vec![size_native].into(), None); + + // Using ArrayDataBuilder so we can maintain the fields + let dt = DataType::RunEndEncoded( + Arc::clone(run_ends_field), + Arc::clone(values_field), + ); + let builder = ArrayDataBuilder::new(dt) + .len(size) + .add_child_data(run_ends.to_data()) + .add_child_data(values.to_data()); + let run_array = RunArray::::from(builder.build()?); + + Ok(Arc::new(run_array)) + } + match run_ends_field.data_type() { + DataType::Int16 => make_run_array::( + run_ends_field, + values_field, + value, + size, + )?, + DataType::Int32 => make_run_array::( + run_ends_field, + values_field, + value, + size, + )?, + DataType::Int64 => make_run_array::( + run_ends_field, + values_field, + value, + size, + )?, + dt => unreachable!("Invalid run-ends type: {dt}"), } } - ScalarValue::Null => new_null_array(&DataType::Null, size), + ScalarValue::Null => get_or_create_cached_null_array(size), }) } @@ -2679,6 +3393,24 @@ impl ScalarValue { scale: i8, ) -> Result { match array.data_type() { + DataType::Decimal32(_, _) => { + let array = as_decimal32_array(array)?; + if array.is_null(index) { + Ok(ScalarValue::Decimal32(None, precision, scale)) + } else { + let value = array.value(index); + Ok(ScalarValue::Decimal32(Some(value), precision, scale)) + } + } + DataType::Decimal64(_, _) => { + let array = as_decimal64_array(array)?; + if array.is_null(index) { + Ok(ScalarValue::Decimal64(None, precision, scale)) + } else { + let value = array.value(index); + Ok(ScalarValue::Decimal64(Some(value), precision, scale)) + } + } DataType::Decimal128(_, _) => { let array = as_decimal128_array(array)?; if array.is_null(index) { @@ -2697,46 +3429,59 @@ impl ScalarValue { Ok(ScalarValue::Decimal256(Some(value), precision, scale)) } } - _ => _internal_err!("Unsupported decimal type"), + other => { + unreachable!("Invalid type isn't decimal: {other:?}") + } } } + /// Repeats the rows of `arr` `size` times, producing an array with + /// `arr.len() * size` total rows. fn list_to_array_of_size(arr: &dyn Array, size: usize) -> Result { - let arrays = repeat_n(arr, size).collect::>(); - let ret = match !arrays.is_empty() { - true => arrow::compute::concat(arrays.as_slice())?, - false => arr.slice(0, 0), - }; - Ok(ret) + if size == 0 { + return Ok(arr.slice(0, 0)); + } + + // Examples: given `arr = [[A, B, C]]` and `size = 3`, `indices = [0, 0, 0]` and + // the result is `[[A, B, C], [A, B, C], [A, B, C]]`. + // + // Given `arr = [[A, B], [C]]` and `size = 2`, `indices = [0, 1, 0, 1]` and the + // result is `[[A, B], [C], [A, B], [C]]`. (But in practice, we are always called + // with `arr.len() == 1`.) + let n = arr.len() as u32; + let indices = UInt32Array::from_iter_values((0..size).flat_map(|_| 0..n)); + Ok(arrow::compute::take(arr, &indices, None)?) } /// Retrieve ScalarValue for each row in `array` /// + /// Elements in `array` may be NULL, in which case the corresponding element in the returned vector is None. + /// /// Example 1: Array (ScalarValue::Int32) /// ``` - /// use datafusion_common::ScalarValue; /// use arrow::array::ListArray; /// use arrow::datatypes::{DataType, Int32Type}; + /// use datafusion_common::ScalarValue; /// /// // Equivalent to [[1,2,3], [4,5]] /// let list_arr = ListArray::from_iter_primitive::(vec![ - /// Some(vec![Some(1), Some(2), Some(3)]), - /// Some(vec![Some(4), Some(5)]) + /// Some(vec![Some(1), Some(2), Some(3)]), + /// Some(vec![Some(4), Some(5)]), /// ]); /// /// // Convert the array into Scalar Values for each row /// let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&list_arr).unwrap(); /// /// let expected = vec![ - /// vec![ - /// ScalarValue::Int32(Some(1)), - /// ScalarValue::Int32(Some(2)), - /// ScalarValue::Int32(Some(3)), - /// ], - /// vec![ - /// ScalarValue::Int32(Some(4)), - /// ScalarValue::Int32(Some(5)), - /// ], + /// Some(vec![ + /// ScalarValue::Int32(Some(1)), + /// ScalarValue::Int32(Some(2)), + /// ScalarValue::Int32(Some(3)), + /// ]), + /// Some(vec![ + /// ScalarValue::Int32(Some(4)), + /// ScalarValue::Int32(Some(5)), + /// ]), /// ]; /// /// assert_eq!(scalar_vec, expected); @@ -2744,15 +3489,15 @@ impl ScalarValue { /// /// Example 2: Nested array (ScalarValue::List) /// ``` - /// use datafusion_common::ScalarValue; /// use arrow::array::ListArray; /// use arrow::datatypes::{DataType, Int32Type}; /// use datafusion_common::utils::SingleRowListArrayBuilder; + /// use datafusion_common::ScalarValue; /// use std::sync::Arc; /// /// let list_arr = ListArray::from_iter_primitive::(vec![ - /// Some(vec![Some(1), Some(2), Some(3)]), - /// Some(vec![Some(4), Some(5)]) + /// Some(vec![Some(1), Some(2), Some(3)]), + /// Some(vec![Some(4), Some(5)]), /// ]); /// /// // Wrap into another layer of list, we got nested array as [ [[1,2,3], [4,5]] ] @@ -2761,34 +3506,82 @@ impl ScalarValue { /// // Convert the array into Scalar Values for each row, we got 1D arrays in this example /// let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&list_arr).unwrap(); /// - /// let l1 = ListArray::from_iter_primitive::(vec![ + /// let l1 = ListArray::from_iter_primitive::(vec![Some(vec![ + /// Some(1), + /// Some(2), + /// Some(3), + /// ])]); + /// let l2 = ListArray::from_iter_primitive::(vec![Some(vec![ + /// Some(4), + /// Some(5), + /// ])]); + /// + /// let expected = vec![Some(vec![ + /// ScalarValue::List(Arc::new(l1)), + /// ScalarValue::List(Arc::new(l2)), + /// ])]; + /// + /// assert_eq!(scalar_vec, expected); + /// ``` + /// + /// Example 3: Nullable array + /// ``` + /// use arrow::array::ListArray; + /// use arrow::datatypes::{DataType, Int32Type}; + /// use datafusion_common::ScalarValue; + /// + /// let list_arr = ListArray::from_iter_primitive::(vec![ /// Some(vec![Some(1), Some(2), Some(3)]), - /// ]); - /// let l2 = ListArray::from_iter_primitive::(vec![ + /// None, /// Some(vec![Some(4), Some(5)]), /// ]); /// + /// // Convert the array into Scalar Values for each row + /// let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&list_arr).unwrap(); + /// /// let expected = vec![ - /// vec![ - /// ScalarValue::List(Arc::new(l1)), - /// ScalarValue::List(Arc::new(l2)), - /// ], + /// Some(vec![ + /// ScalarValue::Int32(Some(1)), + /// ScalarValue::Int32(Some(2)), + /// ScalarValue::Int32(Some(3)), + /// ]), + /// None, + /// Some(vec![ + /// ScalarValue::Int32(Some(4)), + /// ScalarValue::Int32(Some(5)), + /// ]), /// ]; /// /// assert_eq!(scalar_vec, expected); /// ``` - pub fn convert_array_to_scalar_vec(array: &dyn Array) -> Result>> { - let mut scalars = Vec::with_capacity(array.len()); - - for index in 0..array.len() { - let nested_array = array.as_list::().value(index); - let scalar_values = (0..nested_array.len()) - .map(|i| ScalarValue::try_from_array(&nested_array, i)) - .collect::>>()?; - scalars.push(scalar_values); + pub fn convert_array_to_scalar_vec( + array: &dyn Array, + ) -> Result>>> { + fn generic_collect( + array: &dyn Array, + ) -> Result>>> { + array + .as_list::() + .iter() + .map(|nested_array| { + nested_array + .map(|array| { + (0..array.len()) + .map(|i| ScalarValue::try_from_array(&array, i)) + .collect::>>() + }) + .transpose() + }) + .collect() } - Ok(scalars) + match array.data_type() { + DataType::List(_) => generic_collect::(array), + DataType::LargeList(_) => generic_collect::(array), + _ => _internal_err!( + "ScalarValue::convert_array_to_scalar_vec input must be a List/LargeList type" + ), + } } #[deprecated( @@ -2805,12 +3598,22 @@ impl ScalarValue { /// Converts a value in `array` at `index` into a ScalarValue pub fn try_from_array(array: &dyn Array, index: usize) -> Result { // handle NULL value - if !array.is_valid(index) { + if array.is_null(index) { return array.data_type().try_into(); } Ok(match array.data_type() { DataType::Null => ScalarValue::Null, + DataType::Decimal32(precision, scale) => { + ScalarValue::get_decimal_value_from_array( + array, index, *precision, *scale, + )? + } + DataType::Decimal64(precision, scale) => { + ScalarValue::get_decimal_value_from_array( + array, index, *precision, *scale, + )? + } DataType::Decimal128(precision, scale) => { ScalarValue::get_decimal_value_from_array( array, index, *precision, *scale, @@ -2821,30 +3624,32 @@ impl ScalarValue { array, index, *precision, *scale, )? } - DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean)?, - DataType::Float64 => typed_cast!(array, index, Float64Array, Float64)?, - DataType::Float32 => typed_cast!(array, index, Float32Array, Float32)?, - DataType::Float16 => typed_cast!(array, index, Float16Array, Float16)?, - DataType::UInt64 => typed_cast!(array, index, UInt64Array, UInt64)?, - DataType::UInt32 => typed_cast!(array, index, UInt32Array, UInt32)?, - DataType::UInt16 => typed_cast!(array, index, UInt16Array, UInt16)?, - DataType::UInt8 => typed_cast!(array, index, UInt8Array, UInt8)?, - DataType::Int64 => typed_cast!(array, index, Int64Array, Int64)?, - DataType::Int32 => typed_cast!(array, index, Int32Array, Int32)?, - DataType::Int16 => typed_cast!(array, index, Int16Array, Int16)?, - DataType::Int8 => typed_cast!(array, index, Int8Array, Int8)?, - DataType::Binary => typed_cast!(array, index, BinaryArray, Binary)?, + DataType::Boolean => typed_cast!(array, index, as_boolean_array, Boolean)?, + DataType::Float64 => typed_cast!(array, index, as_float64_array, Float64)?, + DataType::Float32 => typed_cast!(array, index, as_float32_array, Float32)?, + DataType::Float16 => typed_cast!(array, index, as_float16_array, Float16)?, + DataType::UInt64 => typed_cast!(array, index, as_uint64_array, UInt64)?, + DataType::UInt32 => typed_cast!(array, index, as_uint32_array, UInt32)?, + DataType::UInt16 => typed_cast!(array, index, as_uint16_array, UInt16)?, + DataType::UInt8 => typed_cast!(array, index, as_uint8_array, UInt8)?, + DataType::Int64 => typed_cast!(array, index, as_int64_array, Int64)?, + DataType::Int32 => typed_cast!(array, index, as_int32_array, Int32)?, + DataType::Int16 => typed_cast!(array, index, as_int16_array, Int16)?, + DataType::Int8 => typed_cast!(array, index, as_int8_array, Int8)?, + DataType::Binary => typed_cast!(array, index, as_binary_array, Binary)?, DataType::LargeBinary => { - typed_cast!(array, index, LargeBinaryArray, LargeBinary)? + typed_cast!(array, index, as_large_binary_array, LargeBinary)? } DataType::BinaryView => { - typed_cast!(array, index, BinaryViewArray, BinaryView)? + typed_cast!(array, index, as_binary_view_array, BinaryView)? } - DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8)?, + DataType::Utf8 => typed_cast!(array, index, as_string_array, Utf8)?, DataType::LargeUtf8 => { - typed_cast!(array, index, LargeStringArray, LargeUtf8)? + typed_cast!(array, index, as_large_string_array, LargeUtf8)? + } + DataType::Utf8View => { + typed_cast!(array, index, as_string_view_array, Utf8View)? } - DataType::Utf8View => typed_cast!(array, index, StringViewArray, Utf8View)?, DataType::List(field) => { let list_array = array.as_list::(); let nested_array = list_array.value(index); @@ -2854,7 +3659,7 @@ impl ScalarValue { .build_list_scalar() } DataType::LargeList(field) => { - let list_array = as_large_list_array(array); + let list_array = as_large_list_array(array)?; let nested_array = list_array.value(index); // Produces a single element `LargeListArray` with the value at `index`. SingleRowListArrayBuilder::new(nested_array) @@ -2871,45 +3676,61 @@ impl ScalarValue { .with_field(field) .build_fixed_size_list_scalar(list_size) } - DataType::Date32 => typed_cast!(array, index, Date32Array, Date32)?, - DataType::Date64 => typed_cast!(array, index, Date64Array, Date64)?, + DataType::ListView(field) => { + let list_array = array.as_list_view::(); + let nested_array = list_array.value(index); + // Store as List scalar since ScalarValue has no ListView variant. + SingleRowListArrayBuilder::new(nested_array) + .with_field(field) + .build_list_scalar() + } + DataType::LargeListView(field) => { + let list_array = array.as_list_view::(); + let nested_array = list_array.value(index); + // Store as LargeList scalar since ScalarValue has no LargeListView variant. + SingleRowListArrayBuilder::new(nested_array) + .with_field(field) + .build_large_list_scalar() + } + DataType::Date32 => typed_cast!(array, index, as_date32_array, Date32)?, + DataType::Date64 => typed_cast!(array, index, as_date64_array, Date64)?, DataType::Time32(TimeUnit::Second) => { - typed_cast!(array, index, Time32SecondArray, Time32Second)? + typed_cast!(array, index, as_time32_second_array, Time32Second)? } DataType::Time32(TimeUnit::Millisecond) => { - typed_cast!(array, index, Time32MillisecondArray, Time32Millisecond)? + typed_cast!(array, index, as_time32_millisecond_array, Time32Millisecond)? } DataType::Time64(TimeUnit::Microsecond) => { - typed_cast!(array, index, Time64MicrosecondArray, Time64Microsecond)? + typed_cast!(array, index, as_time64_microsecond_array, Time64Microsecond)? } DataType::Time64(TimeUnit::Nanosecond) => { - typed_cast!(array, index, Time64NanosecondArray, Time64Nanosecond)? + typed_cast!(array, index, as_time64_nanosecond_array, Time64Nanosecond)? } DataType::Timestamp(TimeUnit::Second, tz_opt) => typed_cast_tz!( array, index, - TimestampSecondArray, + as_timestamp_second_array, TimestampSecond, tz_opt )?, DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => typed_cast_tz!( array, index, - TimestampMillisecondArray, + as_timestamp_millisecond_array, TimestampMillisecond, tz_opt )?, DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => typed_cast_tz!( array, index, - TimestampMicrosecondArray, + as_timestamp_microsecond_array, TimestampMicrosecond, tz_opt )?, DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => typed_cast_tz!( array, index, - TimestampNanosecondArray, + as_timestamp_nanosecond_array, TimestampNanosecond, tz_opt )?, @@ -2923,7 +3744,7 @@ impl ScalarValue { DataType::UInt16 => get_dict_value::(array, index)?, DataType::UInt32 => get_dict_value::(array, index)?, DataType::UInt64 => get_dict_value::(array, index)?, - _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), + _ => unreachable!("Invalid dictionary keys type: {}", key_type), }; // look up the index in the values dictionary let value = match values_index { @@ -2936,6 +3757,28 @@ impl ScalarValue { Self::Dictionary(key_type.clone(), Box::new(value)) } + DataType::RunEndEncoded(run_ends_field, value_field) => { + // Explicitly check length here since get_physical_index() doesn't + // bound check for us + if index > array.len() { + return _exec_err!( + "Index {index} out of bounds for array of length {}", + array.len() + ); + } + let scalar = downcast_run_array!( + array => { + let index = array.get_physical_index(index); + ScalarValue::try_from_array(array.values(), index)? + }, + dt => unreachable!("Invalid run-ends type: {dt}") + ); + Self::RunEndEncoded( + Arc::clone(run_ends_field), + Arc::clone(value_field), + Box::new(scalar), + ) + } DataType::Struct(_) => { let a = array.slice(index, 1); Self::Struct(Arc::new(a.as_struct().to_owned())) @@ -2955,36 +3798,42 @@ impl ScalarValue { ) } DataType::Interval(IntervalUnit::DayTime) => { - typed_cast!(array, index, IntervalDayTimeArray, IntervalDayTime)? + typed_cast!(array, index, as_interval_dt_array, IntervalDayTime)? } DataType::Interval(IntervalUnit::YearMonth) => { - typed_cast!(array, index, IntervalYearMonthArray, IntervalYearMonth)? + typed_cast!(array, index, as_interval_ym_array, IntervalYearMonth)? + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + typed_cast!(array, index, as_interval_mdn_array, IntervalMonthDayNano)? } - DataType::Interval(IntervalUnit::MonthDayNano) => typed_cast!( - array, - index, - IntervalMonthDayNanoArray, - IntervalMonthDayNano - )?, DataType::Duration(TimeUnit::Second) => { - typed_cast!(array, index, DurationSecondArray, DurationSecond)? - } - DataType::Duration(TimeUnit::Millisecond) => { - typed_cast!(array, index, DurationMillisecondArray, DurationMillisecond)? - } - DataType::Duration(TimeUnit::Microsecond) => { - typed_cast!(array, index, DurationMicrosecondArray, DurationMicrosecond)? - } - DataType::Duration(TimeUnit::Nanosecond) => { - typed_cast!(array, index, DurationNanosecondArray, DurationNanosecond)? + typed_cast!(array, index, as_duration_second_array, DurationSecond)? } + DataType::Duration(TimeUnit::Millisecond) => typed_cast!( + array, + index, + as_duration_millisecond_array, + DurationMillisecond + )?, + DataType::Duration(TimeUnit::Microsecond) => typed_cast!( + array, + index, + as_duration_microsecond_array, + DurationMicrosecond + )?, + DataType::Duration(TimeUnit::Nanosecond) => typed_cast!( + array, + index, + as_duration_nanosecond_array, + DurationNanosecond + )?, DataType::Map(_, _) => { let a = array.slice(index, 1); Self::Map(Arc::new(a.as_map().to_owned())) } DataType::Union(fields, mode) => { - let array = as_union_array(array); + let array = as_union_array(array)?; let ti = array.type_id(index); let index = array.value_offset(index); let value = ScalarValue::try_from_array(array.child(ti), index)?; @@ -3042,6 +3891,7 @@ impl ScalarValue { ScalarValue::LargeUtf8(v) => v, ScalarValue::Utf8View(v) => v, ScalarValue::Dictionary(_, v) => return v.try_as_str(), + ScalarValue::RunEndEncoded(_, _, v) => return v.try_as_str(), _ => return None, }; Some(v.as_ref().map(|v| v.as_str())) @@ -3058,55 +3908,50 @@ impl ScalarValue { target_type: &DataType, cast_options: &CastOptions<'static>, ) -> Result { - let scalar_array = match (self, target_type) { - ( - ScalarValue::Float64(Some(float_ts)), - DataType::Timestamp(TimeUnit::Nanosecond, None), - ) => ScalarValue::Int64(Some((float_ts * 1_000_000_000_f64).trunc() as i64)) - .to_array()?, - ( - ScalarValue::Decimal128(Some(decimal_value), _, scale), - DataType::Timestamp(time_unit, None), - ) => { - let scale_factor = 10_i128.pow(*scale as u32); - let seconds = decimal_value / scale_factor; - let fraction = decimal_value % scale_factor; - - let timestamp_value = match time_unit { - TimeUnit::Second => ScalarValue::Int64(Some(seconds as i64)), - TimeUnit::Millisecond => { - let millis = seconds * 1_000 + (fraction * 1_000) / scale_factor; - ScalarValue::Int64(Some(millis as i64)) - } - TimeUnit::Microsecond => { - let micros = - seconds * 1_000_000 + (fraction * 1_000_000) / scale_factor; - ScalarValue::Int64(Some(micros as i64)) - } - TimeUnit::Nanosecond => { - let nanos = seconds * 1_000_000_000 - + (fraction * 1_000_000_000) / scale_factor; - ScalarValue::Int64(Some(nanos as i64)) - } - }; + let source_type = self.data_type(); + if let Some(multiplier) = date_to_timestamp_multiplier(&source_type, target_type) + && let Some(value) = self.date_scalar_value_as_i64() + { + ensure_timestamp_in_bounds(value, multiplier, &source_type, target_type)?; + } - timestamp_value.to_array()? + let scalar_array = self.to_array()?; + + // For struct types, use name-based casting logic that matches fields by name + // and recursively casts nested structs. The field name wrapper is arbitrary + // since cast_column only uses the DataType::Struct field definitions inside. + let cast_arr = match target_type { + DataType::Struct(_) => { + // Field name is unused; only the struct's inner field names matter + let target_field = Field::new("_", target_type.clone(), true); + crate::nested_struct::cast_column( + &scalar_array, + &target_field, + cast_options, + )? } - _ => self.to_array()?, + _ => cast_with_options(&scalar_array, target_type, cast_options)?, }; - let cast_arr = cast_with_options(&scalar_array, target_type, cast_options)?; ScalarValue::try_from_array(&cast_arr, 0) } - fn eq_array_decimal( + fn date_scalar_value_as_i64(&self) -> Option { + match self { + ScalarValue::Date32(Some(value)) => Some(i64::from(*value)), + ScalarValue::Date64(Some(value)) => Some(*value), + _ => None, + } + } + + fn eq_array_decimal32( array: &ArrayRef, index: usize, - value: Option<&i128>, + value: Option<&i32>, precision: u8, scale: i8, ) -> Result { - let array = as_decimal128_array(array)?; + let array = as_decimal32_array(array)?; if array.precision() != precision || array.scale() != scale { return Ok(false); } @@ -3118,14 +3963,52 @@ impl ScalarValue { } } - fn eq_array_decimal256( + fn eq_array_decimal64( array: &ArrayRef, index: usize, - value: Option<&i256>, + value: Option<&i64>, precision: u8, scale: i8, ) -> Result { - let array = as_decimal256_array(array)?; + let array = as_decimal64_array(array)?; + if array.precision() != precision || array.scale() != scale { + return Ok(false); + } + let is_null = array.is_null(index); + if let Some(v) = value { + Ok(!array.is_null(index) && array.value(index) == *v) + } else { + Ok(is_null) + } + } + + fn eq_array_decimal( + array: &ArrayRef, + index: usize, + value: Option<&i128>, + precision: u8, + scale: i8, + ) -> Result { + let array = as_decimal128_array(array)?; + if array.precision() != precision || array.scale() != scale { + return Ok(false); + } + let is_null = array.is_null(index); + if let Some(v) = value { + Ok(!array.is_null(index) && array.value(index) == *v) + } else { + Ok(is_null) + } + } + + fn eq_array_decimal256( + array: &ArrayRef, + index: usize, + value: Option<&i256>, + precision: u8, + scale: i8, + ) -> Result { + let array = as_decimal256_array(array)?; if array.precision() != precision || array.scale() != scale { return Ok(false); } @@ -3166,6 +4049,24 @@ impl ScalarValue { #[inline] pub fn eq_array(&self, array: &ArrayRef, index: usize) -> Result { Ok(match self { + ScalarValue::Decimal32(v, precision, scale) => { + ScalarValue::eq_array_decimal32( + array, + index, + v.as_ref(), + *precision, + *scale, + )? + } + ScalarValue::Decimal64(v, precision, scale) => { + ScalarValue::eq_array_decimal64( + array, + index, + v.as_ref(), + *precision, + *scale, + )? + } ScalarValue::Decimal128(v, precision, scale) => { ScalarValue::eq_array_decimal( array, @@ -3185,59 +4086,61 @@ impl ScalarValue { )? } ScalarValue::Boolean(val) => { - eq_array_primitive!(array, index, BooleanArray, val)? + eq_array_primitive!(array, index, as_boolean_array, val)? } ScalarValue::Float16(val) => { - eq_array_primitive!(array, index, Float16Array, val)? + eq_array_primitive!(array, index, as_float16_array, val)? } ScalarValue::Float32(val) => { - eq_array_primitive!(array, index, Float32Array, val)? + eq_array_primitive!(array, index, as_float32_array, val)? } ScalarValue::Float64(val) => { - eq_array_primitive!(array, index, Float64Array, val)? + eq_array_primitive!(array, index, as_float64_array, val)? + } + ScalarValue::Int8(val) => { + eq_array_primitive!(array, index, as_int8_array, val)? } - ScalarValue::Int8(val) => eq_array_primitive!(array, index, Int8Array, val)?, ScalarValue::Int16(val) => { - eq_array_primitive!(array, index, Int16Array, val)? + eq_array_primitive!(array, index, as_int16_array, val)? } ScalarValue::Int32(val) => { - eq_array_primitive!(array, index, Int32Array, val)? + eq_array_primitive!(array, index, as_int32_array, val)? } ScalarValue::Int64(val) => { - eq_array_primitive!(array, index, Int64Array, val)? + eq_array_primitive!(array, index, as_int64_array, val)? } ScalarValue::UInt8(val) => { - eq_array_primitive!(array, index, UInt8Array, val)? + eq_array_primitive!(array, index, as_uint8_array, val)? } ScalarValue::UInt16(val) => { - eq_array_primitive!(array, index, UInt16Array, val)? + eq_array_primitive!(array, index, as_uint16_array, val)? } ScalarValue::UInt32(val) => { - eq_array_primitive!(array, index, UInt32Array, val)? + eq_array_primitive!(array, index, as_uint32_array, val)? } ScalarValue::UInt64(val) => { - eq_array_primitive!(array, index, UInt64Array, val)? + eq_array_primitive!(array, index, as_uint64_array, val)? } ScalarValue::Utf8(val) => { - eq_array_primitive!(array, index, StringArray, val)? + eq_array_primitive!(array, index, as_string_array, val)? } ScalarValue::Utf8View(val) => { - eq_array_primitive!(array, index, StringViewArray, val)? + eq_array_primitive!(array, index, as_string_view_array, val)? } ScalarValue::LargeUtf8(val) => { - eq_array_primitive!(array, index, LargeStringArray, val)? + eq_array_primitive!(array, index, as_large_string_array, val)? } ScalarValue::Binary(val) => { - eq_array_primitive!(array, index, BinaryArray, val)? + eq_array_primitive!(array, index, as_binary_array, val)? } ScalarValue::BinaryView(val) => { - eq_array_primitive!(array, index, BinaryViewArray, val)? + eq_array_primitive!(array, index, as_binary_view_array, val)? } ScalarValue::FixedSizeBinary(_, val) => { - eq_array_primitive!(array, index, FixedSizeBinaryArray, val)? + eq_array_primitive!(array, index, as_fixed_size_binary_array, val)? } ScalarValue::LargeBinary(val) => { - eq_array_primitive!(array, index, LargeBinaryArray, val)? + eq_array_primitive!(array, index, as_large_binary_array, val)? } ScalarValue::List(arr) => { Self::eq_array_list(&(arr.to_owned() as ArrayRef), array, index) @@ -3255,58 +4158,58 @@ impl ScalarValue { Self::eq_array_list(&(arr.to_owned() as ArrayRef), array, index) } ScalarValue::Date32(val) => { - eq_array_primitive!(array, index, Date32Array, val)? + eq_array_primitive!(array, index, as_date32_array, val)? } ScalarValue::Date64(val) => { - eq_array_primitive!(array, index, Date64Array, val)? + eq_array_primitive!(array, index, as_date64_array, val)? } ScalarValue::Time32Second(val) => { - eq_array_primitive!(array, index, Time32SecondArray, val)? + eq_array_primitive!(array, index, as_time32_second_array, val)? } ScalarValue::Time32Millisecond(val) => { - eq_array_primitive!(array, index, Time32MillisecondArray, val)? + eq_array_primitive!(array, index, as_time32_millisecond_array, val)? } ScalarValue::Time64Microsecond(val) => { - eq_array_primitive!(array, index, Time64MicrosecondArray, val)? + eq_array_primitive!(array, index, as_time64_microsecond_array, val)? } ScalarValue::Time64Nanosecond(val) => { - eq_array_primitive!(array, index, Time64NanosecondArray, val)? + eq_array_primitive!(array, index, as_time64_nanosecond_array, val)? } ScalarValue::TimestampSecond(val, _) => { - eq_array_primitive!(array, index, TimestampSecondArray, val)? + eq_array_primitive!(array, index, as_timestamp_second_array, val)? } ScalarValue::TimestampMillisecond(val, _) => { - eq_array_primitive!(array, index, TimestampMillisecondArray, val)? + eq_array_primitive!(array, index, as_timestamp_millisecond_array, val)? } ScalarValue::TimestampMicrosecond(val, _) => { - eq_array_primitive!(array, index, TimestampMicrosecondArray, val)? + eq_array_primitive!(array, index, as_timestamp_microsecond_array, val)? } ScalarValue::TimestampNanosecond(val, _) => { - eq_array_primitive!(array, index, TimestampNanosecondArray, val)? + eq_array_primitive!(array, index, as_timestamp_nanosecond_array, val)? } ScalarValue::IntervalYearMonth(val) => { - eq_array_primitive!(array, index, IntervalYearMonthArray, val)? + eq_array_primitive!(array, index, as_interval_ym_array, val)? } ScalarValue::IntervalDayTime(val) => { - eq_array_primitive!(array, index, IntervalDayTimeArray, val)? + eq_array_primitive!(array, index, as_interval_dt_array, val)? } ScalarValue::IntervalMonthDayNano(val) => { - eq_array_primitive!(array, index, IntervalMonthDayNanoArray, val)? + eq_array_primitive!(array, index, as_interval_mdn_array, val)? } ScalarValue::DurationSecond(val) => { - eq_array_primitive!(array, index, DurationSecondArray, val)? + eq_array_primitive!(array, index, as_duration_second_array, val)? } ScalarValue::DurationMillisecond(val) => { - eq_array_primitive!(array, index, DurationMillisecondArray, val)? + eq_array_primitive!(array, index, as_duration_millisecond_array, val)? } ScalarValue::DurationMicrosecond(val) => { - eq_array_primitive!(array, index, DurationMicrosecondArray, val)? + eq_array_primitive!(array, index, as_duration_microsecond_array, val)? } ScalarValue::DurationNanosecond(val) => { - eq_array_primitive!(array, index, DurationNanosecondArray, val)? + eq_array_primitive!(array, index, as_duration_nanosecond_array, val)? } ScalarValue::Union(value, _, _) => { - let array = as_union_array(array); + let array = as_union_array(array)?; let ti = array.type_id(index); let index = array.value_offset(index); if let Some((ti_v, value)) = value { @@ -3325,7 +4228,7 @@ impl ScalarValue { DataType::UInt16 => get_dict_value::(array, index)?, DataType::UInt32 => get_dict_value::(array, index)?, DataType::UInt64 => get_dict_value::(array, index)?, - _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), + _ => unreachable!("Invalid dictionary keys type: {}", key_type), }; // was the value in the array non null? match values_index { @@ -3333,6 +4236,34 @@ impl ScalarValue { None => v.is_null(), } } + ScalarValue::RunEndEncoded(run_ends_field, _, value) => { + // Explicitly check length here since get_physical_index() doesn't + // bound check for us + if index > array.len() { + return _exec_err!( + "Index {index} out of bounds for array of length {}", + array.len() + ); + } + match run_ends_field.data_type() { + DataType::Int16 => { + let array = as_run_array::(array)?; + let index = array.get_physical_index(index); + value.eq_array(array.values(), index)? + } + DataType::Int32 => { + let array = as_run_array::(array)?; + let index = array.get_physical_index(index); + value.eq_array(array.values(), index)? + } + DataType::Int64 => { + let array = as_run_array::(array)?; + let index = array.get_physical_index(index); + value.eq_array(array.values(), index)? + } + dt => unreachable!("Invalid run-ends type: {dt}"), + } + } ScalarValue::Null => array.is_null(index), }) } @@ -3342,6 +4273,16 @@ impl ScalarValue { arr1 == &right } + /// Compare `self` with `other` and return an `Ordering`. + /// + /// This is the same as [`PartialOrd`] except that it returns + /// `Err` if the values cannot be compared, e.g., they have incompatible data types. + pub fn try_cmp(&self, other: &Self) -> Result { + self.partial_cmp(other).ok_or_else(|| { + _internal_datafusion_err!("Uncomparable values: {self:?}, {other:?}") + }) + } + /// Estimate size if bytes including `Self`. For values with internal containers such as `String` /// includes the allocated size (`capacity`) rather than the current length (`len`) pub fn size(&self) -> usize { @@ -3352,6 +4293,8 @@ impl ScalarValue { | ScalarValue::Float16(_) | ScalarValue::Float32(_) | ScalarValue::Float64(_) + | ScalarValue::Decimal32(_, _, _) + | ScalarValue::Decimal64(_, _, _) | ScalarValue::Decimal128(_, _, _) | ScalarValue::Decimal256(_, _, _) | ScalarValue::Int8(_) @@ -3410,6 +4353,7 @@ impl ScalarValue { // `dt` and `sv` are boxed, so they are NOT already included in `self` dt.size() + sv.size() } + ScalarValue::RunEndEncoded(rf, vf, v) => rf.size() + vf.size() + v.size(), } } @@ -3461,6 +4405,8 @@ impl ScalarValue { | ScalarValue::Float16(_) | ScalarValue::Float32(_) | ScalarValue::Float64(_) + | ScalarValue::Decimal32(_, _, _) + | ScalarValue::Decimal64(_, _, _) | ScalarValue::Decimal128(_, _, _) | ScalarValue::Decimal256(_, _, _) | ScalarValue::Int8(_) @@ -3523,13 +4469,258 @@ impl ScalarValue { ScalarValue::Dictionary(_, value) => { value.compact(); } + ScalarValue::RunEndEncoded(_, _, value) => { + value.compact(); + } + } + } + + /// Compacts ([ScalarValue::compact]) the current [ScalarValue] and returns it. + pub fn compacted(mut self) -> Self { + self.compact(); + self + } + + /// Returns the minimum value for the given numeric `DataType`. + /// + /// This function returns the smallest representable value for numeric + /// and temporal data types. For non-numeric types, it returns `None`. + /// + /// # Supported Types + /// + /// - **Integer types**: `i8::MIN`, `i16::MIN`, etc. + /// - **Unsigned types**: Always 0 (`u8::MIN`, `u16::MIN`, etc.) + /// - **Float types**: Negative infinity (IEEE 754) + /// - **Decimal types**: Smallest value based on precision + /// - **Temporal types**: Minimum timestamp/date values + /// - **Time types**: 0 (midnight) + /// - **Duration types**: `i64::MIN` + pub fn min(datatype: &DataType) -> Option { + match datatype { + DataType::Int8 => Some(ScalarValue::Int8(Some(i8::MIN))), + DataType::Int16 => Some(ScalarValue::Int16(Some(i16::MIN))), + DataType::Int32 => Some(ScalarValue::Int32(Some(i32::MIN))), + DataType::Int64 => Some(ScalarValue::Int64(Some(i64::MIN))), + DataType::UInt8 => Some(ScalarValue::UInt8(Some(u8::MIN))), + DataType::UInt16 => Some(ScalarValue::UInt16(Some(u16::MIN))), + DataType::UInt32 => Some(ScalarValue::UInt32(Some(u32::MIN))), + DataType::UInt64 => Some(ScalarValue::UInt64(Some(u64::MIN))), + DataType::Float16 => Some(ScalarValue::Float16(Some(f16::NEG_INFINITY))), + DataType::Float32 => Some(ScalarValue::Float32(Some(f32::NEG_INFINITY))), + DataType::Float64 => Some(ScalarValue::Float64(Some(f64::NEG_INFINITY))), + DataType::Decimal128(precision, scale) => { + // For decimal, min is -10^(precision-scale) + 10^(-scale) + // But for simplicity, we use the minimum i128 value that fits the precision + let max_digits = 10_i128.pow(*precision as u32) - 1; + Some(ScalarValue::Decimal128( + Some(-max_digits), + *precision, + *scale, + )) + } + DataType::Decimal256(precision, scale) => { + // Similar to Decimal128 but with i256 + // For now, use a large negative value + let max_digits = i256::from_i128(10_i128) + .checked_pow(*precision as u32) + .and_then(|v| v.checked_sub(i256::from_i128(1))) + .unwrap_or(i256::MAX); + Some(ScalarValue::Decimal256( + Some(max_digits.neg_wrapping()), + *precision, + *scale, + )) + } + DataType::Date32 => Some(ScalarValue::Date32(Some(i32::MIN))), + DataType::Date64 => Some(ScalarValue::Date64(Some(i64::MIN))), + DataType::Time32(TimeUnit::Second) => { + Some(ScalarValue::Time32Second(Some(0))) + } + DataType::Time32(TimeUnit::Millisecond) => { + Some(ScalarValue::Time32Millisecond(Some(0))) + } + DataType::Time64(TimeUnit::Microsecond) => { + Some(ScalarValue::Time64Microsecond(Some(0))) + } + DataType::Time64(TimeUnit::Nanosecond) => { + Some(ScalarValue::Time64Nanosecond(Some(0))) + } + DataType::Timestamp(unit, tz) => match unit { + TimeUnit::Second => { + Some(ScalarValue::TimestampSecond(Some(i64::MIN), tz.clone())) + } + TimeUnit::Millisecond => Some(ScalarValue::TimestampMillisecond( + Some(i64::MIN), + tz.clone(), + )), + TimeUnit::Microsecond => Some(ScalarValue::TimestampMicrosecond( + Some(i64::MIN), + tz.clone(), + )), + TimeUnit::Nanosecond => { + Some(ScalarValue::TimestampNanosecond(Some(i64::MIN), tz.clone())) + } + }, + DataType::Duration(unit) => match unit { + TimeUnit::Second => Some(ScalarValue::DurationSecond(Some(i64::MIN))), + TimeUnit::Millisecond => { + Some(ScalarValue::DurationMillisecond(Some(i64::MIN))) + } + TimeUnit::Microsecond => { + Some(ScalarValue::DurationMicrosecond(Some(i64::MIN))) + } + TimeUnit::Nanosecond => { + Some(ScalarValue::DurationNanosecond(Some(i64::MIN))) + } + }, + _ => None, + } + } + + /// Returns the maximum value for the given numeric `DataType`. + /// + /// This function returns the largest representable value for numeric + /// and temporal data types. For non-numeric types, it returns `None`. + /// + /// # Supported Types + /// + /// - **Integer types**: `i8::MAX`, `i16::MAX`, etc. + /// - **Unsigned types**: `u8::MAX`, `u16::MAX`, etc. + /// - **Float types**: Positive infinity (IEEE 754) + /// - **Decimal types**: Largest value based on precision + /// - **Temporal types**: Maximum timestamp/date values + /// - **Time types**: Maximum time in the day (1 day - 1 unit) + /// - **Duration types**: `i64::MAX` + pub fn max(datatype: &DataType) -> Option { + match datatype { + DataType::Int8 => Some(ScalarValue::Int8(Some(i8::MAX))), + DataType::Int16 => Some(ScalarValue::Int16(Some(i16::MAX))), + DataType::Int32 => Some(ScalarValue::Int32(Some(i32::MAX))), + DataType::Int64 => Some(ScalarValue::Int64(Some(i64::MAX))), + DataType::UInt8 => Some(ScalarValue::UInt8(Some(u8::MAX))), + DataType::UInt16 => Some(ScalarValue::UInt16(Some(u16::MAX))), + DataType::UInt32 => Some(ScalarValue::UInt32(Some(u32::MAX))), + DataType::UInt64 => Some(ScalarValue::UInt64(Some(u64::MAX))), + DataType::Float16 => Some(ScalarValue::Float16(Some(f16::INFINITY))), + DataType::Float32 => Some(ScalarValue::Float32(Some(f32::INFINITY))), + DataType::Float64 => Some(ScalarValue::Float64(Some(f64::INFINITY))), + DataType::Decimal128(precision, scale) => { + // For decimal, max is 10^(precision-scale) - 10^(-scale) + // But for simplicity, we use the maximum i128 value that fits the precision + let max_digits = 10_i128.pow(*precision as u32) - 1; + Some(ScalarValue::Decimal128( + Some(max_digits), + *precision, + *scale, + )) + } + DataType::Decimal256(precision, scale) => { + // Similar to Decimal128 but with i256 + let max_digits = i256::from_i128(10_i128) + .checked_pow(*precision as u32) + .and_then(|v| v.checked_sub(i256::from_i128(1))) + .unwrap_or(i256::MAX); + Some(ScalarValue::Decimal256( + Some(max_digits), + *precision, + *scale, + )) + } + DataType::Date32 => Some(ScalarValue::Date32(Some(i32::MAX))), + DataType::Date64 => Some(ScalarValue::Date64(Some(i64::MAX))), + DataType::Time32(TimeUnit::Second) => { + // 86399 seconds = 23:59:59 + Some(ScalarValue::Time32Second(Some(86_399))) + } + DataType::Time32(TimeUnit::Millisecond) => { + // 86_399_999 milliseconds = 23:59:59.999 + Some(ScalarValue::Time32Millisecond(Some(86_399_999))) + } + DataType::Time64(TimeUnit::Microsecond) => { + // 86_399_999_999 microseconds = 23:59:59.999999 + Some(ScalarValue::Time64Microsecond(Some(86_399_999_999))) + } + DataType::Time64(TimeUnit::Nanosecond) => { + // 86_399_999_999_999 nanoseconds = 23:59:59.999999999 + Some(ScalarValue::Time64Nanosecond(Some(86_399_999_999_999))) + } + DataType::Timestamp(unit, tz) => match unit { + TimeUnit::Second => { + Some(ScalarValue::TimestampSecond(Some(i64::MAX), tz.clone())) + } + TimeUnit::Millisecond => Some(ScalarValue::TimestampMillisecond( + Some(i64::MAX), + tz.clone(), + )), + TimeUnit::Microsecond => Some(ScalarValue::TimestampMicrosecond( + Some(i64::MAX), + tz.clone(), + )), + TimeUnit::Nanosecond => { + Some(ScalarValue::TimestampNanosecond(Some(i64::MAX), tz.clone())) + } + }, + DataType::Duration(unit) => match unit { + TimeUnit::Second => Some(ScalarValue::DurationSecond(Some(i64::MAX))), + TimeUnit::Millisecond => { + Some(ScalarValue::DurationMillisecond(Some(i64::MAX))) + } + TimeUnit::Microsecond => { + Some(ScalarValue::DurationMicrosecond(Some(i64::MAX))) + } + TimeUnit::Nanosecond => { + Some(ScalarValue::DurationNanosecond(Some(i64::MAX))) + } + }, + _ => None, } } + + /// A thin wrapper on Arrow's validation that throws internal error if validation + /// fails. + fn validate_decimal_or_internal_err( + precision: u8, + scale: i8, + ) -> Result<()> { + validate_decimal_precision_and_scale::(precision, scale).map_err(|err| { + _internal_datafusion_err!( + "Decimal precision/scale invariant violated \ + (precision={precision}, scale={scale}): {err}" + ) + }) + } } -pub fn copy_array_data(data: &ArrayData) -> ArrayData { - let mut copy = MutableArrayData::new(vec![&data], true, data.len()); - copy.extend(0, 0, data.len()); +/// Compacts the data of an `ArrayData` into a new `ArrayData`. +/// +/// This is useful when you want to minimize the memory footprint of an +/// `ArrayData`. For example, the value returned by [`Array::slice`] still +/// points at the same underlying data buffers as the original array, which may +/// hold many more values. Calling `copy_array_data` on the sliced array will +/// create a new, smaller, `ArrayData` that only contains the data for the +/// sliced array. +/// +/// # Example +/// ``` +/// # use arrow::array::{make_array, Array, Int32Array}; +/// use datafusion_common::scalar::copy_array_data; +/// let array = Int32Array::from_iter_values(0..8192); +/// // Take only the first 2 elements +/// let sliced_array = array.slice(0, 2); +/// // The memory footprint of `sliced_array` is close to 8192 * 4 bytes +/// assert_eq!(32864, sliced_array.get_array_memory_size()); +/// // however, we can copy the data to a new `ArrayData` +/// let new_array = make_array(copy_array_data(&sliced_array.into_data())); +/// // The memory footprint of `new_array` is now only 2 * 4 bytes +/// // and overhead: +/// assert_eq!(160, new_array.get_array_memory_size()); +/// ``` +/// +/// See also [`ScalarValue::compact`] which applies to `ScalarValue` instances +/// as necessary. +pub fn copy_array_data(src_data: &ArrayData) -> ArrayData { + let mut copy = MutableArrayData::new(vec![&src_data], true, src_data.len()); + copy.extend(0, 0, src_data.len()); copy.freeze() } @@ -3551,6 +4742,7 @@ macro_rules! impl_scalar { impl_scalar!(f64, Float64); impl_scalar!(f32, Float32); +impl_scalar!(f16, Float16); impl_scalar!(i8, Int8); impl_scalar!(i16, Int16); impl_scalar!(i32, Int32); @@ -3570,7 +4762,7 @@ impl From<&str> for ScalarValue { impl From> for ScalarValue { fn from(value: Option<&str>) -> Self { let value = value.map(|s| s.to_string()); - ScalarValue::Utf8(value) + value.into() } } @@ -3597,7 +4789,13 @@ impl FromStr for ScalarValue { impl From for ScalarValue { fn from(value: String) -> Self { - ScalarValue::Utf8(Some(value)) + Some(value).into() + } +} + +impl From> for ScalarValue { + fn from(value: Option) -> Self { + ScalarValue::Utf8(value) } } @@ -3701,6 +4899,7 @@ impl_try_from!(UInt8, u8); impl_try_from!(UInt16, u16); impl_try_from!(UInt32, u32); impl_try_from!(UInt64, u64); +impl_try_from!(Float16, f16); impl_try_from!(Float32, f32); impl_try_from!(Float64, f64); impl_try_from!(Boolean, bool); @@ -3740,6 +4939,12 @@ macro_rules! format_option { impl fmt::Display for ScalarValue { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { + ScalarValue::Decimal32(v, p, s) => { + write!(f, "{v:?},{p:?},{s:?}")?; + } + ScalarValue::Decimal64(v, p, s) => { + write!(f, "{v:?},{p:?},{s:?}")?; + } ScalarValue::Decimal128(v, p, s) => { write!(f, "{v:?},{p:?},{s:?}")?; } @@ -3771,8 +4976,10 @@ impl fmt::Display for ScalarValue { | ScalarValue::BinaryView(e) => match e { Some(bytes) => { // print up to first 10 bytes, with trailing ... if needed + const HEX_CHARS_UPPER: &[u8; 16] = b"0123456789ABCDEF"; for b in bytes.iter().take(10) { - write!(f, "{b:02X}")?; + f.write_char(HEX_CHARS_UPPER[(b >> 4) as usize] as char)?; + f.write_char(HEX_CHARS_UPPER[(b & 0x0f) as usize] as char)?; } if bytes.len() > 10 { write!(f, "...")?; @@ -3780,15 +4987,31 @@ impl fmt::Display for ScalarValue { } None => write!(f, "NULL")?, }, - ScalarValue::List(arr) => fmt_list(arr.to_owned() as ArrayRef, f)?, - ScalarValue::LargeList(arr) => fmt_list(arr.to_owned() as ArrayRef, f)?, - ScalarValue::FixedSizeList(arr) => fmt_list(arr.to_owned() as ArrayRef, f)?, - ScalarValue::Date32(e) => { - format_option!(f, e.map(|v| Date32Type::to_naive_date(v).to_string()))? - } - ScalarValue::Date64(e) => { - format_option!(f, e.map(|v| Date64Type::to_naive_date(v).to_string()))? - } + ScalarValue::List(arr) => fmt_list(arr.as_ref(), f)?, + ScalarValue::LargeList(arr) => fmt_list(arr.as_ref(), f)?, + ScalarValue::FixedSizeList(arr) => fmt_list(arr.as_ref(), f)?, + ScalarValue::Date32(e) => format_option!( + f, + e.map(|v| { + let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + match epoch.checked_add_signed(Duration::try_days(v as i64).unwrap()) + { + Some(date) => date.to_string(), + None => "".to_string(), + } + }) + )?, + ScalarValue::Date64(e) => format_option!( + f, + e.map(|v| { + let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + match epoch.checked_add_signed(Duration::try_milliseconds(v).unwrap()) + { + Some(date) => date.to_string(), + None => "".to_string(), + } + }) + )?, ScalarValue::Time32Second(e) => format_option!(f, e)?, ScalarValue::Time32Millisecond(e) => format_option!(f, e)?, ScalarValue::Time64Microsecond(e) => format_option!(f, e)?, @@ -3882,18 +5105,18 @@ impl fmt::Display for ScalarValue { None => write!(f, "NULL")?, }, ScalarValue::Dictionary(_k, v) => write!(f, "{v}")?, + ScalarValue::RunEndEncoded(_, _, v) => write!(f, "{v}")?, ScalarValue::Null => write!(f, "NULL")?, }; Ok(()) } } -fn fmt_list(arr: ArrayRef, f: &mut fmt::Formatter) -> fmt::Result { +fn fmt_list(arr: &dyn Array, f: &mut fmt::Formatter) -> fmt::Result { // ScalarValue List, LargeList, FixedSizeList should always have a single element assert_eq!(arr.len(), 1); let options = FormatOptions::default().with_display_error(true); - let formatter = - ArrayFormatter::try_new(arr.as_ref() as &dyn Array, &options).unwrap(); + let formatter = ArrayFormatter::try_new(arr, &options).unwrap(); let value_formatter = formatter.value(0); write!(f, "{value_formatter}") } @@ -3913,6 +5136,8 @@ fn fmt_binary(data: &[u8], f: &mut fmt::Formatter) -> fmt::Result { impl fmt::Debug for ScalarValue { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { + ScalarValue::Decimal32(_, _, _) => write!(f, "Decimal32({self})"), + ScalarValue::Decimal64(_, _, _) => write!(f, "Decimal64({self})"), ScalarValue::Decimal128(_, _, _) => write!(f, "Decimal128({self})"), ScalarValue::Decimal256(_, _, _) => write!(f, "Decimal256({self})"), ScalarValue::Boolean(_) => write!(f, "Boolean({self})"), @@ -4059,6 +5284,9 @@ impl fmt::Debug for ScalarValue { None => write!(f, "Union(NULL)"), }, ScalarValue::Dictionary(k, v) => write!(f, "Dictionary({k:?}, {v:?})"), + ScalarValue::RunEndEncoded(rf, vf, v) => { + write!(f, "RunEndEncoded({rf:?}, {vf:?}, {v:?})") + } ScalarValue::Null => write!(f, "NULL"), } } @@ -4108,17 +5336,22 @@ impl ScalarType for Date32Type { #[cfg(test)] mod tests { + use std::sync::Arc; use super::*; - use crate::cast::{ - as_map_array, as_string_array, as_struct_array, as_uint32_array, as_uint64_array, - }; - + use crate::cast::{as_list_array, as_map_array, as_struct_array}; use crate::test_util::batches_to_string; - use arrow::array::{types::Float64Type, NullBufferBuilder}; - use arrow::buffer::{Buffer, OffsetBuffer}; + use arrow::array::{ + FixedSizeListBuilder, Int32Builder, LargeListBuilder, ListBuilder, MapBuilder, + NullArray, NullBufferBuilder, OffsetSizeTrait, PrimitiveBuilder, RecordBatch, + StringBuilder, StringDictionaryBuilder, StructBuilder, UnionBuilder, + }; + use arrow::buffer::{Buffer, NullBuffer, OffsetBuffer}; use arrow::compute::{is_null, kernels}; - use arrow::datatypes::Fields; + use arrow::datatypes::{ + ArrowNumericType, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, Fields, + Float64Type, TimeUnit, + }; use arrow::error::ArrowError; use arrow::util::pretty::pretty_format_columns; use chrono::NaiveDate; @@ -4150,6 +5383,52 @@ mod tests { assert_eq!(actual, &expected); } + #[test] + fn test_format_timestamp_type_for_error_and_bounds() { + // format helper + let ts_ns = format_timestamp_type_for_error(&DataType::Timestamp( + TimeUnit::Nanosecond, + None, + )); + assert_eq!(ts_ns, "Timestamp(ns)"); + + let ts_us = format_timestamp_type_for_error(&DataType::Timestamp( + TimeUnit::Microsecond, + None, + )); + assert_eq!(ts_us, "Timestamp(us)"); + + // ensure_timestamp_in_bounds: Date32 non-overflow + let ok = ensure_timestamp_in_bounds( + 1000, + NANOS_PER_DAY, + &DataType::Date32, + &DataType::Timestamp(TimeUnit::Nanosecond, None), + ); + assert!(ok.is_ok()); + + // Date32 overflow -- known large day value (9999-12-31 -> 2932896) + let err = ensure_timestamp_in_bounds( + 2932896, + NANOS_PER_DAY, + &DataType::Date32, + &DataType::Timestamp(TimeUnit::Nanosecond, None), + ); + assert!(err.is_err()); + let msg = err.unwrap_err().to_string(); + assert!(msg.contains("Cannot cast Date32 value 2932896 to Timestamp(ns): converted value exceeds the representable i64 range")); + + // Date64 overflow for ns (millis * 1_000_000) + let overflow_millis: i64 = (i64::MAX / NANOS_PER_MILLISECOND) + 1; + let err2 = ensure_timestamp_in_bounds( + overflow_millis, + NANOS_PER_MILLISECOND, + &DataType::Date64, + &DataType::Timestamp(TimeUnit::Nanosecond, None), + ); + assert!(err2.is_err()); + } + #[test] fn test_scalar_value_from_for_struct() { let boolean = Arc::new(BooleanArray::from(vec![false])); @@ -4281,6 +5560,91 @@ mod tests { assert_eq!(empty_array.len(), 0); } + #[test] + fn test_to_array_of_size_list_size_one() { + // size=1 takes the fast path (Arc::clone) + let arr = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(10), + Some(20), + ])]); + let sv = ScalarValue::List(Arc::new(arr.clone())); + let result = sv.to_array_of_size(1).unwrap(); + assert_eq!(result.as_list::(), &arr); + } + + #[test] + fn test_to_array_of_size_list_empty_inner() { + // A list scalar containing an empty list: [[]] + let arr = ListArray::from_iter_primitive::(vec![Some(vec![])]); + let sv = ScalarValue::List(Arc::new(arr)); + let result = sv.to_array_of_size(3).unwrap(); + let result_list = result.as_list::(); + assert_eq!(result_list.len(), 3); + for i in 0..3 { + assert_eq!(result_list.value(i).len(), 0); + } + } + + #[test] + fn test_to_array_of_size_large_list() { + let arr = + LargeListArray::from_iter_primitive::(vec![Some(vec![ + Some(100), + Some(200), + ])]); + let sv = ScalarValue::LargeList(Arc::new(arr)); + let result = sv.to_array_of_size(3).unwrap(); + let expected = LargeListArray::from_iter_primitive::(vec![ + Some(vec![Some(100), Some(200)]), + Some(vec![Some(100), Some(200)]), + Some(vec![Some(100), Some(200)]), + ]); + assert_eq!(result.as_list::(), &expected); + } + + #[test] + fn test_list_to_array_of_size_multi_row() { + // Call list_to_array_of_size directly with arr.len() > 1 + let arr = Int32Array::from(vec![Some(10), None, Some(30)]); + let result = ScalarValue::list_to_array_of_size(&arr, 3).unwrap(); + let result = result.as_primitive::(); + assert_eq!( + result.iter().collect::>(), + vec![ + Some(10), + None, + Some(30), + Some(10), + None, + Some(30), + Some(10), + None, + Some(30), + ] + ); + } + + #[test] + fn test_to_array_of_size_null_list() { + let dt = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))); + let sv = ScalarValue::try_from(&dt).unwrap(); + let result = sv.to_array_of_size(3).unwrap(); + assert_eq!(result.len(), 3); + assert_eq!(result.null_count(), 3); + } + + /// See https://github.com/apache/datafusion/issues/18870 + #[test] + fn test_to_array_of_size_for_none_fsb() { + let sv = ScalarValue::FixedSizeBinary(5, None); + let result = sv + .to_array_of_size(2) + .expect("Failed to convert to array of size"); + assert_eq!(result.len(), 2); + assert_eq!(result.null_count(), 2); + assert_eq!(result.as_fixed_size_binary().values().len(), 10); + } + #[test] fn test_list_to_array_string() { let scalars = vec![ @@ -4475,7 +5839,7 @@ mod tests { ]); let array = ScalarValue::iter_to_array(scalars).unwrap(); - let list_array = as_list_array(&array); + let list_array = as_list_array(&array).unwrap(); // List[[1,2,3], null, [4,5]] let expected = ListArray::from_iter_primitive::(vec![ Some(vec![Some(1), Some(2), Some(3)]), @@ -4491,7 +5855,7 @@ mod tests { ]); let array = ScalarValue::iter_to_array(scalars).unwrap(); - let list_array = as_large_list_array(&array); + let list_array = as_large_list_array(&array).unwrap(); let expected = LargeListArray::from_iter_primitive::(vec![ Some(vec![Some(1), Some(2), Some(3)]), None, @@ -4555,6 +5919,17 @@ mod tests { } } + #[test] + fn test_eq_array_err_message() { + assert_starts_with( + ScalarValue::Utf8(Some("123".to_string())) + .eq_array(&(Arc::new(Int32Array::from(vec![123])) as ArrayRef), 0) + .unwrap_err() + .message(), + "could not cast array of type Int32 to arrow_array::array::byte_array::GenericByteArray>", + ); + } + #[test] fn scalar_add_trait_test() -> Result<()> { let float_value = ScalarValue::Float64(Some(123.)); @@ -4625,7 +6000,10 @@ mod tests { .sub_checked(&int_value_2) .unwrap_err() .strip_backtrace(); - assert_eq!(err, "Arrow error: Arithmetic overflow: Overflow happened on: 9223372036854775807 - -9223372036854775808") + assert_eq!( + err, + "Arrow error: Arithmetic overflow: Overflow happened on: 9223372036854775807 - -9223372036854775808" + ) } #[test] @@ -4711,6 +6089,32 @@ mod tests { Ok(()) } + #[test] + fn test_try_cmp() { + assert_eq!( + ScalarValue::try_cmp( + &ScalarValue::Int32(Some(1)), + &ScalarValue::Int32(Some(2)) + ) + .unwrap(), + Ordering::Less + ); + assert_eq!( + ScalarValue::try_cmp(&ScalarValue::Int32(None), &ScalarValue::Int32(Some(2))) + .unwrap(), + Ordering::Less + ); + assert_starts_with( + ScalarValue::try_cmp( + &ScalarValue::Int32(Some(1)), + &ScalarValue::Int64(Some(2)), + ) + .unwrap_err() + .message(), + "Uncomparable values: Int32(1), Int64(2)", + ); + } + #[test] fn scalar_decimal_test() -> Result<()> { let decimal_value = ScalarValue::Decimal128(Some(123), 10, 1); @@ -4747,12 +6151,16 @@ mod tests { assert_eq!(123i128, array_decimal.value(0)); assert_eq!(123i128, array_decimal.value(9)); // test eq array - assert!(decimal_value - .eq_array(&array, 1) - .expect("Failed to compare arrays")); - assert!(decimal_value - .eq_array(&array, 5) - .expect("Failed to compare arrays")); + assert!( + decimal_value + .eq_array(&array, 1) + .expect("Failed to compare arrays") + ); + assert!( + decimal_value + .eq_array(&array, 5) + .expect("Failed to compare arrays") + ); // test try from array assert_eq!( decimal_value, @@ -4797,18 +6205,24 @@ mod tests { assert_eq!(4, array.len()); assert_eq!(DataType::Decimal128(10, 2), array.data_type().clone()); - assert!(ScalarValue::try_new_decimal128(1, 10, 2) - .unwrap() - .eq_array(&array, 0) - .expect("Failed to compare arrays")); - assert!(ScalarValue::try_new_decimal128(2, 10, 2) - .unwrap() - .eq_array(&array, 1) - .expect("Failed to compare arrays")); - assert!(ScalarValue::try_new_decimal128(3, 10, 2) - .unwrap() - .eq_array(&array, 2) - .expect("Failed to compare arrays")); + assert!( + ScalarValue::try_new_decimal128(1, 10, 2) + .unwrap() + .eq_array(&array, 0) + .expect("Failed to compare arrays") + ); + assert!( + ScalarValue::try_new_decimal128(2, 10, 2) + .unwrap() + .eq_array(&array, 1) + .expect("Failed to compare arrays") + ); + assert!( + ScalarValue::try_new_decimal128(3, 10, 2) + .unwrap() + .eq_array(&array, 2) + .expect("Failed to compare arrays") + ); assert_eq!( ScalarValue::Decimal128(None, 10, 2), ScalarValue::try_from_array(&array, 3).unwrap() @@ -4818,12 +6232,120 @@ mod tests { } #[test] - fn test_list_partial_cmp() { - let a = - ScalarValue::List(Arc::new( - ListArray::from_iter_primitive::(vec![Some(vec![ - Some(1), - Some(2), + fn test_new_one_decimal128() { + assert_eq!( + ScalarValue::new_one(&DataType::Decimal128(5, 0)).unwrap(), + ScalarValue::Decimal128(Some(1), 5, 0) + ); + assert_eq!( + ScalarValue::new_one(&DataType::Decimal128(5, 1)).unwrap(), + ScalarValue::Decimal128(Some(10), 5, 1) + ); + assert_eq!( + ScalarValue::new_one(&DataType::Decimal128(5, 2)).unwrap(), + ScalarValue::Decimal128(Some(100), 5, 2) + ); + // More precision + assert_eq!( + ScalarValue::new_one(&DataType::Decimal128(7, 2)).unwrap(), + ScalarValue::Decimal128(Some(100), 7, 2) + ); + // No negative scale + assert!(ScalarValue::new_one(&DataType::Decimal128(5, -1)).is_err()); + // Invalid combination + assert!(ScalarValue::new_one(&DataType::Decimal128(0, 2)).is_err()); + assert!(ScalarValue::new_one(&DataType::Decimal128(5, 7)).is_err()); + } + + #[test] + fn test_new_one_decimal256() { + assert_eq!( + ScalarValue::new_one(&DataType::Decimal256(5, 0)).unwrap(), + ScalarValue::Decimal256(Some(1.into()), 5, 0) + ); + assert_eq!( + ScalarValue::new_one(&DataType::Decimal256(5, 1)).unwrap(), + ScalarValue::Decimal256(Some(10.into()), 5, 1) + ); + assert_eq!( + ScalarValue::new_one(&DataType::Decimal256(5, 2)).unwrap(), + ScalarValue::Decimal256(Some(100.into()), 5, 2) + ); + // More precision + assert_eq!( + ScalarValue::new_one(&DataType::Decimal256(7, 2)).unwrap(), + ScalarValue::Decimal256(Some(100.into()), 7, 2) + ); + // No negative scale + assert!(ScalarValue::new_one(&DataType::Decimal256(5, -1)).is_err()); + // Invalid combination + assert!(ScalarValue::new_one(&DataType::Decimal256(0, 2)).is_err()); + assert!(ScalarValue::new_one(&DataType::Decimal256(5, 7)).is_err()); + } + + #[test] + fn test_new_ten_decimal128() { + assert_eq!( + ScalarValue::new_ten(&DataType::Decimal128(5, 1)).unwrap(), + ScalarValue::Decimal128(Some(100), 5, 1) + ); + assert_eq!( + ScalarValue::new_ten(&DataType::Decimal128(5, 2)).unwrap(), + ScalarValue::Decimal128(Some(1000), 5, 2) + ); + // More precision + assert_eq!( + ScalarValue::new_ten(&DataType::Decimal128(7, 2)).unwrap(), + ScalarValue::Decimal128(Some(1000), 7, 2) + ); + // No negative scale + assert!(ScalarValue::new_ten(&DataType::Decimal128(5, -1)).is_err()); + // Invalid combination + assert!(ScalarValue::new_ten(&DataType::Decimal128(0, 2)).is_err()); + assert!(ScalarValue::new_ten(&DataType::Decimal128(5, 7)).is_err()); + } + + #[test] + fn test_new_ten_decimal256() { + assert_eq!( + ScalarValue::new_ten(&DataType::Decimal256(5, 1)).unwrap(), + ScalarValue::Decimal256(Some(100.into()), 5, 1) + ); + assert_eq!( + ScalarValue::new_ten(&DataType::Decimal256(5, 2)).unwrap(), + ScalarValue::Decimal256(Some(1000.into()), 5, 2) + ); + // More precision + assert_eq!( + ScalarValue::new_ten(&DataType::Decimal256(7, 2)).unwrap(), + ScalarValue::Decimal256(Some(1000.into()), 7, 2) + ); + // No negative scale + assert!(ScalarValue::new_ten(&DataType::Decimal256(5, -1)).is_err()); + // Invalid combination + assert!(ScalarValue::new_ten(&DataType::Decimal256(0, 2)).is_err()); + assert!(ScalarValue::new_ten(&DataType::Decimal256(5, 7)).is_err()); + } + + #[test] + fn test_new_negative_one_decimal128() { + assert_eq!( + ScalarValue::new_negative_one(&DataType::Decimal128(5, 0)).unwrap(), + ScalarValue::Decimal128(Some(-1), 5, 0) + ); + assert_eq!( + ScalarValue::new_negative_one(&DataType::Decimal128(5, 2)).unwrap(), + ScalarValue::Decimal128(Some(-100), 5, 2) + ); + } + + #[test] + fn test_list_partial_cmp() { + let a = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), Some(3), ])]), )); @@ -5136,8 +6658,6 @@ mod tests { } #[test] - // despite clippy claiming they are useless, the code doesn't compile otherwise. - #[allow(clippy::useless_vec)] fn scalar_iter_to_array_boolean() { check_scalar_iter!(Boolean, BooleanArray, vec![Some(true), None, Some(false)]); check_scalar_iter!(Float32, Float32Array, vec![Some(1.9), None, Some(-2.1)]); @@ -5187,12 +6707,12 @@ mod tests { check_scalar_iter_binary!( Binary, BinaryArray, - vec![Some(b"foo"), None, Some(b"bar")] + [Some(b"foo"), None, Some(b"bar")] ); check_scalar_iter_binary!( LargeBinary, LargeBinaryArray, - vec![Some(b"foo"), None, Some(b"bar")] + [Some(b"foo"), None, Some(b"bar")] ); } @@ -5645,7 +7165,9 @@ mod tests { for other_index in 0..array.len() { if index != other_index { assert!( - !scalar.eq_array(&array, other_index).expect("Failed to compare arrays"), + !scalar + .eq_array(&array, other_index) + .expect("Failed to compare arrays"), "Expected {scalar:?} to be NOT equal to {array:?} at index {other_index}" ); } @@ -6073,6 +7595,31 @@ mod tests { } } + #[test] + fn roundtrip_run_array() { + // Comparison logic in round_trip_through_scalar doesn't work for RunArrays + // so we have a custom test for them + // TODO: https://github.com/apache/arrow-rs/pull/9213 might fix this ^ + let run_ends = Int16Array::from(vec![2, 3]); + let values = Int64Array::from(vec![Some(1), None]); + let run_array = RunArray::try_new(&run_ends, &values).unwrap(); + let run_array = run_array.downcast::().unwrap(); + + let expected_values = run_array.into_iter().collect::>(); + + for i in 0..run_array.len() { + let scalar = ScalarValue::try_from_array(&run_array, i).unwrap(); + let array = scalar.to_array_of_size(1).unwrap(); + assert_eq!(array.data_type(), run_array.data_type()); + let array = array.as_run::(); + let array = array.downcast::().unwrap(); + assert_eq!( + array.into_iter().collect::>(), + expected_values[i..i + 1] + ); + } + } + #[test] fn test_scalar_union_sparse() { let field_a = Arc::new(Field::new("A", DataType::Int32, true)); @@ -6570,7 +8117,6 @@ mod tests { } #[test] - #[allow(arithmetic_overflow)] // we want to test them fn test_scalar_negative_overflows() -> Result<()> { macro_rules! test_overflow_on_value { ($($val:expr),* $(,)?) => {$( @@ -6579,10 +8125,7 @@ mod tests { let err = value.arithmetic_negate().expect_err("Should receive overflow error on negating {value:?}"); let root_err = err.find_root(); match root_err{ - DataFusionError::ArrowError( - ArrowError::ArithmeticOverflow(_), - _, - ) => {} + DataFusionError::ArrowError(err, _) if matches!(err.as_ref(), ArrowError::ArithmeticOverflow(_)) => {} _ => return Err(err), }; } @@ -6870,6 +8413,26 @@ mod tests { ScalarValue::Float64(Some(-9.9)), 5, ), + ( + ScalarValue::Decimal128(Some(10), 1, 0), + ScalarValue::Decimal128(Some(5), 1, 0), + 5, + ), + ( + ScalarValue::Decimal128(Some(5), 1, 0), + ScalarValue::Decimal128(Some(10), 1, 0), + 5, + ), + ( + ScalarValue::Decimal256(Some(10.into()), 1, 0), + ScalarValue::Decimal256(Some(5.into()), 1, 0), + 5, + ), + ( + ScalarValue::Decimal256(Some(5.into()), 1, 0), + ScalarValue::Decimal256(Some(10.into()), 1, 0), + 5, + ), ]; for (lhs, rhs, expected) in cases.iter() { let distance = lhs.distance(rhs).unwrap(); @@ -6877,6 +8440,24 @@ mod tests { } } + #[test] + fn test_distance_none() { + let cases = [ + ( + ScalarValue::Decimal128(Some(i128::MAX), DECIMAL128_MAX_PRECISION, 0), + ScalarValue::Decimal128(Some(-i128::MAX), DECIMAL128_MAX_PRECISION, 0), + ), + ( + ScalarValue::Decimal256(Some(i256::MAX), DECIMAL256_MAX_PRECISION, 0), + ScalarValue::Decimal256(Some(-i256::MAX), DECIMAL256_MAX_PRECISION, 0), + ), + ]; + for (lhs, rhs) in cases.iter() { + let distance = lhs.distance(rhs); + assert!(distance.is_none(), "{lhs} vs {rhs}"); + } + } + #[test] fn test_scalar_distance_invalid() { let cases = [ @@ -6918,7 +8499,33 @@ mod tests { (ScalarValue::Date64(Some(0)), ScalarValue::Date64(Some(1))), ( ScalarValue::Decimal128(Some(123), 5, 5), - ScalarValue::Decimal128(Some(120), 5, 5), + ScalarValue::Decimal128(Some(120), 5, 3), + ), + ( + ScalarValue::Decimal128(Some(123), 5, 5), + ScalarValue::Decimal128(Some(120), 3, 5), + ), + ( + ScalarValue::Decimal256(Some(123.into()), 5, 5), + ScalarValue::Decimal256(Some(120.into()), 3, 5), + ), + // Distance 2 * 2^50 is larger than usize + ( + ScalarValue::Decimal256( + Some(i256::from_parts(0, 2_i64.pow(50).into())), + 1, + 0, + ), + ScalarValue::Decimal256( + Some(i256::from_parts(0, (-(2_i64).pow(50)).into())), + 1, + 0, + ), + ), + // Distance overflow + ( + ScalarValue::Decimal256(Some(i256::from_parts(0, i128::MAX)), 1, 0), + ScalarValue::Decimal256(Some(i256::from_parts(0, -i128::MAX)), 1, 0), ), ]; for (lhs, rhs) in cases { @@ -7196,6 +8803,19 @@ mod tests { "); } + #[test] + fn test_display_date64_large_values() { + assert_eq!( + format!("{}", ScalarValue::Date64(Some(790179464505))), + "1995-01-15" + ); + // This used to panic, see https://github.com/apache/arrow-rs/issues/7728 + assert_eq!( + format!("{}", ScalarValue::Date64(Some(-790179464505600000))), + "" + ); + } + #[test] fn test_struct_display_null() { let fields = vec![Field::new("a", DataType::Int32, false)]; @@ -7512,6 +9132,19 @@ mod tests { assert!(dense_scalar.is_null()); } + #[test] + fn cast_date_to_timestamp_overflow_returns_error() { + let scalar = ScalarValue::Date32(Some(i32::MAX)); + let err = scalar + .cast_to(&DataType::Timestamp(TimeUnit::Nanosecond, None)) + .expect_err("expected cast to fail"); + assert!( + err.to_string() + .contains("converted value exceeds the representable i64 range"), + "unexpected error: {err}" + ); + } + #[test] fn null_dictionary_scalar_produces_null_dictionary_array() { let dictionary_scalar = ScalarValue::Dictionary( @@ -7584,7 +9217,7 @@ mod tests { ])), true, )); - let scalars = vec![ + let scalars = [ ScalarValue::try_new_null(&DataType::List(Arc::clone(&field_ref))).unwrap(), ScalarValue::try_new_null(&DataType::LargeList(Arc::clone(&field_ref))) .unwrap(), @@ -7599,11 +9232,654 @@ mod tests { .unwrap(), ScalarValue::try_new_null(&DataType::Map(map_field_ref, false)).unwrap(), ScalarValue::try_new_null(&DataType::Union( - UnionFields::new(vec![42], vec![field_ref]), + UnionFields::try_new(vec![42], vec![field_ref]).unwrap(), UnionMode::Dense, )) .unwrap(), ]; assert!(scalars.iter().all(|s| s.is_null())); } + + // `err.to_string()` depends on backtrace being present (may have backtrace appended) + // `err.strip_backtrace()` also depends on backtrace being present (may have "This was likely caused by ..." stripped) + fn assert_starts_with(actual: impl AsRef, expected_prefix: impl AsRef) { + let actual = actual.as_ref(); + let expected_prefix = expected_prefix.as_ref(); + assert!( + actual.starts_with(expected_prefix), + "Expected '{actual}' to start with '{expected_prefix}'" + ); + } + + #[test] + fn test_new_default() { + // Test numeric types + assert_eq!( + ScalarValue::new_default(&DataType::Int32).unwrap(), + ScalarValue::Int32(Some(0)) + ); + assert_eq!( + ScalarValue::new_default(&DataType::Float64).unwrap(), + ScalarValue::Float64(Some(0.0)) + ); + assert_eq!( + ScalarValue::new_default(&DataType::Boolean).unwrap(), + ScalarValue::Boolean(Some(false)) + ); + + // Test string types + assert_eq!( + ScalarValue::new_default(&DataType::Utf8).unwrap(), + ScalarValue::Utf8(Some("".to_string())) + ); + assert_eq!( + ScalarValue::new_default(&DataType::LargeUtf8).unwrap(), + ScalarValue::LargeUtf8(Some("".to_string())) + ); + + // Test binary types + assert_eq!( + ScalarValue::new_default(&DataType::Binary).unwrap(), + ScalarValue::Binary(Some(vec![])) + ); + + // Test fixed size binary + assert_eq!( + ScalarValue::new_default(&DataType::FixedSizeBinary(5)).unwrap(), + ScalarValue::FixedSizeBinary(5, Some(vec![0, 0, 0, 0, 0])) + ); + + // Test temporal types + assert_eq!( + ScalarValue::new_default(&DataType::Date32).unwrap(), + ScalarValue::Date32(Some(0)) + ); + assert_eq!( + ScalarValue::new_default(&DataType::Time32(TimeUnit::Second)).unwrap(), + ScalarValue::Time32Second(Some(0)) + ); + + // Test decimal types + assert_eq!( + ScalarValue::new_default(&DataType::Decimal128(10, 2)).unwrap(), + ScalarValue::Decimal128(Some(0), 10, 2) + ); + + // Test list type + let list_field = Field::new_list_field(DataType::Int32, true); + let list_result = + ScalarValue::new_default(&DataType::List(Arc::new(list_field.clone()))) + .unwrap(); + match list_result { + ScalarValue::List(arr) => { + assert_eq!(arr.len(), 1); + assert_eq!(arr.value_length(0), 0); // empty list + } + _ => panic!("Expected List"), + } + + // Test struct type + let struct_fields = Fields::from(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + ]); + let struct_result = + ScalarValue::new_default(&DataType::Struct(struct_fields.clone())).unwrap(); + match struct_result { + ScalarValue::Struct(arr) => { + assert_eq!(arr.len(), 1); + assert_eq!(arr.column(0).as_primitive::().value(0), 0); + assert_eq!(arr.column(1).as_string::().value(0), ""); + } + _ => panic!("Expected Struct"), + } + + // Test union type + let union_fields = UnionFields::try_new( + vec![0, 1], + vec![ + Field::new("i32", DataType::Int32, false), + Field::new("f64", DataType::Float64, false), + ], + ) + .unwrap(); + let union_result = ScalarValue::new_default(&DataType::Union( + union_fields.clone(), + UnionMode::Sparse, + )) + .unwrap(); + match union_result { + ScalarValue::Union(Some((type_id, value)), _, _) => { + assert_eq!(type_id, 0); + assert_eq!(*value, ScalarValue::Int32(Some(0))); + } + _ => panic!("Expected Union"), + } + } + + #[test] + fn test_scalar_min() { + // Test integer types + assert_eq!( + ScalarValue::min(&DataType::Int8), + Some(ScalarValue::Int8(Some(i8::MIN))) + ); + assert_eq!( + ScalarValue::min(&DataType::Int32), + Some(ScalarValue::Int32(Some(i32::MIN))) + ); + assert_eq!( + ScalarValue::min(&DataType::UInt8), + Some(ScalarValue::UInt8(Some(0))) + ); + assert_eq!( + ScalarValue::min(&DataType::UInt64), + Some(ScalarValue::UInt64(Some(0))) + ); + + // Test float types + assert_eq!( + ScalarValue::min(&DataType::Float32), + Some(ScalarValue::Float32(Some(f32::NEG_INFINITY))) + ); + assert_eq!( + ScalarValue::min(&DataType::Float64), + Some(ScalarValue::Float64(Some(f64::NEG_INFINITY))) + ); + + // Test decimal types + let decimal_min = ScalarValue::min(&DataType::Decimal128(5, 2)).unwrap(); + match decimal_min { + ScalarValue::Decimal128(Some(val), 5, 2) => { + assert_eq!(val, -99999); // -999.99 with scale 2 + } + _ => panic!("Expected Decimal128"), + } + + // Test temporal types + assert_eq!( + ScalarValue::min(&DataType::Date32), + Some(ScalarValue::Date32(Some(i32::MIN))) + ); + assert_eq!( + ScalarValue::min(&DataType::Time32(TimeUnit::Second)), + Some(ScalarValue::Time32Second(Some(0))) + ); + assert_eq!( + ScalarValue::min(&DataType::Timestamp(TimeUnit::Nanosecond, None)), + Some(ScalarValue::TimestampNanosecond(Some(i64::MIN), None)) + ); + + // Test duration types + assert_eq!( + ScalarValue::min(&DataType::Duration(TimeUnit::Second)), + Some(ScalarValue::DurationSecond(Some(i64::MIN))) + ); + + // Test unsupported types + assert_eq!(ScalarValue::min(&DataType::Utf8), None); + assert_eq!(ScalarValue::min(&DataType::Binary), None); + assert_eq!( + ScalarValue::min(&DataType::List(Arc::new(Field::new( + "item", + DataType::Int32, + true + )))), + None + ); + } + + #[test] + fn test_scalar_max() { + // Test integer types + assert_eq!( + ScalarValue::max(&DataType::Int8), + Some(ScalarValue::Int8(Some(i8::MAX))) + ); + assert_eq!( + ScalarValue::max(&DataType::Int32), + Some(ScalarValue::Int32(Some(i32::MAX))) + ); + assert_eq!( + ScalarValue::max(&DataType::UInt8), + Some(ScalarValue::UInt8(Some(u8::MAX))) + ); + assert_eq!( + ScalarValue::max(&DataType::UInt64), + Some(ScalarValue::UInt64(Some(u64::MAX))) + ); + + // Test float types + assert_eq!( + ScalarValue::max(&DataType::Float32), + Some(ScalarValue::Float32(Some(f32::INFINITY))) + ); + assert_eq!( + ScalarValue::max(&DataType::Float64), + Some(ScalarValue::Float64(Some(f64::INFINITY))) + ); + + // Test decimal types + let decimal_max = ScalarValue::max(&DataType::Decimal128(5, 2)).unwrap(); + match decimal_max { + ScalarValue::Decimal128(Some(val), 5, 2) => { + assert_eq!(val, 99999); // 999.99 with scale 2 + } + _ => panic!("Expected Decimal128"), + } + + // Test temporal types + assert_eq!( + ScalarValue::max(&DataType::Date32), + Some(ScalarValue::Date32(Some(i32::MAX))) + ); + assert_eq!( + ScalarValue::max(&DataType::Time32(TimeUnit::Second)), + Some(ScalarValue::Time32Second(Some(86_399))) // 23:59:59 + ); + assert_eq!( + ScalarValue::max(&DataType::Time64(TimeUnit::Microsecond)), + Some(ScalarValue::Time64Microsecond(Some(86_399_999_999))) // 23:59:59.999999 + ); + assert_eq!( + ScalarValue::max(&DataType::Timestamp(TimeUnit::Nanosecond, None)), + Some(ScalarValue::TimestampNanosecond(Some(i64::MAX), None)) + ); + + // Test duration types + assert_eq!( + ScalarValue::max(&DataType::Duration(TimeUnit::Millisecond)), + Some(ScalarValue::DurationMillisecond(Some(i64::MAX))) + ); + + // Test unsupported types + assert_eq!(ScalarValue::max(&DataType::Utf8), None); + assert_eq!(ScalarValue::max(&DataType::Binary), None); + assert_eq!( + ScalarValue::max(&DataType::Struct(Fields::from(vec![Field::new( + "field", + DataType::Int32, + true + )]))), + None + ); + } + + #[test] + fn test_min_max_float16() { + // Test Float16 min and max + let min_f16 = ScalarValue::min(&DataType::Float16).unwrap(); + match min_f16 { + ScalarValue::Float16(Some(val)) => { + assert_eq!(val, f16::NEG_INFINITY); + } + _ => panic!("Expected Float16"), + } + + let max_f16 = ScalarValue::max(&DataType::Float16).unwrap(); + match max_f16 { + ScalarValue::Float16(Some(val)) => { + assert_eq!(val, f16::INFINITY); + } + _ => panic!("Expected Float16"), + } + } + + #[test] + fn test_new_default_interval() { + // Test all interval types + assert_eq!( + ScalarValue::new_default(&DataType::Interval(IntervalUnit::YearMonth)) + .unwrap(), + ScalarValue::IntervalYearMonth(Some(0)) + ); + assert_eq!( + ScalarValue::new_default(&DataType::Interval(IntervalUnit::DayTime)).unwrap(), + ScalarValue::IntervalDayTime(Some(IntervalDayTime::ZERO)) + ); + assert_eq!( + ScalarValue::new_default(&DataType::Interval(IntervalUnit::MonthDayNano)) + .unwrap(), + ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano::ZERO)) + ); + } + + #[test] + fn test_min_max_with_timezone() { + let tz = Some(Arc::from("UTC")); + + // Test timestamp with timezone + let min_ts = + ScalarValue::min(&DataType::Timestamp(TimeUnit::Second, tz.clone())).unwrap(); + match min_ts { + ScalarValue::TimestampSecond(Some(val), Some(tz_str)) => { + assert_eq!(val, i64::MIN); + assert_eq!(tz_str.as_ref(), "UTC"); + } + _ => panic!("Expected TimestampSecond with timezone"), + } + + let max_ts = + ScalarValue::max(&DataType::Timestamp(TimeUnit::Millisecond, tz.clone())) + .unwrap(); + match max_ts { + ScalarValue::TimestampMillisecond(Some(val), Some(tz_str)) => { + assert_eq!(val, i64::MAX); + assert_eq!(tz_str.as_ref(), "UTC"); + } + _ => panic!("Expected TimestampMillisecond with timezone"), + } + } + + #[test] + fn test_views_minimize_memory() { + let value = "this string is longer than 12 bytes".to_string(); + + let scalar = ScalarValue::Utf8View(Some(value.clone())); + let array = scalar.to_array_of_size(10).unwrap(); + let array = array.as_string_view(); + let buffers = array.data_buffers(); + assert_eq!(1, buffers.len()); + // Ensure we only have a single copy of the value string + assert_eq!(value.len(), buffers[0].len()); + + // Same but for BinaryView + let scalar = ScalarValue::BinaryView(Some(value.bytes().collect())); + let array = scalar.to_array_of_size(10).unwrap(); + let array = array.as_binary_view(); + let buffers = array.data_buffers(); + assert_eq!(1, buffers.len()); + assert_eq!(value.len(), buffers[0].len()); + } + + #[test] + fn test_to_array_of_size_run_end_encoded() { + fn run_test() { + let value = Box::new(ScalarValue::Float32(Some(1.0))); + let size = 5; + let scalar = ScalarValue::RunEndEncoded( + Field::new("run_ends", R::DATA_TYPE, false).into(), + Field::new("values", DataType::Float32, true).into(), + value.clone(), + ); + let array = scalar.to_array_of_size(size).unwrap(); + let array = array.as_run::(); + let array = array.downcast::().unwrap(); + assert_eq!(vec![Some(1.0); size], array.into_iter().collect::>()); + assert_eq!(1, array.values().len()); + } + + run_test::(); + run_test::(); + run_test::(); + + let scalar = ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int16, false).into(), + Field::new("values", DataType::Float32, true).into(), + Box::new(ScalarValue::Float32(Some(1.0))), + ); + let err = scalar.to_array_of_size(i16::MAX as usize + 10).unwrap_err(); + assert_eq!( + "Execution error: Cannot construct RunArray of size 32777: Overflows run-ends type Int16", + err.to_string() + ) + } + + #[test] + fn test_eq_array_run_end_encoded() { + let run_ends = Int16Array::from(vec![1, 3]); + let values = Float32Array::from(vec![None, Some(1.0)]); + let run_array = + Arc::new(RunArray::try_new(&run_ends, &values).unwrap()) as ArrayRef; + + let scalar = ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int16, false).into(), + Field::new("values", DataType::Float32, true).into(), + Box::new(ScalarValue::Float32(None)), + ); + assert!(scalar.eq_array(&run_array, 0).unwrap()); + + let scalar = ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int16, false).into(), + Field::new("values", DataType::Float32, true).into(), + Box::new(ScalarValue::Float32(Some(1.0))), + ); + assert!(scalar.eq_array(&run_array, 1).unwrap()); + assert!(scalar.eq_array(&run_array, 2).unwrap()); + + // value types must match + let scalar = ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int16, false).into(), + Field::new("values", DataType::Float64, true).into(), + Box::new(ScalarValue::Float64(Some(1.0))), + ); + let err = scalar.eq_array(&run_array, 1).unwrap_err(); + let expected = "Internal error: could not cast array of type Float32 to arrow_array::array::primitive_array::PrimitiveArray"; + assert!(err.to_string().starts_with(expected)); + + // run ends type must match + let scalar = ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int32, false).into(), + Field::new("values", DataType::Float32, true).into(), + Box::new(ScalarValue::Float32(None)), + ); + let err = scalar.eq_array(&run_array, 0).unwrap_err(); + let expected = "Internal error: could not cast array of type RunEndEncoded(\"run_ends\": non-null Int16, \"values\": Float32) to arrow_array::array::run_array::RunArray"; + assert!(err.to_string().starts_with(expected)); + } + + #[test] + fn test_iter_to_array_run_end_encoded() { + let run_ends_field = Arc::new(Field::new("run_ends", DataType::Int16, false)); + let values_field = Arc::new(Field::new("values", DataType::Int64, true)); + let scalars = vec![ + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(1))), + ), + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(1))), + ), + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(None)), + ), + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(2))), + ), + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(2))), + ), + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(2))), + ), + ]; + + let run_array = ScalarValue::iter_to_array(scalars).unwrap(); + let expected = RunArray::try_new( + &Int16Array::from(vec![2, 3, 6]), + &Int64Array::from(vec![Some(1), None, Some(2)]), + ) + .unwrap(); + assert_eq!(&expected as &dyn Array, run_array.as_ref()); + + // inconsistent run-ends type + let scalars = vec![ + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(1))), + ), + ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int32, false).into(), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(1))), + ), + ]; + let err = ScalarValue::iter_to_array(scalars).unwrap_err(); + let expected = "Execution error: Expected RunEndEncoded scalar with run-ends field Field { \"run_ends\": Int16 } but got: RunEndEncoded(Field { name: \"run_ends\", data_type: Int32 }, Field { name: \"values\", data_type: Int64, nullable: true }, Int64(1))"; + assert!(err.to_string().starts_with(expected)); + + // inconsistent value type + let scalars = vec![ + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(1))), + ), + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Field::new("values", DataType::Int32, true).into(), + Box::new(ScalarValue::Int32(Some(1))), + ), + ]; + let err = ScalarValue::iter_to_array(scalars).unwrap_err(); + let expected = "Execution error: Expected RunEndEncoded scalar with run-ends field Field { \"run_ends\": Int16 } but got: RunEndEncoded(Field { name: \"run_ends\", data_type: Int16 }, Field { name: \"values\", data_type: Int32, nullable: true }, Int32(1))"; + assert!(err.to_string().starts_with(expected)); + + // inconsistent scalars type + let scalars = vec![ + ScalarValue::RunEndEncoded( + Arc::clone(&run_ends_field), + Arc::clone(&values_field), + Box::new(ScalarValue::Int64(Some(1))), + ), + ScalarValue::Int64(Some(1)), + ]; + let err = ScalarValue::iter_to_array(scalars).unwrap_err(); + let expected = "Execution error: Expected RunEndEncoded scalar with run-ends field Field { \"run_ends\": Int16 } but got: Int64(1)"; + assert!(err.to_string().starts_with(expected)); + } + + #[test] + fn test_convert_array_to_scalar_vec() { + // 1: Regular ListArray + let list = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(4)]), + ]); + let converted = ScalarValue::convert_array_to_scalar_vec(&list).unwrap(); + assert_eq!( + converted, + vec![ + Some(vec![ + ScalarValue::Int64(Some(1)), + ScalarValue::Int64(Some(2)) + ]), + None, + Some(vec![ + ScalarValue::Int64(Some(3)), + ScalarValue::Int64(None), + ScalarValue::Int64(Some(4)) + ]), + ] + ); + + // 2: Regular LargeListArray + let large_list = LargeListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(4)]), + ]); + let converted = ScalarValue::convert_array_to_scalar_vec(&large_list).unwrap(); + assert_eq!( + converted, + vec![ + Some(vec![ + ScalarValue::Int64(Some(1)), + ScalarValue::Int64(Some(2)) + ]), + None, + Some(vec![ + ScalarValue::Int64(Some(3)), + ScalarValue::Int64(None), + ScalarValue::Int64(Some(4)) + ]), + ] + ); + + // 3: Funky (null slot has non-zero list offsets) + // Offsets + Values looks like this: [[1, 2], [3, 4], [5]] + // But with NullBuffer it's like this: [[1, 2], NULL, [5]] + let funky = ListArray::new( + Field::new_list_field(DataType::Int64, true).into(), + OffsetBuffer::new(vec![0, 2, 4, 5].into()), + Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5, 6])), + Some(NullBuffer::from(vec![true, false, true])), + ); + let converted = ScalarValue::convert_array_to_scalar_vec(&funky).unwrap(); + assert_eq!( + converted, + vec![ + Some(vec![ + ScalarValue::Int64(Some(1)), + ScalarValue::Int64(Some(2)) + ]), + None, + Some(vec![ScalarValue::Int64(Some(5))]), + ] + ); + + // 4: Offsets + Values looks like this: [[1, 2], [], [5]] + // But with NullBuffer it's like this: [[1, 2], NULL, [5]] + // The converted result is: [[1, 2], None, [5]] + let array4 = ListArray::new( + Field::new_list_field(DataType::Int64, true).into(), + OffsetBuffer::new(vec![0, 2, 2, 5].into()), + Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5, 6])), + Some(NullBuffer::from(vec![true, false, true])), + ); + let converted = ScalarValue::convert_array_to_scalar_vec(&array4).unwrap(); + assert_eq!( + converted, + vec![ + Some(vec![ + ScalarValue::Int64(Some(1)), + ScalarValue::Int64(Some(2)) + ]), + None, + Some(vec![ + ScalarValue::Int64(Some(3)), + ScalarValue::Int64(Some(4)), + ScalarValue::Int64(Some(5)), + ]), + ] + ); + + // 5: Offsets + Values looks like this: [[1, 2], [], [5]] + // Same as 4, but the middle array is not null, so after conversion it's empty. + let array5 = ListArray::new( + Field::new_list_field(DataType::Int64, true).into(), + OffsetBuffer::new(vec![0, 2, 2, 5].into()), + Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5, 6])), + Some(NullBuffer::from(vec![true, true, true])), + ); + let converted = ScalarValue::convert_array_to_scalar_vec(&array5).unwrap(); + assert_eq!( + converted, + vec![ + Some(vec![ + ScalarValue::Int64(Some(1)), + ScalarValue::Int64(Some(2)) + ]), + Some(vec![]), + Some(vec![ + ScalarValue::Int64(Some(3)), + ScalarValue::Int64(Some(4)), + ScalarValue::Int64(Some(5)), + ]), + ] + ); + } } diff --git a/datafusion/common/src/scalar/struct_builder.rs b/datafusion/common/src/scalar/struct_builder.rs index fd19dccf89636..045b5778243df 100644 --- a/datafusion/common/src/scalar/struct_builder.rs +++ b/datafusion/common/src/scalar/struct_builder.rs @@ -47,13 +47,11 @@ impl ScalarStructBuilder { /// ```rust /// # use arrow::datatypes::{DataType, Field}; /// # use datafusion_common::scalar::ScalarStructBuilder; - /// let fields = vec![ - /// Field::new("a", DataType::Int32, false), - /// ]; + /// let fields = vec![Field::new("a", DataType::Int32, false)]; /// let sv = ScalarStructBuilder::new_null(fields); /// // Note this is `NULL`, not `{a: NULL}` /// assert_eq!(format!("{sv}"), "NULL"); - ///``` + /// ``` /// /// To create a struct where the *fields* are null, use `Self::new()` and /// pass null values for each field: @@ -65,9 +63,9 @@ impl ScalarStructBuilder { /// let field = Field::new("a", DataType::Int32, true); /// // add a null value for the "a" field /// let sv = ScalarStructBuilder::new() - /// .with_scalar(field, ScalarValue::Int32(None)) - /// .build() - /// .unwrap(); + /// .with_scalar(field, ScalarValue::Int32(None)) + /// .build() + /// .unwrap(); /// // value is not null, but field is /// assert_eq!(format!("{sv}"), "{a:}"); /// ``` @@ -85,6 +83,7 @@ impl ScalarStructBuilder { } /// Add the specified field and `ScalarValue` to the struct. + #[expect(clippy::needless_pass_by_value)] // Skip for public API's compatibility pub fn with_scalar(self, field: impl IntoFieldRef, value: ScalarValue) -> Self { // valid scalar value should not fail let array = value.to_array().unwrap(); diff --git a/datafusion/common/src/spans.rs b/datafusion/common/src/spans.rs index 5111e264123ce..c0b52977e14a9 100644 --- a/datafusion/common/src/spans.rs +++ b/datafusion/common/src/spans.rs @@ -39,6 +39,7 @@ impl fmt::Debug for Location { } } +#[cfg(feature = "sql")] impl From for Location { fn from(value: sqlparser::tokenizer::Location) -> Self { Self { @@ -70,6 +71,7 @@ impl Span { /// Convert a [`Span`](sqlparser::tokenizer::Span) from the parser, into a /// DataFusion [`Span`]. If the input span is empty (line 0 column 0, to /// line 0 column 0), then [`None`] is returned. + #[cfg(feature = "sql")] pub fn try_from_sqlparser_span(span: sqlparser::tokenizer::Span) -> Option { if span == sqlparser::tokenizer::Span::empty() { None diff --git a/datafusion/common/src/stats.rs b/datafusion/common/src/stats.rs index a6d132ef51f6a..f263c905faf6b 100644 --- a/datafusion/common/src/stats.rs +++ b/datafusion/common/src/stats.rs @@ -22,17 +22,40 @@ use std::fmt::{self, Debug, Display}; use crate::{Result, ScalarValue}; use crate::error::_plan_err; -use arrow::datatypes::{DataType, Schema, SchemaRef}; +use crate::utils::aggregate::precision_add; +use arrow::datatypes::{DataType, Schema}; /// Represents a value with a degree of certainty. `Precision` is used to /// propagate information the precision of statistical values. #[derive(Clone, PartialEq, Eq, Default, Copy)] pub enum Precision { - /// The exact value is known + /// The exact value is known. Used for guaranteeing correctness. + /// + /// Comes from definitive sources such as: + /// - Parquet file metadata (row counts, byte sizes) + /// - In-memory RecordBatch data (actual row counts, byte sizes, null counts) + /// - and more... Exact(T), - /// The value is not known exactly, but is likely close to this value + /// The value is not known exactly, but is likely close to this value. + /// Used for cost-based optimizations. + /// + /// Some operations that would result in `Inexact(T)` would be: + /// - Applying a filter (selectivity is unknown) + /// - Mixing exact and inexact values in arithmetic + /// - and more... Inexact(T), - /// Nothing is known about the value + /// Nothing is known about the value. This is the default state. + /// + /// Acts as an absorbing element in arithmetic -> any operation + /// involving `Absent` yields `Absent`. [`Precision::to_inexact`] + /// on `Absent` returns `Absent`, not `Inexact` — it represents + /// a fundamentally different state. + /// + /// Common sources include: + /// - Data sources without statistics + /// - Parquet columns missing from file metadata + /// - Statistics that cannot be derived for an operation (e.g., + /// `distinct_count` after a union, `total_byte_size` for joins) #[default] Absent, } @@ -120,10 +143,15 @@ impl Precision { /// values is [`Precision::Absent`], the result is `Absent` too. pub fn add(&self, other: &Precision) -> Precision { match (self, other) { - (Precision::Exact(a), Precision::Exact(b)) => Precision::Exact(a + b), + (Precision::Exact(a), Precision::Exact(b)) => a.checked_add(*b).map_or_else( + || Precision::Inexact(a.saturating_add(*b)), + Precision::Exact, + ), (Precision::Inexact(a), Precision::Exact(b)) | (Precision::Exact(a), Precision::Inexact(b)) - | (Precision::Inexact(a), Precision::Inexact(b)) => Precision::Inexact(a + b), + | (Precision::Inexact(a), Precision::Inexact(b)) => { + Precision::Inexact(a.saturating_add(*b)) + } (_, _) => Precision::Absent, } } @@ -133,10 +161,15 @@ impl Precision { /// values is [`Precision::Absent`], the result is `Absent` too. pub fn sub(&self, other: &Precision) -> Precision { match (self, other) { - (Precision::Exact(a), Precision::Exact(b)) => Precision::Exact(a - b), + (Precision::Exact(a), Precision::Exact(b)) => a.checked_sub(*b).map_or_else( + || Precision::Inexact(a.saturating_sub(*b)), + Precision::Exact, + ), (Precision::Inexact(a), Precision::Exact(b)) | (Precision::Exact(a), Precision::Inexact(b)) - | (Precision::Inexact(a), Precision::Inexact(b)) => Precision::Inexact(a - b), + | (Precision::Inexact(a), Precision::Inexact(b)) => { + Precision::Inexact(a.saturating_sub(*b)) + } (_, _) => Precision::Absent, } } @@ -146,10 +179,15 @@ impl Precision { /// values is [`Precision::Absent`], the result is `Absent` too. pub fn multiply(&self, other: &Precision) -> Precision { match (self, other) { - (Precision::Exact(a), Precision::Exact(b)) => Precision::Exact(a * b), + (Precision::Exact(a), Precision::Exact(b)) => a.checked_mul(*b).map_or_else( + || Precision::Inexact(a.saturating_mul(*b)), + Precision::Exact, + ), (Precision::Inexact(a), Precision::Exact(b)) | (Precision::Exact(a), Precision::Inexact(b)) - | (Precision::Inexact(a), Precision::Inexact(b)) => Precision::Inexact(a * b), + | (Precision::Inexact(a), Precision::Inexact(b)) => { + Precision::Inexact(a.saturating_mul(*b)) + } (_, _) => Precision::Absent, } } @@ -168,15 +206,22 @@ impl Precision { /// Calculates the sum of two (possibly inexact) [`ScalarValue`] values, /// conservatively propagating exactness information. If one of the input /// values is [`Precision::Absent`], the result is `Absent` too. + /// + /// Uses [`ScalarValue::add_checked`] so that integer overflow returns + /// an error (mapped to `Absent`) instead of silently wrapping. + /// + /// For performance-sensitive paths prefer `precision_add` which + /// avoids the Arrow array round-trip. pub fn add(&self, other: &Precision) -> Precision { match (self, other) { - (Precision::Exact(a), Precision::Exact(b)) => { - a.add(b).map(Precision::Exact).unwrap_or(Precision::Absent) - } + (Precision::Exact(a), Precision::Exact(b)) => a + .add_checked(b) + .map(Precision::Exact) + .unwrap_or(Precision::Absent), (Precision::Inexact(a), Precision::Exact(b)) | (Precision::Exact(a), Precision::Inexact(b)) | (Precision::Inexact(a), Precision::Inexact(b)) => a - .add(b) + .add_checked(b) .map(Precision::Inexact) .unwrap_or(Precision::Absent), (_, _) => Precision::Absent, @@ -268,9 +313,14 @@ impl From> for Precision { /// and the transformations output are not always predictable. #[derive(Debug, Clone, PartialEq, Eq)] pub struct Statistics { - /// The number of table rows. + /// The number of rows estimated to be scanned. pub num_rows: Precision, - /// Total bytes of the table rows. + /// The total bytes of the output data. + /// + /// Note that this is not the same as the total bytes that may be scanned, + /// processed, etc. + /// E.g. we may read 1GB of data from a Parquet file but the Arrow data + /// the node produces may be 2GB; it's this 2GB that is tracked here. pub total_byte_size: Precision, /// Statistics on a column level. /// @@ -302,6 +352,31 @@ impl Statistics { } } + /// Calculates `total_byte_size` based on the schema and `num_rows`. + /// If any of the columns has non-primitive width, `total_byte_size` is set to inexact. + pub fn calculate_total_byte_size(&mut self, schema: &Schema) { + let mut row_size = Some(0); + for field in schema.fields() { + match field.data_type().primitive_width() { + Some(width) => { + row_size = row_size.map(|s| s + width); + } + None => { + row_size = None; + break; + } + } + } + match row_size { + None => { + self.total_byte_size = self.total_byte_size.to_inexact(); + } + Some(size) => { + self.total_byte_size = self.num_rows.multiply(&Precision::Exact(size)); + } + } + } + /// Returns an unbounded `ColumnStatistics` for each field in the schema. pub fn unknown_column(schema: &Schema) -> Vec { schema @@ -347,12 +422,17 @@ impl Statistics { /// For example, if we had statistics for columns `{"a", "b", "c"}`, /// projecting to `vec![2, 1]` would return statistics for columns `{"c", /// "b"}`. - pub fn project(mut self, projection: Option<&Vec>) -> Self { - let Some(projection) = projection else { + pub fn project(self, projection: Option<&impl AsRef<[usize]>>) -> Self { + let projection = projection.map(AsRef::as_ref); + self.project_impl(projection) + } + + fn project_impl(mut self, projection: Option<&[usize]>) -> Self { + let Some(projection) = projection.map(AsRef::as_ref) else { return self; }; - #[allow(clippy::large_enum_variant)] + #[expect(clippy::large_enum_variant)] enum Slot { /// The column is taken and put into the specified statistics location Taken(usize), @@ -366,7 +446,7 @@ impl Statistics { .map(Slot::Present) .collect(); - for idx in projection { + for idx in projection.iter() { let next_idx = self.column_statistics.len(); let slot = std::mem::replace( columns.get_mut(*idx).expect("projection out of bounds"), @@ -391,13 +471,15 @@ impl Statistics { /// parameter to compute global statistics in a multi-partition setting. pub fn with_fetch( mut self, - schema: SchemaRef, fetch: Option, skip: usize, n_partitions: usize, ) -> Result { let fetch_val = fetch.unwrap_or(usize::MAX); + // Get the ratio of rows after / rows before on a per-partition basis + let num_rows_before = self.num_rows; + self.num_rows = match self { Statistics { num_rows: Precision::Exact(nr), @@ -431,8 +513,7 @@ impl Statistics { // At this point we know that we were given a `fetch` value // as the `None` case would go into the branch above. Since // the input has more rows than `fetch + skip`, the number - // of rows will be the `fetch`, but we won't be able to - // predict the other statistics. + // of rows will be the `fetch`, other statistics will have to be downgraded to inexact. check_num_rows( fetch_val.checked_mul(n_partitions), // We know that we have an estimate for the number of rows: @@ -445,8 +526,55 @@ impl Statistics { .. } => check_num_rows(fetch.and_then(|v| v.checked_mul(n_partitions)), false), }; - self.column_statistics = Statistics::unknown_column(&schema); - self.total_byte_size = Precision::Absent; + let ratio: f64 = match (num_rows_before, self.num_rows) { + ( + Precision::Exact(nr_before) | Precision::Inexact(nr_before), + Precision::Exact(nr_after) | Precision::Inexact(nr_after), + ) => { + if nr_before == 0 { + 0.0 + } else { + nr_after as f64 / nr_before as f64 + } + } + _ => 0.0, + }; + self.column_statistics = self + .column_statistics + .into_iter() + .map(|cs| { + let mut cs = cs.to_inexact(); + // Scale byte_size by the row ratio + cs.byte_size = match cs.byte_size { + Precision::Exact(n) | Precision::Inexact(n) => { + Precision::Inexact((n as f64 * ratio) as usize) + } + Precision::Absent => Precision::Absent, + }; + cs + }) + .collect(); + + // Compute total_byte_size as sum of column byte_size values if all are present, + // otherwise fall back to scaling the original total_byte_size + let sum_scan_bytes: Option = self + .column_statistics + .iter() + .map(|cs| cs.byte_size.get_value().copied()) + .try_fold(0usize, |acc, val| val.map(|v| acc + v)); + + self.total_byte_size = match sum_scan_bytes { + Some(sum) => Precision::Inexact(sum), + None => { + // Fall back to scaling original total_byte_size if not all columns have byte_size + match &self.total_byte_size { + Precision::Exact(n) | Precision::Inexact(n) => { + Precision::Inexact((*n as f64 * ratio) as usize) + } + Precision::Absent => Precision::Absent, + } + } + }; Ok(self) } @@ -456,23 +584,6 @@ impl Statistics { /// If not, maybe you can call `SchemaMapper::map_column_statistics` to make them consistent. /// /// Returns an error if the statistics do not match the specified schemas. - pub fn try_merge_iter<'a, I>(items: I, schema: &Schema) -> Result - where - I: IntoIterator, - { - let mut items = items.into_iter(); - - let Some(init) = items.next() else { - return Ok(Statistics::new_unknown(schema)); - }; - items.try_fold(init.clone(), |acc: Statistics, item_stats: &Statistics| { - acc.try_merge(item_stats) - }) - } - - /// Merge this Statistics value with another Statistics value. - /// - /// Returns an error if the statistics do not match (different schemas). /// /// # Example /// ``` @@ -480,64 +591,113 @@ impl Statistics { /// # use arrow::datatypes::{Field, Schema, DataType}; /// # use datafusion_common::stats::Precision; /// let stats1 = Statistics::default() - /// .with_num_rows(Precision::Exact(1)) - /// .with_total_byte_size(Precision::Exact(2)) - /// .add_column_statistics(ColumnStatistics::new_unknown() - /// .with_null_count(Precision::Exact(3)) - /// .with_min_value(Precision::Exact(ScalarValue::from(4))) - /// .with_max_value(Precision::Exact(ScalarValue::from(5))) - /// ); + /// .with_num_rows(Precision::Exact(10)) + /// .add_column_statistics( + /// ColumnStatistics::new_unknown() + /// .with_min_value(Precision::Exact(ScalarValue::from(1))) + /// .with_max_value(Precision::Exact(ScalarValue::from(100))) + /// .with_sum_value(Precision::Exact(ScalarValue::from(500))), + /// ); /// /// let stats2 = Statistics::default() - /// .with_num_rows(Precision::Exact(10)) - /// .with_total_byte_size(Precision::Inexact(20)) - /// .add_column_statistics(ColumnStatistics::new_unknown() - /// // absent null count - /// .with_min_value(Precision::Exact(ScalarValue::from(40))) - /// .with_max_value(Precision::Exact(ScalarValue::from(50))) - /// ); + /// .with_num_rows(Precision::Exact(20)) + /// .add_column_statistics( + /// ColumnStatistics::new_unknown() + /// .with_min_value(Precision::Exact(ScalarValue::from(5))) + /// .with_max_value(Precision::Exact(ScalarValue::from(200))) + /// .with_sum_value(Precision::Exact(ScalarValue::from(1000))), + /// ); /// - /// let merged_stats = stats1.try_merge(&stats2).unwrap(); - /// let expected_stats = Statistics::default() - /// .with_num_rows(Precision::Exact(11)) - /// .with_total_byte_size(Precision::Inexact(22)) // inexact in stats2 --> inexact - /// .add_column_statistics( - /// ColumnStatistics::new_unknown() - /// .with_null_count(Precision::Absent) // missing from stats2 --> absent - /// .with_min_value(Precision::Exact(ScalarValue::from(4))) - /// .with_max_value(Precision::Exact(ScalarValue::from(50))) - /// ); + /// let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + /// let merged = Statistics::try_merge_iter( + /// &[stats1, stats2], + /// &schema, + /// ).unwrap(); /// - /// assert_eq!(merged_stats, expected_stats) + /// assert_eq!(merged.num_rows, Precision::Exact(30)); + /// assert_eq!(merged.column_statistics[0].min_value, + /// Precision::Exact(ScalarValue::from(1))); + /// assert_eq!(merged.column_statistics[0].max_value, + /// Precision::Exact(ScalarValue::from(200))); + /// assert_eq!(merged.column_statistics[0].sum_value, + /// Precision::Exact(ScalarValue::from(1500))); /// ``` - pub fn try_merge(self, other: &Statistics) -> Result { - let Self { - mut num_rows, - mut total_byte_size, - mut column_statistics, - } = self; - - // Accumulate statistics for subsequent items - num_rows = num_rows.add(&other.num_rows); - total_byte_size = total_byte_size.add(&other.total_byte_size); - - if column_statistics.len() != other.column_statistics.len() { - return _plan_err!( - "Cannot merge statistics with different number of columns: {} vs {}", - column_statistics.len(), - other.column_statistics.len() - ); + pub fn try_merge_iter<'a, I>(items: I, schema: &Schema) -> Result + where + I: IntoIterator, + { + let items: Vec<&Statistics> = items.into_iter().collect(); + + if items.is_empty() { + return Ok(Statistics::new_unknown(schema)); + } + if items.len() == 1 { + return Ok(items[0].clone()); + } + + let num_cols = items[0].column_statistics.len(); + // Validate all items have the same number of columns + for (i, stat) in items.iter().enumerate().skip(1) { + if stat.column_statistics.len() != num_cols { + return _plan_err!( + "Cannot merge statistics with different number of columns: {} vs {} (item {})", + num_cols, + stat.column_statistics.len(), + i + ); + } } - for (item_col_stats, col_stats) in other + // Aggregate usize fields (cheap arithmetic) + let mut num_rows = Precision::Exact(0usize); + let mut total_byte_size = Precision::Exact(0usize); + for stat in &items { + num_rows = num_rows.add(&stat.num_rows); + total_byte_size = total_byte_size.add(&stat.total_byte_size); + } + + let first = items[0]; + let mut column_statistics: Vec = first .column_statistics .iter() - .zip(column_statistics.iter_mut()) - { - col_stats.null_count = col_stats.null_count.add(&item_col_stats.null_count); - col_stats.max_value = col_stats.max_value.max(&item_col_stats.max_value); - col_stats.min_value = col_stats.min_value.min(&item_col_stats.min_value); - col_stats.sum_value = col_stats.sum_value.add(&item_col_stats.sum_value); + .map(|cs| ColumnStatistics { + null_count: cs.null_count, + max_value: cs.max_value.clone(), + min_value: cs.min_value.clone(), + sum_value: cs.sum_value.clone(), + distinct_count: cs.distinct_count, + byte_size: cs.byte_size, + }) + .collect(); + + // Accumulate all statistics in a single pass. + // Uses precision_add for sum (avoids the expensive + // ScalarValue::add round-trip through Arrow arrays), and + // Precision::min/max which use cheap PartialOrd comparison. + for stat in items.iter().skip(1) { + for (col_idx, col_stats) in column_statistics.iter_mut().enumerate() { + let item_cs = &stat.column_statistics[col_idx]; + + col_stats.null_count = col_stats.null_count.add(&item_cs.null_count); + + // NDV must be computed before min/max update (needs pre-merge ranges) + col_stats.distinct_count = match ( + col_stats.distinct_count.get_value(), + item_cs.distinct_count.get_value(), + ) { + (Some(&l), Some(&r)) => Precision::Inexact( + estimate_ndv_with_overlap(col_stats, item_cs, l, r) + .unwrap_or_else(|| usize::max(l, r)), + ), + _ => Precision::Absent, + }; + + col_stats.min_value = col_stats.min_value.min(&item_cs.min_value); + col_stats.max_value = col_stats.max_value.max(&item_cs.max_value); + col_stats.sum_value = + precision_add(&col_stats.sum_value, &item_cs.sum_value); + col_stats.byte_size = col_stats.byte_size.add(&item_cs.byte_size); + } } Ok(Statistics { @@ -548,6 +708,96 @@ impl Statistics { } } +/// Estimates the combined number of distinct values (NDV) when merging two +/// column statistics, using range overlap to avoid double-counting shared values. +/// +/// Assumes values are distributed uniformly within each input's +/// `[min, max]` range (the standard assumption when only summary +/// statistics are available). Under uniformity the fraction of an input's +/// distinct values that land in a sub-range equals the fraction of +/// the range that sub-range covers. +/// +/// The combined value space is split into three disjoint regions: +/// +/// ```text +/// |-- only A --|-- overlap --|-- only B --| +/// ``` +/// +/// * **Only in A/B** - values outside the other input's range +/// contribute `(1 - overlap_a) * NDV_a` and `(1 - overlap_b) * NDV_b`. +/// * **Overlap** - both inputs may produce values here. We take +/// `max(overlap_a * NDV_a, overlap_b * NDV_b)` rather than the +/// sum because values in the same sub-range are likely shared +/// (the smaller set is assumed to be a subset of the larger). +/// +/// The formula ranges between `[max(NDV_a, NDV_b), NDV_a + NDV_b]`, +/// from full overlap to no overlap. +/// +/// ```text +/// NDV = max(overlap_a * NDV_a, overlap_b * NDV_b) [intersection] +/// + (1 - overlap_a) * NDV_a [only in A] +/// + (1 - overlap_b) * NDV_b [only in B] +/// ``` +/// +/// Returns `None` when min/max are absent or distance is unsupported +/// (e.g. strings), in which case the caller should fall back to a simpler +/// estimate. +pub fn estimate_ndv_with_overlap( + left: &ColumnStatistics, + right: &ColumnStatistics, + ndv_left: usize, + ndv_right: usize, +) -> Option { + let left_min = left.min_value.get_value()?; + let left_max = left.max_value.get_value()?; + let right_min = right.min_value.get_value()?; + let right_max = right.max_value.get_value()?; + + let range_left = left_max.distance(left_min)?; + let range_right = right_max.distance(right_min)?; + + // Constant columns (range == 0) can't use the proportional overlap + // formula below, so check interval overlap directly instead. + if range_left == 0 || range_right == 0 { + let overlaps = left_min <= right_max && right_min <= left_max; + return Some(if overlaps { + usize::max(ndv_left, ndv_right) + } else { + ndv_left + ndv_right + }); + } + + let overlap_min = if left_min >= right_min { + left_min + } else { + right_min + }; + let overlap_max = if left_max <= right_max { + left_max + } else { + right_max + }; + + // Disjoint ranges: no overlap, NDVs are additive + if overlap_min > overlap_max { + return Some(ndv_left + ndv_right); + } + + let overlap_range = overlap_max.distance(overlap_min)? as f64; + + let overlap_left = overlap_range / range_left as f64; + let overlap_right = overlap_range / range_right as f64; + + let intersection = f64::max( + overlap_left * ndv_left as f64, + overlap_right * ndv_right as f64, + ); + let only_left = (1.0 - overlap_left) * ndv_left as f64; + let only_right = (1.0 - overlap_right) * ndv_right as f64; + + Some((intersection + only_left + only_right).round() as usize) +} + /// Creates an estimate of the number of rows in the output using the given /// optional value and exactness flag. fn check_num_rows(value: Option, is_exact: bool) -> Precision { @@ -599,6 +849,11 @@ impl Display for Statistics { } else { s }; + let s = if cs.byte_size != Precision::Absent { + format!("{} ScanBytes={}", s, cs.byte_size) + } else { + s + }; s + ")" }) @@ -628,6 +883,21 @@ pub struct ColumnStatistics { pub sum_value: Precision, /// Number of distinct values pub distinct_count: Precision, + /// Estimated size of this column's data in bytes for the output. + /// + /// Note that this is not the same as the total bytes that may be scanned, + /// processed, etc. + /// + /// E.g. we may read 1GB of data from a Parquet file but the Arrow data + /// the node produces may be 2GB; it's this 2GB that is tracked here. + /// + /// Currently this is accurately calculated for primitive types only. + /// For complex types (like Utf8, List, Struct, etc), this value may be + /// absent or inexact (e.g. estimated from the size of the data in the source Parquet files). + /// + /// This value is automatically scaled when operations like limits or + /// filters reduce the number of rows (see [`Statistics::with_fetch`]). + pub byte_size: Precision, } impl ColumnStatistics { @@ -650,6 +920,7 @@ impl ColumnStatistics { min_value: Precision::Absent, sum_value: Precision::Absent, distinct_count: Precision::Absent, + byte_size: Precision::Absent, } } @@ -683,6 +954,13 @@ impl ColumnStatistics { self } + /// Set the scan byte size + /// This should initially be set to the total size of the column. + pub fn with_byte_size(mut self, byte_size: Precision) -> Self { + self.byte_size = byte_size; + self + } + /// If the exactness of a [`ColumnStatistics`] instance is lost, this /// function relaxes the exactness of all information by converting them /// [`Precision::Inexact`]. @@ -692,6 +970,7 @@ impl ColumnStatistics { self.min_value = self.min_value.to_inexact(); self.sum_value = self.sum_value.to_inexact(); self.distinct_count = self.distinct_count.to_inexact(); + self.byte_size = self.byte_size.to_inexact(); self } } @@ -781,11 +1060,21 @@ mod tests { let precision2 = Precision::Inexact(23); let precision3 = Precision::Exact(30); let absent_precision = Precision::Absent; + let precision_max_exact = Precision::Exact(usize::MAX); + let precision_max_inexact = Precision::Exact(usize::MAX); assert_eq!(precision1.add(&precision2), Precision::Inexact(65)); assert_eq!(precision1.add(&precision3), Precision::Exact(72)); assert_eq!(precision2.add(&precision3), Precision::Inexact(53)); assert_eq!(precision1.add(&absent_precision), Precision::Absent); + assert_eq!( + precision_max_exact.add(&precision1), + Precision::Inexact(usize::MAX) + ); + assert_eq!( + precision_max_inexact.add(&precision1), + Precision::Inexact(usize::MAX) + ); } #[test] @@ -817,6 +1106,8 @@ mod tests { assert_eq!(precision1.sub(&precision2), Precision::Inexact(19)); assert_eq!(precision1.sub(&precision3), Precision::Exact(12)); + assert_eq!(precision2.sub(&precision1), Precision::Inexact(0)); + assert_eq!(precision3.sub(&precision1), Precision::Inexact(0)); assert_eq!(precision1.sub(&absent_precision), Precision::Absent); } @@ -845,12 +1136,22 @@ mod tests { let precision1 = Precision::Exact(6); let precision2 = Precision::Inexact(3); let precision3 = Precision::Exact(5); + let precision_max_exact = Precision::Exact(usize::MAX); + let precision_max_inexact = Precision::Exact(usize::MAX); let absent_precision = Precision::Absent; assert_eq!(precision1.multiply(&precision2), Precision::Inexact(18)); assert_eq!(precision1.multiply(&precision3), Precision::Exact(30)); assert_eq!(precision2.multiply(&precision3), Precision::Inexact(15)); assert_eq!(precision1.multiply(&absent_precision), Precision::Absent); + assert_eq!( + precision_max_exact.multiply(&precision1), + Precision::Inexact(usize::MAX) + ); + assert_eq!( + precision_max_inexact.multiply(&precision1), + Precision::Inexact(usize::MAX) + ); } #[test] @@ -896,9 +1197,11 @@ mod tests { Precision::Exact(ScalarValue::Int64(None)), ); // Overflow returns error - assert!(Precision::Exact(ScalarValue::Int32(Some(256))) - .cast_to(&DataType::Int8) - .is_err()); + assert!( + Precision::Exact(ScalarValue::Int32(Some(256))) + .cast_to(&DataType::Int8) + .is_err() + ); } #[test] @@ -911,15 +1214,13 @@ mod tests { // Precision is not copy (requires .clone()) let precision: Precision = Precision::Exact(ScalarValue::Int64(Some(42))); - // Clippy would complain about this if it were Copy - #[allow(clippy::redundant_clone)] let p2 = precision.clone(); assert_eq!(precision, p2); } #[test] fn test_project_none() { - let projection = None; + let projection: Option> = None; let stats = make_stats(vec![10, 20, 30]).project(projection.as_ref()); assert_eq!(stats, make_stats(vec![10, 20, 30])); } @@ -961,11 +1262,12 @@ mod tests { min_value: Precision::Exact(ScalarValue::Int64(Some(64))), sum_value: Precision::Exact(ScalarValue::Int64(Some(4600))), distinct_count: Precision::Exact(100), + byte_size: Precision::Exact(800), } } #[test] - fn test_try_merge_basic() { + fn test_try_merge() { // Create a schema with two columns let schema = Arc::new(Schema::new(vec![ Field::new("col1", DataType::Int32, false), @@ -983,6 +1285,7 @@ mod tests { min_value: Precision::Exact(ScalarValue::Int32(Some(1))), sum_value: Precision::Exact(ScalarValue::Int32(Some(500))), distinct_count: Precision::Absent, + byte_size: Precision::Exact(40), }, ColumnStatistics { null_count: Precision::Exact(2), @@ -990,6 +1293,7 @@ mod tests { min_value: Precision::Exact(ScalarValue::Int32(Some(10))), sum_value: Precision::Exact(ScalarValue::Int32(Some(1000))), distinct_count: Precision::Absent, + byte_size: Precision::Exact(40), }, ], }; @@ -1004,6 +1308,7 @@ mod tests { min_value: Precision::Exact(ScalarValue::Int32(Some(-10))), sum_value: Precision::Exact(ScalarValue::Int32(Some(600))), distinct_count: Precision::Absent, + byte_size: Precision::Exact(60), }, ColumnStatistics { null_count: Precision::Exact(3), @@ -1011,6 +1316,7 @@ mod tests { min_value: Precision::Exact(ScalarValue::Int32(Some(5))), sum_value: Precision::Exact(ScalarValue::Int32(Some(1200))), distinct_count: Precision::Absent, + byte_size: Precision::Exact(60), }, ], }; @@ -1074,6 +1380,7 @@ mod tests { min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), sum_value: Precision::Exact(ScalarValue::Int32(Some(500))), distinct_count: Precision::Absent, + byte_size: Precision::Exact(40), }], }; @@ -1086,6 +1393,7 @@ mod tests { min_value: Precision::Exact(ScalarValue::Int32(Some(-10))), sum_value: Precision::Absent, distinct_count: Precision::Absent, + byte_size: Precision::Inexact(60), }], }; @@ -1106,7 +1414,7 @@ mod tests { col_stats.min_value, Precision::Inexact(ScalarValue::Int32(Some(-10))) ); - assert!(matches!(col_stats.sum_value, Precision::Absent)); + assert_eq!(col_stats.sum_value, Precision::Absent); } #[test] @@ -1150,6 +1458,1059 @@ mod tests { let items = vec![stats1, stats2]; let e = Statistics::try_merge_iter(&items, &schema).unwrap_err(); - assert_contains!(e.to_string(), "Error during planning: Cannot merge statistics with different number of columns: 0 vs 1"); + assert_contains!( + e.to_string(), + "Error during planning: Cannot merge statistics with different number of columns: 0 vs 1" + ); + } + + #[test] + fn test_try_merge_distinct_count_absent() { + // Create statistics with known distinct counts + let stats1 = Statistics::default() + .with_num_rows(Precision::Exact(10)) + .with_total_byte_size(Precision::Exact(100)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_null_count(Precision::Exact(0)) + .with_min_value(Precision::Exact(ScalarValue::Int32(Some(1)))) + .with_max_value(Precision::Exact(ScalarValue::Int32(Some(10)))) + .with_distinct_count(Precision::Exact(5)), + ); + + let stats2 = Statistics::default() + .with_num_rows(Precision::Exact(15)) + .with_total_byte_size(Precision::Exact(150)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_null_count(Precision::Exact(0)) + .with_min_value(Precision::Exact(ScalarValue::Int32(Some(5)))) + .with_max_value(Precision::Exact(ScalarValue::Int32(Some(20)))) + .with_distinct_count(Precision::Exact(7)), + ); + + // Merge statistics + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let merged_stats = + Statistics::try_merge_iter([&stats1, &stats2], &schema).unwrap(); + + // Verify the results + assert_eq!(merged_stats.num_rows, Precision::Exact(25)); + assert_eq!(merged_stats.total_byte_size, Precision::Exact(250)); + + let col_stats = &merged_stats.column_statistics[0]; + assert_eq!(col_stats.null_count, Precision::Exact(0)); + assert_eq!( + col_stats.min_value, + Precision::Exact(ScalarValue::Int32(Some(1))) + ); + assert_eq!( + col_stats.max_value, + Precision::Exact(ScalarValue::Int32(Some(20))) + ); + // Overlap-based NDV: ranges [1,10] and [5,20], overlap [5,10] + // range_left=9, range_right=15, overlap=5 + // overlap_left=5*(5/9)=2.78, overlap_right=7*(5/15)=2.33 + // result = max(2.78, 2.33) + (5-2.78) + (7-2.33) = 9.67 -> 10 + assert_eq!(col_stats.distinct_count, Precision::Inexact(10)); + } + + #[test] + fn test_try_merge_ndv_disjoint_ranges() { + let stats1 = Statistics::default() + .with_num_rows(Precision::Exact(10)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::Int32(Some(0)))) + .with_max_value(Precision::Exact(ScalarValue::Int32(Some(10)))) + .with_distinct_count(Precision::Exact(5)), + ); + let stats2 = Statistics::default() + .with_num_rows(Precision::Exact(10)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::Int32(Some(20)))) + .with_max_value(Precision::Exact(ScalarValue::Int32(Some(30)))) + .with_distinct_count(Precision::Exact(8)), + ); + + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let merged = Statistics::try_merge_iter([&stats1, &stats2], &schema).unwrap(); + // No overlap -> sum of NDVs + assert_eq!( + merged.column_statistics[0].distinct_count, + Precision::Inexact(13) + ); + } + + #[test] + fn test_try_merge_ndv_identical_ranges() { + let stats1 = Statistics::default() + .with_num_rows(Precision::Exact(100)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::Int32(Some(0)))) + .with_max_value(Precision::Exact(ScalarValue::Int32(Some(100)))) + .with_distinct_count(Precision::Exact(50)), + ); + let stats2 = Statistics::default() + .with_num_rows(Precision::Exact(100)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::Int32(Some(0)))) + .with_max_value(Precision::Exact(ScalarValue::Int32(Some(100)))) + .with_distinct_count(Precision::Exact(30)), + ); + + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let merged = Statistics::try_merge_iter([&stats1, &stats2], &schema).unwrap(); + // Full overlap -> max(50, 30) = 50 + assert_eq!( + merged.column_statistics[0].distinct_count, + Precision::Inexact(50) + ); + } + + #[test] + fn test_try_merge_ndv_partial_overlap() { + let stats1 = Statistics::default() + .with_num_rows(Precision::Exact(100)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::Int32(Some(0)))) + .with_max_value(Precision::Exact(ScalarValue::Int32(Some(100)))) + .with_distinct_count(Precision::Exact(80)), + ); + let stats2 = Statistics::default() + .with_num_rows(Precision::Exact(100)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::Int32(Some(50)))) + .with_max_value(Precision::Exact(ScalarValue::Int32(Some(150)))) + .with_distinct_count(Precision::Exact(60)), + ); + + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let merged = Statistics::try_merge_iter([&stats1, &stats2], &schema).unwrap(); + // overlap=[50,100], range_left=100, range_right=100, overlap_range=50 + // overlap_left=80*(50/100)=40, overlap_right=60*(50/100)=30 + // result = max(40,30) + (80-40) + (60-30) = 40 + 40 + 30 = 110 + assert_eq!( + merged.column_statistics[0].distinct_count, + Precision::Inexact(110) + ); + } + + #[test] + fn test_try_merge_ndv_missing_min_max() { + let stats1 = Statistics::default() + .with_num_rows(Precision::Exact(10)) + .add_column_statistics( + ColumnStatistics::new_unknown().with_distinct_count(Precision::Exact(5)), + ); + let stats2 = Statistics::default() + .with_num_rows(Precision::Exact(10)) + .add_column_statistics( + ColumnStatistics::new_unknown().with_distinct_count(Precision::Exact(8)), + ); + + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let merged = Statistics::try_merge_iter([&stats1, &stats2], &schema).unwrap(); + // No min/max -> fallback to max(5, 8) + assert_eq!( + merged.column_statistics[0].distinct_count, + Precision::Inexact(8) + ); + } + + #[test] + fn test_try_merge_ndv_non_numeric_types() { + let stats1 = Statistics::default() + .with_num_rows(Precision::Exact(10)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::Utf8(Some( + "aaa".to_string(), + )))) + .with_max_value(Precision::Exact(ScalarValue::Utf8(Some( + "zzz".to_string(), + )))) + .with_distinct_count(Precision::Exact(5)), + ); + let stats2 = Statistics::default() + .with_num_rows(Precision::Exact(10)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::Utf8(Some( + "bbb".to_string(), + )))) + .with_max_value(Precision::Exact(ScalarValue::Utf8(Some( + "yyy".to_string(), + )))) + .with_distinct_count(Precision::Exact(8)), + ); + + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); + let merged = Statistics::try_merge_iter([&stats1, &stats2], &schema).unwrap(); + // distance() unsupported for strings -> fallback to max + assert_eq!( + merged.column_statistics[0].distinct_count, + Precision::Inexact(8) + ); + } + + #[test] + fn test_try_merge_ndv_constant_columns() { + // Same constant: [5,5]+[5,5] -> max + let stats1 = Statistics::default() + .with_num_rows(Precision::Exact(10)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::Int32(Some(5)))) + .with_max_value(Precision::Exact(ScalarValue::Int32(Some(5)))) + .with_distinct_count(Precision::Exact(1)), + ); + let stats2 = Statistics::default() + .with_num_rows(Precision::Exact(10)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::Int32(Some(5)))) + .with_max_value(Precision::Exact(ScalarValue::Int32(Some(5)))) + .with_distinct_count(Precision::Exact(1)), + ); + + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let merged = Statistics::try_merge_iter([&stats1, &stats2], &schema).unwrap(); + assert_eq!( + merged.column_statistics[0].distinct_count, + Precision::Inexact(1) + ); + + // Different constants: [5,5]+[10,10] -> sum + let stats3 = Statistics::default() + .with_num_rows(Precision::Exact(10)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::Int32(Some(5)))) + .with_max_value(Precision::Exact(ScalarValue::Int32(Some(5)))) + .with_distinct_count(Precision::Exact(1)), + ); + let stats4 = Statistics::default() + .with_num_rows(Precision::Exact(10)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::Int32(Some(10)))) + .with_max_value(Precision::Exact(ScalarValue::Int32(Some(10)))) + .with_distinct_count(Precision::Exact(1)), + ); + + let merged = Statistics::try_merge_iter([&stats3, &stats4], &schema).unwrap(); + assert_eq!( + merged.column_statistics[0].distinct_count, + Precision::Inexact(2) + ); + } + + #[test] + fn test_with_fetch_basic_preservation() { + // Test that column statistics and byte size are preserved (as inexact) when applying fetch + let original_stats = Statistics { + num_rows: Precision::Exact(1000), + total_byte_size: Precision::Exact(8000), + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(10), + max_value: Precision::Exact(ScalarValue::Int32(Some(100))), + min_value: Precision::Exact(ScalarValue::Int32(Some(0))), + sum_value: Precision::Exact(ScalarValue::Int32(Some(5050))), + distinct_count: Precision::Exact(50), + byte_size: Precision::Exact(4000), + }, + ColumnStatistics { + null_count: Precision::Exact(20), + max_value: Precision::Exact(ScalarValue::Int64(Some(200))), + min_value: Precision::Exact(ScalarValue::Int64(Some(10))), + sum_value: Precision::Exact(ScalarValue::Int64(Some(10100))), + distinct_count: Precision::Exact(75), + byte_size: Precision::Exact(8000), + }, + ], + }; + + // Apply fetch of 100 rows (10% of original) + let result = original_stats.clone().with_fetch(Some(100), 0, 1).unwrap(); + + // Check num_rows + assert_eq!(result.num_rows, Precision::Exact(100)); + + // Check total_byte_size is computed as sum of scaled column byte_size values + // Column 1: 4000 * 0.1 = 400, Column 2: 8000 * 0.1 = 800, Sum = 1200 + assert_eq!(result.total_byte_size, Precision::Inexact(1200)); + + // Check column statistics are preserved but marked as inexact + assert_eq!(result.column_statistics.len(), 2); + + // First column + assert_eq!( + result.column_statistics[0].null_count, + Precision::Inexact(10) + ); + assert_eq!( + result.column_statistics[0].max_value, + Precision::Inexact(ScalarValue::Int32(Some(100))) + ); + assert_eq!( + result.column_statistics[0].min_value, + Precision::Inexact(ScalarValue::Int32(Some(0))) + ); + assert_eq!( + result.column_statistics[0].sum_value, + Precision::Inexact(ScalarValue::Int32(Some(5050))) + ); + assert_eq!( + result.column_statistics[0].distinct_count, + Precision::Inexact(50) + ); + + // Second column + assert_eq!( + result.column_statistics[1].null_count, + Precision::Inexact(20) + ); + assert_eq!( + result.column_statistics[1].max_value, + Precision::Inexact(ScalarValue::Int64(Some(200))) + ); + assert_eq!( + result.column_statistics[1].min_value, + Precision::Inexact(ScalarValue::Int64(Some(10))) + ); + assert_eq!( + result.column_statistics[1].sum_value, + Precision::Inexact(ScalarValue::Int64(Some(10100))) + ); + assert_eq!( + result.column_statistics[1].distinct_count, + Precision::Inexact(75) + ); + } + + #[test] + fn test_with_fetch_inexact_input() { + // Test that inexact input statistics remain inexact + let original_stats = Statistics { + num_rows: Precision::Inexact(1000), + total_byte_size: Precision::Inexact(8000), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Inexact(10), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), + min_value: Precision::Inexact(ScalarValue::Int32(Some(0))), + sum_value: Precision::Inexact(ScalarValue::Int32(Some(5050))), + distinct_count: Precision::Inexact(50), + byte_size: Precision::Inexact(4000), + }], + }; + + let result = original_stats.clone().with_fetch(Some(500), 0, 1).unwrap(); + + // Check num_rows is inexact + assert_eq!(result.num_rows, Precision::Inexact(500)); + + // Check total_byte_size is computed as sum of scaled column byte_size values + // Column 1: 4000 * 0.5 = 2000, Sum = 2000 + assert_eq!(result.total_byte_size, Precision::Inexact(2000)); + + // Column stats remain inexact + assert_eq!( + result.column_statistics[0].null_count, + Precision::Inexact(10) + ); + } + + #[test] + fn test_with_fetch_skip_all_rows() { + // Test when skip >= num_rows (all rows are skipped) + let original_stats = Statistics { + num_rows: Precision::Exact(100), + total_byte_size: Precision::Exact(800), + column_statistics: vec![col_stats_i64(10)], + }; + + let result = original_stats.clone().with_fetch(Some(50), 100, 1).unwrap(); + + assert_eq!(result.num_rows, Precision::Exact(0)); + // When ratio is 0/100 = 0, byte size should be 0 + assert_eq!(result.total_byte_size, Precision::Inexact(0)); + } + + #[test] + fn test_with_fetch_no_limit() { + // Test when fetch is None and skip is 0 (no limit applied) + let original_stats = Statistics { + num_rows: Precision::Exact(100), + total_byte_size: Precision::Exact(800), + column_statistics: vec![col_stats_i64(10)], + }; + + let result = original_stats.clone().with_fetch(None, 0, 1).unwrap(); + + // Stats should be unchanged when no fetch and no skip + assert_eq!(result.num_rows, Precision::Exact(100)); + assert_eq!(result.total_byte_size, Precision::Exact(800)); + } + + #[test] + fn test_with_fetch_with_skip() { + // Test with both skip and fetch + let original_stats = Statistics { + num_rows: Precision::Exact(1000), + total_byte_size: Precision::Exact(8000), + column_statistics: vec![col_stats_i64(10)], + }; + + // Skip 200, fetch 300, so we get rows 200-500 + let result = original_stats + .clone() + .with_fetch(Some(300), 200, 1) + .unwrap(); + + assert_eq!(result.num_rows, Precision::Exact(300)); + // Column 1: byte_size 800 * (300/500) = 240, Sum = 240 + assert_eq!(result.total_byte_size, Precision::Inexact(240)); + } + + #[test] + fn test_with_fetch_multi_partition() { + // Test with multiple partitions + let original_stats = Statistics { + num_rows: Precision::Exact(1000), // per partition + total_byte_size: Precision::Exact(8000), + column_statistics: vec![col_stats_i64(10)], + }; + + // Fetch 100 per partition, 4 partitions = 400 total + let result = original_stats.clone().with_fetch(Some(100), 0, 4).unwrap(); + + assert_eq!(result.num_rows, Precision::Exact(400)); + // Column 1: byte_size 800 * 0.4 = 320, Sum = 320 + assert_eq!(result.total_byte_size, Precision::Inexact(320)); + } + + #[test] + fn test_with_fetch_absent_stats() { + // Test with absent statistics + let original_stats = Statistics { + num_rows: Precision::Absent, + total_byte_size: Precision::Absent, + column_statistics: vec![ColumnStatistics { + null_count: Precision::Absent, + max_value: Precision::Absent, + min_value: Precision::Absent, + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Absent, + }], + }; + + let result = original_stats.clone().with_fetch(Some(100), 0, 1).unwrap(); + + // With absent input stats, output should be inexact estimate + assert_eq!(result.num_rows, Precision::Inexact(100)); + assert_eq!(result.total_byte_size, Precision::Absent); + // Column stats should remain absent + assert_eq!(result.column_statistics[0].null_count, Precision::Absent); + } + + #[test] + fn test_with_fetch_fetch_exceeds_rows() { + // Test when fetch is larger than available rows after skip + let original_stats = Statistics { + num_rows: Precision::Exact(100), + total_byte_size: Precision::Exact(800), + column_statistics: vec![col_stats_i64(10)], + }; + + // Skip 50, fetch 100, but only 50 rows remain + let result = original_stats.clone().with_fetch(Some(100), 50, 1).unwrap(); + + assert_eq!(result.num_rows, Precision::Exact(50)); + // 50/100 = 0.5, so 800 * 0.5 = 400 + assert_eq!(result.total_byte_size, Precision::Inexact(400)); + } + + #[test] + fn test_with_fetch_preserves_all_column_stats() { + // Comprehensive test that all column statistic fields are preserved + let original_col_stats = ColumnStatistics { + null_count: Precision::Exact(42), + max_value: Precision::Exact(ScalarValue::Int32(Some(999))), + min_value: Precision::Exact(ScalarValue::Int32(Some(-100))), + sum_value: Precision::Exact(ScalarValue::Int32(Some(123456))), + distinct_count: Precision::Exact(789), + byte_size: Precision::Exact(4000), + }; + + let original_stats = Statistics { + num_rows: Precision::Exact(1000), + total_byte_size: Precision::Exact(8000), + column_statistics: vec![original_col_stats.clone()], + }; + + let result = original_stats.with_fetch(Some(250), 0, 1).unwrap(); + + let result_col_stats = &result.column_statistics[0]; + + // All values should be preserved but marked as inexact + assert_eq!(result_col_stats.null_count, Precision::Inexact(42)); + assert_eq!( + result_col_stats.max_value, + Precision::Inexact(ScalarValue::Int32(Some(999))) + ); + assert_eq!( + result_col_stats.min_value, + Precision::Inexact(ScalarValue::Int32(Some(-100))) + ); + assert_eq!( + result_col_stats.sum_value, + Precision::Inexact(ScalarValue::Int32(Some(123456))) + ); + assert_eq!(result_col_stats.distinct_count, Precision::Inexact(789)); + } + + #[test] + fn test_byte_size_to_inexact() { + let col_stats = ColumnStatistics { + null_count: Precision::Exact(10), + max_value: Precision::Absent, + min_value: Precision::Absent, + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(5000), + }; + + let inexact = col_stats.to_inexact(); + assert_eq!(inexact.byte_size, Precision::Inexact(5000)); + } + + #[test] + fn test_with_byte_size_builder() { + let col_stats = + ColumnStatistics::new_unknown().with_byte_size(Precision::Exact(8192)); + assert_eq!(col_stats.byte_size, Precision::Exact(8192)); + } + + #[test] + fn test_with_fetch_scales_byte_size() { + // Test that byte_size is scaled by the row ratio in with_fetch + let original_stats = Statistics { + num_rows: Precision::Exact(1000), + total_byte_size: Precision::Exact(8000), + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(10), + max_value: Precision::Absent, + min_value: Precision::Absent, + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(4000), + }, + ColumnStatistics { + null_count: Precision::Exact(20), + max_value: Precision::Absent, + min_value: Precision::Absent, + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(8000), + }, + ], + }; + + // Apply fetch of 100 rows (10% of original) + let result = original_stats.with_fetch(Some(100), 0, 1).unwrap(); + + // byte_size should be scaled: 4000 * 0.1 = 400, 8000 * 0.1 = 800 + assert_eq!( + result.column_statistics[0].byte_size, + Precision::Inexact(400) + ); + assert_eq!( + result.column_statistics[1].byte_size, + Precision::Inexact(800) + ); + + // total_byte_size should be computed as sum of byte_size values: 400 + 800 = 1200 + assert_eq!(result.total_byte_size, Precision::Inexact(1200)); + } + + #[test] + fn test_with_fetch_total_byte_size_fallback() { + // Test that total_byte_size falls back to scaling when not all columns have byte_size + let original_stats = Statistics { + num_rows: Precision::Exact(1000), + total_byte_size: Precision::Exact(8000), + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(10), + max_value: Precision::Absent, + min_value: Precision::Absent, + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(4000), + }, + ColumnStatistics { + null_count: Precision::Exact(20), + max_value: Precision::Absent, + min_value: Precision::Absent, + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Absent, // One column has no byte_size + }, + ], + }; + + // Apply fetch of 100 rows (10% of original) + let result = original_stats.with_fetch(Some(100), 0, 1).unwrap(); + + // total_byte_size should fall back to scaling: 8000 * 0.1 = 800 + assert_eq!(result.total_byte_size, Precision::Inexact(800)); + } + + #[test] + fn test_try_merge_iter_basic() { + let schema = Arc::new(Schema::new(vec![ + Field::new("col1", DataType::Int32, false), + Field::new("col2", DataType::Int32, false), + ])); + + let stats1 = Statistics { + num_rows: Precision::Exact(10), + total_byte_size: Precision::Exact(100), + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(1), + max_value: Precision::Exact(ScalarValue::Int32(Some(100))), + min_value: Precision::Exact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Exact(ScalarValue::Int32(Some(500))), + distinct_count: Precision::Absent, + byte_size: Precision::Exact(40), + }, + ColumnStatistics { + null_count: Precision::Exact(2), + max_value: Precision::Exact(ScalarValue::Int32(Some(200))), + min_value: Precision::Exact(ScalarValue::Int32(Some(10))), + sum_value: Precision::Exact(ScalarValue::Int32(Some(1000))), + distinct_count: Precision::Absent, + byte_size: Precision::Exact(40), + }, + ], + }; + + let stats2 = Statistics { + num_rows: Precision::Exact(15), + total_byte_size: Precision::Exact(150), + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(2), + max_value: Precision::Exact(ScalarValue::Int32(Some(120))), + min_value: Precision::Exact(ScalarValue::Int32(Some(-10))), + sum_value: Precision::Exact(ScalarValue::Int32(Some(600))), + distinct_count: Precision::Absent, + byte_size: Precision::Exact(60), + }, + ColumnStatistics { + null_count: Precision::Exact(3), + max_value: Precision::Exact(ScalarValue::Int32(Some(180))), + min_value: Precision::Exact(ScalarValue::Int32(Some(5))), + sum_value: Precision::Exact(ScalarValue::Int32(Some(1200))), + distinct_count: Precision::Absent, + byte_size: Precision::Exact(60), + }, + ], + }; + + let items = vec![&stats1, &stats2]; + let summary_stats = Statistics::try_merge_iter(items, &schema).unwrap(); + + assert_eq!(summary_stats.num_rows, Precision::Exact(25)); + assert_eq!(summary_stats.total_byte_size, Precision::Exact(250)); + + let col1_stats = &summary_stats.column_statistics[0]; + assert_eq!(col1_stats.null_count, Precision::Exact(3)); + assert_eq!( + col1_stats.max_value, + Precision::Exact(ScalarValue::Int32(Some(120))) + ); + assert_eq!( + col1_stats.min_value, + Precision::Exact(ScalarValue::Int32(Some(-10))) + ); + assert_eq!( + col1_stats.sum_value, + Precision::Exact(ScalarValue::Int32(Some(1100))) + ); + + let col2_stats = &summary_stats.column_statistics[1]; + assert_eq!(col2_stats.null_count, Precision::Exact(5)); + assert_eq!( + col2_stats.max_value, + Precision::Exact(ScalarValue::Int32(Some(200))) + ); + assert_eq!( + col2_stats.min_value, + Precision::Exact(ScalarValue::Int32(Some(5))) + ); + assert_eq!( + col2_stats.sum_value, + Precision::Exact(ScalarValue::Int32(Some(2200))) + ); + } + + #[test] + fn test_try_merge_iter_mixed_precision() { + let schema = Arc::new(Schema::new(vec![Field::new( + "col1", + DataType::Int32, + false, + )])); + + let stats1 = Statistics { + num_rows: Precision::Exact(10), + total_byte_size: Precision::Inexact(100), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(1), + max_value: Precision::Exact(ScalarValue::Int32(Some(100))), + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Exact(ScalarValue::Int32(Some(500))), + distinct_count: Precision::Absent, + byte_size: Precision::Exact(40), + }], + }; + + let stats2 = Statistics { + num_rows: Precision::Inexact(15), + total_byte_size: Precision::Exact(150), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Inexact(2), + max_value: Precision::Inexact(ScalarValue::Int32(Some(120))), + min_value: Precision::Exact(ScalarValue::Int32(Some(-10))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Inexact(60), + }], + }; + + let items = vec![&stats1, &stats2]; + let summary_stats = Statistics::try_merge_iter(items, &schema).unwrap(); + + assert_eq!(summary_stats.num_rows, Precision::Inexact(25)); + assert_eq!(summary_stats.total_byte_size, Precision::Inexact(250)); + + let col_stats = &summary_stats.column_statistics[0]; + assert_eq!(col_stats.null_count, Precision::Inexact(3)); + assert_eq!( + col_stats.max_value, + Precision::Inexact(ScalarValue::Int32(Some(120))) + ); + assert_eq!( + col_stats.min_value, + Precision::Inexact(ScalarValue::Int32(Some(-10))) + ); + // sum_value becomes Absent because stats2 has Absent sum + assert_eq!(col_stats.sum_value, Precision::Absent); + } + + #[test] + fn test_try_merge_iter_empty() { + let schema = Arc::new(Schema::new(vec![Field::new( + "col1", + DataType::Int32, + false, + )])); + + let items: Vec<&Statistics> = vec![]; + let summary_stats = Statistics::try_merge_iter(items, &schema).unwrap(); + + assert_eq!(summary_stats.num_rows, Precision::Absent); + assert_eq!(summary_stats.total_byte_size, Precision::Absent); + assert_eq!(summary_stats.column_statistics.len(), 1); + assert_eq!( + summary_stats.column_statistics[0].null_count, + Precision::Absent + ); + } + + #[test] + fn test_try_merge_iter_single_item() { + let schema = Arc::new(Schema::new(vec![Field::new( + "col1", + DataType::Int32, + false, + )])); + + let stats = Statistics { + num_rows: Precision::Exact(10), + total_byte_size: Precision::Exact(100), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(1), + max_value: Precision::Exact(ScalarValue::Int32(Some(100))), + min_value: Precision::Exact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Exact(ScalarValue::Int32(Some(500))), + distinct_count: Precision::Exact(10), + byte_size: Precision::Exact(40), + }], + }; + + let items = vec![&stats]; + let summary_stats = Statistics::try_merge_iter(items, &schema).unwrap(); + + assert_eq!(summary_stats, stats); + } + + #[test] + fn test_try_merge_iter_mismatched_columns() { + let schema = Arc::new(Schema::new(vec![Field::new( + "col1", + DataType::Int32, + false, + )])); + + let stats1 = Statistics::default(); + let stats2 = + Statistics::default().add_column_statistics(ColumnStatistics::new_unknown()); + + let items = vec![&stats1, &stats2]; + let e = Statistics::try_merge_iter(items, &schema).unwrap_err(); + assert_contains!( + e.to_string(), + "Cannot merge statistics with different number of columns: 0 vs 1" + ); + } + + #[test] + fn test_try_merge_iter_three_items() { + // Verify that merging three items works correctly + let schema = Arc::new(Schema::new(vec![Field::new( + "col1", + DataType::Int64, + false, + )])); + + let stats1 = Statistics { + num_rows: Precision::Exact(10), + total_byte_size: Precision::Exact(100), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(1), + max_value: Precision::Exact(ScalarValue::Int64(Some(100))), + min_value: Precision::Exact(ScalarValue::Int64(Some(10))), + sum_value: Precision::Exact(ScalarValue::Int64(Some(500))), + distinct_count: Precision::Exact(8), + byte_size: Precision::Exact(80), + }], + }; + + let stats2 = Statistics { + num_rows: Precision::Exact(20), + total_byte_size: Precision::Exact(200), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(2), + max_value: Precision::Exact(ScalarValue::Int64(Some(200))), + min_value: Precision::Exact(ScalarValue::Int64(Some(5))), + sum_value: Precision::Exact(ScalarValue::Int64(Some(1000))), + distinct_count: Precision::Exact(15), + byte_size: Precision::Exact(160), + }], + }; + + let stats3 = Statistics { + num_rows: Precision::Exact(30), + total_byte_size: Precision::Exact(300), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(3), + max_value: Precision::Exact(ScalarValue::Int64(Some(150))), + min_value: Precision::Exact(ScalarValue::Int64(Some(1))), + sum_value: Precision::Exact(ScalarValue::Int64(Some(2000))), + distinct_count: Precision::Exact(25), + byte_size: Precision::Exact(240), + }], + }; + + let items = vec![&stats1, &stats2, &stats3]; + let summary_stats = Statistics::try_merge_iter(items, &schema).unwrap(); + + assert_eq!(summary_stats.num_rows, Precision::Exact(60)); + assert_eq!(summary_stats.total_byte_size, Precision::Exact(600)); + + let col_stats = &summary_stats.column_statistics[0]; + assert_eq!(col_stats.null_count, Precision::Exact(6)); + assert_eq!( + col_stats.max_value, + Precision::Exact(ScalarValue::Int64(Some(200))) + ); + assert_eq!( + col_stats.min_value, + Precision::Exact(ScalarValue::Int64(Some(1))) + ); + assert_eq!( + col_stats.sum_value, + Precision::Exact(ScalarValue::Int64(Some(3500))) + ); + assert_eq!(col_stats.byte_size, Precision::Exact(480)); + // Overlap-based NDV merge (pairwise left-to-right): + // stats1+stats2: [10,100]+[5,200] -> NDV=16, then +stats3: [5,200]+[1,150] -> NDV=29 + assert_eq!(col_stats.distinct_count, Precision::Inexact(29)); + } + + #[test] + fn test_try_merge_iter_float_types() { + let schema = Arc::new(Schema::new(vec![Field::new( + "col1", + DataType::Float64, + false, + )])); + + let stats1 = Statistics { + num_rows: Precision::Exact(10), + total_byte_size: Precision::Exact(80), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Float64(Some(99.9))), + min_value: Precision::Exact(ScalarValue::Float64(Some(1.1))), + sum_value: Precision::Exact(ScalarValue::Float64(Some(500.5))), + distinct_count: Precision::Absent, + byte_size: Precision::Exact(80), + }], + }; + + let stats2 = Statistics { + num_rows: Precision::Exact(10), + total_byte_size: Precision::Exact(80), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Float64(Some(200.0))), + min_value: Precision::Exact(ScalarValue::Float64(Some(0.5))), + sum_value: Precision::Exact(ScalarValue::Float64(Some(1000.0))), + distinct_count: Precision::Absent, + byte_size: Precision::Exact(80), + }], + }; + + let items = vec![&stats1, &stats2]; + let summary_stats = Statistics::try_merge_iter(items, &schema).unwrap(); + + let col_stats = &summary_stats.column_statistics[0]; + assert_eq!( + col_stats.max_value, + Precision::Exact(ScalarValue::Float64(Some(200.0))) + ); + assert_eq!( + col_stats.min_value, + Precision::Exact(ScalarValue::Float64(Some(0.5))) + ); + assert_eq!( + col_stats.sum_value, + Precision::Exact(ScalarValue::Float64(Some(1500.5))) + ); + } + + #[test] + fn test_try_merge_iter_string_types() { + let schema = + Arc::new(Schema::new(vec![Field::new("col1", DataType::Utf8, false)])); + + let stats1 = Statistics { + num_rows: Precision::Exact(10), + total_byte_size: Precision::Exact(100), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Utf8(Some("dog".to_string()))), + min_value: Precision::Exact(ScalarValue::Utf8(Some("ant".to_string()))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(100), + }], + }; + + let stats2 = Statistics { + num_rows: Precision::Exact(10), + total_byte_size: Precision::Exact(100), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Utf8(Some("zebra".to_string()))), + min_value: Precision::Exact(ScalarValue::Utf8(Some("bat".to_string()))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(100), + }], + }; + + let items = vec![&stats1, &stats2]; + let summary_stats = Statistics::try_merge_iter(items, &schema).unwrap(); + + let col_stats = &summary_stats.column_statistics[0]; + assert_eq!( + col_stats.max_value, + Precision::Exact(ScalarValue::Utf8(Some("zebra".to_string()))) + ); + assert_eq!( + col_stats.min_value, + Precision::Exact(ScalarValue::Utf8(Some("ant".to_string()))) + ); + assert_eq!(col_stats.sum_value, Precision::Absent); + } + + #[test] + fn test_try_merge_iter_all_inexact() { + let schema = Arc::new(Schema::new(vec![Field::new( + "col1", + DataType::Int32, + false, + )])); + + let stats1 = Statistics { + num_rows: Precision::Inexact(10), + total_byte_size: Precision::Inexact(100), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Inexact(1), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Inexact(ScalarValue::Int32(Some(500))), + distinct_count: Precision::Absent, + byte_size: Precision::Inexact(40), + }], + }; + + let stats2 = Statistics { + num_rows: Precision::Inexact(20), + total_byte_size: Precision::Inexact(200), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Inexact(2), + max_value: Precision::Inexact(ScalarValue::Int32(Some(200))), + min_value: Precision::Inexact(ScalarValue::Int32(Some(-5))), + sum_value: Precision::Inexact(ScalarValue::Int32(Some(1000))), + distinct_count: Precision::Absent, + byte_size: Precision::Inexact(60), + }], + }; + + let items = vec![&stats1, &stats2]; + let summary_stats = Statistics::try_merge_iter(items, &schema).unwrap(); + + assert_eq!(summary_stats.num_rows, Precision::Inexact(30)); + assert_eq!(summary_stats.total_byte_size, Precision::Inexact(300)); + + let col_stats = &summary_stats.column_statistics[0]; + assert_eq!(col_stats.null_count, Precision::Inexact(3)); + assert_eq!( + col_stats.max_value, + Precision::Inexact(ScalarValue::Int32(Some(200))) + ); + assert_eq!( + col_stats.min_value, + Precision::Inexact(ScalarValue::Int32(Some(-5))) + ); + assert_eq!( + col_stats.sum_value, + Precision::Inexact(ScalarValue::Int32(Some(1500))) + ); } } diff --git a/datafusion/common/src/table_reference.rs b/datafusion/common/src/table_reference.rs index 9b6f9696c00bb..3163a8b16c8dc 100644 --- a/datafusion/common/src/table_reference.rs +++ b/datafusion/common/src/table_reference.rs @@ -15,7 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::utils::{parse_identifiers_normalized, quote_identifier}; +use crate::utils::parse_identifiers_normalized; +use crate::utils::quote_identifier; use std::sync::Arc; /// A fully resolved path to a table of the form "catalog.schema.table" @@ -68,8 +69,11 @@ impl std::fmt::Display for ResolvedTableReference { /// /// // Get a table reference to 'myschema.mytable' (note the capitalization) /// let table_reference = TableReference::from("MySchema.MyTable"); -/// assert_eq!(table_reference, TableReference::partial("myschema", "mytable")); -///``` +/// assert_eq!( +/// table_reference, +/// TableReference::partial("myschema", "mytable") +/// ); +/// ``` #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] pub enum TableReference { /// An unqualified table reference, e.g. "table" @@ -246,7 +250,10 @@ impl TableReference { /// assert_eq!(table_reference.to_quoted_string(), "myschema.mytable"); /// /// let table_reference = TableReference::partial("MySchema", "MyTable"); - /// assert_eq!(table_reference.to_quoted_string(), r#""MySchema"."MyTable""#); + /// assert_eq!( + /// table_reference.to_quoted_string(), + /// r#""MySchema"."MyTable""# + /// ); /// ``` pub fn to_quoted_string(&self) -> String { match self { @@ -268,24 +275,41 @@ impl TableReference { } /// Forms a [`TableReference`] by parsing `s` as a multipart SQL - /// identifier. See docs on [`TableReference`] for more details. + /// identifier, normalizing `s` to lowercase. + /// See docs on [`TableReference`] for more details. pub fn parse_str(s: &str) -> Self { - let mut parts = parse_identifiers_normalized(s, false); + Self::parse_str_normalized(s, false) + } + /// Forms a [`TableReference`] by parsing `s` as a multipart SQL + /// identifier, normalizing `s` to lowercase if `ignore_case` is `false`. + /// See docs on [`TableReference`] for more details. + pub fn parse_str_normalized(s: &str, ignore_case: bool) -> Self { + let table_parts = parse_identifiers_normalized(s, ignore_case); + + Self::from_vec(table_parts).unwrap_or_else(|| Self::Bare { table: s.into() }) + } + + /// Consume a vector of identifier parts to compose a [`TableReference`]. The input vector + /// should contain 1 <= N <= 3 elements in the following sequence: + /// ```no_rust + /// [, , table] + /// ``` + fn from_vec(mut parts: Vec) -> Option { match parts.len() { - 1 => Self::Bare { - table: parts.remove(0).into(), - }, - 2 => Self::Partial { - schema: parts.remove(0).into(), - table: parts.remove(0).into(), - }, - 3 => Self::Full { - catalog: parts.remove(0).into(), - schema: parts.remove(0).into(), - table: parts.remove(0).into(), - }, - _ => Self::Bare { table: s.into() }, + 1 => Some(Self::Bare { + table: parts.pop()?.into(), + }), + 2 => Some(Self::Partial { + table: parts.pop()?.into(), + schema: parts.pop()?.into(), + }), + 3 => Some(Self::Full { + table: parts.pop()?.into(), + schema: parts.pop()?.into(), + catalog: parts.pop()?.into(), + }), + _ => None, } } @@ -367,26 +391,32 @@ mod tests { let actual = TableReference::from("TABLE"); assert_eq!(expected, actual); - // if fail to parse, take entire input string as identifier - let expected = TableReference::Bare { - table: "TABLE()".into(), - }; - let actual = TableReference::from("TABLE()"); - assert_eq!(expected, actual); + // Disable this test for non-sql features so that we don't need to reproduce + // things like table function upper case conventions, since those will not + // be used if SQL is not selected. + #[cfg(feature = "sql")] + { + // if fail to parse, take entire input string as identifier + let expected = TableReference::Bare { + table: "TABLE()".into(), + }; + let actual = TableReference::from("TABLE()"); + assert_eq!(expected, actual); + } } #[test] fn test_table_reference_to_vector() { - let table_reference = TableReference::parse_str("table"); + let table_reference = TableReference::from("table"); assert_eq!(vec!["table".to_string()], table_reference.to_vec()); - let table_reference = TableReference::parse_str("schema.table"); + let table_reference = TableReference::from("schema.table"); assert_eq!( vec!["schema".to_string(), "table".to_string()], table_reference.to_vec() ); - let table_reference = TableReference::parse_str("catalog.schema.table"); + let table_reference = TableReference::from("catalog.schema.table"); assert_eq!( vec![ "catalog".to_string(), diff --git a/datafusion/common/src/test_util.rs b/datafusion/common/src/test_util.rs index 820a230bf6e17..f060704944233 100644 --- a/datafusion/common/src/test_util.rs +++ b/datafusion/common/src/test_util.rs @@ -55,7 +55,7 @@ pub fn format_batches(results: &[RecordBatch]) -> Result) +/// Both arguments must be convertible into Strings ([`Into`]<[`String`]>) #[macro_export] macro_rules! assert_contains { ($ACTUAL: expr, $EXPECTED: expr) => { @@ -181,7 +181,7 @@ macro_rules! assert_contains { /// Is a macro so test error /// messages are on the same line as the failure; /// -/// Both arguments must be convertable into Strings ([`Into`]<[`String`]>) +/// Both arguments must be convertible into Strings ([`Into`]<[`String`]>) #[macro_export] macro_rules! assert_not_contains { ($ACTUAL: expr, $UNEXPECTED: expr) => { @@ -255,7 +255,14 @@ pub fn arrow_test_data() -> String { #[cfg(feature = "parquet")] pub fn parquet_test_data() -> String { match get_data_dir("PARQUET_TEST_DATA", "../../parquet-testing/data") { - Ok(pb) => pb.display().to_string(), + Ok(pb) => { + let mut path = pb.display().to_string(); + if cfg!(target_os = "windows") { + // Replace backslashes (Windows paths; avoids some test issues). + path = path.replace("\\", "/"); + } + path + } Err(err) => panic!("failed to get parquet data dir: {err}"), } } @@ -314,43 +321,43 @@ pub fn get_data_dir( #[macro_export] macro_rules! create_array { (Boolean, $values: expr) => { - std::sync::Arc::new(arrow::array::BooleanArray::from($values)) + std::sync::Arc::new($crate::arrow::array::BooleanArray::from($values)) }; (Int8, $values: expr) => { - std::sync::Arc::new(arrow::array::Int8Array::from($values)) + std::sync::Arc::new($crate::arrow::array::Int8Array::from($values)) }; (Int16, $values: expr) => { - std::sync::Arc::new(arrow::array::Int16Array::from($values)) + std::sync::Arc::new($crate::arrow::array::Int16Array::from($values)) }; (Int32, $values: expr) => { - std::sync::Arc::new(arrow::array::Int32Array::from($values)) + std::sync::Arc::new($crate::arrow::array::Int32Array::from($values)) }; (Int64, $values: expr) => { - std::sync::Arc::new(arrow::array::Int64Array::from($values)) + std::sync::Arc::new($crate::arrow::array::Int64Array::from($values)) }; (UInt8, $values: expr) => { - std::sync::Arc::new(arrow::array::UInt8Array::from($values)) + std::sync::Arc::new($crate::arrow::array::UInt8Array::from($values)) }; (UInt16, $values: expr) => { - std::sync::Arc::new(arrow::array::UInt16Array::from($values)) + std::sync::Arc::new($crate::arrow::array::UInt16Array::from($values)) }; (UInt32, $values: expr) => { - std::sync::Arc::new(arrow::array::UInt32Array::from($values)) + std::sync::Arc::new($crate::arrow::array::UInt32Array::from($values)) }; (UInt64, $values: expr) => { - std::sync::Arc::new(arrow::array::UInt64Array::from($values)) + std::sync::Arc::new($crate::arrow::array::UInt64Array::from($values)) }; (Float16, $values: expr) => { - std::sync::Arc::new(arrow::array::Float16Array::from($values)) + std::sync::Arc::new($crate::arrow::array::Float16Array::from($values)) }; (Float32, $values: expr) => { - std::sync::Arc::new(arrow::array::Float32Array::from($values)) + std::sync::Arc::new($crate::arrow::array::Float32Array::from($values)) }; (Float64, $values: expr) => { - std::sync::Arc::new(arrow::array::Float64Array::from($values)) + std::sync::Arc::new($crate::arrow::array::Float64Array::from($values)) }; (Utf8, $values: expr) => { - std::sync::Arc::new(arrow::array::StringArray::from($values)) + std::sync::Arc::new($crate::arrow::array::StringArray::from($values)) }; } @@ -359,7 +366,7 @@ macro_rules! create_array { /// /// Example: /// ``` -/// use datafusion_common::{record_batch, create_array}; +/// use datafusion_common::record_batch; /// let batch = record_batch!( /// ("a", Int32, vec![1, 2, 3]), /// ("b", Float64, vec![Some(4.0), None, Some(5.0)]), @@ -370,13 +377,13 @@ macro_rules! create_array { macro_rules! record_batch { ($(($name: expr, $type: ident, $values: expr)),*) => { { - let schema = std::sync::Arc::new(arrow::datatypes::Schema::new(vec![ + let schema = std::sync::Arc::new($crate::arrow::datatypes::Schema::new(vec![ $( - arrow::datatypes::Field::new($name, arrow::datatypes::DataType::$type, true), + $crate::arrow::datatypes::Field::new($name, $crate::arrow::datatypes::DataType::$type, true), )* ])); - let batch = arrow::array::RecordBatch::try_new( + let batch = $crate::arrow::array::RecordBatch::try_new( schema, vec![$( $crate::create_array!($type, $values), @@ -728,32 +735,34 @@ mod tests { let non_existing = cwd.join("non-existing-dir").display().to_string(); let non_existing_str = non_existing.as_str(); - env::set_var(udf_env, non_existing_str); - let res = get_data_dir(udf_env, existing_str); - assert!(res.is_err()); + unsafe { + env::set_var(udf_env, non_existing_str); + let res = get_data_dir(udf_env, existing_str); + assert!(res.is_err()); - env::set_var(udf_env, ""); - let res = get_data_dir(udf_env, existing_str); - assert!(res.is_ok()); - assert_eq!(res.unwrap(), existing_pb); + env::set_var(udf_env, ""); + let res = get_data_dir(udf_env, existing_str); + assert!(res.is_ok()); + assert_eq!(res.unwrap(), existing_pb); - env::set_var(udf_env, " "); - let res = get_data_dir(udf_env, existing_str); - assert!(res.is_ok()); - assert_eq!(res.unwrap(), existing_pb); + env::set_var(udf_env, " "); + let res = get_data_dir(udf_env, existing_str); + assert!(res.is_ok()); + assert_eq!(res.unwrap(), existing_pb); - env::set_var(udf_env, existing_str); - let res = get_data_dir(udf_env, existing_str); - assert!(res.is_ok()); - assert_eq!(res.unwrap(), existing_pb); + env::set_var(udf_env, existing_str); + let res = get_data_dir(udf_env, existing_str); + assert!(res.is_ok()); + assert_eq!(res.unwrap(), existing_pb); - env::remove_var(udf_env); - let res = get_data_dir(udf_env, non_existing_str); - assert!(res.is_err()); + env::remove_var(udf_env); + let res = get_data_dir(udf_env, non_existing_str); + assert!(res.is_err()); - let res = get_data_dir(udf_env, existing_str); - assert!(res.is_ok()); - assert_eq!(res.unwrap(), existing_pb); + let res = get_data_dir(udf_env, existing_str); + assert!(res.is_ok()); + assert_eq!(res.unwrap(), existing_pb); + } } #[test] diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index cf51dadf6b4ad..1e7c02e424256 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -638,12 +638,13 @@ impl TreeNodeRecursion { /// # fn make_new_expr(i: i64) -> i64 { 2 } /// let expr = orig_expr(); /// let ret = Transformed::no(expr.clone()) -/// .transform_data(|expr| { -/// // closure returns a result and potentially transforms the node -/// // in this example, it does transform the node -/// let new_expr = make_new_expr(expr); -/// Ok(Transformed::yes(new_expr)) -/// }).unwrap(); +/// .transform_data(|expr| { +/// // closure returns a result and potentially transforms the node +/// // in this example, it does transform the node +/// let new_expr = make_new_expr(expr); +/// Ok(Transformed::yes(new_expr)) +/// }) +/// .unwrap(); /// // transformed flag is the union of the original ans closure's transformed flag /// assert!(ret.transformed); /// ``` @@ -680,6 +681,11 @@ impl Transformed { Self::new(data, true, TreeNodeRecursion::Continue) } + /// Wrapper for transformed data with [`TreeNodeRecursion::Stop`] statement. + pub fn complete(data: T) -> Self { + Self::new(data, true, TreeNodeRecursion::Stop) + } + /// Wrapper for unchanged data with [`TreeNodeRecursion::Continue`] statement. pub fn no(data: T) -> Self { Self::new(data, false, TreeNodeRecursion::Continue) @@ -950,12 +956,12 @@ impl<'a, T: 'a, C0: TreeNodeContainer<'a, T>, C1: TreeNodeContainer<'a, T>> } impl< - 'a, - T: 'a, - C0: TreeNodeContainer<'a, T>, - C1: TreeNodeContainer<'a, T>, - C2: TreeNodeContainer<'a, T>, - > TreeNodeContainer<'a, T> for (C0, C1, C2) + 'a, + T: 'a, + C0: TreeNodeContainer<'a, T>, + C1: TreeNodeContainer<'a, T>, + C2: TreeNodeContainer<'a, T>, +> TreeNodeContainer<'a, T> for (C0, C1, C2) { fn apply_elements Result>( &'a self, @@ -985,6 +991,48 @@ impl< } } +impl< + 'a, + T: 'a, + C0: TreeNodeContainer<'a, T>, + C1: TreeNodeContainer<'a, T>, + C2: TreeNodeContainer<'a, T>, + C3: TreeNodeContainer<'a, T>, +> TreeNodeContainer<'a, T> for (C0, C1, C2, C3) +{ + fn apply_elements Result>( + &'a self, + mut f: F, + ) -> Result { + self.0 + .apply_elements(&mut f)? + .visit_sibling(|| self.1.apply_elements(&mut f))? + .visit_sibling(|| self.2.apply_elements(&mut f))? + .visit_sibling(|| self.3.apply_elements(&mut f)) + } + + fn map_elements Result>>( + self, + mut f: F, + ) -> Result> { + self.0 + .map_elements(&mut f)? + .map_data(|new_c0| Ok((new_c0, self.1, self.2, self.3)))? + .transform_sibling(|(new_c0, c1, c2, c3)| { + c1.map_elements(&mut f)? + .map_data(|new_c1| Ok((new_c0, new_c1, c2, c3))) + })? + .transform_sibling(|(new_c0, new_c1, c2, c3)| { + c2.map_elements(&mut f)? + .map_data(|new_c2| Ok((new_c0, new_c1, new_c2, c3))) + })? + .transform_sibling(|(new_c0, new_c1, new_c2, c3)| { + c3.map_elements(&mut f)? + .map_data(|new_c3| Ok((new_c0, new_c1, new_c2, new_c3))) + }) + } +} + /// [`TreeNodeRefContainer`] contains references to elements that a function can be /// applied on. The elements of the container are siblings so the continuation rules are /// similar to [`TreeNodeRecursion::visit_sibling`]. @@ -1042,12 +1090,12 @@ impl<'a, T: 'a, C0: TreeNodeContainer<'a, T>, C1: TreeNodeContainer<'a, T>> } impl< - 'a, - T: 'a, - C0: TreeNodeContainer<'a, T>, - C1: TreeNodeContainer<'a, T>, - C2: TreeNodeContainer<'a, T>, - > TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1, &'a C2) + 'a, + T: 'a, + C0: TreeNodeContainer<'a, T>, + C1: TreeNodeContainer<'a, T>, + C2: TreeNodeContainer<'a, T>, +> TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1, &'a C2) { fn apply_ref_elements Result>( &self, @@ -1060,6 +1108,27 @@ impl< } } +impl< + 'a, + T: 'a, + C0: TreeNodeContainer<'a, T>, + C1: TreeNodeContainer<'a, T>, + C2: TreeNodeContainer<'a, T>, + C3: TreeNodeContainer<'a, T>, +> TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1, &'a C2, &'a C3) +{ + fn apply_ref_elements Result>( + &self, + mut f: F, + ) -> Result { + self.0 + .apply_elements(&mut f)? + .visit_sibling(|| self.1.apply_elements(&mut f))? + .visit_sibling(|| self.2.apply_elements(&mut f))? + .visit_sibling(|| self.3.apply_elements(&mut f)) + } +} + /// Transformation helper to process a sequence of iterable tree nodes that are siblings. pub trait TreeNodeIterator: Iterator { /// Apples `f` to each item in this iterator @@ -1267,11 +1336,11 @@ pub(crate) mod tests { use std::collections::HashMap; use std::fmt::Display; + use crate::Result; use crate::tree_node::{ Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, }; - use crate::Result; #[derive(Debug, Eq, Hash, PartialEq, Clone)] pub struct TestTreeNode { diff --git a/datafusion/common/src/types/builtin.rs b/datafusion/common/src/types/builtin.rs index ec69db7903779..dfd2cc4cf2d8b 100644 --- a/datafusion/common/src/types/builtin.rs +++ b/datafusion/common/src/types/builtin.rs @@ -15,9 +15,18 @@ // specific language governing permissions and limitations // under the License. +use arrow::datatypes::IntervalUnit::*; +use arrow::datatypes::TimeUnit::*; + use crate::types::{LogicalTypeRef, NativeType}; use std::sync::{Arc, LazyLock}; +/// Create a singleton and accompanying static variable for a [`LogicalTypeRef`] +/// of a [`NativeType`]. +/// * `name`: name of the static variable, must be unique. +/// * `getter`: name of the public function that will return the singleton instance +/// of the static variable. +/// * `ty`: the [`NativeType`]. macro_rules! singleton { ($name:ident, $getter:ident, $ty:ident) => { static $name: LazyLock = @@ -31,6 +40,26 @@ macro_rules! singleton { }; } +/// Similar to [`singleton`], but for native types that have variants, such as +/// `NativeType::Interval(MonthDayNano)`. +/// * `name`: name of the static variable, must be unique. +/// * `getter`: name of the public function that will return the singleton instance +/// of the static variable. +/// * `ty`: the [`NativeType`]. +/// * `variant`: specific variant of the `ty`. +macro_rules! singleton_variant { + ($name:ident, $getter:ident, $ty:ident, $variant:ident) => { + static $name: LazyLock = + LazyLock::new(|| Arc::new(NativeType::$ty($variant))); + + #[doc = "Getter for singleton instance of a logical type representing"] + #[doc = concat!("[`NativeType::", stringify!($ty), "`] of unit [`", stringify!($variant),"`].`")] + pub fn $getter() -> LogicalTypeRef { + Arc::clone(&$name) + } + }; +} + singleton!(LOGICAL_NULL, logical_null, Null); singleton!(LOGICAL_BOOLEAN, logical_boolean, Boolean); singleton!(LOGICAL_INT8, logical_int8, Int8); @@ -47,3 +76,24 @@ singleton!(LOGICAL_FLOAT64, logical_float64, Float64); singleton!(LOGICAL_DATE, logical_date, Date); singleton!(LOGICAL_BINARY, logical_binary, Binary); singleton!(LOGICAL_STRING, logical_string, String); + +singleton_variant!( + LOGICAL_INTERVAL_MDN, + logical_interval_mdn, + Interval, + MonthDayNano +); + +singleton_variant!( + LOGICAL_INTERVAL_YEAR_MONTH, + logical_interval_year_month, + Interval, + YearMonth +); + +singleton_variant!( + LOGICAL_DURATION_MICROSECOND, + logical_duration_microsecond, + Duration, + Microsecond +); diff --git a/datafusion/common/src/types/logical.rs b/datafusion/common/src/types/logical.rs index 884ce20fd9e29..0f886252d6452 100644 --- a/datafusion/common/src/types/logical.rs +++ b/datafusion/common/src/types/logical.rs @@ -67,12 +67,12 @@ pub type LogicalTypeRef = Arc; /// &NativeType::String /// } /// -/// fn signature(&self) -> TypeSignature<'_> { -/// TypeSignature::Extension { -/// name: "JSON", -/// parameters: &[], -/// } -/// } +/// fn signature(&self) -> TypeSignature<'_> { +/// TypeSignature::Extension { +/// name: "JSON", +/// parameters: &[], +/// } +/// } /// } /// ``` pub trait LogicalType: Sync + Send { @@ -100,12 +100,16 @@ impl fmt::Debug for dyn LogicalType { impl std::fmt::Display for dyn LogicalType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{self:?}") + match self.signature() { + TypeSignature::Native(_) => write!(f, "{}", self.native()), + TypeSignature::Extension { name, .. } => write!(f, "{name}"), + } } } impl PartialEq for dyn LogicalType { fn eq(&self, other: &Self) -> bool { + // Logical types with identical signatures are considered equal. self.signature().eq(&other.signature()) } } @@ -120,15 +124,129 @@ impl PartialOrd for dyn LogicalType { impl Ord for dyn LogicalType { fn cmp(&self, other: &Self) -> Ordering { - self.signature() - .cmp(&other.signature()) - .then(self.native().cmp(other.native())) + // Logical types with identical signatures are considered equal. + self.signature().cmp(&other.signature()) } } impl Hash for dyn LogicalType { fn hash(&self, state: &mut H) { + // Logical types with identical signatures are considered equal. self.signature().hash(state); - self.native().hash(state); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::{ + LogicalField, LogicalFields, logical_boolean, logical_date, logical_float32, + logical_float64, logical_int32, logical_int64, logical_null, logical_string, + }; + use arrow::datatypes::{DataType, Field, Fields}; + use insta::assert_snapshot; + + #[test] + fn test_logical_type_display_simple() { + assert_snapshot!(logical_null(), @"Null"); + assert_snapshot!(logical_boolean(), @"Boolean"); + assert_snapshot!(logical_int32(), @"Int32"); + assert_snapshot!(logical_int64(), @"Int64"); + assert_snapshot!(logical_float32(), @"Float32"); + assert_snapshot!(logical_float64(), @"Float64"); + assert_snapshot!(logical_string(), @"String"); + assert_snapshot!(logical_date(), @"Date"); + } + + #[test] + fn test_logical_type_display_list() { + let list_type: Arc = Arc::new(NativeType::List(Arc::new( + LogicalField::from(&Field::new("item", DataType::Int32, true)), + ))); + assert_snapshot!(list_type, @"List(Int32)"); + } + + #[test] + fn test_logical_type_display_struct() { + let struct_type: Arc = Arc::new(NativeType::Struct( + LogicalFields::from(&Fields::from(vec![ + Field::new("x", DataType::Float64, false), + Field::new("y", DataType::Float64, true), + ])), + )); + assert_snapshot!(struct_type, @r#"Struct("x": non-null Float64, "y": Float64)"#); + } + + #[test] + fn test_logical_type_display_fixed_size_list() { + let fsl_type: Arc = Arc::new(NativeType::FixedSizeList( + Arc::new(LogicalField::from(&Field::new( + "item", + DataType::Float32, + false, + ))), + 3, + )); + assert_snapshot!(fsl_type, @"FixedSizeList(3 x non-null Float32)"); + } + + #[test] + fn test_logical_type_display_map() { + let map_type: Arc = Arc::new(NativeType::Map(Arc::new( + LogicalField::from(&Field::new("entries", DataType::Utf8, false)), + ))); + assert_snapshot!(map_type, @"Map(non-null String)"); + } + + #[test] + fn test_logical_type_display_union() { + use arrow::datatypes::UnionFields; + + let union_fields = UnionFields::try_new( + vec![0, 1], + vec![ + Field::new("int_val", DataType::Int32, false), + Field::new("str_val", DataType::Utf8, true), + ], + ) + .unwrap(); + let union_type: Arc = Arc::new(NativeType::Union( + crate::types::LogicalUnionFields::from(&union_fields), + )); + assert_snapshot!(union_type, @r#"Union(0: ("int_val": non-null Int32), 1: ("str_val": String))"#); + } + + #[test] + fn test_logical_type_display_nullable_vs_non_nullable() { + let nullable_list: Arc = Arc::new(NativeType::List(Arc::new( + LogicalField::from(&Field::new("item", DataType::Int32, true)), + ))); + let non_nullable_list: Arc = + Arc::new(NativeType::List(Arc::new(LogicalField::from(&Field::new( + "item", + DataType::Int32, + false, + ))))); + + assert_snapshot!(nullable_list, @"List(Int32)"); + assert_snapshot!(non_nullable_list, @"List(non-null Int32)"); + } + + #[test] + fn test_logical_type_display_extension() { + struct JsonType; + impl LogicalType for JsonType { + fn native(&self) -> &NativeType { + &NativeType::String + } + fn signature(&self) -> TypeSignature<'_> { + TypeSignature::Extension { + name: "JSON", + parameters: &[], + } + } + } + let json: Arc = Arc::new(JsonType); + assert_snapshot!(json, @"JSON"); } } diff --git a/datafusion/common/src/types/native.rs b/datafusion/common/src/types/native.rs index 39c79b4b99742..a4202db986bbf 100644 --- a/datafusion/common/src/types/native.rs +++ b/datafusion/common/src/types/native.rs @@ -19,10 +19,11 @@ use super::{ LogicalField, LogicalFieldRef, LogicalFields, LogicalType, LogicalUnionFields, TypeSignature, }; -use crate::error::{Result, _internal_err}; +use crate::error::{_internal_err, Result}; use arrow::compute::can_cast_types; use arrow::datatypes::{ - DataType, Field, FieldRef, Fields, IntervalUnit, TimeUnit, UnionFields, + DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, DECIMAL128_MAX_PRECISION, DataType, + Field, FieldRef, Fields, IntervalUnit, TimeUnit, UnionFields, }; use std::{fmt::Display, sync::Arc}; @@ -183,9 +184,82 @@ pub enum NativeType { Map(LogicalFieldRef), } +/// Format a [`LogicalField`] for display, matching [`arrow::datatypes::DataType`]'s +/// Display convention of showing a `"non-null "` prefix for non-nullable fields. +fn format_logical_field( + f: &mut std::fmt::Formatter<'_>, + field: &LogicalField, +) -> std::fmt::Result { + let non_null = if field.nullable { "" } else { "non-null " }; + write!(f, "{:?}: {non_null}{}", field.name, field.logical_type) +} + impl Display for NativeType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "NativeType::{self:?}") + // Match the format used by arrow::datatypes::DataType's Display impl + match self { + Self::Null => write!(f, "Null"), + Self::Boolean => write!(f, "Boolean"), + Self::Int8 => write!(f, "Int8"), + Self::Int16 => write!(f, "Int16"), + Self::Int32 => write!(f, "Int32"), + Self::Int64 => write!(f, "Int64"), + Self::UInt8 => write!(f, "UInt8"), + Self::UInt16 => write!(f, "UInt16"), + Self::UInt32 => write!(f, "UInt32"), + Self::UInt64 => write!(f, "UInt64"), + Self::Float16 => write!(f, "Float16"), + Self::Float32 => write!(f, "Float32"), + Self::Float64 => write!(f, "Float64"), + Self::Timestamp(unit, Some(tz)) => write!(f, "Timestamp({unit}, {tz:?})"), + Self::Timestamp(unit, None) => write!(f, "Timestamp({unit})"), + Self::Date => write!(f, "Date"), + Self::Time(unit) => write!(f, "Time({unit})"), + Self::Duration(unit) => write!(f, "Duration({unit})"), + Self::Interval(unit) => write!(f, "Interval({unit:?})"), + Self::Binary => write!(f, "Binary"), + Self::FixedSizeBinary(size) => write!(f, "FixedSizeBinary({size})"), + Self::String => write!(f, "String"), + Self::List(field) => { + let non_null = if field.nullable { "" } else { "non-null " }; + write!(f, "List({non_null}{})", field.logical_type) + } + Self::FixedSizeList(field, size) => { + let non_null = if field.nullable { "" } else { "non-null " }; + write!( + f, + "FixedSizeList({size} x {non_null}{})", + field.logical_type + ) + } + Self::Struct(fields) => { + write!(f, "Struct(")?; + for (i, field) in fields.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + format_logical_field(f, field)?; + } + write!(f, ")") + } + Self::Union(fields) => { + write!(f, "Union(")?; + for (i, (type_id, field)) in fields.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{type_id}: (")?; + format_logical_field(f, field)?; + write!(f, ")")?; + } + write!(f, ")") + } + Self::Decimal(precision, scale) => write!(f, "Decimal({precision}, {scale})"), + Self::Map(field) => { + let non_null = if field.nullable { "" } else { "non-null " }; + write!(f, "Map({non_null}{})", field.logical_type) + } + } } } @@ -228,13 +302,19 @@ impl LogicalType for NativeType { (Self::Float16, _) => Float16, (Self::Float32, _) => Float32, (Self::Float64, _) => Float64, - (Self::Decimal(p, s), _) if p <= &38 => Decimal128(*p, *s), + (Self::Decimal(p, s), _) if *p <= DECIMAL32_MAX_PRECISION => { + Decimal32(*p, *s) + } + (Self::Decimal(p, s), _) if *p <= DECIMAL64_MAX_PRECISION => { + Decimal64(*p, *s) + } + (Self::Decimal(p, s), _) if *p <= DECIMAL128_MAX_PRECISION => { + Decimal128(*p, *s) + } (Self::Decimal(p, s), _) => Decimal256(*p, *s), (Self::Timestamp(tu, tz), _) => Timestamp(*tu, tz.clone()), // If given type is Date, return the same type - (Self::Date, origin) if matches!(origin, Date32 | Date64) => { - origin.to_owned() - } + (Self::Date, Date32 | Date64) => origin.to_owned(), (Self::Date, _) => Date32, (Self::Time(tu), _) => match tu { TimeUnit::Second | TimeUnit::Millisecond => Time32(*tu), @@ -244,6 +324,8 @@ impl LogicalType for NativeType { (Self::Interval(iu), _) => Interval(*iu), (Self::Binary, LargeUtf8) => LargeBinary, (Self::Binary, Utf8View) => BinaryView, + // We don't cast to another kind of binary type if the origin one is already a binary type + (Self::Binary, Binary | LargeBinary | BinaryView) => origin.to_owned(), (Self::Binary, data_type) if can_cast_types(data_type, &BinaryView) => { BinaryView } @@ -352,10 +434,10 @@ impl LogicalType for NativeType { } _ => { return _internal_err!( - "Unavailable default cast for native type {:?} from physical type {:?}", - self, - origin - ) + "Unavailable default cast for native type {} from physical type {}", + self, + origin + ); } }) } @@ -407,7 +489,10 @@ impl From for NativeType { DataType::Union(union_fields, _) => { Union(LogicalUnionFields::from(&union_fields)) } - DataType::Decimal128(p, s) | DataType::Decimal256(p, s) => Decimal(p, s), + DataType::Decimal32(p, s) + | DataType::Decimal64(p, s) + | DataType::Decimal128(p, s) + | DataType::Decimal256(p, s) => Decimal(p, s), DataType::Map(field, _) => Map(Arc::new(field.as_ref().into())), DataType::Dictionary(_, data_type) => data_type.as_ref().clone().into(), DataType::RunEndEncoded(_, field) => field.data_type().clone().into(), @@ -418,22 +503,7 @@ impl From for NativeType { impl NativeType { #[inline] pub fn is_numeric(&self) -> bool { - use NativeType::*; - matches!( - self, - UInt8 - | UInt16 - | UInt32 - | UInt64 - | Int8 - | Int16 - | Int32 - | Int64 - | Float16 - | Float32 - | Float64 - | Decimal(_, _) - ) + self.is_integer() || self.is_float() || self.is_decimal() } #[inline] @@ -452,7 +522,7 @@ impl NativeType { #[inline] pub fn is_date(&self) -> bool { - matches!(self, NativeType::Date) + *self == NativeType::Date } #[inline] @@ -469,4 +539,111 @@ impl NativeType { pub fn is_duration(&self) -> bool { matches!(self, NativeType::Duration(_)) } + + #[inline] + pub fn is_binary(&self) -> bool { + matches!(self, NativeType::Binary | NativeType::FixedSizeBinary(_)) + } + + #[inline] + pub fn is_null(&self) -> bool { + *self == NativeType::Null + } + + #[inline] + pub fn is_decimal(&self) -> bool { + matches!(self, Self::Decimal(_, _)) + } + + #[inline] + pub fn is_float(&self) -> bool { + matches!(self, Self::Float16 | Self::Float32 | Self::Float64) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::LogicalField; + use arrow::datatypes::Field; + use insta::assert_snapshot; + + #[test] + fn test_native_type_display() { + assert_snapshot!(NativeType::Null, @"Null"); + assert_snapshot!(NativeType::Boolean, @"Boolean"); + assert_snapshot!(NativeType::Int8, @"Int8"); + assert_snapshot!(NativeType::Int16, @"Int16"); + assert_snapshot!(NativeType::Int32, @"Int32"); + assert_snapshot!(NativeType::Int64, @"Int64"); + assert_snapshot!(NativeType::UInt8, @"UInt8"); + assert_snapshot!(NativeType::UInt16, @"UInt16"); + assert_snapshot!(NativeType::UInt32, @"UInt32"); + assert_snapshot!(NativeType::UInt64, @"UInt64"); + assert_snapshot!(NativeType::Float16, @"Float16"); + assert_snapshot!(NativeType::Float32, @"Float32"); + assert_snapshot!(NativeType::Float64, @"Float64"); + assert_snapshot!(NativeType::Date, @"Date"); + assert_snapshot!(NativeType::Binary, @"Binary"); + assert_snapshot!(NativeType::String, @"String"); + assert_snapshot!(NativeType::FixedSizeBinary(16), @"FixedSizeBinary(16)"); + assert_snapshot!(NativeType::Decimal(10, 2), @"Decimal(10, 2)"); + } + + #[test] + fn test_native_type_display_timestamp() { + assert_snapshot!( + NativeType::Timestamp(TimeUnit::Second, None), + @"Timestamp(s)" + ); + assert_snapshot!( + NativeType::Timestamp(TimeUnit::Millisecond, None), + @"Timestamp(ms)" + ); + assert_snapshot!( + NativeType::Timestamp(TimeUnit::Nanosecond, Some(Arc::from("UTC"))), + @r#"Timestamp(ns, "UTC")"# + ); + } + + #[test] + fn test_native_type_display_time_duration_interval() { + assert_snapshot!(NativeType::Time(TimeUnit::Microsecond), @"Time(µs)"); + assert_snapshot!(NativeType::Duration(TimeUnit::Nanosecond), @"Duration(ns)"); + assert_snapshot!(NativeType::Interval(IntervalUnit::YearMonth), @"Interval(YearMonth)"); + assert_snapshot!(NativeType::Interval(IntervalUnit::MonthDayNano), @"Interval(MonthDayNano)"); + } + + #[test] + fn test_native_type_display_nested() { + let list = NativeType::List(Arc::new(LogicalField::from(&Field::new( + "item", + DataType::Int32, + true, + )))); + assert_snapshot!(list, @"List(Int32)"); + + let fixed_list = NativeType::FixedSizeList( + Arc::new(LogicalField::from(&Field::new( + "item", + DataType::Float64, + false, + ))), + 3, + ); + assert_snapshot!(fixed_list, @"FixedSizeList(3 x non-null Float64)"); + + let struct_type = NativeType::Struct(LogicalFields::from(&Fields::from(vec![ + Field::new("name", DataType::Utf8, false), + Field::new("age", DataType::Int32, true), + ]))); + assert_snapshot!(struct_type, @r#"Struct("name": non-null String, "age": Int32)"#); + + let map = NativeType::Map(Arc::new(LogicalField::from(&Field::new( + "entries", + DataType::Utf8, + false, + )))); + assert_snapshot!(map, @"Map(non-null String)"); + } } diff --git a/datafusion/common/src/utils/aggregate.rs b/datafusion/common/src/utils/aggregate.rs new file mode 100644 index 0000000000000..43bc0676b2d3c --- /dev/null +++ b/datafusion/common/src/utils/aggregate.rs @@ -0,0 +1,149 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Scalar-level aggregation utilities for statistics merging. +//! +//! Provides a cheap pairwise [`ScalarValue`] addition that directly +//! extracts inner primitive values, avoiding the expensive +//! `ScalarValue::add` path (which round-trips through Arrow arrays). +use arrow::datatypes::i256; + +use crate::stats::Precision; +use crate::{Result, ScalarValue}; + +/// Saturating addition for [`i256`] (which lacks a built-in +/// `saturating_add`). Returns `i256::MAX` on positive overflow and +/// `i256::MIN` on negative overflow. +#[inline] +fn i256_saturating_add(a: i256, b: i256) -> i256 { + match a.checked_add(b) { + Some(sum) => sum, + None => { + // If b is non-negative the overflow is positive, otherwise + // negative. + if b >= i256::ZERO { + i256::MAX + } else { + i256::MIN + } + } + } +} + +/// Add two [`ScalarValue`]s by directly extracting and adding their +/// inner primitive values. +/// +/// This avoids `ScalarValue::add` which converts both operands to +/// single-element Arrow arrays, runs the `add_wrapping` kernel, and +/// converts the result back — 3 heap allocations per call. +/// +/// For non-primitive types, falls back to `ScalarValue::add`. +pub(crate) fn scalar_add(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { + macro_rules! add_int { + ($lhs:expr, $rhs:expr, $VARIANT:ident) => { + match ($lhs, $rhs) { + (ScalarValue::$VARIANT(Some(a)), ScalarValue::$VARIANT(Some(b))) => { + Ok(ScalarValue::$VARIANT(Some(a.saturating_add(*b)))) + } + (ScalarValue::$VARIANT(None), other) + | (other, ScalarValue::$VARIANT(None)) => Ok(other.clone()), + _ => unreachable!(), + } + }; + } + + macro_rules! add_decimal { + ($lhs:expr, $rhs:expr, $VARIANT:ident) => { + match ($lhs, $rhs) { + ( + ScalarValue::$VARIANT(Some(a), p, s), + ScalarValue::$VARIANT(Some(b), _, _), + ) => Ok(ScalarValue::$VARIANT(Some(a.saturating_add(*b)), *p, *s)), + (ScalarValue::$VARIANT(None, _, _), other) + | (other, ScalarValue::$VARIANT(None, _, _)) => Ok(other.clone()), + _ => unreachable!(), + } + }; + } + + macro_rules! add_float { + ($lhs:expr, $rhs:expr, $VARIANT:ident) => { + match ($lhs, $rhs) { + (ScalarValue::$VARIANT(Some(a)), ScalarValue::$VARIANT(Some(b))) => { + Ok(ScalarValue::$VARIANT(Some(*a + *b))) + } + (ScalarValue::$VARIANT(None), other) + | (other, ScalarValue::$VARIANT(None)) => Ok(other.clone()), + _ => unreachable!(), + } + }; + } + + match lhs { + ScalarValue::Int8(_) => add_int!(lhs, rhs, Int8), + ScalarValue::Int16(_) => add_int!(lhs, rhs, Int16), + ScalarValue::Int32(_) => add_int!(lhs, rhs, Int32), + ScalarValue::Int64(_) => add_int!(lhs, rhs, Int64), + ScalarValue::UInt8(_) => add_int!(lhs, rhs, UInt8), + ScalarValue::UInt16(_) => add_int!(lhs, rhs, UInt16), + ScalarValue::UInt32(_) => add_int!(lhs, rhs, UInt32), + ScalarValue::UInt64(_) => add_int!(lhs, rhs, UInt64), + ScalarValue::Float16(_) => add_float!(lhs, rhs, Float16), + ScalarValue::Float32(_) => add_float!(lhs, rhs, Float32), + ScalarValue::Float64(_) => add_float!(lhs, rhs, Float64), + ScalarValue::Decimal32(_, _, _) => add_decimal!(lhs, rhs, Decimal32), + ScalarValue::Decimal64(_, _, _) => add_decimal!(lhs, rhs, Decimal64), + ScalarValue::Decimal128(_, _, _) => add_decimal!(lhs, rhs, Decimal128), + ScalarValue::Decimal256(_, _, _) => match (lhs, rhs) { + ( + ScalarValue::Decimal256(Some(a), p, s), + ScalarValue::Decimal256(Some(b), _, _), + ) => Ok(ScalarValue::Decimal256( + Some(i256_saturating_add(*a, *b)), + *p, + *s, + )), + (ScalarValue::Decimal256(None, _, _), other) + | (other, ScalarValue::Decimal256(None, _, _)) => Ok(other.clone()), + _ => unreachable!(), + }, + // Fallback: use the existing ScalarValue::add + _ => lhs.add(rhs), + } +} + +/// [`Precision`]-aware sum of two [`ScalarValue`] precisions using +/// cheap direct addition via [`scalar_add`]. +/// +/// Mirrors the semantics of `Precision::add` but avoids +/// the expensive `ScalarValue::add` round-trip through Arrow arrays. +pub(crate) fn precision_add( + lhs: &Precision, + rhs: &Precision, +) -> Precision { + match (lhs, rhs) { + (Precision::Exact(a), Precision::Exact(b)) => scalar_add(a, b) + .map(Precision::Exact) + .unwrap_or(Precision::Absent), + (Precision::Inexact(a), Precision::Exact(b)) + | (Precision::Exact(a), Precision::Inexact(b)) + | (Precision::Inexact(a), Precision::Inexact(b)) => scalar_add(a, b) + .map(Precision::Inexact) + .unwrap_or(Precision::Absent), + (_, _) => Precision::Absent, + } +} diff --git a/datafusion/common/src/utils/memory.rs b/datafusion/common/src/utils/memory.rs index 7ac081e0beb84..78ec434d2b577 100644 --- a/datafusion/common/src/utils/memory.rs +++ b/datafusion/common/src/utils/memory.rs @@ -17,8 +17,11 @@ //! This module provides a function to estimate the memory size of a HashTable prior to allocation -use crate::{DataFusionError, Result}; -use std::mem::size_of; +use crate::error::_exec_datafusion_err; +use crate::{HashSet, Result}; +use arrow::array::ArrayData; +use arrow::record_batch::RecordBatch; +use std::{mem::size_of, ptr::NonNull}; /// Estimates the memory size required for a hash table prior to allocation. /// @@ -36,7 +39,7 @@ use std::mem::size_of; /// buckets. /// - One byte overhead for each bucket. /// - The fixed size overhead of the collection. -/// - If the estimation overflows, we return a [`DataFusionError`] +/// - If the estimation overflows, we return a [`crate::error::DataFusionError`] /// /// # Examples /// --- @@ -55,8 +58,8 @@ use std::mem::size_of; /// impl MyStruct { /// fn size(&self) -> Result { /// let num_elements = self.values.len(); -/// let fixed_size = std::mem::size_of_val(self) + -/// std::mem::size_of_val(&self.values); +/// let fixed_size = +/// std::mem::size_of_val(self) + std::mem::size_of_val(&self.values); /// /// estimate_memory_size::(num_elements, fixed_size) /// } @@ -72,8 +75,8 @@ use std::mem::size_of; /// let num_rows = 100; /// let fixed_size = std::mem::size_of::>(); /// let estimated_hashtable_size = -/// estimate_memory_size::<(u64, u64)>(num_rows,fixed_size) -/// .expect("Size estimation failed"); +/// estimate_memory_size::<(u64, u64)>(num_rows, fixed_size) +/// .expect("Size estimation failed"); /// ``` pub fn estimate_memory_size(num_elements: usize, fixed_size: usize) -> Result { // For the majority of cases hashbrown overestimates the bucket quantity @@ -94,12 +97,78 @@ pub fn estimate_memory_size(num_elements: usize, fixed_size: usize) -> Result .checked_add(fixed_size) }) .ok_or_else(|| { - DataFusionError::Execution( - "usize overflow while estimating the number of buckets".to_string(), - ) + _exec_datafusion_err!("usize overflow while estimating the number of buckets") }) } +/// Calculate total used memory of this batch. +/// +/// This function is used to estimate the physical memory usage of the `RecordBatch`. +/// It only counts the memory of large data `Buffer`s, and ignores metadata like +/// types and pointers. +/// The implementation will add up all unique `Buffer`'s memory +/// size, due to: +/// - The data pointer inside `Buffer` are memory regions returned by global memory +/// allocator, those regions can't have overlap. +/// - The actual used range of `ArrayRef`s inside `RecordBatch` can have overlap +/// or reuse the same `Buffer`. For example: taking a slice from `Array`. +/// +/// Example: +/// For a `RecordBatch` with two columns: `col1` and `col2`, two columns are pointing +/// to a sub-region of the same buffer. +/// +/// {xxxxxxxxxxxxxxxxxxx} <--- buffer +/// ^ ^ ^ ^ +/// | | | | +/// col1->{ } | | +/// col2--------->{ } +/// +/// In the above case, `get_record_batch_memory_size` will return the size of +/// the buffer, instead of the sum of `col1` and `col2`'s actual memory size. +/// +/// Note: Current `RecordBatch`.get_array_memory_size()` will double count the +/// buffer memory size if multiple arrays within the batch are sharing the same +/// `Buffer`. This method provides temporary fix until the issue is resolved: +/// +pub fn get_record_batch_memory_size(batch: &RecordBatch) -> usize { + // Store pointers to `Buffer`'s start memory address (instead of actual + // used data region's pointer represented by current `Array`) + let mut counted_buffers: HashSet> = HashSet::new(); + let mut total_size = 0; + + for array in batch.columns() { + let array_data = array.to_data(); + count_array_data_memory_size(&array_data, &mut counted_buffers, &mut total_size); + } + + total_size +} + +/// Count the memory usage of `array_data` and its children recursively. +fn count_array_data_memory_size( + array_data: &ArrayData, + counted_buffers: &mut HashSet>, + total_size: &mut usize, +) { + // Count memory usage for `array_data` + for buffer in array_data.buffers() { + if counted_buffers.insert(buffer.data_ptr()) { + *total_size += buffer.capacity(); + } // Otherwise the buffer's memory is already counted + } + + if let Some(null_buffer) = array_data.nulls() + && counted_buffers.insert(null_buffer.inner().inner().data_ptr()) + { + *total_size += null_buffer.inner().inner().capacity(); + } + + // Count all children `ArrayData` recursively + for child in array_data.child_data() { + count_array_data_memory_size(child, counted_buffers, total_size); + } +} + #[cfg(test)] mod tests { use std::{collections::HashSet, mem::size_of}; @@ -133,3 +202,129 @@ mod tests { assert!(estimated.is_err()); } } + +#[cfg(test)] +mod record_batch_tests { + use super::*; + use arrow::array::{Float64Array, Int32Array, ListArray}; + use arrow::datatypes::{DataType, Field, Int32Type, Schema}; + use std::sync::Arc; + + #[test] + fn test_get_record_batch_memory_size() { + let schema = Arc::new(Schema::new(vec![ + Field::new("ints", DataType::Int32, true), + Field::new("float64", DataType::Float64, false), + ])); + + let int_array = + Int32Array::from(vec![Some(1), Some(2), Some(3), Some(4), Some(5)]); + let float64_array = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]); + + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(int_array), Arc::new(float64_array)], + ) + .unwrap(); + + let size = get_record_batch_memory_size(&batch); + assert_eq!(size, 60); + } + + #[test] + fn test_get_record_batch_memory_size_with_null() { + let schema = Arc::new(Schema::new(vec![ + Field::new("ints", DataType::Int32, true), + Field::new("float64", DataType::Float64, false), + ])); + + let int_array = Int32Array::from(vec![None, Some(2), Some(3)]); + let float64_array = Float64Array::from(vec![1.0, 2.0, 3.0]); + + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(int_array), Arc::new(float64_array)], + ) + .unwrap(); + + let size = get_record_batch_memory_size(&batch); + assert_eq!(size, 100); + } + + #[test] + fn test_get_record_batch_memory_size_empty() { + let schema = Arc::new(Schema::new(vec![Field::new( + "ints", + DataType::Int32, + false, + )])); + + let int_array: Int32Array = Int32Array::from(vec![] as Vec); + let batch = RecordBatch::try_new(schema, vec![Arc::new(int_array)]).unwrap(); + + let size = get_record_batch_memory_size(&batch); + assert_eq!(size, 0, "Empty batch should have 0 memory size"); + } + + #[test] + fn test_get_record_batch_memory_size_shared_buffer() { + let original = Int32Array::from(vec![1, 2, 3, 4, 5]); + let slice1 = original.slice(0, 3); + let slice2 = original.slice(2, 3); + + let schema_origin = Arc::new(Schema::new(vec![Field::new( + "origin_col", + DataType::Int32, + false, + )])); + let batch_origin = + RecordBatch::try_new(schema_origin, vec![Arc::new(original)]).unwrap(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("slice1", DataType::Int32, false), + Field::new("slice2", DataType::Int32, false), + ])); + + let batch_sliced = + RecordBatch::try_new(schema, vec![Arc::new(slice1), Arc::new(slice2)]) + .unwrap(); + + let size_origin = get_record_batch_memory_size(&batch_origin); + let size_sliced = get_record_batch_memory_size(&batch_sliced); + + assert_eq!(size_origin, size_sliced); + } + + #[test] + fn test_get_record_batch_memory_size_nested_array() { + let schema = Arc::new(Schema::new(vec![ + Field::new( + "nested_int", + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))), + false, + ), + Field::new( + "nested_int2", + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))), + false, + ), + ])); + + let int_list_array = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + ]); + + let int_list_array2 = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(4), Some(5), Some(6)]), + ]); + + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(int_list_array), Arc::new(int_list_array2)], + ) + .unwrap(); + + let size = get_record_batch_memory_size(&batch); + assert_eq!(size, 8208); + } +} diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 409f248621f7f..075a189c371dc 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -17,25 +17,26 @@ //! This module provides the bisect function, which implements binary search. +pub(crate) mod aggregate; pub mod expr; pub mod memory; pub mod proxy; pub mod string_utils; -use crate::error::{_exec_datafusion_err, _internal_datafusion_err, _internal_err}; -use crate::{DataFusionError, Result, ScalarValue}; +use crate::assert_or_internal_err; +use crate::error::{_exec_datafusion_err, _internal_datafusion_err}; +use crate::{Result, ScalarValue}; use arrow::array::{ - cast::AsArray, Array, ArrayRef, FixedSizeListArray, LargeListArray, ListArray, - OffsetSizeTrait, + Array, ArrayRef, FixedSizeListArray, LargeListArray, ListArray, OffsetSizeTrait, + cast::AsArray, }; use arrow::buffer::OffsetBuffer; -use arrow::compute::{partition, SortColumn, SortOptions}; +use arrow::compute::{SortColumn, SortOptions, partition}; use arrow::datatypes::{DataType, Field, SchemaRef}; -use sqlparser::ast::Ident; -use sqlparser::dialect::GenericDialect; -use sqlparser::parser::Parser; +#[cfg(feature = "sql")] +use sqlparser::{ast::Ident, dialect::GenericDialect, parser::Parser}; use std::borrow::{Borrow, Cow}; -use std::cmp::{min, Ordering}; +use std::cmp::{Ordering, min}; use std::collections::HashSet; use std::num::NonZero; use std::ops::Range; @@ -47,36 +48,33 @@ use std::thread::available_parallelism; /// /// Example: /// ``` -/// use arrow::datatypes::{SchemaRef, Schema, Field, DataType}; +/// use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; /// use datafusion_common::project_schema; /// /// // Schema with columns 'a', 'b', and 'c' /// let schema = SchemaRef::new(Schema::new(vec![ -/// Field::new("a", DataType::Int32, true), -/// Field::new("b", DataType::Int64, true), -/// Field::new("c", DataType::Utf8, true), +/// Field::new("a", DataType::Int32, true), +/// Field::new("b", DataType::Int64, true), +/// Field::new("c", DataType::Utf8, true), /// ])); /// /// // Pick columns 'c' and 'b' -/// let projection = Some(vec![2,1]); -/// let projected_schema = project_schema( -/// &schema, -/// projection.as_ref() -/// ).unwrap(); +/// let projection = Some(vec![2, 1]); +/// let projected_schema = project_schema(&schema, projection.as_ref()).unwrap(); /// /// let expected_schema = SchemaRef::new(Schema::new(vec![ -/// Field::new("c", DataType::Utf8, true), -/// Field::new("b", DataType::Int64, true), +/// Field::new("c", DataType::Utf8, true), +/// Field::new("b", DataType::Int64, true), /// ])); /// /// assert_eq!(projected_schema, expected_schema); /// ``` pub fn project_schema( schema: &SchemaRef, - projection: Option<&Vec>, + projection: Option<&impl AsRef<[usize]>>, ) -> Result { let schema = match projection { - Some(columns) => Arc::new(schema.project(columns)?), + Some(columns) => Arc::new(schema.project(columns.as_ref())?), None => Arc::clone(schema), }; Ok(schema) @@ -120,14 +118,13 @@ pub fn compare_rows( let result = match (lhs.is_null(), rhs.is_null(), sort_options.nulls_first) { (true, false, false) | (false, true, true) => Ordering::Greater, (true, false, true) | (false, true, false) => Ordering::Less, - (false, false, _) => if sort_options.descending { - rhs.partial_cmp(lhs) - } else { - lhs.partial_cmp(rhs) + (false, false, _) => { + if sort_options.descending { + rhs.try_cmp(lhs)? + } else { + lhs.try_cmp(rhs)? + } } - .ok_or_else(|| { - _internal_datafusion_err!("Column array shouldn't be empty") - })?, (true, true, _) => continue, }; if result != Ordering::Equal { @@ -149,9 +146,7 @@ pub fn bisect( let low: usize = 0; let high: usize = item_columns .first() - .ok_or_else(|| { - DataFusionError::Internal("Column array shouldn't be empty".to_string()) - })? + .ok_or_else(|| _internal_datafusion_err!("Column array shouldn't be empty"))? .len(); let compare_fn = |current: &[ScalarValue], target: &[ScalarValue]| { let cmp = compare_rows(current, target, sort_options)?; @@ -200,9 +195,7 @@ pub fn linear_search( let low: usize = 0; let high: usize = item_columns .first() - .ok_or_else(|| { - DataFusionError::Internal("Column array shouldn't be empty".to_string()) - })? + .ok_or_else(|| _internal_datafusion_err!("Column array shouldn't be empty"))? .len(); let compare_fn = |current: &[ScalarValue], target: &[ScalarValue]| { let cmp = compare_rows(current, target, sort_options)?; @@ -261,7 +254,7 @@ pub fn evaluate_partition_ranges( /// the identifier by replacing it with two double quotes /// /// e.g. identifier `tab.le"name` becomes `"tab.le""name"` -pub fn quote_identifier(s: &str) -> Cow { +pub fn quote_identifier(s: &str) -> Cow<'_, str> { if needs_quotes(s) { Cow::Owned(format!("\"{}\"", s.replace('"', "\"\""))) } else { @@ -274,15 +267,16 @@ fn needs_quotes(s: &str) -> bool { let mut chars = s.chars(); // first char can not be a number unless escaped - if let Some(first_char) = chars.next() { - if !(first_char.is_ascii_lowercase() || first_char == '_') { - return true; - } + if let Some(first_char) = chars.next() + && !(first_char.is_ascii_lowercase() || first_char == '_') + { + return true; } !chars.all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '_') } +#[cfg(feature = "sql")] pub(crate) fn parse_identifiers(s: &str) -> Result> { let dialect = GenericDialect; let mut parser = Parser::new(&dialect).try_with_sql(s)?; @@ -290,6 +284,10 @@ pub(crate) fn parse_identifiers(s: &str) -> Result> { Ok(idents) } +/// Parse a string into a vector of identifiers. +/// +/// Note: If ignore_case is false, the string will be normalized to lowercase. +#[cfg(feature = "sql")] pub(crate) fn parse_identifiers_normalized(s: &str, ignore_case: bool) -> Vec { parse_identifiers(s) .unwrap_or_default() @@ -302,6 +300,59 @@ pub(crate) fn parse_identifiers_normalized(s: &str, ignore_case: bool) -> Vec>() } +#[cfg(not(feature = "sql"))] +pub(crate) fn parse_identifiers(s: &str) -> Result> { + let mut result = Vec::new(); + let mut current = String::new(); + let mut in_quotes = false; + + for ch in s.chars() { + match ch { + '"' => { + in_quotes = !in_quotes; + current.push(ch); + } + '.' if !in_quotes => { + result.push(current.clone()); + current.clear(); + } + _ => { + current.push(ch); + } + } + } + + // Push the last part if it's not empty + if !current.is_empty() { + result.push(current); + } + + Ok(result) +} + +#[cfg(not(feature = "sql"))] +pub(crate) fn parse_identifiers_normalized(s: &str, ignore_case: bool) -> Vec { + parse_identifiers(s) + .unwrap_or_default() + .into_iter() + .map(|id| { + let is_double_quoted = if id.len() > 2 { + let mut chars = id.chars(); + chars.next() == Some('"') && chars.last() == Some('"') + } else { + false + }; + if is_double_quoted { + id[1..id.len() - 1].to_string().replace("\"\"", "\"") + } else if ignore_case { + id + } else { + id.to_ascii_lowercase() + } + }) + .collect::>() +} + /// This function "takes" the elements at `indices` from the slice `items`. pub fn get_at_indices>( items: &[T], @@ -312,9 +363,7 @@ pub fn get_at_indices>( .map(|idx| items.get(*idx.borrow()).cloned()) .collect::>>() .ok_or_else(|| { - DataFusionError::Execution( - "Expects indices to be in the range of searched vector".to_string(), - ) + _exec_datafusion_err!("Expects indices to be in the range of searched vector") }) } @@ -348,9 +397,11 @@ pub fn longest_consecutive_prefix>( /// # use arrow::array::types::Int64Type; /// # use datafusion_common::utils::SingleRowListArrayBuilder; /// // Array is [1, 2, 3] -/// let arr = ListArray::from_iter_primitive::(vec![ -/// Some(vec![Some(1), Some(2), Some(3)]), -/// ]); +/// let arr = ListArray::from_iter_primitive::(vec![Some(vec![ +/// Some(1), +/// Some(2), +/// Some(3), +/// ])]); /// // Wrap as a list array: [[1, 2, 3]] /// let list_arr = SingleRowListArrayBuilder::new(Arc::new(arr)).build_list_array(); /// assert_eq!(list_arr.len(), 1); @@ -445,94 +496,6 @@ impl SingleRowListArrayBuilder { } } -/// Wrap an array into a single element `ListArray`. -/// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]` -/// The field in the list array is nullable. -#[deprecated( - since = "44.0.0", - note = "please use `SingleRowListArrayBuilder` instead" -)] -pub fn array_into_list_array_nullable(arr: ArrayRef) -> ListArray { - SingleRowListArrayBuilder::new(arr) - .with_nullable(true) - .build_list_array() -} - -/// Wrap an array into a single element `ListArray`. -/// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]` -#[deprecated( - since = "44.0.0", - note = "please use `SingleRowListArrayBuilder` instead" -)] -pub fn array_into_list_array(arr: ArrayRef, nullable: bool) -> ListArray { - SingleRowListArrayBuilder::new(arr) - .with_nullable(nullable) - .build_list_array() -} - -#[deprecated( - since = "44.0.0", - note = "please use `SingleRowListArrayBuilder` instead" -)] -pub fn array_into_list_array_with_field_name( - arr: ArrayRef, - nullable: bool, - field_name: &str, -) -> ListArray { - SingleRowListArrayBuilder::new(arr) - .with_nullable(nullable) - .with_field_name(Some(field_name.to_string())) - .build_list_array() -} - -/// Wrap an array into a single element `LargeListArray`. -/// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]` -#[deprecated( - since = "44.0.0", - note = "please use `SingleRowListArrayBuilder` instead" -)] -pub fn array_into_large_list_array(arr: ArrayRef) -> LargeListArray { - SingleRowListArrayBuilder::new(arr).build_large_list_array() -} - -#[deprecated( - since = "44.0.0", - note = "please use `SingleRowListArrayBuilder` instead" -)] -pub fn array_into_large_list_array_with_field_name( - arr: ArrayRef, - field_name: &str, -) -> LargeListArray { - SingleRowListArrayBuilder::new(arr) - .with_field_name(Some(field_name.to_string())) - .build_large_list_array() -} - -#[deprecated( - since = "44.0.0", - note = "please use `SingleRowListArrayBuilder` instead" -)] -pub fn array_into_fixed_size_list_array( - arr: ArrayRef, - list_size: usize, -) -> FixedSizeListArray { - SingleRowListArrayBuilder::new(arr).build_fixed_size_list_array(list_size) -} - -#[deprecated( - since = "44.0.0", - note = "please use `SingleRowListArrayBuilder` instead" -)] -pub fn array_into_fixed_size_list_array_with_field_name( - arr: ArrayRef, - list_size: usize, - field_name: &str, -) -> FixedSizeListArray { - SingleRowListArrayBuilder::new(arr) - .with_field_name(Some(field_name.to_string())) - .build_fixed_size_list_array(list_size) -} - /// Wrap arrays into a single element `ListArray`. /// /// Example: @@ -554,13 +517,12 @@ pub fn array_into_fixed_size_list_array_with_field_name( /// ); /// /// assert_eq!(list_arr, expected); +/// ``` pub fn arrays_into_list_array( arr: impl IntoIterator, ) -> Result { let arr = arr.into_iter().collect::>(); - if arr.is_empty() { - return _internal_err!("Cannot wrap empty array into list array"); - } + assert_or_internal_err!(!arr.is_empty(), "Cannot wrap empty array into list array"); let lens = arr.iter().map(|x| x.len()).collect::>(); // Assume data type is consistent @@ -592,7 +554,8 @@ pub fn fixed_size_list_to_arrays(a: &ArrayRef) -> Vec { /// use datafusion_common::utils::base_type; /// use std::sync::Arc; /// -/// let data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))); +/// let data_type = +/// DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))); /// assert_eq!(base_type(&data_type), DataType::Int32); /// /// let data_type = DataType::Int32; @@ -626,6 +589,7 @@ pub enum ListCoercion { /// let base_type = DataType::Float64; /// let coerced_type = coerced_type_with_base_type_only(&data_type, &base_type, None); /// assert_eq!(coerced_type, DataType::List(Arc::new(Field::new_list_field(DataType::Float64, true)))); +/// ``` pub fn coerced_type_with_base_type_only( data_type: &DataType, base_type: &DataType, @@ -732,10 +696,14 @@ pub mod datafusion_strsim { } /// Calculates the minimum number of insertions, deletions, and substitutions - /// required to change one sequence into the other. - fn generic_levenshtein<'a, 'b, Iter1, Iter2, Elem1, Elem2>( + /// required to change one sequence into the other, using a reusable cache buffer. + /// + /// This is the generic implementation that works with any iterator types. + /// The `cache` buffer will be resized as needed and reused across calls. + fn generic_levenshtein_with_buffer<'a, 'b, Iter1, Iter2, Elem1, Elem2>( a: &'a Iter1, b: &'b Iter2, + cache: &mut Vec, ) -> usize where &'a Iter1: IntoIterator, @@ -748,7 +716,9 @@ pub mod datafusion_strsim { return b_len; } - let mut cache: Vec = (1..b_len + 1).collect(); + // Resize cache to fit b_len elements + cache.clear(); + cache.extend(1..=b_len); let mut result = 0; @@ -768,6 +738,21 @@ pub mod datafusion_strsim { result } + /// Calculates the minimum number of insertions, deletions, and substitutions + /// required to change one sequence into the other. + fn generic_levenshtein<'a, 'b, Iter1, Iter2, Elem1, Elem2>( + a: &'a Iter1, + b: &'b Iter2, + ) -> usize + where + &'a Iter1: IntoIterator, + &'b Iter2: IntoIterator, + Elem1: PartialEq, + { + let mut cache = Vec::new(); + generic_levenshtein_with_buffer(a, b, &mut cache) + } + /// Calculates the minimum number of insertions, deletions, and substitutions /// required to change one string into the other. /// @@ -780,6 +765,15 @@ pub mod datafusion_strsim { generic_levenshtein(&StringWrapper(a), &StringWrapper(b)) } + /// Calculates the Levenshtein distance using a reusable cache buffer. + /// This avoids allocating a new Vec for each call, improving performance + /// when computing many distances. + /// + /// The `cache` buffer will be resized as needed and reused across calls. + pub fn levenshtein_with_buffer(a: &str, b: &str, cache: &mut Vec) -> usize { + generic_levenshtein_with_buffer(&StringWrapper(a), &StringWrapper(b), cache) + } + /// Calculates the normalized Levenshtein distance between two strings. /// The normalized distance is a value between 0.0 and 1.0, where 1.0 indicates /// that the strings are identical and 0.0 indicates no similarity. @@ -833,21 +827,6 @@ pub fn set_difference, S: Borrow>( .collect() } -/// Checks whether the given index sequence is monotonically non-decreasing. -#[deprecated(since = "45.0.0", note = "Use std::Iterator::is_sorted instead")] -pub fn is_sorted>(sequence: impl IntoIterator) -> bool { - // TODO: Remove this function when `is_sorted` graduates from Rust nightly. - let mut previous = 0; - for item in sequence.into_iter() { - let current = *item.borrow(); - if current < previous { - return false; - } - previous = current; - } - true -} - /// Find indices of each element in `targets` inside `items`. If one of the /// elements is absent in `items`, returns an error. pub fn find_indices>( @@ -858,7 +837,7 @@ pub fn find_indices>( .into_iter() .map(|target| items.iter().position(|e| target.borrow().eq(e))) .collect::>() - .ok_or_else(|| DataFusionError::Execution("Target not found".to_string())) + .ok_or_else(|| _exec_datafusion_err!("Target not found")) } /// Transposes the given vector of vectors. @@ -950,7 +929,7 @@ pub fn get_available_parallelism() -> usize { .get() } -/// Converts a collection of function arguments into an fixed-size array of length N +/// Converts a collection of function arguments into a fixed-size array of length N /// producing a reasonable error message in case of unexpected number of arguments. /// /// # Example @@ -959,16 +938,19 @@ pub fn get_available_parallelism() -> usize { /// # use datafusion_common::utils::take_function_args; /// # use datafusion_common::ScalarValue; /// fn my_function(args: &[ScalarValue]) -> Result<()> { -/// // function expects 2 args, so create a 2-element array -/// let [arg1, arg2] = take_function_args("my_function", args)?; -/// // ... do stuff.. -/// Ok(()) +/// // function expects 2 args, so create a 2-element array +/// let [arg1, arg2] = take_function_args("my_function", args)?; +/// // ... do stuff.. +/// Ok(()) /// } /// /// // Calling the function with 1 argument produces an error: /// let args = vec![ScalarValue::Int32(Some(10))]; /// let err = my_function(&args).unwrap_err(); -/// assert_eq!(err.to_string(), "Execution error: my_function function requires 2 arguments, got 1"); +/// assert_eq!( +/// err.to_string(), +/// "Execution error: my_function function requires 2 arguments, got 1" +/// ); /// // Calling the function with 2 arguments works great /// let args = vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(20))]; /// my_function(&args).unwrap(); @@ -994,7 +976,6 @@ mod tests { use super::*; use crate::ScalarValue::Null; use arrow::array::Float64Array; - use sqlparser::tokenizer::Span; #[test] fn test_bisect_linear_left_and_right() -> Result<()> { @@ -1190,6 +1171,7 @@ mod tests { Ok(()) } + #[cfg(feature = "sql")] #[test] fn test_quote_identifier() -> Result<()> { let cases = vec![ @@ -1222,7 +1204,7 @@ mod tests { let expected_parsed = vec![Ident { value: identifier.to_string(), quote_style, - span: Span::empty(), + span: sqlparser::tokenizer::Span::empty(), }]; assert_eq!( @@ -1275,19 +1257,6 @@ mod tests { assert_eq!(set_difference([3, 4, 0], [4, 1, 2]), vec![3, 0]); } - #[test] - #[expect(deprecated)] - fn test_is_sorted() { - assert!(is_sorted::([])); - assert!(is_sorted([0])); - assert!(is_sorted([0, 3, 4])); - assert!(is_sorted([0, 1, 2])); - assert!(is_sorted([0, 1, 4])); - assert!(is_sorted([0usize; 0])); - assert!(is_sorted([1, 2])); - assert!(!is_sorted([3, 2])); - } - #[test] fn test_find_indices() -> Result<()> { assert_eq!(find_indices(&[0, 3, 4], [0, 3, 4])?, vec![0, 1, 2]); diff --git a/datafusion/common/src/utils/proxy.rs b/datafusion/common/src/utils/proxy.rs index d940677a5fb3b..846c928515d60 100644 --- a/datafusion/common/src/utils/proxy.rs +++ b/datafusion/common/src/utils/proxy.rs @@ -15,12 +15,9 @@ // specific language governing permissions and limitations // under the License. -//! [`VecAllocExt`] and [`RawTableAllocExt`] to help tracking of memory allocations +//! [`VecAllocExt`] to help tracking of memory allocations -use hashbrown::{ - hash_table::HashTable, - raw::{Bucket, RawTable}, -}; +use hashbrown::hash_table::HashTable; use std::mem::size_of; /// Extension trait for [`Vec`] to account for allocations. @@ -47,7 +44,9 @@ pub trait VecAllocExt { /// assert_eq!(allocated, 16); // no new allocation needed /// /// // push more data into the vec - /// for _ in 0..10 { vec.push_accounted(1, &mut allocated); } + /// for _ in 0..10 { + /// vec.push_accounted(1, &mut allocated); + /// } /// assert_eq!(allocated, 64); // underlying vec has space for 10 u32s /// assert_eq!(vec.allocated_size(), 64); /// ``` @@ -82,7 +81,9 @@ pub trait VecAllocExt { /// assert_eq!(vec.allocated_size(), 16); // no new allocation needed /// /// // push more data into the vec - /// for _ in 0..10 { vec.push(1); } + /// for _ in 0..10 { + /// vec.push(1); + /// } /// assert_eq!(vec.allocated_size(), 64); // space for 64 now /// ``` fn allocated_size(&self) -> usize; @@ -110,73 +111,6 @@ impl VecAllocExt for Vec { } } -/// Extension trait for hash browns [`RawTable`] to account for allocations. -pub trait RawTableAllocExt { - /// Item type. - type T; - - /// [Insert](RawTable::insert) new element into table and increase - /// `accounting` by any newly allocated bytes. - /// - /// Returns the bucket where the element was inserted. - /// Note that allocation counts capacity, not size. - /// - /// # Example: - /// ``` - /// # use datafusion_common::utils::proxy::RawTableAllocExt; - /// # use hashbrown::raw::RawTable; - /// let mut table = RawTable::new(); - /// let mut allocated = 0; - /// let hash_fn = |x: &u32| (*x as u64) % 1000; - /// // pretend 0x3117 is the hash value for 1 - /// table.insert_accounted(1, hash_fn, &mut allocated); - /// assert_eq!(allocated, 64); - /// - /// // insert more values - /// for i in 0..100 { table.insert_accounted(i, hash_fn, &mut allocated); } - /// assert_eq!(allocated, 400); - /// ``` - fn insert_accounted( - &mut self, - x: Self::T, - hasher: impl Fn(&Self::T) -> u64, - accounting: &mut usize, - ) -> Bucket; -} - -impl RawTableAllocExt for RawTable { - type T = T; - - fn insert_accounted( - &mut self, - x: Self::T, - hasher: impl Fn(&Self::T) -> u64, - accounting: &mut usize, - ) -> Bucket { - let hash = hasher(&x); - - match self.try_insert_no_grow(hash, x) { - Ok(bucket) => bucket, - Err(x) => { - // need to request more memory - - let bump_elements = self.capacity().max(16); - let bump_size = bump_elements * size_of::(); - *accounting = (*accounting).checked_add(bump_size).expect("overflow"); - - self.reserve(bump_elements, hasher); - - // still need to insert the element since first try failed - // Note: cannot use `.expect` here because `T` may not implement `Debug` - match self.try_insert_no_grow(hash, x) { - Ok(bucket) => bucket, - Err(_) => panic!("just grew the container"), - } - } - } - } -} - /// Extension trait for hash browns [`HashTable`] to account for allocations. pub trait HashTableAllocExt { /// Item type. @@ -187,6 +121,8 @@ pub trait HashTableAllocExt { /// /// Returns the bucket where the element was inserted. /// Note that allocation counts capacity, not size. + /// Panics: + /// Assumes the element is not already present, and may panic if it does /// /// # Example: /// ``` @@ -200,7 +136,9 @@ pub trait HashTableAllocExt { /// assert_eq!(allocated, 64); /// /// // insert more values - /// for i in 0..100 { table.insert_accounted(i, hash_fn, &mut allocated); } + /// for i in 2..100 { + /// table.insert_accounted(i, hash_fn, &mut allocated); + /// } /// assert_eq!(allocated, 400); /// ``` fn insert_accounted( @@ -225,22 +163,24 @@ where ) { let hash = hasher(&x); - // NOTE: `find_entry` does NOT grow! - match self.find_entry(hash, |y| y == &x) { - Ok(_occupied) => {} - Err(_absent) => { - if self.len() == self.capacity() { - // need to request more memory - let bump_elements = self.capacity().max(16); - let bump_size = bump_elements * size_of::(); - *accounting = (*accounting).checked_add(bump_size).expect("overflow"); + if cfg!(debug_assertions) { + // In debug mode, check that the element is not already present + debug_assert!( + self.find_entry(hash, |y| y == &x).is_err(), + "attempted to insert duplicate element into HashTableAllocExt::insert_accounted" + ); + } - self.reserve(bump_elements, &hasher); - } + if self.len() == self.capacity() { + // need to request more memory + let bump_elements = self.capacity().max(16); + let bump_size = bump_elements * size_of::(); + *accounting = (*accounting).checked_add(bump_size).expect("overflow"); - // still need to insert the element since first try failed - self.entry(hash, |y| y == &x, hasher).insert(x); - } + self.reserve(bump_elements, &hasher); } + + // We assume the element is not already present + self.insert_unique(hash, x, hasher); } } diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 03a9ec8f3f150..326b791a2f624 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -32,6 +32,9 @@ rust-version = { workspace = true } [package.metadata.docs.rs] all-features = true +# Note: add additional linter rules in lib.rs. +# Rust does not support workspace + new linter rules in subcrates yet +# https://github.com/rust-lang/cargo/issues/13157 [lints] workspace = true @@ -43,10 +46,11 @@ array_expressions = ["nested_expressions"] avro = ["datafusion-common/avro", "datafusion-datasource-avro"] backtrace = ["datafusion-common/backtrace"] compression = [ - "xz2", + "liblzma", "bzip2", "flate2", "zstd", + "datafusion-datasource-arrow/compression", "datafusion-datasource/compression", ] crypto_expressions = ["datafusion-functions/crypto_expressions"] @@ -62,13 +66,19 @@ default = [ "compression", "parquet", "recursive_protection", + "sql", ] encoding_expressions = ["datafusion-functions/encoding_expressions"] # Used for testing ONLY: causes all values to hash to the same value (test for collisions) force_hash_collisions = ["datafusion-physical-plan/force_hash_collisions", "datafusion-common/force_hash_collisions"] math_expressions = ["datafusion-functions/math_expressions"] parquet = ["datafusion-common/parquet", "dep:parquet", "datafusion-datasource-parquet"] -pyarrow = ["datafusion-common/pyarrow", "parquet"] +parquet_encryption = [ + "parquet", + "parquet/encryption", + "datafusion-common/parquet_encryption", + "datafusion-datasource-parquet/parquet_encryption", +] regex_expressions = [ "datafusion-functions/regex_expressions", ] @@ -77,7 +87,9 @@ recursive_protection = [ "datafusion-expr/recursive_protection", "datafusion-optimizer/recursive_protection", "datafusion-physical-optimizer/recursive_protection", - "datafusion-sql/recursive_protection", + "datafusion-physical-expr/recursive_protection", + "datafusion-sql?/recursive_protection", + "sqlparser?/recursive-protection", ] serde = [ "dep:serde", @@ -85,62 +97,66 @@ serde = [ # statements in `arrow-schema` crate "arrow-schema/serde", ] +sql = [ + "datafusion-common/sql", + "datafusion-functions-nested?/sql", + "datafusion-sql", + "sqlparser", +] string_expressions = ["datafusion-functions/string_expressions"] unicode_expressions = [ - "datafusion-sql/unicode_expressions", + "datafusion-sql?/unicode_expressions", "datafusion-functions/unicode_expressions", ] extended_tests = [] [dependencies] arrow = { workspace = true } -arrow-ipc = { workspace = true } arrow-schema = { workspace = true, features = ["canonical_extension_types"] } async-trait = { workspace = true } -bytes = { workspace = true } -bzip2 = { version = "0.5.2", optional = true } +bzip2 = { workspace = true, optional = true } chrono = { workspace = true } datafusion-catalog = { workspace = true } datafusion-catalog-listing = { workspace = true } datafusion-common = { workspace = true, features = ["object_store"] } datafusion-common-runtime = { workspace = true } datafusion-datasource = { workspace = true } +datafusion-datasource-arrow = { workspace = true } datafusion-datasource-avro = { workspace = true, optional = true } datafusion-datasource-csv = { workspace = true } datafusion-datasource-json = { workspace = true } datafusion-datasource-parquet = { workspace = true, optional = true } datafusion-execution = { workspace = true } -datafusion-expr = { workspace = true } +datafusion-expr = { workspace = true, default-features = false } datafusion-expr-common = { workspace = true } datafusion-functions = { workspace = true } datafusion-functions-aggregate = { workspace = true } -datafusion-functions-nested = { workspace = true, optional = true } +datafusion-functions-nested = { workspace = true, default-features = false, optional = true } datafusion-functions-table = { workspace = true } datafusion-functions-window = { workspace = true } datafusion-optimizer = { workspace = true } datafusion-physical-expr = { workspace = true } +datafusion-physical-expr-adapter = { workspace = true } datafusion-physical-expr-common = { workspace = true } datafusion-physical-optimizer = { workspace = true } datafusion-physical-plan = { workspace = true } datafusion-session = { workspace = true } -datafusion-sql = { workspace = true } -flate2 = { version = "1.1.1", optional = true } +datafusion-sql = { workspace = true, optional = true } +flate2 = { workspace = true, optional = true } futures = { workspace = true } itertools = { workspace = true } +liblzma = { workspace = true, optional = true } log = { workspace = true } object_store = { workspace = true } parking_lot = { workspace = true } parquet = { workspace = true, optional = true, default-features = true } -rand = { workspace = true } -regex = { workspace = true } serde = { version = "1.0", default-features = false, features = ["derive"], optional = true } -sqlparser = { workspace = true } +sqlparser = { workspace = true, optional = true } tempfile = { workspace = true } tokio = { workspace = true } url = { workspace = true } -uuid = { version = "1.17", features = ["v4", "js"] } -xz2 = { version = "0.1", optional = true, features = ["static"] } -zstd = { version = "0.13", optional = true, default-features = false } +uuid = { workspace = true, features = ["v4", "js"] } +zstd = { workspace = true, optional = true } [dev-dependencies] async-trait = { workspace = true } @@ -152,20 +168,26 @@ datafusion-functions-window-common = { workspace = true } datafusion-macros = { workspace = true } datafusion-physical-optimizer = { workspace = true } doc-comment = { workspace = true } +bytes = { workspace = true } env_logger = { workspace = true } +glob = { workspace = true } insta = { workspace = true } -paste = "^1.0" +pretty_assertions = "1.0" rand = { workspace = true, features = ["small_rng"] } rand_distr = "0.5" +recursive = { workspace = true } regex = { workspace = true } rstest = { workspace = true } serde_json = { workspace = true } -sysinfo = "0.35.1" +sysinfo = "0.38.2" test-utils = { path = "../../test-utils" } tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot", "fs"] } +[package.metadata.cargo-machete] +ignored = ["datafusion-doc", "datafusion-macros", "dashmap"] + [target.'cfg(not(target_os = "windows"))'.dev-dependencies] -nix = { version = "0.30.1", features = ["fs"] } +nix = { version = "0.31.1", features = ["fs"] } [[bench]] harness = false @@ -203,6 +225,10 @@ name = "struct_query_sql" harness = false name = "window_query_sql" +[[bench]] +harness = false +name = "topk_repartition" + [[bench]] harness = false name = "scalar" @@ -216,10 +242,23 @@ harness = false name = "parquet_query_sql" required-features = ["parquet"] +[[bench]] +harness = false +name = "parquet_struct_query" +required-features = ["parquet"] + +[[bench]] +harness = false +name = "range_and_generate_series" + [[bench]] harness = false name = "sql_planner" +[[bench]] +harness = false +name = "sql_planner_extended" + [[bench]] harness = false name = "sql_query_with_io" @@ -244,3 +283,12 @@ name = "dataframe" [[bench]] harness = false name = "spm" + +[[bench]] +harness = false +name = "preserve_file_partitioning" +required-features = ["parquet"] + +[[bench]] +harness = false +name = "reset_plan_states" diff --git a/datafusion/core/README.md b/datafusion/core/README.md index b5501087d2647..859fcb9c0dff9 100644 --- a/datafusion/core/README.md +++ b/datafusion/core/README.md @@ -17,15 +17,12 @@ under the License. --> -# DataFusion Core + -DataFusion is an extensible query execution framework, written in Rust, -that uses Apache Arrow as its in-memory format. +# Apache DataFusion Core This crate contains the main entry points and high level DataFusion APIs such as `SessionContext`, `DataFrame` and `ListingTable`. - -For more information, please see: - -- [DataFusion Website](https://datafusion.apache.org) -- [DataFusion API Docs](https://docs.rs/datafusion/latest/datafusion/) diff --git a/datafusion/core/benches/aggregate_query_sql.rs b/datafusion/core/benches/aggregate_query_sql.rs index 057a0e1d1b54c..402ac9c7176b5 100644 --- a/datafusion/core/benches/aggregate_query_sql.rs +++ b/datafusion/core/benches/aggregate_query_sql.rs @@ -15,23 +15,21 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -extern crate arrow; -extern crate datafusion; - mod data_utils; -use crate::criterion::Criterion; + +use criterion::{Criterion, criterion_group, criterion_main}; use data_utils::create_table_provider; use datafusion::error::Result; use datafusion::execution::context::SessionContext; use parking_lot::Mutex; +use std::hint::black_box; use std::sync::Arc; use tokio::runtime::Runtime; +#[expect(clippy::needless_pass_by_value)] fn query(ctx: Arc>, rt: &Runtime, sql: &str) { let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); - criterion::black_box(rt.block_on(df.collect()).unwrap()); + black_box(rt.block_on(df.collect()).unwrap()); } fn create_context( @@ -153,6 +151,38 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); + c.bench_function( + "aggregate_query_group_by_wide_u64_and_string_without_aggregate_expressions", + |b| { + b.iter(|| { + query( + ctx.clone(), + &rt, + // Due to the large number of distinct values in u64_wide, + // this query test the actual grouping performance for more than 1 column + "SELECT u64_wide, utf8 \ + FROM t GROUP BY u64_wide, utf8", + ) + }) + }, + ); + + c.bench_function( + "aggregate_query_group_by_wide_u64_and_f32_without_aggregate_expressions", + |b| { + b.iter(|| { + query( + ctx.clone(), + &rt, + // Due to the large number of distinct values in u64_wide, + // this query test the actual grouping performance for more than 1 column + "SELECT u64_wide, f32 \ + FROM t GROUP BY u64_wide, f32", + ) + }) + }, + ); + c.bench_function("aggregate_query_approx_percentile_cont_on_u64", |b| { b.iter(|| { query( @@ -221,6 +251,50 @@ fn criterion_benchmark(c: &mut Criterion) { ) }) }); + + c.bench_function("array_agg_query_group_by_few_groups", |b| { + b.iter(|| { + query( + ctx.clone(), + &rt, + "SELECT u64_narrow, array_agg(f64) \ + FROM t GROUP BY u64_narrow", + ) + }) + }); + + c.bench_function("array_agg_query_group_by_mid_groups", |b| { + b.iter(|| { + query( + ctx.clone(), + &rt, + "SELECT u64_mid, array_agg(f64) \ + FROM t GROUP BY u64_mid", + ) + }) + }); + + c.bench_function("array_agg_query_group_by_many_groups", |b| { + b.iter(|| { + query( + ctx.clone(), + &rt, + "SELECT u64_wide, array_agg(f64) \ + FROM t GROUP BY u64_wide", + ) + }) + }); + + c.bench_function("array_agg_struct_query_group_by_mid_groups", |b| { + b.iter(|| { + query( + ctx.clone(), + &rt, + "SELECT u64_mid, array_agg(named_struct('market', dict10, 'price', f64)) \ + FROM t GROUP BY u64_mid", + ) + }) + }); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/core/benches/csv_load.rs b/datafusion/core/benches/csv_load.rs index 3f984757466d5..13843dadddd0c 100644 --- a/datafusion/core/benches/csv_load.rs +++ b/datafusion/core/benches/csv_load.rs @@ -15,23 +15,21 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -extern crate arrow; -extern crate datafusion; - mod data_utils; -use crate::criterion::Criterion; + +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion::error::Result; use datafusion::execution::context::SessionContext; use datafusion::prelude::CsvReadOptions; use datafusion::test_util::csv::TestCsvFile; use parking_lot::Mutex; +use std::hint::black_box; use std::sync::Arc; use std::time::Duration; use test_utils::AccessLogGenerator; use tokio::runtime::Runtime; +#[expect(clippy::needless_pass_by_value)] fn load_csv( ctx: Arc>, rt: &Runtime, @@ -39,7 +37,7 @@ fn load_csv( options: CsvReadOptions, ) { let df = rt.block_on(ctx.lock().read_csv(path, options)).unwrap(); - criterion::black_box(rt.block_on(df.collect()).unwrap()); + black_box(rt.block_on(df.collect()).unwrap()); } fn create_context() -> Result>> { diff --git a/datafusion/core/benches/data_utils/mod.rs b/datafusion/core/benches/data_utils/mod.rs index c0477b1306f75..728c6490c72bd 100644 --- a/datafusion/core/benches/data_utils/mod.rs +++ b/datafusion/core/benches/data_utils/mod.rs @@ -18,10 +18,11 @@ //! This module provides the in-memory table for more realistic benchmarking. use arrow::array::{ - builder::{Int64Builder, StringBuilder}, ArrayRef, Float32Array, Float64Array, RecordBatch, StringArray, StringViewBuilder, UInt64Array, + builder::{Int64Builder, StringBuilder, StringDictionaryBuilder}, }; +use arrow::datatypes::Int32Type; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::datasource::MemTable; use datafusion::error::Result; @@ -36,6 +37,7 @@ use std::sync::Arc; /// create an in-memory table given the partition len, array len, and batch size, /// and the result table will be of array_len in total, and then partitioned, and batched. +#[expect(clippy::allow_attributes)] // some issue where expect(dead_code) doesn't fire properly #[allow(dead_code)] pub fn create_table_provider( partitions_len: usize, @@ -44,7 +46,7 @@ pub fn create_table_provider( ) -> Result> { let schema = Arc::new(create_schema()); let partitions = - create_record_batches(schema.clone(), array_len, partitions_len, batch_size); + create_record_batches(&schema, array_len, partitions_len, batch_size); // declare a table in memory. In spark API, this corresponds to createDataFrame(...). MemTable::try_new(schema, partitions).map(Arc::new) } @@ -55,21 +57,24 @@ pub fn create_schema() -> Schema { Field::new("utf8", DataType::Utf8, false), Field::new("f32", DataType::Float32, false), Field::new("f64", DataType::Float64, true), - // This field will contain integers randomly selected from a large - // range of values, i.e. [0, u64::MAX], such that there are none (or - // very few) repeated values. - Field::new("u64_wide", DataType::UInt64, true), - // This field will contain integers randomly selected from a narrow - // range of values such that there are a few distinct values, but they - // are repeated often. + // Integers randomly selected from a wide range of values, i.e. [0, + // u64::MAX], such that there are ~no repeated values. + Field::new("u64_wide", DataType::UInt64, false), + // Integers randomly selected from a mid-range of values [0, 1000), + // providing ~1000 distinct groups. + Field::new("u64_mid", DataType::UInt64, false), + // Integers randomly selected from a narrow range of values such that + // there are a few distinct values, but they are repeated often. Field::new("u64_narrow", DataType::UInt64, false), + Field::new( + "dict10", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + ), ]) } -fn create_data(size: usize, null_density: f64) -> Vec> { - // use random numbers to avoid spurious compiler optimizations wrt to branching - let mut rng = StdRng::seed_from_u64(42); - +fn create_data(rng: &mut StdRng, size: usize, null_density: f64) -> Vec> { (0..size) .map(|_| { if rng.random::() > null_density { @@ -81,56 +86,54 @@ fn create_data(size: usize, null_density: f64) -> Vec> { .collect() } -fn create_integer_data(size: usize, value_density: f64) -> Vec> { - // use random numbers to avoid spurious compiler optimizations wrt to branching - let mut rng = StdRng::seed_from_u64(42); - - (0..size) - .map(|_| { - if rng.random::() > value_density { - None - } else { - Some(rng.random::()) - } - }) - .collect() -} - fn create_record_batch( schema: SchemaRef, rng: &mut StdRng, batch_size: usize, - i: usize, + batch_index: usize, ) -> RecordBatch { - // the 4 here is the number of different keys. - // a higher number increase sparseness - let vs = [0, 1, 2, 3]; - let keys: Vec = (0..batch_size) - .map( - // use random numbers to avoid spurious compiler optimizations wrt to branching - |_| format!("hi{:?}", vs.choose(rng)), - ) - .collect(); - let keys: Vec<&str> = keys.iter().map(|e| &**e).collect(); + // Randomly choose from 4 distinct key values; a higher number increases sparseness. + let key_suffixes = [0, 1, 2, 3]; + let keys = StringArray::from_iter_values( + (0..batch_size).map(|_| format!("hi{}", key_suffixes.choose(rng).unwrap())), + ); - let values = create_data(batch_size, 0.5); + let values = create_data(rng, batch_size, 0.5); // Integer values between [0, u64::MAX]. - let integer_values_wide = create_integer_data(batch_size, 9.0); + let integer_values_wide = (0..batch_size) + .map(|_| rng.random::()) + .collect::>(); + + // Integer values between [0, 1000). + let integer_values_mid = (0..batch_size) + .map(|_| rng.random_range(0..1000)) + .collect::>(); - // Integer values between [0, 9]. + // Integer values between [0, 10). let integer_values_narrow = (0..batch_size) - .map(|_| rng.random_range(0_u64..10)) + .map(|_| rng.random_range(0..10)) .collect::>(); + let mut dict_builder = StringDictionaryBuilder::::new(); + for _ in 0..batch_size { + if rng.random::() > 0.9 { + dict_builder.append_null(); + } else { + dict_builder.append_value(format!("market_{}", rng.random_range(0..10))); + } + } + RecordBatch::try_new( schema, vec![ - Arc::new(StringArray::from(keys)), - Arc::new(Float32Array::from(vec![i as f32; batch_size])), + Arc::new(keys), + Arc::new(Float32Array::from(vec![batch_index as f32; batch_size])), Arc::new(Float64Array::from(values)), Arc::new(UInt64Array::from(integer_values_wide)), + Arc::new(UInt64Array::from(integer_values_mid)), Arc::new(UInt64Array::from(integer_values_narrow)), + Arc::new(dict_builder.finish()), ], ) .unwrap() @@ -139,19 +142,28 @@ fn create_record_batch( /// Create record batches of `partitions_len` partitions and `batch_size` for each batch, /// with a total number of `array_len` records pub fn create_record_batches( - schema: SchemaRef, + schema: &SchemaRef, array_len: usize, partitions_len: usize, batch_size: usize, ) -> Vec> { let mut rng = StdRng::seed_from_u64(42); - (0..partitions_len) - .map(|_| { - (0..array_len / batch_size / partitions_len) - .map(|i| create_record_batch(schema.clone(), &mut rng, batch_size, i)) - .collect::>() - }) - .collect::>() + let mut partitions = Vec::with_capacity(partitions_len); + let batches_per_partition = array_len / batch_size / partitions_len; + + for _ in 0..partitions_len { + let mut batches = Vec::with_capacity(batches_per_partition); + for batch_index in 0..batches_per_partition { + batches.push(create_record_batch( + schema.clone(), + &mut rng, + batch_size, + batch_index, + )); + } + partitions.push(batches); + } + partitions } /// An enum that wraps either a regular StringBuilder or a GenericByteViewBuilder @@ -181,6 +193,7 @@ impl TraceIdBuilder { /// Create time series data with `partition_cnt` partitions and `sample_cnt` rows per partition /// in ascending order, if `asc` is true, otherwise randomly sampled using a Pareto distribution +#[expect(clippy::allow_attributes)] // some issue where expect(dead_code) doesn't fire properly #[allow(dead_code)] pub(crate) fn make_data( partition_cnt: i32, diff --git a/datafusion/core/benches/dataframe.rs b/datafusion/core/benches/dataframe.rs index 12eb34719e4ba..5aeade315cc7b 100644 --- a/datafusion/core/benches/dataframe.rs +++ b/datafusion/core/benches/dataframe.rs @@ -15,17 +15,13 @@ // specific language governing permissions and limitations // under the License. -extern crate arrow; -#[macro_use] -extern crate criterion; -extern crate datafusion; - use arrow_schema::{DataType, Field, Schema}; -use criterion::Criterion; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion::datasource::MemTable; use datafusion::prelude::SessionContext; use datafusion_expr::col; use datafusion_functions::expr_fn::btrim; +use std::hint::black_box; use std::sync::Arc; use tokio::runtime::Runtime; @@ -44,8 +40,9 @@ fn create_context(field_count: u32) -> datafusion_common::Result, rt: &Runtime) { - criterion::black_box(rt.block_on(async { + black_box(rt.block_on(async { let mut data_frame = ctx.table("t").await.unwrap(); for i in 0..column_count { diff --git a/datafusion/core/benches/distinct_query_sql.rs b/datafusion/core/benches/distinct_query_sql.rs index c1ef55992689e..d389b1b3d6a22 100644 --- a/datafusion/core/benches/distinct_query_sql.rs +++ b/datafusion/core/benches/distinct_query_sql.rs @@ -15,27 +15,25 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -extern crate arrow; -extern crate datafusion; - mod data_utils; -use crate::criterion::Criterion; + +use criterion::{Criterion, criterion_group, criterion_main}; use data_utils::{create_table_provider, make_data}; use datafusion::execution::context::SessionContext; -use datafusion::physical_plan::{collect, ExecutionPlan}; +use datafusion::physical_plan::{ExecutionPlan, collect}; use datafusion::{datasource::MemTable, error::Result}; -use datafusion_execution::config::SessionConfig; use datafusion_execution::TaskContext; +use datafusion_execution::config::SessionConfig; use parking_lot::Mutex; +use std::hint::black_box; use std::{sync::Arc, time::Duration}; use tokio::runtime::Runtime; +#[expect(clippy::needless_pass_by_value)] fn query(ctx: Arc>, rt: &Runtime, sql: &str) { let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); - criterion::black_box(rt.block_on(df.collect()).unwrap()); + black_box(rt.block_on(df.collect()).unwrap()); } fn create_context( @@ -123,9 +121,9 @@ async fn distinct_with_limit( Ok(()) } +#[expect(clippy::needless_pass_by_value)] fn run(rt: &Runtime, plan: Arc, ctx: Arc) { - criterion::black_box(rt.block_on(distinct_with_limit(plan.clone(), ctx.clone()))) - .unwrap(); + black_box(rt.block_on(distinct_with_limit(plan.clone(), ctx.clone()))).unwrap(); } pub async fn create_context_sampled_data( diff --git a/datafusion/core/benches/filter_query_sql.rs b/datafusion/core/benches/filter_query_sql.rs index c82a1607184dc..3b80518d32dcd 100644 --- a/datafusion/core/benches/filter_query_sql.rs +++ b/datafusion/core/benches/filter_query_sql.rs @@ -20,17 +20,18 @@ use arrow::{ datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion::prelude::SessionContext; use datafusion::{datasource::MemTable, error::Result}; use futures::executor::block_on; +use std::hint::black_box; use std::sync::Arc; use tokio::runtime::Runtime; async fn query(ctx: &SessionContext, rt: &Runtime, sql: &str) { // execute the query let df = rt.block_on(ctx.sql(sql)).unwrap(); - criterion::black_box(rt.block_on(df.collect()).unwrap()); + black_box(rt.block_on(df.collect()).unwrap()); } fn create_context(array_len: usize, batch_size: usize) -> Result { diff --git a/datafusion/core/benches/map_query_sql.rs b/datafusion/core/benches/map_query_sql.rs index 97d47fc3b9079..67904197bc257 100644 --- a/datafusion/core/benches/map_query_sql.rs +++ b/datafusion/core/benches/map_query_sql.rs @@ -15,13 +15,15 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashSet; +use std::hint::black_box; use std::sync::Arc; use arrow::array::{ArrayRef, Int32Array, RecordBatch}; -use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use parking_lot::Mutex; -use rand::prelude::ThreadRng; use rand::Rng; +use rand::prelude::ThreadRng; use tokio::runtime::Runtime; use datafusion::prelude::SessionContext; @@ -32,11 +34,12 @@ use datafusion_functions_nested::map::map; mod data_utils; fn build_keys(rng: &mut ThreadRng) -> Vec { - let mut keys = vec![]; - for _ in 0..1000 { - keys.push(rng.random_range(0..9999).to_string()); + let mut keys = HashSet::with_capacity(1000); + while keys.len() < 1000 { + let key = rng.random_range(0..9999).to_string(); + keys.insert(key); } - keys + keys.into_iter().collect() } fn build_values(rng: &mut ThreadRng) -> Vec { @@ -71,8 +74,11 @@ fn criterion_benchmark(c: &mut Criterion) { let mut value_buffer = Vec::new(); for i in 0..1000 { - key_buffer.push(Expr::Literal(ScalarValue::Utf8(Some(keys[i].clone())))); - value_buffer.push(Expr::Literal(ScalarValue::Int32(Some(values[i])))); + key_buffer.push(Expr::Literal( + ScalarValue::Utf8(Some(keys[i].clone())), + None, + )); + value_buffer.push(Expr::Literal(ScalarValue::Int32(Some(values[i])), None)); } c.bench_function("map_1000_1", |b| { b.iter(|| { diff --git a/datafusion/core/benches/math_query_sql.rs b/datafusion/core/benches/math_query_sql.rs index 76824850c114c..f5df56e95a2d8 100644 --- a/datafusion/core/benches/math_query_sql.rs +++ b/datafusion/core/benches/math_query_sql.rs @@ -15,18 +15,13 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -use criterion::Criterion; +use criterion::{Criterion, criterion_group, criterion_main}; use parking_lot::Mutex; use std::sync::Arc; use tokio::runtime::Runtime; -extern crate arrow; -extern crate datafusion; - use arrow::{ array::{Float32Array, Float64Array}, datatypes::{DataType, Field, Schema}, @@ -36,6 +31,7 @@ use datafusion::datasource::MemTable; use datafusion::error::Result; use datafusion::execution::context::SessionContext; +#[expect(clippy::needless_pass_by_value)] fn query(ctx: Arc>, rt: &Runtime, sql: &str) { // execute the query let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); diff --git a/datafusion/core/benches/parquet_query_sql.rs b/datafusion/core/benches/parquet_query_sql.rs index 14dcdf15f173b..f099137973592 100644 --- a/datafusion/core/benches/parquet_query_sql.rs +++ b/datafusion/core/benches/parquet_query_sql.rs @@ -23,14 +23,14 @@ use arrow::datatypes::{ SchemaRef, }; use arrow::record_batch::RecordBatch; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::instant::Instant; use futures::stream::StreamExt; use parquet::arrow::ArrowWriter; use parquet::file::properties::{WriterProperties, WriterVersion}; -use rand::distr::uniform::SampleUniform; use rand::distr::Alphanumeric; +use rand::distr::uniform::SampleUniform; use rand::prelude::*; use rand::rng; use std::fs::File; @@ -45,7 +45,7 @@ const NUM_BATCHES: usize = 2048; /// The number of rows in each record batch to write const WRITE_RECORD_BATCH_SIZE: usize = 1024; /// The number of rows in a row group -const ROW_GROUP_SIZE: usize = 1024 * 1024; +const ROW_GROUP_ROW_COUNT: usize = 1024 * 1024; /// The number of row groups expected const EXPECTED_ROW_GROUPS: usize = 2; @@ -154,7 +154,7 @@ fn generate_file() -> NamedTempFile { let properties = WriterProperties::builder() .set_writer_version(WriterVersion::PARQUET_2_0) - .set_max_row_group_size(ROW_GROUP_SIZE) + .set_max_row_group_row_count(Some(ROW_GROUP_ROW_COUNT)) .build(); let mut writer = @@ -166,11 +166,12 @@ fn generate_file() -> NamedTempFile { } let metadata = writer.close().unwrap(); + let file_metadata = metadata.file_metadata(); assert_eq!( - metadata.num_rows as usize, + file_metadata.num_rows() as usize, WRITE_RECORD_BATCH_SIZE * NUM_BATCHES ); - assert_eq!(metadata.row_groups.len(), EXPECTED_ROW_GROUPS); + assert_eq!(metadata.row_groups().len(), EXPECTED_ROW_GROUPS); println!( "Generated parquet file in {} seconds", diff --git a/datafusion/core/benches/parquet_struct_query.rs b/datafusion/core/benches/parquet_struct_query.rs new file mode 100644 index 0000000000000..e7e91f0dd0e1e --- /dev/null +++ b/datafusion/core/benches/parquet_struct_query.rs @@ -0,0 +1,312 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmarks of SQL queries on struct columns in parquet data + +use arrow::array::{ArrayRef, Int32Array, StringArray, StructArray}; +use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion::prelude::SessionContext; +use datafusion_common::instant::Instant; +use parquet::arrow::ArrowWriter; +use parquet::file::properties::{WriterProperties, WriterVersion}; +use rand::distr::Alphanumeric; +use rand::prelude::*; +use rand::rng; +use std::hint::black_box; +use std::ops::Range; +use std::path::Path; +use std::sync::Arc; +use tempfile::NamedTempFile; +use tokio::runtime::Runtime; + +/// The number of batches to write +const NUM_BATCHES: usize = 128; +/// The number of rows in each record batch to write +const WRITE_RECORD_BATCH_SIZE: usize = 4096; +/// The number of rows in a row group +const ROW_GROUP_ROW_COUNT: usize = 65536; +/// The number of row groups expected +const EXPECTED_ROW_GROUPS: usize = 8; +/// The range for random string lengths +const STRING_LENGTH_RANGE: Range = 50..200; + +fn schema() -> SchemaRef { + let struct_fields = Fields::from(vec![ + Field::new("id", DataType::Int32, false), + Field::new("value", DataType::Utf8, false), + ]); + let struct_type = DataType::Struct(struct_fields); + + Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("s", struct_type, false), + ])) +} + +fn generate_strings(len: usize) -> ArrayRef { + let mut rng = rng(); + Arc::new(StringArray::from_iter((0..len).map(|_| { + let string_len = rng.random_range(STRING_LENGTH_RANGE.clone()); + Some( + (0..string_len) + .map(|_| char::from(rng.sample(Alphanumeric))) + .collect::(), + ) + }))) +} + +fn generate_batch(batch_id: usize) -> RecordBatch { + let schema = schema(); + let len = WRITE_RECORD_BATCH_SIZE; + + // Generate sequential IDs based on batch_id for uniqueness + let base_id = (batch_id * len) as i32; + let id_values: Vec = (0..len).map(|i| base_id + i as i32).collect(); + let id_array = Arc::new(Int32Array::from(id_values.clone())); + + // Create struct id array (matching top-level id) + let struct_id_array = Arc::new(Int32Array::from(id_values)); + + // Generate random strings for struct value field + let value_array = generate_strings(len); + + // Construct StructArray + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("id", DataType::Int32, false)), + struct_id_array as ArrayRef, + ), + ( + Arc::new(Field::new("value", DataType::Utf8, false)), + value_array, + ), + ]); + + RecordBatch::try_new(schema, vec![id_array, Arc::new(struct_array)]).unwrap() +} + +fn generate_file() -> NamedTempFile { + let now = Instant::now(); + let mut named_file = tempfile::Builder::new() + .prefix("parquet_struct_query") + .suffix(".parquet") + .tempfile() + .unwrap(); + + println!("Generating parquet file - {}", named_file.path().display()); + let schema = schema(); + + let properties = WriterProperties::builder() + .set_writer_version(WriterVersion::PARQUET_2_0) + .set_max_row_group_row_count(Some(ROW_GROUP_ROW_COUNT)) + .build(); + + let mut writer = + ArrowWriter::try_new(&mut named_file, schema, Some(properties)).unwrap(); + + for batch_id in 0..NUM_BATCHES { + let batch = generate_batch(batch_id); + writer.write(&batch).unwrap(); + } + + let metadata = writer.close().unwrap(); + let file_metadata = metadata.file_metadata(); + let expected_rows = WRITE_RECORD_BATCH_SIZE * NUM_BATCHES; + assert_eq!( + file_metadata.num_rows() as usize, + expected_rows, + "Expected {} rows but got {}", + expected_rows, + file_metadata.num_rows() + ); + assert_eq!( + metadata.row_groups().len(), + EXPECTED_ROW_GROUPS, + "Expected {} row groups but got {}", + EXPECTED_ROW_GROUPS, + metadata.row_groups().len() + ); + + println!( + "Generated parquet file with {} rows and {} row groups in {} seconds", + file_metadata.num_rows(), + metadata.row_groups().len(), + now.elapsed().as_secs_f32() + ); + + named_file +} + +fn create_context(file_path: &str) -> SessionContext { + let ctx = SessionContext::new(); + let rt = Runtime::new().unwrap(); + rt.block_on(ctx.register_parquet("t", file_path, Default::default())) + .unwrap(); + ctx +} + +fn query(ctx: &SessionContext, rt: &Runtime, sql: &str) { + let ctx = ctx.clone(); + let sql = sql.to_string(); + let df = rt.block_on(ctx.sql(&sql)).unwrap(); + black_box(rt.block_on(df.collect()).unwrap()); +} + +fn criterion_benchmark(c: &mut Criterion) { + let (file_path, temp_file) = match std::env::var("PARQUET_FILE") { + Ok(file) => (file, None), + Err(_) => { + let temp_file = generate_file(); + (temp_file.path().display().to_string(), Some(temp_file)) + } + }; + + assert!(Path::new(&file_path).exists(), "path not found"); + println!("Using parquet file {file_path}"); + + let ctx = create_context(&file_path); + let rt = Runtime::new().unwrap(); + + // Basic struct access + c.bench_function("struct_access", |b| { + b.iter(|| query(&ctx, &rt, "select id, s['id'] from t")) + }); + + // Filter queries + c.bench_function("filter_struct_field_eq", |b| { + b.iter(|| query(&ctx, &rt, "select id from t where s['id'] = 5")) + }); + + c.bench_function("filter_struct_field_with_select", |b| { + b.iter(|| query(&ctx, &rt, "select id, s['id'] from t where s['id'] = 5")) + }); + + c.bench_function("filter_top_level_with_struct_select", |b| { + b.iter(|| query(&ctx, &rt, "select s['id'] from t where id = 5")) + }); + + c.bench_function("filter_struct_string_length", |b| { + b.iter(|| query(&ctx, &rt, "select id from t where length(s['value']) > 100")) + }); + + c.bench_function("filter_struct_range", |b| { + b.iter(|| { + query( + &ctx, + &rt, + "select id from t where s['id'] > 100 and s['id'] < 200", + ) + }) + }); + + // Join queries (limited with WHERE id < 1000 for performance) + c.bench_function("join_struct_to_struct", |b| { + b.iter(|| query( + &ctx, + &rt, + "select t1.id from t t1 join t t2 on t1.s['id'] = t2.s['id'] where t1.id < 1000" + )) + }); + + c.bench_function("join_struct_to_toplevel", |b| { + b.iter(|| query( + &ctx, + &rt, + "select t1.id from t t1 join t t2 on t1.s['id'] = t2.id where t1.id < 1000" + )) + }); + + c.bench_function("join_toplevel_to_struct", |b| { + b.iter(|| query( + &ctx, + &rt, + "select t1.id from t t1 join t t2 on t1.id = t2.s['id'] where t1.id < 1000" + )) + }); + + c.bench_function("join_struct_to_struct_with_top_level", |b| { + b.iter(|| query( + &ctx, + &rt, + "select t1.id from t t1 join t t2 on t1.s['id'] = t2.s['id'] and t1.id = t2.id where t1.id < 1000" + )) + }); + + c.bench_function("join_struct_and_struct_value", |b| { + b.iter(|| query( + &ctx, + &rt, + "select t1.s['id'], t2.s['value'] from t t1 join t t2 on t1.id = t2.id where t1.id < 1000" + )) + }); + + // Group by queries + c.bench_function("group_by_struct_field", |b| { + b.iter(|| query(&ctx, &rt, "select s['id'] from t group by s['id']")) + }); + + c.bench_function("group_by_struct_select_toplevel", |b| { + b.iter(|| query(&ctx, &rt, "select max(id) from t group by s['id']")) + }); + + c.bench_function("group_by_toplevel_select_struct", |b| { + b.iter(|| query(&ctx, &rt, "select max(s['id']) from t group by id")) + }); + + c.bench_function("group_by_struct_with_count", |b| { + b.iter(|| { + query( + &ctx, + &rt, + "select s['id'], count(*) from t group by s['id']", + ) + }) + }); + + c.bench_function("group_by_multiple_with_count", |b| { + b.iter(|| { + query( + &ctx, + &rt, + "select id, s['id'], count(*) from t group by id, s['id']", + ) + }) + }); + + // Additional queries + c.bench_function("order_by_struct_limit", |b| { + b.iter(|| { + query( + &ctx, + &rt, + "select id, s['id'] from t order by s['id'] limit 1000", + ) + }) + }); + + c.bench_function("distinct_struct_field", |b| { + b.iter(|| query(&ctx, &rt, "select distinct s['id'] from t")) + }); + + // Temporary file must outlive the benchmarks, it is deleted when dropped + drop(temp_file); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/core/benches/physical_plan.rs b/datafusion/core/benches/physical_plan.rs index 0a65c52f72def..7b66996b05929 100644 --- a/datafusion/core/benches/physical_plan.rs +++ b/datafusion/core/benches/physical_plan.rs @@ -15,11 +15,7 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -use criterion::{BatchSize, Criterion}; -extern crate arrow; -extern crate datafusion; +use criterion::{BatchSize, Criterion, criterion_group, criterion_main}; use std::sync::Arc; @@ -32,7 +28,7 @@ use tokio::runtime::Runtime; use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion::physical_plan::{ collect, - expressions::{col, PhysicalSortExpr}, + expressions::{PhysicalSortExpr, col}, }; use datafusion::prelude::SessionContext; use datafusion_datasource::memory::MemorySourceConfig; @@ -40,6 +36,7 @@ use datafusion_physical_expr_common::sort_expr::LexOrdering; // Initialize the operator using the provided record batches and the sort key // as inputs. All record batches must have the same schema. +#[expect(clippy::needless_pass_by_value)] fn sort_preserving_merge_operator( session_ctx: Arc, rt: &Runtime, @@ -50,11 +47,8 @@ fn sort_preserving_merge_operator( let sort = sort .iter() - .map(|name| PhysicalSortExpr { - expr: col(name, &schema).unwrap(), - options: Default::default(), - }) - .collect::(); + .map(|name| PhysicalSortExpr::new_default(col(name, &schema).unwrap())); + let sort = LexOrdering::new(sort).unwrap(); let exec = MemorySourceConfig::try_new_exec( &batches.into_iter().map(|rb| vec![rb]).collect::>(), diff --git a/datafusion/core/benches/preserve_file_partitioning.rs b/datafusion/core/benches/preserve_file_partitioning.rs new file mode 100644 index 0000000000000..9b1f59adc6823 --- /dev/null +++ b/datafusion/core/benches/preserve_file_partitioning.rs @@ -0,0 +1,838 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmark for `preserve_file_partitions` optimization. +//! +//! When enabled, this optimization declares Hive-partitioned tables as +//! `Hash([partition_col])` partitioned, allowing the query optimizer to +//! skip unnecessary repartitioning and sorting operations. +//! +//! When This Optimization Helps +//! - Window functions: PARTITION BY on partition column eliminates RepartitionExec and SortExec +//! - Aggregates with ORDER BY: GROUP BY partition column and ORDER BY eliminates post aggregate sort +//! +//! When This Optimization Does NOT Help +//! - GROUP BY non-partition columns: Required Hash distribution doesn't match declared partitioning +//! - When the number of distinct file partitioning groups < the number of CPUs available: Reduces +//! parallelization, thus may outweigh the pros of reduced shuffles +//! +//! Usage +//! - BENCH_SIZE=small|medium|large cargo bench -p datafusion --bench preserve_file_partitions +//! - SAVE_PLANS=1 cargo bench ... # Save query plans to files + +use arrow::array::{ArrayRef, Float64Array, StringArray, TimestampMillisecondArray}; +use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; +use arrow::record_batch::RecordBatch; +use arrow::util::pretty::pretty_format_batches; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion::prelude::{ParquetReadOptions, SessionConfig, SessionContext, col}; +use datafusion_expr::SortExpr; +use parquet::arrow::ArrowWriter; +use parquet::file::properties::WriterProperties; +use std::fs::{self, File}; +use std::io::Write; +use std::path::Path; +use std::sync::Arc; +use tempfile::TempDir; +use tokio::runtime::Runtime; + +#[derive(Debug, Clone, Copy)] +struct BenchConfig { + fact_partitions: usize, + rows_per_partition: usize, + target_partitions: usize, + measurement_time_secs: u64, +} + +impl BenchConfig { + fn small() -> Self { + Self { + fact_partitions: 10, + rows_per_partition: 1_000_000, + target_partitions: 10, + measurement_time_secs: 15, + } + } + + fn medium() -> Self { + Self { + fact_partitions: 30, + rows_per_partition: 3_000_000, + target_partitions: 30, + measurement_time_secs: 30, + } + } + + fn large() -> Self { + Self { + fact_partitions: 50, + rows_per_partition: 6_000_000, + target_partitions: 50, + measurement_time_secs: 90, + } + } + + fn from_env() -> Self { + match std::env::var("BENCH_SIZE").as_deref() { + Ok("small") | Ok("SMALL") => Self::small(), + Ok("medium") | Ok("MEDIUM") => Self::medium(), + Ok("large") | Ok("LARGE") => Self::large(), + _ => { + println!("Using SMALL dataset (set BENCH_SIZE=small|medium|large)"); + Self::small() + } + } + } + + fn total_rows(&self) -> usize { + self.fact_partitions * self.rows_per_partition + } + + fn high_cardinality(base: &Self) -> Self { + Self { + fact_partitions: (base.fact_partitions as f64 * 2.5) as usize, + rows_per_partition: base.rows_per_partition / 2, + target_partitions: base.target_partitions, + measurement_time_secs: base.measurement_time_secs, + } + } +} + +fn dkey_names(count: usize) -> Vec { + (0..count) + .map(|i| { + if i < 26 { + ((b'A' + i as u8) as char).to_string() + } else { + format!( + "{}{}", + (b'A' + ((i / 26) - 1) as u8) as char, + (b'A' + (i % 26) as u8) as char + ) + } + }) + .collect() +} + +/// Hive-partitioned fact table, sorted by timestamp within each partition. +fn generate_fact_table( + base_dir: &Path, + num_partitions: usize, + rows_per_partition: usize, +) { + let fact_dir = base_dir.join("fact"); + + let schema = Arc::new(Schema::new(vec![ + Field::new( + "timestamp", + DataType::Timestamp(TimeUnit::Millisecond, None), + false, + ), + Field::new("value", DataType::Float64, false), + ])); + + let props = WriterProperties::builder() + .set_compression(parquet::basic::Compression::SNAPPY) + .build(); + + let dkeys = dkey_names(num_partitions); + + for dkey in &dkeys { + let part_dir = fact_dir.join(format!("f_dkey={dkey}")); + fs::create_dir_all(&part_dir).unwrap(); + let file_path = part_dir.join("data.parquet"); + let file = File::create(file_path).unwrap(); + + let mut writer = + ArrowWriter::try_new(file, schema.clone(), Some(props.clone())).unwrap(); + + let base_ts = 1672567200000i64; // 2023-01-01T09:00:00 + let timestamps: Vec = (0..rows_per_partition) + .map(|i| base_ts + (i as i64 * 10000)) + .collect(); + + let values: Vec = (0..rows_per_partition) + .map(|i| 50.0 + (i % 100) as f64 + ((i % 7) as f64 * 10.0)) + .collect(); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(TimestampMillisecondArray::from(timestamps)) as ArrayRef, + Arc::new(Float64Array::from(values)), + ], + ) + .unwrap(); + + writer.write(&batch).unwrap(); + writer.close().unwrap(); + } +} + +/// Single-file dimension table for CollectLeft joins. +fn generate_dimension_table(base_dir: &Path, num_partitions: usize) { + let dim_dir = base_dir.join("dimension"); + fs::create_dir_all(&dim_dir).unwrap(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("d_dkey", DataType::Utf8, false), + Field::new("env", DataType::Utf8, false), + Field::new("service", DataType::Utf8, false), + Field::new("host", DataType::Utf8, false), + ])); + + let props = WriterProperties::builder() + .set_compression(parquet::basic::Compression::SNAPPY) + .build(); + + let file_path = dim_dir.join("data.parquet"); + let file = File::create(file_path).unwrap(); + let mut writer = ArrowWriter::try_new(file, schema.clone(), Some(props)).unwrap(); + + let dkeys = dkey_names(num_partitions); + let envs = ["dev", "prod", "staging", "test"]; + let services = ["log", "trace", "metric"]; + let hosts = ["ma", "vim", "nano", "emacs"]; + + let d_dkey_vals: Vec = dkeys.clone(); + let env_vals: Vec = dkeys + .iter() + .enumerate() + .map(|(i, _)| envs[i % envs.len()].to_string()) + .collect(); + let service_vals: Vec = dkeys + .iter() + .enumerate() + .map(|(i, _)| services[i % services.len()].to_string()) + .collect(); + let host_vals: Vec = dkeys + .iter() + .enumerate() + .map(|(i, _)| hosts[i % hosts.len()].to_string()) + .collect(); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(d_dkey_vals)) as ArrayRef, + Arc::new(StringArray::from(env_vals)), + Arc::new(StringArray::from(service_vals)), + Arc::new(StringArray::from(host_vals)), + ], + ) + .unwrap(); + + writer.write(&batch).unwrap(); + writer.close().unwrap(); +} + +struct BenchVariant { + name: &'static str, + preserve_file_partitions: usize, + prefer_existing_sort: bool, +} + +const BENCH_VARIANTS: [BenchVariant; 3] = [ + BenchVariant { + name: "with_optimization", + preserve_file_partitions: 1, + prefer_existing_sort: false, + }, + BenchVariant { + name: "prefer_existing_sort", + preserve_file_partitions: 0, + prefer_existing_sort: true, + }, + BenchVariant { + name: "without_optimization", + preserve_file_partitions: 0, + prefer_existing_sort: false, + }, +]; + +async fn save_plans( + output_file: &Path, + fact_path: &str, + dim_path: Option<&str>, + target_partitions: usize, + query: &str, + file_sort_order: Option>>, +) { + let mut file = File::create(output_file).unwrap(); + writeln!(file, "Query: {query}\n").unwrap(); + + for variant in &BENCH_VARIANTS { + let session_config = SessionConfig::new() + .with_target_partitions(target_partitions) + .set_usize( + "datafusion.optimizer.preserve_file_partitions", + variant.preserve_file_partitions, + ) + .set_bool( + "datafusion.optimizer.prefer_existing_sort", + variant.prefer_existing_sort, + ); + let ctx = SessionContext::new_with_config(session_config); + + let mut fact_options = ParquetReadOptions { + table_partition_cols: vec![("f_dkey".to_string(), DataType::Utf8)], + ..Default::default() + }; + if let Some(ref order) = file_sort_order { + fact_options.file_sort_order = order.clone(); + } + ctx.register_parquet("fact", fact_path, fact_options) + .await + .unwrap(); + + if let Some(dim) = dim_path { + let dim_schema = Arc::new(Schema::new(vec![ + Field::new("d_dkey", DataType::Utf8, false), + Field::new("env", DataType::Utf8, false), + Field::new("service", DataType::Utf8, false), + Field::new("host", DataType::Utf8, false), + ])); + let dim_options = ParquetReadOptions { + schema: Some(&dim_schema), + ..Default::default() + }; + ctx.register_parquet("dimension", dim, dim_options) + .await + .unwrap(); + } + + let df = ctx.sql(query).await.unwrap(); + let plan = df.explain(false, false).unwrap().collect().await.unwrap(); + writeln!(file, "=== {} ===", variant.name).unwrap(); + writeln!(file, "{}\n", pretty_format_batches(&plan).unwrap()).unwrap(); + } +} + +#[expect(clippy::too_many_arguments)] +fn run_benchmark( + c: &mut Criterion, + rt: &Runtime, + name: &str, + fact_path: &str, + dim_path: Option<&str>, + target_partitions: usize, + query: &str, + file_sort_order: &Option>>, +) { + if std::env::var("SAVE_PLANS").is_ok() { + let output_path = format!("{name}_plans.txt"); + rt.block_on(save_plans( + Path::new(&output_path), + fact_path, + dim_path, + target_partitions, + query, + file_sort_order.clone(), + )); + println!("Plans saved to {output_path}"); + } + + let mut group = c.benchmark_group(name); + + for variant in &BENCH_VARIANTS { + let fact_path_owned = fact_path.to_string(); + let dim_path_owned = dim_path.map(|s| s.to_string()); + let sort_order = file_sort_order.clone(); + let query_owned = query.to_string(); + let preserve_file_partitions = variant.preserve_file_partitions; + let prefer_existing_sort = variant.prefer_existing_sort; + + group.bench_function(variant.name, |b| { + b.to_async(rt).iter(|| { + let fact_path = fact_path_owned.clone(); + let dim_path = dim_path_owned.clone(); + let sort_order = sort_order.clone(); + let query = query_owned.clone(); + async move { + let session_config = SessionConfig::new() + .with_target_partitions(target_partitions) + .set_usize( + "datafusion.optimizer.preserve_file_partitions", + preserve_file_partitions, + ) + .set_bool( + "datafusion.optimizer.prefer_existing_sort", + prefer_existing_sort, + ); + let ctx = SessionContext::new_with_config(session_config); + + let mut fact_options = ParquetReadOptions { + table_partition_cols: vec![( + "f_dkey".to_string(), + DataType::Utf8, + )], + ..Default::default() + }; + if let Some(ref order) = sort_order { + fact_options.file_sort_order = order.clone(); + } + ctx.register_parquet("fact", &fact_path, fact_options) + .await + .unwrap(); + + if let Some(ref dim) = dim_path { + let dim_schema = Arc::new(Schema::new(vec![ + Field::new("d_dkey", DataType::Utf8, false), + Field::new("env", DataType::Utf8, false), + Field::new("service", DataType::Utf8, false), + Field::new("host", DataType::Utf8, false), + ])); + let dim_options = ParquetReadOptions { + schema: Some(&dim_schema), + ..Default::default() + }; + ctx.register_parquet("dimension", dim, dim_options) + .await + .unwrap(); + } + + let df = ctx.sql(&query).await.unwrap(); + df.collect().await.unwrap() + } + }) + }); + } + + group.finish(); +} + +/// Aggregate on high-cardinality partitions which eliminates repartition and sort. +/// +/// Query: SELECT f_dkey, COUNT(*), SUM(value) FROM fact GROUP BY f_dkey ORDER BY f_dkey +/// +/// ┌─────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +/// │ with_optimization │ +/// │ (preserve_file_partitions=enabled) │ +/// │ │ +/// │ ┌───────────────────────────┐ │ +/// │ │ SortPreservingMergeExec │ Sort Preserved │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ AggregateExec │ No repartitioning needed │ +/// │ │ (SinglePartitioned) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ DataSourceExec │ partitioning=Hash([f_dkey]) │ +/// │ │ file_groups={N groups} │ │ +/// │ └───────────────────────────┘ │ +/// └─────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +/// +/// ┌─────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +/// │ prefer_existing_sort │ +/// │ (preserve_file_partitions=disabled, prefer_existing_sort=true) │ +/// │ │ +/// │ ┌───────────────────────────┐ │ +/// │ │ SortPreservingMergeExec │ Sort Preserved │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ AggregateExec │ │ +/// │ │ (FinalPartitioned) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ RepartitionExec │ Hash shuffle with order preservation │ +/// │ │ Hash([f_dkey], N) │ Uses k-way merge to maintain sort, has overhead │ +/// │ │ preserve_order=true │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ AggregateExec │ │ +/// │ │ (Partial) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ DataSourceExec │ partitioning=UnknownPartitioning │ +/// │ └───────────────────────────┘ │ +/// └─────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +/// +/// ┌─────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +/// │ without_optimization │ +/// │ (preserve_file_partitions=disabled, prefer_existing_sort=false) │ +/// │ │ +/// │ ┌───────────────────────────┐ │ +/// │ │ SortPreservingMergeExec │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ AggregateExec │ FinalPartitioned │ +/// │ │ (FinalPartitioned) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ SortExec │ Must sort after shuffle │ +/// │ │ [f_dkey ASC] │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ RepartitionExec │ Hash shuffle destroys ordering │ +/// │ │ Hash([f_dkey], N) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ AggregateExec │ │ +/// │ │ (Partial) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ DataSourceExec │ partitioning=UnknownPartitioning │ +/// │ └───────────────────────────┘ │ +/// └─────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +fn preserve_order_bench( + c: &mut Criterion, + rt: &Runtime, + hc_fact_path: &str, + target_partitions: usize, +) { + let query = "SELECT f_dkey, COUNT(*) as cnt, SUM(value) as total \ + FROM fact \ + GROUP BY f_dkey \ + ORDER BY f_dkey"; + + let file_sort_order = vec![vec![col("f_dkey").sort(true, false)]]; + + run_benchmark( + c, + rt, + "preserve_order", + hc_fact_path, + None, + target_partitions, + query, + &Some(file_sort_order), + ); +} + +/// Join and aggregate on partition column which demonstrates propagation through join. +/// +/// Query: SELECT f.f_dkey, MAX(d.env), ... FROM fact f JOIN dimension d ON f.f_dkey = d.d_dkey +/// WHERE d.service = 'log' GROUP BY f.f_dkey ORDER BY f.f_dkey +/// +/// ┌─────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +/// │ with_optimization │ +/// │ (preserve_file_partitions=enabled) │ +/// │ │ +/// │ ┌───────────────────────────┐ │ +/// │ │ SortPreservingMergeExec │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ AggregateExec │ Hash partitioning propagates through join │ +/// │ │ (SinglePartitioned) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ HashJoinExec │ Hash partitioning preserved on probe side │ +/// │ │ (CollectLeft) │ │ +/// │ └──────────┬────────────────┘ │ +/// │ │ │ +/// │ ┌──────┴──────┐ │ +/// │ │ │ │ +/// │ ┌───▼───┐ ┌────▼────────────────┐ │ +/// │ │ Dim │ │ DataSourceExec │ partitioning=Hash([f_dkey]), output_ordering=[f_dkey] │ +/// │ │ Table │ │ (fact, N groups) │ │ +/// │ └───────┘ └─────────────────────┘ │ +/// └─────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +/// +/// ┌─────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +/// │ prefer_existing_sort │ +/// │ (preserve_file_partitions=disabled, prefer_existing_sort=true) │ +/// │ │ +/// │ ┌───────────────────────────┐ │ +/// │ │ SortPreservingMergeExec │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ AggregateExec │ │ +/// │ │ (FinalPartitioned) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ RepartitionExec │ Hash shuffle with order preservation │ +/// │ │ preserve_order=true │ Uses k-way merge to maintain sort, has overhead │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ AggregateExec │ │ +/// │ │ (Partial) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ HashJoinExec │ │ +/// │ │ (CollectLeft) │ │ +/// │ └──────────┬────────────────┘ │ +/// │ │ │ +/// │ ┌──────┴──────┐ │ +/// │ │ │ │ +/// │ ┌───▼───┐ ┌────▼────────────────┐ │ +/// │ │ Dim │ │ DataSourceExec │ partitioning=UnknownPartitioning, output_ordering=[f_dkey] │ +/// │ │ Table │ │ (fact) │ │ +/// │ └───────┘ └─────────────────────┘ │ +/// └─────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +/// +/// ┌─────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +/// │ without_optimization │ +/// │ (preserve_file_partitions=disabled, prefer_existing_sort=false) │ +/// │ │ +/// │ ┌───────────────────────────┐ │ +/// │ │ SortPreservingMergeExec │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ AggregateExec │ │ +/// │ │ (FinalPartitioned) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ SortExec │ Must sort after shuffle │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ RepartitionExec │ Hash shuffle destroys ordering │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ AggregateExec │ │ +/// │ │ (Partial) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ HashJoinExec │ │ +/// │ │ (CollectLeft) │ │ +/// │ └──────────┬────────────────┘ │ +/// │ │ │ +/// │ ┌──────┴──────┐ │ +/// │ │ │ │ +/// │ ┌───▼───┐ ┌────▼────────────────┐ │ +/// │ │ Dim │ │ DataSourceExec │ partitioning=UnknownPartitioning, output_ordering=[f_dkey] │ +/// │ │ Table │ │ (fact) │ │ +/// │ └───────┘ └─────────────────────┘ │ +/// └─────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +fn preserve_order_join_bench( + c: &mut Criterion, + rt: &Runtime, + hc_fact_path: &str, + dim_path: &str, + target_partitions: usize, +) { + let query = "SELECT f.f_dkey, MAX(d.env), MAX(d.service), COUNT(*), SUM(f.value) \ + FROM fact f \ + INNER JOIN dimension d ON f.f_dkey = d.d_dkey \ + WHERE d.service = 'log' \ + GROUP BY f.f_dkey \ + ORDER BY f.f_dkey"; + + let file_sort_order = vec![vec![col("f_dkey").sort(true, false)]]; + + run_benchmark( + c, + rt, + "preserve_order_join", + hc_fact_path, + Some(dim_path), + target_partitions, + query, + &Some(file_sort_order), + ); +} + +/// Window function with LIMIT which demonstrates partition and sort elimination. +/// +/// Query: SELECT f_dkey, timestamp, value, +/// ROW_NUMBER() OVER (PARTITION BY f_dkey ORDER BY timestamp) as rn +/// FROM fact LIMIT 1000 +/// +/// ┌─────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +/// │ with_optimization │ +/// │ (preserve_file_partitions=enabled) │ +/// │ │ +/// │ ┌───────────────────────────┐ │ +/// │ │ GlobalLimitExec │ │ +/// │ │ (LIMIT 1000) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ BoundedWindowAggExec │ No repaartition needed │ +/// │ │ PARTITION BY f_dkey │ │ +/// │ │ ORDER BY timestamp │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ DataSourceExec │ partitioning=Hash([f_dkey]), output_ordering=[f_dkey, timestamp] │ +/// │ │ file_groups={N groups} │ │ +/// │ └───────────────────────────┘ │ +/// └─────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +/// +/// ┌─────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +/// │ prefer_existing_sort │ +/// │ (preserve_file_partitions=disabled, prefer_existing_sort=true) │ +/// │ │ +/// │ ┌───────────────────────────┐ │ +/// │ │ GlobalLimitExec │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ BoundedWindowAggExec │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ RepartitionExec │ Hash shuffle with order preservation │ +/// │ │ Hash([f_dkey], N) │ Uses k-way merge to maintain sort, has overhead │ +/// │ │ preserve_order=true │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ DataSourceExec │ partitioning=UnknownPartitioning, output_ordering=[f_dkey, timestamp] │ +/// │ └───────────────────────────┘ │ +/// └─────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +/// +/// ┌─────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +/// │ without_optimization │ +/// │ (preserve_file_partitions=disabled, prefer_existing_sort=false) │ +/// │ │ +/// │ ┌───────────────────────────┐ │ +/// │ │ GlobalLimitExec │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ BoundedWindowAggExec │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ SortExec │ Must sort after shuffle │ +/// │ │ [f_dkey, timestamp] │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ RepartitionExec │ Hash shuffle destroys ordering │ +/// │ │ Hash([f_dkey], N) │ │ +/// │ └─────────────┬─────────────┘ │ +/// │ │ │ +/// │ ┌─────────────▼─────────────┐ │ +/// │ │ DataSourceExec │ partitioning=UnknownPartitioning, output_ordering=[f_dkey, timestamp] │ +/// │ └───────────────────────────┘ │ +/// └─────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +fn preserve_order_window_bench( + c: &mut Criterion, + rt: &Runtime, + fact_path: &str, + target_partitions: usize, +) { + let query = "SELECT f_dkey, timestamp, value, \ + ROW_NUMBER() OVER (PARTITION BY f_dkey ORDER BY timestamp) as rn \ + FROM fact \ + LIMIT 1000"; + + let file_sort_order = vec![vec![ + col("f_dkey").sort(true, false), + col("timestamp").sort(true, false), + ]]; + + run_benchmark( + c, + rt, + "preserve_order_window", + fact_path, + None, + target_partitions, + query, + &Some(file_sort_order), + ); +} + +fn benchmark_main(c: &mut Criterion) { + let config = BenchConfig::from_env(); + let hc_config = BenchConfig::high_cardinality(&config); + + println!("\n=== Preserve File Partitioning Benchmark ==="); + println!( + "Normal config: {} partitions × {} rows = {} total rows", + config.fact_partitions, + config.rows_per_partition, + config.total_rows() + ); + println!( + "High-cardinality config: {} partitions × {} rows = {} total rows", + hc_config.fact_partitions, + hc_config.rows_per_partition, + hc_config.total_rows() + ); + println!("Target partitions: {}\n", config.target_partitions); + + let tmp_dir = TempDir::new().unwrap(); + println!("Generating data..."); + + // High-cardinality fact table + generate_fact_table( + tmp_dir.path(), + hc_config.fact_partitions, + hc_config.rows_per_partition, + ); + let hc_fact_dir = tmp_dir.path().join("fact_hc"); + fs::rename(tmp_dir.path().join("fact"), &hc_fact_dir).unwrap(); + let hc_fact_path = hc_fact_dir.to_str().unwrap().to_string(); + + // Normal fact table + generate_fact_table( + tmp_dir.path(), + config.fact_partitions, + config.rows_per_partition, + ); + let fact_path = tmp_dir.path().join("fact").to_str().unwrap().to_string(); + + // Dimension table (for join) + generate_dimension_table(tmp_dir.path(), hc_config.fact_partitions); + let dim_path = tmp_dir + .path() + .join("dimension") + .to_str() + .unwrap() + .to_string(); + + println!("Done.\n"); + + let rt = Runtime::new().unwrap(); + + preserve_order_bench(c, &rt, &hc_fact_path, hc_config.target_partitions); + preserve_order_join_bench( + c, + &rt, + &hc_fact_path, + &dim_path, + hc_config.target_partitions, + ); + preserve_order_window_bench(c, &rt, &fact_path, config.target_partitions); +} + +criterion_group! { + name = benches; + config = { + let config = BenchConfig::from_env(); + Criterion::default() + .measurement_time(std::time::Duration::from_secs(config.measurement_time_secs)) + .sample_size(10) + }; + targets = benchmark_main +} +criterion_main!(benches); diff --git a/datafusion/core/benches/push_down_filter.rs b/datafusion/core/benches/push_down_filter.rs index 139fb12c30947..d41085907dbc8 100644 --- a/datafusion/core/benches/push_down_filter.rs +++ b/datafusion/core/benches/push_down_filter.rs @@ -18,16 +18,16 @@ use arrow::array::RecordBatch; use arrow::datatypes::{DataType, Field, Schema}; use bytes::{BufMut, BytesMut}; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion::config::ConfigOptions; use datafusion::prelude::{ParquetReadOptions, SessionContext}; use datafusion_execution::object_store::ObjectStoreUrl; -use datafusion_physical_optimizer::filter_pushdown::FilterPushdown; use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_optimizer::filter_pushdown::FilterPushdown; use datafusion_physical_plan::ExecutionPlan; use object_store::memory::InMemory; use object_store::path::Path; -use object_store::ObjectStore; +use object_store::{ObjectStore, ObjectStoreExt}; use parquet::arrow::ArrowWriter; use std::sync::Arc; diff --git a/datafusion/core/benches/range_and_generate_series.rs b/datafusion/core/benches/range_and_generate_series.rs new file mode 100644 index 0000000000000..10d560df0813e --- /dev/null +++ b/datafusion/core/benches/range_and_generate_series.rs @@ -0,0 +1,86 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod data_utils; + +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion::execution::context::SessionContext; +use parking_lot::Mutex; +use std::hint::black_box; +use std::sync::Arc; +use tokio::runtime::Runtime; + +#[expect(clippy::needless_pass_by_value)] +fn query(ctx: Arc>, rt: &Runtime, sql: &str) { + let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); + black_box(rt.block_on(df.collect()).unwrap()); +} + +fn create_context() -> Arc> { + let ctx = SessionContext::new(); + Arc::new(Mutex::new(ctx)) +} + +fn criterion_benchmark(c: &mut Criterion) { + let ctx = create_context(); + let rt = Runtime::new().unwrap(); + + c.bench_function("range(1000000)", |b| { + b.iter(|| query(ctx.clone(), &rt, "SELECT value from range(1000000)")) + }); + + c.bench_function("generate_series(1000000)", |b| { + b.iter(|| { + query( + ctx.clone(), + &rt, + "SELECT value from generate_series(1000000)", + ) + }) + }); + + c.bench_function("range(0, 1000000, 5)", |b| { + b.iter(|| query(ctx.clone(), &rt, "SELECT value from range(0, 1000000, 5)")) + }); + + c.bench_function("generate_series(0, 1000000, 5)", |b| { + b.iter(|| { + query( + ctx.clone(), + &rt, + "SELECT value from generate_series(0, 1000000, 5)", + ) + }) + }); + + c.bench_function("range(1000000, 0, -5)", |b| { + b.iter(|| query(ctx.clone(), &rt, "SELECT value from range(1000000, 0, -5)")) + }); + + c.bench_function("generate_series(1000000, 0, -5)", |b| { + b.iter(|| { + query( + ctx.clone(), + &rt, + "SELECT value from generate_series(1000000, 0, -5)", + ) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/core/benches/reset_plan_states.rs b/datafusion/core/benches/reset_plan_states.rs new file mode 100644 index 0000000000000..5afae7f43242d --- /dev/null +++ b/datafusion/core/benches/reset_plan_states.rs @@ -0,0 +1,200 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::{Arc, LazyLock}; + +use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion::prelude::SessionContext; +use datafusion_catalog::MemTable; +use datafusion_physical_plan::ExecutionPlan; +use datafusion_physical_plan::displayable; +use datafusion_physical_plan::execution_plan::reset_plan_states; +use tokio::runtime::Runtime; + +const NUM_FIELDS: usize = 1000; +const PREDICATE_LEN: usize = 50; + +static SCHEMA: LazyLock = LazyLock::new(|| { + Arc::new(Schema::new( + (0..NUM_FIELDS) + .map(|i| Arc::new(Field::new(format!("x_{i}"), DataType::Int64, false))) + .collect::(), + )) +}); + +fn col_name(i: usize) -> String { + format!("x_{i}") +} + +fn aggr_name(i: usize) -> String { + format!("aggr_{i}") +} + +fn physical_plan( + ctx: &SessionContext, + rt: &Runtime, + sql: &str, +) -> Arc { + rt.block_on(async { + ctx.sql(sql) + .await + .unwrap() + .create_physical_plan() + .await + .unwrap() + }) +} + +fn predicate(col_name: impl Fn(usize) -> String, len: usize) -> String { + let mut predicate = String::new(); + for i in 0..len { + if i > 0 { + predicate.push_str(" AND "); + } + predicate.push_str(&col_name(i)); + predicate.push_str(" = "); + predicate.push_str(&i.to_string()); + } + predicate +} + +/// Returns a typical plan for the query like: +/// +/// ```sql +/// SELECT aggr1(col1) as aggr1, aggr2(col2) as aggr2 FROM t +/// WHERE p1 +/// HAVING p2 +/// ``` +/// +/// Where `p1` and `p2` some long predicates. +/// +fn query1() -> String { + let mut query = String::new(); + query.push_str("SELECT "); + for i in 0..NUM_FIELDS { + if i > 0 { + query.push_str(", "); + } + query.push_str("AVG("); + query.push_str(&col_name(i)); + query.push_str(") AS "); + query.push_str(&aggr_name(i)); + } + query.push_str(" FROM t WHERE "); + query.push_str(&predicate(col_name, PREDICATE_LEN)); + query.push_str(" HAVING "); + query.push_str(&predicate(aggr_name, PREDICATE_LEN)); + query +} + +/// Returns a typical plan for the query like: +/// +/// ```sql +/// SELECT projection FROM t JOIN v ON t.a = v.a +/// WHERE p1 +/// ``` +/// +fn query2() -> String { + let mut query = String::new(); + query.push_str("SELECT "); + for i in (0..NUM_FIELDS).step_by(2) { + if i > 0 { + query.push_str(", "); + } + if (i / 2) % 2 == 0 { + query.push_str(&format!("t.{}", col_name(i))); + } else { + query.push_str(&format!("v.{}", col_name(i))); + } + } + query.push_str(" FROM t JOIN v ON t.x_0 = v.x_0 WHERE "); + + fn qualified_name(i: usize) -> String { + format!("t.{}", col_name(i)) + } + + query.push_str(&predicate(qualified_name, PREDICATE_LEN)); + query +} + +/// Returns a typical plan for the query like: +/// +/// ```sql +/// SELECT projection FROM t +/// WHERE p +/// ``` +/// +fn query3() -> String { + let mut query = String::new(); + query.push_str("SELECT "); + + // Create non-trivial projection. + for i in 0..NUM_FIELDS / 2 { + if i > 0 { + query.push_str(", "); + } + query.push_str(&col_name(i * 2)); + query.push_str(" + "); + query.push_str(&col_name(i * 2 + 1)); + } + + query.push_str(" FROM t WHERE "); + query.push_str(&predicate(col_name, PREDICATE_LEN)); + query +} + +fn run_reset_states(b: &mut criterion::Bencher, plan: &Arc) { + b.iter(|| std::hint::black_box(reset_plan_states(Arc::clone(plan)).unwrap())); +} + +/// Benchmark is intended to measure overhead of actions, required to perform +/// making an independent instance of the execution plan to re-execute it, avoiding +/// re-planning stage. +fn bench_reset_plan_states(c: &mut Criterion) { + env_logger::init(); + + let rt = Runtime::new().unwrap(); + let ctx = SessionContext::new(); + ctx.register_table( + "t", + Arc::new(MemTable::try_new(Arc::clone(&SCHEMA), vec![vec![], vec![]]).unwrap()), + ) + .unwrap(); + + ctx.register_table( + "v", + Arc::new(MemTable::try_new(Arc::clone(&SCHEMA), vec![vec![], vec![]]).unwrap()), + ) + .unwrap(); + + macro_rules! bench_query { + ($query_producer: expr) => {{ + let sql = $query_producer(); + let plan = physical_plan(&ctx, &rt, &sql); + log::debug!("plan:\n{}", displayable(plan.as_ref()).indent(true)); + move |b| run_reset_states(b, &plan) + }}; + } + + c.bench_function("query1", bench_query!(query1)); + c.bench_function("query2", bench_query!(query2)); + c.bench_function("query3", bench_query!(query3)); +} + +criterion_group!(benches, bench_reset_plan_states); +criterion_main!(benches); diff --git a/datafusion/core/benches/scalar.rs b/datafusion/core/benches/scalar.rs index 540f7212e96e9..d06ed3f28b743 100644 --- a/datafusion/core/benches/scalar.rs +++ b/datafusion/core/benches/scalar.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion::scalar::ScalarValue; fn criterion_benchmark(c: &mut Criterion) { diff --git a/datafusion/core/benches/sort.rs b/datafusion/core/benches/sort.rs index e1bc478b36f0a..4ba57a1530e81 100644 --- a/datafusion/core/benches/sort.rs +++ b/datafusion/core/benches/sort.rs @@ -71,7 +71,6 @@ use std::sync::Arc; use arrow::array::StringViewArray; use arrow::{ array::{DictionaryArray, Float64Array, Int64Array, StringArray}, - compute::SortOptions, datatypes::{Int32Type, Schema}, record_batch::RecordBatch, }; @@ -79,18 +78,18 @@ use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::{ execution::context::TaskContext, physical_plan::{ + ExecutionPlan, ExecutionPlanProperties, coalesce_partitions::CoalescePartitionsExec, - sorts::sort_preserving_merge::SortPreservingMergeExec, ExecutionPlan, - ExecutionPlanProperties, + sorts::sort_preserving_merge::SortPreservingMergeExec, }, prelude::SessionContext, }; use datafusion_datasource::memory::MemorySourceConfig; -use datafusion_physical_expr::{expressions::col, PhysicalSortExpr}; +use datafusion_physical_expr::{PhysicalSortExpr, expressions::col}; use datafusion_physical_expr_common::sort_expr::LexOrdering; /// Benchmarks for SortPreservingMerge stream -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use futures::StreamExt; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; @@ -272,14 +271,11 @@ impl BenchCase { /// Make sort exprs for each column in `schema` fn make_sort_exprs(schema: &Schema) -> LexOrdering { - schema + let sort_exprs = schema .fields() .iter() - .map(|f| PhysicalSortExpr { - expr: col(f.name(), schema).unwrap(), - options: SortOptions::default(), - }) - .collect() + .map(|f| PhysicalSortExpr::new_default(col(f.name(), schema).unwrap())); + LexOrdering::new(sort_exprs).unwrap() } /// Create streams of int64 (where approximately 1/3 values is repeated) @@ -359,14 +355,14 @@ fn utf8_high_cardinality_streams(sorted: bool) -> PartitionedBatches { /// Create a batch of (utf8_low, utf8_low, utf8_high) fn utf8_tuple_streams(sorted: bool) -> PartitionedBatches { - let mut gen = DataGenerator::new(); + let mut data_gen = DataGenerator::new(); // need to sort by the combined key, so combine them together - let mut tuples: Vec<_> = gen + let mut tuples: Vec<_> = data_gen .utf8_low_cardinality_values() .into_iter() - .zip(gen.utf8_low_cardinality_values()) - .zip(gen.utf8_high_cardinality_values()) + .zip(data_gen.utf8_low_cardinality_values()) + .zip(data_gen.utf8_high_cardinality_values()) .collect(); if sorted { @@ -392,14 +388,14 @@ fn utf8_tuple_streams(sorted: bool) -> PartitionedBatches { /// Create a batch of (utf8_view_low, utf8_view_low, utf8_view_high) fn utf8_view_tuple_streams(sorted: bool) -> PartitionedBatches { - let mut gen = DataGenerator::new(); + let mut data_gen = DataGenerator::new(); // need to sort by the combined key, so combine them together - let mut tuples: Vec<_> = gen + let mut tuples: Vec<_> = data_gen .utf8_low_cardinality_values() .into_iter() - .zip(gen.utf8_low_cardinality_values()) - .zip(gen.utf8_high_cardinality_values()) + .zip(data_gen.utf8_low_cardinality_values()) + .zip(data_gen.utf8_high_cardinality_values()) .collect(); if sorted { @@ -425,15 +421,15 @@ fn utf8_view_tuple_streams(sorted: bool) -> PartitionedBatches { /// Create a batch of (f64, utf8_low, utf8_low, i64) fn mixed_tuple_streams(sorted: bool) -> PartitionedBatches { - let mut gen = DataGenerator::new(); + let mut data_gen = DataGenerator::new(); // need to sort by the combined key, so combine them together - let mut tuples: Vec<_> = gen + let mut tuples: Vec<_> = data_gen .i64_values() .into_iter() - .zip(gen.utf8_low_cardinality_values()) - .zip(gen.utf8_low_cardinality_values()) - .zip(gen.i64_values()) + .zip(data_gen.utf8_low_cardinality_values()) + .zip(data_gen.utf8_low_cardinality_values()) + .zip(data_gen.i64_values()) .collect(); if sorted { @@ -463,15 +459,15 @@ fn mixed_tuple_streams(sorted: bool) -> PartitionedBatches { /// Create a batch of (f64, utf8_view_low, utf8_view_low, i64) fn mixed_tuple_with_utf8_view_streams(sorted: bool) -> PartitionedBatches { - let mut gen = DataGenerator::new(); + let mut data_gen = DataGenerator::new(); // need to sort by the combined key, so combine them together - let mut tuples: Vec<_> = gen + let mut tuples: Vec<_> = data_gen .i64_values() .into_iter() - .zip(gen.utf8_low_cardinality_values()) - .zip(gen.utf8_low_cardinality_values()) - .zip(gen.i64_values()) + .zip(data_gen.utf8_low_cardinality_values()) + .zip(data_gen.utf8_low_cardinality_values()) + .zip(data_gen.i64_values()) .collect(); if sorted { @@ -501,8 +497,8 @@ fn mixed_tuple_with_utf8_view_streams(sorted: bool) -> PartitionedBatches { /// Create a batch of (utf8_dict) fn dictionary_streams(sorted: bool) -> PartitionedBatches { - let mut gen = DataGenerator::new(); - let mut values = gen.utf8_low_cardinality_values(); + let mut data_gen = DataGenerator::new(); + let mut values = data_gen.utf8_low_cardinality_values(); if sorted { values.sort_unstable(); } @@ -516,12 +512,12 @@ fn dictionary_streams(sorted: bool) -> PartitionedBatches { /// Create a batch of (utf8_dict, utf8_dict, utf8_dict) fn dictionary_tuple_streams(sorted: bool) -> PartitionedBatches { - let mut gen = DataGenerator::new(); - let mut tuples: Vec<_> = gen + let mut data_gen = DataGenerator::new(); + let mut tuples: Vec<_> = data_gen .utf8_low_cardinality_values() .into_iter() - .zip(gen.utf8_low_cardinality_values()) - .zip(gen.utf8_low_cardinality_values()) + .zip(data_gen.utf8_low_cardinality_values()) + .zip(data_gen.utf8_low_cardinality_values()) .collect(); if sorted { @@ -547,13 +543,13 @@ fn dictionary_tuple_streams(sorted: bool) -> PartitionedBatches { /// Create a batch of (utf8_dict, utf8_dict, utf8_dict, i64) fn mixed_dictionary_tuple_streams(sorted: bool) -> PartitionedBatches { - let mut gen = DataGenerator::new(); - let mut tuples: Vec<_> = gen + let mut data_gen = DataGenerator::new(); + let mut tuples: Vec<_> = data_gen .utf8_low_cardinality_values() .into_iter() - .zip(gen.utf8_low_cardinality_values()) - .zip(gen.utf8_low_cardinality_values()) - .zip(gen.i64_values()) + .zip(data_gen.utf8_low_cardinality_values()) + .zip(data_gen.utf8_low_cardinality_values()) + .zip(data_gen.i64_values()) .collect(); if sorted { diff --git a/datafusion/core/benches/sort_limit_query_sql.rs b/datafusion/core/benches/sort_limit_query_sql.rs index e535a018161f1..54cd9a0bcd547 100644 --- a/datafusion/core/benches/sort_limit_query_sql.rs +++ b/datafusion/core/benches/sort_limit_query_sql.rs @@ -15,9 +15,7 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -use criterion::Criterion; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion::datasource::file_format::csv::CsvFormat; use datafusion::datasource::listing::{ ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, @@ -27,9 +25,6 @@ use datafusion::prelude::SessionConfig; use parking_lot::Mutex; use std::sync::Arc; -extern crate arrow; -extern crate datafusion; - use arrow::datatypes::{DataType, Field, Schema}; use datafusion::datasource::MemTable; @@ -37,6 +32,7 @@ use datafusion::execution::context::SessionContext; use tokio::runtime::Runtime; +#[expect(clippy::needless_pass_by_value)] fn query(ctx: Arc>, rt: &Runtime, sql: &str) { // execute the query let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); @@ -97,8 +93,7 @@ fn create_context() -> Arc> { ctx_holder.lock().push(Arc::new(Mutex::new(ctx))) }); - let ctx = ctx_holder.lock().first().unwrap().clone(); - ctx + ctx_holder.lock().first().unwrap().clone() } fn criterion_benchmark(c: &mut Criterion) { diff --git a/datafusion/core/benches/spm.rs b/datafusion/core/benches/spm.rs index 8613525cb248d..afd384f7b170e 100644 --- a/datafusion/core/benches/spm.rs +++ b/datafusion/core/benches/spm.rs @@ -15,18 +15,18 @@ // specific language governing permissions and limitations // under the License. +use std::hint::black_box; use std::sync::Arc; use arrow::array::{ArrayRef, Int32Array, Int64Array, RecordBatch, StringArray}; use datafusion_execution::TaskContext; -use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::PhysicalSortExpr; -use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_expr::expressions::col; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; -use datafusion_physical_plan::{collect, ExecutionPlan}; +use datafusion_physical_plan::{ExecutionPlan, collect}; use criterion::async_executor::FuturesExecutor; -use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_datasource::memory::MemorySourceConfig; fn generate_spm_for_round_robin_tie_breaker( @@ -66,11 +66,10 @@ fn generate_spm_for_round_robin_tie_breaker( RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap() }; - let rbs = (0..batch_count).map(|_| rb.clone()).collect::>(); - let partitiones = vec![rbs.clone(); partition_count]; - let schema = rb.schema(); - let sort = LexOrdering::new(vec![ + let rbs = std::iter::repeat_n(rb, batch_count).collect::>(); + let partitions = vec![rbs.clone(); partition_count]; + let sort = [ PhysicalSortExpr { expr: col("b", &schema).unwrap(), options: Default::default(), @@ -79,9 +78,10 @@ fn generate_spm_for_round_robin_tie_breaker( expr: col("c", &schema).unwrap(), options: Default::default(), }, - ]); + ] + .into(); - let exec = MemorySourceConfig::try_new_exec(&partitiones, schema, None).unwrap(); + let exec = MemorySourceConfig::try_new_exec(&partitions, schema, None).unwrap(); SortPreservingMergeExec::new(sort, exec) .with_round_robin_repartition(enable_round_robin_repartition) } diff --git a/datafusion/core/benches/sql_planner.rs b/datafusion/core/benches/sql_planner.rs index 6dc953f56b435..59502da987904 100644 --- a/datafusion/core/benches/sql_planner.rs +++ b/datafusion/core/benches/sql_planner.rs @@ -15,29 +15,26 @@ // specific language governing permissions and limitations // under the License. -extern crate arrow; -#[macro_use] -extern crate criterion; -extern crate datafusion; - mod data_utils; -use crate::criterion::Criterion; +use arrow::array::PrimitiveArray; use arrow::array::{ArrayRef, RecordBatch}; +use arrow::datatypes::ArrowNativeTypeOp; +use arrow::datatypes::ArrowPrimitiveType; use arrow::datatypes::{DataType, Field, Fields, Schema}; use criterion::Bencher; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion::datasource::MemTable; use datafusion::execution::context::SessionContext; -use datafusion_common::ScalarValue; +use datafusion_common::{ScalarValue, config::Dialect}; use datafusion_expr::col; -use itertools::Itertools; -use std::fs::File; -use std::io::{BufRead, BufReader}; +use rand_distr::num_traits::NumCast; +use std::hint::black_box; use std::path::PathBuf; use std::sync::Arc; +use test_utils::TableDef; use test_utils::tpcds::tpcds_schemas; use test_utils::tpch::tpch_schemas; -use test_utils::TableDef; use tokio::runtime::Runtime; const BENCHMARKS_PATH_1: &str = "../../benchmarks/"; @@ -46,12 +43,12 @@ const CLICKBENCH_DATA_PATH: &str = "data/hits_partitioned/"; /// Create a logical plan from the specified sql fn logical_plan(ctx: &SessionContext, rt: &Runtime, sql: &str) { - criterion::black_box(rt.block_on(ctx.sql(sql)).unwrap()); + black_box(rt.block_on(ctx.sql(sql)).unwrap()); } /// Create a physical ExecutionPlan (by way of logical plan) fn physical_plan(ctx: &SessionContext, rt: &Runtime, sql: &str) { - criterion::black_box(rt.block_on(async { + black_box(rt.block_on(async { ctx.sql(sql) .await .unwrap() @@ -76,6 +73,21 @@ fn create_table_provider(column_prefix: &str, num_columns: usize) -> Arc Arc { + let struct_fields = Fields::from(vec![ + Field::new("value", DataType::Int32, true), + Field::new("label", DataType::Utf8, true), + ]); + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, true), + Field::new("props", DataType::Struct(struct_fields), true), + ])); + MemTable::try_new(schema, vec![vec![]]) + .map(Arc::new) + .unwrap() +} + fn create_context() -> SessionContext { let ctx = SessionContext::new(); ctx.register_table("t1", create_table_provider("a", 200)) @@ -86,11 +98,16 @@ fn create_context() -> SessionContext { .unwrap(); ctx.register_table("t1000", create_table_provider("d", 1000)) .unwrap(); + ctx.register_table("struct_t1", create_struct_table_provider()) + .unwrap(); + ctx.register_table("struct_t2", create_struct_table_provider()) + .unwrap(); ctx } /// Register the table definitions as a MemTable with the context and return the /// context +#[expect(clippy::needless_pass_by_value)] fn register_defs(ctx: SessionContext, defs: Vec) -> SessionContext { defs.iter().for_each(|TableDef { name, schema }| { ctx.register_table( @@ -115,6 +132,11 @@ fn register_clickbench_hits_table(rt: &Runtime) -> SessionContext { let sql = format!("CREATE EXTERNAL TABLE hits STORED AS PARQUET LOCATION '{path}'"); + // ClickBench partitioned dataset was written by an ancient version of pyarrow that + // that wrote strings with the wrong logical type. To read it correctly, we must + // automatically convert binary to string. + rt.block_on(ctx.sql("SET datafusion.execution.parquet.binary_as_string = true;")) + .unwrap(); rt.block_on(ctx.sql(&sql)).unwrap(); let count = @@ -140,12 +162,15 @@ fn benchmark_with_param_values_many_columns( } // SELECT max(attr0), ..., max(attrN) FROM t1. let query = format!("SELECT {aggregates} FROM t1"); - let statement = ctx.state().sql_to_statement(&query, "Generic").unwrap(); + let statement = ctx + .state() + .sql_to_statement(&query, &Dialect::Generic) + .unwrap(); let plan = rt.block_on(async { ctx.state().statement_to_plan(statement).await.unwrap() }); b.iter(|| { let plan = plan.clone(); - criterion::black_box(plan.with_param_values(vec![ScalarValue::from(1)]).unwrap()); + black_box(plan.with_param_values(vec![ScalarValue::from(1)]).unwrap()); }); } @@ -154,18 +179,30 @@ fn benchmark_with_param_values_many_columns( /// 0,100...9900 /// 0,200...19800 /// 0,300...29700 -fn register_union_order_table(ctx: &SessionContext, num_columns: usize, num_rows: usize) { - // ("c0", [0, 0, ...]) - // ("c1": [100, 200, ...]) - // etc - let iter = (0..num_columns).map(|i| i as u64).map(|i| { - let array: ArrayRef = Arc::new(arrow::array::UInt64Array::from_iter_values( - (0..num_rows) - .map(|j| j as u64 * 100 + i) - .collect::>(), - )); +fn register_union_order_table_generic( + ctx: &SessionContext, + num_columns: usize, + num_rows: usize, +) where + T: ArrowPrimitiveType, + T::Native: ArrowNativeTypeOp + NumCast, +{ + let iter = (0..num_columns).map(|i| { + let array_data: Vec = (0..num_rows) + .map(|j| { + let value = (j as u64) * 100 + (i as u64); + ::from(value).unwrap_or_else(|| { + panic!("Failed to cast numeric value to Native type") + }) + }) + .collect(); + + // Use PrimitiveArray which is generic over the ArrowPrimitiveType T + let array: ArrayRef = Arc::new(PrimitiveArray::::from_iter_values(array_data)); + (format!("c{i}"), array) }); + let batch = RecordBatch::try_from_iter(iter).unwrap(); let schema = batch.schema(); let partitions = vec![vec![batch]]; @@ -182,14 +219,13 @@ fn register_union_order_table(ctx: &SessionContext, num_columns: usize, num_rows ctx.register_table("t", Arc::new(table)).unwrap(); } - /// return a query like /// ```sql -/// select c1, null as c2, ... null as cn from t ORDER BY c1 +/// select c1, 2 as c2, ... n as cn from t ORDER BY c1 /// UNION ALL -/// select null as c1, c2, ... null as cn from t ORDER BY c2 +/// select 1 as c1, c2, ... n as cn from t ORDER BY c2 /// ... -/// select null as c1, null as c2, ... cn from t ORDER BY cn +/// select 1 as c1, 2 as c2, ... cn from t ORDER BY cn /// ORDER BY c1, c2 ... CN /// ``` fn union_orderby_query(n: usize) -> String { @@ -203,7 +239,7 @@ fn union_orderby_query(n: usize) -> String { if i == j { format!("c{j}") } else { - format!("null as c{j}") + format!("{j} as c{j}") } }) .collect::>() @@ -225,8 +261,10 @@ fn criterion_benchmark(c: &mut Criterion) { if !PathBuf::from(format!("{BENCHMARKS_PATH_1}{CLICKBENCH_DATA_PATH}")).exists() && !PathBuf::from(format!("{BENCHMARKS_PATH_2}{CLICKBENCH_DATA_PATH}")).exists() { - panic!("benchmarks/data/hits_partitioned/ could not be loaded. Please run \ - 'benchmarks/bench.sh data clickbench_partitioned' prior to running this benchmark") + panic!( + "benchmarks/data/hits_partitioned/ could not be loaded. Please run \ + 'benchmarks/bench.sh data clickbench_partitioned' prior to running this benchmark" + ) } let ctx = create_context(); @@ -301,6 +339,34 @@ fn criterion_benchmark(c: &mut Criterion) { }); }); + // It was observed in production that queries with window functions sometimes partition over more than 30 columns + for partitioning_columns in [4, 7, 8, 12, 30] { + c.bench_function( + &format!( + "physical_window_function_partition_by_{partitioning_columns}_on_values" + ), + |b| { + let source = format!( + "SELECT 1 AS n{}", + (0..partitioning_columns) + .map(|i| format!(", {i} AS c{i}")) + .collect::() + ); + let window = format!( + "SUM(n) OVER (PARTITION BY {}) AS sum_n", + (0..partitioning_columns) + .map(|i| format!("c{i}")) + .collect::>() + .join(", ") + ); + let query = format!("SELECT {window} FROM ({source})"); + b.iter(|| { + physical_plan(&ctx, &rt, &query); + }); + }, + ); + } + // Benchmark for Physical Planning Joins c.bench_function("physical_join_consider_sort", |b| { b.iter(|| { @@ -372,16 +438,70 @@ fn criterion_benchmark(c: &mut Criterion) { }); }); - // -- Sorted Queries -- - register_union_order_table(&ctx, 100, 1000); - - // this query has many expressions in its sort order so stresses - // order equivalence validation - c.bench_function("physical_sorted_union_orderby", |b| { - // SELECT ... UNION ALL ... - let query = union_orderby_query(20); - b.iter(|| physical_plan(&ctx, &rt, &query)) + let struct_agg_sort_query = "SELECT \ + struct_t1.props['label'], \ + SUM(struct_t1.props['value']), \ + MAX(struct_t2.props['value']), \ + COUNT(*) \ + FROM struct_t1 \ + JOIN struct_t2 ON struct_t1.id = struct_t2.id \ + WHERE struct_t1.props['value'] > 50 \ + GROUP BY struct_t1.props['label'] \ + ORDER BY SUM(struct_t1.props['value']) DESC"; + + // -- Struct column benchmarks -- + c.bench_function("logical_plan_struct_join_agg_sort", |b| { + b.iter(|| logical_plan(&ctx, &rt, struct_agg_sort_query)) }); + c.bench_function("physical_plan_struct_join_agg_sort", |b| { + b.iter(|| physical_plan(&ctx, &rt, struct_agg_sort_query)) + }); + + // -- Sorted Queries -- + // 100, 200 && 300 is taking too long - https://github.com/apache/datafusion/issues/18366 + // Logical Plan for datatype Int64 and UInt64 differs, UInt64 Logical Plan's Union are wrapped + // up in Projection, and EliminateNestedUnion OptimezerRule is not applied leading to significantly + // longer execution time. + // https://github.com/apache/datafusion/issues/17261 + + for column_count in [10, 50 /* 100, 200, 300 */] { + register_union_order_table_generic::( + &ctx, + column_count, + 1000, + ); + + // this query has many expressions in its sort order so stresses + // order equivalence validation + c.bench_function( + &format!("physical_sorted_union_order_by_{column_count}_int64"), + |b| { + // SELECT ... UNION ALL ... + let query = union_orderby_query(column_count); + b.iter(|| physical_plan(&ctx, &rt, &query)) + }, + ); + + let _ = ctx.deregister_table("t"); + } + + for column_count in [10, 50 /* 100, 200, 300 */] { + register_union_order_table_generic::( + &ctx, + column_count, + 1000, + ); + c.bench_function( + &format!("physical_sorted_union_order_by_{column_count}_uint64"), + |b| { + // SELECT ... UNION ALL ... + let query = union_orderby_query(column_count); + b.iter(|| physical_plan(&ctx, &rt, &query)) + }, + ); + + let _ = ctx.deregister_table("t"); + } // --- TPC-H --- @@ -466,17 +586,20 @@ fn criterion_benchmark(c: &mut Criterion) { // }); // -- clickbench -- - - let queries_file = - File::open(format!("{benchmarks_path}queries/clickbench/queries.sql")).unwrap(); - let extended_file = - File::open(format!("{benchmarks_path}queries/clickbench/extended.sql")).unwrap(); - - let clickbench_queries: Vec = BufReader::new(queries_file) - .lines() - .chain(BufReader::new(extended_file).lines()) - .map(|l| l.expect("Could not parse line")) - .collect_vec(); + let clickbench_queries = (0..=42) + .map(|q| { + std::fs::read_to_string(format!( + "{benchmarks_path}queries/clickbench/queries/q{q}.sql" + )) + .unwrap() + }) + .chain((0..=7).map(|q| { + std::fs::read_to_string(format!( + "{benchmarks_path}queries/clickbench/extended/q{q}.sql" + )) + .unwrap() + })) + .collect::>(); let clickbench_ctx = register_clickbench_hits_table(&rt); diff --git a/datafusion/core/benches/sql_planner_extended.rs b/datafusion/core/benches/sql_planner_extended.rs new file mode 100644 index 0000000000000..d4955313c79c3 --- /dev/null +++ b/datafusion/core/benches/sql_planner_extended.rs @@ -0,0 +1,466 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, RecordBatch}; +use arrow_schema::DataType; +use arrow_schema::TimeUnit::Nanosecond; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use datafusion::prelude::{DataFrame, SessionContext}; +use datafusion_catalog::MemTable; +use datafusion_common::ScalarValue; +use datafusion_expr::Expr::Literal; +use datafusion_expr::{cast, col, lit, not, try_cast, when}; +use datafusion_functions::expr_fn::{ + btrim, length, regexp_like, regexp_replace, to_timestamp, upper, +}; +use std::fmt::Write; +use std::hint::black_box; +use std::ops::Rem; +use std::sync::Arc; +use tokio::runtime::Runtime; + +// This benchmark suite is designed to test the performance of +// logical planning with a large plan containing unions, many columns +// with a variety of operations in it. +// +// Since it is (currently) very slow to execute it has been separated +// out from the sql_planner benchmark suite to this file. +// +// See https://github.com/apache/datafusion/issues/17261 for details. + +/// Registers a table like this: +/// c0,c1,c2...,c99 +/// "0","100"..."9900" +/// "0","200"..."19800" +/// "0","300"..."29700" +fn register_string_table(ctx: &SessionContext, num_columns: usize, num_rows: usize) { + // ("c0", ["0", "0", ...]) + // ("c1": ["100", "200", ...]) + // etc + let iter = (0..num_columns).map(|i| i as u64).map(|i| { + let array: ArrayRef = Arc::new(arrow::array::StringViewArray::from_iter_values( + (0..num_rows) + .map(|j| format!("c{}", j as u64 * 100 + i)) + .collect::>(), + )); + (format!("c{i}"), array) + }); + let batch = RecordBatch::try_from_iter(iter).unwrap(); + let schema = batch.schema(); + let partitions = vec![vec![batch]]; + + // create the table + let table = MemTable::try_new(schema, partitions).unwrap(); + + ctx.register_table("t", Arc::new(table)).unwrap(); +} + +/// Build a dataframe for testing logical plan optimization +fn build_test_data_frame(ctx: &SessionContext, rt: &Runtime) -> DataFrame { + register_string_table(ctx, 100, 1000); + + rt.block_on(async { + let mut df = ctx.table("t").await.unwrap(); + // add some columns in + for i in 100..150 { + df = df + .with_column(&format!("c{i}"), Literal(ScalarValue::Utf8(None), None)) + .unwrap(); + } + // add in some columns with string encoded timestamps + for i in 150..175 { + df = df + .with_column( + &format!("c{i}"), + Literal(ScalarValue::Utf8(Some("2025-08-21 09:43:17".into())), None), + ) + .unwrap(); + } + // do a bunch of ops on the columns + for i in 0..175 { + // trim the columns + df = df + .with_column(&format!("c{i}"), btrim(vec![col(format!("c{i}"))])) + .unwrap(); + } + + for i in 0..175 { + let c_name = format!("c{i}"); + let c = col(&c_name); + + // random ops + if i % 5 == 0 && i < 150 { + // the actual ops here are largely unimportant as they are just a sample + // of ops that could occur on a dataframe + df = df + .with_column(&c_name, cast(c.clone(), DataType::Utf8)) + .unwrap() + .with_column( + &c_name, + when( + cast(c.clone(), DataType::Int32).gt(lit(135)), + cast( + cast(c.clone(), DataType::Int32) - lit(i + 3), + DataType::Utf8, + ), + ) + .otherwise(c.clone()) + .unwrap(), + ) + .unwrap() + .with_column( + &c_name, + when( + c.clone().is_not_null().and( + cast(c.clone(), DataType::Int32) + .between(lit(120), lit(130)), + ), + Literal(ScalarValue::Utf8(None), None), + ) + .otherwise( + when( + c.clone().is_not_null().and(regexp_like( + cast(c.clone(), DataType::Utf8View), + lit("[0-9]*"), + None, + )), + upper(c.clone()), + ) + .otherwise(c.clone()) + .unwrap(), + ) + .unwrap(), + ) + .unwrap() + .with_column( + &c_name, + when( + c.clone().is_not_null().and( + cast(c.clone(), DataType::Int32) + .between(lit(90), lit(100)), + ), + cast(c.clone(), DataType::Utf8View), + ) + .otherwise(Literal(ScalarValue::Date32(None), None)) + .unwrap(), + ) + .unwrap() + .with_column( + &c_name, + when( + c.clone().is_not_null().and( + cast(c.clone(), DataType::Int32).rem(lit(10)).gt(lit(7)), + ), + regexp_replace( + cast(c.clone(), DataType::Utf8View), + lit("1"), + lit("a"), + None, + ), + ) + .otherwise(Literal(ScalarValue::Date32(None), None)) + .unwrap(), + ) + .unwrap() + } + if i >= 150 { + df = df + .with_column( + &c_name, + try_cast( + to_timestamp(vec![c.clone(), lit("%Y-%m-%d %H:%M:%S")]), + DataType::Timestamp(Nanosecond, Some("UTC".into())), + ), + ) + .unwrap() + .with_column(&c_name, try_cast(c.clone(), DataType::Date32)) + .unwrap() + } + + // add in a few unions + if i % 30 == 0 { + let df1 = df + .clone() + .filter(length(c.clone()).gt(lit(2))) + .unwrap() + .with_column(&format!("c{i}_filtered"), lit(true)) + .unwrap(); + let df2 = df + .filter(not(length(c.clone()).gt(lit(2)))) + .unwrap() + .with_column(&format!("c{i}_filtered"), lit(false)) + .unwrap(); + + df = df1.union_by_name(df2).unwrap() + } + } + + df + }) +} + +/// Build a CASE-heavy dataframe over a non-inner join to stress +/// planner-time filter pushdown and nullability/type inference. +fn build_case_heavy_left_join_df(ctx: &SessionContext, rt: &Runtime) -> DataFrame { + register_string_table(ctx, 100, 1000); + let query = build_case_heavy_left_join_query(30, 1); + rt.block_on(async { ctx.sql(&query).await.unwrap() }) +} + +fn build_case_heavy_left_join_query(predicate_count: usize, case_depth: usize) -> String { + let mut query = String::from( + "SELECT l.c0, r.c0 AS rc0 FROM t l LEFT JOIN t r ON l.c0 = r.c0 WHERE ", + ); + + if predicate_count == 0 { + query.push_str("TRUE"); + return query; + } + + // Keep this deterministic so comparisons between profiles are stable. + for i in 0..predicate_count { + if i > 0 { + query.push_str(" AND "); + } + + let mut expr = format!("length(l.c{})", i % 20); + for depth in 0..case_depth { + let left_col = (i + depth + 1) % 20; + let right_col = (i + depth + 2) % 20; + expr = format!( + "CASE WHEN l.c{left_col} IS NOT NULL THEN {expr} ELSE length(r.c{right_col}) END" + ); + } + + let _ = write!(&mut query, "{expr} > 2"); + } + + query +} + +fn build_case_heavy_left_join_df_with_push_down_filter( + rt: &Runtime, + predicate_count: usize, + case_depth: usize, + push_down_filter_enabled: bool, +) -> DataFrame { + let ctx = SessionContext::new(); + register_string_table(&ctx, 100, 1000); + if !push_down_filter_enabled { + let removed = ctx.remove_optimizer_rule("push_down_filter"); + assert!( + removed, + "push_down_filter rule should be present in the default optimizer" + ); + } + + let query = build_case_heavy_left_join_query(predicate_count, case_depth); + rt.block_on(async { ctx.sql(&query).await.unwrap() }) +} + +fn build_non_case_left_join_query( + predicate_count: usize, + nesting_depth: usize, +) -> String { + let mut query = String::from( + "SELECT l.c0, r.c0 AS rc0 FROM t l LEFT JOIN t r ON l.c0 = r.c0 WHERE ", + ); + + if predicate_count == 0 { + query.push_str("TRUE"); + return query; + } + + // Keep this deterministic so comparisons between profiles are stable. + for i in 0..predicate_count { + if i > 0 { + query.push_str(" AND "); + } + + let left_col = i % 20; + let mut expr = format!("l.c{left_col}"); + for depth in 0..nesting_depth { + let right_col = (i + depth + 1) % 20; + expr = format!("coalesce({expr}, r.c{right_col})"); + } + + let _ = write!(&mut query, "length({expr}) > 2"); + } + + query +} + +fn build_non_case_left_join_df_with_push_down_filter( + rt: &Runtime, + predicate_count: usize, + nesting_depth: usize, + push_down_filter_enabled: bool, +) -> DataFrame { + let ctx = SessionContext::new(); + register_string_table(&ctx, 100, 1000); + if !push_down_filter_enabled { + let removed = ctx.remove_optimizer_rule("push_down_filter"); + assert!( + removed, + "push_down_filter rule should be present in the default optimizer" + ); + } + + let query = build_non_case_left_join_query(predicate_count, nesting_depth); + rt.block_on(async { ctx.sql(&query).await.unwrap() }) +} + +fn criterion_benchmark(c: &mut Criterion) { + let baseline_ctx = SessionContext::new(); + let case_heavy_ctx = SessionContext::new(); + let rt = Runtime::new().unwrap(); + + // validate logical plan optimize performance + // https://github.com/apache/datafusion/issues/17261 + + let df = build_test_data_frame(&baseline_ctx, &rt); + let case_heavy_left_join_df = build_case_heavy_left_join_df(&case_heavy_ctx, &rt); + + c.bench_function("logical_plan_optimize", |b| { + b.iter(|| { + let df_clone = df.clone(); + black_box(rt.block_on(async { df_clone.into_optimized_plan().unwrap() })); + }) + }); + + c.bench_function("logical_plan_optimize_hotspot_case_heavy_left_join", |b| { + b.iter(|| { + let df_clone = case_heavy_left_join_df.clone(); + black_box(rt.block_on(async { df_clone.into_optimized_plan().unwrap() })); + }) + }); + + let predicate_sweep = [10, 20, 30, 40, 60]; + let case_depth_sweep = [1, 2, 3]; + + let mut hotspot_group = + c.benchmark_group("push_down_filter_hotspot_case_heavy_left_join_ab"); + for case_depth in case_depth_sweep { + for predicate_count in predicate_sweep { + let with_push_down_filter = + build_case_heavy_left_join_df_with_push_down_filter( + &rt, + predicate_count, + case_depth, + true, + ); + let without_push_down_filter = + build_case_heavy_left_join_df_with_push_down_filter( + &rt, + predicate_count, + case_depth, + false, + ); + + let input_label = + format!("predicates={predicate_count},case_depth={case_depth}"); + // A/B interpretation: + // - with_push_down_filter: default optimizer path (rule enabled) + // - without_push_down_filter: control path with the rule removed + // Compare both IDs at the same sweep point to isolate rule impact. + hotspot_group.bench_with_input( + BenchmarkId::new("with_push_down_filter", &input_label), + &with_push_down_filter, + |b, df| { + b.iter(|| { + let df_clone = df.clone(); + black_box( + rt.block_on(async { + df_clone.into_optimized_plan().unwrap() + }), + ); + }) + }, + ); + hotspot_group.bench_with_input( + BenchmarkId::new("without_push_down_filter", &input_label), + &without_push_down_filter, + |b, df| { + b.iter(|| { + let df_clone = df.clone(); + black_box( + rt.block_on(async { + df_clone.into_optimized_plan().unwrap() + }), + ); + }) + }, + ); + } + } + hotspot_group.finish(); + + let mut control_group = + c.benchmark_group("push_down_filter_control_non_case_left_join_ab"); + for nesting_depth in case_depth_sweep { + for predicate_count in predicate_sweep { + let with_push_down_filter = build_non_case_left_join_df_with_push_down_filter( + &rt, + predicate_count, + nesting_depth, + true, + ); + let without_push_down_filter = + build_non_case_left_join_df_with_push_down_filter( + &rt, + predicate_count, + nesting_depth, + false, + ); + + let input_label = + format!("predicates={predicate_count},nesting_depth={nesting_depth}"); + control_group.bench_with_input( + BenchmarkId::new("with_push_down_filter", &input_label), + &with_push_down_filter, + |b, df| { + b.iter(|| { + let df_clone = df.clone(); + black_box( + rt.block_on(async { + df_clone.into_optimized_plan().unwrap() + }), + ); + }) + }, + ); + control_group.bench_with_input( + BenchmarkId::new("without_push_down_filter", &input_label), + &without_push_down_filter, + |b, df| { + b.iter(|| { + let df_clone = df.clone(); + black_box( + rt.block_on(async { + df_clone.into_optimized_plan().unwrap() + }), + ); + }) + }, + ); + } + } + control_group.finish(); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/core/benches/sql_query_with_io.rs b/datafusion/core/benches/sql_query_with_io.rs index 58797dfed6b67..fc8caf31acd11 100644 --- a/datafusion/core/benches/sql_query_with_io.rs +++ b/datafusion/core/benches/sql_query_with_io.rs @@ -20,7 +20,7 @@ use std::{fmt::Write, sync::Arc, time::Duration}; use arrow::array::{Int64Builder, RecordBatch, UInt64Builder}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use bytes::Bytes; -use criterion::{criterion_group, criterion_main, Criterion, SamplingMode}; +use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; use datafusion::{ datasource::{ file_format::parquet::ParquetFormat, @@ -31,13 +31,13 @@ use datafusion::{ use datafusion_execution::runtime_env::RuntimeEnv; use itertools::Itertools; use object_store::{ + ObjectStore, ObjectStoreExt, memory::InMemory, path::Path, throttle::{ThrottleConfig, ThrottledStore}, - ObjectStore, }; use parquet::arrow::ArrowWriter; -use rand::{rngs::StdRng, Rng, SeedableRng}; +use rand::{Rng, SeedableRng, rngs::StdRng}; use tokio::runtime::Runtime; use url::Url; diff --git a/datafusion/core/benches/struct_query_sql.rs b/datafusion/core/benches/struct_query_sql.rs index f9cc43d1ea2c5..96434fc379ea6 100644 --- a/datafusion/core/benches/struct_query_sql.rs +++ b/datafusion/core/benches/struct_query_sql.rs @@ -20,17 +20,18 @@ use arrow::{ datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion::prelude::SessionContext; use datafusion::{datasource::MemTable, error::Result}; use futures::executor::block_on; +use std::hint::black_box; use std::sync::Arc; use tokio::runtime::Runtime; async fn query(ctx: &SessionContext, rt: &Runtime, sql: &str) { // execute the query let df = rt.block_on(ctx.sql(sql)).unwrap(); - criterion::black_box(rt.block_on(df.collect()).unwrap()); + black_box(rt.block_on(df.collect()).unwrap()); } fn create_context(array_len: usize, batch_size: usize) -> Result { diff --git a/datafusion/core/benches/topk_aggregate.rs b/datafusion/core/benches/topk_aggregate.rs index cf3c7fa2e26fe..f71cf1087be7d 100644 --- a/datafusion/core/benches/topk_aggregate.rs +++ b/datafusion/core/benches/topk_aggregate.rs @@ -16,25 +16,71 @@ // under the License. mod data_utils; + +use arrow::array::Int64Builder; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use data_utils::make_data; -use datafusion::physical_plan::{collect, displayable, ExecutionPlan}; +use datafusion::physical_plan::{collect, displayable}; use datafusion::prelude::SessionContext; use datafusion::{datasource::MemTable, error::Result}; use datafusion_execution::config::SessionConfig; -use datafusion_execution::TaskContext; +use rand::SeedableRng; +use rand::seq::SliceRandom; +use std::hint::black_box; use std::sync::Arc; use tokio::runtime::Runtime; +const LIMIT: usize = 10; + +/// Create deterministic data for DISTINCT benchmarks with predictable trace_ids +/// This ensures consistent results across benchmark runs +fn make_distinct_data( + partition_cnt: i32, + sample_cnt: i32, +) -> Result<(Arc, Vec>)> { + let mut rng = rand::rngs::SmallRng::from_seed([42; 32]); + let total_samples = partition_cnt as usize * sample_cnt as usize; + let mut ids = Vec::new(); + for i in 0..total_samples { + ids.push(i as i64); + } + ids.shuffle(&mut rng); + + let mut global_idx = 0; + let schema = test_distinct_schema(); + let mut partitions = vec![]; + for _ in 0..partition_cnt { + let mut id_builder = Int64Builder::new(); + + for _ in 0..sample_cnt { + let id = ids[global_idx]; + id_builder.append_value(id); + global_idx += 1; + } + + let id_col = Arc::new(id_builder.finish()); + let batch = RecordBatch::try_new(schema.clone(), vec![id_col])?; + partitions.push(vec![batch]); + } + + Ok((schema, partitions)) +} + +/// Returns a Schema for distinct benchmarks with i64 trace_id +fn test_distinct_schema() -> SchemaRef { + Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)])) +} + async fn create_context( - limit: usize, partition_cnt: i32, sample_cnt: i32, asc: bool, use_topk: bool, use_view: bool, -) -> Result<(Arc, Arc)> { +) -> Result { let (schema, parts) = make_data(partition_cnt, sample_cnt, asc, use_view).unwrap(); let mem_table = Arc::new(MemTable::try_new(schema, parts).unwrap()); @@ -44,55 +90,196 @@ async fn create_context( opts.optimizer.enable_topk_aggregation = use_topk; let ctx = SessionContext::new_with_config(cfg); let _ = ctx.register_table("traces", mem_table)?; - let sql = format!("select trace_id, max(timestamp_ms) from traces group by trace_id order by max(timestamp_ms) desc limit {limit};"); + + Ok(ctx) +} + +async fn create_context_distinct( + partition_cnt: i32, + sample_cnt: i32, + use_topk: bool, +) -> Result { + // Use deterministic data generation for DISTINCT queries to ensure consistent results + let (schema, parts) = make_distinct_data(partition_cnt, sample_cnt).unwrap(); + let mem_table = Arc::new(MemTable::try_new(schema, parts).unwrap()); + + // Create the DataFrame + let mut cfg = SessionConfig::new(); + let opts = cfg.options_mut(); + opts.optimizer.enable_topk_aggregation = use_topk; + let ctx = SessionContext::new_with_config(cfg); + let _ = ctx.register_table("traces", mem_table)?; + + Ok(ctx) +} + +fn run(rt: &Runtime, ctx: SessionContext, limit: usize, use_topk: bool, asc: bool) { + black_box(rt.block_on(async { aggregate(ctx, limit, use_topk, asc).await })).unwrap(); +} + +fn run_string(rt: &Runtime, ctx: SessionContext, limit: usize, use_topk: bool) { + black_box(rt.block_on(async { aggregate_string(ctx, limit, use_topk).await })) + .unwrap(); +} + +fn run_distinct( + rt: &Runtime, + ctx: SessionContext, + limit: usize, + use_topk: bool, + asc: bool, +) { + black_box(rt.block_on(async { aggregate_distinct(ctx, limit, use_topk, asc).await })) + .unwrap(); +} + +async fn aggregate( + ctx: SessionContext, + limit: usize, + use_topk: bool, + asc: bool, +) -> Result<()> { + let sql = format!( + "select max(timestamp_ms) from traces group by trace_id order by max(timestamp_ms) desc limit {limit};" + ); let df = ctx.sql(sql.as_str()).await?; - let physical_plan = df.create_physical_plan().await?; - let actual_phys_plan = displayable(physical_plan.as_ref()).indent(true).to_string(); + let plan = df.create_physical_plan().await?; + let actual_phys_plan = displayable(plan.as_ref()).indent(true).to_string(); assert_eq!( actual_phys_plan.contains(&format!("lim=[{limit}]")), use_topk ); - Ok((physical_plan, ctx.task_ctx())) + let batches = collect(plan, ctx.task_ctx()).await?; + assert_eq!(batches.len(), 1); + let batch = batches.first().unwrap(); + assert_eq!(batch.num_rows(), LIMIT); + + let actual = format!("{}", pretty_format_batches(&batches)?).to_lowercase(); + let expected_asc = r#" ++--------------------------+ +| max(traces.timestamp_ms) | ++--------------------------+ +| 16909009999999 | +| 16909009999998 | +| 16909009999997 | +| 16909009999996 | +| 16909009999995 | +| 16909009999994 | +| 16909009999993 | +| 16909009999992 | +| 16909009999991 | +| 16909009999990 | ++--------------------------+ + "# + .trim(); + if asc { + assert_eq!(actual.trim(), expected_asc); + } + + Ok(()) } -fn run(rt: &Runtime, plan: Arc, ctx: Arc, asc: bool) { - criterion::black_box( - rt.block_on(async { aggregate(plan.clone(), ctx.clone(), asc).await }), - ) - .unwrap(); +/// Benchmark for string aggregate functions with topk optimization. +/// This tests grouping by a numeric column (timestamp_ms) and aggregating +/// a string column (trace_id) with Utf8 or Utf8View data types. +async fn aggregate_string( + ctx: SessionContext, + limit: usize, + use_topk: bool, +) -> Result<()> { + let sql = format!( + "select max(trace_id) from traces group by timestamp_ms order by max(trace_id) desc limit {limit};" + ); + let df = ctx.sql(sql.as_str()).await?; + let plan = df.create_physical_plan().await?; + let actual_phys_plan = displayable(plan.as_ref()).indent(true).to_string(); + assert_eq!( + actual_phys_plan.contains(&format!("lim=[{limit}]")), + use_topk + ); + + let batches = collect(plan, ctx.task_ctx()).await?; + assert_eq!(batches.len(), 1); + let batch = batches.first().unwrap(); + assert_eq!(batch.num_rows(), LIMIT); + + Ok(()) } -async fn aggregate( - plan: Arc, - ctx: Arc, +async fn aggregate_distinct( + ctx: SessionContext, + limit: usize, + use_topk: bool, asc: bool, ) -> Result<()> { - let batches = collect(plan, ctx).await?; + let order_direction = if asc { "asc" } else { "desc" }; + let sql = format!( + "select id from traces group by id order by id {order_direction} limit {limit};" + ); + let df = ctx.sql(sql.as_str()).await?; + let plan = df.create_physical_plan().await?; + let actual_phys_plan = displayable(plan.as_ref()).indent(true).to_string(); + assert_eq!( + actual_phys_plan.contains(&format!("lim=[{limit}]")), + use_topk + ); + let batches = collect(plan, ctx.task_ctx()).await?; assert_eq!(batches.len(), 1); let batch = batches.first().unwrap(); - assert_eq!(batch.num_rows(), 10); + assert_eq!(batch.num_rows(), LIMIT); let actual = format!("{}", pretty_format_batches(&batches)?).to_lowercase(); + let expected_asc = r#" -+----------------------------------+--------------------------+ -| trace_id | max(traces.timestamp_ms) | -+----------------------------------+--------------------------+ -| 5868861a23ed31355efc5200eb80fe74 | 16909009999999 | -| 4040e64656804c3d77320d7a0e7eb1f0 | 16909009999998 | -| 02801bbe533190a9f8713d75222f445d | 16909009999997 | -| 9e31b3b5a620de32b68fefa5aeea57f1 | 16909009999996 | -| 2d88a860e9bd1cfaa632d8e7caeaa934 | 16909009999995 | -| a47edcef8364ab6f191dd9103e51c171 | 16909009999994 | -| 36a3fa2ccfbf8e00337f0b1254384db6 | 16909009999993 | -| 0756be84f57369012e10de18b57d8a2f | 16909009999992 | -| d4d6bf9845fa5897710e3a8db81d5907 | 16909009999991 | -| 3c2cc1abe728a66b61e14880b53482a0 | 16909009999990 | -+----------------------------------+--------------------------+ - "# ++----+ +| id | ++----+ +| 0 | +| 1 | +| 2 | +| 3 | +| 4 | +| 5 | +| 6 | +| 7 | +| 8 | +| 9 | ++----+ +"# .trim(); + + let expected_desc = r#" ++---------+ +| id | ++---------+ +| 9999999 | +| 9999998 | +| 9999997 | +| 9999996 | +| 9999995 | +| 9999994 | +| 9999993 | +| 9999992 | +| 9999991 | +| 9999990 | ++---------+ +"# + .trim(); + + // Verify exact results match expected values if asc { - assert_eq!(actual.trim(), expected_asc); + assert_eq!( + actual.trim(), + expected_asc, + "Ascending DISTINCT results do not match expected values" + ); + } else { + assert_eq!( + actual.trim(), + expected_desc, + "Descending DISTINCT results do not match expected values" + ); } Ok(()) @@ -100,110 +287,154 @@ async fn aggregate( fn criterion_benchmark(c: &mut Criterion) { let rt = Runtime::new().unwrap(); - let limit = 10; + let limit = LIMIT; let partitions = 10; let samples = 1_000_000; + let ctx = rt + .block_on(create_context(partitions, samples, false, false, false)) + .unwrap(); c.bench_function( format!("aggregate {} time-series rows", partitions * samples).as_str(), - |b| { - b.iter(|| { - let real = rt.block_on(async { - create_context(limit, partitions, samples, false, false, false) - .await - .unwrap() - }); - run(&rt, real.0.clone(), real.1.clone(), false) - }) - }, + |b| b.iter(|| run(&rt, ctx.clone(), limit, false, false)), ); + let ctx = rt + .block_on(create_context(partitions, samples, true, false, false)) + .unwrap(); c.bench_function( format!("aggregate {} worst-case rows", partitions * samples).as_str(), - |b| { - b.iter(|| { - let asc = rt.block_on(async { - create_context(limit, partitions, samples, true, false, false) - .await - .unwrap() - }); - run(&rt, asc.0.clone(), asc.1.clone(), true) - }) - }, + |b| b.iter(|| run(&rt, ctx.clone(), limit, false, true)), ); + let ctx = rt + .block_on(create_context(partitions, samples, false, true, false)) + .unwrap(); c.bench_function( format!( "top k={limit} aggregate {} time-series rows", partitions * samples ) .as_str(), - |b| { - b.iter(|| { - let topk_real = rt.block_on(async { - create_context(limit, partitions, samples, false, true, false) - .await - .unwrap() - }); - run(&rt, topk_real.0.clone(), topk_real.1.clone(), false) - }) - }, + |b| b.iter(|| run(&rt, ctx.clone(), limit, true, false)), ); + let ctx = rt + .block_on(create_context(partitions, samples, true, true, false)) + .unwrap(); c.bench_function( format!( "top k={limit} aggregate {} worst-case rows", partitions * samples ) .as_str(), - |b| { - b.iter(|| { - let topk_asc = rt.block_on(async { - create_context(limit, partitions, samples, true, true, false) - .await - .unwrap() - }); - run(&rt, topk_asc.0.clone(), topk_asc.1.clone(), true) - }) - }, + |b| b.iter(|| run(&rt, ctx.clone(), limit, true, true)), ); // Utf8View schema,time-series rows + let ctx = rt + .block_on(create_context(partitions, samples, false, true, true)) + .unwrap(); c.bench_function( format!( "top k={limit} aggregate {} time-series rows [Utf8View]", partitions * samples ) .as_str(), - |b| { - b.iter(|| { - let topk_real = rt.block_on(async { - create_context(limit, partitions, samples, false, true, true) - .await - .unwrap() - }); - run(&rt, topk_real.0.clone(), topk_real.1.clone(), false) - }) - }, + |b| b.iter(|| run(&rt, ctx.clone(), limit, true, false)), ); // Utf8View schema,worst-case rows + let ctx = rt + .block_on(create_context(partitions, samples, true, true, true)) + .unwrap(); c.bench_function( format!( "top k={limit} aggregate {} worst-case rows [Utf8View]", partitions * samples ) .as_str(), - |b| { - b.iter(|| { - let topk_asc = rt.block_on(async { - create_context(limit, partitions, samples, true, true, true) - .await - .unwrap() - }); - run(&rt, topk_asc.0.clone(), topk_asc.1.clone(), true) - }) - }, + |b| b.iter(|| run(&rt, ctx.clone(), limit, true, true)), + ); + + // String aggregate benchmarks - grouping by timestamp, aggregating string column + let ctx = rt + .block_on(create_context(partitions, samples, false, true, false)) + .unwrap(); + c.bench_function( + format!( + "top k={limit} string aggregate {} time-series rows [Utf8]", + partitions * samples + ) + .as_str(), + |b| b.iter(|| run_string(&rt, ctx.clone(), limit, true)), + ); + + let ctx = rt + .block_on(create_context(partitions, samples, true, true, false)) + .unwrap(); + c.bench_function( + format!( + "top k={limit} string aggregate {} worst-case rows [Utf8]", + partitions * samples + ) + .as_str(), + |b| b.iter(|| run_string(&rt, ctx.clone(), limit, true)), + ); + + let ctx = rt + .block_on(create_context(partitions, samples, false, true, true)) + .unwrap(); + c.bench_function( + format!( + "top k={limit} string aggregate {} time-series rows [Utf8View]", + partitions * samples + ) + .as_str(), + |b| b.iter(|| run_string(&rt, ctx.clone(), limit, true)), + ); + + let ctx = rt + .block_on(create_context(partitions, samples, true, true, true)) + .unwrap(); + c.bench_function( + format!( + "top k={limit} string aggregate {} worst-case rows [Utf8View]", + partitions * samples + ) + .as_str(), + |b| b.iter(|| run_string(&rt, ctx.clone(), limit, true)), + ); + + // DISTINCT benchmarks + let ctx = rt.block_on(async { + create_context_distinct(partitions, samples, false) + .await + .unwrap() + }); + c.bench_function( + format!("distinct {} rows desc [no TopK]", partitions * samples).as_str(), + |b| b.iter(|| run_distinct(&rt, ctx.clone(), limit, false, false)), + ); + + c.bench_function( + format!("distinct {} rows asc [no TopK]", partitions * samples).as_str(), + |b| b.iter(|| run_distinct(&rt, ctx.clone(), limit, false, true)), + ); + + let ctx_topk = rt.block_on(async { + create_context_distinct(partitions, samples, true) + .await + .unwrap() + }); + c.bench_function( + format!("distinct {} rows desc [TopK]", partitions * samples).as_str(), + |b| b.iter(|| run_distinct(&rt, ctx_topk.clone(), limit, true, false)), + ); + + c.bench_function( + format!("distinct {} rows asc [TopK]", partitions * samples).as_str(), + |b| b.iter(|| run_distinct(&rt, ctx_topk.clone(), limit, true, true)), ); } diff --git a/datafusion/core/benches/topk_repartition.rs b/datafusion/core/benches/topk_repartition.rs new file mode 100644 index 0000000000000..e1f14e4aaa633 --- /dev/null +++ b/datafusion/core/benches/topk_repartition.rs @@ -0,0 +1,90 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmark for the TopKRepartition optimizer rule. +//! +//! Measures the benefit of pushing TopK (Sort with fetch) below hash +//! repartition when running partitioned window functions with LIMIT. + +mod data_utils; + +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use data_utils::create_table_provider; +use datafusion::prelude::{SessionConfig, SessionContext}; +use parking_lot::Mutex; +use std::hint::black_box; +use std::sync::Arc; +use tokio::runtime::Runtime; + +#[expect(clippy::needless_pass_by_value)] +fn query(ctx: Arc>, rt: &Runtime, sql: &str) { + let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); + black_box(rt.block_on(df.collect()).unwrap()); +} + +fn create_context( + partitions_len: usize, + target_partitions: usize, + enable_topk_repartition: bool, +) -> Arc> { + let array_len = 1024 * 1024; + let batch_size = 8 * 1024; + let mut config = SessionConfig::new().with_target_partitions(target_partitions); + config.options_mut().optimizer.enable_topk_repartition = enable_topk_repartition; + let ctx = SessionContext::new_with_config(config); + let rt = Runtime::new().unwrap(); + rt.block_on(async { + let provider = + create_table_provider(partitions_len, array_len, batch_size).unwrap(); + ctx.register_table("t", provider).unwrap(); + }); + Arc::new(Mutex::new(ctx)) +} + +fn criterion_benchmark(c: &mut Criterion) { + let rt = Runtime::new().unwrap(); + + let limits = [10, 1_000, 10_000, 100_000]; + let scans = 16; + let target_partitions = 4; + + let group = format!("topk_repartition_{scans}_to_{target_partitions}"); + let mut group = c.benchmark_group(group); + for limit in limits { + let sql = format!( + "SELECT \ + SUM(f64) OVER (PARTITION BY u64_narrow ORDER BY u64_wide ROWS UNBOUNDED PRECEDING) \ + FROM t \ + ORDER BY u64_narrow, u64_wide \ + LIMIT {limit}" + ); + + let ctx_disabled = create_context(scans, target_partitions, false); + group.bench_function(BenchmarkId::new("disabled", limit), |b| { + b.iter(|| query(ctx_disabled.clone(), &rt, &sql)) + }); + + let ctx_enabled = create_context(scans, target_partitions, true); + group.bench_function(BenchmarkId::new("enabled", limit), |b| { + b.iter(|| query(ctx_enabled.clone(), &rt, &sql)) + }); + } + group.finish(); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/core/benches/window_query_sql.rs b/datafusion/core/benches/window_query_sql.rs index a55d17a7c5dcf..1657cae913fef 100644 --- a/datafusion/core/benches/window_query_sql.rs +++ b/datafusion/core/benches/window_query_sql.rs @@ -15,23 +15,21 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -extern crate arrow; -extern crate datafusion; - mod data_utils; -use crate::criterion::Criterion; + +use criterion::{Criterion, criterion_group, criterion_main}; use data_utils::create_table_provider; use datafusion::error::Result; use datafusion::execution::context::SessionContext; use parking_lot::Mutex; +use std::hint::black_box; use std::sync::Arc; use tokio::runtime::Runtime; +#[expect(clippy::needless_pass_by_value)] fn query(ctx: Arc>, rt: &Runtime, sql: &str) { let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); - criterion::black_box(rt.block_on(df.collect()).unwrap()); + black_box(rt.block_on(df.collect()).unwrap()); } fn create_context( diff --git a/datafusion/core/src/bin/print_functions_docs.rs b/datafusion/core/src/bin/print_functions_docs.rs index 1044717aaffb1..2466d42692192 100644 --- a/datafusion/core/src/bin/print_functions_docs.rs +++ b/datafusion/core/src/bin/print_functions_docs.rs @@ -16,10 +16,10 @@ // under the License. use datafusion::execution::SessionStateDefaults; -use datafusion_common::{not_impl_err, HashSet, Result}; +use datafusion_common::{HashSet, Result, not_impl_err}; use datafusion_expr::{ - aggregate_doc_sections, scalar_doc_sections, window_doc_sections, AggregateUDF, - DocSection, Documentation, ScalarUDF, WindowUDF, + AggregateUDF, DocSection, Documentation, ScalarUDF, WindowUDF, + aggregate_doc_sections, scalar_doc_sections, window_doc_sections, }; use itertools::Itertools; use std::env::args; @@ -84,30 +84,7 @@ fn print_window_docs() -> Result { print_docs(providers, window_doc_sections::doc_sections()) } -// Temporary method useful to semi automate -// the migration of UDF documentation generation from code based -// to attribute based -// To be removed -#[allow(dead_code)] -fn save_doc_code_text(documentation: &Documentation, name: &str) { - let attr_text = documentation.to_doc_attribute(); - - let file_path = format!("{name}.txt"); - if std::path::Path::new(&file_path).exists() { - std::fs::remove_file(&file_path).unwrap(); - } - - // Open the file in append mode, create it if it doesn't exist - let mut file = std::fs::OpenOptions::new() - .append(true) // Open in append mode - .create(true) // Create the file if it doesn't exist - .open(file_path) - .unwrap(); - - use std::io::Write; - file.write_all(attr_text.as_bytes()).unwrap(); -} - +#[expect(clippy::needless_pass_by_value)] fn print_docs( providers: Vec>, doc_sections: Vec, @@ -254,13 +231,15 @@ fn print_docs( for f in &providers_with_no_docs { eprintln!(" - {f}"); } - not_impl_err!("Some functions do not have documentation. Please implement `documentation` for: {providers_with_no_docs:?}") + not_impl_err!( + "Some functions do not have documentation. Please implement `documentation` for: {providers_with_no_docs:?}" + ) } else { Ok(docs) } } -/// Trait for accessing name / aliases / documentation for differnet functions +/// Trait for accessing name / aliases / documentation for different functions trait DocProvider { fn get_name(&self) -> String; fn get_aliases(&self) -> Vec; @@ -303,8 +282,7 @@ impl DocProvider for WindowUDF { } } -#[allow(clippy::borrowed_box)] -#[allow(clippy::ptr_arg)] +#[expect(clippy::borrowed_box)] fn get_names_and_aliases(functions: &Vec<&Box>) -> Vec { functions .iter() diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 69992e57ca7d0..2292f5855bfde 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -26,42 +26,43 @@ use crate::datasource::file_format::csv::CsvFormatFactory; use crate::datasource::file_format::format_as_file_type; use crate::datasource::file_format::json::JsonFormatFactory; use crate::datasource::{ - provider_as_source, DefaultTableSource, MemTable, TableProvider, + DefaultTableSource, MemTable, TableProvider, provider_as_source, }; use crate::error::Result; -use crate::execution::context::{SessionState, TaskContext}; use crate::execution::FunctionRegistry; +use crate::execution::context::{SessionState, TaskContext}; use crate::logical_expr::utils::find_window_exprs; use crate::logical_expr::{ - col, ident, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, - LogicalPlanBuilderOptions, Partitioning, TableType, + Expr, JoinType, LogicalPlan, LogicalPlanBuilder, LogicalPlanBuilderOptions, + Partitioning, TableType, col, ident, }; use crate::physical_plan::{ - collect, collect_partitioned, execute_stream, execute_stream_partitioned, - ExecutionPlan, SendableRecordBatchStream, + ExecutionPlan, SendableRecordBatchStream, collect, collect_partitioned, + execute_stream, execute_stream_partitioned, }; use crate::prelude::SessionContext; use std::any::Any; use std::borrow::Cow; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; use arrow::array::{Array, ArrayRef, Int64Array, StringArray}; use arrow::compute::{cast, concat}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow_schema::FieldRef; use datafusion_common::config::{CsvOptions, JsonOptions}; use datafusion_common::{ - exec_err, not_impl_err, plan_datafusion_err, plan_err, Column, DFSchema, - DataFusionError, ParamValues, ScalarValue, SchemaError, UnnestOptions, + Column, DFSchema, DataFusionError, ParamValues, ScalarValue, SchemaError, + TableReference, UnnestOptions, exec_err, internal_datafusion_err, not_impl_err, + plan_datafusion_err, plan_err, unqualified_field_not_found, }; use datafusion_expr::select_expr::SelectExpr; use datafusion_expr::{ - case, + ExplainOption, SortExpr, TableProviderFilterPushDown, UNNAMED_TABLE, case, dml::InsertOp, expr::{Alias, ScalarFunction}, is_null, lit, utils::COUNT_STAR_EXPANSION, - SortExpr, TableProviderFilterPushDown, UNNAMED_TABLE, }; use datafusion_functions::core::coalesce; use datafusion_functions_aggregate::expr_fn::{ @@ -70,7 +71,6 @@ use datafusion_functions_aggregate::expr_fn::{ use async_trait::async_trait; use datafusion_catalog::Session; -use datafusion_sql::TableReference; /// Contains options that control how data is /// written out from a DataFrame @@ -78,9 +78,11 @@ pub struct DataFrameWriteOptions { /// Controls how new data should be written to the table, determining whether /// to append, overwrite, or replace existing data. insert_op: InsertOp, - /// Controls if all partitions should be coalesced into a single output file - /// Generally will have slower performance when set to true. - single_file_output: bool, + /// Controls if all partitions should be coalesced into a single output file. + /// - `None`: Use automatic mode (extension-based heuristic) + /// - `Some(true)`: Force single file output at exact path + /// - `Some(false)`: Force directory output with generated filenames + single_file_output: Option, /// Sets which columns should be used for hive-style partitioned writes by name. /// Can be set to empty vec![] for non-partitioned writes. partition_by: Vec, @@ -94,7 +96,7 @@ impl DataFrameWriteOptions { pub fn new() -> Self { DataFrameWriteOptions { insert_op: InsertOp::Append, - single_file_output: false, + single_file_output: None, partition_by: vec![], sort_by: vec![], } @@ -107,8 +109,14 @@ impl DataFrameWriteOptions { } /// Set the single_file_output value to true or false + /// + /// - `true`: Force single file output at the exact path specified + /// - `false`: Force directory output with generated filenames + /// + /// When not called, automatic mode is used (extension-based heuristic). + /// When set to true, an output file will always be created even if the DataFrame is empty. pub fn with_single_file_output(mut self, single_file_output: bool) -> Self { - self.single_file_output = single_file_output; + self.single_file_output = Some(single_file_output); self } @@ -123,6 +131,15 @@ impl DataFrameWriteOptions { self.sort_by = sort_by; self } + + /// Build the options HashMap to pass to CopyTo for sink configuration. + fn build_sink_options(&self) -> HashMap { + let mut options = HashMap::new(); + if let Some(single_file) = self.single_file_output { + options.insert("single_file_output".to_string(), single_file.to_string()); + } + options + } } impl Default for DataFrameWriteOptions { @@ -258,15 +275,19 @@ impl DataFrame { /// # async fn main() -> Result<()> { /// // datafusion will parse number as i64 first. /// let sql = "a > 1 and b in (1, 10)"; - /// let expected = col("a").gt(lit(1 as i64)) - /// .and(col("b").in_list(vec![lit(1 as i64), lit(10 as i64)], false)); + /// let expected = col("a") + /// .gt(lit(1 as i64)) + /// .and(col("b").in_list(vec![lit(1 as i64), lit(10 as i64)], false)); /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let df = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await?; /// let expr = df.parse_sql_expr(sql)?; /// assert_eq!(expected, expr); /// # Ok(()) /// # } /// ``` + #[cfg(feature = "sql")] pub fn parse_sql_expr(&self, sql: &str) -> Result { let df_schema = self.schema(); @@ -288,14 +309,16 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let df = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await?; /// let df = df.select_columns(&["a", "b"])?; /// let expected = vec![ /// "+---+---+", /// "| a | b |", /// "+---+---+", /// "| 1 | 2 |", - /// "+---+---+" + /// "+---+---+", /// ]; /// # assert_batches_sorted_eq!(expected, &df.collect().await?); /// # Ok(()) @@ -305,11 +328,20 @@ impl DataFrame { let fields = columns .iter() .map(|name| { - self.plan + let fields = self + .plan .schema() - .qualified_field_with_unqualified_name(name) + .qualified_fields_with_unqualified_name(name); + if fields.is_empty() { + Err(unqualified_field_not_found(name, self.plan.schema())) + } else { + Ok(fields) + } }) - .collect::>>()?; + .collect::, _>>()? + .into_iter() + .flatten() + .collect::>(); let expr: Vec = fields .into_iter() .map(|(qualifier, field)| Expr::Column(Column::from((qualifier, field)))) @@ -328,11 +360,14 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; - /// let df : DataFrame = df.select_exprs(&["a * b", "c"])?; + /// let df = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await?; + /// let df: DataFrame = df.select_exprs(&["a * b", "c"])?; /// # Ok(()) /// # } /// ``` + #[cfg(feature = "sql")] pub fn select_exprs(self, exprs: &[&str]) -> Result { let expr_list = exprs .iter() @@ -355,14 +390,16 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let df = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await?; /// let df = df.select(vec![col("a"), col("b") * col("c")])?; /// let expected = vec![ /// "+---+-----------------------+", /// "| a | ?table?.b * ?table?.c |", /// "+---+-----------------------+", /// "| 1 | 6 |", - /// "+---+-----------------------+" + /// "+---+-----------------------+", /// ]; /// # assert_batches_sorted_eq!(expected, &df.collect().await?); /// # Ok(()) @@ -375,15 +412,12 @@ impl DataFrame { let expr_list: Vec = expr_list.into_iter().map(|e| e.into()).collect::>(); - let expressions = expr_list - .iter() - .filter_map(|e| match e { - SelectExpr::Expression(expr) => Some(expr.clone()), - _ => None, - }) - .collect::>(); + let expressions = expr_list.iter().filter_map(|e| match e { + SelectExpr::Expression(expr) => Some(expr), + _ => None, + }); - let window_func_exprs = find_window_exprs(&expressions); + let window_func_exprs = find_window_exprs(expressions); let plan = if window_func_exprs.is_empty() { self.plan } else { @@ -408,7 +442,9 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let df = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await?; /// // +----+----+----+ /// // | a | b | c | /// // +----+----+----+ @@ -420,22 +456,37 @@ impl DataFrame { /// "| b | c |", /// "+---+---+", /// "| 2 | 3 |", - /// "+---+---+" + /// "+---+---+", /// ]; /// # assert_batches_sorted_eq!(expected, &df.collect().await?); /// # Ok(()) /// # } /// ``` - pub fn drop_columns(self, columns: &[&str]) -> Result { + pub fn drop_columns(self, columns: &[T]) -> Result + where + T: Into + Clone, + { let fields_to_drop = columns .iter() - .map(|name| { - self.plan - .schema() - .qualified_field_with_unqualified_name(name) + .flat_map(|col| { + let column: Column = col.clone().into(); + match column.relation.as_ref() { + Some(_) => { + // qualified_field_from_column returns Result<(Option<&TableReference>, &FieldRef)> + vec![self.plan.schema().qualified_field_from_column(&column)] + } + None => { + // qualified_fields_with_unqualified_name returns Vec<(Option<&TableReference>, &FieldRef)> + self.plan + .schema() + .qualified_fields_with_unqualified_name(&column.name) + .into_iter() + .map(Ok) + .collect::>() + } + } }) - .filter(|r| r.is_ok()) - .collect::>>()?; + .collect::, _>>()?; let expr: Vec = self .plan .schema() @@ -461,7 +512,7 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_json("tests/data/unnest.json", NdJsonReadOptions::default()).await?; + /// let df = ctx.read_json("tests/data/unnest.json", JsonReadOptions::default()).await?; /// // expand into multiple columns if it's json array, flatten field name if it's nested structure /// let df = df.unnest_columns(&["b","c","d"])?; /// let expected = vec![ @@ -519,7 +570,9 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example_long.csv", CsvReadOptions::new()).await?; + /// let df = ctx + /// .read_csv("tests/data/example_long.csv", CsvReadOptions::new()) + /// .await?; /// let df = df.filter(col("a").lt_eq(col("b")))?; /// // all rows where a <= b are returned /// let expected = vec![ @@ -529,7 +582,7 @@ impl DataFrame { /// "| 1 | 2 | 3 |", /// "| 4 | 5 | 6 |", /// "| 7 | 8 | 9 |", - /// "+---+---+---+" + /// "+---+---+---+", /// ]; /// # assert_batches_sorted_eq!(expected, &df.collect().await?); /// # Ok(()) @@ -558,7 +611,9 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example_long.csv", CsvReadOptions::new()).await?; + /// let df = ctx + /// .read_csv("tests/data/example_long.csv", CsvReadOptions::new()) + /// .await?; /// /// // The following use is the equivalent of "SELECT MIN(b) GROUP BY a" /// let df1 = df.clone().aggregate(vec![col("a")], vec![min(col("b"))])?; @@ -569,7 +624,7 @@ impl DataFrame { /// "| 1 | 2 |", /// "| 4 | 5 |", /// "| 7 | 8 |", - /// "+---+----------------+" + /// "+---+----------------+", /// ]; /// assert_batches_sorted_eq!(expected1, &df1.collect().await?); /// // The following use is the equivalent of "SELECT MIN(b)" @@ -579,7 +634,7 @@ impl DataFrame { /// "| min(?table?.b) |", /// "+----------------+", /// "| 2 |", - /// "+----------------+" + /// "+----------------+", /// ]; /// # assert_batches_sorted_eq!(expected2, &df2.collect().await?); /// # Ok(()) @@ -647,7 +702,9 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example_long.csv", CsvReadOptions::new()).await?; + /// let df = ctx + /// .read_csv("tests/data/example_long.csv", CsvReadOptions::new()) + /// .await?; /// let df = df.limit(1, Some(2))?; /// let expected = vec![ /// "+---+---+---+", @@ -655,7 +712,7 @@ impl DataFrame { /// "+---+---+---+", /// "| 4 | 5 | 6 |", /// "| 7 | 8 | 9 |", - /// "+---+---+---+" + /// "+---+---+---+", /// ]; /// # assert_batches_sorted_eq!(expected, &df.collect().await?); /// # Ok(()) @@ -684,7 +741,9 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await? ; + /// let df = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await?; /// let d2 = df.clone(); /// let df = df.union(d2)?; /// let expected = vec![ @@ -693,7 +752,7 @@ impl DataFrame { /// "+---+---+---+", /// "| 1 | 2 | 3 |", /// "| 1 | 2 | 3 |", - /// "+---+---+---+" + /// "+---+---+---+", /// ]; /// # assert_batches_sorted_eq!(expected, &df.collect().await?); /// # Ok(()) @@ -724,8 +783,13 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; - /// let d2 = df.clone().select_columns(&["b", "c", "a"])?.with_column("d", lit("77"))?; + /// let df = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await?; + /// let d2 = df + /// .clone() + /// .select_columns(&["b", "c", "a"])? + /// .with_column("d", lit("77"))?; /// let df = df.union_by_name(d2)?; /// let expected = vec![ /// "+---+---+---+----+", @@ -733,7 +797,7 @@ impl DataFrame { /// "+---+---+---+----+", /// "| 1 | 2 | 3 | |", /// "| 1 | 2 | 3 | 77 |", - /// "+---+---+---+----+" + /// "+---+---+---+----+", /// ]; /// # assert_batches_sorted_eq!(expected, &df.collect().await?); /// # Ok(()) @@ -763,7 +827,9 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let df = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await?; /// let d2 = df.clone(); /// let df = df.union_distinct(d2)?; /// // df2 are duplicate of df @@ -772,7 +838,7 @@ impl DataFrame { /// "| a | b | c |", /// "+---+---+---+", /// "| 1 | 2 | 3 |", - /// "+---+---+---+" + /// "+---+---+---+", /// ]; /// # assert_batches_sorted_eq!(expected, &df.collect().await?); /// # Ok(()) @@ -803,7 +869,9 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let df = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await?; /// let d2 = df.clone().select_columns(&["b", "c", "a"])?; /// let df = df.union_by_name_distinct(d2)?; /// let expected = vec![ @@ -811,7 +879,7 @@ impl DataFrame { /// "| a | b | c |", /// "+---+---+---+", /// "| 1 | 2 | 3 |", - /// "+---+---+---+" + /// "+---+---+---+", /// ]; /// # assert_batches_sorted_eq!(expected, &df.collect().await?); /// # Ok(()) @@ -838,14 +906,16 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let df = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await?; /// let df = df.distinct()?; /// let expected = vec![ /// "+---+---+---+", /// "| a | b | c |", /// "+---+---+---+", /// "| 1 | 2 | 3 |", - /// "+---+---+---+" + /// "+---+---+---+", /// ]; /// # assert_batches_sorted_eq!(expected, &df.collect().await?); /// # Ok(()) @@ -872,15 +942,17 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await? - /// // Return a single row (a, b) for each distinct value of a - /// .distinct_on(vec![col("a")], vec![col("a"), col("b")], None)?; + /// let df = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await? + /// // Return a single row (a, b) for each distinct value of a + /// .distinct_on(vec![col("a")], vec![col("a"), col("b")], None)?; /// let expected = vec![ /// "+---+---+", /// "| a | b |", /// "+---+---+", /// "| 1 | 2 |", - /// "+---+---+" + /// "+---+---+", /// ]; /// # assert_batches_sorted_eq!(expected, &df.collect().await?); /// # Ok(()) @@ -953,7 +1025,7 @@ impl DataFrame { })); //collect recordBatch - let describe_record_batch = vec![ + let describe_record_batch = [ // count aggregation self.clone().aggregate( vec![], @@ -1126,11 +1198,13 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example_long.csv", CsvReadOptions::new()).await?; + /// let df = ctx + /// .read_csv("tests/data/example_long.csv", CsvReadOptions::new()) + /// .await?; /// let df = df.sort(vec![ - /// col("a").sort(false, true), // a DESC, nulls first - /// col("b").sort(true, false), // b ASC, nulls last - /// ])?; + /// col("a").sort(false, true), // a DESC, nulls first + /// col("b").sort(true, false), // b ASC, nulls last + /// ])?; /// let expected = vec![ /// "+---+---+---+", /// "| a | b | c |", @@ -1177,12 +1251,17 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let left = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; - /// let right = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await? - /// .select(vec![ - /// col("a").alias("a2"), - /// col("b").alias("b2"), - /// col("c").alias("c2")])?; + /// let left = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await?; + /// let right = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await? + /// .select(vec![ + /// col("a").alias("a2"), + /// col("b").alias("b2"), + /// col("c").alias("c2"), + /// ])?; /// // Perform the equivalent of `left INNER JOIN right ON (a = a2 AND b = b2)` /// // finding all pairs of rows from `left` and `right` where `a = a2` and `b = b2`. /// let join = left.join(right, JoinType::Inner, &["a", "b"], &["a2", "b2"], None)?; @@ -1191,13 +1270,12 @@ impl DataFrame { /// "| a | b | c | a2 | b2 | c2 |", /// "+---+---+---+----+----+----+", /// "| 1 | 2 | 3 | 1 | 2 | 3 |", - /// "+---+---+---+----+----+----+" + /// "+---+---+---+----+----+----+", /// ]; /// assert_batches_sorted_eq!(expected, &join.collect().await?); /// # Ok(()) /// # } /// ``` - /// pub fn join( self, right: DataFrame, @@ -1259,7 +1337,7 @@ impl DataFrame { /// "+---+---+---+----+----+----+", /// "| a | b | c | a2 | b2 | c2 |", /// "+---+---+---+----+----+----+", - /// "+---+---+---+----+----+----+" + /// "+---+---+---+----+----+----+", /// ]; /// # assert_batches_sorted_eq!(expected, &join_on.collect().await?); /// # Ok(()) @@ -1291,7 +1369,9 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example_long.csv", CsvReadOptions::new()).await?; + /// let df = ctx + /// .read_csv("tests/data/example_long.csv", CsvReadOptions::new()) + /// .await?; /// let df1 = df.repartition(Partitioning::RoundRobinBatch(4))?; /// let expected = vec![ /// "+---+---+---+", @@ -1300,7 +1380,7 @@ impl DataFrame { /// "| 1 | 2 | 3 |", /// "| 4 | 5 | 6 |", /// "| 7 | 8 | 9 |", - /// "+---+---+---+" + /// "+---+---+---+", /// ]; /// # assert_batches_sorted_eq!(expected, &df1.collect().await?); /// # Ok(()) @@ -1329,7 +1409,9 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let df = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await?; /// let count = df.count().await?; // 1 /// # assert_eq!(count, 1); /// # Ok(()) @@ -1337,7 +1419,10 @@ impl DataFrame { /// ``` pub async fn count(self) -> Result { let rows = self - .aggregate(vec![], vec![count(Expr::Literal(COUNT_STAR_EXPANSION))])? + .aggregate( + vec![], + vec![count(Expr::Literal(COUNT_STAR_EXPANSION, None))], + )? .collect() .await?; let len = *rows @@ -1345,9 +1430,9 @@ impl DataFrame { .and_then(|r| r.columns().first()) .and_then(|c| c.as_any().downcast_ref::()) .and_then(|a| a.values().first()) - .ok_or(DataFusionError::Internal( - "Unexpected output when collecting for count()".to_string(), - ))? as usize; + .ok_or_else(|| { + internal_datafusion_err!("Unexpected output when collecting for count()") + })? as usize; Ok(len) } @@ -1365,7 +1450,9 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let df = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await?; /// let batches = df.collect().await?; /// # Ok(()) /// # } @@ -1385,7 +1472,9 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let df = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await?; /// df.show().await?; /// # Ok(()) /// # } @@ -1444,7 +1533,9 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let df = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await?; /// df.show_limit(10).await?; /// # Ok(()) /// # } @@ -1470,7 +1561,9 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let df = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await?; /// let stream = df.execute_stream().await?; /// # Ok(()) /// # } @@ -1496,7 +1589,9 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let df = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await?; /// let batches = df.collect_partitioned().await?; /// # Ok(()) /// # } @@ -1516,7 +1611,9 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let df = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await?; /// let batches = df.execute_stream_partitioned().await?; /// # Ok(()) /// # } @@ -1545,7 +1642,9 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let df = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await?; /// let schema = df.schema(); /// # Ok(()) /// # } @@ -1596,12 +1695,26 @@ impl DataFrame { /// Note: This discards the [`SessionState`] associated with this /// [`DataFrame`] in favour of the one passed to [`TableProvider::scan`] pub fn into_view(self) -> Arc { - Arc::new(DataFrameTableProvider { plan: self.plan }) + Arc::new(DataFrameTableProvider { + plan: self.plan, + table_type: TableType::View, + }) + } + + /// See [`Self::into_view`]. The returned [`TableProvider`] will + /// create a transient table. + pub fn into_temporary_view(self) -> Arc { + Arc::new(DataFrameTableProvider { + plan: self.plan, + table_type: TableType::Temporary, + }) } /// Return a DataFrame with the explanation of its plan so far. /// /// if `analyze` is specified, runs the plan and reports metrics + /// if `verbose` is true, prints out additional details. + /// The default format is Indent format. /// /// ``` /// # use datafusion::prelude::*; @@ -1609,17 +1722,60 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; - /// let batches = df.limit(0, Some(100))?.explain(false, false)?.collect().await?; + /// let df = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await?; + /// let batches = df + /// .limit(0, Some(100))? + /// .explain(false, false)? + /// .collect() + /// .await?; /// # Ok(()) /// # } /// ``` pub fn explain(self, verbose: bool, analyze: bool) -> Result { + // Set the default format to Indent to keep the previous behavior + let opts = ExplainOption::default() + .with_verbose(verbose) + .with_analyze(analyze); + self.explain_with_options(opts) + } + + /// Return a DataFrame with the explanation of its plan so far. + /// + /// `opt` is used to specify the options for the explain operation. + /// Details of the options can be found in [`ExplainOption`]. + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// use datafusion_expr::{Explain, ExplainOption}; + /// let ctx = SessionContext::new(); + /// let df = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await?; + /// let batches = df + /// .limit(0, Some(100))? + /// .explain_with_options( + /// ExplainOption::default() + /// .with_verbose(false) + /// .with_analyze(false), + /// )? + /// .collect() + /// .await?; + /// # Ok(()) + /// # } + /// ``` + pub fn explain_with_options( + self, + explain_option: ExplainOption, + ) -> Result { if matches!(self.plan, LogicalPlan::Explain(_)) { return plan_err!("Nested EXPLAINs are not supported"); } let plan = LogicalPlanBuilder::from(self.plan) - .explain(verbose, analyze)? + .explain_option_format(explain_option)? .build()?; Ok(DataFrame { session_state: self.session_state, @@ -1637,7 +1793,9 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let df = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await?; /// let f = df.registry(); /// // use f.udf("name", vec![...]) to use the udf /// # Ok(()) @@ -1656,15 +1814,19 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; - /// let d2 = ctx.read_csv("tests/data/example_long.csv", CsvReadOptions::new()).await?; + /// let df = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await?; + /// let d2 = ctx + /// .read_csv("tests/data/example_long.csv", CsvReadOptions::new()) + /// .await?; /// let df = df.intersect(d2)?; /// let expected = vec![ /// "+---+---+---+", /// "| a | b | c |", /// "+---+---+---+", /// "| 1 | 2 | 3 |", - /// "+---+---+---+" + /// "+---+---+---+", /// ]; /// # assert_batches_sorted_eq!(expected, &df.collect().await?); /// # Ok(()) @@ -1681,6 +1843,44 @@ impl DataFrame { }) } + /// Calculate the distinct intersection of two [`DataFrame`]s. The two [`DataFrame`]s must have exactly the same schema + /// + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # use datafusion_common::assert_batches_sorted_eq; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let ctx = SessionContext::new(); + /// let df = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await?; + /// let d2 = ctx + /// .read_csv("tests/data/example_long.csv", CsvReadOptions::new()) + /// .await?; + /// let df = df.intersect_distinct(d2)?; + /// let expected = vec![ + /// "+---+---+---+", + /// "| a | b | c |", + /// "+---+---+---+", + /// "| 1 | 2 | 3 |", + /// "+---+---+---+", + /// ]; + /// # assert_batches_sorted_eq!(expected, &df.collect().await?); + /// # Ok(()) + /// # } + /// ``` + pub fn intersect_distinct(self, dataframe: DataFrame) -> Result { + let left_plan = self.plan; + let right_plan = dataframe.plan; + let plan = LogicalPlanBuilder::intersect(left_plan, right_plan, false)?; + Ok(DataFrame { + session_state: self.session_state, + plan, + projection_requires_validation: true, + }) + } + /// Calculate the exception of two [`DataFrame`]s. The two [`DataFrame`]s must have exactly the same schema /// /// ``` @@ -1690,8 +1890,12 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example_long.csv", CsvReadOptions::new()).await?; - /// let d2 = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let df = ctx + /// .read_csv("tests/data/example_long.csv", CsvReadOptions::new()) + /// .await?; + /// let d2 = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await?; /// let result = df.except(d2)?; /// // those columns are not in example.csv, but in example_long.csv /// let expected = vec![ @@ -1700,7 +1904,7 @@ impl DataFrame { /// "+---+---+---+", /// "| 4 | 5 | 6 |", /// "| 7 | 8 | 9 |", - /// "+---+---+---+" + /// "+---+---+---+", /// ]; /// # assert_batches_sorted_eq!(expected, &result.collect().await?); /// # Ok(()) @@ -1717,6 +1921,46 @@ impl DataFrame { }) } + /// Calculate the distinct exception of two [`DataFrame`]s. The two [`DataFrame`]s must have exactly the same schema + /// + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # use datafusion_common::assert_batches_sorted_eq; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let ctx = SessionContext::new(); + /// let df = ctx + /// .read_csv("tests/data/example_long.csv", CsvReadOptions::new()) + /// .await?; + /// let d2 = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await?; + /// let result = df.except_distinct(d2)?; + /// // those columns are not in example.csv, but in example_long.csv + /// let expected = vec![ + /// "+---+---+---+", + /// "| a | b | c |", + /// "+---+---+---+", + /// "| 4 | 5 | 6 |", + /// "| 7 | 8 | 9 |", + /// "+---+---+---+", + /// ]; + /// # assert_batches_sorted_eq!(expected, &result.collect().await?); + /// # Ok(()) + /// # } + /// ``` + pub fn except_distinct(self, dataframe: DataFrame) -> Result { + let left_plan = self.plan; + let right_plan = dataframe.plan; + let plan = LogicalPlanBuilder::except(left_plan, right_plan, false)?; + Ok(DataFrame { + session_state: self.session_state, + plan, + projection_requires_validation: true, + }) + } + /// Execute this `DataFrame` and write the results to `table_name`. /// /// Returns a single [RecordBatch] containing a single column and @@ -1777,13 +2021,15 @@ impl DataFrame { /// use datafusion::dataframe::DataFrameWriteOptions; /// let ctx = SessionContext::new(); /// // Sort the data by column "b" and write it to a new location - /// ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await? - /// .sort(vec![col("b").sort(true, true)])? // sort by b asc, nulls first - /// .write_csv( - /// "output.csv", - /// DataFrameWriteOptions::new(), - /// None, // can also specify CSV writing options here - /// ).await?; + /// ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await? + /// .sort(vec![col("b").sort(true, true)])? // sort by b asc, nulls first + /// .write_csv( + /// "output.csv", + /// DataFrameWriteOptions::new(), + /// None, // can also specify CSV writing options here + /// ) + /// .await?; /// # fs::remove_file("output.csv")?; /// # Ok(()) /// # } @@ -1809,6 +2055,8 @@ impl DataFrame { let file_type = format_as_file_type(format); + let copy_options = options.build_sink_options(); + let plan = if options.sort_by.is_empty() { self.plan } else { @@ -1821,7 +2069,7 @@ impl DataFrame { plan, path.into(), file_type, - HashMap::new(), + copy_options, options.partition_by, )? .build()?; @@ -1847,13 +2095,11 @@ impl DataFrame { /// use datafusion::dataframe::DataFrameWriteOptions; /// let ctx = SessionContext::new(); /// // Sort the data by column "b" and write it to a new location - /// ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await? - /// .sort(vec![col("b").sort(true, true)])? // sort by b asc, nulls first - /// .write_json( - /// "output.json", - /// DataFrameWriteOptions::new(), - /// None - /// ).await?; + /// ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await? + /// .sort(vec![col("b").sort(true, true)])? // sort by b asc, nulls first + /// .write_json("output.json", DataFrameWriteOptions::new(), None) + /// .await?; /// # fs::remove_file("output.json")?; /// # Ok(()) /// # } @@ -1879,6 +2125,8 @@ impl DataFrame { let file_type = format_as_file_type(format); + let copy_options = options.build_sink_options(); + let plan = if options.sort_by.is_empty() { self.plan } else { @@ -1891,7 +2139,7 @@ impl DataFrame { plan, path.into(), file_type, - Default::default(), + copy_options, options.partition_by, )? .build()?; @@ -1914,39 +2162,48 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let df = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await?; /// let df = df.with_column("ab_sum", col("a") + col("b"))?; /// # Ok(()) /// # } /// ``` pub fn with_column(self, name: &str, expr: Expr) -> Result { - let window_func_exprs = find_window_exprs(std::slice::from_ref(&expr)); + let window_func_exprs = find_window_exprs([&expr]); + + let original_names: HashSet = self + .plan + .schema() + .iter() + .map(|(_, f)| f.name().clone()) + .collect(); - let (window_fn_str, plan) = if window_func_exprs.is_empty() { - (None, self.plan) + // Maybe build window plan + let plan = if window_func_exprs.is_empty() { + self.plan } else { - ( - Some(window_func_exprs[0].to_string()), - LogicalPlanBuilder::window_plan(self.plan, window_func_exprs)?, - ) + LogicalPlanBuilder::window_plan(self.plan, window_func_exprs)? }; - let mut col_exists = false; let new_column = expr.alias(name); + let mut col_exists = false; + let mut fields: Vec<(Expr, bool)> = plan .schema() .iter() .filter_map(|(qualifier, field)| { + // Skip new fields introduced by window_plan + if !original_names.contains(field.name()) { + return None; + } + if field.name() == name { col_exists = true; Some((new_column.clone(), true)) } else { let e = col(Column::from((qualifier, field))); - window_fn_str - .as_ref() - .filter(|s| *s == &e.to_string()) - .is_none() - .then_some((e, self.projection_requires_validation)) + Some((e, self.projection_requires_validation)) } }) .collect(); @@ -1981,7 +2238,9 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let df = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await?; /// let df = df.with_column_renamed("ab_sum", "total")?; /// /// # Ok(()) @@ -2007,10 +2266,11 @@ impl DataFrame { match self.plan.schema().qualified_field_from_column(&old_column) { Ok(qualifier_and_field) => qualifier_and_field, // no-op if field not found - Err(DataFusionError::SchemaError( - SchemaError::FieldNotFound { .. }, - _, - )) => return Ok(self), + Err(DataFusionError::SchemaError(e, _)) + if matches!(*e, SchemaError::FieldNotFound { .. }) => + { + return Ok(self); + } Err(err) => return Err(err), }; let projection = self @@ -2018,7 +2278,7 @@ impl DataFrame { .schema() .iter() .map(|(qualifier, field)| { - if qualifier.eq(&qualifier_rename) && field.as_ref() == field_rename { + if qualifier.eq(&qualifier_rename) && field == field_rename { ( col(Column::from((qualifier, field))) .alias_qualified(qualifier.cloned(), new_name), @@ -2107,26 +2367,38 @@ impl DataFrame { /// Cache DataFrame as a memory table. /// + /// Default behavior could be changed using + /// a [`crate::execution::session_state::CacheFactory`] + /// configured via [`SessionState`]. + /// /// ``` /// # use datafusion::prelude::*; /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let df = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await?; /// let df = df.cache().await?; /// # Ok(()) /// # } /// ``` pub async fn cache(self) -> Result { - let context = SessionContext::new_with_state((*self.session_state).clone()); - // The schema is consistent with the output - let plan = self.clone().create_physical_plan().await?; - let schema = plan.schema(); - let task_ctx = Arc::new(self.task_ctx()); - let partitions = collect_partitioned(plan, task_ctx).await?; - let mem_table = MemTable::try_new(schema, partitions)?; - context.read_table(Arc::new(mem_table)) + if let Some(cache_factory) = self.session_state.cache_factory() { + let new_plan = + cache_factory.create(self.plan, self.session_state.as_ref())?; + Ok(Self::new(*self.session_state, new_plan)) + } else { + let context = SessionContext::new_with_state((*self.session_state).clone()); + // The schema is consistent with the output + let plan = self.clone().create_physical_plan().await?; + let schema = plan.schema(); + let task_ctx = Arc::new(self.task_ctx()); + let partitions = collect_partitioned(plan, task_ctx).await?; + let mem_table = MemTable::try_new(schema, partitions)?; + context.read_table(Arc::new(mem_table)) + } } /// Apply an alias to the DataFrame. @@ -2157,7 +2429,9 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let df = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await?; /// // Fill nulls in only columns "a" and "c": /// let df = df.fill_null(ScalarValue::from(0), vec!["a".to_owned(), "c".to_owned()])?; /// // Fill nulls across all columns: @@ -2165,6 +2439,7 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` + #[expect(clippy::needless_pass_by_value)] pub fn fill_null( &self, value: ScalarValue, @@ -2175,7 +2450,7 @@ impl DataFrame { .schema() .fields() .iter() - .map(|f| f.as_ref().clone()) + .map(Arc::clone) .collect() } else { self.find_columns(&columns)? @@ -2212,7 +2487,7 @@ impl DataFrame { } // Helper to find columns from names - fn find_columns(&self, names: &[String]) -> Result> { + fn find_columns(&self, names: &[String]) -> Result> { let schema = self.logical_plan().schema(); names .iter() @@ -2225,12 +2500,54 @@ impl DataFrame { .collect() } + /// Find qualified columns for this dataframe from names + /// + /// # Arguments + /// * `names` - Unqualified names to find. + /// + /// # Example + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # use datafusion_common::ScalarValue; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let ctx = SessionContext::new(); + /// ctx.register_csv("first_table", "tests/data/example.csv", CsvReadOptions::new()) + /// .await?; + /// let df = ctx.table("first_table").await?; + /// ctx.register_csv("second_table", "tests/data/example.csv", CsvReadOptions::new()) + /// .await?; + /// let df2 = ctx.table("second_table").await?; + /// let join_expr = df.find_qualified_columns(&["a"])?.iter() + /// .zip(df2.find_qualified_columns(&["a"])?.iter()) + /// .map(|(col1, col2)| col(*col1).eq(col(*col2))) + /// .collect::>(); + /// let df3 = df.join_on(df2, JoinType::Inner, join_expr)?; + /// # Ok(()) + /// # } + /// ``` + pub fn find_qualified_columns( + &self, + names: &[&str], + ) -> Result, &FieldRef)>> { + let schema = self.logical_plan().schema(); + names + .iter() + .map(|name| { + schema + .qualified_field_from_column(&Column::from_name(*name)) + .map_err(|_| plan_datafusion_err!("Column '{}' not found", name)) + }) + .collect() + } + /// Helper for creating DataFrame. /// # Example /// ``` - /// use std::sync::Arc; /// use arrow::array::{ArrayRef, Int32Array, StringArray}; /// use datafusion::prelude::DataFrame; + /// use std::sync::Arc; /// let id: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); /// let name: ArrayRef = Arc::new(StringArray::from(vec!["foo", "bar", "baz"])); /// let df = DataFrame::from_columns(vec![("id", id), ("name", name)]).unwrap(); @@ -2317,6 +2634,7 @@ macro_rules! dataframe { #[derive(Debug)] struct DataFrameTableProvider { plan: LogicalPlan, + table_type: TableType, } #[async_trait] @@ -2325,7 +2643,7 @@ impl TableProvider for DataFrameTableProvider { self } - fn get_logical_plan(&self) -> Option> { + fn get_logical_plan(&self) -> Option> { Some(Cow::Borrowed(&self.plan)) } @@ -2338,12 +2656,11 @@ impl TableProvider for DataFrameTableProvider { } fn schema(&self) -> SchemaRef { - let schema: Schema = self.plan.schema().as_ref().into(); - Arc::new(schema) + Arc::clone(self.plan.schema().inner()) } fn table_type(&self) -> TableType { - TableType::View + self.table_type } async fn scan( diff --git a/datafusion/core/src/dataframe/parquet.rs b/datafusion/core/src/dataframe/parquet.rs index 1bb5444ca009f..e9c49a92843d6 100644 --- a/datafusion/core/src/dataframe/parquet.rs +++ b/datafusion/core/src/dataframe/parquet.rs @@ -42,13 +42,15 @@ impl DataFrame { /// use datafusion::dataframe::DataFrameWriteOptions; /// let ctx = SessionContext::new(); /// // Sort the data by column "b" and write it to a new location - /// ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await? - /// .sort(vec![col("b").sort(true, true)])? // sort by b asc, nulls first - /// .write_parquet( - /// "output.parquet", - /// DataFrameWriteOptions::new(), - /// None, // can also specify parquet writing options here - /// ).await?; + /// ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await? + /// .sort(vec![col("b").sort(true, true)])? // sort by b asc, nulls first + /// .write_parquet( + /// "output.parquet", + /// DataFrameWriteOptions::new(), + /// None, // can also specify parquet writing options here + /// ) + /// .await?; /// # fs::remove_file("output.parquet")?; /// # Ok(()) /// # } @@ -74,6 +76,8 @@ impl DataFrame { let file_type = format_as_file_type(format); + let copy_options = options.build_sink_options(); + let plan = if options.sort_by.is_empty() { self.plan } else { @@ -86,7 +90,7 @@ impl DataFrame { plan, path.into(), file_type, - Default::default(), + copy_options, options.partition_by, )? .build()?; @@ -116,11 +120,26 @@ mod tests { use datafusion_execution::config::SessionConfig; use datafusion_expr::{col, lit}; + #[cfg(feature = "parquet_encryption")] + use datafusion_common::config::ConfigFileEncryptionProperties; use object_store::local::LocalFileSystem; use parquet::file::reader::FileReader; use tempfile::TempDir; use url::Url; + /// Helper to extract a metric value by name from aggregated metrics. + fn metric_usize( + aggregated: &datafusion_physical_expr_common::metrics::MetricsSet, + name: &str, + ) -> usize { + aggregated + .iter() + .find(|m| m.value().name() == name) + .unwrap_or_else(|| panic!("should have {name} metric")) + .value() + .as_usize() + } + #[tokio::test] async fn filter_pushdown_dataframe() -> Result<()> { let ctx = SessionContext::new(); @@ -146,7 +165,7 @@ mod tests { let plan = df.explain(false, false)?.collect().await?; // Filters all the way to Parquet let formatted = pretty::pretty_format_batches(&plan)?.to_string(); - assert!(formatted.contains("FilterExec: id@0 = 1")); + assert!(formatted.contains("FilterExec: id@0 = 1"), "{formatted}"); Ok(()) } @@ -205,7 +224,7 @@ mod tests { &HashMap::from_iter( [("datafusion.execution.batch_size", "10")] .iter() - .map(|(s1, s2)| (s1.to_string(), s2.to_string())), + .map(|(s1, s2)| ((*s1).to_string(), (*s2).to_string())), ), )?); register_aggregate_csv(&ctx, "aggregate_test_100").await?; @@ -246,4 +265,350 @@ mod tests { Ok(()) } + + #[rstest::rstest] + #[cfg(feature = "parquet_encryption")] + #[tokio::test] + async fn roundtrip_parquet_with_encryption( + #[values(false, true)] allow_single_file_parallelism: bool, + ) -> Result<()> { + use parquet::encryption::decrypt::FileDecryptionProperties; + use parquet::encryption::encrypt::FileEncryptionProperties; + + let test_df = test_util::test_table().await?; + + let schema = test_df.schema(); + let footer_key = b"0123456789012345".to_vec(); // 128bit/16 + let column_key = b"1234567890123450".to_vec(); // 128bit/16 + + let mut encrypt = FileEncryptionProperties::builder(footer_key.clone()); + let mut decrypt = FileDecryptionProperties::builder(footer_key.clone()); + + for field in schema.fields().iter() { + encrypt = encrypt.with_column_key(field.name().as_str(), column_key.clone()); + decrypt = decrypt.with_column_key(field.name().as_str(), column_key.clone()); + } + + let encrypt = encrypt.build()?; + let decrypt = decrypt.build()?; + + let df = test_df.clone(); + let tmp_dir = TempDir::new()?; + let tempfile = tmp_dir.path().join("roundtrip.parquet"); + let tempfile_str = tempfile.into_os_string().into_string().unwrap(); + + // Write encrypted parquet using write_parquet + let mut options = TableParquetOptions::default(); + options.crypto.file_encryption = + Some(ConfigFileEncryptionProperties::from(&encrypt)); + options.global.allow_single_file_parallelism = allow_single_file_parallelism; + + df.write_parquet( + tempfile_str.as_str(), + DataFrameWriteOptions::new().with_single_file_output(true), + Some(options), + ) + .await?; + let num_rows_written = test_df.count().await?; + + // Read encrypted parquet + let ctx: SessionContext = SessionContext::new(); + let read_options = + ParquetReadOptions::default().file_decryption_properties((&decrypt).into()); + + ctx.register_parquet("roundtrip_parquet", &tempfile_str, read_options.clone()) + .await?; + + let df_enc = ctx.sql("SELECT * FROM roundtrip_parquet").await?; + let num_rows_read = df_enc.count().await?; + + assert_eq!(num_rows_read, num_rows_written); + + // Read encrypted parquet and subset rows + columns + let encrypted_parquet_df = ctx.read_parquet(tempfile_str, read_options).await?; + + // Select three columns and filter the results + // Test that the filter works as expected + let selected = encrypted_parquet_df + .clone() + .select_columns(&["c1", "c2", "c3"])? + .filter(col("c2").gt(lit(4)))?; + + let num_rows_selected = selected.count().await?; + assert_eq!(num_rows_selected, 14); + + Ok(()) + } + + /// Test FileOutputMode::SingleFile - explicitly request single file output + /// for paths WITHOUT file extensions. This verifies the fix for the regression + /// where extension heuristics ignored the explicit with_single_file_output(true). + #[tokio::test] + async fn test_file_output_mode_single_file() -> Result<()> { + use arrow::array::Int32Array; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + + let ctx = SessionContext::new(); + let tmp_dir = TempDir::new()?; + + // Path WITHOUT .parquet extension - this is the key scenario + let output_path = tmp_dir.path().join("data_no_ext"); + let output_path_str = output_path.to_str().unwrap(); + + let df = ctx.read_batch(RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + )?)?; + + // Explicitly request single file output + df.write_parquet( + output_path_str, + DataFrameWriteOptions::new().with_single_file_output(true), + None, + ) + .await?; + + // Verify: output should be a FILE, not a directory + assert!( + output_path.is_file(), + "Expected single file at {:?}, but got is_file={}, is_dir={}", + output_path, + output_path.is_file(), + output_path.is_dir() + ); + + // Verify the file is readable as parquet + let file = std::fs::File::open(&output_path)?; + let reader = parquet::file::reader::SerializedFileReader::new(file)?; + let metadata = reader.metadata(); + assert_eq!(metadata.num_row_groups(), 1); + assert_eq!(metadata.file_metadata().num_rows(), 3); + + Ok(()) + } + + /// Test FileOutputMode::Automatic - uses extension heuristic. + /// Path WITH extension -> single file; path WITHOUT extension -> directory. + #[tokio::test] + async fn test_file_output_mode_automatic() -> Result<()> { + use arrow::array::Int32Array; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + + let ctx = SessionContext::new(); + let tmp_dir = TempDir::new()?; + + let schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + )?; + + // Case 1: Path WITH extension -> should create single file (Automatic mode) + let output_with_ext = tmp_dir.path().join("data.parquet"); + let df = ctx.read_batch(batch.clone())?; + df.write_parquet( + output_with_ext.to_str().unwrap(), + DataFrameWriteOptions::new(), // Automatic mode (default) + None, + ) + .await?; + + assert!( + output_with_ext.is_file(), + "Path with extension should be a single file, got is_file={}, is_dir={}", + output_with_ext.is_file(), + output_with_ext.is_dir() + ); + + // Case 2: Path WITHOUT extension -> should create directory (Automatic mode) + let output_no_ext = tmp_dir.path().join("data_dir"); + let df = ctx.read_batch(batch)?; + df.write_parquet( + output_no_ext.to_str().unwrap(), + DataFrameWriteOptions::new(), // Automatic mode (default) + None, + ) + .await?; + + assert!( + output_no_ext.is_dir(), + "Path without extension should be a directory, got is_file={}, is_dir={}", + output_no_ext.is_file(), + output_no_ext.is_dir() + ); + + Ok(()) + } + + /// Test that ParquetSink exposes rows_written, bytes_written, and + /// elapsed_compute metrics via DataSinkExec. + #[tokio::test] + async fn test_parquet_sink_metrics() -> Result<()> { + use arrow::array::Int32Array; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use datafusion_execution::TaskContext; + + use futures::TryStreamExt; + + let ctx = SessionContext::new(); + let tmp_dir = TempDir::new()?; + let output_path = tmp_dir.path().join("metrics_test.parquet"); + let output_path_str = output_path.to_str().unwrap(); + + // Register a table with 100 rows + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("val", DataType::Int32, false), + ])); + let ids: Vec = (0..100).collect(); + let vals: Vec = (100..200).collect(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(ids)), + Arc::new(Int32Array::from(vals)), + ], + )?; + ctx.register_batch("source", batch)?; + + // Create the physical plan for COPY TO + let df = ctx + .sql(&format!( + "COPY source TO '{output_path_str}' STORED AS PARQUET" + )) + .await?; + let plan = df.create_physical_plan().await?; + + // Execute the plan + let task_ctx = Arc::new(TaskContext::from(&ctx.state())); + let stream = plan.execute(0, task_ctx)?; + let _batches: Vec<_> = stream.try_collect().await?; + + // Check metrics on the DataSinkExec (top-level plan) + let metrics = plan + .metrics() + .expect("DataSinkExec should return metrics from ParquetSink"); + let aggregated = metrics.aggregate_by_name(); + + // rows_written should be 100 + assert_eq!( + metric_usize(&aggregated, "rows_written"), + 100, + "expected 100 rows written" + ); + + // bytes_written should be > 0 + let bytes_written = metric_usize(&aggregated, "bytes_written"); + assert!( + bytes_written > 0, + "expected bytes_written > 0, got {bytes_written}" + ); + + // elapsed_compute should be > 0 + let elapsed = metric_usize(&aggregated, "elapsed_compute"); + assert!(elapsed > 0, "expected elapsed_compute > 0"); + + Ok(()) + } + + /// Test that ParquetSink metrics work with single_file_parallelism enabled. + #[tokio::test] + async fn test_parquet_sink_metrics_parallel() -> Result<()> { + use arrow::array::Int32Array; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use datafusion_execution::TaskContext; + + use futures::TryStreamExt; + + let ctx = SessionContext::new(); + ctx.sql("SET datafusion.execution.parquet.allow_single_file_parallelism = true") + .await? + .collect() + .await?; + + let tmp_dir = TempDir::new()?; + let output_path = tmp_dir.path().join("metrics_parallel.parquet"); + let output_path_str = output_path.to_str().unwrap(); + + let schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + let ids: Vec = (0..50).collect(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(ids))], + )?; + ctx.register_batch("source2", batch)?; + + let df = ctx + .sql(&format!( + "COPY source2 TO '{output_path_str}' STORED AS PARQUET" + )) + .await?; + let plan = df.create_physical_plan().await?; + let task_ctx = Arc::new(TaskContext::from(&ctx.state())); + let stream = plan.execute(0, task_ctx)?; + let _batches: Vec<_> = stream.try_collect().await?; + + let metrics = plan.metrics().expect("DataSinkExec should return metrics"); + let aggregated = metrics.aggregate_by_name(); + + assert_eq!(metric_usize(&aggregated, "rows_written"), 50); + assert!(metric_usize(&aggregated, "bytes_written") > 0); + + Ok(()) + } + + /// Test FileOutputMode::Directory - explicitly request directory output + /// even for paths WITH file extensions. + #[tokio::test] + async fn test_file_output_mode_directory() -> Result<()> { + use arrow::array::Int32Array; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + + let ctx = SessionContext::new(); + let tmp_dir = TempDir::new()?; + + // Path WITH .parquet extension but explicitly requesting directory output + let output_path = tmp_dir.path().join("output.parquet"); + let output_path_str = output_path.to_str().unwrap(); + + let df = ctx.read_batch(RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + )?)?; + + // Explicitly request directory output (single_file_output = false) + df.write_parquet( + output_path_str, + DataFrameWriteOptions::new().with_single_file_output(false), + None, + ) + .await?; + + // Verify: output should be a DIRECTORY, not a single file + assert!( + output_path.is_dir(), + "Expected directory at {:?}, but got is_file={}, is_dir={}", + output_path, + output_path.is_file(), + output_path.is_dir() + ); + + // Verify the directory contains parquet file(s) + let entries: Vec<_> = std::fs::read_dir(&output_path)? + .filter_map(|e| e.ok()) + .collect(); + assert!( + !entries.is_empty(), + "Directory should contain at least one file" + ); + + Ok(()) + } } diff --git a/datafusion/core/src/datasource/dynamic_file.rs b/datafusion/core/src/datasource/dynamic_file.rs index b30d53e586911..50ee96da3dff0 100644 --- a/datafusion/core/src/datasource/dynamic_file.rs +++ b/datafusion/core/src/datasource/dynamic_file.rs @@ -20,8 +20,9 @@ use std::sync::Arc; -use crate::datasource::listing::{ListingTable, ListingTableConfig, ListingTableUrl}; use crate::datasource::TableProvider; +use crate::datasource::listing::ListingTableConfigExt; +use crate::datasource::listing::{ListingTable, ListingTableConfig, ListingTableUrl}; use crate::error::Result; use crate::execution::context::SessionState; diff --git a/datafusion/core/src/datasource/empty.rs b/datafusion/core/src/datasource/empty.rs index 77686c5eb7c27..5aeca92b1626d 100644 --- a/datafusion/core/src/datasource/empty.rs +++ b/datafusion/core/src/datasource/empty.rs @@ -28,8 +28,8 @@ use datafusion_common::project_schema; use crate::datasource::{TableProvider, TableType}; use crate::error::Result; use crate::logical_expr::Expr; -use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::ExecutionPlan; +use datafusion_physical_plan::empty::EmptyExec; /// An empty plan that is useful for testing and generating plans /// without mapping them to actual data. diff --git a/datafusion/core/src/datasource/file_format/arrow.rs b/datafusion/core/src/datasource/file_format/arrow.rs index b620ff62d9a65..338de76b1353b 100644 --- a/datafusion/core/src/datasource/file_format/arrow.rs +++ b/datafusion/core/src/datasource/file_format/arrow.rs @@ -15,504 +15,97 @@ // specific language governing permissions and limitations // under the License. -//! [`ArrowFormat`]: Apache Arrow [`FileFormat`] abstractions -//! -//! Works with files following the [Arrow IPC format](https://arrow.apache.org/docs/format/Columnar.html#ipc-file-format) +//! Re-exports the [`datafusion_datasource_arrow::file_format`] module, and contains tests for it. +pub use datafusion_datasource_arrow::file_format::*; -use std::any::Any; -use std::borrow::Cow; -use std::collections::HashMap; -use std::fmt::{self, Debug}; -use std::sync::Arc; - -use super::file_compression_type::FileCompressionType; -use super::write::demux::DemuxedStreamReceiver; -use super::write::SharedBuffer; -use super::FileFormatFactory; -use crate::datasource::file_format::write::get_writer_schema; -use crate::datasource::file_format::FileFormat; -use crate::datasource::physical_plan::{ArrowSource, FileSink, FileSinkConfig}; -use crate::error::Result; -use crate::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; - -use arrow::datatypes::{Schema, SchemaRef}; -use arrow::error::ArrowError; -use arrow::ipc::convert::fb_to_schema; -use arrow::ipc::reader::FileReader; -use arrow::ipc::writer::IpcWriteOptions; -use arrow::ipc::{root_as_message, CompressionType}; -use datafusion_catalog::Session; -use datafusion_common::parsers::CompressionTypeVariant; -use datafusion_common::{ - not_impl_err, DataFusionError, GetExt, Statistics, DEFAULT_ARROW_EXTENSION, -}; -use datafusion_common_runtime::{JoinSet, SpawnedTask}; -use datafusion_datasource::display::FileGroupDisplay; -use datafusion_datasource::file::FileSource; -use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; -use datafusion_datasource::sink::{DataSink, DataSinkExec}; -use datafusion_datasource::write::ObjectWriterBuilder; -use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_expr::dml::InsertOp; -use datafusion_physical_expr_common::sort_expr::LexRequirement; - -use async_trait::async_trait; -use bytes::Bytes; -use datafusion_datasource::source::DataSourceExec; -use futures::stream::BoxStream; -use futures::StreamExt; -use object_store::{GetResultPayload, ObjectMeta, ObjectStore}; -use tokio::io::AsyncWriteExt; - -/// Initial writing buffer size. Note this is just a size hint for efficiency. It -/// will grow beyond the set value if needed. -const INITIAL_BUFFER_BYTES: usize = 1048576; - -/// If the buffered Arrow data exceeds this size, it is flushed to object store -const BUFFER_FLUSH_BYTES: usize = 1024000; - -#[derive(Default, Debug)] -/// Factory struct used to create [ArrowFormat] -pub struct ArrowFormatFactory; - -impl ArrowFormatFactory { - /// Creates an instance of [ArrowFormatFactory] - pub fn new() -> Self { - Self {} - } -} - -impl FileFormatFactory for ArrowFormatFactory { - fn create( - &self, - _state: &dyn Session, - _format_options: &HashMap, - ) -> Result> { - Ok(Arc::new(ArrowFormat)) - } - - fn default(&self) -> Arc { - Arc::new(ArrowFormat) - } - - fn as_any(&self) -> &dyn Any { - self - } -} - -impl GetExt for ArrowFormatFactory { - fn get_ext(&self) -> String { - // Removes the dot, i.e. ".parquet" -> "parquet" - DEFAULT_ARROW_EXTENSION[1..].to_string() - } -} - -/// Arrow `FileFormat` implementation. -#[derive(Default, Debug)] -pub struct ArrowFormat; - -#[async_trait] -impl FileFormat for ArrowFormat { - fn as_any(&self) -> &dyn Any { - self - } - - fn get_ext(&self) -> String { - ArrowFormatFactory::new().get_ext() - } - - fn get_ext_with_compression( - &self, - file_compression_type: &FileCompressionType, - ) -> Result { - let ext = self.get_ext(); - match file_compression_type.get_variant() { - CompressionTypeVariant::UNCOMPRESSED => Ok(ext), - _ => Err(DataFusionError::Internal( - "Arrow FileFormat does not support compression.".into(), - )), - } - } - - async fn infer_schema( - &self, - _state: &dyn Session, - store: &Arc, - objects: &[ObjectMeta], - ) -> Result { - let mut schemas = vec![]; - for object in objects { - let r = store.as_ref().get(&object.location).await?; - let schema = match r.payload { - #[cfg(not(target_arch = "wasm32"))] - GetResultPayload::File(mut file, _) => { - let reader = FileReader::try_new(&mut file, None)?; - reader.schema() - } - GetResultPayload::Stream(stream) => { - infer_schema_from_file_stream(stream).await? - } - }; - schemas.push(schema.as_ref().clone()); - } - let merged_schema = Schema::try_merge(schemas)?; - Ok(Arc::new(merged_schema)) - } - - async fn infer_stats( - &self, - _state: &dyn Session, - _store: &Arc, - table_schema: SchemaRef, - _object: &ObjectMeta, - ) -> Result { - Ok(Statistics::new_unknown(&table_schema)) - } - - async fn create_physical_plan( - &self, - _state: &dyn Session, - conf: FileScanConfig, - ) -> Result> { - let source = Arc::new(ArrowSource::default()); - let config = FileScanConfigBuilder::from(conf) - .with_source(source) - .build(); - - Ok(DataSourceExec::from_data_source(config)) - } - - async fn create_writer_physical_plan( - &self, - input: Arc, - _state: &dyn Session, - conf: FileSinkConfig, - order_requirements: Option, - ) -> Result> { - if conf.insert_op != InsertOp::Append { - return not_impl_err!("Overwrites are not implemented yet for Arrow format"); - } - - let sink = Arc::new(ArrowFileSink::new(conf)); - - Ok(Arc::new(DataSinkExec::new(input, sink, order_requirements)) as _) - } - - fn file_source(&self) -> Arc { - Arc::new(ArrowSource::default()) - } -} - -/// Implements [`FileSink`] for writing to arrow_ipc files -struct ArrowFileSink { - config: FileSinkConfig, -} - -impl ArrowFileSink { - fn new(config: FileSinkConfig) -> Self { - Self { config } - } -} - -#[async_trait] -impl FileSink for ArrowFileSink { - fn config(&self) -> &FileSinkConfig { - &self.config - } - - async fn spawn_writer_tasks_and_join( - &self, - context: &Arc, - demux_task: SpawnedTask>, - mut file_stream_rx: DemuxedStreamReceiver, - object_store: Arc, - ) -> Result { - let mut file_write_tasks: JoinSet> = - JoinSet::new(); - - let ipc_options = - IpcWriteOptions::try_new(64, false, arrow_ipc::MetadataVersion::V5)? - .try_with_compression(Some(CompressionType::LZ4_FRAME))?; - while let Some((path, mut rx)) = file_stream_rx.recv().await { - let shared_buffer = SharedBuffer::new(INITIAL_BUFFER_BYTES); - let mut arrow_writer = arrow_ipc::writer::FileWriter::try_new_with_options( - shared_buffer.clone(), - &get_writer_schema(&self.config), - ipc_options.clone(), - )?; - let mut object_store_writer = ObjectWriterBuilder::new( - FileCompressionType::UNCOMPRESSED, - &path, - Arc::clone(&object_store), - ) - .with_buffer_size(Some( - context - .session_config() - .options() - .execution - .objectstore_writer_buffer_size, - )) - .build()?; - file_write_tasks.spawn(async move { - let mut row_count = 0; - while let Some(batch) = rx.recv().await { - row_count += batch.num_rows(); - arrow_writer.write(&batch)?; - let mut buff_to_flush = shared_buffer.buffer.try_lock().unwrap(); - if buff_to_flush.len() > BUFFER_FLUSH_BYTES { - object_store_writer - .write_all(buff_to_flush.as_slice()) - .await?; - buff_to_flush.clear(); - } - } - arrow_writer.finish()?; - let final_buff = shared_buffer.buffer.try_lock().unwrap(); - - object_store_writer.write_all(final_buff.as_slice()).await?; - object_store_writer.shutdown().await?; - Ok(row_count) - }); - } - - let mut row_count = 0; - while let Some(result) = file_write_tasks.join_next().await { - match result { - Ok(r) => { - row_count += r?; - } - Err(e) => { - if e.is_panic() { - std::panic::resume_unwind(e.into_panic()); - } else { - unreachable!(); - } - } - } - } - - demux_task - .join_unwind() - .await - .map_err(DataFusionError::ExecutionJoin)??; - Ok(row_count as u64) - } -} - -impl Debug for ArrowFileSink { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("ArrowFileSink").finish() - } -} - -impl DisplayAs for ArrowFileSink { - fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match t { - DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!(f, "ArrowFileSink(file_groups=",)?; - FileGroupDisplay(&self.config.file_group).fmt_as(t, f)?; - write!(f, ")") - } - DisplayFormatType::TreeRender => { - writeln!(f, "format: arrow")?; - write!(f, "file={}", &self.config.original_url) - } - } - } -} - -#[async_trait] -impl DataSink for ArrowFileSink { - fn as_any(&self) -> &dyn Any { - self - } - - fn schema(&self) -> &SchemaRef { - self.config.output_schema() - } - - async fn write_all( - &self, - data: SendableRecordBatchStream, - context: &Arc, - ) -> Result { - FileSink::write_all(self, data, context).await - } -} - -const ARROW_MAGIC: [u8; 6] = [b'A', b'R', b'R', b'O', b'W', b'1']; -const CONTINUATION_MARKER: [u8; 4] = [0xff; 4]; - -/// Custom implementation of inferring schema. Should eventually be moved upstream to arrow-rs. -/// See -async fn infer_schema_from_file_stream( - mut stream: BoxStream<'static, object_store::Result>, -) -> Result { - // Expected format: - // - 6 bytes - // - 2 bytes - // - 4 bytes, not present below v0.15.0 - // - 4 bytes - // - // - - // So in first read we need at least all known sized sections, - // which is 6 + 2 + 4 + 4 = 16 bytes. - let bytes = collect_at_least_n_bytes(&mut stream, 16, None).await?; - - // Files should start with these magic bytes - if bytes[0..6] != ARROW_MAGIC { - return Err(ArrowError::ParseError( - "Arrow file does not contain correct header".to_string(), - ))?; - } - - // Since continuation marker bytes added in later versions - let (meta_len, rest_of_bytes_start_index) = if bytes[8..12] == CONTINUATION_MARKER { - (&bytes[12..16], 16) - } else { - (&bytes[8..12], 12) - }; - - let meta_len = [meta_len[0], meta_len[1], meta_len[2], meta_len[3]]; - let meta_len = i32::from_le_bytes(meta_len); - - // Read bytes for Schema message - let block_data = if bytes[rest_of_bytes_start_index..].len() < meta_len as usize { - // Need to read more bytes to decode Message - let mut block_data = Vec::with_capacity(meta_len as usize); - // In case we had some spare bytes in our initial read chunk - block_data.extend_from_slice(&bytes[rest_of_bytes_start_index..]); - let size_to_read = meta_len as usize - block_data.len(); - let block_data = - collect_at_least_n_bytes(&mut stream, size_to_read, Some(block_data)).await?; - Cow::Owned(block_data) - } else { - // Already have the bytes we need - let end_index = meta_len as usize + rest_of_bytes_start_index; - let block_data = &bytes[rest_of_bytes_start_index..end_index]; - Cow::Borrowed(block_data) - }; +#[cfg(test)] +mod tests { + use futures::StreamExt; + use std::sync::Arc; - // Decode Schema message - let message = root_as_message(&block_data).map_err(|err| { - ArrowError::ParseError(format!("Unable to read IPC message as metadata: {err:?}")) - })?; - let ipc_schema = message.header_as_schema().ok_or_else(|| { - ArrowError::IpcError("Unable to read IPC message as schema".to_string()) - })?; - let schema = fb_to_schema(ipc_schema); + use arrow::array::{Int64Array, StringArray}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use datafusion_common::Result; - Ok(Arc::new(schema)) -} + use crate::execution::options::ArrowReadOptions; + use crate::prelude::SessionContext; -async fn collect_at_least_n_bytes( - stream: &mut BoxStream<'static, object_store::Result>, - n: usize, - extend_from: Option>, -) -> Result> { - let mut buf = extend_from.unwrap_or_else(|| Vec::with_capacity(n)); - // If extending existing buffer then ensure we read n additional bytes - let n = n + buf.len(); - while let Some(bytes) = stream.next().await.transpose()? { - buf.extend_from_slice(&bytes); - if buf.len() >= n { - break; - } - } - if buf.len() < n { - return Err(ArrowError::ParseError( - "Unexpected end of byte stream for Arrow IPC file".to_string(), - ))?; - } - Ok(buf) -} + #[tokio::test] + async fn test_write_empty_arrow_from_sql() -> Result<()> { + let ctx = SessionContext::new(); -#[cfg(test)] -mod tests { - use super::*; - use crate::execution::context::SessionContext; + let tmp_dir = tempfile::TempDir::new()?; + let path = format!("{}/empty_sql.arrow", tmp_dir.path().to_string_lossy()); - use chrono::DateTime; - use object_store::{chunked::ChunkedStore, memory::InMemory, path::Path}; + ctx.sql(&format!( + "COPY (SELECT CAST(1 AS BIGINT) AS id LIMIT 0) TO '{path}' STORED AS ARROW", + )) + .await? + .collect() + .await?; - #[tokio::test] - async fn test_infer_schema_stream() -> Result<()> { - let mut bytes = std::fs::read("tests/data/example.arrow")?; - bytes.truncate(bytes.len() - 20); // mangle end to show we don't need to read whole file - let location = Path::parse("example.arrow")?; - let in_memory_store: Arc = Arc::new(InMemory::new()); - in_memory_store.put(&location, bytes.into()).await?; + assert!(std::path::Path::new(&path).exists()); - let session_ctx = SessionContext::new(); - let state = session_ctx.state(); - let object_meta = ObjectMeta { - location, - last_modified: DateTime::default(), - size: u64::MAX, - e_tag: None, - version: None, - }; + let read_df = ctx.read_arrow(&path, ArrowReadOptions::default()).await?; + let stream = read_df.execute_stream().await?; - let arrow_format = ArrowFormat {}; - let expected = vec!["f0: Int64", "f1: Utf8", "f2: Boolean"]; + assert_eq!(stream.schema().fields().len(), 1); + assert_eq!(stream.schema().field(0).name(), "id"); - // Test chunk sizes where too small so we keep having to read more bytes - // And when large enough that first read contains all we need - for chunk_size in [7, 3000] { - let store = Arc::new(ChunkedStore::new(in_memory_store.clone(), chunk_size)); - let inferred_schema = arrow_format - .infer_schema( - &state, - &(store.clone() as Arc), - std::slice::from_ref(&object_meta), - ) - .await?; - let actual_fields = inferred_schema - .fields() - .iter() - .map(|f| format!("{}: {:?}", f.name(), f.data_type())) - .collect::>(); - assert_eq!(expected, actual_fields); - } + let results: Vec<_> = stream.collect().await; + let total_rows: usize = results + .iter() + .filter_map(|r| r.as_ref().ok()) + .map(|b| b.num_rows()) + .sum(); + assert_eq!(total_rows, 0); Ok(()) } #[tokio::test] - async fn test_infer_schema_short_stream() -> Result<()> { - let mut bytes = std::fs::read("tests/data/example.arrow")?; - bytes.truncate(20); // should cause error that file shorter than expected - let location = Path::parse("example.arrow")?; - let in_memory_store: Arc = Arc::new(InMemory::new()); - in_memory_store.put(&location, bytes.into()).await?; - - let session_ctx = SessionContext::new(); - let state = session_ctx.state(); - let object_meta = ObjectMeta { - location, - last_modified: DateTime::default(), - size: u64::MAX, - e_tag: None, - version: None, - }; - - let arrow_format = ArrowFormat {}; - - let store = Arc::new(ChunkedStore::new(in_memory_store.clone(), 7)); - let err = arrow_format - .infer_schema( - &state, - &(store.clone() as Arc), - std::slice::from_ref(&object_meta), - ) - .await; - - assert!(err.is_err()); - assert_eq!( - "Arrow error: Parser error: Unexpected end of byte stream for Arrow IPC file", - err.unwrap_err().to_string() - ); + async fn test_write_empty_arrow_from_record_batch() -> Result<()> { + let ctx = SessionContext::new(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, true), + ])); + let empty_batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int64Array::from(Vec::::new())), + Arc::new(StringArray::from(Vec::>::new())), + ], + )?; + + let tmp_dir = tempfile::TempDir::new()?; + let path = format!("{}/empty_batch.arrow", tmp_dir.path().to_string_lossy()); + + ctx.register_batch("empty_table", empty_batch)?; + + ctx.sql(&format!("COPY empty_table TO '{path}' STORED AS ARROW")) + .await? + .collect() + .await?; + + assert!(std::path::Path::new(&path).exists()); + + let read_df = ctx.read_arrow(&path, ArrowReadOptions::default()).await?; + let stream = read_df.execute_stream().await?; + + assert_eq!(stream.schema().fields().len(), 2); + assert_eq!(stream.schema().field(0).name(), "id"); + assert_eq!(stream.schema().field(1).name(), "name"); + + let results: Vec<_> = stream.collect().await; + let total_rows: usize = results + .iter() + .filter_map(|r| r.as_ref().ok()) + .map(|b| b.num_rows()) + .sum(); + assert_eq!(total_rows, 0); Ok(()) } diff --git a/datafusion/core/src/datasource/file_format/avro.rs b/datafusion/core/src/datasource/file_format/avro.rs index 3428d08a6ae52..7cf23ee294d86 100644 --- a/datafusion/core/src/datasource/file_format/avro.rs +++ b/datafusion/core/src/datasource/file_format/avro.rs @@ -26,20 +26,21 @@ mod tests { use crate::{ datasource::file_format::test_util::scan_format, prelude::SessionContext, }; - use arrow::array::{as_string_array, Array}; + use arrow::array::{Array, as_string_array}; use datafusion_catalog::Session; use datafusion_common::test_util::batches_to_string; use datafusion_common::{ + Result, cast::{ as_binary_array, as_boolean_array, as_float32_array, as_float64_array, as_int32_array, as_timestamp_microsecond_array, }, - test_util, Result, + test_util, }; use datafusion_datasource_avro::AvroFormat; use datafusion_execution::config::SessionConfig; - use datafusion_physical_plan::{collect, ExecutionPlan}; + use datafusion_physical_plan::{ExecutionPlan, collect}; use futures::StreamExt; use insta::assert_snapshot; @@ -94,7 +95,7 @@ mod tests { .schema() .fields() .iter() - .map(|f| format!("{}: {:?}", f.name(), f.data_type())) + .map(|f| format!("{}: {}", f.name(), f.data_type())) .collect(); assert_eq!( vec![ @@ -108,7 +109,7 @@ mod tests { "double_col: Float64", "date_string_col: Binary", "string_col: Binary", - "timestamp_col: Timestamp(Microsecond, None)", + "timestamp_col: Timestamp(µs)", ], x ); @@ -116,20 +117,20 @@ mod tests { let batches = collect(exec, task_ctx).await?; assert_eq!(batches.len(), 1); - assert_snapshot!(batches_to_string(&batches),@r###" - +----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+ - | id | bool_col | tinyint_col | smallint_col | int_col | bigint_col | float_col | double_col | date_string_col | string_col | timestamp_col | - +----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+ - | 4 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30332f30312f3039 | 30 | 2009-03-01T00:00:00 | - | 5 | false | 1 | 1 | 1 | 10 | 1.1 | 10.1 | 30332f30312f3039 | 31 | 2009-03-01T00:01:00 | - | 6 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30342f30312f3039 | 30 | 2009-04-01T00:00:00 | - | 7 | false | 1 | 1 | 1 | 10 | 1.1 | 10.1 | 30342f30312f3039 | 31 | 2009-04-01T00:01:00 | - | 2 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30322f30312f3039 | 30 | 2009-02-01T00:00:00 | - | 3 | false | 1 | 1 | 1 | 10 | 1.1 | 10.1 | 30322f30312f3039 | 31 | 2009-02-01T00:01:00 | - | 0 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30312f30312f3039 | 30 | 2009-01-01T00:00:00 | - | 1 | false | 1 | 1 | 1 | 10 | 1.1 | 10.1 | 30312f30312f3039 | 31 | 2009-01-01T00:01:00 | - +----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+ - "###); + assert_snapshot!(batches_to_string(&batches),@r" + +----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+ + | id | bool_col | tinyint_col | smallint_col | int_col | bigint_col | float_col | double_col | date_string_col | string_col | timestamp_col | + +----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+ + | 4 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30332f30312f3039 | 30 | 2009-03-01T00:00:00 | + | 5 | false | 1 | 1 | 1 | 10 | 1.1 | 10.1 | 30332f30312f3039 | 31 | 2009-03-01T00:01:00 | + | 6 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30342f30312f3039 | 30 | 2009-04-01T00:00:00 | + | 7 | false | 1 | 1 | 1 | 10 | 1.1 | 10.1 | 30342f30312f3039 | 31 | 2009-04-01T00:01:00 | + | 2 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30322f30312f3039 | 30 | 2009-02-01T00:00:00 | + | 3 | false | 1 | 1 | 1 | 10 | 1.1 | 10.1 | 30322f30312f3039 | 31 | 2009-02-01T00:01:00 | + | 0 | true | 0 | 0 | 0 | 0 | 0.0 | 0.0 | 30312f30312f3039 | 30 | 2009-01-01T00:00:00 | + | 1 | false | 1 | 1 | 1 | 10 | 1.1 | 10.1 | 30312f30312f3039 | 31 | 2009-01-01T00:01:00 | + +----+----------+-------------+--------------+---------+------------+-----------+------------+------------------+------------+---------------------+ + "); Ok(()) } @@ -245,7 +246,10 @@ mod tests { values.push(array.value(i)); } - assert_eq!("[1235865600000000, 1235865660000000, 1238544000000000, 1238544060000000, 1233446400000000, 1233446460000000, 1230768000000000, 1230768060000000]", format!("{values:?}")); + assert_eq!( + "[1235865600000000, 1235865660000000, 1238544000000000, 1238544060000000, 1233446400000000, 1233446460000000, 1230768000000000, 1230768060000000]", + format!("{values:?}") + ); Ok(()) } diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index efec07abbca05..a068b4f5c0413 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -32,11 +32,12 @@ mod tests { use crate::prelude::{CsvReadOptions, SessionConfig, SessionContext}; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use datafusion_catalog::Session; + use datafusion_common::Result; use datafusion_common::cast::as_string_array; + use datafusion_common::config::CsvOptions; use datafusion_common::internal_err; use datafusion_common::stats::Precision; use datafusion_common::test_util::{arrow_test_data, batches_to_string}; - use datafusion_common::Result; use datafusion_datasource::decoder::{ BatchDeserializer, DecoderDeserializer, DeserializerOutput, }; @@ -44,10 +45,10 @@ mod tests { use datafusion_datasource::file_format::FileFormat; use datafusion_datasource::write::BatchSerializer; use datafusion_expr::{col, lit}; - use datafusion_physical_plan::{collect, ExecutionPlan}; + use datafusion_physical_plan::{ExecutionPlan, collect}; use arrow::array::{ - BooleanArray, Float64Array, Int32Array, RecordBatch, StringArray, + Array, BooleanArray, Float64Array, Int32Array, RecordBatch, StringArray, }; use arrow::compute::concat_batches; use arrow::csv::ReaderBuilder; @@ -55,14 +56,17 @@ mod tests { use async_trait::async_trait; use bytes::Bytes; use chrono::DateTime; - use futures::stream::BoxStream; + use datafusion_common::parsers::CompressionTypeVariant; use futures::StreamExt; + use futures::stream::BoxStream; use insta::assert_snapshot; + use object_store::chunked::ChunkedStore; use object_store::local::LocalFileSystem; use object_store::path::Path; use object_store::{ Attributes, GetOptions, GetResult, GetResultPayload, ListResult, MultipartUpload, - ObjectMeta, ObjectStore, PutMultipartOpts, PutOptions, PutPayload, PutResult, + ObjectMeta, ObjectStore, ObjectStoreExt, PutMultipartOptions, PutOptions, + PutPayload, PutResult, }; use regex::Regex; use rstest::*; @@ -96,16 +100,22 @@ mod tests { async fn put_multipart_opts( &self, _location: &Path, - _opts: PutMultipartOpts, + _opts: PutMultipartOptions, ) -> object_store::Result> { unimplemented!() } - async fn get(&self, location: &Path) -> object_store::Result { + async fn get_opts( + &self, + location: &Path, + _opts: GetOptions, + ) -> object_store::Result { let bytes = self.bytes_to_repeat.clone(); let len = bytes.len() as u64; let range = 0..len * self.max_iterations; let arc = self.iterations_detected.clone(); + #[expect(clippy::result_large_err)] + // closure only ever returns Ok; Err type is never constructed let stream = futures::stream::repeat_with(move || { let arc_inner = arc.clone(); *arc_inner.lock().unwrap() += 1; @@ -128,14 +138,6 @@ mod tests { }) } - async fn get_opts( - &self, - _location: &Path, - _opts: GetOptions, - ) -> object_store::Result { - unimplemented!() - } - async fn get_ranges( &self, _location: &Path, @@ -144,14 +146,6 @@ mod tests { unimplemented!() } - async fn head(&self, _location: &Path) -> object_store::Result { - unimplemented!() - } - - async fn delete(&self, _location: &Path) -> object_store::Result<()> { - unimplemented!() - } - fn list( &self, _prefix: Option<&Path>, @@ -166,17 +160,21 @@ mod tests { unimplemented!() } - async fn copy(&self, _from: &Path, _to: &Path) -> object_store::Result<()> { - unimplemented!() - } - - async fn copy_if_not_exists( + async fn copy_opts( &self, _from: &Path, _to: &Path, + _options: object_store::CopyOptions, ) -> object_store::Result<()> { unimplemented!() } + + fn delete_stream( + &self, + _locations: BoxStream<'static, object_store::Result>, + ) -> BoxStream<'static, object_store::Result> { + unimplemented!() + } } impl VariableStream { @@ -468,6 +466,59 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_infer_schema_stream_null_chunks() -> Result<()> { + let session_ctx = SessionContext::new(); + let state = session_ctx.state(); + + // a stream where each line is read as a separate chunk, + // data type for each chunk is inferred separately. + // +----+-----+----+ + // | c1 | c2 | c3 | + // +----+-----+----+ + // | 1 | 1.0 | | type: Int64, Float64, Null + // | | | | type: Null, Null, Null + // +----+-----+----+ + let chunked_object_store = Arc::new(ChunkedStore::new( + Arc::new(VariableStream::new( + Bytes::from( + r#"c1,c2,c3 +1,1.0, +,, +"#, + ), + 1, + )), + 1, + )); + let object_meta = ObjectMeta { + location: Path::parse("/")?, + last_modified: DateTime::default(), + size: u64::MAX, + e_tag: None, + version: None, + }; + + let csv_format = CsvFormat::default().with_has_header(true); + let inferred_schema = csv_format + .infer_schema( + &state, + &(chunked_object_store as Arc), + &[object_meta], + ) + .await?; + + let actual_fields: Vec<_> = inferred_schema + .fields() + .iter() + .map(|f| format!("{}: {:?}", f.name(), f.data_type())) + .collect(); + + // ensure null chunks don't skew type inference + assert_eq!(vec!["c1: Int64", "c2: Float64", "c3: Null"], actual_fields); + Ok(()) + } + #[rstest( file_compression_type, case(FileCompressionType::UNCOMPRESSED), @@ -565,15 +616,15 @@ mod tests { .collect() .await?; - assert_snapshot!(batches_to_string(&record_batch), @r###" - +----+------+ - | c2 | c3 | - +----+------+ - | 5 | 36 | - | 5 | -31 | - | 5 | -101 | - +----+------+ - "###); + assert_snapshot!(batches_to_string(&record_batch), @r" + +----+------+ + | c2 | c3 | + +----+------+ + | 5 | 36 | + | 5 | -31 | + | 5 | -101 | + +----+------+ + "); Ok(()) } @@ -650,11 +701,11 @@ mod tests { let re = Regex::new(r"DataSourceExec: file_groups=\{(\d+) group").unwrap(); - if let Some(captures) = re.captures(&plan) { - if let Some(match_) = captures.get(1) { - let n_partitions = match_.as_str().parse::().unwrap(); - return Ok(n_partitions); - } + if let Some(captures) = re.captures(&plan) + && let Some(match_) = captures.get(1) + { + let n_partitions = match_.as_str().parse::().unwrap(); + return Ok(n_partitions); } internal_err!("query contains no DataSourceExec") @@ -680,13 +731,13 @@ mod tests { let query_result = ctx.sql(query).await?.collect().await?; let actual_partitions = count_query_csv_partitions(&ctx, query).await?; - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r###" + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r" +--------------+ | sum(aggr.c2) | +--------------+ | 285 | +--------------+ - "###); + "); } assert_eq!(n_partitions, actual_partitions); @@ -719,13 +770,13 @@ mod tests { let query_result = ctx.sql(query).await?.collect().await?; let actual_partitions = count_query_csv_partitions(&ctx, query).await?; - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r###" + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r" +--------------+ | sum(aggr.c3) | +--------------+ | 781 | +--------------+ - "###); + "); } assert_eq!(1, actual_partitions); // Compressed csv won't be scanned in parallel @@ -756,13 +807,13 @@ mod tests { let query_result = ctx.sql(query).await?.collect().await?; let actual_partitions = count_query_csv_partitions(&ctx, query).await?; - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r###" + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r" +--------------+ | sum(aggr.c3) | +--------------+ | 781 | +--------------+ - "###); + "); } assert_eq!(1, actual_partitions); // csv won't be scanned in parallel when newlines_in_values is set @@ -787,10 +838,10 @@ mod tests { let query = "select * from empty where random() > 0.5;"; let query_result = ctx.sql(query).await?.collect().await?; - assert_snapshot!(batches_to_string(&query_result),@r###" - ++ - ++ - "###); + assert_snapshot!(batches_to_string(&query_result),@r" + ++ + ++ + "); Ok(()) } @@ -812,10 +863,136 @@ mod tests { let query = "select * from empty where random() > 0.5;"; let query_result = ctx.sql(query).await?.collect().await?; - assert_snapshot!(batches_to_string(&query_result),@r###" - ++ - ++ - "###); + assert_snapshot!(batches_to_string(&query_result),@r" + ++ + ++ + "); + + Ok(()) + } + + /// Read multiple csv files (some are empty) with header + /// + /// some_empty_with_header + /// ├── a_empty.csv + /// ├── b.csv + /// └── c_nulls_column.csv + /// + /// a_empty.csv: + /// c1,c2,c3 + /// + /// b.csv: + /// c1,c2,c3 + /// 1,1,1 + /// 2,2,2 + /// + /// c_nulls_column.csv: + /// c1,c2,c3 + /// 3,3, + #[tokio::test] + async fn test_csv_some_empty_with_header() -> Result<()> { + let ctx = SessionContext::new(); + ctx.register_csv( + "some_empty_with_header", + "tests/data/empty_files/some_empty_with_header", + CsvReadOptions::new().has_header(true), + ) + .await?; + + let query = "select sum(c3) from some_empty_with_header;"; + let query_result = ctx.sql(query).await?.collect().await?; + + assert_snapshot!(batches_to_string(&query_result),@r" + +--------------------------------+ + | sum(some_empty_with_header.c3) | + +--------------------------------+ + | 3 | + +--------------------------------+ + "); + + Ok(()) + } + + #[tokio::test] + async fn test_csv_extension_compressed() -> Result<()> { + // Write compressed CSV files + // Expect: under the directory, a file is created with ".csv.gz" extension + let ctx = SessionContext::new(); + + let df = ctx + .read_csv( + &format!("{}/csv/aggregate_test_100.csv", arrow_test_data()), + CsvReadOptions::default().has_header(true), + ) + .await?; + + let tmp_dir = tempfile::TempDir::new().unwrap(); + let path = format!("{}", tmp_dir.path().to_string_lossy()); + + let cfg1 = crate::dataframe::DataFrameWriteOptions::new(); + let cfg2 = CsvOptions::default() + .with_has_header(true) + .with_compression(CompressionTypeVariant::GZIP); + + df.write_csv(&path, cfg1, Some(cfg2)).await?; + assert!(std::path::Path::new(&path).exists()); + + let files: Vec<_> = std::fs::read_dir(&path).unwrap().collect(); + assert_eq!(files.len(), 1); + assert!( + files + .last() + .unwrap() + .as_ref() + .unwrap() + .path() + .file_name() + .unwrap() + .to_str() + .unwrap() + .ends_with(".csv.gz") + ); + + Ok(()) + } + + #[tokio::test] + async fn test_csv_extension_uncompressed() -> Result<()> { + // Write plain uncompressed CSV files + // Expect: under the directory, a file is created with ".csv" extension + let ctx = SessionContext::new(); + + let df = ctx + .read_csv( + &format!("{}/csv/aggregate_test_100.csv", arrow_test_data()), + CsvReadOptions::default().has_header(true), + ) + .await?; + + let tmp_dir = tempfile::TempDir::new().unwrap(); + let path = format!("{}", tmp_dir.path().to_string_lossy()); + + let cfg1 = crate::dataframe::DataFrameWriteOptions::new(); + let cfg2 = CsvOptions::default().with_has_header(true); + + df.write_csv(&path, cfg1, Some(cfg2)).await?; + assert!(std::path::Path::new(&path).exists()); + + let files: Vec<_> = std::fs::read_dir(&path).unwrap().collect(); + assert_eq!(files.len(), 1); + assert!( + files + .last() + .unwrap() + .as_ref() + .unwrap() + .path() + .file_name() + .unwrap() + .to_str() + .unwrap() + .ends_with(".csv") + ); Ok(()) } @@ -854,10 +1031,10 @@ mod tests { let query = "select * from empty where random() > 0.5;"; let query_result = ctx.sql(query).await?.collect().await?; - assert_snapshot!(batches_to_string(&query_result),@r###" - ++ - ++ - "###); + assert_snapshot!(batches_to_string(&query_result),@r" + ++ + ++ + "); Ok(()) } @@ -906,13 +1083,13 @@ mod tests { let query_result = ctx.sql(query).await?.collect().await?; let actual_partitions = count_query_csv_partitions(&ctx, query).await?; - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r###" - +---------------------+ - | sum(empty.column_1) | - +---------------------+ - | 10 | - +---------------------+ - "###);} + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r" + +---------------------+ + | sum(empty.column_1) | + +---------------------+ + | 10 | + +---------------------+ + ");} assert_eq!(n_partitions, actual_partitions); // Won't get partitioned if all files are empty @@ -954,13 +1131,13 @@ mod tests { file_size }; - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r###" + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r" +-----------------------+ | sum(one_col.column_1) | +-----------------------+ | 50 | +-----------------------+ - "###); + "); } assert_eq!(expected_partitions, actual_partitions); @@ -993,13 +1170,13 @@ mod tests { let query_result = ctx.sql(query).await?.collect().await?; let actual_partitions = count_query_csv_partitions(&ctx, query).await?; - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r###" - +---------------+ - | sum_of_5_cols | - +---------------+ - | 15 | - +---------------+ - "###);} + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&query_result),@r" + +---------------+ + | sum_of_5_cols | + +---------------+ + | 15 | + +---------------+ + ");} assert_eq!(n_partitions, actual_partitions); @@ -1013,7 +1190,9 @@ mod tests { ) -> Result<()> { let schema = csv_schema(); let generator = CsvBatchGenerator::new(batch_size, line_count); - let mut deserializer = csv_deserializer(batch_size, &schema); + + let schema_clone = Arc::clone(&schema); + let mut deserializer = csv_deserializer(batch_size, &schema_clone); for data in generator { deserializer.digest(data); @@ -1052,7 +1231,8 @@ mod tests { ) -> Result<()> { let schema = csv_schema(); let generator = CsvBatchGenerator::new(batch_size, line_count); - let mut deserializer = csv_deserializer(batch_size, &schema); + let schema_clone = Arc::clone(&schema); + let mut deserializer = csv_deserializer(batch_size, &schema_clone); for data in generator { deserializer.digest(data); @@ -1151,7 +1331,7 @@ mod tests { fn csv_values(line_number: usize) -> (i32, f64, bool, String) { let int_value = line_number as i32; let float_value = line_number as f64; - let bool_value = line_number % 2 == 0; + let bool_value = line_number.is_multiple_of(2); let char_value = format!("{line_number}-string"); (int_value, float_value, bool_value, char_value) } @@ -1174,4 +1354,271 @@ mod tests { .build_decoder(); DecoderDeserializer::new(CsvDecoder::new(decoder)) } + + fn csv_deserializer_with_truncated( + batch_size: usize, + schema: &Arc, + ) -> impl BatchDeserializer { + // using Arrow's ReaderBuilder and enabling truncated_rows + let decoder = ReaderBuilder::new(schema.clone()) + .with_batch_size(batch_size) + .with_truncated_rows(true) // <- enable runtime truncated_rows + .build_decoder(); + DecoderDeserializer::new(CsvDecoder::new(decoder)) + } + + #[tokio::test] + async fn infer_schema_with_truncated_rows_true() -> Result<()> { + let session_ctx = SessionContext::new(); + let state = session_ctx.state(); + + // CSV: header has 3 columns, but first data row has only 2 columns, second row has 3 + let csv_data = Bytes::from("a,b,c\n1,2\n3,4,5\n"); + let variable_object_store = Arc::new(VariableStream::new(csv_data, 1)); + let object_meta = ObjectMeta { + location: Path::parse("/")?, + last_modified: DateTime::default(), + size: u64::MAX, + e_tag: None, + version: None, + }; + + // Construct CsvFormat and enable truncated_rows via CsvOptions + let csv_options = CsvOptions::default().with_truncated_rows(true); + let csv_format = CsvFormat::default() + .with_has_header(true) + .with_options(csv_options) + .with_schema_infer_max_rec(10); + + let inferred_schema = csv_format + .infer_schema( + &state, + &(variable_object_store.clone() as Arc), + &[object_meta], + ) + .await?; + + // header has 3 columns; inferred schema should also have 3 + assert_eq!(inferred_schema.fields().len(), 3); + + // inferred columns should be nullable + for f in inferred_schema.fields() { + assert!(f.is_nullable()); + } + + Ok(()) + } + #[test] + fn test_decoder_truncated_rows_runtime() -> Result<()> { + // Synchronous test: Decoder API used here is synchronous + let schema = csv_schema(); // helper already defined in file + + // Construct a decoder that enables truncated_rows at runtime + let mut deserializer = csv_deserializer_with_truncated(10, &schema); + + // Provide two rows: first row complete, second row missing last column + let input = Bytes::from("0,0.0,true,0-string\n1,1.0,true\n"); + deserializer.digest(input); + + // Finish and collect output + deserializer.finish(); + + let output = deserializer.next()?; + match output { + DeserializerOutput::RecordBatch(batch) => { + // ensure at least two rows present + assert!(batch.num_rows() >= 2); + // column 4 (index 3) should be a StringArray where second row is NULL + let col4 = batch + .column(3) + .as_any() + .downcast_ref::() + .expect("column 4 should be StringArray"); + + // first row present, second row should be null + assert!(!col4.is_null(0)); + assert!(col4.is_null(1)); + } + other => panic!("expected RecordBatch but got {other:?}"), + } + Ok(()) + } + + #[tokio::test] + async fn infer_schema_truncated_rows_false_error() -> Result<()> { + let session_ctx = SessionContext::new(); + let state = session_ctx.state(); + + // CSV: header has 4 cols, first data row has 3 cols -> truncated at end + let csv_data = Bytes::from("id,a,b,c\n1,foo,bar\n2,foo,bar,baz\n"); + let variable_object_store = Arc::new(VariableStream::new(csv_data, 1)); + let object_meta = ObjectMeta { + location: Path::parse("/")?, + last_modified: DateTime::default(), + size: u64::MAX, + e_tag: None, + version: None, + }; + + // CsvFormat without enabling truncated_rows (default behavior = false) + let csv_format = CsvFormat::default() + .with_has_header(true) + .with_schema_infer_max_rec(10); + + let res = csv_format + .infer_schema( + &state, + &(variable_object_store.clone() as Arc), + &[object_meta], + ) + .await; + + // Expect an error due to unequal lengths / incorrect number of fields + assert!( + res.is_err(), + "expected infer_schema to error on truncated rows when disabled" + ); + + // Optional: check message contains indicative text (two known possibilities) + if let Err(err) = res { + let msg = format!("{err}"); + assert!( + msg.contains("Encountered unequal lengths") + || msg.contains("incorrect number of fields"), + "unexpected error message: {msg}", + ); + } + + Ok(()) + } + + #[tokio::test] + async fn test_read_csv_truncated_rows_via_tempfile() -> Result<()> { + use std::io::Write; + + // create a SessionContext + let ctx = SessionContext::new(); + + // Create a temp file with a .csv suffix so the reader accepts it + let mut tmp = tempfile::Builder::new().suffix(".csv").tempfile()?; // ensures path ends with .csv + // CSV has header "a,b,c". First data row is truncated (only "1,2"), second row is complete. + write!(tmp, "a,b,c\n1,2\n3,4,5\n")?; + let path = tmp.path().to_str().unwrap().to_string(); + + // Build CsvReadOptions: header present, enable truncated_rows. + // (Use the exact builder method your crate exposes: `truncated_rows(true)` here, + // if the method name differs in your codebase use the appropriate one.) + let options = CsvReadOptions::default().truncated_rows(true); + + println!("options: {}, path: {path}", options.truncated_rows); + + // Call the API under test + let df = ctx.read_csv(&path, options).await?; + + // Collect the results and combine batches so we can inspect columns + let batches = df.collect().await?; + let combined = concat_batches(&batches[0].schema(), &batches)?; + + // Column 'c' is the 3rd column (index 2). The first data row was truncated -> should be NULL. + let col_c = combined.column(2); + assert!( + col_c.is_null(0), + "expected first row column 'c' to be NULL due to truncated row" + ); + + // Also ensure we read at least one row + assert!(combined.num_rows() >= 2); + + Ok(()) + } + + #[tokio::test] + async fn test_write_empty_csv_from_sql() -> Result<()> { + let ctx = SessionContext::new(); + let tmp_dir = tempfile::TempDir::new()?; + let path = format!("{}/empty_sql.csv", tmp_dir.path().to_string_lossy()); + let df = ctx.sql("SELECT CAST(1 AS BIGINT) AS id LIMIT 0").await?; + df.write_csv(&path, crate::dataframe::DataFrameWriteOptions::new(), None) + .await?; + assert!(std::path::Path::new(&path).exists()); + + let read_df = ctx + .read_csv(&path, CsvReadOptions::default().has_header(true)) + .await?; + let stream = read_df.execute_stream().await?; + assert_eq!(stream.schema().fields().len(), 1); + assert_eq!(stream.schema().field(0).name(), "id"); + + let results: Vec<_> = stream.collect().await; + assert_eq!(results.len(), 0); + + Ok(()) + } + + #[tokio::test] + async fn test_write_empty_csv_from_record_batch() -> Result<()> { + let ctx = SessionContext::new(); + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, true), + ])); + let empty_batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(arrow::array::Int64Array::from(Vec::::new())), + Arc::new(StringArray::from(Vec::>::new())), + ], + )?; + + let tmp_dir = tempfile::TempDir::new()?; + let path = format!("{}/empty_batch.csv", tmp_dir.path().to_string_lossy()); + + // Write empty RecordBatch + let df = ctx.read_batch(empty_batch.clone())?; + df.write_csv(&path, crate::dataframe::DataFrameWriteOptions::new(), None) + .await?; + // Expected the file to exist + assert!(std::path::Path::new(&path).exists()); + + let read_df = ctx + .read_csv(&path, CsvReadOptions::default().has_header(true)) + .await?; + let stream = read_df.execute_stream().await?; + assert_eq!(stream.schema().fields().len(), 2); + assert_eq!(stream.schema().field(0).name(), "id"); + assert_eq!(stream.schema().field(1).name(), "name"); + + let results: Vec<_> = stream.collect().await; + assert_eq!(results.len(), 0); + + Ok(()) + } + + #[tokio::test] + async fn test_infer_schema_with_zero_max_records() -> Result<()> { + let session_ctx = SessionContext::new(); + let state = session_ctx.state(); + + let root = format!("{}/csv", arrow_test_data()); + let format = CsvFormat::default() + .with_has_header(true) + .with_schema_infer_max_rec(0); // Set to 0 to disable inference + let exec = scan_format( + &state, + &format, + None, + &root, + "aggregate_test_100.csv", + None, + None, + ) + .await?; + + // related to https://github.com/apache/datafusion/issues/19417 + for f in exec.schema().fields() { + assert_eq!(*f.data_type(), DataType::Utf8); + } + + Ok(()) + } } diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index 34d3d64f07fb2..5b3e22705620e 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -25,7 +25,7 @@ mod tests { use super::*; use crate::datasource::file_format::test_util::scan_format; - use crate::prelude::{NdJsonReadOptions, SessionConfig, SessionContext}; + use crate::prelude::{SessionConfig, SessionContext}; use crate::test::object_store::local_unpartitioned_file; use arrow::array::RecordBatch; use arrow_schema::Schema; @@ -36,7 +36,7 @@ mod tests { BatchDeserializer, DecoderDeserializer, DeserializerOutput, }; use datafusion_datasource::file_format::FileFormat; - use datafusion_physical_plan::{collect, ExecutionPlan}; + use datafusion_physical_plan::{ExecutionPlan, collect}; use arrow::compute::concat_batches; use arrow::datatypes::{DataType, Field}; @@ -46,12 +46,54 @@ mod tests { use datafusion_common::internal_err; use datafusion_common::stats::Precision; + use crate::execution::options::JsonReadOptions; use datafusion_common::Result; + use datafusion_datasource::file_compression_type::FileCompressionType; use futures::StreamExt; use insta::assert_snapshot; use object_store::local::LocalFileSystem; use regex::Regex; use rstest::rstest; + // ==================== Test Helpers ==================== + + /// Create a temporary JSON file and return (TempDir, path) + fn create_temp_json(content: &str) -> (tempfile::TempDir, String) { + let tmp_dir = tempfile::TempDir::new().unwrap(); + let path = tmp_dir.path().join("test.json"); + std::fs::write(&path, content).unwrap(); + (tmp_dir, path.to_string_lossy().to_string()) + } + + /// Infer schema from JSON array format file + async fn infer_json_array_schema( + content: &str, + ) -> Result { + let (_tmp_dir, path) = create_temp_json(content); + let session = SessionContext::new(); + let ctx = session.state(); + let store = Arc::new(LocalFileSystem::new()) as _; + let format = JsonFormat::default().with_newline_delimited(false); + format + .infer_schema(&ctx, &store, &[local_unpartitioned_file(&path)]) + .await + } + + /// Register a JSON array table and run a query + async fn query_json_array(content: &str, query: &str) -> Result> { + let (_tmp_dir, path) = create_temp_json(content); + let ctx = SessionContext::new(); + let options = JsonReadOptions::default().newline_delimited(false); + ctx.register_json("test_table", &path, options).await?; + ctx.sql(query).await?.collect().await + } + + /// Register a JSON array table and run a query, return formatted string + async fn query_json_array_str(content: &str, query: &str) -> Result { + let result = query_json_array(content, query).await?; + Ok(batches_to_string(&result)) + } + + // ==================== Existing Tests ==================== #[tokio::test] async fn read_small_batches() -> Result<()> { @@ -187,11 +229,11 @@ mod tests { let re = Regex::new(r"file_groups=\{(\d+) group").unwrap(); - if let Some(captures) = re.captures(&plan) { - if let Some(match_) = captures.get(1) { - let count = match_.as_str().parse::().unwrap(); - return Ok(count); - } + if let Some(captures) = re.captures(&plan) + && let Some(match_) = captures.get(1) + { + let count = match_.as_str().parse::().unwrap(); + return Ok(count); } internal_err!("Query contains no Exec: file_groups") @@ -208,7 +250,7 @@ mod tests { let ctx = SessionContext::new_with_config(config); let table_path = "tests/data/1.json"; - let options = NdJsonReadOptions::default(); + let options = JsonReadOptions::default(); ctx.register_json("json_parallel", table_path, options) .await?; @@ -218,13 +260,13 @@ mod tests { let result = ctx.sql(query).await?.collect().await?; let actual_partitions = count_num_partitions(&ctx, query).await?; - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&result),@r###" - +----------------------+ - | sum(json_parallel.a) | - +----------------------+ - | -7 | - +----------------------+ - "###);} + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&result),@r" + +----------------------+ + | sum(json_parallel.a) | + +----------------------+ + | -7 | + +----------------------+ + ");} assert_eq!(n_partitions, actual_partitions); @@ -240,7 +282,7 @@ mod tests { let ctx = SessionContext::new_with_config(config); let table_path = "tests/data/empty.json"; - let options = NdJsonReadOptions::default(); + let options = JsonReadOptions::default(); ctx.register_json("json_parallel_empty", table_path, options) .await?; @@ -249,10 +291,10 @@ mod tests { let result = ctx.sql(query).await?.collect().await?; - assert_snapshot!(batches_to_string(&result),@r###" - ++ - ++ - "###); + assert_snapshot!(batches_to_string(&result),@r" + ++ + ++ + "); Ok(()) } @@ -284,15 +326,15 @@ mod tests { } assert_eq!(deserializer.next()?, DeserializerOutput::InputExhausted); - assert_snapshot!(batches_to_string(&[all_batches]),@r###" - +----+----+----+----+----+ - | c1 | c2 | c3 | c4 | c5 | - +----+----+----+----+----+ - | 1 | 2 | 3 | 4 | 5 | - | 6 | 7 | 8 | 9 | 10 | - | 11 | 12 | 13 | 14 | 15 | - +----+----+----+----+----+ - "###); + assert_snapshot!(batches_to_string(&[all_batches]),@r" + +----+----+----+----+----+ + | c1 | c2 | c3 | c4 | c5 | + +----+----+----+----+----+ + | 1 | 2 | 3 | 4 | 5 | + | 6 | 7 | 8 | 9 | 10 | + | 11 | 12 | 13 | 14 | 15 | + +----+----+----+----+----+ + "); Ok(()) } @@ -314,7 +356,6 @@ mod tests { .digest(r#"{ "c1": 11, "c2": 12, "c3": 13, "c4": 14, "c5": 15 }"#.into()); let mut all_batches = RecordBatch::new_empty(schema.clone()); - // We get RequiresMoreData after 2 batches because of how json::Decoder works for _ in 0..2 { let output = deserializer.next()?; let DeserializerOutput::RecordBatch(batch) = output else { @@ -324,14 +365,14 @@ mod tests { } assert_eq!(deserializer.next()?, DeserializerOutput::RequiresMoreData); - insta::assert_snapshot!(fmt_batches(&[all_batches]),@r###" - +----+----+----+----+----+ - | c1 | c2 | c3 | c4 | c5 | - +----+----+----+----+----+ - | 1 | 2 | 3 | 4 | 5 | - | 6 | 7 | 8 | 9 | 10 | - +----+----+----+----+----+ - "###); + insta::assert_snapshot!(fmt_batches(&[all_batches]),@r" + +----+----+----+----+----+ + | c1 | c2 | c3 | c4 | c5 | + +----+----+----+----+----+ + | 1 | 2 | 3 | 4 | 5 | + | 6 | 7 | 8 | 9 | 10 | + +----+----+----+----+----+ + "); Ok(()) } @@ -349,4 +390,248 @@ mod tests { fn fmt_batches(batches: &[RecordBatch]) -> String { pretty::pretty_format_batches(batches).unwrap().to_string() } + + #[tokio::test] + async fn test_write_empty_json_from_sql() -> Result<()> { + let ctx = SessionContext::new(); + let tmp_dir = tempfile::TempDir::new()?; + let path = tmp_dir.path().join("empty_sql.json"); + let path = path.to_string_lossy().to_string(); + let df = ctx.sql("SELECT CAST(1 AS BIGINT) AS id LIMIT 0").await?; + df.write_json(&path, crate::dataframe::DataFrameWriteOptions::new(), None) + .await?; + assert!(std::path::Path::new(&path).exists()); + let metadata = std::fs::metadata(&path)?; + assert_eq!(metadata.len(), 0); + Ok(()) + } + + #[tokio::test] + async fn test_write_empty_json_from_record_batch() -> Result<()> { + let ctx = SessionContext::new(); + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, true), + ])); + let empty_batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(arrow::array::Int64Array::from(Vec::::new())), + Arc::new(arrow::array::StringArray::from(Vec::>::new())), + ], + )?; + + let tmp_dir = tempfile::TempDir::new()?; + let path = tmp_dir.path().join("empty_batch.json"); + let path = path.to_string_lossy().to_string(); + let df = ctx.read_batch(empty_batch.clone())?; + df.write_json(&path, crate::dataframe::DataFrameWriteOptions::new(), None) + .await?; + assert!(std::path::Path::new(&path).exists()); + let metadata = std::fs::metadata(&path)?; + assert_eq!(metadata.len(), 0); + Ok(()) + } + + // ==================== JSON Array Format Tests ==================== + + #[tokio::test] + async fn test_json_array_schema_inference() -> Result<()> { + let schema = infer_json_array_schema( + r#"[{"a": 1, "b": 2.0, "c": true}, {"a": 2, "b": 3.5, "c": false}]"#, + ) + .await?; + + let fields: Vec<_> = schema + .fields() + .iter() + .map(|f| format!("{}: {:?}", f.name(), f.data_type())) + .collect(); + assert_eq!(vec!["a: Int64", "b: Float64", "c: Boolean"], fields); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_empty() -> Result<()> { + let schema = infer_json_array_schema("[]").await?; + assert_eq!(schema.fields().len(), 0); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_nested_struct() -> Result<()> { + let schema = infer_json_array_schema( + r#"[{"id": 1, "info": {"name": "Alice", "age": 30}}]"#, + ) + .await?; + + let info_field = schema.field_with_name("info").unwrap(); + assert!(matches!(info_field.data_type(), DataType::Struct(_))); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_list_type() -> Result<()> { + let schema = + infer_json_array_schema(r#"[{"id": 1, "tags": ["a", "b", "c"]}]"#).await?; + + let tags_field = schema.field_with_name("tags").unwrap(); + assert!(matches!(tags_field.data_type(), DataType::List(_))); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_basic_query() -> Result<()> { + let result = query_json_array_str( + r#"[{"a": 1, "b": "hello"}, {"a": 2, "b": "world"}, {"a": 3, "b": "test"}]"#, + "SELECT a, b FROM test_table ORDER BY a", + ) + .await?; + + assert_snapshot!(result, @r" + +---+-------+ + | a | b | + +---+-------+ + | 1 | hello | + | 2 | world | + | 3 | test | + +---+-------+ + "); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_with_nulls() -> Result<()> { + let result = query_json_array_str( + r#"[{"id": 1, "name": "Alice"}, {"id": 2, "name": null}, {"id": 3, "name": "Charlie"}]"#, + "SELECT id, name FROM test_table ORDER BY id", + ) + .await?; + + assert_snapshot!(result, @r" + +----+---------+ + | id | name | + +----+---------+ + | 1 | Alice | + | 2 | | + | 3 | Charlie | + +----+---------+ + "); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_unnest() -> Result<()> { + let result = query_json_array_str( + r#"[{"id": 1, "values": [10, 20, 30]}, {"id": 2, "values": [40, 50]}]"#, + "SELECT id, unnest(values) as value FROM test_table ORDER BY id, value", + ) + .await?; + + assert_snapshot!(result, @r" + +----+-------+ + | id | value | + +----+-------+ + | 1 | 10 | + | 1 | 20 | + | 1 | 30 | + | 2 | 40 | + | 2 | 50 | + +----+-------+ + "); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_unnest_struct() -> Result<()> { + let result = query_json_array_str( + r#"[{"id": 1, "orders": [{"product": "A", "qty": 2}, {"product": "B", "qty": 3}]}, {"id": 2, "orders": [{"product": "C", "qty": 1}]}]"#, + "SELECT id, unnest(orders)['product'] as product, unnest(orders)['qty'] as qty FROM test_table ORDER BY id, product", + ) + .await?; + + assert_snapshot!(result, @r" + +----+---------+-----+ + | id | product | qty | + +----+---------+-----+ + | 1 | A | 2 | + | 1 | B | 3 | + | 2 | C | 1 | + +----+---------+-----+ + "); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_nested_struct_access() -> Result<()> { + let result = query_json_array_str( + r#"[{"id": 1, "dept": {"name": "Engineering", "head": "Alice"}}, {"id": 2, "dept": {"name": "Sales", "head": "Bob"}}]"#, + "SELECT id, dept['name'] as dept_name, dept['head'] as head FROM test_table ORDER BY id", + ) + .await?; + + assert_snapshot!(result, @r" + +----+-------------+-------+ + | id | dept_name | head | + +----+-------------+-------+ + | 1 | Engineering | Alice | + | 2 | Sales | Bob | + +----+-------------+-------+ + "); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_with_compression() -> Result<()> { + use flate2::Compression; + use flate2::write::GzEncoder; + use std::io::Write; + + let tmp_dir = tempfile::TempDir::new()?; + let path = tmp_dir.path().join("array.json.gz"); + let path = path.to_string_lossy().to_string(); + + let file = std::fs::File::create(&path)?; + let mut encoder = GzEncoder::new(file, Compression::default()); + encoder.write_all( + r#"[{"a": 1, "b": "hello"}, {"a": 2, "b": "world"}]"#.as_bytes(), + )?; + encoder.finish()?; + + let ctx = SessionContext::new(); + let options = JsonReadOptions::default() + .newline_delimited(false) + .file_compression_type(FileCompressionType::GZIP) + .file_extension(".json.gz"); + + ctx.register_json("test_table", &path, options).await?; + let result = ctx + .sql("SELECT a, b FROM test_table ORDER BY a") + .await? + .collect() + .await?; + + assert_snapshot!(batches_to_string(&result), @r" + +---+-------+ + | a | b | + +---+-------+ + | 1 | hello | + | 2 | world | + +---+-------+ + "); + Ok(()) + } + + #[tokio::test] + async fn test_json_array_list_of_structs() -> Result<()> { + let batches = query_json_array( + r#"[{"id": 1, "items": [{"name": "x", "price": 10.5}]}, {"id": 2, "items": []}]"#, + "SELECT id, items FROM test_table ORDER BY id", + ) + .await?; + + assert_eq!(1, batches.len()); + assert_eq!(2, batches[0].num_rows()); + Ok(()) + } } diff --git a/datafusion/core/src/datasource/file_format/mod.rs b/datafusion/core/src/datasource/file_format/mod.rs index e165707c2eb0e..b04238ebc9b37 100644 --- a/datafusion/core/src/datasource/file_format/mod.rs +++ b/datafusion/core/src/datasource/file_format/mod.rs @@ -39,8 +39,9 @@ pub(crate) mod test_util { use arrow_schema::SchemaRef; use datafusion_catalog::Session; use datafusion_common::Result; + use datafusion_datasource::TableSchema; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; - use datafusion_datasource::{file_format::FileFormat, PartitionedFile}; + use datafusion_datasource::{PartitionedFile, file_format::FileFormat}; use datafusion_execution::object_store::ObjectStoreUrl; use std::sync::Arc; @@ -66,31 +67,24 @@ pub(crate) mod test_util { .await? }; + let table_schema = TableSchema::new(file_schema.clone(), vec![]); + let statistics = format .infer_stats(state, &store, file_schema.clone(), &meta) .await?; - let file_groups = vec![vec![PartitionedFile { - object_meta: meta, - partition_values: vec![], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - }] - .into()]; + let file_groups = vec![vec![PartitionedFile::new_from_meta(meta)].into()]; let exec = format .create_physical_plan( state, FileScanConfigBuilder::new( ObjectStoreUrl::local_filesystem(), - file_schema, - format.file_source(), + format.file_source(table_schema), ) .with_file_groups(file_groups) .with_statistics(statistics) - .with_projection(projection) + .with_projection_indices(projection)? .with_limit(limit) .build(), ) @@ -131,7 +125,10 @@ mod tests { .write_parquet(out_dir_url, DataFrameWriteOptions::new(), None) .await .expect_err("should fail because input file does not match inferred schema"); - assert_eq!(e.strip_backtrace(), "Arrow error: Parser error: Error while parsing value 'd' as type 'Int64' for column 0 at line 4. Row data: '[d,4]'"); + assert_eq!( + e.strip_backtrace(), + "Arrow error: Parser error: Error while parsing value 'd' as type 'Int64' for column 0 at line 4. Row data: '[d,4]'" + ); Ok(()) } } diff --git a/datafusion/core/src/datasource/file_format/options.rs b/datafusion/core/src/datasource/file_format/options.rs index 9aaf1cf598113..bd0ac36087381 100644 --- a/datafusion/core/src/datasource/file_format/options.rs +++ b/datafusion/core/src/datasource/file_format/options.rs @@ -25,16 +25,16 @@ use crate::datasource::file_format::avro::AvroFormat; #[cfg(feature = "parquet")] use crate::datasource::file_format::parquet::ParquetFormat; +use crate::datasource::file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD; use crate::datasource::file_format::arrow::ArrowFormat; use crate::datasource::file_format::file_compression_type::FileCompressionType; -use crate::datasource::file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD; use crate::datasource::listing::ListingTableUrl; use crate::datasource::{file_format::csv::CsvFormat, listing::ListingOptions}; use crate::error::Result; use crate::execution::context::{SessionConfig, SessionState}; use arrow::datatypes::{DataType, Schema, SchemaRef}; -use datafusion_common::config::TableOptions; +use datafusion_common::config::{ConfigFileDecryptionProperties, TableOptions}; use datafusion_common::{ DEFAULT_ARROW_EXTENSION, DEFAULT_AVRO_EXTENSION, DEFAULT_CSV_EXTENSION, DEFAULT_JSON_EXTENSION, DEFAULT_PARQUET_EXTENSION, @@ -91,6 +91,11 @@ pub struct CsvReadOptions<'a> { pub file_sort_order: Vec>, /// Optional regex to match null values pub null_regex: Option, + /// Whether to allow truncated rows when parsing. + /// By default this is set to false and will error if the CSV rows have different lengths. + /// When set to true then it will allow records with less than the expected number of columns and fill the missing columns with nulls. + /// If the record’s schema is not nullable, then it will still return an error. + pub truncated_rows: bool, } impl Default for CsvReadOptions<'_> { @@ -117,6 +122,7 @@ impl<'a> CsvReadOptions<'a> { file_sort_order: vec![], comment: None, null_regex: None, + truncated_rows: false, } } @@ -223,6 +229,15 @@ impl<'a> CsvReadOptions<'a> { self.null_regex = null_regex; self } + + /// Configure whether to allow truncated rows when parsing. + /// By default this is set to false and will error if the CSV rows have different lengths + /// When set to true then it will allow records with less than the expected number of columns and fill the missing columns with nulls. + /// If the record’s schema is not nullable, then it will still return an error. + pub fn truncated_rows(mut self, truncated_rows: bool) -> Self { + self.truncated_rows = truncated_rows; + self + } } /// Options that control the reading of Parquet files. @@ -252,6 +267,10 @@ pub struct ParquetReadOptions<'a> { pub schema: Option<&'a Schema>, /// Indicates how the file is sorted pub file_sort_order: Vec>, + /// Properties for decryption of Parquet files that use modular encryption + pub file_decryption_properties: Option, + /// Metadata size hint for Parquet files reading (in bytes) + pub metadata_size_hint: Option, } impl Default for ParquetReadOptions<'_> { @@ -263,6 +282,8 @@ impl Default for ParquetReadOptions<'_> { skip_metadata: None, schema: None, file_sort_order: vec![], + file_decryption_properties: None, + metadata_size_hint: None, } } } @@ -313,6 +334,21 @@ impl<'a> ParquetReadOptions<'a> { self.file_sort_order = file_sort_order; self } + + /// Configure file decryption properties for reading encrypted Parquet files + pub fn file_decryption_properties( + mut self, + file_decryption_properties: ConfigFileDecryptionProperties, + ) -> Self { + self.file_decryption_properties = Some(file_decryption_properties); + self + } + + /// Configure metadata size hint for Parquet files reading (in bytes) + pub fn metadata_size_hint(mut self, size_hint: Option) -> Self { + self.metadata_size_hint = size_hint; + self + } } /// Options that control the reading of ARROW files. @@ -406,14 +442,23 @@ impl<'a> AvroReadOptions<'a> { } } -/// Options that control the reading of Line-delimited JSON files (NDJson) +#[deprecated( + since = "53.0.0", + note = "Use `JsonReadOptions` instead. This alias will be removed in a future version." +)] +#[doc = "Deprecated: Use [`JsonReadOptions`] instead."] +pub type NdJsonReadOptions<'a> = JsonReadOptions<'a>; + +/// Options that control the reading of JSON files. +/// +/// Supports both newline-delimited JSON (NDJSON) and JSON array formats. /// /// Note this structure is supplied when a datasource is created and -/// can not not vary from statement to statement. For settings that +/// can not vary from statement to statement. For settings that /// can vary statement to statement see /// [`ConfigOptions`](crate::config::ConfigOptions). #[derive(Clone)] -pub struct NdJsonReadOptions<'a> { +pub struct JsonReadOptions<'a> { /// The data source schema. pub schema: Option<&'a Schema>, /// Max number of rows to read from JSON files for schema inference if needed. Defaults to `DEFAULT_SCHEMA_INFER_MAX_RECORD`. @@ -429,9 +474,25 @@ pub struct NdJsonReadOptions<'a> { pub infinite: bool, /// Indicates how the file is sorted pub file_sort_order: Vec>, + /// Whether to read as newline-delimited JSON (default: true). + /// + /// When `true` (default), expects newline-delimited JSON (NDJSON): + /// ```text + /// {"key1": 1, "key2": "val"} + /// {"key1": 2, "key2": "vals"} + /// ``` + /// + /// When `false`, expects JSON array format: + /// ```text + /// [ + /// {"key1": 1, "key2": "val"}, + /// {"key1": 2, "key2": "vals"} + /// ] + /// ``` + pub newline_delimited: bool, } -impl Default for NdJsonReadOptions<'_> { +impl Default for JsonReadOptions<'_> { fn default() -> Self { Self { schema: None, @@ -441,11 +502,12 @@ impl Default for NdJsonReadOptions<'_> { file_compression_type: FileCompressionType::UNCOMPRESSED, infinite: false, file_sort_order: vec![], + newline_delimited: true, } } } -impl<'a> NdJsonReadOptions<'a> { +impl<'a> JsonReadOptions<'a> { /// Specify table_partition_cols for partition pruning pub fn table_partition_cols( mut self, @@ -487,6 +549,32 @@ impl<'a> NdJsonReadOptions<'a> { self.file_sort_order = file_sort_order; self } + + /// Specify how many rows to read for schema inference + pub fn schema_infer_max_records(mut self, schema_infer_max_records: usize) -> Self { + self.schema_infer_max_records = schema_infer_max_records; + self + } + + /// Set whether to read as newline-delimited JSON. + /// + /// When `true` (default), expects newline-delimited JSON (NDJSON): + /// ```text + /// {"key1": 1, "key2": "val"} + /// {"key1": 2, "key2": "vals"} + /// ``` + /// + /// When `false`, expects JSON array format: + /// ```text + /// [ + /// {"key1": 1, "key2": "val"}, + /// {"key1": 2, "key2": "vals"} + /// ] + /// ``` + pub fn newline_delimited(mut self, newline_delimited: bool) -> Self { + self.newline_delimited = newline_delimited; + self + } } #[async_trait] @@ -546,7 +634,8 @@ impl ReadOptions<'_> for CsvReadOptions<'_> { .with_newlines_in_values(self.newlines_in_values) .with_schema_infer_max_rec(self.schema_infer_max_records) .with_file_compression_type(self.file_compression_type.to_owned()) - .with_null_regex(self.null_regex.clone()); + .with_null_regex(self.null_regex.clone()) + .with_truncated_rows(self.truncated_rows); ListingOptions::new(Arc::new(file_format)) .with_file_extension(self.file_extension) @@ -574,7 +663,16 @@ impl ReadOptions<'_> for ParquetReadOptions<'_> { config: &SessionConfig, table_options: TableOptions, ) -> ListingOptions { - let mut file_format = ParquetFormat::new().with_options(table_options.parquet); + let mut options = table_options.parquet; + if let Some(file_decryption_properties) = &self.file_decryption_properties { + options.crypto.file_decryption = Some(file_decryption_properties.clone()); + } + // This can be overridden per-read in ParquetReadOptions, if setting. + if let Some(metadata_size_hint) = self.metadata_size_hint { + options.global.metadata_size_hint = Some(metadata_size_hint); + } + + let mut file_format = ParquetFormat::new().with_options(options); if let Some(parquet_pruning) = self.parquet_pruning { file_format = file_format.with_enable_pruning(parquet_pruning) @@ -602,7 +700,7 @@ impl ReadOptions<'_> for ParquetReadOptions<'_> { } #[async_trait] -impl ReadOptions<'_> for NdJsonReadOptions<'_> { +impl ReadOptions<'_> for JsonReadOptions<'_> { fn to_listing_options( &self, config: &SessionConfig, @@ -611,7 +709,8 @@ impl ReadOptions<'_> for NdJsonReadOptions<'_> { let file_format = JsonFormat::default() .with_options(table_options.json) .with_schema_infer_max_rec(self.schema_infer_max_records) - .with_file_compression_type(self.file_compression_type.to_owned()); + .with_file_compression_type(self.file_compression_type.to_owned()) + .with_newline_delimited(self.newline_delimited); ListingOptions::new(Arc::new(file_format)) .with_file_extension(self.file_extension) diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 9705225c24c7b..6a8f7ab999757 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -107,10 +107,8 @@ pub(crate) mod test_util { mod tests { use std::fmt::{self, Display, Formatter}; - use std::pin::Pin; - use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; - use std::task::{Context, Poll}; + use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Duration; use crate::datasource::file_format::parquet::test_util::store_parquet; @@ -120,8 +118,9 @@ mod tests { use crate::prelude::{ParquetReadOptions, SessionConfig, SessionContext}; use arrow::array::RecordBatch; - use arrow_schema::{Schema, SchemaRef}; + use arrow_schema::Schema; use datafusion_catalog::Session; + use datafusion_common::ScalarValue::Utf8; use datafusion_common::cast::{ as_binary_array, as_binary_view_array, as_boolean_array, as_float32_array, as_float64_array, as_int32_array, as_timestamp_nanosecond_array, @@ -129,44 +128,47 @@ mod tests { use datafusion_common::config::{ParquetOptions, TableParquetOptions}; use datafusion_common::stats::Precision; use datafusion_common::test_util::batches_to_string; - use datafusion_common::ScalarValue::Utf8; use datafusion_common::{Result, ScalarValue}; use datafusion_datasource::file_format::FileFormat; - use datafusion_datasource::file_sink_config::{FileSink, FileSinkConfig}; + use datafusion_datasource::file_sink_config::{ + FileOutputMode, FileSink, FileSinkConfig, + }; use datafusion_datasource::{ListingTableUrl, PartitionedFile}; use datafusion_datasource_parquet::{ - fetch_parquet_metadata, fetch_statistics, statistics_from_parquet_meta_calc, ParquetFormat, ParquetFormatFactory, ParquetSink, }; + use datafusion_execution::TaskContext; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::runtime_env::RuntimeEnv; - use datafusion_execution::{RecordBatchStream, TaskContext}; use datafusion_expr::dml::InsertOp; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; - use datafusion_physical_plan::{collect, ExecutionPlan}; + use datafusion_physical_plan::{ExecutionPlan, collect}; + use crate::test_util::bounded_stream; use arrow::array::{ - types::Int32Type, Array, ArrayRef, DictionaryArray, Int32Array, Int64Array, - StringArray, + Array, ArrayRef, DictionaryArray, Int32Array, Int64Array, StringArray, + types::Int32Type, }; use arrow::datatypes::{DataType, Field}; use async_trait::async_trait; use datafusion_datasource::file_groups::FileGroup; + use datafusion_datasource_parquet::metadata::DFParquetMetadata; + use futures::StreamExt; use futures::stream::BoxStream; - use futures::{Stream, StreamExt}; use insta::assert_snapshot; - use log::error; use object_store::local::LocalFileSystem; - use object_store::ObjectMeta; + use object_store::{CopyOptions, ObjectMeta}; use object_store::{ - path::Path, GetOptions, GetResult, ListResult, MultipartUpload, ObjectStore, - PutMultipartOpts, PutOptions, PutPayload, PutResult, + GetOptions, GetResult, ListResult, MultipartUpload, ObjectStore, + PutMultipartOptions, PutOptions, PutPayload, PutResult, path::Path, }; - use parquet::arrow::arrow_reader::ArrowReaderOptions; use parquet::arrow::ParquetRecordBatchStreamBuilder; - use parquet::file::metadata::{KeyValue, ParquetColumnIndex, ParquetOffsetIndex}; - use parquet::file::page_index::index::Index; - use parquet::format::FileMetaData; + use parquet::arrow::arrow_reader::ArrowReaderOptions; + use parquet::file::metadata::{ + KeyValue, PageIndexPolicy, ParquetColumnIndex, ParquetMetaData, + ParquetOffsetIndex, + }; + use parquet::file::page_index::column_index::ColumnIndexMetaData; use tokio::fs::File; enum ForceViews { @@ -180,8 +182,8 @@ mod tests { let c2: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(2), None])); - let batch1 = RecordBatch::try_from_iter(vec![("c1", c1.clone())]).unwrap(); - let batch2 = RecordBatch::try_from_iter(vec![("c2", c2)]).unwrap(); + let batch1 = RecordBatch::try_from_iter(vec![("c1", c1.clone())])?; + let batch2 = RecordBatch::try_from_iter(vec![("c2", c2)])?; let store = Arc::new(LocalFileSystem::new()) as _; let (meta, _files) = store_parquet(vec![batch1, batch2], false).await?; @@ -193,10 +195,14 @@ mod tests { ForceViews::No => false, }; let format = ParquetFormat::default().with_force_view_types(force_views); - let schema = format.infer_schema(&ctx, &store, &meta).await.unwrap(); + let schema = format.infer_schema(&ctx, &store, &meta).await?; - let stats = - fetch_statistics(store.as_ref(), schema.clone(), &meta[0], None).await?; + let file_metadata_cache = + ctx.runtime_env().cache_manager.get_file_metadata_cache(); + let stats = DFParquetMetadata::new(&store, &meta[0]) + .with_file_metadata_cache(Some(Arc::clone(&file_metadata_cache))) + .fetch_statistics(&schema) + .await?; assert_eq!(stats.num_rows, Precision::Exact(3)); let c1_stats = &stats.column_statistics[0]; @@ -204,7 +210,11 @@ mod tests { assert_eq!(c1_stats.null_count, Precision::Exact(1)); assert_eq!(c2_stats.null_count, Precision::Exact(3)); - let stats = fetch_statistics(store.as_ref(), schema, &meta[1], None).await?; + let stats = DFParquetMetadata::new(&store, &meta[1]) + .with_file_metadata_cache(Some(Arc::clone(&file_metadata_cache))) + .fetch_statistics(&schema) + .await?; + assert_eq!(stats.num_rows, Precision::Exact(3)); let c1_stats = &stats.column_statistics[0]; let c2_stats = &stats.column_statistics[1]; @@ -238,11 +248,9 @@ mod tests { let c2: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(2), None])); let batch1 = - RecordBatch::try_from_iter(vec![("a", c1.clone()), ("b", c1.clone())]) - .unwrap(); + RecordBatch::try_from_iter(vec![("a", c1.clone()), ("b", c1.clone())])?; let batch2 = - RecordBatch::try_from_iter(vec![("c", c2.clone()), ("d", c2.clone())]) - .unwrap(); + RecordBatch::try_from_iter(vec![("c", c2.clone()), ("d", c2.clone())])?; let store = Arc::new(LocalFileSystem::new()) as _; let (meta, _files) = store_parquet(vec![batch1, batch2], false).await?; @@ -250,7 +258,7 @@ mod tests { let session = SessionContext::new(); let ctx = session.state(); let format = ParquetFormat::default(); - let schema = format.infer_schema(&ctx, &store, &meta).await.unwrap(); + let schema = format.infer_schema(&ctx, &store, &meta).await?; let order: Vec<_> = ["a", "b", "c", "d"] .into_iter() @@ -303,15 +311,15 @@ mod tests { _payload: PutPayload, _opts: PutOptions, ) -> object_store::Result { - Err(object_store::Error::NotImplemented) + unimplemented!() } async fn put_multipart_opts( &self, _location: &Path, - _opts: PutMultipartOpts, + _opts: PutMultipartOptions, ) -> object_store::Result> { - Err(object_store::Error::NotImplemented) + unimplemented!() } async fn get_opts( @@ -323,40 +331,34 @@ mod tests { self.inner.get_opts(location, options).await } - async fn head(&self, _location: &Path) -> object_store::Result { - Err(object_store::Error::NotImplemented) - } - - async fn delete(&self, _location: &Path) -> object_store::Result<()> { - Err(object_store::Error::NotImplemented) + fn delete_stream( + &self, + _locations: BoxStream<'static, object_store::Result>, + ) -> BoxStream<'static, object_store::Result> { + unimplemented!() } fn list( &self, _prefix: Option<&Path>, ) -> BoxStream<'static, object_store::Result> { - Box::pin(futures::stream::once(async { - Err(object_store::Error::NotImplemented) - })) + unimplemented!() } async fn list_with_delimiter( &self, _prefix: Option<&Path>, ) -> object_store::Result { - Err(object_store::Error::NotImplemented) + unimplemented!() } - async fn copy(&self, _from: &Path, _to: &Path) -> object_store::Result<()> { - Err(object_store::Error::NotImplemented) - } - - async fn copy_if_not_exists( + async fn copy_opts( &self, _from: &Path, _to: &Path, + _options: CopyOptions, ) -> object_store::Result<()> { - Err(object_store::Error::NotImplemented) + unimplemented!() } } @@ -366,24 +368,42 @@ mod tests { let c2: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(2), None])); - let batch1 = RecordBatch::try_from_iter(vec![("c1", c1.clone())]).unwrap(); - let batch2 = RecordBatch::try_from_iter(vec![("c2", c2)]).unwrap(); + let batch1 = RecordBatch::try_from_iter(vec![("c1", c1.clone())])?; + let batch2 = RecordBatch::try_from_iter(vec![("c2", c2)])?; let store = Arc::new(RequestCountingObjectStore::new(Arc::new( LocalFileSystem::new(), ))); let (meta, _files) = store_parquet(vec![batch1, batch2], false).await?; + let session = SessionContext::new(); + let ctx = session.state(); + // Use a size hint larger than the parquet footer but smaller than the actual metadata, requiring a second fetch // for the remaining metadata - fetch_parquet_metadata(store.as_ref() as &dyn ObjectStore, &meta[0], Some(9)) - .await - .expect("error reading metadata with hint"); - + let file_metadata_cache = + ctx.runtime_env().cache_manager.get_file_metadata_cache(); + let df_meta = DFParquetMetadata::new(store.as_ref(), &meta[0]) + .with_metadata_size_hint(Some(9)); + df_meta.fetch_metadata().await?; assert_eq!(store.request_count(), 2); - let session = SessionContext::new(); - let ctx = session.state(); + let df_meta = + df_meta.with_file_metadata_cache(Some(Arc::clone(&file_metadata_cache))); + + // Increases by 3 because cache has no entries yet + df_meta.fetch_metadata().await?; + assert_eq!(store.request_count(), 5); + + // No increase because cache has an entry + df_meta.fetch_metadata().await?; + assert_eq!(store.request_count(), 5); + + // Increase by 2 because `get_file_metadata_cache()` is None + let df_meta = df_meta.with_file_metadata_cache(None); + df_meta.fetch_metadata().await?; + assert_eq!(store.request_count(), 7); + let force_views = match force_views { ForceViews::Yes => true, ForceViews::No => false, @@ -391,14 +411,18 @@ mod tests { let format = ParquetFormat::default() .with_metadata_size_hint(Some(9)) .with_force_view_types(force_views); - let schema = format - .infer_schema(&ctx, &store.upcast(), &meta) - .await - .unwrap(); - - let stats = - fetch_statistics(store.upcast().as_ref(), schema.clone(), &meta[0], Some(9)) - .await?; + // Increase by 3, partial cache being used. + let _schema = format.infer_schema(&ctx, &store.upcast(), &meta).await?; + assert_eq!(store.request_count(), 10); + // No increase, full cache being used. + let schema = format.infer_schema(&ctx, &store.upcast(), &meta).await?; + assert_eq!(store.request_count(), 10); + + // No increase, cache being used + let df_meta = + df_meta.with_file_metadata_cache(Some(Arc::clone(&file_metadata_cache))); + let stats = df_meta.fetch_statistics(&schema).await?; + assert_eq!(store.request_count(), 10); assert_eq!(stats.num_rows, Precision::Exact(3)); let c1_stats = &stats.column_statistics[0]; @@ -412,28 +436,46 @@ mod tests { // Use the file size as the hint so we can get the full metadata from the first fetch let size_hint = meta[0].size as usize; + let df_meta = DFParquetMetadata::new(store.as_ref(), &meta[0]) + .with_metadata_size_hint(Some(size_hint)); - fetch_parquet_metadata(store.upcast().as_ref(), &meta[0], Some(size_hint)) - .await - .expect("error reading metadata with hint"); - + df_meta.fetch_metadata().await?; // ensure the requests were coalesced into a single request assert_eq!(store.request_count(), 1); + let session = SessionContext::new(); + let ctx = session.state(); + let file_metadata_cache = + ctx.runtime_env().cache_manager.get_file_metadata_cache(); + let df_meta = + df_meta.with_file_metadata_cache(Some(Arc::clone(&file_metadata_cache))); + // Increases by 1 because cache has no entries yet and new session context + df_meta.fetch_metadata().await?; + assert_eq!(store.request_count(), 2); + + // No increase because cache has an entry + df_meta.fetch_metadata().await?; + assert_eq!(store.request_count(), 2); + + // Increase by 1 because `get_file_metadata_cache` is None + let df_meta = df_meta.with_file_metadata_cache(None); + df_meta.fetch_metadata().await?; + assert_eq!(store.request_count(), 3); + let format = ParquetFormat::default() .with_metadata_size_hint(Some(size_hint)) .with_force_view_types(force_views); - let schema = format - .infer_schema(&ctx, &store.upcast(), &meta) - .await - .unwrap(); - let stats = fetch_statistics( - store.upcast().as_ref(), - schema.clone(), - &meta[0], - Some(size_hint), - ) - .await?; + // Increase by 1, partial cache being used. + let _schema = format.infer_schema(&ctx, &store.upcast(), &meta).await?; + assert_eq!(store.request_count(), 4); + // No increase, full cache being used. + let schema = format.infer_schema(&ctx, &store.upcast(), &meta).await?; + assert_eq!(store.request_count(), 4); + // No increase, cache being used + let df_meta = + df_meta.with_file_metadata_cache(Some(Arc::clone(&file_metadata_cache))); + let stats = df_meta.fetch_statistics(&schema).await?; + assert_eq!(store.request_count(), 4); assert_eq!(stats.num_rows, Precision::Exact(3)); let c1_stats = &stats.column_statistics[0]; @@ -445,13 +487,18 @@ mod tests { LocalFileSystem::new(), ))); - // Use the a size hint larger than the file size to make sure we don't panic + // Use a size hint larger than the file size to make sure we don't panic let size_hint = (meta[0].size + 100) as usize; + let df_meta = DFParquetMetadata::new(store.as_ref(), &meta[0]) + .with_metadata_size_hint(Some(size_hint)); - fetch_parquet_metadata(store.upcast().as_ref(), &meta[0], Some(size_hint)) - .await - .expect("error reading metadata with hint"); + df_meta.fetch_metadata().await?; + assert_eq!(store.request_count(), 1); + // No increase because cache has an entry + let df_meta = + df_meta.with_file_metadata_cache(Some(Arc::clone(&file_metadata_cache))); + df_meta.fetch_metadata().await?; assert_eq!(store.request_count(), 1); Ok(()) @@ -470,25 +517,47 @@ mod tests { // Data for column c_dic: ["a", "b", "c", "d"] let values = StringArray::from_iter_values(["a", "b", "c", "d"]); let keys = Int32Array::from_iter_values([0, 1, 2, 3]); - let dic_array = - DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); + let dic_array = DictionaryArray::::try_new(keys, Arc::new(values))?; let c_dic: ArrayRef = Arc::new(dic_array); - let batch1 = RecordBatch::try_from_iter(vec![("c_dic", c_dic)]).unwrap(); + // Data for column string_truncation: ["a".repeat(128), null, "b".repeat(128), null] + let string_truncation: ArrayRef = Arc::new(StringArray::from(vec![ + Some("a".repeat(128)), + None, + Some("b".repeat(128)), + None, + ])); + + let batch1 = RecordBatch::try_from_iter(vec![ + ("c_dic", c_dic), + ("string_truncation", string_truncation), + ])?; // Use store_parquet to write each batch to its own file // . batch1 written into first file and includes: // - column c_dic that has 4 rows with no null. Stats min and max of dictionary column is available. - let store = Arc::new(LocalFileSystem::new()) as _; + // - column string_truncation that has 4 rows with 2 nulls. Stats min and max of string column is available but not exact. + let store = Arc::new(RequestCountingObjectStore::new(Arc::new( + LocalFileSystem::new(), + ))); let (files, _file_names) = store_parquet(vec![batch1], false).await?; let state = SessionContext::new().state(); - let format = ParquetFormat::default(); - let schema = format.infer_schema(&state, &store, &files).await.unwrap(); - - // Fetch statistics for first file - let pq_meta = fetch_parquet_metadata(store.as_ref(), &files[0], None).await?; - let stats = statistics_from_parquet_meta_calc(&pq_meta, schema.clone())?; + // Make metadata size hint None to keep original behavior + let format = ParquetFormat::default().with_metadata_size_hint(None); + let _schema = format.infer_schema(&state, &store.upcast(), &files).await?; + assert_eq!(store.request_count(), 3); + // No increase, cache being used. + let schema = format.infer_schema(&state, &store.upcast(), &files).await?; + assert_eq!(store.request_count(), 3); + + // No increase in request count because cache is not empty + let file_metadata_cache = + state.runtime_env().cache_manager.get_file_metadata_cache(); + let stats = DFParquetMetadata::new(store.as_ref(), &files[0]) + .with_file_metadata_cache(Some(Arc::clone(&file_metadata_cache))) + .fetch_statistics(&schema) + .await?; assert_eq!(stats.num_rows, Precision::Exact(4)); // column c_dic @@ -504,6 +573,19 @@ mod tests { Precision::Exact(Utf8(Some("a".into()))) ); + // column string_truncation + let string_truncation_stats = &stats.column_statistics[1]; + + assert_eq!(string_truncation_stats.null_count, Precision::Exact(2)); + assert_eq!( + string_truncation_stats.max_value, + Precision::Inexact(ScalarValue::Utf8View(Some("b".repeat(63) + "c"))) + ); + assert_eq!( + string_truncation_stats.min_value, + Precision::Inexact(ScalarValue::Utf8View(Some("a".repeat(64)))) + ); + Ok(()) } @@ -513,18 +595,20 @@ mod tests { // Data for column c1: ["Foo", null, "bar"] let c1: ArrayRef = Arc::new(StringArray::from(vec![Some("Foo"), None, Some("bar")])); - let batch1 = RecordBatch::try_from_iter(vec![("c1", c1.clone())]).unwrap(); + let batch1 = RecordBatch::try_from_iter(vec![("c1", c1.clone())])?; // Data for column c2: [1, 2, null] let c2: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(2), None])); - let batch2 = RecordBatch::try_from_iter(vec![("c2", c2)]).unwrap(); + let batch2 = RecordBatch::try_from_iter(vec![("c2", c2)])?; // Use store_parquet to write each batch to its own file // . batch1 written into first file and includes: // - column c1 that has 3 rows with one null. Stats min and max of string column is missing for this test even the column has values // . batch2 written into second file and includes: // - column c2 that has 3 rows with one null. Stats min and max of int are available and 1 and 2 respectively - let store = Arc::new(LocalFileSystem::new()) as _; + let store = Arc::new(RequestCountingObjectStore::new(Arc::new( + LocalFileSystem::new(), + ))); let (files, _file_names) = store_parquet(vec![batch1, batch2], false).await?; let force_views = match force_views { @@ -534,8 +618,11 @@ mod tests { let mut state = SessionContext::new().state(); state = set_view_state(state, force_views); - let format = ParquetFormat::default().with_force_view_types(force_views); - let schema = format.infer_schema(&state, &store, &files).await.unwrap(); + let format = ParquetFormat::default() + .with_force_view_types(force_views) + .with_metadata_size_hint(None); + let schema = format.infer_schema(&state, &store.upcast(), &files).await?; + assert_eq!(store.request_count(), 6); let null_i64 = ScalarValue::Int64(None); let null_utf8 = if force_views { @@ -544,9 +631,14 @@ mod tests { Utf8(None) }; - // Fetch statistics for first file - let pq_meta = fetch_parquet_metadata(store.as_ref(), &files[0], None).await?; - let stats = statistics_from_parquet_meta_calc(&pq_meta, schema.clone())?; + // No increase in request count because cache is not empty + let file_metadata_cache = + state.runtime_env().cache_manager.get_file_metadata_cache(); + let stats = DFParquetMetadata::new(store.as_ref(), &files[0]) + .with_file_metadata_cache(Some(Arc::clone(&file_metadata_cache))) + .fetch_statistics(&schema) + .await?; + assert_eq!(store.request_count(), 6); assert_eq!(stats.num_rows, Precision::Exact(3)); // column c1 let c1_stats = &stats.column_statistics[0]; @@ -570,9 +662,12 @@ mod tests { assert_eq!(c2_stats.max_value, Precision::Exact(null_i64.clone())); assert_eq!(c2_stats.min_value, Precision::Exact(null_i64.clone())); - // Fetch statistics for second file - let pq_meta = fetch_parquet_metadata(store.as_ref(), &files[1], None).await?; - let stats = statistics_from_parquet_meta_calc(&pq_meta, schema.clone())?; + // No increase in request count because cache is not empty + let stats = DFParquetMetadata::new(store.as_ref(), &files[1]) + .with_file_metadata_cache(Some(Arc::clone(&file_metadata_cache))) + .fetch_statistics(&schema) + .await?; + assert_eq!(store.request_count(), 6); assert_eq!(stats.num_rows, Precision::Exact(3)); // column c1: missing from the file so the table treats all 3 rows as null let c1_stats = &stats.column_statistics[0]; @@ -626,7 +721,7 @@ mod tests { // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 assert_eq!( exec.partition_statistics(None)?.total_byte_size, - Precision::Exact(671) + Precision::Absent, ); Ok(()) @@ -672,10 +767,9 @@ mod tests { exec.partition_statistics(None)?.num_rows, Precision::Exact(8) ); - // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 assert_eq!( exec.partition_statistics(None)?.total_byte_size, - Precision::Exact(671) + Precision::Absent, ); let batches = collect(exec, task_ctx).await?; assert_eq!(1, batches.len()); @@ -718,7 +812,7 @@ mod tests { .schema() .fields() .iter() - .map(|f| format!("{}: {:?}", f.name(), f.data_type())) + .map(|f| format!("{}: {}", f.name(), f.data_type())) .collect(); let y = x.join("\n"); assert_eq!(expected, y); @@ -744,7 +838,7 @@ mod tests { double_col: Float64\n\ date_string_col: Binary\n\ string_col: Binary\n\ - timestamp_col: Timestamp(Nanosecond, None)"; + timestamp_col: Timestamp(ns)"; _run_read_alltypes_plain_parquet(ForceViews::No, no_views).await?; let with_views = "id: Int32\n\ @@ -757,7 +851,7 @@ mod tests { double_col: Float64\n\ date_string_col: BinaryView\n\ string_col: BinaryView\n\ - timestamp_col: Timestamp(Nanosecond, None)"; + timestamp_col: Timestamp(ns)"; _run_read_alltypes_plain_parquet(ForceViews::Yes, with_views).await?; Ok(()) @@ -833,7 +927,10 @@ mod tests { values.push(array.value(i)); } - assert_eq!("[1235865600000000000, 1235865660000000000, 1238544000000000000, 1238544060000000000, 1233446400000000000, 1233446460000000000, 1230768000000000000, 1230768060000000000]", format!("{values:?}")); + assert_eq!( + "[1235865600000000000, 1235865660000000000, 1238544000000000000, 1238544060000000000, 1233446400000000000, 1233446460000000000, 1230768000000000000, 1230768060000000000]", + format!("{values:?}") + ); Ok(()) } @@ -1002,22 +1099,21 @@ mod tests { async fn test_read_parquet_page_index() -> Result<()> { let testdata = datafusion_common::test_util::parquet_test_data(); let path = format!("{testdata}/alltypes_tiny_pages.parquet"); - let file = File::open(path).await.unwrap(); - let options = ArrowReaderOptions::new().with_page_index(true); + let file = File::open(path).await?; + let options = + ArrowReaderOptions::new().with_page_index_policy(PageIndexPolicy::Required); let builder = ParquetRecordBatchStreamBuilder::new_with_options(file, options.clone()) - .await - .unwrap() + .await? .metadata() .clone(); check_page_index_validation(builder.column_index(), builder.offset_index()); let path = format!("{testdata}/alltypes_tiny_pages_plain.parquet"); - let file = File::open(path).await.unwrap(); + let file = File::open(path).await?; let builder = ParquetRecordBatchStreamBuilder::new_with_options(file, options) - .await - .unwrap() + .await? .metadata() .clone(); check_page_index_validation(builder.column_index(), builder.offset_index()); @@ -1051,18 +1147,14 @@ mod tests { // 325 pages in int_col assert_eq!(int_col_offset.len(), 325); - match int_col_index { - Index::INT32(index) => { - assert_eq!(index.indexes.len(), 325); - for min_max in index.clone().indexes { - assert!(min_max.min.is_some()); - assert!(min_max.max.is_some()); - assert!(min_max.null_count.is_some()); - } - } - _ => { - error!("fail to read page index.") - } + let ColumnIndexMetaData::INT32(index) = int_col_index else { + panic!("fail to read page index.") + }; + assert_eq!(index.min_values().len(), 325); + assert_eq!(index.max_values().len(), 325); + // all values are non null + for idx in 0..325 { + assert_eq!(index.null_count(idx), Some(0)); } } @@ -1099,7 +1191,7 @@ mod tests { /// Test that 0-byte files don't break while reading #[tokio::test] async fn test_read_empty_parquet() -> Result<()> { - let tmp_dir = tempfile::TempDir::new().unwrap(); + let tmp_dir = tempfile::TempDir::new()?; let path = format!("{}/empty.parquet", tmp_dir.path().to_string_lossy()); File::create(&path).await?; @@ -1112,10 +1204,10 @@ mod tests { let result = df.collect().await?; - assert_snapshot!(batches_to_string(&result), @r###" - ++ - ++ - "###); + assert_snapshot!(batches_to_string(&result), @r" + ++ + ++ + "); Ok(()) } @@ -1123,12 +1215,10 @@ mod tests { /// Test that 0-byte files don't break while reading #[tokio::test] async fn test_read_partitioned_empty_parquet() -> Result<()> { - let tmp_dir = tempfile::TempDir::new().unwrap(); + let tmp_dir = tempfile::TempDir::new()?; let partition_dir = tmp_dir.path().join("col1=a"); - std::fs::create_dir(&partition_dir).unwrap(); - File::create(partition_dir.join("empty.parquet")) - .await - .unwrap(); + std::fs::create_dir(&partition_dir)?; + File::create(partition_dir.join("empty.parquet")).await?; let ctx = SessionContext::new(); @@ -1143,10 +1233,10 @@ mod tests { let result = df.collect().await?; - assert_snapshot!(batches_to_string(&result), @r###" - ++ - ++ - "###); + assert_snapshot!(batches_to_string(&result), @r" + ++ + ++ + "); Ok(()) } @@ -1246,6 +1336,56 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_write_empty_recordbatch_creates_file() -> Result<()> { + let empty_record_batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(Vec::::new()))], + ) + .expect("Failed to create empty RecordBatch"); + + let tmp_dir = tempfile::TempDir::new()?; + let path = format!("{}/empty2.parquet", tmp_dir.path().to_string_lossy()); + + let ctx = SessionContext::new(); + let df = ctx.read_batch(empty_record_batch.clone())?; + df.write_parquet(&path, crate::dataframe::DataFrameWriteOptions::new(), None) + .await?; + assert!(std::path::Path::new(&path).exists()); + + let stream = ctx + .read_parquet(&path, ParquetReadOptions::new()) + .await? + .execute_stream() + .await?; + assert_eq!(stream.schema(), empty_record_batch.schema()); + let results = stream.collect::>().await; + assert_eq!(results.len(), 0); + Ok(()) + } + + #[tokio::test] + async fn test_write_empty_parquet_from_sql() -> Result<()> { + let ctx = SessionContext::new(); + + let tmp_dir = tempfile::TempDir::new()?; + let path = format!("{}/empty_sql.parquet", tmp_dir.path().to_string_lossy()); + let df = ctx.sql("SELECT CAST(1 AS INT) AS id LIMIT 0").await?; + df.write_parquet(&path, crate::dataframe::DataFrameWriteOptions::new(), None) + .await?; + // Expected the file to exist + assert!(std::path::Path::new(&path).exists()); + let read_df = ctx.read_parquet(&path, ParquetReadOptions::new()).await?; + let stream = read_df.execute_stream().await?; + assert_eq!(stream.schema().fields().len(), 1); + assert_eq!(stream.schema().field(0).name(), "id"); + + let results: Vec<_> = stream.collect().await; + assert_eq!(results.len(), 0); + + Ok(()) + } + #[tokio::test] async fn parquet_sink_write_insert_schema_into_metadata() -> Result<()> { // expected kv metadata without schema @@ -1405,6 +1545,7 @@ mod tests { insert_op: InsertOp::Overwrite, keep_partition_by_columns: false, file_extension: "parquet".into(), + file_output_mode: FileOutputMode::Automatic, }; let parquet_sink = Arc::new(ParquetSink::new( file_sink_config, @@ -1421,7 +1562,7 @@ mod tests { // create data let col_a: ArrayRef = Arc::new(StringArray::from(vec!["foo", "bar"])); let col_b: ArrayRef = Arc::new(StringArray::from(vec!["baz", "baz"])); - let batch = RecordBatch::try_from_iter(vec![("a", col_a), ("b", col_b)]).unwrap(); + let batch = RecordBatch::try_from_iter(vec![("a", col_a), ("b", col_b)])?; // write stream FileSink::write_all( @@ -1437,7 +1578,7 @@ mod tests { Ok(parquet_sink) } - fn get_written(parquet_sink: Arc) -> Result<(Path, FileMetaData)> { + fn get_written(parquet_sink: Arc) -> Result<(Path, ParquetMetaData)> { let mut written = parquet_sink.written(); let written = written.drain(); assert_eq!( @@ -1447,28 +1588,33 @@ mod tests { written.len() ); - let (path, file_metadata) = written.take(1).next().unwrap(); - Ok((path, file_metadata)) + let (path, parquet_meta_data) = written.take(1).next().unwrap(); + Ok((path, parquet_meta_data)) } - fn assert_file_metadata(file_metadata: FileMetaData, expected_kv: &Vec) { - let FileMetaData { - num_rows, - schema, - key_value_metadata, - .. - } = file_metadata; - assert_eq!(num_rows, 2, "file metadata to have 2 rows"); + fn assert_file_metadata( + parquet_meta_data: ParquetMetaData, + expected_kv: &Vec, + ) { + let file_metadata = parquet_meta_data.file_metadata(); + let schema_descr = file_metadata.schema_descr(); + assert_eq!(file_metadata.num_rows(), 2, "file metadata to have 2 rows"); assert!( - schema.iter().any(|col_schema| col_schema.name == "a"), + schema_descr + .columns() + .iter() + .any(|col_schema| col_schema.name() == "a"), "output file metadata should contain col a" ); assert!( - schema.iter().any(|col_schema| col_schema.name == "b"), + schema_descr + .columns() + .iter() + .any(|col_schema| col_schema.name() == "b"), "output file metadata should contain col b" ); - let mut key_value_metadata = key_value_metadata.unwrap(); + let mut key_value_metadata = file_metadata.key_value_metadata().unwrap().clone(); key_value_metadata.sort_by(|a, b| a.key.cmp(&b.key)); assert_eq!(&key_value_metadata, expected_kv); } @@ -1491,6 +1637,7 @@ mod tests { insert_op: InsertOp::Overwrite, keep_partition_by_columns: false, file_extension: "parquet".into(), + file_output_mode: FileOutputMode::Automatic, }; let parquet_sink = Arc::new(ParquetSink::new( file_sink_config, @@ -1500,7 +1647,7 @@ mod tests { // create data with 2 partitions let col_a: ArrayRef = Arc::new(StringArray::from(vec!["foo", "bar"])); let col_b: ArrayRef = Arc::new(StringArray::from(vec!["baz", "baz"])); - let batch = RecordBatch::try_from_iter(vec![("a", col_a), ("b", col_b)]).unwrap(); + let batch = RecordBatch::try_from_iter(vec![("a", col_a), ("b", col_b)])?; // write stream FileSink::write_all( @@ -1525,13 +1672,11 @@ mod tests { // check the file metadata includes partitions let mut expected_partitions = std::collections::HashSet::from(["a=foo", "a=bar"]); - for ( - path, - FileMetaData { - num_rows, schema, .. - }, - ) in written.take(2) - { + for (path, parquet_metadata) in written.take(2) { + let file_metadata = parquet_metadata.file_metadata(); + let schema = file_metadata.schema_descr(); + let num_rows = file_metadata.num_rows(); + let path_parts = path.parts().collect::>(); assert_eq!(path_parts.len(), 2, "should have path prefix"); @@ -1544,11 +1689,17 @@ mod tests { assert_eq!(num_rows, 1, "file metadata to have 1 row"); assert!( - !schema.iter().any(|col_schema| col_schema.name == "a"), + !schema + .columns() + .iter() + .any(|col_schema| col_schema.name() == "a"), "output file metadata will not contain partitioned col a" ); assert!( - schema.iter().any(|col_schema| col_schema.name == "b"), + schema + .columns() + .iter() + .any(|col_schema| col_schema.name() == "b"), "output file metadata should contain col b" ); } @@ -1577,6 +1728,7 @@ mod tests { insert_op: InsertOp::Overwrite, keep_partition_by_columns: false, file_extension: "parquet".into(), + file_output_mode: FileOutputMode::Automatic, }; let parquet_sink = Arc::new(ParquetSink::new( file_sink_config, @@ -1593,8 +1745,7 @@ mod tests { // create data let col_a: ArrayRef = Arc::new(StringArray::from(vec!["foo", "bar"])); let col_b: ArrayRef = Arc::new(StringArray::from(vec!["baz", "baz"])); - let batch = - RecordBatch::try_from_iter(vec![("a", col_a), ("b", col_b)]).unwrap(); + let batch = RecordBatch::try_from_iter(vec![("a", col_a), ("b", col_b)])?; // create task context let task_context = build_ctx(object_store_url.as_ref()); @@ -1662,43 +1813,4 @@ mod tests { Ok(()) } - - /// Creates an bounded stream for testing purposes. - fn bounded_stream( - batch: RecordBatch, - limit: usize, - ) -> datafusion_execution::SendableRecordBatchStream { - Box::pin(BoundedStream { - count: 0, - limit, - batch, - }) - } - - struct BoundedStream { - limit: usize, - count: usize, - batch: RecordBatch, - } - - impl Stream for BoundedStream { - type Item = Result; - - fn poll_next( - mut self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { - if self.count >= self.limit { - return Poll::Ready(None); - } - self.count += 1; - Poll::Ready(Some(Ok(self.batch.clone()))) - } - } - - impl RecordBatchStream for BoundedStream { - fn schema(&self) -> SchemaRef { - self.batch.schema() - } - } } diff --git a/datafusion/core/src/datasource/listing/mod.rs b/datafusion/core/src/datasource/listing/mod.rs index a58db55bccb61..85dee3f91cffb 100644 --- a/datafusion/core/src/datasource/listing/mod.rs +++ b/datafusion/core/src/datasource/listing/mod.rs @@ -20,7 +20,9 @@ mod table; pub use datafusion_catalog_listing::helpers; -pub use datafusion_datasource::{ - FileRange, ListingTableUrl, PartitionedFile, PartitionedFileStream, -}; -pub use table::{ListingOptions, ListingTable, ListingTableConfig}; +pub use datafusion_catalog_listing::{ListingOptions, ListingTable, ListingTableConfig}; +// Keep for backwards compatibility until removed +#[expect(deprecated)] +pub use datafusion_datasource::PartitionedFileStream; +pub use datafusion_datasource::{FileRange, ListingTableUrl, PartitionedFile}; +pub use table::ListingTableConfigExt; diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 3c87d3ee2329c..5dd11739c1f57 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -15,144 +15,42 @@ // specific language governing permissions and limitations // under the License. -//! The table implementation. - -use super::helpers::{expr_applicable_for_cols, pruned_partition_list}; -use super::{ListingTableUrl, PartitionedFile}; +use crate::execution::SessionState; +use async_trait::async_trait; +use datafusion_catalog_listing::{ListingOptions, ListingTableConfig}; +use datafusion_common::{config_datafusion_err, internal_datafusion_err}; +use datafusion_session::Session; +use futures::StreamExt; use std::collections::HashMap; -use std::{any::Any, str::FromStr, sync::Arc}; - -use crate::datasource::{ - create_ordering, - file_format::{file_compression_type::FileCompressionType, FileFormat}, - physical_plan::FileSinkConfig, -}; -use crate::execution::context::SessionState; -use datafusion_catalog::TableProvider; -use datafusion_common::{config_err, DataFusionError, Result}; -use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; -use datafusion_datasource::schema_adapter::DefaultSchemaAdapterFactory; -use datafusion_execution::config::SessionConfig; -use datafusion_expr::dml::InsertOp; -use datafusion_expr::{Expr, TableProviderFilterPushDown}; -use datafusion_expr::{SortExpr, TableType}; -use datafusion_physical_plan::empty::EmptyExec; -use datafusion_physical_plan::{ExecutionPlan, Statistics}; - -use arrow::datatypes::{DataType, Field, Schema, SchemaBuilder, SchemaRef}; -use datafusion_common::{ - config_datafusion_err, internal_err, plan_err, project_schema, Constraints, SchemaExt, -}; -use datafusion_execution::cache::{ - cache_manager::FileStatisticsCache, cache_unit::DefaultFileStatisticsCache, -}; -use datafusion_physical_expr::{LexOrdering, PhysicalSortRequirement}; -use async_trait::async_trait; -use datafusion_catalog::Session; -use datafusion_common::stats::Precision; -use datafusion_datasource::compute_all_files_statistics; -use datafusion_datasource::file_groups::FileGroup; -use datafusion_physical_expr_common::sort_expr::LexRequirement; -use futures::{future, stream, Stream, StreamExt, TryStreamExt}; -use itertools::Itertools; -use object_store::ObjectStore; - -/// Configuration for creating a [`ListingTable`] -/// +/// Extension trait for [`ListingTableConfig`] that supports inferring schemas /// -#[derive(Debug, Clone)] -pub struct ListingTableConfig { - /// Paths on the `ObjectStore` for creating `ListingTable`. - /// They should share the same schema and object store. - pub table_paths: Vec, - /// Optional `SchemaRef` for the to be created `ListingTable`. - /// - /// See details on [`ListingTableConfig::with_schema`] - pub file_schema: Option, - /// Optional [`ListingOptions`] for the to be created [`ListingTable`]. - /// - /// See details on [`ListingTableConfig::with_listing_options`] - pub options: Option, -} - -impl ListingTableConfig { - /// Creates new [`ListingTableConfig`] for reading the specified URL - pub fn new(table_path: ListingTableUrl) -> Self { - let table_paths = vec![table_path]; - Self { - table_paths, - file_schema: None, - options: None, - } - } - - /// Creates new [`ListingTableConfig`] with multiple table paths. - /// - /// See [`Self::infer_options`] for details on what happens with multiple paths - pub fn new_with_multi_paths(table_paths: Vec) -> Self { - Self { - table_paths, - file_schema: None, - options: None, - } - } - /// Set the `schema` for the overall [`ListingTable`] - /// - /// [`ListingTable`] will automatically coerce, when possible, the schema - /// for individual files to match this schema. - /// - /// If a schema is not provided, it is inferred using - /// [`Self::infer_schema`]. - /// - /// If the schema is provided, it must contain only the fields in the file - /// without the table partitioning columns. - pub fn with_schema(self, schema: SchemaRef) -> Self { - Self { - table_paths: self.table_paths, - file_schema: Some(schema), - options: self.options, - } - } - - /// Add `listing_options` to [`ListingTableConfig`] - /// - /// If not provided, format and other options are inferred via - /// [`Self::infer_options`]. - pub fn with_listing_options(self, listing_options: ListingOptions) -> Self { - Self { - table_paths: self.table_paths, - file_schema: self.file_schema, - options: Some(listing_options), - } - } - - /// Returns a tuple of `(file_extension, optional compression_extension)` - /// - /// For example a path ending with blah.test.csv.gz returns `("csv", Some("gz"))` - /// For example a path ending with blah.test.csv returns `("csv", None)` - fn infer_file_extension_and_compression_type( - path: &str, - ) -> Result<(String, Option)> { - let mut exts = path.rsplit('.'); - - let splitted = exts.next().unwrap_or(""); - - let file_compression_type = FileCompressionType::from_str(splitted) - .unwrap_or(FileCompressionType::UNCOMPRESSED); - - if file_compression_type.is_compressed() { - let splitted2 = exts.next().unwrap_or(""); - Ok((splitted2.to_string(), Some(splitted.to_string()))) - } else { - Ok((splitted.to_string(), None)) - } - } - +/// This trait exists because the following inference methods only +/// work for [`SessionState`] implementations of [`Session`]. +/// See [`ListingTableConfig`] for the remaining inference methods. +#[async_trait] +pub trait ListingTableConfigExt { /// Infer `ListingOptions` based on `table_path` and file suffix. /// /// The format is inferred based on the first `table_path`. - pub async fn infer_options(self, state: &dyn Session) -> Result { + async fn infer_options( + self, + state: &dyn Session, + ) -> datafusion_common::Result; + + /// Convenience method to call both [`Self::infer_options`] and [`ListingTableConfig::infer_schema`] + async fn infer( + self, + state: &dyn Session, + ) -> datafusion_common::Result; +} + +#[async_trait] +impl ListingTableConfigExt for ListingTableConfig { + async fn infer_options( + self, + state: &dyn Session, + ) -> datafusion_common::Result { let store = if let Some(url) = self.table_paths.first() { state.runtime_env().object_store(url)? } else { @@ -167,7 +65,7 @@ impl ListingTableConfig { .await? .next() .await - .ok_or_else(|| DataFusionError::Internal("No files for table".into()))??; + .ok_or_else(|| internal_datafusion_err!("No files for table"))??; let (file_extension, maybe_compression_type) = ListingTableConfig::infer_file_extension_and_compression_type( @@ -199,1099 +97,136 @@ impl ListingTableConfig { .with_target_partitions(state.config().target_partitions()) .with_collect_stat(state.config().collect_statistics()); - Ok(Self { - table_paths: self.table_paths, - file_schema: self.file_schema, - options: Some(listing_options), - }) + Ok(self.with_listing_options(listing_options)) } - /// Infer the [`SchemaRef`] based on `table_path`s. - /// - /// This method infers the table schema using the first `table_path`. - /// See [`ListingOptions::infer_schema`] for more details - /// - /// # Errors - /// * if `self.options` is not set. See [`Self::with_listing_options`] - pub async fn infer_schema(self, state: &dyn Session) -> Result { - match self.options { - Some(options) => { - let schema = if let Some(url) = self.table_paths.first() { - options.infer_schema(state, url).await? - } else { - Arc::new(Schema::empty()) - }; - - Ok(Self { - table_paths: self.table_paths, - file_schema: Some(schema), - options: Some(options), - }) - } - None => internal_err!("No `ListingOptions` set for inferring schema"), - } - } - - /// Convenience method to call both [`Self::infer_options`] and [`Self::infer_schema`] - pub async fn infer(self, state: &dyn Session) -> Result { + async fn infer(self, state: &dyn Session) -> datafusion_common::Result { self.infer_options(state).await?.infer_schema(state).await } - - /// Infer the partition columns from `table_paths`. - /// - /// # Errors - /// * if `self.options` is not set. See [`Self::with_listing_options`] - pub async fn infer_partitions_from_path(self, state: &dyn Session) -> Result { - match self.options { - Some(options) => { - let Some(url) = self.table_paths.first() else { - return config_err!("No table path found"); - }; - let partitions = options - .infer_partitions(state, url) - .await? - .into_iter() - .map(|col_name| { - ( - col_name, - DataType::Dictionary( - Box::new(DataType::UInt16), - Box::new(DataType::Utf8), - ), - ) - }) - .collect::>(); - let options = options.with_table_partition_cols(partitions); - Ok(Self { - table_paths: self.table_paths, - file_schema: self.file_schema, - options: Some(options), - }) - } - None => config_err!("No `ListingOptions` set for inferring schema"), - } - } } -/// Options for creating a [`ListingTable`] -#[derive(Clone, Debug)] -pub struct ListingOptions { - /// A suffix on which files should be filtered (leave empty to - /// keep all files on the path) - pub file_extension: String, - /// The file format - pub format: Arc, - /// The expected partition column names in the folder structure. - /// See [Self::with_table_partition_cols] for details - pub table_partition_cols: Vec<(String, DataType)>, - /// Set true to try to guess statistics from the files. - /// This can add a lot of overhead as it will usually require files - /// to be opened and at least partially parsed. - pub collect_stat: bool, - /// Group files to avoid that the number of partitions exceeds - /// this limit - pub target_partitions: usize, - /// Optional pre-known sort order(s). Must be `SortExpr`s. - /// - /// DataFusion may take advantage of this ordering to omit sorts - /// or use more efficient algorithms. Currently sortedness must be - /// provided if it is known by some external mechanism, but may in - /// the future be automatically determined, for example using - /// parquet metadata. - /// - /// See - /// - /// NOTE: This attribute stores all equivalent orderings (the outer `Vec`) - /// where each ordering consists of an individual lexicographic - /// ordering (encapsulated by a `Vec`). If there aren't - /// multiple equivalent orderings, the outer `Vec` will have a - /// single element. - pub file_sort_order: Vec>, -} - -impl ListingOptions { - /// Creates an options instance with the given format - /// Default values: - /// - use default file extension filter - /// - no input partition to discover - /// - one target partition - /// - do not collect statistics - pub fn new(format: Arc) -> Self { - Self { - file_extension: format.get_ext(), - format, - table_partition_cols: vec![], - collect_stat: false, - target_partitions: 1, - file_sort_order: vec![], - } - } - - /// Set options from [`SessionConfig`] and returns self. - /// - /// Currently this sets `target_partitions` and `collect_stat` - /// but if more options are added in the future that need to be coordinated - /// they will be synchronized thorugh this method. - pub fn with_session_config_options(mut self, config: &SessionConfig) -> Self { - self = self.with_target_partitions(config.target_partitions()); - self = self.with_collect_stat(config.collect_statistics()); - self - } - - /// Set file extension on [`ListingOptions`] and returns self. - /// - /// # Example - /// ``` - /// # use std::sync::Arc; - /// # use datafusion::prelude::SessionContext; - /// # use datafusion::datasource::{listing::ListingOptions, file_format::parquet::ParquetFormat}; - /// - /// let listing_options = ListingOptions::new(Arc::new( - /// ParquetFormat::default() - /// )) - /// .with_file_extension(".parquet"); - /// - /// assert_eq!(listing_options.file_extension, ".parquet"); - /// ``` - pub fn with_file_extension(mut self, file_extension: impl Into) -> Self { - self.file_extension = file_extension.into(); - self - } - - /// Optionally set file extension on [`ListingOptions`] and returns self. - /// - /// If `file_extension` is `None`, the file extension will not be changed - /// - /// # Example - /// ``` - /// # use std::sync::Arc; - /// # use datafusion::prelude::SessionContext; - /// # use datafusion::datasource::{listing::ListingOptions, file_format::parquet::ParquetFormat}; - /// let extension = Some(".parquet"); - /// let listing_options = ListingOptions::new(Arc::new( - /// ParquetFormat::default() - /// )) - /// .with_file_extension_opt(extension); - /// - /// assert_eq!(listing_options.file_extension, ".parquet"); - /// ``` - pub fn with_file_extension_opt(mut self, file_extension: Option) -> Self - where - S: Into, - { - if let Some(file_extension) = file_extension { - self.file_extension = file_extension.into(); - } - self - } - - /// Set `table partition columns` on [`ListingOptions`] and returns self. - /// - /// "partition columns," used to support [Hive Partitioning], are - /// columns added to the data that is read, based on the folder - /// structure where the data resides. - /// - /// For example, give the following files in your filesystem: - /// - /// ```text - /// /mnt/nyctaxi/year=2022/month=01/tripdata.parquet - /// /mnt/nyctaxi/year=2021/month=12/tripdata.parquet - /// /mnt/nyctaxi/year=2021/month=11/tripdata.parquet - /// ``` - /// - /// A [`ListingTable`] created at `/mnt/nyctaxi/` with partition - /// columns "year" and "month" will include new `year` and `month` - /// columns while reading the files. The `year` column would have - /// value `2022` and the `month` column would have value `01` for - /// the rows read from - /// `/mnt/nyctaxi/year=2022/month=01/tripdata.parquet` - /// - ///# Notes - /// - /// - If only one level (e.g. `year` in the example above) is - /// specified, the other levels are ignored but the files are - /// still read. - /// - /// - Files that don't follow this partitioning scheme will be - /// ignored. - /// - /// - Since the columns have the same value for all rows read from - /// each individual file (such as dates), they are typically - /// dictionary encoded for efficiency. You may use - /// [`wrap_partition_type_in_dict`] to request a - /// dictionary-encoded type. - /// - /// - The partition columns are solely extracted from the file path. Especially they are NOT part of the parquet files itself. - /// - /// # Example - /// - /// ``` - /// # use std::sync::Arc; - /// # use arrow::datatypes::DataType; - /// # use datafusion::prelude::col; - /// # use datafusion::datasource::{listing::ListingOptions, file_format::parquet::ParquetFormat}; - /// - /// // listing options for files with paths such as `/mnt/data/col_a=x/col_b=y/data.parquet` - /// // `col_a` and `col_b` will be included in the data read from those files - /// let listing_options = ListingOptions::new(Arc::new( - /// ParquetFormat::default() - /// )) - /// .with_table_partition_cols(vec![("col_a".to_string(), DataType::Utf8), - /// ("col_b".to_string(), DataType::Utf8)]); - /// - /// assert_eq!(listing_options.table_partition_cols, vec![("col_a".to_string(), DataType::Utf8), - /// ("col_b".to_string(), DataType::Utf8)]); - /// ``` - /// - /// [Hive Partitioning]: https://docs.cloudera.com/HDPDocuments/HDP2/HDP-2.1.3/bk_system-admin-guide/content/hive_partitioned_tables.html - /// [`wrap_partition_type_in_dict`]: crate::datasource::physical_plan::wrap_partition_type_in_dict - pub fn with_table_partition_cols( - mut self, - table_partition_cols: Vec<(String, DataType)>, - ) -> Self { - self.table_partition_cols = table_partition_cols; - self - } - - /// Set stat collection on [`ListingOptions`] and returns self. - /// - /// ``` - /// # use std::sync::Arc; - /// # use datafusion::datasource::{listing::ListingOptions, file_format::parquet::ParquetFormat}; - /// - /// let listing_options = ListingOptions::new(Arc::new( - /// ParquetFormat::default() - /// )) - /// .with_collect_stat(true); - /// - /// assert_eq!(listing_options.collect_stat, true); - /// ``` - pub fn with_collect_stat(mut self, collect_stat: bool) -> Self { - self.collect_stat = collect_stat; - self - } - - /// Set number of target partitions on [`ListingOptions`] and returns self. - /// - /// ``` - /// # use std::sync::Arc; - /// # use datafusion::datasource::{listing::ListingOptions, file_format::parquet::ParquetFormat}; - /// - /// let listing_options = ListingOptions::new(Arc::new( - /// ParquetFormat::default() - /// )) - /// .with_target_partitions(8); - /// - /// assert_eq!(listing_options.target_partitions, 8); - /// ``` - pub fn with_target_partitions(mut self, target_partitions: usize) -> Self { - self.target_partitions = target_partitions; - self - } - - /// Set file sort order on [`ListingOptions`] and returns self. - /// - /// ``` - /// # use std::sync::Arc; - /// # use datafusion::prelude::col; - /// # use datafusion::datasource::{listing::ListingOptions, file_format::parquet::ParquetFormat}; - /// - /// // Tell datafusion that the files are sorted by column "a" - /// let file_sort_order = vec![vec![ - /// col("a").sort(true, true) - /// ]]; - /// - /// let listing_options = ListingOptions::new(Arc::new( - /// ParquetFormat::default() - /// )) - /// .with_file_sort_order(file_sort_order.clone()); - /// - /// assert_eq!(listing_options.file_sort_order, file_sort_order); - /// ``` - pub fn with_file_sort_order(mut self, file_sort_order: Vec>) -> Self { - self.file_sort_order = file_sort_order; - self - } - - /// Infer the schema of the files at the given path on the provided object store. - /// - /// If the table_path contains one or more files (i.e. it is a directory / - /// prefix of files) their schema is merged by calling [`FileFormat::infer_schema`] - /// - /// Note: The inferred schema does not include any partitioning columns. - /// - /// This method is called as part of creating a [`ListingTable`]. - pub async fn infer_schema<'a>( - &'a self, - state: &dyn Session, - table_path: &'a ListingTableUrl, - ) -> Result { - let store = state.runtime_env().object_store(table_path)?; - - let files: Vec<_> = table_path - .list_all_files(state, store.as_ref(), &self.file_extension) - .await? - // Empty files cannot affect schema but may throw when trying to read for it - .try_filter(|object_meta| future::ready(object_meta.size > 0)) - .try_collect() - .await?; - - let schema = self.format.infer_schema(state, &store, &files).await?; - - Ok(schema) - } - - /// Infers the partition columns stored in `LOCATION` and compares - /// them with the columns provided in `PARTITIONED BY` to help prevent - /// accidental corrupts of partitioned tables. - /// - /// Allows specifying partial partitions. - pub async fn validate_partitions( - &self, - state: &dyn Session, - table_path: &ListingTableUrl, - ) -> Result<()> { - if self.table_partition_cols.is_empty() { - return Ok(()); - } - - if !table_path.is_collection() { - return plan_err!( - "Can't create a partitioned table backed by a single file, \ - perhaps the URL is missing a trailing slash?" - ); - } - - let inferred = self.infer_partitions(state, table_path).await?; - - // no partitioned files found on disk - if inferred.is_empty() { - return Ok(()); - } - - let table_partition_names = self - .table_partition_cols - .iter() - .map(|(col_name, _)| col_name.clone()) - .collect_vec(); - - if inferred.len() < table_partition_names.len() { - return plan_err!( - "Inferred partitions to be {:?}, but got {:?}", - inferred, - table_partition_names - ); - } - - // match prefix to allow creating tables with partial partitions - for (idx, col) in table_partition_names.iter().enumerate() { - if &inferred[idx] != col { - return plan_err!( - "Inferred partitions to be {:?}, but got {:?}", - inferred, - table_partition_names - ); - } - } - - Ok(()) - } - - /// Infer the partitioning at the given path on the provided object store. - /// For performance reasons, it doesn't read all the files on disk - /// and therefore may fail to detect invalid partitioning. - pub(crate) async fn infer_partitions( - &self, - state: &dyn Session, - table_path: &ListingTableUrl, - ) -> Result> { - let store = state.runtime_env().object_store(table_path)?; - - // only use 10 files for inference - // This can fail to detect inconsistent partition keys - // A DFS traversal approach of the store can help here - let files: Vec<_> = table_path - .list_all_files(state, store.as_ref(), &self.file_extension) - .await? - .take(10) - .try_collect() - .await?; +#[cfg(test)] +mod tests { + #[cfg(feature = "parquet")] + use crate::datasource::file_format::parquet::ParquetFormat; + use crate::datasource::listing::table::ListingTableConfigExt; + use crate::execution::options::JsonReadOptions; + use crate::prelude::*; + use crate::{ + datasource::{ + DefaultTableSource, MemTable, file_format::csv::CsvFormat, + file_format::json::JsonFormat, provider_as_source, + }, + execution::options::ArrowReadOptions, + test::{ + columns, object_store::ensure_head_concurrency, + object_store::make_test_store_and_state, object_store::register_test_store, + }, + }; + use arrow::{compute::SortOptions, record_batch::RecordBatch}; + use arrow_schema::{DataType, Field, Schema, SchemaRef}; + use datafusion_catalog::TableProvider; + use datafusion_catalog_listing::{ + ListingOptions, ListingTable, ListingTableConfig, SchemaSource, + }; + use datafusion_common::{ + DataFusionError, Result, ScalarValue, assert_contains, + stats::Precision, + test_util::{batches_to_string, datafusion_test_data}, + }; + use datafusion_datasource::ListingTableUrl; + use datafusion_datasource::file_compression_type::FileCompressionType; + use datafusion_datasource::file_format::FileFormat; + use datafusion_expr::dml::InsertOp; + use datafusion_expr::{BinaryExpr, LogicalPlanBuilder, Operator}; + use datafusion_physical_expr::PhysicalSortExpr; + use datafusion_physical_expr::expressions::binary; + use datafusion_physical_expr_common::sort_expr::LexOrdering; + use datafusion_physical_plan::empty::EmptyExec; + use datafusion_physical_plan::{ExecutionPlanProperties, collect}; + use std::collections::HashMap; + use std::io::Write; + use std::sync::Arc; + use tempfile::TempDir; + use url::Url; - let stripped_path_parts = files.iter().map(|file| { - table_path - .strip_prefix(&file.location) - .unwrap() - .collect_vec() - }); - - let partition_keys = stripped_path_parts - .map(|path_parts| { - path_parts - .into_iter() - .rev() - .skip(1) // get parents only; skip the file itself - .rev() - .map(|s| s.split('=').take(1).collect()) - .collect_vec() - }) - .collect_vec(); - - match partition_keys.into_iter().all_equal_value() { - Ok(v) => Ok(v), - Err(None) => Ok(vec![]), - Err(Some(diff)) => { - let mut sorted_diff = [diff.0, diff.1]; - sorted_diff.sort(); - plan_err!("Found mixed partition values on disk {:?}", sorted_diff) - } - } + /// Creates a test schema with standard field types used in tests + fn create_test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Float32, true), + Field::new("c2", DataType::Float64, true), + Field::new("c3", DataType::Boolean, true), + Field::new("c4", DataType::Utf8, true), + ])) + } + + /// Helper function to generate test file paths with given prefix, count, and optional start index + fn generate_test_files(prefix: &str, count: usize) -> Vec { + generate_test_files_with_start(prefix, count, 0) + } + + /// Helper function to generate test file paths with given prefix, count, and start index + fn generate_test_files_with_start( + prefix: &str, + count: usize, + start_index: usize, + ) -> Vec { + (start_index..start_index + count) + .map(|i| format!("{prefix}/file{i}")) + .collect() } -} - -/// Reads data from one or more files as a single table. -/// -/// Implements [`TableProvider`], a DataFusion data source. The files are read -/// using an [`ObjectStore`] instance, for example from local files or objects -/// from AWS S3. -/// -/// # Reading Directories -/// For example, given the `table1` directory (or object store prefix) -/// -/// ```text -/// table1 -/// ├── file1.parquet -/// └── file2.parquet -/// ``` -/// -/// A `ListingTable` would read the files `file1.parquet` and `file2.parquet` as -/// a single table, merging the schemas if the files have compatible but not -/// identical schemas. -/// -/// Given the `table2` directory (or object store prefix) -/// -/// ```text -/// table2 -/// ├── date=2024-06-01 -/// │ ├── file3.parquet -/// │ └── file4.parquet -/// └── date=2024-06-02 -/// └── file5.parquet -/// ``` -/// -/// A `ListingTable` would read the files `file3.parquet`, `file4.parquet`, and -/// `file5.parquet` as a single table, again merging schemas if necessary. -/// -/// Given the hive style partitioning structure (e.g,. directories named -/// `date=2024-06-01` and `date=2026-06-02`), `ListingTable` also adds a `date` -/// column when reading the table: -/// * The files in `table2/date=2024-06-01` will have the value `2024-06-01` -/// * The files in `table2/date=2024-06-02` will have the value `2024-06-02`. -/// -/// If the query has a predicate like `WHERE date = '2024-06-01'` -/// only the corresponding directory will be read. -/// -/// `ListingTable` also supports limit, filter and projection pushdown for formats that -/// support it as such as Parquet. -/// -/// # See Also -/// -/// 1. [`ListingTableConfig`]: Configuration options -/// 1. [`DataSourceExec`]: `ExecutionPlan` used by `ListingTable` -/// -/// [`DataSourceExec`]: crate::datasource::source::DataSourceExec -/// -/// # Example: Read a directory of parquet files using a [`ListingTable`] -/// -/// ```no_run -/// # use datafusion::prelude::SessionContext; -/// # use datafusion::error::Result; -/// # use std::sync::Arc; -/// # use datafusion::datasource::{ -/// # listing::{ -/// # ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, -/// # }, -/// # file_format::parquet::ParquetFormat, -/// # }; -/// # #[tokio::main] -/// # async fn main() -> Result<()> { -/// let ctx = SessionContext::new(); -/// let session_state = ctx.state(); -/// let table_path = "/path/to/parquet"; -/// -/// // Parse the path -/// let table_path = ListingTableUrl::parse(table_path)?; -/// -/// // Create default parquet options -/// let file_format = ParquetFormat::new(); -/// let listing_options = ListingOptions::new(Arc::new(file_format)) -/// .with_file_extension(".parquet"); -/// -/// // Resolve the schema -/// let resolved_schema = listing_options -/// .infer_schema(&session_state, &table_path) -/// .await?; -/// -/// let config = ListingTableConfig::new(table_path) -/// .with_listing_options(listing_options) -/// .with_schema(resolved_schema); -/// -/// // Create a new TableProvider -/// let provider = Arc::new(ListingTable::try_new(config)?); -/// -/// // This provider can now be read as a dataframe: -/// let df = ctx.read_table(provider.clone()); -/// -/// // or registered as a named table: -/// ctx.register_table("my_table", provider); -/// -/// # Ok(()) -/// # } -/// ``` -#[derive(Debug, Clone)] -pub struct ListingTable { - table_paths: Vec, - /// `file_schema` contains only the columns physically stored in the data files themselves. - /// - Represents the actual fields found in files like Parquet, CSV, etc. - /// - Used when reading the raw data from files - file_schema: SchemaRef, - /// `table_schema` combines `file_schema` + partition columns - /// - Partition columns are derived from directory paths (not stored in files) - /// - These are columns like "year=2022/month=01" in paths like `/data/year=2022/month=01/file.parquet` - table_schema: SchemaRef, - options: ListingOptions, - definition: Option, - collected_statistics: FileStatisticsCache, - constraints: Constraints, - column_defaults: HashMap, -} - -impl ListingTable { - /// Create new [`ListingTable`] - /// - /// See documentation and example on [`ListingTable`] and [`ListingTableConfig`] - pub fn try_new(config: ListingTableConfig) -> Result { - let file_schema = config - .file_schema - .ok_or_else(|| DataFusionError::Internal("No schema provided.".into()))?; - - let options = config.options.ok_or_else(|| { - DataFusionError::Internal("No ListingOptions provided".into()) - })?; - - // Add the partition columns to the file schema - let mut builder = SchemaBuilder::from(file_schema.as_ref().to_owned()); - for (part_col_name, part_col_type) in &options.table_partition_cols { - builder.push(Field::new(part_col_name, part_col_type.clone(), false)); - } - let table_schema = Arc::new( - builder - .finish() - .with_metadata(file_schema.metadata().clone()), + #[tokio::test] + async fn test_schema_source_tracking_comprehensive() -> Result<()> { + let ctx = SessionContext::new(); + let testdata = datafusion_test_data(); + let filename = format!("{testdata}/aggregate_simple.csv"); + let table_path = ListingTableUrl::parse(filename)?; + + // Test default schema source + let format = CsvFormat::default(); + let options = ListingOptions::new(Arc::new(format)); + let config = + ListingTableConfig::new(table_path.clone()).with_listing_options(options); + assert_eq!(config.schema_source(), SchemaSource::Unset); + + // Test schema source after setting a schema explicitly + let provided_schema = create_test_schema(); + let config_with_schema = config.clone().with_schema(provided_schema.clone()); + assert_eq!(config_with_schema.schema_source(), SchemaSource::Specified); + + // Test schema source after inferring schema + assert_eq!(config.schema_source(), SchemaSource::Unset); + + let config_with_inferred = config.infer_schema(&ctx.state()).await?; + assert_eq!(config_with_inferred.schema_source(), SchemaSource::Inferred); + + // Test schema preservation through operations + let config_with_schema_and_options = config_with_schema.clone(); + assert_eq!( + config_with_schema_and_options.schema_source(), + SchemaSource::Specified ); - let table = Self { - table_paths: config.table_paths, - file_schema, - table_schema, - options, - definition: None, - collected_statistics: Arc::new(DefaultFileStatisticsCache::default()), - constraints: Constraints::empty(), - column_defaults: HashMap::new(), - }; - - Ok(table) - } - - /// Assign constraints - pub fn with_constraints(mut self, constraints: Constraints) -> Self { - self.constraints = constraints; - self - } - - /// Assign column defaults - pub fn with_column_defaults( - mut self, - column_defaults: HashMap, - ) -> Self { - self.column_defaults = column_defaults; - self - } - - /// Set the [`FileStatisticsCache`] used to cache parquet file statistics. - /// - /// Setting a statistics cache on the `SessionContext` can avoid refetching statistics - /// multiple times in the same session. - /// - /// If `None`, creates a new [`DefaultFileStatisticsCache`] scoped to this query. - pub fn with_cache(mut self, cache: Option) -> Self { - self.collected_statistics = - cache.unwrap_or_else(|| Arc::new(DefaultFileStatisticsCache::default())); - self - } - - /// Specify the SQL definition for this table, if any - pub fn with_definition(mut self, definition: Option) -> Self { - self.definition = definition; - self - } - - /// Get paths ref - pub fn table_paths(&self) -> &Vec { - &self.table_paths - } - - /// Get options ref - pub fn options(&self) -> &ListingOptions { - &self.options - } - - /// If file_sort_order is specified, creates the appropriate physical expressions - fn try_create_output_ordering(&self) -> Result> { - create_ordering(&self.table_schema, &self.options.file_sort_order) - } -} - -// Expressions can be used for parttion pruning if they can be evaluated using -// only the partiton columns and there are partition columns. -fn can_be_evaluted_for_partition_pruning( - partition_column_names: &[&str], - expr: &Expr, -) -> bool { - !partition_column_names.is_empty() - && expr_applicable_for_cols(partition_column_names, expr) -} - -#[async_trait] -impl TableProvider for ListingTable { - fn as_any(&self) -> &dyn Any { - self - } - - fn schema(&self) -> SchemaRef { - Arc::clone(&self.table_schema) - } - - fn constraints(&self) -> Option<&Constraints> { - Some(&self.constraints) - } - - fn table_type(&self) -> TableType { - TableType::Base - } - - async fn scan( - &self, - state: &dyn Session, - projection: Option<&Vec>, - filters: &[Expr], - limit: Option, - ) -> Result> { - // extract types of partition columns - let table_partition_cols = self - .options - .table_partition_cols - .iter() - .map(|col| Ok(self.table_schema.field_with_name(&col.0)?.clone())) - .collect::>>()?; - - let table_partition_col_names = table_partition_cols - .iter() - .map(|field| field.name().as_str()) - .collect::>(); - // If the filters can be resolved using only partition cols, there is no need to - // pushdown it to TableScan, otherwise, `unhandled` pruning predicates will be generated - let (partition_filters, filters): (Vec<_>, Vec<_>) = - filters.iter().cloned().partition(|filter| { - can_be_evaluted_for_partition_pruning(&table_partition_col_names, filter) - }); - - // We should not limit the number of partitioned files to scan if there are filters and limit - // at the same time. This is because the limit should be applied after the filters are applied. - let statistic_file_limit = if filters.is_empty() { limit } else { None }; - - let (mut partitioned_file_lists, statistics) = self - .list_files_for_scan(state, &partition_filters, statistic_file_limit) + // Make sure inferred schema doesn't override specified schema + let config_with_schema_and_infer = config_with_schema_and_options + .clone() + .infer(&ctx.state()) .await?; + assert_eq!( + config_with_schema_and_infer.schema_source(), + SchemaSource::Specified + ); - // if no files need to be read, return an `EmptyExec` - if partitioned_file_lists.is_empty() { - let projected_schema = project_schema(&self.schema(), projection)?; - return Ok(Arc::new(EmptyExec::new(projected_schema))); - } - - let output_ordering = self.try_create_output_ordering()?; - match state - .config_options() - .execution - .split_file_groups_by_statistics - .then(|| { - output_ordering.first().map(|output_ordering| { - FileScanConfig::split_groups_by_statistics_with_target_partitions( - &self.table_schema, - &partitioned_file_lists, - output_ordering, - self.options.target_partitions, - ) - }) - }) - .flatten() - { - Some(Err(e)) => log::debug!("failed to split file groups by statistics: {e}"), - Some(Ok(new_groups)) => { - if new_groups.len() <= self.options.target_partitions { - partitioned_file_lists = new_groups; - } else { - log::debug!("attempted to split file groups by statistics, but there were more file groups than target_partitions; falling back to unordered") - } - } - None => {} // no ordering required - }; - - let Some(object_store_url) = - self.table_paths.first().map(ListingTableUrl::object_store) - else { - return Ok(Arc::new(EmptyExec::new(Arc::new(Schema::empty())))); - }; - - // create the execution plan - self.options - .format - .create_physical_plan( - state, - FileScanConfigBuilder::new( - object_store_url, - Arc::clone(&self.file_schema), - self.options.format.file_source(), - ) - .with_file_groups(partitioned_file_lists) - .with_constraints(self.constraints.clone()) - .with_statistics(statistics) - .with_projection(projection.cloned()) - .with_limit(limit) - .with_output_ordering(output_ordering) - .with_table_partition_cols(table_partition_cols) - .build(), - ) - .await - } - - fn supports_filters_pushdown( - &self, - filters: &[&Expr], - ) -> Result> { - let partition_column_names = self - .options - .table_partition_cols - .iter() - .map(|col| col.0.as_str()) - .collect::>(); - filters - .iter() - .map(|filter| { - if can_be_evaluted_for_partition_pruning(&partition_column_names, filter) - { - // if filter can be handled by partition pruning, it is exact - return Ok(TableProviderFilterPushDown::Exact); - } - - Ok(TableProviderFilterPushDown::Inexact) - }) - .collect() - } - - fn get_table_definition(&self) -> Option<&str> { - self.definition.as_deref() - } - - async fn insert_into( - &self, - state: &dyn Session, - input: Arc, - insert_op: InsertOp, - ) -> Result> { - // Check that the schema of the plan matches the schema of this table. - self.schema() - .logically_equivalent_names_and_types(&input.schema())?; - - let table_path = &self.table_paths()[0]; - if !table_path.is_collection() { - return plan_err!( - "Inserting into a ListingTable backed by a single file is not supported, URL is possibly missing a trailing `/`. \ - To append to an existing file use StreamTable, e.g. by using CREATE UNBOUNDED EXTERNAL TABLE" - ); - } - - // Get the object store for the table path. - let store = state.runtime_env().object_store(table_path)?; - - let file_list_stream = pruned_partition_list( - state, - store.as_ref(), - table_path, - &[], - &self.options.file_extension, - &self.options.table_partition_cols, - ) - .await?; - - let file_group = file_list_stream.try_collect::>().await?.into(); - let keep_partition_by_columns = - state.config_options().execution.keep_partition_by_columns; - - // Sink related option, apart from format - let config = FileSinkConfig { - original_url: String::default(), - object_store_url: self.table_paths()[0].object_store(), - table_paths: self.table_paths().clone(), - file_group, - output_schema: self.schema(), - table_partition_cols: self.options.table_partition_cols.clone(), - insert_op, - keep_partition_by_columns, - file_extension: self.options().format.get_ext(), - }; - - let order_requirements = if !self.options().file_sort_order.is_empty() { - // Multiple sort orders in outer vec are equivalent, so we pass only the first one - let orderings = self.try_create_output_ordering()?; - let Some(ordering) = orderings.first() else { - return internal_err!( - "Expected ListingTable to have a sort order, but none found!" - ); - }; - // Converts Vec> into type required by execution plan to specify its required input ordering - Some(LexRequirement::new( - ordering - .into_iter() - .cloned() - .map(PhysicalSortRequirement::from) - .collect::>(), - )) - } else { - None - }; - - self.options() - .format - .create_writer_physical_plan(input, state, config, order_requirements) - .await - } - - fn get_column_default(&self, column: &str) -> Option<&Expr> { - self.column_defaults.get(column) - } -} - -impl ListingTable { - /// Get the list of files for a scan as well as the file level statistics. - /// The list is grouped to let the execution plan know how the files should - /// be distributed to different threads / executors. - async fn list_files_for_scan<'a>( - &'a self, - ctx: &'a dyn Session, - filters: &'a [Expr], - limit: Option, - ) -> Result<(Vec, Statistics)> { - let store = if let Some(url) = self.table_paths.first() { - ctx.runtime_env().object_store(url)? - } else { - return Ok((vec![], Statistics::new_unknown(&self.file_schema))); - }; - // list files (with partitions) - let file_list = future::try_join_all(self.table_paths.iter().map(|table_path| { - pruned_partition_list( - ctx, - store.as_ref(), - table_path, - filters, - &self.options.file_extension, - &self.options.table_partition_cols, - ) - })) - .await?; - let meta_fetch_concurrency = - ctx.config_options().execution.meta_fetch_concurrency; - let file_list = stream::iter(file_list).flatten_unordered(meta_fetch_concurrency); - // collect the statistics if required by the config - let files = file_list - .map(|part_file| async { - let part_file = part_file?; - let statistics = if self.options.collect_stat { - self.do_collect_statistics(ctx, &store, &part_file).await? - } else { - Arc::new(Statistics::new_unknown(&self.file_schema)) - }; - Ok(part_file.with_statistics(statistics)) - }) - .boxed() - .buffer_unordered(ctx.config_options().execution.meta_fetch_concurrency); - - let (file_group, inexact_stats) = - get_files_with_limit(files, limit, self.options.collect_stat).await?; - - let file_groups = file_group.split_files(self.options.target_partitions); - let (mut file_groups, mut stats) = compute_all_files_statistics( - file_groups, - self.schema(), - self.options.collect_stat, - inexact_stats, - )?; - let (schema_mapper, _) = DefaultSchemaAdapterFactory::from_schema(self.schema()) - .map_schema(self.file_schema.as_ref())?; - stats.column_statistics = - schema_mapper.map_column_statistics(&stats.column_statistics)?; - file_groups.iter_mut().try_for_each(|file_group| { - if let Some(stat) = file_group.statistics_mut() { - stat.column_statistics = - schema_mapper.map_column_statistics(&stat.column_statistics)?; - } - Ok::<_, DataFusionError>(()) - })?; - Ok((file_groups, stats)) - } - - /// Collects statistics for a given partitioned file. - /// - /// This method first checks if the statistics for the given file are already cached. - /// If they are, it returns the cached statistics. - /// If they are not, it infers the statistics from the file and stores them in the cache. - async fn do_collect_statistics( - &self, - ctx: &dyn Session, - store: &Arc, - part_file: &PartitionedFile, - ) -> Result> { - match self - .collected_statistics - .get_with_extra(&part_file.object_meta.location, &part_file.object_meta) - { - Some(statistics) => Ok(statistics), - None => { - let statistics = self - .options - .format - .infer_stats( - ctx, - store, - Arc::clone(&self.file_schema), - &part_file.object_meta, - ) - .await?; - let statistics = Arc::new(statistics); - self.collected_statistics.put_with_extra( - &part_file.object_meta.location, - Arc::clone(&statistics), - &part_file.object_meta, - ); - Ok(statistics) - } - } - } -} - -/// Processes a stream of partitioned files and returns a `FileGroup` containing the files. -/// -/// This function collects files from the provided stream until either: -/// 1. The stream is exhausted -/// 2. The accumulated number of rows exceeds the provided `limit` (if specified) -/// -/// # Arguments -/// * `files` - A stream of `Result` items to process -/// * `limit` - An optional row count limit. If provided, the function will stop collecting files -/// once the accumulated number of rows exceeds this limit -/// * `collect_stats` - Whether to collect and accumulate statistics from the files -/// -/// # Returns -/// A `Result` containing a `FileGroup` with the collected files -/// and a boolean indicating whether the statistics are inexact. -/// -/// # Note -/// The function will continue processing files if statistics are not available or if the -/// limit is not provided. If `collect_stats` is false, statistics won't be accumulated -/// but files will still be collected. -async fn get_files_with_limit( - files: impl Stream>, - limit: Option, - collect_stats: bool, -) -> Result<(FileGroup, bool)> { - let mut file_group = FileGroup::default(); - // Fusing the stream allows us to call next safely even once it is finished. - let mut all_files = Box::pin(files.fuse()); - enum ProcessingState { - ReadingFiles, - ReachedLimit, - } - - let mut state = ProcessingState::ReadingFiles; - let mut num_rows = Precision::Absent; - - while let Some(file_result) = all_files.next().await { - // Early exit if we've already reached our limit - if matches!(state, ProcessingState::ReachedLimit) { - break; - } - - let file = file_result?; - - // Update file statistics regardless of state - if collect_stats { - if let Some(file_stats) = &file.statistics { - num_rows = if file_group.is_empty() { - // For the first file, just take its row count - file_stats.num_rows - } else { - // For subsequent files, accumulate the counts - num_rows.add(&file_stats.num_rows) - }; - } - } + // Verify sources in actual ListingTable objects + let table_specified = ListingTable::try_new(config_with_schema_and_options)?; + assert_eq!(table_specified.schema_source(), SchemaSource::Specified); - // Always add the file to our group - file_group.push(file); + let table_inferred = ListingTable::try_new(config_with_inferred)?; + assert_eq!(table_inferred.schema_source(), SchemaSource::Inferred); - // Check if we've hit the limit (if one was specified) - if let Some(limit) = limit { - if let Precision::Exact(row_count) = num_rows { - if row_count > limit { - state = ProcessingState::ReachedLimit; - } - } - } + Ok(()) } - // If we still have files in the stream, it means that the limit kicked - // in, and the statistic could have been different had we processed the - // files in a different order. - let inexact_stats = all_files.next().await.is_some(); - Ok((file_group, inexact_stats)) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::datasource::file_format::csv::CsvFormat; - use crate::datasource::file_format::json::JsonFormat; - #[cfg(feature = "parquet")] - use crate::datasource::file_format::parquet::ParquetFormat; - use crate::datasource::{provider_as_source, DefaultTableSource, MemTable}; - use crate::execution::options::ArrowReadOptions; - use crate::prelude::*; - use crate::test::{columns, object_store::register_test_store}; - - use arrow::compute::SortOptions; - use arrow::record_batch::RecordBatch; - use datafusion_common::stats::Precision; - use datafusion_common::test_util::batches_to_string; - use datafusion_common::{assert_contains, ScalarValue}; - use datafusion_expr::{BinaryExpr, LogicalPlanBuilder, Operator}; - use datafusion_physical_expr::PhysicalSortExpr; - use datafusion_physical_plan::collect; - use datafusion_physical_plan::ExecutionPlanProperties; - - use crate::test::object_store::{ensure_head_concurrency, make_test_store_and_state}; - use tempfile::TempDir; - use url::Url; #[tokio::test] async fn read_single_file() -> Result<()> { @@ -1316,86 +251,7 @@ mod tests { ); assert_eq!( exec.partition_statistics(None)?.total_byte_size, - Precision::Exact(671) - ); - - Ok(()) - } - - #[cfg(feature = "parquet")] - #[tokio::test] - async fn do_not_load_table_stats_by_default() -> Result<()> { - use crate::datasource::file_format::parquet::ParquetFormat; - - let testdata = crate::test_util::parquet_test_data(); - let filename = format!("{}/{}", testdata, "alltypes_plain.parquet"); - let table_path = ListingTableUrl::parse(filename).unwrap(); - - let ctx = SessionContext::new(); - let state = ctx.state(); - - let opt = ListingOptions::new(Arc::new(ParquetFormat::default())); - let schema = opt.infer_schema(&state, &table_path).await?; - let config = ListingTableConfig::new(table_path.clone()) - .with_listing_options(opt) - .with_schema(schema); - let table = ListingTable::try_new(config)?; - - let exec = table.scan(&state, None, &[], None).await?; - assert_eq!(exec.partition_statistics(None)?.num_rows, Precision::Absent); - // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 - assert_eq!( - exec.partition_statistics(None)?.total_byte_size, - Precision::Absent - ); - - let opt = ListingOptions::new(Arc::new(ParquetFormat::default())) - .with_collect_stat(true); - let schema = opt.infer_schema(&state, &table_path).await?; - let config = ListingTableConfig::new(table_path) - .with_listing_options(opt) - .with_schema(schema); - let table = ListingTable::try_new(config)?; - - let exec = table.scan(&state, None, &[], None).await?; - assert_eq!( - exec.partition_statistics(None)?.num_rows, - Precision::Exact(8) - ); - // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 - assert_eq!( - exec.partition_statistics(None)?.total_byte_size, - Precision::Exact(671) - ); - - Ok(()) - } - - #[cfg(feature = "parquet")] - #[tokio::test] - async fn load_table_stats_when_no_stats() -> Result<()> { - use crate::datasource::file_format::parquet::ParquetFormat; - - let testdata = crate::test_util::parquet_test_data(); - let filename = format!("{}/{}", testdata, "alltypes_plain.parquet"); - let table_path = ListingTableUrl::parse(filename).unwrap(); - - let ctx = SessionContext::new(); - let state = ctx.state(); - - let opt = ListingOptions::new(Arc::new(ParquetFormat::default())) - .with_collect_stat(false); - let schema = opt.infer_schema(&state, &table_path).await?; - let config = ListingTableConfig::new(table_path) - .with_listing_options(opt) - .with_schema(schema); - let table = ListingTable::try_new(config)?; - - let exec = table.scan(&state, None, &[], None).await?; - assert_eq!(exec.partition_statistics(None)?.num_rows, Precision::Absent); - assert_eq!( - exec.partition_statistics(None)?.total_byte_size, - Precision::Absent + Precision::Absent, ); Ok(()) @@ -1415,31 +271,48 @@ mod tests { use crate::datasource::file_format::parquet::ParquetFormat; use datafusion_physical_plan::expressions::col as physical_col; + use datafusion_physical_plan::expressions::lit as physical_lit; use std::ops::Add; // (file_sort_order, expected_result) let cases = vec![ - (vec![], Ok(vec![])), + ( + vec![], + Ok::, DataFusionError>(Vec::::new()), + ), // sort expr, but non column ( - vec![vec![ - col("int_col").add(lit(1)).sort(true, true), - ]], - Err("Expected single column reference in sort_order[0][0], got int_col + Int32(1)"), + vec![vec![col("int_col").add(lit(1)).sort(true, true)]], + Ok(vec![ + [PhysicalSortExpr { + expr: binary( + physical_col("int_col", &schema).unwrap(), + Operator::Plus, + physical_lit(1), + &schema, + ) + .unwrap(), + options: SortOptions { + descending: false, + nulls_first: true, + }, + }] + .into(), + ]), ), // ok with one column ( vec![vec![col("string_col").sort(true, false)]], - Ok(vec![LexOrdering::new( - vec![PhysicalSortExpr { - expr: physical_col("string_col", &schema).unwrap(), - options: SortOptions { - descending: false, - nulls_first: false, - }, - }], - ) - ]) + Ok(vec![ + [PhysicalSortExpr { + expr: physical_col("string_col", &schema).unwrap(), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }] + .into(), + ]), ), // ok with two columns, different options ( @@ -1447,17 +320,21 @@ mod tests { col("string_col").sort(true, false), col("int_col").sort(false, true), ]], - Ok(vec![LexOrdering::new( - vec![ - PhysicalSortExpr::new_default(physical_col("string_col", &schema).unwrap()) - .asc() - .nulls_last(), - PhysicalSortExpr::new_default(physical_col("int_col", &schema).unwrap()) - .desc() - .nulls_first() - ], - ) - ]) + Ok(vec![ + [ + PhysicalSortExpr::new_default( + physical_col("string_col", &schema).unwrap(), + ) + .asc() + .nulls_last(), + PhysicalSortExpr::new_default( + physical_col("int_col", &schema).unwrap(), + ) + .desc() + .nulls_first(), + ] + .into(), + ]), ), ]; @@ -1470,7 +347,8 @@ mod tests { let table = ListingTable::try_new(config.clone()).expect("Creating the table"); - let ordering_result = table.try_create_output_ordering(); + let ordering_result = + table.try_create_output_ordering(state.execution_props(), &[]); match (expected_result, ordering_result) { (Ok(expected), Ok(result)) => { @@ -1505,290 +383,33 @@ mod tests { .with_table_partition_cols(vec![(String::from("p1"), DataType::Utf8)]) .with_target_partitions(4); - let table_path = ListingTableUrl::parse("test:///table/").unwrap(); + let table_path = ListingTableUrl::parse("test:///table/")?; let file_schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Boolean, false)])); let config = ListingTableConfig::new(table_path) .with_listing_options(opt) - .with_schema(file_schema); - let table = ListingTable::try_new(config)?; - - assert_eq!( - columns(&table.schema()), - vec!["a".to_owned(), "p1".to_owned()] - ); - - // this will filter out the only file in the store - let filter = Expr::not_eq(col("p1"), lit("v1")); - - let scan = table - .scan(&ctx.state(), None, &[filter], None) - .await - .expect("Empty execution plan"); - - assert!(scan.as_any().is::()); - assert_eq!( - columns(&scan.schema()), - vec!["a".to_owned(), "p1".to_owned()] - ); - - Ok(()) - } - - #[tokio::test] - async fn test_assert_list_files_for_scan_grouping() -> Result<()> { - // more expected partitions than files - assert_list_files_for_scan_grouping( - &[ - "bucket/key-prefix/file0", - "bucket/key-prefix/file1", - "bucket/key-prefix/file2", - "bucket/key-prefix/file3", - "bucket/key-prefix/file4", - ], - "test:///bucket/key-prefix/", - 12, - 5, - Some(""), - ) - .await?; - - // as many expected partitions as files - assert_list_files_for_scan_grouping( - &[ - "bucket/key-prefix/file0", - "bucket/key-prefix/file1", - "bucket/key-prefix/file2", - "bucket/key-prefix/file3", - ], - "test:///bucket/key-prefix/", - 4, - 4, - Some(""), - ) - .await?; - - // more files as expected partitions - assert_list_files_for_scan_grouping( - &[ - "bucket/key-prefix/file0", - "bucket/key-prefix/file1", - "bucket/key-prefix/file2", - "bucket/key-prefix/file3", - "bucket/key-prefix/file4", - ], - "test:///bucket/key-prefix/", - 2, - 2, - Some(""), - ) - .await?; - - // no files => no groups - assert_list_files_for_scan_grouping( - &[], - "test:///bucket/key-prefix/", - 2, - 0, - Some(""), - ) - .await?; - - // files that don't match the prefix - assert_list_files_for_scan_grouping( - &[ - "bucket/key-prefix/file0", - "bucket/key-prefix/file1", - "bucket/other-prefix/roguefile", - ], - "test:///bucket/key-prefix/", - 10, - 2, - Some(""), - ) - .await?; - - // files that don't match the prefix or the default file extention - assert_list_files_for_scan_grouping( - &[ - "bucket/key-prefix/file0.json", - "bucket/key-prefix/file1.parquet", - "bucket/other-prefix/roguefile.json", - ], - "test:///bucket/key-prefix/", - 10, - 1, - None, - ) - .await?; - Ok(()) - } - - #[tokio::test] - async fn test_assert_list_files_for_multi_path() -> Result<()> { - // more expected partitions than files - assert_list_files_for_multi_paths( - &[ - "bucket/key1/file0", - "bucket/key1/file1", - "bucket/key1/file2", - "bucket/key2/file3", - "bucket/key2/file4", - "bucket/key3/file5", - ], - &["test:///bucket/key1/", "test:///bucket/key2/"], - 12, - 5, - Some(""), - ) - .await?; - - // as many expected partitions as files - assert_list_files_for_multi_paths( - &[ - "bucket/key1/file0", - "bucket/key1/file1", - "bucket/key1/file2", - "bucket/key2/file3", - "bucket/key2/file4", - "bucket/key3/file5", - ], - &["test:///bucket/key1/", "test:///bucket/key2/"], - 5, - 5, - Some(""), - ) - .await?; - - // more files as expected partitions - assert_list_files_for_multi_paths( - &[ - "bucket/key1/file0", - "bucket/key1/file1", - "bucket/key1/file2", - "bucket/key2/file3", - "bucket/key2/file4", - "bucket/key3/file5", - ], - &["test:///bucket/key1/"], - 2, - 2, - Some(""), - ) - .await?; - - // no files => no groups - assert_list_files_for_multi_paths(&[], &["test:///bucket/key1/"], 2, 0, Some("")) - .await?; - - // files that don't match the prefix - assert_list_files_for_multi_paths( - &[ - "bucket/key1/file0", - "bucket/key1/file1", - "bucket/key1/file2", - "bucket/key2/file3", - "bucket/key2/file4", - "bucket/key3/file5", - ], - &["test:///bucket/key3/"], - 2, - 1, - Some(""), - ) - .await?; - - // files that don't match the prefix or the default file ext - assert_list_files_for_multi_paths( - &[ - "bucket/key1/file0.json", - "bucket/key1/file1.csv", - "bucket/key1/file2.json", - "bucket/key2/file3.csv", - "bucket/key2/file4.json", - "bucket/key3/file5.csv", - ], - &["test:///bucket/key1/", "test:///bucket/key3/"], - 2, - 2, - None, - ) - .await?; - Ok(()) - } + .with_schema(file_schema); + let table = ListingTable::try_new(config)?; - #[tokio::test] - async fn test_assert_list_files_for_exact_paths() -> Result<()> { - // more expected partitions than files - assert_list_files_for_exact_paths( - &[ - "bucket/key1/file0", - "bucket/key1/file1", - "bucket/key1/file2", - "bucket/key2/file3", - "bucket/key2/file4", - ], - 12, - 5, - Some(""), - ) - .await?; + assert_eq!( + columns(&table.schema()), + vec!["a".to_owned(), "p1".to_owned()] + ); - // more files than meta_fetch_concurrency (32) - let files: Vec = - (0..64).map(|i| format!("bucket/key1/file{i}")).collect(); - // Collect references to each string - let file_refs: Vec<&str> = files.iter().map(|s| s.as_str()).collect(); - assert_list_files_for_exact_paths(file_refs.as_slice(), 5, 5, Some("")).await?; - - // as many expected partitions as files - assert_list_files_for_exact_paths( - &[ - "bucket/key1/file0", - "bucket/key1/file1", - "bucket/key1/file2", - "bucket/key2/file3", - "bucket/key2/file4", - ], - 5, - 5, - Some(""), - ) - .await?; + // this will filter out the only file in the store + let filter = Expr::not_eq(col("p1"), lit("v1")); - // more files as expected partitions - assert_list_files_for_exact_paths( - &[ - "bucket/key1/file0", - "bucket/key1/file1", - "bucket/key1/file2", - "bucket/key2/file3", - "bucket/key2/file4", - ], - 2, - 2, - Some(""), - ) - .await?; + let scan = table + .scan(&ctx.state(), None, &[filter], None) + .await + .expect("Empty execution plan"); + + assert!(scan.as_any().is::()); + assert_eq!( + columns(&scan.schema()), + vec!["a".to_owned(), "p1".to_owned()] + ); - // no files => no groups - assert_list_files_for_exact_paths(&[], 2, 0, Some("")).await?; - - // files that don't match the default file ext - assert_list_files_for_exact_paths( - &[ - "bucket/key1/file0.json", - "bucket/key1/file1.csv", - "bucket/key1/file2.json", - "bucket/key2/file3.csv", - "bucket/key2/file4.json", - "bucket/key3/file5.csv", - ], - 2, - 2, - None, - ) - .await?; Ok(()) } @@ -1798,7 +419,7 @@ mod tests { ) -> Result> { let testdata = crate::test_util::parquet_test_data(); let filename = format!("{testdata}/{name}"); - let table_path = ListingTableUrl::parse(filename).unwrap(); + let table_path = ListingTableUrl::parse(filename)?; let config = ListingTableConfig::new(table_path) .infer(&ctx.state()) @@ -1825,16 +446,16 @@ mod tests { let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]); - let table_path = ListingTableUrl::parse(table_prefix).unwrap(); + let table_path = ListingTableUrl::parse(table_prefix)?; let config = ListingTableConfig::new(table_path) .with_listing_options(opt) .with_schema(Arc::new(schema)); let table = ListingTable::try_new(config)?; - let (file_list, _) = table.list_files_for_scan(&ctx.state(), &[], None).await?; + let result = table.list_files_for_scan(&ctx.state(), &[], None).await?; - assert_eq!(file_list.len(), output_partitioning); + assert_eq!(result.file_groups.len(), output_partitioning); Ok(()) } @@ -1867,9 +488,9 @@ mod tests { let table = ListingTable::try_new(config)?; - let (file_list, _) = table.list_files_for_scan(&ctx.state(), &[], None).await?; + let result = table.list_files_for_scan(&ctx.state(), &[], None).await?; - assert_eq!(file_list.len(), output_partitioning); + assert_eq!(result.file_groups.len(), output_partitioning); Ok(()) } @@ -1894,10 +515,10 @@ mod tests { .execution .meta_fetch_concurrency; let expected_concurrency = files.len().min(meta_fetch_concurrency); - let head_blocking_store = ensure_head_concurrency(store, expected_concurrency); + let head_concurrency_store = ensure_head_concurrency(store, expected_concurrency); let url = Url::parse("test://").unwrap(); - ctx.register_object_store(&url, head_blocking_store.clone()); + ctx.register_object_store(&url, head_concurrency_store.clone()); let format = JsonFormat::default(); @@ -1917,84 +538,10 @@ mod tests { let table = ListingTable::try_new(config)?; - let (file_list, _) = table.list_files_for_scan(&ctx.state(), &[], None).await?; - - assert_eq!(file_list.len(), output_partitioning); - - Ok(()) - } - - #[tokio::test] - async fn test_insert_into_append_new_json_files() -> Result<()> { - let mut config_map: HashMap = HashMap::new(); - config_map.insert("datafusion.execution.batch_size".into(), "10".into()); - config_map.insert( - "datafusion.execution.soft_max_rows_per_output_file".into(), - "10".into(), - ); - helper_test_append_new_files_to_table( - JsonFormat::default().get_ext(), - FileCompressionType::UNCOMPRESSED, - Some(config_map), - 2, - ) - .await?; - Ok(()) - } - - #[tokio::test] - async fn test_insert_into_append_new_csv_files() -> Result<()> { - let mut config_map: HashMap = HashMap::new(); - config_map.insert("datafusion.execution.batch_size".into(), "10".into()); - config_map.insert( - "datafusion.execution.soft_max_rows_per_output_file".into(), - "10".into(), - ); - helper_test_append_new_files_to_table( - CsvFormat::default().get_ext(), - FileCompressionType::UNCOMPRESSED, - Some(config_map), - 2, - ) - .await?; - Ok(()) - } + let result = table.list_files_for_scan(&ctx.state(), &[], None).await?; - #[cfg(feature = "parquet")] - #[tokio::test] - async fn test_insert_into_append_2_new_parquet_files_defaults() -> Result<()> { - let mut config_map: HashMap = HashMap::new(); - config_map.insert("datafusion.execution.batch_size".into(), "10".into()); - config_map.insert( - "datafusion.execution.soft_max_rows_per_output_file".into(), - "10".into(), - ); - helper_test_append_new_files_to_table( - ParquetFormat::default().get_ext(), - FileCompressionType::UNCOMPRESSED, - Some(config_map), - 2, - ) - .await?; - Ok(()) - } + assert_eq!(result.file_groups.len(), output_partitioning); - #[cfg(feature = "parquet")] - #[tokio::test] - async fn test_insert_into_append_1_new_parquet_files_defaults() -> Result<()> { - let mut config_map: HashMap = HashMap::new(); - config_map.insert("datafusion.execution.batch_size".into(), "20".into()); - config_map.insert( - "datafusion.execution.soft_max_rows_per_output_file".into(), - "20".into(), - ); - helper_test_append_new_files_to_table( - ParquetFormat::default().get_ext(), - FileCompressionType::UNCOMPRESSED, - Some(config_map), - 1, - ) - .await?; Ok(()) } @@ -2108,7 +655,6 @@ mod tests { #[tokio::test] async fn test_insert_into_append_new_parquet_files_session_overrides() -> Result<()> { let mut config_map: HashMap = HashMap::new(); - config_map.insert("datafusion.execution.batch_size".into(), "10".into()); config_map.insert( "datafusion.execution.soft_max_rows_per_output_file".into(), "10".into(), @@ -2173,7 +719,7 @@ mod tests { "datafusion.execution.parquet.write_batch_size".into(), "5".into(), ); - config_map.insert("datafusion.execution.batch_size".into(), "1".into()); + config_map.insert("datafusion.execution.batch_size".into(), "10".into()); helper_test_append_new_files_to_table( ParquetFormat::default().get_ext(), FileCompressionType::UNCOMPRESSED, @@ -2185,8 +731,8 @@ mod tests { } #[tokio::test] - async fn test_insert_into_append_new_parquet_files_invalid_session_fails( - ) -> Result<()> { + async fn test_insert_into_append_new_parquet_files_invalid_session_fails() + -> Result<()> { let mut config_map: HashMap = HashMap::new(); config_map.insert( "datafusion.execution.parquet.compression".into(), @@ -2200,7 +746,10 @@ mod tests { ) .await .expect_err("Example should fail!"); - assert_eq!(e.strip_backtrace(), "Invalid or Unsupported Configuration: zstd compression requires specifying a level such as zstd(4)"); + assert_eq!( + e.strip_backtrace(), + "Invalid or Unsupported Configuration: zstd compression requires specifying a level such as zstd(4)" + ); Ok(()) } @@ -2230,7 +779,7 @@ mod tests { let filter_predicate = Expr::BinaryExpr(BinaryExpr::new( Box::new(Expr::Column("column1".into())), Operator::GtEq, - Box::new(Expr::Literal(ScalarValue::Int32(Some(0)))), + Box::new(Expr::Literal(ScalarValue::Int32(Some(0)), None)), )); // Create a new batch of data to insert into the table @@ -2260,7 +809,7 @@ mod tests { .register_json( "t", tmp_dir.path().to_str().unwrap(), - NdJsonReadOptions::default() + JsonReadOptions::default() .schema(schema.as_ref()) .file_compression_type(file_compression_type), ) @@ -2327,13 +876,13 @@ mod tests { let res = collect(plan, session_ctx.task_ctx()).await?; // Insert returns the number of rows written, in our case this would be 6. - insta::allow_duplicates! {insta::assert_snapshot!(batches_to_string(&res),@r###" - +-------+ - | count | - +-------+ - | 20 | - +-------+ - "###);} + insta::allow_duplicates! {insta::assert_snapshot!(batches_to_string(&res),@r" + +-------+ + | count | + +-------+ + | 20 | + +-------+ + ");} // Read the records in the table let batches = session_ctx @@ -2342,13 +891,13 @@ mod tests { .collect() .await?; - insta::allow_duplicates! {insta::assert_snapshot!(batches_to_string(&batches),@r###" - +-------+ - | count | - +-------+ - | 20 | - +-------+ - "###);} + insta::allow_duplicates! {insta::assert_snapshot!(batches_to_string(&batches),@r" + +-------+ + | count | + +-------+ + | 20 | + +-------+ + ");} // Assert that `target_partition_number` many files were added to the table. let num_files = tmp_dir.path().read_dir()?.count(); @@ -2363,13 +912,13 @@ mod tests { // Again, execute the physical plan and collect the results let res = collect(plan, session_ctx.task_ctx()).await?; - insta::allow_duplicates! {insta::assert_snapshot!(batches_to_string(&res),@r###" - +-------+ - | count | - +-------+ - | 20 | - +-------+ - "###);} + insta::allow_duplicates! {insta::assert_snapshot!(batches_to_string(&res),@r" + +-------+ + | count | + +-------+ + | 20 | + +-------+ + ");} // Read the contents of the table let batches = session_ctx @@ -2378,13 +927,13 @@ mod tests { .collect() .await?; - insta::allow_duplicates! {insta::assert_snapshot!(batches_to_string(&batches),@r###" - +-------+ - | count | - +-------+ - | 40 | - +-------+ - "###);} + insta::allow_duplicates! {insta::assert_snapshot!(batches_to_string(&batches),@r" + +-------+ + | count | + +-------+ + | 40 | + +-------+ + ");} // Assert that another `target_partition_number` many files were added to the table. let num_files = tmp_dir.path().read_dir()?.count(); @@ -2442,15 +991,15 @@ mod tests { .collect() .await?; - insta::allow_duplicates! {insta::assert_snapshot!(batches_to_string(&batches),@r###" - +-----+-----+---+ - | a | b | c | - +-----+-----+---+ - | foo | bar | 1 | - | foo | bar | 2 | - | foo | bar | 3 | - +-----+-----+---+ - "###);} + insta::allow_duplicates! {insta::assert_snapshot!(batches_to_string(&batches),@r" + +-----+-----+---+ + | a | b | c | + +-----+-----+---+ + | foo | bar | 1 | + | foo | bar | 2 | + | foo | bar | 3 | + +-----+-----+---+ + ");} Ok(()) } @@ -2459,7 +1008,7 @@ mod tests { async fn test_infer_options_compressed_csv() -> Result<()> { let testdata = crate::test_util::arrow_test_data(); let filename = format!("{testdata}/csv/aggregate_test_100.csv.gz"); - let table_path = ListingTableUrl::parse(filename).unwrap(); + let table_path = ListingTableUrl::parse(filename)?; let ctx = SessionContext::new(); @@ -2473,4 +1022,467 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn infer_preserves_provided_schema() -> Result<()> { + let ctx = SessionContext::new(); + + let testdata = datafusion_test_data(); + let filename = format!("{testdata}/aggregate_simple.csv"); + let table_path = ListingTableUrl::parse(filename)?; + + let provided_schema = create_test_schema(); + + let format = CsvFormat::default(); + let options = ListingOptions::new(Arc::new(format)); + let config = ListingTableConfig::new(table_path) + .with_listing_options(options) + .with_schema(Arc::clone(&provided_schema)); + + let config = config.infer(&ctx.state()).await?; + + assert_eq!(*config.file_schema.unwrap(), *provided_schema); + + Ok(()) + } + + #[tokio::test] + async fn test_listing_table_config_with_multiple_files_comprehensive() -> Result<()> { + let ctx = SessionContext::new(); + + // Create test files with different schemas + let tmp_dir = TempDir::new()?; + let file_path1 = tmp_dir.path().join("file1.csv"); + let file_path2 = tmp_dir.path().join("file2.csv"); + + // File 1: c1,c2,c3 + let mut file1 = std::fs::File::create(&file_path1)?; + writeln!(file1, "c1,c2,c3")?; + writeln!(file1, "1,2,3")?; + writeln!(file1, "4,5,6")?; + + // File 2: c1,c2,c3,c4 + let mut file2 = std::fs::File::create(&file_path2)?; + writeln!(file2, "c1,c2,c3,c4")?; + writeln!(file2, "7,8,9,10")?; + writeln!(file2, "11,12,13,14")?; + + // Parse paths + let table_path1 = ListingTableUrl::parse(file_path1.to_str().unwrap())?; + let table_path2 = ListingTableUrl::parse(file_path2.to_str().unwrap())?; + + // Create format and options + let format = CsvFormat::default().with_has_header(true); + let options = ListingOptions::new(Arc::new(format)); + + // Test case 1: Infer schema using first file's schema + let config1 = ListingTableConfig::new_with_multi_paths(vec![ + table_path1.clone(), + table_path2.clone(), + ]) + .with_listing_options(options.clone()); + let config1 = config1.infer_schema(&ctx.state()).await?; + assert_eq!(config1.schema_source(), SchemaSource::Inferred); + + // Verify schema matches first file + let schema1 = config1.file_schema.as_ref().unwrap().clone(); + assert_eq!(schema1.fields().len(), 3); + assert_eq!(schema1.field(0).name(), "c1"); + assert_eq!(schema1.field(1).name(), "c2"); + assert_eq!(schema1.field(2).name(), "c3"); + + // Test case 2: Use specified schema with 3 columns + let schema_3cols = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Utf8, true), + Field::new("c2", DataType::Utf8, true), + Field::new("c3", DataType::Utf8, true), + ])); + + let config2 = ListingTableConfig::new_with_multi_paths(vec![ + table_path1.clone(), + table_path2.clone(), + ]) + .with_listing_options(options.clone()) + .with_schema(schema_3cols); + let config2 = config2.infer_schema(&ctx.state()).await?; + assert_eq!(config2.schema_source(), SchemaSource::Specified); + + // Verify that the schema is still the one we specified (3 columns) + let schema2 = config2.file_schema.as_ref().unwrap().clone(); + assert_eq!(schema2.fields().len(), 3); + assert_eq!(schema2.field(0).name(), "c1"); + assert_eq!(schema2.field(1).name(), "c2"); + assert_eq!(schema2.field(2).name(), "c3"); + + // Test case 3: Use specified schema with 4 columns + let schema_4cols = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Utf8, true), + Field::new("c2", DataType::Utf8, true), + Field::new("c3", DataType::Utf8, true), + Field::new("c4", DataType::Utf8, true), + ])); + + let config3 = ListingTableConfig::new_with_multi_paths(vec![ + table_path1.clone(), + table_path2.clone(), + ]) + .with_listing_options(options.clone()) + .with_schema(schema_4cols); + let config3 = config3.infer_schema(&ctx.state()).await?; + assert_eq!(config3.schema_source(), SchemaSource::Specified); + + // Verify that the schema is still the one we specified (4 columns) + let schema3 = config3.file_schema.as_ref().unwrap().clone(); + assert_eq!(schema3.fields().len(), 4); + assert_eq!(schema3.field(0).name(), "c1"); + assert_eq!(schema3.field(1).name(), "c2"); + assert_eq!(schema3.field(2).name(), "c3"); + assert_eq!(schema3.field(3).name(), "c4"); + + // Test case 4: Verify order matters when inferring schema + let config4 = ListingTableConfig::new_with_multi_paths(vec![ + table_path2.clone(), + table_path1.clone(), + ]) + .with_listing_options(options); + let config4 = config4.infer_schema(&ctx.state()).await?; + + // Should use first file's schema, which now has 4 columns + let schema4 = config4.file_schema.as_ref().unwrap().clone(); + assert_eq!(schema4.fields().len(), 4); + assert_eq!(schema4.field(0).name(), "c1"); + assert_eq!(schema4.field(1).name(), "c2"); + assert_eq!(schema4.field(2).name(), "c3"); + assert_eq!(schema4.field(3).name(), "c4"); + + Ok(()) + } + + #[tokio::test] + async fn test_list_files_configurations() -> Result<()> { + // Define common test cases as (description, files, paths, target_partitions, expected_partitions, file_ext) + let test_cases = vec![ + // Single path cases + ( + "Single path, more partitions than files", + generate_test_files("bucket/key-prefix", 5), + vec!["test:///bucket/key-prefix/"], + 12, + 5, + Some(""), + ), + ( + "Single path, equal partitions and files", + generate_test_files("bucket/key-prefix", 4), + vec!["test:///bucket/key-prefix/"], + 4, + 4, + Some(""), + ), + ( + "Single path, more files than partitions", + generate_test_files("bucket/key-prefix", 5), + vec!["test:///bucket/key-prefix/"], + 2, + 2, + Some(""), + ), + // Multi path cases + ( + "Multi path, more partitions than files", + { + let mut files = generate_test_files("bucket/key1", 3); + files.extend(generate_test_files_with_start("bucket/key2", 2, 3)); + files.extend(generate_test_files_with_start("bucket/key3", 1, 5)); + files + }, + vec!["test:///bucket/key1/", "test:///bucket/key2/"], + 12, + 5, + Some(""), + ), + // No files case + ( + "No files", + vec![], + vec!["test:///bucket/key-prefix/"], + 2, + 0, + Some(""), + ), + // Exact path cases + ( + "Exact paths test", + { + let mut files = generate_test_files("bucket/key1", 3); + files.extend(generate_test_files_with_start("bucket/key2", 2, 3)); + files + }, + vec![ + "test:///bucket/key1/file0", + "test:///bucket/key1/file1", + "test:///bucket/key1/file2", + "test:///bucket/key2/file3", + "test:///bucket/key2/file4", + ], + 12, + 5, + Some(""), + ), + ]; + + // Run each test case + for (test_name, files, paths, target_partitions, expected_partitions, file_ext) in + test_cases + { + println!("Running test: {test_name}"); + + if files.is_empty() { + // Test empty files case + assert_list_files_for_multi_paths( + &[], + &paths, + target_partitions, + expected_partitions, + file_ext, + ) + .await?; + } else if paths.len() == 1 { + // Test using single path API + let file_refs: Vec<&str> = files.iter().map(|s| s.as_str()).collect(); + assert_list_files_for_scan_grouping( + &file_refs, + paths[0], + target_partitions, + expected_partitions, + file_ext, + ) + .await?; + } else if paths[0].contains("test:///bucket/key") { + // Test using multi path API + let file_refs: Vec<&str> = files.iter().map(|s| s.as_str()).collect(); + assert_list_files_for_multi_paths( + &file_refs, + &paths, + target_partitions, + expected_partitions, + file_ext, + ) + .await?; + } else { + // Test using exact path API for specific cases + let file_refs: Vec<&str> = files.iter().map(|s| s.as_str()).collect(); + assert_list_files_for_exact_paths( + &file_refs, + target_partitions, + expected_partitions, + file_ext, + ) + .await?; + } + } + + Ok(()) + } + + #[tokio::test] + async fn test_listing_table_prunes_extra_files_in_hive() -> Result<()> { + let files = [ + "bucket/test/pid=1/file1", + "bucket/test/pid=1/file2", + "bucket/test/pid=2/file3", + "bucket/test/pid=2/file4", + "bucket/test/other/file5", + ]; + + let ctx = SessionContext::new(); + register_test_store(&ctx, &files.iter().map(|f| (*f, 10)).collect::>()); + + let opt = ListingOptions::new(Arc::new(JsonFormat::default())) + .with_file_extension_opt(Some("")) + .with_table_partition_cols(vec![("pid".to_string(), DataType::Int32)]); + + let table_path = ListingTableUrl::parse("test:///bucket/test/").unwrap(); + let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]); + let config = ListingTableConfig::new(table_path) + .with_listing_options(opt) + .with_schema(Arc::new(schema)); + + let table = ListingTable::try_new(config)?; + + let result = table.list_files_for_scan(&ctx.state(), &[], None).await?; + assert_eq!(result.file_groups.len(), 1); + + let files = result.file_groups[0].clone(); + + assert_eq!( + files + .iter() + .map(|f| f.path().to_string()) + .collect::>(), + vec![ + "bucket/test/pid=1/file1", + "bucket/test/pid=1/file2", + "bucket/test/pid=2/file3", + "bucket/test/pid=2/file4", + ] + ); + + Ok(()) + } + + #[cfg(feature = "parquet")] + #[tokio::test] + async fn test_table_stats_behaviors() -> Result<()> { + use crate::datasource::file_format::parquet::ParquetFormat; + + let testdata = crate::test_util::parquet_test_data(); + let filename = format!("{}/{}", testdata, "alltypes_plain.parquet"); + let table_path = ListingTableUrl::parse(filename)?; + + let ctx = SessionContext::new(); + let state = ctx.state(); + + // Test 1: Default behavior - stats not collected + let opt_default = ListingOptions::new(Arc::new(ParquetFormat::default())); + let schema_default = opt_default.infer_schema(&state, &table_path).await?; + let config_default = ListingTableConfig::new(table_path.clone()) + .with_listing_options(opt_default) + .with_schema(schema_default); + + let table_default = ListingTable::try_new(config_default)?; + + let exec_default = table_default.scan(&state, None, &[], None).await?; + assert_eq!( + exec_default.partition_statistics(None)?.num_rows, + Precision::Absent + ); + + // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 + assert_eq!( + exec_default.partition_statistics(None)?.total_byte_size, + Precision::Absent + ); + + // Test 2: Explicitly disable stats + let opt_disabled = ListingOptions::new(Arc::new(ParquetFormat::default())) + .with_collect_stat(false); + let schema_disabled = opt_disabled.infer_schema(&state, &table_path).await?; + let config_disabled = ListingTableConfig::new(table_path.clone()) + .with_listing_options(opt_disabled) + .with_schema(schema_disabled); + let table_disabled = ListingTable::try_new(config_disabled)?; + + let exec_disabled = table_disabled.scan(&state, None, &[], None).await?; + assert_eq!( + exec_disabled.partition_statistics(None)?.num_rows, + Precision::Absent + ); + assert_eq!( + exec_disabled.partition_statistics(None)?.total_byte_size, + Precision::Absent + ); + + // Test 3: Explicitly enable stats + let opt_enabled = ListingOptions::new(Arc::new(ParquetFormat::default())) + .with_collect_stat(true); + let schema_enabled = opt_enabled.infer_schema(&state, &table_path).await?; + let config_enabled = ListingTableConfig::new(table_path) + .with_listing_options(opt_enabled) + .with_schema(schema_enabled); + let table_enabled = ListingTable::try_new(config_enabled)?; + + let exec_enabled = table_enabled.scan(&state, None, &[], None).await?; + assert_eq!( + exec_enabled.partition_statistics(None)?.num_rows, + Precision::Exact(8) + ); + // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 + assert_eq!( + exec_enabled.partition_statistics(None)?.total_byte_size, + Precision::Absent, + ); + + Ok(()) + } + + #[tokio::test] + async fn test_insert_into_parameterized() -> Result<()> { + let test_cases = vec![ + // (file_format, batch_size, soft_max_rows, expected_files) + ("json", 10, 10, 2), + ("csv", 10, 10, 2), + #[cfg(feature = "parquet")] + ("parquet", 10, 10, 2), + #[cfg(feature = "parquet")] + ("parquet", 20, 20, 1), + ]; + + for (format, batch_size, soft_max_rows, expected_files) in test_cases { + println!( + "Testing insert with format: {format}, batch_size: {batch_size}, expected files: {expected_files}" + ); + + let mut config_map = HashMap::new(); + config_map.insert( + "datafusion.execution.batch_size".into(), + batch_size.to_string(), + ); + config_map.insert( + "datafusion.execution.soft_max_rows_per_output_file".into(), + soft_max_rows.to_string(), + ); + + let file_extension = match format { + "json" => JsonFormat::default().get_ext(), + "csv" => CsvFormat::default().get_ext(), + #[cfg(feature = "parquet")] + "parquet" => ParquetFormat::default().get_ext(), + _ => unreachable!("Unsupported format"), + }; + + helper_test_append_new_files_to_table( + file_extension, + FileCompressionType::UNCOMPRESSED, + Some(config_map), + expected_files, + ) + .await?; + } + + Ok(()) + } + + #[tokio::test] + async fn test_basic_table_scan() -> Result<()> { + let ctx = SessionContext::new(); + + // Test basic table creation and scanning + let path = "table/file.json"; + register_test_store(&ctx, &[(path, 10)]); + + let format = JsonFormat::default(); + let opt = ListingOptions::new(Arc::new(format)).with_collect_stat(false); + let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]); + let table_path = ListingTableUrl::parse("test:///table/")?; + + let config = ListingTableConfig::new(table_path) + .with_listing_options(opt) + .with_schema(Arc::new(schema)); + + let table = ListingTable::try_new(config)?; + + // The scan should work correctly + let scan_result = table.scan(&ctx.state(), None, &[], None).await; + assert!(scan_result.is_ok(), "Scan should succeed"); + + // Verify file listing works + let result = table.list_files_for_scan(&ctx.state(), &[], None).await?; + assert!( + !result.file_groups.is_empty(), + "Should list files successfully" + ); + + Ok(()) + } } diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index 71686c61a8f76..f85f15a6d8c63 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -27,9 +27,9 @@ use crate::datasource::listing::{ }; use crate::execution::context::SessionState; -use arrow::datatypes::{DataType, SchemaRef}; -use datafusion_common::{arrow_datafusion_err, plan_err, DataFusionError, ToDFSchema}; -use datafusion_common::{config_datafusion_err, Result}; +use arrow::datatypes::DataType; +use datafusion_common::{Result, config_datafusion_err}; +use datafusion_common::{ToDFSchema, arrow_datafusion_err, plan_err}; use datafusion_expr::CreateExternalTable; use async_trait::async_trait; @@ -54,7 +54,15 @@ impl TableProviderFactory for ListingTableFactory { cmd: &CreateExternalTable, ) -> Result> { // TODO (https://github.com/apache/datafusion/issues/11600) remove downcast_ref from here. Should file format factory be an extension to session state? - let session_state = state.as_any().downcast_ref::().unwrap(); + let session_state = + state + .as_any() + .downcast_ref::() + .ok_or_else(|| { + datafusion_common::internal_datafusion_err!( + "ListingTableFactory requires SessionState" + ) + })?; let file_format = session_state .get_file_format_factory(cmd.file_type.as_str()) .ok_or(config_datafusion_err!( @@ -63,16 +71,40 @@ impl TableProviderFactory for ListingTableFactory { ))? .create(session_state, &cmd.options)?; - let file_extension = get_extension(cmd.location.as_str()); + let mut table_path = + ListingTableUrl::parse(&cmd.location)?.with_table_ref(cmd.name.clone()); + let file_extension = match table_path.is_collection() { + // Setting the extension to be empty instead of allowing the default extension seems + // odd, but was done to ensure existing behavior isn't modified. It seems like this + // could be refactored to either use the default extension or set the fully expected + // extension when compression is included (e.g. ".csv.gz") + true => "", + false => &get_extension(cmd.location.as_str()), + }; + let mut options = ListingOptions::new(file_format) + .with_session_config_options(session_state.config()) + .with_file_extension(file_extension); let (provided_schema, table_partition_cols) = if cmd.schema.fields().is_empty() { + let infer_parts = session_state + .config_options() + .execution + .listing_table_factory_infer_partitions; + let part_cols = if cmd.table_partition_cols.is_empty() && infer_parts { + options + .infer_partitions(session_state, &table_path) + .await? + .into_iter() + } else { + cmd.table_partition_cols.clone().into_iter() + }; + ( None, - cmd.table_partition_cols - .iter() - .map(|x| { + part_cols + .map(|p| { ( - x.clone(), + p, DataType::Dictionary( Box::new(DataType::UInt16), Box::new(DataType::Utf8), @@ -82,7 +114,7 @@ impl TableProviderFactory for ListingTableFactory { .collect::>(), ) } else { - let schema: SchemaRef = Arc::new(cmd.schema.as_ref().to_owned().into()); + let schema = Arc::clone(cmd.schema.inner()); let table_partition_cols = cmd .table_partition_cols .iter() @@ -108,12 +140,7 @@ impl TableProviderFactory for ListingTableFactory { (Some(schema), table_partition_cols) }; - let table_path = ListingTableUrl::parse(&cmd.location)?; - - let options = ListingOptions::new(file_format) - .with_file_extension(file_extension) - .with_session_config_options(session_state.config()) - .with_table_partition_cols(table_partition_cols); + options = options.with_table_partition_cols(table_partition_cols); options .validate_partitions(session_state, &table_path) @@ -125,6 +152,25 @@ impl TableProviderFactory for ListingTableFactory { // specifically for parquet file format. // See: https://github.com/apache/datafusion/issues/7317 None => { + // if the folder then rewrite a file path as 'path/*.parquet' + // to only read the files the reader can understand + if table_path.is_folder() && table_path.get_glob().is_none() { + // Since there are no files yet to infer an actual extension, + // derive the pattern based on compression type. + // So for gzipped CSV the pattern is `*.csv.gz` + let glob = match options.format.compression_type() { + Some(compression) => { + match options.format.get_ext_with_compression(&compression) { + // Use glob based on `FileFormat` extension + Ok(ext) => format!("*.{ext}"), + // Fallback to `file_type`, if not supported by `FileFormat` + Err(_) => format!("*.{}", cmd.file_type.to_lowercase()), + } + } + None => format!("*.{}", cmd.file_type.to_lowercase()), + }; + table_path = table_path.with_glob(glob.as_ref())?; + } let schema = options.infer_schema(session_state, &table_path).await?; let df_schema = Arc::clone(&schema).to_dfschema()?; let column_refs: HashSet<_> = cmd @@ -153,6 +199,16 @@ impl TableProviderFactory for ListingTableFactory { .with_definition(cmd.definition.clone()) .with_constraints(cmd.constraints.clone()) .with_column_defaults(cmd.column_defaults.clone()); + + // Pre-warm statistics cache if collect_statistics is enabled + if session_state.config().collect_statistics() { + let filters = &[]; + let limit = None; + if let Err(e) = table.list_files_for_scan(state, filters, limit).await { + log::warn!("Failed to pre-warm statistics cache: {e}"); + } + } + Ok(Arc::new(table)) } } @@ -168,14 +224,23 @@ fn get_extension(path: &str) -> String { #[cfg(test)] mod tests { - use std::collections::HashMap; - use super::*; use crate::{ datasource::file_format::csv::CsvFormat, execution::context::SessionContext, + test_util::parquet_test_data, }; + use datafusion_execution::cache::CacheAccessor; + use datafusion_execution::cache::cache_manager::CacheManagerConfig; + use datafusion_execution::cache::cache_unit::DefaultFileStatisticsCache; + use datafusion_execution::config::SessionConfig; + use datafusion_execution::runtime_env::RuntimeEnvBuilder; + use glob::Pattern; + use std::collections::HashMap; + use std::fs; + use std::path::PathBuf; - use datafusion_common::{Constraints, DFSchema, TableReference}; + use datafusion_common::parsers::CompressionTypeVariant; + use datafusion_common::{DFSchema, TableReference}; #[tokio::test] async fn test_create_using_non_std_file_ext() { @@ -189,21 +254,14 @@ mod tests { let context = SessionContext::new(); let state = context.state(); let name = TableReference::bare("foo"); - let cmd = CreateExternalTable { + let cmd = CreateExternalTable::builder( name, - location: csv_file.path().to_str().unwrap().to_string(), - file_type: "csv".to_string(), - schema: Arc::new(DFSchema::empty()), - table_partition_cols: vec![], - if_not_exists: false, - temporary: false, - definition: None, - order_exprs: vec![], - unbounded: false, - options: HashMap::from([("format.has_header".into(), "true".into())]), - constraints: Constraints::empty(), - column_defaults: HashMap::new(), - }; + csv_file.path().to_str().unwrap().to_string(), + "csv", + Arc::new(DFSchema::empty()), + ) + .with_options(HashMap::from([("format.has_header".into(), "true".into())])) + .build(); let table_provider = factory.create(&state, &cmd).await.unwrap(); let listing_table = table_provider .as_any() @@ -229,21 +287,14 @@ mod tests { let mut options = HashMap::new(); options.insert("format.schema_infer_max_rec".to_owned(), "1000".to_owned()); options.insert("format.has_header".into(), "true".into()); - let cmd = CreateExternalTable { + let cmd = CreateExternalTable::builder( name, - location: csv_file.path().to_str().unwrap().to_string(), - file_type: "csv".to_string(), - schema: Arc::new(DFSchema::empty()), - table_partition_cols: vec![], - if_not_exists: false, - temporary: false, - definition: None, - order_exprs: vec![], - unbounded: false, - options, - constraints: Constraints::empty(), - column_defaults: HashMap::new(), - }; + csv_file.path().to_str().unwrap().to_string(), + "csv", + Arc::new(DFSchema::empty()), + ) + .with_options(options) + .build(); let table_provider = factory.create(&state, &cmd).await.unwrap(); let listing_table = table_provider .as_any() @@ -257,4 +308,349 @@ mod tests { let listing_options = listing_table.options(); assert_eq!(".tbl", listing_options.file_extension); } + + /// Validates that CreateExternalTable with compression + /// searches for gzipped files in a directory location + #[tokio::test] + async fn test_create_using_folder_with_compression() { + let dir = tempfile::tempdir().unwrap(); + + let factory = ListingTableFactory::new(); + let context = SessionContext::new(); + let state = context.state(); + let name = TableReference::bare("foo"); + + let mut options = HashMap::new(); + options.insert("format.schema_infer_max_rec".to_owned(), "1000".to_owned()); + options.insert("format.has_header".into(), "true".into()); + options.insert("format.compression".into(), "gzip".into()); + let cmd = CreateExternalTable::builder( + name, + dir.path().to_str().unwrap().to_string(), + "csv", + Arc::new(DFSchema::empty()), + ) + .with_options(options) + .build(); + let table_provider = factory.create(&state, &cmd).await.unwrap(); + let listing_table = table_provider + .as_any() + .downcast_ref::() + .unwrap(); + + // Verify compression is used + let format = listing_table.options().format.clone(); + let csv_format = format.as_any().downcast_ref::().unwrap(); + let csv_options = csv_format.options().clone(); + assert_eq!(csv_options.compression, CompressionTypeVariant::GZIP); + + let listing_options = listing_table.options(); + assert_eq!("", listing_options.file_extension); + // Glob pattern is set to search for gzipped files + let table_path = listing_table.table_paths().first().unwrap(); + assert_eq!( + table_path.get_glob().clone().unwrap(), + Pattern::new("*.csv.gz").unwrap() + ); + } + + /// Validates that CreateExternalTable without compression + /// searches for normal files in a directory location + #[tokio::test] + async fn test_create_using_folder_without_compression() { + let dir = tempfile::tempdir().unwrap(); + + let factory = ListingTableFactory::new(); + let context = SessionContext::new(); + let state = context.state(); + let name = TableReference::bare("foo"); + + let mut options = HashMap::new(); + options.insert("format.schema_infer_max_rec".to_owned(), "1000".to_owned()); + options.insert("format.has_header".into(), "true".into()); + let cmd = CreateExternalTable::builder( + name, + dir.path().to_str().unwrap().to_string(), + "csv", + Arc::new(DFSchema::empty()), + ) + .with_options(options) + .build(); + let table_provider = factory.create(&state, &cmd).await.unwrap(); + let listing_table = table_provider + .as_any() + .downcast_ref::() + .unwrap(); + + let listing_options = listing_table.options(); + assert_eq!("", listing_options.file_extension); + // Glob pattern is set to search for gzipped files + let table_path = listing_table.table_paths().first().unwrap(); + assert_eq!( + table_path.get_glob().clone().unwrap(), + Pattern::new("*.csv").unwrap() + ); + } + + #[tokio::test] + async fn test_odd_directory_names() { + let dir = tempfile::tempdir().unwrap(); + let mut path = PathBuf::from(dir.path()); + path.extend(["odd.v1", "odd.v2"]); + fs::create_dir_all(&path).unwrap(); + + let factory = ListingTableFactory::new(); + let context = SessionContext::new(); + let state = context.state(); + let name = TableReference::bare("foo"); + + let cmd = CreateExternalTable::builder( + name, + String::from(path.to_str().unwrap()), + "parquet", + Arc::new(DFSchema::empty()), + ) + .build(); + let table_provider = factory.create(&state, &cmd).await.unwrap(); + let listing_table = table_provider + .as_any() + .downcast_ref::() + .unwrap(); + + let listing_options = listing_table.options(); + assert_eq!("", listing_options.file_extension); + } + + #[tokio::test] + async fn test_create_with_hive_partitions() { + let dir = tempfile::tempdir().unwrap(); + let mut path = PathBuf::from(dir.path()); + path.extend(["key1=value1", "key2=value2"]); + fs::create_dir_all(&path).unwrap(); + path.push("data.parquet"); + fs::File::create_new(&path).unwrap(); + + let factory = ListingTableFactory::new(); + let context = SessionContext::new(); + let state = context.state(); + let name = TableReference::bare("foo"); + + let cmd = CreateExternalTable::builder( + name, + dir.path().to_str().unwrap(), + "parquet", + Arc::new(DFSchema::empty()), + ) + .build(); + let table_provider = factory.create(&state, &cmd).await.unwrap(); + let listing_table = table_provider + .as_any() + .downcast_ref::() + .unwrap(); + + let listing_options = listing_table.options(); + let dtype = + DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)); + let expected_cols = vec![ + (String::from("key1"), dtype.clone()), + (String::from("key2"), dtype.clone()), + ]; + assert_eq!(expected_cols, listing_options.table_partition_cols); + + // Ensure partition detection can be disabled via config + let factory = ListingTableFactory::new(); + let mut cfg = SessionConfig::new(); + cfg.options_mut() + .execution + .listing_table_factory_infer_partitions = false; + let context = SessionContext::new_with_config(cfg); + let state = context.state(); + let name = TableReference::bare("foo"); + + let cmd = CreateExternalTable::builder( + name, + dir.path().to_str().unwrap().to_string(), + "parquet", + Arc::new(DFSchema::empty()), + ) + .build(); + let table_provider = factory.create(&state, &cmd).await.unwrap(); + let listing_table = table_provider + .as_any() + .downcast_ref::() + .unwrap(); + + let listing_options = listing_table.options(); + assert!(listing_options.table_partition_cols.is_empty()); + } + + #[tokio::test] + async fn test_statistics_cache_prewarming() { + let factory = ListingTableFactory::new(); + + let location = PathBuf::from(parquet_test_data()) + .join("alltypes_tiny_pages_plain.parquet") + .to_string_lossy() + .to_string(); + + // Test with collect_statistics enabled + let file_statistics_cache = Arc::new(DefaultFileStatisticsCache::default()); + let cache_config = CacheManagerConfig::default() + .with_files_statistics_cache(Some(file_statistics_cache.clone())); + let runtime = RuntimeEnvBuilder::new() + .with_cache_manager(cache_config) + .build_arc() + .unwrap(); + + let mut config = SessionConfig::new(); + config.options_mut().execution.collect_statistics = true; + let context = SessionContext::new_with_config_rt(config, runtime); + let state = context.state(); + let name = TableReference::bare("test"); + + let cmd = CreateExternalTable::builder( + name, + location.clone(), + "parquet", + Arc::new(DFSchema::empty()), + ) + .build(); + + let _table_provider = factory.create(&state, &cmd).await.unwrap(); + + assert!( + file_statistics_cache.len() > 0, + "Statistics cache should be pre-warmed when collect_statistics is enabled" + ); + + // Test with collect_statistics disabled + let file_statistics_cache = Arc::new(DefaultFileStatisticsCache::default()); + let cache_config = CacheManagerConfig::default() + .with_files_statistics_cache(Some(file_statistics_cache.clone())); + let runtime = RuntimeEnvBuilder::new() + .with_cache_manager(cache_config) + .build_arc() + .unwrap(); + + let mut config = SessionConfig::new(); + config.options_mut().execution.collect_statistics = false; + let context = SessionContext::new_with_config_rt(config, runtime); + let state = context.state(); + let name = TableReference::bare("test"); + + let cmd = CreateExternalTable::builder( + name, + location, + "parquet", + Arc::new(DFSchema::empty()), + ) + .build(); + + let _table_provider = factory.create(&state, &cmd).await.unwrap(); + + assert_eq!( + file_statistics_cache.len(), + 0, + "Statistics cache should not be pre-warmed when collect_statistics is disabled" + ); + } + + #[tokio::test] + async fn test_create_with_invalid_session() { + use async_trait::async_trait; + use datafusion_catalog::Session; + use datafusion_common::Result; + use datafusion_common::config::TableOptions; + use datafusion_execution::TaskContext; + use datafusion_execution::config::SessionConfig; + use datafusion_physical_expr::PhysicalExpr; + use datafusion_physical_plan::ExecutionPlan; + use std::any::Any; + use std::collections::HashMap; + use std::sync::Arc; + + // A mock Session that is NOT SessionState + #[derive(Debug)] + struct MockSession; + + #[async_trait] + impl Session for MockSession { + fn session_id(&self) -> &str { + "mock_session" + } + fn config(&self) -> &SessionConfig { + unimplemented!() + } + async fn create_physical_plan( + &self, + _logical_plan: &datafusion_expr::LogicalPlan, + ) -> Result> { + unimplemented!() + } + fn create_physical_expr( + &self, + _expr: datafusion_expr::Expr, + _df_schema: &DFSchema, + ) -> Result> { + unimplemented!() + } + fn scalar_functions( + &self, + ) -> &HashMap> { + unimplemented!() + } + fn aggregate_functions( + &self, + ) -> &HashMap> { + unimplemented!() + } + fn window_functions( + &self, + ) -> &HashMap> { + unimplemented!() + } + fn runtime_env(&self) -> &Arc { + unimplemented!() + } + fn execution_props( + &self, + ) -> &datafusion_expr::execution_props::ExecutionProps { + unimplemented!() + } + fn as_any(&self) -> &dyn Any { + self + } + fn table_options(&self) -> &TableOptions { + unimplemented!() + } + fn table_options_mut(&mut self) -> &mut TableOptions { + unimplemented!() + } + fn task_ctx(&self) -> Arc { + unimplemented!() + } + } + + let factory = ListingTableFactory::new(); + let mock_session = MockSession; + + let name = TableReference::bare("foo"); + let cmd = CreateExternalTable::builder( + name, + "foo.csv".to_string(), + "csv", + Arc::new(DFSchema::empty()), + ) + .build(); + + // This should return an error, not panic + let result = factory.create(&mock_session, &cmd).await; + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .strip_backtrace() + .contains("Internal error: ListingTableFactory requires SessionState") + ); + } } diff --git a/datafusion/core/src/datasource/memory_test.rs b/datafusion/core/src/datasource/memory_test.rs index 381000ab8ee1e..c7721cafb02ea 100644 --- a/datafusion/core/src/datasource/memory_test.rs +++ b/datafusion/core/src/datasource/memory_test.rs @@ -19,7 +19,7 @@ mod tests { use crate::datasource::MemTable; - use crate::datasource::{provider_as_source, DefaultTableSource}; + use crate::datasource::{DefaultTableSource, provider_as_source}; use crate::physical_plan::collect; use crate::prelude::SessionContext; use arrow::array::{AsArray, Int32Array}; @@ -29,8 +29,8 @@ mod tests { use arrow_schema::SchemaRef; use datafusion_catalog::TableProvider; use datafusion_common::{DataFusionError, Result}; - use datafusion_expr::dml::InsertOp; use datafusion_expr::LogicalPlanBuilder; + use datafusion_expr::dml::InsertOp; use futures::StreamExt; use std::collections::HashMap; use std::sync::Arc; @@ -130,12 +130,15 @@ mod tests { .scan(&session_ctx.state(), Some(&projection), &[], None) .await { - Err(DataFusionError::ArrowError(ArrowError::SchemaError(e), _)) => { - assert_eq!( - "\"project index 4 out of bounds, max field 3\"", - format!("{e:?}") - ) - } + Err(DataFusionError::ArrowError(err, _)) => match err.as_ref() { + ArrowError::SchemaError(e) => { + assert_eq!( + "\"project index 4 out of bounds, max field 3\"", + format!("{e:?}") + ) + } + _ => panic!("unexpected error"), + }, res => panic!("Scan should failed on invalid projection, got {res:?}"), }; @@ -326,12 +329,11 @@ mod tests { ); let col = batch.column(0).as_primitive::(); assert_eq!(col.len(), 1, "expected 1 row, got {}", col.len()); - let val = col - .iter() + + col.iter() .next() .expect("had value") - .expect("expected non null"); - val + .expect("expected non null") } // Test inserting a single batch of data into a single partition @@ -443,7 +445,7 @@ mod tests { .unwrap_err(); // Ensure that there is a descriptive error message assert_eq!( - "Error during planning: Cannot insert into MemTable with zero partitions", + "Error during planning: No partitions provided, expected at least one partition", experiment_result.strip_backtrace() ); Ok(()) diff --git a/datafusion/core/src/datasource/mod.rs b/datafusion/core/src/datasource/mod.rs index f0c6771515a7f..c0cb9b5fa0fe6 100644 --- a/datafusion/core/src/datasource/mod.rs +++ b/datafusion/core/src/datasource/mod.rs @@ -31,7 +31,7 @@ mod view_test; // backwards compatibility pub use self::default_table_source::{ - provider_as_source, source_as_provider, DefaultTableSource, + DefaultTableSource, provider_as_source, source_as_provider, }; pub use self::memory::MemTable; pub use self::view::ViewTable; @@ -45,40 +45,46 @@ pub use datafusion_catalog::view; pub use datafusion_datasource::schema_adapter; pub use datafusion_datasource::sink; pub use datafusion_datasource::source; +pub use datafusion_datasource::table_schema; pub use datafusion_execution::object_store; pub use datafusion_physical_expr::create_ordering; #[cfg(all(test, feature = "parquet"))] mod tests { - use datafusion_datasource::schema_adapter::{ - DefaultSchemaAdapterFactory, SchemaAdapter, SchemaAdapterFactory, SchemaMapper, - }; - use crate::prelude::SessionContext; + use ::object_store::{ObjectMeta, path::Path}; use arrow::{ - array::{Int32Array, StringArray}, + array::Int32Array, datatypes::{DataType, Field, Schema, SchemaRef}, record_batch::RecordBatch, }; - use datafusion_common::{record_batch, test_util::batches_to_sort_string}; + use datafusion_common::{ + Result, ScalarValue, + test_util::batches_to_sort_string, + tree_node::{Transformed, TransformedResult, TreeNode}, + }; use datafusion_datasource::{ - file::FileSource, file_scan_config::FileScanConfigBuilder, - source::DataSourceExec, PartitionedFile, + PartitionedFile, file_scan_config::FileScanConfigBuilder, source::DataSourceExec, }; use datafusion_datasource_parquet::source::ParquetSource; - use datafusion_execution::object_store::ObjectStoreUrl; + use datafusion_physical_expr::expressions::{Column, Literal}; + use datafusion_physical_expr_adapter::{ + PhysicalExprAdapter, PhysicalExprAdapterFactory, + }; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_plan::collect; - use object_store::{path::Path, ObjectMeta}; use std::{fs, sync::Arc}; use tempfile::TempDir; + use url::Url; #[tokio::test] - async fn can_override_schema_adapter() { - // Test shows that SchemaAdapter can add a column that doesn't existing in the - // record batches returned from parquet. This can be useful for schema evolution + async fn can_override_physical_expr_adapter() { + // Test shows that PhysicalExprAdapter can add a column that doesn't exist in the + // record batches returned from parquet. This can be useful for schema evolution // where older files may not have all columns. + use datafusion_execution::object_store::ObjectStoreUrl; let tmp_dir = TempDir::new().unwrap(); let table_dir = tmp_dir.path().join("parquet_test"); fs::DirBuilder::new().create(table_dir.as_path()).unwrap(); @@ -98,7 +104,8 @@ mod tests { writer.write(&rec_batch).unwrap(); writer.close().unwrap(); - let location = Path::parse(path.to_str().unwrap()).unwrap(); + let url = Url::from_file_path(path.canonicalize().unwrap()).unwrap(); + let location = Path::from_url_path(url.path()).unwrap(); let metadata = fs::metadata(path.as_path()).expect("Local file metadata"); let meta = ObjectMeta { location, @@ -108,28 +115,18 @@ mod tests { version: None, }; - let partitioned_file = PartitionedFile { - object_meta: meta, - partition_values: vec![], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - }; + let partitioned_file = PartitionedFile::new_from_meta(meta); let f1 = Field::new("id", DataType::Int32, true); let f2 = Field::new("extra_column", DataType::Utf8, true); let schema = Arc::new(Schema::new(vec![f1.clone(), f2.clone()])); - let source = ParquetSource::default() - .with_schema_adapter_factory(Arc::new(TestSchemaAdapterFactory {})); - let base_conf = FileScanConfigBuilder::new( - ObjectStoreUrl::local_filesystem(), - schema, - source, - ) - .with_file(partitioned_file) - .build(); + let source = Arc::new(ParquetSource::new(Arc::clone(&schema))); + let base_conf = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), source) + .with_file(partitioned_file) + .with_expr_adapter(Some(Arc::new(TestPhysicalExprAdapterFactory))) + .build(); let parquet_exec = DataSourceExec::from_data_source(base_conf); @@ -137,134 +134,52 @@ mod tests { let task_ctx = session_ctx.task_ctx(); let read = collect(parquet_exec, task_ctx).await.unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&read),@r###" + insta::assert_snapshot!(batches_to_sort_string(&read),@r" +----+--------------+ | id | extra_column | +----+--------------+ | 1 | foo | +----+--------------+ - "###); - } - - #[test] - fn default_schema_adapter() { - let table_schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Utf8, true), - ]); - - // file has a subset of the table schema fields and different type - let file_schema = Schema::new(vec![ - Field::new("c", DataType::Float64, true), // not in table schema - Field::new("b", DataType::Float64, true), - ]); - - let adapter = DefaultSchemaAdapterFactory::from_schema(Arc::new(table_schema)); - let (mapper, indices) = adapter.map_schema(&file_schema).unwrap(); - assert_eq!(indices, vec![1]); - - let file_batch = record_batch!(("b", Float64, vec![1.0, 2.0])).unwrap(); - - let mapped_batch = mapper.map_batch(file_batch).unwrap(); - - // the mapped batch has the correct schema and the "b" column has been cast to Utf8 - let expected_batch = record_batch!( - ("a", Int32, vec![None, None]), // missing column filled with nulls - ("b", Utf8, vec!["1.0", "2.0"]) // b was cast to string and order was changed - ) - .unwrap(); - assert_eq!(mapped_batch, expected_batch); - } - - #[test] - fn default_schema_adapter_non_nullable_columns() { - let table_schema = Schema::new(vec![ - Field::new("a", DataType::Int32, false), // "a"" is declared non nullable - Field::new("b", DataType::Utf8, true), - ]); - let file_schema = Schema::new(vec![ - // since file doesn't have "a" it will be filled with nulls - Field::new("b", DataType::Float64, true), - ]); - - let adapter = DefaultSchemaAdapterFactory::from_schema(Arc::new(table_schema)); - let (mapper, indices) = adapter.map_schema(&file_schema).unwrap(); - assert_eq!(indices, vec![0]); - - let file_batch = record_batch!(("b", Float64, vec![1.0, 2.0])).unwrap(); - - // Mapping fails because it tries to fill in a non-nullable column with nulls - let err = mapper.map_batch(file_batch).unwrap_err().to_string(); - assert!(err.contains("Invalid argument error: Column 'a' is declared as non-nullable but contains null values"), "{err}"); + "); } #[derive(Debug)] - struct TestSchemaAdapterFactory; + struct TestPhysicalExprAdapterFactory; - impl SchemaAdapterFactory for TestSchemaAdapterFactory { + impl PhysicalExprAdapterFactory for TestPhysicalExprAdapterFactory { fn create( &self, - projected_table_schema: SchemaRef, - _table_schema: SchemaRef, - ) -> Box { - Box::new(TestSchemaAdapter { - table_schema: projected_table_schema, - }) + _logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + ) -> Result> { + Ok(Arc::new(TestPhysicalExprAdapter { + physical_file_schema, + })) } } - struct TestSchemaAdapter { - /// Schema for the table - table_schema: SchemaRef, + #[derive(Debug)] + struct TestPhysicalExprAdapter { + physical_file_schema: SchemaRef, } - impl SchemaAdapter for TestSchemaAdapter { - fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option { - let field = self.table_schema.field(index); - Some(file_schema.fields.find(field.name())?.0) - } - - fn map_schema( - &self, - file_schema: &Schema, - ) -> datafusion_common::Result<(Arc, Vec)> { - let mut projection = Vec::with_capacity(file_schema.fields().len()); - - for (file_idx, file_field) in file_schema.fields.iter().enumerate() { - if self.table_schema.fields().find(file_field.name()).is_some() { - projection.push(file_idx); + impl PhysicalExprAdapter for TestPhysicalExprAdapter { + fn rewrite(&self, expr: Arc) -> Result> { + expr.transform(|e| { + if let Some(column) = e.as_any().downcast_ref::() { + // If column is "extra_column" and missing from physical schema, inject "foo" + if column.name() == "extra_column" + && self.physical_file_schema.index_of("extra_column").is_err() + { + return Ok(Transformed::yes(Arc::new(Literal::new( + ScalarValue::Utf8(Some("foo".to_string())), + )) + as Arc)); + } } - } - - Ok((Arc::new(TestSchemaMapping {}), projection)) - } - } - - #[derive(Debug)] - struct TestSchemaMapping {} - - impl SchemaMapper for TestSchemaMapping { - fn map_batch( - &self, - batch: RecordBatch, - ) -> datafusion_common::Result { - let f1 = Field::new("id", DataType::Int32, true); - let f2 = Field::new("extra_column", DataType::Utf8, true); - - let schema = Arc::new(Schema::new(vec![f1, f2])); - - let extra_column = Arc::new(StringArray::from(vec!["foo"])); - let mut new_columns = batch.columns().to_vec(); - new_columns.push(extra_column); - - Ok(RecordBatch::try_new(schema, new_columns).unwrap()) - } - - fn map_column_statistics( - &self, - _file_col_statistics: &[datafusion_common::ColumnStatistics], - ) -> datafusion_common::Result> { - unimplemented!() + Ok(Transformed::no(e)) + }) + .data() } } } diff --git a/datafusion/core/src/datasource/physical_plan/arrow.rs b/datafusion/core/src/datasource/physical_plan/arrow.rs new file mode 100644 index 0000000000000..392eaa8c4be49 --- /dev/null +++ b/datafusion/core/src/datasource/physical_plan/arrow.rs @@ -0,0 +1,23 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Reexports the [`datafusion_datasource_arrow::source`] module, containing [Arrow] based [`FileSource`]. +//! +//! [Arrow]: https://arrow.apache.org/docs/python/ipc.html +//! [`FileSource`]: datafusion_datasource::file::FileSource + +pub use datafusion_datasource_arrow::source::*; diff --git a/datafusion/core/src/datasource/physical_plan/arrow_file.rs b/datafusion/core/src/datasource/physical_plan/arrow_file.rs deleted file mode 100644 index 6de72aa8ff720..0000000000000 --- a/datafusion/core/src/datasource/physical_plan/arrow_file.rs +++ /dev/null @@ -1,239 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::any::Any; -use std::sync::Arc; - -use crate::datasource::physical_plan::{FileMeta, FileOpenFuture, FileOpener}; -use crate::error::Result; -use datafusion_datasource::schema_adapter::SchemaAdapterFactory; -use datafusion_datasource::{as_file_source, impl_schema_adapter_methods}; - -use arrow::buffer::Buffer; -use arrow::datatypes::SchemaRef; -use arrow_ipc::reader::FileDecoder; -use datafusion_common::Statistics; -use datafusion_datasource::file::FileSource; -use datafusion_datasource::file_scan_config::FileScanConfig; -use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; - -use futures::StreamExt; -use itertools::Itertools; -use object_store::{GetOptions, GetRange, GetResultPayload, ObjectStore}; - -/// Arrow configuration struct that is given to DataSourceExec -/// Does not hold anything special, since [`FileScanConfig`] is sufficient for arrow -#[derive(Clone, Default)] -pub struct ArrowSource { - metrics: ExecutionPlanMetricsSet, - projected_statistics: Option, - schema_adapter_factory: Option>, -} - -impl From for Arc { - fn from(source: ArrowSource) -> Self { - as_file_source(source) - } -} - -impl FileSource for ArrowSource { - fn create_file_opener( - &self, - object_store: Arc, - base_config: &FileScanConfig, - _partition: usize, - ) -> Arc { - Arc::new(ArrowOpener { - object_store, - projection: base_config.file_column_projection_indices(), - }) - } - - fn as_any(&self) -> &dyn Any { - self - } - - fn with_batch_size(&self, _batch_size: usize) -> Arc { - Arc::new(Self { ..self.clone() }) - } - - fn with_schema(&self, _schema: SchemaRef) -> Arc { - Arc::new(Self { ..self.clone() }) - } - fn with_statistics(&self, statistics: Statistics) -> Arc { - let mut conf = self.clone(); - conf.projected_statistics = Some(statistics); - Arc::new(conf) - } - - fn with_projection(&self, _config: &FileScanConfig) -> Arc { - Arc::new(Self { ..self.clone() }) - } - - fn metrics(&self) -> &ExecutionPlanMetricsSet { - &self.metrics - } - - fn statistics(&self) -> Result { - let statistics = &self.projected_statistics; - Ok(statistics - .clone() - .expect("projected_statistics must be set")) - } - - fn file_type(&self) -> &str { - "arrow" - } - - impl_schema_adapter_methods!(); -} - -/// The struct arrow that implements `[FileOpener]` trait -pub struct ArrowOpener { - pub object_store: Arc, - pub projection: Option>, -} - -impl FileOpener for ArrowOpener { - fn open(&self, file_meta: FileMeta) -> Result { - let object_store = Arc::clone(&self.object_store); - let projection = self.projection.clone(); - Ok(Box::pin(async move { - let range = file_meta.range.clone(); - match range { - None => { - let r = object_store.get(file_meta.location()).await?; - match r.payload { - #[cfg(not(target_arch = "wasm32"))] - GetResultPayload::File(file, _) => { - let arrow_reader = arrow::ipc::reader::FileReader::try_new( - file, projection, - )?; - Ok(futures::stream::iter(arrow_reader).boxed()) - } - GetResultPayload::Stream(_) => { - let bytes = r.bytes().await?; - let cursor = std::io::Cursor::new(bytes); - let arrow_reader = arrow::ipc::reader::FileReader::try_new( - cursor, projection, - )?; - Ok(futures::stream::iter(arrow_reader).boxed()) - } - } - } - Some(range) => { - // range is not none, the file maybe split into multiple parts to scan in parallel - // get footer_len firstly - let get_option = GetOptions { - range: Some(GetRange::Suffix(10)), - ..Default::default() - }; - let get_result = object_store - .get_opts(file_meta.location(), get_option) - .await?; - let footer_len_buf = get_result.bytes().await?; - let footer_len = arrow_ipc::reader::read_footer_length( - footer_len_buf[..].try_into().unwrap(), - )?; - // read footer according to footer_len - let get_option = GetOptions { - range: Some(GetRange::Suffix(10 + (footer_len as u64))), - ..Default::default() - }; - let get_result = object_store - .get_opts(file_meta.location(), get_option) - .await?; - let footer_buf = get_result.bytes().await?; - let footer = arrow_ipc::root_as_footer( - footer_buf[..footer_len].try_into().unwrap(), - ) - .map_err(|err| { - arrow::error::ArrowError::ParseError(format!( - "Unable to get root as footer: {err:?}" - )) - })?; - // build decoder according to footer & projection - let schema = - arrow_ipc::convert::fb_to_schema(footer.schema().unwrap()); - let mut decoder = FileDecoder::new(schema.into(), footer.version()); - if let Some(projection) = projection { - decoder = decoder.with_projection(projection); - } - let dict_ranges = footer - .dictionaries() - .iter() - .flatten() - .map(|block| { - let block_len = - block.bodyLength() as u64 + block.metaDataLength() as u64; - let block_offset = block.offset() as u64; - block_offset..block_offset + block_len - }) - .collect_vec(); - let dict_results = object_store - .get_ranges(file_meta.location(), &dict_ranges) - .await?; - for (dict_block, dict_result) in - footer.dictionaries().iter().flatten().zip(dict_results) - { - decoder - .read_dictionary(dict_block, &Buffer::from(dict_result))?; - } - - // filter recordbatches according to range - let recordbatches = footer - .recordBatches() - .iter() - .flatten() - .filter(|block| { - let block_offset = block.offset() as u64; - block_offset >= range.start as u64 - && block_offset < range.end as u64 - }) - .copied() - .collect_vec(); - - let recordbatch_ranges = recordbatches - .iter() - .map(|block| { - let block_len = - block.bodyLength() as u64 + block.metaDataLength() as u64; - let block_offset = block.offset() as u64; - block_offset..block_offset + block_len - }) - .collect_vec(); - - let recordbatch_results = object_store - .get_ranges(file_meta.location(), &recordbatch_ranges) - .await?; - - Ok(futures::stream::iter( - recordbatches - .into_iter() - .zip(recordbatch_results) - .filter_map(move |(block, data)| { - decoder - .read_record_batch(&block, &Buffer::from(data)) - .transpose() - }), - ) - .boxed()) - } - } - })) - } -} diff --git a/datafusion/core/src/datasource/physical_plan/avro.rs b/datafusion/core/src/datasource/physical_plan/avro.rs index 8a00af959ccc9..2954a47403299 100644 --- a/datafusion/core/src/datasource/physical_plan/avro.rs +++ b/datafusion/core/src/datasource/physical_plan/avro.rs @@ -31,21 +31,21 @@ mod tests { use crate::test::object_store::local_unpartitioned_file; use arrow::datatypes::{DataType, Field, SchemaBuilder}; use datafusion_common::test_util::batches_to_string; - use datafusion_common::{test_util, Result, ScalarValue}; + use datafusion_common::{Result, ScalarValue, test_util}; use datafusion_datasource::file_format::FileFormat; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; - use datafusion_datasource::PartitionedFile; - use datafusion_datasource_avro::source::AvroSource; + use datafusion_datasource::{PartitionedFile, TableSchema}; use datafusion_datasource_avro::AvroFormat; + use datafusion_datasource_avro::source::AvroSource; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_physical_plan::ExecutionPlan; use datafusion_datasource::source::DataSourceExec; use futures::StreamExt; use insta::assert_snapshot; + use object_store::ObjectStore; use object_store::chunked::ChunkedStore; use object_store::local::LocalFileSystem; - use object_store::ObjectStore; use rstest::*; use url::Url; @@ -81,15 +81,11 @@ mod tests { .infer_schema(&state, &store, std::slice::from_ref(&meta)) .await?; - let source = Arc::new(AvroSource::new()); - let conf = FileScanConfigBuilder::new( - ObjectStoreUrl::local_filesystem(), - file_schema, - source, - ) - .with_file(meta.into()) - .with_projection(Some(vec![0, 1, 2])) - .build(); + let source = Arc::new(AvroSource::new(Arc::clone(&file_schema))); + let conf = FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), source) + .with_file(meta.into()) + .with_projection_indices(Some(vec![0, 1, 2]))? + .build(); let source_exec = DataSourceExec::from_data_source(conf); assert_eq!( @@ -109,20 +105,20 @@ mod tests { .expect("plan iterator empty") .expect("plan iterator returned an error"); - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch]), @r###" - +----+----------+-------------+ - | id | bool_col | tinyint_col | - +----+----------+-------------+ - | 4 | true | 0 | - | 5 | false | 1 | - | 6 | true | 0 | - | 7 | false | 1 | - | 2 | true | 0 | - | 3 | false | 1 | - | 0 | true | 0 | - | 1 | false | 1 | - +----+----------+-------------+ - "###);} + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch]), @r" + +----+----------+-------------+ + | id | bool_col | tinyint_col | + +----+----------+-------------+ + | 4 | true | 0 | + | 5 | false | 1 | + | 6 | true | 0 | + | 7 | false | 1 | + | 2 | true | 0 | + | 3 | false | 1 | + | 0 | true | 0 | + | 1 | false | 1 | + +----+----------+-------------+ + ");} let batch = results.next().await; assert!(batch.is_none()); @@ -157,10 +153,10 @@ mod tests { // Include the missing column in the projection let projection = Some(vec![0, 1, 2, actual_schema.fields().len()]); - let source = Arc::new(AvroSource::new()); - let conf = FileScanConfigBuilder::new(object_store_url, file_schema, source) + let source = Arc::new(AvroSource::new(Arc::clone(&file_schema))); + let conf = FileScanConfigBuilder::new(object_store_url, source) .with_file(meta.into()) - .with_projection(projection) + .with_projection_indices(projection)? .build(); let source_exec = DataSourceExec::from_data_source(conf); @@ -182,20 +178,20 @@ mod tests { .expect("plan iterator empty") .expect("plan iterator returned an error"); - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch]), @r###" - +----+----------+-------------+-------------+ - | id | bool_col | tinyint_col | missing_col | - +----+----------+-------------+-------------+ - | 4 | true | 0 | | - | 5 | false | 1 | | - | 6 | true | 0 | | - | 7 | false | 1 | | - | 2 | true | 0 | | - | 3 | false | 1 | | - | 0 | true | 0 | | - | 1 | false | 1 | | - +----+----------+-------------+-------------+ - "###);} + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch]), @r" + +----+----------+-------------+-------------+ + | id | bool_col | tinyint_col | missing_col | + +----+----------+-------------+-------------+ + | 4 | true | 0 | | + | 5 | false | 1 | | + | 6 | true | 0 | | + | 7 | false | 1 | | + | 2 | true | 0 | | + | 3 | false | 1 | | + | 0 | true | 0 | | + | 1 | false | 1 | | + +----+----------+-------------+-------------+ + ");} let batch = results.next().await; assert!(batch.is_none()); @@ -227,13 +223,16 @@ mod tests { partitioned_file.partition_values = vec![ScalarValue::from("2021-10-26")]; let projection = Some(vec![0, 1, file_schema.fields().len(), 2]); - let source = Arc::new(AvroSource::new()); - let conf = FileScanConfigBuilder::new(object_store_url, file_schema, source) + let table_schema = TableSchema::new( + file_schema.clone(), + vec![Arc::new(Field::new("date", DataType::Utf8, false))], + ); + let source = Arc::new(AvroSource::new(table_schema.clone())); + let conf = FileScanConfigBuilder::new(object_store_url, source) // select specific columns of the files as well as the partitioning // column which is supposed to be the last column in the table schema. - .with_projection(projection) + .with_projection_indices(projection)? .with_file(partitioned_file) - .with_table_partition_cols(vec![Field::new("date", DataType::Utf8, false)]) .build(); let source_exec = DataSourceExec::from_data_source(conf); @@ -256,20 +255,20 @@ mod tests { .expect("plan iterator empty") .expect("plan iterator returned an error"); - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch]), @r###" - +----+----------+------------+-------------+ - | id | bool_col | date | tinyint_col | - +----+----------+------------+-------------+ - | 4 | true | 2021-10-26 | 0 | - | 5 | false | 2021-10-26 | 1 | - | 6 | true | 2021-10-26 | 0 | - | 7 | false | 2021-10-26 | 1 | - | 2 | true | 2021-10-26 | 0 | - | 3 | false | 2021-10-26 | 1 | - | 0 | true | 2021-10-26 | 0 | - | 1 | false | 2021-10-26 | 1 | - +----+----------+------------+-------------+ - "###);} + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch]), @r" + +----+----------+------------+-------------+ + | id | bool_col | date | tinyint_col | + +----+----------+------------+-------------+ + | 4 | true | 2021-10-26 | 0 | + | 5 | false | 2021-10-26 | 1 | + | 6 | true | 2021-10-26 | 0 | + | 7 | false | 2021-10-26 | 1 | + | 2 | true | 2021-10-26 | 0 | + | 3 | false | 2021-10-26 | 1 | + | 0 | true | 2021-10-26 | 0 | + | 1 | false | 2021-10-26 | 1 | + +----+----------+------------+-------------+ + ");} let batch = results.next().await; assert!(batch.is_none()); diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index 3ef4030134520..82c47b6c7281c 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -29,18 +29,21 @@ mod tests { use std::io::Write; use std::sync::Arc; + use datafusion_datasource::TableSchema; use datafusion_datasource_csv::CsvFormat; - use object_store::ObjectStore; + use object_store::{ObjectStore, ObjectStoreExt}; + use crate::datasource::file_format::FileFormat; use crate::prelude::CsvReadOptions; use crate::prelude::SessionContext; use crate::test::partitioned_file_groups; + use datafusion_common::config::CsvOptions; use datafusion_common::test_util::arrow_test_data; use datafusion_common::test_util::batches_to_string; - use datafusion_common::{assert_batches_eq, Result}; + use datafusion_common::{Result, assert_batches_eq}; use datafusion_execution::config::SessionConfig; - use datafusion_physical_plan::metrics::MetricsSet; use datafusion_physical_plan::ExecutionPlan; + use datafusion_physical_plan::metrics::MetricsSet; #[cfg(feature = "compression")] use datafusion_datasource::file_compression_type::FileCompressionType; @@ -94,34 +97,41 @@ mod tests { async fn csv_exec_with_projection( file_compression_type: FileCompressionType, ) -> Result<()> { + use datafusion_datasource::TableSchema; + let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); let file_schema = aggr_test_schema(); let path = format!("{}/csv", arrow_test_data()); let filename = "aggregate_test_100.csv"; let tmp_dir = TempDir::new()?; + let csv_format: Arc = Arc::new(CsvFormat::default()); let file_groups = partitioned_file_groups( path.as_str(), filename, 1, - Arc::new(CsvFormat::default()), + &csv_format, file_compression_type.to_owned(), tmp_dir.path(), )?; - let source = Arc::new(CsvSource::new(true, b',', b'"')); - let config = FileScanConfigBuilder::from(partitioned_csv_config( - file_schema, - file_groups, - source, - )) - .with_file_compression_type(file_compression_type) - .with_newlines_in_values(false) - .with_projection(Some(vec![0, 2, 4])) - .build(); - - assert_eq!(13, config.file_schema.fields().len()); + let options = CsvOptions { + has_header: Some(true), + delimiter: b',', + quote: b'"', + ..Default::default() + }; + let table_schema = TableSchema::from_file_schema(Arc::clone(&file_schema)); + let source = + Arc::new(CsvSource::new(table_schema.clone()).with_csv_options(options)); + let config = + FileScanConfigBuilder::from(partitioned_csv_config(file_groups, source)?) + .with_file_compression_type(file_compression_type) + .with_projection_indices(Some(vec![0, 2, 4]))? + .build(); + + assert_eq!(13, config.file_schema().fields().len()); let csv = DataSourceExec::from_data_source(config); assert_eq!(3, csv.schema().fields().len()); @@ -131,17 +141,17 @@ mod tests { assert_eq!(3, batch.num_columns()); assert_eq!(100, batch.num_rows()); - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch.slice(0, 5)]), @r###" - +----+-----+------------+ - | c1 | c3 | c5 | - +----+-----+------------+ - | c | 1 | 2033001162 | - | d | -40 | 706441268 | - | b | 29 | 994303988 | - | a | -85 | 1171968280 | - | b | -82 | 1824882165 | - +----+-----+------------+ - "###);} + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch.slice(0, 5)]), @r" + +----+-----+------------+ + | c1 | c3 | c5 | + +----+-----+------------+ + | c | 1 | 2033001162 | + | d | -40 | 706441268 | + | b | 29 | 994303988 | + | a | -85 | 1171968280 | + | b | -82 | 1824882165 | + +----+-----+------------+ + ");} Ok(()) } @@ -158,6 +168,8 @@ mod tests { async fn csv_exec_with_mixed_order_projection( file_compression_type: FileCompressionType, ) -> Result<()> { + use datafusion_datasource::TableSchema; + let cfg = SessionConfig::new().set_str("datafusion.catalog.has_header", "true"); let session_ctx = SessionContext::new_with_config(cfg); let task_ctx = session_ctx.task_ctx(); @@ -165,27 +177,32 @@ mod tests { let path = format!("{}/csv", arrow_test_data()); let filename = "aggregate_test_100.csv"; let tmp_dir = TempDir::new()?; + let csv_format: Arc = Arc::new(CsvFormat::default()); let file_groups = partitioned_file_groups( path.as_str(), filename, 1, - Arc::new(CsvFormat::default()), + &csv_format, file_compression_type.to_owned(), tmp_dir.path(), )?; - let source = Arc::new(CsvSource::new(true, b',', b'"')); - let config = FileScanConfigBuilder::from(partitioned_csv_config( - file_schema, - file_groups, - source, - )) - .with_newlines_in_values(false) - .with_file_compression_type(file_compression_type.to_owned()) - .with_projection(Some(vec![4, 0, 2])) - .build(); - assert_eq!(13, config.file_schema.fields().len()); + let options = CsvOptions { + has_header: Some(true), + delimiter: b',', + quote: b'"', + ..Default::default() + }; + let table_schema = TableSchema::from_file_schema(Arc::clone(&file_schema)); + let source = + Arc::new(CsvSource::new(table_schema.clone()).with_csv_options(options)); + let config = + FileScanConfigBuilder::from(partitioned_csv_config(file_groups, source)?) + .with_file_compression_type(file_compression_type.to_owned()) + .with_projection_indices(Some(vec![4, 0, 2]))? + .build(); + assert_eq!(13, config.file_schema().fields().len()); let csv = DataSourceExec::from_data_source(config); assert_eq!(3, csv.schema().fields().len()); @@ -194,17 +211,17 @@ mod tests { assert_eq!(3, batch.num_columns()); assert_eq!(100, batch.num_rows()); - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch.slice(0, 5)]), @r###" - +------------+----+-----+ - | c5 | c1 | c3 | - +------------+----+-----+ - | 2033001162 | c | 1 | - | 706441268 | d | -40 | - | 994303988 | b | 29 | - | 1171968280 | a | -85 | - | 1824882165 | b | -82 | - +------------+----+-----+ - "###);} + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch.slice(0, 5)]), @r" + +------------+----+-----+ + | c5 | c1 | c3 | + +------------+----+-----+ + | 2033001162 | c | 1 | + | 706441268 | d | -40 | + | 994303988 | b | 29 | + | 1171968280 | a | -85 | + | 1824882165 | b | -82 | + +------------+----+-----+ + ");} Ok(()) } @@ -221,6 +238,7 @@ mod tests { async fn csv_exec_with_limit( file_compression_type: FileCompressionType, ) -> Result<()> { + use datafusion_datasource::TableSchema; use futures::StreamExt; let cfg = SessionConfig::new().set_str("datafusion.catalog.has_header", "true"); @@ -230,27 +248,32 @@ mod tests { let path = format!("{}/csv", arrow_test_data()); let filename = "aggregate_test_100.csv"; let tmp_dir = TempDir::new()?; + let csv_format: Arc = Arc::new(CsvFormat::default()); let file_groups = partitioned_file_groups( path.as_str(), filename, 1, - Arc::new(CsvFormat::default()), + &csv_format, file_compression_type.to_owned(), tmp_dir.path(), )?; - let source = Arc::new(CsvSource::new(true, b',', b'"')); - let config = FileScanConfigBuilder::from(partitioned_csv_config( - file_schema, - file_groups, - source, - )) - .with_newlines_in_values(false) - .with_file_compression_type(file_compression_type.to_owned()) - .with_limit(Some(5)) - .build(); - assert_eq!(13, config.file_schema.fields().len()); + let options = CsvOptions { + has_header: Some(true), + delimiter: b',', + quote: b'"', + ..Default::default() + }; + let table_schema = TableSchema::from_file_schema(Arc::clone(&file_schema)); + let source = + Arc::new(CsvSource::new(table_schema.clone()).with_csv_options(options)); + let config = + FileScanConfigBuilder::from(partitioned_csv_config(file_groups, source)?) + .with_file_compression_type(file_compression_type.to_owned()) + .with_limit(Some(5)) + .build(); + assert_eq!(13, config.file_schema().fields().len()); let csv = DataSourceExec::from_data_source(config); assert_eq!(13, csv.schema().fields().len()); @@ -259,17 +282,17 @@ mod tests { assert_eq!(13, batch.num_columns()); assert_eq!(5, batch.num_rows()); - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch]), @r###" - +----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+ - | c1 | c2 | c3 | c4 | c5 | c6 | c7 | c8 | c9 | c10 | c11 | c12 | c13 | - +----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+ - | c | 2 | 1 | 18109 | 2033001162 | -6513304855495910254 | 25 | 43062 | 1491205016 | 5863949479783605708 | 0.110830784 | 0.9294097332465232 | 6WfVFBVGJSQb7FhA7E0lBwdvjfZnSW | - | d | 5 | -40 | 22614 | 706441268 | -7542719935673075327 | 155 | 14337 | 3373581039 | 11720144131976083864 | 0.69632107 | 0.3114712539863804 | C2GT5KVyOPZpgKVl110TyZO0NcJ434 | - | b | 1 | 29 | -18218 | 994303988 | 5983957848665088916 | 204 | 9489 | 3275293996 | 14857091259186476033 | 0.53840446 | 0.17909035118828576 | AyYVExXK6AR2qUTxNZ7qRHQOVGMLcz | - | a | 1 | -85 | -15154 | 1171968280 | 1919439543497968449 | 77 | 52286 | 774637006 | 12101411955859039553 | 0.12285209 | 0.6864391962767343 | 0keZ5G8BffGwgF2RwQD59TFzMStxCB | - | b | 5 | -82 | 22080 | 1824882165 | 7373730676428214987 | 208 | 34331 | 3342719438 | 3330177516592499461 | 0.82634634 | 0.40975383525297016 | Ig1QcuKsjHXkproePdERo2w0mYzIqd | - +----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+ - "###);} + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch]), @r" + +----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+ + | c1 | c2 | c3 | c4 | c5 | c6 | c7 | c8 | c9 | c10 | c11 | c12 | c13 | + +----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+ + | c | 2 | 1 | 18109 | 2033001162 | -6513304855495910254 | 25 | 43062 | 1491205016 | 5863949479783605708 | 0.110830784 | 0.9294097332465232 | 6WfVFBVGJSQb7FhA7E0lBwdvjfZnSW | + | d | 5 | -40 | 22614 | 706441268 | -7542719935673075327 | 155 | 14337 | 3373581039 | 11720144131976083864 | 0.69632107 | 0.3114712539863804 | C2GT5KVyOPZpgKVl110TyZO0NcJ434 | + | b | 1 | 29 | -18218 | 994303988 | 5983957848665088916 | 204 | 9489 | 3275293996 | 14857091259186476033 | 0.53840446 | 0.17909035118828576 | AyYVExXK6AR2qUTxNZ7qRHQOVGMLcz | + | a | 1 | -85 | -15154 | 1171968280 | 1919439543497968449 | 77 | 52286 | 774637006 | 12101411955859039553 | 0.12285209 | 0.6864391962767343 | 0keZ5G8BffGwgF2RwQD59TFzMStxCB | + | b | 5 | -82 | 22080 | 1824882165 | 7373730676428214987 | 208 | 34331 | 3342719438 | 3330177516592499461 | 0.82634634 | 0.40975383525297016 | Ig1QcuKsjHXkproePdERo2w0mYzIqd | + +----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+ + ");} Ok(()) } @@ -287,33 +310,40 @@ mod tests { async fn csv_exec_with_missing_column( file_compression_type: FileCompressionType, ) -> Result<()> { + use datafusion_datasource::TableSchema; + let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); let file_schema = aggr_test_schema_with_missing_col(); let path = format!("{}/csv", arrow_test_data()); let filename = "aggregate_test_100.csv"; let tmp_dir = TempDir::new()?; + let csv_format: Arc = Arc::new(CsvFormat::default()); let file_groups = partitioned_file_groups( path.as_str(), filename, 1, - Arc::new(CsvFormat::default()), + &csv_format, file_compression_type.to_owned(), tmp_dir.path(), )?; - let source = Arc::new(CsvSource::new(true, b',', b'"')); - let config = FileScanConfigBuilder::from(partitioned_csv_config( - file_schema, - file_groups, - source, - )) - .with_newlines_in_values(false) - .with_file_compression_type(file_compression_type.to_owned()) - .with_limit(Some(5)) - .build(); - assert_eq!(14, config.file_schema.fields().len()); + let options = CsvOptions { + has_header: Some(true), + delimiter: b',', + quote: b'"', + ..Default::default() + }; + let table_schema = TableSchema::from_file_schema(Arc::clone(&file_schema)); + let source = + Arc::new(CsvSource::new(table_schema.clone()).with_csv_options(options)); + let config = + FileScanConfigBuilder::from(partitioned_csv_config(file_groups, source)?) + .with_file_compression_type(file_compression_type.to_owned()) + .with_limit(Some(5)) + .build(); + assert_eq!(14, config.file_schema().fields().len()); let csv = DataSourceExec::from_data_source(config); assert_eq!(14, csv.schema().fields().len()); @@ -341,6 +371,7 @@ mod tests { file_compression_type: FileCompressionType, ) -> Result<()> { use datafusion_common::ScalarValue; + use datafusion_datasource::TableSchema; let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); @@ -348,38 +379,45 @@ mod tests { let path = format!("{}/csv", arrow_test_data()); let filename = "aggregate_test_100.csv"; let tmp_dir = TempDir::new()?; + let csv_format: Arc = Arc::new(CsvFormat::default()); - let file_groups = partitioned_file_groups( + let mut file_groups = partitioned_file_groups( path.as_str(), filename, 1, - Arc::new(CsvFormat::default()), + &csv_format, file_compression_type.to_owned(), tmp_dir.path(), )?; - - let source = Arc::new(CsvSource::new(true, b',', b'"')); - let mut config = FileScanConfigBuilder::from(partitioned_csv_config( - file_schema, - file_groups, - source, - )) - .with_newlines_in_values(false) - .with_file_compression_type(file_compression_type.to_owned()) - .build(); - - // Add partition columns - config.table_partition_cols = vec![Field::new("date", DataType::Utf8, false)]; - config.file_groups[0][0].partition_values = vec![ScalarValue::from("2021-10-26")]; - - // We should be able to project on the partition column - // Which is supposed to be after the file fields - config.projection = Some(vec![0, config.file_schema.fields().len()]); + // Add partition columns / values + file_groups[0][0].partition_values = vec![ScalarValue::from("2021-10-26")]; + + let num_file_schema_fields = file_schema.fields().len(); + + let options = CsvOptions { + has_header: Some(true), + delimiter: b',', + quote: b'"', + ..Default::default() + }; + let table_schema = TableSchema::new( + Arc::clone(&file_schema), + vec![Arc::new(Field::new("date", DataType::Utf8, false))], + ); + let source = + Arc::new(CsvSource::new(table_schema.clone()).with_csv_options(options)); + let config = + FileScanConfigBuilder::from(partitioned_csv_config(file_groups, source)?) + .with_file_compression_type(file_compression_type.to_owned()) + // We should be able to project on the partition column + // Which is supposed to be after the file fields + .with_projection_indices(Some(vec![0, num_file_schema_fields]))? + .build(); // we don't have `/date=xx/` in the path but that is ok because // partitions are resolved during scan anyway - assert_eq!(13, config.file_schema.fields().len()); + assert_eq!(13, config.file_schema().fields().len()); let csv = DataSourceExec::from_data_source(config); assert_eq!(2, csv.schema().fields().len()); @@ -388,17 +426,17 @@ mod tests { assert_eq!(2, batch.num_columns()); assert_eq!(100, batch.num_rows()); - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch.slice(0, 5)]), @r###" - +----+------------+ - | c1 | date | - +----+------------+ - | c | 2021-10-26 | - | d | 2021-10-26 | - | b | 2021-10-26 | - | a | 2021-10-26 | - | b | 2021-10-26 | - +----+------------+ - "###);} + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&[batch.slice(0, 5)]), @r" + +----+------------+ + | c1 | date | + +----+------------+ + | c | 2021-10-26 | + | d | 2021-10-26 | + | b | 2021-10-26 | + | a | 2021-10-26 | + | b | 2021-10-26 | + +----+------------+ + ");} let metrics = csv.metrics().expect("doesn't found metrics"); let time_elapsed_processing = get_value(&metrics, "time_elapsed_processing"); @@ -452,26 +490,31 @@ mod tests { let path = format!("{}/csv", arrow_test_data()); let filename = "aggregate_test_100.csv"; let tmp_dir = TempDir::new()?; + let csv_format: Arc = Arc::new(CsvFormat::default()); let file_groups = partitioned_file_groups( path.as_str(), filename, 1, - Arc::new(CsvFormat::default()), + &csv_format, file_compression_type.to_owned(), tmp_dir.path(), ) .unwrap(); - let source = Arc::new(CsvSource::new(true, b',', b'"')); - let config = FileScanConfigBuilder::from(partitioned_csv_config( - file_schema, - file_groups, - source, - )) - .with_newlines_in_values(false) - .with_file_compression_type(file_compression_type.to_owned()) - .build(); + let options = CsvOptions { + has_header: Some(true), + delimiter: b',', + quote: b'"', + ..Default::default() + }; + let table_schema = TableSchema::from_file_schema(Arc::clone(&file_schema)); + let source = + Arc::new(CsvSource::new(table_schema.clone()).with_csv_options(options)); + let config = + FileScanConfigBuilder::from(partitioned_csv_config(file_groups, source)?) + .with_file_compression_type(file_compression_type.to_owned()) + .build(); let csv = DataSourceExec::from_data_source(config); let it = csv.execute(0, task_ctx).unwrap(); @@ -527,14 +570,14 @@ mod tests { let result = df.collect().await.unwrap(); - assert_snapshot!(batches_to_string(&result), @r###" - +---+---+ - | a | b | - +---+---+ - | 1 | 2 | - | 3 | 4 | - +---+---+ - "###); + assert_snapshot!(batches_to_string(&result), @r" + +---+---+ + | a | b | + +---+---+ + | 1 | 2 | + | 3 | 4 | + +---+---+ + "); } #[tokio::test] @@ -556,14 +599,14 @@ mod tests { let result = df.collect().await.unwrap(); - assert_snapshot!(batches_to_string(&result),@r###" - +---+---+ - | a | b | - +---+---+ - | 1 | 2 | - | 3 | 4 | - +---+---+ - "###); + assert_snapshot!(batches_to_string(&result),@r" + +---+---+ + | a | b | + +---+---+ + | 1 | 2 | + | 3 | 4 | + +---+---+ + "); let e = session_ctx .read_csv("memory:///", CsvReadOptions::new().terminator(Some(b'\n'))) @@ -572,7 +615,10 @@ mod tests { .collect() .await .unwrap_err(); - assert_eq!(e.strip_backtrace(), "Arrow error: Csv error: incorrect number of fields for line 1, expected 2 got more than 2") + assert_eq!( + e.strip_backtrace(), + "Arrow error: Csv error: incorrect number of fields for line 1, expected 2 got more than 2" + ) } #[tokio::test] @@ -593,22 +639,22 @@ mod tests { .await?; let df = ctx.sql(r#"select * from t1"#).await?.collect().await?; - assert_snapshot!(batches_to_string(&df),@r###" - +------+--------+ - | col1 | col2 | - +------+--------+ - | id0 | value0 | - | id1 | value1 | - | id2 | value2 | - | id3 | value3 | - +------+--------+ - "###); + assert_snapshot!(batches_to_string(&df),@r" + +------+--------+ + | col1 | col2 | + +------+--------+ + | id0 | value0 | + | id1 | value1 | + | id2 | value2 | + | id3 | value3 | + +------+--------+ + "); Ok(()) } #[tokio::test] - async fn test_create_external_table_with_terminator_with_newlines_in_values( - ) -> Result<()> { + async fn test_create_external_table_with_terminator_with_newlines_in_values() + -> Result<()> { let ctx = SessionContext::new(); ctx.sql(r#" CREATE EXTERNAL TABLE t1 ( @@ -658,7 +704,10 @@ mod tests { ) .await .expect_err("should fail because input file does not match inferred schema"); - assert_eq!(e.strip_backtrace(), "Arrow error: Parser error: Error while parsing value 'd' as type 'Int64' for column 0 at line 4. Row data: '[d,4]'"); + assert_eq!( + e.strip_backtrace(), + "Arrow error: Parser error: Error while parsing value 'd' as type 'Int64' for column 0 at line 4. Row data: '[d,4]'" + ); Ok(()) } diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index 0d45711c76fb0..b70791c7b2390 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -32,11 +32,11 @@ mod tests { use crate::dataframe::DataFrameWriteOptions; use crate::execution::SessionState; - use crate::prelude::{CsvReadOptions, NdJsonReadOptions, SessionContext}; + use crate::prelude::{CsvReadOptions, JsonReadOptions, SessionContext}; use crate::test::partitioned_file_groups; + use datafusion_common::Result; use datafusion_common::cast::{as_int32_array, as_int64_array, as_string_array}; use datafusion_common::test_util::batches_to_string; - use datafusion_common::Result; use datafusion_datasource::file_compression_type::FileCompressionType; use datafusion_datasource::file_format::FileFormat; use datafusion_datasource_json::JsonFormat; @@ -51,9 +51,9 @@ mod tests { use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_datasource::source::DataSourceExec; use insta::assert_snapshot; + use object_store::ObjectStore; use object_store::chunked::ChunkedStore; use object_store::local::LocalFileSystem; - use object_store::ObjectStore; use rstest::*; use tempfile::TempDir; use url::Url; @@ -69,11 +69,13 @@ mod tests { let store = state.runtime_env().object_store(&store_url).unwrap(); let filename = "1.json"; + let json_format: Arc = Arc::new(JsonFormat::default()); + let file_groups = partitioned_file_groups( TEST_DATA_BASE, filename, 1, - Arc::new(JsonFormat::default()), + &json_format, file_compression_type.to_owned(), work_dir, ) @@ -104,11 +106,13 @@ mod tests { ctx.register_object_store(&url, store.clone()); let filename = "1.json"; let tmp_dir = TempDir::new()?; + let json_format: Arc = Arc::new(JsonFormat::default()); + let file_groups = partitioned_file_groups( TEST_DATA_BASE, filename, 1, - Arc::new(JsonFormat::default()), + &json_format, file_compression_type.to_owned(), tmp_dir.path(), ) @@ -132,22 +136,22 @@ mod tests { .get_ext_with_compression(&file_compression_type) .unwrap(); - let read_options = NdJsonReadOptions::default() + let read_options = JsonReadOptions::default() .file_extension(ext.as_str()) .file_compression_type(file_compression_type.to_owned()); let frame = ctx.read_json(path, read_options).await.unwrap(); let results = frame.collect().await.unwrap(); - insta::allow_duplicates! {assert_snapshot!(batches_to_string(&results), @r###" - +-----+------------------+---------------+------+ - | a | b | c | d | - +-----+------------------+---------------+------+ - | 1 | [2.0, 1.3, -6.1] | [false, true] | 4 | - | -10 | [2.0, 1.3, -6.1] | [true, true] | 4 | - | 2 | [2.0, , -6.1] | [false, ] | text | - | | | | | - +-----+------------------+---------------+------+ - "###);} + insta::allow_duplicates! {assert_snapshot!(batches_to_string(&results), @r" + +-----+------------------+---------------+------+ + | a | b | c | d | + +-----+------------------+---------------+------+ + | 1 | [2.0, 1.3, -6.1] | [false, true] | 4 | + | -10 | [2.0, 1.3, -6.1] | [true, true] | 4 | + | 2 | [2.0, , -6.1] | [false, ] | text | + | | | | | + +-----+------------------+---------------+------+ + ");} Ok(()) } @@ -176,8 +180,8 @@ mod tests { let (object_store_url, file_groups, file_schema) = prepare_store(&state, file_compression_type.to_owned(), tmp_dir.path()).await; - let source = Arc::new(JsonSource::new()); - let conf = FileScanConfigBuilder::new(object_store_url, file_schema, source) + let source = Arc::new(JsonSource::new(Arc::clone(&file_schema))); + let conf = FileScanConfigBuilder::new(object_store_url, source) .with_file_groups(file_groups) .with_limit(Some(3)) .with_file_compression_type(file_compression_type.to_owned()) @@ -251,8 +255,8 @@ mod tests { let file_schema = Arc::new(builder.finish()); let missing_field_idx = file_schema.fields.len() - 1; - let source = Arc::new(JsonSource::new()); - let conf = FileScanConfigBuilder::new(object_store_url, file_schema, source) + let source = Arc::new(JsonSource::new(Arc::clone(&file_schema))); + let conf = FileScanConfigBuilder::new(object_store_url, source) .with_file_groups(file_groups) .with_limit(Some(3)) .with_file_compression_type(file_compression_type.to_owned()) @@ -294,10 +298,11 @@ mod tests { let (object_store_url, file_groups, file_schema) = prepare_store(&state, file_compression_type.to_owned(), tmp_dir.path()).await; - let source = Arc::new(JsonSource::new()); - let conf = FileScanConfigBuilder::new(object_store_url, file_schema, source) + let source = Arc::new(JsonSource::new(Arc::clone(&file_schema))); + let conf = FileScanConfigBuilder::new(object_store_url, source) .with_file_groups(file_groups) - .with_projection(Some(vec![0, 2])) + .with_projection_indices(Some(vec![0, 2])) + .unwrap() .with_file_compression_type(file_compression_type.to_owned()) .build(); let exec = DataSourceExec::from_data_source(conf); @@ -342,10 +347,10 @@ mod tests { let (object_store_url, file_groups, file_schema) = prepare_store(&state, file_compression_type.to_owned(), tmp_dir.path()).await; - let source = Arc::new(JsonSource::new()); - let conf = FileScanConfigBuilder::new(object_store_url, file_schema, source) + let source = Arc::new(JsonSource::new(Arc::clone(&file_schema))); + let conf = FileScanConfigBuilder::new(object_store_url, source) .with_file_groups(file_groups) - .with_projection(Some(vec![3, 0, 2])) + .with_projection_indices(Some(vec![3, 0, 2]))? .with_file_compression_type(file_compression_type.to_owned()) .build(); let exec = DataSourceExec::from_data_source(conf); @@ -384,7 +389,7 @@ mod tests { let path = format!("{TEST_DATA_BASE}/1.json"); // register json file with the execution context - ctx.register_json("test", path.as_str(), NdJsonReadOptions::default()) + ctx.register_json("test", path.as_str(), JsonReadOptions::default()) .await?; // register a local file system object store for /tmp directory @@ -426,7 +431,7 @@ mod tests { } // register each partition as well as the top level dir - let json_read_option = NdJsonReadOptions::default(); + let json_read_option = JsonReadOptions::default(); ctx.register_json( "part0", &format!("{out_dir}/{part_0_name}"), @@ -494,7 +499,10 @@ mod tests { .write_json(out_dir_url, DataFrameWriteOptions::new(), None) .await .expect_err("should fail because input file does not match inferred schema"); - assert_eq!(e.strip_backtrace(), "Arrow error: Parser error: Error while parsing value 'd' as type 'Int64' for column 0 at line 4. Row data: '[d,4]'"); + assert_eq!( + e.strip_backtrace(), + "Arrow error: Parser error: Error while parsing value 'd' as type 'Int64' for column 0 at line 4. Row data: '[d,4]'" + ); Ok(()) } @@ -503,7 +511,7 @@ mod tests { async fn read_test_data(schema_infer_max_records: usize) -> Result { let ctx = SessionContext::new(); - let options = NdJsonReadOptions { + let options = JsonReadOptions { schema_infer_max_records, ..Default::default() }; @@ -579,7 +587,7 @@ mod tests { .get_ext_with_compression(&file_compression_type) .unwrap(); - let read_option = NdJsonReadOptions::default() + let read_option = JsonReadOptions::default() .file_compression_type(file_compression_type) .file_extension(ext.as_str()); diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index 3f71b253d9695..8e4855afa66bb 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -17,7 +17,7 @@ //! Execution plans that read file formats -mod arrow_file; +pub mod arrow; pub mod csv; pub mod json; @@ -35,156 +35,19 @@ pub use datafusion_datasource_parquet::source::ParquetSource; #[cfg(feature = "parquet")] pub use datafusion_datasource_parquet::{ParquetFileMetrics, ParquetFileReaderFactory}; -pub use arrow_file::ArrowSource; - pub use json::{JsonOpener, JsonSource}; +pub use arrow::{ArrowOpener, ArrowSource}; pub use csv::{CsvOpener, CsvSource}; pub use datafusion_datasource::file::FileSource; pub use datafusion_datasource::file_groups::FileGroup; pub use datafusion_datasource::file_groups::FileGroupPartitioner; -pub use datafusion_datasource::file_meta::FileMeta; pub use datafusion_datasource::file_scan_config::{ - wrap_partition_type_in_dict, wrap_partition_value_in_dict, FileScanConfig, - FileScanConfigBuilder, + FileScanConfig, FileScanConfigBuilder, wrap_partition_type_in_dict, + wrap_partition_value_in_dict, }; pub use datafusion_datasource::file_sink_config::*; pub use datafusion_datasource::file_stream::{ - FileOpenFuture, FileOpener, FileStream, OnError, + FileOpenFuture, FileOpener, FileStream, FileStreamBuilder, OnError, }; - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use arrow::array::{ - cast::AsArray, - types::{Float32Type, Float64Type, UInt32Type}, - BinaryArray, BooleanArray, Float32Array, Int32Array, Int64Array, RecordBatch, - StringArray, UInt64Array, - }; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow_schema::SchemaRef; - - use crate::datasource::schema_adapter::{ - DefaultSchemaAdapterFactory, SchemaAdapterFactory, - }; - - #[test] - fn schema_mapping_map_batch() { - let table_schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Utf8, true), - Field::new("c2", DataType::UInt32, true), - Field::new("c3", DataType::Float64, true), - ])); - - let adapter = DefaultSchemaAdapterFactory - .create(table_schema.clone(), table_schema.clone()); - - let file_schema = Schema::new(vec![ - Field::new("c1", DataType::Utf8, true), - Field::new("c2", DataType::UInt64, true), - Field::new("c3", DataType::Float32, true), - ]); - - let (mapping, _) = adapter.map_schema(&file_schema).expect("map schema failed"); - - let c1 = StringArray::from(vec!["hello", "world"]); - let c2 = UInt64Array::from(vec![9_u64, 5_u64]); - let c3 = Float32Array::from(vec![2.0_f32, 7.0_f32]); - let batch = RecordBatch::try_new( - Arc::new(file_schema), - vec![Arc::new(c1), Arc::new(c2), Arc::new(c3)], - ) - .unwrap(); - - let mapped_batch = mapping.map_batch(batch).unwrap(); - - assert_eq!(mapped_batch.schema(), table_schema); - assert_eq!(mapped_batch.num_columns(), 3); - assert_eq!(mapped_batch.num_rows(), 2); - - let c1 = mapped_batch.column(0).as_string::(); - let c2 = mapped_batch.column(1).as_primitive::(); - let c3 = mapped_batch.column(2).as_primitive::(); - - assert_eq!(c1.value(0), "hello"); - assert_eq!(c1.value(1), "world"); - assert_eq!(c2.value(0), 9_u32); - assert_eq!(c2.value(1), 5_u32); - assert_eq!(c3.value(0), 2.0_f64); - assert_eq!(c3.value(1), 7.0_f64); - } - - #[test] - fn schema_adapter_map_schema_with_projection() { - let table_schema = Arc::new(Schema::new(vec![ - Field::new("c0", DataType::Utf8, true), - Field::new("c1", DataType::Utf8, true), - Field::new("c2", DataType::Float64, true), - Field::new("c3", DataType::Int32, true), - Field::new("c4", DataType::Float32, true), - ])); - - let file_schema = Schema::new(vec![ - Field::new("id", DataType::Int32, true), - Field::new("c1", DataType::Boolean, true), - Field::new("c2", DataType::Float32, true), - Field::new("c3", DataType::Binary, true), - Field::new("c4", DataType::Int64, true), - ]); - - let indices = vec![1, 2, 4]; - let schema = SchemaRef::from(table_schema.project(&indices).unwrap()); - let adapter = DefaultSchemaAdapterFactory.create(schema, table_schema.clone()); - let (mapping, projection) = adapter.map_schema(&file_schema).unwrap(); - - let id = Int32Array::from(vec![Some(1), Some(2), Some(3)]); - let c1 = BooleanArray::from(vec![Some(true), Some(false), Some(true)]); - let c2 = Float32Array::from(vec![Some(2.0_f32), Some(7.0_f32), Some(3.0_f32)]); - let c3 = BinaryArray::from_opt_vec(vec![ - Some(b"hallo"), - Some(b"danke"), - Some(b"super"), - ]); - let c4 = Int64Array::from(vec![1, 2, 3]); - let batch = RecordBatch::try_new( - Arc::new(file_schema), - vec![ - Arc::new(id), - Arc::new(c1), - Arc::new(c2), - Arc::new(c3), - Arc::new(c4), - ], - ) - .unwrap(); - let rows_num = batch.num_rows(); - let projected = batch.project(&projection).unwrap(); - let mapped_batch = mapping.map_batch(projected).unwrap(); - - assert_eq!( - mapped_batch.schema(), - Arc::new(table_schema.project(&indices).unwrap()) - ); - assert_eq!(mapped_batch.num_columns(), indices.len()); - assert_eq!(mapped_batch.num_rows(), rows_num); - - let c1 = mapped_batch.column(0).as_string::(); - let c2 = mapped_batch.column(1).as_primitive::(); - let c4 = mapped_batch.column(2).as_primitive::(); - - assert_eq!(c1.value(0), "true"); - assert_eq!(c1.value(1), "false"); - assert_eq!(c1.value(2), "true"); - - assert_eq!(c2.value(0), 2.0_f64); - assert_eq!(c2.value(1), 7.0_f64); - assert_eq!(c2.value(2), 3.0_f64); - - assert_eq!(c4.value(0), 1.0_f32); - assert_eq!(c4.value(1), 2.0_f32); - assert_eq!(c4.value(2), 3.0_f32); - } -} diff --git a/datafusion/core/src/datasource/physical_plan/parquet.rs b/datafusion/core/src/datasource/physical_plan/parquet.rs index 8dee79ad61b23..4c6d915d5bcaa 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet.rs @@ -38,34 +38,35 @@ mod tests { use crate::prelude::{ParquetReadOptions, SessionConfig, SessionContext}; use crate::test::object_store::local_unpartitioned_file; use arrow::array::{ - ArrayRef, AsArray, Date64Array, Int32Array, Int64Array, Int8Array, StringArray, - StringViewArray, StructArray, + ArrayRef, AsArray, Date64Array, DictionaryArray, Int8Array, Int32Array, + Int64Array, StringArray, StringViewArray, StructArray, TimestampNanosecondArray, }; - use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaBuilder}; + use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaBuilder, UInt16Type}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; use arrow_schema::{SchemaRef, TimeUnit}; use bytes::{BufMut, BytesMut}; use datafusion_common::config::TableParquetOptions; use datafusion_common::test_util::{batches_to_sort_string, batches_to_string}; - use datafusion_common::{assert_contains, Result, ScalarValue}; + use datafusion_common::{Result, ScalarValue, assert_contains}; use datafusion_datasource::file_format::FileFormat; - use datafusion_datasource::file_meta::FileMeta; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_datasource::source::DataSourceExec; use datafusion_datasource::file::FileSource; - use datafusion_datasource::{FileRange, PartitionedFile}; + use datafusion_datasource::{PartitionedFile, TableSchema}; use datafusion_datasource_parquet::source::ParquetSource; use datafusion_datasource_parquet::{ DefaultParquetFileReaderFactory, ParquetFileReaderFactory, ParquetFormat, }; use datafusion_execution::object_store::ObjectStoreUrl; - use datafusion_expr::{col, lit, when, Expr}; + use datafusion_expr::{Expr, col, lit, when}; use datafusion_physical_expr::planner::logical2physical; use datafusion_physical_plan::analyze::AnalyzeExec; use datafusion_physical_plan::collect; - use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; + use datafusion_physical_plan::metrics::{ + ExecutionPlanMetricsSet, MetricType, MetricValue, MetricsSet, + }; use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; use chrono::{TimeZone, Utc}; @@ -160,7 +161,7 @@ mod tests { .as_ref() .map(|p| logical2physical(p, &table_schema)); - let mut source = ParquetSource::default(); + let mut source = ParquetSource::new(table_schema); if let Some(predicate) = predicate { source = source.with_predicate(predicate); } @@ -185,23 +186,20 @@ mod tests { source = source.with_bloom_filter_on_read(false); } - source.with_schema(Arc::clone(&table_schema)) + Arc::new(source) } fn build_parquet_exec( &self, - file_schema: SchemaRef, file_group: FileGroup, source: Arc, ) -> Arc { - let base_config = FileScanConfigBuilder::new( - ObjectStoreUrl::local_filesystem(), - file_schema, - source, - ) - .with_file_group(file_group) - .with_projection(self.projection.clone()) - .build(); + let base_config = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), source) + .with_file_group(file_group) + .with_projection_indices(self.projection.clone()) + .unwrap() + .build(); DataSourceExec::from_data_source(base_config) } @@ -230,18 +228,15 @@ mod tests { // build a ParquetExec to return the results let parquet_source = self.build_file_source(Arc::clone(table_schema)); - let parquet_exec = self.build_parquet_exec( - Arc::clone(table_schema), - file_group.clone(), - Arc::clone(&parquet_source), - ); + let parquet_exec = + self.build_parquet_exec(file_group.clone(), Arc::clone(&parquet_source)); let analyze_exec = Arc::new(AnalyzeExec::new( false, false, + vec![MetricType::SUMMARY, MetricType::DEV], // use a new ParquetSource to avoid sharing execution metrics self.build_parquet_exec( - Arc::clone(table_schema), file_group.clone(), self.build_file_source(Arc::clone(table_schema)), ), @@ -311,7 +306,7 @@ mod tests { let batch = RecordBatch::try_new(file_schema.clone(), vec![c1]).unwrap(); - // Since c2 is missing from the file and we didn't supply a custom `SchemaAdapterFactory`, + // Since c2 is missing from the file and we didn't supply a custom `PhysicalExprAdapterFactory`, // the default behavior is to fill in missing columns with nulls. // Thus this predicate will come back as false. let filter = col("c2").eq(lit(1_i32)); @@ -332,7 +327,7 @@ mod tests { let metric = get_value(&metrics, "pushdown_rows_pruned"); assert_eq!(metric, 3, "Expected all rows to be pruned"); - // If we excplicitly allow nulls the rest of the predicate should work + // If we explicitly allow nulls the rest of the predicate should work let filter = col("c2").is_null().and(col("c1").eq(lit(1_i32))); let rt = RoundTrip::new() .with_table_schema(table_schema.clone()) @@ -342,13 +337,13 @@ mod tests { .await; let batches = rt.batches.unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&batches),@r###" + insta::assert_snapshot!(batches_to_sort_string(&batches),@r" +----+----+ | c1 | c2 | +----+----+ | 1 | | +----+----+ - "###); + "); let metrics = rt.parquet_exec.metrics().unwrap(); let metric = get_value(&metrics, "pushdown_rows_pruned"); @@ -369,7 +364,7 @@ mod tests { let batch = RecordBatch::try_new(file_schema.clone(), vec![c1]).unwrap(); - // Since c2 is missing from the file and we didn't supply a custom `SchemaAdapterFactory`, + // Since c2 is missing from the file and we didn't supply a custom `PhysicalExprAdapterFactory`, // the default behavior is to fill in missing columns with nulls. // Thus this predicate will come back as false. let filter = col("c2").eq(lit("abc")); @@ -390,7 +385,7 @@ mod tests { let metric = get_value(&metrics, "pushdown_rows_pruned"); assert_eq!(metric, 3, "Expected all rows to be pruned"); - // If we excplicitly allow nulls the rest of the predicate should work + // If we explicitly allow nulls the rest of the predicate should work let filter = col("c2").is_null().and(col("c1").eq(lit(1_i32))); let rt = RoundTrip::new() .with_table_schema(table_schema.clone()) @@ -400,13 +395,13 @@ mod tests { .await; let batches = rt.batches.unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&batches),@r###" + insta::assert_snapshot!(batches_to_sort_string(&batches),@r" +----+----+ | c1 | c2 | +----+----+ | 1 | | +----+----+ - "###); + "); let metrics = rt.parquet_exec.metrics().unwrap(); let metric = get_value(&metrics, "pushdown_rows_pruned"); @@ -431,7 +426,7 @@ mod tests { let batch = RecordBatch::try_new(file_schema.clone(), vec![c1, c3]).unwrap(); - // Since c2 is missing from the file and we didn't supply a custom `SchemaAdapterFactory`, + // Since c2 is missing from the file and we didn't supply a custom `PhysicalExprAdapterFactory`, // the default behavior is to fill in missing columns with nulls. // Thus this predicate will come back as false. let filter = col("c2").eq(lit("abc")); @@ -452,7 +447,7 @@ mod tests { let metric = get_value(&metrics, "pushdown_rows_pruned"); assert_eq!(metric, 3, "Expected all rows to be pruned"); - // If we excplicitly allow nulls the rest of the predicate should work + // If we explicitly allow nulls the rest of the predicate should work let filter = col("c2").is_null().and(col("c1").eq(lit(1_i32))); let rt = RoundTrip::new() .with_table_schema(table_schema.clone()) @@ -462,13 +457,13 @@ mod tests { .await; let batches = rt.batches.unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&batches),@r###" + insta::assert_snapshot!(batches_to_sort_string(&batches),@r" +----+----+----+ | c1 | c2 | c3 | +----+----+----+ | 1 | | 7 | +----+----+----+ - "###); + "); let metrics = rt.parquet_exec.metrics().unwrap(); let metric = get_value(&metrics, "pushdown_rows_pruned"); @@ -493,7 +488,7 @@ mod tests { let batch = RecordBatch::try_new(file_schema.clone(), vec![c3.clone(), c3]).unwrap(); - // Since c2 is missing from the file and we didn't supply a custom `SchemaAdapterFactory`, + // Since c2 is missing from the file and we didn't supply a custom `PhysicalExprAdapterFactory`, // the default behavior is to fill in missing columns with nulls. // Thus this predicate will come back as false. let filter = col("c2").eq(lit("abc")); @@ -514,7 +509,7 @@ mod tests { let metric = get_value(&metrics, "pushdown_rows_pruned"); assert_eq!(metric, 3, "Expected all rows to be pruned"); - // If we excplicitly allow nulls the rest of the predicate should work + // If we explicitly allow nulls the rest of the predicate should work let filter = col("c2").is_null().and(col("c3").eq(lit(7_i32))); let rt = RoundTrip::new() .with_table_schema(table_schema.clone()) @@ -524,13 +519,13 @@ mod tests { .await; let batches = rt.batches.unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&batches),@r###" + insta::assert_snapshot!(batches_to_sort_string(&batches),@r" +----+----+----+ | c1 | c2 | c3 | +----+----+----+ | | | 7 | +----+----+----+ - "###); + "); let metrics = rt.parquet_exec.metrics().unwrap(); let metric = get_value(&metrics, "pushdown_rows_pruned"); @@ -573,13 +568,13 @@ mod tests { let batches = rt.batches.unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&batches),@r###" + insta::assert_snapshot!(batches_to_sort_string(&batches),@r" +----+----+----+ | c1 | c2 | c3 | +----+----+----+ | 1 | | 10 | +----+----+----+ - "###); + "); let metrics = rt.parquet_exec.metrics().unwrap(); let metric = get_value(&metrics, "pushdown_rows_pruned"); @@ -603,7 +598,7 @@ mod tests { let batches = rt.batches.unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&batches),@r###" + insta::assert_snapshot!(batches_to_sort_string(&batches),@r" +----+----+----+ | c1 | c2 | c3 | +----+----+----+ @@ -611,7 +606,7 @@ mod tests { | 4 | | 40 | | 5 | | 50 | +----+----+----+ - "###); + "); let metrics = rt.parquet_exec.metrics().unwrap(); let metric = get_value(&metrics, "pushdown_rows_pruned"); @@ -640,7 +635,7 @@ mod tests { .await .unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&read), @r###" + insta::assert_snapshot!(batches_to_sort_string(&read), @r" +-----+----+----+ | c1 | c2 | c3 | +-----+----+----+ @@ -654,7 +649,7 @@ mod tests { | bar | | | | bar | | | +-----+----+----+ - "###); + "); } #[tokio::test] @@ -755,18 +750,18 @@ mod tests { .await .unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&read),@r###" - +-----+----+----+ - | c1 | c3 | c2 | - +-----+----+----+ - | | | | - | | 10 | 1 | - | | 20 | | - | | 20 | 2 | - | Foo | 10 | | - | bar | | | - +-----+----+----+ - "###); + insta::assert_snapshot!(batches_to_sort_string(&read),@r" + +-----+----+----+ + | c1 | c3 | c2 | + +-----+----+----+ + | | | | + | | 10 | 1 | + | | 20 | | + | | 20 | 2 | + | Foo | 10 | | + | bar | | | + +-----+----+----+ + "); } #[tokio::test] @@ -787,14 +782,14 @@ mod tests { .round_trip(vec![batch1, batch2]) .await; - insta::assert_snapshot!(batches_to_sort_string(&rt.batches.unwrap()), @r###" + insta::assert_snapshot!(batches_to_sort_string(&rt.batches.unwrap()), @r" +----+----+----+ | c1 | c3 | c2 | +----+----+----+ | | 10 | 1 | | | 20 | 2 | +----+----+----+ - "###); + "); let metrics = rt.parquet_exec.metrics().unwrap(); // Note there are were 6 rows in total (across three batches) assert_eq!(get_value(&metrics, "pushdown_rows_pruned"), 4); @@ -830,7 +825,7 @@ mod tests { .await .unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&read), @r###" + insta::assert_snapshot!(batches_to_sort_string(&read), @r" +-----+-----+ | c1 | c4 | +-----+-----+ @@ -841,7 +836,7 @@ mod tests { | bar | | | bar | | +-----+-----+ - "###); + "); } #[tokio::test] @@ -960,6 +955,73 @@ mod tests { assert_eq!(read, 2, "Expected 2 rows to match the predicate"); } + #[tokio::test] + async fn evolved_schema_column_type_filter_timestamp_units() { + // The table and filter have a common data type + // The table schema is in milliseconds, but the file schema is in nanoseconds + let c1: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![ + Some(1_000_000_000), // 1970-01-01T00:00:01Z + Some(2_000_000_000), // 1970-01-01T00:00:02Z + Some(3_000_000_000), // 1970-01-01T00:00:03Z + Some(4_000_000_000), // 1970-01-01T00:00:04Z + ])); + let batch = create_batch(vec![("c1", c1.clone())]); + let table_schema = Arc::new(Schema::new(vec![Field::new( + "c1", + DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())), + false, + )])); + // One row should match, 2 pruned via page index, 1 pruned via filter pushdown + let filter = col("c1").eq(lit(ScalarValue::TimestampMillisecond( + Some(1_000), + Some("UTC".into()), + ))); + let rt = RoundTrip::new() + .with_predicate(filter) + .with_pushdown_predicate() + .with_page_index_predicate() // produces pages with 2 rows each (2 pages total for our data) + .with_table_schema(table_schema.clone()) + .round_trip(vec![batch.clone()]) + .await; + // There should be no predicate evaluation errors and we keep 1 row + let metrics = rt.parquet_exec.metrics().unwrap(); + assert_eq!(get_value(&metrics, "predicate_evaluation_errors"), 0); + let read = rt + .batches + .unwrap() + .iter() + .map(|b| b.num_rows()) + .sum::(); + assert_eq!(read, 1, "Expected 1 rows to match the predicate"); + assert_eq!(get_value(&metrics, "row_groups_pruned_statistics"), 0); + assert_eq!(get_value(&metrics, "page_index_rows_pruned"), 2); + assert_eq!(get_value(&metrics, "page_index_pages_pruned"), 1); + assert_eq!(get_value(&metrics, "pushdown_rows_pruned"), 1); + // If we filter with a value that is completely out of the range of the data + // we prune at the row group level. + let filter = col("c1").eq(lit(ScalarValue::TimestampMillisecond( + Some(5_000), + Some("UTC".into()), + ))); + let rt = RoundTrip::new() + .with_predicate(filter) + .with_pushdown_predicate() + .with_table_schema(table_schema) + .round_trip(vec![batch]) + .await; + // There should be no predicate evaluation errors and we keep 0 rows + let metrics = rt.parquet_exec.metrics().unwrap(); + assert_eq!(get_value(&metrics, "predicate_evaluation_errors"), 0); + let read = rt + .batches + .unwrap() + .iter() + .map(|b| b.num_rows()) + .sum::(); + assert_eq!(read, 0, "Expected 0 rows to match the predicate"); + assert_eq!(get_value(&metrics, "row_groups_pruned_statistics"), 1); + } + #[tokio::test] async fn evolved_schema_disjoint_schema_filter() { let c1: ArrayRef = @@ -988,18 +1050,18 @@ mod tests { // In a real query where this predicate was pushed down from a filter stage instead of created directly in the `DataSourceExec`, // the filter stage would be preserved as a separate execution plan stage so the actual query results would be as expected. - insta::assert_snapshot!(batches_to_sort_string(&read),@r###" - +-----+----+ - | c1 | c2 | - +-----+----+ - | | | - | | | - | | 1 | - | | 2 | - | Foo | | - | bar | | - +-----+----+ - "###); + insta::assert_snapshot!(batches_to_sort_string(&read),@r" + +-----+----+ + | c1 | c2 | + +-----+----+ + | | | + | | | + | | 1 | + | | 2 | + | Foo | | + | bar | | + +-----+----+ + "); } #[tokio::test] @@ -1024,13 +1086,13 @@ mod tests { .round_trip(vec![batch1, batch2]) .await; - insta::assert_snapshot!(batches_to_sort_string(&rt.batches.unwrap()), @r###" + insta::assert_snapshot!(batches_to_sort_string(&rt.batches.unwrap()), @r" +----+----+ | c1 | c2 | +----+----+ | | 1 | +----+----+ - "###); + "); let metrics = rt.parquet_exec.metrics().unwrap(); // Note there are were 6 rows in total (across three batches) assert_eq!(get_value(&metrics, "pushdown_rows_pruned"), 5); @@ -1084,7 +1146,7 @@ mod tests { .round_trip(vec![batch1, batch2, batch3, batch4]) .await; - insta::assert_snapshot!(batches_to_sort_string(&rt.batches.unwrap()), @r###" + insta::assert_snapshot!(batches_to_sort_string(&rt.batches.unwrap()), @r" +------+----+ | c1 | c2 | +------+----+ @@ -1101,14 +1163,22 @@ mod tests { | Foo2 | | | Foo3 | | +------+----+ - "###); + "); let metrics = rt.parquet_exec.metrics().unwrap(); // There are 4 rows pruned in each of batch2, batch3, and // batch4 for a total of 12. batch1 had no pruning as c2 was // filled in as null - assert_eq!(get_value(&metrics, "page_index_rows_pruned"), 12); - assert_eq!(get_value(&metrics, "page_index_rows_matched"), 6); + let (page_index_rows_pruned, page_index_rows_matched) = + get_pruning_metric(&metrics, "page_index_rows_pruned"); + assert_eq!(page_index_rows_pruned, 12); + assert_eq!(page_index_rows_matched, 6); + + // each page has 2 rows, so the num of pages is 1/2 the number of rows + let (page_index_pages_pruned, page_index_pages_matched) = + get_pruning_metric(&metrics, "page_index_pages_pruned"); + assert_eq!(page_index_pages_pruned, 6); + assert_eq!(page_index_pages_matched, 3); } #[tokio::test] @@ -1131,14 +1201,14 @@ mod tests { .await .unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&read),@r###" - +-----+----+ - | c1 | c2 | - +-----+----+ - | Foo | 1 | - | bar | | - +-----+----+ - "###); + insta::assert_snapshot!(batches_to_sort_string(&read),@r" + +-----+----+ + | c1 | c2 | + +-----+----+ + | Foo | 1 | + | bar | | + +-----+----+ + "); } #[tokio::test] @@ -1161,15 +1231,15 @@ mod tests { .await .unwrap(); - insta::assert_snapshot!(batches_to_sort_string(&read),@r###" - +-----+----+ - | c1 | c2 | - +-----+----+ - | | 2 | - | Foo | 1 | - | bar | | - +-----+----+ - "###); + insta::assert_snapshot!(batches_to_sort_string(&read),@r" + +-----+----+ + | c1 | c2 | + +-----+----+ + | | 2 | + | Foo | 1 | + | bar | | + +-----+----+ + "); } #[tokio::test] @@ -1194,7 +1264,7 @@ mod tests { ("c3", c3.clone()), ]); - // batch2: c3(int8), c2(int64), c1(string), c4(string) + // batch2: c3(date64), c2(int64), c1(string) let batch2 = create_batch(vec![("c3", c4), ("c2", c2), ("c1", c1)]); let table_schema = Schema::new(vec![ @@ -1208,8 +1278,10 @@ mod tests { .with_table_schema(Arc::new(table_schema)) .round_trip_to_batches(vec![batch1, batch2]) .await; - assert_contains!(read.unwrap_err().to_string(), - "Cannot cast file schema field c3 of type Date64 to table schema field of type Int8"); + assert_contains!( + read.unwrap_err().to_string(), + "Cannot cast column 'c3' from 'Date64' (physical data type) to 'Int8' (logical data type)" + ); } #[tokio::test] @@ -1259,7 +1331,7 @@ mod tests { async fn parquet_exec_with_int96_from_spark() -> Result<()> { // arrow-rs relies on the chrono library to convert between timestamps and strings, so // instead compare as Int64. The underlying type should be a PrimitiveArray of Int64 - // anyway, so this should be a zero-copy non-modifying cast at the SchemaAdapter. + // anyway, so this should be a zero-copy non-modifying cast. let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, true)])); let testdata = datafusion_common::test_util::parquet_test_data(); @@ -1462,14 +1534,7 @@ mod tests { #[tokio::test] async fn parquet_exec_with_range() -> Result<()> { fn file_range(meta: &ObjectMeta, start: i64, end: i64) -> PartitionedFile { - PartitionedFile { - object_meta: meta.clone(), - partition_values: vec![], - range: Some(FileRange { start, end }), - statistics: None, - extensions: None, - metadata_size_hint: None, - } + PartitionedFile::new_from_meta(meta.clone()).with_range(start, end) } async fn assert_parquet_read( @@ -1480,8 +1545,7 @@ mod tests { ) -> Result<()> { let config = FileScanConfigBuilder::new( ObjectStoreUrl::local_filesystem(), - file_schema, - Arc::new(ParquetSource::default()), + Arc::new(ParquetSource::new(file_schema)), ) .with_file_groups(file_groups) .build(); @@ -1552,21 +1616,15 @@ mod tests { .await .unwrap(); - let partitioned_file = PartitionedFile { - object_meta: meta, - partition_values: vec![ + let partitioned_file = PartitionedFile::new_from_meta(meta) + .with_partition_values(vec![ ScalarValue::from("2021"), ScalarValue::UInt8(Some(10)), ScalarValue::Dictionary( Box::new(DataType::UInt16), Box::new(ScalarValue::from("26")), ), - ], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - }; + ]); let expected_schema = Schema::new(vec![ Field::new("id", DataType::Int32, true), @@ -1583,23 +1641,27 @@ mod tests { ), ]); - let source = Arc::new(ParquetSource::default()); - let config = FileScanConfigBuilder::new(object_store_url, schema.clone(), source) - .with_file(partitioned_file) - // file has 10 cols so index 12 should be month and 13 should be day - .with_projection(Some(vec![0, 1, 2, 12, 13])) - .with_table_partition_cols(vec![ - Field::new("year", DataType::Utf8, false), - Field::new("month", DataType::UInt8, false), - Field::new( + let table_schema = TableSchema::new( + Arc::clone(&schema), + vec![ + Arc::new(Field::new("year", DataType::Utf8, false)), + Arc::new(Field::new("month", DataType::UInt8, false)), + Arc::new(Field::new( "day", DataType::Dictionary( Box::new(DataType::UInt16), Box::new(DataType::Utf8), ), false, - ), - ]) + )), + ], + ); + let source = Arc::new(ParquetSource::new(table_schema.clone())); + let config = FileScanConfigBuilder::new(object_store_url, source) + .with_file(partitioned_file) + // file has 10 cols so index 12 should be month and 13 should be day + .with_projection_indices(Some(vec![0, 1, 2, 12, 13])) + .unwrap() .build(); let parquet_exec = DataSourceExec::from_data_source(config); @@ -1614,20 +1676,20 @@ mod tests { let batch = results.next().await.unwrap()?; assert_eq!(batch.schema().as_ref(), &expected_schema); - assert_snapshot!(batches_to_string(&[batch]),@r###" - +----+----------+-------------+-------+-----+ - | id | bool_col | tinyint_col | month | day | - +----+----------+-------------+-------+-----+ - | 4 | true | 0 | 10 | 26 | - | 5 | false | 1 | 10 | 26 | - | 6 | true | 0 | 10 | 26 | - | 7 | false | 1 | 10 | 26 | - | 2 | true | 0 | 10 | 26 | - | 3 | false | 1 | 10 | 26 | - | 0 | true | 0 | 10 | 26 | - | 1 | false | 1 | 10 | 26 | - +----+----------+-------------+-------+-----+ - "###); + assert_snapshot!(batches_to_string(&[batch]),@r" + +----+----------+-------------+-------+-----+ + | id | bool_col | tinyint_col | month | day | + +----+----------+-------------+-------+-----+ + | 4 | true | 0 | 10 | 26 | + | 5 | false | 1 | 10 | 26 | + | 6 | true | 0 | 10 | 26 | + | 7 | false | 1 | 10 | 26 | + | 2 | true | 0 | 10 | 26 | + | 3 | false | 1 | 10 | 26 | + | 0 | true | 0 | 10 | 26 | + | 1 | false | 1 | 10 | 26 | + +----+----------+-------------+-------+-----+ + "); let batch = results.next().await; assert!(batch.is_none()); @@ -1643,26 +1705,18 @@ mod tests { .unwrap() .child("invalid.parquet"); - let partitioned_file = PartitionedFile { - object_meta: ObjectMeta { - location, - last_modified: Utc.timestamp_nanos(0), - size: 1337, - e_tag: None, - version: None, - }, - partition_values: vec![], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - }; + let partitioned_file = PartitionedFile::new_from_meta(ObjectMeta { + location, + last_modified: Utc.timestamp_nanos(0), + size: 1337, + e_tag: None, + version: None, + }); let file_schema = Arc::new(Schema::empty()); let config = FileScanConfigBuilder::new( ObjectStoreUrl::local_filesystem(), - file_schema, - Arc::new(ParquetSource::default()), + Arc::new(ParquetSource::new(file_schema)), ) .with_file(partitioned_file) .build(); @@ -1687,6 +1741,7 @@ mod tests { Some(3), Some(4), Some(5), + Some(6), // last page with only one row ])); let batch1 = create_batch(vec![("int", c1.clone())]); @@ -1695,25 +1750,53 @@ mod tests { let rt = RoundTrip::new() .with_predicate(filter) .with_page_index_predicate() - .round_trip(vec![batch1]) + .round_trip(vec![batch1.clone()]) .await; let metrics = rt.parquet_exec.metrics().unwrap(); - assert_snapshot!(batches_to_sort_string(&rt.batches.unwrap()),@r###" - +-----+ - | int | - +-----+ - | 4 | - | 5 | - +-----+ - "###); - assert_eq!(get_value(&metrics, "page_index_rows_pruned"), 4); - assert_eq!(get_value(&metrics, "page_index_rows_matched"), 2); + assert_snapshot!(batches_to_sort_string(&rt.batches.unwrap()),@r" + +-----+ + | int | + +-----+ + | 4 | + | 5 | + +-----+ + "); + let (page_index_rows_pruned, page_index_rows_matched) = + get_pruning_metric(&metrics, "page_index_rows_pruned"); + assert_eq!(page_index_rows_pruned, 5); + assert_eq!(page_index_rows_matched, 2); assert!( get_value(&metrics, "page_index_eval_time") > 0, "no eval time in metrics: {metrics:#?}" ); + + // each page has 2 rows, so the num of pages is 1/2 the number of rows + let (page_index_pages_pruned, page_index_pages_matched) = + get_pruning_metric(&metrics, "page_index_pages_pruned"); + assert_eq!(page_index_pages_pruned, 3); + assert_eq!(page_index_pages_matched, 1); + + // test with a filter that matches the page with one row + let filter = col("int").eq(lit(6_i32)); + let rt = RoundTrip::new() + .with_predicate(filter) + .with_page_index_predicate() + .round_trip(vec![batch1]) + .await; + + let metrics = rt.parquet_exec.metrics().unwrap(); + + let (page_index_rows_pruned, page_index_rows_matched) = + get_pruning_metric(&metrics, "page_index_rows_pruned"); + assert_eq!(page_index_rows_pruned, 6); + assert_eq!(page_index_rows_matched, 1); + + let (page_index_pages_pruned, page_index_pages_matched) = + get_pruning_metric(&metrics, "page_index_pages_pruned"); + assert_eq!(page_index_pages_pruned, 3); + assert_eq!(page_index_pages_matched, 1); } /// Returns a string array with contents: @@ -1751,14 +1834,14 @@ mod tests { let metrics = rt.parquet_exec.metrics().unwrap(); // assert the batches and some metrics - assert_snapshot!(batches_to_string(&rt.batches.unwrap()),@r###" - +-----+ - | c1 | - +-----+ - | Foo | - | zzz | - +-----+ - "###); + assert_snapshot!(batches_to_string(&rt.batches.unwrap()),@r" + +-----+ + | c1 | + +-----+ + | Foo | + | zzz | + +-----+ + "); // pushdown predicates have eliminated all 4 bar rows and the // null row for 5 rows total @@ -1798,13 +1881,109 @@ mod tests { assert_contains!(&explain, "predicate=c1@0 != bar"); // there's a single row group, but we can check that it matched - // if no pruning was done this would be 0 instead of 1 - assert_contains!(&explain, "row_groups_matched_statistics=1"); + assert_contains!( + &explain, + "row_groups_pruned_statistics=1 total \u{2192} 1 matched" + ); // check the projection assert_contains!(&explain, "projection=[c1]"); } + #[tokio::test] + async fn parquet_exec_metrics_with_multiple_predicates() { + // Test that metrics are correctly calculated when multiple predicates + // are pushed down (connected with AND). This ensures we don't double-count + // rows when multiple predicates filter the data sequentially. + + // Create a batch with two columns: c1 (string) and c2 (int32) + // Total: 10 rows + let c1: ArrayRef = Arc::new(StringArray::from(vec![ + Some("foo"), // 0 - passes c1 filter, fails c2 filter (5 <= 10) + Some("bar"), // 1 - fails c1 filter + Some("bar"), // 2 - fails c1 filter + Some("baz"), // 3 - passes both filters (20 > 10) + Some("foo"), // 4 - passes both filters (12 > 10) + Some("bar"), // 5 - fails c1 filter + Some("baz"), // 6 - passes both filters (25 > 10) + Some("foo"), // 7 - passes c1 filter, fails c2 filter (7 <= 10) + Some("bar"), // 8 - fails c1 filter + Some("qux"), // 9 - passes both filters (30 > 10) + ])); + + let c2: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(5), + Some(15), + Some(8), + Some(20), + Some(12), + Some(9), + Some(25), + Some(7), + Some(18), + Some(30), + ])); + + let batch = create_batch(vec![("c1", c1), ("c2", c2)]); + + // Create filter: c1 != 'bar' AND c2 > 10 + // + // First predicate (c1 != 'bar'): + // - Rows passing: 0, 3, 4, 6, 7, 9 (6 rows) + // - Rows pruned: 1, 2, 5, 8 (4 rows) + // + // Second predicate (c2 > 10) on remaining 6 rows: + // - Rows passing: 3, 4, 6, 9 (4 rows with c2 = 20, 12, 25, 30) + // - Rows pruned: 0, 7 (2 rows with c2 = 5, 7) + // + // Expected final metrics: + // - pushdown_rows_matched: 4 (final result) + // - pushdown_rows_pruned: 4 + 2 = 6 (cumulative) + // - Total: 4 + 6 = 10 + + let filter = col("c1").not_eq(lit("bar")).and(col("c2").gt(lit(10))); + + let rt = RoundTrip::new() + .with_predicate(filter) + .with_pushdown_predicate() + .round_trip(vec![batch]) + .await; + + let metrics = rt.parquet_exec.metrics().unwrap(); + + // Verify the result rows + assert_snapshot!(batches_to_string(&rt.batches.unwrap()),@r" + +-----+----+ + | c1 | c2 | + +-----+----+ + | baz | 20 | + | foo | 12 | + | baz | 25 | + | qux | 30 | + +-----+----+ + "); + + // Verify metrics - this is the key test + let pushdown_rows_matched = get_value(&metrics, "pushdown_rows_matched"); + let pushdown_rows_pruned = get_value(&metrics, "pushdown_rows_pruned"); + + assert_eq!( + pushdown_rows_matched, 4, + "Expected 4 rows to pass both predicates" + ); + assert_eq!( + pushdown_rows_pruned, 6, + "Expected 6 rows to be pruned (4 by first predicate + 2 by second predicate)" + ); + + // The sum should equal the total number of rows + assert_eq!( + pushdown_rows_matched + pushdown_rows_pruned, + 10, + "matched + pruned should equal total rows" + ); + } + #[tokio::test] async fn parquet_exec_has_no_pruning_predicate_if_can_not_prune() { // batch1: c1(string) @@ -1830,8 +2009,10 @@ mod tests { // When both matched and pruned are 0, it means that the pruning predicate // was not used at all. - assert_contains!(&explain, "row_groups_matched_statistics=0"); - assert_contains!(&explain, "row_groups_pruned_statistics=0"); + assert_contains!( + &explain, + "row_groups_pruned_statistics=1 total \u{2192} 1 matched" + ); // But pushdown predicate should be present assert_contains!( @@ -1884,7 +2065,12 @@ mod tests { /// Panics if no such metric. fn get_value(metrics: &MetricsSet, metric_name: &str) -> usize { match metrics.sum_by_name(metric_name) { - Some(v) => v.as_usize(), + Some(v) => match v { + MetricValue::PruningMetrics { + pruning_metrics, .. + } => pruning_metrics.pruned(), + _ => v.as_usize(), + }, _ => { panic!( "Expected metric not found. Looking for '{metric_name}' in\n\n{metrics:#?}" @@ -1893,6 +2079,20 @@ mod tests { } } + fn get_pruning_metric(metrics: &MetricsSet, metric_name: &str) -> (usize, usize) { + match metrics.sum_by_name(metric_name) { + Some(MetricValue::PruningMetrics { + pruning_metrics, .. + }) => (pruning_metrics.pruned(), pruning_metrics.matched()), + Some(_) => panic!( + "Metric '{metric_name}' is not a pruning metric in\n\n{metrics:#?}" + ), + None => panic!( + "Expected metric not found. Looking for '{metric_name}' in\n\n{metrics:#?}" + ), + } + } + fn populate_csv_partitions( tmp_dir: &TempDir, partition_count: usize, @@ -1952,14 +2152,14 @@ mod tests { let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out"; fs::create_dir(&out_dir).unwrap(); let df = ctx.sql("SELECT c1, c2 FROM test").await?; - let schema: Schema = df.schema().into(); + let schema = Arc::clone(df.schema().inner()); // Register a listing table - this will use all files in the directory as data sources // for the query ctx.register_listing_table( "my_table", &out_dir, listing_options, - Some(Arc::new(schema)), + Some(schema), None, ) .await @@ -2024,13 +2224,13 @@ mod tests { let sql = "select * from base_table where name='test02'"; let batch = ctx.sql(sql).await.unwrap().collect().await.unwrap(); assert_eq!(batch.len(), 1); - insta::assert_snapshot!(batches_to_string(&batch),@r###" - +---------------------+----+--------+ - | struct | id | name | - +---------------------+----+--------+ - | {id: 4, name: aaa2} | 2 | test02 | - +---------------------+----+--------+ - "###); + insta::assert_snapshot!(batches_to_string(&batch),@r" + +---------------------+----+--------+ + | struct | id | name | + +---------------------+----+--------+ + | {id: 4, name: aaa2} | 2 | test02 | + +---------------------+----+--------+ + "); Ok(()) } @@ -2053,13 +2253,55 @@ mod tests { let sql = "select * from base_table where name='test02'"; let batch = ctx.sql(sql).await.unwrap().collect().await.unwrap(); assert_eq!(batch.len(), 1); - insta::assert_snapshot!(batches_to_string(&batch),@r###" - +---------------------+----+--------+ - | struct | id | name | - +---------------------+----+--------+ - | {id: 4, name: aaa2} | 2 | test02 | - +---------------------+----+--------+ - "###); + insta::assert_snapshot!(batches_to_string(&batch),@r" + +---------------------+----+--------+ + | struct | id | name | + +---------------------+----+--------+ + | {id: 4, name: aaa2} | 2 | test02 | + +---------------------+----+--------+ + "); + Ok(()) + } + + /// Tests that constant dictionary columns (where min == max in statistics) + /// are correctly handled. This reproduced a bug where the constant value + /// from statistics had type Utf8 but the schema expected Dictionary. + #[tokio::test] + async fn test_constant_dictionary_column_parquet() -> Result<()> { + let tmp_dir = TempDir::new()?; + let path = tmp_dir.path().to_str().unwrap().to_string() + "/test.parquet"; + + // Write parquet with dictionary column where all values are the same + let schema = Arc::new(Schema::new(vec![Field::new( + "status", + DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)), + false, + )])); + let status: DictionaryArray = + vec!["active", "active"].into_iter().collect(); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(status)])?; + let file = File::create(&path)?; + let props = WriterProperties::builder() + .set_statistics_enabled(parquet::file::properties::EnabledStatistics::Page) + .build(); + let mut writer = ArrowWriter::try_new(file, schema, Some(props))?; + writer.write(&batch)?; + writer.close()?; + + // Query the constant dictionary column + let ctx = SessionContext::new(); + ctx.register_parquet("t", &path, ParquetReadOptions::default()) + .await?; + let result = ctx.sql("SELECT status FROM t").await?.collect().await?; + + insta::assert_snapshot!(batches_to_string(&result),@r" + +--------+ + | status | + +--------+ + | active | + | active | + +--------+ + "); Ok(()) } @@ -2141,7 +2383,7 @@ mod tests { fn create_reader( &self, partition_index: usize, - file_meta: FileMeta, + partitioned_file: PartitionedFile, metadata_size_hint: Option, metrics: &ExecutionPlanMetricsSet, ) -> Result> @@ -2152,7 +2394,7 @@ mod tests { .push(metadata_size_hint); self.inner.create_reader( partition_index, - file_meta, + partitioned_file, metadata_size_hint, metrics, ) @@ -2184,42 +2426,28 @@ mod tests { let size_hint_calls = reader_factory.metadata_size_hint_calls.clone(); let source = Arc::new( - ParquetSource::default() + ParquetSource::new(Arc::clone(&schema)) .with_parquet_file_reader_factory(reader_factory) .with_metadata_size_hint(456), ); - let config = FileScanConfigBuilder::new(store_url, schema, source) + let config = FileScanConfigBuilder::new(store_url, source) .with_file( - PartitionedFile { - object_meta: ObjectMeta { - location: Path::from(name_1), - last_modified: Utc::now(), - size: total_size_1, - e_tag: None, - version: None, - }, - partition_values: vec![], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - } - .with_metadata_size_hint(123), - ) - .with_file(PartitionedFile { - object_meta: ObjectMeta { - location: Path::from(name_2), + PartitionedFile::new_from_meta(ObjectMeta { + location: Path::from(name_1), last_modified: Utc::now(), - size: total_size_2, + size: total_size_1, e_tag: None, version: None, - }, - partition_values: vec![], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - }) + }) + .with_metadata_size_hint(123), + ) + .with_file(PartitionedFile::new_from_meta(ObjectMeta { + location: Path::from(name_2), + last_modified: Utc::now(), + size: total_size_2, + e_tag: None, + version: None, + })) .build(); let exec = DataSourceExec::from_data_source(config); diff --git a/datafusion/core/src/datasource/view_test.rs b/datafusion/core/src/datasource/view_test.rs index 85ad9ff664ade..35418d6dea632 100644 --- a/datafusion/core/src/datasource/view_test.rs +++ b/datafusion/core/src/datasource/view_test.rs @@ -46,13 +46,13 @@ mod tests { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&results),@r###" + insta::assert_snapshot!(batches_to_string(&results),@r" +---+ | b | +---+ | 2 | +---+ - "###); + "); Ok(()) } @@ -96,14 +96,14 @@ mod tests { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&results),@r###" + insta::assert_snapshot!(batches_to_string(&results),@r" +---------+---------+---------+ | column1 | column2 | column3 | +---------+---------+---------+ | 1 | 2 | 3 | | 4 | 5 | 6 | +---------+---------+---------+ - "###); + "); let view_sql = "CREATE VIEW replace_xyz AS SELECT * REPLACE (column1*2 as column1) FROM xyz"; @@ -115,14 +115,14 @@ mod tests { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&results),@r###" + insta::assert_snapshot!(batches_to_string(&results),@r" +---------+---------+---------+ | column1 | column2 | column3 | +---------+---------+---------+ | 2 | 2 | 3 | | 8 | 5 | 6 | +---------+---------+---------+ - "###); + "); Ok(()) } @@ -146,14 +146,14 @@ mod tests { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&results),@r###" + insta::assert_snapshot!(batches_to_string(&results),@r" +---------------+ | column1_alias | +---------------+ | 1 | | 4 | +---------------+ - "###); + "); Ok(()) } @@ -177,14 +177,14 @@ mod tests { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&results),@r###" + insta::assert_snapshot!(batches_to_string(&results),@r" +---------------+---------------+ | column2_alias | column1_alias | +---------------+---------------+ | 2 | 1 | | 5 | 4 | +---------------+---------------+ - "###); + "); Ok(()) } @@ -213,14 +213,14 @@ mod tests { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&results),@r###" + insta::assert_snapshot!(batches_to_string(&results),@r" +---------+ | column1 | +---------+ | 1 | | 4 | +---------+ - "###); + "); Ok(()) } @@ -249,13 +249,13 @@ mod tests { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&results),@r###" + insta::assert_snapshot!(batches_to_string(&results),@r" +---------+ | column1 | +---------+ | 4 | +---------+ - "###); + "); Ok(()) } @@ -287,14 +287,14 @@ mod tests { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&results),@r###" + insta::assert_snapshot!(batches_to_string(&results),@r" +---------+---------+---------+ | column2 | column1 | column3 | +---------+---------+---------+ | 2 | 1 | 3 | | 5 | 4 | 6 | +---------+---------+---------+ - "###); + "); Ok(()) } @@ -358,7 +358,10 @@ mod tests { .to_string(); assert!(formatted.contains("DataSourceExec: ")); assert!(formatted.contains("file_type=parquet")); - assert!(formatted.contains("projection=[bool_col, int_col], limit=10")); + assert!( + formatted.contains("projection=[bool_col, int_col], limit=10"), + "{formatted}" + ); Ok(()) } @@ -442,14 +445,14 @@ mod tests { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&results),@r###" + insta::assert_snapshot!(batches_to_string(&results),@r" +---------+ | column1 | +---------+ | 1 | | 4 | +---------+ - "###); + "); Ok(()) } diff --git a/datafusion/core/src/execution/context/csv.rs b/datafusion/core/src/execution/context/csv.rs index 15d6d21f038a0..e6f95886e91d1 100644 --- a/datafusion/core/src/execution/context/csv.rs +++ b/datafusion/core/src/execution/context/csv.rs @@ -37,9 +37,16 @@ impl SessionContext { /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); /// // You can read a single file using `read_csv` - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let df = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await?; /// // you can also read multiple files: - /// let df = ctx.read_csv(vec!["tests/data/example.csv", "tests/data/example.csv"], CsvReadOptions::new()).await?; + /// let df = ctx + /// .read_csv( + /// vec!["tests/data/example.csv", "tests/data/example.csv"], + /// CsvReadOptions::new(), + /// ) + /// .await?; /// # Ok(()) /// # } /// ``` diff --git a/datafusion/core/src/execution/context/json.rs b/datafusion/core/src/execution/context/json.rs index e9d799400863d..f7df2ad7a1cd6 100644 --- a/datafusion/core/src/execution/context/json.rs +++ b/datafusion/core/src/execution/context/json.rs @@ -15,13 +15,13 @@ // specific language governing permissions and limitations // under the License. +use super::super::options::ReadOptions; +use super::{DataFilePaths, DataFrame, ExecutionPlan, Result, SessionContext}; +use crate::execution::options::JsonReadOptions; use datafusion_common::TableReference; use datafusion_datasource_json::source::plan_to_json; use std::sync::Arc; -use super::super::options::{NdJsonReadOptions, ReadOptions}; -use super::{DataFilePaths, DataFrame, ExecutionPlan, Result, SessionContext}; - impl SessionContext { /// Creates a [`DataFrame`] for reading an JSON data source. /// @@ -32,7 +32,7 @@ impl SessionContext { pub async fn read_json( &self, table_paths: P, - options: NdJsonReadOptions<'_>, + options: JsonReadOptions<'_>, ) -> Result { self._read_type(table_paths, options).await } @@ -43,7 +43,7 @@ impl SessionContext { &self, table_ref: impl Into, table_path: impl AsRef, - options: NdJsonReadOptions<'_>, + options: JsonReadOptions<'_>, ) -> Result<()> { let listing_options = options .to_listing_options(&self.copied_config(), self.copied_table_options()); diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 5ef666b61e547..5dbae61fc534d 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -20,6 +20,7 @@ use std::collections::HashSet; use std::fmt::Debug; use std::sync::{Arc, Weak}; +use std::time::Duration; use super::options::ReadOptions; use crate::datasource::dynamic_file::DynamicListTableFactory; @@ -33,20 +34,20 @@ use crate::{ datasource::listing::{ ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, }, - datasource::{provider_as_source, MemTable, ViewTable}, - error::{DataFusionError, Result}, + datasource::{MemTable, ViewTable, provider_as_source}, + error::Result, execution::{ + FunctionRegistry, options::ArrowReadOptions, runtime_env::{RuntimeEnv, RuntimeEnvBuilder}, - FunctionRegistry, }, logical_expr::AggregateUDF, logical_expr::ScalarUDF, logical_expr::{ CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateFunction, CreateMemoryTable, CreateView, DropCatalogSchema, DropFunction, DropTable, - DropView, Execute, LogicalPlan, LogicalPlanBuilder, Prepare, SetVariable, - TableType, UNNAMED_TABLE, + DropView, Execute, LogicalPlan, LogicalPlanBuilder, Prepare, ResetVariable, + SetVariable, TableType, UNNAMED_TABLE, }, physical_expr::PhysicalExpr, physical_plan::ExecutionPlan, @@ -58,30 +59,43 @@ pub use crate::execution::session_state::SessionState; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use datafusion_catalog::memory::MemorySchemaProvider; use datafusion_catalog::MemoryCatalogProvider; +use datafusion_catalog::memory::MemorySchemaProvider; use datafusion_catalog::{ DynamicFileCatalog, TableFunction, TableFunctionImpl, UrlTableFactory, }; -use datafusion_common::config::ConfigOptions; +use datafusion_common::config::{ConfigField, ConfigOptions}; +use datafusion_common::metadata::ScalarAndMetadata; use datafusion_common::{ + DFSchema, DataFusionError, ParamValues, SchemaReference, TableReference, config::{ConfigExtension, TableOptions}, - exec_datafusion_err, exec_err, not_impl_err, plan_datafusion_err, plan_err, + exec_datafusion_err, exec_err, internal_datafusion_err, not_impl_err, + plan_datafusion_err, plan_err, tree_node::{TreeNodeRecursion, TreeNodeVisitor}, - DFSchema, ParamValues, ScalarValue, SchemaReference, TableReference, +}; +pub use datafusion_execution::TaskContext; +use datafusion_execution::cache::cache_manager::{ + DEFAULT_LIST_FILES_CACHE_MEMORY_LIMIT, DEFAULT_LIST_FILES_CACHE_TTL, + DEFAULT_METADATA_CACHE_LIMIT, }; pub use datafusion_execution::config::SessionConfig; +use datafusion_execution::disk_manager::{ + DEFAULT_MAX_TEMP_DIRECTORY_SIZE, DiskManagerBuilder, +}; use datafusion_execution::registry::SerializerRegistry; -pub use datafusion_execution::TaskContext; pub use datafusion_expr::execution_props::ExecutionProps; +#[cfg(feature = "sql")] +use datafusion_expr::planner::RelationPlanner; +use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::{ + Expr, UserDefinedLogicalNode, WindowUDF, expr_rewriter::FunctionRewrite, logical_plan::{DdlStatement, Statement}, planner::ExprPlanner, - Expr, UserDefinedLogicalNode, WindowUDF, }; use datafusion_optimizer::analyzer::type_coercion::TypeCoercion; -use datafusion_optimizer::Analyzer; +use datafusion_optimizer::simplify_expressions::ExprSimplifier; +use datafusion_optimizer::{Analyzer, OptimizerContext}; use datafusion_optimizer::{AnalyzerRule, OptimizerRule}; use datafusion_session::SessionStore; @@ -164,22 +178,23 @@ where /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); -/// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; -/// let df = df.filter(col("a").lt_eq(col("b")))? -/// .aggregate(vec![col("a")], vec![min(col("b"))])? -/// .limit(0, Some(100))?; -/// let results = df -/// .collect() -/// .await?; +/// let df = ctx +/// .read_csv("tests/data/example.csv", CsvReadOptions::new()) +/// .await?; +/// let df = df +/// .filter(col("a").lt_eq(col("b")))? +/// .aggregate(vec![col("a")], vec![min(col("b"))])? +/// .limit(0, Some(100))?; +/// let results = df.collect().await?; /// assert_batches_eq!( -/// &[ -/// "+---+----------------+", -/// "| a | min(?table?.b) |", -/// "+---+----------------+", -/// "| 1 | 2 |", -/// "+---+----------------+", -/// ], -/// &results +/// &[ +/// "+---+----------------+", +/// "| a | min(?table?.b) |", +/// "+---+----------------+", +/// "| 1 | 2 |", +/// "+---+----------------+", +/// ], +/// &results /// ); /// # Ok(()) /// # } @@ -195,21 +210,22 @@ where /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); -/// ctx.register_csv("example", "tests/data/example.csv", CsvReadOptions::new()).await?; +/// ctx.register_csv("example", "tests/data/example.csv", CsvReadOptions::new()) +/// .await?; /// let results = ctx -/// .sql("SELECT a, min(b) FROM example GROUP BY a LIMIT 100") -/// .await? -/// .collect() -/// .await?; +/// .sql("SELECT a, min(b) FROM example GROUP BY a LIMIT 100") +/// .await? +/// .collect() +/// .await?; /// assert_batches_eq!( -/// &[ -/// "+---+----------------+", -/// "| a | min(example.b) |", -/// "+---+----------------+", -/// "| 1 | 2 |", -/// "+---+----------------+", -/// ], -/// &results +/// &[ +/// "+---+----------------+", +/// "| a | min(example.b) |", +/// "+---+----------------+", +/// "| 1 | 2 |", +/// "+---+----------------+", +/// ], +/// &results /// ); /// # Ok(()) /// # } @@ -226,21 +242,21 @@ where /// # use datafusion::execution::SessionStateBuilder; /// # use datafusion_execution::runtime_env::RuntimeEnvBuilder; /// // Configure a 4k batch size -/// let config = SessionConfig::new() .with_batch_size(4 * 1024); +/// let config = SessionConfig::new().with_batch_size(4 * 1024); /// /// // configure a memory limit of 1GB with 20% slop -/// let runtime_env = RuntimeEnvBuilder::new() +/// let runtime_env = RuntimeEnvBuilder::new() /// .with_memory_limit(1024 * 1024 * 1024, 0.80) /// .build_arc() /// .unwrap(); /// /// // Create a SessionState using the config and runtime_env /// let state = SessionStateBuilder::new() -/// .with_config(config) -/// .with_runtime_env(runtime_env) -/// // include support for built in functions and configurations -/// .with_default_features() -/// .build(); +/// .with_config(config) +/// .with_runtime_env(runtime_env) +/// // include support for built in functions and configurations +/// .with_default_features() +/// .build(); /// /// // Create a SessionContext /// let ctx = SessionContext::from(state); @@ -297,13 +313,13 @@ impl SessionContext { pub async fn refresh_catalogs(&self) -> Result<()> { let cat_names = self.catalog_names().clone(); for cat_name in cat_names.iter() { - let cat = self.catalog(cat_name.as_str()).ok_or_else(|| { - DataFusionError::Internal("Catalog not found!".to_string()) - })?; + let cat = self + .catalog(cat_name.as_str()) + .ok_or_else(|| internal_datafusion_err!("Catalog not found!"))?; for schema_name in cat.schema_names() { - let schema = cat.schema(schema_name.as_str()).ok_or_else(|| { - DataFusionError::Internal("Schema not found!".to_string()) - })?; + let schema = cat + .schema(schema_name.as_str()) + .ok_or_else(|| internal_datafusion_err!("Schema not found!"))?; let lister = schema.as_any().downcast_ref::(); if let Some(lister) = lister { lister.refresh(&self.state()).await?; @@ -426,16 +442,14 @@ impl SessionContext { /// # use datafusion::prelude::*; /// # use datafusion::execution::SessionStateBuilder; /// # use datafusion_optimizer::push_down_filter::PushDownFilter; - /// let my_rule = PushDownFilter{}; // pretend it is a new rule - /// // Create a new builder with a custom optimizer rule + /// let my_rule = PushDownFilter {}; // pretend it is a new rule + /// // Create a new builder with a custom optimizer rule /// let context: SessionContext = SessionStateBuilder::new() - /// .with_optimizer_rule(Arc::new(my_rule)) - /// .build() - /// .into(); + /// .with_optimizer_rule(Arc::new(my_rule)) + /// .build() + /// .into(); /// // Enable local file access and convert context back to a builder - /// let builder = context - /// .enable_url_table() - /// .into_state_builder(); + /// let builder = context.enable_url_table().into_state_builder(); /// ``` pub fn into_state_builder(self) -> SessionStateBuilder { let SessionContext { @@ -474,6 +488,11 @@ impl SessionContext { self.state.write().append_optimizer_rule(optimizer_rule); } + /// Removes an optimizer rule by name, returning `true` if it existed. + pub fn remove_optimizer_rule(&self, name: &str) -> bool { + self.state.write().remove_optimizer_rule(name) + } + /// Adds an analyzer rule to the end of the existing rules. /// /// See [`SessionState`] for more control of when the rule is applied. @@ -504,19 +523,21 @@ impl SessionContext { self.runtime_env().register_object_store(url, object_store) } - /// Registers the [`RecordBatch`] as the specified table name + /// Deregisters an [`ObjectStore`] associated with the specific URL prefix. + /// + /// See [`RuntimeEnv::deregister_object_store`] for more details. + pub fn deregister_object_store(&self, url: &Url) -> Result> { + self.runtime_env().deregister_object_store(url) + } + + /// Registers the given [`RecordBatch`] as the specified table reference. pub fn register_batch( &self, - table_name: &str, + table_ref: impl Into, batch: RecordBatch, ) -> Result>> { let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?; - self.register_table( - TableReference::Bare { - table: table_name.into(), - }, - Arc::new(table), - ) + self.register_table(table_ref, Arc::new(table)) } /// Return the [RuntimeEnv] used to run queries with this `SessionContext` @@ -576,15 +597,15 @@ impl SessionContext { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// ctx - /// .sql("CREATE TABLE foo (x INTEGER)") - /// .await? - /// .collect() - /// .await?; + /// ctx.sql("CREATE TABLE foo (x INTEGER)") + /// .await? + /// .collect() + /// .await?; /// assert!(ctx.table_exist("foo").unwrap()); /// # Ok(()) /// # } /// ``` + #[cfg(feature = "sql")] pub async fn sql(&self, sql: &str) -> Result { self.sql_with_options(sql, SQLOptions::new()).await } @@ -604,17 +625,18 @@ impl SessionContext { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let options = SQLOptions::new() - /// .with_allow_ddl(false); - /// let err = ctx.sql_with_options("CREATE TABLE foo (x INTEGER)", options) - /// .await - /// .unwrap_err(); - /// assert!( - /// err.to_string().starts_with("Error during planning: DDL not supported: CreateMemoryTable") - /// ); + /// let options = SQLOptions::new().with_allow_ddl(false); + /// let err = ctx + /// .sql_with_options("CREATE TABLE foo (x INTEGER)", options) + /// .await + /// .unwrap_err(); + /// assert!(err + /// .to_string() + /// .starts_with("Error during planning: DDL not supported: CreateMemoryTable")); /// # Ok(()) /// # } /// ``` + #[cfg(feature = "sql")] pub async fn sql_with_options( &self, sql: &str, @@ -642,12 +664,12 @@ impl SessionContext { /// // provide type information that `a` is an Int32 /// let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); /// let df_schema = DFSchema::try_from(schema).unwrap(); - /// let expr = SessionContext::new() - /// .parse_sql_expr(sql, &df_schema)?; + /// let expr = SessionContext::new().parse_sql_expr(sql, &df_schema)?; /// assert_eq!(expected, expr); /// # Ok(()) /// # } /// ``` + #[cfg(feature = "sql")] pub fn parse_sql_expr(&self, sql: &str, df_schema: &DFSchema) -> Result { self.state.read().create_logical_expr(sql, df_schema) } @@ -668,7 +690,7 @@ impl SessionContext { match ddl { DdlStatement::CreateExternalTable(cmd) => { (Box::pin(async move { self.create_external_table(&cmd).await }) - as std::pin::Pin + Send>>) + as std::pin::Pin + Send>>) .await } DdlStatement::CreateMemoryTable(cmd) => { @@ -699,30 +721,42 @@ impl SessionContext { } // TODO what about the other statements (like TransactionStart and TransactionEnd) LogicalPlan::Statement(Statement::SetVariable(stmt)) => { - self.set_variable(stmt).await + self.set_variable(stmt).await?; + self.return_empty_dataframe() + } + LogicalPlan::Statement(Statement::ResetVariable(stmt)) => { + self.reset_variable(stmt).await?; + self.return_empty_dataframe() } LogicalPlan::Statement(Statement::Prepare(Prepare { name, input, - data_types, + fields, })) => { // The number of parameters must match the specified data types length. - if !data_types.is_empty() { + if !fields.is_empty() { let param_names = input.get_parameter_names()?; - if param_names.len() != data_types.len() { + if param_names.len() != fields.len() { return plan_err!( "Prepare specifies {} data types but query has {} parameters", - data_types.len(), + fields.len(), param_names.len() ); } } - // Store the unoptimized plan into the session state. Although storing the - // optimized plan or the physical plan would be more efficient, doing so is - // not currently feasible. This is because `now()` would be optimized to a - // constant value, causing each EXECUTE to yield the same result, which is - // incorrect behavior. - self.state.write().store_prepared(name, data_types, input)?; + // Optimize the plan without evaluating expressions like now() + let optimizer_context = OptimizerContext::new_with_config_options( + Arc::clone(self.state().config().options()), + ) + .without_query_execution_start_time(); + let plan = self.state().optimizer().optimize( + Arc::unwrap_or_clone(input), + &optimizer_context, + |_1, _2| {}, + )?; + self.state + .write() + .store_prepared(name, fields, Arc::new(plan))?; self.return_empty_dataframe() } LogicalPlan::Statement(Statement::Execute(execute)) => { @@ -764,7 +798,7 @@ impl SessionContext { /// * [`SessionState::create_physical_expr`] for a lower level API /// /// [simplified]: datafusion_optimizer::simplify_expressions - /// [expr_api]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/expr_api.rs + /// [expr_api]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/query_planning/expr_api.rs pub fn create_physical_expr( &self, expr: Expr, @@ -789,19 +823,44 @@ impl SessionContext { return not_impl_err!("Temporary tables not supported"); } - if exist { - match cmd.if_not_exists { - true => return self.return_empty_dataframe(), - false => { - return exec_err!("Table '{}' already exists", cmd.name); + match (cmd.if_not_exists, cmd.or_replace, exist) { + (true, false, true) => self.return_empty_dataframe(), + (false, true, true) => { + let result = self + .find_and_deregister(cmd.name.clone(), TableType::Base) + .await; + + match result { + Ok(true) => { + let table_provider: Arc = + self.create_custom_table(cmd).await?; + self.register_table(cmd.name.clone(), table_provider)?; + self.return_empty_dataframe() + } + Ok(false) => { + let table_provider: Arc = + self.create_custom_table(cmd).await?; + self.register_table(cmd.name.clone(), table_provider)?; + self.return_empty_dataframe() + } + Err(e) => { + exec_err!("Errored while deregistering external table: {}", e) + } } } + (true, true, true) => { + exec_err!("'IF NOT EXISTS' cannot coexist with 'REPLACE'") + } + (_, _, false) => { + let table_provider: Arc = + self.create_custom_table(cmd).await?; + self.register_table(cmd.name.clone(), table_provider)?; + self.return_empty_dataframe() + } + (false, false, true) => { + exec_err!("External table '{}' already exists", cmd.name) + } } - - let table_provider: Arc = - self.create_custom_table(cmd).await?; - self.register_table(cmd.name.clone(), table_provider)?; - self.return_empty_dataframe() } async fn create_memory_table(&self, cmd: CreateMemoryTable) -> Result { @@ -827,7 +886,7 @@ impl SessionContext { (true, false, Ok(_)) => self.return_empty_dataframe(), (false, true, Ok(_)) => { self.deregister_table(name.clone())?; - let schema = Arc::new(input.schema().as_ref().into()); + let schema = Arc::clone(input.schema().inner()); let physical = DataFrame::new(self.state(), input); let batches: Vec<_> = physical.collect_partitioned().await?; @@ -845,8 +904,7 @@ impl SessionContext { exec_err!("'IF NOT EXISTS' cannot coexist with 'REPLACE'") } (_, _, Err(_)) => { - let df_schema = input.schema(); - let schema = Arc::new(df_schema.as_ref().into()); + let schema = Arc::clone(input.schema().inner()); let physical = DataFrame::new(self.state(), input); let batches: Vec<_> = physical.collect_partitioned().await?; @@ -914,7 +972,7 @@ impl SessionContext { .. } = cmd; - // sqlparser doesnt accept database / catalog as parameter to CREATE SCHEMA + // sqlparser doesn't accept database / catalog as parameter to CREATE SCHEMA // so for now, we default to default catalog let tokens: Vec<&str> = schema_name.split('.').collect(); let (catalog, schema_name) = match tokens.len() { @@ -922,17 +980,15 @@ impl SessionContext { let state = self.state.read(); let name = &state.config().options().catalog.default_catalog; let catalog = state.catalog_list().catalog(name).ok_or_else(|| { - DataFusionError::Execution(format!( - "Missing default catalog '{name}'" - )) + exec_datafusion_err!("Missing default catalog '{name}'") })?; (catalog, tokens[0]) } 2 => { let name = &tokens[0]; - let catalog = self.catalog(name).ok_or_else(|| { - DataFusionError::Execution(format!("Missing catalog '{name}'")) - })?; + let catalog = self + .catalog(name) + .ok_or_else(|| exec_datafusion_err!("Missing catalog '{name}'"))?; (catalog, tokens[1]) } _ => return exec_err!("Unable to parse catalog from {schema_name}"), @@ -1020,22 +1076,22 @@ impl SessionContext { } else if allow_missing { return self.return_empty_dataframe(); } else { - return self.schema_doesnt_exist_err(name); + return self.schema_doesnt_exist_err(&name); } }; let dereg = catalog.deregister_schema(name.schema_name(), cascade)?; match (dereg, allow_missing) { (None, true) => self.return_empty_dataframe(), - (None, false) => self.schema_doesnt_exist_err(name), + (None, false) => self.schema_doesnt_exist_err(&name), (Some(_), _) => self.return_empty_dataframe(), } } - fn schema_doesnt_exist_err(&self, schemaref: SchemaReference) -> Result { + fn schema_doesnt_exist_err(&self, schemaref: &SchemaReference) -> Result { exec_err!("Schema '{schemaref}' doesn't exist.") } - async fn set_variable(&self, stmt: SetVariable) -> Result { + async fn set_variable(&self, stmt: SetVariable) -> Result<()> { let SetVariable { variable, value, .. } = stmt; @@ -1046,33 +1102,132 @@ impl SessionContext { } else { let mut state = self.state.write(); state.config_mut().options_mut().set(&variable, &value)?; - drop(state); + + // Re-initialize any UDFs that depend on configuration + // This allows both built-in and custom functions to respond to configuration changes + let config_options = state.config().options(); + + // Collect updated UDFs in a separate vector + let udfs_to_update: Vec<_> = state + .scalar_functions() + .values() + .filter_map(|udf| { + udf.inner() + .with_updated_config(config_options) + .map(Arc::new) + }) + .collect(); + + for udf in udfs_to_update { + state.register_udf(udf)?; + } } - self.return_empty_dataframe() + Ok(()) + } + + async fn reset_variable(&self, stmt: ResetVariable) -> Result<()> { + let variable = stmt.variable; + if variable.starts_with("datafusion.runtime.") { + return self.reset_runtime_variable(&variable); + } + + let mut state = self.state.write(); + state.config_mut().options_mut().reset(&variable)?; + + // Refresh UDFs to ensure configuration-dependent behavior updates + let config_options = state.config().options(); + let udfs_to_update: Vec<_> = state + .scalar_functions() + .values() + .filter_map(|udf| { + udf.inner() + .with_updated_config(config_options) + .map(Arc::new) + }) + .collect(); + + for udf in udfs_to_update { + state.register_udf(udf)?; + } + + Ok(()) } fn set_runtime_variable(&self, variable: &str, value: &str) -> Result<()> { let key = variable.strip_prefix("datafusion.runtime.").unwrap(); + let mut state = self.state.write(); + + let mut builder = RuntimeEnvBuilder::from_runtime_env(state.runtime_env()); + builder = match key { + "memory_limit" => { + let memory_limit = Self::parse_capacity_limit(variable, value)?; + builder.with_memory_limit(memory_limit, 1.0) + } + "max_temp_directory_size" => { + let directory_size = Self::parse_capacity_limit(variable, value)?; + builder.with_max_temp_directory_size(directory_size as u64) + } + "temp_directory" => builder.with_temp_file_path(value), + "metadata_cache_limit" => { + let limit = Self::parse_capacity_limit(variable, value)?; + builder.with_metadata_cache_limit(limit) + } + "list_files_cache_limit" => { + let limit = Self::parse_capacity_limit(variable, value)?; + builder.with_object_list_cache_limit(limit) + } + "list_files_cache_ttl" => { + let duration = Self::parse_duration(variable, value)?; + builder.with_object_list_cache_ttl(Some(duration)) + } + _ => return plan_err!("Unknown runtime configuration: {variable}"), + // Remember to update `reset_runtime_variable()` when adding new options + }; + + *state = SessionStateBuilder::from(state.clone()) + .with_runtime_env(Arc::new(builder.build()?)) + .build(); + + Ok(()) + } + + fn reset_runtime_variable(&self, variable: &str) -> Result<()> { + let key = variable.strip_prefix("datafusion.runtime.").unwrap(); + + let mut state = self.state.write(); + + let mut builder = RuntimeEnvBuilder::from_runtime_env(state.runtime_env()); match key { "memory_limit" => { - let memory_limit = Self::parse_memory_limit(value)?; - - let mut state = self.state.write(); - let mut builder = - RuntimeEnvBuilder::from_runtime_env(state.runtime_env()); - builder = builder.with_memory_limit(memory_limit, 1.0); - *state = SessionStateBuilder::from(state.clone()) - .with_runtime_env(Arc::new(builder.build()?)) - .build(); + builder.memory_pool = None; } - _ => { - return Err(DataFusionError::Plan(format!( - "Unknown runtime configuration: {variable}" - ))) + "max_temp_directory_size" => { + builder = + builder.with_max_temp_directory_size(DEFAULT_MAX_TEMP_DIRECTORY_SIZE); } - } + "temp_directory" => { + builder.disk_manager_builder = Some(DiskManagerBuilder::default()); + } + "metadata_cache_limit" => { + builder = builder.with_metadata_cache_limit(DEFAULT_METADATA_CACHE_LIMIT); + } + "list_files_cache_limit" => { + builder = builder + .with_object_list_cache_limit(DEFAULT_LIST_FILES_CACHE_MEMORY_LIMIT); + } + "list_files_cache_ttl" => { + builder = + builder.with_object_list_cache_ttl(DEFAULT_LIST_FILES_CACHE_TTL); + } + _ => return plan_err!("Unknown runtime configuration: {variable}"), + }; + + *state = SessionStateBuilder::from(state.clone()) + .with_runtime_env(Arc::new(builder.build()?)) + .build(); + Ok(()) } @@ -1083,27 +1238,146 @@ impl SessionContext { /// ``` /// use datafusion::execution::context::SessionContext; /// - /// assert_eq!(SessionContext::parse_memory_limit("1M").unwrap(), 1024 * 1024); - /// assert_eq!(SessionContext::parse_memory_limit("1.5G").unwrap(), (1.5 * 1024.0 * 1024.0 * 1024.0) as usize); + /// assert_eq!( + /// SessionContext::parse_memory_limit("1M").unwrap(), + /// 1024 * 1024 + /// ); + /// assert_eq!( + /// SessionContext::parse_memory_limit("1.5G").unwrap(), + /// (1.5 * 1024.0 * 1024.0 * 1024.0) as usize + /// ); /// ``` + #[deprecated( + since = "53.0.0", + note = "please use `parse_capacity_limit` function instead." + )] pub fn parse_memory_limit(limit: &str) -> Result { + if limit.trim().is_empty() { + return Err(plan_datafusion_err!("Empty limit value found!")); + } let (number, unit) = limit.split_at(limit.len() - 1); let number: f64 = number.parse().map_err(|_| { - DataFusionError::Plan(format!( - "Failed to parse number from memory limit '{limit}'" - )) + plan_datafusion_err!("Failed to parse number from memory limit '{limit}'") })?; + if number.is_sign_negative() || number.is_infinite() { + return Err(plan_datafusion_err!( + "Limit value should be positive finite number" + )); + } match unit { "K" => Ok((number * 1024.0) as usize), "M" => Ok((number * 1024.0 * 1024.0) as usize), "G" => Ok((number * 1024.0 * 1024.0 * 1024.0) as usize), - _ => Err(DataFusionError::Plan(format!( - "Unsupported unit '{unit}' in memory limit '{limit}'" - ))), + _ => plan_err!("Unsupported unit '{unit}' in memory limit '{limit}'"), } } + /// Parse capacity limit from string to number of bytes by allowing units: K, M and G. + /// Supports formats like '1.5G', '100M', '512K' + /// + /// # Examples + /// ``` + /// use datafusion::execution::context::SessionContext; + /// + /// assert_eq!( + /// SessionContext::parse_capacity_limit("datafusion.runtime.memory_limit", "1M").unwrap(), + /// 1024 * 1024 + /// ); + /// assert_eq!( + /// SessionContext::parse_capacity_limit("datafusion.runtime.memory_limit", "1.5G").unwrap(), + /// (1.5 * 1024.0 * 1024.0 * 1024.0) as usize + /// ); + /// ``` + pub fn parse_capacity_limit(config_name: &str, limit: &str) -> Result { + if limit.trim().is_empty() { + return Err(plan_datafusion_err!( + "Empty limit value found for '{config_name}'" + )); + } + let (number, unit) = limit.split_at(limit.len() - 1); + let number: f64 = number.parse().map_err(|_| { + plan_datafusion_err!( + "Failed to parse number from '{config_name}', limit '{limit}'" + ) + })?; + if number.is_sign_negative() || number.is_infinite() { + return Err(plan_datafusion_err!( + "Limit value should be positive finite number for '{config_name}'" + )); + } + + match unit { + "K" => Ok((number * 1024.0) as usize), + "M" => Ok((number * 1024.0 * 1024.0) as usize), + "G" => Ok((number * 1024.0 * 1024.0 * 1024.0) as usize), + _ => plan_err!( + "Unsupported unit '{unit}' in '{config_name}', limit '{limit}'. \ + Unit must be one of: 'K', 'M', 'G'" + ), + } + } + + fn parse_duration(config_name: &str, duration: &str) -> Result { + if duration.trim().is_empty() { + return Err(plan_datafusion_err!( + "Duration should not be empty or blank for '{config_name}'" + )); + } + + let mut minutes = None; + let mut seconds = None; + + for duration in duration.split_inclusive(&['m', 's']) { + let (number, unit) = duration.split_at(duration.len() - 1); + let number: u64 = number.parse().map_err(|_| { + plan_datafusion_err!("Failed to parse number from duration '{duration}' for '{config_name}'") + })?; + + match unit { + "m" if minutes.is_none() && seconds.is_none() => minutes = Some(number), + "s" if seconds.is_none() => seconds = Some(number), + other => plan_err!( + "Invalid duration unit: '{other}'. The unit must be either 'm' (minutes), or 's' (seconds), and be in the correct order for '{config_name}'" + )?, + } + } + + let secs = Self::check_overflow(config_name, minutes, 60, seconds)?; + let duration = Duration::from_secs(secs); + + if duration.is_zero() { + return plan_err!( + "Duration must be greater than 0 seconds for '{config_name}'" + ); + } + + Ok(duration) + } + + fn check_overflow( + config_name: &str, + mins: Option, + multiplier: u64, + secs: Option, + ) -> Result { + let first_part_of_secs = mins.unwrap_or_default().checked_mul(multiplier); + if first_part_of_secs.is_none() { + plan_err!( + "Duration has overflowed allowed maximum limit due to 'mins * {multiplier}' when setting '{config_name}'" + )? + } + let second_part_of_secs = first_part_of_secs + .unwrap() + .checked_add(secs.unwrap_or_default()); + if second_part_of_secs.is_none() { + plan_err!( + "Duration has overflowed allowed maximum limit due to 'mins * {multiplier} + secs' when setting '{config_name}'" + )? + } + Ok(second_part_of_secs.unwrap()) + } + async fn create_custom_table( &self, cmd: &CreateExternalTable, @@ -1115,10 +1389,7 @@ impl SessionContext { .table_factories() .get(file_type.as_str()) .ok_or_else(|| { - DataFusionError::Execution(format!( - "Unable to find factory for {}", - cmd.file_type - )) + exec_datafusion_err!("Unable to find factory for {}", cmd.file_type) })?; let table = (*factory).create(&state, cmd).await?; Ok(table) @@ -1133,20 +1404,24 @@ impl SessionContext { let table = table_ref.table().to_owned(); let maybe_schema = { let state = self.state.read(); - let resolved = state.resolve_table_ref(table_ref); + let resolved = state.resolve_table_ref(table_ref.clone()); state .catalog_list() .catalog(&resolved.catalog) .and_then(|c| c.schema(&resolved.schema)) }; - if let Some(schema) = maybe_schema { - if let Some(table_provider) = schema.table(&table).await? { - if table_provider.table_type() == table_type { - schema.deregister_table(&table)?; - return Ok(true); - } + if let Some(schema) = maybe_schema + && let Some(table_provider) = schema.table(&table).await? + && table_provider.table_type() == table_type + { + schema.deregister_table(&table)?; + if table_type == TableType::Base + && let Some(lfc) = self.runtime_env().cache_manager.get_list_files_cache() + { + lfc.drop_table_entries(&Some(table_ref))?; } + return Ok(true); } Ok(false) @@ -1159,9 +1434,11 @@ impl SessionContext { match function_factory { Some(f) => f.create(&state, stmt).await?, - _ => Err(DataFusionError::Configuration( - "Function factory has not been configured".into(), - ))?, + _ => { + return Err(DataFusionError::Configuration( + "Function factory has not been configured".to_string(), + )); + } } }; @@ -1210,29 +1487,40 @@ impl SessionContext { exec_datafusion_err!("Prepared statement '{}' does not exist", name) })?; + let state = self.state.read(); + let context = SimplifyContext::default() + .with_schema(Arc::clone(prepared.plan.schema())) + .with_config_options(Arc::clone(state.config_options())) + .with_query_execution_start_time( + state.execution_props().query_execution_start_time, + ); + let simplifier = ExprSimplifier::new(context); + // Only allow literals as parameters for now. - let mut params: Vec = parameters + let mut params: Vec = parameters .into_iter() - .map(|e| match e { - Expr::Literal(scalar) => Ok(scalar), - _ => not_impl_err!("Unsupported parameter type: {}", e), + .map(|e| match simplifier.simplify(e)? { + Expr::Literal(scalar, metadata) => { + Ok(ScalarAndMetadata::new(scalar, metadata)) + } + e => not_impl_err!("Unsupported parameter type: {e}"), }) .collect::>()?; // If the prepared statement provides data types, cast the params to those types. - if !prepared.data_types.is_empty() { - if params.len() != prepared.data_types.len() { + if !prepared.fields.is_empty() { + if params.len() != prepared.fields.len() { return exec_err!( "Prepared statement '{}' expects {} parameters, but {} provided", name, - prepared.data_types.len(), + prepared.fields.len(), params.len() ); } params = params .into_iter() - .zip(prepared.data_types.iter()) - .map(|(e, dt)| e.cast_to(dt)) + .zip(prepared.fields.iter()) + .map(|(e, dt)| -> Result<_> { e.cast_storage_to(dt.data_type()) }) .collect::>()?; } @@ -1298,6 +1586,18 @@ impl SessionContext { self.state.write().register_udwf(Arc::new(f)).ok(); } + #[cfg(feature = "sql")] + /// Registers a [`RelationPlanner`] to customize SQL table-factor planning. + /// + /// Planners are invoked in reverse registration order, allowing newer + /// planners to take precedence over existing ones. + pub fn register_relation_planner( + &self, + planner: Arc, + ) -> Result<()> { + self.state.write().register_relation_planner(planner) + } + /// Deregisters a UDF within this context. pub fn deregister_udf(&self, name: &str) { self.state.write().deregister_udf(name).ok(); @@ -1483,15 +1783,14 @@ impl SessionContext { /// SQL statements executed against this context. pub async fn register_arrow( &self, - name: &str, - table_path: &str, + table_ref: impl Into, + table_path: impl AsRef, options: ArrowReadOptions<'_>, ) -> Result<()> { let listing_options = options .to_listing_options(&self.copied_config(), self.copied_table_options()); - self.register_listing_table( - name, + table_ref, table_path, listing_options, options.schema.map(|s| Arc::new(s.to_owned())), @@ -1640,7 +1939,7 @@ impl SessionContext { /// [`ConfigOptions`]: crate::config::ConfigOptions pub fn state(&self) -> SessionState { let mut state = self.state.read().clone(); - state.execution_props_mut().start_execution(); + state.mark_start_execution(); state } @@ -1717,6 +2016,20 @@ impl FunctionRegistry for SessionContext { ) -> Result<()> { self.state.write().register_expr_planner(expr_planner) } + + fn udafs(&self) -> HashSet { + self.state.read().udafs() + } + + fn udwfs(&self) -> HashSet { + self.state.read().udwfs() + } +} + +impl datafusion_execution::TaskContextProvider for SessionContext { + fn task_ctx(&self) -> Arc { + SessionContext::task_ctx(self) + } } /// Create a new task context instance from SessionContext @@ -1741,7 +2054,7 @@ impl From for SessionStateBuilder { /// A planner used to add extensions to DataFusion logical and physical plans. #[async_trait] pub trait QueryPlanner: Debug { - /// Given a `LogicalPlan`, create an [`ExecutionPlan`] suitable for execution + /// Given a [`LogicalPlan`], create an [`ExecutionPlan`] suitable for execution async fn create_physical_plan( &self, logical_plan: &LogicalPlan, @@ -1749,12 +2062,46 @@ pub trait QueryPlanner: Debug { ) -> Result>; } -/// A pluggable interface to handle `CREATE FUNCTION` statements -/// and interact with [SessionState] to registers new udf, udaf or udwf. +/// Interface for handling `CREATE FUNCTION` statements and interacting with +/// [SessionState] to create and register functions ([`ScalarUDF`], +/// [`AggregateUDF`], [`WindowUDF`], and [`TableFunctionImpl`]) dynamically. +/// +/// Implement this trait to create user-defined functions in a custom way, such +/// as loading from external libraries or defining them programmatically. +/// DataFusion will parse `CREATE FUNCTION` statements into [`CreateFunction`] +/// structs and pass them to the [`create`](Self::create) method. +/// +/// Note there is no default implementation of this trait provided in DataFusion, +/// because the implementation and requirements vary widely. Please see +/// [function_factory example] for a reference implementation. +/// +/// [function_factory example]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/builtin_functions/function_factory.rs +/// +/// # Examples of syntax that can be supported +/// +/// ```sql +/// CREATE FUNCTION f1(BIGINT) +/// RETURNS BIGINT +/// RETURN $1 + 1; +/// ``` +/// or +/// ```sql +/// CREATE FUNCTION to_miles(DOUBLE) +/// RETURNS DOUBLE +/// LANGUAGE PYTHON +/// AS ' +/// import pyarrow.compute as pc +/// +/// conversation_rate_multiplier = 0.62137119 +/// +/// def to_miles(km_data): +/// return pc.multiply(km_data, conversation_rate_multiplier) +/// ' +/// ``` #[async_trait] pub trait FunctionFactory: Debug + Sync + Send { - /// Handles creation of user defined function specified in [CreateFunction] statement + /// Creates a new dynamic function from the SQL in the [CreateFunction] statement async fn create( &self, state: &SessionState, @@ -1762,7 +2109,8 @@ pub trait FunctionFactory: Debug + Sync + Send { ) -> Result; } -/// Type of function to create +/// The result of processing a [`CreateFunction`] statement with [`FunctionFactory`]. +#[derive(Debug, Clone)] pub enum RegisterFunction { /// Scalar user defined function Scalar(Arc), @@ -1894,6 +2242,9 @@ mod tests { use crate::test; use crate::test_util::{plan_and_collect, populate_csv_partitions}; use arrow::datatypes::{DataType, TimeUnit}; + use arrow_schema::FieldRef; + use datafusion_common::DataFusionError; + use datafusion_common::datatype::DataTypeExt; use std::error::Error; use std::path::PathBuf; @@ -1918,7 +2269,7 @@ mod tests { // configure with same memory / disk manager let memory_pool = ctx1.runtime_env().memory_pool.clone(); - let mut reservation = MemoryConsumer::new("test").register(&memory_pool); + let reservation = MemoryConsumer::new("test").register(&memory_pool); reservation.grow(100); let disk_manager = ctx1.runtime_env().disk_manager.clone(); @@ -2410,7 +2761,7 @@ mod tests { struct MyTypePlanner {} impl TypePlanner for MyTypePlanner { - fn plan_type(&self, sql_type: &ast::DataType) -> Result> { + fn plan_type_field(&self, sql_type: &ast::DataType) -> Result> { match sql_type { ast::DataType::Datetime(precision) => { let precision = match precision { @@ -2420,10 +2771,213 @@ mod tests { None | Some(9) => TimeUnit::Nanosecond, _ => unreachable!(), }; - Ok(Some(DataType::Timestamp(precision, None))) + Ok(Some( + DataType::Timestamp(precision, None).into_nullable_field_ref(), + )) } _ => Ok(None), } } } + + #[tokio::test] + async fn remove_optimizer_rule() -> Result<()> { + let get_optimizer_rules = |ctx: &SessionContext| { + ctx.state() + .optimizer() + .rules + .iter() + .map(|r| r.name().to_owned()) + .collect::>() + }; + + let ctx = SessionContext::new(); + assert!(get_optimizer_rules(&ctx).contains("simplify_expressions")); + + // default plan + let plan = ctx + .sql("select 1 + 1") + .await? + .into_optimized_plan()? + .to_string(); + assert_snapshot!(plan, @r" + Projection: Int64(2) AS Int64(1) + Int64(1) + EmptyRelation: rows=1 + "); + + assert!(ctx.remove_optimizer_rule("simplify_expressions")); + assert!(!get_optimizer_rules(&ctx).contains("simplify_expressions")); + + // plan without the simplify_expressions rule + let plan = ctx + .sql("select 1 + 1") + .await? + .into_optimized_plan()? + .to_string(); + assert_snapshot!(plan, @r" + Projection: Int64(1) + Int64(1) + EmptyRelation: rows=1 + "); + + // attempting to remove a non-existing rule returns false + assert!(!ctx.remove_optimizer_rule("simplify_expressions")); + + Ok(()) + } + + #[test] + fn test_parse_duration() { + const LIST_FILES_CACHE_TTL: &str = "datafusion.runtime.list_files_cache_ttl"; + + // Valid durations + for (duration, want) in [ + ("1s", Duration::from_secs(1)), + ("1m", Duration::from_secs(60)), + ("1m0s", Duration::from_secs(60)), + ("1m1s", Duration::from_secs(61)), + ] { + let have = + SessionContext::parse_duration(LIST_FILES_CACHE_TTL, duration).unwrap(); + assert_eq!(want, have); + } + + // Invalid durations + for duration in [ + "0s", "0m", "1s0m", "1s1m", "XYZ", "1h", "XYZm2s", "", " ", "-1m", "1m 1s", + "1m1s ", " 1m1s", + ] { + let have = SessionContext::parse_duration(LIST_FILES_CACHE_TTL, duration); + assert!(have.is_err()); + assert!( + have.unwrap_err() + .message() + .to_string() + .contains(LIST_FILES_CACHE_TTL) + ); + } + } + + #[test] + fn test_parse_duration_with_overflow_check() { + const LIST_FILES_CACHE_TTL: &str = "datafusion.runtime.list_files_cache_ttl"; + + // Valid durations which are close to max allowed limit + for (duration, want) in [ + ( + "18446744073709551615s", + Duration::from_secs(18446744073709551615), + ), + ( + "307445734561825860m", + Duration::from_secs(307445734561825860 * 60), + ), + ( + "307445734561825860m10s", + Duration::from_secs(307445734561825860 * 60 + 10), + ), + ( + "1m18446744073709551555s", + Duration::from_secs(60 + 18446744073709551555), + ), + ] { + let have = + SessionContext::parse_duration(LIST_FILES_CACHE_TTL, duration).unwrap(); + assert_eq!(want, have); + } + + // Invalid durations which overflow max allowed limit + for (duration, error_message_prefix) in [ + ( + "18446744073709551616s", + "Failed to parse number from duration", + ), + ( + "307445734561825861m", + "Duration has overflowed allowed maximum limit due to", + ), + ( + "307445734561825860m60s", + "Duration has overflowed allowed maximum limit due to", + ), + ( + "1m18446744073709551556s", + "Duration has overflowed allowed maximum limit due to", + ), + ] { + let have = SessionContext::parse_duration(LIST_FILES_CACHE_TTL, duration); + assert!(have.is_err()); + let error_message = have.unwrap_err().message().to_string(); + assert!( + error_message.contains(error_message_prefix) + && error_message.contains(LIST_FILES_CACHE_TTL) + ); + } + } + + #[test] + fn test_parse_memory_limit() { + // Valid memory_limit + for (limit, want) in [ + ("1.5K", (1.5 * 1024.0) as usize), + ("2M", (2f64 * 1024.0 * 1024.0) as usize), + ("1G", (1f64 * 1024.0 * 1024.0 * 1024.0) as usize), + ] { + #[expect(deprecated)] + let have = SessionContext::parse_memory_limit(limit).unwrap(); + assert_eq!(want, have); + } + + // Invalid memory_limit + for limit in [ + "1B", + "1T", + "", + " ", + "XYZG", + "-1G", + "infG", + "-infG", + "G", + "1024B", + "invalid_size", + ] { + #[expect(deprecated)] + let have = SessionContext::parse_memory_limit(limit); + assert!(have.is_err()); + } + } + + #[test] + fn test_parse_capacity_limit() { + const MEMORY_LIMIT: &str = "datafusion.runtime.memory_limit"; + + // Valid capacity_limit + for (limit, want) in [ + ("1.5K", (1.5 * 1024.0) as usize), + ("2M", (2f64 * 1024.0 * 1024.0) as usize), + ("1G", (1f64 * 1024.0 * 1024.0 * 1024.0) as usize), + ] { + let have = SessionContext::parse_capacity_limit(MEMORY_LIMIT, limit).unwrap(); + assert_eq!(want, have); + } + + // Invalid capacity_limit + for limit in [ + "1B", + "1T", + "", + " ", + "XYZG", + "-1G", + "infG", + "-infG", + "G", + "1024B", + "invalid_size", + ] { + let have = SessionContext::parse_capacity_limit(MEMORY_LIMIT, limit); + assert!(have.is_err()); + assert!(have.unwrap_err().to_string().contains(MEMORY_LIMIT)); + } + } } diff --git a/datafusion/core/src/execution/context/parquet.rs b/datafusion/core/src/execution/context/parquet.rs index eea2b804770a3..823dc946ea732 100644 --- a/datafusion/core/src/execution/context/parquet.rs +++ b/datafusion/core/src/execution/context/parquet.rs @@ -34,13 +34,12 @@ impl SessionContext { /// /// # Note: Statistics /// - /// NOTE: by default, statistics are not collected when reading the Parquet - /// files as this can slow down the initial DataFrame creation. However, - /// collecting statistics can greatly accelerate queries with certain - /// filters. + /// NOTE: by default, statistics are collected when reading the Parquet + /// files This can slow down the initial DataFrame creation while + /// greatly accelerating queries with certain filters. /// - /// To enable collect statistics, set the [config option] - /// `datafusion.execution.collect_statistics` to `true`. See + /// To disable statistics collection, set the [config option] + /// `datafusion.execution.collect_statistics` to `false`. See /// [`ConfigOptions`] and [`ExecutionOptions::collect_statistics`] for more /// details. /// @@ -108,11 +107,13 @@ mod tests { use crate::test_util::parquet_test_data; use arrow::util::pretty::pretty_format_batches; - use datafusion_common::assert_contains; use datafusion_common::config::TableParquetOptions; + use datafusion_common::{ + assert_batches_eq, assert_batches_sorted_eq, assert_contains, + }; use datafusion_execution::config::SessionConfig; - use tempfile::tempdir; + use tempfile::{TempDir, tempdir}; #[tokio::test] async fn read_with_glob_path() -> Result<()> { @@ -171,28 +172,28 @@ mod tests { #[tokio::test] async fn register_parquet_respects_collect_statistics_config() -> Result<()> { - // The default is false + // The default is true let mut config = SessionConfig::new(); config.options_mut().explain.physical_plan_only = true; config.options_mut().explain.show_statistics = true; let content = explain_query_all_with_config(config).await?; - assert_contains!(content, "statistics=[Rows=Absent,"); + assert_contains!(content, "statistics=[Rows=Exact("); - // Explicitly set to false + // Explicitly set to true let mut config = SessionConfig::new(); config.options_mut().explain.physical_plan_only = true; config.options_mut().explain.show_statistics = true; - config.options_mut().execution.collect_statistics = false; + config.options_mut().execution.collect_statistics = true; let content = explain_query_all_with_config(config).await?; - assert_contains!(content, "statistics=[Rows=Absent,"); + assert_contains!(content, "statistics=[Rows=Exact("); - // Explicitly set to true + // Explicitly set to false let mut config = SessionConfig::new(); config.options_mut().explain.physical_plan_only = true; config.options_mut().explain.show_statistics = true; - config.options_mut().execution.collect_statistics = true; + config.options_mut().execution.collect_statistics = false; let content = explain_query_all_with_config(config).await?; - assert_contains!(content, "statistics=[Rows=Exact(10),"); + assert_contains!(content, "statistics=[Rows=Absent,"); Ok(()) } @@ -354,7 +355,9 @@ mod tests { let expected_path = binding[0].as_str(); assert_eq!( read_df.unwrap_err().strip_backtrace(), - format!("Execution error: File path '{expected_path}' does not match the expected extension '.parquet'") + format!( + "Execution error: File path '{expected_path}' does not match the expected extension '.parquet'" + ) ); // Read the dataframe from 'output3.parquet.snappy.parquet' with the correct file extension. @@ -401,4 +404,124 @@ mod tests { assert_eq!(total_rows, 5); Ok(()) } + + #[tokio::test] + async fn read_from_parquet_folder() -> Result<()> { + let ctx = SessionContext::new(); + let tmp_dir = TempDir::new()?; + let test_path = tmp_dir.path().to_str().unwrap().to_string(); + + ctx.sql("SELECT 1 a") + .await? + .write_parquet(&test_path, DataFrameWriteOptions::default(), None) + .await?; + + ctx.sql("SELECT 2 a") + .await? + .write_parquet(&test_path, DataFrameWriteOptions::default(), None) + .await?; + + // Adding CSV to check it is not read with Parquet reader + ctx.sql("SELECT 3 a") + .await? + .write_csv(&test_path, DataFrameWriteOptions::default(), None) + .await?; + + let actual = ctx + .read_parquet(&test_path, ParquetReadOptions::default()) + .await? + .collect() + .await?; + + #[cfg_attr(any(), rustfmt::skip)] + assert_batches_sorted_eq!(&[ + "+---+", + "| a |", + "+---+", + "| 2 |", + "| 1 |", + "+---+", + ], &actual); + + let actual = ctx + .read_parquet(test_path, ParquetReadOptions::default()) + .await? + .collect() + .await?; + + #[cfg_attr(any(), rustfmt::skip)] + assert_batches_sorted_eq!(&[ + "+---+", + "| a |", + "+---+", + "| 2 |", + "| 1 |", + "+---+", + ], &actual); + + Ok(()) + } + + #[tokio::test] + async fn read_from_parquet_folder_table() -> Result<()> { + let ctx = SessionContext::new(); + let tmp_dir = TempDir::new()?; + let test_path = tmp_dir.path().to_str().unwrap().to_string(); + + ctx.sql("SELECT 1 a") + .await? + .write_parquet(&test_path, DataFrameWriteOptions::default(), None) + .await?; + + ctx.sql("SELECT 2 a") + .await? + .write_parquet(&test_path, DataFrameWriteOptions::default(), None) + .await?; + + // Adding CSV to check it is not read with Parquet reader + ctx.sql("SELECT 3 a") + .await? + .write_csv(&test_path, DataFrameWriteOptions::default(), None) + .await?; + + ctx.sql(format!("CREATE EXTERNAL TABLE parquet_folder_t1 STORED AS PARQUET LOCATION '{test_path}'").as_ref()) + .await?; + + let actual = ctx + .sql("select * from parquet_folder_t1") + .await? + .collect() + .await?; + #[cfg_attr(any(), rustfmt::skip)] + assert_batches_sorted_eq!(&[ + "+---+", + "| a |", + "+---+", + "| 2 |", + "| 1 |", + "+---+", + ], &actual); + + Ok(()) + } + + #[tokio::test] + async fn read_dummy_folder() -> Result<()> { + let ctx = SessionContext::new(); + let test_path = "/foo/"; + + let actual = ctx + .read_parquet(test_path, ParquetReadOptions::default()) + .await? + .collect() + .await?; + + #[cfg_attr(any(), rustfmt::skip)] + assert_batches_eq!(&[ + "++", + "++", + ], &actual); + + Ok(()) + } } diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 8aa812cc5258a..9560616c1b6da 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -24,61 +24,67 @@ use std::fmt::Debug; use std::sync::Arc; use crate::catalog::{CatalogProviderList, SchemaProvider, TableProviderFactory}; -use crate::datasource::cte_worktable::CteWorkTable; -use crate::datasource::file_format::{format_as_file_type, FileFormatFactory}; +use crate::datasource::file_format::FileFormatFactory; +#[cfg(feature = "sql")] use crate::datasource::provider_as_source; -use crate::execution::context::{EmptySerializerRegistry, FunctionFactory, QueryPlanner}; use crate::execution::SessionStateDefaults; +use crate::execution::context::{EmptySerializerRegistry, FunctionFactory, QueryPlanner}; use crate::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}; +use arrow_schema::{DataType, FieldRef}; +use datafusion_catalog::MemoryCatalogProviderList; use datafusion_catalog::information_schema::{ - InformationSchemaProvider, INFORMATION_SCHEMA, + INFORMATION_SCHEMA, InformationSchemaProvider, }; - -use arrow::datatypes::{DataType, SchemaRef}; -use datafusion_catalog::MemoryCatalogProviderList; use datafusion_catalog::{TableFunction, TableFunctionImpl}; use datafusion_common::alias::AliasGenerator; +#[cfg(feature = "sql")] +use datafusion_common::config::Dialect; use datafusion_common::config::{ConfigExtension, ConfigOptions, TableOptions}; use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; -use datafusion_common::file_options::file_type::FileType; use datafusion_common::tree_node::TreeNode; use datafusion_common::{ - config_err, exec_err, not_impl_err, plan_datafusion_err, DFSchema, DataFusionError, - ResolvedTableReference, TableReference, + DFSchema, DataFusionError, ResolvedTableReference, TableReference, config_err, + exec_err, plan_datafusion_err, }; +use datafusion_execution::TaskContext; use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnv; -use datafusion_execution::TaskContext; +#[cfg(feature = "sql")] +use datafusion_expr::TableSource; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr_rewriter::FunctionRewrite; -use datafusion_expr::planner::{ExprPlanner, TypePlanner}; +use datafusion_expr::planner::ExprPlanner; +#[cfg(feature = "sql")] +use datafusion_expr::planner::{RelationPlanner, TypePlanner}; use datafusion_expr::registry::{FunctionRegistry, SerializerRegistry}; -use datafusion_expr::simplify::SimplifyInfo; -use datafusion_expr::var_provider::{is_system_variables, VarType}; -use datafusion_expr::{ - AggregateUDF, Explain, Expr, ExprSchemable, LogicalPlan, ScalarUDF, TableSource, - WindowUDF, -}; +use datafusion_expr::simplify::SimplifyContext; +use datafusion_expr::{AggregateUDF, Explain, Expr, LogicalPlan, ScalarUDF, WindowUDF}; use datafusion_optimizer::simplify_expressions::ExprSimplifier; use datafusion_optimizer::{ Analyzer, AnalyzerRule, Optimizer, OptimizerConfig, OptimizerRule, }; use datafusion_physical_expr::create_physical_expr; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_optimizer::optimizer::PhysicalOptimizer; use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_optimizer::optimizer::PhysicalOptimizer; use datafusion_physical_plan::ExecutionPlan; use datafusion_session::Session; -use datafusion_sql::parser::{DFParserBuilder, Statement}; -use datafusion_sql::planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel}; +#[cfg(feature = "sql")] +use datafusion_sql::{ + parser::{DFParserBuilder, Statement}, + planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel}, +}; use async_trait::async_trait; use chrono::{DateTime, Utc}; use itertools::Itertools; use log::{debug, info}; use object_store::ObjectStore; -use sqlparser::ast::{Expr as SQLExpr, ExprWithAlias as SQLExprWithAlias}; -use sqlparser::dialect::dialect_from_str; +#[cfg(feature = "sql")] +use sqlparser::{ + ast::{Expr as SQLExpr, ExprWithAlias as SQLExprWithAlias}, + dialect::dialect_from_str, +}; use url::Url; use uuid::Uuid; @@ -108,12 +114,12 @@ use uuid::Uuid; /// # use std::sync::Arc; /// # #[tokio::main] /// # async fn main() -> Result<()> { -/// let state = SessionStateBuilder::new() -/// .with_config(SessionConfig::new()) -/// .with_runtime_env(Arc::new(RuntimeEnv::default())) -/// .with_default_features() -/// .build(); -/// Ok(()) +/// let state = SessionStateBuilder::new() +/// .with_config(SessionConfig::new()) +/// .with_runtime_env(Arc::new(RuntimeEnv::default())) +/// .with_default_features() +/// .build(); +/// Ok(()) /// # } /// ``` /// @@ -131,7 +137,10 @@ pub struct SessionState { analyzer: Analyzer, /// Provides support for customizing the SQL planner, e.g. to add support for custom operators like `->>` or `?` expr_planners: Vec>, + #[cfg(feature = "sql")] + relation_planners: Vec>, /// Provides support for customizing the SQL type planning + #[cfg(feature = "sql")] type_planner: Option>, /// Responsible for optimizing a logical plan optimizer: Optimizer, @@ -176,6 +185,7 @@ pub struct SessionState { /// It will be invoked on `CREATE FUNCTION` statements. /// thus, changing dialect o PostgreSql is required function_factory: Option>, + cache_factory: Option>, /// Cache logical plans of prepared statements for later execution. /// Key is the prepared statement name. prepared_plans: HashMap>, @@ -185,7 +195,8 @@ impl Debug for SessionState { /// Prefer having short fields at the top and long vector fields near the end /// Group fields by fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("SessionState") + let mut debug_struct = f.debug_struct("SessionState"); + let ret = debug_struct .field("session_id", &self.session_id) .field("config", &self.config) .field("runtime_env", &self.runtime_env) @@ -196,9 +207,16 @@ impl Debug for SessionState { .field("table_options", &self.table_options) .field("table_factories", &self.table_factories) .field("function_factory", &self.function_factory) - .field("expr_planners", &self.expr_planners) - .field("type_planner", &self.type_planner) - .field("query_planners", &self.query_planner) + .field("cache_factory", &self.cache_factory) + .field("expr_planners", &self.expr_planners); + + #[cfg(feature = "sql")] + let ret = ret.field("relation_planners", &self.relation_planners); + + #[cfg(feature = "sql")] + let ret = ret.field("type_planner", &self.type_planner); + + ret.field("query_planners", &self.query_planner) .field("analyzer", &self.analyzer) .field("optimizer", &self.optimizer) .field("physical_optimizers", &self.physical_optimizers) @@ -274,17 +292,6 @@ impl Session for SessionState { } impl SessionState { - /// Returns new [`SessionState`] using the provided - /// [`SessionConfig`] and [`RuntimeEnv`]. - #[deprecated(since = "41.0.0", note = "Use SessionStateBuilder")] - pub fn new_with_config_rt(config: SessionConfig, runtime: Arc) -> Self { - SessionStateBuilder::new() - .with_config(config) - .with_runtime_env(runtime) - .with_default_features() - .build() - } - pub(crate) fn resolve_table_ref( &self, table_ref: impl Into, @@ -343,6 +350,13 @@ impl SessionState { self.optimizer.rules.push(optimizer_rule); } + /// Removes an optimizer rule by name, returning `true` if it existed. + pub(crate) fn remove_optimizer_rule(&mut self, name: &str) -> bool { + let original_len = self.optimizer.rules.len(); + self.optimizer.rules.retain(|r| r.name() != name); + self.optimizer.rules.len() < original_len + } + /// Registers a [`FunctionFactory`] to handle `CREATE FUNCTION` statements pub fn set_function_factory(&mut self, function_factory: Arc) { self.function_factory = Some(function_factory); @@ -353,6 +367,16 @@ impl SessionState { self.function_factory.as_ref() } + /// Register a [`CacheFactory`] for custom caching strategy + pub fn set_cache_factory(&mut self, cache_factory: Arc) { + self.cache_factory = Some(cache_factory); + } + + /// Get the cache factory + pub fn cache_factory(&self) -> Option<&Arc> { + self.cache_factory.as_ref() + } + /// Get the table factories pub fn table_factories(&self) -> &HashMap> { &self.table_factories @@ -369,10 +393,11 @@ impl SessionState { /// [`Statement`]. See [`SessionContext::sql`] for running queries. /// /// [`SessionContext::sql`]: crate::execution::context::SessionContext::sql + #[cfg(feature = "sql")] pub fn sql_to_statement( &self, sql: &str, - dialect: &str, + dialect: &Dialect, ) -> datafusion_common::Result { let dialect = dialect_from_str(dialect).ok_or_else(|| { plan_datafusion_err!( @@ -391,7 +416,7 @@ impl SessionState { .parse_statements()?; if statements.len() > 1 { - return not_impl_err!( + return datafusion_common::not_impl_err!( "The context currently only supports a single SQL statement" ); } @@ -405,10 +430,11 @@ impl SessionState { /// parse a sql string into a sqlparser-rs AST [`SQLExpr`]. /// /// See [`Self::create_logical_expr`] for parsing sql to [`Expr`]. + #[cfg(feature = "sql")] pub fn sql_to_expr( &self, sql: &str, - dialect: &str, + dialect: &Dialect, ) -> datafusion_common::Result { self.sql_to_expr_with_alias(sql, dialect).map(|x| x.expr) } @@ -416,10 +442,11 @@ impl SessionState { /// parse a sql string into a sqlparser-rs AST [`SQLExprWithAlias`]. /// /// See [`Self::create_logical_expr`] for parsing sql to [`Expr`]. + #[cfg(feature = "sql")] pub fn sql_to_expr_with_alias( &self, sql: &str, - dialect: &str, + dialect: &Dialect, ) -> datafusion_common::Result { let dialect = dialect_from_str(dialect).ok_or_else(|| { plan_datafusion_err!( @@ -434,7 +461,7 @@ impl SessionState { .with_dialect(dialect.as_ref()) .with_recursion_limit(recursion_limit) .build()? - .parse_expr()?; + .parse_into_expr()?; Ok(expr) } @@ -444,6 +471,7 @@ impl SessionState { /// See [`datafusion_sql::resolve::resolve_table_references`] for more information. /// /// [`datafusion_sql::resolve::resolve_table_references`]: datafusion_sql::resolve::resolve_table_references + #[cfg(feature = "sql")] pub fn resolve_table_references( &self, statement: &Statement, @@ -458,6 +486,7 @@ impl SessionState { } /// Convert an AST Statement into a LogicalPlan + #[cfg(feature = "sql")] pub async fn statement_to_plan( &self, statement: Statement, @@ -473,10 +502,10 @@ impl SessionState { let resolved = self.resolve_table_ref(reference); if let Entry::Vacant(v) = provider.tables.entry(resolved) { let resolved = v.key(); - if let Ok(schema) = self.schema_for_ref(resolved.clone()) { - if let Some(table) = schema.table(&resolved.table).await? { - v.insert(provider_as_source(table)); - } + if let Ok(schema) = self.schema_for_ref(resolved.clone()) + && let Some(table) = schema.table(&resolved.table).await? + { + v.insert(provider_as_source(table)); } } } @@ -485,6 +514,7 @@ impl SessionState { query.statement_to_plan(statement) } + #[cfg(feature = "sql")] fn get_parser_options(&self) -> ParserOptions { let sql_parser_options = &self.config.options().sql_parser; @@ -494,8 +524,12 @@ impl SessionState { enable_options_value_normalization: sql_parser_options .enable_options_value_normalization, support_varchar_with_length: sql_parser_options.support_varchar_with_length, - map_varchar_to_utf8view: sql_parser_options.map_varchar_to_utf8view, + map_string_types_to_utf8view: sql_parser_options.map_string_types_to_utf8view, collect_spans: sql_parser_options.collect_spans, + default_null_ordering: sql_parser_options + .default_null_ordering + .as_str() + .into(), } } @@ -511,12 +545,13 @@ impl SessionState { /// /// [`SessionContext::sql`]: crate::execution::context::SessionContext::sql /// [`SessionContext::sql_with_options`]: crate::execution::context::SessionContext::sql_with_options + #[cfg(feature = "sql")] pub async fn create_logical_plan( &self, sql: &str, ) -> datafusion_common::Result { - let dialect = self.config.options().sql_parser.dialect.as_str(); - let statement = self.sql_to_statement(sql, dialect)?; + let dialect = self.config.options().sql_parser.dialect; + let statement = self.sql_to_statement(sql, &dialect)?; let plan = self.statement_to_plan(statement).await?; Ok(plan) } @@ -524,15 +559,26 @@ impl SessionState { /// Creates a datafusion style AST [`Expr`] from a SQL string. /// /// See example on [SessionContext::parse_sql_expr](crate::execution::context::SessionContext::parse_sql_expr) + #[cfg(feature = "sql")] pub fn create_logical_expr( &self, sql: &str, df_schema: &DFSchema, ) -> datafusion_common::Result { - let dialect = self.config.options().sql_parser.dialect.as_str(); + let dialect = self.config.options().sql_parser.dialect; - let sql_expr = self.sql_to_expr_with_alias(sql, dialect)?; + let sql_expr = self.sql_to_expr_with_alias(sql, &dialect)?; + self.create_logical_expr_from_sql_expr(sql_expr, df_schema) + } + + /// Creates a datafusion style AST [`Expr`] from a SQL expression. + #[cfg(feature = "sql")] + pub fn create_logical_expr_from_sql_expr( + &self, + sql_expr: SQLExprWithAlias, + df_schema: &DFSchema, + ) -> datafusion_common::Result { let provider = SessionContextProvider { state: self, tables: HashMap::new(), @@ -557,6 +603,24 @@ impl SessionState { &self.expr_planners } + #[cfg(feature = "sql")] + /// Returns the registered relation planners in priority order. + pub fn relation_planners(&self) -> &[Arc] { + &self.relation_planners + } + + #[cfg(feature = "sql")] + /// Registers a [`RelationPlanner`] to customize SQL relation planning. + /// + /// Newly registered planners are given higher priority than existing ones. + pub fn register_relation_planner( + &mut self, + planner: Arc, + ) -> datafusion_common::Result<()> { + self.relation_planners.insert(0, planner); + Ok(()) + } + /// Returns the [`QueryPlanner`] for this session pub fn query_planner(&self) -> &Arc { &self.query_planner @@ -570,7 +634,7 @@ impl SessionState { // analyze & capture output of each rule let analyzer_result = self.analyzer.execute_and_check( e.plan.as_ref().clone(), - self.options(), + &self.options(), |analyzed_plan, analyzer| { let analyzer_name = analyzer.name().to_string(); let plan_type = PlanType::AnalyzedLogicalPlan { analyzer_name }; @@ -632,7 +696,7 @@ impl SessionState { } else { let analyzed_plan = self.analyzer.execute_and_check( plan.clone(), - self.options(), + &self.options(), |_, _| {}, )?; self.optimizer.optimize(analyzed_plan, self, |_, _| {}) @@ -671,20 +735,25 @@ impl SessionState { /// * [`create_physical_expr`] for a lower-level API /// /// [simplified]: datafusion_optimizer::simplify_expressions - /// [expr_api]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/expr_api.rs + /// [expr_api]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/query_planning/expr_api.rs /// [`SessionContext::create_physical_expr`]: crate::execution::context::SessionContext::create_physical_expr pub fn create_physical_expr( &self, expr: Expr, df_schema: &DFSchema, ) -> datafusion_common::Result> { - let simplifier = - ExprSimplifier::new(SessionSimplifyProvider::new(self, df_schema)); + let config_options = self.config_options(); + let simplify_context = SimplifyContext::default() + .with_schema(Arc::new(df_schema.clone())) + .with_config_options(Arc::clone(config_options)) + .with_query_execution_start_time( + self.execution_props().query_execution_start_time, + ); + let simplifier = ExprSimplifier::new(simplify_context); // apply type coercion here to ensure types match let mut expr = simplifier.coerce(expr, df_schema)?; // rewrite Exprs to functions if necessary - let config_options = self.config_options(); for rewrite in self.analyzer.function_rewrites() { expr = expr .transform_up(|expr| rewrite.rewrite(expr, df_schema, config_options))? @@ -734,10 +803,16 @@ impl SessionState { } /// return the configuration options - pub fn config_options(&self) -> &ConfigOptions { + pub fn config_options(&self) -> &Arc { self.config.options() } + /// Mark the start of the execution + pub fn mark_start_execution(&mut self) { + let config = Arc::clone(self.config.options()); + self.execution_props.mark_start_execution(config); + } + /// Return the table options pub fn table_options(&self) -> &TableOptions { &self.table_options @@ -768,10 +843,18 @@ impl SessionState { overwrite: bool, ) -> Result<(), DataFusionError> { let ext = file_format.get_ext().to_lowercase(); - match (self.file_formats.entry(ext.clone()), overwrite){ - (Entry::Vacant(e), _) => {e.insert(file_format);}, - (Entry::Occupied(mut e), true) => {e.insert(file_format);}, - (Entry::Occupied(_), false) => return config_err!("File type already registered for extension {ext}. Set overwrite to true to replace this extension."), + match (self.file_formats.entry(ext.clone()), overwrite) { + (Entry::Vacant(e), _) => { + e.insert(file_format); + } + (Entry::Occupied(mut e), true) => { + e.insert(file_format); + } + (Entry::Occupied(_), false) => { + return config_err!( + "File type already registered for extension {ext}. Set overwrite to true to replace this extension." + ); + } }; Ok(()) } @@ -795,11 +878,8 @@ impl SessionState { &self.catalog_list } - /// set the catalog list - pub(crate) fn register_catalog_list( - &mut self, - catalog_list: Arc, - ) { + /// Set the catalog list + pub fn register_catalog_list(&mut self, catalog_list: Arc) { self.catalog_list = catalog_list; } @@ -854,12 +934,12 @@ impl SessionState { pub(crate) fn store_prepared( &mut self, name: String, - data_types: Vec, + fields: Vec, plan: Arc, ) -> datafusion_common::Result<()> { match self.prepared_plans.entry(name) { Entry::Vacant(e) => { - e.insert(Arc::new(PreparedPlan { data_types, plan })); + e.insert(Arc::new(PreparedPlan { fields, plan })); Ok(()) } Entry::Occupied(e) => { @@ -889,10 +969,14 @@ impl SessionState { /// be used for all values unless explicitly provided. /// /// See example on [`SessionState`] +#[derive(Clone)] pub struct SessionStateBuilder { session_id: Option, analyzer: Option, expr_planners: Option>>, + #[cfg(feature = "sql")] + relation_planners: Option>>, + #[cfg(feature = "sql")] type_planner: Option>, optimizer: Option, physical_optimizers: Option, @@ -910,6 +994,7 @@ pub struct SessionStateBuilder { table_factories: Option>>, runtime_env: Option>, function_factory: Option>, + cache_factory: Option>, // fields to support convenience functions analyzer_rules: Option>>, optimizer_rules: Option>>, @@ -929,6 +1014,9 @@ impl SessionStateBuilder { session_id: None, analyzer: None, expr_planners: None, + #[cfg(feature = "sql")] + relation_planners: None, + #[cfg(feature = "sql")] type_planner: None, optimizer: None, physical_optimizers: None, @@ -946,6 +1034,7 @@ impl SessionStateBuilder { table_factories: None, runtime_env: None, function_factory: None, + cache_factory: None, // fields to support convenience functions analyzer_rules: None, optimizer_rules: None, @@ -978,6 +1067,9 @@ impl SessionStateBuilder { session_id: None, analyzer: Some(existing.analyzer), expr_planners: Some(existing.expr_planners), + #[cfg(feature = "sql")] + relation_planners: Some(existing.relation_planners), + #[cfg(feature = "sql")] type_planner: existing.type_planner, optimizer: Some(existing.optimizer), physical_optimizers: Some(existing.physical_optimizers), @@ -997,7 +1089,7 @@ impl SessionStateBuilder { table_factories: Some(existing.table_factories), runtime_env: Some(existing.runtime_env), function_factory: existing.function_factory, - + cache_factory: existing.cache_factory, // fields to support convenience functions analyzer_rules: None, optimizer_rules: None, @@ -1118,7 +1210,18 @@ impl SessionStateBuilder { self } + #[cfg(feature = "sql")] + /// Sets the [`RelationPlanner`]s used to customize SQL relation planning. + pub fn with_relation_planners( + mut self, + relation_planners: Vec>, + ) -> Self { + self.relation_planners = Some(relation_planners); + self + } + /// Set the [`TypePlanner`] used to customize the behavior of the SQL planner. + #[cfg(feature = "sql")] pub fn with_type_planner(mut self, type_planner: Arc) -> Self { self.type_planner = Some(type_planner); self @@ -1285,6 +1388,15 @@ impl SessionStateBuilder { self } + /// Set a [`CacheFactory`] for custom caching strategy + pub fn with_cache_factory( + mut self, + cache_factory: Option>, + ) -> Self { + self.cache_factory = cache_factory; + self + } + /// Register an `ObjectStore` to the [`RuntimeEnv`]. See [`RuntimeEnv::register_object_store`] /// for more details. /// @@ -1300,7 +1412,7 @@ impl SessionStateBuilder { /// let url = Url::try_from("file://").unwrap(); /// let object_store = object_store::local::LocalFileSystem::new(); /// let state = SessionStateBuilder::new() - /// .with_config(SessionConfig::new()) + /// .with_config(SessionConfig::new()) /// .with_object_store(&url, Arc::new(object_store)) /// .with_default_features() /// .build(); @@ -1330,6 +1442,9 @@ impl SessionStateBuilder { session_id, analyzer, expr_planners, + #[cfg(feature = "sql")] + relation_planners, + #[cfg(feature = "sql")] type_planner, optimizer, physical_optimizers, @@ -1347,6 +1462,7 @@ impl SessionStateBuilder { table_factories, runtime_env, function_factory, + cache_factory, analyzer_rules, optimizer_rules, physical_optimizer_rules, @@ -1359,6 +1475,9 @@ impl SessionStateBuilder { session_id: session_id.unwrap_or_else(|| Uuid::new_v4().to_string()), analyzer: analyzer.unwrap_or_default(), expr_planners: expr_planners.unwrap_or_default(), + #[cfg(feature = "sql")] + relation_planners: relation_planners.unwrap_or_default(), + #[cfg(feature = "sql")] type_planner, optimizer: optimizer.unwrap_or_default(), physical_optimizers: physical_optimizers.unwrap_or_default(), @@ -1382,6 +1501,7 @@ impl SessionStateBuilder { table_factories: table_factories.unwrap_or_default(), runtime_env, function_factory, + cache_factory, prepared_plans: HashMap::new(), }; @@ -1394,12 +1514,31 @@ impl SessionStateBuilder { } if let Some(scalar_functions) = scalar_functions { - scalar_functions.into_iter().for_each(|udf| { - let existing_udf = state.register_udf(udf); - if let Ok(Some(existing_udf)) = existing_udf { - debug!("Overwrote an existing UDF: {}", existing_udf.name()); + for udf in scalar_functions { + let config_options = state.config().options(); + match udf.inner().with_updated_config(config_options) { + Some(new_udf) => { + if let Err(err) = state.register_udf(Arc::new(new_udf)) { + debug!( + "Failed to re-register updated UDF '{}': {}", + udf.name(), + err + ); + } + } + None => match state.register_udf(Arc::clone(&udf)) { + Ok(Some(existing)) => { + debug!("Overwrote existing UDF '{}'", existing.name()); + } + Ok(None) => { + debug!("Registered UDF '{}'", udf.name()); + } + Err(err) => { + debug!("Failed to register UDF '{}': {}", udf.name(), err); + } + }, } - }); + } } if let Some(aggregate_functions) = aggregate_functions { @@ -1476,7 +1615,14 @@ impl SessionStateBuilder { &mut self.expr_planners } + #[cfg(feature = "sql")] + /// Returns a mutable reference to the current [`RelationPlanner`] list. + pub fn relation_planners(&mut self) -> &mut Option>> { + &mut self.relation_planners + } + /// Returns the current type_planner value + #[cfg(feature = "sql")] pub fn type_planner(&mut self) -> &mut Option> { &mut self.type_planner } @@ -1565,6 +1711,11 @@ impl SessionStateBuilder { &mut self.function_factory } + /// Returns the cache factory + pub fn cache_factory(&mut self) -> &mut Option> { + &mut self.cache_factory + } + /// Returns the current analyzer_rules value pub fn analyzer_rules( &mut self, @@ -1591,7 +1742,8 @@ impl Debug for SessionStateBuilder { /// Prefer having short fields at the top and long vector fields near the end /// Group fields by fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("SessionStateBuilder") + let mut debug_struct = f.debug_struct("SessionStateBuilder"); + let ret = debug_struct .field("session_id", &self.session_id) .field("config", &self.config) .field("runtime_env", &self.runtime_env) @@ -1602,9 +1754,11 @@ impl Debug for SessionStateBuilder { .field("table_options", &self.table_options) .field("table_factories", &self.table_factories) .field("function_factory", &self.function_factory) - .field("expr_planners", &self.expr_planners) - .field("type_planner", &self.type_planner) - .field("query_planners", &self.query_planner) + .field("cache_factory", &self.cache_factory) + .field("expr_planners", &self.expr_planners); + #[cfg(feature = "sql")] + let ret = ret.field("type_planner", &self.type_planner); + ret.field("query_planners", &self.query_planner) .field("analyzer_rules", &self.analyzer_rules) .field("analyzer", &self.analyzer) .field("optimizer_rules", &self.optimizer_rules) @@ -1635,16 +1789,22 @@ impl From for SessionStateBuilder { /// /// This is used so the SQL planner can access the state of the session without /// having a direct dependency on the [`SessionState`] struct (and core crate) +#[cfg(feature = "sql")] struct SessionContextProvider<'a> { state: &'a SessionState, tables: HashMap>, } +#[cfg(feature = "sql")] impl ContextProvider for SessionContextProvider<'_> { fn get_expr_planners(&self) -> &[Arc] { self.state.expr_planners() } + fn get_relation_planners(&self) -> &[Arc] { + self.state.relation_planners() + } + fn get_type_planner(&self) -> Option> { if let Some(type_planner) = &self.state.type_planner { Some(Arc::clone(type_planner)) @@ -1675,6 +1835,21 @@ impl ContextProvider for SessionContextProvider<'_> { .get(name) .cloned() .ok_or_else(|| plan_datafusion_err!("table function '{name}' not found"))?; + let simplify_context = SimplifyContext::default() + .with_config_options(Arc::clone(self.state.config_options())) + .with_query_execution_start_time( + self.state.execution_props().query_execution_start_time, + ); + let simplifier = ExprSimplifier::new(simplify_context); + let schema = DFSchema::empty(); + let args = args + .into_iter() + .map(|arg| { + simplifier + .coerce(arg, &schema) + .and_then(|e| simplifier.simplify(e)) + }) + .collect::>>()?; let provider = tbl_func.create_table_provider(&args)?; Ok(provider_as_source(provider)) @@ -1686,9 +1861,11 @@ impl ContextProvider for SessionContextProvider<'_> { fn create_cte_work_table( &self, name: &str, - schema: SchemaRef, + schema: arrow::datatypes::SchemaRef, ) -> datafusion_common::Result> { - let table = Arc::new(CteWorkTable::new(name, schema)); + let table = Arc::new(crate::datasource::cte_worktable::CteWorkTable::new( + name, schema, + )); Ok(provider_as_source(table)) } @@ -1705,6 +1882,8 @@ impl ContextProvider for SessionContextProvider<'_> { } fn get_variable_type(&self, variable_names: &[String]) -> Option { + use datafusion_expr::var_provider::{VarType, is_system_variables}; + if variable_names.is_empty() { return None; } @@ -1738,14 +1917,21 @@ impl ContextProvider for SessionContextProvider<'_> { self.state.window_functions().keys().cloned().collect() } - fn get_file_type(&self, ext: &str) -> datafusion_common::Result> { + fn get_file_type( + &self, + ext: &str, + ) -> datafusion_common::Result< + Arc, + > { self.state .file_formats .get(&ext.to_lowercase()) .ok_or(plan_datafusion_err!( "There is no registered file format with ext {ext}" )) - .map(|file_type| format_as_file_type(Arc::clone(file_type))) + .map(|file_type| { + crate::datasource::file_format::format_as_file_type(Arc::clone(file_type)) + }) } } @@ -1869,10 +2055,24 @@ impl FunctionRegistry for SessionState { self.expr_planners.push(expr_planner); Ok(()) } + + fn udafs(&self) -> HashSet { + self.aggregate_functions.keys().cloned().collect() + } + + fn udwfs(&self) -> HashSet { + self.window_functions.keys().cloned().collect() + } +} + +impl datafusion_execution::TaskContextProvider for SessionState { + fn task_ctx(&self) -> Arc { + SessionState::task_ctx(self) + } } impl OptimizerConfig for SessionState { - fn query_execution_start_time(&self) -> DateTime { + fn query_execution_start_time(&self) -> Option> { self.execution_props.query_execution_start_time } @@ -1880,8 +2080,8 @@ impl OptimizerConfig for SessionState { &self.execution_props.alias_generator } - fn options(&self) -> &ConfigOptions { - self.config_options() + fn options(&self) -> Arc { + Arc::clone(self.config.options()) } fn function_registry(&self) -> Option<&dyn FunctionRegistry> { @@ -1924,51 +2124,35 @@ impl QueryPlanner for DefaultQueryPlanner { } } -struct SessionSimplifyProvider<'a> { - state: &'a SessionState, - df_schema: &'a DFSchema, -} - -impl<'a> SessionSimplifyProvider<'a> { - fn new(state: &'a SessionState, df_schema: &'a DFSchema) -> Self { - Self { state, df_schema } - } -} - -impl SimplifyInfo for SessionSimplifyProvider<'_> { - fn is_boolean_type(&self, expr: &Expr) -> datafusion_common::Result { - Ok(expr.get_type(self.df_schema)? == DataType::Boolean) - } - - fn nullable(&self, expr: &Expr) -> datafusion_common::Result { - expr.nullable(self.df_schema) - } - - fn execution_props(&self) -> &ExecutionProps { - self.state.execution_props() - } - - fn get_data_type(&self, expr: &Expr) -> datafusion_common::Result { - expr.get_type(self.df_schema) - } -} - #[derive(Debug)] pub(crate) struct PreparedPlan { /// Data types of the parameters - pub(crate) data_types: Vec, + pub(crate) fields: Vec, /// The prepared logical plan pub(crate) plan: Arc, } +/// A [`CacheFactory`] can be registered via [`SessionState`] +/// to create a custom logical plan for [`crate::dataframe::DataFrame::cache`]. +/// Additionally, a custom [`crate::physical_planner::ExtensionPlanner`]/[`QueryPlanner`] +/// may need to be implemented to handle such plans. +pub trait CacheFactory: Debug + Send + Sync { + /// Create a logical plan for caching + fn create( + &self, + plan: LogicalPlan, + session_state: &SessionState, + ) -> datafusion_common::Result; +} + #[cfg(test)] mod tests { use super::{SessionContextProvider, SessionStateBuilder}; use crate::common::assert_contains; use crate::config::ConfigOptions; + use crate::datasource::MemTable; use crate::datasource::empty::EmptyTable; use crate::datasource::provider_as_source; - use crate::datasource::MemTable; use crate::execution::context::SessionState; use crate::logical_expr::planner::ExprPlanner; use crate::logical_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF}; @@ -1980,18 +2164,21 @@ mod tests { use datafusion_catalog::MemoryCatalogProviderList; use datafusion_common::DFSchema; use datafusion_common::Result; + use datafusion_common::config::Dialect; use datafusion_execution::config::SessionConfig; use datafusion_expr::Expr; - use datafusion_optimizer::optimizer::OptimizerRule; use datafusion_optimizer::Optimizer; + use datafusion_optimizer::optimizer::OptimizerRule; use datafusion_physical_plan::display::DisplayableExecutionPlan; use datafusion_sql::planner::{PlannerContext, SqlToRel}; use std::collections::HashMap; use std::sync::Arc; #[test] + #[cfg(feature = "sql")] fn test_session_state_with_default_features() { // test array planners with and without builtin planners + #[cfg(feature = "sql")] fn sql_to_expr(state: &SessionState) -> Result { let provider = SessionContextProvider { state, @@ -2001,8 +2188,8 @@ mod tests { let sql = "[1,2,3]"; let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); let df_schema = DFSchema::try_from(schema)?; - let dialect = state.config.options().sql_parser.dialect.as_str(); - let sql_expr = state.sql_to_expr(sql, dialect)?; + let dialect = state.config.options().sql_parser.dialect; + let sql_expr = state.sql_to_expr(sql, &dialect)?; let query = SqlToRel::new_with_options(&provider, state.get_parser_options()); query.sql_to_expr(sql_expr, &df_schema, &mut PlannerContext::new()) @@ -2018,6 +2205,36 @@ mod tests { assert!(sql_to_expr(&state).is_err()) } + #[test] + #[cfg(feature = "sql")] + fn test_create_logical_expr_from_sql_expr() { + let state = SessionStateBuilder::new().with_default_features().build(); + + let provider = SessionContextProvider { + state: &state, + tables: HashMap::new(), + }; + + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let df_schema = DFSchema::try_from(schema).unwrap(); + let dialect = state.config.options().sql_parser.dialect; + let query = SqlToRel::new_with_options(&provider, state.get_parser_options()); + + for sql in ["[1,2,3]", "a > 10", "SUM(a)"] { + let sql_expr = state.sql_to_expr(sql, &dialect).unwrap(); + let from_str = query + .sql_to_expr(sql_expr, &df_schema, &mut PlannerContext::new()) + .unwrap(); + + let sql_expr_with_alias = + state.sql_to_expr_with_alias(sql, &dialect).unwrap(); + let from_expr = state + .create_logical_expr_from_sql_expr(sql_expr_with_alias, &df_schema) + .unwrap(); + assert_eq!(from_str, from_expr); + } + } + #[test] fn test_from_existing() -> Result<()> { fn employee_batch() -> RecordBatch { @@ -2058,13 +2275,15 @@ mod tests { .table_exist("employee"); assert!(is_exist); let new_state = SessionStateBuilder::new_from_existing(session_state).build(); - assert!(new_state - .catalog_list() - .catalog(default_catalog.as_str()) - .unwrap() - .schema(default_schema.as_str()) - .unwrap() - .table_exist("employee")); + assert!( + new_state + .catalog_list() + .catalog(default_catalog.as_str()) + .unwrap() + .schema(default_schema.as_str()) + .unwrap() + .table_exist("employee") + ); // if `with_create_default_catalog_and_schema` is disabled, the new one shouldn't create default catalog and schema let disable_create_default = @@ -2072,10 +2291,12 @@ mod tests { let without_default_state = SessionStateBuilder::new() .with_config(disable_create_default) .build(); - assert!(without_default_state - .catalog_list() - .catalog(&default_catalog) - .is_none()); + assert!( + without_default_state + .catalog_list() + .catalog(&default_catalog) + .is_none() + ); let new_state = SessionStateBuilder::new_from_existing(without_default_state).build(); assert!(new_state.catalog_list().catalog(&default_catalog).is_none()); @@ -2160,7 +2381,8 @@ mod tests { } let state = &context_provider.state; - let statement = state.sql_to_statement("select count(*) from t", "mysql")?; + let statement = + state.sql_to_statement("select count(*) from t", &Dialect::MySQL)?; let plan = SqlToRel::new(&context_provider).statement_to_plan(statement)?; state.create_physical_plan(&plan).await } diff --git a/datafusion/core/src/execution/session_state_defaults.rs b/datafusion/core/src/execution/session_state_defaults.rs index a241738bd3a42..721710d4e057e 100644 --- a/datafusion/core/src/execution/session_state_defaults.rs +++ b/datafusion/core/src/execution/session_state_defaults.rs @@ -17,6 +17,7 @@ use crate::catalog::listing_schema::ListingSchemaProvider; use crate::catalog::{CatalogProvider, TableProviderFactory}; +use crate::datasource::file_format::FileFormatFactory; use crate::datasource::file_format::arrow::ArrowFormatFactory; #[cfg(feature = "avro")] use crate::datasource::file_format::avro::AvroFormatFactory; @@ -24,7 +25,6 @@ use crate::datasource::file_format::csv::CsvFormatFactory; use crate::datasource::file_format::json::JsonFormatFactory; #[cfg(feature = "parquet")] use crate::datasource::file_format::parquet::ParquetFormatFactory; -use crate::datasource::file_format::FileFormatFactory; use crate::datasource::provider::DefaultTableFactory; use crate::execution::context::SessionState; #[cfg(feature = "nested_expressions")] @@ -90,11 +90,10 @@ impl SessionStateDefaults { Arc::new(functions_nested::planner::NestedFunctionPlanner), #[cfg(feature = "nested_expressions")] Arc::new(functions_nested::planner::FieldAccessPlanner), - #[cfg(any( - feature = "datetime_expressions", - feature = "unicode_expressions" - ))] - Arc::new(functions::planner::UserDefinedFunctionPlanner), + #[cfg(feature = "datetime_expressions")] + Arc::new(functions::datetime::planner::DatetimeFunctionPlanner), + #[cfg(feature = "unicode_expressions")] + Arc::new(functions::unicode::planner::UnicodeFunctionPlanner), Arc::new(functions_aggregate::planner::AggregateFunctionPlanner), Arc::new(functions_window::planner::WindowFunctionPlanner), ]; @@ -102,9 +101,9 @@ impl SessionStateDefaults { expr_planners } - /// returns the list of default [`ScalarUDF']'s + /// returns the list of default [`ScalarUDF`]s pub fn default_scalar_functions() -> Vec> { - #[cfg_attr(not(feature = "nested_expressions"), allow(unused_mut))] + #[cfg_attr(not(feature = "nested_expressions"), expect(unused_mut))] let mut functions: Vec> = functions::all_default_functions(); #[cfg(feature = "nested_expressions")] @@ -113,12 +112,12 @@ impl SessionStateDefaults { functions } - /// returns the list of default [`AggregateUDF']'s + /// returns the list of default [`AggregateUDF`]s pub fn default_aggregate_functions() -> Vec> { functions_aggregate::all_default_aggregate_functions() } - /// returns the list of default [`WindowUDF']'s + /// returns the list of default [`WindowUDF`]s pub fn default_window_functions() -> Vec> { functions_window::all_default_window_functions() } @@ -128,7 +127,7 @@ impl SessionStateDefaults { functions_table::all_default_table_functions() } - /// returns the list of default [`FileFormatFactory']'s + /// returns the list of default [`FileFormatFactory`]s pub fn default_file_formats() -> Vec> { let file_formats: Vec> = vec![ #[cfg(feature = "parquet")] @@ -156,7 +155,7 @@ impl SessionStateDefaults { } /// registers all the builtin array functions - #[cfg_attr(not(feature = "nested_expressions"), allow(unused_variables))] + #[cfg_attr(not(feature = "nested_expressions"), expect(unused_variables))] pub fn register_array_functions(state: &mut SessionState) { // register crate of array expressions (if enabled) #[cfg(feature = "nested_expressions")] diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index 6956108e2df3f..349eee5592abe 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 // @@ -35,11 +35,15 @@ ) )] #![warn(missing_docs, clippy::needless_borrow)] +// Use `allow` instead of `expect` for test configuration to explicitly +// disable the lint for all test code rather than expecting violations +#![cfg_attr(test, allow(clippy::needless_pass_by_value))] //! [DataFusion] is an extensible query engine written in Rust that //! uses [Apache Arrow] as its in-memory format. DataFusion's target users are //! developers building fast and feature rich database and analytic systems, -//! customized to particular workloads. See [use cases] for examples. +//! customized to particular workloads. Please see the [DataFusion website] for +//! additional documentation, [use cases] and examples. //! //! "Out of the box," DataFusion offers [SQL] and [`Dataframe`] APIs, //! excellent [performance], built-in support for CSV, Parquet, JSON, and Avro, @@ -53,6 +57,7 @@ //! See the [Architecture] section below for more details. //! //! [DataFusion]: https://datafusion.apache.org/ +//! [DataFusion website]: https://datafusion.apache.org //! [Apache Arrow]: https://arrow.apache.org //! [use cases]: https://datafusion.apache.org/user-guide/introduction.html#use-cases //! [SQL]: https://datafusion.apache.org/user-guide/sql/index.html @@ -84,26 +89,29 @@ //! let ctx = SessionContext::new(); //! //! // create the dataframe -//! let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; +//! let df = ctx +//! .read_csv("tests/data/example.csv", CsvReadOptions::new()) +//! .await?; //! //! // create a plan -//! let df = df.filter(col("a").lt_eq(col("b")))? -//! .aggregate(vec![col("a")], vec![min(col("b"))])? -//! .limit(0, Some(100))?; +//! let df = df +//! .filter(col("a").lt_eq(col("b")))? +//! .aggregate(vec![col("a")], vec![min(col("b"))])? +//! .limit(0, Some(100))?; //! //! // execute the plan //! let results: Vec = df.collect().await?; //! //! // format the results -//! let pretty_results = arrow::util::pretty::pretty_format_batches(&results)? -//! .to_string(); +//! let pretty_results = +//! arrow::util::pretty::pretty_format_batches(&results)?.to_string(); //! //! let expected = vec![ //! "+---+----------------+", //! "| a | min(?table?.b) |", //! "+---+----------------+", //! "| 1 | 2 |", -//! "+---+----------------+" +//! "+---+----------------+", //! ]; //! //! assert_eq!(pretty_results.trim().lines().collect::>(), expected); @@ -124,24 +132,27 @@ //! # async fn main() -> Result<()> { //! let ctx = SessionContext::new(); //! -//! ctx.register_csv("example", "tests/data/example.csv", CsvReadOptions::new()).await?; +//! ctx.register_csv("example", "tests/data/example.csv", CsvReadOptions::new()) +//! .await?; //! //! // create a plan -//! let df = ctx.sql("SELECT a, MIN(b) FROM example WHERE a <= b GROUP BY a LIMIT 100").await?; +//! let df = ctx +//! .sql("SELECT a, MIN(b) FROM example WHERE a <= b GROUP BY a LIMIT 100") +//! .await?; //! //! // execute the plan //! let results: Vec = df.collect().await?; //! //! // format the results -//! let pretty_results = arrow::util::pretty::pretty_format_batches(&results)? -//! .to_string(); +//! let pretty_results = +//! arrow::util::pretty::pretty_format_batches(&results)?.to_string(); //! //! let expected = vec![ //! "+---+----------------+", //! "| a | min(example.b) |", //! "+---+----------------+", //! "| 1 | 2 |", -//! "+---+----------------+" +//! "+---+----------------+", //! ]; //! //! assert_eq!(pretty_results.trim().lines().collect::>(), expected); @@ -311,17 +322,17 @@ //! ``` //! //! A [`TableProvider`] provides information for planning and -//! an [`ExecutionPlan`] for execution. DataFusion includes [`ListingTable`], -//! a [`TableProvider`] which reads individual files or directories of files -//! ("partitioned datasets") of the same file format. Users can add -//! support for new file formats by implementing the [`TableProvider`] -//! trait. +//! an [`ExecutionPlan`] for execution. DataFusion includes two built-in +//! table providers that support common file formats and require no runtime services, +//! [`ListingTable`] and [`MemTable`]. You can add support for any other data +//! source and/or file formats by implementing the [`TableProvider`] trait. //! //! See also: //! //! 1. [`ListingTable`]: Reads data from one or more Parquet, JSON, CSV, or AVRO -//! files supporting HIVE style partitioning, optional compression, directly -//! reading from remote object store and more. +//! files in one or more local or remote directories. Supports HIVE style +//! partitioning, optional compression, directly reading from remote +//! object store, file metadata caching, and more. //! //! 2. [`MemTable`]: Reads data from in memory [`RecordBatch`]es. //! @@ -350,7 +361,7 @@ //! [`TreeNode`]: datafusion_common::tree_node::TreeNode //! [`tree_node module`]: datafusion_expr::logical_plan::tree_node //! [`ExprSimplifier`]: crate::optimizer::simplify_expressions::ExprSimplifier -//! [`expr_api`.rs]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/expr_api.rs +//! [`expr_api`.rs]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/query_planning/expr_api.rs //! //! ### Physical Plans //! @@ -441,7 +452,30 @@ //! other operators read a single [`RecordBatch`] from their input to produce a //! single [`RecordBatch`] as output. //! -//! For example, given this SQL query: +//! For example, given this SQL: +//! +//! ```sql +//! SELECT name FROM 'data.parquet' WHERE id > 10 +//! ``` +//! +//! An simplified DataFusion execution plan is shown below. It first reads +//! data from the Parquet file, then applies the filter, then the projection, +//! and finally produces output. Each step processes one [`RecordBatch`] at a +//! time. Multiple batches are processed concurrently on different CPU cores +//! for plans with multiple partitions. +//! +//! ```text +//! ┌─────────────┐ ┌──────────────┐ ┌────────────────┐ ┌──────────────────┐ ┌──────────┐ +//! │ Parquet │───▶│ DataSource │───▶│ FilterExec │───▶│ ProjectionExec │───▶│ Results │ +//! │ File │ │ │ │ │ │ │ │ │ +//! └─────────────┘ └──────────────┘ └────────────────┘ └──────────────────┘ └──────────┘ +//! (reads data) (id > 10) (keeps "name" col) +//! RecordBatch ───▶ RecordBatch ────▶ RecordBatch ────▶ RecordBatch +//! ``` +//! +//! DataFusion uses the classic "pull" based control flow (explained more in the +//! next section) to implement streaming execution. As an example, +//! consider the following SQL query: //! //! ```sql //! SELECT date_trunc('month', time) FROM data WHERE id IN (10,20,30); @@ -498,10 +532,21 @@ //! While preparing for execution, DataFusion tries to create this many distinct //! `async` [`Stream`]s for each [`ExecutionPlan`]. //! The [`Stream`]s for certain [`ExecutionPlan`]s, such as [`RepartitionExec`] -//! and [`CoalescePartitionsExec`], spawn [Tokio] [`task`]s, that are run by +//! and [`CoalescePartitionsExec`], spawn [Tokio] [`task`]s, that run on //! threads managed by the [`Runtime`]. //! Many DataFusion [`Stream`]s perform CPU intensive processing. //! +//! ### Cooperative Scheduling +//! +//! DataFusion uses cooperative scheduling, which means that each [`Stream`] +//! is responsible for yielding control back to the [`Runtime`] after +//! some amount of work is done. Please see the [`coop`] module documentation +//! for more details. +//! +//! [`coop`]: datafusion_physical_plan::coop +//! +//! ### Network I/O and CPU intensive tasks +//! //! Using `async` for CPU intensive tasks makes it easy for [`TableProvider`]s //! to perform network I/O using standard Rust `async` during execution. //! However, this design also makes it very easy to mix CPU intensive and latency @@ -510,17 +555,20 @@ //! initial development and processing local files, but it can lead to problems //! under load and/or when reading from network sources such as AWS S3. //! +//! ### Optimizing Latency: Throttled CPU / IO under Highly Concurrent Load +//! //! If your system does not fully utilize either the CPU or network bandwidth //! during execution, or you see significantly higher tail (e.g. p99) latencies //! responding to network requests, **it is likely you need to use a different -//! [`Runtime`] for CPU intensive DataFusion plans**. This effect can be especially -//! pronounced when running several queries concurrently. +//! [`Runtime`] for DataFusion plans**. The [thread_pools example] +//! has an example of how to do so. //! -//! As shown in the following figure, using the same [`Runtime`] for both CPU -//! intensive processing and network requests can introduce significant -//! delays in responding to those network requests. Delays in processing network -//! requests can and does lead network flow control to throttle the available -//! bandwidth in response. +//! As shown below, using the same [`Runtime`] for both CPU intensive processing +//! and network requests can introduce significant delays in responding to +//! those network requests. Delays in processing network requests can and does +//! lead network flow control to throttle the available bandwidth in response. +//! This effect can be especially pronounced when running multiple queries +//! concurrently. //! //! ```text //! Legend @@ -591,7 +639,7 @@ //! └─────────────┘ ┗━━━━━━━━━━━━━━━━━━━┻━━━━━━━━━━━━━━━━━━━┻━━━━━━━━━━━━━━┛ //! ─────────────────────────────────────────────────────────────▶ //! time -//!``` +//! ``` //! //! Note that DataFusion does not use [`tokio::task::spawn_blocking`] for //! CPU-bounded work, because `spawn_blocking` is designed for blocking **IO**, @@ -602,6 +650,7 @@ //! //! [Tokio]: https://tokio.rs //! [`Runtime`]: tokio::runtime::Runtime +//! [thread_pools example]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/query_planning/thread_pools.rs //! [`task`]: tokio::task //! [Using Rustlang’s Async Tokio Runtime for CPU-Bound Tasks]: https://thenewstack.io/using-rustlangs-async-tokio-runtime-for-cpu-bound-tasks/ //! [`RepartitionExec`]: physical_plan::repartition::RepartitionExec @@ -717,6 +766,8 @@ pub const DATAFUSION_VERSION: &str = env!("CARGO_PKG_VERSION"); extern crate core; + +#[cfg(feature = "sql")] extern crate sqlparser; pub mod dataframe; @@ -727,11 +778,16 @@ pub mod physical_planner; pub mod prelude; pub mod scalar; -// re-export dependencies from arrow-rs to minimize version maintenance for crate users +// Re-export dependencies that are part of DataFusion public API (e.g. via DataFusionError) pub use arrow; +pub use object_store; + #[cfg(feature = "parquet")] pub use parquet; +#[cfg(feature = "avro")] +pub use datafusion_datasource_avro::apache_avro; + // re-export DataFusion sub-crates at the top level. Use `pub use *` // so that the contents of the subcrates appears in rustdocs // for details, see https://github.com/apache/datafusion/issues/6648 @@ -786,6 +842,11 @@ pub mod physical_expr { pub use datafusion_physical_expr::*; } +/// re-export of [`datafusion_physical_expr_adapter`] crate +pub mod physical_expr_adapter { + pub use datafusion_physical_expr_adapter::*; +} + /// re-export of [`datafusion_physical_plan`] crate pub mod physical_plan { pub use datafusion_physical_plan::*; @@ -796,6 +857,7 @@ pub use datafusion_common::assert_batches_eq; pub use datafusion_common::assert_batches_sorted_eq; /// re-export of [`datafusion_sql`] crate +#[cfg(feature = "sql")] pub mod sql { pub use datafusion_sql::*; } @@ -811,13 +873,6 @@ pub mod functions_nested { pub use datafusion_functions_nested::*; } -/// re-export of [`datafusion_functions_nested`] crate as [`functions_array`] for backward compatibility, if "nested_expressions" feature is enabled -#[deprecated(since = "41.0.0", note = "use datafusion-functions-nested instead")] -pub mod functions_array { - #[cfg(feature = "nested_expressions")] - pub use datafusion_functions_nested::*; -} - /// re-export of [`datafusion_functions_aggregate`] crate pub mod functions_aggregate { pub use datafusion_functions_aggregate::*; @@ -876,20 +931,20 @@ doc_comment::doctest!("../../../README.md", readme_example_test); // #[cfg(doctest)] doc_comment::doctest!( - "../../../docs/source/user-guide/concepts-readings-events.md", - user_guide_concepts_readings_events + "../../../docs/source/user-guide/arrow-introduction.md", + user_guide_arrow_introduction ); #[cfg(doctest)] doc_comment::doctest!( - "../../../docs/source/user-guide/configs.md", - user_guide_configs + "../../../docs/source/user-guide/concepts-readings-events.md", + user_guide_concepts_readings_events ); #[cfg(doctest)] doc_comment::doctest!( - "../../../docs/source/user-guide/runtime_configs.md", - user_guide_runtime_configs + "../../../docs/source/user-guide/configs.md", + user_guide_configs ); #[cfg(doctest)] @@ -1047,8 +1102,14 @@ doc_comment::doctest!( #[cfg(doctest)] doc_comment::doctest!( - "../../../docs/source/library-user-guide/adding-udfs.md", - library_user_guide_adding_udfs + "../../../docs/source/library-user-guide/functions/adding-udfs.md", + library_user_guide_functions_adding_udfs +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/functions/spark.md", + library_user_guide_functions_spark ); #[cfg(doctest)] @@ -1119,8 +1180,56 @@ doc_comment::doctest!( #[cfg(doctest)] doc_comment::doctest!( - "../../../docs/source/library-user-guide/upgrading.md", - library_user_guide_upgrading + "../../../docs/source/library-user-guide/upgrading/46.0.0.md", + library_user_guide_upgrading_46_0_0 +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/upgrading/47.0.0.md", + library_user_guide_upgrading_47_0_0 +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/upgrading/48.0.0.md", + library_user_guide_upgrading_48_0_0 +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/upgrading/48.0.1.md", + library_user_guide_upgrading_48_0_1 +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/upgrading/49.0.0.md", + library_user_guide_upgrading_49_0_0 +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/upgrading/50.0.0.md", + library_user_guide_upgrading_50_0_0 +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/upgrading/51.0.0.md", + library_user_guide_upgrading_51_0_0 +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/upgrading/52.0.0.md", + library_user_guide_upgrading_52_0_0 +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/upgrading/53.0.0.md", + library_user_guide_upgrading_53_0_0 ); #[cfg(doctest)] diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index fbb4250fc4dfb..b4fb44f670e8d 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -18,13 +18,13 @@ //! Planner for [`LogicalPlan`] to [`ExecutionPlan`] use std::borrow::Cow; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; use crate::datasource::file_format::file_type_to_format; use crate::datasource::listing::ListingTableUrl; -use crate::datasource::physical_plan::FileSinkConfig; -use crate::datasource::{source_as_provider, DefaultTableSource}; +use crate::datasource::physical_plan::{FileOutputMode, FileSinkConfig}; +use crate::datasource::{DefaultTableSource, source_as_provider}; use crate::error::{DataFusionError, Result}; use crate::execution::context::{ExecutionProps, SessionState}; use crate::logical_expr::utils::generate_sort_key; @@ -39,64 +39,76 @@ use crate::physical_expr::{create_physical_expr, create_physical_exprs}; use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use crate::physical_plan::analyze::AnalyzeExec; use crate::physical_plan::explain::ExplainExec; -use crate::physical_plan::filter::FilterExec; +use crate::physical_plan::filter::FilterExecBuilder; use crate::physical_plan::joins::utils as join_utils; use crate::physical_plan::joins::{ CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode, SortMergeJoinExec, }; use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; -use crate::physical_plan::projection::ProjectionExec; +use crate::physical_plan::projection::{ProjectionExec, ProjectionExpr}; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::union::UnionExec; use crate::physical_plan::unnest::UnnestExec; use crate::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use crate::physical_plan::{ - displayable, windows, ExecutionPlan, ExecutionPlanProperties, InputOrderMode, - Partitioning, PhysicalExpr, WindowExpr, + ExecutionPlan, ExecutionPlanProperties, InputOrderMode, Partitioning, PhysicalExpr, + WindowExpr, displayable, windows, }; -use datafusion_physical_plan::empty::EmptyExec; -use datafusion_physical_plan::recursive_query::RecursiveQueryExec; +use crate::schema_equivalence::schema_satisfied_by; -use arrow::array::{builder::StringBuilder, RecordBatch}; +use arrow::array::{RecordBatch, builder::StringBuilder}; use arrow::compute::SortOptions; -use arrow::datatypes::{Schema, SchemaRef}; +use arrow::datatypes::Schema; +use arrow_schema::Field; +use datafusion_catalog::ScanArgs; +use datafusion_common::Column; use datafusion_common::display::ToStringifiedPlan; +use datafusion_common::format::ExplainAnalyzeLevel; use datafusion_common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeVisitor, + Transformed, TreeNode, TreeNodeRecursion, TreeNodeVisitor, +}; +use datafusion_common::{ + DFSchema, DFSchemaRef, ScalarValue, exec_err, internal_datafusion_err, internal_err, + not_impl_err, plan_err, }; use datafusion_common::{ - exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, - ScalarValue, + TableReference, assert_eq_or_internal_err, assert_or_internal_err, }; +use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource::memory::MemorySourceConfig; use datafusion_expr::dml::{CopyTo, InsertOp}; use datafusion_expr::expr::{ - physical_name, AggregateFunction, AggregateFunctionParams, Alias, GroupingSet, - WindowFunction, WindowFunctionParams, + AggregateFunction, AggregateFunctionParams, Alias, GroupingSet, NullTreatment, + WindowFunction, WindowFunctionParams, physical_name, }; use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; +use datafusion_expr::utils::{expr_to_columns, split_conjunction}; use datafusion_expr::{ - Analyze, DescribeTable, DmlStatement, Explain, ExplainFormat, Extension, FetchType, - Filter, JoinType, RecursiveQuery, SkipType, StringifiedPlan, WindowFrame, - WindowFrameBound, WriteOp, + Analyze, BinaryExpr, DescribeTable, DmlStatement, Explain, ExplainFormat, Extension, + FetchType, Filter, JoinType, Operator, RecursiveQuery, SkipType, StringifiedPlan, + WindowFrame, WindowFrameBound, WriteOp, }; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; -use datafusion_physical_expr::expressions::{Column, Literal}; -use datafusion_physical_expr::LexOrdering; +use datafusion_physical_expr::expressions::Literal; +use datafusion_physical_expr::{ + LexOrdering, PhysicalSortExpr, create_physical_sort_exprs, +}; use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::execution_plan::InvariantLevel; +use datafusion_physical_plan::joins::PiecewiseMergeJoinExec; +use datafusion_physical_plan::metrics::MetricType; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; +use datafusion_physical_plan::recursive_query::RecursiveQueryExec; use datafusion_physical_plan::unnest::ListUnnest; -use crate::schema_equivalence::schema_satisfied_by; use async_trait::async_trait; -use datafusion_datasource::file_groups::FileGroup; +use datafusion_physical_plan::async_func::{AsyncFuncExec, AsyncMapper}; use futures::{StreamExt, TryStreamExt}; -use itertools::{multiunzip, Itertools}; -use log::{debug, trace}; -use sqlparser::ast::NullTreatment; +use itertools::{Itertools, multiunzip}; +use log::debug; use tokio::sync::Mutex; /// Physical query planner that converts a `LogicalPlan` to an @@ -145,6 +157,80 @@ pub trait ExtensionPlanner { physical_inputs: &[Arc], session_state: &SessionState, ) -> Result>>; + + /// Create a physical plan for a [`LogicalPlan::TableScan`]. + /// + /// This is useful for planning valid [`TableSource`]s that are not [`TableProvider`]s. + /// + /// Returns: + /// * `Ok(Some(plan))` if the planner knows how to plan the `scan` + /// * `Ok(None)` if the planner does not know how to plan the `scan` and wants to delegate the planning to another [`ExtensionPlanner`] + /// * `Err` if the planner knows how to plan the `scan` but errors while doing so + /// + /// # Example + /// + /// ```rust,ignore + /// use std::sync::Arc; + /// use datafusion::physical_plan::ExecutionPlan; + /// use datafusion::logical_expr::TableScan; + /// use datafusion::execution::context::SessionState; + /// use datafusion::error::Result; + /// use datafusion_physical_planner::{ExtensionPlanner, PhysicalPlanner}; + /// use async_trait::async_trait; + /// + /// // Your custom table source type + /// struct MyCustomTableSource { /* ... */ } + /// + /// // Your custom execution plan + /// struct MyCustomExec { /* ... */ } + /// + /// struct MyExtensionPlanner; + /// + /// #[async_trait] + /// impl ExtensionPlanner for MyExtensionPlanner { + /// async fn plan_extension( + /// &self, + /// _planner: &dyn PhysicalPlanner, + /// _node: &dyn UserDefinedLogicalNode, + /// _logical_inputs: &[&LogicalPlan], + /// _physical_inputs: &[Arc], + /// _session_state: &SessionState, + /// ) -> Result>> { + /// Ok(None) + /// } + /// + /// async fn plan_table_scan( + /// &self, + /// _planner: &dyn PhysicalPlanner, + /// scan: &TableScan, + /// _session_state: &SessionState, + /// ) -> Result>> { + /// // Check if this is your custom table source + /// if scan.source.as_any().is::() { + /// // Create a custom execution plan for your table source + /// let exec = MyCustomExec::new( + /// scan.table_name.clone(), + /// Arc::clone(scan.projected_schema.inner()), + /// ); + /// Ok(Some(Arc::new(exec))) + /// } else { + /// // Return None to let other extension planners handle it + /// Ok(None) + /// } + /// } + /// } + /// ``` + /// + /// [`TableSource`]: datafusion_expr::TableSource + /// [`TableProvider`]: datafusion_catalog::TableProvider + async fn plan_table_scan( + &self, + _planner: &dyn PhysicalPlanner, + _scan: &TableScan, + _session_state: &SessionState, + ) -> Result>> { + Ok(None) + } } /// Default single node physical query planner that converts a @@ -266,7 +352,8 @@ struct LogicalNode<'a> { impl DefaultPhysicalPlanner { /// Create a physical planner that uses `extension_planners` to - /// plan user-defined logical nodes [`LogicalPlan::Extension`]. + /// plan user-defined logical nodes [`LogicalPlan::Extension`] + /// or user-defined table sources in [`LogicalPlan::TableScan`]. /// The planner uses the first [`ExtensionPlanner`] to return a non-`None` /// plan. pub fn with_extension_planners( @@ -275,6 +362,24 @@ impl DefaultPhysicalPlanner { Self { extension_planners } } + fn ensure_schema_matches( + &self, + logical_schema: &DFSchemaRef, + physical_plan: &Arc, + context: &str, + ) -> Result<()> { + if !logical_schema.matches_arrow_schema(&physical_plan.schema()) { + return plan_err!( + "{} created an ExecutionPlan with mismatched schema. \ + LogicalPlan schema: {:?}, ExecutionPlan schema: {:?}", + context, + logical_schema, + physical_plan.schema() + ); + } + Ok(()) + } + /// Create a physical plan from a logical plan async fn create_initial_plan( &self, @@ -341,11 +446,11 @@ impl DefaultPhysicalPlanner { .flatten() .collect::>(); // Ideally this never happens if we have a valid LogicalPlan tree - if outputs.len() != 1 { - return internal_err!( - "Failed to convert LogicalPlan to ExecutionPlan: More than one root detected" - ); - } + assert_eq_or_internal_err!( + outputs.len(), + 1, + "Failed to convert LogicalPlan to ExecutionPlan: More than one root detected" + ); let plan = outputs.pop().unwrap(); Ok(plan) } @@ -443,24 +548,55 @@ impl DefaultPhysicalPlanner { ) -> Result> { let exec_node: Arc = match node { // Leaves (no children) - LogicalPlan::TableScan(TableScan { - source, - projection, - filters, - fetch, - .. - }) => { - let source = source_as_provider(source)?; - // Remove all qualifiers from the scan as the provider - // doesn't know (nor should care) how the relation was - // referred to in the query - let filters = unnormalize_cols(filters.iter().cloned()); - source - .scan(session_state, projection.as_ref(), &filters, *fetch) - .await? + LogicalPlan::TableScan(scan) => { + let TableScan { + source, + projection, + filters, + fetch, + projected_schema, + .. + } = scan; + + if let Ok(source) = source_as_provider(source) { + // Remove all qualifiers from the scan as the provider + // doesn't know (nor should care) how the relation was + // referred to in the query + let filters = unnormalize_cols(filters.iter().cloned()); + let filters_vec = filters.into_iter().collect::>(); + let opts = ScanArgs::default() + .with_projection(projection.as_deref()) + .with_filters(Some(&filters_vec)) + .with_limit(*fetch); + let res = source.scan_with_args(session_state, opts).await?; + Arc::clone(res.plan()) + } else { + let mut maybe_plan = None; + for planner in &self.extension_planners { + if maybe_plan.is_some() { + break; + } + + maybe_plan = + planner.plan_table_scan(self, scan, session_state).await?; + } + + let plan = match maybe_plan { + Some(plan) => plan, + None => { + return plan_err!( + "No installed planner was able to plan TableScan for custom TableSource: {:?}", + scan.table_name + ); + } + }; + let context = + format!("Extension planner for table scan {}", scan.table_name); + self.ensure_schema_matches(projected_schema, &plan, &context)?; + plan + } } LogicalPlan::Values(Values { values, schema }) => { - let exec_schema = schema.as_ref().to_owned().into(); let exprs = values .iter() .map(|row| { @@ -471,27 +607,23 @@ impl DefaultPhysicalPlanner { .collect::>>>() }) .collect::>>()?; - MemorySourceConfig::try_new_as_values(SchemaRef::new(exec_schema), exprs)? + MemorySourceConfig::try_new_as_values(Arc::clone(schema.inner()), exprs)? as _ } LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, schema, - }) => Arc::new(EmptyExec::new(SchemaRef::new( - schema.as_ref().to_owned().into(), - ))), + }) => Arc::new(EmptyExec::new(Arc::clone(schema.inner()))), LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: true, schema, - }) => Arc::new(PlaceholderRowExec::new(SchemaRef::new( - schema.as_ref().to_owned().into(), - ))), + }) => Arc::new(PlaceholderRowExec::new(Arc::clone(schema.inner()))), LogicalPlan::DescribeTable(DescribeTable { schema, output_schema, }) => { - let output_schema: Schema = output_schema.as_ref().into(); - self.plan_describe(Arc::clone(schema), Arc::new(output_schema))? + let output_schema = Arc::clone(output_schema.inner()); + self.plan_describe(&Arc::clone(schema), output_schema)? } // 1 Child @@ -501,13 +633,14 @@ impl DefaultPhysicalPlanner { file_type, partition_by, options: source_option_tuples, + output_schema: _, }) => { let original_url = output_url.clone(); let input_exec = children.one()?; let parsed_url = ListingTableUrl::parse(output_url)?; let object_store_url = parsed_url.object_store(); - let schema: Schema = (**input.schema()).clone().into(); + let schema = Arc::clone(input.schema().inner()); // Note: the DataType passed here is ignored for the purposes of writing and inferred instead // from the schema of the RecordBatch being written. This allows COPY statements to specify only @@ -519,16 +652,56 @@ impl DefaultPhysicalPlanner { let keep_partition_by_columns = match source_option_tuples .get("execution.keep_partition_by_columns") - .map(|v| v.trim()) { - None => session_state.config().options().execution.keep_partition_by_columns, + .map(|v| v.trim()) + { + None => { + session_state + .config() + .options() + .execution + .keep_partition_by_columns + } Some("true") => true, Some("false") => false, - Some(value) => - return Err(DataFusionError::Configuration(format!("provided value for 'execution.keep_partition_by_columns' was not recognized: \"{value}\""))), + Some(value) => { + return Err(DataFusionError::Configuration(format!( + "provided value for 'execution.keep_partition_by_columns' was not recognized: \"{value}\"" + ))); + } + }; + + // Parse single_file_output option if explicitly set + let file_output_mode = match source_option_tuples + .get("single_file_output") + .map(|v| v.trim()) + { + None => FileOutputMode::Automatic, + Some("true") => FileOutputMode::SingleFile, + Some("false") => FileOutputMode::Directory, + Some(value) => { + return Err(DataFusionError::Configuration(format!( + "provided value for 'single_file_output' was not recognized: \"{value}\"" + ))); + } }; + // Filter out sink-related options that are not format options + let format_options: HashMap = source_option_tuples + .iter() + .filter(|(k, _)| k.as_str() != "single_file_output") + .map(|(k, v)| (k.clone(), v.clone())) + .collect(); + let sink_format = file_type_to_format(file_type)? - .create(session_state, source_option_tuples)?; + .create(session_state, &format_options)?; + + // Determine extension based on format extension and compression + let file_extension = match sink_format.compression_type() { + Some(compression_type) => sink_format + .get_ext_with_compression(&compression_type) + .unwrap_or_else(|_| sink_format.get_ext()), + None => sink_format.get_ext(), + }; // Set file sink related options let config = FileSinkConfig { @@ -536,15 +709,23 @@ impl DefaultPhysicalPlanner { object_store_url, table_paths: vec![parsed_url], file_group: FileGroup::default(), - output_schema: Arc::new(schema), + output_schema: schema, table_partition_cols, insert_op: InsertOp::Append, keep_partition_by_columns, - file_extension: sink_format.get_ext(), + file_extension, + file_output_mode, }; + let ordering = input_exec.properties().output_ordering().cloned(); + sink_format - .create_writer_physical_plan(input_exec, session_state, config, None) + .create_writer_physical_plan( + input_exec, + session_state, + config, + ordering.map(Into::into), + ) .await? } LogicalPlan::Dml(DmlStatement { @@ -566,35 +747,110 @@ impl DefaultPhysicalPlanner { ); } } - LogicalPlan::Window(Window { window_expr, .. }) => { - if window_expr.is_empty() { - return internal_err!("Impossibly got empty window expression"); + LogicalPlan::Dml(DmlStatement { + table_name, + target, + op: WriteOp::Delete, + input, + .. + }) => { + if let Some(provider) = + target.as_any().downcast_ref::() + { + let filters = extract_dml_filters(input, table_name)?; + provider + .table_provider + .delete_from(session_state, filters) + .await + .map_err(|e| { + e.context(format!("DELETE operation on table '{table_name}'")) + })? + } else { + return exec_err!( + "Table source can't be downcasted to DefaultTableSource" + ); + } + } + LogicalPlan::Dml(DmlStatement { + table_name, + target, + op: WriteOp::Update, + input, + .. + }) => { + if let Some(provider) = + target.as_any().downcast_ref::() + { + // For UPDATE, the assignments are encoded in the projection of input + // We pass the filters and let the provider handle the projection + let filters = extract_dml_filters(input, table_name)?; + // Extract assignments from the projection in input plan + let assignments = extract_update_assignments(input)?; + provider + .table_provider + .update(session_state, assignments, filters) + .await + .map_err(|e| { + e.context(format!("UPDATE operation on table '{table_name}'")) + })? + } else { + return exec_err!( + "Table source can't be downcasted to DefaultTableSource" + ); + } + } + LogicalPlan::Dml(DmlStatement { + table_name, + target, + op: WriteOp::Truncate, + .. + }) => { + if let Some(provider) = + target.as_any().downcast_ref::() + { + provider + .table_provider + .truncate(session_state) + .await + .map_err(|e| { + e.context(format!( + "TRUNCATE operation on table '{table_name}'" + )) + })? + } else { + return exec_err!( + "Table source can't be downcasted to DefaultTableSource" + ); } + } + LogicalPlan::Window(Window { window_expr, .. }) => { + assert_or_internal_err!( + !window_expr.is_empty(), + "Impossibly got empty window expression" + ); let input_exec = children.one()?; let get_sort_keys = |expr: &Expr| match expr { - Expr::WindowFunction(WindowFunction { - params: - WindowFunctionParams { - ref partition_by, - ref order_by, - .. - }, - .. - }) => generate_sort_key(partition_by, order_by), + Expr::WindowFunction(window_fun) => { + let WindowFunctionParams { + partition_by, + order_by, + .. + } = &window_fun.as_ref().params; + generate_sort_key(partition_by, order_by) + } Expr::Alias(Alias { expr, .. }) => { // Convert &Box to &T match &**expr { - Expr::WindowFunction(WindowFunction { - params: - WindowFunctionParams { - ref partition_by, - ref order_by, - .. - }, - .. - }) => generate_sort_key(partition_by, order_by), + Expr::WindowFunction(window_fun) => { + let WindowFunctionParams { + partition_by, + order_by, + .. + } = &window_fun.as_ref().params; + generate_sort_key(partition_by, order_by) + } _ => unreachable!(), } } @@ -603,11 +859,11 @@ impl DefaultPhysicalPlanner { let sort_keys = get_sort_keys(&window_expr[0])?; if window_expr.len() > 1 { debug_assert!( - window_expr[1..] - .iter() - .all(|expr| get_sort_keys(expr).unwrap() == sort_keys), - "all window expressions shall have the same sort keys, as guaranteed by logical planning" - ); + window_expr[1..] + .iter() + .all(|expr| get_sort_keys(expr).unwrap() == sort_keys), + "all window expressions shall have the same sort keys, as guaranteed by logical planning" + ); } let logical_schema = node.schema(); @@ -664,6 +920,17 @@ impl DefaultPhysicalPlanner { ) { let mut differences = Vec::new(); + + if physical_input_schema.metadata() + != physical_input_schema_from_logical.metadata() + { + differences.push(format!( + "schema metadata differs: (physical) {:?} vs (logical) {:?}", + physical_input_schema.metadata(), + physical_input_schema_from_logical.metadata() + )); + } + if physical_input_schema.fields().len() != physical_input_schema_from_logical.fields().len() { @@ -693,11 +960,20 @@ impl DefaultPhysicalPlanner { if physical_field.is_nullable() && !logical_field.is_nullable() { differences.push(format!("field nullability at index {} [{}]: (physical) {} vs (logical) {}", i, physical_field.name(), physical_field.is_nullable(), logical_field.is_nullable())); } + if physical_field.metadata() != logical_field.metadata() { + differences.push(format!( + "field metadata at index {} [{}]: (physical) {:?} vs (logical) {:?}", + i, + physical_field.name(), + physical_field.metadata(), + logical_field.metadata() + )); + } } - return internal_err!("Physical input schema should be the same as the one converted from logical input schema. Differences: {}", differences - .iter() - .map(|s| format!("\n\t- {s}")) - .join("")); + return internal_err!( + "Physical input schema should be the same as the one converted from logical input schema. Differences: {}", + differences.iter().map(|s| format!("\n\t- {s}")).join("") + ); } let groups = self.create_grouping_physical_expr( @@ -719,9 +995,54 @@ impl DefaultPhysicalPlanner { }) .collect::>>()?; - let (aggregates, filters, _order_bys): (Vec<_>, Vec<_>, Vec<_>) = + let (mut aggregates, filters, _order_bys): (Vec<_>, Vec<_>, Vec<_>) = multiunzip(agg_filter); + let mut async_exprs = Vec::new(); + let num_input_columns = physical_input_schema.fields().len(); + + for agg_func in &mut aggregates { + match self.try_plan_async_exprs( + num_input_columns, + PlannedExprResult::Expr(agg_func.expressions()), + physical_input_schema.as_ref(), + )? { + PlanAsyncExpr::Async( + async_map, + PlannedExprResult::Expr(physical_exprs), + ) => { + async_exprs.extend(async_map.async_exprs); + + if let Some(new_agg_func) = agg_func.with_new_expressions( + physical_exprs, + agg_func + .order_bys() + .iter() + .cloned() + .map(|x| x.expr) + .collect(), + ) { + *agg_func = Arc::new(new_agg_func); + } else { + return internal_err!("Failed to plan async expression"); + } + } + PlanAsyncExpr::Sync(PlannedExprResult::Expr(_)) => { + // Do nothing + } + _ => { + return internal_err!( + "Unexpected result from try_plan_async_exprs" + ); + } + } + } + let input_exec = if !async_exprs.is_empty() { + Arc::new(AsyncFuncExec::try_new(async_exprs, input_exec)?) + } else { + input_exec + }; + let initial_aggr = Arc::new(AggregateExec::try_new( AggregateMode::Partial, groups.clone(), @@ -777,12 +1098,53 @@ impl DefaultPhysicalPlanner { let runtime_expr = self.create_physical_expr(predicate, input_dfschema, session_state)?; + + let input_schema = input.schema(); + let filter = match self.try_plan_async_exprs( + input_schema.fields().len(), + PlannedExprResult::Expr(vec![runtime_expr]), + input_schema.as_arrow(), + )? { + PlanAsyncExpr::Sync(PlannedExprResult::Expr(runtime_expr)) => { + FilterExecBuilder::new( + Arc::clone(&runtime_expr[0]), + physical_input, + ) + .with_batch_size(session_state.config().batch_size()) + .build()? + } + PlanAsyncExpr::Async( + async_map, + PlannedExprResult::Expr(runtime_expr), + ) => { + let async_exec = AsyncFuncExec::try_new( + async_map.async_exprs, + physical_input, + )?; + FilterExecBuilder::new( + Arc::clone(&runtime_expr[0]), + Arc::new(async_exec), + ) + // project the output columns excluding the async functions + // The async functions are always appended to the end of the schema. + .apply_projection(Some( + (0..input.schema().fields().len()).collect::>(), + ))? + .with_batch_size(session_state.config().batch_size()) + .build()? + } + _ => { + return internal_err!( + "Unexpected result from try_plan_async_exprs" + ); + } + }; + let selectivity = session_state .config() .options() .optimizer .default_filter_selectivity; - let filter = FilterExec::try_new(runtime_expr, physical_input)?; Arc::new(filter.with_default_selectivity(selectivity)?) } LogicalPlan::Repartition(Repartition { @@ -824,13 +1186,17 @@ impl DefaultPhysicalPlanner { }) => { let physical_input = children.one()?; let input_dfschema = input.as_ref().schema(); - let sort_expr = create_physical_sort_exprs( + let sort_exprs = create_physical_sort_exprs( expr, input_dfschema, session_state.execution_props(), )?; - let new_sort = - SortExec::new(sort_expr, physical_input).with_fetch(*fetch); + let Some(ordering) = LexOrdering::new(sort_exprs) else { + return internal_err!( + "SortExec requires at least one sort expression" + ); + }; + let new_sort = SortExec::new(ordering, physical_input).with_fetch(*fetch); Arc::new(new_sort) } LogicalPlan::Subquery(_) => todo!(), @@ -873,7 +1239,7 @@ impl DefaultPhysicalPlanner { .. }) => { let input = children.one()?; - let schema = SchemaRef::new(schema.as_ref().to_owned().into()); + let schema = Arc::clone(schema.inner()); let list_column_indices = list_type_columns .iter() .map(|(index, unnesting)| ListUnnest { @@ -887,22 +1253,21 @@ impl DefaultPhysicalPlanner { struct_type_columns.clone(), schema, options.clone(), - )) + )?) } // 2 Children LogicalPlan::Join(Join { - left, - right, + left: original_left, + right: original_right, on: keys, filter, join_type, - null_equals_null, + null_equality, + null_aware, schema: join_schema, .. }) => { - let null_equals_null = *null_equals_null; - let [physical_left, physical_right] = children.two()?; // If join has expression equijoin keys, add physical projection. @@ -918,23 +1283,25 @@ impl DefaultPhysicalPlanner { let (left, left_col_keys, left_projected) = wrap_projection_for_join_if_necessary( &left_keys, - left.as_ref().clone(), + original_left.as_ref().clone(), )?; let (right, right_col_keys, right_projected) = wrap_projection_for_join_if_necessary( &right_keys, - right.as_ref().clone(), + original_right.as_ref().clone(), )?; let column_on = (left_col_keys, right_col_keys); let left = Arc::new(left); let right = Arc::new(right); - let new_join = LogicalPlan::Join(Join::try_new_with_project_input( + let (new_join, requalified) = Join::try_new_with_project_input( node, Arc::clone(&left), Arc::clone(&right), column_on, - )?); + )?; + + let new_join = LogicalPlan::Join(new_join); // If inputs were projected then create ExecutionPlan for these new // LogicalPlan nodes. @@ -967,8 +1334,24 @@ impl DefaultPhysicalPlanner { // Remove temporary projected columns if left_projected || right_projected { - let final_join_result = - join_schema.iter().map(Expr::from).collect::>(); + // Re-qualify the join schema only if the inputs were previously requalified in + // `try_new_with_project_input`. This ensures that when building the Projection + // it can correctly resolve field nullability and data types + // by disambiguating fields from the left and right sides of the join. + let qualified_join_schema = if requalified { + Arc::new(qualify_join_schema_sides( + join_schema, + original_left, + original_right, + )?) + } else { + Arc::clone(join_schema) + }; + + let final_join_result = qualified_join_schema + .iter() + .map(Expr::from) + .collect::>(); let projection = LogicalPlan::Projection(Projection::try_new( final_join_result, Arc::new(new_join), @@ -1017,8 +1400,42 @@ impl DefaultPhysicalPlanner { }) .collect::>()?; + // TODO: `num_range_filters` can be used later on for ASOF joins (`num_range_filters > 1`) + let mut num_range_filters = 0; + let mut range_filters: Vec = Vec::new(); + let mut total_filters = 0; + let join_filter = match filter { Some(expr) => { + let split_expr = split_conjunction(expr); + for expr in split_expr.iter() { + match *expr { + Expr::BinaryExpr(BinaryExpr { + left: _, + right: _, + op, + }) => { + if matches!( + op, + Operator::Lt + | Operator::LtEq + | Operator::Gt + | Operator::GtEq + ) { + range_filters.push((**expr).clone()); + num_range_filters += 1; + } + total_filters += 1; + } + // TODO: Want to deal with `Expr::Between` for IEJoins, it counts as two range predicates + // which is why it is not dealt with in PWMJ + // Expr::Between(_) => {}, + _ => { + total_filters += 1; + } + } + } + // Extract columns from filter expression and saved in a HashSet let cols = expr.column_refs(); @@ -1055,7 +1472,7 @@ impl DefaultPhysicalPlanner { let filter_df_fields = filter_df_fields .into_iter() .map(|(qualifier, field)| { - (qualifier.cloned(), Arc::new(field.clone())) + (qualifier.cloned(), Arc::clone(field)) }) .collect(); @@ -1074,6 +1491,7 @@ impl DefaultPhysicalPlanner { )?; let filter_schema = Schema::new_with_metadata(filter_fields, metadata); + let filter_expr = create_physical_expr( expr, &filter_df_schema, @@ -1096,10 +1514,123 @@ impl DefaultPhysicalPlanner { let prefer_hash_join = session_state.config_options().optimizer.prefer_hash_join; + // TODO: Allow PWMJ to deal with residual equijoin conditions let join: Arc = if join_on.is_empty() { - if join_filter.is_none() && matches!(join_type, JoinType::Inner) { + if join_filter.is_none() && *join_type == JoinType::Inner { // cross join if there is no join conditions and no join filter set Arc::new(CrossJoinExec::new(physical_left, physical_right)) + } else if num_range_filters == 1 + && total_filters == 1 + && !matches!( + join_type, + JoinType::LeftSemi + | JoinType::RightSemi + | JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::LeftMark + | JoinType::RightMark + ) + && session_state + .config_options() + .optimizer + .enable_piecewise_merge_join + { + let Expr::BinaryExpr(be) = &range_filters[0] else { + return plan_err!( + "Unsupported expression for PWMJ: Expected `Expr::BinaryExpr`" + ); + }; + + let mut op = be.op; + if !matches!( + op, + Operator::Lt | Operator::LtEq | Operator::Gt | Operator::GtEq + ) { + return plan_err!( + "Unsupported operator for PWMJ: {:?}. Expected one of <, <=, >, >=", + op + ); + } + + fn reverse_ineq(op: Operator) -> Operator { + match op { + Operator::Lt => Operator::Gt, + Operator::LtEq => Operator::GtEq, + Operator::Gt => Operator::Lt, + Operator::GtEq => Operator::LtEq, + _ => op, + } + } + + #[derive(Clone, Copy, Debug, PartialEq, Eq)] + enum Side { + Left, + Right, + Both, + } + + let side_of = |e: &Expr| -> Result { + let cols = e.column_refs(); + let any_left = cols + .iter() + .any(|c| left_df_schema.index_of_column(c).is_ok()); + let any_right = cols + .iter() + .any(|c| right_df_schema.index_of_column(c).is_ok()); + + Ok(match (any_left, any_right) { + (true, false) => Side::Left, + (false, true) => Side::Right, + (true, true) => Side::Both, + _ => unreachable!(), + }) + }; + + let mut lhs_logical = &be.left; + let mut rhs_logical = &be.right; + + let left_side = side_of(lhs_logical)?; + let right_side = side_of(rhs_logical)?; + if left_side == Side::Both || right_side == Side::Both { + return Ok(Arc::new(NestedLoopJoinExec::try_new( + physical_left, + physical_right, + join_filter, + join_type, + None, + )?)); + } + + if left_side == Side::Right && right_side == Side::Left { + std::mem::swap(&mut lhs_logical, &mut rhs_logical); + op = reverse_ineq(op); + } else if !(left_side == Side::Left && right_side == Side::Right) + { + return plan_err!( + "Unsupported operator for PWMJ: {:?}. Expected one of <, <=, >, >=", + op + ); + } + + let on_left = create_physical_expr( + lhs_logical, + left_df_schema, + session_state.execution_props(), + )?; + let on_right = create_physical_expr( + rhs_logical, + right_df_schema, + session_state.execution_props(), + )?; + + Arc::new(PiecewiseMergeJoinExec::try_new( + physical_left, + physical_right, + (on_left, on_right), + op, + *join_type, + session_state.config().target_partitions(), + )?) } else { // there is no equal join condition, use the nested loop join Arc::new(NestedLoopJoinExec::try_new( @@ -1123,11 +1654,13 @@ impl DefaultPhysicalPlanner { join_filter, *join_type, vec![SortOptions::default(); join_on_len], - null_equals_null, + *null_equality, )?) } else if session_state.config().target_partitions() > 1 && session_state.config().repartition_joins() && prefer_hash_join + && !*null_aware + // Null-aware joins must use CollectLeft { Arc::new(HashJoinExec::try_new( physical_left, @@ -1137,7 +1670,8 @@ impl DefaultPhysicalPlanner { join_type, None, PartitionMode::Auto, - null_equals_null, + *null_equality, + *null_aware, )?) } else { Arc::new(HashJoinExec::try_new( @@ -1148,7 +1682,8 @@ impl DefaultPhysicalPlanner { join_type, None, PartitionMode::CollectLeft, - null_equals_null, + *null_equality, + *null_aware, )?) }; @@ -1173,7 +1708,7 @@ impl DefaultPhysicalPlanner { } // N Children - LogicalPlan::Union(_) => Arc::new(UnionExec::new(children.vec())), + LogicalPlan::Union(_) => UnionExec::try_new(children.vec())?, LogicalPlan::Extension(Extension { node }) => { let mut maybe_plan = None; let children = children.vec(); @@ -1195,22 +1730,16 @@ impl DefaultPhysicalPlanner { } let plan = match maybe_plan { - Some(v) => Ok(v), - _ => plan_err!("No installed planner was able to convert the custom node to an execution plan: {:?}", node) - }?; - - // Ensure the ExecutionPlan's schema matches the - // declared logical schema to catch and warn about - // logic errors when creating user defined plans. - if !node.schema().matches_arrow_schema(&plan.schema()) { - return plan_err!( - "Extension planner for {:?} created an ExecutionPlan with mismatched schema. \ - LogicalPlan schema: {:?}, ExecutionPlan schema: {:?}", - node, node.schema(), plan.schema() - ); - } else { - plan - } + Some(v) => Ok(v), + _ => plan_err!( + "No installed planner was able to convert the custom node to an execution plan: {:?}", + node + ), + }?; + + let context = format!("Extension planner for {node:?}"); + self.ensure_schema_matches(node.schema(), &plan, &context)?; + plan } // Other @@ -1234,17 +1763,17 @@ impl DefaultPhysicalPlanner { LogicalPlan::Explain(_) => { return internal_err!( "Unsupported logical plan: Explain must be root of the plan" - ) + ); } LogicalPlan::Distinct(_) => { return internal_err!( "Unsupported logical plan: Distinct should be replaced to Aggregate" - ) + ); } LogicalPlan::Analyze(_) => { return internal_err!( "Unsupported logical plan: Analyze must be root of the plan" - ) + ); } }; Ok(exec_node) @@ -1286,6 +1815,10 @@ impl DefaultPhysicalPlanner { physical_name(expr), ))?])), } + } else if group_expr.is_empty() { + // No GROUP BY clause - create empty PhysicalGroupBy + // no expressions, no null expressions and no grouping expressions + Ok(PhysicalGroupBy::new(vec![], vec![], vec![], false)) } else { Ok(PhysicalGroupBy::new_single( group_expr @@ -1357,6 +1890,7 @@ fn merge_grouping_set_physical_expr( grouping_set_expr, null_exprs, merged_sets, + true, )) } @@ -1399,7 +1933,7 @@ fn create_cube_physical_expr( } } - Ok(PhysicalGroupBy::new(all_exprs, null_exprs, groups)) + Ok(PhysicalGroupBy::new(all_exprs, null_exprs, groups, true)) } /// Expand and align a ROLLUP expression. This is a special case of GROUPING SETS @@ -1444,7 +1978,7 @@ fn create_rollup_physical_expr( groups.push(group) } - Ok(PhysicalGroupBy::new(all_exprs, null_exprs, groups)) + Ok(PhysicalGroupBy::new(all_exprs, null_exprs, groups, true)) } /// For a given logical expr, get a properly typed NULL ScalarValue physical expression @@ -1465,6 +1999,62 @@ fn get_null_physical_expr_pair( Ok((Arc::new(null_value), physical_name)) } +/// Qualifies the fields in a join schema with "left" and "right" qualifiers +/// without mutating the original schema. This function should only be used when +/// the join inputs have already been requalified earlier in `try_new_with_project_input`. +/// +/// The purpose is to avoid ambiguity errors later in planning (e.g., in nullability or data type resolution) +/// when converting expressions to fields. +fn qualify_join_schema_sides( + join_schema: &DFSchema, + left: &LogicalPlan, + right: &LogicalPlan, +) -> Result { + let left_fields = left.schema().fields(); + let right_fields = right.schema().fields(); + let join_fields = join_schema.fields(); + + // Validate lengths + assert_eq_or_internal_err!( + join_fields.len(), + left_fields.len() + right_fields.len(), + "Join schema field count must match left and right field count." + ); + + // Validate field names match + for (i, (field, expected)) in join_fields + .iter() + .zip(left_fields.iter().chain(right_fields.iter())) + .enumerate() + { + assert_eq_or_internal_err!( + field.name(), + expected.name(), + "Field name mismatch at index {}", + i + ); + } + + // qualify sides + let qualifiers = join_fields + .iter() + .enumerate() + .map(|(i, _)| { + if i < left_fields.len() { + Some(TableReference::Bare { + table: Arc::from("left"), + }) + } else { + Some(TableReference::Bare { + table: Arc::from("right"), + }) + } + }) + .collect(); + + join_schema.with_field_specific_qualified_schema(qualifiers) +} + fn get_physical_expr_pair( expr: &Expr, input_dfschema: &DFSchema, @@ -1476,47 +2066,276 @@ fn get_physical_expr_pair( Ok((physical_expr, physical_name)) } -/// Check if window bounds are valid after schema information is available, and -/// window_frame bounds are casted to the corresponding column type. -/// queries like: -/// OVER (ORDER BY a RANGES BETWEEN 3 PRECEDING AND 5 PRECEDING) -/// OVER (ORDER BY a RANGES BETWEEN INTERVAL '3 DAY' PRECEDING AND '5 DAY' PRECEDING) are rejected -pub fn is_window_frame_bound_valid(window_frame: &WindowFrame) -> bool { - match (&window_frame.start_bound, &window_frame.end_bound) { - (WindowFrameBound::Following(_), WindowFrameBound::Preceding(_)) - | (WindowFrameBound::Following(_), WindowFrameBound::CurrentRow) - | (WindowFrameBound::CurrentRow, WindowFrameBound::Preceding(_)) => false, - (WindowFrameBound::Preceding(lhs), WindowFrameBound::Preceding(rhs)) => { - !rhs.is_null() && (lhs.is_null() || (lhs >= rhs)) +/// Extract filter predicates from a DML input plan (DELETE/UPDATE). +/// +/// Walks the logical plan tree and collects Filter predicates and any filters +/// pushed down into TableScan nodes, splitting AND conjunctions into individual expressions. +/// +/// For UPDATE...FROM queries involving multiple tables, this function only extracts predicates +/// that reference the target table. Filters from source table scans are excluded to prevent +/// incorrect filter semantics. +/// +/// Column qualifiers are stripped so expressions can be evaluated against the TableProvider's +/// schema. Deduplication is performed because filters may appear in both Filter nodes and +/// TableScan.filters when the optimizer performs partial (Inexact) filter pushdown. +/// +/// # Parameters +/// - `input`: The logical plan tree to extract filters from (typically a DELETE or UPDATE plan) +/// - `target`: The target table reference to scope filter extraction (prevents multi-table filter leakage) +/// +/// # Returns +/// A vector of unqualified filter expressions that can be passed to the TableProvider for execution. +/// Returns an empty vector if no applicable filters are found. +/// +fn extract_dml_filters( + input: &Arc, + target: &TableReference, +) -> Result> { + let mut filters = Vec::new(); + let mut allowed_refs = vec![target.clone()]; + + // First pass: collect any alias references to the target table + input.apply(|node| { + if let LogicalPlan::SubqueryAlias(alias) = node + // Check if this alias points to the target table + && let LogicalPlan::TableScan(scan) = alias.input.as_ref() + && scan.table_name.resolved_eq(target) + { + allowed_refs.push(TableReference::bare(alias.alias.to_string())); } - (WindowFrameBound::Following(lhs), WindowFrameBound::Following(rhs)) => { - !lhs.is_null() && (rhs.is_null() || (lhs <= rhs)) + Ok(TreeNodeRecursion::Continue) + })?; + + input.apply(|node| { + match node { + LogicalPlan::Filter(filter) => { + // Split AND predicates into individual expressions + for predicate in split_conjunction(&filter.predicate) { + if predicate_is_on_target_multi(predicate, &allowed_refs)? { + filters.push(predicate.clone()); + } + } + } + LogicalPlan::TableScan(TableScan { + table_name, + filters: scan_filters, + .. + }) => { + // Only extract filters from the target table scan. + // This prevents incorrect filter extraction in UPDATE...FROM scenarios + // where multiple table scans may have filters. + if table_name.resolved_eq(target) { + for filter in scan_filters { + filters.extend(split_conjunction(filter).into_iter().cloned()); + } + } + } + // Plans without filter information + LogicalPlan::EmptyRelation(_) + | LogicalPlan::Values(_) + | LogicalPlan::DescribeTable(_) + | LogicalPlan::Explain(_) + | LogicalPlan::Analyze(_) + | LogicalPlan::Distinct(_) + | LogicalPlan::Extension(_) + | LogicalPlan::Statement(_) + | LogicalPlan::Dml(_) + | LogicalPlan::Ddl(_) + | LogicalPlan::Copy(_) + | LogicalPlan::Unnest(_) + | LogicalPlan::RecursiveQuery(_) => { + // No filters to extract from leaf/meta plans + } + // Plans with inputs (may contain filters in children) + LogicalPlan::Projection(_) + | LogicalPlan::SubqueryAlias(_) + | LogicalPlan::Limit(_) + | LogicalPlan::Sort(_) + | LogicalPlan::Union(_) + | LogicalPlan::Join(_) + | LogicalPlan::Repartition(_) + | LogicalPlan::Aggregate(_) + | LogicalPlan::Window(_) + | LogicalPlan::Subquery(_) => { + // Filter information may appear in child nodes; continue traversal + // to extract filters from Filter/TableScan nodes deeper in the plan + } } - _ => true, - } + Ok(TreeNodeRecursion::Continue) + })?; + + // Strip qualifiers and deduplicate. This ensures: + // 1. Only target-table predicates are retained from Filter nodes + // 2. Qualifiers stripped for TableProvider compatibility + // 3. Duplicates removed (from Filter nodes + TableScan.filters) + // + // Deduplication is necessary because filters may appear in both Filter nodes + // and TableScan.filters when the optimizer performs partial (Inexact) pushdown. + let mut seen_filters = HashSet::new(); + filters + .into_iter() + .try_fold(Vec::new(), |mut deduped, filter| { + let unqualified = strip_column_qualifiers(filter).map_err(|e| { + e.context(format!( + "Failed to strip column qualifiers for DML filter on table '{target}'" + )) + })?; + if seen_filters.insert(unqualified.clone()) { + deduped.push(unqualified); + } + Ok(deduped) + }) } -/// Create a window expression with a name from a logical expression -pub fn create_window_expr_with_name( - e: &Expr, - name: impl Into, - logical_schema: &DFSchema, +/// Determine whether a predicate references only columns from the target table +/// or its aliases. +/// +/// Columns may be qualified with the target table name or any of its aliases. +/// Unqualified columns are also accepted as they implicitly belong to the target table. +fn predicate_is_on_target_multi( + expr: &Expr, + allowed_refs: &[TableReference], +) -> Result { + let mut columns = HashSet::new(); + expr_to_columns(expr, &mut columns)?; + + // Short-circuit on first mismatch: returns false if any column references a table not in allowed_refs. + // Columns are accepted if: + // 1. They are unqualified (no relation specified), OR + // 2. Their relation matches one of the allowed table references using resolved equality + Ok(!columns.iter().any(|column| { + column.relation.as_ref().is_some_and(|relation| { + !allowed_refs + .iter() + .any(|allowed| relation.resolved_eq(allowed)) + }) + })) +} + +/// Strip table qualifiers from column references in an expression. +/// This is needed because DML filter expressions contain qualified column names +/// (e.g., "table.column") but the TableProvider's schema only has simple names. +fn strip_column_qualifiers(expr: Expr) -> Result { + expr.transform(|e| { + if let Expr::Column(col) = &e + && col.relation.is_some() + { + // Strip the qualifier + return Ok(Transformed::yes(Expr::Column(Column::new_unqualified( + col.name.clone(), + )))); + } + Ok(Transformed::no(e)) + }) + .map(|t| t.data) +} + +/// Extract column assignments from an UPDATE input plan. +/// For UPDATE statements, the SQL planner encodes assignments as a projection +/// over the source table. This function extracts column name and expression pairs +/// from the projection. Column qualifiers are stripped from the expressions. +/// +fn extract_update_assignments(input: &Arc) -> Result> { + // The UPDATE input plan structure is: + // Projection(updated columns as expressions with aliases) + // Filter(optional WHERE clause) + // TableScan + // + // Each projected expression has an alias matching the column name + let mut assignments = Vec::new(); + + // Find the top-level projection + if let LogicalPlan::Projection(projection) = input.as_ref() { + for expr in &projection.expr { + if let Expr::Alias(alias) = expr { + // The alias name is the column name being updated + // The inner expression is the new value + let column_name = alias.name.clone(); + // Only include if it's not just a column reference to itself + // (those are columns that aren't being updated) + if !is_identity_assignment(&alias.expr, &column_name) { + // Strip qualifiers from the assignment expression + let stripped_expr = strip_column_qualifiers((*alias.expr).clone())?; + assignments.push((column_name, stripped_expr)); + } + } + } + } else { + // Try to find projection deeper in the plan + input.apply(|node| { + if let LogicalPlan::Projection(projection) = node { + for expr in &projection.expr { + if let Expr::Alias(alias) = expr { + let column_name = alias.name.clone(); + if !is_identity_assignment(&alias.expr, &column_name) { + let stripped_expr = + strip_column_qualifiers((*alias.expr).clone())?; + assignments.push((column_name, stripped_expr)); + } + } + } + return Ok(TreeNodeRecursion::Stop); + } + Ok(TreeNodeRecursion::Continue) + })?; + } + + Ok(assignments) +} + +/// Check if an assignment is an identity assignment (column = column) +/// These are columns that are not being modified in the UPDATE +fn is_identity_assignment(expr: &Expr, column_name: &str) -> bool { + match expr { + Expr::Column(col) => col.name == column_name, + _ => false, + } +} + +/// Check if window bounds are valid after schema information is available, and +/// window_frame bounds are casted to the corresponding column type. +/// queries like: +/// OVER (ORDER BY a RANGES BETWEEN 3 PRECEDING AND 5 PRECEDING) +/// OVER (ORDER BY a RANGES BETWEEN INTERVAL '3 DAY' PRECEDING AND '5 DAY' PRECEDING) are rejected +pub fn is_window_frame_bound_valid(window_frame: &WindowFrame) -> bool { + match (&window_frame.start_bound, &window_frame.end_bound) { + (WindowFrameBound::Following(_), WindowFrameBound::Preceding(_)) + | (WindowFrameBound::Following(_), WindowFrameBound::CurrentRow) + | (WindowFrameBound::CurrentRow, WindowFrameBound::Preceding(_)) => false, + (WindowFrameBound::Preceding(lhs), WindowFrameBound::Preceding(rhs)) => { + !rhs.is_null() && (lhs.is_null() || (lhs >= rhs)) + } + (WindowFrameBound::Following(lhs), WindowFrameBound::Following(rhs)) => { + !lhs.is_null() && (rhs.is_null() || (lhs <= rhs)) + } + _ => true, + } +} + +/// Create a window expression with a name from a logical expression +pub fn create_window_expr_with_name( + e: &Expr, + name: impl Into, + logical_schema: &DFSchema, execution_props: &ExecutionProps, ) -> Result> { let name = name.into(); - let physical_schema: &Schema = &logical_schema.into(); + let physical_schema = Arc::clone(logical_schema.inner()); match e { - Expr::WindowFunction(WindowFunction { - fun, - params: - WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - null_treatment, - }, - }) => { + Expr::WindowFunction(window_fun) => { + let WindowFunction { + fun, + params: + WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + null_treatment, + distinct, + filter, + }, + } = window_fun.as_ref(); let physical_args = create_physical_exprs(args, logical_schema, execution_props)?; let partition_by = @@ -1526,23 +2345,31 @@ pub fn create_window_expr_with_name( if !is_window_frame_bound_valid(window_frame) { return plan_err!( - "Invalid window frame: start bound ({}) cannot be larger than end bound ({})", - window_frame.start_bound, window_frame.end_bound - ); + "Invalid window frame: start bound ({}) cannot be larger than end bound ({})", + window_frame.start_bound, + window_frame.end_bound + ); } let window_frame = Arc::new(window_frame.clone()); let ignore_nulls = null_treatment.unwrap_or(NullTreatment::RespectNulls) == NullTreatment::IgnoreNulls; + let physical_filter = filter + .as_ref() + .map(|f| create_physical_expr(f, logical_schema, execution_props)) + .transpose()?; + windows::create_window_expr( fun, name, &physical_args, &partition_by, - order_by.as_ref(), + &order_by, window_frame, physical_schema, ignore_nulls, + *distinct, + physical_filter, ) } other => plan_err!("Invalid window expression '{other:?}'"), @@ -1567,8 +2394,8 @@ type AggregateExprWithOptionalArgs = ( Arc, // The filter clause, if any Option>, - // Ordering requirements, if any - Option, + // Expressions in the ORDER BY clause + Vec, ); /// Create an aggregate expression with a name from a logical expression @@ -1612,22 +2439,16 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( let ignore_nulls = null_treatment.unwrap_or(NullTreatment::RespectNulls) == NullTreatment::IgnoreNulls; - let (agg_expr, filter, order_by) = { - let physical_sort_exprs = match order_by { - Some(exprs) => Some(create_physical_sort_exprs( - exprs, - logical_input_schema, - execution_props, - )?), - None => None, - }; - - let ordering_reqs: LexOrdering = - physical_sort_exprs.clone().unwrap_or_default(); + let (agg_expr, filter, order_bys) = { + let order_bys = create_physical_sort_exprs( + order_by, + logical_input_schema, + execution_props, + )?; let agg_expr = AggregateExprBuilder::new(func.to_owned(), physical_args.to_vec()) - .order_by(ordering_reqs) + .order_by(order_bys.clone()) .schema(Arc::new(physical_input_schema.to_owned())) .alias(name) .human_display(human_displan) @@ -1636,10 +2457,10 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( .build() .map(Arc::new)?; - (agg_expr, filter, physical_sort_exprs) + (agg_expr, filter, order_bys) }; - Ok((agg_expr, filter, order_by)) + Ok((agg_expr, filter, order_bys)) } other => internal_err!("Invalid aggregate expression '{other:?}'"), } @@ -1652,21 +2473,24 @@ pub fn create_aggregate_expr_and_maybe_filter( physical_input_schema: &Schema, execution_props: &ExecutionProps, ) -> Result { - // unpack (nested) aliased logical expressions, e.g. "sum(col) as total" + // Unpack (potentially nested) aliased logical expressions, e.g. "sum(col) as total" + // Some functions like `count_all()` create internal aliases, + // Unwrap all alias layers to get to the underlying aggregate function let (name, human_display, e) = match e { - Expr::Alias(Alias { expr, name, .. }) => { - (Some(name.clone()), String::default(), expr.as_ref()) + Expr::Alias(Alias { name, .. }) => { + let unaliased = e.clone().unalias_nested().data; + (Some(name.clone()), e.human_display().to_string(), unaliased) } Expr::AggregateFunction(_) => ( Some(e.schema_name().to_string()), e.human_display().to_string(), - e, + e.clone(), ), - _ => (None, String::default(), e), + _ => (None, String::default(), e.clone()), }; create_aggregate_expr_with_name_and_maybe_filter( - e, + &e, name, human_display, logical_input_schema, @@ -1675,14 +2499,6 @@ pub fn create_aggregate_expr_and_maybe_filter( ) } -#[deprecated( - since = "47.0.0", - note = "use datafusion::{create_physical_sort_expr, create_physical_sort_exprs}" -)] -pub use datafusion_physical_expr::{ - create_physical_sort_expr, create_physical_sort_exprs, -}; - impl DefaultPhysicalPlanner { /// Handles capturing the various plans for EXPLAIN queries /// @@ -1739,6 +2555,7 @@ impl DefaultPhysicalPlanner { stringified_plans.push(StringifiedPlan::new( FinalPhysicalPlan, displayable(optimized_plan.as_ref()) + .set_tree_maximum_render_width(config.tree_maximum_render_width) .tree_render() .to_string(), )); @@ -1896,11 +2713,17 @@ impl DefaultPhysicalPlanner { session_state: &SessionState, ) -> Result> { let input = self.create_physical_plan(&a.input, session_state).await?; - let schema = SchemaRef::new((*a.schema).clone().into()); + let schema = Arc::clone(a.schema.inner()); let show_statistics = session_state.config_options().explain.show_statistics; + let analyze_level = session_state.config_options().explain.analyze_level; + let metric_types = match analyze_level { + ExplainAnalyzeLevel::Summary => vec![MetricType::SUMMARY], + ExplainAnalyzeLevel::Dev => vec![MetricType::SUMMARY, MetricType::DEV], + }; Ok(Arc::new(AnalyzeExec::new( a.verbose, show_statistics, + metric_types, input, schema, ))) @@ -1908,6 +2731,7 @@ impl DefaultPhysicalPlanner { /// Optimize a physical plan by applying each physical optimizer, /// calling observer(plan, optimizer after each one) + #[expect(clippy::needless_pass_by_value)] pub fn optimize_physical_plan( &self, plan: Arc, @@ -1922,7 +2746,7 @@ impl DefaultPhysicalPlanner { "Input physical plan:\n{}\n", displayable(plan.as_ref()).indent(false) ); - trace!( + debug!( "Detailed input physical plan:\n{}", displayable(plan.as_ref()).indent(true) ); @@ -1942,9 +2766,9 @@ impl DefaultPhysicalPlanner { // This only checks the schema in release build, and performs additional checks in debug mode. OptimizationInvariantChecker::new(optimizer) - .check(&new_plan, before_schema)?; + .check(&new_plan, &before_schema)?; - trace!( + debug!( "Optimized physical plan by {}:\n{}\n", optimizer.name(), displayable(new_plan.as_ref()).indent(false) @@ -1960,14 +2784,22 @@ impl DefaultPhysicalPlanner { "Optimized physical plan:\n{}\n", displayable(new_plan.as_ref()).indent(false) ); - trace!("Detailed optimized physical plan:\n{new_plan:?}"); + + // Don't print new_plan directly, as that may overflow the stack. + // For example: + // thread 'tokio-runtime-worker' has overflowed its stack + // fatal runtime error: stack overflow, aborting + debug!( + "Detailed optimized physical plan:\n{}\n", + displayable(new_plan.as_ref()).indent(true) + ); Ok(new_plan) } // return an record_batch which describes a table's schema. fn plan_describe( &self, - table_schema: Arc, + table_schema: &Arc, output_schema: Arc, ) -> Result> { let mut column_names = StringBuilder::new(); @@ -1978,7 +2810,7 @@ impl DefaultPhysicalPlanner { // "System supplied type" --> Use debug format of the datatype let data_type = field.data_type(); - data_types.append_value(format!("{data_type:?}")); + data_types.append_value(format!("{data_type}")); // "YES if the column is possibly nullable, NO if it is known not nullable. " let nullable_str = if field.is_nullable() { "YES" } else { "NO" }; @@ -2044,21 +2876,105 @@ impl DefaultPhysicalPlanner { let physical_expr = self.create_physical_expr(e, input_logical_schema, session_state); - // Check for possible column name mismatches - let final_physical_expr = - maybe_fix_physical_column_name(physical_expr, &input_physical_schema); - - tuple_err((final_physical_expr, physical_name)) + tuple_err((physical_expr, physical_name)) }) .collect::>>()?; - Ok(Arc::new(ProjectionExec::try_new( - physical_exprs, - input_exec, - )?)) + let num_input_columns = input_exec.schema().fields().len(); + + match self.try_plan_async_exprs( + num_input_columns, + PlannedExprResult::ExprWithName(physical_exprs), + input_physical_schema.as_ref(), + )? { + PlanAsyncExpr::Sync(PlannedExprResult::ExprWithName(physical_exprs)) => { + let proj_exprs: Vec = physical_exprs + .into_iter() + .map(|(expr, alias)| ProjectionExpr { expr, alias }) + .collect(); + Ok(Arc::new(ProjectionExec::try_new(proj_exprs, input_exec)?)) + } + PlanAsyncExpr::Async( + async_map, + PlannedExprResult::ExprWithName(physical_exprs), + ) => { + let async_exec = + AsyncFuncExec::try_new(async_map.async_exprs, input_exec)?; + let proj_exprs: Vec = physical_exprs + .into_iter() + .map(|(expr, alias)| ProjectionExpr { expr, alias }) + .collect(); + let new_proj_exec = + ProjectionExec::try_new(proj_exprs, Arc::new(async_exec))?; + Ok(Arc::new(new_proj_exec)) + } + _ => internal_err!("Unexpected PlanAsyncExpressions variant"), + } + } + + fn try_plan_async_exprs( + &self, + num_input_columns: usize, + physical_expr: PlannedExprResult, + schema: &Schema, + ) -> Result { + let mut async_map = AsyncMapper::new(num_input_columns); + match &physical_expr { + PlannedExprResult::ExprWithName(exprs) => { + exprs + .iter() + .try_for_each(|(expr, _)| async_map.find_references(expr, schema))?; + } + PlannedExprResult::Expr(exprs) => { + exprs + .iter() + .try_for_each(|expr| async_map.find_references(expr, schema))?; + } + } + + if async_map.is_empty() { + return Ok(PlanAsyncExpr::Sync(physical_expr)); + } + + let new_exprs = match physical_expr { + PlannedExprResult::ExprWithName(exprs) => PlannedExprResult::ExprWithName( + exprs + .iter() + .map(|(expr, column_name)| { + let new_expr = Arc::clone(expr) + .transform_up(|e| Ok(async_map.map_expr(e)))?; + Ok((new_expr.data, column_name.to_string())) + }) + .collect::>()?, + ), + PlannedExprResult::Expr(exprs) => PlannedExprResult::Expr( + exprs + .iter() + .map(|expr| { + let new_expr = Arc::clone(expr) + .transform_up(|e| Ok(async_map.map_expr(e)))?; + Ok(new_expr.data) + }) + .collect::>()?, + ), + }; + // rewrite the projection's expressions in terms of the columns with the result of async evaluation + Ok(PlanAsyncExpr::Async(async_map, new_exprs)) } } +#[derive(Debug)] +enum PlannedExprResult { + ExprWithName(Vec<(Arc, String)>), + Expr(Vec>), +} + +#[derive(Debug)] +enum PlanAsyncExpr { + Sync(PlannedExprResult), + Async(AsyncMapper, PlannedExprResult), +} + fn tuple_err(value: (Result, Result)) -> Result<(T, R)> { match value { (Ok(e), Ok(e1)) => Ok((e, e1)), @@ -2068,47 +2984,6 @@ fn tuple_err(value: (Result, Result)) -> Result<(T, R)> { } } -// Handle the case where the name of a physical column expression does not match the corresponding physical input fields names. -// Physical column names are derived from the physical schema, whereas physical column expressions are derived from the logical column names. -// -// This is a special case that applies only to column expressions. Logical plans may slightly modify column names by appending a suffix (e.g., using ':'), -// to avoid duplicates—since DFSchemas do not allow duplicate names. For example: `count(Int64(1)):1`. -fn maybe_fix_physical_column_name( - expr: Result>, - input_physical_schema: &SchemaRef, -) -> Result> { - let Ok(expr) = expr else { return expr }; - expr.transform_down(|node| { - if let Some(column) = node.as_any().downcast_ref::() { - let idx = column.index(); - let physical_field = input_physical_schema.field(idx); - let expr_col_name = column.name(); - let physical_name = physical_field.name(); - - if expr_col_name != physical_name { - // handle edge cases where the physical_name contains ':'. - let colon_count = physical_name.matches(':').count(); - let mut splits = expr_col_name.match_indices(':'); - let split_pos = splits.nth(colon_count); - - if let Some((i, _)) = split_pos { - let base_name = &expr_col_name[..i]; - if base_name == physical_name { - let updated_column = Column::new(physical_name, idx); - return Ok(Transformed::yes(Arc::new(updated_column))); - } - } - } - - // If names already match or fix is not possible, just leave it as it is - Ok(Transformed::no(node)) - } else { - Ok(Transformed::no(node)) - } - }) - .data() -} - struct OptimizationInvariantChecker<'a> { rule: &'a Arc, } @@ -2127,11 +3002,14 @@ impl<'a> OptimizationInvariantChecker<'a> { pub fn check( &mut self, plan: &Arc, - previous_schema: Arc, + previous_schema: &Arc, ) -> Result<()> { // if the rule is not permitted to change the schema, confirm that it did not change. - if self.rule.schema_check() && plan.schema() != previous_schema { - internal_err!("PhysicalOptimizer rule '{}' failed. Schema mismatch. Expected original schema: {:?}, got new schema: {:?}", + if self.rule.schema_check() + && !is_allowed_schema_change(previous_schema.as_ref(), plan.schema().as_ref()) + { + internal_err!( + "PhysicalOptimizer rule '{}' failed. Schema mismatch. Expected original schema: {}, got new schema: {}", self.rule.name(), previous_schema, plan.schema() @@ -2146,12 +3024,44 @@ impl<'a> OptimizationInvariantChecker<'a> { } } +/// Checks if the change from `old` schema to `new` is allowed or not. +/// +/// The current implementation only allows nullability of individual fields to change +/// from 'nullable' to 'not nullable'. This can happen due to physical expressions knowing +/// more about their null-ness than their logical counterparts. +/// This change is allowed because for any field the non-nullable domain `F` is a strict subset +/// of the nullable domain `F ∪ { NULL }`. A physical schema that guarantees a stricter subset +/// of values will not violate any assumptions made based on the less strict schema. +fn is_allowed_schema_change(old: &Schema, new: &Schema) -> bool { + if new.metadata != old.metadata { + return false; + } + + if new.fields.len() != old.fields.len() { + return false; + } + + let new_fields = new.fields.iter().map(|f| f.as_ref()); + let old_fields = old.fields.iter().map(|f| f.as_ref()); + old_fields + .zip(new_fields) + .all(|(old, new)| is_allowed_field_change(old, new)) +} + +fn is_allowed_field_change(old_field: &Field, new_field: &Field) -> bool { + new_field.name() == old_field.name() + && new_field.data_type() == old_field.data_type() + && new_field.metadata() == old_field.metadata() + && (new_field.is_nullable() == old_field.is_nullable() + || !new_field.is_nullable()) +} + impl<'n> TreeNodeVisitor<'n> for OptimizationInvariantChecker<'_> { type Node = Arc; fn f_down(&mut self, node: &'n Self::Node) -> Result { // Checks for the more permissive `InvariantLevel::Always`. - // Plans are not guarenteed to be executable after each physical optimizer run. + // Plans are not guaranteed to be executable after each physical optimizer run. node.check_invariants(InvariantLevel::Always).map_err(|e| e.context(format!("Invariant for ExecutionPlan node '{}' failed for PhysicalOptimizer rule '{}'", node.name(), self.rule.name())) )?; @@ -2194,11 +3104,11 @@ mod tests { use std::ops::{BitAnd, Not}; use super::*; - use crate::datasource::file_format::options::CsvReadOptions; use crate::datasource::MemTable; + use crate::datasource::file_format::options::CsvReadOptions; use crate::physical_plan::{ - expressions, DisplayAs, DisplayFormatType, PlanProperties, - SendableRecordBatchStream, + DisplayAs, DisplayFormatType, PlanProperties, SendableRecordBatchStream, + expressions, }; use crate::prelude::{SessionConfig, SessionContext}; use crate::test_util::{scan_empty, scan_empty_with_partitions}; @@ -2206,17 +3116,19 @@ mod tests { use crate::execution::session_state::SessionStateBuilder; use arrow::array::{ArrayRef, DictionaryArray, Int32Array}; use arrow::datatypes::{DataType, Field, Int32Type}; + use arrow_schema::SchemaRef; use datafusion_common::config::ConfigOptions; use datafusion_common::{ - assert_contains, DFSchemaRef, TableReference, ToDFSchema as _, + DFSchemaRef, TableReference, ToDFSchema as _, assert_contains, }; - use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; + use datafusion_execution::runtime_env::RuntimeEnv; + use datafusion_expr::builder::subquery_alias; use datafusion_expr::{ - col, lit, LogicalPlanBuilder, Operator, UserDefinedLogicalNodeCore, + LogicalPlanBuilder, TableSource, UserDefinedLogicalNodeCore, col, lit, }; + use datafusion_functions_aggregate::count::count_all; use datafusion_functions_aggregate::expr_fn::sum; - use datafusion_physical_expr::expressions::{BinaryExpr, IsNotNullExpr}; use datafusion_physical_expr::EquivalenceProperties; use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; @@ -2258,8 +3170,9 @@ mod tests { // verify that the plan correctly casts u8 to i64 // the cast from u8 to i64 for literal will be simplified, and get lit(int64(5)) // the cast here is implicit so has CastOptions with safe=true - let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 }, op: Lt, right: Literal { value: Int64(5) }, fail_on_overflow: false }"; - assert!(format!("{exec_plan:?}").contains(expected)); + let expected = r#"BinaryExpr { left: Column { name: "c7", index: 2 }, op: Lt, right: Literal { value: Int64(5), field: Field { name: "lit", data_type: Int64 } }, fail_on_overflow: false"#; + + assert_contains!(format!("{exec_plan:?}"), expected); Ok(()) } @@ -2283,9 +3196,113 @@ mod tests { &session_state, ); - let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]] })"#; - - assert_eq!(format!("{cube:?}"), expected); + insta::assert_debug_snapshot!(cube, @r#" + Ok( + PhysicalGroupBy { + expr: [ + ( + Column { + name: "c1", + index: 0, + }, + "c1", + ), + ( + Column { + name: "c2", + index: 1, + }, + "c2", + ), + ( + Column { + name: "c3", + index: 2, + }, + "c3", + ), + ], + null_expr: [ + ( + Literal { + value: Utf8(NULL), + field: Field { + name: "lit", + data_type: Utf8, + nullable: true, + }, + }, + "c1", + ), + ( + Literal { + value: Int64(NULL), + field: Field { + name: "lit", + data_type: Int64, + nullable: true, + }, + }, + "c2", + ), + ( + Literal { + value: Int64(NULL), + field: Field { + name: "lit", + data_type: Int64, + nullable: true, + }, + }, + "c3", + ), + ], + groups: [ + [ + false, + false, + false, + ], + [ + true, + false, + false, + ], + [ + false, + true, + false, + ], + [ + false, + false, + true, + ], + [ + true, + true, + false, + ], + [ + true, + false, + true, + ], + [ + false, + true, + true, + ], + [ + true, + true, + true, + ], + ], + has_grouping_set: true, + }, + ) + "#); Ok(()) } @@ -2310,9 +3327,93 @@ mod tests { &session_state, ); - let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]] })"#; - - assert_eq!(format!("{rollup:?}"), expected); + insta::assert_debug_snapshot!(rollup, @r#" + Ok( + PhysicalGroupBy { + expr: [ + ( + Column { + name: "c1", + index: 0, + }, + "c1", + ), + ( + Column { + name: "c2", + index: 1, + }, + "c2", + ), + ( + Column { + name: "c3", + index: 2, + }, + "c3", + ), + ], + null_expr: [ + ( + Literal { + value: Utf8(NULL), + field: Field { + name: "lit", + data_type: Utf8, + nullable: true, + }, + }, + "c1", + ), + ( + Literal { + value: Int64(NULL), + field: Field { + name: "lit", + data_type: Int64, + nullable: true, + }, + }, + "c2", + ), + ( + Literal { + value: Int64(NULL), + field: Field { + name: "lit", + data_type: Int64, + nullable: true, + }, + }, + "c3", + ), + ], + groups: [ + [ + true, + true, + true, + ], + [ + false, + true, + true, + ], + [ + false, + false, + true, + ], + [ + false, + false, + false, + ], + ], + has_grouping_set: true, + }, + ) + "#); Ok(()) } @@ -2427,8 +3528,7 @@ mod tests { .create_physical_plan(&logical_plan, &session_state) .await; - let expected_error = - "No installed planner was able to convert the custom node to an execution plan: NoOp"; + let expected_error = "No installed planner was able to convert the custom node to an execution plan: NoOp"; match plan { Ok(_) => panic!("Expected planning failure"), Err(e) => assert!( @@ -2450,35 +3550,13 @@ mod tests { let logical_plan = LogicalPlan::Extension(Extension { node: Arc::new(NoOpExtensionNode::default()), }); - let plan = planner + let e = planner .create_physical_plan(&logical_plan, &session_state) - .await; + .await + .expect_err("planning error") + .strip_backtrace(); - let expected_error: &str = "Error during planning: \ - Extension planner for NoOp created an ExecutionPlan with mismatched schema. \ - LogicalPlan schema: \ - DFSchema { inner: Schema { fields: \ - [Field { name: \"a\", \ - data_type: Int32, \ - nullable: false, \ - dict_id: 0, \ - dict_is_ordered: false, metadata: {} }], \ - metadata: {} }, field_qualifiers: [None], \ - functional_dependencies: FunctionalDependencies { deps: [] } }, \ - ExecutionPlan schema: Schema { fields: \ - [Field { name: \"b\", \ - data_type: Int32, \ - nullable: false, \ - dict_id: 0, \ - dict_is_ordered: false, metadata: {} }], \ - metadata: {} }"; - match plan { - Ok(_) => panic!("Expected planning failure"), - Err(e) => assert!( - e.to_string().contains(expected_error), - "Error '{e}' did not contain expected error '{expected_error}'" - ), - } + insta::assert_snapshot!(e, @r#"Error during planning: Extension planner for NoOp created an ExecutionPlan with mismatched schema. LogicalPlan schema: DFSchema { inner: Schema { fields: [Field { name: "a", data_type: Int32 }], metadata: {} }, field_qualifiers: [None], functional_dependencies: FunctionalDependencies { deps: [] } }, ExecutionPlan schema: Schema { fields: [Field { name: "b", data_type: Int32 }], metadata: {} }"#); } #[tokio::test] @@ -2494,10 +3572,9 @@ mod tests { let execution_plan = plan(&logical_plan).await?; // verify that the plan correctly adds cast from Int64(1) to Utf8, and the const will be evaluated. - let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") }, fail_on_overflow: false }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\") }, fail_on_overflow: false }, fail_on_overflow: false }"; + let expected = r#"expr: BinaryExpr { left: BinaryExpr { left: Column { name: "c1", index: 0 }, op: Eq, right: Literal { value: Utf8("a"), field: Field { name: "lit", data_type: Utf8 } }, fail_on_overflow: false }"#; - let actual = format!("{execution_plan:?}"); - assert!(actual.contains(expected), "{}", actual); + assert_contains!(format!("{execution_plan:?}"), expected); Ok(()) } @@ -2517,7 +3594,7 @@ mod tests { assert_contains!( &e, - r#"Error during planning: Can not find compatible types to compare Boolean with [Struct([Field { name: "foo", data_type: Boolean, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }]), Utf8]"# + r#"Error during planning: Can not find compatible types to compare Boolean with [Struct("foo": non-null Boolean), Utf8]"# ); Ok(()) @@ -2674,6 +3751,25 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_aggregate_count_all_with_alias() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::UInt32, false), + ])); + + let logical_plan = scan_empty(None, schema.as_ref(), None)? + .aggregate(Vec::::new(), vec![count_all().alias("total_rows")])? + .build()?; + + let physical_plan = plan(&logical_plan).await?; + assert_eq!( + "total_rows", + physical_plan.schema().field(0).name().as_str() + ); + Ok(()) + } + #[tokio::test] async fn test_explain() { let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); @@ -2689,18 +3785,27 @@ mod tests { if let Some(plan) = plan.as_any().downcast_ref::() { let stringified_plans = plan.stringified_plans(); assert!(stringified_plans.len() >= 4); - assert!(stringified_plans - .iter() - .any(|p| matches!(p.plan_type, PlanType::FinalLogicalPlan))); - assert!(stringified_plans - .iter() - .any(|p| matches!(p.plan_type, PlanType::InitialPhysicalPlan))); - assert!(stringified_plans - .iter() - .any(|p| matches!(p.plan_type, PlanType::OptimizedPhysicalPlan { .. }))); - assert!(stringified_plans - .iter() - .any(|p| matches!(p.plan_type, PlanType::FinalPhysicalPlan))); + assert!( + stringified_plans + .iter() + .any(|p| p.plan_type == PlanType::FinalLogicalPlan) + ); + assert!( + stringified_plans + .iter() + .any(|p| p.plan_type == PlanType::InitialPhysicalPlan) + ); + assert!( + stringified_plans.iter().any(|p| matches!( + p.plan_type, + PlanType::OptimizedPhysicalPlan { .. } + )) + ); + assert!( + stringified_plans + .iter() + .any(|p| p.plan_type == PlanType::FinalPhysicalPlan) + ); } else { panic!( "Plan was not an explain plan: {}", @@ -2757,71 +3862,6 @@ mod tests { } } - #[tokio::test] - async fn test_maybe_fix_colon_in_physical_name() { - // The physical schema has a field name with a colon - let schema = Schema::new(vec![Field::new("metric:avg", DataType::Int32, false)]); - let schema_ref: SchemaRef = Arc::new(schema); - - // What might happen after deduplication - let logical_col_name = "metric:avg:1"; - let expr_with_suffix = - Arc::new(Column::new(logical_col_name, 0)) as Arc; - let expr_result = Ok(expr_with_suffix); - - // Call function under test - let fixed_expr = - maybe_fix_physical_column_name(expr_result, &schema_ref).unwrap(); - - // Downcast back to Column so we can check the name - let col = fixed_expr - .as_any() - .downcast_ref::() - .expect("Column"); - - assert_eq!(col.name(), "metric:avg"); - } - - #[tokio::test] - async fn test_maybe_fix_nested_column_name_with_colon() { - let schema = Schema::new(vec![Field::new("column", DataType::Int32, false)]); - let schema_ref: SchemaRef = Arc::new(schema); - - // Construct the nested expr - let col_expr = Arc::new(Column::new("column:1", 0)) as Arc; - let is_not_null_expr = Arc::new(IsNotNullExpr::new(col_expr.clone())); - - // Create a binary expression and put the column inside - let binary_expr = Arc::new(BinaryExpr::new( - is_not_null_expr.clone(), - Operator::Or, - is_not_null_expr.clone(), - )) as Arc; - - let fixed_expr = - maybe_fix_physical_column_name(Ok(binary_expr), &schema_ref).unwrap(); - - let bin = fixed_expr - .as_any() - .downcast_ref::() - .expect("Expected BinaryExpr"); - - // Check that both sides where renamed - for expr in &[bin.left(), bin.right()] { - let is_not_null = expr - .as_any() - .downcast_ref::() - .expect("Expected IsNotNull"); - - let col = is_not_null - .arg() - .as_any() - .downcast_ref::() - .expect("Expected Column"); - - assert_eq!(col.name(), "column"); - } - } struct ErrorExtensionPlanner {} #[async_trait] @@ -2908,13 +3948,15 @@ mod tests { #[derive(Debug)] struct NoOpExecutionPlan { - cache: PlanProperties, + cache: Arc, } impl NoOpExecutionPlan { fn new(schema: SchemaRef) -> Self { let cache = Self::compute_properties(schema); - Self { cache } + Self { + cache: Arc::new(cache), + } } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. @@ -2952,7 +3994,7 @@ mod tests { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -2974,6 +4016,20 @@ mod tests { ) -> Result { unimplemented!("NoOpExecutionPlan::execute"); } + + fn apply_expressions( + &self, + f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + // Visit expressions in the output ordering from equivalence properties + let mut tnr = TreeNodeRecursion::Continue; + if let Some(ordering) = self.cache.output_ordering() { + for sort_expr in ordering { + tnr = tnr.visit_sibling(|| f(sort_expr.expr.as_ref()))?; + } + } + Ok(tnr) + } } // Produces an execution plan where the schema is mismatched from @@ -3106,7 +4162,7 @@ digraph { fn children(&self) -> Vec<&Arc> { self.0.iter().collect::>() } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { unimplemented!() } fn execute( @@ -3116,6 +4172,12 @@ digraph { ) -> Result { unimplemented!() } + fn apply_expressions( + &self, + _f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + Ok(TreeNodeRecursion::Continue) + } } impl DisplayAs for OkExtensionNode { fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { @@ -3132,8 +4194,12 @@ digraph { } fn check_invariants(&self, check: InvariantLevel) -> Result<()> { match check { - InvariantLevel::Always => plan_err!("extension node failed it's user-defined always-invariant check"), - InvariantLevel::Executable => panic!("the OptimizationInvariantChecker should not be checking for executableness"), + InvariantLevel::Always => plan_err!( + "extension node failed it's user-defined always-invariant check" + ), + InvariantLevel::Executable => panic!( + "the OptimizationInvariantChecker should not be checking for executableness" + ), } } fn schema(&self) -> SchemaRef { @@ -3151,7 +4217,7 @@ digraph { fn children(&self) -> Vec<&Arc> { unimplemented!() } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { unimplemented!() } fn execute( @@ -3161,6 +4227,12 @@ digraph { ) -> Result { unimplemented!() } + fn apply_expressions( + &self, + _f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + Ok(TreeNodeRecursion::Continue) + } } impl DisplayAs for InvariantFailsExtensionNode { fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { @@ -3202,24 +4274,26 @@ digraph { // Test: check should pass with same schema let equal_schema = ok_plan.schema(); - OptimizationInvariantChecker::new(&rule).check(&ok_plan, equal_schema)?; + OptimizationInvariantChecker::new(&rule).check(&ok_plan, &equal_schema)?; // Test: should fail with schema changed let different_schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Boolean, false)])); let expected_err = OptimizationInvariantChecker::new(&rule) - .check(&ok_plan, different_schema) + .check(&ok_plan, &different_schema) .unwrap_err(); assert!(expected_err.to_string().contains("PhysicalOptimizer rule 'OptimizerRuleWithSchemaCheck' failed. Schema mismatch. Expected original schema")); // Test: should fail when extension node fails it's own invariant check let failing_node: Arc = Arc::new(InvariantFailsExtensionNode); let expected_err = OptimizationInvariantChecker::new(&rule) - .check(&failing_node, ok_plan.schema()) + .check(&failing_node, &ok_plan.schema()) .unwrap_err(); - assert!(expected_err - .to_string() - .contains("extension node failed it's user-defined always-invariant check")); + assert!( + expected_err.to_string().contains( + "extension node failed it's user-defined always-invariant check" + ) + ); // Test: should fail when descendent extension node fails let failing_node: Arc = Arc::new(InvariantFailsExtensionNode); @@ -3228,11 +4302,13 @@ digraph { Arc::clone(&child), ])?; let expected_err = OptimizationInvariantChecker::new(&rule) - .check(&invalid_plan, ok_plan.schema()) + .check(&invalid_plan, &ok_plan.schema()) .unwrap_err(); - assert!(expected_err - .to_string() - .contains("extension node failed it's user-defined always-invariant check")); + assert!( + expected_err.to_string().contains( + "extension node failed it's user-defined always-invariant check" + ) + ); Ok(()) } @@ -3268,7 +4344,7 @@ digraph { fn children(&self) -> Vec<&Arc> { vec![] } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { unimplemented!() } fn execute( @@ -3278,6 +4354,12 @@ digraph { ) -> Result { unimplemented!() } + fn apply_expressions( + &self, + _f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + Ok(TreeNodeRecursion::Continue) + } } impl DisplayAs for ExecutableInvariantFails { fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { @@ -3318,4 +4400,358 @@ digraph { Ok(()) } + + // Reproducer for DataFusion issue #17405: + // + // The following SQL is semantically invalid. Notably, the `SELECT left_table.a, right_table.a` + // clause is missing from the explicit logical plan: + // + // SELECT a FROM ( + // -- SELECT left_table.a, right_table.a + // FROM left_table + // FULL JOIN right_table ON left_table.a = right_table.a + // ) AS alias + // GROUP BY a; + // + // As a result, the variables within `alias` subquery are not properly distinguished, which + // leads to a bug for logical and physical planning. + // + // The fix is to implicitly insert a Projection node to represent the missing SELECT clause to + // ensure each field is correctly aliased to a unique name when the SubqueryAlias node is added. + #[tokio::test] + async fn subquery_alias_confusing_the_optimizer() -> Result<()> { + let state = make_session_state(); + + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let schema = Arc::new(schema); + + let table = MemTable::try_new(schema.clone(), vec![vec![]])?; + let table = Arc::new(table); + + let source = DefaultTableSource::new(table); + let source = Arc::new(source); + + let left = LogicalPlanBuilder::scan("left", source.clone(), None)?; + let right = LogicalPlanBuilder::scan("right", source, None)?.build()?; + + let join_keys = ( + vec![Column::new(Some("left"), "a")], + vec![Column::new(Some("right"), "a")], + ); + + let join = left.join(right, JoinType::Full, join_keys, None)?.build()?; + + let alias = subquery_alias(join, "alias")?; + + let planner = DefaultPhysicalPlanner::default(); + + let logical_plan = LogicalPlanBuilder::new(alias) + .aggregate(vec![col("a:1")], Vec::::new())? + .build()?; + let _physical_plan = planner.create_physical_plan(&logical_plan, &state).await?; + + let optimized_logical_plan = state.optimize(&logical_plan)?; + let _optimized_physical_plan = planner + .create_physical_plan(&optimized_logical_plan, &state) + .await?; + + Ok(()) + } + + // --- Tests for aggregate schema mismatch error messages --- + + use crate::catalog::TableProvider; + use datafusion_catalog::Session; + use datafusion_expr::TableType; + + /// A TableProvider that returns schemas for logical planning vs physical planning. + /// Used to test schema mismatch error messages. + #[derive(Debug)] + struct MockSchemaTableProvider { + logical_schema: SchemaRef, + physical_schema: SchemaRef, + } + + #[async_trait] + impl TableProvider for MockSchemaTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.logical_schema) + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &dyn Session, + _projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + Ok(Arc::new(NoOpExecutionPlan::new(Arc::clone( + &self.physical_schema, + )))) + } + } + + /// Attempts to plan a query with potentially mismatched schemas. + async fn plan_with_schemas( + logical_schema: SchemaRef, + physical_schema: SchemaRef, + query: &str, + ) -> Result> { + let provider = MockSchemaTableProvider { + logical_schema, + physical_schema, + }; + let ctx = SessionContext::new(); + ctx.register_table("test", Arc::new(provider)).unwrap(); + + ctx.sql(query).await.unwrap().create_physical_plan().await + } + + #[tokio::test] + // When schemas match, planning proceeds past the schema_satisfied_by check. + // It then panics on unimplemented error in NoOpExecutionPlan. + #[should_panic(expected = "NoOpExecutionPlan")] + async fn test_aggregate_schema_check_passes() { + let schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); + + plan_with_schemas( + Arc::clone(&schema), + schema, + "SELECT count(*) FROM test GROUP BY c1", + ) + .await + .unwrap(); + } + + #[tokio::test] + async fn test_aggregate_schema_mismatch_metadata() { + let logical_schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); + let physical_schema = Arc::new( + Schema::new(vec![Field::new("c1", DataType::Int32, false)]) + .with_metadata(HashMap::from([("key".into(), "value".into())])), + ); + + let err = plan_with_schemas( + logical_schema, + physical_schema, + "SELECT count(*) FROM test GROUP BY c1", + ) + .await + .unwrap_err(); + + assert_contains!(err.to_string(), "schema metadata differs"); + } + + #[tokio::test] + async fn test_aggregate_schema_mismatch_field_count() { + let logical_schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); + let physical_schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Int32, false), + ])); + + let err = plan_with_schemas( + logical_schema, + physical_schema, + "SELECT count(*) FROM test GROUP BY c1", + ) + .await + .unwrap_err(); + + assert_contains!(err.to_string(), "Different number of fields"); + } + + #[tokio::test] + async fn test_aggregate_schema_mismatch_field_name() { + let logical_schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); + let physical_schema = Arc::new(Schema::new(vec![Field::new( + "different_name", + DataType::Int32, + false, + )])); + + let err = plan_with_schemas( + logical_schema, + physical_schema, + "SELECT count(*) FROM test GROUP BY c1", + ) + .await + .unwrap_err(); + + assert_contains!(err.to_string(), "field name at index"); + } + + #[tokio::test] + async fn test_aggregate_schema_mismatch_field_type() { + let logical_schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); + let physical_schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int64, false)])); + + let err = plan_with_schemas( + logical_schema, + physical_schema, + "SELECT count(*) FROM test GROUP BY c1", + ) + .await + .unwrap_err(); + + assert_contains!(err.to_string(), "field data type at index"); + } + + #[tokio::test] + async fn test_aggregate_schema_mismatch_field_nullability() { + let logical_schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); + let physical_schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, true)])); + + let err = plan_with_schemas( + logical_schema, + physical_schema, + "SELECT count(*) FROM test GROUP BY c1", + ) + .await + .unwrap_err(); + + assert_contains!(err.to_string(), "field nullability at index"); + } + + #[tokio::test] + async fn test_aggregate_schema_mismatch_field_metadata() { + let logical_schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); + let physical_schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, false) + .with_metadata(HashMap::from([("key".into(), "value".into())])), + ])); + + let err = plan_with_schemas( + logical_schema, + physical_schema, + "SELECT count(*) FROM test GROUP BY c1", + ) + .await + .unwrap_err(); + + assert_contains!(err.to_string(), "field metadata at index"); + } + + #[tokio::test] + async fn test_aggregate_schema_mismatch_multiple() { + let logical_schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Utf8, false), + ])); + let physical_schema = Arc::new( + Schema::new(vec![ + Field::new("c1", DataType::Int64, true) + .with_metadata(HashMap::from([("key".into(), "value".into())])), + Field::new("c2", DataType::Utf8, false), + ]) + .with_metadata(HashMap::from([( + "schema_key".into(), + "schema_value".into(), + )])), + ); + + let err = plan_with_schemas( + logical_schema, + physical_schema, + "SELECT count(*) FROM test GROUP BY c1", + ) + .await + .unwrap_err(); + + // Verify all applicable error fragments are present + let err_str = err.to_string(); + assert_contains!(&err_str, "schema metadata differs"); + assert_contains!(&err_str, "field data type at index"); + assert_contains!(&err_str, "field nullability at index"); + assert_contains!(&err_str, "field metadata at index"); + } + + #[derive(Debug)] + struct MockTableSource { + schema: SchemaRef, + } + + impl TableSource for MockTableSource { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + } + + struct MockTableScanExtensionPlanner; + + #[async_trait] + impl ExtensionPlanner for MockTableScanExtensionPlanner { + async fn plan_extension( + &self, + _planner: &dyn PhysicalPlanner, + _node: &dyn UserDefinedLogicalNode, + _logical_inputs: &[&LogicalPlan], + _physical_inputs: &[Arc], + _session_state: &SessionState, + ) -> Result>> { + Ok(None) + } + + async fn plan_table_scan( + &self, + _planner: &dyn PhysicalPlanner, + scan: &TableScan, + _session_state: &SessionState, + ) -> Result>> { + if scan.source.as_any().is::() { + Ok(Some(Arc::new(EmptyExec::new(Arc::clone( + scan.projected_schema.inner(), + ))))) + } else { + Ok(None) + } + } + } + + #[tokio::test] + async fn test_table_scan_extension_planner() { + let session_state = make_session_state(); + let planner = Arc::new(MockTableScanExtensionPlanner); + let physical_planner = + DefaultPhysicalPlanner::with_extension_planners(vec![planner]); + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + let table_source = Arc::new(MockTableSource { + schema: Arc::clone(&schema), + }); + let logical_plan = LogicalPlanBuilder::scan("test", table_source, None) + .unwrap() + .build() + .unwrap(); + + let plan = physical_planner + .create_physical_plan(&logical_plan, &session_state) + .await + .unwrap(); + + assert_eq!(plan.schema(), schema); + assert!(plan.as_any().is::()); + } } diff --git a/datafusion/core/src/prelude.rs b/datafusion/core/src/prelude.rs index d723620d32323..31d9d7eb471f0 100644 --- a/datafusion/core/src/prelude.rs +++ b/datafusion/core/src/prelude.rs @@ -29,15 +29,15 @@ pub use crate::dataframe; pub use crate::dataframe::DataFrame; pub use crate::execution::context::{SQLOptions, SessionConfig, SessionContext}; pub use crate::execution::options::{ - AvroReadOptions, CsvReadOptions, NdJsonReadOptions, ParquetReadOptions, + AvroReadOptions, CsvReadOptions, JsonReadOptions, ParquetReadOptions, }; pub use datafusion_common::Column; pub use datafusion_expr::{ + Expr, expr_fn::*, lit, lit_timestamp_nano, logical_plan::{JoinType, Partitioning}, - Expr, }; pub use datafusion_functions::expr_fn::*; #[cfg(feature = "nested_expressions")] diff --git a/datafusion/core/src/test/mod.rs b/datafusion/core/src/test/mod.rs index 8719a16f4919f..717182f1d3d5b 100644 --- a/datafusion/core/src/test/mod.rs +++ b/datafusion/core/src/test/mod.rs @@ -25,9 +25,9 @@ use std::io::{BufReader, BufWriter}; use std::path::Path; use std::sync::Arc; +use crate::datasource::file_format::FileFormat; use crate::datasource::file_format::csv::CsvFormat; use crate::datasource::file_format::file_compression_type::FileCompressionType; -use crate::datasource::file_format::FileFormat; use crate::datasource::physical_plan::CsvSource; use crate::datasource::{MemTable, TableProvider}; @@ -35,27 +35,31 @@ use crate::error::Result; use crate::logical_expr::LogicalPlan; use crate::test_util::{aggr_test_schema, arrow_test_data}; +use datafusion_common::config::CsvOptions; + use arrow::array::{self, Array, ArrayRef, Decimal128Builder, Int32Array}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; +#[cfg(feature = "compression")] use datafusion_common::DataFusionError; +use datafusion_datasource::TableSchema; use datafusion_datasource::source::DataSourceExec; -#[cfg(feature = "compression")] -use bzip2::write::BzEncoder; #[cfg(feature = "compression")] use bzip2::Compression as BzCompression; +#[cfg(feature = "compression")] +use bzip2::write::BzEncoder; use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_datasource_csv::partitioned_csv_config; #[cfg(feature = "compression")] +use flate2::Compression as GzCompression; +#[cfg(feature = "compression")] use flate2::write::GzEncoder; #[cfg(feature = "compression")] -use flate2::Compression as GzCompression; +use liblzma::write::XzEncoder; use object_store::local_unpartitioned_file; #[cfg(feature = "compression")] -use xz2::write::XzEncoder; -#[cfg(feature = "compression")] use zstd::Encoder as ZstdEncoder; pub fn create_table_dual() -> Arc { @@ -83,17 +87,26 @@ pub fn scan_partitioned_csv( let schema = aggr_test_schema(); let filename = "aggregate_test_100.csv"; let path = format!("{}/csv", arrow_test_data()); + let csv_format: Arc = Arc::new(CsvFormat::default()); + let file_groups = partitioned_file_groups( path.as_str(), filename, partitions, - Arc::new(CsvFormat::default()), + &csv_format, FileCompressionType::UNCOMPRESSED, work_dir, )?; - let source = Arc::new(CsvSource::new(true, b'"', b'"')); + let options = CsvOptions { + has_header: Some(true), + delimiter: b',', + quote: b'"', + ..Default::default() + }; + let table_schema = TableSchema::from_file_schema(schema); + let source = Arc::new(CsvSource::new(table_schema.clone()).with_csv_options(options)); let config = - FileScanConfigBuilder::from(partitioned_csv_config(schema, file_groups, source)) + FileScanConfigBuilder::from(partitioned_csv_config(file_groups, source)?) .with_file_compression_type(FileCompressionType::UNCOMPRESSED) .build(); Ok(DataSourceExec::from_data_source(config)) @@ -104,7 +117,7 @@ pub fn partitioned_file_groups( path: &str, filename: &str, partitions: usize, - file_format: Arc, + file_format: &Arc, file_compression_type: FileCompressionType, work_dir: &Path, ) -> Result> { @@ -188,7 +201,7 @@ pub fn partitioned_file_groups( .collect::>()) } -pub fn assert_fields_eq(plan: &LogicalPlan, expected: Vec<&str>) { +pub fn assert_fields_eq(plan: &LogicalPlan, expected: &[&str]) { let actual: Vec = plan .schema() .fields() diff --git a/datafusion/core/src/test/object_store.rs b/datafusion/core/src/test/object_store.rs index ed8474bbfc812..62c6699f8fcd1 100644 --- a/datafusion/core/src/test/object_store.rs +++ b/datafusion/core/src/test/object_store.rs @@ -17,21 +17,24 @@ //! Object store implementation used for testing -use crate::execution::context::SessionState; -use crate::execution::session_state::SessionStateBuilder; -use crate::prelude::SessionContext; -use futures::stream::BoxStream; -use futures::FutureExt; -use object_store::{ - memory::InMemory, path::Path, Error, GetOptions, GetResult, ListResult, - MultipartUpload, ObjectMeta, ObjectStore, PutMultipartOpts, PutOptions, PutPayload, - PutResult, +use crate::{ + execution::{context::SessionState, session_state::SessionStateBuilder}, + object_store::{ + Error, GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, + ObjectStore, PutMultipartOptions, PutOptions, PutPayload, PutResult, + memory::InMemory, path::Path, + }, + prelude::SessionContext, +}; +use futures::{FutureExt, stream::BoxStream}; +use object_store::{CopyOptions, ObjectStoreExt}; +use std::{ + fmt::{Debug, Display, Formatter}, + sync::Arc, }; -use std::fmt::{Debug, Display, Formatter}; -use std::sync::Arc; use tokio::{ sync::Barrier, - time::{timeout, Duration}, + time::{Duration, timeout}, }; use url::Url; @@ -118,7 +121,7 @@ impl ObjectStore for BlockingObjectStore { async fn put_multipart_opts( &self, location: &Path, - opts: PutMultipartOpts, + opts: PutMultipartOptions, ) -> object_store::Result> { self.inner.put_multipart_opts(location, opts).await } @@ -128,39 +131,40 @@ impl ObjectStore for BlockingObjectStore { location: &Path, options: GetOptions, ) -> object_store::Result { - self.inner.get_opts(location, options).await - } - - async fn head(&self, location: &Path) -> object_store::Result { - println!( - "{} received head call for {location}", - BlockingObjectStore::NAME - ); - // Wait until the expected number of concurrent calls is reached, but timeout after 1 second to avoid hanging failing tests. - let wait_result = timeout(Duration::from_secs(1), self.barrier.wait()).await; - match wait_result { - Ok(_) => println!( - "{} barrier reached for {location}", + if options.head { + println!( + "{} received head call for {location}", BlockingObjectStore::NAME - ), - Err(_) => { - let error_message = format!( - "{} barrier wait timed out for {location}", + ); + // Wait until the expected number of concurrent calls is reached, but timeout after 1 second to avoid hanging failing tests. + let wait_result = timeout(Duration::from_secs(1), self.barrier.wait()).await; + match wait_result { + Ok(_) => println!( + "{} barrier reached for {location}", BlockingObjectStore::NAME - ); - log::error!("{error_message}"); - return Err(Error::Generic { - store: BlockingObjectStore::NAME, - source: error_message.into(), - }); + ), + Err(_) => { + let error_message = format!( + "{} barrier wait timed out for {location}", + BlockingObjectStore::NAME + ); + log::error!("{error_message}"); + return Err(Error::Generic { + store: BlockingObjectStore::NAME, + source: error_message.into(), + }); + } } } + // Forward the call to the inner object store. - self.inner.head(location).await + self.inner.get_opts(location, options).await } - - async fn delete(&self, location: &Path) -> object_store::Result<()> { - self.inner.delete(location).await + fn delete_stream( + &self, + locations: BoxStream<'static, object_store::Result>, + ) -> BoxStream<'static, object_store::Result> { + self.inner.delete_stream(locations) } fn list( @@ -177,15 +181,12 @@ impl ObjectStore for BlockingObjectStore { self.inner.list_with_delimiter(prefix).await } - async fn copy(&self, from: &Path, to: &Path) -> object_store::Result<()> { - self.inner.copy(from, to).await - } - - async fn copy_if_not_exists( + async fn copy_opts( &self, from: &Path, to: &Path, + options: CopyOptions, ) -> object_store::Result<()> { - self.inner.copy_if_not_exists(from, to).await + self.inner.copy_opts(from, to, options).await } } diff --git a/datafusion/core/src/test_util/mod.rs b/datafusion/core/src/test_util/mod.rs index d6865ca3d532a..466ee38a426fd 100644 --- a/datafusion/core/src/test_util/mod.rs +++ b/datafusion/core/src/test_util/mod.rs @@ -22,27 +22,36 @@ pub mod parquet; pub mod csv; +use futures::Stream; use std::any::Any; use std::collections::HashMap; +use std::fmt::Formatter; use std::fs::File; use std::io::Write; use std::path::Path; use std::sync::Arc; +use std::task::{Context, Poll}; use crate::catalog::{TableProvider, TableProviderFactory}; use crate::dataframe::DataFrame; use crate::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; use crate::datasource::{empty::EmptyTable, provider_as_source}; use crate::error::Result; +use crate::execution::session_state::CacheFactory; use crate::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE}; use crate::physical_plan::ExecutionPlan; use crate::prelude::{CsvReadOptions, SessionContext}; +use crate::execution::{SendableRecordBatchStream, SessionState, SessionStateBuilder}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_catalog::Session; -use datafusion_common::TableReference; -use datafusion_expr::{CreateExternalTable, Expr, SortExpr, TableType}; +use datafusion_common::{DFSchemaRef, TableReference}; +use datafusion_expr::{ + CreateExternalTable, Expr, LogicalPlan, SortExpr, TableType, + UserDefinedLogicalNodeCore, +}; +use std::pin::Pin; use async_trait::async_trait; @@ -52,6 +61,8 @@ use tempfile::TempDir; pub use datafusion_common::test_util::parquet_test_data; pub use datafusion_common::test_util::{arrow_test_data, get_data_dir}; +use crate::execution::RecordBatchStream; + /// Scan an empty data source, mainly used in tests pub fn scan_empty( name: Option<&str>, @@ -129,6 +140,7 @@ pub async fn test_table() -> Result { } /// Execute SQL and return results +#[cfg(feature = "sql")] pub async fn plan_and_collect( ctx: &SessionContext, sql: &str, @@ -178,7 +190,7 @@ impl TableProviderFactory for TestTableFactory { ) -> Result> { Ok(Arc::new(TestTableProvider { url: cmd.location.to_string(), - schema: Arc::new(cmd.schema.as_ref().into()), + schema: Arc::clone(cmd.schema.inner()), })) } } @@ -234,3 +246,108 @@ pub fn register_unbounded_file_with_ordering( ctx.register_table(table_name, Arc::new(StreamTable::new(Arc::new(config))))?; Ok(()) } + +/// Creates a bounded stream that emits the same record batch a specified number of times. +/// This is useful for testing purposes. +pub fn bounded_stream( + record_batch: RecordBatch, + limit: usize, +) -> SendableRecordBatchStream { + Box::pin(BoundedStream { + record_batch, + count: 0, + limit, + }) +} + +struct BoundedStream { + record_batch: RecordBatch, + count: usize, + limit: usize, +} + +impl Stream for BoundedStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + if self.count >= self.limit { + Poll::Ready(None) + } else { + self.count += 1; + Poll::Ready(Some(Ok(self.record_batch.clone()))) + } + } +} + +impl RecordBatchStream for BoundedStream { + fn schema(&self) -> SchemaRef { + self.record_batch.schema() + } +} + +#[derive(Hash, Eq, PartialEq, PartialOrd, Debug)] +struct CacheNode { + input: LogicalPlan, +} + +impl UserDefinedLogicalNodeCore for CacheNode { + fn name(&self) -> &str { + "CacheNode" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + fn schema(&self) -> &DFSchemaRef { + self.input.schema() + } + + fn expressions(&self) -> Vec { + vec![] + } + + fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "CacheNode") + } + + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + inputs: Vec, + ) -> Result { + assert_eq!(inputs.len(), 1, "input size inconsistent"); + Ok(Self { + input: inputs[0].clone(), + }) + } +} + +#[derive(Debug)] +struct TestCacheFactory {} + +impl CacheFactory for TestCacheFactory { + fn create( + &self, + plan: LogicalPlan, + _session_state: &SessionState, + ) -> Result { + Ok(LogicalPlan::Extension(datafusion_expr::Extension { + node: Arc::new(CacheNode { input: plan }), + })) + } +} + +/// Create a test table registered to a session context with an associated cache factory +pub async fn test_table_with_cache_factory() -> Result { + let session_state = SessionStateBuilder::new() + .with_cache_factory(Some(Arc::new(TestCacheFactory {}))) + .build(); + let ctx = SessionContext::new_with_state(session_state); + let name = "aggregate_test_100"; + register_aggregate_csv(&ctx, name).await?; + ctx.table(name).await +} diff --git a/datafusion/core/src/test_util/parquet.rs b/datafusion/core/src/test_util/parquet.rs index eb4c61c025248..dba017f83ba1e 100644 --- a/datafusion/core/src/test_util/parquet.rs +++ b/datafusion/core/src/test_util/parquet.rs @@ -32,16 +32,15 @@ use crate::logical_expr::execution_props::ExecutionProps; use crate::logical_expr::simplify::SimplifyContext; use crate::optimizer::simplify_expressions::ExprSimplifier; use crate::physical_expr::create_physical_expr; +use crate::physical_plan::ExecutionPlan; use crate::physical_plan::filter::FilterExec; use crate::physical_plan::metrics::MetricsSet; -use crate::physical_plan::ExecutionPlan; use crate::prelude::{Expr, SessionConfig, SessionContext}; -use datafusion_datasource::file::FileSource; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_datasource::source::DataSourceExec; -use object_store::path::Path; use object_store::ObjectMeta; +use object_store::path::Path; use parquet::arrow::ArrowWriter; use parquet::file::properties::WriterProperties; @@ -156,26 +155,18 @@ impl TestParquetFile { maybe_filter: Option, ) -> Result> { let parquet_options = ctx.copied_table_options().parquet; - let source = Arc::new(ParquetSource::new(parquet_options.clone())); - let scan_config_builder = FileScanConfigBuilder::new( - self.object_store_url.clone(), - Arc::clone(&self.schema), - source, - ) - .with_file(PartitionedFile { - object_meta: self.object_meta.clone(), - partition_values: vec![], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - }); + let source = Arc::new( + ParquetSource::new(Arc::clone(&self.schema)) + .with_table_parquet_options(parquet_options.clone()), + ); + let scan_config_builder = + FileScanConfigBuilder::new(self.object_store_url.clone(), source) + .with_file(PartitionedFile::new_from_meta(self.object_meta.clone())); let df_schema = Arc::clone(&self.schema).to_dfschema_ref()?; // run coercion on the filters to coerce types etc. - let props = ExecutionProps::new(); - let context = SimplifyContext::new(&props).with_schema(Arc::clone(&df_schema)); + let context = SimplifyContext::default().with_schema(Arc::clone(&df_schema)); if let Some(filter) = maybe_filter { let simplifier = ExprSimplifier::new(context); let filter = simplifier.coerce(filter, &df_schema).unwrap(); @@ -183,10 +174,10 @@ impl TestParquetFile { create_physical_expr(&filter, &df_schema, &ExecutionProps::default())?; let source = Arc::new( - ParquetSource::new(parquet_options) + ParquetSource::new(Arc::clone(&self.schema)) + .with_table_parquet_options(parquet_options) .with_predicate(Arc::clone(&physical_filter_expr)), - ) - .with_schema(Arc::clone(&self.schema)); + ); let config = scan_config_builder.with_source(source).build(); let parquet_exec = DataSourceExec::from_data_source(config); @@ -203,13 +194,12 @@ impl TestParquetFile { /// Recursively searches for DataSourceExec and returns the metrics /// on the first one it finds pub fn parquet_metrics(plan: &Arc) -> Option { - if let Some(data_source_exec) = plan.as_any().downcast_ref::() { - if data_source_exec + if let Some(data_source_exec) = plan.as_any().downcast_ref::() + && data_source_exec .downcast_to_file_source::() .is_some() - { - return data_source_exec.metrics(); - } + { + return data_source_exec.metrics(); } for child in plan.children() { diff --git a/datafusion/core/tests/catalog/memory.rs b/datafusion/core/tests/catalog/memory.rs index b0753eb5c9494..5258f3bf97574 100644 --- a/datafusion/core/tests/catalog/memory.rs +++ b/datafusion/core/tests/catalog/memory.rs @@ -19,7 +19,7 @@ use arrow::datatypes::Schema; use datafusion::catalog::CatalogProvider; use datafusion::datasource::empty::EmptyTable; use datafusion::datasource::listing::{ - ListingTable, ListingTableConfig, ListingTableUrl, + ListingTable, ListingTableConfig, ListingTableConfigExt, ListingTableUrl, }; use datafusion::prelude::SessionContext; use datafusion_catalog::memory::*; @@ -47,6 +47,20 @@ fn memory_catalog_dereg_nonempty_schema() { assert!(cat.deregister_schema("foo", true).unwrap().is_some()); } +#[test] +fn memory_catalog_dereg_nonempty_schema_with_table_removal() { + let cat = Arc::new(MemoryCatalogProvider::new()) as Arc; + + let schema = Arc::new(MemorySchemaProvider::new()) as Arc; + let test_table = + Arc::new(EmptyTable::new(Arc::new(Schema::empty()))) as Arc; + schema.register_table("t".into(), test_table).unwrap(); + + cat.register_schema("foo", schema.clone()).unwrap(); + schema.deregister_table("t").unwrap(); + assert!(cat.deregister_schema("foo", false).unwrap().is_some()); +} + #[test] fn memory_catalog_dereg_empty_schema() { let cat = Arc::new(MemoryCatalogProvider::new()) as Arc; @@ -102,14 +116,16 @@ async fn test_mem_provider() { assert!(provider.deregister_table(table_name).unwrap().is_none()); let test_table = EmptyTable::new(Arc::new(Schema::empty())); // register table successfully - assert!(provider - .register_table(table_name.to_string(), Arc::new(test_table)) - .unwrap() - .is_none()); + assert!( + provider + .register_table(table_name.to_string(), Arc::new(test_table)) + .unwrap() + .is_none() + ); assert!(provider.table_exist(table_name)); let other_table = EmptyTable::new(Arc::new(Schema::empty())); let result = provider.register_table(table_name.to_string(), Arc::new(other_table)); - assert!(result.is_err()); + assert!(result.is_err(), "The table test_table_exist already exists"); } #[tokio::test] diff --git a/datafusion/core/tests/catalog_listing/mod.rs b/datafusion/core/tests/catalog_listing/mod.rs new file mode 100644 index 0000000000000..cb6cac4fb0672 --- /dev/null +++ b/datafusion/core/tests/catalog_listing/mod.rs @@ -0,0 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod pruned_partition_list; diff --git a/datafusion/core/tests/catalog_listing/pruned_partition_list.rs b/datafusion/core/tests/catalog_listing/pruned_partition_list.rs new file mode 100644 index 0000000000000..8f93dc17dbad2 --- /dev/null +++ b/datafusion/core/tests/catalog_listing/pruned_partition_list.rs @@ -0,0 +1,251 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow_schema::DataType; +use futures::{FutureExt, StreamExt as _, TryStreamExt as _}; +use object_store::{ObjectStoreExt, memory::InMemory, path::Path}; + +use datafusion::execution::SessionStateBuilder; +use datafusion_catalog_listing::helpers::{ + describe_partition, list_partitions, pruned_partition_list, +}; +use datafusion_common::ScalarValue; +use datafusion_datasource::ListingTableUrl; +use datafusion_expr::{Expr, col, lit}; +use datafusion_session::Session; + +#[tokio::test] +async fn test_pruned_partition_list_empty() { + let (store, state) = make_test_store_and_state(&[ + ("tablepath/mypartition=val1/notparquetfile", 100), + ("tablepath/mypartition=val1/ignoresemptyfile.parquet", 0), + ("tablepath/file.parquet", 100), + ("tablepath/notapartition/file.parquet", 100), + ("tablepath/notmypartition=val1/file.parquet", 100), + ]); + let filter = Expr::eq(col("mypartition"), lit("val1")); + let pruned = pruned_partition_list( + state.as_ref(), + store.as_ref(), + &ListingTableUrl::parse("file:///tablepath/").unwrap(), + &[filter], + ".parquet", + &[(String::from("mypartition"), DataType::Utf8)], + ) + .await + .expect("partition pruning failed") + .collect::>() + .await; + + assert_eq!(pruned.len(), 0); +} + +#[tokio::test] +async fn test_pruned_partition_list() { + let (store, state) = make_test_store_and_state(&[ + ("tablepath/mypartition=val1/file.parquet", 100), + ("tablepath/mypartition=val2/file.parquet", 100), + ("tablepath/mypartition=val1/ignoresemptyfile.parquet", 0), + ("tablepath/mypartition=val1/other=val3/file.parquet", 100), + ("tablepath/notapartition/file.parquet", 100), + ("tablepath/notmypartition=val1/file.parquet", 100), + ]); + let filter = Expr::eq(col("mypartition"), lit("val1")); + let pruned = pruned_partition_list( + state.as_ref(), + store.as_ref(), + &ListingTableUrl::parse("file:///tablepath/").unwrap(), + &[filter], + ".parquet", + &[(String::from("mypartition"), DataType::Utf8)], + ) + .await + .expect("partition pruning failed") + .try_collect::>() + .await + .unwrap(); + + assert_eq!(pruned.len(), 2); + let f1 = &pruned[0]; + assert_eq!( + f1.object_meta.location.as_ref(), + "tablepath/mypartition=val1/file.parquet" + ); + assert_eq!(&f1.partition_values, &[ScalarValue::from("val1")]); + let f2 = &pruned[1]; + assert_eq!( + f2.object_meta.location.as_ref(), + "tablepath/mypartition=val1/other=val3/file.parquet" + ); + assert_eq!(f2.partition_values, &[ScalarValue::from("val1"),]); +} + +#[tokio::test] +async fn test_pruned_partition_list_multi() { + let (store, state) = make_test_store_and_state(&[ + ("tablepath/part1=p1v1/file.parquet", 100), + ("tablepath/part1=p1v2/part2=p2v1/file1.parquet", 100), + ("tablepath/part1=p1v2/part2=p2v1/file2.parquet", 100), + ("tablepath/part1=p1v3/part2=p2v1/file2.parquet", 100), + ("tablepath/part1=p1v2/part2=p2v2/file2.parquet", 100), + ]); + let filter1 = Expr::eq(col("part1"), lit("p1v2")); + let filter2 = Expr::eq(col("part2"), lit("p2v1")); + let pruned = pruned_partition_list( + state.as_ref(), + store.as_ref(), + &ListingTableUrl::parse("file:///tablepath/").unwrap(), + &[filter1, filter2], + ".parquet", + &[ + (String::from("part1"), DataType::Utf8), + (String::from("part2"), DataType::Utf8), + ], + ) + .await + .expect("partition pruning failed") + .try_collect::>() + .await + .unwrap(); + + assert_eq!(pruned.len(), 2); + let f1 = &pruned[0]; + assert_eq!( + f1.object_meta.location.as_ref(), + "tablepath/part1=p1v2/part2=p2v1/file1.parquet" + ); + assert_eq!( + &f1.partition_values, + &[ScalarValue::from("p1v2"), ScalarValue::from("p2v1"),] + ); + let f2 = &pruned[1]; + assert_eq!( + f2.object_meta.location.as_ref(), + "tablepath/part1=p1v2/part2=p2v1/file2.parquet" + ); + assert_eq!( + &f2.partition_values, + &[ScalarValue::from("p1v2"), ScalarValue::from("p2v1")] + ); +} + +#[tokio::test] +async fn test_list_partition() { + let (store, _) = make_test_store_and_state(&[ + ("tablepath/part1=p1v1/file.parquet", 100), + ("tablepath/part1=p1v2/part2=p2v1/file1.parquet", 100), + ("tablepath/part1=p1v2/part2=p2v1/file2.parquet", 100), + ("tablepath/part1=p1v3/part2=p2v1/file3.parquet", 100), + ("tablepath/part1=p1v2/part2=p2v2/file4.parquet", 100), + ("tablepath/part1=p1v2/part2=p2v2/empty.parquet", 0), + ]); + + let partitions = list_partitions( + store.as_ref(), + &ListingTableUrl::parse("file:///tablepath/").unwrap(), + 0, + None, + ) + .await + .expect("listing partitions failed"); + + assert_eq!( + &partitions + .iter() + .map(describe_partition) + .collect::>(), + &vec![ + ("tablepath", 0, vec![]), + ("tablepath/part1=p1v1", 1, vec![]), + ("tablepath/part1=p1v2", 1, vec![]), + ("tablepath/part1=p1v3", 1, vec![]), + ] + ); + + let partitions = list_partitions( + store.as_ref(), + &ListingTableUrl::parse("file:///tablepath/").unwrap(), + 1, + None, + ) + .await + .expect("listing partitions failed"); + + assert_eq!( + &partitions + .iter() + .map(describe_partition) + .collect::>(), + &vec![ + ("tablepath", 0, vec![]), + ("tablepath/part1=p1v1", 1, vec!["file.parquet"]), + ("tablepath/part1=p1v2", 1, vec![]), + ("tablepath/part1=p1v2/part2=p2v1", 2, vec![]), + ("tablepath/part1=p1v2/part2=p2v2", 2, vec![]), + ("tablepath/part1=p1v3", 1, vec![]), + ("tablepath/part1=p1v3/part2=p2v1", 2, vec![]), + ] + ); + + let partitions = list_partitions( + store.as_ref(), + &ListingTableUrl::parse("file:///tablepath/").unwrap(), + 2, + None, + ) + .await + .expect("listing partitions failed"); + + assert_eq!( + &partitions + .iter() + .map(describe_partition) + .collect::>(), + &vec![ + ("tablepath", 0, vec![]), + ("tablepath/part1=p1v1", 1, vec!["file.parquet"]), + ("tablepath/part1=p1v2", 1, vec![]), + ("tablepath/part1=p1v3", 1, vec![]), + ( + "tablepath/part1=p1v2/part2=p2v1", + 2, + vec!["file1.parquet", "file2.parquet"] + ), + ("tablepath/part1=p1v2/part2=p2v2", 2, vec!["file4.parquet"]), + ("tablepath/part1=p1v3/part2=p2v1", 2, vec!["file3.parquet"]), + ] + ); +} + +pub fn make_test_store_and_state( + files: &[(&str, u64)], +) -> (Arc, Arc) { + let memory = InMemory::new(); + + for (name, size) in files { + memory + .put(&Path::from(*name), vec![0; *size as usize].into()) + .now_or_never() + .unwrap() + .unwrap(); + } + + let state = SessionStateBuilder::new().build(); + (Arc::new(memory), Arc::new(state)) +} diff --git a/datafusion/core/tests/config_from_env.rs b/datafusion/core/tests/config_from_env.rs index 976597c8a9ac5..6375d4e25d8eb 100644 --- a/datafusion/core/tests/config_from_env.rs +++ b/datafusion/core/tests/config_from_env.rs @@ -20,35 +20,43 @@ use std::env; #[test] fn from_env() { - // Note: these must be a single test to avoid interference from concurrent execution - let env_key = "DATAFUSION_OPTIMIZER_FILTER_NULL_JOIN_KEYS"; - // valid testing in different cases - for bool_option in ["true", "TRUE", "True", "tRUe"] { - env::set_var(env_key, bool_option); - let config = ConfigOptions::from_env().unwrap(); - env::remove_var(env_key); - assert!(config.optimizer.filter_null_join_keys); - } + unsafe { + // Note: these must be a single test to avoid interference from concurrent execution + let env_key = "DATAFUSION_OPTIMIZER_FILTER_NULL_JOIN_KEYS"; + // valid testing in different cases + for bool_option in ["true", "TRUE", "True", "tRUe"] { + env::set_var(env_key, bool_option); + let config = ConfigOptions::from_env().unwrap(); + env::remove_var(env_key); + assert!(config.optimizer.filter_null_join_keys); + } - // invalid testing - env::set_var(env_key, "ttruee"); - let err = ConfigOptions::from_env().unwrap_err().strip_backtrace(); - assert_eq!(err, "Error parsing 'ttruee' as bool\ncaused by\nExternal error: provided string was not `true` or `false`"); - env::remove_var(env_key); + // invalid testing + env::set_var(env_key, "ttruee"); + let err = ConfigOptions::from_env().unwrap_err().strip_backtrace(); + assert_eq!( + err, + "Error parsing 'ttruee' as bool\ncaused by\nExternal error: provided string was not `true` or `false`" + ); + env::remove_var(env_key); - let env_key = "DATAFUSION_EXECUTION_BATCH_SIZE"; + let env_key = "DATAFUSION_EXECUTION_BATCH_SIZE"; - // for valid testing - env::set_var(env_key, "4096"); - let config = ConfigOptions::from_env().unwrap(); - assert_eq!(config.execution.batch_size, 4096); + // for valid testing + env::set_var(env_key, "4096"); + let config = ConfigOptions::from_env().unwrap(); + assert_eq!(config.execution.batch_size, 4096); - // for invalid testing - env::set_var(env_key, "abc"); - let err = ConfigOptions::from_env().unwrap_err().strip_backtrace(); - assert_eq!(err, "Error parsing 'abc' as usize\ncaused by\nExternal error: invalid digit found in string"); + // for invalid testing + env::set_var(env_key, "abc"); + let err = ConfigOptions::from_env().unwrap_err().strip_backtrace(); + assert_eq!( + err, + "Error parsing 'abc' as usize\ncaused by\nExternal error: invalid digit found in string" + ); - env::remove_var(env_key); - let config = ConfigOptions::from_env().unwrap(); - assert_eq!(config.execution.batch_size, 8192); // set to its default value + env::remove_var(env_key); + let config = ConfigOptions::from_env().unwrap(); + assert_eq!(config.execution.batch_size, 8192); // set to its default value + } } diff --git a/datafusion/core/tests/core_integration.rs b/datafusion/core/tests/core_integration.rs index 250538b133703..bdbe72245323d 100644 --- a/datafusion/core/tests/core_integration.rs +++ b/datafusion/core/tests/core_integration.rs @@ -21,6 +21,9 @@ mod sql; /// Run all tests that are found in the `dataframe` directory mod dataframe; +/// Run all tests that are found in the `datasource` directory +mod datasource; + /// Run all tests that are found in the `macro_hygiene` directory mod macro_hygiene; @@ -51,6 +54,9 @@ mod serde; /// Run all tests that are found in the `catalog` directory mod catalog; +/// Run all tests that are found in the `catalog_listing` directory +mod catalog_listing; + /// Run all tests that are found in the `tracing` directory mod tracing; diff --git a/datafusion/core/tests/custom_sources_cases/dml_planning.rs b/datafusion/core/tests/custom_sources_cases/dml_planning.rs new file mode 100644 index 0000000000000..8c4bae5e98b36 --- /dev/null +++ b/datafusion/core/tests/custom_sources_cases/dml_planning.rs @@ -0,0 +1,819 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for DELETE, UPDATE, and TRUNCATE planning to verify filter and assignment extraction. + +use std::any::Any; +use std::sync::{Arc, Mutex}; + +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use async_trait::async_trait; +use datafusion::datasource::{TableProvider, TableType}; +use datafusion::error::Result; +use datafusion::execution::context::{SessionConfig, SessionContext}; +use datafusion::logical_expr::{ + Expr, LogicalPlan, TableProviderFilterPushDown, TableScan, +}; +use datafusion_catalog::Session; +use datafusion_common::ScalarValue; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; +use datafusion_physical_plan::ExecutionPlan; +use datafusion_physical_plan::empty::EmptyExec; + +/// A TableProvider that captures the filters passed to delete_from(). +struct CaptureDeleteProvider { + schema: SchemaRef, + received_filters: Arc>>>, + filter_pushdown: TableProviderFilterPushDown, + per_filter_pushdown: Option>, +} + +impl CaptureDeleteProvider { + fn new(schema: SchemaRef) -> Self { + Self { + schema, + received_filters: Arc::new(Mutex::new(None)), + filter_pushdown: TableProviderFilterPushDown::Unsupported, + per_filter_pushdown: None, + } + } + + fn new_with_filter_pushdown( + schema: SchemaRef, + filter_pushdown: TableProviderFilterPushDown, + ) -> Self { + Self { + schema, + received_filters: Arc::new(Mutex::new(None)), + filter_pushdown, + per_filter_pushdown: None, + } + } + + fn new_with_per_filter_pushdown( + schema: SchemaRef, + per_filter_pushdown: Vec, + ) -> Self { + Self { + schema, + received_filters: Arc::new(Mutex::new(None)), + filter_pushdown: TableProviderFilterPushDown::Unsupported, + per_filter_pushdown: Some(per_filter_pushdown), + } + } + + fn captured_filters(&self) -> Option> { + self.received_filters.lock().unwrap().clone() + } +} + +impl std::fmt::Debug for CaptureDeleteProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CaptureDeleteProvider") + .field("schema", &self.schema) + .finish() + } +} + +#[async_trait] +impl TableProvider for CaptureDeleteProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &dyn Session, + _projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + Ok(Arc::new(EmptyExec::new(Arc::clone(&self.schema)))) + } + + async fn delete_from( + &self, + _state: &dyn Session, + filters: Vec, + ) -> Result> { + *self.received_filters.lock().unwrap() = Some(filters); + Ok(Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![ + Field::new("count", DataType::UInt64, false), + ]))))) + } + + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> Result> { + if let Some(per_filter) = &self.per_filter_pushdown + && per_filter.len() == filters.len() + { + return Ok(per_filter.clone()); + } + + Ok(vec![self.filter_pushdown.clone(); filters.len()]) + } +} + +/// A TableProvider that captures filters and assignments passed to update(). +#[expect(clippy::type_complexity)] +struct CaptureUpdateProvider { + schema: SchemaRef, + received_filters: Arc>>>, + received_assignments: Arc>>>, + filter_pushdown: TableProviderFilterPushDown, + per_filter_pushdown: Option>, +} + +impl CaptureUpdateProvider { + fn new(schema: SchemaRef) -> Self { + Self { + schema, + received_filters: Arc::new(Mutex::new(None)), + received_assignments: Arc::new(Mutex::new(None)), + filter_pushdown: TableProviderFilterPushDown::Unsupported, + per_filter_pushdown: None, + } + } + + fn new_with_filter_pushdown( + schema: SchemaRef, + filter_pushdown: TableProviderFilterPushDown, + ) -> Self { + Self { + schema, + received_filters: Arc::new(Mutex::new(None)), + received_assignments: Arc::new(Mutex::new(None)), + filter_pushdown, + per_filter_pushdown: None, + } + } + + fn captured_filters(&self) -> Option> { + self.received_filters.lock().unwrap().clone() + } + + fn captured_assignments(&self) -> Option> { + self.received_assignments.lock().unwrap().clone() + } +} + +impl std::fmt::Debug for CaptureUpdateProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CaptureUpdateProvider") + .field("schema", &self.schema) + .finish() + } +} + +#[async_trait] +impl TableProvider for CaptureUpdateProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &dyn Session, + _projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + Ok(Arc::new(EmptyExec::new(Arc::clone(&self.schema)))) + } + + async fn update( + &self, + _state: &dyn Session, + assignments: Vec<(String, Expr)>, + filters: Vec, + ) -> Result> { + *self.received_filters.lock().unwrap() = Some(filters); + *self.received_assignments.lock().unwrap() = Some(assignments); + Ok(Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![ + Field::new("count", DataType::UInt64, false), + ]))))) + } + + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> Result> { + if let Some(per_filter) = &self.per_filter_pushdown + && per_filter.len() == filters.len() + { + return Ok(per_filter.clone()); + } + + Ok(vec![self.filter_pushdown.clone(); filters.len()]) + } +} + +/// A TableProvider that captures whether truncate() was called. +struct CaptureTruncateProvider { + schema: SchemaRef, + truncate_called: Arc>, +} + +impl CaptureTruncateProvider { + fn new(schema: SchemaRef) -> Self { + Self { + schema, + truncate_called: Arc::new(Mutex::new(false)), + } + } + + fn was_truncated(&self) -> bool { + *self.truncate_called.lock().unwrap() + } +} + +impl std::fmt::Debug for CaptureTruncateProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CaptureTruncateProvider") + .field("schema", &self.schema) + .finish() + } +} + +#[async_trait] +impl TableProvider for CaptureTruncateProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &dyn Session, + _projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + Ok(Arc::new(EmptyExec::new(Arc::clone(&self.schema)))) + } + + async fn truncate(&self, _state: &dyn Session) -> Result> { + *self.truncate_called.lock().unwrap() = true; + + Ok(Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![ + Field::new("count", DataType::UInt64, false), + ]))))) + } +} + +fn test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("status", DataType::Utf8, true), + Field::new("value", DataType::Int32, true), + ])) +} + +#[tokio::test] +async fn test_delete_single_filter() -> Result<()> { + let provider = Arc::new(CaptureDeleteProvider::new(test_schema())); + let ctx = SessionContext::new(); + ctx.register_table("t", Arc::clone(&provider) as Arc)?; + + ctx.sql("DELETE FROM t WHERE id = 1") + .await? + .collect() + .await?; + + let filters = provider + .captured_filters() + .expect("filters should be captured"); + assert_eq!(filters.len(), 1); + assert!(filters[0].to_string().contains("id")); + Ok(()) +} + +#[tokio::test] +async fn test_delete_multiple_filters() -> Result<()> { + let provider = Arc::new(CaptureDeleteProvider::new(test_schema())); + let ctx = SessionContext::new(); + ctx.register_table("t", Arc::clone(&provider) as Arc)?; + + ctx.sql("DELETE FROM t WHERE id = 1 AND status = 'x'") + .await? + .collect() + .await?; + + let filters = provider + .captured_filters() + .expect("filters should be captured"); + assert!(!filters.is_empty()); + Ok(()) +} + +#[tokio::test] +async fn test_delete_no_filters() -> Result<()> { + let provider = Arc::new(CaptureDeleteProvider::new(test_schema())); + let ctx = SessionContext::new(); + ctx.register_table("t", Arc::clone(&provider) as Arc)?; + + ctx.sql("DELETE FROM t").await?.collect().await?; + + let filters = provider + .captured_filters() + .expect("filters should be captured"); + assert!( + filters.is_empty(), + "DELETE without WHERE should have empty filters" + ); + Ok(()) +} + +#[tokio::test] +async fn test_delete_complex_expr() -> Result<()> { + let provider = Arc::new(CaptureDeleteProvider::new(test_schema())); + let ctx = SessionContext::new(); + ctx.register_table("t", Arc::clone(&provider) as Arc)?; + + ctx.sql("DELETE FROM t WHERE id > 5 AND (status = 'a' OR status = 'b')") + .await? + .collect() + .await?; + + let filters = provider + .captured_filters() + .expect("filters should be captured"); + assert!(!filters.is_empty()); + Ok(()) +} + +#[tokio::test] +async fn test_delete_filter_pushdown_extracts_table_scan_filters() -> Result<()> { + let provider = Arc::new(CaptureDeleteProvider::new_with_filter_pushdown( + test_schema(), + TableProviderFilterPushDown::Exact, + )); + let ctx = SessionContext::new(); + ctx.register_table("t", Arc::clone(&provider) as Arc)?; + + let df = ctx.sql("DELETE FROM t WHERE id = 1").await?; + let optimized_plan = df.clone().into_optimized_plan()?; + + let mut scan_filters = Vec::new(); + optimized_plan.apply(|node| { + if let LogicalPlan::TableScan(TableScan { filters, .. }) = node { + scan_filters.extend(filters.clone()); + } + Ok(TreeNodeRecursion::Continue) + })?; + + assert_eq!(scan_filters.len(), 1); + assert!(scan_filters[0].to_string().contains("id")); + + df.collect().await?; + + let filters = provider + .captured_filters() + .expect("filters should be captured"); + assert_eq!(filters.len(), 1); + assert!(filters[0].to_string().contains("id")); + Ok(()) +} + +#[tokio::test] +async fn test_delete_compound_filters_with_pushdown() -> Result<()> { + let provider = Arc::new(CaptureDeleteProvider::new_with_filter_pushdown( + test_schema(), + TableProviderFilterPushDown::Exact, + )); + let ctx = SessionContext::new(); + ctx.register_table("t", Arc::clone(&provider) as Arc)?; + + ctx.sql("DELETE FROM t WHERE id = 1 AND status = 'active'") + .await? + .collect() + .await?; + + let filters = provider + .captured_filters() + .expect("filters should be captured"); + // Should receive both filters, not deduplicate valid separate predicates + assert_eq!( + filters.len(), + 2, + "compound filters should not be over-suppressed" + ); + + let filter_strs: Vec = filters.iter().map(|f| f.to_string()).collect(); + assert!( + filter_strs.iter().any(|s| s.contains("id")), + "should contain id filter" + ); + assert!( + filter_strs.iter().any(|s| s.contains("status")), + "should contain status filter" + ); + Ok(()) +} + +#[tokio::test] +async fn test_delete_mixed_filter_locations() -> Result<()> { + // Test mixed-location filters: some in Filter node, some in TableScan.filters + // This happens when provider uses TableProviderFilterPushDown::Inexact, + // meaning it can push down some predicates but not others. + let provider = Arc::new(CaptureDeleteProvider::new_with_filter_pushdown( + test_schema(), + TableProviderFilterPushDown::Inexact, + )); + let ctx = SessionContext::new(); + ctx.register_table("t", Arc::clone(&provider) as Arc)?; + + // Execute DELETE with compound WHERE clause + ctx.sql("DELETE FROM t WHERE id = 1 AND status = 'active'") + .await? + .collect() + .await?; + + // Verify that both predicates are extracted and passed to delete_from(), + // even though they may be split between Filter node and TableScan.filters + let filters = provider + .captured_filters() + .expect("filters should be captured"); + assert_eq!( + filters.len(), + 2, + "should extract both predicates (union of Filter and TableScan.filters)" + ); + + let filter_strs: Vec = filters.iter().map(|f| f.to_string()).collect(); + assert!( + filter_strs.iter().any(|s| s.contains("id")), + "should contain id filter" + ); + assert!( + filter_strs.iter().any(|s| s.contains("status")), + "should contain status filter" + ); + Ok(()) +} + +#[tokio::test] +async fn test_delete_per_filter_pushdown_mixed_locations() -> Result<()> { + // Force per-filter pushdown decisions to exercise mixed locations in one query. + // First predicate is pushed down (Exact), second stays as residual (Unsupported). + let provider = Arc::new(CaptureDeleteProvider::new_with_per_filter_pushdown( + test_schema(), + vec![ + TableProviderFilterPushDown::Exact, + TableProviderFilterPushDown::Unsupported, + ], + )); + + let ctx = SessionContext::new(); + ctx.register_table("t", Arc::clone(&provider) as Arc)?; + + let df = ctx + .sql("DELETE FROM t WHERE id = 1 AND status = 'active'") + .await?; + let optimized_plan = df.clone().into_optimized_plan()?; + + // Only the first predicate should be pushed to TableScan.filters. + let mut scan_filters = Vec::new(); + optimized_plan.apply(|node| { + if let LogicalPlan::TableScan(TableScan { filters, .. }) = node { + scan_filters.extend(filters.clone()); + } + Ok(TreeNodeRecursion::Continue) + })?; + assert_eq!(scan_filters.len(), 1); + assert!(scan_filters[0].to_string().contains("id")); + + // Both predicates should still reach the provider (union + dedup behavior). + df.collect().await?; + + let filters = provider + .captured_filters() + .expect("filters should be captured"); + assert_eq!(filters.len(), 2); + + let filter_strs: Vec = filters.iter().map(|f| f.to_string()).collect(); + assert!( + filter_strs.iter().any(|s| s.contains("id")), + "should contain pushed-down id filter" + ); + assert!( + filter_strs.iter().any(|s| s.contains("status")), + "should contain residual status filter" + ); + + Ok(()) +} + +#[tokio::test] +async fn test_update_assignments() -> Result<()> { + let provider = Arc::new(CaptureUpdateProvider::new(test_schema())); + let ctx = SessionContext::new(); + ctx.register_table("t", Arc::clone(&provider) as Arc)?; + + ctx.sql("UPDATE t SET value = 100, status = 'updated' WHERE id = 5") + .await? + .collect() + .await?; + + let assignments = provider + .captured_assignments() + .expect("assignments should be captured"); + assert_eq!(assignments.len(), 2, "should have 2 assignments"); + + let filters = provider + .captured_filters() + .expect("filters should be captured"); + assert!(!filters.is_empty(), "should have filter for WHERE clause"); + Ok(()) +} + +#[tokio::test] +async fn test_update_filter_pushdown_extracts_table_scan_filters() -> Result<()> { + let provider = Arc::new(CaptureUpdateProvider::new_with_filter_pushdown( + test_schema(), + TableProviderFilterPushDown::Exact, + )); + let ctx = SessionContext::new(); + ctx.register_table("t", Arc::clone(&provider) as Arc)?; + + let df = ctx.sql("UPDATE t SET value = 100 WHERE id = 1").await?; + let optimized_plan = df.clone().into_optimized_plan()?; + + // Verify that the optimizer pushed down the filter into TableScan + let mut scan_filters = Vec::new(); + optimized_plan.apply(|node| { + if let LogicalPlan::TableScan(TableScan { filters, .. }) = node { + scan_filters.extend(filters.clone()); + } + Ok(TreeNodeRecursion::Continue) + })?; + + assert_eq!(scan_filters.len(), 1); + assert!(scan_filters[0].to_string().contains("id")); + + // Execute the UPDATE and verify filters were extracted and passed to update() + df.collect().await?; + + let filters = provider + .captured_filters() + .expect("filters should be captured"); + assert_eq!(filters.len(), 1); + assert!(filters[0].to_string().contains("id")); + Ok(()) +} + +#[tokio::test] +async fn test_update_filter_pushdown_passes_table_scan_filters() -> Result<()> { + let provider = Arc::new(CaptureUpdateProvider::new_with_filter_pushdown( + test_schema(), + TableProviderFilterPushDown::Exact, + )); + let ctx = SessionContext::new(); + ctx.register_table("t", Arc::clone(&provider) as Arc)?; + + let df = ctx + .sql("UPDATE t SET value = 42 WHERE status = 'ready'") + .await?; + let optimized_plan = df.clone().into_optimized_plan()?; + + let mut scan_filters = Vec::new(); + optimized_plan.apply(|node| { + if let LogicalPlan::TableScan(TableScan { filters, .. }) = node { + scan_filters.extend(filters.clone()); + } + Ok(TreeNodeRecursion::Continue) + })?; + + assert!( + !scan_filters.is_empty(), + "expected filter pushdown to populate TableScan filters" + ); + + df.collect().await?; + + let filters = provider + .captured_filters() + .expect("filters should be captured"); + assert!( + !filters.is_empty(), + "expected filters extracted from TableScan during UPDATE" + ); + Ok(()) +} + +#[tokio::test] +async fn test_truncate_calls_provider() -> Result<()> { + let provider = Arc::new(CaptureTruncateProvider::new(test_schema())); + let config = SessionConfig::new().set( + "datafusion.optimizer.max_passes", + &ScalarValue::UInt64(Some(0)), + ); + + let ctx = SessionContext::new_with_config(config); + + ctx.register_table("t", Arc::clone(&provider) as Arc)?; + + ctx.sql("TRUNCATE TABLE t").await?.collect().await?; + + assert!( + provider.was_truncated(), + "truncate() should be called on the TableProvider" + ); + + Ok(()) +} + +#[tokio::test] +async fn test_unsupported_table_delete() -> Result<()> { + let schema = test_schema(); + let ctx = SessionContext::new(); + + let empty_table = datafusion::datasource::empty::EmptyTable::new(schema); + ctx.register_table("empty_t", Arc::new(empty_table))?; + + let result = ctx.sql("DELETE FROM empty_t WHERE id = 1").await; + assert!(result.is_err() || result.unwrap().collect().await.is_err()); + Ok(()) +} + +#[tokio::test] +async fn test_unsupported_table_update() -> Result<()> { + let schema = test_schema(); + let ctx = SessionContext::new(); + + let empty_table = datafusion::datasource::empty::EmptyTable::new(schema); + ctx.register_table("empty_t", Arc::new(empty_table))?; + + let result = ctx.sql("UPDATE empty_t SET value = 1 WHERE id = 1").await; + + assert!(result.is_err() || result.unwrap().collect().await.is_err()); + Ok(()) +} + +#[tokio::test] +async fn test_delete_target_table_scoping() -> Result<()> { + // Test that DELETE only extracts filters from the target table, + // not from other tables (important for DELETE...FROM safety) + let target_provider = Arc::new(CaptureDeleteProvider::new_with_filter_pushdown( + test_schema(), + TableProviderFilterPushDown::Exact, + )); + let ctx = SessionContext::new(); + ctx.register_table( + "target_t", + Arc::clone(&target_provider) as Arc, + )?; + + // For now, we test single-table DELETE + // and validate that the scoping logic is correct + let df = ctx.sql("DELETE FROM target_t WHERE id > 5").await?; + df.collect().await?; + + let filters = target_provider + .captured_filters() + .expect("filters should be captured"); + assert_eq!(filters.len(), 1); + assert!( + filters[0].to_string().contains("id"), + "Filter should be for id column" + ); + assert!( + filters[0].to_string().contains("5"), + "Filter should contain the value 5" + ); + Ok(()) +} + +#[tokio::test] +async fn test_update_from_drops_non_target_predicates() -> Result<()> { + // UPDATE ... FROM is currently not working + // TODO fix https://github.com/apache/datafusion/issues/19950 + let target_provider = Arc::new(CaptureUpdateProvider::new_with_filter_pushdown( + test_schema(), + TableProviderFilterPushDown::Exact, + )); + let ctx = SessionContext::new(); + ctx.register_table("t1", Arc::clone(&target_provider) as Arc)?; + + let source_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("status", DataType::Utf8, true), + // t2-only column to avoid false negatives after qualifier stripping + Field::new("src_only", DataType::Utf8, true), + ])); + let source_table = datafusion::datasource::empty::EmptyTable::new(source_schema); + ctx.register_table("t2", Arc::new(source_table))?; + + let result = ctx + .sql( + "UPDATE t1 SET value = 1 FROM t2 \ + WHERE t1.id = t2.id AND t2.src_only = 'active' AND t1.value > 10", + ) + .await; + + // Verify UPDATE ... FROM is rejected with appropriate error + // TODO fix https://github.com/apache/datafusion/issues/19950 + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + err.to_string().contains("UPDATE ... FROM is not supported"), + "Expected 'UPDATE ... FROM is not supported' error, got: {err}" + ); + Ok(()) +} + +#[tokio::test] +async fn test_delete_qualifier_stripping_and_validation() -> Result<()> { + // Test that filter qualifiers are properly stripped and validated + // Unqualified predicates should work fine + let provider = Arc::new(CaptureDeleteProvider::new_with_filter_pushdown( + test_schema(), + TableProviderFilterPushDown::Exact, + )); + let ctx = SessionContext::new(); + ctx.register_table("t", Arc::clone(&provider) as Arc)?; + + // Execute DELETE with unqualified column reference + // (After parsing, the planner adds qualifiers, but our validation should accept them) + let df = ctx.sql("DELETE FROM t WHERE id = 1").await?; + df.collect().await?; + + let filters = provider + .captured_filters() + .expect("filters should be captured"); + assert!(!filters.is_empty(), "Should have extracted filter"); + + // Verify qualifiers are stripped: check that Column expressions have no qualifier + let has_qualified_column = filters[0] + .exists(|expr| Ok(matches!(expr, Expr::Column(col) if col.relation.is_some())))?; + assert!( + !has_qualified_column, + "Filter should have unqualified columns after stripping" + ); + + // Also verify the string representation doesn't contain table qualifiers + let filter_str = filters[0].to_string(); + assert!( + !filter_str.contains("t.id"), + "Filter should not contain qualified column reference, got: {filter_str}" + ); + assert!( + filter_str.contains("id") || filter_str.contains("1"), + "Filter should reference id column or the value 1, got: {filter_str}" + ); + Ok(()) +} + +#[tokio::test] +async fn test_unsupported_table_truncate() -> Result<()> { + let schema = test_schema(); + let ctx = SessionContext::new(); + + let empty_table = datafusion::datasource::empty::EmptyTable::new(schema); + ctx.register_table("empty_t", Arc::new(empty_table))?; + + let result = ctx.sql("TRUNCATE TABLE empty_t").await; + + assert!(result.is_err() || result.unwrap().collect().await.is_err()); + + Ok(()) +} diff --git a/datafusion/core/tests/custom_sources_cases/mod.rs b/datafusion/core/tests/custom_sources_cases/mod.rs index cbdc4a448ea41..6919d9794b29e 100644 --- a/datafusion/core/tests/custom_sources_cases/mod.rs +++ b/datafusion/core/tests/custom_sources_cases/mod.rs @@ -28,25 +28,27 @@ use datafusion::datasource::{TableProvider, TableType}; use datafusion::error::Result; use datafusion::execution::context::{SessionContext, TaskContext}; use datafusion::logical_expr::{ - col, Expr, LogicalPlan, LogicalPlanBuilder, TableScan, UNNAMED_TABLE, + Expr, LogicalPlan, LogicalPlanBuilder, TableScan, UNNAMED_TABLE, col, }; use datafusion::physical_plan::{ - collect, ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, - RecordBatchStream, SendableRecordBatchStream, Statistics, + ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, + RecordBatchStream, SendableRecordBatchStream, Statistics, collect, }; use datafusion::scalar::ScalarValue; use datafusion_catalog::Session; use datafusion_common::cast::as_primitive_array; use datafusion_common::project_schema; use datafusion_common::stats::Precision; +use datafusion_common::tree_node::TreeNodeRecursion; use datafusion_physical_expr::EquivalenceProperties; +use datafusion_physical_plan::PlanProperties; use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; -use datafusion_physical_plan::PlanProperties; use async_trait::async_trait; use futures::stream::Stream; +mod dml_planning; mod provider_filter_pushdown; mod statistics; @@ -78,7 +80,7 @@ struct CustomTableProvider; #[derive(Debug, Clone)] struct CustomExecutionPlan { projection: Option>, - cache: PlanProperties, + cache: Arc, } impl CustomExecutionPlan { @@ -87,7 +89,10 @@ impl CustomExecutionPlan { let schema = project_schema(&schema, projection.as_ref()).expect("projected schema"); let cache = Self::compute_properties(schema); - Self { projection, cache } + Self { + projection, + cache: Arc::new(cache), + } } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. @@ -156,7 +161,7 @@ impl ExecutionPlan for CustomExecutionPlan { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -179,16 +184,12 @@ impl ExecutionPlan for CustomExecutionPlan { Ok(Box::pin(TestCustomRecordBatchStream { nb_batch: 1 })) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - - fn partition_statistics(&self, partition: Option) -> Result { + fn partition_statistics(&self, partition: Option) -> Result> { if partition.is_some() { - return Ok(Statistics::new_unknown(&self.schema())); + return Ok(Arc::new(Statistics::new_unknown(&self.schema()))); } let batch = TEST_CUSTOM_RECORD_BATCH!().unwrap(); - Ok(Statistics { + Ok(Arc::new(Statistics { num_rows: Precision::Exact(batch.num_rows()), total_byte_size: Precision::Absent, column_statistics: self @@ -207,7 +208,23 @@ impl ExecutionPlan for CustomExecutionPlan { ..Default::default() }) .collect(), - }) + })) + } + + fn apply_expressions( + &self, + f: &mut dyn FnMut( + &dyn datafusion::physical_plan::PhysicalExpr, + ) -> Result, + ) -> Result { + // Visit expressions in the output ordering from equivalence properties + let mut tnr = TreeNodeRecursion::Continue; + if let Some(ordering) = self.cache.output_ordering() { + for sort_expr in ordering { + tnr = tnr.visit_sibling(|| f(sort_expr.expr.as_ref()))?; + } + } + Ok(tnr) } } @@ -316,6 +333,7 @@ async fn optimizers_catch_all_statistics() { assert_eq!(format!("{:?}", actual[0]), format!("{expected:?}")); } +#[expect(clippy::needless_pass_by_value)] fn contains_place_holder_exec(plan: Arc) -> bool { if plan.as_any().is::() { true diff --git a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs index f68bcfaf15507..8078b0a7ec158 100644 --- a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs +++ b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs @@ -29,13 +29,14 @@ use datafusion::logical_expr::TableProviderFilterPushDown; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, - SendableRecordBatchStream, Statistics, + SendableRecordBatchStream, }; use datafusion::prelude::*; use datafusion::scalar::ScalarValue; use datafusion_catalog::Session; use datafusion_common::cast::as_primitive_array; -use datafusion_common::{internal_err, not_impl_err}; +use datafusion_common::tree_node::TreeNodeRecursion; +use datafusion_common::{DataFusionError, internal_err, not_impl_err}; use datafusion_expr::expr::{BinaryExpr, Cast}; use datafusion_functions_aggregate::expr_fn::count; use datafusion_physical_expr::EquivalenceProperties; @@ -62,13 +63,16 @@ fn create_batch(value: i32, num_rows: usize) -> Result { #[derive(Debug)] struct CustomPlan { batches: Vec, - cache: PlanProperties, + cache: Arc, } impl CustomPlan { fn new(schema: SchemaRef, batches: Vec) -> Self { let cache = Self::compute_properties(schema); - Self { batches, cache } + Self { + batches, + cache: Arc::new(cache), + } } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. @@ -109,7 +113,7 @@ impl ExecutionPlan for CustomPlan { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -134,16 +138,36 @@ impl ExecutionPlan for CustomPlan { _partition: usize, _context: Arc, ) -> Result { + let schema_captured = self.schema().clone(); Ok(Box::pin(RecordBatchStreamAdapter::new( self.schema(), - futures::stream::iter(self.batches.clone().into_iter().map(Ok)), + futures::stream::iter(self.batches.clone().into_iter().map(move |batch| { + let projection: Vec = schema_captured + .fields() + .iter() + .filter_map(|field| batch.schema().index_of(field.name()).ok()) + .collect(); + batch + .project(&projection) + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None)) + })), ))) } - fn statistics(&self) -> Result { - // here we could provide more accurate statistics - // but we want to test the filter pushdown not the CBOs - Ok(Statistics::new_unknown(&self.schema())) + fn apply_expressions( + &self, + f: &mut dyn FnMut( + &dyn datafusion::physical_plan::PhysicalExpr, + ) -> Result, + ) -> Result { + // Visit expressions in the output ordering from equivalence properties + let mut tnr = TreeNodeRecursion::Continue; + if let Some(ordering) = self.cache.output_ordering() { + for sort_expr in ordering { + tnr = tnr.visit_sibling(|| f(sort_expr.expr.as_ref()))?; + } + } + Ok(tnr) } } @@ -179,12 +203,12 @@ impl TableProvider for CustomProvider { match &filters[0] { Expr::BinaryExpr(BinaryExpr { right, .. }) => { let int_value = match &**right { - Expr::Literal(ScalarValue::Int8(Some(i))) => *i as i64, - Expr::Literal(ScalarValue::Int16(Some(i))) => *i as i64, - Expr::Literal(ScalarValue::Int32(Some(i))) => *i as i64, - Expr::Literal(ScalarValue::Int64(Some(i))) => *i, - Expr::Cast(Cast { expr, data_type: _ }) => match expr.deref() { - Expr::Literal(lit_value) => match lit_value { + Expr::Literal(ScalarValue::Int8(Some(i)), _) => *i as i64, + Expr::Literal(ScalarValue::Int16(Some(i)), _) => *i as i64, + Expr::Literal(ScalarValue::Int32(Some(i)), _) => *i as i64, + Expr::Literal(ScalarValue::Int64(Some(i)), _) => *i, + Expr::Cast(Cast { expr, field: _ }) => match expr.deref() { + Expr::Literal(lit_value, _) => match lit_value { ScalarValue::Int8(Some(v)) => *v as i64, ScalarValue::Int16(Some(v)) => *v as i64, ScalarValue::Int32(Some(v)) => *v as i64, diff --git a/datafusion/core/tests/custom_sources_cases/statistics.rs b/datafusion/core/tests/custom_sources_cases/statistics.rs index f9b0db0e808c0..561c6b3b246ff 100644 --- a/datafusion/core/tests/custom_sources_cases/statistics.rs +++ b/datafusion/core/tests/custom_sources_cases/statistics.rs @@ -33,6 +33,7 @@ use datafusion::{ scalar::ScalarValue, }; use datafusion_catalog::Session; +use datafusion_common::tree_node::TreeNodeRecursion; use datafusion_common::{project_schema, stats::Precision}; use datafusion_physical_expr::EquivalenceProperties; use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; @@ -45,7 +46,7 @@ use async_trait::async_trait; struct StatisticsValidation { stats: Statistics, schema: Arc, - cache: PlanProperties, + cache: Arc, } impl StatisticsValidation { @@ -59,7 +60,7 @@ impl StatisticsValidation { Self { stats, schema, - cache, + cache: Arc::new(cache), } } @@ -158,7 +159,7 @@ impl ExecutionPlan for StatisticsValidation { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -181,16 +182,28 @@ impl ExecutionPlan for StatisticsValidation { unimplemented!("This plan only serves for testing statistics") } - fn statistics(&self) -> Result { - Ok(self.stats.clone()) - } - - fn partition_statistics(&self, partition: Option) -> Result { + fn partition_statistics(&self, partition: Option) -> Result> { if partition.is_some() { - Ok(Statistics::new_unknown(&self.schema)) + Ok(Arc::new(Statistics::new_unknown(&self.schema))) } else { - Ok(self.stats.clone()) + Ok(Arc::new(self.stats.clone())) + } + } + + fn apply_expressions( + &self, + f: &mut dyn FnMut( + &dyn datafusion::physical_plan::PhysicalExpr, + ) -> Result, + ) -> Result { + // Visit expressions in the output ordering from equivalence properties + let mut tnr = TreeNodeRecursion::Continue; + if let Some(ordering) = self.cache.output_ordering() { + for sort_expr in ordering { + tnr = tnr.visit_sibling(|| f(sort_expr.expr.as_ref()))?; + } } + Ok(tnr) } } @@ -214,6 +227,7 @@ fn fully_defined() -> (Statistics, Schema) { min_value: Precision::Exact(ScalarValue::Int32(Some(-24))), sum_value: Precision::Exact(ScalarValue::Int64(Some(10))), null_count: Precision::Exact(0), + byte_size: Precision::Absent, }, ColumnStatistics { distinct_count: Precision::Exact(13), @@ -221,6 +235,7 @@ fn fully_defined() -> (Statistics, Schema) { min_value: Precision::Exact(ScalarValue::Int64(Some(-6783))), sum_value: Precision::Exact(ScalarValue::Int64(Some(10))), null_count: Precision::Exact(5), + byte_size: Precision::Absent, }, ], }, @@ -240,7 +255,7 @@ async fn sql_basic() -> Result<()> { let physical_plan = df.create_physical_plan().await.unwrap(); // the statistics should be those of the source - assert_eq!(stats, physical_plan.partition_statistics(None)?); + assert_eq!(stats, *physical_plan.partition_statistics(None)?); Ok(()) } @@ -265,20 +280,22 @@ async fn sql_filter() -> Result<()> { #[tokio::test] async fn sql_limit() -> Result<()> { let (stats, schema) = fully_defined(); - let col_stats = Statistics::unknown_column(&schema); let ctx = init_ctx(stats.clone(), schema)?; let df = ctx.sql("SELECT * FROM stats_table LIMIT 5").await.unwrap(); let physical_plan = df.create_physical_plan().await.unwrap(); - // when the limit is smaller than the original number of lines - // we loose all statistics except the for number of rows which becomes the limit + // when the limit is smaller than the original number of lines we mark the statistics as inexact assert_eq!( Statistics { num_rows: Precision::Exact(5), - column_statistics: col_stats, + column_statistics: stats + .column_statistics + .iter() + .map(|c| c.clone().to_inexact()) + .collect(), total_byte_size: Precision::Absent }, - physical_plan.partition_statistics(None)? + *physical_plan.partition_statistics(None)? ); let df = ctx @@ -287,7 +304,7 @@ async fn sql_limit() -> Result<()> { .unwrap(); let physical_plan = df.create_physical_plan().await.unwrap(); // when the limit is larger than the original number of lines, statistics remain unchanged - assert_eq!(stats, physical_plan.partition_statistics(None)?); + assert_eq!(stats, *physical_plan.partition_statistics(None)?); Ok(()) } @@ -307,7 +324,7 @@ async fn sql_window() -> Result<()> { let result = physical_plan.partition_statistics(None)?; assert_eq!(stats.num_rows, result.num_rows); - let col_stats = result.column_statistics; + let col_stats = &result.column_statistics; assert_eq!(2, col_stats.len()); assert_eq!(stats.column_statistics[1], col_stats[0]); diff --git a/datafusion/core/tests/data/empty_files/some_empty_with_header/a_empty.csv b/datafusion/core/tests/data/empty_files/some_empty_with_header/a_empty.csv new file mode 100644 index 0000000000000..f1968a0906d09 --- /dev/null +++ b/datafusion/core/tests/data/empty_files/some_empty_with_header/a_empty.csv @@ -0,0 +1 @@ +c1,c2,c3 diff --git a/datafusion/core/tests/data/empty_files/some_empty_with_header/b.csv b/datafusion/core/tests/data/empty_files/some_empty_with_header/b.csv new file mode 100644 index 0000000000000..ff596071444c3 --- /dev/null +++ b/datafusion/core/tests/data/empty_files/some_empty_with_header/b.csv @@ -0,0 +1,3 @@ +c1,c2,c3 +1,1,1 +2,2,2 diff --git a/datafusion/core/tests/data/empty_files/some_empty_with_header/c_nulls_column.csv b/datafusion/core/tests/data/empty_files/some_empty_with_header/c_nulls_column.csv new file mode 100644 index 0000000000000..bf86844cb0293 --- /dev/null +++ b/datafusion/core/tests/data/empty_files/some_empty_with_header/c_nulls_column.csv @@ -0,0 +1,2 @@ +c1,c2,c3 +3,3, diff --git a/datafusion/core/tests/data/json_array.json b/datafusion/core/tests/data/json_array.json new file mode 100644 index 0000000000000..1a8716dbf4beb --- /dev/null +++ b/datafusion/core/tests/data/json_array.json @@ -0,0 +1,5 @@ +[ + {"a": 1, "b": "hello"}, + {"a": 2, "b": "world"}, + {"a": 3, "b": "test"} +] diff --git a/datafusion/core/tests/data/json_empty_array.json b/datafusion/core/tests/data/json_empty_array.json new file mode 100644 index 0000000000000..fe51488c7066f --- /dev/null +++ b/datafusion/core/tests/data/json_empty_array.json @@ -0,0 +1 @@ +[] diff --git a/datafusion/core/tests/data/partitioned_table_arrow_stream/part=123/data.arrow b/datafusion/core/tests/data/partitioned_table_arrow_stream/part=123/data.arrow new file mode 100644 index 0000000000000..bad9e3de4a57f Binary files /dev/null and b/datafusion/core/tests/data/partitioned_table_arrow_stream/part=123/data.arrow differ diff --git a/datafusion/core/tests/data/partitioned_table_arrow_stream/part=456/data.arrow b/datafusion/core/tests/data/partitioned_table_arrow_stream/part=456/data.arrow new file mode 100644 index 0000000000000..4a07fbfa47f32 Binary files /dev/null and b/datafusion/core/tests/data/partitioned_table_arrow_stream/part=456/data.arrow differ diff --git a/datafusion/core/tests/data/recursive_cte/closure.csv b/datafusion/core/tests/data/recursive_cte/closure.csv new file mode 100644 index 0000000000000..a31e2bfbf36b6 --- /dev/null +++ b/datafusion/core/tests/data/recursive_cte/closure.csv @@ -0,0 +1,6 @@ +start,end +1,2 +2,3 +2,4 +2,4 +4,1 \ No newline at end of file diff --git a/datafusion/core/tests/data/tpch_customer_small.parquet b/datafusion/core/tests/data/tpch_customer_small.parquet new file mode 100644 index 0000000000000..3d5f73ef3a066 Binary files /dev/null and b/datafusion/core/tests/data/tpch_customer_small.parquet differ diff --git a/datafusion/core/tests/data/tpch_lineitem_small.parquet b/datafusion/core/tests/data/tpch_lineitem_small.parquet new file mode 100644 index 0000000000000..5e98706669d3b Binary files /dev/null and b/datafusion/core/tests/data/tpch_lineitem_small.parquet differ diff --git a/datafusion/core/tests/data/tpch_nation_small.parquet b/datafusion/core/tests/data/tpch_nation_small.parquet new file mode 100644 index 0000000000000..99da99594cf89 Binary files /dev/null and b/datafusion/core/tests/data/tpch_nation_small.parquet differ diff --git a/datafusion/core/tests/data/tpch_orders_small.parquet b/datafusion/core/tests/data/tpch_orders_small.parquet new file mode 100644 index 0000000000000..79e043137caf6 Binary files /dev/null and b/datafusion/core/tests/data/tpch_orders_small.parquet differ diff --git a/datafusion/core/tests/data/tpch_part_small.parquet b/datafusion/core/tests/data/tpch_part_small.parquet new file mode 100644 index 0000000000000..d8e1d7d680aa2 Binary files /dev/null and b/datafusion/core/tests/data/tpch_part_small.parquet differ diff --git a/datafusion/core/tests/data/tpch_partsupp_small.parquet b/datafusion/core/tests/data/tpch_partsupp_small.parquet new file mode 100644 index 0000000000000..711d58dda7493 Binary files /dev/null and b/datafusion/core/tests/data/tpch_partsupp_small.parquet differ diff --git a/datafusion/core/tests/data/tpch_region_small.parquet b/datafusion/core/tests/data/tpch_region_small.parquet new file mode 100644 index 0000000000000..5e00a1f6da1d9 Binary files /dev/null and b/datafusion/core/tests/data/tpch_region_small.parquet differ diff --git a/datafusion/core/tests/data/tpch_supplier_small.parquet b/datafusion/core/tests/data/tpch_supplier_small.parquet new file mode 100644 index 0000000000000..18323395fcbed Binary files /dev/null and b/datafusion/core/tests/data/tpch_supplier_small.parquet differ diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index 40590d74ad910..014f356cd64cd 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{types::Int32Type, ListArray}; +use arrow::array::{ListArray, types::Int32Type}; use arrow::datatypes::SchemaRef; use arrow::datatypes::{DataType, Field, Schema}; use arrow::{ @@ -31,7 +31,7 @@ use datafusion::prelude::*; use datafusion_common::test_util::batches_to_string; use datafusion_common::{DFSchema, ScalarValue}; use datafusion_expr::expr::Alias; -use datafusion_expr::{table_scan, ExprSchemable, LogicalPlanBuilder}; +use datafusion_expr::{ExprSchemable, LogicalPlanBuilder, table_scan}; use datafusion_functions_aggregate::expr_fn::{approx_median, approx_percentile_cont}; use datafusion_functions_nested::map::map; use insta::assert_snapshot; @@ -274,6 +274,33 @@ async fn test_nvl2() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_nvl2_short_circuit() -> Result<()> { + let expr = nvl2( + col("a"), + arrow_cast(lit("1"), lit("Int32")), + arrow_cast(col("a"), lit("Int32")), + ); + + let batches = get_batches(expr).await?; + + assert_snapshot!( + batches_to_string(&batches), + @r#" + +-----------------------------------------------------------------------------------+ + | nvl2(test.a,arrow_cast(Utf8("1"),Utf8("Int32")),arrow_cast(test.a,Utf8("Int32"))) | + +-----------------------------------------------------------------------------------+ + | 1 | + | 1 | + | 1 | + | 1 | + +-----------------------------------------------------------------------------------+ + "# + ); + + Ok(()) +} #[tokio::test] async fn test_fn_arrow_typeof() -> Result<()> { let expr = arrow_typeof(col("l")); @@ -282,16 +309,16 @@ async fn test_fn_arrow_typeof() -> Result<()> { assert_snapshot!( batches_to_string(&batches), - @r#" - +------------------------------------------------------------------------------------------------------------------+ - | arrow_typeof(test.l) | - +------------------------------------------------------------------------------------------------------------------+ - | List(Field { name: "item", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) | - | List(Field { name: "item", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) | - | List(Field { name: "item", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) | - | List(Field { name: "item", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) | - +------------------------------------------------------------------------------------------------------------------+ - "#); + @r" + +----------------------+ + | arrow_typeof(test.l) | + +----------------------+ + | List(Int32) | + | List(Int32) | + | List(Int32) | + | List(Int32) | + +----------------------+ + "); Ok(()) } @@ -1215,7 +1242,7 @@ async fn test_fn_decode() -> Result<()> { // Note that the decode function returns binary, and the default display of // binary is "hexadecimal" and therefore the output looks like decode did // nothing. So compare to a constant. - let df_schema = DFSchema::try_from(test_schema().as_ref().clone())?; + let df_schema = DFSchema::try_from(test_schema())?; let expr = decode(encode(col("a"), lit("hex")), lit("hex")) // need to cast to utf8 otherwise the default display of binary array is hex // so it looks like nothing is done @@ -1316,3 +1343,28 @@ async fn test_count_wildcard() -> Result<()> { Ok(()) } + +/// Call count wildcard with alias from dataframe API +#[tokio::test] +async fn test_count_wildcard_with_alias() -> Result<()> { + let df = create_test_table().await?; + let result_df = df.aggregate(vec![], vec![count_all().alias("total_count")])?; + + let schema = result_df.schema(); + assert_eq!(schema.fields().len(), 1); + assert_eq!(schema.field(0).name(), "total_count"); + assert_eq!(*schema.field(0).data_type(), DataType::Int64); + + let batches = result_df.collect().await?; + assert_eq!(batches.len(), 1); + assert_eq!(batches[0].num_rows(), 1); + + let count_array = batches[0] + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(count_array.value(0), 4); + + Ok(()) +} diff --git a/datafusion/core/tests/dataframe/describe.rs b/datafusion/core/tests/dataframe/describe.rs index 9bd69dfa72b4c..c61fe4fed1615 100644 --- a/datafusion/core/tests/dataframe/describe.rs +++ b/datafusion/core/tests/dataframe/describe.rs @@ -17,7 +17,7 @@ use datafusion::prelude::{ParquetReadOptions, SessionContext}; use datafusion_common::test_util::batches_to_string; -use datafusion_common::{test_util::parquet_test_data, Result}; +use datafusion_common::{Result, test_util::parquet_test_data}; use insta::assert_snapshot; #[tokio::test] diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 089ff8808134d..80bbde1f6ba14 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -20,30 +20,33 @@ mod dataframe_functions; mod describe; use arrow::array::{ - record_batch, Array, ArrayRef, BooleanArray, DictionaryArray, FixedSizeListArray, - FixedSizeListBuilder, Float32Array, Float64Array, Int32Array, Int32Builder, - Int8Array, LargeListArray, ListArray, ListBuilder, RecordBatch, StringArray, - StringBuilder, StructBuilder, UInt32Array, UInt32Builder, UnionArray, + Array, ArrayRef, BooleanArray, DictionaryArray, FixedSizeListArray, + FixedSizeListBuilder, Float32Array, Float64Array, Int8Array, Int32Array, + Int32Builder, LargeListArray, ListArray, ListBuilder, RecordBatch, StringArray, + StringBuilder, StructBuilder, UInt32Array, UInt32Builder, UnionArray, record_batch, }; use arrow::buffer::ScalarBuffer; use arrow::datatypes::{ - DataType, Field, Float32Type, Int32Type, Schema, SchemaRef, UInt64Type, UnionFields, - UnionMode, + DataType, Field, Float32Type, Int32Type, Schema, UInt64Type, UnionFields, UnionMode, }; use arrow::error::ArrowError; use arrow::util::pretty::pretty_format_batches; +use arrow_schema::{SortOptions, TimeUnit}; use datafusion::{assert_batches_eq, dataframe}; +use datafusion_common::metadata::FieldMetadata; use datafusion_functions_aggregate::count::{count_all, count_all_window}; use datafusion_functions_aggregate::expr_fn::{ - array_agg, avg, count, count_distinct, max, median, min, sum, + array_agg, avg, avg_distinct, count, count_distinct, max, median, min, sum, + sum_distinct, }; use datafusion_functions_nested::make_array::make_array_udf; -use datafusion_functions_window::expr_fn::{first_value, row_number}; +use datafusion_functions_window::expr_fn::{first_value, lead, row_number}; use insta::assert_snapshot; use object_store::local::LocalFileSystem; -use sqlparser::ast::NullTreatment; +use rstest::rstest; use std::collections::HashMap; use std::fs; +use std::path::Path; use std::sync::Arc; use tempfile::TempDir; use url::Url; @@ -54,34 +57,43 @@ use datafusion::error::Result; use datafusion::execution::context::SessionContext; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::logical_expr::{ColumnarValue, Volatility}; -use datafusion::prelude::{ - CsvReadOptions, JoinType, NdJsonReadOptions, ParquetReadOptions, -}; +use datafusion::prelude::{CsvReadOptions, JoinType, ParquetReadOptions}; use datafusion::test_util::{ parquet_test_data, populate_csv_partitions, register_aggregate_csv, test_table, - test_table_with_name, + test_table_with_cache_factory, test_table_with_name, }; use datafusion_catalog::TableProvider; use datafusion_common::test_util::{batches_to_sort_string, batches_to_string}; use datafusion_common::{ - assert_contains, Constraint, Constraints, DataFusionError, ParamValues, ScalarValue, - TableReference, UnnestOptions, + Constraint, Constraints, DFSchema, DataFusionError, ScalarValue, SchemaError, + TableReference, UnnestOptions, assert_contains, internal_datafusion_err, }; use datafusion_common_runtime::SpawnedTask; +use datafusion_datasource::file_format::format_as_file_type; use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnv; -use datafusion_expr::expr::{GroupingSet, Sort, WindowFunction}; +use datafusion_expr::expr::{GroupingSet, NullTreatment, Sort, WindowFunction}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ - cast, col, create_udf, exists, in_subquery, lit, out_ref_col, placeholder, - scalar_subquery, when, wildcard, Expr, ExprFunctionExt, ExprSchemable, LogicalPlan, - ScalarFunctionImplementation, WindowFrame, WindowFrameBound, WindowFrameUnits, - WindowFunctionDefinition, + Expr, ExprFunctionExt, ExprSchemable, LogicalPlan, LogicalPlanBuilder, + ScalarFunctionImplementation, SortExpr, TableType, WindowFrame, WindowFrameBound, + WindowFrameUnits, WindowFunctionDefinition, cast, col, create_udf, exists, + in_subquery, lit, out_ref_col, placeholder, scalar_subquery, when, wildcard, }; -use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::Partitioning; +use datafusion_physical_expr::aggregate::AggregateExprBuilder; +use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_plan::{displayable, ExecutionPlanProperties}; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; +use datafusion_physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, +}; +use datafusion_physical_plan::empty::EmptyExec; +use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties, displayable}; + +use datafusion::error::Result as DataFusionResult; +use datafusion::execution::options::JsonReadOptions; +use datafusion_functions_window::expr_fn::lag; // Get string representation of the plan async fn physical_plan_to_string(df: &DataFrame) -> String { @@ -91,8 +103,8 @@ async fn physical_plan_to_string(df: &DataFrame) -> String { .await .expect("Error creating physical plan"); - let formated = displayable(physical_plan.as_ref()).indent(true); - formated.to_string() + let formatted = displayable(physical_plan.as_ref()).indent(true); + formatted.to_string() } pub fn table_with_constraints() -> Arc { @@ -117,8 +129,7 @@ pub fn table_with_constraints() -> Arc { } async fn assert_logical_expr_schema_eq_physical_expr_schema(df: DataFrame) -> Result<()> { - let logical_expr_dfschema = df.schema(); - let logical_expr_schema = SchemaRef::from(logical_expr_dfschema.to_owned()); + let logical_expr_schema = Arc::clone(df.schema().inner()); let batches = df.collect().await?; let physical_expr_schema = batches[0].schema(); assert_eq!(logical_expr_schema, physical_expr_schema); @@ -150,6 +161,46 @@ async fn test_array_agg_ord_schema() -> Result<()> { Ok(()) } +type WindowFnCase = (fn() -> Expr, &'static str); + +#[tokio::test] +async fn with_column_window_functions() -> DataFusionResult<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))], + )?; + + let ctx = SessionContext::new(); + + let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]])?; + ctx.register_table("t", Arc::new(provider))?; + + // Define test cases: (expr builder, alias name) + let test_cases: Vec = vec![ + (|| lag(col("a"), Some(1), None), "lag_val"), + (|| lead(col("a"), Some(1), None), "lead_val"), + (row_number, "row_num"), + ]; + + for (make_expr, alias) in test_cases { + let df = ctx.table("t").await?; + let expr = make_expr(); + let df_with = df.with_column(alias, expr)?; + let df_schema = df_with.schema().clone(); + + assert!( + df_schema.has_column_with_unqualified_name(alias), + "Schema does not contain expected column {alias}", + ); + + assert_eq!(2, df_schema.columns().len()); + } + + Ok(()) +} + #[tokio::test] async fn test_coalesce_schema() -> Result<()> { let ctx = SessionContext::new(); @@ -254,6 +305,27 @@ async fn select_columns() -> Result<()> { Ok(()) } +#[tokio::test] +async fn select_columns_with_nonexistent_columns() -> Result<()> { + let t = test_table().await?; + let t2 = t.select_columns(&["canada", "c2", "rocks"]); + + match t2 { + Err(DataFusionError::SchemaError(boxed_err, _)) => { + // Verify it's the first invalid column + match boxed_err.as_ref() { + SchemaError::FieldNotFound { field, .. } => { + assert_eq!(field.name(), "canada"); + } + _ => panic!("Expected SchemaError::FieldNotFound for 'canada'"), + } + } + _ => panic!("Expected SchemaError"), + } + + Ok(()) +} + #[tokio::test] async fn select_expr() -> Result<()> { // build plan using Table API @@ -341,16 +413,65 @@ async fn select_with_periods() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +------+ | f.c1 | +------+ | 1 | | 10 | +------+ - "### + " + ); + + Ok(()) +} + +#[tokio::test] +async fn select_columns_duplicated_names_from_different_qualifiers() -> Result<()> { + let t1 = test_table_with_name("t1") + .await? + .select_columns(&["c1"])? + .limit(0, Some(3))?; + let t2 = test_table_with_name("t2") + .await? + .select_columns(&["c1"])? + .limit(3, Some(3))?; + let t3 = test_table_with_name("t3") + .await? + .select_columns(&["c1"])? + .limit(6, Some(3))?; + + let join_res = t1 + .join(t2, JoinType::Left, &["t1.c1"], &["t2.c1"], None)? + .join(t3, JoinType::Left, &["t1.c1"], &["t3.c1"], None)?; + assert_snapshot!( + batches_to_sort_string(&join_res.clone().collect().await.unwrap()), + @r" + +----+----+----+ + | c1 | c1 | c1 | + +----+----+----+ + | b | b | | + | b | b | | + | c | | | + | d | | d | + +----+----+----+ + " ); + let select_res = join_res.select_columns(&["c1"])?; + assert_snapshot!( + batches_to_sort_string(&select_res.clone().collect().await.unwrap()), + @r" + +----+----+----+ + | c1 | c1 | c1 | + +----+----+----+ + | b | b | | + | b | b | | + | c | | | + | d | | d | + +----+----+----+ + " + ); Ok(()) } @@ -413,7 +534,8 @@ async fn drop_columns_with_nonexistent_columns() -> Result<()> { async fn drop_columns_with_empty_array() -> Result<()> { // build plan using Table API let t = test_table().await?; - let t2 = t.drop_columns(&[])?; + let drop_columns = vec![] as Vec<&str>; + let t2 = t.drop_columns(&drop_columns)?; let plan = t2.logical_plan().clone(); // build query using SQL @@ -428,6 +550,107 @@ async fn drop_columns_with_empty_array() -> Result<()> { Ok(()) } +#[tokio::test] +async fn drop_columns_qualified() -> Result<()> { + // build plan using Table API + let mut t = test_table().await?; + t = t.select_columns(&["c1", "c2", "c11"])?; + let mut t2 = test_table_with_name("another_table").await?; + t2 = t2.select_columns(&["c1", "c2", "c11"])?; + let mut t3 = t.join_on( + t2, + JoinType::Inner, + [col("aggregate_test_100.c1").eq(col("another_table.c1"))], + )?; + t3 = t3.drop_columns(&["another_table.c2", "another_table.c11"])?; + + let plan = t3.logical_plan().clone(); + + let sql = "SELECT aggregate_test_100.c1, aggregate_test_100.c2, aggregate_test_100.c11, another_table.c1 FROM (SELECT c1, c2, c11 FROM aggregate_test_100) INNER JOIN (SELECT c1, c2, c11 FROM another_table) ON aggregate_test_100.c1 = another_table.c1"; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx, "aggregate_test_100").await?; + register_aggregate_csv(&ctx, "another_table").await?; + let sql_plan = ctx.sql(sql).await?.into_unoptimized_plan(); + + // the two plans should be identical + assert_same_plan(&plan, &sql_plan); + + Ok(()) +} + +#[tokio::test] +async fn drop_columns_qualified_find_qualified() -> Result<()> { + // build plan using Table API + let mut t = test_table().await?; + t = t.select_columns(&["c1", "c2", "c11"])?; + let mut t2 = test_table_with_name("another_table").await?; + t2 = t2.select_columns(&["c1", "c2", "c11"])?; + let mut t3 = t.join_on( + t2.clone(), + JoinType::Inner, + [col("aggregate_test_100.c1").eq(col("another_table.c1"))], + )?; + t3 = t3.drop_columns(&t2.find_qualified_columns(&["c2", "c11"])?)?; + + let plan = t3.logical_plan().clone(); + + let sql = "SELECT aggregate_test_100.c1, aggregate_test_100.c2, aggregate_test_100.c11, another_table.c1 FROM (SELECT c1, c2, c11 FROM aggregate_test_100) INNER JOIN (SELECT c1, c2, c11 FROM another_table) ON aggregate_test_100.c1 = another_table.c1"; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx, "aggregate_test_100").await?; + register_aggregate_csv(&ctx, "another_table").await?; + let sql_plan = ctx.sql(sql).await?.into_unoptimized_plan(); + + // the two plans should be identical + assert_same_plan(&plan, &sql_plan); + + Ok(()) +} + +#[tokio::test] +async fn test_find_qualified_names() -> Result<()> { + let t = test_table().await?; + let column_names = ["c1", "c2", "c3"]; + let columns = t.find_qualified_columns(&column_names)?; + + // Expected results for each column + let binding = TableReference::bare("aggregate_test_100"); + let expected = [ + (Some(&binding), "c1"), + (Some(&binding), "c2"), + (Some(&binding), "c3"), + ]; + + // Verify we got the expected number of results + assert_eq!( + columns.len(), + expected.len(), + "Expected {} columns, got {}", + expected.len(), + columns.len() + ); + + // Iterate over the results and check each one individually + for (i, (actual, expected)) in columns.iter().zip(expected.iter()).enumerate() { + let (actual_table_ref, actual_field_ref) = actual; + let (expected_table_ref, expected_field_name) = expected; + + // Check table reference + assert_eq!( + actual_table_ref, expected_table_ref, + "Column {i}: expected table reference {expected_table_ref:?}, got {actual_table_ref:?}" + ); + + // Check field name + assert_eq!( + actual_field_ref.name(), + *expected_field_name, + "Column {i}: expected field name '{expected_field_name}', got '{actual_field_ref}'" + ); + } + + Ok(()) +} + #[tokio::test] async fn drop_with_quotes() -> Result<()> { // define data with a column name that has a "." in it: @@ -447,14 +670,14 @@ async fn drop_with_quotes() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r#" +------+ | f"c2 | +------+ | 11 | | 2 | +------+ - "### + "# ); Ok(()) @@ -473,20 +696,68 @@ async fn drop_with_periods() -> Result<()> { let ctx = SessionContext::new(); ctx.register_batch("t", batch)?; - let df = ctx.table("t").await?.drop_columns(&["f.c1"])?; + let df = ctx.table("t").await?.drop_columns(&["\"f.c1\""])?; let df_results = df.collect().await?; assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +------+ | f.c2 | +------+ | 11 | | 2 | +------+ - "### + " + ); + + Ok(()) +} + +#[tokio::test] +async fn drop_columns_duplicated_names_from_different_qualifiers() -> Result<()> { + let t1 = test_table_with_name("t1") + .await? + .select_columns(&["c1"])? + .limit(0, Some(3))?; + let t2 = test_table_with_name("t2") + .await? + .select_columns(&["c1"])? + .limit(3, Some(3))?; + let t3 = test_table_with_name("t3") + .await? + .select_columns(&["c1"])? + .limit(6, Some(3))?; + + let join_res = t1 + .join(t2, JoinType::LeftMark, &["c1"], &["c1"], None)? + .join(t3, JoinType::LeftMark, &["c1"], &["c1"], None)?; + assert_snapshot!( + batches_to_sort_string(&join_res.clone().collect().await.unwrap()), + @r" + +----+-------+-------+ + | c1 | mark | mark | + +----+-------+-------+ + | b | true | false | + | c | false | false | + | d | false | true | + +----+-------+-------+ + " + ); + + let drop_res = join_res.drop_columns(&["mark"])?; + assert_snapshot!( + batches_to_sort_string(&drop_res.clone().collect().await.unwrap()), + @r" + +----+ + | c1 | + +----+ + | b | + | c | + | d | + +----+ + " ); Ok(()) @@ -495,32 +766,35 @@ async fn drop_with_periods() -> Result<()> { #[tokio::test] async fn aggregate() -> Result<()> { // build plan using DataFrame API - let df = test_table().await?; + // union so some of the distincts have a clearly distinct result + let df = test_table().await?.union(test_table().await?)?; let group_expr = vec![col("c1")]; let aggr_expr = vec![ - min(col("c12")), - max(col("c12")), - avg(col("c12")), - sum(col("c12")), - count(col("c12")), - count_distinct(col("c12")), + min(col("c4")).alias("min(c4)"), + max(col("c4")).alias("max(c4)"), + avg(col("c4")).alias("avg(c4)"), + avg_distinct(col("c4")).alias("avg_distinct(c4)"), + sum(col("c4")).alias("sum(c4)"), + sum_distinct(col("c4")).alias("sum_distinct(c4)"), + count(col("c4")).alias("count(c4)"), + count_distinct(col("c4")).alias("count_distinct(c4)"), ]; let df: Vec = df.aggregate(group_expr, aggr_expr)?.collect().await?; assert_snapshot!( batches_to_sort_string(&df), - @r###" - +----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+ - | c1 | min(aggregate_test_100.c12) | max(aggregate_test_100.c12) | avg(aggregate_test_100.c12) | sum(aggregate_test_100.c12) | count(aggregate_test_100.c12) | count(DISTINCT aggregate_test_100.c12) | - +----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+ - | a | 0.02182578039211991 | 0.9800193410444061 | 0.48754517466109415 | 10.238448667882977 | 21 | 21 | - | b | 0.04893135681998029 | 0.9185813970744787 | 0.41040709263815384 | 7.797734760124923 | 19 | 19 | - | c | 0.0494924465469434 | 0.991517828651004 | 0.6600456536439784 | 13.860958726523545 | 21 | 21 | - | d | 0.061029375346466685 | 0.9748360509016578 | 0.48855379387549824 | 8.793968289758968 | 18 | 18 | - | e | 0.01479305307777301 | 0.9965400387585364 | 0.48600669271341534 | 10.206140546981722 | 21 | 21 | - +----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+ - "### + @r" + +----+---------+---------+---------------------+---------------------+---------+------------------+-----------+--------------------+ + | c1 | min(c4) | max(c4) | avg(c4) | avg_distinct(c4) | sum(c4) | sum_distinct(c4) | count(c4) | count_distinct(c4) | + +----+---------+---------+---------------------+---------------------+---------+------------------+-----------+--------------------+ + | a | -28462 | 32064 | 306.04761904761904 | 306.04761904761904 | 12854 | 6427 | 42 | 21 | + | b | -28070 | 25286 | 7732.315789473684 | 7732.315789473684 | 293828 | 146914 | 38 | 19 | + | c | -30508 | 29106 | -1320.5238095238096 | -1320.5238095238096 | -55462 | -27731 | 42 | 21 | + | d | -24558 | 31106 | 10890.111111111111 | 10890.111111111111 | 392044 | 196022 | 36 | 18 | + | e | -31500 | 32514 | -4268.333333333333 | -4268.333333333333 | -179270 | -89635 | 42 | 21 | + +----+---------+---------+---------------------+---------------------+---------+------------------+-----------+--------------------+ + " ); Ok(()) @@ -535,7 +809,9 @@ async fn aggregate_assert_no_empty_batches() -> Result<()> { min(col("c12")), max(col("c12")), avg(col("c12")), + avg_distinct(col("c12")), sum(col("c12")), + sum_distinct(col("c12")), count(col("c12")), count_distinct(col("c12")), median(col("c12")), @@ -570,23 +846,23 @@ async fn test_aggregate_with_pk() -> Result<()> { assert_snapshot!( physical_plan_to_string(&df).await, - @r###" + @r" AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[] DataSourceExec: partitions=1, partition_sizes=[1] - "### + " ); let df_results = df.collect().await?; assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+------+ | id | name | +----+------+ | 1 | a | +----+------+ - "### + " ); Ok(()) @@ -611,12 +887,11 @@ async fn test_aggregate_with_pk2() -> Result<()> { let df = df.filter(predicate)?; assert_snapshot!( physical_plan_to_string(&df).await, - @r###" - CoalesceBatchesExec: target_batch_size=8192 + @r" + AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[], ordering_mode=Sorted FilterExec: id@0 = 1 AND name@1 = a - AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[] - DataSourceExec: partitions=1, partition_sizes=[1] - "### + DataSourceExec: partitions=1, partition_sizes=[1] + " ); // Since id and name are functionally dependant, we can use name among expression @@ -625,13 +900,13 @@ async fn test_aggregate_with_pk2() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+------+ | id | name | +----+------+ | 1 | a | +----+------+ - "### + " ); Ok(()) @@ -660,12 +935,11 @@ async fn test_aggregate_with_pk3() -> Result<()> { let df = df.select(vec![col("id"), col("name")])?; assert_snapshot!( physical_plan_to_string(&df).await, - @r###" - CoalesceBatchesExec: target_batch_size=8192 + @r" + AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[], ordering_mode=PartiallySorted([0]) FilterExec: id@0 = 1 - AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[] - DataSourceExec: partitions=1, partition_sizes=[1] - "### + DataSourceExec: partitions=1, partition_sizes=[1] + " ); // Since id and name are functionally dependant, we can use name among expression @@ -674,13 +948,13 @@ async fn test_aggregate_with_pk3() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+------+ | id | name | +----+------+ | 1 | a | +----+------+ - "### + " ); Ok(()) @@ -711,25 +985,24 @@ async fn test_aggregate_with_pk4() -> Result<()> { // columns are not used. assert_snapshot!( physical_plan_to_string(&df).await, - @r###" - CoalesceBatchesExec: target_batch_size=8192 + @r" + AggregateExec: mode=Single, gby=[id@0 as id], aggr=[], ordering_mode=Sorted FilterExec: id@0 = 1 - AggregateExec: mode=Single, gby=[id@0 as id], aggr=[] - DataSourceExec: partitions=1, partition_sizes=[1] - "### + DataSourceExec: partitions=1, partition_sizes=[1] + " ); let df_results = df.collect().await?; assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+ | id | +----+ | 1 | +----+ - "### + " ); Ok(()) @@ -751,7 +1024,7 @@ async fn test_aggregate_alias() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+ | c2 | +----+ @@ -761,7 +1034,7 @@ async fn test_aggregate_alias() -> Result<()> { | 5 | | 6 | +----+ - "### + " ); Ok(()) @@ -798,7 +1071,7 @@ async fn test_aggregate_with_union() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+------------+ | c1 | sum_result | +----+------------+ @@ -808,7 +1081,7 @@ async fn test_aggregate_with_union() -> Result<()> { | d | 126 | | e | 121 | +----+------------+ - "### + " ); Ok(()) } @@ -834,7 +1107,7 @@ async fn test_aggregate_subexpr() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----------------+------+ | c2 + Int32(10) | sum | +----------------+------+ @@ -844,7 +1117,7 @@ async fn test_aggregate_subexpr() -> Result<()> { | 15 | 95 | | 16 | -146 | +----------------+------+ - "### + " ); Ok(()) @@ -867,7 +1140,7 @@ async fn test_aggregate_name_collision() -> Result<()> { // The select expr has the same display_name as the group_expr, // but since they are different expressions, it should fail. .expect_err("Expected error"); - assert_snapshot!(df.strip_backtrace(), @r###"Schema error: No field named aggregate_test_100.c2. Valid fields are "aggregate_test_100.c2 + aggregate_test_100.c3"."###); + assert_snapshot!(df.strip_backtrace(), @r#"Schema error: No field named aggregate_test_100.c2. Valid fields are "aggregate_test_100.c2 + aggregate_test_100.c3"."#); Ok(()) } @@ -907,7 +1180,7 @@ async fn window_using_aggregates() -> Result<()> { vec![col("c3")], ); - Expr::WindowFunction(w) + Expr::from(w) .null_treatment(NullTreatment::IgnoreNulls) .order_by(vec![col("c2").sort(true, true), col("c3").sort(true, true)]) .window_frame(WindowFrame::new_bounds( @@ -926,33 +1199,110 @@ async fn window_using_aggregates() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df), - @r###" + @r" +-------------+----------+-----------------+---------------+--------+-----+------+----+------+ | first_value | last_val | approx_distinct | approx_median | median | max | min | c2 | c3 | +-------------+----------+-----------------+---------------+--------+-----+------+----+------+ | | | | | | | | 1 | -85 | - | -85 | -101 | 14 | -12 | -101 | 83 | -101 | 4 | -54 | - | -85 | -101 | 17 | -25 | -101 | 83 | -101 | 5 | -31 | - | -85 | -12 | 10 | -32 | -12 | 83 | -85 | 3 | 13 | - | -85 | -25 | 3 | -56 | -25 | -25 | -85 | 1 | -5 | - | -85 | -31 | 18 | -29 | -31 | 83 | -101 | 5 | 36 | - | -85 | -38 | 16 | -25 | -38 | 83 | -101 | 4 | 65 | + | -85 | -101 | 14 | -12 | -12 | 83 | -101 | 4 | -54 | + | -85 | -101 | 17 | -25 | -25 | 83 | -101 | 5 | -31 | + | -85 | -12 | 10 | -32 | -34 | 83 | -85 | 3 | 13 | + | -85 | -25 | 3 | -56 | -56 | -25 | -85 | 1 | -5 | + | -85 | -31 | 18 | -29 | -28 | 83 | -101 | 5 | 36 | + | -85 | -38 | 16 | -25 | -25 | 83 | -101 | 4 | 65 | | -85 | -43 | 7 | -43 | -43 | 83 | -85 | 2 | 45 | - | -85 | -48 | 6 | -35 | -48 | 83 | -85 | 2 | -43 | - | -85 | -5 | 4 | -37 | -5 | -5 | -85 | 1 | 83 | - | -85 | -54 | 15 | -17 | -54 | 83 | -101 | 4 | -38 | - | -85 | -56 | 2 | -70 | -56 | -56 | -85 | 1 | -25 | - | -85 | -72 | 9 | -43 | -72 | 83 | -85 | 3 | -12 | + | -85 | -48 | 6 | -35 | -36 | 83 | -85 | 2 | -43 | + | -85 | -5 | 4 | -37 | -40 | -5 | -85 | 1 | 83 | + | -85 | -54 | 15 | -17 | -18 | 83 | -101 | 4 | -38 | + | -85 | -56 | 2 | -70 | -70 | -56 | -85 | 1 | -25 | + | -85 | -72 | 9 | -43 | -43 | 83 | -85 | 3 | -12 | | -85 | -85 | 1 | -85 | -85 | -85 | -85 | 1 | -56 | - | -85 | 13 | 11 | -17 | 13 | 83 | -85 | 3 | 14 | - | -85 | 13 | 11 | -25 | 13 | 83 | -85 | 3 | 13 | - | -85 | 14 | 12 | -12 | 14 | 83 | -85 | 3 | 17 | - | -85 | 17 | 13 | -11 | 17 | 83 | -85 | 4 | -101 | - | -85 | 45 | 8 | -34 | 45 | 83 | -85 | 3 | -72 | - | -85 | 65 | 17 | -17 | 65 | 83 | -101 | 5 | -101 | - | -85 | 83 | 5 | -25 | 83 | 83 | -85 | 2 | -48 | + | -85 | 13 | 11 | -17 | -18 | 83 | -85 | 3 | 14 | + | -85 | 13 | 11 | -25 | -25 | 83 | -85 | 3 | 13 | + | -85 | 14 | 12 | -12 | -12 | 83 | -85 | 3 | 17 | + | -85 | 17 | 13 | -11 | -8 | 83 | -85 | 4 | -101 | + | -85 | 45 | 8 | -34 | -34 | 83 | -85 | 3 | -72 | + | -85 | 65 | 17 | -17 | -18 | 83 | -101 | 5 | -101 | + | -85 | 83 | 5 | -25 | -25 | 83 | -85 | 2 | -48 | +-------------+----------+-----------------+---------------+--------+-----+------+----+------+ - "### + " + ); + + Ok(()) +} + +#[tokio::test] +async fn window_aggregates_with_filter() -> Result<()> { + // Define a small in-memory table to make expected values clear + let ts: Int32Array = [1, 2, 3, 4, 5].into_iter().collect(); + let val: Int32Array = [-3, -2, 1, 4, -1].into_iter().collect(); + let batch = RecordBatch::try_from_iter(vec![ + ("ts", Arc::new(ts) as _), + ("val", Arc::new(val) as _), + ])?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + + let df = ctx.table("t").await?; + + // Build filtered window aggregates over ORDER BY ts ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + let mut exprs = vec![ + (datafusion_functions_aggregate::sum::sum_udaf(), "sum_pos"), + ( + datafusion_functions_aggregate::average::avg_udaf(), + "avg_pos", + ), + ( + datafusion_functions_aggregate::min_max::min_udaf(), + "min_pos", + ), + ( + datafusion_functions_aggregate::min_max::max_udaf(), + "max_pos", + ), + ( + datafusion_functions_aggregate::count::count_udaf(), + "cnt_pos", + ), + ] + .into_iter() + .map(|(func, alias)| { + let w = WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(func), + vec![col("val")], + ); + + Expr::from(w) + .order_by(vec![col("ts").sort(true, true)]) + .window_frame(WindowFrame::new_bounds( + WindowFrameUnits::Rows, + WindowFrameBound::Preceding(ScalarValue::UInt64(None)), + WindowFrameBound::CurrentRow, + )) + .filter(col("val").gt(lit(0))) + .build() + .unwrap() + .alias(alias) + }) + .collect::>(); + exprs.extend_from_slice(&[col("ts"), col("val")]); + + let results = df.select(exprs)?.collect().await?; + + assert_snapshot!( + batches_to_string(&results), + @r" + +---------+---------+---------+---------+---------+----+-----+ + | sum_pos | avg_pos | min_pos | max_pos | cnt_pos | ts | val | + +---------+---------+---------+---------+---------+----+-----+ + | | | | | 0 | 1 | -3 | + | | | | | 0 | 2 | -2 | + | 1 | 1.0 | 1 | 1 | 1 | 3 | 1 | + | 5 | 2.5 | 1 | 4 | 2 | 4 | 4 | + | 5 | 2.5 | 1 | 4 | 2 | 5 | -1 | + +---------+---------+---------+---------+---------+----+-----+ + " ); Ok(()) @@ -1008,7 +1358,7 @@ async fn test_distinct_sort_by() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+ | c1 | +----+ @@ -1018,7 +1368,7 @@ async fn test_distinct_sort_by() -> Result<()> { | d | | e | +----+ - "### + " ); Ok(()) @@ -1056,7 +1406,7 @@ async fn test_distinct_on() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+ | c1 | +----+ @@ -1066,7 +1416,7 @@ async fn test_distinct_on() -> Result<()> { | d | | e | +----+ - "### + " ); Ok(()) @@ -1091,7 +1441,7 @@ async fn test_distinct_on_sort_by() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+ | c1 | +----+ @@ -1101,7 +1451,7 @@ async fn test_distinct_on_sort_by() -> Result<()> { | d | | e | +----+ - "### + " ); Ok(()) @@ -1165,13 +1515,13 @@ async fn join_coercion_unnamed() -> Result<()> { assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----+------+ | id | name | +----+------+ | 10 | d | +----+------+ - "### + " ); Ok(()) } @@ -1190,13 +1540,13 @@ async fn join_on() -> Result<()> { [col("a.c1").not_eq(col("b.c1")), col("a.c2").eq(col("b.c2"))], )?; - assert_snapshot!(join.logical_plan(), @r###" + assert_snapshot!(join.logical_plan(), @r" Inner Join: Filter: a.c1 != b.c1 AND a.c2 = b.c2 Projection: a.c1, a.c2 TableScan: a Projection: b.c1, b.c2 TableScan: b - "###); + "); Ok(()) } @@ -1210,16 +1560,20 @@ async fn join_on_filter_datatype() -> Result<()> { let join = left.clone().join_on( right.clone(), JoinType::Inner, - Some(Expr::Literal(ScalarValue::Null)), + Some(Expr::Literal(ScalarValue::Null, None)), )?; - assert_snapshot!(join.into_optimized_plan().unwrap(), @"EmptyRelation"); + assert_snapshot!(join.into_optimized_plan().unwrap(), @"EmptyRelation: rows=0"); // JOIN ON expression must be boolean type let join = left.join_on(right, JoinType::Inner, Some(lit("TRUE")))?; let err = join.into_optimized_plan().unwrap_err(); assert_snapshot!( err.strip_backtrace(), - @"type_coercion\ncaused by\nError during planning: Join condition must be boolean type, but got Utf8" + @r" + type_coercion + caused by + Error during planning: Join condition must be boolean type, but got Utf8 + " ); Ok(()) } @@ -1360,6 +1714,36 @@ async fn except() -> Result<()> { Ok(()) } +#[tokio::test] +async fn except_distinct() -> Result<()> { + let df = test_table().await?.select_columns(&["c1", "c3"])?; + let d2 = df.clone(); + let plan = df.except_distinct(d2)?; + let result = plan.logical_plan().clone(); + let expected = create_plan( + "SELECT c1, c3 FROM aggregate_test_100 + EXCEPT DISTINCT SELECT c1, c3 FROM aggregate_test_100", + ) + .await?; + assert_same_plan(&result, &expected); + Ok(()) +} + +#[tokio::test] +async fn intersect_distinct() -> Result<()> { + let df = test_table().await?.select_columns(&["c1", "c3"])?; + let d2 = df.clone(); + let plan = df.intersect_distinct(d2)?; + let result = plan.logical_plan().clone(); + let expected = create_plan( + "SELECT c1, c3 FROM aggregate_test_100 + INTERSECT DISTINCT SELECT c1, c3 FROM aggregate_test_100", + ) + .await?; + assert_same_plan(&result, &expected); + Ok(()) +} + #[tokio::test] async fn register_table() -> Result<()> { let df = test_table().await?.select_columns(&["c1", "c12"])?; @@ -1367,7 +1751,9 @@ async fn register_table() -> Result<()> { let df_impl = DataFrame::new(ctx.state(), df.logical_plan().clone()); // register a dataframe as a table - ctx.register_table("test_table", df_impl.clone().into_view())?; + let table_provider = df_impl.clone().into_view(); + assert_eq!(table_provider.table_type(), TableType::View); + ctx.register_table("test_table", table_provider)?; // pull the table out let table = ctx.table("test_table").await?; @@ -1384,7 +1770,7 @@ async fn register_table() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+-----------------------------+ | c1 | sum(aggregate_test_100.c12) | +----+-----------------------------+ @@ -1394,13 +1780,13 @@ async fn register_table() -> Result<()> { | d | 8.793968289758968 | | e | 10.206140546981722 | +----+-----------------------------+ - "### + " ); // the results are the same as the results from the view, modulo the leaf table name assert_snapshot!( batches_to_sort_string(table_results), - @r###" + @r" +----+---------------------+ | c1 | sum(test_table.c12) | +----+---------------------+ @@ -1410,11 +1796,28 @@ async fn register_table() -> Result<()> { | d | 8.793968289758968 | | e | 10.206140546981722 | +----+---------------------+ - "### + " ); Ok(()) } +#[tokio::test] +async fn register_temporary_table() -> Result<()> { + let df = test_table().await?.select_columns(&["c1", "c12"])?; + let ctx = SessionContext::new(); + let df_impl = DataFrame::new(ctx.state(), df.logical_plan().clone()); + + let df_table_provider = df_impl.clone().into_temporary_view(); + + // check that we set the correct table_type + assert_eq!(df_table_provider.table_type(), TableType::Temporary); + + // check that we can register a dataframe as a temporary table + ctx.register_table("test_table", df_table_provider)?; + + Ok(()) +} + /// Compare the formatted string representation of two plans for equality fn assert_same_plan(plan1: &LogicalPlan, plan2: &LogicalPlan) { assert_eq!(format!("{plan1:?}"), format!("{plan2:?}")); @@ -1442,7 +1845,7 @@ async fn with_column() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+----+-----+-----+ | c1 | c2 | c3 | sum | +----+----+-----+-----+ @@ -1453,7 +1856,7 @@ async fn with_column() -> Result<()> { | a | 3 | 14 | 17 | | a | 3 | 17 | 20 | +----+----+-----+-----+ - "### + " ); // check that col with the same name overwritten @@ -1465,7 +1868,7 @@ async fn with_column() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results_overwrite), - @r###" + @r" +-----+----+-----+-----+ | c1 | c2 | c3 | sum | +-----+----+-----+-----+ @@ -1476,7 +1879,7 @@ async fn with_column() -> Result<()> { | 17 | 3 | 14 | 17 | | 20 | 3 | 17 | 20 | +-----+----+-----+-----+ - "### + " ); // check that col with the same name overwritten using same name as reference @@ -1488,7 +1891,7 @@ async fn with_column() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results_overwrite_self), - @r###" + @r" +----+----+-----+-----+ | c1 | c2 | c3 | sum | +----+----+-----+-----+ @@ -1499,7 +1902,7 @@ async fn with_column() -> Result<()> { | a | 4 | 14 | 17 | | a | 4 | 17 | 20 | +----+----+-----+-----+ - "### + " ); Ok(()) @@ -1527,14 +1930,14 @@ async fn test_window_function_with_column() -> Result<()> { let df_results = df.clone().collect().await?; assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+----+-----+-----+---+ | c1 | c2 | c3 | s | r | +----+----+-----+-----+---+ | c | 2 | 1 | 3 | 1 | | d | 5 | -40 | -35 | 2 | +----+----+-----+-----+---+ - "### + " ); Ok(()) @@ -1569,13 +1972,13 @@ async fn with_column_join_same_columns() -> Result<()> { let df_results = df.clone().collect().await?; assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+----+ | c1 | c1 | +----+----+ | a | a | +----+----+ - "### + " ); let df_with_column = df.clone().with_column("new_column", lit(true))?; @@ -1598,7 +2001,7 @@ async fn with_column_join_same_columns() -> Result<()> { assert_snapshot!( df_with_column.clone().into_optimized_plan().unwrap(), - @r###" + @r" Projection: t1.c1, t2.c1, Boolean(true) AS new_column Sort: t1.c1 ASC NULLS FIRST, fetch=1 Inner Join: t1.c1 = t2.c1 @@ -1606,20 +2009,20 @@ async fn with_column_join_same_columns() -> Result<()> { TableScan: aggregate_test_100 projection=[c1] SubqueryAlias: t2 TableScan: aggregate_test_100 projection=[c1] - "### + " ); let df_results = df_with_column.collect().await?; assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+----+------------+ | c1 | c1 | new_column | +----+----+------------+ | a | a | true | +----+----+------------+ - "### + " ); Ok(()) @@ -1669,13 +2072,13 @@ async fn with_column_renamed() -> Result<()> { assert_snapshot!( batches_to_sort_string(batches), - @r###" + @r" +-----+-----+-----+-------+ | one | two | c3 | total | +-----+-----+-----+-------+ | a | 3 | -72 | -69 | +-----+-----+-----+-------+ - "### + " ); Ok(()) @@ -1740,13 +2143,13 @@ async fn with_column_renamed_join() -> Result<()> { let df_results = df.clone().collect().await?; assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+----+-----+----+----+-----+ | c1 | c2 | c3 | c1 | c2 | c3 | +----+----+-----+----+----+-----+ | a | 1 | -85 | a | 1 | -85 | +----+----+-----+----+----+-----+ - "### + " ); let df_renamed = df.clone().with_column_renamed("t1.c1", "AAA")?; @@ -1769,7 +2172,7 @@ async fn with_column_renamed_join() -> Result<()> { assert_snapshot!( df_renamed.clone().into_optimized_plan().unwrap(), - @r###" + @r" Projection: t1.c1 AS AAA, t1.c2, t1.c3, t2.c1, t2.c2, t2.c3 Sort: t1.c1 ASC NULLS FIRST, t1.c2 ASC NULLS FIRST, t1.c3 ASC NULLS FIRST, t2.c1 ASC NULLS FIRST, t2.c2 ASC NULLS FIRST, t2.c3 ASC NULLS FIRST, fetch=1 Inner Join: t1.c1 = t2.c1 @@ -1777,20 +2180,20 @@ async fn with_column_renamed_join() -> Result<()> { TableScan: aggregate_test_100 projection=[c1, c2, c3] SubqueryAlias: t2 TableScan: aggregate_test_100 projection=[c1, c2, c3] - "### + " ); let df_results = df_renamed.collect().await?; assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +-----+----+-----+----+----+-----+ | AAA | c2 | c3 | c1 | c2 | c3 | +-----+----+-----+----+----+-----+ | a | 1 | -85 | a | 1 | -85 | +-----+----+-----+----+----+-----+ - "### + " ); Ok(()) @@ -1825,13 +2228,13 @@ async fn with_column_renamed_case_sensitive() -> Result<()> { assert_snapshot!( batches_to_sort_string(res), - @r###" + @r" +---------+ | CoLuMn1 | +---------+ | a | +---------+ - "### + " ); let df_renamed = df_renamed @@ -1841,13 +2244,13 @@ async fn with_column_renamed_case_sensitive() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_renamed), - @r###" + @r" +----+ | c1 | +----+ | a | +----+ - "### + " ); Ok(()) @@ -1885,19 +2288,19 @@ async fn describe_lookup_via_quoted_identifier() -> Result<()> { .await?; assert_snapshot!( batches_to_sort_string(&describe_result.clone().collect().await?), - @r###" - +------------+--------------+ - | describe | CoLu.Mn["1"] | - +------------+--------------+ - | count | 1 | - | max | a | - | mean | null | - | median | null | - | min | a | - | null_count | 0 | - | std | null | - +------------+--------------+ - "### + @r#" + +------------+--------------+ + | describe | CoLu.Mn["1"] | + +------------+--------------+ + | count | 1 | + | max | a | + | mean | null | + | median | null | + | min | a | + | null_count | 0 | + | std | null | + +------------+--------------+ + "# ); Ok(()) @@ -1915,13 +2318,13 @@ async fn cast_expr_test() -> Result<()> { df.clone().show().await?; assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +----+----+-----+ | c2 | c3 | sum | +----+----+-----+ | 2 | 1 | 3 | +----+----+-----+ - "### + " ); Ok(()) @@ -1937,12 +2340,14 @@ async fn row_writer_resize_test() -> Result<()> { let data = RecordBatch::try_new( schema, - vec![ - Arc::new(StringArray::from(vec![ - Some("2a0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), - Some("3a0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000800"), - ])) - ], + vec![Arc::new(StringArray::from(vec![ + Some( + "2a0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", + ), + Some( + "3a0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000800", + ), + ]))], )?; let ctx = SessionContext::new(); @@ -1981,14 +2386,14 @@ async fn with_column_name() -> Result<()> { assert_snapshot!( batches_to_sort_string(&df_results), - @r###" + @r" +------+-------+ | f.c1 | f.c2 | +------+-------+ | 1 | hello | | 10 | hello | +------+-------+ - "### + " ); Ok(()) @@ -2024,13 +2429,13 @@ async fn cache_test() -> Result<()> { let cached_df_results = cached_df.collect().await?; assert_snapshot!( batches_to_sort_string(&cached_df_results), - @r###" + @r" +----+----+-----+ | c2 | c3 | sum | +----+----+-----+ | 2 | 1 | 3 | +----+----+-----+ - "### + " ); assert_eq!(&df_results, &cached_df_results); @@ -2038,6 +2443,29 @@ async fn cache_test() -> Result<()> { Ok(()) } +#[tokio::test] +async fn cache_producer_test() -> Result<()> { + let df = test_table_with_cache_factory() + .await? + .select_columns(&["c2", "c3"])? + .limit(0, Some(1))? + .with_column("sum", cast(col("c2") + col("c3"), DataType::Int64))?; + + let cached_df = df.clone().cache().await?; + + assert_snapshot!( + cached_df.clone().into_optimized_plan().unwrap(), + @r" + CacheNode + Projection: aggregate_test_100.c2, aggregate_test_100.c3, CAST(CAST(aggregate_test_100.c2 AS Int64) + CAST(aggregate_test_100.c3 AS Int64) AS Int64) AS sum + Projection: aggregate_test_100.c2, aggregate_test_100.c3 + Limit: skip=0, fetch=1 + TableScan: aggregate_test_100, fetch=1 + " + ); + Ok(()) +} + #[tokio::test] async fn partition_aware_union() -> Result<()> { let left = test_table().await?.select_columns(&["c1", "c2"])?; @@ -2145,6 +2573,7 @@ async fn verify_join_output_partitioning() -> Result<()> { JoinType::LeftAnti, JoinType::RightAnti, JoinType::LeftMark, + JoinType::RightMark, ]; let default_partition_count = SessionConfig::new().target_partitions(); @@ -2178,7 +2607,8 @@ async fn verify_join_output_partitioning() -> Result<()> { JoinType::Inner | JoinType::Right | JoinType::RightSemi - | JoinType::RightAnti => { + | JoinType::RightAnti + | JoinType::RightMark => { let right_exprs: Vec> = vec![ Arc::new(Column::new_with_schema("c2_c1", &join_schema)?), Arc::new(Column::new_with_schema("c2_c2", &join_schema)?), @@ -2300,18 +2730,18 @@ async fn filtered_aggr_with_param_values() -> Result<()> { let df = ctx .sql("select count (c2) filter (where c3 > $1) from table1") .await? - .with_param_values(ParamValues::List(vec![ScalarValue::from(10u64)])); + .with_param_values(vec![ScalarValue::from(10u64)]); let df_results = df?.collect().await?; assert_snapshot!( batches_to_string(&df_results), - @r###" + @r" +------------------------------------------------+ | count(table1.c2) FILTER (WHERE table1.c3 > $1) | +------------------------------------------------+ | 54 | +------------------------------------------------+ - "### + " ); Ok(()) @@ -2359,7 +2789,7 @@ async fn write_parquet_with_order() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +---+---+ | a | b | +---+---+ @@ -2369,7 +2799,7 @@ async fn write_parquet_with_order() -> Result<()> { | 5 | 3 | | 7 | 4 | +---+---+ - "### + " ); Ok(()) @@ -2417,7 +2847,7 @@ async fn write_csv_with_order() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +---+---+ | a | b | +---+---+ @@ -2427,7 +2857,7 @@ async fn write_csv_with_order() -> Result<()> { | 5 | 3 | | 7 | 4 | +---+---+ - "### + " ); Ok(()) } @@ -2465,7 +2895,7 @@ async fn write_json_with_order() -> Result<()> { ctx.register_json( "data", test_path.to_str().unwrap(), - NdJsonReadOptions::default().schema(&schema), + JsonReadOptions::default().schema(&schema), ) .await?; @@ -2474,7 +2904,7 @@ async fn write_json_with_order() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +---+---+ | a | b | +---+---+ @@ -2484,7 +2914,7 @@ async fn write_json_with_order() -> Result<()> { | 5 | 3 | | 7 | 4 | +---+---+ - "### + " ); Ok(()) } @@ -2533,7 +2963,7 @@ async fn write_table_with_order() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +-----------+ | tablecol1 | +-----------+ @@ -2543,7 +2973,7 @@ async fn write_table_with_order() -> Result<()> { | x | | z | +-----------+ - "### + " ); Ok(()) } @@ -2570,50 +3000,44 @@ async fn test_count_wildcard_on_sort() -> Result<()> { assert_snapshot!( pretty_format_batches(&sql_results).unwrap(), - @r###" - +---------------+------------------------------------------------------------------------------------------------------------+ - | plan_type | plan | - +---------------+------------------------------------------------------------------------------------------------------------+ - | logical_plan | Projection: t1.b, count(*) | - | | Sort: count(Int64(1)) AS count(*) AS count(*) ASC NULLS LAST | - | | Projection: t1.b, count(Int64(1)) AS count(*), count(Int64(1)) | - | | Aggregate: groupBy=[[t1.b]], aggr=[[count(Int64(1))]] | - | | TableScan: t1 projection=[b] | - | physical_plan | ProjectionExec: expr=[b@0 as b, count(*)@1 as count(*)] | - | | SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST] | - | | SortExec: expr=[count(Int64(1))@2 ASC NULLS LAST], preserve_partitioning=[true] | - | | ProjectionExec: expr=[b@0 as b, count(Int64(1))@1 as count(*), count(Int64(1))@1 as count(Int64(1))] | - | | AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[count(Int64(1))] | - | | CoalesceBatchesExec: target_batch_size=8192 | - | | RepartitionExec: partitioning=Hash([b@0], 4), input_partitions=4 | - | | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 | - | | AggregateExec: mode=Partial, gby=[b@0 as b], aggr=[count(Int64(1))] | - | | DataSourceExec: partitions=1, partition_sizes=[1] | - | | | - +---------------+------------------------------------------------------------------------------------------------------------+ - "### + @r" + +---------------+------------------------------------------------------------------------------------+ + | plan_type | plan | + +---------------+------------------------------------------------------------------------------------+ + | logical_plan | Sort: count(*) ASC NULLS LAST | + | | Projection: t1.b, count(Int64(1)) AS count(*) | + | | Aggregate: groupBy=[[t1.b]], aggr=[[count(Int64(1))]] | + | | TableScan: t1 projection=[b] | + | physical_plan | SortPreservingMergeExec: [count(*)@1 ASC NULLS LAST] | + | | SortExec: expr=[count(*)@1 ASC NULLS LAST], preserve_partitioning=[true] | + | | ProjectionExec: expr=[b@0 as b, count(Int64(1))@1 as count(*)] | + | | AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[count(Int64(1))] | + | | RepartitionExec: partitioning=Hash([b@0], 4), input_partitions=1 | + | | AggregateExec: mode=Partial, gby=[b@0 as b], aggr=[count(Int64(1))] | + | | DataSourceExec: partitions=1, partition_sizes=[1] | + | | | + +---------------+------------------------------------------------------------------------------------+ + " ); assert_snapshot!( pretty_format_batches(&df_results).unwrap(), - @r###" - +---------------+--------------------------------------------------------------------------------+ - | plan_type | plan | - +---------------+--------------------------------------------------------------------------------+ - | logical_plan | Sort: count(*) ASC NULLS LAST | - | | Aggregate: groupBy=[[t1.b]], aggr=[[count(Int64(1)) AS count(*)]] | - | | TableScan: t1 projection=[b] | - | physical_plan | SortPreservingMergeExec: [count(*)@1 ASC NULLS LAST] | - | | SortExec: expr=[count(*)@1 ASC NULLS LAST], preserve_partitioning=[true] | - | | AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[count(*)] | - | | CoalesceBatchesExec: target_batch_size=8192 | - | | RepartitionExec: partitioning=Hash([b@0], 4), input_partitions=4 | - | | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 | - | | AggregateExec: mode=Partial, gby=[b@0 as b], aggr=[count(*)] | - | | DataSourceExec: partitions=1, partition_sizes=[1] | - | | | - +---------------+--------------------------------------------------------------------------------+ - "### + @r" + +---------------+----------------------------------------------------------------------------+ + | plan_type | plan | + +---------------+----------------------------------------------------------------------------+ + | logical_plan | Sort: count(*) AS count(*) ASC NULLS LAST | + | | Aggregate: groupBy=[[t1.b]], aggr=[[count(Int64(1)) AS count(*)]] | + | | TableScan: t1 projection=[b] | + | physical_plan | SortPreservingMergeExec: [count(*)@1 ASC NULLS LAST] | + | | SortExec: expr=[count(*)@1 ASC NULLS LAST], preserve_partitioning=[true] | + | | AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[count(*)] | + | | RepartitionExec: partitioning=Hash([b@0], 4), input_partitions=1 | + | | AggregateExec: mode=Partial, gby=[b@0 as b], aggr=[count(*)] | + | | DataSourceExec: partitions=1, partition_sizes=[1] | + | | | + +---------------+----------------------------------------------------------------------------+ + " ); Ok(()) } @@ -2631,23 +3055,22 @@ async fn test_count_wildcard_on_where_in() -> Result<()> { assert_snapshot!( pretty_format_batches(&sql_results).unwrap(), @r" - +---------------+------------------------------------------------------------------------------------------------------------------------+ - | plan_type | plan | - +---------------+------------------------------------------------------------------------------------------------------------------------+ - | logical_plan | LeftSemi Join: CAST(t1.a AS Int64) = __correlated_sq_1.count(*) | - | | TableScan: t1 projection=[a, b] | - | | SubqueryAlias: __correlated_sq_1 | - | | Projection: count(Int64(1)) AS count(*) | - | | Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] | - | | TableScan: t2 projection=[] | - | physical_plan | CoalesceBatchesExec: target_batch_size=8192 | - | | HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(count(*)@0, CAST(t1.a AS Int64)@2)], projection=[a@0, b@1] | - | | ProjectionExec: expr=[4 as count(*)] | - | | PlaceholderRowExec | - | | ProjectionExec: expr=[a@0 as a, b@1 as b, CAST(a@0 AS Int64) as CAST(t1.a AS Int64)] | - | | DataSourceExec: partitions=1, partition_sizes=[1] | - | | | - +---------------+------------------------------------------------------------------------------------------------------------------------+ + +---------------+----------------------------------------------------------------------------------------------------------------------+ + | plan_type | plan | + +---------------+----------------------------------------------------------------------------------------------------------------------+ + | logical_plan | LeftSemi Join: CAST(t1.a AS Int64) = __correlated_sq_1.count(*) | + | | TableScan: t1 projection=[a, b] | + | | SubqueryAlias: __correlated_sq_1 | + | | Projection: count(Int64(1)) AS count(*) | + | | Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] | + | | TableScan: t2 projection=[] | + | physical_plan | HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(count(*)@0, CAST(t1.a AS Int64)@2)], projection=[a@0, b@1] | + | | ProjectionExec: expr=[4 as count(*)] | + | | PlaceholderRowExec | + | | ProjectionExec: expr=[a@0 as a, b@1 as b, CAST(a@0 AS Int64) as CAST(t1.a AS Int64)] | + | | DataSourceExec: partitions=1, partition_sizes=[1] | + | | | + +---------------+----------------------------------------------------------------------------------------------------------------------+ " ); @@ -2677,22 +3100,21 @@ async fn test_count_wildcard_on_where_in() -> Result<()> { assert_snapshot!( pretty_format_batches(&df_results).unwrap(), @r" - +---------------+------------------------------------------------------------------------------------------------------------------------+ - | plan_type | plan | - +---------------+------------------------------------------------------------------------------------------------------------------------+ - | logical_plan | LeftSemi Join: CAST(t1.a AS Int64) = __correlated_sq_1.count(*) | - | | TableScan: t1 projection=[a, b] | - | | SubqueryAlias: __correlated_sq_1 | - | | Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] | - | | TableScan: t2 projection=[] | - | physical_plan | CoalesceBatchesExec: target_batch_size=8192 | - | | HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(count(*)@0, CAST(t1.a AS Int64)@2)], projection=[a@0, b@1] | - | | ProjectionExec: expr=[4 as count(*)] | - | | PlaceholderRowExec | - | | ProjectionExec: expr=[a@0 as a, b@1 as b, CAST(a@0 AS Int64) as CAST(t1.a AS Int64)] | - | | DataSourceExec: partitions=1, partition_sizes=[1] | - | | | - +---------------+------------------------------------------------------------------------------------------------------------------------+ + +---------------+----------------------------------------------------------------------------------------------------------------------+ + | plan_type | plan | + +---------------+----------------------------------------------------------------------------------------------------------------------+ + | logical_plan | LeftSemi Join: CAST(t1.a AS Int64) = __correlated_sq_1.count(*) | + | | TableScan: t1 projection=[a, b] | + | | SubqueryAlias: __correlated_sq_1 | + | | Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] | + | | TableScan: t2 projection=[] | + | physical_plan | HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(count(*)@0, CAST(t1.a AS Int64)@2)], projection=[a@0, b@1] | + | | ProjectionExec: expr=[4 as count(*)] | + | | PlaceholderRowExec | + | | ProjectionExec: expr=[a@0 as a, b@1 as b, CAST(a@0 AS Int64) as CAST(t1.a AS Int64)] | + | | DataSourceExec: partitions=1, partition_sizes=[1] | + | | | + +---------------+----------------------------------------------------------------------------------------------------------------------+ " ); @@ -2711,23 +3133,20 @@ async fn test_count_wildcard_on_where_exist() -> Result<()> { assert_snapshot!( pretty_format_batches(&sql_results).unwrap(), - @r###" - +---------------+---------------------------------------------------------+ - | plan_type | plan | - +---------------+---------------------------------------------------------+ - | logical_plan | LeftSemi Join: | - | | TableScan: t1 projection=[a, b] | - | | SubqueryAlias: __correlated_sq_1 | - | | Projection: | - | | Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] | - | | TableScan: t2 projection=[] | - | physical_plan | NestedLoopJoinExec: join_type=RightSemi | - | | ProjectionExec: expr=[] | - | | PlaceholderRowExec | - | | DataSourceExec: partitions=1, partition_sizes=[1] | - | | | - +---------------+---------------------------------------------------------+ - "### + @r" + +---------------+-----------------------------------------------------+ + | plan_type | plan | + +---------------+-----------------------------------------------------+ + | logical_plan | LeftSemi Join: | + | | TableScan: t1 projection=[a, b] | + | | SubqueryAlias: __correlated_sq_1 | + | | EmptyRelation: rows=1 | + | physical_plan | NestedLoopJoinExec: join_type=RightSemi | + | | PlaceholderRowExec | + | | DataSourceExec: partitions=1, partition_sizes=[1] | + | | | + +---------------+-----------------------------------------------------+ + " ); let df_results = ctx @@ -2750,92 +3169,194 @@ async fn test_count_wildcard_on_where_exist() -> Result<()> { assert_snapshot!( pretty_format_batches(&df_results).unwrap(), - @r###" - +---------------+---------------------------------------------------------------------+ - | plan_type | plan | - +---------------+---------------------------------------------------------------------+ - | logical_plan | LeftSemi Join: | - | | TableScan: t1 projection=[a, b] | - | | SubqueryAlias: __correlated_sq_1 | - | | Projection: | - | | Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] | - | | TableScan: t2 projection=[] | - | physical_plan | NestedLoopJoinExec: join_type=RightSemi | - | | ProjectionExec: expr=[] | - | | PlaceholderRowExec | - | | DataSourceExec: partitions=1, partition_sizes=[1] | - | | | - +---------------+---------------------------------------------------------------------+ - "### + @r" + +---------------+-----------------------------------------------------+ + | plan_type | plan | + +---------------+-----------------------------------------------------+ + | logical_plan | LeftSemi Join: | + | | TableScan: t1 projection=[a, b] | + | | SubqueryAlias: __correlated_sq_1 | + | | EmptyRelation: rows=1 | + | physical_plan | NestedLoopJoinExec: join_type=RightSemi | + | | PlaceholderRowExec | + | | DataSourceExec: partitions=1, partition_sizes=[1] | + | | | + +---------------+-----------------------------------------------------+ + " ); Ok(()) } -#[tokio::test] -async fn test_count_wildcard_on_window() -> Result<()> { - let ctx = create_join_context()?; +#[tokio::test] +async fn test_count_wildcard_on_window() -> Result<()> { + let ctx = create_join_context()?; + + let sql_results = ctx + .sql("select count(*) OVER(ORDER BY a DESC RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING) from t1") + .await? + .explain(false, false)? + .collect() + .await?; + + assert_snapshot!( + pretty_format_batches(&sql_results).unwrap(), + @r#" + +---------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + | plan_type | plan | + +---------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + | logical_plan | Projection: count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING AS count(*) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING | + | | WindowAggr: windowExpr=[[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]] | + | | TableScan: t1 projection=[a] | + | physical_plan | ProjectionExec: expr=[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@1 as count(*) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING] | + | | BoundedWindowAggExec: wdw=[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING: Field { "count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING": Int64 }, frame: RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING], mode=[Sorted] | + | | SortExec: expr=[a@0 DESC], preserve_partitioning=[false] | + | | DataSourceExec: partitions=1, partition_sizes=[1] | + | | | + +---------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + "# + ); + + let df_results = ctx + .table("t1") + .await? + .select(vec![ + count_all_window() + .order_by(vec![Sort::new(col("a"), false, true)]) + .window_frame(WindowFrame::new_bounds( + WindowFrameUnits::Range, + WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), + WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), + )) + .build() + .unwrap(), + ])? + .explain(false, false)? + .collect() + .await?; + + assert_snapshot!( + pretty_format_batches(&df_results).unwrap(), + @r#" + +---------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + | plan_type | plan | + +---------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + | logical_plan | Projection: count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING | + | | WindowAggr: windowExpr=[[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]] | + | | TableScan: t1 projection=[a] | + | physical_plan | ProjectionExec: expr=[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@1 as count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING] | + | | BoundedWindowAggExec: wdw=[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING: Field { "count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING": Int64 }, frame: RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING], mode=[Sorted] | + | | SortExec: expr=[a@0 DESC], preserve_partitioning=[false] | + | | DataSourceExec: partitions=1, partition_sizes=[1] | + | | | + +---------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + "# + ); + + Ok(()) +} + +#[tokio::test] +// Test with `repartition_sorts` disabled, causing a full resort of the data +async fn union_with_mix_of_presorted_and_explicitly_resorted_inputs_with_repartition_sorts_false() +-> Result<()> { + assert_snapshot!( + union_with_mix_of_presorted_and_explicitly_resorted_inputs_impl(false).await?, + @r" + AggregateExec: mode=Final, gby=[id@0 as id], aggr=[], ordering_mode=Sorted + SortPreservingMergeExec: [id@0 ASC NULLS LAST] + AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[], ordering_mode=Sorted + UnionExec + DataSourceExec: file_groups={1 group: [[{testdata}/alltypes_tiny_pages.parquet]]}, projection=[id], output_ordering=[id@0 ASC NULLS LAST], file_type=parquet + SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[{testdata}/alltypes_tiny_pages.parquet]]}, projection=[id], file_type=parquet + "); + Ok(()) +} + +#[tokio::test] +// Test with `repartition_sorts` enabled to preserve pre-sorted partitions and avoid resorting +async fn union_with_mix_of_presorted_and_explicitly_resorted_inputs_with_repartition_sorts_true() +-> Result<()> { + assert_snapshot!( + union_with_mix_of_presorted_and_explicitly_resorted_inputs_impl(true).await?, + @r" + AggregateExec: mode=Final, gby=[id@0 as id], aggr=[], ordering_mode=Sorted + SortPreservingMergeExec: [id@0 ASC NULLS LAST] + AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[], ordering_mode=Sorted + UnionExec + DataSourceExec: file_groups={1 group: [[{testdata}/alltypes_tiny_pages.parquet]]}, projection=[id], output_ordering=[id@0 ASC NULLS LAST], file_type=parquet + SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[{testdata}/alltypes_tiny_pages.parquet]]}, projection=[id], file_type=parquet + "); + + Ok(()) +} + +async fn union_with_mix_of_presorted_and_explicitly_resorted_inputs_impl( + repartition_sorts: bool, +) -> Result { + let config = SessionConfig::default() + .with_target_partitions(1) + .with_repartition_sorts(repartition_sorts); + let ctx = SessionContext::new_with_config(config); + + let testdata = parquet_test_data(); + + // Register "sorted" table, that is sorted + ctx.register_parquet( + "sorted", + &format!("{testdata}/alltypes_tiny_pages.parquet"), + ParquetReadOptions::default() + .file_sort_order(vec![vec![col("id").sort(true, false)]]), + ) + .await?; + + // Register "unsorted" table + ctx.register_parquet( + "unsorted", + &format!("{testdata}/alltypes_tiny_pages.parquet"), + ParquetReadOptions::default(), + ) + .await?; + + let source_sorted = ctx + .table("sorted") + .await + .unwrap() + .select(vec![col("id")]) + .unwrap(); + + let source_unsorted = ctx + .table("unsorted") + .await + .unwrap() + .select(vec![col("id")]) + .unwrap(); + + let source_unsorted_resorted = + source_unsorted.sort(vec![col("id").sort(true, false)])?; + + let union = source_sorted.union(source_unsorted_resorted)?; - let sql_results = ctx - .sql("select count(*) OVER(ORDER BY a DESC RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING) from t1") - .await? - .explain(false, false)? - .collect() - .await?; + let agg = union.aggregate(vec![col("id")], vec![])?; - assert_snapshot!( - pretty_format_batches(&sql_results).unwrap(), - @r###" - +---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ - | plan_type | plan | - +---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ - | logical_plan | Projection: count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING AS count(*) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING | - | | WindowAggr: windowExpr=[[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]] | - | | TableScan: t1 projection=[a] | - | physical_plan | ProjectionExec: expr=[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@1 as count(*) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING] | - | | BoundedWindowAggExec: wdw=[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING: Ok(Field { name: "count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(6)), end_bound: Following(UInt32(2)), is_causal: false }], mode=[Sorted] | - | | SortExec: expr=[a@0 DESC], preserve_partitioning=[false] | - | | DataSourceExec: partitions=1, partition_sizes=[1] | - | | | - +---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ - "### - ); + let df = agg; - let df_results = ctx - .table("t1") - .await? - .select(vec![count_all_window() - .order_by(vec![Sort::new(col("a"), false, true)]) - .window_frame(WindowFrame::new_bounds( - WindowFrameUnits::Range, - WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), - WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), - )) - .build() - .unwrap()])? - .explain(false, false)? - .collect() - .await?; + // To be able to remove user specific paths from the plan, for stable assertions + let testdata_clean = Path::new(&testdata).canonicalize()?.display().to_string(); + let testdata_clean = testdata_clean.strip_prefix("/").unwrap_or(&testdata_clean); - assert_snapshot!( - pretty_format_batches(&df_results).unwrap(), - @r###" - +---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ - | plan_type | plan | - +---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ - | logical_plan | Projection: count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING | - | | WindowAggr: windowExpr=[[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]] | - | | TableScan: t1 projection=[a] | - | physical_plan | ProjectionExec: expr=[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@1 as count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING] | - | | BoundedWindowAggExec: wdw=[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING: Ok(Field { name: "count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(6)), end_bound: Following(UInt32(2)), is_causal: false }], mode=[Sorted] | - | | SortExec: expr=[a@0 DESC], preserve_partitioning=[false] | - | | DataSourceExec: partitions=1, partition_sizes=[1] | - | | | - +---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ - "### - ); + // Use displayable() rather than explain().collect() to avoid table formatting issues. We need + // to replace machine-specific paths with variable lengths, which breaks table alignment and + // causes snapshot mismatches. + let physical_plan = df.create_physical_plan().await?; + let displayable_plan = displayable(physical_plan.as_ref()) + .indent(true) + .to_string() + .replace(testdata_clean, "{testdata}"); - Ok(()) + Ok(displayable_plan) } #[tokio::test] @@ -2852,7 +3373,7 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> { assert_snapshot!( pretty_format_batches(&sql_results).unwrap(), - @r###" + @r" +---------------+-----------------------------------------------------+ | plan_type | plan | +---------------+-----------------------------------------------------+ @@ -2863,7 +3384,7 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> { | | PlaceholderRowExec | | | | +---------------+-----------------------------------------------------+ - "### + " ); // add `.select(vec![count_wildcard()])?` to make sure we can analyze all node instead of just top node. @@ -2878,7 +3399,7 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> { assert_snapshot!( pretty_format_batches(&df_results).unwrap(), - @r###" + @r" +---------------+---------------------------------------------------------------+ | plan_type | plan | +---------------+---------------------------------------------------------------+ @@ -2888,7 +3409,7 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> { | | PlaceholderRowExec | | | | +---------------+---------------------------------------------------------------+ - "### + " ); Ok(()) @@ -2908,32 +3429,31 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { assert_snapshot!( pretty_format_batches(&sql_results).unwrap(), @r" - +---------------+---------------------------------------------------------------------------------------------------------------------------+ - | plan_type | plan | - +---------------+---------------------------------------------------------------------------------------------------------------------------+ - | logical_plan | Projection: t1.a, t1.b | - | | Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END > Int64(0) | - | | Projection: t1.a, t1.b, __scalar_sq_1.count(*), __scalar_sq_1.__always_true | - | | Left Join: t1.a = __scalar_sq_1.a | - | | TableScan: t1 projection=[a, b] | - | | SubqueryAlias: __scalar_sq_1 | - | | Projection: count(Int64(1)) AS count(*), t2.a, Boolean(true) AS __always_true | - | | Aggregate: groupBy=[[t2.a]], aggr=[[count(Int64(1))]] | - | | TableScan: t2 projection=[a] | - | physical_plan | CoalesceBatchesExec: target_batch_size=8192 | - | | FilterExec: CASE WHEN __always_true@3 IS NULL THEN 0 ELSE count(*)@2 END > 0, projection=[a@0, b@1] | - | | CoalesceBatchesExec: target_batch_size=8192 | - | | HashJoinExec: mode=CollectLeft, join_type=Left, on=[(a@0, a@1)], projection=[a@0, b@1, count(*)@2, __always_true@4] | - | | DataSourceExec: partitions=1, partition_sizes=[1] | - | | ProjectionExec: expr=[count(Int64(1))@1 as count(*), a@0 as a, true as __always_true] | - | | AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count(Int64(1))] | - | | CoalesceBatchesExec: target_batch_size=8192 | - | | RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 | - | | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 | - | | AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count(Int64(1))] | - | | DataSourceExec: partitions=1, partition_sizes=[1] | - | | | - +---------------+---------------------------------------------------------------------------------------------------------------------------+ + +---------------+----------------------------------------------------------------------------------------------------------------------------+ + | plan_type | plan | + +---------------+----------------------------------------------------------------------------------------------------------------------------+ + | logical_plan | Projection: t1.a, t1.b | + | | Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END > Int64(0) | + | | Projection: t1.a, t1.b, __scalar_sq_1.count(*), __scalar_sq_1.__always_true | + | | Left Join: t1.a = __scalar_sq_1.a | + | | TableScan: t1 projection=[a, b] | + | | SubqueryAlias: __scalar_sq_1 | + | | Projection: count(Int64(1)) AS count(*), t2.a, Boolean(true) AS __always_true | + | | Aggregate: groupBy=[[t2.a]], aggr=[[count(Int64(1))]] | + | | TableScan: t2 projection=[a] | + | physical_plan | FilterExec: CASE WHEN __always_true@3 IS NULL THEN 0 ELSE count(*)@2 END > 0, projection=[a@0, b@1] | + | | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 | + | | ProjectionExec: expr=[a@2 as a, b@3 as b, count(*)@0 as count(*), __always_true@1 as __always_true] | + | | HashJoinExec: mode=CollectLeft, join_type=Right, on=[(a@1, a@0)], projection=[count(*)@0, __always_true@2, a@3, b@4] | + | | CoalescePartitionsExec | + | | ProjectionExec: expr=[count(Int64(1))@1 as count(*), a@0 as a, true as __always_true] | + | | AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count(Int64(1))] | + | | RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1 | + | | AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count(Int64(1))] | + | | DataSourceExec: partitions=1, partition_sizes=[1] | + | | DataSourceExec: partitions=1, partition_sizes=[1] | + | | | + +---------------+----------------------------------------------------------------------------------------------------------------------------+ " ); @@ -2965,32 +3485,31 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { assert_snapshot!( pretty_format_batches(&df_results).unwrap(), @r" - +---------------+---------------------------------------------------------------------------------------------------------------------------+ - | plan_type | plan | - +---------------+---------------------------------------------------------------------------------------------------------------------------+ - | logical_plan | Projection: t1.a, t1.b | - | | Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END > Int64(0) | - | | Projection: t1.a, t1.b, __scalar_sq_1.count(*), __scalar_sq_1.__always_true | - | | Left Join: t1.a = __scalar_sq_1.a | - | | TableScan: t1 projection=[a, b] | - | | SubqueryAlias: __scalar_sq_1 | - | | Projection: count(*), t2.a, Boolean(true) AS __always_true | - | | Aggregate: groupBy=[[t2.a]], aggr=[[count(Int64(1)) AS count(*)]] | - | | TableScan: t2 projection=[a] | - | physical_plan | CoalesceBatchesExec: target_batch_size=8192 | - | | FilterExec: CASE WHEN __always_true@3 IS NULL THEN 0 ELSE count(*)@2 END > 0, projection=[a@0, b@1] | - | | CoalesceBatchesExec: target_batch_size=8192 | - | | HashJoinExec: mode=CollectLeft, join_type=Left, on=[(a@0, a@1)], projection=[a@0, b@1, count(*)@2, __always_true@4] | - | | DataSourceExec: partitions=1, partition_sizes=[1] | - | | ProjectionExec: expr=[count(*)@1 as count(*), a@0 as a, true as __always_true] | - | | AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count(*)] | - | | CoalesceBatchesExec: target_batch_size=8192 | - | | RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 | - | | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 | - | | AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count(*)] | - | | DataSourceExec: partitions=1, partition_sizes=[1] | - | | | - +---------------+---------------------------------------------------------------------------------------------------------------------------+ + +---------------+----------------------------------------------------------------------------------------------------------------------------+ + | plan_type | plan | + +---------------+----------------------------------------------------------------------------------------------------------------------------+ + | logical_plan | Projection: t1.a, t1.b | + | | Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END > Int64(0) | + | | Projection: t1.a, t1.b, __scalar_sq_1.count(*), __scalar_sq_1.__always_true | + | | Left Join: t1.a = __scalar_sq_1.a | + | | TableScan: t1 projection=[a, b] | + | | SubqueryAlias: __scalar_sq_1 | + | | Projection: count(*), t2.a, Boolean(true) AS __always_true | + | | Aggregate: groupBy=[[t2.a]], aggr=[[count(Int64(1)) AS count(*)]] | + | | TableScan: t2 projection=[a] | + | physical_plan | FilterExec: CASE WHEN __always_true@3 IS NULL THEN 0 ELSE count(*)@2 END > 0, projection=[a@0, b@1] | + | | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 | + | | ProjectionExec: expr=[a@2 as a, b@3 as b, count(*)@0 as count(*), __always_true@1 as __always_true] | + | | HashJoinExec: mode=CollectLeft, join_type=Right, on=[(a@1, a@0)], projection=[count(*)@0, __always_true@2, a@3, b@4] | + | | CoalescePartitionsExec | + | | ProjectionExec: expr=[count(*)@1 as count(*), a@0 as a, true as __always_true] | + | | AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count(*)] | + | | RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1 | + | | AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count(*)] | + | | DataSourceExec: partitions=1, partition_sizes=[1] | + | | DataSourceExec: partitions=1, partition_sizes=[1] | + | | | + +---------------+----------------------------------------------------------------------------------------------------------------------------+ " ); @@ -3075,7 +3594,7 @@ async fn sort_on_unprojected_columns() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +-----+ | a | +-----+ @@ -3084,7 +3603,7 @@ async fn sort_on_unprojected_columns() -> Result<()> { | 10 | | 1 | +-----+ - "### + " ); Ok(()) @@ -3122,7 +3641,7 @@ async fn sort_on_distinct_columns() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +-----+ | a | +-----+ @@ -3130,7 +3649,7 @@ async fn sort_on_distinct_columns() -> Result<()> { | 10 | | 1 | +-----+ - "### + " ); Ok(()) } @@ -3261,14 +3780,14 @@ async fn filter_with_alias_overwrite() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +------+ | a | +------+ | true | | true | +------+ - "### + " ); Ok(()) @@ -3297,7 +3816,7 @@ async fn select_with_alias_overwrite() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +-------+ | a | +-------+ @@ -3306,7 +3825,7 @@ async fn select_with_alias_overwrite() -> Result<()> { | true | | false | +-------+ - "### + " ); Ok(()) @@ -3332,7 +3851,7 @@ async fn test_grouping_sets() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +-----------+-----+---------------+ | a | b | count(test.a) | +-----------+-----+---------------+ @@ -3348,7 +3867,7 @@ async fn test_grouping_sets() -> Result<()> { | 123AbcDef | | 1 | | 123AbcDef | 100 | 1 | +-----------+-----+---------------+ - "### + " ); Ok(()) @@ -3375,7 +3894,7 @@ async fn test_grouping_sets_count() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +----+----+-----------------+ | c1 | c2 | count(Int32(1)) | +----+----+-----------------+ @@ -3390,7 +3909,7 @@ async fn test_grouping_sets_count() -> Result<()> { | b | | 19 | | a | | 21 | +----+----+-----------------+ - "### + " ); Ok(()) @@ -3424,7 +3943,7 @@ async fn test_grouping_set_array_agg_with_overflow() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +----+----+--------+---------------------+ | c1 | c2 | sum_c3 | avg_c3 | +----+----+--------+---------------------+ @@ -3464,7 +3983,7 @@ async fn test_grouping_set_array_agg_with_overflow() -> Result<()> { | a | 2 | -46 | -15.333333333333334 | | a | 1 | -88 | -17.6 | +----+----+--------+---------------------+ - "### + " ); Ok(()) @@ -3501,25 +4020,25 @@ async fn join_with_alias_filter() -> Result<()> { let actual = formatted.trim(); assert_snapshot!( actual, - @r###" + @r" Projection: t1.a, t2.a, t1.b, t1.c, t2.b, t2.c [a:UInt32, a:UInt32, b:Utf8, c:Int32, b:Utf8, c:Int32] Inner Join: t1.a + UInt32(3) = t2.a + UInt32(1) [a:UInt32, b:Utf8, c:Int32, a:UInt32, b:Utf8, c:Int32] TableScan: t1 projection=[a, b, c] [a:UInt32, b:Utf8, c:Int32] TableScan: t2 projection=[a, b, c] [a:UInt32, b:Utf8, c:Int32] - "### + " ); let results = df.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----+----+---+----+---+---+ | a | a | b | c | b | c | +----+----+---+----+---+---+ | 1 | 3 | a | 10 | a | 1 | | 11 | 13 | c | 30 | c | 3 | +----+----+---+----+---+---+ - "### + " ); Ok(()) @@ -3546,27 +4065,27 @@ async fn right_semi_with_alias_filter() -> Result<()> { let actual = formatted.trim(); assert_snapshot!( actual, - @r###" + @r" RightSemi Join: t1.a = t2.a [a:UInt32, b:Utf8, c:Int32] Projection: t1.a [a:UInt32] Filter: t1.c > Int32(1) [a:UInt32, c:Int32] TableScan: t1 projection=[a, c] [a:UInt32, c:Int32] Filter: t2.c > Int32(1) [a:UInt32, b:Utf8, c:Int32] TableScan: t2 projection=[a, b, c] [a:UInt32, b:Utf8, c:Int32] - "### + " ); let results = df.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +-----+---+---+ | a | b | c | +-----+---+---+ | 10 | b | 2 | | 100 | d | 4 | +-----+---+---+ - "### + " ); Ok(()) @@ -3593,26 +4112,26 @@ async fn right_anti_filter_push_down() -> Result<()> { let actual = formatted.trim(); assert_snapshot!( actual, - @r###" + @r" RightAnti Join: t1.a = t2.a Filter: t2.c > Int32(1) [a:UInt32, b:Utf8, c:Int32] Projection: t1.a [a:UInt32] Filter: t1.c > Int32(1) [a:UInt32, c:Int32] TableScan: t1 projection=[a, c] [a:UInt32, c:Int32] TableScan: t2 projection=[a, b, c] [a:UInt32, b:Utf8, c:Int32] - "### + " ); let results = df.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----+---+---+ | a | b | c | +----+---+---+ | 13 | c | 3 | | 3 | a | 1 | +----+---+---+ - "### + " ); Ok(()) @@ -3625,37 +4144,37 @@ async fn unnest_columns() -> Result<()> { let results = df.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" - +----------+---------------------------------+--------------------------+ - | shape_id | points | tags | - +----------+---------------------------------+--------------------------+ - | 1 | [{x: 5, y: -8}, {x: -3, y: -4}] | [tag1] | - | 2 | [{x: 6, y: 2}, {x: -2, y: -8}] | [tag1] | - | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | [tag1, tag2, tag3, tag4] | - | 4 | | [tag1, tag2, tag3] | - +----------+---------------------------------+--------------------------+ - "###); + @r" + +----------+---------------------------------+--------------------------+ + | shape_id | points | tags | + +----------+---------------------------------+--------------------------+ + | 1 | [{x: 5, y: -8}, {x: -3, y: -4}] | [tag1] | + | 2 | [{x: 6, y: 2}, {x: -2, y: -8}] | [tag1] | + | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | [tag1, tag2, tag3, tag4] | + | 4 | | [tag1, tag2, tag3] | + +----------+---------------------------------+--------------------------+ + "); // Unnest tags let df = table_with_nested_types(NUM_ROWS).await?; let results = df.unnest_columns(&["tags"])?.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" - +----------+---------------------------------+------+ - | shape_id | points | tags | - +----------+---------------------------------+------+ - | 1 | [{x: 5, y: -8}, {x: -3, y: -4}] | tag1 | - | 2 | [{x: 6, y: 2}, {x: -2, y: -8}] | tag1 | - | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag1 | - | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag2 | - | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag3 | - | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag4 | - | 4 | | tag1 | - | 4 | | tag2 | - | 4 | | tag3 | - +----------+---------------------------------+------+ - "###); + @r" + +----------+---------------------------------+------+ + | shape_id | points | tags | + +----------+---------------------------------+------+ + | 1 | [{x: 5, y: -8}, {x: -3, y: -4}] | tag1 | + | 2 | [{x: 6, y: 2}, {x: -2, y: -8}] | tag1 | + | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag1 | + | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag2 | + | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag3 | + | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag4 | + | 4 | | tag1 | + | 4 | | tag2 | + | 4 | | tag3 | + +----------+---------------------------------+------+ + "); // Test aggregate results for tags. let df = table_with_nested_types(NUM_ROWS).await?; @@ -3667,19 +4186,19 @@ async fn unnest_columns() -> Result<()> { let results = df.unnest_columns(&["points"])?.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" - +----------+----------------+--------------------------+ - | shape_id | points | tags | - +----------+----------------+--------------------------+ - | 1 | {x: -3, y: -4} | [tag1] | - | 1 | {x: 5, y: -8} | [tag1] | - | 2 | {x: -2, y: -8} | [tag1] | - | 2 | {x: 6, y: 2} | [tag1] | - | 3 | {x: -2, y: 5} | [tag1, tag2, tag3, tag4] | - | 3 | {x: -9, y: -7} | [tag1, tag2, tag3, tag4] | - | 4 | | [tag1, tag2, tag3] | - +----------+----------------+--------------------------+ - "###); + @r" + +----------+----------------+--------------------------+ + | shape_id | points | tags | + +----------+----------------+--------------------------+ + | 1 | {x: -3, y: -4} | [tag1] | + | 1 | {x: 5, y: -8} | [tag1] | + | 2 | {x: -2, y: -8} | [tag1] | + | 2 | {x: 6, y: 2} | [tag1] | + | 3 | {x: -2, y: 5} | [tag1, tag2, tag3, tag4] | + | 3 | {x: -9, y: -7} | [tag1, tag2, tag3, tag4] | + | 4 | | [tag1, tag2, tag3] | + +----------+----------------+--------------------------+ + "); // Test aggregate results for points. let df = table_with_nested_types(NUM_ROWS).await?; @@ -3695,27 +4214,27 @@ async fn unnest_columns() -> Result<()> { .await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" - +----------+----------------+------+ - | shape_id | points | tags | - +----------+----------------+------+ - | 1 | {x: -3, y: -4} | tag1 | - | 1 | {x: 5, y: -8} | tag1 | - | 2 | {x: -2, y: -8} | tag1 | - | 2 | {x: 6, y: 2} | tag1 | - | 3 | {x: -2, y: 5} | tag1 | - | 3 | {x: -2, y: 5} | tag2 | - | 3 | {x: -2, y: 5} | tag3 | - | 3 | {x: -2, y: 5} | tag4 | - | 3 | {x: -9, y: -7} | tag1 | - | 3 | {x: -9, y: -7} | tag2 | - | 3 | {x: -9, y: -7} | tag3 | - | 3 | {x: -9, y: -7} | tag4 | - | 4 | | tag1 | - | 4 | | tag2 | - | 4 | | tag3 | - +----------+----------------+------+ - "###); + @r" + +----------+----------------+------+ + | shape_id | points | tags | + +----------+----------------+------+ + | 1 | {x: -3, y: -4} | tag1 | + | 1 | {x: 5, y: -8} | tag1 | + | 2 | {x: -2, y: -8} | tag1 | + | 2 | {x: 6, y: 2} | tag1 | + | 3 | {x: -2, y: 5} | tag1 | + | 3 | {x: -2, y: 5} | tag2 | + | 3 | {x: -2, y: 5} | tag3 | + | 3 | {x: -2, y: 5} | tag4 | + | 3 | {x: -9, y: -7} | tag1 | + | 3 | {x: -9, y: -7} | tag2 | + | 3 | {x: -9, y: -7} | tag3 | + | 3 | {x: -9, y: -7} | tag4 | + | 4 | | tag1 | + | 4 | | tag2 | + | 4 | | tag3 | + +----------+----------------+------+ + "); // Test aggregate results for points and tags. let df = table_with_nested_types(NUM_ROWS).await?; @@ -3755,7 +4274,7 @@ async fn unnest_dict_encoded_columns() -> Result<()> { let results = df.collect().await.unwrap(); assert_snapshot!( batches_to_string(&results), - @r###" + @r" +-----------------+---------+ | make_array_expr | column1 | +-----------------+---------+ @@ -3763,7 +4282,7 @@ async fn unnest_dict_encoded_columns() -> Result<()> { | y | y | | z | z | +-----------------+---------+ - "### + " ); // make_array(dict_encoded_string,literal string) @@ -3783,7 +4302,7 @@ async fn unnest_dict_encoded_columns() -> Result<()> { let results = df.collect().await.unwrap(); assert_snapshot!( batches_to_string(&results), - @r###" + @r" +-----------------+---------+ | make_array_expr | column1 | +-----------------+---------+ @@ -3794,7 +4313,7 @@ async fn unnest_dict_encoded_columns() -> Result<()> { | z | z | | fixed_string | z | +-----------------+---------+ - "### + " ); Ok(()) } @@ -3805,7 +4324,7 @@ async fn unnest_column_nulls() -> Result<()> { let results = df.clone().collect().await?; assert_snapshot!( batches_to_string(&results), - @r###" + @r" +--------+----+ | list | id | +--------+----+ @@ -3814,7 +4333,7 @@ async fn unnest_column_nulls() -> Result<()> { | [] | C | | [3] | D | +--------+----+ - "### + " ); // Unnest, preserving nulls (row with B is preserved) @@ -3827,7 +4346,7 @@ async fn unnest_column_nulls() -> Result<()> { .await?; assert_snapshot!( batches_to_string(&results), - @r###" + @r" +------+----+ | list | id | +------+----+ @@ -3836,7 +4355,7 @@ async fn unnest_column_nulls() -> Result<()> { | | B | | 3 | D | +------+----+ - "### + " ); let options = UnnestOptions::new().with_preserve_nulls(false); @@ -3846,7 +4365,7 @@ async fn unnest_column_nulls() -> Result<()> { .await?; assert_snapshot!( batches_to_string(&results), - @r###" + @r" +------+----+ | list | id | +------+----+ @@ -3854,7 +4373,7 @@ async fn unnest_column_nulls() -> Result<()> { | 2 | A | | 3 | D | +------+----+ - "### + " ); Ok(()) @@ -3871,7 +4390,7 @@ async fn unnest_fixed_list() -> Result<()> { let results = df.clone().collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----------+----------------+ | shape_id | tags | +----------+----------------+ @@ -3882,7 +4401,7 @@ async fn unnest_fixed_list() -> Result<()> { | 5 | [tag51, tag52] | | 6 | [tag61, tag62] | +----------+----------------+ - "### + " ); let options = UnnestOptions::new().with_preserve_nulls(true); @@ -3893,7 +4412,7 @@ async fn unnest_fixed_list() -> Result<()> { .await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----------+-------+ | shape_id | tags | +----------+-------+ @@ -3908,7 +4427,7 @@ async fn unnest_fixed_list() -> Result<()> { | 6 | tag61 | | 6 | tag62 | +----------+-------+ - "### + " ); Ok(()) @@ -3925,7 +4444,7 @@ async fn unnest_fixed_list_drop_nulls() -> Result<()> { let results = df.clone().collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----------+----------------+ | shape_id | tags | +----------+----------------+ @@ -3936,7 +4455,7 @@ async fn unnest_fixed_list_drop_nulls() -> Result<()> { | 5 | [tag51, tag52] | | 6 | [tag61, tag62] | +----------+----------------+ - "### + " ); let options = UnnestOptions::new().with_preserve_nulls(false); @@ -3947,7 +4466,7 @@ async fn unnest_fixed_list_drop_nulls() -> Result<()> { .await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----------+-------+ | shape_id | tags | +----------+-------+ @@ -3960,7 +4479,7 @@ async fn unnest_fixed_list_drop_nulls() -> Result<()> { | 6 | tag61 | | 6 | tag62 | +----------+-------+ - "### + " ); Ok(()) @@ -3996,7 +4515,7 @@ async fn unnest_fixed_list_non_null() -> Result<()> { let results = df.clone().collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----------+----------------+ | shape_id | tags | +----------+----------------+ @@ -4007,7 +4526,7 @@ async fn unnest_fixed_list_non_null() -> Result<()> { | 5 | [tag51, tag52] | | 6 | [tag61, tag62] | +----------+----------------+ - "### + " ); let options = UnnestOptions::new().with_preserve_nulls(true); @@ -4017,7 +4536,7 @@ async fn unnest_fixed_list_non_null() -> Result<()> { .await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----------+-------+ | shape_id | tags | +----------+-------+ @@ -4034,7 +4553,7 @@ async fn unnest_fixed_list_non_null() -> Result<()> { | 6 | tag61 | | 6 | tag62 | +----------+-------+ - "### + " ); Ok(()) @@ -4048,17 +4567,17 @@ async fn unnest_aggregate_columns() -> Result<()> { let results = df.select_columns(&["tags"])?.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" - +--------------------------+ - | tags | - +--------------------------+ - | [tag1, tag2, tag3, tag4] | - | [tag1, tag2, tag3] | - | [tag1, tag2] | - | [tag1] | - | [tag1] | - +--------------------------+ - "### + @r" + +--------------------------+ + | tags | + +--------------------------+ + | [tag1, tag2, tag3, tag4] | + | [tag1, tag2, tag3] | + | [tag1, tag2] | + | [tag1] | + | [tag1] | + +--------------------------+ + " ); let df = table_with_nested_types(NUM_ROWS).await?; @@ -4069,13 +4588,13 @@ async fn unnest_aggregate_columns() -> Result<()> { .await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +-------------+ | count(tags) | +-------------+ | 11 | +-------------+ - "### + " ); Ok(()) @@ -4148,7 +4667,7 @@ async fn unnest_array_agg() -> Result<()> { assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----------+--------+ | shape_id | tag_id | +----------+--------+ @@ -4162,7 +4681,7 @@ async fn unnest_array_agg() -> Result<()> { | 3 | 32 | | 3 | 33 | +----------+--------+ - "### + " ); // Doing an `array_agg` by `shape_id` produces: @@ -4176,7 +4695,7 @@ async fn unnest_array_agg() -> Result<()> { .await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----------+--------------+ | shape_id | tag_id | +----------+--------------+ @@ -4184,7 +4703,7 @@ async fn unnest_array_agg() -> Result<()> { | 2 | [21, 22, 23] | | 3 | [31, 32, 33] | +----------+--------------+ - "### + " ); // Unnesting again should produce the original batch. @@ -4200,7 +4719,7 @@ async fn unnest_array_agg() -> Result<()> { .await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----------+--------+ | shape_id | tag_id | +----------+--------+ @@ -4214,7 +4733,7 @@ async fn unnest_array_agg() -> Result<()> { | 3 | 32 | | 3 | 33 | +----------+--------+ - "### + " ); Ok(()) @@ -4244,7 +4763,7 @@ async fn unnest_with_redundant_columns() -> Result<()> { let results = df.clone().collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----------+--------+ | shape_id | tag_id | +----------+--------+ @@ -4258,7 +4777,7 @@ async fn unnest_with_redundant_columns() -> Result<()> { | 3 | 32 | | 3 | 33 | +----------+--------+ - "### + " ); // Doing an `array_agg` by `shape_id` produces: @@ -4277,18 +4796,18 @@ async fn unnest_with_redundant_columns() -> Result<()> { let actual = formatted.trim(); assert_snapshot!( actual, - @r###" + @r" Projection: shapes.shape_id [shape_id:UInt32] Unnest: lists[shape_id2|depth=1] structs[] [shape_id:UInt32, shape_id2:UInt32;N] - Aggregate: groupBy=[[shapes.shape_id]], aggr=[[array_agg(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: "item", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N] + Aggregate: groupBy=[[shapes.shape_id]], aggr=[[array_agg(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(UInt32);N] TableScan: shapes projection=[shape_id] [shape_id:UInt32] - "### + " ); let results = df.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----------+ | shape_id | +----------+ @@ -4302,7 +4821,7 @@ async fn unnest_with_redundant_columns() -> Result<()> { | 3 | | 3 | +----------+ - "### + " ); Ok(()) @@ -4343,7 +4862,7 @@ async fn unnest_multiple_columns() -> Result<()> { // string: a, b, c, d assert_snapshot!( batches_to_string(&results), - @r###" + @r" +------+------------+------------+--------+ | list | large_list | fixed_list | string | +------+------------+------------+--------+ @@ -4357,7 +4876,7 @@ async fn unnest_multiple_columns() -> Result<()> { | | | 4 | c | | | | | d | +------+------------+------------+--------+ - "### + " ); // Test with `preserve_nulls = false`` @@ -4374,7 +4893,7 @@ async fn unnest_multiple_columns() -> Result<()> { // string: a, b, c, d assert_snapshot!( batches_to_string(&results), - @r###" + @r" +------+------------+------------+--------+ | list | large_list | fixed_list | string | +------+------------+------------+--------+ @@ -4387,7 +4906,7 @@ async fn unnest_multiple_columns() -> Result<()> { | | | 3 | c | | | | 4 | c | +------+------------+------------+--------+ - "### + " ); Ok(()) @@ -4416,7 +4935,7 @@ async fn unnest_non_nullable_list() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +----+ | c1 | +----+ @@ -4424,7 +4943,7 @@ async fn unnest_non_nullable_list() -> Result<()> { | 2 | | | +----+ - "### + " ); Ok(()) @@ -4469,7 +4988,7 @@ async fn test_read_batches() -> Result<()> { let results = df.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----+--------+ | id | number | +----+--------+ @@ -4482,7 +5001,7 @@ async fn test_read_batches() -> Result<()> { | 5 | 3.33 | | 5 | 6.66 | +----+--------+ - "### + " ); Ok(()) } @@ -4503,10 +5022,10 @@ async fn test_read_batches_empty() -> Result<()> { let results = df.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" ++ ++ - "### + " ); Ok(()) } @@ -4527,7 +5046,10 @@ async fn consecutive_projection_same_schema() -> Result<()> { // Add `t` column full of nulls let df = df - .with_column("t", cast(Expr::Literal(ScalarValue::Null), DataType::Int32)) + .with_column( + "t", + cast(Expr::Literal(ScalarValue::Null, None), DataType::Int32), + ) .unwrap(); df.clone().show().await.unwrap(); @@ -4552,14 +5074,14 @@ async fn consecutive_projection_same_schema() -> Result<()> { let results = df.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +----+----+----+ | id | t | t2 | +----+----+----+ | 0 | | | | 1 | 10 | 10 | +----+----+----+ - "### + " ); Ok(()) @@ -4846,7 +5368,7 @@ async fn use_var_provider() -> Result<()> { Field::new("bar", DataType::Int64, false), ])); - let mem_table = Arc::new(MemTable::try_new(schema, vec![])?); + let mem_table = Arc::new(MemTable::try_new(schema, vec![vec![]])?); let config = SessionConfig::new() .with_target_partitions(4) @@ -4873,13 +5395,13 @@ async fn test_array_agg() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +-------------------------------------+ | array_agg(test.a) | +-------------------------------------+ | [abcDEF, abc123, CBAdef, 123AbcDef] | +-------------------------------------+ - "### + " ); Ok(()) @@ -4904,11 +5426,11 @@ async fn test_dataframe_placeholder_missing_param_values() -> Result<()> { assert_snapshot!( actual, - @r###" + @r" Filter: a = $0 [a:Int32] Projection: Int32(1) AS a [a:Int32] - EmptyRelation [] - "### + EmptyRelation: rows=1 [] + " ); // Executing LogicalPlans with placeholders that don't have bound values @@ -4937,20 +5459,20 @@ async fn test_dataframe_placeholder_missing_param_values() -> Result<()> { assert_snapshot!( actual, - @r###" + @r" Filter: a = Int32(3) [a:Int32] Projection: Int32(1) AS a [a:Int32] - EmptyRelation [] - "### + EmptyRelation: rows=1 [] + " ); // N.B., the test is basically `SELECT 1 as a WHERE a = 3;` which returns no results. assert_snapshot!( batches_to_string(&df.collect().await.unwrap()), - @r###" + @r" ++ ++ - "### + " ); Ok(()) @@ -4968,10 +5490,10 @@ async fn test_dataframe_placeholder_column_parameter() -> Result<()> { assert_snapshot!( actual, - @r###" + @r" Projection: $1 [$1:Null;N] - EmptyRelation [] - "### + EmptyRelation: rows=1 [] + " ); // Executing LogicalPlans with placeholders that don't have bound values @@ -4998,21 +5520,21 @@ async fn test_dataframe_placeholder_column_parameter() -> Result<()> { assert_snapshot!( actual, - @r###" - Projection: Int32(3) AS $1 [$1:Null;N] - EmptyRelation [] - "### + @r" + Projection: Int32(3) AS $1 [$1:Int32] + EmptyRelation: rows=1 [] + " ); assert_snapshot!( batches_to_string(&df.collect().await.unwrap()), - @r###" + @r" +----+ | $1 | +----+ | 3 | +----+ - "### + " ); Ok(()) @@ -5037,11 +5559,11 @@ async fn test_dataframe_placeholder_like_expression() -> Result<()> { assert_snapshot!( actual, - @r###" + @r#" Filter: a LIKE $1 [a:Utf8] Projection: Utf8("foo") AS a [a:Utf8] - EmptyRelation [] - "### + EmptyRelation: rows=1 [] + "# ); // Executing LogicalPlans with placeholders that don't have bound values @@ -5070,51 +5592,54 @@ async fn test_dataframe_placeholder_like_expression() -> Result<()> { assert_snapshot!( actual, - @r###" + @r#" Filter: a LIKE Utf8("f%") [a:Utf8] Projection: Utf8("foo") AS a [a:Utf8] - EmptyRelation [] - "### + EmptyRelation: rows=1 [] + "# ); assert_snapshot!( batches_to_string(&df.collect().await.unwrap()), - @r###" + @r" +-----+ | a | +-----+ | foo | +-----+ - "### + " ); Ok(()) } +#[rstest] +#[case(DataType::Utf8)] +#[case(DataType::LargeUtf8)] +#[case(DataType::Utf8View)] #[tokio::test] -async fn write_partitioned_parquet_results() -> Result<()> { - // create partitioned input file and context - let tmp_dir = TempDir::new()?; - - let ctx = SessionContext::new(); - +async fn write_partitioned_parquet_results(#[case] string_type: DataType) -> Result<()> { // Create an in memory table with schema C1 and C2, both strings let schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Utf8, false), - Field::new("c2", DataType::Utf8, false), + Field::new("c1", string_type.clone(), false), + Field::new("c2", string_type.clone(), false), ])); - let record_batch = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(StringArray::from(vec!["abc", "def"])), - Arc::new(StringArray::from(vec!["123", "456"])), - ], - )?; + let columns = [ + Arc::new(StringArray::from(vec!["abc", "def"])) as ArrayRef, + Arc::new(StringArray::from(vec!["123", "456"])) as ArrayRef, + ] + .map(|col| arrow::compute::cast(&col, &string_type).unwrap()) + .to_vec(); + + let record_batch = RecordBatch::try_new(schema.clone(), columns)?; let mem_table = Arc::new(MemTable::try_new(schema, vec![vec![record_batch]])?); // Register the table in the context + // create partitioned input file and context + let tmp_dir = TempDir::new()?; + let ctx = SessionContext::new(); ctx.register_table("test", mem_table)?; let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?); @@ -5141,16 +5666,17 @@ async fn write_partitioned_parquet_results() -> Result<()> { // Check that the c2 column is gone and that c1 is abc. let results = filter_df.collect().await?; + insta::allow_duplicates! { assert_snapshot!( batches_to_string(&results), - @r###" + @r" +-----+ | c1 | +-----+ | abc | +-----+ - "### - ); + " + )}; // Read the entire set of parquet files let df = ctx @@ -5163,17 +5689,19 @@ async fn write_partitioned_parquet_results() -> Result<()> { // Check that the df has the entire set of data let results = df.collect().await?; - assert_snapshot!( - batches_to_sort_string(&results), - @r###" + insta::allow_duplicates! { + assert_snapshot!( + batches_to_sort_string(&results), + @r" +-----+-----+ | c1 | c2 | +-----+-----+ | abc | 123 | | def | 456 | +-----+-----+ - "### - ); + " + ) + }; Ok(()) } @@ -5284,11 +5812,11 @@ async fn union_literal_is_null_and_not_null() -> Result<()> { for batch in batches { // Verify schema is the same for all batches if !schema.contains(&batch.schema()) { - return Err(DataFusionError::Internal(format!( + return Err(internal_datafusion_err!( "Schema mismatch. Previously had\n{:#?}\n\nGot:\n{:#?}", &schema, batch.schema() - ))); + )); } } @@ -5329,7 +5857,7 @@ async fn sparse_union_is_null() { // view_all assert_snapshot!( batches_to_sort_string(&df.clone().collect().await.unwrap()), - @r###" + @r" +----------+ | my_union | +----------+ @@ -5340,14 +5868,14 @@ async fn sparse_union_is_null() { | {C=a} | | {C=} | +----------+ - "### + " ); // filter where is null let result_df = df.clone().filter(col("my_union").is_null()).unwrap(); assert_snapshot!( batches_to_sort_string(&result_df.collect().await.unwrap()), - @r###" + @r" +----------+ | my_union | +----------+ @@ -5355,14 +5883,14 @@ async fn sparse_union_is_null() { | {B=} | | {C=} | +----------+ - "### + " ); // filter where is not null let result_df = df.filter(col("my_union").is_not_null()).unwrap(); assert_snapshot!( batches_to_sort_string(&result_df.collect().await.unwrap()), - @r###" + @r" +----------+ | my_union | +----------+ @@ -5370,7 +5898,7 @@ async fn sparse_union_is_null() { | {B=3.2} | | {C=a} | +----------+ - "### + " ); } @@ -5412,7 +5940,7 @@ async fn dense_union_is_null() { // view_all assert_snapshot!( batches_to_sort_string(&df.clone().collect().await.unwrap()), - @r###" + @r" +----------+ | my_union | +----------+ @@ -5423,14 +5951,14 @@ async fn dense_union_is_null() { | {C=a} | | {C=} | +----------+ - "### + " ); // filter where is null let result_df = df.clone().filter(col("my_union").is_null()).unwrap(); assert_snapshot!( batches_to_sort_string(&result_df.collect().await.unwrap()), - @r###" + @r" +----------+ | my_union | +----------+ @@ -5438,14 +5966,14 @@ async fn dense_union_is_null() { | {B=} | | {C=} | +----------+ - "### + " ); // filter where is not null let result_df = df.filter(col("my_union").is_not_null()).unwrap(); assert_snapshot!( batches_to_sort_string(&result_df.collect().await.unwrap()), - @r###" + @r" +----------+ | my_union | +----------+ @@ -5453,7 +5981,7 @@ async fn dense_union_is_null() { | {B=3.2} | | {C=a} | +----------+ - "### + " ); } @@ -5485,7 +6013,7 @@ async fn boolean_dictionary_as_filter() { // view_all assert_snapshot!( batches_to_string(&df.clone().collect().await.unwrap()), - @r###" + @r" +---------+ | my_dict | +---------+ @@ -5497,14 +6025,14 @@ async fn boolean_dictionary_as_filter() { | true | | false | +---------+ - "### + " ); let result_df = df.clone().filter(col("my_dict")).unwrap(); assert_snapshot!( batches_to_string(&result_df.collect().await.unwrap()), - @r###" + @r" +---------+ | my_dict | +---------+ @@ -5512,7 +6040,7 @@ async fn boolean_dictionary_as_filter() { | true | | true | +---------+ - "### + " ); // test nested dictionary @@ -5543,26 +6071,26 @@ async fn boolean_dictionary_as_filter() { // view_all assert_snapshot!( batches_to_string(&df.clone().collect().await.unwrap()), - @r###" + @r" +----------------+ | my_nested_dict | +----------------+ | true | | false | +----------------+ - "### + " ); let result_df = df.clone().filter(col("my_nested_dict")).unwrap(); assert_snapshot!( batches_to_string(&result_df.collect().await.unwrap()), - @r###" + @r" +----------------+ | my_nested_dict | +----------------+ | true | +----------------+ - "### + " ); } @@ -5630,7 +6158,7 @@ async fn test_alias() -> Result<()> { .await? .select(vec![col("a"), col("test.b"), lit(1).alias("one")])? .alias("table_alias")?; - // All ouput column qualifiers are changed to "table_alias" + // All output column qualifiers are changed to "table_alias" df.schema().columns().iter().for_each(|c| { assert_eq!(c.relation, Some("table_alias".into())); }); @@ -5640,11 +6168,11 @@ async fn test_alias() -> Result<()> { .into_unoptimized_plan() .display_indent_schema() .to_string(); - assert_snapshot!(plan, @r###" + assert_snapshot!(plan, @r" SubqueryAlias: table_alias [a:Utf8, b:Int32, one:Int32] Projection: test.a, test.b, Int32(1) AS one [a:Utf8, b:Int32, one:Int32] TableScan: test [a:Utf8, b:Int32] - "###); + "); // Select over the aliased DataFrame let df = df.select(vec![ @@ -5653,7 +6181,7 @@ async fn test_alias() -> Result<()> { ])?; assert_snapshot!( batches_to_sort_string(&df.collect().await.unwrap()), - @r###" + @r" +-----------+---------------------------------+ | a | table_alias.b + table_alias.one | +-----------+---------------------------------+ @@ -5662,7 +6190,7 @@ async fn test_alias() -> Result<()> { | abc123 | 11 | | abcDEF | 2 | +-----------+---------------------------------+ - "### + " ); Ok(()) } @@ -5671,6 +6199,7 @@ async fn test_alias() -> Result<()> { async fn test_alias_with_metadata() -> Result<()> { let mut metadata = HashMap::new(); metadata.insert(String::from("k"), String::from("v")); + let metadata = FieldMetadata::from(metadata); let df = create_test_table("test") .await? .select(vec![col("a").alias_with_metadata("b", Some(metadata))])? @@ -5691,7 +6220,7 @@ async fn test_alias_self_join() -> Result<()> { let joined = left.join(right, JoinType::Full, &["a"], &["a"], None)?; assert_snapshot!( batches_to_sort_string(&joined.collect().await.unwrap()), - @r###" + @r" +-----------+-----+-----------+-----+ | a | b | a | b | +-----------+-----+-----------+-----+ @@ -5700,7 +6229,7 @@ async fn test_alias_self_join() -> Result<()> { | abc123 | 10 | abc123 | 10 | | abcDEF | 1 | abcDEF | 1 | +-----------+-----+-----------+-----+ - "### + " ); Ok(()) } @@ -5713,14 +6242,14 @@ async fn test_alias_empty() -> Result<()> { .into_unoptimized_plan() .display_indent_schema() .to_string(); - assert_snapshot!(plan, @r###" + assert_snapshot!(plan, @r" SubqueryAlias: [a:Utf8, b:Int32] TableScan: test [a:Utf8, b:Int32] - "###); + "); assert_snapshot!( batches_to_sort_string(&df.select(vec![col("a"), col("b")])?.collect().await.unwrap()), - @r###" + @r" +-----------+-----+ | a | b | +-----------+-----+ @@ -5729,7 +6258,7 @@ async fn test_alias_empty() -> Result<()> { | abc123 | 10 | | abcDEF | 1 | +-----------+-----+ - "### + " ); Ok(()) @@ -5748,12 +6277,12 @@ async fn test_alias_nested() -> Result<()> { .into_optimized_plan()? .display_indent_schema() .to_string(); - assert_snapshot!(plan, @r###" + assert_snapshot!(plan, @r" SubqueryAlias: alias2 [a:Utf8, b:Int32, one:Int32] SubqueryAlias: alias1 [a:Utf8, b:Int32, one:Int32] Projection: test.a, test.b, Int32(1) AS one [a:Utf8, b:Int32, one:Int32] TableScan: test projection=[a, b] [a:Utf8, b:Int32] - "###); + "); // Select over the aliased DataFrame let select1 = df @@ -5762,7 +6291,7 @@ async fn test_alias_nested() -> Result<()> { assert_snapshot!( batches_to_sort_string(&select1.collect().await.unwrap()), - @r###" + @r" +-----------+-----------------------+ | a | alias2.b + alias2.one | +-----------+-----------------------+ @@ -5771,7 +6300,7 @@ async fn test_alias_nested() -> Result<()> { | abc123 | 11 | | abcDEF | 2 | +-----------+-----------------------+ - "### + " ); // Only the outermost alias is visible @@ -5790,7 +6319,7 @@ async fn register_non_json_file() { .register_json( "data", "tests/data/test_binary.parquet", - NdJsonReadOptions::default(), + JsonReadOptions::default(), ) .await; assert_contains!( @@ -5891,7 +6420,10 @@ async fn test_insert_into_checking() -> Result<()> { .await .unwrap_err(); - assert_contains!(e.to_string(), "Inserting query schema mismatch: Expected table field 'a' with type Int64, but got 'column1' with type Utf8"); + assert_contains!( + e.to_string(), + "Inserting query schema mismatch: Expected table field 'a' with type Int64, but got 'column1' with type Utf8" + ); Ok(()) } @@ -5938,7 +6470,7 @@ async fn test_fill_null() -> Result<()> { let results = df_filled.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +---+---------+ | a | b | +---+---------+ @@ -5946,7 +6478,7 @@ async fn test_fill_null() -> Result<()> { | 1 | x | | 3 | z | +---+---------+ - "### + " ); Ok(()) @@ -5966,7 +6498,7 @@ async fn test_fill_null_all_columns() -> Result<()> { assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +---+---------+ | a | b | +---+---------+ @@ -5974,7 +6506,7 @@ async fn test_fill_null_all_columns() -> Result<()> { | 1 | x | | 3 | z | +---+---------+ - "### + " ); // Fill column "a" null values with a value that cannot be cast to Int32. @@ -5983,7 +6515,7 @@ async fn test_fill_null_all_columns() -> Result<()> { let results = df_filled.collect().await?; assert_snapshot!( batches_to_sort_string(&results), - @r###" + @r" +---+---------+ | a | b | +---+---------+ @@ -5991,7 +6523,7 @@ async fn test_fill_null_all_columns() -> Result<()> { | 1 | x | | 3 | z | +---+---------+ - "### + " ); Ok(()) } @@ -6000,7 +6532,7 @@ async fn test_fill_null_all_columns() -> Result<()> { async fn test_insert_into_casting_support() -> Result<()> { // Testing case1: // Inserting query schema mismatch: Expected table field 'a' with type Float16, but got 'a' with type Utf8. - // And the cast is not supported from Utf8 to Float16. + // And the cast is not supported from Binary to Float16. // Create a new schema with one field called "a" of type Float16, and setting nullable to false let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float16, false)])); @@ -6011,7 +6543,10 @@ async fn test_insert_into_casting_support() -> Result<()> { let initial_table = Arc::new(MemTable::try_new(schema.clone(), vec![vec![]])?); session_ctx.register_table("t", initial_table.clone())?; - let mut write_df = session_ctx.sql("values ('a123'), ('b456')").await.unwrap(); + let mut write_df = session_ctx + .sql("values (x'a123'), (x'b456')") + .await + .unwrap(); write_df = write_df .clone() @@ -6023,7 +6558,10 @@ async fn test_insert_into_casting_support() -> Result<()> { .await .unwrap_err(); - assert_contains!(e.to_string(), "Inserting query schema mismatch: Expected table field 'a' with type Float16, but got 'a' with type Utf8."); + assert_contains!( + e.to_string(), + "Inserting query schema mismatch: Expected table field 'a' with type Float16, but got 'a' with type Binary." + ); // Testing case2: // Inserting query schema mismatch: Expected table field 'a' with type Utf8View, but got 'a' with type Utf8. @@ -6061,14 +6599,14 @@ async fn test_insert_into_casting_support() -> Result<()> { assert_snapshot!( batches_to_string(&res), - @r###" + @r" +------+ | a | +------+ | a123 | | b456 | +------+ - "### + " ); Ok(()) } @@ -6131,3 +6669,188 @@ async fn test_dataframe_macro() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_copy_schema() -> Result<()> { + let tmp_dir = TempDir::new()?; + + let session_state = SessionStateBuilder::new_with_default_features().build(); + + let session_ctx = SessionContext::new_with_state(session_state); + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, true)])); + + // Create and register the source table with the provided schema and data + let source_table = Arc::new(MemTable::try_new(schema.clone(), vec![vec![]])?); + session_ctx.register_table("source_table", source_table.clone())?; + + let target_path = tmp_dir.path().join("target.csv"); + + let query = format!( + "COPY source_table TO '{}' STORED AS csv", + target_path.to_str().unwrap() + ); + + let result = session_ctx.sql(&query).await?; + assert_logical_expr_schema_eq_physical_expr_schema(result).await?; + Ok(()) +} + +#[tokio::test] +async fn test_copy_to_preserves_order() -> Result<()> { + let tmp_dir = TempDir::new()?; + + let session_state = SessionStateBuilder::new_with_default_features().build(); + let session_ctx = SessionContext::new_with_state(session_state); + + let target_path = tmp_dir.path().join("target_ordered.csv"); + let csv_file_format = session_ctx + .state() + .get_file_format_factory("csv") + .map(format_as_file_type) + .unwrap(); + + let ordered_select_plan = LogicalPlanBuilder::values(vec![ + vec![lit(1u64)], + vec![lit(10u64)], + vec![lit(20u64)], + vec![lit(100u64)], + ])? + .sort(vec![SortExpr::new(col("column1"), false, true)])? + .build()?; + + let copy_to_plan = LogicalPlanBuilder::copy_to( + ordered_select_plan, + target_path.to_str().unwrap().to_string(), + csv_file_format, + HashMap::new(), + vec![], + )? + .build()?; + + let union_side_branch = LogicalPlanBuilder::values(vec![vec![lit(1u64)]])?.build()?; + let union_plan = LogicalPlanBuilder::from(copy_to_plan) + .union(union_side_branch)? + .build()?; + + let frame = session_ctx.execute_logical_plan(union_plan).await?; + let physical_plan = frame.create_physical_plan().await?; + + let physical_plan_format = + displayable(physical_plan.as_ref()).indent(true).to_string(); + + // Expect that input to the DataSinkExec is sorted correctly + assert_snapshot!( + physical_plan_format, + @r" + UnionExec + DataSinkExec: sink=CsvSink(file_groups=[]) + SortExec: expr=[column1@0 DESC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[1] + DataSourceExec: partitions=1, partition_sizes=[1] + " + ); + Ok(()) +} + +#[tokio::test] +async fn test_duplicate_state_fields_for_dfschema_construct() -> Result<()> { + let ctx = SessionContext::new(); + + // Simple schema with just the fields we need + let file_schema = Arc::new(Schema::new(vec![ + Field::new( + "timestamp", + DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())), + true, + ), + Field::new("ticker", DataType::Utf8, true), + Field::new("value", DataType::Float64, true), + Field::new("date", DataType::Utf8, false), + ])); + + let df_schema = DFSchema::try_from(file_schema.clone())?; + + let timestamp = col("timestamp"); + let value = col("value"); + let ticker = col("ticker"); + let date = col("date"); + + let mock_exec = Arc::new(EmptyExec::new(file_schema.clone())); + + // Build first_value aggregate + let first_value = Arc::new( + AggregateExprBuilder::new( + datafusion_functions_aggregate::first_last::first_value_udaf(), + vec![ctx.create_physical_expr(value.clone(), &df_schema)?], + ) + .alias("first_value(value)") + .order_by(vec![PhysicalSortExpr::new( + ctx.create_physical_expr(timestamp.clone(), &df_schema)?, + SortOptions::new(false, false), + )]) + .schema(file_schema.clone()) + .build() + .expect("Failed to build first_value"), + ); + + // Build last_value aggregate + let last_value = Arc::new( + AggregateExprBuilder::new( + datafusion_functions_aggregate::first_last::last_value_udaf(), + vec![ctx.create_physical_expr(value.clone(), &df_schema)?], + ) + .alias("last_value(value)") + .order_by(vec![PhysicalSortExpr::new( + ctx.create_physical_expr(timestamp.clone(), &df_schema)?, + SortOptions::new(false, false), + )]) + .schema(file_schema.clone()) + .build() + .expect("Failed to build last_value"), + ); + + let partial_agg = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(vec![ + ( + ctx.create_physical_expr(date.clone(), &df_schema)?, + "date".to_string(), + ), + ( + ctx.create_physical_expr(ticker.clone(), &df_schema)?, + "ticker".to_string(), + ), + ]), + vec![first_value, last_value], + vec![None, None], + mock_exec, + file_schema, + ) + .expect("Failed to build partial agg"); + + // Assert that the schema field names match the expected names + let expected_field_names = vec![ + "date", + "ticker", + "first_value(value)[first_value]", + "timestamp@0", + "first_value(value)[first_value_is_set]", + "last_value(value)[last_value]", + "timestamp@0", + "last_value(value)[last_value_is_set]", + ]; + + let binding = partial_agg.schema(); + let actual_field_names: Vec<_> = binding.fields().iter().map(|f| f.name()).collect(); + assert_eq!(actual_field_names, expected_field_names); + + // Ensure that DFSchema::try_from does not fail + let partial_agg_exec_schema = DFSchema::try_from(partial_agg.schema()); + assert!( + partial_agg_exec_schema.is_ok(), + "Expected get AggregateExec schema to succeed with duplicate state fields" + ); + + Ok(()) +} diff --git a/datafusion/core/tests/datasource/csv.rs b/datafusion/core/tests/datasource/csv.rs new file mode 100644 index 0000000000000..2e1daa113b096 --- /dev/null +++ b/datafusion/core/tests/datasource/csv.rs @@ -0,0 +1,122 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Test for CSV schema inference with different column counts (GitHub issue #17516) + +use datafusion::error::Result; +use datafusion::prelude::*; +use datafusion_common::test_util::batches_to_sort_string; +use insta::assert_snapshot; +use std::fs; +use tempfile::TempDir; + +#[tokio::test] +async fn test_csv_schema_inference_different_column_counts() -> Result<()> { + // Create temporary directory for test files + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let temp_path = temp_dir.path(); + + // Create CSV file 1 with 3 columns (simulating older railway services format) + let csv1_content = r#"service_id,route_type,agency_id +1,bus,agency1 +2,rail,agency2 +3,bus,agency3 +"#; + fs::write(temp_path.join("services_2024.csv"), csv1_content)?; + + // Create CSV file 2 with 6 columns (simulating newer railway services format) + let csv2_content = r#"service_id,route_type,agency_id,stop_platform_change,stop_planned_platform,stop_actual_platform +4,rail,agency2,true,Platform A,Platform B +5,bus,agency1,false,Stop 1,Stop 1 +6,rail,agency3,true,Platform C,Platform D +"#; + fs::write(temp_path.join("services_2025.csv"), csv2_content)?; + + // Create DataFusion context + let ctx = SessionContext::new(); + + // This should now work (previously would have failed with column count mismatch) + // Enable truncated_rows to handle files with different column counts + let df = ctx + .read_csv( + temp_path.to_str().unwrap(), + CsvReadOptions::new().truncated_rows(true), + ) + .await + .expect("Should successfully read CSV directory with different column counts"); + + // Verify the schema contains all 6 columns (union of both files) + let df_clone = df.clone(); + let schema = df_clone.schema(); + assert_eq!( + schema.fields().len(), + 6, + "Schema should contain all 6 columns" + ); + + // Check that we have all expected columns + let field_names: Vec<&str> = + schema.fields().iter().map(|f| f.name().as_str()).collect(); + assert!(field_names.contains(&"service_id")); + assert!(field_names.contains(&"route_type")); + assert!(field_names.contains(&"agency_id")); + assert!(field_names.contains(&"stop_platform_change")); + assert!(field_names.contains(&"stop_planned_platform")); + assert!(field_names.contains(&"stop_actual_platform")); + + // All fields should be nullable since they don't appear in all files + for field in schema.fields() { + assert!( + field.is_nullable(), + "Field {} should be nullable", + field.name() + ); + } + + // Verify we can actually read the data + let results = df.collect().await?; + + // Calculate total rows across all batches + let total_rows: usize = results.iter().map(|batch| batch.num_rows()).sum(); + assert_eq!(total_rows, 6, "Should have 6 total rows across all batches"); + + // All batches should have 6 columns (the union schema) + for batch in &results { + assert_eq!(batch.num_columns(), 6, "All batches should have 6 columns"); + assert_eq!( + batch.schema().fields().len(), + 6, + "Each batch should use the union schema with 6 fields" + ); + } + + // Verify the actual content of the data using snapshot testing + assert_snapshot!(batches_to_sort_string(&results), @r" + +------------+------------+-----------+----------------------+-----------------------+----------------------+ + | service_id | route_type | agency_id | stop_platform_change | stop_planned_platform | stop_actual_platform | + +------------+------------+-----------+----------------------+-----------------------+----------------------+ + | 1 | bus | agency1 | | | | + | 2 | rail | agency2 | | | | + | 3 | bus | agency3 | | | | + | 4 | rail | agency2 | true | Platform A | Platform B | + | 5 | bus | agency1 | false | Stop 1 | Stop 1 | + | 6 | rail | agency3 | true | Platform C | Platform D | + +------------+------------+-----------+----------------------+-----------------------+----------------------+ + "); + + Ok(()) +} diff --git a/datafusion/core/tests/datasource/mod.rs b/datafusion/core/tests/datasource/mod.rs new file mode 100644 index 0000000000000..3785aa0766182 --- /dev/null +++ b/datafusion/core/tests/datasource/mod.rs @@ -0,0 +1,24 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for various DataSources +//! +//! Note tests for the Parquet format are in `parquet_integration` binary + +// Include tests in csv module +mod csv; +mod object_store_access; diff --git a/datafusion/core/tests/datasource/object_store_access.rs b/datafusion/core/tests/datasource/object_store_access.rs new file mode 100644 index 0000000000000..30654c687f8d2 --- /dev/null +++ b/datafusion/core/tests/datasource/object_store_access.rs @@ -0,0 +1,971 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for object store access patterns with [`ListingTable`]\ +//! +//! These tests setup a `ListingTable` backed by an in-memory object store +//! that counts the number of requests made against it and then do +//! various operations (table creation, queries with and without predicates) +//! to verify the expected object store access patterns. +//! +//! [`ListingTable`]: datafusion::datasource::listing::ListingTable + +use arrow::array::{ArrayRef, Int32Array, RecordBatch}; +use async_trait::async_trait; +use bytes::Bytes; +use datafusion::prelude::{CsvReadOptions, ParquetReadOptions, SessionContext}; +use datafusion_catalog_listing::{ListingOptions, ListingTable, ListingTableConfig}; +use datafusion_datasource::ListingTableUrl; +use datafusion_datasource_csv::CsvFormat; +use futures::stream::BoxStream; +use insta::assert_snapshot; +use object_store::memory::InMemory; +use object_store::path::Path; +use object_store::{ + CopyOptions, GetOptions, GetRange, GetResult, ListResult, MultipartUpload, + ObjectMeta, ObjectStore, ObjectStoreExt, PutMultipartOptions, PutOptions, PutPayload, + PutResult, +}; +use parking_lot::Mutex; +use std::fmt; +use std::fmt::{Display, Formatter}; +use std::ops::Range; +use std::sync::Arc; +use url::Url; + +#[tokio::test] +async fn create_single_csv_file() { + let test = Test::new().with_single_file_csv().await; + assert_snapshot!( + test.requests(), + @r" + RequestCountingObjectStore() + Total Requests: 2 + - GET (opts) path=csv_table.csv head=true + - GET (opts) path=csv_table.csv + " + ); +} + +#[tokio::test] +async fn query_single_csv_file() { + let test = Test::new().with_single_file_csv().await; + assert_snapshot!( + test.query("select * from csv_table").await, + @r" + ------- Query Output (2 rows) ------- + +---------+-------+-------+ + | c1 | c2 | c3 | + +---------+-------+-------+ + | 0.00001 | 5e-12 | true | + | 0.00002 | 4e-12 | false | + +---------+-------+-------+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 2 + - GET (opts) path=csv_table.csv head=true + - GET (opts) path=csv_table.csv + " + ); +} + +#[tokio::test] +async fn create_multi_file_csv_file() { + let test = Test::new().with_multi_file_csv().await; + assert_snapshot!( + test.requests(), + @r" + RequestCountingObjectStore() + Total Requests: 4 + - LIST prefix=data + - GET (opts) path=data/file_0.csv + - GET (opts) path=data/file_1.csv + - GET (opts) path=data/file_2.csv + " + ); +} + +#[tokio::test] +async fn multi_query_multi_file_csv_file() { + let test = Test::new().with_multi_file_csv().await; + assert_snapshot!( + test.query("select * from csv_table").await, + @r" + ------- Query Output (6 rows) ------- + +---------+-------+-------+ + | c1 | c2 | c3 | + +---------+-------+-------+ + | 0.0 | 0.0 | true | + | 0.00003 | 5e-12 | false | + | 0.00001 | 1e-12 | true | + | 0.00003 | 5e-12 | false | + | 0.00002 | 2e-12 | true | + | 0.00003 | 5e-12 | false | + +---------+-------+-------+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 3 + - GET (opts) path=data/file_0.csv + - GET (opts) path=data/file_1.csv + - GET (opts) path=data/file_2.csv + " + ); + + // Force a cache eviction by removing the data limit for the cache + assert_snapshot!( + test.query("set datafusion.runtime.list_files_cache_limit=\"0K\"").await, + @r" + ------- Query Output (0 rows) ------- + ++ + ++ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 0 + " + ); + + // Then re-enable the cache + assert_snapshot!( + test.query("set datafusion.runtime.list_files_cache_limit=\"1M\"").await, + @r" + ------- Query Output (0 rows) ------- + ++ + ++ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 0 + " + ); + + // this query should list the table since the cache entries were evicted + assert_snapshot!( + test.query("select * from csv_table").await, + @r" + ------- Query Output (6 rows) ------- + +---------+-------+-------+ + | c1 | c2 | c3 | + +---------+-------+-------+ + | 0.0 | 0.0 | true | + | 0.00003 | 5e-12 | false | + | 0.00001 | 1e-12 | true | + | 0.00003 | 5e-12 | false | + | 0.00002 | 2e-12 | true | + | 0.00003 | 5e-12 | false | + +---------+-------+-------+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 4 + - LIST prefix=data + - GET (opts) path=data/file_0.csv + - GET (opts) path=data/file_1.csv + - GET (opts) path=data/file_2.csv + " + ); + + // this query should not list the table since the entries were added in the previous query + assert_snapshot!( + test.query("select * from csv_table").await, + @r" + ------- Query Output (6 rows) ------- + +---------+-------+-------+ + | c1 | c2 | c3 | + +---------+-------+-------+ + | 0.0 | 0.0 | true | + | 0.00003 | 5e-12 | false | + | 0.00001 | 1e-12 | true | + | 0.00003 | 5e-12 | false | + | 0.00002 | 2e-12 | true | + | 0.00003 | 5e-12 | false | + +---------+-------+-------+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 3 + - GET (opts) path=data/file_0.csv + - GET (opts) path=data/file_1.csv + - GET (opts) path=data/file_2.csv + " + ); +} + +#[tokio::test] +async fn query_multi_csv_file() { + let test = Test::new().with_multi_file_csv().await; + assert_snapshot!( + test.query("select * from csv_table").await, + @r" + ------- Query Output (6 rows) ------- + +---------+-------+-------+ + | c1 | c2 | c3 | + +---------+-------+-------+ + | 0.0 | 0.0 | true | + | 0.00003 | 5e-12 | false | + | 0.00001 | 1e-12 | true | + | 0.00003 | 5e-12 | false | + | 0.00002 | 2e-12 | true | + | 0.00003 | 5e-12 | false | + +---------+-------+-------+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 3 + - GET (opts) path=data/file_0.csv + - GET (opts) path=data/file_1.csv + - GET (opts) path=data/file_2.csv + " + ); +} + +#[tokio::test] +async fn query_partitioned_csv_file() { + let test = Test::new().with_partitioned_csv().await; + assert_snapshot!( + test.query("select * from csv_table_partitioned").await, + @r" + ------- Query Output (6 rows) ------- + +---------+-------+-------+---+----+-----+ + | d1 | d2 | d3 | a | b | c | + +---------+-------+-------+---+----+-----+ + | 0.00001 | 1e-12 | true | 1 | 10 | 100 | + | 0.00003 | 5e-12 | false | 1 | 10 | 100 | + | 0.00002 | 2e-12 | true | 2 | 20 | 200 | + | 0.00003 | 5e-12 | false | 2 | 20 | 200 | + | 0.00003 | 3e-12 | true | 3 | 30 | 300 | + | 0.00003 | 5e-12 | false | 3 | 30 | 300 | + +---------+-------+-------+---+----+-----+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 3 + - GET (opts) path=data/a=1/b=10/c=100/file_1.csv + - GET (opts) path=data/a=2/b=20/c=200/file_2.csv + - GET (opts) path=data/a=3/b=30/c=300/file_3.csv + " + ); + + assert_snapshot!( + test.query("select * from csv_table_partitioned WHERE a=2").await, + @r" + ------- Query Output (2 rows) ------- + +---------+-------+-------+---+----+-----+ + | d1 | d2 | d3 | a | b | c | + +---------+-------+-------+---+----+-----+ + | 0.00002 | 2e-12 | true | 2 | 20 | 200 | + | 0.00003 | 5e-12 | false | 2 | 20 | 200 | + +---------+-------+-------+---+----+-----+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 1 + - GET (opts) path=data/a=2/b=20/c=200/file_2.csv + " + ); + + assert_snapshot!( + test.query("select * from csv_table_partitioned WHERE b=20").await, + @r" + ------- Query Output (2 rows) ------- + +---------+-------+-------+---+----+-----+ + | d1 | d2 | d3 | a | b | c | + +---------+-------+-------+---+----+-----+ + | 0.00002 | 2e-12 | true | 2 | 20 | 200 | + | 0.00003 | 5e-12 | false | 2 | 20 | 200 | + +---------+-------+-------+---+----+-----+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 1 + - GET (opts) path=data/a=2/b=20/c=200/file_2.csv + " + ); + + assert_snapshot!( + test.query("select * from csv_table_partitioned WHERE c=200").await, + @r" + ------- Query Output (2 rows) ------- + +---------+-------+-------+---+----+-----+ + | d1 | d2 | d3 | a | b | c | + +---------+-------+-------+---+----+-----+ + | 0.00002 | 2e-12 | true | 2 | 20 | 200 | + | 0.00003 | 5e-12 | false | 2 | 20 | 200 | + +---------+-------+-------+---+----+-----+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 1 + - GET (opts) path=data/a=2/b=20/c=200/file_2.csv + " + ); + + assert_snapshot!( + test.query("select * from csv_table_partitioned WHERE a=2 AND b=20").await, + @r" + ------- Query Output (2 rows) ------- + +---------+-------+-------+---+----+-----+ + | d1 | d2 | d3 | a | b | c | + +---------+-------+-------+---+----+-----+ + | 0.00002 | 2e-12 | true | 2 | 20 | 200 | + | 0.00003 | 5e-12 | false | 2 | 20 | 200 | + +---------+-------+-------+---+----+-----+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 1 + - GET (opts) path=data/a=2/b=20/c=200/file_2.csv + " + ); + + assert_snapshot!( + test.query("select * from csv_table_partitioned WHERE a<2 AND b=10 AND c=100").await, + @r" + ------- Query Output (2 rows) ------- + +---------+-------+-------+---+----+-----+ + | d1 | d2 | d3 | a | b | c | + +---------+-------+-------+---+----+-----+ + | 0.00001 | 1e-12 | true | 1 | 10 | 100 | + | 0.00003 | 5e-12 | false | 1 | 10 | 100 | + +---------+-------+-------+---+----+-----+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 1 + - GET (opts) path=data/a=1/b=10/c=100/file_1.csv + " + ); +} + +#[tokio::test] +async fn create_single_parquet_file_default() { + // The default metadata size hint is 512KB + // which is enough to fetch the entire footer metadata and PageIndex + // in a single GET request. + let test = Test::new().with_single_file_parquet().await; + // expect 1 get request which reads the footer metadata and page index + assert_snapshot!( + test.requests(), + @r" + RequestCountingObjectStore() + Total Requests: 2 + - GET (opts) path=parquet_table.parquet head=true + - GET (opts) path=parquet_table.parquet range=0-2994 + " + ); +} + +#[tokio::test] +async fn create_single_parquet_file_prefetch() { + // Explicitly specify a prefetch hint that is adequate for the footer and page index + let test = Test::new() + .with_parquet_metadata_size_hint(Some(1000)) + .with_single_file_parquet() + .await; + // expect 1 1000 byte request which reads the footer metadata and page index + assert_snapshot!( + test.requests(), + @r" + RequestCountingObjectStore() + Total Requests: 2 + - GET (opts) path=parquet_table.parquet head=true + - GET (opts) path=parquet_table.parquet range=1994-2994 + " + ); +} + +#[tokio::test] +async fn create_single_parquet_file_too_small_prefetch() { + // configure a prefetch size that is too small to fetch the footer + // metadata + // + // Using the ranges from the test below (with no_prefetch), + // pick a number less than 730: + // -------- + // 2286-2294: (8 bytes) footer + length + // 2264-2986: (722 bytes) footer metadata + let test = Test::new() + .with_parquet_metadata_size_hint(Some(500)) + .with_single_file_parquet() + .await; + // expect three get requests: + // 1. read the footer (500 bytes per hint, not enough for the footer metadata) + // 2. Read the footer metadata + // 3. reads the PageIndex + assert_snapshot!( + test.requests(), + @r" + RequestCountingObjectStore() + Total Requests: 4 + - GET (opts) path=parquet_table.parquet head=true + - GET (opts) path=parquet_table.parquet range=2494-2994 + - GET (opts) path=parquet_table.parquet range=2264-2986 + - GET (opts) path=parquet_table.parquet range=2124-2264 + " + ); +} + +#[tokio::test] +async fn create_single_parquet_file_small_prefetch() { + // configure a prefetch size that is large enough for the footer + // metadata but **not** the PageIndex + // + // Using the ranges from the test below (with no_prefetch), + // the 730 is determined as follows; + // -------- + // 2286-2294: (8 bytes) footer + length + // 2264-2986: (722 bytes) footer metadata + let test = Test::new() + // 740 is enough to get both the footer + length (8 bytes) + // but not the entire PageIndex + .with_parquet_metadata_size_hint(Some(740)) + .with_single_file_parquet() + .await; + // expect two get requests: + // 1. read the footer metadata + // 2. reads the PageIndex + assert_snapshot!( + test.requests(), + @r" + RequestCountingObjectStore() + Total Requests: 3 + - GET (opts) path=parquet_table.parquet head=true + - GET (opts) path=parquet_table.parquet range=2254-2994 + - GET (opts) path=parquet_table.parquet range=2124-2264 + " + ); +} + +#[tokio::test] +async fn create_single_parquet_file_no_prefetch() { + let test = Test::new() + // force no prefetch by setting size hint to None + .with_parquet_metadata_size_hint(None) + .with_single_file_parquet() + .await; + // Without a metadata size hint, the parquet reader + // does *three* range requests to read the footer metadata: + // 1. The footer length (last 8 bytes) + // 2. The footer metadata + // 3. The PageIndex metadata + assert_snapshot!( + test.requests(), + @r" + RequestCountingObjectStore() + Total Requests: 2 + - GET (opts) path=parquet_table.parquet head=true + - GET (opts) path=parquet_table.parquet range=0-2994 + " + ); +} + +#[tokio::test] +async fn query_single_parquet_file() { + let test = Test::new().with_single_file_parquet().await; + assert_snapshot!( + test.query("select count(distinct a), count(b) from parquet_table").await, + @r" + ------- Query Output (1 rows) ------- + +---------------------------------+------------------------+ + | count(DISTINCT parquet_table.a) | count(parquet_table.b) | + +---------------------------------+------------------------+ + | 200 | 200 | + +---------------------------------+------------------------+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 3 + - GET (opts) path=parquet_table.parquet head=true + - GET (ranges) path=parquet_table.parquet ranges=4-534,534-1064 + - GET (ranges) path=parquet_table.parquet ranges=1064-1594,1594-2124 + " + ); +} + +#[tokio::test] +async fn query_single_parquet_file_with_single_predicate() { + let test = Test::new().with_single_file_parquet().await; + // Note that evaluating predicates requires additional object store requests + // (to evaluate predicates) + assert_snapshot!( + test.query("select min(a), max(b) from parquet_table WHERE a > 150").await, + @r" + ------- Query Output (1 rows) ------- + +----------------------+----------------------+ + | min(parquet_table.a) | max(parquet_table.b) | + +----------------------+----------------------+ + | 151 | 1199 | + +----------------------+----------------------+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 2 + - GET (opts) path=parquet_table.parquet head=true + - GET (ranges) path=parquet_table.parquet ranges=1064-1481,1481-1594,1594-2011,2011-2124 + " + ); +} + +#[tokio::test] +async fn query_single_parquet_file_multi_row_groups_multiple_predicates() { + let test = Test::new().with_single_file_parquet().await; + + // Note that evaluating predicates requires additional object store requests + // (to evaluate predicates) + assert_snapshot!( + test.query("select min(a), max(b) from parquet_table WHERE a > 50 AND b < 1150").await, + @r" + ------- Query Output (1 rows) ------- + +----------------------+----------------------+ + | min(parquet_table.a) | max(parquet_table.b) | + +----------------------+----------------------+ + | 51 | 1149 | + +----------------------+----------------------+ + ------- Object Store Request Summary ------- + RequestCountingObjectStore() + Total Requests: 3 + - GET (opts) path=parquet_table.parquet head=true + - GET (ranges) path=parquet_table.parquet ranges=4-421,421-534,534-951,951-1064 + - GET (ranges) path=parquet_table.parquet ranges=1064-1481,1481-1594,1594-2011,2011-2124 + " + ); +} + +/// Runs tests with a request counting object store +struct Test { + object_store: Arc, + session_context: SessionContext, + /// metadata size hint to use when registering parquet files + /// + /// * `None`: uses the default (does not set a size_hint) + /// * `Some(None)`L: set prefetch hint to None (prefetching) + /// * `Some(Some(size))`: set prefetch hint to size + parquet_metadata_size_hint: Option>, +} + +impl Test { + fn new() -> Self { + let object_store = Arc::new(RequestCountingObjectStore::new()); + let session_context = SessionContext::new(); + session_context + .runtime_env() + .register_object_store(&Url::parse("mem://").unwrap(), object_store.clone()); + Self { + object_store, + session_context, + parquet_metadata_size_hint: None, + } + } + + /// Specify the metadata size hint to use when registering parquet files + fn with_parquet_metadata_size_hint(mut self, size_hint: Option) -> Self { + self.parquet_metadata_size_hint = Some(size_hint); + self + } + + /// Returns a string representation of all recorded requests thus far + fn requests(&self) -> String { + format!("{}", self.object_store) + } + + /// Store the specified bytes at the given path + async fn with_bytes(self, path: &str, bytes: impl Into) -> Self { + let path = Path::from(path); + self.object_store + .inner + .put(&path, PutPayload::from(bytes.into())) + .await + .unwrap(); + self + } + + /// Register a CSV file at the given path + async fn register_csv(self, table_name: &str, path: &str) -> Self { + let mut options = CsvReadOptions::new(); + options.has_header = true; + let url = format!("mem://{path}"); + self.session_context + .register_csv(table_name, url, options) + .await + .unwrap(); + self + } + + /// Register a partitioned CSV table at the given path + async fn register_partitioned_csv(self, table_name: &str, path: &str) -> Self { + let file_format = Arc::new(CsvFormat::default().with_has_header(true)); + let options = ListingOptions::new(file_format); + + let url = format!("mem://{path}").parse().unwrap(); + let table_url = ListingTableUrl::try_new(url, None).unwrap(); + + let session_state = self.session_context.state(); + let mut config = ListingTableConfig::new(table_url).with_listing_options(options); + config = config + .infer_partitions_from_path(&session_state) + .await + .unwrap(); + config = config.infer_schema(&session_state).await.unwrap(); + + let table = Arc::new(ListingTable::try_new(config).unwrap()); + self.session_context + .register_table(table_name, table) + .unwrap(); + self + } + + /// Register a Parquet file at the given path + async fn register_parquet(self, table_name: &str, path: &str) -> Self { + let path = format!("mem://{path}"); + let mut options: ParquetReadOptions<'_> = ParquetReadOptions::new(); + + // If a metadata size hint was specified, apply it + if let Some(parquet_metadata_size_hint) = self.parquet_metadata_size_hint { + options = options.metadata_size_hint(parquet_metadata_size_hint); + } + + self.session_context + .register_parquet(table_name, path, options) + .await + .unwrap(); + self + } + + /// Register a single CSV file with three columns and two row named + /// `csv_table` + async fn with_single_file_csv(self) -> Test { + // upload CSV data to object store + let csv_data = r#"c1,c2,c3 +0.00001,5e-12,true +0.00002,4e-12,false +"#; + self.with_bytes("/csv_table.csv", csv_data) + .await + .register_csv("csv_table", "/csv_table.csv") + .await + } + + /// Register three CSV files in a directory, called `csv_table` + async fn with_multi_file_csv(mut self) -> Test { + // upload CSV data to object store + for i in 0..3 { + let csv_data1 = format!( + r#"c1,c2,c3 +0.0000{i},{i}e-12,true +0.00003,5e-12,false +"# + ); + self = self + .with_bytes(&format!("/data/file_{i}.csv"), csv_data1) + .await; + } + // register table + self.register_csv("csv_table", "/data/").await + } + + /// Register three CSV files in a partitioned directory structure, called + /// `csv_table_partitioned` + async fn with_partitioned_csv(mut self) -> Test { + for i in 1..4 { + // upload CSV data to object store + let csv_data1 = format!( + r#"d1,d2,d3 +0.0000{i},{i}e-12,true +0.00003,5e-12,false +"# + ); + self = self + .with_bytes( + &format!("/data/a={i}/b={}/c={}/file_{i}.csv", i * 10, i * 100,), + csv_data1, + ) + .await; + } + // register table + self.register_partitioned_csv("csv_table_partitioned", "/data/") + .await + } + + /// Add a single parquet file that has two columns and two row groups named `parquet_table` + /// + /// Column "a": Int32 with values 0-100] in row group 1 + /// and [101-200] in row group 2 + /// + /// Column "b": Int32 with values 1000-1100] in row group 1 + /// and [1101-1200] in row group 2 + async fn with_single_file_parquet(self) -> Test { + // Create parquet bytes + let a: ArrayRef = Arc::new(Int32Array::from_iter_values(0..200)); + let b: ArrayRef = Arc::new(Int32Array::from_iter_values(1000..1200)); + let batch = RecordBatch::try_from_iter([("a", a), ("b", b)]).unwrap(); + + let mut buffer = vec![]; + let props = parquet::file::properties::WriterProperties::builder() + .set_max_row_group_row_count(Some(100)) + .build(); + let mut writer = parquet::arrow::ArrowWriter::try_new( + &mut buffer, + batch.schema(), + Some(props), + ) + .unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + + self.with_bytes("/parquet_table.parquet", buffer) + .await + .register_parquet("parquet_table", "/parquet_table.parquet") + .await + } + + /// Runs the specified query and returns a string representation of the results + /// suitable for comparison with insta snapshots + /// + /// Clears all recorded requests before running the query + async fn query(&self, sql: &str) -> String { + self.object_store.clear_requests(); + let results = self + .session_context + .sql(sql) + .await + .unwrap() + .collect() + .await + .unwrap(); + + let num_rows = results.iter().map(|batch| batch.num_rows()).sum::(); + let formatted_result = + arrow::util::pretty::pretty_format_batches(&results).unwrap(); + + let object_store = &self.object_store; + + format!( + r#"------- Query Output ({num_rows} rows) ------- +{formatted_result} +------- Object Store Request Summary ------- +{object_store} +"# + ) + } +} + +/// Details of individual requests made through the [`RequestCountingObjectStore`] +#[derive(Clone, Debug)] +enum RequestDetails { + GetOpts { path: Path, get_options: GetOptions }, + GetRanges { path: Path, ranges: Vec> }, + List { prefix: Option }, + ListWithDelimiter { prefix: Option }, + ListWithOffset { prefix: Option, offset: Path }, +} + +fn display_range(range: &Range) -> impl Display + '_ { + struct Wrapper<'a>(&'a Range); + impl Display for Wrapper<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}-{}", self.0.start, self.0.end) + } + } + Wrapper(range) +} +impl Display for RequestDetails { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + RequestDetails::GetOpts { path, get_options } => { + write!(f, "GET (opts) path={path}")?; + if let Some(range) = &get_options.range { + match range { + GetRange::Bounded(range) => { + let range = display_range(range); + write!(f, " range={range}")?; + } + GetRange::Offset(offset) => { + write!(f, " range=offset:{offset}")?; + } + GetRange::Suffix(suffix) => { + write!(f, " range=suffix:{suffix}")?; + } + } + } + if let Some(version) = &get_options.version { + write!(f, " version={version}")?; + } + if get_options.head { + write!(f, " head=true")?; + } + Ok(()) + } + RequestDetails::GetRanges { path, ranges } => { + write!(f, "GET (ranges) path={path}")?; + if !ranges.is_empty() { + write!(f, " ranges=")?; + for (i, range) in ranges.iter().enumerate() { + if i > 0 { + write!(f, ",")?; + } + write!(f, "{}", display_range(range))?; + } + } + Ok(()) + } + RequestDetails::List { prefix } => { + write!(f, "LIST")?; + if let Some(prefix) = prefix { + write!(f, " prefix={prefix}")?; + } + Ok(()) + } + RequestDetails::ListWithDelimiter { prefix } => { + write!(f, "LIST (with delimiter)")?; + if let Some(prefix) = prefix { + write!(f, " prefix={prefix}")?; + } + Ok(()) + } + RequestDetails::ListWithOffset { prefix, offset } => { + write!(f, "LIST (with offset) offset={offset}")?; + if let Some(prefix) = prefix { + write!(f, " prefix={prefix}")?; + } + Ok(()) + } + } + } +} + +#[derive(Debug)] +struct RequestCountingObjectStore { + /// Inner (memory) store + inner: Arc, + requests: Mutex>, +} + +impl Display for RequestCountingObjectStore { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "RequestCountingObjectStore()")?; + let requests = self.recorded_requests(); + write!(f, "\nTotal Requests: {}", requests.len())?; + for request in requests { + write!(f, "\n- {request}")?; + } + Ok(()) + } +} + +impl RequestCountingObjectStore { + pub fn new() -> Self { + let inner = Arc::new(InMemory::new()); + Self { + inner, + requests: Mutex::new(vec![]), + } + } + + pub fn clear_requests(&self) { + self.requests.lock().clear(); + } + + /// Return a copy of the recorded requests normalized + /// by removing the path prefix + pub fn recorded_requests(&self) -> Vec { + self.requests.lock().to_vec() + } +} + +#[async_trait] +impl ObjectStore for RequestCountingObjectStore { + async fn put_opts( + &self, + _location: &Path, + _payload: PutPayload, + _opts: PutOptions, + ) -> object_store::Result { + unimplemented!() + } + + async fn put_multipart_opts( + &self, + _location: &Path, + _opts: PutMultipartOptions, + ) -> object_store::Result> { + unimplemented!() + } + + async fn get_opts( + &self, + location: &Path, + options: GetOptions, + ) -> object_store::Result { + let result = self.inner.get_opts(location, options.clone()).await?; + self.requests.lock().push(RequestDetails::GetOpts { + path: location.to_owned(), + get_options: options, + }); + Ok(result) + } + + async fn get_ranges( + &self, + location: &Path, + ranges: &[Range], + ) -> object_store::Result> { + let result = self.inner.get_ranges(location, ranges).await?; + self.requests.lock().push(RequestDetails::GetRanges { + path: location.to_owned(), + ranges: ranges.to_vec(), + }); + Ok(result) + } + + fn list( + &self, + prefix: Option<&Path>, + ) -> BoxStream<'static, object_store::Result> { + self.requests.lock().push(RequestDetails::List { + prefix: prefix.map(|p| p.to_owned()), + }); + + self.inner.list(prefix) + } + + fn list_with_offset( + &self, + prefix: Option<&Path>, + offset: &Path, + ) -> BoxStream<'static, object_store::Result> { + self.requests.lock().push(RequestDetails::ListWithOffset { + prefix: prefix.map(|p| p.to_owned()), + offset: offset.to_owned(), + }); + self.inner.list_with_offset(prefix, offset) + } + + async fn list_with_delimiter( + &self, + prefix: Option<&Path>, + ) -> object_store::Result { + self.requests + .lock() + .push(RequestDetails::ListWithDelimiter { + prefix: prefix.map(|p| p.to_owned()), + }); + self.inner.list_with_delimiter(prefix).await + } + + fn delete_stream( + &self, + _locations: BoxStream<'static, object_store::Result>, + ) -> BoxStream<'static, object_store::Result> { + unimplemented!() + } + + async fn copy_opts( + &self, + _from: &Path, + _to: &Path, + _options: CopyOptions, + ) -> object_store::Result<()> { + unimplemented!() + } +} diff --git a/datafusion/core/tests/execution/coop.rs b/datafusion/core/tests/execution/coop.rs new file mode 100644 index 0000000000000..e02364a0530cc --- /dev/null +++ b/datafusion/core/tests/execution/coop.rs @@ -0,0 +1,835 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Int64Array, RecordBatch}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow_schema::SortOptions; +use datafusion::common::NullEquality; +use datafusion::functions_aggregate::sum; +use datafusion::physical_expr::aggregate::AggregateExprBuilder; +use datafusion::physical_plan; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::aggregates::{ + AggregateExec, AggregateMode, LimitOptions, PhysicalGroupBy, +}; +use datafusion::physical_plan::execution_plan::Boundedness; +use datafusion::prelude::SessionContext; +use datafusion_common::{DataFusionError, JoinType, ScalarValue, exec_datafusion_err}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_expr_common::operator::Operator; +use datafusion_expr_common::operator::Operator::{Divide, Eq, Gt, Modulo}; +use datafusion_functions_aggregate::min_max; +use datafusion_physical_expr::Partitioning; +use datafusion_physical_expr::expressions::{ + BinaryExpr, Column, Literal, binary, col, lit, +}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_optimizer::ensure_coop::EnsureCooperative; +use datafusion_physical_plan::coop::make_cooperative; +use datafusion_physical_plan::filter::FilterExec; +use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode, SortMergeJoinExec}; +use datafusion_physical_plan::memory::{LazyBatchGenerator, LazyMemoryExec}; +use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr}; +use datafusion_physical_plan::repartition::RepartitionExec; +use datafusion_physical_plan::sorts::sort::SortExec; +use datafusion_physical_plan::stream::RecordBatchStreamAdapter; +use datafusion_physical_plan::union::InterleaveExec; +use futures::StreamExt; +use parking_lot::RwLock; +use rstest::rstest; +use std::any::Any; +use std::error::Error; +use std::fmt::Formatter; +use std::ops::Range; +use std::sync::Arc; +use std::task::Poll; +use std::time::Duration; +use tokio::runtime::{Handle, Runtime}; +use tokio::select; + +#[derive(Debug, Clone)] +struct RangeBatchGenerator { + schema: SchemaRef, + value_range: Range, + boundedness: Boundedness, + batch_size: usize, + poll_count: usize, + original_range: Range, +} + +impl std::fmt::Display for RangeBatchGenerator { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + // Display current counter + write!(f, "InfiniteGenerator(counter={})", self.poll_count) + } +} + +impl LazyBatchGenerator for RangeBatchGenerator { + fn as_any(&self) -> &dyn Any { + self + } + + fn boundedness(&self) -> Boundedness { + self.boundedness + } + + /// Generate the next RecordBatch. + fn generate_next_batch(&mut self) -> datafusion_common::Result> { + self.poll_count += 1; + + let mut builder = Int64Array::builder(self.batch_size); + for _ in 0..self.batch_size { + match self.value_range.next() { + None => break, + Some(v) => builder.append_value(v), + } + } + let array = builder.finish(); + + if array.is_empty() { + return Ok(None); + } + + let batch = + RecordBatch::try_new(Arc::clone(&self.schema), vec![Arc::new(array)])?; + Ok(Some(batch)) + } + + fn reset_state(&self) -> Arc> { + let mut new = self.clone(); + new.poll_count = 0; + new.value_range = new.original_range.clone(); + Arc::new(RwLock::new(new)) + } +} + +fn make_lazy_exec(column_name: &str, pretend_infinite: bool) -> LazyMemoryExec { + make_lazy_exec_with_range(column_name, i64::MIN..i64::MAX, pretend_infinite) +} + +fn make_lazy_exec_with_range( + column_name: &str, + range: Range, + pretend_infinite: bool, +) -> LazyMemoryExec { + let schema = Arc::new(Schema::new(vec![Field::new( + column_name, + DataType::Int64, + false, + )])); + + let boundedness = if pretend_infinite { + Boundedness::Unbounded { + requires_infinite_memory: false, + } + } else { + Boundedness::Bounded + }; + + // Instantiate the generator with the batch and limit + let batch_gen = RangeBatchGenerator { + schema: Arc::clone(&schema), + boundedness, + value_range: range.clone(), + batch_size: 8192, + poll_count: 0, + original_range: range, + }; + + // Wrap the generator in a trait object behind Arc> + let generator: Arc> = Arc::new(RwLock::new(batch_gen)); + + // Create a LazyMemoryExec with one partition using our generator + let mut exec = LazyMemoryExec::try_new(schema, vec![generator]).unwrap(); + + exec.add_ordering(vec![PhysicalSortExpr::new( + Arc::new(Column::new(column_name, 0)), + SortOptions::new(false, true), + )]); + + exec +} + +#[rstest] +#[tokio::test] +async fn agg_no_grouping_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // build session + let session_ctx = SessionContext::new(); + + // set up an aggregation without grouping + let inf = Arc::new(make_lazy_exec("value", pretend_infinite)); + let aggr = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new(vec![], vec![], vec![], false), + vec![Arc::new( + AggregateExprBuilder::new( + sum::sum_udaf(), + vec![col("value", &inf.schema())?], + ) + .schema(inf.schema()) + .alias("sum") + .build()?, + )], + vec![None], + inf.clone(), + inf.schema(), + )?); + + query_yields(aggr, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn agg_grouping_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // build session + let session_ctx = SessionContext::new(); + + // set up an aggregation with grouping + let inf = Arc::new(make_lazy_exec("value", pretend_infinite)); + + let value_col = col("value", &inf.schema())?; + let group = binary(value_col.clone(), Divide, lit(1000000i64), &inf.schema())?; + + let aggr = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new(vec![(group, "group".to_string())], vec![], vec![], false), + vec![Arc::new( + AggregateExprBuilder::new(sum::sum_udaf(), vec![value_col.clone()]) + .schema(inf.schema()) + .alias("sum") + .build()?, + )], + vec![None], + inf.clone(), + inf.schema(), + )?); + + query_yields(aggr, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn agg_grouped_topk_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // build session + + let session_ctx = SessionContext::new(); + + // set up a top-k aggregation + let inf = Arc::new(make_lazy_exec("value", pretend_infinite)); + + let value_col = col("value", &inf.schema())?; + let group = binary(value_col.clone(), Divide, lit(1000000i64), &inf.schema())?; + + let aggr = Arc::new( + AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new( + vec![(group, "group".to_string())], + vec![], + vec![vec![false]], + false, + ), + vec![Arc::new( + AggregateExprBuilder::new(min_max::max_udaf(), vec![value_col.clone()]) + .schema(inf.schema()) + .alias("max") + .build()?, + )], + vec![None], + inf.clone(), + inf.schema(), + )? + .with_limit_options(Some(LimitOptions::new(100))), + ); + + query_yields(aggr, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +// A test that mocks the behavior of `SpillManager::read_spill_as_stream` without file access +// to verify that a cooperative stream would properly yields in a spill file read scenario +async fn spill_reader_stream_yield() -> Result<(), Box> { + use datafusion_physical_plan::common::spawn_buffered; + + // A mock stream that always returns `Poll::Ready(Some(...))` immediately + let always_ready = + make_lazy_exec("value", false).execute(0, SessionContext::new().task_ctx())?; + + // this function makes a consumer stream that resembles how read_stream from spill file is constructed + let stream = make_cooperative(always_ready); + + // Set large buffer so that buffer always has free space for the producer/sender + let buffer_capacity = 100_000; + let mut mock_stream = spawn_buffered(stream, buffer_capacity); + let schema = mock_stream.schema(); + + let consumer_stream = futures::stream::poll_fn(move |cx| { + let mut collected = vec![]; + // To make sure that inner stream is polled multiple times, loop until the buffer is full + // Ideally, the stream will yield before the loop ends + for _ in 0..buffer_capacity { + match mock_stream.as_mut().poll_next(cx) { + Poll::Ready(Some(Ok(batch))) => { + collected.push(batch); + } + Poll::Ready(Some(Err(e))) => { + return Poll::Ready(Some(Err(e))); + } + Poll::Ready(None) => { + break; + } + Poll::Pending => { + // polling inner stream may return Pending only when it reaches budget, since + // we intentionally made ProducerStream always return Ready + return Poll::Pending; + } + } + } + + // This should be unreachable since the stream is canceled + unreachable!("Expected the stream to be canceled, but it continued polling"); + }); + + let consumer_record_batch_stream = + Box::pin(RecordBatchStreamAdapter::new(schema, consumer_stream)); + + stream_yields(consumer_record_batch_stream).await +} + +#[rstest] +#[tokio::test] +async fn sort_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // build session + let session_ctx = SessionContext::new(); + + // set up the infinite source + let inf = Arc::new(make_lazy_exec("value", pretend_infinite)); + + // set up a SortExec that will not be able to finish in time because input is very large + let sort_expr = PhysicalSortExpr::new( + col("value", &inf.schema())?, + SortOptions { + descending: true, + nulls_first: true, + }, + ); + + let lex_ordering = LexOrdering::new(vec![sort_expr]).unwrap(); + let sort_exec = Arc::new(SortExec::new(lex_ordering, inf.clone())); + + query_yields(sort_exec, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn sort_merge_join_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // build session + let session_ctx = SessionContext::new(); + + // set up the join sources + let inf1 = Arc::new(make_lazy_exec_with_range( + "value1", + i64::MIN..0, + pretend_infinite, + )); + let inf2 = Arc::new(make_lazy_exec_with_range( + "value2", + 0..i64::MAX, + pretend_infinite, + )); + + // set up a SortMergeJoinExec that will take a long time skipping left side content to find + // the first right side match + let join = Arc::new(SortMergeJoinExec::try_new( + inf1.clone(), + inf2.clone(), + vec![( + col("value1", &inf1.schema())?, + col("value2", &inf2.schema())?, + )], + None, + JoinType::Inner, + vec![inf1.properties().eq_properties.output_ordering().unwrap()[0].options], + NullEquality::NullEqualsNull, + )?); + + query_yields(join, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn filter_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // build session + let session_ctx = SessionContext::new(); + + // set up the infinite source + let inf = Arc::new(make_lazy_exec("value", pretend_infinite)); + + // set up a FilterExec that will filter out entire batches + let filter_expr = binary( + col("value", &inf.schema())?, + Operator::Lt, + lit(i64::MIN), + &inf.schema(), + )?; + let filter = Arc::new(FilterExec::try_new(filter_expr, inf.clone())?); + + query_yields(filter, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn filter_reject_all_batches_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // Create a Session, Schema, and an 8K-row RecordBatch + let session_ctx = SessionContext::new(); + + // Wrap this batch in an InfiniteExec + let infinite = make_lazy_exec_with_range("value", i64::MIN..0, pretend_infinite); + + // 2b) Construct a FilterExec that is always false: “value > 10000” (no rows pass) + let false_predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("value", 0)), + Gt, + Arc::new(Literal::new(ScalarValue::Int64(Some(0)))), + )); + let filtered = Arc::new(FilterExec::try_new(false_predicate, Arc::new(infinite))?); + + query_yields(filtered, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn interleave_then_filter_all_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // Build a session and a schema with one i64 column. + let session_ctx = SessionContext::new(); + + // Create multiple infinite sources, each filtered by a different threshold. + // This ensures InterleaveExec has many children. + let mut infinite_children = vec![]; + + // Use 32 distinct thresholds (each >0 and <8 192) to force 32 infinite inputs + for threshold in 1..32 { + // One infinite exec: + let mut inf = make_lazy_exec_with_range("value", 0..i64::MAX, pretend_infinite); + + // Now repartition so that all children share identical Hash partitioning + // on “value” into 1 bucket. This is required for InterleaveExec::try_new. + let exprs = vec![Arc::new(Column::new("value", 0)) as _]; + let partitioning = Partitioning::Hash(exprs, 1); + inf.try_set_partitioning(partitioning)?; + + // Apply a FilterExec: “(value / 8192) % threshold == 0”. + let filter_expr = binary( + binary( + binary( + col("value", &inf.schema())?, + Divide, + lit(8192i64), + &inf.schema(), + )?, + Modulo, + lit(threshold as i64), + &inf.schema(), + )?, + Eq, + lit(0i64), + &inf.schema(), + )?; + let filtered = Arc::new(FilterExec::try_new(filter_expr, Arc::new(inf))?); + + infinite_children.push(filtered as _); + } + + // Build an InterleaveExec over all infinite children. + let interleave = Arc::new(InterleaveExec::try_new(infinite_children)?); + + // Wrap the InterleaveExec in a FilterExec that always returns false, + // ensuring that no rows are ever emitted. + let always_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))); + let filtered_interleave = Arc::new(FilterExec::try_new(always_false, interleave)?); + + query_yields(filtered_interleave, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn interleave_then_aggregate_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // Build session, schema, and a sample batch. + let session_ctx = SessionContext::new(); + + // Create N infinite sources, each filtered by a different predicate. + // That way, the InterleaveExec will have multiple children. + let mut infinite_children = vec![]; + + // Use 32 distinct thresholds (each >0 and <8 192) to force 32 infinite inputs + for threshold in 1..32 { + // One infinite exec: + let mut inf = make_lazy_exec_with_range("value", 0..i64::MAX, pretend_infinite); + + // Now repartition so that all children share identical Hash partitioning + // on “value” into 1 bucket. This is required for InterleaveExec::try_new. + let exprs = vec![Arc::new(Column::new("value", 0)) as _]; + let partitioning = Partitioning::Hash(exprs, 1); + inf.try_set_partitioning(partitioning)?; + + // Apply a FilterExec: “(value / 8192) % threshold == 0”. + let filter_expr = binary( + binary( + binary( + col("value", &inf.schema())?, + Divide, + lit(8192i64), + &inf.schema(), + )?, + Modulo, + lit(threshold as i64), + &inf.schema(), + )?, + Eq, + lit(0i64), + &inf.schema(), + )?; + let filtered = Arc::new(FilterExec::try_new(filter_expr, Arc::new(inf))?); + + infinite_children.push(filtered as _); + } + + // Build an InterleaveExec over all N children. + // Since each child now has Partitioning::Hash([col "value"], 1), InterleaveExec::try_new succeeds. + let interleave = Arc::new(InterleaveExec::try_new(infinite_children)?); + let interleave_schema = interleave.schema(); + + // Build a global AggregateExec that sums “value” over all rows. + // Because we use `AggregateMode::Single` with no GROUP BY columns, this plan will + // only produce one “final” row once all inputs finish. But our inputs never finish, + // so we should never get any output. + let aggregate_expr = AggregateExprBuilder::new( + sum::sum_udaf(), + vec![Arc::new(Column::new("value", 0))], + ) + .schema(interleave_schema.clone()) + .alias("total") + .build()?; + + let aggr = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new( + vec![], // no GROUP BY columns + vec![], // no GROUP BY expressions + vec![], // no GROUP BY physical expressions + false, + ), + vec![Arc::new(aggregate_expr)], + vec![None], // no “distinct” flags + interleave, + interleave_schema, + )?); + + query_yields(aggr, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn join_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // Session, schema, and a single 8 K‐row batch for each side + let session_ctx = SessionContext::new(); + + // on the right side, we’ll shift each value by +1 so that not everything joins, + // but plenty of matching keys exist (e.g. 0 on left matches 1 on right, etc.) + let infinite_left = make_lazy_exec_with_range("value", -10..10, false); + let infinite_right = + make_lazy_exec_with_range("value", 0..i64::MAX, pretend_infinite); + + // Create Join keys → join on “value” = “value” + let left_keys: Vec> = vec![Arc::new(Column::new("value", 0))]; + let right_keys: Vec> = vec![Arc::new(Column::new("value", 0))]; + + let part_left = Partitioning::Hash(left_keys, 1); + let part_right = Partitioning::Hash(right_keys, 1); + + // Wrap each side in Repartition so they are both hashed into 1 partition + let hashed_left = Arc::new(RepartitionExec::try_new( + Arc::new(infinite_left), + part_left, + )?); + let hashed_right = Arc::new(RepartitionExec::try_new( + Arc::new(infinite_right), + part_right, + )?); + + // Build an Inner HashJoinExec → left.value = right.value + let join = Arc::new(HashJoinExec::try_new( + hashed_left, + hashed_right, + vec![( + Arc::new(Column::new("value", 0)), + Arc::new(Column::new("value", 0)), + )], + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNull, + false, + )?); + + query_yields(join, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn join_agg_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // Session, schema, and a single 8 K‐row batch for each side + let session_ctx = SessionContext::new(); + + // on the right side, we’ll shift each value by +1 so that not everything joins, + // but plenty of matching keys exist (e.g. 0 on left matches 1 on right, etc.) + let infinite_left = make_lazy_exec_with_range("value", -10..10, false); + let infinite_right = + make_lazy_exec_with_range("value", 0..i64::MAX, pretend_infinite); + + // 2b) Create Join keys → join on “value” = “value” + let left_keys: Vec> = vec![Arc::new(Column::new("value", 0))]; + let right_keys: Vec> = vec![Arc::new(Column::new("value", 0))]; + + let part_left = Partitioning::Hash(left_keys, 1); + let part_right = Partitioning::Hash(right_keys, 1); + + // Wrap each side in Repartition so they are both hashed into 1 partition + let hashed_left = Arc::new(RepartitionExec::try_new( + Arc::new(infinite_left), + part_left, + )?); + let hashed_right = Arc::new(RepartitionExec::try_new( + Arc::new(infinite_right), + part_right, + )?); + + // Build an Inner HashJoinExec → left.value = right.value + let join = Arc::new(HashJoinExec::try_new( + hashed_left, + hashed_right, + vec![( + Arc::new(Column::new("value", 0)), + Arc::new(Column::new("value", 0)), + )], + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNull, + false, + )?); + + // Project only one column (“value” from the left side) because we just want to sum that + let input_schema = join.schema(); + + let proj_expr = vec![ProjectionExpr::new( + Arc::new(Column::new_with_schema("value", &input_schema)?) as _, + "value", + )]; + + let projection = Arc::new(ProjectionExec::try_new(proj_expr, join)?); + let projection_schema = projection.schema(); + + let output_fields = vec![Field::new("total", DataType::Int64, true)]; + let output_schema = Arc::new(Schema::new(output_fields)); + + // 4) Global aggregate (Single) over “value” + let aggregate_expr = AggregateExprBuilder::new( + sum::sum_udaf(), + vec![Arc::new(Column::new_with_schema( + "value", + &projection.schema(), + )?)], + ) + .schema(output_schema) + .alias("total") + .build()?; + + let aggr = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new(vec![], vec![], vec![], false), + vec![Arc::new(aggregate_expr)], + vec![None], + projection, + projection_schema, + )?); + + query_yields(aggr, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn hash_join_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // build session + let session_ctx = SessionContext::new(); + + // set up the join sources + let inf1 = Arc::new(make_lazy_exec("value1", pretend_infinite)); + let inf2 = Arc::new(make_lazy_exec("value2", pretend_infinite)); + + // set up a HashJoinExec that will take a long time in the build phase + let join = Arc::new(HashJoinExec::try_new( + inf1.clone(), + inf2.clone(), + vec![( + col("value1", &inf1.schema())?, + col("value2", &inf2.schema())?, + )], + None, + &JoinType::Left, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNull, + false, + )?); + + query_yields(join, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn hash_join_without_repartition_and_no_agg( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // Create Session, schema, and an 8K-row RecordBatch for each side + let session_ctx = SessionContext::new(); + + // on the right side, we’ll shift each value by +1 so that not everything joins, + // but plenty of matching keys exist (e.g. 0 on left matches 1 on right, etc.) + let infinite_left = make_lazy_exec_with_range("value", -10..10, false); + let infinite_right = + make_lazy_exec_with_range("value", 0..i64::MAX, pretend_infinite); + + // Directly feed `infinite_left` and `infinite_right` into HashJoinExec. + // Do not use aggregation or repartition. + let join = Arc::new(HashJoinExec::try_new( + Arc::new(infinite_left), + Arc::new(infinite_right), + vec![( + Arc::new(Column::new("value", 0)), + Arc::new(Column::new("value", 0)), + )], + /* filter */ None, + &JoinType::Inner, + /* output64 */ None, + // Using CollectLeft is fine—just avoid RepartitionExec's partitioned channels. + PartitionMode::CollectLeft, + NullEquality::NullEqualsNull, + false, + )?); + + query_yields(join, session_ctx.task_ctx()).await +} + +#[derive(Debug)] +enum Yielded { + ReadyOrPending, + Err(#[expect(dead_code)] DataFusionError), + Timeout, +} + +async fn stream_yields( + mut stream: SendableRecordBatchStream, +) -> Result<(), Box> { + // Create an independent executor pool + let child_runtime = Runtime::new()?; + + // Spawn a task that tries to poll the stream + // The task returns Ready when the stream yielded with either Ready or Pending + let join_handle = child_runtime.spawn(std::future::poll_fn(move |cx| { + match stream.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(_))) => Poll::Ready(Poll::Ready(Ok(()))), + Poll::Ready(Some(Err(e))) => Poll::Ready(Poll::Ready(Err(e))), + Poll::Ready(None) => Poll::Ready(Poll::Ready(Ok(()))), + Poll::Pending => Poll::Ready(Poll::Pending), + } + })); + + let abort_handle = join_handle.abort_handle(); + + // Now select on the join handle of the task running in the child executor with a timeout + let yielded = select! { + result = join_handle => { + match result { + Ok(Poll::Pending) => Yielded::ReadyOrPending, + Ok(Poll::Ready(Ok(_))) => Yielded::ReadyOrPending, + Ok(Poll::Ready(Err(e))) => Yielded::Err(e), + Err(_) => Yielded::Err(exec_datafusion_err!("join error")), + } + }, + _ = tokio::time::sleep(Duration::from_secs(10)) => { + Yielded::Timeout + } + }; + + // Try to abort the poll task and shutdown the child runtime + abort_handle.abort(); + Handle::current().spawn_blocking(move || { + child_runtime.shutdown_timeout(Duration::from_secs(5)); + }); + + // Finally, check if poll_next yielded + assert!( + matches!(yielded, Yielded::ReadyOrPending), + "Result is not Ready or Pending: {yielded:?}" + ); + Ok(()) +} + +async fn query_yields( + plan: Arc, + task_ctx: Arc, +) -> Result<(), Box> { + // Run plan through EnsureCooperative + let optimized = + EnsureCooperative::new().optimize(plan, task_ctx.session_config().options())?; + + // Get the stream + let stream = physical_plan::execute_stream(optimized, task_ctx)?; + + // Spawn a task that tries to poll the stream and check whether given stream yields + stream_yields(stream).await +} diff --git a/datafusion/core/tests/execution/datasource_split.rs b/datafusion/core/tests/execution/datasource_split.rs new file mode 100644 index 0000000000000..370249cd8044e --- /dev/null +++ b/datafusion/core/tests/execution/datasource_split.rs @@ -0,0 +1,123 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{ + array::{ArrayRef, Int32Array}, + datatypes::{DataType, Field, Schema}, + record_batch::RecordBatch, +}; +use datafusion_datasource::memory::MemorySourceConfig; +use datafusion_execution::TaskContext; +use datafusion_physical_plan::{ExecutionPlan, common::collect}; +use std::sync::Arc; + +/// Helper function to create a memory source with the given batch size and collect all batches +async fn create_and_collect_batches( + batch_size: usize, +) -> datafusion_common::Result> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let array = Int32Array::from_iter_values(0..batch_size as i32); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array) as ArrayRef])?; + let exec = MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None)?; + let ctx = Arc::new(TaskContext::default()); + let stream = exec.execute(0, ctx)?; + collect(stream).await +} + +/// Helper function to create a memory source with multiple batches and collect all results +async fn create_and_collect_multiple_batches( + input_batches: Vec, +) -> datafusion_common::Result> { + let schema = input_batches[0].schema(); + let exec = MemorySourceConfig::try_new_exec(&[input_batches], schema, None)?; + let ctx = Arc::new(TaskContext::default()); + let stream = exec.execute(0, ctx)?; + collect(stream).await +} + +#[tokio::test] +async fn datasource_splits_large_batches() -> datafusion_common::Result<()> { + let batch_size = 20000; + let batches = create_and_collect_batches(batch_size).await?; + + assert!(batches.len() > 1); + let max = batches.iter().map(|b| b.num_rows()).max().unwrap(); + assert!( + max <= datafusion_execution::config::SessionConfig::new() + .options() + .execution + .batch_size + ); + let total: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total, batch_size); + Ok(()) +} + +#[tokio::test] +async fn datasource_exact_batch_size_no_split() -> datafusion_common::Result<()> { + let session_config = datafusion_execution::config::SessionConfig::new(); + let configured_batch_size = session_config.options().execution.batch_size; + + let batches = create_and_collect_batches(configured_batch_size).await?; + + // Should not split when exactly equal to batch_size + assert_eq!(batches.len(), 1); + assert_eq!(batches[0].num_rows(), configured_batch_size); + Ok(()) +} + +#[tokio::test] +async fn datasource_small_batch_no_split() -> datafusion_common::Result<()> { + // Test with batch smaller than the batch size (8192) + let small_batch_size = 512; // Less than 8192 + + let batches = create_and_collect_batches(small_batch_size).await?; + + // Should not split small batches below the batch size + assert_eq!(batches.len(), 1); + assert_eq!(batches[0].num_rows(), small_batch_size); + Ok(()) +} + +#[tokio::test] +async fn datasource_empty_batch_clean_termination() -> datafusion_common::Result<()> { + let batches = create_and_collect_batches(0).await?; + + // Empty batch should result in one empty batch + assert_eq!(batches.len(), 1); + assert_eq!(batches[0].num_rows(), 0); + Ok(()) +} + +#[tokio::test] +async fn datasource_multiple_empty_batches() -> datafusion_common::Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let empty_array = Int32Array::from_iter_values(std::iter::empty::()); + let empty_batch = + RecordBatch::try_new(schema.clone(), vec![Arc::new(empty_array) as ArrayRef])?; + + // Create multiple empty batches + let input_batches = vec![empty_batch.clone(), empty_batch.clone(), empty_batch]; + let batches = create_and_collect_multiple_batches(input_batches).await?; + + // Should preserve empty batches without issues + assert_eq!(batches.len(), 3); + for batch in &batches { + assert_eq!(batch.num_rows(), 0); + } + Ok(()) +} diff --git a/datafusion/core/tests/execution/logical_plan.rs b/datafusion/core/tests/execution/logical_plan.rs index 97bb2a727bbfe..3eaa3fb2ed5e6 100644 --- a/datafusion/core/tests/execution/logical_plan.rs +++ b/datafusion/core/tests/execution/logical_plan.rs @@ -20,7 +20,7 @@ use arrow::array::Int64Array; use arrow::datatypes::{DataType, Field, Schema}; -use datafusion::datasource::{provider_as_source, ViewTable}; +use datafusion::datasource::{ViewTable, provider_as_source}; use datafusion::execution::session_state::SessionStateBuilder; use datafusion_common::{Column, DFSchema, DFSchemaRef, Result, ScalarValue, Spans}; use datafusion_execution::TaskContext; @@ -47,9 +47,9 @@ async fn count_only_nulls() -> Result<()> { let input = Arc::new(LogicalPlan::Values(Values { schema: input_schema, values: vec![ - vec![Expr::Literal(ScalarValue::Null)], - vec![Expr::Literal(ScalarValue::Null)], - vec![Expr::Literal(ScalarValue::Null)], + vec![Expr::Literal(ScalarValue::Null, None)], + vec![Expr::Literal(ScalarValue::Null, None)], + vec![Expr::Literal(ScalarValue::Null, None)], ], })); let input_col_ref = Expr::Column(Column { @@ -68,7 +68,7 @@ async fn count_only_nulls() -> Result<()> { args: vec![input_col_ref], distinct: false, filter: None, - order_by: None, + order_by: vec![], null_treatment: None, }, })], @@ -128,7 +128,7 @@ fn inline_scan_projection_test() -> Result<()> { @r" SubqueryAlias: ?table? Projection: a - EmptyRelation + EmptyRelation: rows=0 " ); diff --git a/datafusion/core/tests/execution/mod.rs b/datafusion/core/tests/execution/mod.rs index 8169db1a4611e..f33ef87aa3023 100644 --- a/datafusion/core/tests/execution/mod.rs +++ b/datafusion/core/tests/execution/mod.rs @@ -15,4 +15,7 @@ // specific language governing permissions and limitations // under the License. +mod coop; +mod datasource_split; mod logical_plan; +mod register_arrow; diff --git a/datafusion/core/tests/execution/register_arrow.rs b/datafusion/core/tests/execution/register_arrow.rs new file mode 100644 index 0000000000000..4ce16dc0906c1 --- /dev/null +++ b/datafusion/core/tests/execution/register_arrow.rs @@ -0,0 +1,90 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Integration tests for register_arrow API + +use datafusion::{execution::options::ArrowReadOptions, prelude::*}; +use datafusion_common::Result; + +#[tokio::test] +async fn test_register_arrow_auto_detects_format() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_arrow( + "file_format", + "../../datafusion/datasource-arrow/tests/data/example.arrow", + ArrowReadOptions::default(), + ) + .await?; + + ctx.register_arrow( + "stream_format", + "../../datafusion/datasource-arrow/tests/data/example_stream.arrow", + ArrowReadOptions::default(), + ) + .await?; + + let file_result = ctx.sql("SELECT * FROM file_format ORDER BY f0").await?; + let stream_result = ctx.sql("SELECT * FROM stream_format ORDER BY f0").await?; + + let file_batches = file_result.collect().await?; + let stream_batches = stream_result.collect().await?; + + assert_eq!(file_batches.len(), stream_batches.len()); + assert_eq!(file_batches[0].schema(), stream_batches[0].schema()); + + let file_rows: usize = file_batches.iter().map(|b| b.num_rows()).sum(); + let stream_rows: usize = stream_batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(file_rows, stream_rows); + + Ok(()) +} + +#[tokio::test] +async fn test_register_arrow_join_file_and_stream() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_arrow( + "file_table", + "../../datafusion/datasource-arrow/tests/data/example.arrow", + ArrowReadOptions::default(), + ) + .await?; + + ctx.register_arrow( + "stream_table", + "../../datafusion/datasource-arrow/tests/data/example_stream.arrow", + ArrowReadOptions::default(), + ) + .await?; + + let result = ctx + .sql( + "SELECT a.f0, a.f1, b.f0, b.f1 + FROM file_table a + JOIN stream_table b ON a.f0 = b.f0 + WHERE a.f0 <= 2 + ORDER BY a.f0", + ) + .await?; + let batches = result.collect().await?; + + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 2); + + Ok(()) +} diff --git a/datafusion/core/tests/expr_api/mod.rs b/datafusion/core/tests/expr_api/mod.rs index a9cf7f04bb3a2..91dd5de7fcd64 100644 --- a/datafusion/core/tests/expr_api/mod.rs +++ b/datafusion/core/tests/expr_api/mod.rs @@ -16,26 +16,26 @@ // under the License. use arrow::array::{ - builder::{ListBuilder, StringBuilder}, ArrayRef, Int64Array, RecordBatch, StringArray, StructArray, + builder::{ListBuilder, StringBuilder}, }; use arrow::datatypes::{DataType, Field}; use arrow::util::pretty::{pretty_format_batches, pretty_format_columns}; use datafusion::prelude::*; use datafusion_common::{DFSchema, ScalarValue}; -use datafusion_expr::execution_props::ExecutionProps; -use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::ExprFunctionExt; +use datafusion_expr::expr::NullTreatment; +use datafusion_expr::simplify::SimplifyContext; use datafusion_functions::core::expr_ext::FieldAccessor; use datafusion_functions_aggregate::first_last::first_value_udaf; use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_functions_nested::expr_ext::{IndexAccessor, SliceAccessor}; use datafusion_optimizer::simplify_expressions::ExprSimplifier; -use sqlparser::ast::NullTreatment; /// Tests of using and evaluating `Expr`s outside the context of a LogicalPlan use std::sync::{Arc, LazyLock}; mod parse_sql_expr; +#[expect(clippy::needless_pass_by_value)] mod simplification; #[test] @@ -320,6 +320,26 @@ async fn test_create_physical_expr() { create_simplified_expr_test(lit(1i32) + lit(2i32), "3"); } +#[test] +fn test_create_physical_expr_nvl2() { + let batch = &TEST_BATCH; + let df_schema = DFSchema::try_from(batch.schema()).unwrap(); + let ctx = SessionContext::new(); + + let expect_err = |expr| { + let physical_expr = ctx.create_physical_expr(expr, &df_schema).unwrap(); + let err = physical_expr.evaluate(batch).unwrap_err(); + assert!( + err.to_string() + .contains("nvl2 should have been simplified to case"), + "unexpected error: {err:?}" + ); + }; + + expect_err(nvl2(col("i"), lit(1i64), lit(0i64))); + expect_err(nvl2(lit(1i64), col("i"), lit(0i64))); +} + #[tokio::test] async fn test_create_physical_expr_coercion() { // create_physical_expr does apply type coercion and unwrapping in cast @@ -364,6 +384,7 @@ async fn evaluate_agg_test(expr: Expr, expected_lines: Vec<&str>) { /// Converts the `Expr` to a `PhysicalExpr`, evaluates it against the provided /// `RecordBatch` and compares the result to the expected result. +#[expect(clippy::needless_pass_by_value)] fn evaluate_expr_test(expr: Expr, expected_lines: Vec<&str>) { let batch = &TEST_BATCH; let df_schema = DFSchema::try_from(batch.schema()).unwrap(); @@ -400,9 +421,7 @@ fn create_simplified_expr_test(expr: Expr, expected_expr: &str) { let df_schema = DFSchema::try_from(batch.schema()).unwrap(); // Simplify the expression first - let props = ExecutionProps::new(); - let simplify_context = - SimplifyContext::new(&props).with_schema(df_schema.clone().into()); + let simplify_context = SimplifyContext::default().with_schema(Arc::new(df_schema)); let simplifier = ExprSimplifier::new(simplify_context).with_max_cycles(10); let simplified = simplifier.simplify(expr).unwrap(); create_expr_test(simplified, expected_expr); diff --git a/datafusion/core/tests/expr_api/parse_sql_expr.rs b/datafusion/core/tests/expr_api/parse_sql_expr.rs index 92c18204324f7..b0d8b3a349ae2 100644 --- a/datafusion/core/tests/expr_api/parse_sql_expr.rs +++ b/datafusion/core/tests/expr_api/parse_sql_expr.rs @@ -19,9 +19,9 @@ use arrow::datatypes::{DataType, Field, Schema}; use datafusion::prelude::{CsvReadOptions, SessionContext}; use datafusion_common::DFSchema; use datafusion_common::{DFSchemaRef, Result, ToDFSchema}; +use datafusion_expr::Expr; use datafusion_expr::col; use datafusion_expr::lit; -use datafusion_expr::Expr; use datafusion_sql::unparser::Unparser; /// A schema like: /// diff --git a/datafusion/core/tests/expr_api/simplification.rs b/datafusion/core/tests/expr_api/simplification.rs index 34e0487f312fb..02f2503faf22a 100644 --- a/datafusion/core/tests/expr_api/simplification.rs +++ b/datafusion/core/tests/expr_api/simplification.rs @@ -17,20 +17,22 @@ //! This program demonstrates the DataFusion expression simplification API. +use insta::assert_snapshot; + use arrow::array::types::IntervalDayTime; use arrow::array::{ArrayRef, Int32Array}; use arrow::datatypes::{DataType, Field, Schema}; use chrono::{DateTime, TimeZone, Utc}; -use datafusion::{error::Result, execution::context::ExecutionProps, prelude::*}; -use datafusion_common::cast::as_int32_array; +use datafusion::{error::Result, prelude::*}; use datafusion_common::ScalarValue; +use datafusion_common::cast::as_int32_array; use datafusion_common::{DFSchemaRef, ToDFSchema}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::logical_plan::builder::table_scan_with_filters; -use datafusion_expr::simplify::SimplifyInfo; +use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::{ - table_scan, Cast, ColumnarValue, ExprSchemable, LogicalPlan, LogicalPlanBuilder, - ScalarUDF, Volatility, + Cast, ColumnarValue, ExprSchemable, LogicalPlan, LogicalPlanBuilder, Projection, + ScalarUDF, Volatility, table_scan, }; use datafusion_functions::math; use datafusion_optimizer::optimizer::Optimizer; @@ -38,50 +40,6 @@ use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyExpress use datafusion_optimizer::{OptimizerContext, OptimizerRule}; use std::sync::Arc; -/// In order to simplify expressions, DataFusion must have information -/// about the expressions. -/// -/// You can provide that information using DataFusion [DFSchema] -/// objects or from some other implementation -struct MyInfo { - /// The input schema - schema: DFSchemaRef, - - /// Execution specific details needed for constant evaluation such - /// as the current time for `now()` and [VariableProviders] - execution_props: ExecutionProps, -} - -impl SimplifyInfo for MyInfo { - fn is_boolean_type(&self, expr: &Expr) -> Result { - Ok(matches!( - expr.get_type(self.schema.as_ref())?, - DataType::Boolean - )) - } - - fn nullable(&self, expr: &Expr) -> Result { - expr.nullable(self.schema.as_ref()) - } - - fn execution_props(&self) -> &ExecutionProps { - &self.execution_props - } - - fn get_data_type(&self, expr: &Expr) -> Result { - expr.get_type(self.schema.as_ref()) - } -} - -impl From for MyInfo { - fn from(schema: DFSchemaRef) -> Self { - Self { - schema, - execution_props: ExecutionProps::new(), - } - } -} - /// A schema like: /// /// a: Int32 (possibly with nulls) @@ -130,14 +88,10 @@ fn test_evaluate_with_start_time( expected_expr: Expr, date_time: &DateTime, ) { - let execution_props = - ExecutionProps::new().with_query_execution_start_time(*date_time); - - let info: MyInfo = MyInfo { - schema: schema(), - execution_props, - }; - let simplifier = ExprSimplifier::new(info); + let context = SimplifyContext::default() + .with_schema(schema()) + .with_query_execution_start_time(Some(*date_time)); + let simplifier = ExprSimplifier::new(context); let simplified_expr = simplifier .simplify(input_expr.clone()) .expect("successfully evaluated"); @@ -199,7 +153,9 @@ fn to_timestamp_expr(arg: impl Into) -> Expr { #[test] fn basic() { - let info: MyInfo = schema().into(); + let context = SimplifyContext::default() + .with_schema(schema()) + .with_query_execution_start_time(Some(Utc::now())); // The `Expr` is a core concept in DataFusion, and DataFusion can // help simplify it. @@ -208,21 +164,21 @@ fn basic() { // optimize form `a < 5` automatically let expr = col("a").lt(lit(2i32) + lit(3i32)); - let simplifier = ExprSimplifier::new(info); + let simplifier = ExprSimplifier::new(context); let simplified = simplifier.simplify(expr).unwrap(); assert_eq!(simplified, col("a").lt(lit(5i32))); } #[test] fn fold_and_simplify() { - let info: MyInfo = schema().into(); + let context = SimplifyContext::default().with_schema(schema()); // What will it do with the expression `concat('foo', 'bar') == 'foobar')`? let expr = concat(vec![lit("foo"), lit("bar")]).eq(lit("foobar")); // Since datafusion applies both simplification *and* rewriting // some expressions can be entirely simplified - let simplifier = ExprSimplifier::new(info); + let simplifier = ExprSimplifier::new(context); let simplified = simplifier.simplify(expr).unwrap(); assert_eq!(simplified, lit(true)) } @@ -237,11 +193,15 @@ fn to_timestamp_expr_folded() -> Result<()> { .project(proj)? .build()?; - let expected = "Projection: TimestampNanosecond(1599566400000000000, None) AS to_timestamp(Utf8(\"2020-09-08T12:00:00+00:00\"))\ - \n TableScan: test" - .to_string(); - let actual = get_optimized_plan_formatted(plan, &Utc::now()); - assert_eq!(expected, actual); + let formatted = get_optimized_plan_formatted(plan, &Utc::now()); + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r#" + Projection: TimestampNanosecond(1599566400000000000, None) AS to_timestamp(Utf8("2020-09-08T12:00:00+00:00")) + TableScan: test + "# + ); Ok(()) } @@ -262,11 +222,16 @@ fn now_less_than_timestamp() -> Result<()> { // Note that constant folder runs and folds the entire // expression down to a single constant (true) - let expected = "Filter: Boolean(true)\ - \n TableScan: test"; - let actual = get_optimized_plan_formatted(plan, &time); - - assert_eq!(expected, actual); + let formatted = get_optimized_plan_formatted(plan, &time); + let actual = formatted.trim(); + + assert_snapshot!( + actual, + @r" + Filter: Boolean(true) + TableScan: test + " + ); Ok(()) } @@ -282,10 +247,13 @@ fn select_date_plus_interval() -> Result<()> { let date_plus_interval_expr = to_timestamp_expr(ts_string) .cast_to(&DataType::Date32, schema)? - + Expr::Literal(ScalarValue::IntervalDayTime(Some(IntervalDayTime { - days: 123, - milliseconds: 0, - }))); + + Expr::Literal( + ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 123, + milliseconds: 0, + })), + None, + ); let plan = LogicalPlanBuilder::from(table_scan.clone()) .project(vec![date_plus_interval_expr])? @@ -293,11 +261,16 @@ fn select_date_plus_interval() -> Result<()> { // Note that constant folder runs and folds the entire // expression down to a single constant (true) - let expected = r#"Projection: Date32("2021-01-09") AS to_timestamp(Utf8("2020-09-08T12:05:00+00:00")) + IntervalDayTime("IntervalDayTime { days: 123, milliseconds: 0 }") - TableScan: test"#; - let actual = get_optimized_plan_formatted(plan, &time); - - assert_eq!(expected, actual); + let formatted = get_optimized_plan_formatted(plan, &time); + let actual = formatted.trim(); + + assert_snapshot!( + actual, + @r#" + Projection: Date32("2021-01-09") AS to_timestamp(Utf8("2020-09-08T12:05:00+00:00")) + IntervalDayTime("IntervalDayTime { days: 123, milliseconds: 0 }") + TableScan: test + "# + ); Ok(()) } @@ -311,10 +284,15 @@ fn simplify_project_scalar_fn() -> Result<()> { // before simplify: power(t.f, 1.0) // after simplify: t.f as "power(t.f, 1.0)" - let expected = "Projection: test.f AS power(test.f,Float64(1))\ - \n TableScan: test"; - let actual = get_optimized_plan_formatted(plan, &Utc::now()); - assert_eq!(expected, actual); + let formatter = get_optimized_plan_formatted(plan, &Utc::now()); + let actual = formatter.trim(); + assert_snapshot!( + actual, + @r" + Projection: test.f AS power(test.f,Float64(1)) + TableScan: test + " + ); Ok(()) } @@ -334,9 +312,9 @@ fn simplify_scan_predicate() -> Result<()> { // before simplify: t.g = power(t.f, 1.0) // after simplify: t.g = t.f" - let expected = "TableScan: test, full_filters=[g = f]"; - let actual = get_optimized_plan_formatted(plan, &Utc::now()); - assert_eq!(expected, actual); + let formatted = get_optimized_plan_formatted(plan, &Utc::now()); + let actual = formatted.trim(); + assert_snapshot!(actual, @"TableScan: test, full_filters=[g = f]"); Ok(()) } @@ -490,8 +468,7 @@ fn multiple_now() -> Result<()> { // expect the same timestamp appears in both exprs let actual = get_optimized_plan_formatted(plan, &time); let expected = format!( - "Projection: TimestampNanosecond({}, Some(\"+00:00\")) AS now(), TimestampNanosecond({}, Some(\"+00:00\")) AS t2\ - \n TableScan: test", + "Projection: TimestampNanosecond({}, None) AS now(), TimestampNanosecond({}, None) AS t2\n TableScan: test", time.timestamp_nanos_opt().unwrap(), time.timestamp_nanos_opt().unwrap() ); @@ -500,6 +477,72 @@ fn multiple_now() -> Result<()> { Ok(()) } +/// Unwraps an alias expression to get the inner expression +fn unrwap_aliases(expr: &Expr) -> &Expr { + match expr { + Expr::Alias(alias) => unrwap_aliases(&alias.expr), + expr => expr, + } +} + +/// Test that `now()` is simplified to a literal when execution start time is set, +/// but remains as an expression when no execution start time is available. +#[test] +fn now_simplification_with_and_without_start_time() { + let plan = LogicalPlanBuilder::empty(false) + .project(vec![now()]) + .unwrap() + .build() + .unwrap(); + + // Case 1: With execution start time set, now() should be simplified to a literal + { + let time = DateTime::::from_timestamp_nanos(123); + let ctx: OptimizerContext = + OptimizerContext::new().with_query_execution_start_time(time); + let optimizer = SimplifyExpressions {}; + let simplified = optimizer + .rewrite(plan.clone(), &ctx) + .expect("rewrite should succeed") + .data; + let LogicalPlan::Projection(Projection { expr, .. }) = simplified else { + panic!("Expected Projection plan"); + }; + assert_eq!(expr.len(), 1); + let simplified = unrwap_aliases(expr.first().unwrap()); + // Should be a literal timestamp + match simplified { + Expr::Literal(ScalarValue::TimestampNanosecond(Some(ts), _), _) => { + assert_eq!(*ts, time.timestamp_nanos_opt().unwrap()); + } + other => panic!("Expected timestamp literal, got: {other:?}"), + } + } + + // Case 2: Without execution start time, now() should remain as a function call + { + let ctx: OptimizerContext = + OptimizerContext::new().without_query_execution_start_time(); + let optimizer = SimplifyExpressions {}; + let simplified = optimizer + .rewrite(plan, &ctx) + .expect("rewrite should succeed") + .data; + let LogicalPlan::Projection(Projection { expr, .. }) = simplified else { + panic!("Expected Projection plan"); + }; + assert_eq!(expr.len(), 1); + let simplified = unrwap_aliases(expr.first().unwrap()); + // Should still be a now() function call + match simplified { + Expr::ScalarFunction(ScalarFunction { func, .. }) => { + assert_eq!(func.name(), "now"); + } + other => panic!("Expected now() function call, got: {other:?}"), + } + } +} + // ------------------------------ // --- Simplifier tests ----- // ------------------------------ @@ -522,11 +565,8 @@ fn expr_test_schema() -> DFSchemaRef { } fn test_simplify(input_expr: Expr, expected_expr: Expr) { - let info: MyInfo = MyInfo { - schema: expr_test_schema(), - execution_props: ExecutionProps::new(), - }; - let simplifier = ExprSimplifier::new(info); + let context = SimplifyContext::default().with_schema(expr_test_schema()); + let simplifier = ExprSimplifier::new(context); let simplified_expr = simplifier .simplify(input_expr.clone()) .expect("successfully evaluated"); @@ -541,11 +581,10 @@ fn test_simplify_with_cycle_count( expected_expr: Expr, expected_count: u32, ) { - let info: MyInfo = MyInfo { - schema: expr_test_schema(), - execution_props: ExecutionProps::new(), - }; - let simplifier = ExprSimplifier::new(info); + let context = SimplifyContext::default() + .with_schema(expr_test_schema()) + .with_query_execution_start_time(Some(Utc::now())); + let simplifier = ExprSimplifier::new(context); let (simplified_expr, count) = simplifier .simplify_with_cycle_count_transformed(input_expr.clone()) .expect("successfully evaluated"); diff --git a/datafusion/core/tests/fifo/mod.rs b/datafusion/core/tests/fifo/mod.rs index 141a3f3b75586..3d99cc72fa590 100644 --- a/datafusion/core/tests/fifo/mod.rs +++ b/datafusion/core/tests/fifo/mod.rs @@ -22,21 +22,21 @@ mod unix_test { use std::fs::File; use std::path::PathBuf; - use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; + use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; use arrow::array::Array; use arrow::csv::ReaderBuilder; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - use datafusion::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; use datafusion::datasource::TableProvider; + use datafusion::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; use datafusion::{ prelude::{CsvReadOptions, SessionConfig, SessionContext}, test_util::{aggr_test_schema, arrow_test_data}, }; use datafusion_common::instant::Instant; - use datafusion_common::{exec_err, Result}; + use datafusion_common::{Result, exec_err}; use datafusion_expr::SortExpr; use futures::StreamExt; @@ -44,7 +44,7 @@ mod unix_test { use nix::unistd; use tempfile::TempDir; use tokio::io::AsyncWriteExt; - use tokio::task::{spawn_blocking, JoinHandle}; + use tokio::task::{JoinHandle, spawn_blocking}; /// Makes a TableProvider for a fifo file fn fifo_table( @@ -94,7 +94,6 @@ mod unix_test { /// This function creates a writing task for the FIFO file. To verify /// incremental processing, it waits for a signal to continue writing after /// a certain number of lines are written. - #[allow(clippy::disallowed_methods)] fn create_writing_task( file_path: PathBuf, header: String, @@ -105,6 +104,7 @@ mod unix_test { // Timeout for a long period of BrokenPipe error let broken_pipe_timeout = Duration::from_secs(10); // Spawn a new task to write to the FIFO file + #[expect(clippy::disallowed_methods)] tokio::spawn(async move { let mut file = tokio::fs::OpenOptions::new() .write(true) @@ -357,7 +357,7 @@ mod unix_test { (sink_fifo_path.clone(), sink_fifo_path.display()); // Spawn a new thread to read sink EXTERNAL TABLE. - #[allow(clippy::disallowed_methods)] // spawn allowed only in tests + #[expect(clippy::disallowed_methods)] // spawn allowed only in tests tasks.push(spawn_blocking(move || { let file = File::open(sink_fifo_path_thread).unwrap(); let schema = Arc::new(Schema::new(vec![ diff --git a/datafusion/core/tests/fuzz.rs b/datafusion/core/tests/fuzz.rs index 92646e8b37636..5e94f12b5805d 100644 --- a/datafusion/core/tests/fuzz.rs +++ b/datafusion/core/tests/fuzz.rs @@ -15,7 +15,10 @@ // specific language governing permissions and limitations // under the License. -/// Run all tests that are found in the `fuzz_cases` directory +/// Run all tests that are found in the `fuzz_cases` directory. +/// Fuzz tests are slow and gated behind the `extended_tests` feature. +/// Run with: cargo test --features extended_tests +#[cfg(feature = "extended_tests")] mod fuzz_cases; #[cfg(test)] diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 7e5ad011b5dd8..d64223abdb767 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -17,47 +17,44 @@ use std::sync::Arc; +use super::record_batch_generator::get_supported_types_columns; use crate::fuzz_cases::aggregation_fuzzer::query_builder::QueryBuilder; use crate::fuzz_cases::aggregation_fuzzer::{ AggregationFuzzerBuilder, DatasetGeneratorConfig, }; use arrow::array::{ - types::Int64Type, Array, ArrayRef, AsArray, Int32Array, Int64Array, RecordBatch, - StringArray, + Array, ArrayRef, AsArray, Int32Array, Int64Array, RecordBatch, StringArray, + types::Int64Type, }; -use arrow::compute::{concat_batches, SortOptions}; +use arrow::compute::concat_batches; use arrow::datatypes::DataType; use arrow::util::pretty::pretty_format_batches; use arrow_schema::{Field, Schema, SchemaRef}; -use datafusion::common::Result; +use datafusion::datasource::MemTable; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::source::DataSourceExec; -use datafusion::datasource::MemTable; -use datafusion::physical_expr::aggregate::AggregateExprBuilder; -use datafusion::physical_plan::aggregates::{ - AggregateExec, AggregateMode, PhysicalGroupBy, -}; -use datafusion::physical_plan::{collect, displayable, ExecutionPlan}; use datafusion::prelude::{DataFrame, SessionConfig, SessionContext}; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; -use datafusion_common::HashMap; +use datafusion_common::{HashMap, Result}; use datafusion_common_runtime::JoinSet; use datafusion_functions_aggregate::sum::sum_udaf; -use datafusion_physical_expr::expressions::{col, lit, Column}; use datafusion_physical_expr::PhysicalSortExpr; -use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_expr::expressions::{Column, col, lit}; use datafusion_physical_plan::InputOrderMode; -use test_utils::{add_empty_batches, StringBatchGenerator}; +use test_utils::{StringBatchGenerator, add_empty_batches}; +use datafusion_execution::TaskContext; use datafusion_execution::memory_pool::FairSpillPool; use datafusion_execution::runtime_env::RuntimeEnvBuilder; -use datafusion_execution::TaskContext; +use datafusion_physical_expr::aggregate::AggregateExprBuilder; +use datafusion_physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, +}; use datafusion_physical_plan::metrics::MetricValue; +use datafusion_physical_plan::{ExecutionPlan, collect, displayable}; use rand::rngs::StdRng; -use rand::{random, rng, Rng, SeedableRng}; - -use super::record_batch_generator::get_supported_types_columns; +use rand::{Rng, SeedableRng, random, rng}; // ======================================================================== // The new aggregation fuzz tests based on [`AggregationFuzzer`] @@ -254,6 +251,12 @@ fn baseline_config() -> DatasetGeneratorConfig { // low cardinality to try and get many repeated runs vec![String::from("u8_low")], vec![String::from("utf8_low"), String::from("u8_low")], + vec![String::from("dictionary_utf8_low")], + vec![ + String::from("dictionary_utf8_low"), + String::from("utf8_low"), + String::from("u8_low"), + ], ], } } @@ -303,13 +306,9 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str let schema = input1[0].schema(); let session_config = SessionConfig::new().with_batch_size(50); let ctx = SessionContext::new_with_config(session_config); - let mut sort_keys = LexOrdering::default(); - for ordering_col in ["a", "b", "c"] { - sort_keys.push(PhysicalSortExpr { - expr: col(ordering_col, &schema).unwrap(), - options: SortOptions::default(), - }) - } + let sort_keys = ["a", "b", "c"].map(|ordering_col| { + PhysicalSortExpr::new_default(col(ordering_col, &schema).unwrap()) + }); let concat_input_record = concat_batches(&schema, &input1).unwrap(); @@ -321,24 +320,23 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str .unwrap(); let running_source = DataSourceExec::from_data_source( - MemorySourceConfig::try_new(&[input1.clone()], schema.clone(), None) + MemorySourceConfig::try_new(std::slice::from_ref(&input1), schema.clone(), None) .unwrap() - .try_with_sort_information(vec![sort_keys]) + .try_with_sort_information(vec![sort_keys.into()]) .unwrap(), ); - let aggregate_expr = - vec![ - AggregateExprBuilder::new(sum_udaf(), vec![col("d", &schema).unwrap()]) - .schema(Arc::clone(&schema)) - .alias("sum1") - .build() - .map(Arc::new) - .unwrap(), - ]; + let aggregate_expr = vec![ + AggregateExprBuilder::new(sum_udaf(), vec![col("d", &schema).unwrap()]) + .schema(Arc::clone(&schema)) + .alias("sum1") + .build() + .map(Arc::new) + .unwrap(), + ]; let expr = group_by_columns .iter() - .map(|elem| (col(elem, &schema).unwrap(), elem.to_string())) + .map(|elem| (col(elem, &schema).unwrap(), (*elem).to_string())) .collect::>(); let group_by = PhysicalGroupBy::new_single(expr); @@ -404,7 +402,7 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str Left Plan:\n{}\n\ Right Plan:\n{}\n\ schema:\n{schema}\n\ - Left Ouptut:\n{}\n\ + Left Output:\n{}\n\ Right Output:\n{}\n\ input:\n{}\n\ ", @@ -556,7 +554,7 @@ async fn verify_ordered_aggregate(frame: &DataFrame, expected_sort: bool) { InputOrderMode::PartiallySorted(_) | InputOrderMode::Sorted )); } else { - assert!(matches!(exec.input_order_mode(), InputOrderMode::Linear)); + assert_eq!(*exec.input_order_mode(), InputOrderMode::Linear); } } Ok(TreeNodeRecursion::Continue) @@ -633,8 +631,11 @@ fn extract_result_counts(results: Vec) -> HashMap, i output } -fn assert_spill_count_metric(expect_spill: bool, single_aggregate: Arc) { - if let Some(metrics_set) = single_aggregate.metrics() { +pub(crate) fn assert_spill_count_metric( + expect_spill: bool, + plan_that_spills: Arc, +) -> usize { + if let Some(metrics_set) = plan_that_spills.metrics() { let mut spill_count = 0; // Inspect metrics for SpillCount @@ -648,8 +649,12 @@ fn assert_spill_count_metric(expect_spill: bool, single_aggregate: Arc 0 { - panic!("Expected no spill but found SpillCount metric with value greater than 0."); + panic!( + "Expected no spill but found SpillCount metric with value greater than 0." + ); } + + spill_count } else { panic!("No metrics returned from the operator; cannot verify spilling."); } @@ -657,7 +662,7 @@ fn assert_spill_count_metric(expect_spill: bool, single_aggregate: Arc Result<()> { +async fn test_single_mode_aggregate_single_mode_aggregate_with_spill() -> Result<()> { let scan_schema = Arc::new(Schema::new(vec![ Field::new("col_0", DataType::Int64, true), Field::new("col_1", DataType::Utf8, true), diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs index 2abfcd8417cbc..fe31098622c58 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs @@ -25,7 +25,7 @@ use datafusion_catalog::TableProvider; use datafusion_common::ScalarValue; use datafusion_common::{error::Result, utils::get_available_parallelism}; use datafusion_expr::col; -use rand::{rng, Rng}; +use rand::{Rng, rng}; use crate::fuzz_cases::aggregation_fuzzer::data_generator::Dataset; @@ -44,7 +44,6 @@ use crate::fuzz_cases::aggregation_fuzzer::data_generator::Dataset; /// - hint `sorted` or not /// - `spilling` or not (TODO, I think a special `MemoryPool` may be needed /// to support this) -/// pub struct SessionContextGenerator { /// Current testing dataset dataset: Arc, @@ -215,7 +214,7 @@ impl GeneratedSessionContextBuilder { /// The generated params for [`SessionContext`] #[derive(Debug)] -#[allow(dead_code)] +#[expect(dead_code)] pub struct SessionContextParams { batch_size: usize, target_partitions: usize, diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs index 82bfe199234ef..e49cffa89b04e 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs @@ -18,7 +18,7 @@ use arrow::array::RecordBatch; use arrow::datatypes::DataType; use datafusion_common::Result; -use datafusion_physical_expr::{expressions::col, PhysicalSortExpr}; +use datafusion_physical_expr::{PhysicalSortExpr, expressions::col}; use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_plan::sorts::sort::sort_batch; use test_utils::stagger_batch; @@ -39,7 +39,6 @@ use crate::fuzz_cases::record_batch_generator::{ColumnDescr, RecordBatchGenerato /// will generate one `base dataset` firstly. Then the `base dataset` will be sorted /// based on each `sort_key` respectively. And finally `len(sort_keys) + 1` datasets /// will be returned -/// #[derive(Debug, Clone)] pub struct DatasetGeneratorConfig { /// Descriptions of columns in datasets, it's `required` @@ -115,7 +114,6 @@ impl DatasetGeneratorConfig { /// /// - Split each batch to multiple batches which each sub-batch in has the randomly `rows num`, /// and this multiple batches will be used to create the `Dataset`. -/// pub struct DatasetGenerator { batch_generator: RecordBatchGenerator, sort_keys_set: Vec>, @@ -149,14 +147,14 @@ impl DatasetGenerator { for sort_keys in self.sort_keys_set.clone() { let sort_exprs = sort_keys .iter() - .map(|key| { - let col_expr = col(key, schema)?; - Ok(PhysicalSortExpr::new_default(col_expr)) - }) - .collect::>()?; - let sorted_batch = sort_batch(&base_batch, sort_exprs.as_ref(), None)?; - - let batches = stagger_batch(sorted_batch); + .map(|key| col(key, schema).map(PhysicalSortExpr::new_default)) + .collect::>>()?; + let batch = if let Some(ordering) = LexOrdering::new(sort_exprs) { + sort_batch(&base_batch, &ordering, None)? + } else { + base_batch.clone() + }; + let batches = stagger_batch(batch); let dataset = Dataset::new(batches, sort_keys); datasets.push(dataset); } @@ -211,8 +209,8 @@ mod test { sort_keys_set: vec![vec!["b".to_string()]], }; - let mut gen = DatasetGenerator::new(config); - let datasets = gen.generate().unwrap(); + let mut data_gen = DatasetGenerator::new(config); + let datasets = data_gen.generate().unwrap(); // Should Generate 2 datasets assert_eq!(datasets.len(), 2); diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs index cfb3c1c6a1b98..430762b1c28db 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs @@ -19,9 +19,9 @@ use std::sync::Arc; use arrow::array::RecordBatch; use arrow::util::pretty::pretty_format_batches; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{Result, internal_datafusion_err}; use datafusion_common_runtime::JoinSet; -use rand::{rng, Rng}; +use rand::{Rng, rng}; use crate::fuzz_cases::aggregation_fuzzer::query_builder::QueryBuilder; use crate::fuzz_cases::aggregation_fuzzer::{ @@ -171,7 +171,7 @@ impl AggregationFuzzer { let datasets = self .dataset_generator .generate() - .expect("should success to generate dataset"); + .expect("should succeed to generate dataset"); // Then for each of them, we random select a test sql for it let query_groups = datasets @@ -197,7 +197,7 @@ impl AggregationFuzzer { while let Some(join_handle) = join_set.join_next().await { // propagate errors join_handle.map_err(|e| { - DataFusionError::Internal(format!("AggregationFuzzer task error: {e:?}")) + internal_datafusion_err!("AggregationFuzzer task error: {e:?}") })??; } Ok(()) @@ -216,16 +216,16 @@ impl AggregationFuzzer { // Generate the baseline context, and get the baseline result firstly let baseline_ctx_with_params = ctx_generator .generate_baseline() - .expect("should success to generate baseline session context"); + .expect("should succeed to generate baseline session context"); let baseline_result = run_sql(&sql, &baseline_ctx_with_params.ctx) .await - .expect("should success to run baseline sql"); + .expect("should succeed to run baseline sql"); let baseline_result = Arc::new(baseline_result); // Generate test tasks for _ in 0..CTX_GEN_ROUNDS { let ctx_with_params = ctx_generator .generate() - .expect("should success to generate session context"); + .expect("should succeed to generate session context"); let task = AggregationFuzzTestTask { dataset_ref: dataset_ref.clone(), expected_result: baseline_result.clone(), @@ -253,7 +253,6 @@ impl AggregationFuzzer { /// /// - `dataset_ref`, the input dataset, store it for error reported when found /// the inconsistency between the one for `ctx` and `expected results`. -/// struct AggregationFuzzTestTask { /// Generated session context in current test case ctx_with_params: SessionContextWithParams, @@ -308,7 +307,7 @@ impl AggregationFuzzTestTask { format_batches_with_limit(expected_result), format_batches_with_limit(&self.dataset_ref.batches), ); - DataFusionError::Internal(message) + internal_datafusion_err!("{message}") }) } diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs index 04b764e46a96b..e7ce557d2267d 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs @@ -77,8 +77,8 @@ pub(crate) fn check_equality_of_batches( if lhs_row != rhs_row { return Err(InconsistentResult { row_idx, - lhs_row: lhs_row.to_string(), - rhs_row: rhs_row.to_string(), + lhs_row: (*lhs_row).to_string(), + rhs_row: (*rhs_row).to_string(), }); } } diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/query_builder.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/query_builder.rs index 209278385b7b5..7bb6177c31010 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/query_builder.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/query_builder.rs @@ -17,14 +17,14 @@ use std::{collections::HashSet, str::FromStr}; -use rand::{rng, seq::SliceRandom, Rng}; +use rand::{Rng, rng, seq::SliceRandom}; /// Random aggregate query builder /// /// Creates queries like /// ```sql /// SELECT AGG(..) FROM table_name GROUP BY -///``` +/// ``` #[derive(Debug, Default, Clone)] pub struct QueryBuilder { // =================================== @@ -95,7 +95,6 @@ pub struct QueryBuilder { /// More details can see [`GroupOrdering`]. /// /// [`GroupOrdering`]: datafusion_physical_plan::aggregates::order::GroupOrdering - /// dataset_sort_keys: Vec>, /// If we will also test the no grouping case like: @@ -103,7 +102,6 @@ pub struct QueryBuilder { /// ```text /// SELECT aggr FROM t; /// ``` - /// no_grouping: bool, // ==================================== @@ -184,13 +182,13 @@ impl QueryBuilder { /// Add max columns num in group by(default: 3), for example if it is set to 1, /// the generated sql will group by at most 1 column - #[allow(dead_code)] + #[expect(dead_code)] pub fn with_max_group_by_columns(mut self, max_group_by_columns: usize) -> Self { self.max_group_by_columns = max_group_by_columns; self } - #[allow(dead_code)] + #[expect(dead_code)] pub fn with_min_group_by_columns(mut self, min_group_by_columns: usize) -> Self { self.min_group_by_columns = min_group_by_columns; self @@ -204,7 +202,7 @@ impl QueryBuilder { } /// Add if also test the no grouping aggregation case(default: true) - #[allow(dead_code)] + #[expect(dead_code)] pub fn with_no_grouping(mut self, no_grouping: bool) -> Self { self.no_grouping = no_grouping; self diff --git a/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs b/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs index 3049631d4b3fe..92adda200d1a5 100644 --- a/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs @@ -19,7 +19,7 @@ use std::sync::Arc; -use arrow::array::{cast::AsArray, Array, OffsetSizeTrait, RecordBatch}; +use arrow::array::{Array, OffsetSizeTrait, RecordBatch, cast::AsArray}; use datafusion::datasource::MemTable; use datafusion_common_runtime::JoinSet; diff --git a/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs b/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs index d12d0a130c0c0..a57095066ee12 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs @@ -16,15 +16,19 @@ // under the License. use crate::fuzz_cases::equivalence::utils::{ - convert_to_orderings, create_random_schema, create_test_params, create_test_schema_2, + TestScalarUDF, create_random_schema, create_test_params, create_test_schema_2, generate_table_for_eq_properties, generate_table_for_orderings, - is_table_same_after_sort, TestScalarUDF, + is_table_same_after_sort, }; use arrow::compute::SortOptions; use datafusion_common::Result; +use datafusion_common::config::ConfigOptions; use datafusion_expr::{Operator, ScalarUDF}; -use datafusion_physical_expr::expressions::{col, BinaryExpr}; use datafusion_physical_expr::ScalarFunctionExpr; +use datafusion_physical_expr::equivalence::{ + convert_to_orderings, convert_to_sort_exprs, +}; +use datafusion_physical_expr::expressions::{BinaryExpr, col}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use itertools::Itertools; @@ -55,26 +59,25 @@ fn test_ordering_satisfy_with_equivalence_random() -> Result<()> { col("f", &test_schema)?, ]; - for n_req in 0..=col_exprs.len() { + for n_req in 1..=col_exprs.len() { for exprs in col_exprs.iter().combinations(n_req) { - let requirement = exprs + let sort_exprs = exprs .into_iter() - .map(|expr| PhysicalSortExpr { - expr: Arc::clone(expr), - options: SORT_OPTIONS, - }) - .collect::(); + .map(|expr| PhysicalSortExpr::new(Arc::clone(expr), SORT_OPTIONS)); + let Some(ordering) = LexOrdering::new(sort_exprs) else { + unreachable!("Test should always produce non-degenerate orderings"); + }; let expected = is_table_same_after_sort( - requirement.clone(), - table_data_with_properties.clone(), + ordering.clone(), + &table_data_with_properties, )?; let err_msg = format!( - "Error in test case requirement:{requirement:?}, expected: {expected:?}, eq_properties {eq_properties}" + "Error in test case requirement:{ordering:?}, expected: {expected:?}, eq_properties {eq_properties}" ); // Check whether ordering_satisfy API result and // experimental result matches. assert_eq!( - eq_properties.ordering_satisfy(requirement.as_ref()), + eq_properties.ordering_satisfy(ordering)?, expected, "{err_msg}" ); @@ -108,6 +111,7 @@ fn test_ordering_satisfy_with_equivalence_complex_random() -> Result<()> { Arc::clone(&test_fun), vec![col_a], &test_schema, + Arc::new(ConfigOptions::default()), )?); let a_plus_b = Arc::new(BinaryExpr::new( col("a", &test_schema)?, @@ -125,27 +129,26 @@ fn test_ordering_satisfy_with_equivalence_complex_random() -> Result<()> { a_plus_b, ]; - for n_req in 0..=exprs.len() { + for n_req in 1..=exprs.len() { for exprs in exprs.iter().combinations(n_req) { - let requirement = exprs + let sort_exprs = exprs .into_iter() - .map(|expr| PhysicalSortExpr { - expr: Arc::clone(expr), - options: SORT_OPTIONS, - }) - .collect::(); + .map(|expr| PhysicalSortExpr::new(Arc::clone(expr), SORT_OPTIONS)); + let Some(ordering) = LexOrdering::new(sort_exprs) else { + unreachable!("Test should always produce non-degenerate orderings"); + }; let expected = is_table_same_after_sort( - requirement.clone(), - table_data_with_properties.clone(), + ordering.clone(), + &table_data_with_properties, )?; let err_msg = format!( - "Error in test case requirement:{requirement:?}, expected: {expected:?}, eq_properties: {eq_properties}", + "Error in test case requirement:{ordering:?}, expected: {expected:?}, eq_properties: {eq_properties}", ); // Check whether ordering_satisfy API result and // experimental result matches. assert_eq!( - eq_properties.ordering_satisfy(requirement.as_ref()), + eq_properties.ordering_satisfy(ordering)?, (expected | false), "{err_msg}" ); @@ -300,25 +303,19 @@ fn test_ordering_satisfy_with_equivalence() -> Result<()> { ]; for (cols, expected) in requirements { - let err_msg = format!("Error in test case:{cols:?}"); - let required = cols - .into_iter() - .map(|(expr, options)| PhysicalSortExpr { - expr: Arc::clone(expr), - options, - }) - .collect::(); + let err_msg = format!("Error in test case: {cols:?}"); + let sort_exprs = convert_to_sort_exprs(&cols); + let Some(ordering) = LexOrdering::new(sort_exprs) else { + unreachable!("Test should always produce non-degenerate orderings"); + }; // Check expected result with experimental result. assert_eq!( - is_table_same_after_sort( - required.clone(), - table_data_with_properties.clone() - )?, + is_table_same_after_sort(ordering.clone(), &table_data_with_properties)?, expected ); assert_eq!( - eq_properties.ordering_satisfy(required.as_ref()), + eq_properties.ordering_satisfy(ordering)?, expected, "{err_msg}" ); @@ -371,7 +368,7 @@ fn test_ordering_satisfy_on_data() -> Result<()> { (col_d, option_asc), ]; let ordering = convert_to_orderings(&[ordering])[0].clone(); - assert!(!is_table_same_after_sort(ordering, batch.clone())?); + assert!(!is_table_same_after_sort(ordering, &batch)?); // [a ASC, b ASC, d ASC] cannot be deduced let ordering = vec![ @@ -380,12 +377,12 @@ fn test_ordering_satisfy_on_data() -> Result<()> { (col_d, option_asc), ]; let ordering = convert_to_orderings(&[ordering])[0].clone(); - assert!(!is_table_same_after_sort(ordering, batch.clone())?); + assert!(!is_table_same_after_sort(ordering, &batch)?); // [a ASC, b ASC] can be deduced let ordering = vec![(col_a, option_asc), (col_b, option_asc)]; let ordering = convert_to_orderings(&[ordering])[0].clone(); - assert!(is_table_same_after_sort(ordering, batch.clone())?); + assert!(is_table_same_after_sort(ordering, &batch)?); Ok(()) } diff --git a/datafusion/core/tests/fuzz_cases/equivalence/projection.rs b/datafusion/core/tests/fuzz_cases/equivalence/projection.rs index 38e66387a02cd..2f67e211ce915 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/projection.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/projection.rs @@ -16,14 +16,15 @@ // under the License. use crate::fuzz_cases::equivalence::utils::{ - apply_projection, create_random_schema, generate_table_for_eq_properties, - is_table_same_after_sort, TestScalarUDF, + TestScalarUDF, apply_projection, create_random_schema, + generate_table_for_eq_properties, is_table_same_after_sort, }; use arrow::compute::SortOptions; use datafusion_common::Result; +use datafusion_common::config::ConfigOptions; use datafusion_expr::{Operator, ScalarUDF}; use datafusion_physical_expr::equivalence::ProjectionMapping; -use datafusion_physical_expr::expressions::{col, BinaryExpr}; +use datafusion_physical_expr::expressions::{BinaryExpr, col}; use datafusion_physical_expr::{PhysicalExprRef, ScalarFunctionExpr}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; @@ -49,6 +50,7 @@ fn project_orderings_random() -> Result<()> { Arc::clone(&test_fun), vec![col_a], &test_schema, + Arc::new(ConfigOptions::default()), )?); // a + b let a_plus_b = Arc::new(BinaryExpr::new( @@ -56,7 +58,7 @@ fn project_orderings_random() -> Result<()> { Operator::Plus, col("b", &test_schema)?, )) as Arc; - let proj_exprs = vec![ + let proj_exprs = [ (col("a", &test_schema)?, "a_new"), (col("b", &test_schema)?, "b_new"), (col("c", &test_schema)?, "c_new"), @@ -71,7 +73,7 @@ fn project_orderings_random() -> Result<()> { for proj_exprs in proj_exprs.iter().combinations(n_req) { let proj_exprs = proj_exprs .into_iter() - .map(|(expr, name)| (Arc::clone(expr), name.to_string())) + .map(|(expr, name)| (Arc::clone(expr), (*name).to_string())) .collect::>(); let (projected_batch, projected_eq) = apply_projection( proj_exprs.clone(), @@ -87,10 +89,7 @@ fn project_orderings_random() -> Result<()> { // Since ordered section satisfies schema, we expect // that result will be same after sort (e.g sort was unnecessary). assert!( - is_table_same_after_sort( - ordering.clone(), - projected_batch.clone(), - )?, + is_table_same_after_sort(ordering.clone(), &projected_batch)?, "{}", err_msg ); @@ -125,6 +124,7 @@ fn ordering_satisfy_after_projection_random() -> Result<()> { Arc::clone(&test_fun), vec![col_a], &test_schema, + Arc::new(ConfigOptions::default()), )?) as PhysicalExprRef; // a + b let a_plus_b = Arc::new(BinaryExpr::new( @@ -132,7 +132,7 @@ fn ordering_satisfy_after_projection_random() -> Result<()> { Operator::Plus, col("b", &test_schema)?, )) as Arc; - let proj_exprs = vec![ + let proj_exprs = [ (col("a", &test_schema)?, "a_new"), (col("b", &test_schema)?, "b_new"), (col("c", &test_schema)?, "c_new"), @@ -147,8 +147,7 @@ fn ordering_satisfy_after_projection_random() -> Result<()> { for proj_exprs in proj_exprs.iter().combinations(n_req) { let proj_exprs = proj_exprs .into_iter() - .map(|(expr, name)| (Arc::clone(expr), name.to_string())) - .collect::>(); + .map(|(expr, name)| (Arc::clone(expr), (*name).to_string())); let (projected_batch, projected_eq) = apply_projection( proj_exprs.clone(), &table_data_with_properties, @@ -156,33 +155,34 @@ fn ordering_satisfy_after_projection_random() -> Result<()> { )?; let projection_mapping = - ProjectionMapping::try_new(&proj_exprs, &test_schema)?; + ProjectionMapping::try_new(proj_exprs, &test_schema)?; let projected_exprs = projection_mapping .iter() - .map(|(_source, target)| Arc::clone(target)) + .flat_map(|(_, targets)| { + targets.iter().map(|(target, _)| Arc::clone(target)) + }) .collect::>(); - for n_req in 0..=projected_exprs.len() { + for n_req in 1..=projected_exprs.len() { for exprs in projected_exprs.iter().combinations(n_req) { - let requirement = exprs - .into_iter() - .map(|expr| PhysicalSortExpr { - expr: Arc::clone(expr), - options: SORT_OPTIONS, - }) - .collect::(); - let expected = is_table_same_after_sort( - requirement.clone(), - projected_batch.clone(), - )?; + let sort_exprs = exprs.into_iter().map(|expr| { + PhysicalSortExpr::new(Arc::clone(expr), SORT_OPTIONS) + }); + let Some(ordering) = LexOrdering::new(sort_exprs) else { + unreachable!( + "Test should always produce non-degenerate orderings" + ); + }; + let expected = + is_table_same_after_sort(ordering.clone(), &projected_batch)?; let err_msg = format!( - "Error in test case requirement:{requirement:?}, expected: {expected:?}, eq_properties: {eq_properties}, projected_eq: {projected_eq}, projection_mapping: {projection_mapping:?}" + "Error in test case requirement:{ordering:?}, expected: {expected:?}, eq_properties: {eq_properties}, projected_eq: {projected_eq}, projection_mapping: {projection_mapping:?}" ); // Check whether ordering_satisfy API result and // experimental result matches. assert_eq!( - projected_eq.ordering_satisfy(requirement.as_ref()), + projected_eq.ordering_satisfy(ordering)?, expected, "{err_msg}" ); diff --git a/datafusion/core/tests/fuzz_cases/equivalence/properties.rs b/datafusion/core/tests/fuzz_cases/equivalence/properties.rs index 9a21464157495..1490eb08a0291 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/properties.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/properties.rs @@ -15,18 +15,21 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use crate::fuzz_cases::equivalence::utils::{ - create_random_schema, generate_table_for_eq_properties, is_table_same_after_sort, - TestScalarUDF, + TestScalarUDF, create_random_schema, generate_table_for_eq_properties, + is_table_same_after_sort, }; + use datafusion_common::Result; use datafusion_expr::{Operator, ScalarUDF}; -use datafusion_physical_expr::expressions::{col, BinaryExpr}; -use datafusion_physical_expr::{PhysicalExprRef, ScalarFunctionExpr}; -use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_expr::expressions::{BinaryExpr, col}; +use datafusion_physical_expr::{LexOrdering, ScalarFunctionExpr}; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; + +use datafusion_common::config::ConfigOptions; use itertools::Itertools; -use std::sync::Arc; #[test] fn test_find_longest_permutation_random() -> Result<()> { @@ -47,13 +50,14 @@ fn test_find_longest_permutation_random() -> Result<()> { Arc::clone(&test_fun), vec![col_a], &test_schema, - )?) as PhysicalExprRef; + Arc::new(ConfigOptions::default()), + )?) as _; let a_plus_b = Arc::new(BinaryExpr::new( col("a", &test_schema)?, Operator::Plus, col("b", &test_schema)?, - )) as Arc; + )) as _; let exprs = [ col("a", &test_schema)?, col("b", &test_schema)?, @@ -68,16 +72,16 @@ fn test_find_longest_permutation_random() -> Result<()> { for n_req in 0..=exprs.len() { for exprs in exprs.iter().combinations(n_req) { let exprs = exprs.into_iter().cloned().collect::>(); - let (ordering, indices) = eq_properties.find_longest_permutation(&exprs); + let (ordering, indices) = + eq_properties.find_longest_permutation(&exprs)?; // Make sure that find_longest_permutation return values are consistent let ordering2 = indices .iter() .zip(ordering.iter()) - .map(|(&idx, sort_expr)| PhysicalSortExpr { - expr: Arc::clone(&exprs[idx]), - options: sort_expr.options, + .map(|(&idx, sort_expr)| { + PhysicalSortExpr::new(Arc::clone(&exprs[idx]), sort_expr.options) }) - .collect::(); + .collect::>(); assert_eq!( ordering, ordering2, "indices and lexicographical ordering do not match" @@ -89,11 +93,11 @@ fn test_find_longest_permutation_random() -> Result<()> { assert_eq!(ordering.len(), indices.len(), "{err_msg}"); // Since ordered section satisfies schema, we expect // that result will be same after sort (e.g sort was unnecessary). + let Some(ordering) = LexOrdering::new(ordering) else { + continue; + }; assert!( - is_table_same_after_sort( - ordering.clone(), - table_data_with_properties.clone(), - )?, + is_table_same_after_sort(ordering, &table_data_with_properties)?, "{}", err_msg ); diff --git a/datafusion/core/tests/fuzz_cases/equivalence/utils.rs b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs index a906648f872dc..580a226721083 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/utils.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs @@ -15,55 +15,50 @@ // specific language governing permissions and limitations // under the License. -use datafusion::physical_plan::expressions::col; -use datafusion::physical_plan::expressions::Column; -use datafusion_physical_expr::{ConstExpr, EquivalenceProperties, PhysicalSortExpr}; use std::any::Any; use std::cmp::Ordering; use std::sync::Arc; use arrow::array::{ArrayRef, Float32Array, Float64Array, RecordBatch, UInt32Array}; -use arrow::compute::SortOptions; -use arrow::compute::{lexsort_to_indices, take_record_batch, SortColumn}; +use arrow::compute::{SortColumn, SortOptions, lexsort_to_indices, take_record_batch}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::utils::{compare_rows, get_row_at_idx}; -use datafusion_common::{exec_err, plan_datafusion_err, DataFusionError, Result}; +use datafusion_common::{Result, exec_err, internal_datafusion_err, plan_err}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; -use datafusion_physical_expr::equivalence::{EquivalenceClass, ProjectionMapping}; +use datafusion_physical_expr::equivalence::{ + EquivalenceClass, ProjectionMapping, convert_to_orderings, +}; +use datafusion_physical_expr::{ConstExpr, EquivalenceProperties}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_plan::expressions::{Column, col}; use itertools::izip; use rand::prelude::*; +/// Projects the input schema based on the given projection mapping. pub fn output_schema( mapping: &ProjectionMapping, input_schema: &Arc, ) -> Result { - // Calculate output schema - let fields: Result> = mapping - .iter() - .map(|(source, target)| { - let name = target - .as_any() - .downcast_ref::() - .ok_or_else(|| plan_datafusion_err!("Expects to have column"))? - .name(); - let field = Field::new( - name, - source.data_type(input_schema)?, - source.nullable(input_schema)?, - ); - - Ok(field) - }) - .collect(); + // Calculate output schema: + let mut fields = vec![]; + for (source, targets) in mapping.iter() { + let data_type = source.data_type(input_schema)?; + let nullable = source.nullable(input_schema)?; + for (target, _) in targets.iter() { + let Some(column) = target.as_any().downcast_ref::() else { + return plan_err!("Expects to have column"); + }; + fields.push(Field::new(column.name(), data_type.clone(), nullable)); + } + } let output_schema = Arc::new(Schema::new_with_metadata( - fields?, + fields, input_schema.metadata().clone(), )); @@ -100,9 +95,9 @@ pub fn create_random_schema(seed: u64) -> Result<(SchemaRef, EquivalenceProperti let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); // Define a and f are aliases - eq_properties.add_equal_conditions(col_a, col_f)?; + eq_properties.add_equal_conditions(Arc::clone(col_a), Arc::clone(col_f))?; // Column e has constant value. - eq_properties = eq_properties.with_constants([ConstExpr::from(col_e)]); + eq_properties.add_constants([ConstExpr::from(Arc::clone(col_e))])?; // Randomly order columns for sorting let mut rng = StdRng::seed_from_u64(seed); @@ -114,18 +109,18 @@ pub fn create_random_schema(seed: u64) -> Result<(SchemaRef, EquivalenceProperti }; while !remaining_exprs.is_empty() { - let n_sort_expr = rng.random_range(0..remaining_exprs.len() + 1); + let n_sort_expr = rng.random_range(1..remaining_exprs.len() + 1); remaining_exprs.shuffle(&mut rng); - let ordering = remaining_exprs - .drain(0..n_sort_expr) - .map(|expr| PhysicalSortExpr { - expr: Arc::clone(expr), - options: options_asc, - }) - .collect(); + let ordering = + remaining_exprs + .drain(0..n_sort_expr) + .map(|expr| PhysicalSortExpr { + expr: Arc::clone(expr), + options: options_asc, + }); - eq_properties.add_new_orderings([ordering]); + eq_properties.add_ordering(ordering); } Ok((test_schema, eq_properties)) @@ -133,12 +128,12 @@ pub fn create_random_schema(seed: u64) -> Result<(SchemaRef, EquivalenceProperti // Apply projection to the input_data, return projected equivalence properties and record batch pub fn apply_projection( - proj_exprs: Vec<(Arc, String)>, + proj_exprs: impl IntoIterator, String)>, input_data: &RecordBatch, input_eq_properties: &EquivalenceProperties, ) -> Result<(RecordBatch, EquivalenceProperties)> { let input_schema = input_data.schema(); - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + let projection_mapping = ProjectionMapping::try_new(proj_exprs, &input_schema)?; let output_schema = output_schema(&projection_mapping, &input_schema)?; let num_rows = input_data.num_rows(); @@ -168,49 +163,49 @@ fn add_equal_conditions_test() -> Result<()> { ])); let mut eq_properties = EquivalenceProperties::new(schema); - let col_a_expr = Arc::new(Column::new("a", 0)) as Arc; - let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; - let col_c_expr = Arc::new(Column::new("c", 2)) as Arc; - let col_x_expr = Arc::new(Column::new("x", 3)) as Arc; - let col_y_expr = Arc::new(Column::new("y", 4)) as Arc; + let col_a = Arc::new(Column::new("a", 0)) as _; + let col_b = Arc::new(Column::new("b", 1)) as _; + let col_c = Arc::new(Column::new("c", 2)) as _; + let col_x = Arc::new(Column::new("x", 3)) as _; + let col_y = Arc::new(Column::new("y", 4)) as _; // a and b are aliases - eq_properties.add_equal_conditions(&col_a_expr, &col_b_expr)?; + eq_properties.add_equal_conditions(Arc::clone(&col_a), Arc::clone(&col_b))?; assert_eq!(eq_properties.eq_group().len(), 1); // This new entry is redundant, size shouldn't increase - eq_properties.add_equal_conditions(&col_b_expr, &col_a_expr)?; + eq_properties.add_equal_conditions(Arc::clone(&col_b), Arc::clone(&col_a))?; assert_eq!(eq_properties.eq_group().len(), 1); let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 2); - assert!(eq_groups.contains(&col_a_expr)); - assert!(eq_groups.contains(&col_b_expr)); + assert!(eq_groups.contains(&col_a)); + assert!(eq_groups.contains(&col_b)); // b and c are aliases. Existing equivalence class should expand, // however there shouldn't be any new equivalence class - eq_properties.add_equal_conditions(&col_b_expr, &col_c_expr)?; + eq_properties.add_equal_conditions(Arc::clone(&col_b), Arc::clone(&col_c))?; assert_eq!(eq_properties.eq_group().len(), 1); let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 3); - assert!(eq_groups.contains(&col_a_expr)); - assert!(eq_groups.contains(&col_b_expr)); - assert!(eq_groups.contains(&col_c_expr)); + assert!(eq_groups.contains(&col_a)); + assert!(eq_groups.contains(&col_b)); + assert!(eq_groups.contains(&col_c)); // This is a new set of equality. Hence equivalent class count should be 2. - eq_properties.add_equal_conditions(&col_x_expr, &col_y_expr)?; + eq_properties.add_equal_conditions(Arc::clone(&col_x), Arc::clone(&col_y))?; assert_eq!(eq_properties.eq_group().len(), 2); // This equality bridges distinct equality sets. // Hence equivalent class count should decrease from 2 to 1. - eq_properties.add_equal_conditions(&col_x_expr, &col_a_expr)?; + eq_properties.add_equal_conditions(Arc::clone(&col_x), Arc::clone(&col_a))?; assert_eq!(eq_properties.eq_group().len(), 1); let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 5); - assert!(eq_groups.contains(&col_a_expr)); - assert!(eq_groups.contains(&col_b_expr)); - assert!(eq_groups.contains(&col_c_expr)); - assert!(eq_groups.contains(&col_x_expr)); - assert!(eq_groups.contains(&col_y_expr)); + assert!(eq_groups.contains(&col_a)); + assert!(eq_groups.contains(&col_b)); + assert!(eq_groups.contains(&col_c)); + assert!(eq_groups.contains(&col_x)); + assert!(eq_groups.contains(&col_y)); Ok(()) } @@ -226,7 +221,7 @@ fn add_equal_conditions_test() -> Result<()> { /// already sorted according to `required_ordering` to begin with. pub fn is_table_same_after_sort( mut required_ordering: LexOrdering, - batch: RecordBatch, + batch: &RecordBatch, ) -> Result { // Clone the original schema and columns let original_schema = batch.schema(); @@ -327,7 +322,7 @@ pub fn create_test_params() -> Result<(SchemaRef, EquivalenceProperties)> { let col_f = &col("f", &test_schema)?; let col_g = &col("g", &test_schema)?; let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); - eq_properties.add_equal_conditions(col_a, col_c)?; + eq_properties.add_equal_conditions(Arc::clone(col_a), Arc::clone(col_c))?; let option_asc = SortOptions { descending: false, @@ -350,7 +345,7 @@ pub fn create_test_params() -> Result<(SchemaRef, EquivalenceProperties)> { ], ]; let orderings = convert_to_orderings(&orderings); - eq_properties.add_new_orderings(orderings); + eq_properties.add_orderings(orderings); Ok((test_schema, eq_properties)) } @@ -376,7 +371,7 @@ pub fn generate_table_for_eq_properties( // Fill constant columns for constant in eq_properties.constants() { - let col = constant.expr().as_any().downcast_ref::().unwrap(); + let col = constant.expr.as_any().downcast_ref::().unwrap(); let (idx, _field) = schema.column_with_name(col.name()).unwrap(); let arr = Arc::new(Float64Array::from_iter_values(vec![0 as f64; n_elem])) as ArrayRef; @@ -461,7 +456,7 @@ pub fn generate_table_for_orderings( let batch = RecordBatch::try_from_iter(arrays)?; // Sort batch according to first ordering expression - let sort_columns = get_sort_columns(&batch, orderings[0].as_ref())?; + let sort_columns = get_sort_columns(&batch, &orderings[0])?; let sort_indices = lexsort_to_indices(&sort_columns, None)?; let mut batch = take_record_batch(&batch, &sort_indices)?; @@ -494,29 +489,6 @@ pub fn generate_table_for_orderings( Ok(batch) } -// Convert each tuple to PhysicalSortExpr -pub fn convert_to_sort_exprs( - in_data: &[(&Arc, SortOptions)], -) -> LexOrdering { - in_data - .iter() - .map(|(expr, options)| PhysicalSortExpr { - expr: Arc::clone(*expr), - options: *options, - }) - .collect() -} - -// Convert each inner tuple to PhysicalSortExpr -pub fn convert_to_orderings( - orderings: &[Vec<(&Arc, SortOptions)>], -) -> Vec { - orderings - .iter() - .map(|sort_exprs| convert_to_sort_exprs(sort_exprs)) - .collect() -} - // Utility function to generate random f64 array fn generate_random_f64_array( n_elems: usize, @@ -540,7 +512,7 @@ fn get_sort_columns( .collect::>>() } -#[derive(Debug, Clone)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct TestScalarUDF { pub(crate) signature: Signature, } @@ -590,11 +562,11 @@ impl ScalarUDFImpl for TestScalarUDF { DataType::Float64 => Arc::new({ let arg = &args[0].as_any().downcast_ref::().ok_or_else( || { - DataFusionError::Internal(format!( + internal_datafusion_err!( "could not cast {} to {}", self.name(), std::any::type_name::() - )) + ) }, )?; @@ -605,11 +577,11 @@ impl ScalarUDFImpl for TestScalarUDF { DataType::Float32 => Arc::new({ let arg = &args[0].as_any().downcast_ref::().ok_or_else( || { - DataFusionError::Internal(format!( + internal_datafusion_err!( "could not cast {} to {}", self.name(), std::any::type_name::() - )) + ) }, )?; diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 82ee73b525cb1..669b98e39fec1 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -20,7 +20,7 @@ use std::time::SystemTime; use crate::fuzz_cases::join_fuzz::JoinTestType::{HjSmj, NljHj}; -use arrow::array::{ArrayRef, Int32Array}; +use arrow::array::{ArrayRef, BinaryArray, Int32Array}; use arrow::compute::SortOptions; use arrow::datatypes::Schema; use arrow::record_batch::RecordBatch; @@ -37,9 +37,9 @@ use datafusion::physical_plan::joins::{ HashJoinExec, NestedLoopJoinExec, PartitionMode, SortMergeJoinExec, }; use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_common::ScalarValue; -use datafusion_physical_expr::expressions::Literal; +use datafusion_common::{NullEquality, ScalarValue}; use datafusion_physical_expr::PhysicalExprRef; +use datafusion_physical_expr::expressions::Literal; use itertools::Itertools; use rand::Rng; @@ -91,218 +91,564 @@ fn col_lt_col_filter(schema1: Arc, schema2: Arc) -> JoinFilter { #[tokio::test] async fn test_inner_join_1k_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), - JoinType::Inner, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::Inner, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_inner_join_1k() { - JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), - JoinType::Inner, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::Inner, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_left_join_1k() { - JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), - JoinType::Left, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::Left, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_left_join_1k_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), - JoinType::Left, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::Left, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_right_join_1k() { - JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), - JoinType::Right, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::Right, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_right_join_1k_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), - JoinType::Right, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::Right, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_full_join_1k() { - JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), - JoinType::Full, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::Full, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_full_join_1k_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), - JoinType::Full, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[NljHj, HjSmj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::Full, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[NljHj, HjSmj], false) + .await + } } #[tokio::test] async fn test_left_semi_join_1k() { - JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), - JoinType::LeftSemi, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::LeftSemi, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_left_semi_join_1k_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), - JoinType::LeftSemi, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::LeftSemi, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_right_semi_join_1k() { - JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), - JoinType::RightSemi, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::RightSemi, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_right_semi_join_1k_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), - JoinType::RightSemi, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::RightSemi, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_left_anti_join_1k() { - JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), - JoinType::LeftAnti, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::LeftAnti, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_left_anti_join_1k_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), - JoinType::LeftAnti, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::LeftAnti, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_right_anti_join_1k() { - JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), - JoinType::RightAnti, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::RightAnti, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_right_anti_join_1k_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), - JoinType::RightAnti, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::RightAnti, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_left_mark_join_1k() { - JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), - JoinType::LeftMark, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::LeftMark, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_left_mark_join_1k_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), - JoinType::LeftMark, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::LeftMark, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } +} + +// todo: add JoinTestType::HjSmj after Right mark SortMergeJoin support +#[tokio::test] +async fn test_right_mark_join_1k() { + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::RightMark, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } +} + +#[tokio::test] +async fn test_right_mark_join_1k_filtered() { + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::RightMark, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } +} + +#[tokio::test] +async fn test_inner_join_1k_binary_filtered() { + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::Inner, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } +} + +#[tokio::test] +async fn test_inner_join_1k_binary() { + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::Inner, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } +} + +#[tokio::test] +async fn test_left_join_1k_binary() { + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::Left, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } +} + +#[tokio::test] +async fn test_left_join_1k_binary_filtered() { + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::Left, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } +} + +#[tokio::test] +async fn test_right_join_1k_binary() { + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::Right, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } +} + +#[tokio::test] +async fn test_right_join_1k_binary_filtered() { + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::Right, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } +} + +#[tokio::test] +async fn test_full_join_1k_binary() { + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::Full, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } +} + +#[tokio::test] +async fn test_full_join_1k_binary_filtered() { + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::Full, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[NljHj, HjSmj], false) + .await + } +} + +#[tokio::test] +async fn test_left_semi_join_1k_binary() { + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::LeftSemi, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } +} + +#[tokio::test] +async fn test_left_semi_join_1k_binary_filtered() { + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::LeftSemi, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } +} + +#[tokio::test] +async fn test_right_semi_join_1k_binary() { + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::RightSemi, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } +} + +#[tokio::test] +async fn test_right_semi_join_1k_binary_filtered() { + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::RightSemi, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } +} + +#[tokio::test] +async fn test_left_anti_join_1k_binary() { + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::LeftAnti, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } +} + +#[tokio::test] +async fn test_left_anti_join_1k_binary_filtered() { + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::LeftAnti, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } +} + +#[tokio::test] +async fn test_right_anti_join_1k_binary() { + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::RightAnti, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } +} + +#[tokio::test] +async fn test_right_anti_join_1k_binary_filtered() { + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::RightAnti, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } +} + +#[tokio::test] +async fn test_left_mark_join_1k_binary() { + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::LeftMark, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } +} + +#[tokio::test] +async fn test_left_mark_join_1k_binary_filtered() { + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::LeftMark, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } +} + +// todo: add JoinTestType::HjSmj after Right mark SortMergeJoin support +#[tokio::test] +async fn test_right_mark_join_1k_binary() { + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::RightMark, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } +} + +#[tokio::test] +async fn test_right_mark_join_1k_binary_filtered() { + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::RightMark, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } type JoinFilterBuilder = Box, Arc) -> JoinFilter>; @@ -452,12 +798,18 @@ impl JoinFuzzTestCase { fn left_right(&self) -> (Arc, Arc) { let schema1 = self.input1[0].schema(); let schema2 = self.input2[0].schema(); - let left = - MemorySourceConfig::try_new_exec(&[self.input1.clone()], schema1, None) - .unwrap(); - let right = - MemorySourceConfig::try_new_exec(&[self.input2.clone()], schema2, None) - .unwrap(); + let left = MemorySourceConfig::try_new_exec( + std::slice::from_ref(&self.input1), + schema1, + None, + ) + .unwrap(); + let right = MemorySourceConfig::try_new_exec( + std::slice::from_ref(&self.input2), + schema2, + None, + ) + .unwrap(); (left, right) } @@ -479,7 +831,7 @@ impl JoinFuzzTestCase { self.join_filter(), self.join_type, vec![SortOptions::default(); self.on_columns().len()], - false, + NullEquality::NullEqualsNothing, ) .unwrap(), ) @@ -496,6 +848,7 @@ impl JoinFuzzTestCase { &self.join_type, None, PartitionMode::Partitioned, + NullEquality::NullEqualsNothing, false, ) .unwrap(), @@ -569,7 +922,9 @@ impl JoinFuzzTestCase { std::fs::remove_dir_all(fuzz_debug).unwrap_or(()); std::fs::create_dir_all(fuzz_debug).unwrap(); let out_dir_name = &format!("{fuzz_debug}/batch_size_{batch_size}"); - println!("Test result data mismatch found. HJ rows {hj_rows}, SMJ rows {smj_rows}, NLJ rows {nlj_rows}"); + println!( + "Test result data mismatch found. HJ rows {hj_rows}, SMJ rows {smj_rows}, NLJ rows {nlj_rows}" + ); println!("The debug is ON. Input data will be saved to {out_dir_name}"); Self::save_partitioned_batches_as_parquet( @@ -588,7 +943,6 @@ impl JoinFuzzTestCase { hj_formatted_sorted.iter().for_each(|s| println!("{s}")); println!("=============== NestedLoopJoinExec =================="); nlj_formatted_sorted.iter().for_each(|s| println!("{s}")); - Self::save_partitioned_batches_as_parquet( &nlj_collected, out_dir_name, @@ -621,10 +975,18 @@ impl JoinFuzzTestCase { } if join_tests.contains(&NljHj) { - let err_msg_rowcnt = format!("NestedLoopJoinExec and HashJoinExec produced different row counts, batch_size: {batch_size}"); + let err_msg_rowcnt = format!( + "NestedLoopJoinExec and HashJoinExec produced different row counts, batch_size: {batch_size}" + ); assert_eq!(nlj_rows, hj_rows, "{}", err_msg_rowcnt.as_str()); + if nlj_rows == 0 && hj_rows == 0 { + // both joins returned no rows, skip content comparison + continue; + } - let err_msg_contents = format!("NestedLoopJoinExec and HashJoinExec produced different results, batch_size: {batch_size}"); + let err_msg_contents = format!( + "NestedLoopJoinExec and HashJoinExec produced different results, batch_size: {batch_size}" + ); // row level compare if any of joins returns the result // the reason is different formatting when there is no rows for (i, (nlj_line, hj_line)) in nlj_formatted_sorted @@ -642,10 +1004,16 @@ impl JoinFuzzTestCase { } if join_tests.contains(&HjSmj) { - let err_msg_row_cnt = format!("HashJoinExec and SortMergeJoinExec produced different row counts, batch_size: {}", &batch_size); + let err_msg_row_cnt = format!( + "HashJoinExec and SortMergeJoinExec produced different row counts, batch_size: {}", + &batch_size + ); assert_eq!(hj_rows, smj_rows, "{}", err_msg_row_cnt.as_str()); - let err_msg_contents = format!("SortMergeJoinExec and HashJoinExec produced different results, batch_size: {}", &batch_size); + let err_msg_contents = format!( + "SortMergeJoinExec and HashJoinExec produced different results, batch_size: {}", + &batch_size + ); // row level compare if any of joins returns the result // the reason is different formatting when there is no rows if smj_rows > 0 || hj_rows > 0 { @@ -719,7 +1087,7 @@ impl JoinFuzzTestCase { /// Files can be of different sizes /// The method can be useful to read partitions have been saved by `save_partitioned_batches_as_parquet` /// for test debugging purposes - #[allow(dead_code)] + #[expect(dead_code)] async fn load_partitioned_batches_from_parquet( dir: &str, ) -> std::io::Result> { @@ -760,7 +1128,7 @@ impl JoinFuzzTestCase { /// Return randomly sized record batches with: /// two sorted int32 columns 'a', 'b' ranged from 0..99 as join columns /// two random int32 columns 'x', 'y' as other columns -fn make_staggered_batches(len: usize) -> Vec { +fn make_staggered_batches_i32(len: usize, with_extra_column: bool) -> Vec { let mut rng = rand::rng(); let mut input12: Vec<(i32, i32)> = vec![(0, 0); len]; let mut input3: Vec = vec![0; len]; @@ -776,15 +1144,66 @@ fn make_staggered_batches(len: usize) -> Vec { let input3 = Int32Array::from_iter_values(input3); let input4 = Int32Array::from_iter_values(input4); - // split into several record batches - let batch = RecordBatch::try_from_iter(vec![ + let mut columns = vec![ ("a", Arc::new(input1) as ArrayRef), ("b", Arc::new(input2) as ArrayRef), ("x", Arc::new(input3) as ArrayRef), - ("y", Arc::new(input4) as ArrayRef), - ]) - .unwrap(); + ]; + + if with_extra_column { + columns.push(("y", Arc::new(input4) as ArrayRef)); + } + + // split into several record batches + let batch = RecordBatch::try_from_iter(columns).unwrap(); // use a random number generator to pick a random sized output stagger_batch_with_seed(batch, 42) } + +fn rand_bytes(rng: &mut R, min: usize, max: usize) -> Vec { + let n = rng.random_range(min..=max); + let mut v = vec![0u8; n]; + rng.fill(&mut v[..]); + v +} + +/// Return randomly sized record batches with: +/// two sorted binary columns 'a', 'b' (lexicographically) as join columns +/// two random binary columns 'x', 'y' as other columns +fn make_staggered_batches_binary( + len: usize, + with_extra_column: bool, +) -> Vec { + let mut rng = rand::rng(); + + // produce (a,b) pairs then sort lexicographically so SMJ has naturally sorted keys + let mut input12: Vec<(Vec, Vec)> = (0..len) + .map(|_| (rand_bytes(&mut rng, 4, 16), rand_bytes(&mut rng, 4, 16))) + .collect(); + input12.sort_unstable(); // lexicographic on Vec + + // payload cols (also binary so the existing x < x filter is well-typed) + let input3: Vec> = (0..len).map(|_| rand_bytes(&mut rng, 4, 24)).collect(); + let input4: Vec> = (0..len).map(|_| rand_bytes(&mut rng, 4, 24)).collect(); + + let a = BinaryArray::from_iter_values(input12.iter().map(|k| &k.0)); + let b = BinaryArray::from_iter_values(input12.iter().map(|k| &k.1)); + let x = BinaryArray::from_iter_values(input3.iter()); + let y = BinaryArray::from_iter_values(input4.iter()); + + let mut columns = vec![ + ("a", Arc::new(a) as ArrayRef), + ("b", Arc::new(b) as ArrayRef), + ("x", Arc::new(x) as ArrayRef), + ]; + + if with_extra_column { + columns.push(("y", Arc::new(y) as ArrayRef)); + } + + let batch = RecordBatch::try_from_iter(columns).unwrap(); + + // preserve your existing randomized partitioning + stagger_batch_with_seed(batch, 42) +} diff --git a/datafusion/core/tests/fuzz_cases/limit_fuzz.rs b/datafusion/core/tests/fuzz_cases/limit_fuzz.rs index 4c5ebf0402414..1c5741e7a21b3 100644 --- a/datafusion/core/tests/fuzz_cases/limit_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/limit_fuzz.rs @@ -24,7 +24,7 @@ use arrow::util::pretty::pretty_format_batches; use datafusion::datasource::MemTable; use datafusion::prelude::SessionContext; use datafusion_common::assert_contains; -use rand::{rng, Rng}; +use rand::{Rng, rng}; use std::sync::Arc; use test_utils::stagger_batch; diff --git a/datafusion/core/tests/fuzz_cases/merge_fuzz.rs b/datafusion/core/tests/fuzz_cases/merge_fuzz.rs index 92f3755250663..59430a98cc4b4 100644 --- a/datafusion/core/tests/fuzz_cases/merge_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/merge_fuzz.rs @@ -27,11 +27,10 @@ use arrow::{ use datafusion::datasource::memory::MemorySourceConfig; use datafusion::physical_plan::{ collect, - expressions::{col, PhysicalSortExpr}, + expressions::{PhysicalSortExpr, col}, sorts::sort_preserving_merge::SortPreservingMergeExec, }; use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_physical_expr_common::sort_expr::LexOrdering; use test_utils::{batches_to_vec, partitions_to_sorted_vec, stagger_batch_with_seed}; @@ -109,13 +108,14 @@ async fn run_merge_test(input: Vec>) { .expect("at least one batch"); let schema = first_batch.schema(); - let sort = LexOrdering::new(vec![PhysicalSortExpr { + let sort = [PhysicalSortExpr { expr: col("x", &schema).unwrap(), options: SortOptions { descending: false, nulls_first: true, }, - }]); + }] + .into(); let exec = MemorySourceConfig::try_new_exec(&input, schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, exec)); diff --git a/datafusion/core/tests/fuzz_cases/mod.rs b/datafusion/core/tests/fuzz_cases/mod.rs index 8ccc2a5bc1310..edb53df382c62 100644 --- a/datafusion/core/tests/fuzz_cases/mod.rs +++ b/datafusion/core/tests/fuzz_cases/mod.rs @@ -15,21 +15,30 @@ // specific language governing permissions and limitations // under the License. +#[expect(clippy::needless_pass_by_value)] mod aggregate_fuzz; mod distinct_count_string_fuzz; +#[expect(clippy::needless_pass_by_value)] mod join_fuzz; mod merge_fuzz; +#[expect(clippy::needless_pass_by_value)] mod sort_fuzz; +#[expect(clippy::needless_pass_by_value)] mod sort_query_fuzz; +mod topk_filter_pushdown; mod aggregation_fuzzer; +#[expect(clippy::needless_pass_by_value)] mod equivalence; mod pruning; mod limit_fuzz; +#[expect(clippy::needless_pass_by_value)] mod sort_preserving_repartition_fuzz; mod window_fuzz; // Utility modules +mod once_exec; mod record_batch_generator; +mod spilling_fuzz_in_memory_constrained_env; diff --git a/datafusion/core/tests/fuzz_cases/once_exec.rs b/datafusion/core/tests/fuzz_cases/once_exec.rs new file mode 100644 index 0000000000000..eed172f09f994 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/once_exec.rs @@ -0,0 +1,130 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::SchemaRef; +use datafusion_common::internal_datafusion_err; +use datafusion_common::tree_node::TreeNodeRecursion; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; +use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion_physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, +}; +use std::any::Any; +use std::fmt::{Debug, Formatter}; +use std::sync::{Arc, Mutex}; + +/// Execution plan that return the stream on the call to `execute`. further calls to `execute` will +/// return an error +pub struct OnceExec { + /// the results to send back + stream: Mutex>, + cache: Arc, +} + +impl Debug for OnceExec { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "OnceExec") + } +} + +impl OnceExec { + pub fn new(stream: SendableRecordBatchStream) -> Self { + let cache = Self::compute_properties(stream.schema()); + Self { + stream: Mutex::new(Some(stream)), + cache: Arc::new(cache), + } + } + + /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. + fn compute_properties(schema: SchemaRef) -> PlanProperties { + PlanProperties::new( + EquivalenceProperties::new(schema), + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + ) + } +} + +impl DisplayAs for OnceExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "OnceExec:") + } + DisplayFormatType::TreeRender => { + write!(f, "") + } + } + } +} + +impl ExecutionPlan for OnceExec { + fn name(&self) -> &'static str { + Self::static_name() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &Arc { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> datafusion_common::Result> { + unimplemented!() + } + + /// Returns a stream which yields data + fn execute( + &self, + partition: usize, + _context: Arc, + ) -> datafusion_common::Result { + assert_eq!(partition, 0); + + let stream = self.stream.lock().unwrap().take(); + + stream.ok_or_else(|| internal_datafusion_err!("Stream already consumed")) + } + + fn apply_expressions( + &self, + f: &mut dyn FnMut( + &dyn datafusion_physical_plan::PhysicalExpr, + ) -> datafusion_common::Result, + ) -> datafusion_common::Result { + // Visit expressions in the output ordering from equivalence properties + let mut tnr = TreeNodeRecursion::Continue; + if let Some(ordering) = self.cache.output_ordering() { + for sort_expr in ordering { + tnr = tnr.visit_sibling(|| f(sort_expr.expr.as_ref()))?; + } + } + Ok(tnr) + } +} diff --git a/datafusion/core/tests/fuzz_cases/pruning.rs b/datafusion/core/tests/fuzz_cases/pruning.rs index 6e624d458bd93..8ce5207f91190 100644 --- a/datafusion/core/tests/fuzz_cases/pruning.rs +++ b/datafusion/core/tests/fuzz_cases/pruning.rs @@ -29,9 +29,11 @@ use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_datasource::source::DataSourceExec; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_physical_expr::PhysicalExpr; -use datafusion_physical_plan::{collect, filter::FilterExec, ExecutionPlan}; +use datafusion_physical_plan::{ExecutionPlan, collect, filter::FilterExec}; use itertools::Itertools; -use object_store::{memory::InMemory, path::Path, ObjectStore, PutPayload}; +use object_store::{ + ObjectStore, ObjectStoreExt, PutPayload, memory::InMemory, path::Path, +}; use parquet::{ arrow::ArrowWriter, file::properties::{EnabledStatistics, WriterProperties}, @@ -201,7 +203,7 @@ impl Utf8Test { } } - /// all combinations of interesting charactes with lengths ranging from 1 to 4 + /// all combinations of interesting characters with lengths ranging from 1 to 4 fn values() -> &'static [String] { &VALUES } @@ -276,13 +278,12 @@ async fn execute_with_predicate( ctx: &SessionContext, ) -> Vec { let parquet_source = if prune_stats { - ParquetSource::default().with_predicate(predicate.clone()) + ParquetSource::new(schema.clone()).with_predicate(predicate.clone()) } else { - ParquetSource::default() + ParquetSource::new(schema.clone()) }; let config = FileScanConfigBuilder::new( ObjectStoreUrl::parse("memory://").unwrap(), - schema.clone(), Arc::new(parquet_source), ) .with_file_group( @@ -319,14 +320,9 @@ async fn write_parquet_file( row_groups: Vec>, ) -> Bytes { let mut buf = BytesMut::new().writer(); - let mut props = WriterProperties::builder(); - if let Some(truncation_length) = truncation_length { - props = { - #[allow(deprecated)] - props.set_max_statistics_size(truncation_length) - } - } - props = props.set_statistics_enabled(EnabledStatistics::Chunk); // row group level + let props = WriterProperties::builder() + .set_statistics_enabled(EnabledStatistics::Chunk) // row group level + .set_statistics_truncate_length(truncation_length); let props = props.build(); { let mut writer = diff --git a/datafusion/core/tests/fuzz_cases/record_batch_generator.rs b/datafusion/core/tests/fuzz_cases/record_batch_generator.rs index 7b48eadf77e09..22b145f5095a7 100644 --- a/datafusion/core/tests/fuzz_cases/record_batch_generator.rs +++ b/datafusion/core/tests/fuzz_cases/record_batch_generator.rs @@ -17,23 +17,25 @@ use std::sync::Arc; -use arrow::array::{ArrayRef, RecordBatch}; +use arrow::array::{ArrayRef, DictionaryArray, PrimitiveArray, RecordBatch}; use arrow::datatypes::{ - BooleanType, DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, - DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, - DurationSecondType, Field, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, - Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, - IntervalYearMonthType, Schema, Time32MillisecondType, Time32SecondType, - Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType, - TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, - UInt32Type, UInt64Type, UInt8Type, + ArrowPrimitiveType, BooleanType, DataType, Date32Type, Date64Type, Decimal32Type, + Decimal64Type, Decimal128Type, Decimal256Type, DurationMicrosecondType, + DurationMillisecondType, DurationNanosecondType, DurationSecondType, Field, + Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, Int64Type, + IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, + Schema, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, + Time64NanosecondType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, + TimestampNanosecondType, TimestampSecondType, UInt8Type, UInt16Type, UInt32Type, + UInt64Type, }; use arrow_schema::{ - DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, - DECIMAL256_MAX_SCALE, + DECIMAL32_MAX_PRECISION, DECIMAL32_MAX_SCALE, DECIMAL64_MAX_PRECISION, + DECIMAL64_MAX_SCALE, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, + DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, }; -use datafusion_common::{arrow_datafusion_err, DataFusionError, Result}; -use rand::{rng, rngs::StdRng, Rng, SeedableRng}; +use datafusion_common::{Result, arrow_datafusion_err}; +use rand::{Rng, SeedableRng, rng, rngs::StdRng}; use test_utils::array_gen::{ BinaryArrayGenerator, BooleanArrayGenerator, DecimalArrayGenerator, PrimitiveArrayGenerator, StringArrayGenerator, @@ -103,6 +105,20 @@ pub fn get_supported_types_columns(rng_seed: u64) -> Vec { "duration_nanosecond", DataType::Duration(TimeUnit::Nanosecond), ), + ColumnDescr::new("decimal32", { + let precision: u8 = rng.random_range(1..=DECIMAL32_MAX_PRECISION); + let scale: i8 = rng.random_range( + i8::MIN..=std::cmp::min(precision as i8, DECIMAL32_MAX_SCALE), + ); + DataType::Decimal32(precision, scale) + }), + ColumnDescr::new("decimal64", { + let precision: u8 = rng.random_range(1..=DECIMAL64_MAX_PRECISION); + let scale: i8 = rng.random_range( + i8::MIN..=std::cmp::min(precision as i8, DECIMAL64_MAX_SCALE), + ); + DataType::Decimal64(precision, scale) + }), ColumnDescr::new("decimal128", { let precision: u8 = rng.random_range(1..=DECIMAL128_MAX_PRECISION); let scale: i8 = rng.random_range( @@ -126,6 +142,11 @@ pub fn get_supported_types_columns(rng_seed: u64) -> Vec { ColumnDescr::new("binary", DataType::Binary), ColumnDescr::new("large_binary", DataType::LargeBinary), ColumnDescr::new("binaryview", DataType::BinaryView), + ColumnDescr::new( + "dictionary_utf8_low", + DataType::Dictionary(Box::new(DataType::UInt64), Box::new(DataType::Utf8)), + ) + .with_max_num_distinct(10), ] } @@ -185,17 +206,13 @@ pub struct RecordBatchGenerator { } macro_rules! generate_decimal_array { - ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT: expr, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $PRECISION: ident, $SCALE: ident, $ARROW_TYPE: ident) => {{ - let null_pct_idx = - $BATCH_GEN_RNG.random_range(0..$SELF.candidate_null_pcts.len()); - let null_pct = $SELF.candidate_null_pcts[null_pct_idx]; - + ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT: expr, $NULL_PCT:ident, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $PRECISION: ident, $SCALE: ident, $ARROW_TYPE: ident) => {{ let mut generator = DecimalArrayGenerator { precision: $PRECISION, scale: $SCALE, num_decimals: $NUM_ROWS, num_distinct_decimals: $MAX_NUM_DISTINCT, - null_pct, + null_pct: $NULL_PCT, rng: $ARRAY_GEN_RNG, }; @@ -205,18 +222,13 @@ macro_rules! generate_decimal_array { // Generating `BooleanArray` due to it being a special type in Arrow (bit-packed) macro_rules! generate_boolean_array { - ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT:expr, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $ARROW_TYPE: ident) => {{ - // Select a null percentage from the candidate percentages - let null_pct_idx = - $BATCH_GEN_RNG.random_range(0..$SELF.candidate_null_pcts.len()); - let null_pct = $SELF.candidate_null_pcts[null_pct_idx]; - + ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT:expr, $NULL_PCT:ident, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $ARROW_TYPE: ident) => {{ let num_distinct_booleans = if $MAX_NUM_DISTINCT >= 2 { 2 } else { 1 }; let mut generator = BooleanArrayGenerator { num_booleans: $NUM_ROWS, num_distinct_booleans, - null_pct, + null_pct: $NULL_PCT, rng: $ARRAY_GEN_RNG, }; @@ -225,15 +237,11 @@ macro_rules! generate_boolean_array { } macro_rules! generate_primitive_array { - ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT:expr, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $ARROW_TYPE:ident) => {{ - let null_pct_idx = - $BATCH_GEN_RNG.random_range(0..$SELF.candidate_null_pcts.len()); - let null_pct = $SELF.candidate_null_pcts[null_pct_idx]; - + ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT:expr, $NULL_PCT:ident, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $ARROW_TYPE:ident) => {{ let mut generator = PrimitiveArrayGenerator { num_primitives: $NUM_ROWS, num_distinct_primitives: $MAX_NUM_DISTINCT, - null_pct, + null_pct: $NULL_PCT, rng: $ARRAY_GEN_RNG, }; @@ -241,6 +249,28 @@ macro_rules! generate_primitive_array { }}; } +macro_rules! generate_dict { + ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT:expr, $NULL_PCT:ident, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $ARROW_TYPE:ident, $VALUES: ident) => {{ + debug_assert_eq!($VALUES.len(), $MAX_NUM_DISTINCT); + let keys: PrimitiveArray<$ARROW_TYPE> = (0..$NUM_ROWS) + .map(|_| { + if $BATCH_GEN_RNG.random::() < $NULL_PCT { + None + } else if $MAX_NUM_DISTINCT > 1 { + let range = 0..($MAX_NUM_DISTINCT + as <$ARROW_TYPE as ArrowPrimitiveType>::Native); + Some($ARRAY_GEN_RNG.random_range(range)) + } else { + Some(0) + } + }) + .collect(); + + let dict = DictionaryArray::new(keys, $VALUES); + Arc::new(dict) as ArrayRef + }}; +} + impl RecordBatchGenerator { /// Create a new `RecordBatchGenerator` with a random seed. The generated /// batches will be different each time. @@ -302,6 +332,25 @@ impl RecordBatchGenerator { num_rows: usize, batch_gen_rng: &mut StdRng, array_gen_rng: StdRng, + ) -> ArrayRef { + let null_pct_idx = batch_gen_rng.random_range(0..self.candidate_null_pcts.len()); + let null_pct = self.candidate_null_pcts[null_pct_idx]; + + Self::generate_array_of_type_inner( + col, + num_rows, + batch_gen_rng, + array_gen_rng, + null_pct, + ) + } + + fn generate_array_of_type_inner( + col: &ColumnDescr, + num_rows: usize, + batch_gen_rng: &mut StdRng, + array_gen_rng: StdRng, + null_pct: f64, ) -> ArrayRef { let num_distinct = if num_rows > 1 { batch_gen_rng.random_range(1..num_rows) @@ -320,6 +369,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Int8Type @@ -330,6 +380,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Int16Type @@ -340,6 +391,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Int32Type @@ -350,6 +402,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Int64Type @@ -360,6 +413,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, UInt8Type @@ -370,6 +424,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, UInt16Type @@ -380,6 +435,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, UInt32Type @@ -390,6 +446,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, UInt64Type @@ -400,6 +457,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Float32Type @@ -410,6 +468,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Float64Type @@ -420,6 +479,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Date32Type @@ -430,6 +490,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Date64Type @@ -440,6 +501,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Time32SecondType @@ -450,6 +512,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Time32MillisecondType @@ -460,6 +523,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Time64MicrosecondType @@ -470,6 +534,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Time64NanosecondType @@ -480,6 +545,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, IntervalYearMonthType @@ -490,6 +556,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, IntervalDayTimeType @@ -500,6 +567,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, IntervalMonthDayNanoType @@ -510,6 +578,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, DurationSecondType @@ -520,6 +589,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, DurationMillisecondType @@ -530,6 +600,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, DurationMicrosecondType @@ -540,6 +611,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, DurationNanosecondType @@ -550,6 +622,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, TimestampSecondType @@ -560,6 +633,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, TimestampMillisecondType @@ -570,6 +644,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, TimestampMicrosecondType @@ -580,15 +655,13 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, TimestampNanosecondType ) } DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => { - let null_pct_idx = - batch_gen_rng.random_range(0..self.candidate_null_pcts.len()); - let null_pct = self.candidate_null_pcts[null_pct_idx]; let max_len = batch_gen_rng.random_range(1..50); let mut generator = StringArrayGenerator { @@ -607,9 +680,6 @@ impl RecordBatchGenerator { } } DataType::Binary | DataType::LargeBinary | DataType::BinaryView => { - let null_pct_idx = - batch_gen_rng.random_range(0..self.candidate_null_pcts.len()); - let null_pct = self.candidate_null_pcts[null_pct_idx]; let max_len = batch_gen_rng.random_range(1..100); let mut generator = BinaryArrayGenerator { @@ -627,11 +697,38 @@ impl RecordBatchGenerator { _ => unreachable!(), } } + DataType::Decimal32(precision, scale) => { + generate_decimal_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + precision, + scale, + Decimal32Type + ) + } + DataType::Decimal64(precision, scale) => { + generate_decimal_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + precision, + scale, + Decimal64Type + ) + } DataType::Decimal128(precision, scale) => { generate_decimal_array!( self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, precision, @@ -644,6 +741,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, precision, @@ -656,11 +754,41 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, BooleanType } } + DataType::Dictionary(ref key_type, ref value_type) + if key_type.is_dictionary_key_type() => + { + // We generate just num_distinct values because they will be reused by different keys + let mut array_gen_rng = array_gen_rng; + debug_assert!((0.0..=1.0).contains(&null_pct)); + let values = Self::generate_array_of_type_inner( + &ColumnDescr::new("values", *value_type.clone()), + num_distinct, + batch_gen_rng, + array_gen_rng.clone(), + null_pct, // generate some null values + ); + + match key_type.as_ref() { + // new key types can be added here + DataType::UInt64 => generate_dict!( + self, + num_rows, + num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + UInt64Type, + values + ), + _ => panic!("Invalid dictionary keys type: {key_type}"), + } + } _ => { panic!("Unsupported data generator type: {}", col.column_type) } diff --git a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs index 703b8715821a8..0d8a066d432dd 100644 --- a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use arrow::{ - array::{as_string_array, ArrayRef, Int32Array, StringArray}, + array::{ArrayRef, Int32Array, StringArray, as_string_array}, compute::SortOptions, record_batch::RecordBatch, }; @@ -28,7 +28,7 @@ use datafusion::datasource::memory::MemorySourceConfig; use datafusion::execution::runtime_env::RuntimeEnvBuilder; use datafusion::physical_plan::expressions::PhysicalSortExpr; use datafusion::physical_plan::sorts::sort::SortExec; -use datafusion::physical_plan::{collect, ExecutionPlan}; +use datafusion::physical_plan::{ExecutionPlan, collect}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::cast::as_int32_array; use datafusion_execution::memory_pool::GreedyMemoryPool; @@ -188,7 +188,7 @@ impl SortTest { } fn with_sort_columns(mut self, sort_columns: Vec<&str>) -> Self { - self.sort_columns = sort_columns.iter().map(|s| s.to_string()).collect(); + self.sort_columns = sort_columns.iter().map(|s| (*s).to_string()).collect(); self } @@ -232,18 +232,15 @@ impl SortTest { .expect("at least one batch"); let schema = first_batch.schema(); - let sort_ordering = LexOrdering::new( - self.sort_columns - .iter() - .map(|c| PhysicalSortExpr { - expr: col(c, &schema).unwrap(), - options: SortOptions { - descending: false, - nulls_first: true, - }, - }) - .collect(), - ); + let sort_ordering = + LexOrdering::new(self.sort_columns.iter().map(|c| PhysicalSortExpr { + expr: col(c, &schema).unwrap(), + options: SortOptions { + descending: false, + nulls_first: true, + }, + })) + .unwrap(); let exec = MemorySourceConfig::try_new_exec(&input, schema, None).unwrap(); let sort = Arc::new(SortExec::new(sort_ordering, exec)); diff --git a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs index cf6867758edc7..8f3b8ea05324c 100644 --- a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs @@ -20,35 +20,33 @@ mod sp_repartition_fuzz_tests { use std::sync::Arc; use arrow::array::{ArrayRef, Int64Array, RecordBatch, UInt64Array}; - use arrow::compute::{concat_batches, lexsort, SortColumn, SortOptions}; + use arrow::compute::{SortColumn, SortOptions, concat_batches, lexsort}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use datafusion::datasource::memory::MemorySourceConfig; + use datafusion::datasource::source::DataSourceExec; use datafusion::physical_plan::{ - collect, + ExecutionPlan, Partitioning, collect, metrics::{BaselineMetrics, ExecutionPlanMetricsSet}, repartition::RepartitionExec, sorts::sort_preserving_merge::SortPreservingMergeExec, sorts::streaming_merge::StreamingMergeBuilder, stream::RecordBatchStreamAdapter, - ExecutionPlan, Partitioning, }; use datafusion::prelude::SessionContext; use datafusion_common::Result; - use datafusion_execution::{ - config::SessionConfig, memory_pool::MemoryConsumer, SendableRecordBatchStream, - }; - use datafusion_physical_expr::{ - equivalence::{EquivalenceClass, EquivalenceProperties}, - expressions::{col, Column}, - ConstExpr, PhysicalExpr, PhysicalSortExpr, + use datafusion_execution::{config::SessionConfig, memory_pool::MemoryConsumer}; + use datafusion_physical_expr::ConstExpr; + use datafusion_physical_expr::equivalence::{ + EquivalenceClass, EquivalenceProperties, }; + use datafusion_physical_expr::expressions::{Column, col}; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use test_utils::add_empty_batches; - use datafusion::datasource::memory::MemorySourceConfig; - use datafusion::datasource::source::DataSourceExec; - use datafusion_physical_expr_common::sort_expr::LexOrdering; use itertools::izip; - use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng}; + use rand::{Rng, SeedableRng, rngs::StdRng, seq::SliceRandom}; // Generate a schema which consists of 6 columns (a, b, c, d, e, f) fn create_test_schema() -> Result { @@ -80,9 +78,9 @@ mod sp_repartition_fuzz_tests { let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); // Define a and f are aliases - eq_properties.add_equal_conditions(col_a, col_f)?; + eq_properties.add_equal_conditions(Arc::clone(col_a), Arc::clone(col_f))?; // Column e has constant value. - eq_properties = eq_properties.with_constants([ConstExpr::from(col_e)]); + eq_properties.add_constants([ConstExpr::from(Arc::clone(col_e))])?; // Randomly order columns for sorting let mut rng = StdRng::seed_from_u64(seed); @@ -94,18 +92,18 @@ mod sp_repartition_fuzz_tests { }; while !remaining_exprs.is_empty() { - let n_sort_expr = rng.random_range(0..remaining_exprs.len() + 1); + let n_sort_expr = rng.random_range(1..remaining_exprs.len() + 1); remaining_exprs.shuffle(&mut rng); - let ordering = remaining_exprs - .drain(0..n_sort_expr) - .map(|expr| PhysicalSortExpr { - expr: expr.clone(), - options: options_asc, - }) - .collect(); + let ordering = + remaining_exprs + .drain(0..n_sort_expr) + .map(|expr| PhysicalSortExpr { + expr: expr.clone(), + options: options_asc, + }); - eq_properties.add_new_orderings([ordering]); + eq_properties.add_ordering(ordering); } Ok((test_schema, eq_properties)) @@ -151,7 +149,7 @@ mod sp_repartition_fuzz_tests { // Fill constant columns for constant in eq_properties.constants() { - let col = constant.expr().as_any().downcast_ref::().unwrap(); + let col = constant.expr.as_any().downcast_ref::().unwrap(); let (idx, _field) = schema.column_with_name(col.name()).unwrap(); let arr = Arc::new(UInt64Array::from_iter_values(vec![0; n_elem])) as ArrayRef; @@ -227,21 +225,21 @@ mod sp_repartition_fuzz_tests { let table_data_with_properties = generate_table_for_eq_properties(&eq_properties, N_ELEM, N_DISTINCT)?; let schema = table_data_with_properties.schema(); - let streams: Vec = (0..N_PARTITION) + let streams = (0..N_PARTITION) .map(|_idx| { let batch = table_data_with_properties.clone(); Box::pin(RecordBatchStreamAdapter::new( schema.clone(), futures::stream::once(async { Ok(batch) }), - )) as SendableRecordBatchStream + )) as _ }) .collect::>(); - // Returns concatenated version of the all available orderings - let exprs = eq_properties - .oeq_class() - .output_ordering() - .unwrap_or_default(); + // Returns concatenated version of the all available orderings: + let Some(exprs) = eq_properties.oeq_class().output_ordering() else { + // We always should have an ordering due to the way we generate the schema: + unreachable!("No ordering found in eq_properties: {:?}", eq_properties); + }; let context = SessionContext::new().task_ctx(); let mem_reservation = @@ -303,7 +301,7 @@ mod sp_repartition_fuzz_tests { let mut handles = Vec::new(); for seed in seed_start..seed_end { - #[allow(clippy::disallowed_methods)] // spawn allowed only in tests + #[expect(clippy::disallowed_methods)] // spawn allowed only in tests let job = tokio::spawn(run_sort_preserving_repartition_test( make_staggered_batches::(n_row, n_distinct, seed as u64), is_first_roundrobin, @@ -347,20 +345,16 @@ mod sp_repartition_fuzz_tests { let schema = input1[0].schema(); let session_config = SessionConfig::new().with_batch_size(50); let ctx = SessionContext::new_with_config(session_config); - let mut sort_keys = LexOrdering::default(); - for ordering_col in ["a", "b", "c"] { - sort_keys.push(PhysicalSortExpr { - expr: col(ordering_col, &schema).unwrap(), - options: SortOptions::default(), - }) - } + let sort_keys = ["a", "b", "c"].map(|ordering_col| { + PhysicalSortExpr::new_default(col(ordering_col, &schema).unwrap()) + }); let concat_input_record = concat_batches(&schema, &input1).unwrap(); let running_source = Arc::new( - MemorySourceConfig::try_new(&[input1.clone()], schema.clone(), None) + MemorySourceConfig::try_new(&[input1], schema.clone(), None) .unwrap() - .try_with_sort_information(vec![sort_keys.clone()]) + .try_with_sort_information(vec![sort_keys.clone().into()]) .unwrap(), ); let running_source = Arc::new(DataSourceExec::new(running_source)); @@ -381,7 +375,7 @@ mod sp_repartition_fuzz_tests { sort_preserving_repartition_exec_hash(intermediate, hash_exprs.clone()) }; - let final_plan = sort_preserving_merge_exec(sort_keys.clone(), intermediate); + let final_plan = sort_preserving_merge_exec(sort_keys.into(), intermediate); let task_ctx = ctx.task_ctx(); let collected_running = collect(final_plan, task_ctx.clone()).await.unwrap(); @@ -428,10 +422,9 @@ mod sp_repartition_fuzz_tests { } fn sort_preserving_merge_exec( - sort_exprs: impl IntoIterator, + sort_exprs: LexOrdering, input: Arc, ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); Arc::new(SortPreservingMergeExec::new(sort_exprs, input)) } diff --git a/datafusion/core/tests/fuzz_cases/sort_query_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_query_fuzz.rs index d2d3a5e0c22fa..376306f3e0659 100644 --- a/datafusion/core/tests/fuzz_cases/sort_query_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_query_fuzz.rs @@ -24,24 +24,22 @@ use arrow::array::RecordBatch; use arrow_schema::SchemaRef; use datafusion::datasource::MemTable; use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_common::{instant::Instant, Result}; +use datafusion_common::{Result, human_readable_size, instant::Instant}; use datafusion_execution::disk_manager::DiskManagerBuilder; -use datafusion_execution::memory_pool::{ - human_readable_size, MemoryPool, UnboundedMemoryPool, -}; +use datafusion_execution::memory_pool::{MemoryPool, UnboundedMemoryPool}; use datafusion_expr::display_schema; use datafusion_physical_plan::spill::get_record_batch_memory_size; use std::time::Duration; use datafusion_execution::{memory_pool::FairSpillPool, runtime_env::RuntimeEnvBuilder}; -use rand::prelude::IndexedRandom; use rand::Rng; -use rand::{rngs::StdRng, SeedableRng}; +use rand::prelude::IndexedRandom; +use rand::{SeedableRng, rngs::StdRng}; use crate::fuzz_cases::aggregation_fuzzer::check_equality_of_batches; use super::aggregation_fuzzer::ColumnDescr; -use super::record_batch_generator::{get_supported_types_columns, RecordBatchGenerator}; +use super::record_batch_generator::{RecordBatchGenerator, get_supported_types_columns}; /// Entry point for executing the sort query fuzzer. /// @@ -177,16 +175,16 @@ impl SortQueryFuzzer { n_round: usize, n_query: usize, ) -> bool { - if let Some(time_limit) = self.time_limit { - if Instant::now().duration_since(start_time) > time_limit { - println!( - "[SortQueryFuzzer] Time limit reached: {} queries ({} random configs each) in {} rounds", - n_round * self.queries_per_round + n_query, - self.config_variations_per_query, - n_round - ); - return true; - } + if let Some(time_limit) = self.time_limit + && Instant::now().duration_since(start_time) > time_limit + { + println!( + "[SortQueryFuzzer] Time limit reached: {} queries ({} random configs each) in {} rounds", + n_round * self.queries_per_round + n_query, + self.config_variations_per_query, + n_round + ); + return true; } false } @@ -220,7 +218,7 @@ impl SortQueryFuzzer { .test_gen .fuzzer_run(init_seed, query_seed, config_seed) .await?; - println!("\n"); // Seperator between tested runs + println!("\n"); // Separator between tested runs if expected_results.is_none() { expected_results = Some(results); @@ -428,7 +426,7 @@ impl SortFuzzerTestGenerator { .collect(); let mut order_by_clauses = Vec::new(); - for col in selected_columns { + for col in &selected_columns { let mut clause = col.name.clone(); if rng.random_bool(0.5) { let order = if rng.random_bool(0.5) { "ASC" } else { "DESC" }; @@ -463,7 +461,12 @@ impl SortFuzzerTestGenerator { let limit_clause = limit.map_or(String::new(), |l| format!(" LIMIT {l}")); let query = format!( - "SELECT * FROM {} ORDER BY {}{}", + "SELECT {} FROM {} ORDER BY {}{}", + selected_columns + .iter() + .map(|col| col.name.clone()) + .collect::>() + .join(", "), self.table_name, order_by_clauses.join(", "), limit_clause diff --git a/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs b/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs new file mode 100644 index 0000000000000..d401557e966d6 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs @@ -0,0 +1,658 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Fuzz Test for different operators in memory constrained environment + +use std::pin::Pin; +use std::sync::Arc; + +use crate::fuzz_cases::aggregate_fuzz::assert_spill_count_metric; +use crate::fuzz_cases::once_exec::OnceExec; +use arrow::array::UInt64Array; +use arrow::{array::StringArray, compute::SortOptions, record_batch::RecordBatch}; +use arrow_schema::{DataType, Field, Schema}; +use datafusion::common::Result; +use datafusion::execution::runtime_env::RuntimeEnvBuilder; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::expressions::PhysicalSortExpr; +use datafusion::physical_plan::sorts::sort::SortExec; +use datafusion::prelude::SessionConfig; +use datafusion_common::units::{KB, MB}; +use datafusion_execution::memory_pool::{ + FairSpillPool, MemoryConsumer, MemoryReservation, +}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_functions_aggregate::array_agg::array_agg_udaf; +use datafusion_physical_expr::aggregate::AggregateExprBuilder; +use datafusion_physical_expr::expressions::{Column, col}; +use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, +}; +use datafusion_physical_plan::stream::RecordBatchStreamAdapter; +use futures::StreamExt; + +#[tokio::test] +async fn test_sort_with_limited_memory() -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + let record_batch_size = pool_size / 16; + + // Basic test with a lot of groups that cannot all fit in memory and 1 record batch + // from each spill file is too much memory + let spill_count = run_sort_test_with_limited_memory(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx: Arc::new(task_ctx), + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |_| record_batch_size), + memory_behavior: Default::default(), + }) + .await?; + + let total_spill_files_size = spill_count * record_batch_size; + assert!( + total_spill_files_size > pool_size, + "Total spill files size {total_spill_files_size} should be greater than pool size {pool_size}", + ); + + Ok(()) +} + +#[tokio::test] +async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch() -> Result<()> +{ + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_sort_test_with_limited_memory(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx: Arc::new(task_ctx), + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 6 + } else { + 16 * KB as usize + } + }), + memory_behavior: Default::default(), + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_changing_memory_reservation() +-> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_sort_test_with_limited_memory(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx: Arc::new(task_ctx), + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 6 + } else { + 16 * KB as usize + } + }), + memory_behavior: MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(10), + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_take_all_memory() +-> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_sort_test_with_limited_memory(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx: Arc::new(task_ctx), + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 6 + } else { + 16 * KB as usize + } + }), + memory_behavior: MemoryBehavior::TakeAllMemoryAtTheBeginning, + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_sort_with_limited_memory_and_large_record_batch() -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + // Test that the merge degree of multi level merge sort cannot be fixed size when there is not enough memory + run_sort_test_with_limited_memory(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx: Arc::new(task_ctx), + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |_| pool_size / 6), + memory_behavior: Default::default(), + }) + .await?; + + Ok(()) +} + +struct RunTestWithLimitedMemoryArgs { + pool_size: usize, + task_ctx: Arc, + number_of_record_batches: usize, + get_size_of_record_batch_to_generate: + Pin usize + Send + 'static>>, + memory_behavior: MemoryBehavior, +} + +#[derive(Default)] +enum MemoryBehavior { + #[default] + AsIs, + TakeAllMemoryAtTheBeginning, + TakeAllMemoryAndReleaseEveryNthBatch(usize), +} + +async fn run_sort_test_with_limited_memory( + mut args: RunTestWithLimitedMemoryArgs, +) -> Result { + let get_size_of_record_batch_to_generate = std::mem::replace( + &mut args.get_size_of_record_batch_to_generate, + Box::pin(move |_| unreachable!("should not be called after take")), + ); + + let scan_schema = Arc::new(Schema::new(vec![ + Field::new("col_0", DataType::UInt64, true), + Field::new("col_1", DataType::Utf8, true), + ])); + + let record_batch_size = args.task_ctx.session_config().batch_size() as u64; + + let schema = Arc::clone(&scan_schema); + let plan: Arc = + Arc::new(OnceExec::new(Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&schema), + futures::stream::iter((0..args.number_of_record_batches as u64).map( + move |index| { + let mut record_batch_memory_size = + get_size_of_record_batch_to_generate(index as usize); + record_batch_memory_size = record_batch_memory_size + .saturating_sub(size_of::() * record_batch_size as usize); + + let string_item_size = + record_batch_memory_size / record_batch_size as usize; + let string_array = + Arc::new(StringArray::from_iter_values(std::iter::repeat_n( + "a".repeat(string_item_size), + record_batch_size as usize, + ))); + + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(UInt64Array::from_iter_values( + (index * record_batch_size) + ..(index * record_batch_size) + record_batch_size, + )), + string_array, + ], + ) + .map_err(|err| err.into()) + }, + )), + )))); + let sort_exec = Arc::new(SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr { + expr: col("col_0", &scan_schema).unwrap(), + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]) + .unwrap(), + plan, + )); + + let result = sort_exec.execute(0, Arc::clone(&args.task_ctx))?; + + run_test(args, sort_exec, result).await +} + +fn grow_memory_as_much_as_possible( + memory_step: usize, + memory_reservation: &mut MemoryReservation, +) -> Result { + let mut was_able_to_grow = false; + while memory_reservation.try_grow(memory_step).is_ok() { + was_able_to_grow = true; + } + + Ok(was_able_to_grow) +} + +#[tokio::test] +async fn test_aggregate_with_high_cardinality_with_limited_memory() -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + let record_batch_size = pool_size / 16; + + // Basic test with a lot of groups that cannot all fit in memory and 1 record batch + // from each spill file is too much memory + let spill_count = + run_test_aggregate_with_high_cardinality(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx: Arc::new(task_ctx), + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |_| record_batch_size), + memory_behavior: Default::default(), + }) + .await?; + + let total_spill_files_size = spill_count * record_batch_size; + assert!( + total_spill_files_size > pool_size, + "Total spill files size {total_spill_files_size} should be greater than pool size {pool_size}", + ); + + Ok(()) +} + +#[tokio::test] +async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch() +-> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_test_aggregate_with_high_cardinality(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx: Arc::new(task_ctx), + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 6 + } else { + (16 * KB) as usize + } + }), + memory_behavior: Default::default(), + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch_and_changing_memory_reservation() +-> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_test_aggregate_with_high_cardinality(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx: Arc::new(task_ctx), + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 6 + } else { + (16 * KB) as usize + } + }), + memory_behavior: MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(10), + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch_and_take_all_memory() +-> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_test_aggregate_with_high_cardinality(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx: Arc::new(task_ctx), + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 6 + } else { + (16 * KB) as usize + } + }), + memory_behavior: MemoryBehavior::TakeAllMemoryAtTheBeginning, + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_aggregate_with_high_cardinality_with_limited_memory_and_large_record_batch() +-> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + // Test that the merge degree of multi level merge sort cannot be fixed size when there is not enough memory + run_test_aggregate_with_high_cardinality(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx: Arc::new(task_ctx), + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |_| pool_size / 6), + memory_behavior: Default::default(), + }) + .await?; + + Ok(()) +} + +async fn run_test_aggregate_with_high_cardinality( + mut args: RunTestWithLimitedMemoryArgs, +) -> Result { + let get_size_of_record_batch_to_generate = std::mem::replace( + &mut args.get_size_of_record_batch_to_generate, + Box::pin(move |_| unreachable!("should not be called after take")), + ); + let scan_schema = Arc::new(Schema::new(vec![ + Field::new("col_0", DataType::UInt64, true), + Field::new("col_1", DataType::Utf8, true), + ])); + + let group_by = PhysicalGroupBy::new_single(vec![( + Arc::new(Column::new("col_0", 0)), + "col_0".to_string(), + )]); + + let aggregate_expressions = vec![Arc::new( + AggregateExprBuilder::new( + array_agg_udaf(), + vec![col("col_1", &scan_schema).unwrap()], + ) + .schema(Arc::clone(&scan_schema)) + .alias("array_agg(col_1)") + .build()?, + )]; + + let record_batch_size = args.task_ctx.session_config().batch_size() as u64; + + let schema = Arc::clone(&scan_schema); + let plan: Arc = + Arc::new(OnceExec::new(Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&schema), + futures::stream::iter((0..args.number_of_record_batches as u64).map( + move |index| { + let mut record_batch_memory_size = + get_size_of_record_batch_to_generate(index as usize); + record_batch_memory_size = record_batch_memory_size + .saturating_sub(size_of::() * record_batch_size as usize); + + let string_item_size = + record_batch_memory_size / record_batch_size as usize; + let string_array = + Arc::new(StringArray::from_iter_values(std::iter::repeat_n( + "a".repeat(string_item_size), + record_batch_size as usize, + ))); + + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + // Grouping key + Arc::new(UInt64Array::from_iter_values( + (index * record_batch_size) + ..(index * record_batch_size) + record_batch_size, + )), + // Grouping value + string_array, + ], + ) + .map_err(|err| err.into()) + }, + )), + )))); + + let aggregate_exec = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + group_by.clone(), + aggregate_expressions.clone(), + vec![None; aggregate_expressions.len()], + plan, + Arc::clone(&scan_schema), + )?); + let aggregate_final = Arc::new(AggregateExec::try_new( + AggregateMode::Final, + group_by, + aggregate_expressions.clone(), + vec![None; aggregate_expressions.len()], + aggregate_exec, + Arc::clone(&scan_schema), + )?); + + let result = aggregate_final.execute(0, Arc::clone(&args.task_ctx))?; + + run_test(args, aggregate_final, result).await +} + +async fn run_test( + args: RunTestWithLimitedMemoryArgs, + plan: Arc, + result_stream: SendableRecordBatchStream, +) -> Result { + let number_of_record_batches = args.number_of_record_batches; + + consume_stream_and_simulate_other_running_memory_consumers(args, result_stream) + .await?; + + let spill_count = assert_spill_count_metric(true, plan); + + assert!( + spill_count > 0, + "Expected spill, but did not, number of record batches: {number_of_record_batches}", + ); + + Ok(spill_count) +} + +/// Consume the stream and change the amount of memory used while consuming it based on the [`MemoryBehavior`] provided +async fn consume_stream_and_simulate_other_running_memory_consumers( + args: RunTestWithLimitedMemoryArgs, + mut result_stream: SendableRecordBatchStream, +) -> Result<()> { + let mut number_of_rows = 0; + let record_batch_size = args.task_ctx.session_config().batch_size() as u64; + + let memory_pool = args.task_ctx.memory_pool(); + let memory_consumer = MemoryConsumer::new("mock_memory_consumer"); + let mut memory_reservation = memory_consumer.register(memory_pool); + + let mut index = 0; + let mut memory_took = false; + + while let Some(batch) = result_stream.next().await { + match args.memory_behavior { + MemoryBehavior::AsIs => { + // Do nothing + } + MemoryBehavior::TakeAllMemoryAtTheBeginning => { + if !memory_took { + memory_took = true; + grow_memory_as_much_as_possible(10, &mut memory_reservation)?; + } + } + MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(n) => { + if !memory_took { + memory_took = true; + grow_memory_as_much_as_possible( + args.pool_size, + &mut memory_reservation, + )?; + } else if index % n == 0 { + // release memory + memory_reservation.free(); + } + } + } + + let batch = batch?; + number_of_rows += batch.num_rows(); + + index += 1; + } + + assert_eq!( + number_of_rows, + args.number_of_record_batches * record_batch_size as usize + ); + + Ok(()) +} diff --git a/datafusion/core/tests/fuzz_cases/topk_filter_pushdown.rs b/datafusion/core/tests/fuzz_cases/topk_filter_pushdown.rs new file mode 100644 index 0000000000000..d14afaf1b3267 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/topk_filter_pushdown.rs @@ -0,0 +1,387 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::sync::{Arc, LazyLock}; + +use arrow::array::{Int32Array, StringArray, StringDictionaryBuilder}; +use arrow::datatypes::Int32Type; +use arrow::record_batch::RecordBatch; +use arrow::util::pretty::pretty_format_batches; +use arrow_schema::{DataType, Field, Schema}; +use datafusion::datasource::listing::{ListingOptions, ListingTable, ListingTableConfig}; +use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion_datasource::ListingTableUrl; +use datafusion_datasource_parquet::ParquetFormat; +use datafusion_execution::object_store::ObjectStoreUrl; +use itertools::Itertools; +use object_store::memory::InMemory; +use object_store::path::Path; +use object_store::{ObjectStore, ObjectStoreExt, PutPayload}; +use parquet::arrow::ArrowWriter; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use tokio::sync::Mutex; +use tokio::task::JoinSet; + +#[derive(Clone)] +struct TestDataSet { + store: Arc, + schema: Arc, +} + +/// List of in memory parquet files with UTF8 data +// Use a mutex rather than LazyLock to allow for async initialization +static TESTFILES: LazyLock>> = + LazyLock::new(|| Mutex::new(vec![])); + +async fn test_files() -> Vec { + let files_mutex = &TESTFILES; + let mut files = files_mutex.lock().await; + if !files.is_empty() { + return (*files).clone(); + } + + let mut rng = StdRng::seed_from_u64(0); + + for nulls_in_ids in [false, true] { + for nulls_in_names in [false, true] { + for nulls_in_departments in [false, true] { + let store = Arc::new(InMemory::new()); + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, nulls_in_ids), + Field::new("name", DataType::Utf8, nulls_in_names), + Field::new( + "department", + DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8), + ), + nulls_in_departments, + ), + ])); + + let name_choices = if nulls_in_names { + [Some("Alice"), Some("Bob"), None, Some("David"), None] + } else { + [ + Some("Alice"), + Some("Bob"), + Some("Charlie"), + Some("David"), + Some("Eve"), + ] + }; + + let department_choices = if nulls_in_departments { + [ + Some("Theater"), + Some("Engineering"), + None, + Some("Arts"), + None, + ] + } else { + [ + Some("Theater"), + Some("Engineering"), + Some("Healthcare"), + Some("Arts"), + Some("Music"), + ] + }; + + // Generate 5 files, some with overlapping or repeated ids some without + for i in 0..5 { + let num_batches = rng.random_range(1..3); + let mut batches = Vec::with_capacity(num_batches); + for _ in 0..num_batches { + let num_rows = 25; + let ids = Int32Array::from_iter((0..num_rows).map(|file| { + if nulls_in_ids { + if rng.random_bool(1.0 / 10.0) { + None + } else { + Some(rng.random_range(file..file + 5)) + } + } else { + Some(rng.random_range(file..file + 5)) + } + })); + let names = StringArray::from_iter((0..num_rows).map(|_| { + // randomly select a name + let idx = rng.random_range(0..name_choices.len()); + name_choices[idx].map(|s| s.to_string()) + })); + let mut departments = StringDictionaryBuilder::::new(); + for _ in 0..num_rows { + // randomly select a department + let idx = rng.random_range(0..department_choices.len()); + departments.append_option(department_choices[idx].as_ref()); + } + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(ids), + Arc::new(names), + Arc::new(departments.finish()), + ], + ) + .unwrap(); + batches.push(batch); + } + let mut buf = vec![]; + { + let mut writer = + ArrowWriter::try_new(&mut buf, schema.clone(), None).unwrap(); + for batch in batches { + writer.write(&batch).unwrap(); + writer.flush().unwrap(); + } + writer.flush().unwrap(); + writer.finish().unwrap(); + } + let payload = PutPayload::from(buf); + let path = Path::from(format!("file_{i}.parquet")); + store.put(&path, payload).await.unwrap(); + } + files.push(TestDataSet { store, schema }); + } + } + } + (*files).clone() +} + +struct RunResult { + results: Vec, + explain_plan: String, +} + +async fn run_query_with_config( + query: &str, + config: SessionConfig, + dataset: TestDataSet, +) -> RunResult { + let store = dataset.store; + let schema = dataset.schema; + let ctx = SessionContext::new_with_config(config); + let url = ObjectStoreUrl::parse("memory://").unwrap(); + ctx.register_object_store(url.as_ref(), store.clone()); + + let format = Arc::new( + ParquetFormat::default() + .with_options(ctx.state().table_options().parquet.clone()), + ); + let options = ListingOptions::new(format); + let table_path = ListingTableUrl::parse("memory:///").unwrap(); + let config = ListingTableConfig::new(table_path) + .with_listing_options(options) + .with_schema(schema); + let table = Arc::new(ListingTable::try_new(config).unwrap()); + + ctx.register_table("test_table", table).unwrap(); + + let results = ctx.sql(query).await.unwrap().collect().await.unwrap(); + let explain_batches = ctx + .sql(&format!("EXPLAIN ANALYZE {query}")) + .await + .unwrap() + .collect() + .await + .unwrap(); + let explain_plan = pretty_format_batches(&explain_batches).unwrap().to_string(); + RunResult { + results, + explain_plan, + } +} + +#[derive(Debug)] +struct RunQueryResult { + query: String, + result: Vec, + expected: Vec, +} + +impl RunQueryResult { + fn expected_formatted(&self) -> String { + format!("{}", pretty_format_batches(&self.expected).unwrap()) + } + + fn result_formatted(&self) -> String { + format!("{}", pretty_format_batches(&self.result).unwrap()) + } + + fn is_ok(&self) -> bool { + self.expected_formatted() == self.result_formatted() + } +} + +/// Iterate over each line in the plan and check that one of them has `DataSourceExec` and `DynamicFilter` in the same line. +fn has_dynamic_filter_expr_pushdown(plan: &str) -> bool { + for line in plan.lines() { + if line.contains("DataSourceExec") && line.contains("DynamicFilter") { + return true; + } + } + false +} + +async fn run_query( + query: String, + cfg: SessionConfig, + dataset: TestDataSet, +) -> RunQueryResult { + let cfg_with_dynamic_filters = cfg + .clone() + .set_bool("datafusion.optimizer.enable_dynamic_filter_pushdown", true); + let cfg_without_dynamic_filters = cfg + .clone() + .set_bool("datafusion.optimizer.enable_dynamic_filter_pushdown", false); + + let expected_result = + run_query_with_config(&query, cfg_without_dynamic_filters, dataset.clone()).await; + let result = + run_query_with_config(&query, cfg_with_dynamic_filters, dataset.clone()).await; + // Check that dynamic filters were actually pushed down + if !has_dynamic_filter_expr_pushdown(&result.explain_plan) { + panic!( + "Dynamic filter was not pushed down in query: {query}\n\n{}", + result.explain_plan + ); + } + + RunQueryResult { + query: query.to_string(), + result: result.results, + expected: expected_result.results, + } +} + +struct TestCase { + query: String, + cfg: SessionConfig, + dataset: TestDataSet, +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_fuzz_topk_filter_pushdown() { + let order_columns = ["id", "name", "department"]; + let order_directions = ["ASC", "DESC"]; + let null_orders = ["NULLS FIRST", "NULLS LAST"]; + + let start = datafusion_common::instant::Instant::now(); + let mut orders: HashMap> = HashMap::new(); + for order_column in &order_columns { + for order_direction in &order_directions { + for null_order in &null_orders { + // if there is a vec for this column insert the order, otherwise create a new vec + let ordering = format!("{order_column} {order_direction} {null_order}"); + match orders.get_mut(*order_column) { + Some(order_vec) => { + order_vec.push(ordering); + } + None => { + orders.insert((*order_column).to_string(), vec![ordering]); + } + } + } + } + } + + let mut queries = vec![]; + + for limit in [1, 10] { + for num_order_by_columns in [1, 2, 3] { + for order_columns in ["id", "name", "department"] + .iter() + .combinations(num_order_by_columns) + { + for orderings in order_columns + .iter() + .map(|col| orders.get(**col).unwrap()) + .multi_cartesian_product() + { + let query = format!( + "SELECT * FROM test_table ORDER BY {} LIMIT {}", + orderings.into_iter().join(", "), + limit + ); + queries.push(query); + } + } + } + } + + queries.sort_unstable(); + println!( + "Generated {} queries in {:?}", + queries.len(), + start.elapsed() + ); + + let start = datafusion_common::instant::Instant::now(); + let datasets = test_files().await; + println!("Generated test files in {:?}", start.elapsed()); + + let mut test_cases = vec![]; + for enable_filter_pushdown in [true, false] { + for query in &queries { + for dataset in &datasets { + let mut cfg = SessionConfig::new(); + cfg = cfg.set_bool( + "datafusion.optimizer.enable_dynamic_filter_pushdown", + enable_filter_pushdown, + ); + test_cases.push(TestCase { + query: query.to_string(), + cfg, + dataset: dataset.clone(), + }); + } + } + } + + let start = datafusion_common::instant::Instant::now(); + let mut join_set = JoinSet::new(); + for tc in test_cases { + join_set.spawn(run_query(tc.query, tc.cfg, tc.dataset)); + } + let mut results = join_set.join_all().await; + results.sort_unstable_by(|a, b| a.query.cmp(&b.query)); + println!("Ran {} test cases in {:?}", results.len(), start.elapsed()); + + let failures = results + .iter() + .filter(|result| !result.is_ok()) + .collect::>(); + + for failure in &failures { + println!("Failure:"); + println!("Query:\n{}", failure.query); + println!("\nExpected:\n{}", failure.expected_formatted()); + println!("\nResult:\n{}", failure.result_formatted()); + println!("\n\n"); + } + + if !failures.is_empty() { + panic!("Some test cases failed"); + } else { + println!("All test cases passed"); + } +} diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 5bd2e457b42a5..82b6d0e4e9d89 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -18,24 +18,24 @@ use std::sync::Arc; use arrow::array::{ArrayRef, Int32Array, StringArray}; -use arrow::compute::{concat_batches, SortOptions}; +use arrow::compute::{SortOptions, concat_batches}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::source::DataSourceExec; use datafusion::functions_window::row_number::row_number_udwf; +use datafusion::physical_plan::InputOrderMode::{Linear, PartiallySorted, Sorted}; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::windows::{ - create_window_expr, schema_add_window_field, BoundedWindowAggExec, WindowAggExec, + BoundedWindowAggExec, WindowAggExec, create_window_expr, schema_add_window_field, }; -use datafusion::physical_plan::InputOrderMode::{Linear, PartiallySorted, Sorted}; -use datafusion::physical_plan::{collect, InputOrderMode}; +use datafusion::physical_plan::{InputOrderMode, collect}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::HashMap; use datafusion_common::{Result, ScalarValue}; use datafusion_common_runtime::SpawnedTask; -use datafusion_expr::type_coercion::functions::fields_with_aggregate_udf; +use datafusion_expr::type_coercion::functions::fields_with_udf; use datafusion_expr::{ WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; @@ -252,7 +252,6 @@ async fn bounded_window_causal_non_causal() -> Result<()> { ]; let partitionby_exprs = vec![]; - let orderby_exprs = LexOrdering::default(); // Window frame starts with "UNBOUNDED PRECEDING": let start_bound = WindowFrameBound::Preceding(ScalarValue::UInt64(None)); @@ -285,10 +284,12 @@ async fn bounded_window_causal_non_causal() -> Result<()> { fn_name.to_string(), &args, &partitionby_exprs, - orderby_exprs.as_ref(), + &[], Arc::new(window_frame), - &extended_schema, + extended_schema, false, + false, + None, )?; let running_window_exec = Arc::new(BoundedWindowAggExec::try_new( vec![window_expr], @@ -444,17 +445,17 @@ fn get_random_function( let fn_name = window_fn_map.keys().collect::>()[rand_fn_idx]; let (window_fn, args) = window_fn_map.values().collect::>()[rand_fn_idx]; let mut args = args.clone(); - if let WindowFunctionDefinition::AggregateUDF(udf) = window_fn { - if !args.is_empty() { - // Do type coercion first argument - let a = args[0].clone(); - let dt = a.return_field(schema.as_ref()).unwrap(); - let coerced = fields_with_aggregate_udf(&[dt], udf).unwrap(); - args[0] = cast(a, schema, coerced[0].data_type().clone()).unwrap(); - } + if let WindowFunctionDefinition::AggregateUDF(udf) = window_fn + && !args.is_empty() + { + // Do type coercion first argument + let a = args[0].clone(); + let dt = a.return_field(schema.as_ref()).unwrap(); + let coerced = fields_with_udf(&[dt], udf.as_ref()).unwrap(); + args[0] = cast(a, schema, coerced[0].data_type().clone()).unwrap(); } - (window_fn.clone(), args, fn_name.to_string()) + (window_fn.clone(), args, (*fn_name).to_string()) } fn get_random_window_frame(rng: &mut StdRng, is_linear: bool) -> WindowFrame { @@ -568,10 +569,11 @@ fn convert_bound_to_current_row_if_applicable( ) { match bound { WindowFrameBound::Preceding(value) | WindowFrameBound::Following(value) => { - if let Ok(zero) = ScalarValue::new_zero(&value.data_type()) { - if value == &zero && rng.random_range(0..2) == 0 { - *bound = WindowFrameBound::CurrentRow; - } + if let Ok(zero) = ScalarValue::new_zero(&value.data_type()) + && value == &zero + && rng.random_range(0..2) == 0 + { + *bound = WindowFrameBound::CurrentRow; } } _ => {} @@ -587,14 +589,14 @@ async fn run_window_test( orderby_columns: Vec<&str>, search_mode: InputOrderMode, ) -> Result<()> { - let is_linear = !matches!(search_mode, Sorted); + let is_linear = search_mode != Sorted; let mut rng = StdRng::seed_from_u64(random_seed); let schema = input1[0].schema(); let session_config = SessionConfig::new().with_batch_size(50); let ctx = SessionContext::new_with_config(session_config); let (window_fn, args, fn_name) = get_random_function(&schema, &mut rng, is_linear); let window_frame = get_random_window_frame(&mut rng, is_linear); - let mut orderby_exprs = LexOrdering::default(); + let mut orderby_exprs = vec![]; for column in &orderby_columns { orderby_exprs.push(PhysicalSortExpr { expr: col(column, &schema)?, @@ -602,13 +604,13 @@ async fn run_window_test( }) } if orderby_exprs.len() > 1 && !window_frame.can_accept_multi_orderby() { - orderby_exprs = LexOrdering::new(orderby_exprs[0..1].to_vec()); + orderby_exprs.truncate(1); } let mut partitionby_exprs = vec![]; for column in &partition_by_columns { partitionby_exprs.push(col(column, &schema)?); } - let mut sort_keys = LexOrdering::default(); + let mut sort_keys = vec![]; for partition_by_expr in &partitionby_exprs { sort_keys.push(PhysicalSortExpr { expr: partition_by_expr.clone(), @@ -622,7 +624,7 @@ async fn run_window_test( } let concat_input_record = concat_batches(&schema, &input1)?; - let source_sort_keys = LexOrdering::new(vec![ + let source_sort_keys: LexOrdering = [ PhysicalSortExpr { expr: col("a", &schema)?, options: Default::default(), @@ -635,15 +637,16 @@ async fn run_window_test( expr: col("c", &schema)?, options: Default::default(), }, - ]); + ] + .into(); let mut exec1 = DataSourceExec::from_data_source( MemorySourceConfig::try_new(&[vec![concat_input_record]], schema.clone(), None)? .try_with_sort_information(vec![source_sort_keys.clone()])?, ) as _; // Table is ordered according to ORDER BY a, b, c In linear test we use PARTITION BY b, ORDER BY a // For WindowAggExec to produce correct result it need table to be ordered by b,a. Hence add a sort. - if is_linear { - exec1 = Arc::new(SortExec::new(sort_keys, exec1)) as _; + if is_linear && let Some(ordering) = LexOrdering::new(sort_keys) { + exec1 = Arc::new(SortExec::new(ordering, exec1)) as _; } let extended_schema = schema_add_window_field(&args, &schema, &window_fn, &fn_name)?; @@ -654,17 +657,19 @@ async fn run_window_test( fn_name.clone(), &args, &partitionby_exprs, - orderby_exprs.as_ref(), + &orderby_exprs.clone(), Arc::new(window_frame.clone()), - &extended_schema, + Arc::clone(&extended_schema), false, + false, + None, )?], exec1, false, )?) as _; let exec2 = DataSourceExec::from_data_source( - MemorySourceConfig::try_new(&[input1.clone()], schema.clone(), None)? - .try_with_sort_information(vec![source_sort_keys.clone()])?, + MemorySourceConfig::try_new(&[input1], schema, None)? + .try_with_sort_information(vec![source_sort_keys])?, ); let running_window_exec = Arc::new(BoundedWindowAggExec::try_new( vec![create_window_expr( @@ -672,10 +677,12 @@ async fn run_window_test( fn_name, &args, &partitionby_exprs, - orderby_exprs.as_ref(), + &orderby_exprs, Arc::new(window_frame.clone()), - &extended_schema, + extended_schema, false, + false, + None, )?], exec2, search_mode.clone(), @@ -691,7 +698,9 @@ async fn run_window_test( // BoundedWindowAggExec should produce more chunk than the usual WindowAggExec. // Otherwise it means that we cannot generate result in running mode. - let err_msg = format!("Inconsistent result for window_frame: {window_frame:?}, window_fn: {window_fn:?}, args:{args:?}, random_seed: {random_seed:?}, search_mode: {search_mode:?}, partition_by_columns:{partition_by_columns:?}, orderby_columns: {orderby_columns:?}"); + let err_msg = format!( + "Inconsistent result for window_frame: {window_frame:?}, window_fn: {window_fn:?}, args:{args:?}, random_seed: {random_seed:?}, search_mode: {search_mode:?}, partition_by_columns:{partition_by_columns:?}, orderby_columns: {orderby_columns:?}" + ); // Below check makes sure that, streaming execution generates more chunks than the bulk execution. // Since algorithms and operators works on sliding windows in the streaming execution. // However, in the current test setup for some random generated window frame clauses: It is not guaranteed @@ -723,8 +732,12 @@ async fn run_window_test( .enumerate() { if !usual_line.eq(running_line) { - println!("Inconsistent result for window_frame at line:{i:?}: {window_frame:?}, window_fn: {window_fn:?}, args:{args:?}, pb_cols:{partition_by_columns:?}, ob_cols:{orderby_columns:?}, search_mode:{search_mode:?}"); - println!("--------usual_formatted_sorted----------------running_formatted_sorted--------"); + println!( + "Inconsistent result for window_frame at line:{i:?}: {window_frame:?}, window_fn: {window_fn:?}, args:{args:?}, pb_cols:{partition_by_columns:?}, ob_cols:{orderby_columns:?}, search_mode:{search_mode:?}" + ); + println!( + "--------usual_formatted_sorted----------------running_formatted_sorted--------" + ); for (line1, line2) in usual_formatted_sorted.iter().zip(running_formatted_sorted) { diff --git a/datafusion/core/tests/integration_tests/schema_adapter_integration_tests.rs b/datafusion/core/tests/integration_tests/schema_adapter_integration_tests.rs deleted file mode 100644 index 38c2ee582a616..0000000000000 --- a/datafusion/core/tests/integration_tests/schema_adapter_integration_tests.rs +++ /dev/null @@ -1,260 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Integration test for schema adapter factory functionality - -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use arrow::record_batch::RecordBatch; -use datafusion::datasource::object_store::ObjectStoreUrl; -use datafusion::datasource::physical_plan::arrow_file::ArrowSource; -use datafusion::prelude::*; -use datafusion_common::Result; -use datafusion_datasource::file::FileSource; -use datafusion_datasource::file_scan_config::FileScanConfigBuilder; -use datafusion_datasource::schema_adapter::{SchemaAdapter, SchemaAdapterFactory}; -use datafusion_datasource::source::DataSourceExec; -use datafusion_datasource::PartitionedFile; -use std::sync::Arc; -use tempfile::TempDir; - -#[cfg(feature = "parquet")] -use datafusion_datasource_parquet::ParquetSource; -#[cfg(feature = "parquet")] -use parquet::arrow::ArrowWriter; -#[cfg(feature = "parquet")] -use parquet::file::properties::WriterProperties; - -#[cfg(feature = "csv")] -use datafusion_datasource_csv::CsvSource; - -/// A schema adapter factory that transforms column names to uppercase -#[derive(Debug)] -struct UppercaseAdapterFactory {} - -impl SchemaAdapterFactory for UppercaseAdapterFactory { - fn create(&self, schema: &Schema) -> Result> { - Ok(Box::new(UppercaseAdapter { - input_schema: Arc::new(schema.clone()), - })) - } -} - -/// Schema adapter that transforms column names to uppercase -#[derive(Debug)] -struct UppercaseAdapter { - input_schema: SchemaRef, -} - -impl SchemaAdapter for UppercaseAdapter { - fn adapt(&self, record_batch: RecordBatch) -> Result { - // In a real adapter, we might transform the data too - // For this test, we're just passing through the batch - Ok(record_batch) - } - - fn output_schema(&self) -> SchemaRef { - let fields = self - .input_schema - .fields() - .iter() - .map(|f| { - Field::new( - f.name().to_uppercase().as_str(), - f.data_type().clone(), - f.is_nullable(), - ) - }) - .collect(); - - Arc::new(Schema::new(fields)) - } -} - -#[cfg(feature = "parquet")] -#[tokio::test] -async fn test_parquet_integration_with_schema_adapter() -> Result<()> { - // Create a temporary directory for our test file - let tmp_dir = TempDir::new()?; - let file_path = tmp_dir.path().join("test.parquet"); - let file_path_str = file_path.to_str().unwrap(); - - // Create test data - let schema = Arc::new(Schema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("name", DataType::Utf8, true), - ])); - - let batch = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(arrow::array::Int32Array::from(vec![1, 2, 3])), - Arc::new(arrow::array::StringArray::from(vec!["a", "b", "c"])), - ], - )?; - - // Write test parquet file - let file = std::fs::File::create(file_path_str)?; - let props = WriterProperties::builder().build(); - let mut writer = ArrowWriter::try_new(file, schema.clone(), Some(props))?; - writer.write(&batch)?; - writer.close()?; - - // Create a session context - let ctx = SessionContext::new(); - - // Create a ParquetSource with the adapter factory - let source = ParquetSource::default() - .with_schema_adapter_factory(Arc::new(UppercaseAdapterFactory {})); - - // Create a scan config - let config = FileScanConfigBuilder::new( - ObjectStoreUrl::parse(&format!("file://{}", file_path_str))?, - schema.clone(), - ) - .with_source(source) - .build(); - - // Create a data source executor - let exec = DataSourceExec::from_data_source(config); - - // Collect results - let task_ctx = ctx.task_ctx(); - let stream = exec.execute(0, task_ctx)?; - let batches = datafusion::physical_plan::common::collect(stream).await?; - - // There should be one batch - assert_eq!(batches.len(), 1); - - // Verify the schema has uppercase column names - let result_schema = batches[0].schema(); - assert_eq!(result_schema.field(0).name(), "ID"); - assert_eq!(result_schema.field(1).name(), "NAME"); - - Ok(()) -} - -#[tokio::test] -async fn test_multi_source_schema_adapter_reuse() -> Result<()> { - // This test verifies that the same schema adapter factory can be reused - // across different file source types. This is important for ensuring that: - // 1. The schema adapter factory interface works uniformly across all source types - // 2. The factory can be shared and cloned efficiently using Arc - // 3. Various data source implementations correctly implement the schema adapter factory pattern - - // Create a test factory - let factory = Arc::new(UppercaseAdapterFactory {}); - - // Apply the same adapter to different source types - let arrow_source = - ArrowSource::default().with_schema_adapter_factory(factory.clone()); - - #[cfg(feature = "parquet")] - let parquet_source = - ParquetSource::default().with_schema_adapter_factory(factory.clone()); - - #[cfg(feature = "csv")] - let csv_source = CsvSource::default().with_schema_adapter_factory(factory.clone()); - - // Verify adapters were properly set - assert!(arrow_source.schema_adapter_factory().is_some()); - - #[cfg(feature = "parquet")] - assert!(parquet_source.schema_adapter_factory().is_some()); - - #[cfg(feature = "csv")] - assert!(csv_source.schema_adapter_factory().is_some()); - - Ok(()) -} - -// Helper function to test From for Arc implementations -fn test_from_impl> + Default>(expected_file_type: &str) { - let source = T::default(); - let file_source: Arc = source.into(); - assert_eq!(file_source.file_type(), expected_file_type); -} - -#[test] -fn test_from_implementations() { - // Test From implementation for various sources - test_from_impl::("arrow"); - - #[cfg(feature = "parquet")] - test_from_impl::("parquet"); - - #[cfg(feature = "csv")] - test_from_impl::("csv"); - - #[cfg(feature = "json")] - test_from_impl::("json"); -} - -/// A simple test schema adapter factory that doesn't modify the schema -#[derive(Debug)] -struct TestSchemaAdapterFactory {} - -impl SchemaAdapterFactory for TestSchemaAdapterFactory { - fn create(&self, schema: &Schema) -> Result> { - Ok(Box::new(TestSchemaAdapter { - input_schema: Arc::new(schema.clone()), - })) - } -} - -/// A test schema adapter that passes through data unmodified -#[derive(Debug)] -struct TestSchemaAdapter { - input_schema: SchemaRef, -} - -impl SchemaAdapter for TestSchemaAdapter { - fn adapt(&self, record_batch: RecordBatch) -> Result { - // Just pass through the batch unmodified - Ok(record_batch) - } - - fn output_schema(&self) -> SchemaRef { - self.input_schema.clone() - } -} - -#[cfg(feature = "parquet")] -#[test] -fn test_schema_adapter_preservation() { - // Create a test schema - let schema = Arc::new(Schema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("name", DataType::Utf8, true), - ])); - - // Create source with schema adapter factory - let source = ParquetSource::default(); - let factory = Arc::new(TestSchemaAdapterFactory {}); - let file_source = source.with_schema_adapter_factory(factory); - - // Create a FileScanConfig with the source - let config_builder = - FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), schema.clone()) - .with_source(file_source.clone()) - // Add a file to make it valid - .with_file(PartitionedFile::new("test.parquet", 100)); - - let config = config_builder.build(); - - // Verify the schema adapter factory is present in the file source - assert!(config.source().schema_adapter_factory().is_some()); -} diff --git a/datafusion/core/tests/macro_hygiene/mod.rs b/datafusion/core/tests/macro_hygiene/mod.rs index 9196efec972c1..9fd60cd1f06f3 100644 --- a/datafusion/core/tests/macro_hygiene/mod.rs +++ b/datafusion/core/tests/macro_hygiene/mod.rs @@ -65,3 +65,43 @@ mod config_namespace { } } } + +mod config_field { + // NO other imports! + use datafusion_common::config_field; + + #[test] + fn test_macro() { + #[derive(Debug)] + #[expect(dead_code)] + struct E; + + impl std::fmt::Display for E { + fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + unimplemented!() + } + } + + impl std::error::Error for E {} + + #[expect(dead_code)] + #[derive(Default)] + struct S; + + impl std::str::FromStr for S { + type Err = E; + + fn from_str(_s: &str) -> Result { + unimplemented!() + } + } + + impl std::fmt::Display for S { + fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + unimplemented!() + } + } + + config_field!(S); + } +} diff --git a/datafusion/core/tests/memory_limit/memory_limit_validation/sort_mem_validation.rs b/datafusion/core/tests/memory_limit/memory_limit_validation/sort_mem_validation.rs index 64ab1378340aa..e1d5f1b1ab198 100644 --- a/datafusion/core/tests/memory_limit/memory_limit_validation/sort_mem_validation.rs +++ b/datafusion/core/tests/memory_limit/memory_limit_validation/sort_mem_validation.rs @@ -31,7 +31,7 @@ static INIT: Once = Once::new(); // =========================================================================== // Test runners: -// Runners are splitted into multiple tests to run in parallel +// Runners are split into multiple tests to run in parallel // =========================================================================== #[test] @@ -98,11 +98,9 @@ fn init_once() { fn spawn_test_process(test: &str) { init_once(); - let test_path = format!( - "memory_limit::memory_limit_validation::sort_mem_validation::{}", - test - ); - info!("Running test: {}", test_path); + let test_path = + format!("memory_limit::memory_limit_validation::sort_mem_validation::{test}"); + info!("Running test: {test_path}"); // Run the test command let output = Command::new("cargo") @@ -125,7 +123,7 @@ fn spawn_test_process(test: &str) { let stdout = str::from_utf8(&output.stdout).unwrap_or(""); let stderr = str::from_utf8(&output.stderr).unwrap_or(""); - info!("{}", stdout); + info!("{stdout}"); assert!( output.status.success(), diff --git a/datafusion/core/tests/memory_limit/memory_limit_validation/utils.rs b/datafusion/core/tests/memory_limit/memory_limit_validation/utils.rs index 7b157b707a6de..2c9fae20c8606 100644 --- a/datafusion/core/tests/memory_limit/memory_limit_validation/utils.rs +++ b/datafusion/core/tests/memory_limit/memory_limit_validation/utils.rs @@ -16,16 +16,14 @@ // under the License. use datafusion_common_runtime::SpawnedTask; -use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; use sysinfo::{ProcessRefreshKind, ProcessesToUpdate, System}; -use tokio::time::{interval, Duration}; +use tokio::time::{Duration, interval}; use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_execution::{ - memory_pool::{human_readable_size, FairSpillPool}, - runtime_env::RuntimeEnvBuilder, -}; +use datafusion_common::human_readable_size; +use datafusion_execution::{memory_pool::FairSpillPool, runtime_env::RuntimeEnvBuilder}; /// Measures the maximum RSS (in bytes) during the execution of an async task. RSS /// will be sampled every 7ms. @@ -40,7 +38,7 @@ use datafusion_execution::{ async fn measure_max_rss(f: F) -> (T, usize) where F: FnOnce() -> Fut, - Fut: std::future::Future, + Fut: Future, { // Initialize system information let mut system = System::new_all(); diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index 7695cc0969d87..ff8c512cbd22e 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -23,11 +23,13 @@ use std::sync::{Arc, LazyLock}; #[cfg(feature = "extended_tests")] mod memory_limit_validation; +mod repartition_mem_limit; use arrow::array::{ArrayRef, DictionaryArray, Int32Array, RecordBatch, StringViewArray}; use arrow::compute::SortOptions; use arrow::datatypes::{Int32Type, SchemaRef}; use arrow_schema::{DataType, Field, Schema}; use datafusion::assert_batches_eq; +use datafusion::config::SpillCompression; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::source::DataSourceExec; use datafusion::datasource::{MemTable, TableProvider}; @@ -37,19 +39,19 @@ use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::streaming::PartitionStream; use datafusion::physical_plan::{ExecutionPlan, SendableRecordBatchStream}; use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_catalog::streaming::StreamingTable; use datafusion_catalog::Session; -use datafusion_common::{assert_contains, Result}; +use datafusion_catalog::streaming::StreamingTable; +use datafusion_common::{Result, assert_contains}; +use datafusion_execution::TaskContext; use datafusion_execution::disk_manager::{DiskManagerBuilder, DiskManagerMode}; use datafusion_execution::memory_pool::{ FairSpillPool, GreedyMemoryPool, MemoryPool, TrackConsumersPool, }; use datafusion_execution::runtime_env::RuntimeEnv; -use datafusion_execution::TaskContext; use datafusion_expr::{Expr, TableType}; use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr}; -use datafusion_physical_optimizer::join_selection::JoinSelection; use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_optimizer::join_selection::JoinSelection; use datafusion_physical_plan::collect as collect_batches; use datafusion_physical_plan::common::collect; use datafusion_physical_plan::spill::get_record_batch_memory_size; @@ -84,7 +86,8 @@ async fn group_by_none() { TestCase::new() .with_query("select median(request_bytes) from t") .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n AggregateStream" + "Resources exhausted: Additional allocation failed", + "with top memory consumers (across reservations) as:\n AggregateStream", ]) .with_memory_limit(2_000) .run() @@ -96,7 +99,7 @@ async fn group_by_row_hash() { TestCase::new() .with_query("select count(*) from t GROUP BY response_bytes") .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n GroupedHashAggregateStream" + "Resources exhausted: Additional allocation failed", "with top memory consumers (across reservations) as:\n GroupedHashAggregateStream" ]) .with_memory_limit(2_000) .run() @@ -109,7 +112,7 @@ async fn group_by_hash() { // group by dict column .with_query("select count(*) from t GROUP BY service, host, pod, container") .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n GroupedHashAggregateStream" + "Resources exhausted: Additional allocation failed", "with top memory consumers (across reservations) as:\n GroupedHashAggregateStream" ]) .with_memory_limit(1_000) .run() @@ -122,7 +125,8 @@ async fn join_by_key_multiple_partitions() { TestCase::new() .with_query("select t1.* from t t1 JOIN t t2 ON t1.service = t2.service") .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n HashJoinInput", + "Resources exhausted: Additional allocation failed", + "with top memory consumers (across reservations) as:\n HashJoinInput", ]) .with_memory_limit(1_000) .with_config(config) @@ -136,7 +140,8 @@ async fn join_by_key_single_partition() { TestCase::new() .with_query("select t1.* from t t1 JOIN t t2 ON t1.service = t2.service") .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n HashJoinInput", + "Resources exhausted: Additional allocation failed", + "with top memory consumers (across reservations) as:\n HashJoinInput", ]) .with_memory_limit(1_000) .with_config(config) @@ -149,7 +154,7 @@ async fn join_by_expression() { TestCase::new() .with_query("select t1.* from t t1 JOIN t t2 ON t1.service != t2.service") .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n NestedLoopJoinLoad[0]", + "Resources exhausted: Additional allocation failed", "with top memory consumers (across reservations) as:\n NestedLoopJoinLoad[0]", ]) .with_memory_limit(1_000) .run() @@ -161,7 +166,8 @@ async fn cross_join() { TestCase::new() .with_query("select t1.*, t2.* from t t1 CROSS JOIN t t2") .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n CrossJoinExec", + "Resources exhausted: Additional allocation failed", + "with top memory consumers (across reservations) as:\n CrossJoinExec", ]) .with_memory_limit(1_000) .run() @@ -217,7 +223,7 @@ async fn symmetric_hash_join() { "select t1.* from t t1 JOIN t t2 ON t1.pod = t2.pod AND t1.time = t2.time", ) .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n SymmetricHashJoinStream", + "Resources exhausted: Additional allocation failed", "with top memory consumers (across reservations) as:\n SymmetricHashJoinStream", ]) .with_memory_limit(1_000) .with_scenario(Scenario::AccessLogStreaming) @@ -235,7 +241,7 @@ async fn sort_preserving_merge() { // so only a merge is needed .with_query("select * from t ORDER BY a ASC NULLS LAST, b ASC NULLS LAST LIMIT 10") .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n SortPreservingMergeExec", + "Resources exhausted: Additional allocation failed", "with top memory consumers (across reservations) as:\n SortPreservingMergeExec", ]) // provide insufficient memory to merge .with_memory_limit(partition_size / 2) @@ -314,7 +320,8 @@ async fn sort_spill_reservation() { test.clone() .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:", + "Resources exhausted: Additional allocation failed", + "with top memory consumers (across reservations) as:", "B for ExternalSorterMerge", ]) .with_config(config) @@ -344,7 +351,8 @@ async fn oom_recursive_cte() { SELECT * FROM nodes;", ) .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n RecursiveQuery", + "Resources exhausted: Additional allocation failed", + "with top memory consumers (across reservations) as:\n RecursiveQuery", ]) .with_memory_limit(2_000) .run() @@ -396,7 +404,7 @@ async fn oom_with_tracked_consumer_pool() { .with_expected_errors(vec![ "Failed to allocate additional", "for ParquetSink(ArrowColumnWriter)", - "Additional allocation failed with top memory consumers (across reservations) as:\n ParquetSink(ArrowColumnWriter)" + "Additional allocation failed", "with top memory consumers (across reservations) as:\n ParquetSink(ArrowColumnWriter)" ]) .with_memory_pool(Arc::new( TrackConsumersPool::new( @@ -545,10 +553,11 @@ async fn test_external_sort_zero_merge_reservation() { // Tests for disk limit (`max_temp_directory_size` in `DiskManager`) // ------------------------------------------------------------------ -// Create a new `SessionContext` with speicified disk limit and memory pool limit +// Create a new `SessionContext` with specified disk limit, memory pool limit, and spill compression codec async fn setup_context( disk_limit: u64, memory_pool_limit: usize, + spill_compression: SpillCompression, ) -> Result { let disk_manager = DiskManagerBuilder::default() .with_mode(DiskManagerMode::OsTmpDirectory) @@ -565,11 +574,16 @@ async fn setup_context( disk_manager: Arc::new(disk_manager), cache_manager: runtime.cache_manager.clone(), object_store_registry: runtime.object_store_registry.clone(), + #[cfg(feature = "parquet_encryption")] + parquet_encryption_factory_registry: runtime + .parquet_encryption_factory_registry + .clone(), }); let config = SessionConfig::new() .with_sort_spill_reservation_bytes(64 * 1024) // 256KB .with_sort_in_place_threshold_bytes(0) + .with_spill_compression(spill_compression) .with_batch_size(64) // To reduce test memory usage .with_target_partitions(1); @@ -580,18 +594,24 @@ async fn setup_context( /// (specified by `max_temp_directory_size` in `DiskManager`) #[tokio::test] async fn test_disk_spill_limit_reached() -> Result<()> { - let ctx = setup_context(1024 * 1024, 1024 * 1024).await?; // 1MB disk limit, 1MB memory limit + let spill_compression = SpillCompression::Uncompressed; + let ctx = setup_context(1024 * 1024, 1024 * 1024, spill_compression).await?; // 1MB disk limit, 1MB memory limit let df = ctx .sql("select * from generate_series(1, 1000000000000) as t1(v1) order by v1") .await .unwrap(); - let err = df.collect().await.unwrap_err(); - assert_contains!( - err.to_string(), - "The used disk space during the spilling process has exceeded the allowable limit" - ); + let error_message = df.collect().await.unwrap_err().to_string(); + for expected in [ + "The used disk space during the spilling process has exceeded the allowable limit", + "datafusion.runtime.max_temp_directory_size", + ] { + assert!( + error_message.contains(expected), + "'{expected}' is not contained by '{error_message}'" + ); + } Ok(()) } @@ -602,7 +622,8 @@ async fn test_disk_spill_limit_reached() -> Result<()> { #[tokio::test] async fn test_disk_spill_limit_not_reached() -> Result<()> { let disk_spill_limit = 1024 * 1024; // 1MB - let ctx = setup_context(disk_spill_limit, 128 * 1024).await?; // 1MB disk limit, 128KB memory limit + let spill_compression = SpillCompression::Uncompressed; + let ctx = setup_context(disk_spill_limit, 128 * 1024, spill_compression).await?; // 1MB disk limit, 128KB memory limit let df = ctx .sql("select * from generate_series(1, 10000) as t1(v1) order by v1") @@ -630,6 +651,77 @@ async fn test_disk_spill_limit_not_reached() -> Result<()> { Ok(()) } +/// External query should succeed using zstd as spill compression codec and +/// and all temporary spill files are properly cleaned up after execution. +/// Note: This test does not inspect file contents (e.g. magic number), +/// as spill files are automatically deleted on drop. +#[tokio::test] +async fn test_spill_file_compressed_with_zstd() -> Result<()> { + let disk_spill_limit = 1024 * 1024; // 1MB + let spill_compression = SpillCompression::Zstd; + let ctx = setup_context(disk_spill_limit, 128 * 1024, spill_compression).await?; // 1MB disk limit, 128KB memory limit, zstd + + let df = ctx + .sql("select * from generate_series(1, 100000) as t1(v1) order by v1") + .await + .unwrap(); + let plan = df.create_physical_plan().await.unwrap(); + + let task_ctx = ctx.task_ctx(); + let _ = collect_batches(Arc::clone(&plan), task_ctx) + .await + .expect("Query execution failed"); + + let spill_count = plan.metrics().unwrap().spill_count().unwrap(); + let spilled_bytes = plan.metrics().unwrap().spilled_bytes().unwrap(); + + println!("spill count {spill_count}"); + assert!(spill_count > 0); + assert!((spilled_bytes as u64) < disk_spill_limit); + + // Verify that all temporary files have been properly cleaned up by checking + // that the total disk usage tracked by the disk manager is zero + let current_disk_usage = ctx.runtime_env().disk_manager.used_disk_space(); + assert_eq!(current_disk_usage, 0); + + Ok(()) +} + +/// External query should succeed using lz4_frame as spill compression codec and +/// and all temporary spill files are properly cleaned up after execution. +/// Note: This test does not inspect file contents (e.g. magic number), +/// as spill files are automatically deleted on drop. +#[tokio::test] +async fn test_spill_file_compressed_with_lz4_frame() -> Result<()> { + let disk_spill_limit = 1024 * 1024; // 1MB + let spill_compression = SpillCompression::Lz4Frame; + let ctx = setup_context(disk_spill_limit, 128 * 1024, spill_compression).await?; // 1MB disk limit, 128KB memory limit, lz4_frame + + let df = ctx + .sql("select * from generate_series(1, 100000) as t1(v1) order by v1") + .await + .unwrap(); + let plan = df.create_physical_plan().await.unwrap(); + + let task_ctx = ctx.task_ctx(); + let _ = collect_batches(Arc::clone(&plan), task_ctx) + .await + .expect("Query execution failed"); + + let spill_count = plan.metrics().unwrap().spill_count().unwrap(); + let spilled_bytes = plan.metrics().unwrap().spilled_bytes().unwrap(); + + println!("spill count {spill_count}"); + assert!(spill_count > 0); + assert!((spilled_bytes as u64) < disk_spill_limit); + + // Verify that all temporary files have been properly cleaned up by checking + // that the total disk usage tracked by the disk manager is zero + let current_disk_usage = ctx.runtime_env().disk_manager.used_disk_space(); + assert_eq!(current_disk_usage, 0); + + Ok(()) +} /// Run the query with the specified memory limit, /// and verifies the expected errors are returned #[derive(Clone, Debug)] @@ -726,7 +818,7 @@ impl TestCase { /// Specify an expected plan to review pub fn with_expected_plan(mut self, expected_plan: &[&str]) -> Self { - self.expected_plan = expected_plan.iter().map(|s| s.to_string()).collect(); + self.expected_plan = expected_plan.iter().map(|s| (*s).to_string()).collect(); self } @@ -890,16 +982,13 @@ impl Scenario { descending: false, nulls_first: false, }; - let sort_information = vec![LexOrdering::new(vec![ - PhysicalSortExpr { - expr: col("a", &schema).unwrap(), - options, - }, - PhysicalSortExpr { - expr: col("b", &schema).unwrap(), - options, - }, - ])]; + let sort_information = vec![ + [ + PhysicalSortExpr::new(col("a", &schema).unwrap(), options), + PhysicalSortExpr::new(col("b", &schema).unwrap(), options), + ] + .into(), + ]; let table = SortedTableProvider::new(batches, sort_information); Arc::new(table) @@ -975,7 +1064,7 @@ fn make_dict_batches() -> Vec { let batch_size = 50; let mut i = 0; - let gen = std::iter::from_fn(move || { + let batch_gen = std::iter::from_fn(move || { // create values like // 0000000001 // 0000000002 @@ -998,7 +1087,7 @@ fn make_dict_batches() -> Vec { let num_batches = 5; - let batches: Vec<_> = gen.take(num_batches).collect(); + let batches: Vec<_> = batch_gen.take(num_batches).collect(); batches.iter().enumerate().for_each(|(i, batch)| { println!("Dict batch[{i}] size is: {}", batch.get_array_memory_size()); @@ -1013,9 +1102,9 @@ fn batches_byte_size(batches: &[RecordBatch]) -> usize { } #[derive(Debug)] -struct DummyStreamPartition { - schema: SchemaRef, - batches: Vec, +pub(crate) struct DummyStreamPartition { + pub(crate) schema: SchemaRef, + pub(crate) batches: Vec, } impl PartitionStream for DummyStreamPartition { diff --git a/datafusion/core/tests/memory_limit/repartition_mem_limit.rs b/datafusion/core/tests/memory_limit/repartition_mem_limit.rs new file mode 100644 index 0000000000000..b21bffebaf95e --- /dev/null +++ b/datafusion/core/tests/memory_limit/repartition_mem_limit.rs @@ -0,0 +1,119 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{ArrayRef, Int32Array, RecordBatch}; +use datafusion::{ + assert_batches_sorted_eq, + prelude::{SessionConfig, SessionContext}, +}; +use datafusion_catalog::MemTable; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_execution::runtime_env::RuntimeEnvBuilder; +use datafusion_physical_plan::{ExecutionPlanProperties, repartition::RepartitionExec}; +use futures::TryStreamExt; +use itertools::Itertools; + +/// End to end test for spilling in RepartitionExec. +/// The idea is to make a real world query with a relatively low memory limit and +/// then drive one partition at a time, simulating dissimilar execution speed in partitions. +/// Just as some examples of real world scenarios where this can happen consider +/// lopsided groups in a group by especially if one partitions spills and others don't, +/// or in distributed systems if one upstream node is slower than others. +#[tokio::test] +async fn test_repartition_memory_limit() { + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(1024 * 1024, 1.0) + .build() + .unwrap(); + let config = SessionConfig::new() + .with_batch_size(32) + .with_target_partitions(2); + let ctx = SessionContext::new_with_config_rt(config, Arc::new(runtime)); + let batches = vec![ + RecordBatch::try_from_iter(vec![( + "c1", + Arc::new(Int32Array::from_iter_values((0..10).cycle().take(100_000))) + as ArrayRef, + )]) + .unwrap(), + ]; + let table = Arc::new(MemTable::try_new(batches[0].schema(), vec![batches]).unwrap()); + ctx.register_table("t", table).unwrap(); + let plan = ctx + .state() + .create_logical_plan("SELECT c1, count(*) as c FROM t GROUP BY c1;") + .await + .unwrap(); + let plan = ctx.state().create_physical_plan(&plan).await.unwrap(); + assert_eq!(plan.output_partitioning().partition_count(), 2); + // Execute partition 0, this should cause items going into the rest of the partitions to queue up and because + // of the low memory limit should spill to disk. + let batches0 = Arc::clone(&plan) + .execute(0, ctx.task_ctx()) + .unwrap() + .try_collect::>() + .await + .unwrap(); + + let mut metrics = None; + Arc::clone(&plan) + .transform_down(|node| { + if node.as_any().is::() { + metrics = node.metrics(); + } + Ok(Transformed::no(node)) + }) + .unwrap(); + + let metrics = metrics.unwrap(); + assert!(metrics.spilled_bytes().unwrap() > 0); + assert!(metrics.spilled_rows().unwrap() > 0); + assert!(metrics.spill_count().unwrap() > 0); + + // Execute the other partition + let batches1 = Arc::clone(&plan) + .execute(1, ctx.task_ctx()) + .unwrap() + .try_collect::>() + .await + .unwrap(); + + let all_batches = batches0 + .into_iter() + .chain(batches1.into_iter()) + .collect_vec(); + #[rustfmt::skip] + let expected = &[ + "+----+-------+", + "| c1 | c |", + "+----+-------+", + "| 0 | 10000 |", + "| 1 | 10000 |", + "| 2 | 10000 |", + "| 3 | 10000 |", + "| 4 | 10000 |", + "| 5 | 10000 |", + "| 6 | 10000 |", + "| 7 | 10000 |", + "| 8 | 10000 |", + "| 9 | 10000 |", + "+----+-------+", + ]; + assert_batches_sorted_eq!(expected, &all_batches); +} diff --git a/datafusion/core/tests/optimizer/mod.rs b/datafusion/core/tests/optimizer/mod.rs index 2daed4fe36bbe..6466e9ad96d17 100644 --- a/datafusion/core/tests/optimizer/mod.rs +++ b/datafusion/core/tests/optimizer/mod.rs @@ -18,6 +18,7 @@ //! Tests for the DataFusion SQL query planner that require functions from the //! datafusion-functions crate. +use insta::assert_snapshot; use std::any::Any; use std::collections::HashMap; use std::sync::Arc; @@ -26,17 +27,16 @@ use arrow::datatypes::{ DataType, Field, Fields, Schema, SchemaBuilder, SchemaRef, TimeUnit, }; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{TransformedResult, TreeNode}; -use datafusion_common::{plan_err, DFSchema, Result, ScalarValue, TableReference}; +use datafusion_common::tree_node::TransformedResult; +use datafusion_common::{DFSchema, Result, ScalarValue, TableReference, plan_err}; use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; use datafusion_expr::{ - col, lit, AggregateUDF, BinaryExpr, Expr, ExprSchemable, LogicalPlan, Operator, - ScalarUDF, TableSource, WindowUDF, + AggregateUDF, BinaryExpr, Expr, ExprSchemable, LogicalPlan, Operator, ScalarUDF, + TableSource, WindowUDF, col, lit, }; use datafusion_functions::core::expr_ext::FieldAccessor; use datafusion_optimizer::analyzer::Analyzer; use datafusion_optimizer::optimizer::Optimizer; -use datafusion_optimizer::simplify_expressions::GuaranteeRewriter; use datafusion_optimizer::{OptimizerConfig, OptimizerContext}; use datafusion_sql::planner::{ContextProvider, SqlToRel}; use datafusion_sql::sqlparser::ast::Statement; @@ -44,6 +44,7 @@ use datafusion_sql::sqlparser::dialect::GenericDialect; use datafusion_sql::sqlparser::parser::Parser; use chrono::DateTime; +use datafusion_expr::expr_rewriter::rewrite_with_guarantees; use datafusion_functions::datetime; #[cfg(test)] @@ -56,9 +57,14 @@ fn init() { #[test] fn select_arrow_cast() { let sql = "SELECT arrow_cast(1234, 'Float64') as f64, arrow_cast('foo', 'LargeUtf8') as large"; - let expected = "Projection: Float64(1234) AS f64, LargeUtf8(\"foo\") AS large\ - \n EmptyRelation"; - quick_test(sql, expected); + let plan = test_sql(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: Float64(1234) AS f64, LargeUtf8("foo") AS large + EmptyRelation: rows=1 + "# + ); } #[test] fn timestamp_nano_ts_none_predicates() -> Result<()> { @@ -68,11 +74,15 @@ fn timestamp_nano_ts_none_predicates() -> Result<()> { // a scan should have the now()... predicate folded to a single // constant and compared to the column without a cast so it can be // pushed down / pruned - let expected = - "Projection: test.col_int32\ - \n Filter: test.col_ts_nano_none < TimestampNanosecond(1666612093000000000, None)\ - \n TableScan: test projection=[col_int32, col_ts_nano_none]"; - quick_test(sql, expected); + let plan = test_sql(sql).unwrap(); + assert_snapshot!( + plan, + @r" + Projection: test.col_int32 + Filter: test.col_ts_nano_none < TimestampNanosecond(1666612093000000000, None) + TableScan: test projection=[col_int32, col_ts_nano_none] + " + ); Ok(()) } @@ -84,10 +94,15 @@ fn timestamp_nano_ts_utc_predicates() { // a scan should have the now()... predicate folded to a single // constant and compared to the column without a cast so it can be // pushed down / pruned - let expected = - "Projection: test.col_int32\n Filter: test.col_ts_nano_utc < TimestampNanosecond(1666612093000000000, Some(\"+00:00\"))\ - \n TableScan: test projection=[col_int32, col_ts_nano_utc]"; - quick_test(sql, expected); + let plan = test_sql(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: test.col_int32 + Filter: test.col_ts_nano_utc < TimestampNanosecond(1666612093000000000, Some("+00:00")) + TableScan: test projection=[col_int32, col_ts_nano_utc] + "# + ); } #[test] @@ -95,10 +110,14 @@ fn concat_literals() -> Result<()> { let sql = "SELECT concat(true, col_int32, false, null, 'hello', col_utf8, 12, 3.4) \ AS col FROM test"; - let expected = - "Projection: concat(Utf8(\"true\"), CAST(test.col_int32 AS Utf8), Utf8(\"falsehello\"), test.col_utf8, Utf8(\"123.4\")) AS col\ - \n TableScan: test projection=[col_int32, col_utf8]"; - quick_test(sql, expected); + let plan = test_sql(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: concat(Utf8("true"), CAST(test.col_int32 AS Utf8), Utf8("falsehello"), test.col_utf8, Utf8("123.4")) AS col + TableScan: test projection=[col_int32, col_utf8] + "# + ); Ok(()) } @@ -107,16 +126,15 @@ fn concat_ws_literals() -> Result<()> { let sql = "SELECT concat_ws('-', true, col_int32, false, null, 'hello', col_utf8, 12, '', 3.4) \ AS col FROM test"; - let expected = - "Projection: concat_ws(Utf8(\"-\"), Utf8(\"true\"), CAST(test.col_int32 AS Utf8), Utf8(\"false-hello\"), test.col_utf8, Utf8(\"12--3.4\")) AS col\ - \n TableScan: test projection=[col_int32, col_utf8]"; - quick_test(sql, expected); - Ok(()) -} - -fn quick_test(sql: &str, expected_plan: &str) { let plan = test_sql(sql).unwrap(); - assert_eq!(expected_plan, format!("{plan}")); + assert_snapshot!( + plan, + @r#" + Projection: concat_ws(Utf8("-"), Utf8("true"), CAST(test.col_int32 AS Utf8), Utf8("false-hello"), test.col_utf8, Utf8("12--3.4")) AS col + TableScan: test projection=[col_int32, col_utf8] + "# + ); + Ok(()) } fn test_sql(sql: &str) -> Result { @@ -126,8 +144,9 @@ fn test_sql(sql: &str) -> Result { let statement = &ast[0]; // create a logical query plan + let config = ConfigOptions::default(); let context_provider = MyContextProvider::default() - .with_udf(datetime::now()) + .with_udf(datetime::now(&config)) .with_udf(datafusion_functions::core::arrow_cast()) .with_udf(datafusion_functions::string::concat()) .with_udf(datafusion_functions::string::concat_ws()); @@ -142,7 +161,7 @@ fn test_sql(sql: &str) -> Result { let analyzer = Analyzer::new(); let optimizer = Optimizer::new(); // analyze and optimize the logical plan - let plan = analyzer.execute_and_check(plan, config.options(), |_, _| {})?; + let plan = analyzer.execute_and_check(plan, &config.options(), |_, _| {})?; optimizer.optimize(plan, &config, |_, _| {}) } @@ -268,7 +287,7 @@ fn test_nested_schema_nullability() { #[test] fn test_inequalities_non_null_bounded() { - let guarantees = vec![ + let guarantees = [ // x ∈ [1, 3] (not null) ( col("x"), @@ -285,8 +304,6 @@ fn test_inequalities_non_null_bounded() { ), ]; - let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); - // (original_expr, expected_simplification) let simplified_cases = &[ (col("x").lt(lit(0)), false), @@ -318,7 +335,7 @@ fn test_inequalities_non_null_bounded() { ), ]; - validate_simplified_cases(&mut rewriter, simplified_cases); + validate_simplified_cases(&guarantees, simplified_cases); let unchanged_cases = &[ col("x").gt(lit(2)), @@ -329,16 +346,20 @@ fn test_inequalities_non_null_bounded() { col("x").not_between(lit(3), lit(10)), ]; - validate_unchanged_cases(&mut rewriter, unchanged_cases); + validate_unchanged_cases(&guarantees, unchanged_cases); } -fn validate_simplified_cases(rewriter: &mut GuaranteeRewriter, cases: &[(Expr, T)]) -where +fn validate_simplified_cases( + guarantees: &[(Expr, NullableInterval)], + cases: &[(Expr, T)], +) where ScalarValue: From, T: Clone, { for (expr, expected_value) in cases { - let output = expr.clone().rewrite(rewriter).data().unwrap(); + let output = rewrite_with_guarantees(expr.clone(), guarantees) + .data() + .unwrap(); let expected = lit(ScalarValue::from(expected_value.clone())); assert_eq!( output, expected, @@ -346,9 +367,11 @@ where ); } } -fn validate_unchanged_cases(rewriter: &mut GuaranteeRewriter, cases: &[Expr]) { +fn validate_unchanged_cases(guarantees: &[(Expr, NullableInterval)], cases: &[Expr]) { for expr in cases { - let output = expr.clone().rewrite(rewriter).data().unwrap(); + let output = rewrite_with_guarantees(expr.clone(), guarantees) + .data() + .unwrap(); assert_eq!( &output, expr, "{expr} was simplified to {output}, but expected it to be unchanged" diff --git a/datafusion/core/tests/parquet/custom_reader.rs b/datafusion/core/tests/parquet/custom_reader.rs index 761a78a29fd3a..ae11fa9a11334 100644 --- a/datafusion/core/tests/parquet/custom_reader.rs +++ b/datafusion/core/tests/parquet/custom_reader.rs @@ -20,33 +20,33 @@ use std::ops::Range; use std::sync::Arc; use std::time::SystemTime; -use arrow::array::{ArrayRef, Int64Array, Int8Array, StringArray}; +use arrow::array::{ArrayRef, Int8Array, Int64Array, StringArray}; use arrow::datatypes::{Field, Schema, SchemaBuilder}; use arrow::record_batch::RecordBatch; -use datafusion::datasource::file_format::parquet::fetch_parquet_metadata; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{ - FileMeta, ParquetFileMetrics, ParquetFileReaderFactory, ParquetSource, + ParquetFileMetrics, ParquetFileReaderFactory, ParquetSource, }; use datafusion::physical_plan::collect; use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion::prelude::SessionContext; -use datafusion_common::test_util::batches_to_sort_string; use datafusion_common::Result; +use datafusion_common::test_util::batches_to_sort_string; use bytes::Bytes; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_datasource::source::DataSourceExec; +use datafusion_datasource_parquet::metadata::DFParquetMetadata; use futures::future::BoxFuture; use futures::{FutureExt, TryFutureExt}; use insta::assert_snapshot; use object_store::memory::InMemory; use object_store::path::Path; -use object_store::{ObjectMeta, ObjectStore}; +use object_store::{ObjectMeta, ObjectStore, ObjectStoreExt}; +use parquet::arrow::ArrowWriter; use parquet::arrow::arrow_reader::ArrowReaderOptions; use parquet::arrow::async_reader::AsyncFileReader; -use parquet::arrow::ArrowWriter; use parquet::errors::ParquetError; use parquet::file::metadata::ParquetMetaData; @@ -69,18 +69,14 @@ async fn route_data_access_ops_to_parquet_file_reader_factory() { store_parquet_in_memory(vec![batch]).await; let file_group = parquet_files_meta .into_iter() - .map(|meta| PartitionedFile { - object_meta: meta, - partition_values: vec![], - range: None, - statistics: None, - extensions: Some(Arc::new(String::from(EXPECTED_USER_DEFINED_METADATA))), - metadata_size_hint: None, + .map(|meta| { + PartitionedFile::new_from_meta(meta) + .with_extensions(Arc::new(String::from(EXPECTED_USER_DEFINED_METADATA))) }) .collect(); let source = Arc::new( - ParquetSource::default() + ParquetSource::new(file_schema.clone()) // prepare the scan .with_parquet_file_reader_factory(Arc::new( InMemoryParquetFileReaderFactory(Arc::clone(&in_memory_object_store)), @@ -89,7 +85,6 @@ async fn route_data_access_ops_to_parquet_file_reader_factory() { let base_config = FileScanConfigBuilder::new( // just any url that doesn't point to in memory object store ObjectStoreUrl::local_filesystem(), - file_schema, source, ) .with_file_group(file_group) @@ -119,11 +114,11 @@ impl ParquetFileReaderFactory for InMemoryParquetFileReaderFactory { fn create_reader( &self, partition_index: usize, - file_meta: FileMeta, + partitioned_file: PartitionedFile, metadata_size_hint: Option, metrics: &ExecutionPlanMetricsSet, ) -> Result> { - let metadata = file_meta + let metadata = partitioned_file .extensions .as_ref() .expect("has user defined metadata"); @@ -135,13 +130,13 @@ impl ParquetFileReaderFactory for InMemoryParquetFileReaderFactory { let parquet_file_metrics = ParquetFileMetrics::new( partition_index, - file_meta.location().as_ref(), + partitioned_file.object_meta.location.as_ref(), metrics, ); Ok(Box::new(ParquetFileReader { store: Arc::clone(&self.0), - meta: file_meta.object_meta, + meta: partitioned_file.object_meta, metrics: parquet_file_metrics, metadata_size_hint, })) @@ -237,18 +232,16 @@ impl AsyncFileReader for ParquetFileReader { _options: Option<&ArrowReaderOptions>, ) -> BoxFuture<'_, parquet::errors::Result>> { Box::pin(async move { - let metadata = fetch_parquet_metadata( - self.store.as_ref(), - &self.meta, - self.metadata_size_hint, - ) - .await - .map_err(|e| { - ParquetError::General(format!( - "AsyncChunkReader::get_metadata error: {e}" - )) - })?; - Ok(Arc::new(metadata)) + let metadata = DFParquetMetadata::new(self.store.as_ref(), &self.meta) + .with_metadata_size_hint(self.metadata_size_hint) + .fetch_metadata() + .await + .map_err(|e| { + ParquetError::General(format!( + "AsyncChunkReader::get_metadata error: {e}" + )) + })?; + Ok(metadata) }) } } diff --git a/datafusion/core/tests/parquet/encryption.rs b/datafusion/core/tests/parquet/encryption.rs new file mode 100644 index 0000000000000..8b3170e367457 --- /dev/null +++ b/datafusion/core/tests/parquet/encryption.rs @@ -0,0 +1,370 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for reading and writing Parquet files that use Parquet modular encryption + +use arrow::array::{ArrayRef, Int32Array, StringArray}; +use arrow::record_batch::RecordBatch; +use arrow_schema::{DataType, SchemaRef}; +use async_trait::async_trait; +use datafusion::dataframe::DataFrameWriteOptions; +use datafusion::datasource::listing::ListingOptions; +use datafusion::prelude::{ParquetReadOptions, SessionContext}; +use datafusion_common::config::{EncryptionFactoryOptions, TableParquetOptions}; +use datafusion_common::{DataFusionError, assert_batches_sorted_eq, exec_datafusion_err}; +use datafusion_datasource_parquet::ParquetFormat; +use datafusion_execution::parquet_encryption::EncryptionFactory; +use parquet::arrow::ArrowWriter; +use parquet::arrow::arrow_reader::{ArrowReaderMetadata, ArrowReaderOptions}; +use parquet::encryption::decrypt::FileDecryptionProperties; +use parquet::encryption::encrypt::FileEncryptionProperties; +use parquet::file::column_crypto_metadata::ColumnCryptoMetaData; +use parquet::file::properties::WriterProperties; +use std::collections::HashMap; +use std::fs::File; +use std::path::{Path, PathBuf}; +use std::sync::atomic::{AtomicU8, Ordering}; +use std::sync::{Arc, Mutex}; +use tempfile::TempDir; + +async fn read_parquet_test_data<'a, T: Into>( + path: T, + ctx: &SessionContext, + options: ParquetReadOptions<'a>, +) -> Vec { + ctx.read_parquet(path.into(), options) + .await + .unwrap() + .collect() + .await + .unwrap() +} + +#[expect(clippy::needless_pass_by_value)] +pub fn write_batches( + path: PathBuf, + props: WriterProperties, + batches: impl IntoIterator, +) -> datafusion_common::Result { + let mut batches = batches.into_iter(); + let first_batch = batches.next().expect("need at least one record batch"); + let schema = first_batch.schema(); + + let file = File::create(&path)?; + let mut writer = ArrowWriter::try_new(file, Arc::clone(&schema), Some(props))?; + + writer.write(&first_batch)?; + let mut num_rows = first_batch.num_rows(); + + for batch in batches { + writer.write(&batch)?; + num_rows += batch.num_rows(); + } + writer.close()?; + Ok(num_rows) +} + +#[tokio::test] +async fn round_trip_encryption() { + let ctx: SessionContext = SessionContext::new(); + + let options = ParquetReadOptions::default(); + let batches = read_parquet_test_data( + "tests/data/filter_pushdown/single_file.gz.parquet", + &ctx, + options, + ) + .await; + + let schema = batches[0].schema(); + let footer_key = b"0123456789012345".to_vec(); // 128bit/16 + let column_key = b"1234567890123450".to_vec(); // 128bit/16 + + let mut encrypt = FileEncryptionProperties::builder(footer_key.clone()); + let mut decrypt = FileDecryptionProperties::builder(footer_key.clone()); + + for field in schema.fields.iter() { + encrypt = encrypt.with_column_key(field.name().as_str(), column_key.clone()); + decrypt = decrypt.with_column_key(field.name().as_str(), column_key.clone()); + } + let encrypt = encrypt.build().unwrap(); + let decrypt = decrypt.build().unwrap(); + + // Write encrypted parquet + let props = WriterProperties::builder() + .with_file_encryption_properties(encrypt) + .build(); + + let tempdir = TempDir::new_in(Path::new(".")).unwrap(); + let tempfile = tempdir.path().join("data.parquet"); + let num_rows_written = write_batches(tempfile.clone(), props, batches).unwrap(); + + // Read encrypted parquet + let ctx: SessionContext = SessionContext::new(); + let options = + ParquetReadOptions::default().file_decryption_properties((&decrypt).into()); + + let encrypted_batches = read_parquet_test_data( + tempfile.into_os_string().into_string().unwrap(), + &ctx, + options, + ) + .await; + + let num_rows_read = encrypted_batches + .iter() + .fold(0, |acc, x| acc + x.num_rows()); + + assert_eq!(num_rows_written, num_rows_read); +} + +#[tokio::test] +async fn round_trip_parquet_with_encryption_factory() { + let ctx = SessionContext::new(); + let encryption_factory = Arc::new(MockEncryptionFactory::default()); + ctx.runtime_env().register_parquet_encryption_factory( + "test_encryption_factory", + Arc::clone(&encryption_factory) as Arc, + ); + + let tmpdir = TempDir::new().unwrap(); + + // Register some simple test data + let strings: ArrayRef = + Arc::new(StringArray::from(vec!["a", "b", "c", "a", "b", "c"])); + let x1: ArrayRef = Arc::new(Int32Array::from(vec![1, 10, 11, 100, 101, 111])); + let x2: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6])); + let batch = + RecordBatch::try_from_iter(vec![("string", strings), ("x1", x1), ("x2", x2)]) + .unwrap(); + let test_data_schema = batch.schema(); + ctx.register_batch("test_data", batch).unwrap(); + let df = ctx.table("test_data").await.unwrap(); + + // Write encrypted Parquet, partitioned by string column into separate files + let mut parquet_options = TableParquetOptions::new(); + parquet_options.crypto.factory_id = Some("test_encryption_factory".to_string()); + parquet_options + .crypto + .factory_options + .options + .insert("test_key".to_string(), "test value".to_string()); + + let df_write_options = + DataFrameWriteOptions::default().with_partition_by(vec!["string".to_string()]); + df.write_parquet( + tmpdir.path().to_str().unwrap(), + df_write_options, + Some(parquet_options.clone()), + ) + .await + .unwrap(); + + // Crypto factory should have generated one key per partition file + assert_eq!(encryption_factory.encryption_keys.lock().unwrap().len(), 3); + + verify_table_encrypted(tmpdir.path(), &encryption_factory) + .await + .unwrap(); + + // Registering table without decryption properties should fail + let table_path = format!("file://{}/", tmpdir.path().to_str().unwrap()); + let without_decryption_register = ctx + .register_listing_table( + "parquet_missing_decryption", + &table_path, + ListingOptions::new(Arc::new(ParquetFormat::default())), + None, + None, + ) + .await; + assert!(matches!( + without_decryption_register.unwrap_err(), + DataFusionError::ParquetError(_) + )); + + // Registering table succeeds if schema is provided + ctx.register_listing_table( + "parquet_missing_decryption", + &table_path, + ListingOptions::new(Arc::new(ParquetFormat::default())), + Some(test_data_schema), + None, + ) + .await + .unwrap(); + + // But trying to read from the table should fail + let without_decryption_read = ctx + .table("parquet_missing_decryption") + .await + .unwrap() + .collect() + .await; + assert!(matches!( + without_decryption_read.unwrap_err(), + DataFusionError::ParquetError(_) + )); + + // Register table with encryption factory specified + let listing_options = ListingOptions::new(Arc::new( + ParquetFormat::default().with_options(parquet_options), + )) + .with_table_partition_cols(vec![("string".to_string(), DataType::Utf8)]); + ctx.register_listing_table( + "parquet_with_decryption", + &table_path, + listing_options, + None, + None, + ) + .await + .unwrap(); + + // Can read correct data when encryption factory has been specified + let table = ctx + .table("parquet_with_decryption") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let expected = [ + "+-----+----+--------+", + "| x1 | x2 | string |", + "+-----+----+--------+", + "| 1 | 1 | a |", + "| 100 | 4 | a |", + "| 10 | 2 | b |", + "| 101 | 5 | b |", + "| 11 | 3 | c |", + "| 111 | 6 | c |", + "+-----+----+--------+", + ]; + assert_batches_sorted_eq!(expected, &table); +} + +async fn verify_table_encrypted( + table_path: &Path, + encryption_factory: &Arc, +) -> datafusion_common::Result<()> { + let mut directories = vec![table_path.to_path_buf()]; + let mut files_visited = 0; + while let Some(directory) = directories.pop() { + for entry in std::fs::read_dir(&directory)? { + let path = entry?.path(); + if path.is_dir() { + directories.push(path); + } else { + verify_file_encrypted(&path, encryption_factory).await?; + files_visited += 1; + } + } + } + assert!(files_visited > 0); + Ok(()) +} + +async fn verify_file_encrypted( + file_path: &Path, + encryption_factory: &Arc, +) -> datafusion_common::Result<()> { + let mut options = EncryptionFactoryOptions::default(); + options + .options + .insert("test_key".to_string(), "test value".to_string()); + + let file_path_str = if cfg!(target_os = "windows") { + // Windows backslashes are eventually converted to slashes when writing the Parquet files, + // through `ListingTableUrl::parse`, making `encryption_factory.encryption_keys` store them + // it that format. So we also replace backslashes here to ensure they match. + file_path.to_str().unwrap().replace("\\", "/") + } else { + file_path.to_str().unwrap().to_owned() + }; + + let object_path = object_store::path::Path::from(file_path_str); + let decryption_properties = encryption_factory + .get_file_decryption_properties(&options, &object_path) + .await? + .unwrap(); + + let reader_options = + ArrowReaderOptions::new().with_file_decryption_properties(decryption_properties); + let file = File::open(file_path)?; + let reader_metadata = ArrowReaderMetadata::load(&file, reader_options)?; + let metadata = reader_metadata.metadata(); + assert!(metadata.num_row_groups() > 0); + for row_group in metadata.row_groups() { + assert!(row_group.num_columns() > 0); + for col in row_group.columns() { + assert!(matches!( + col.crypto_metadata(), + Some(ColumnCryptoMetaData::ENCRYPTION_WITH_FOOTER_KEY) + )); + } + } + Ok(()) +} + +/// Encryption factory implementation for use in tests, +/// which generates encryption keys in a sequence +#[derive(Debug, Default)] +struct MockEncryptionFactory { + pub encryption_keys: Mutex>>, + pub counter: AtomicU8, +} + +#[async_trait] +impl EncryptionFactory for MockEncryptionFactory { + async fn get_file_encryption_properties( + &self, + config: &EncryptionFactoryOptions, + _schema: &SchemaRef, + file_path: &object_store::path::Path, + ) -> datafusion_common::Result>> { + assert_eq!( + config.options.get("test_key"), + Some(&"test value".to_string()) + ); + let file_idx = self.counter.fetch_add(1, Ordering::Relaxed); + let key = vec![file_idx, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; + let mut keys = self.encryption_keys.lock().unwrap(); + keys.insert(file_path.clone(), key.clone()); + let encryption_properties = FileEncryptionProperties::builder(key).build()?; + Ok(Some(encryption_properties)) + } + + async fn get_file_decryption_properties( + &self, + config: &EncryptionFactoryOptions, + file_path: &object_store::path::Path, + ) -> datafusion_common::Result>> { + assert_eq!( + config.options.get("test_key"), + Some(&"test value".to_string()) + ); + let keys = self.encryption_keys.lock().unwrap(); + let key = keys + .get(file_path) + .ok_or_else(|| exec_datafusion_err!("No key for file {file_path:?}"))?; + let decryption_properties = + FileDecryptionProperties::builder(key.clone()).build()?; + Ok(Some(decryption_properties)) + } +} diff --git a/datafusion/core/tests/parquet/expr_adapter.rs b/datafusion/core/tests/parquet/expr_adapter.rs new file mode 100644 index 0000000000000..f412cdf9bd7a6 --- /dev/null +++ b/datafusion/core/tests/parquet/expr_adapter.rs @@ -0,0 +1,608 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{ + Array, ArrayRef, BooleanArray, Int32Array, Int64Array, RecordBatch, StringArray, + StructArray, record_batch, +}; +use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef}; +use bytes::{BufMut, BytesMut}; +use datafusion::assert_batches_eq; +use datafusion::common::Result; +use datafusion::datasource::listing::{ + ListingTable, ListingTableConfig, ListingTableConfigExt, +}; +use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion_common::DataFusionError; +use datafusion_common::ScalarValue; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_datasource::ListingTableUrl; +use datafusion_execution::object_store::ObjectStoreUrl; +use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr::expressions::{self, Column}; +use datafusion_physical_expr_adapter::{ + DefaultPhysicalExprAdapter, DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, + PhysicalExprAdapterFactory, +}; +use object_store::{ObjectStore, ObjectStoreExt, memory::InMemory, path::Path}; +use parquet::arrow::ArrowWriter; + +async fn write_parquet(batch: RecordBatch, store: Arc, path: &str) { + let mut out = BytesMut::new().writer(); + { + let mut writer = ArrowWriter::try_new(&mut out, batch.schema(), None).unwrap(); + writer.write(&batch).unwrap(); + writer.finish().unwrap(); + } + let data = out.into_inner().freeze(); + store.put(&Path::from(path), data.into()).await.unwrap(); +} + +// Implement a custom PhysicalExprAdapterFactory that fills in missing columns with +// the default value for the field type: +// - Int64 columns are filled with `1` +// - Utf8 columns are filled with `'b'` +#[derive(Debug)] +struct CustomPhysicalExprAdapterFactory; + +impl PhysicalExprAdapterFactory for CustomPhysicalExprAdapterFactory { + fn create( + &self, + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + ) -> Result> { + Ok(Arc::new(CustomPhysicalExprAdapter { + logical_file_schema: Arc::clone(&logical_file_schema), + physical_file_schema: Arc::clone(&physical_file_schema), + inner: Arc::new(DefaultPhysicalExprAdapter::new( + logical_file_schema, + physical_file_schema, + )), + })) + } +} + +#[derive(Debug, Clone)] +struct CustomPhysicalExprAdapter { + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + inner: Arc, +} + +impl PhysicalExprAdapter for CustomPhysicalExprAdapter { + fn rewrite(&self, mut expr: Arc) -> Result> { + expr = expr + .transform(|expr| { + if let Some(column) = expr.as_any().downcast_ref::() { + let field_name = column.name(); + if self + .physical_file_schema + .field_with_name(field_name) + .ok() + .is_none() + { + let field = self + .logical_file_schema + .field_with_name(field_name) + .map_err(|_| { + DataFusionError::Plan(format!( + "Field '{field_name}' not found in logical file schema", + )) + })?; + // If the field does not exist, create a default value expression + // Note that we use slightly different logic here to create a default value so that we can see different behavior in tests + let default_value = match field.data_type() { + DataType::Int64 => ScalarValue::Int64(Some(1)), + DataType::Utf8 => ScalarValue::Utf8(Some("b".to_string())), + _ => unimplemented!( + "Unsupported data type: {}", + field.data_type() + ), + }; + return Ok(Transformed::yes(Arc::new( + expressions::Literal::new(default_value), + ))); + } + } + + Ok(Transformed::no(expr)) + }) + .data()?; + self.inner.rewrite(expr) + } +} + +#[tokio::test] +async fn test_custom_schema_adapter_and_custom_expression_adapter() { + let batch = + record_batch!(("extra", Int64, [1, 2, 3]), ("c1", Int32, [1, 2, 3])).unwrap(); + + let store = Arc::new(InMemory::new()) as Arc; + let store_url = ObjectStoreUrl::parse("memory://").unwrap(); + let path = "test.parquet"; + write_parquet(batch, store.clone(), path).await; + + let table_schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int64, false), + Field::new("c2", DataType::Utf8, true), + ])); + + let mut cfg = SessionConfig::new() + // Disable statistics collection for this test otherwise early pruning makes it hard to demonstrate data adaptation + .with_collect_statistics(false) + .with_parquet_pruning(false) + .with_parquet_page_index_pruning(false); + cfg.options_mut().execution.parquet.pushdown_filters = true; + let ctx = SessionContext::new_with_config(cfg); + ctx.register_object_store(store_url.as_ref(), Arc::clone(&store)); + assert!( + !ctx.state() + .config_mut() + .options_mut() + .execution + .collect_statistics + ); + assert!(!ctx.state().config().collect_statistics()); + + // Test with DefaultPhysicalExprAdapterFactory - missing columns are filled with NULL + let listing_table_config = + ListingTableConfig::new(ListingTableUrl::parse("memory:///").unwrap()) + .infer_options(&ctx.state()) + .await + .unwrap() + .with_schema(table_schema.clone()) + .with_expr_adapter_factory(Arc::new(DefaultPhysicalExprAdapterFactory)); + + let table = ListingTable::try_new(listing_table_config).unwrap(); + ctx.register_table("t", Arc::new(table)).unwrap(); + + let batches = ctx + .sql("SELECT c2, c1 FROM t WHERE c1 = 2 AND c2 IS NULL") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let expected = [ + "+----+----+", + "| c2 | c1 |", + "+----+----+", + "| | 2 |", + "+----+----+", + ]; + assert_batches_eq!(expected, &batches); + + // Test with a custom physical expr adapter + // PhysicalExprAdapterFactory now handles both predicates AND projections + // CustomPhysicalExprAdapterFactory fills missing columns with 'b' for Utf8 + let listing_table_config = + ListingTableConfig::new(ListingTableUrl::parse("memory:///").unwrap()) + .infer_options(&ctx.state()) + .await + .unwrap() + .with_schema(table_schema.clone()) + .with_expr_adapter_factory(Arc::new(CustomPhysicalExprAdapterFactory)); + let table = ListingTable::try_new(listing_table_config).unwrap(); + ctx.deregister_table("t").unwrap(); + ctx.register_table("t", Arc::new(table)).unwrap(); + let batches = ctx + .sql("SELECT c2, c1 FROM t WHERE c1 = 2 AND c2 = 'b'") + .await + .unwrap() + .collect() + .await + .unwrap(); + // With CustomPhysicalExprAdapterFactory, missing column c2 is filled with 'b' + // in both the predicate (c2 = 'b' becomes 'b' = 'b' -> true) and the projection + let expected = [ + "+----+----+", + "| c2 | c1 |", + "+----+----+", + "| b | 2 |", + "+----+----+", + ]; + assert_batches_eq!(expected, &batches); +} + +/// Test demonstrating how to implement a custom PhysicalExprAdapterFactory +/// that fills missing columns with non-null default values. +/// +/// PhysicalExprAdapterFactory rewrites expressions to use literals for +/// missing columns, handling schema evolution efficiently at planning time. +#[tokio::test] +async fn test_physical_expr_adapter_with_non_null_defaults() { + // File only has c1 column + let batch = record_batch!(("c1", Int32, [10, 20, 30])).unwrap(); + + let store = Arc::new(InMemory::new()) as Arc; + let store_url = ObjectStoreUrl::parse("memory://").unwrap(); + write_parquet(batch, store.clone(), "defaults_test.parquet").await; + + // Table schema has additional columns c2 (Utf8) and c3 (Int64) that don't exist in file + let table_schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int64, false), // type differs from file (Int32 vs Int64) + Field::new("c2", DataType::Utf8, true), // missing from file + Field::new("c3", DataType::Int64, true), // missing from file + ])); + + let mut cfg = SessionConfig::new() + .with_collect_statistics(false) + .with_parquet_pruning(false); + cfg.options_mut().execution.parquet.pushdown_filters = true; + let ctx = SessionContext::new_with_config(cfg); + ctx.register_object_store(store_url.as_ref(), Arc::clone(&store)); + + // CustomPhysicalExprAdapterFactory fills: + // - missing Utf8 columns with 'b' + // - missing Int64 columns with 1 + let listing_table_config = + ListingTableConfig::new(ListingTableUrl::parse("memory:///").unwrap()) + .infer_options(&ctx.state()) + .await + .unwrap() + .with_schema(table_schema.clone()) + .with_expr_adapter_factory(Arc::new(CustomPhysicalExprAdapterFactory)); + + let table = ListingTable::try_new(listing_table_config).unwrap(); + ctx.register_table("t", Arc::new(table)).unwrap(); + + // Query all columns - missing columns should have default values + let batches = ctx + .sql("SELECT c1, c2, c3 FROM t ORDER BY c1") + .await + .unwrap() + .collect() + .await + .unwrap(); + + // c1 is cast from Int32 to Int64, c2 defaults to 'b', c3 defaults to 1 + let expected = [ + "+----+----+----+", + "| c1 | c2 | c3 |", + "+----+----+----+", + "| 10 | b | 1 |", + "| 20 | b | 1 |", + "| 30 | b | 1 |", + "+----+----+----+", + ]; + assert_batches_eq!(expected, &batches); + + // Verify predicates work with default values + // c3 = 1 should match all rows since default is 1 + let batches = ctx + .sql("SELECT c1 FROM t WHERE c3 = 1 ORDER BY c1") + .await + .unwrap() + .collect() + .await + .unwrap(); + + #[rustfmt::skip] + let expected = [ + "+----+", + "| c1 |", + "+----+", + "| 10 |", + "| 20 |", + "| 30 |", + "+----+", + ]; + assert_batches_eq!(expected, &batches); + + // c3 = 999 should match no rows + let batches = ctx + .sql("SELECT c1 FROM t WHERE c3 = 999") + .await + .unwrap() + .collect() + .await + .unwrap(); + + #[rustfmt::skip] + let expected = [ + "++", + "++", + ]; + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_struct_schema_evolution_projection_and_filter() -> Result<()> { + use std::collections::HashMap; + + // Physical struct: {id: Int32, name: Utf8} + let physical_struct_fields: Fields = vec![ + Arc::new(Field::new("id", DataType::Int32, false)), + Arc::new(Field::new("name", DataType::Utf8, true)), + ] + .into(); + + let struct_array = StructArray::new( + physical_struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef, + ], + None, + ); + + let physical_schema = Arc::new(Schema::new(vec![Field::new( + "s", + DataType::Struct(physical_struct_fields), + true, + )])); + + let batch = + RecordBatch::try_new(Arc::clone(&physical_schema), vec![Arc::new(struct_array)])?; + + let store = Arc::new(InMemory::new()) as Arc; + let store_url = ObjectStoreUrl::parse("memory://").unwrap(); + write_parquet(batch, store.clone(), "struct_evolution.parquet").await; + + // Logical struct: {id: Int64?, name: Utf8?, extra: Boolean?} + metadata + let logical_struct_fields: Fields = vec![ + Arc::new(Field::new("id", DataType::Int64, true)), + Arc::new(Field::new("name", DataType::Utf8, true)), + Arc::new(Field::new("extra", DataType::Boolean, true).with_metadata( + HashMap::from([("nested_meta".to_string(), "1".to_string())]), + )), + ] + .into(); + + let table_schema = Arc::new(Schema::new(vec![ + Field::new("s", DataType::Struct(logical_struct_fields), false) + .with_metadata(HashMap::from([("top_meta".to_string(), "1".to_string())])), + ])); + + let mut cfg = SessionConfig::new() + .with_collect_statistics(false) + .with_parquet_pruning(false) + .with_parquet_page_index_pruning(false); + cfg.options_mut().execution.parquet.pushdown_filters = true; + + let ctx = SessionContext::new_with_config(cfg); + ctx.register_object_store(store_url.as_ref(), Arc::clone(&store)); + + let listing_table_config = + ListingTableConfig::new(ListingTableUrl::parse("memory:///").unwrap()) + .infer_options(&ctx.state()) + .await + .unwrap() + .with_schema(table_schema.clone()) + .with_expr_adapter_factory(Arc::new(DefaultPhysicalExprAdapterFactory)); + + let table = ListingTable::try_new(listing_table_config).unwrap(); + ctx.register_table("t", Arc::new(table)).unwrap(); + + let batches = ctx + .sql("SELECT s FROM t") + .await + .unwrap() + .collect() + .await + .unwrap(); + assert_eq!(batches.len(), 1); + + // Verify top-level metadata propagation + let output_schema = batches[0].schema(); + let s_field = output_schema.field_with_name("s").unwrap(); + assert_eq!( + s_field.metadata().get("top_meta").map(String::as_str), + Some("1") + ); + + // Verify nested struct type/field propagation + values + let s_array = batches[0] + .column(0) + .as_any() + .downcast_ref::() + .expect("expected struct array"); + + let id_array = s_array + .column_by_name("id") + .expect("id column") + .as_any() + .downcast_ref::() + .expect("id should be cast to Int64"); + assert_eq!(id_array.values(), &[1, 2, 3]); + + let extra_array = s_array.column_by_name("extra").expect("extra column"); + assert_eq!(extra_array.null_count(), 3); + + // Verify nested field metadata propagation + let extra_field = match s_field.data_type() { + DataType::Struct(fields) => fields + .iter() + .find(|f| f.name() == "extra") + .expect("extra field"), + other => panic!("expected struct type for s, got {other:?}"), + }; + assert_eq!( + extra_field + .metadata() + .get("nested_meta") + .map(String::as_str), + Some("1") + ); + + // Smoke test: filtering on a missing nested field evaluates correctly + let filtered = ctx + .sql("SELECT get_field(s, 'extra') AS extra FROM t WHERE get_field(s, 'extra') IS NULL") + .await + .unwrap() + .collect() + .await + .unwrap(); + assert_eq!(filtered.len(), 1); + assert_eq!(filtered[0].num_rows(), 3); + let extra = filtered[0] + .column(0) + .as_any() + .downcast_ref::() + .expect("extra should be a boolean array"); + assert_eq!(extra.null_count(), 3); + + Ok(()) +} + +/// Test demonstrating that a single PhysicalExprAdapterFactory instance can be +/// reused across multiple ListingTable instances. +/// +/// This addresses the concern: "This is important for ListingTable. A test for +/// ListingTable would add assurance that the functionality is retained [i.e. we +/// can re-use a PhysicalExprAdapterFactory]" +#[tokio::test] +async fn test_physical_expr_adapter_factory_reuse_across_tables() { + // Create two different parquet files with different schemas + // File 1: has column c1 only + let batch1 = record_batch!(("c1", Int32, [1, 2, 3])).unwrap(); + // File 2: has column c1 only but different data + let batch2 = record_batch!(("c1", Int32, [10, 20, 30])).unwrap(); + + let store = Arc::new(InMemory::new()) as Arc; + let store_url = ObjectStoreUrl::parse("memory://").unwrap(); + + // Write files to different paths + write_parquet(batch1, store.clone(), "table1/data.parquet").await; + write_parquet(batch2, store.clone(), "table2/data.parquet").await; + + // Table schema has additional columns that don't exist in files + let table_schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int64, false), + Field::new("c2", DataType::Utf8, true), // missing from files + ])); + + let mut cfg = SessionConfig::new() + .with_collect_statistics(false) + .with_parquet_pruning(false); + cfg.options_mut().execution.parquet.pushdown_filters = true; + let ctx = SessionContext::new_with_config(cfg); + ctx.register_object_store(store_url.as_ref(), Arc::clone(&store)); + + // Create ONE factory instance wrapped in Arc - this will be REUSED + let factory: Arc = + Arc::new(CustomPhysicalExprAdapterFactory); + + // Create ListingTable 1 using the shared factory + let listing_table_config1 = + ListingTableConfig::new(ListingTableUrl::parse("memory:///table1/").unwrap()) + .infer_options(&ctx.state()) + .await + .unwrap() + .with_schema(table_schema.clone()) + .with_expr_adapter_factory(Arc::clone(&factory)); // Clone the Arc, not create new factory + + let table1 = ListingTable::try_new(listing_table_config1).unwrap(); + ctx.register_table("t1", Arc::new(table1)).unwrap(); + + // Create ListingTable 2 using the SAME factory instance + let listing_table_config2 = + ListingTableConfig::new(ListingTableUrl::parse("memory:///table2/").unwrap()) + .infer_options(&ctx.state()) + .await + .unwrap() + .with_schema(table_schema.clone()) + .with_expr_adapter_factory(Arc::clone(&factory)); // Reuse same factory + + let table2 = ListingTable::try_new(listing_table_config2).unwrap(); + ctx.register_table("t2", Arc::new(table2)).unwrap(); + + // Verify table 1 works correctly with the shared factory + // CustomPhysicalExprAdapterFactory fills missing Utf8 columns with 'b' + let batches = ctx + .sql("SELECT c1, c2 FROM t1 ORDER BY c1") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let expected = [ + "+----+----+", + "| c1 | c2 |", + "+----+----+", + "| 1 | b |", + "| 2 | b |", + "| 3 | b |", + "+----+----+", + ]; + assert_batches_eq!(expected, &batches); + + // Verify table 2 also works correctly with the SAME shared factory + let batches = ctx + .sql("SELECT c1, c2 FROM t2 ORDER BY c1") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let expected = [ + "+----+----+", + "| c1 | c2 |", + "+----+----+", + "| 10 | b |", + "| 20 | b |", + "| 30 | b |", + "+----+----+", + ]; + assert_batches_eq!(expected, &batches); + + // Verify predicates work on both tables with the shared factory + let batches = ctx + .sql("SELECT c1 FROM t1 WHERE c2 = 'b' ORDER BY c1") + .await + .unwrap() + .collect() + .await + .unwrap(); + + #[rustfmt::skip] + let expected = [ + "+----+", + "| c1 |", + "+----+", + "| 1 |", + "| 2 |", + "| 3 |", + "+----+", + ]; + assert_batches_eq!(expected, &batches); + + let batches = ctx + .sql("SELECT c1 FROM t2 WHERE c2 = 'b' ORDER BY c1") + .await + .unwrap() + .collect() + .await + .unwrap(); + + #[rustfmt::skip] + let expected = [ + "+----+", + "| c1 |", + "+----+", + "| 10 |", + "| 20 |", + "| 30 |", + "+----+", + ]; + assert_batches_eq!(expected, &batches); +} diff --git a/datafusion/core/tests/parquet/external_access_plan.rs b/datafusion/core/tests/parquet/external_access_plan.rs index a5397c5a397ca..9ff8137687c95 100644 --- a/datafusion/core/tests/parquet/external_access_plan.rs +++ b/datafusion/core/tests/parquet/external_access_plan.rs @@ -21,7 +21,7 @@ use std::path::Path; use std::sync::Arc; use crate::parquet::utils::MetricsFinder; -use crate::parquet::{create_data_batch, Scenario}; +use crate::parquet::{Scenario, create_data_batch}; use arrow::datatypes::SchemaRef; use arrow::util::pretty::pretty_format_batches; @@ -29,17 +29,17 @@ use datafusion::common::Result; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::physical_plan::ParquetSource; use datafusion::prelude::SessionContext; -use datafusion_common::{assert_contains, DFSchema}; +use datafusion_common::{DFSchema, assert_contains}; use datafusion_datasource_parquet::{ParquetAccessPlan, RowGroupAccess}; use datafusion_execution::object_store::ObjectStoreUrl; -use datafusion_expr::{col, lit, Expr}; -use datafusion_physical_plan::metrics::MetricsSet; +use datafusion_expr::{Expr, col, lit}; use datafusion_physical_plan::ExecutionPlan; +use datafusion_physical_plan::metrics::{MetricValue, MetricsSet}; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_datasource::source::DataSourceExec; -use parquet::arrow::arrow_reader::{RowSelection, RowSelector}; use parquet::arrow::ArrowWriter; +use parquet::arrow::arrow_reader::{RowSelection, RowSelector}; use parquet::file::properties::WriterProperties; use tempfile::NamedTempFile; @@ -178,12 +178,21 @@ async fn plan_and_filter() { .unwrap(); // Verify that row group pruning still happens for just that group - let row_groups_pruned_statistics = - metric_value(&parquet_metrics, "row_groups_pruned_statistics").unwrap(); - assert_eq!( - row_groups_pruned_statistics, 1, - "metrics : {parquet_metrics:#?}", - ); + let row_groups_pruned_statistics = parquet_metrics + .sum_by_name("row_groups_pruned_statistics") + .unwrap(); + if let MetricValue::PruningMetrics { + pruning_metrics, .. + } = row_groups_pruned_statistics + { + assert_eq!( + pruning_metrics.pruned(), + 1, + "metrics : {parquet_metrics:#?}", + ); + } else { + unreachable!("metrics `row_groups_pruned_statistics` should exist") + } } #[tokio::test] @@ -248,7 +257,10 @@ async fn bad_selection() { .await .unwrap_err(); let err_string = err.to_string(); - assert_contains!(&err_string, "Internal error: Invalid ParquetAccessPlan Selection. Row group 0 has 5 rows but selection only specifies 4 rows"); + assert_contains!( + &err_string, + "Row group 0 has 5 rows but selection only specifies 4 rows." + ); } /// Return a RowSelection of 1 rows from a row group of 5 rows @@ -346,11 +358,11 @@ impl TestFull { let source = if let Some(predicate) = predicate { let df_schema = DFSchema::try_from(schema.clone())?; let predicate = ctx.create_physical_expr(predicate, &df_schema)?; - Arc::new(ParquetSource::default().with_predicate(predicate)) + Arc::new(ParquetSource::new(schema.clone()).with_predicate(predicate)) } else { - Arc::new(ParquetSource::default()) + Arc::new(ParquetSource::new(schema.clone())) }; - let config = FileScanConfigBuilder::new(object_store_url, schema.clone(), source) + let config = FileScanConfigBuilder::new(object_store_url, source) .with_file(partitioned_file) .build(); @@ -397,7 +409,7 @@ fn get_test_data() -> TestData { .expect("tempfile creation"); let props = WriterProperties::builder() - .set_max_row_group_size(row_per_group) + .set_max_row_group_row_count(Some(row_per_group)) .build(); let batches = create_data_batch(scenario); diff --git a/datafusion/core/tests/parquet/file_statistics.rs b/datafusion/core/tests/parquet/file_statistics.rs index a60beaf665e55..fdefdafa00aa4 100644 --- a/datafusion/core/tests/parquet/file_statistics.rs +++ b/datafusion/core/tests/parquet/file_statistics.rs @@ -18,30 +18,30 @@ use std::fs; use std::sync::Arc; +use datafusion::datasource::TableProvider; use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::{ ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, }; use datafusion::datasource::source::DataSourceExec; -use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionState; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::prelude::SessionContext; -use datafusion_common::stats::Precision; use datafusion_common::DFSchema; +use datafusion_common::stats::Precision; +use datafusion_execution::cache::DefaultListFilesCache; use datafusion_execution::cache::cache_manager::CacheManagerConfig; -use datafusion_execution::cache::cache_unit::{ - DefaultFileStatisticsCache, DefaultListFilesCache, -}; +use datafusion_execution::cache::cache_unit::DefaultFileStatisticsCache; use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnvBuilder; -use datafusion_expr::{col, lit, Expr}; +use datafusion_expr::{Expr, col, lit}; use datafusion::datasource::physical_plan::FileScanConfig; -use datafusion_physical_optimizer::filter_pushdown::FilterPushdown; +use datafusion_common::config::ConfigOptions; use datafusion_physical_optimizer::PhysicalOptimizerRule; -use datafusion_physical_plan::filter::FilterExec; +use datafusion_physical_optimizer::filter_pushdown::FilterPushdown; use datafusion_physical_plan::ExecutionPlan; +use datafusion_physical_plan::filter::FilterExec; use tempfile::tempdir; #[tokio::test] @@ -55,7 +55,7 @@ async fn check_stats_precision_with_filter_pushdown() { let table = get_listing_table(&table_path, None, &opt).await; let (_, _, state) = get_cache_runtime_state(); - let mut options = state.config().options().clone(); + let mut options: ConfigOptions = state.config().options().as_ref().clone(); options.execution.parquet.pushdown_filters = true; // Scan without filter, stats are exact @@ -71,7 +71,7 @@ async fn check_stats_precision_with_filter_pushdown() { // source operator after the appropriate optimizer pass. let filter_expr = Expr::gt(col("id"), lit(1)); let exec_with_filter = table - .scan(&state, None, &[filter_expr.clone()], None) + .scan(&state, None, std::slice::from_ref(&filter_expr), None) .await .unwrap(); @@ -126,8 +126,9 @@ async fn load_table_stats_with_session_level_cache() { ); assert_eq!( exec1.partition_statistics(None).unwrap().total_byte_size, - // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 - Precision::Exact(671), + // Byte size is absent because we cannot estimate the output size + // of the Arrow data since there are variable length columns. + Precision::Absent, ); assert_eq!(get_static_cache_size(&state1), 1); @@ -141,8 +142,8 @@ async fn load_table_stats_with_session_level_cache() { ); assert_eq!( exec2.partition_statistics(None).unwrap().total_byte_size, - // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 - Precision::Exact(671), + // Absent because the data contains variable length columns + Precision::Absent, ); assert_eq!(get_static_cache_size(&state2), 1); @@ -156,8 +157,8 @@ async fn load_table_stats_with_session_level_cache() { ); assert_eq!( exec3.partition_statistics(None).unwrap().total_byte_size, - // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 - Precision::Exact(671), + // Absent because the data contains variable length columns + Precision::Absent, ); // List same file no increase assert_eq!(get_static_cache_size(&state1), 1); diff --git a/datafusion/core/tests/parquet/filter_pushdown.rs b/datafusion/core/tests/parquet/filter_pushdown.rs index b8d570916c7c5..e6266b2c088d7 100644 --- a/datafusion/core/tests/parquet/filter_pushdown.rs +++ b/datafusion/core/tests/parquet/filter_pushdown.rs @@ -26,18 +26,19 @@ //! select * from data limit 10; //! ``` -use std::path::Path; - use arrow::compute::concat_batches; use arrow::record_batch::RecordBatch; use datafusion::physical_plan::collect; -use datafusion::physical_plan::metrics::MetricsSet; +use datafusion::physical_plan::metrics::{MetricValue, MetricsSet}; use datafusion::prelude::{ - col, lit, lit_timestamp_nano, Expr, ParquetReadOptions, SessionContext, + Expr, ParquetReadOptions, SessionContext, col, lit, lit_timestamp_nano, }; use datafusion::test_util::parquet::{ParquetScanOptions, TestParquetFile}; use datafusion_expr::utils::{conjunction, disjunction, split_conjunction}; +use std::path::Path; +use datafusion_common::test_util::parquet_test_data; +use datafusion_execution::config::SessionConfig; use itertools::Itertools; use parquet::file::properties::WriterProperties; use tempfile::TempDir; @@ -62,7 +63,7 @@ async fn single_file() { // Set the row group size smaller so can test with fewer rows let props = WriterProperties::builder() - .set_max_row_group_size(1024) + .set_max_row_group_row_count(Some(1024)) .build(); // Only create the parquet file once as it is fairly large @@ -219,7 +220,6 @@ async fn single_file() { } #[tokio::test] -#[allow(dead_code)] async fn single_file_small_data_pages() { let batches = read_parquet_test_data( "tests/data/filter_pushdown/single_file_small_pages.gz.parquet", @@ -230,7 +230,7 @@ async fn single_file_small_data_pages() { // Set a low row count limit to improve page filtering let props = WriterProperties::builder() - .set_max_row_group_size(2048) + .set_max_row_group_row_count(Some(2048)) .set_data_page_row_count_limit(512) .set_write_batch_size(512) .build(); @@ -562,9 +562,9 @@ impl<'a> TestCase<'a> { } }; - let page_index_rows_pruned = get_value(&metrics, "page_index_rows_pruned"); + let (page_index_rows_pruned, page_index_rows_matched) = + get_pruning_metrics(&metrics, "page_index_rows_pruned"); println!(" page_index_rows_pruned: {page_index_rows_pruned}"); - let page_index_rows_matched = get_value(&metrics, "page_index_rows_matched"); println!(" page_index_rows_matched: {page_index_rows_matched}"); let page_index_filtering_expected = if scan_options.enable_page_index { @@ -591,13 +591,158 @@ impl<'a> TestCase<'a> { } } +fn get_pruning_metrics(metrics: &MetricsSet, metric_name: &str) -> (usize, usize) { + match metrics.sum_by_name(metric_name) { + Some(MetricValue::PruningMetrics { + pruning_metrics, .. + }) => (pruning_metrics.pruned(), pruning_metrics.matched()), + Some(_) => { + panic!("Metric '{metric_name}' is not a pruning metric in\n\n{metrics:#?}") + } + None => panic!( + "Expected metric not found. Looking for '{metric_name}' in\n\n{metrics:#?}" + ), + } +} + fn get_value(metrics: &MetricsSet, metric_name: &str) -> usize { match metrics.sum_by_name(metric_name) { + Some(MetricValue::PruningMetrics { + pruning_metrics, .. + }) => pruning_metrics.pruned(), Some(v) => v.as_usize(), - _ => { - panic!( - "Expected metric not found. Looking for '{metric_name}' in\n\n{metrics:#?}" - ); - } + None => panic!( + "Expected metric not found. Looking for '{metric_name}' in\n\n{metrics:#?}" + ), + } +} + +#[tokio::test] +async fn predicate_cache_default() -> datafusion_common::Result<()> { + let ctx = SessionContext::new(); + // The cache is on by default, but not used unless filter pushdown is enabled + PredicateCacheTest { + expected_inner_records: 0, + expected_records: 0, + } + .run(&ctx) + .await +} + +#[tokio::test] +async fn predicate_cache_pushdown_default() -> datafusion_common::Result<()> { + let mut config = SessionConfig::new(); + config.options_mut().execution.parquet.pushdown_filters = true; + let ctx = SessionContext::new_with_config(config); + // The cache is on by default, and used when filter pushdown is enabled + PredicateCacheTest { + expected_inner_records: 8, + expected_records: 7, // reads more than necessary from the cache as then another bitmap is applied + } + .run(&ctx) + .await +} + +#[tokio::test] +async fn predicate_cache_stats_issue_19561() -> datafusion_common::Result<()> { + let mut config = SessionConfig::new(); + config.options_mut().execution.parquet.pushdown_filters = true; + // force to get multiple batches to trigger repeated metric compound bug + config.options_mut().execution.batch_size = 1; + let ctx = SessionContext::new_with_config(config); + // The cache is on by default, and used when filter pushdown is enabled + PredicateCacheTest { + expected_inner_records: 8, + expected_records: 4, + } + .run(&ctx) + .await +} + +#[tokio::test] +async fn predicate_cache_pushdown_default_selections_only() +-> datafusion_common::Result<()> { + let mut config = SessionConfig::new(); + config.options_mut().execution.parquet.pushdown_filters = true; + // forcing filter selections minimizes the number of rows read from the cache + config + .options_mut() + .execution + .parquet + .force_filter_selections = true; + let ctx = SessionContext::new_with_config(config); + // The cache is on by default, and used when filter pushdown is enabled + PredicateCacheTest { + expected_inner_records: 8, + expected_records: 4, + } + .run(&ctx) + .await +} + +#[tokio::test] +async fn predicate_cache_pushdown_disable() -> datafusion_common::Result<()> { + // Can disable the cache even with filter pushdown by setting the size to 0. + // This results in no records read from the cache and no metrics reported + let mut config = SessionConfig::new(); + config.options_mut().execution.parquet.pushdown_filters = true; + config + .options_mut() + .execution + .parquet + .max_predicate_cache_size = Some(0); + let ctx = SessionContext::new_with_config(config); + // Since the cache is disabled, there is no reporting or use of the cache + PredicateCacheTest { + expected_inner_records: 0, + expected_records: 0, + } + .run(&ctx) + .await +} + +/// Runs the query "SELECT * FROM alltypes_plain WHERE double_col != 0.0" +/// with a given SessionContext and asserts that the predicate cache metrics +/// are as expected +#[derive(Debug)] +struct PredicateCacheTest { + /// Expected records read from the underlying reader (to evaluate filters) + /// -- this is the total number of records in the file + expected_inner_records: usize, + /// Expected records to be read from the cache (after filtering) + expected_records: usize, +} + +impl PredicateCacheTest { + async fn run(self, ctx: &SessionContext) -> datafusion_common::Result<()> { + let Self { + expected_inner_records, + expected_records, + } = self; + // Create a dataframe that scans the "alltypes_plain.parquet" file with + // a filter on `double_col != 0.0` + let path = parquet_test_data() + "/alltypes_plain.parquet"; + let exec = ctx + .read_parquet(path, ParquetReadOptions::default()) + .await? + .filter(col("double_col").not_eq(lit(0.0)))? + .create_physical_plan() + .await?; + + // run the plan to completion + let _ = collect(exec.clone(), ctx.task_ctx()).await?; // run plan + let metrics = + TestParquetFile::parquet_metrics(&exec).expect("found parquet metrics"); + + // verify the predicate cache metrics + assert_eq!( + get_value(&metrics, "predicate_cache_inner_records"), + expected_inner_records + ); + assert_eq!( + get_value(&metrics, "predicate_cache_records"), + expected_records + ); + Ok(()) } } diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 87a5ed33f127d..0535ddd9247d4 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -19,33 +19,39 @@ use crate::parquet::utils::MetricsFinder; use arrow::{ array::{ - make_array, Array, ArrayRef, BinaryArray, Date32Array, Date64Array, - Decimal128Array, DictionaryArray, FixedSizeBinaryArray, Float64Array, Int16Array, - Int32Array, Int64Array, Int8Array, LargeBinaryArray, LargeStringArray, - StringArray, TimestampMicrosecondArray, TimestampMillisecondArray, - TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, - UInt64Array, UInt8Array, + Array, ArrayRef, BinaryArray, Date32Array, Date64Array, Decimal128Array, + DictionaryArray, FixedSizeBinaryArray, Float64Array, Int8Array, Int16Array, + Int32Array, Int64Array, LargeBinaryArray, LargeStringArray, StringArray, + TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, + TimestampSecondArray, UInt8Array, UInt16Array, UInt32Array, UInt64Array, + make_array, }, datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, util::pretty::pretty_format_batches, }; +use arrow_schema::SchemaRef; use chrono::{Datelike, Duration, TimeDelta}; use datafusion::{ - datasource::{provider_as_source, TableProvider}, + datasource::{TableProvider, provider_as_source}, physical_plan::metrics::MetricsSet, prelude::{ParquetReadOptions, SessionConfig, SessionContext}, }; use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; +use datafusion_physical_plan::metrics::MetricValue; use parquet::arrow::ArrowWriter; use parquet::file::properties::{EnabledStatistics, WriterProperties}; use std::sync::Arc; use tempfile::NamedTempFile; mod custom_reader; +#[cfg(feature = "parquet_encryption")] +mod encryption; +mod expr_adapter; mod external_access_plan; mod file_statistics; mod filter_pushdown; +mod ordering; mod page_pruning; mod row_group_pruning; mod schema; @@ -105,13 +111,33 @@ struct ContextWithParquet { ctx: SessionContext, } +struct PruningMetric { + total_pruned: usize, + total_matched: usize, + total_fully_matched: usize, +} + +impl PruningMetric { + pub fn total_pruned(&self) -> usize { + self.total_pruned + } + + pub fn total_matched(&self) -> usize { + self.total_matched + } + + pub fn total_fully_matched(&self) -> usize { + self.total_fully_matched + } +} + /// The output of running one of the test cases struct TestOutput { - /// The input string + /// The input query SQL sql: String, /// Execution metrics for the Parquet Scan parquet_metrics: MetricsSet, - /// number of rows in results + /// number of actual rows in results result_rows: usize, /// the contents of the input, as a string pretty_input: String, @@ -122,9 +148,50 @@ struct TestOutput { impl TestOutput { /// retrieve the value of the named metric, if any fn metric_value(&self, metric_name: &str) -> Option { + if let Some(pm) = self.pruning_metric(metric_name) { + return Some(pm.total_pruned()); + } + self.parquet_metrics .sum(|metric| metric.value().name() == metric_name) - .map(|v| v.as_usize()) + .map(|v| match v { + MetricValue::PruningMetrics { + pruning_metrics, .. + } => pruning_metrics.pruned(), + _ => v.as_usize(), + }) + } + + fn pruning_metric(&self, metric_name: &str) -> Option { + let mut total_pruned = 0; + let mut total_matched = 0; + let mut total_fully_matched = 0; + let mut found = false; + + for metric in self.parquet_metrics.iter() { + let metric = metric.as_ref(); + if metric.value().name() == metric_name + && let MetricValue::PruningMetrics { + pruning_metrics, .. + } = metric.value() + { + total_pruned += pruning_metrics.pruned(); + total_matched += pruning_metrics.matched(); + total_fully_matched += pruning_metrics.fully_matched(); + + found = true; + } + } + + if found { + Some(PruningMetric { + total_pruned, + total_matched, + total_fully_matched, + }) + } else { + None + } } /// The number of times the pruning predicate evaluation errors @@ -132,43 +199,63 @@ impl TestOutput { self.metric_value("predicate_evaluation_errors") } - /// The number of row_groups matched by bloom filter - fn row_groups_matched_bloom_filter(&self) -> Option { - self.metric_value("row_groups_matched_bloom_filter") - } - - /// The number of row_groups pruned by bloom filter - fn row_groups_pruned_bloom_filter(&self) -> Option { - self.metric_value("row_groups_pruned_bloom_filter") + /// The number of row_groups pruned / matched by bloom filter + fn row_groups_bloom_filter(&self) -> Option { + self.pruning_metric("row_groups_pruned_bloom_filter") } /// The number of row_groups matched by statistics fn row_groups_matched_statistics(&self) -> Option { - self.metric_value("row_groups_matched_statistics") + self.pruning_metric("row_groups_pruned_statistics") + .map(|pm| pm.total_matched()) + } + + /// The number of row_groups fully matched by statistics + fn row_groups_fully_matched_statistics(&self) -> Option { + self.pruning_metric("row_groups_pruned_statistics") + .map(|pm| pm.total_fully_matched()) } /// The number of row_groups pruned by statistics fn row_groups_pruned_statistics(&self) -> Option { - self.metric_value("row_groups_pruned_statistics") + self.pruning_metric("row_groups_pruned_statistics") + .map(|pm| pm.total_pruned()) + } + + /// Metric `files_ranges_pruned_statistics` tracks both pruned and matched count, + /// for testing purpose, here it only aggregate the `pruned` count. + fn files_ranges_pruned_statistics(&self) -> Option { + self.pruning_metric("files_ranges_pruned_statistics") + .map(|pm| pm.total_pruned()) } /// The number of row_groups matched by bloom filter or statistics + /// + /// E.g. starting with 10 row groups, statistics: 10 total -> 7 matched, bloom + /// filter: 7 total -> 3 matched, this function returns 3 for the final matched + /// count. fn row_groups_matched(&self) -> Option { - self.row_groups_matched_bloom_filter() - .zip(self.row_groups_matched_statistics()) - .map(|(a, b)| a + b) + self.row_groups_bloom_filter().map(|pm| pm.total_matched()) } /// The number of row_groups pruned fn row_groups_pruned(&self) -> Option { - self.row_groups_pruned_bloom_filter() + self.row_groups_bloom_filter() + .map(|pm| pm.total_pruned()) .zip(self.row_groups_pruned_statistics()) .map(|(a, b)| a + b) } /// The number of row pages pruned fn row_pages_pruned(&self) -> Option { - self.metric_value("page_index_rows_pruned") + self.pruning_metric("page_index_rows_pruned") + .map(|pm| pm.total_pruned()) + } + + /// The number of row groups pruned by limit pruning + fn limit_pruned_row_groups(&self) -> Option { + self.pruning_metric("limit_pruned_row_groups") + .map(|pm| pm.total_pruned()) } fn description(&self) -> String { @@ -184,18 +271,41 @@ impl TestOutput { /// and the appropriate scenario impl ContextWithParquet { async fn new(scenario: Scenario, unit: Unit) -> Self { - Self::with_config(scenario, unit, SessionConfig::new()).await + Self::with_config(scenario, unit, SessionConfig::new(), None, None).await + } + + /// Set custom schema and batches for the test + pub async fn with_custom_data( + scenario: Scenario, + unit: Unit, + schema: Arc, + batches: Vec, + ) -> Self { + Self::with_config( + scenario, + unit, + SessionConfig::new(), + Some(schema), + Some(batches), + ) + .await } async fn with_config( scenario: Scenario, unit: Unit, mut config: SessionConfig, + custom_schema: Option, + custom_batches: Option>, ) -> Self { + // Use a single partition for deterministic results no matter how many CPUs the host has + config = config.with_target_partitions(1); let file = match unit { Unit::RowGroup(row_per_group) => { config = config.with_parquet_bloom_filter_pruning(true); - make_test_file_rg(scenario, row_per_group).await + config.options_mut().execution.parquet.pushdown_filters = true; + make_test_file_rg(scenario, row_per_group, custom_schema, custom_batches) + .await } Unit::Page(row_per_page) => { config = config.with_parquet_page_index_pruning(true); @@ -466,9 +576,9 @@ fn make_uint_batches(start: u8, end: u8) -> RecordBatch { Field::new("u64", DataType::UInt64, true), ])); let v8: Vec = (start..end).collect(); - let v16: Vec = (start as _..end as _).collect(); - let v32: Vec = (start as _..end as _).collect(); - let v64: Vec = (start as _..end as _).collect(); + let v16: Vec = (start as u16..end as u16).collect(); + let v32: Vec = (start as u32..end as u32).collect(); + let v64: Vec = (start as u64..end as u64).collect(); RecordBatch::try_new( schema, vec![ @@ -602,6 +712,7 @@ fn make_date_batch(offset: Duration) -> RecordBatch { /// of the column. It is *not* a table named service.name /// /// name | service.name +#[expect(clippy::needless_pass_by_value)] fn make_bytearray_batch( name: &str, string_values: Vec<&str>, @@ -657,6 +768,7 @@ fn make_bytearray_batch( /// of the column. It is *not* a table named service.name /// /// name | service.name +#[expect(clippy::needless_pass_by_value)] fn make_names_batch(name: &str, service_name_values: Vec<&str>) -> RecordBatch { let num_rows = service_name_values.len(); let name: StringArray = std::iter::repeat_n(Some(name), num_rows).collect(); @@ -741,6 +853,7 @@ fn make_utf8_batch(value: Vec>) -> RecordBatch { .unwrap() } +#[expect(clippy::needless_pass_by_value)] fn make_dictionary_batch(strings: Vec<&str>, integers: Vec) -> RecordBatch { let keys = Int32Array::from_iter(0..strings.len() as i32); let small_keys = Int16Array::from_iter(0..strings.len() as i16); @@ -789,6 +902,7 @@ fn make_dictionary_batch(strings: Vec<&str>, integers: Vec) -> RecordBatch .unwrap() } +#[expect(clippy::needless_pass_by_value)] fn create_data_batch(scenario: Scenario) -> Vec { match scenario { Scenario::Timestamps => { @@ -1021,7 +1135,12 @@ fn create_data_batch(scenario: Scenario) -> Vec { } /// Create a test parquet file with various data types -async fn make_test_file_rg(scenario: Scenario, row_per_group: usize) -> NamedTempFile { +async fn make_test_file_rg( + scenario: Scenario, + row_per_group: usize, + custom_schema: Option, + custom_batches: Option>, +) -> NamedTempFile { let mut output_file = tempfile::Builder::new() .prefix("parquet_pruning") .suffix(".parquet") @@ -1029,13 +1148,19 @@ async fn make_test_file_rg(scenario: Scenario, row_per_group: usize) -> NamedTem .expect("tempfile creation"); let props = WriterProperties::builder() - .set_max_row_group_size(row_per_group) + .set_max_row_group_row_count(Some(row_per_group)) .set_bloom_filter_enabled(true) .set_statistics_enabled(EnabledStatistics::Page) .build(); - let batches = create_data_batch(scenario); - let schema = batches[0].schema(); + let (batches, schema) = + if let (Some(schema), Some(batches)) = (custom_schema, custom_batches) { + (batches, schema) + } else { + let batches = create_data_batch(scenario); + let schema = batches[0].schema(); + (batches, schema) + }; let mut writer = ArrowWriter::try_new(&mut output_file, schema, Some(props)).unwrap(); diff --git a/datafusion/core/tests/parquet/ordering.rs b/datafusion/core/tests/parquet/ordering.rs new file mode 100644 index 0000000000000..faecb4ca6a861 --- /dev/null +++ b/datafusion/core/tests/parquet/ordering.rs @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for ordering in Parquet sorting_columns metadata + +use datafusion::prelude::SessionContext; +use datafusion_common::Result; +use tempfile::tempdir; + +/// Test that CREATE TABLE ... WITH ORDER writes sorting_columns to Parquet metadata +#[tokio::test] +async fn test_create_table_with_order_writes_sorting_columns() -> Result<()> { + use parquet::file::reader::FileReader; + use parquet::file::serialized_reader::SerializedFileReader; + use std::fs::File; + + let ctx = SessionContext::new(); + let tmp_dir = tempdir()?; + let table_path = tmp_dir.path().join("sorted_table"); + std::fs::create_dir_all(&table_path)?; + + // Create external table with ordering + let create_table_sql = format!( + "CREATE EXTERNAL TABLE sorted_data (a INT, b VARCHAR) \ + STORED AS PARQUET \ + LOCATION '{}' \ + WITH ORDER (a ASC NULLS FIRST, b DESC NULLS LAST)", + table_path.display() + ); + ctx.sql(&create_table_sql).await?; + + // Insert sorted data + ctx.sql("INSERT INTO sorted_data VALUES (1, 'x'), (2, 'y'), (3, 'z')") + .await? + .collect() + .await?; + + // Find the parquet file that was written + let parquet_files: Vec<_> = std::fs::read_dir(&table_path)? + .filter_map(|e| e.ok()) + .filter(|e| e.path().extension().is_some_and(|ext| ext == "parquet")) + .collect(); + + assert!( + !parquet_files.is_empty(), + "Expected at least one parquet file in {}", + table_path.display() + ); + + // Read the parquet file and verify sorting_columns metadata + let file = File::open(parquet_files[0].path())?; + let reader = SerializedFileReader::new(file)?; + let metadata = reader.metadata(); + + // Check that row group has sorting_columns + let row_group = metadata.row_group(0); + let sorting_columns = row_group.sorting_columns(); + + assert!( + sorting_columns.is_some(), + "Expected sorting_columns in row group metadata" + ); + let sorting = sorting_columns.unwrap(); + assert_eq!(sorting.len(), 2, "Expected 2 sorting columns"); + + // First column: a ASC NULLS FIRST (column_idx = 0) + assert_eq!(sorting[0].column_idx, 0, "First sort column should be 'a'"); + assert!( + !sorting[0].descending, + "First column should be ASC (descending=false)" + ); + assert!( + sorting[0].nulls_first, + "First column should have NULLS FIRST" + ); + + // Second column: b DESC NULLS LAST (column_idx = 1) + assert_eq!(sorting[1].column_idx, 1, "Second sort column should be 'b'"); + assert!( + sorting[1].descending, + "Second column should be DESC (descending=true)" + ); + assert!( + !sorting[1].nulls_first, + "Second column should have NULLS LAST" + ); + + Ok(()) +} diff --git a/datafusion/core/tests/parquet/page_pruning.rs b/datafusion/core/tests/parquet/page_pruning.rs index 9da879a32f6b5..a41803191ad05 100644 --- a/datafusion/core/tests/parquet/page_pruning.rs +++ b/datafusion/core/tests/parquet/page_pruning.rs @@ -20,27 +20,35 @@ use std::sync::Arc; use crate::parquet::Unit::Page; use crate::parquet::{ContextWithParquet, Scenario}; -use datafusion::datasource::file_format::parquet::ParquetFormat; +use arrow::array::{Int32Array, RecordBatch}; +use arrow::datatypes::{DataType, Field, Schema}; use datafusion::datasource::file_format::FileFormat; +use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::ParquetSource; use datafusion::datasource::source::DataSourceExec; use datafusion::execution::context::SessionState; -use datafusion::physical_plan::metrics::MetricValue; use datafusion::physical_plan::ExecutionPlan; -use datafusion::prelude::SessionContext; +use datafusion::physical_plan::metrics::MetricValue; +use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::{ScalarValue, ToDFSchema}; use datafusion_expr::execution_props::ExecutionProps; -use datafusion_expr::{col, lit, Expr}; +use datafusion_expr::{Expr, col, lit}; use datafusion_physical_expr::create_physical_expr; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use futures::StreamExt; -use object_store::path::Path; use object_store::ObjectMeta; - -async fn get_parquet_exec(state: &SessionState, filter: Expr) -> DataSourceExec { +use object_store::path::Path; +use parquet::arrow::ArrowWriter; +use parquet::file::properties::WriterProperties; + +async fn get_parquet_exec( + state: &SessionState, + filter: Expr, + pushdown_filters: bool, +) -> DataSourceExec { let object_store_url = ObjectStoreUrl::local_filesystem(); let store = state.runtime_env().object_store(&object_store_url).unwrap(); @@ -62,63 +70,63 @@ async fn get_parquet_exec(state: &SessionState, filter: Expr) -> DataSourceExec .await .unwrap(); - let partitioned_file = PartitionedFile { - object_meta: meta, - partition_values: vec![], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - }; + let partitioned_file = PartitionedFile::new_from_meta(meta); let df_schema = schema.clone().to_dfschema().unwrap(); let execution_props = ExecutionProps::new(); let predicate = create_physical_expr(&filter, &df_schema, &execution_props).unwrap(); let source = Arc::new( - ParquetSource::default() + ParquetSource::new(schema.clone()) .with_predicate(predicate) - .with_enable_page_index(true), + .with_enable_page_index(true) + .with_pushdown_filters(pushdown_filters), ); - let base_config = FileScanConfigBuilder::new(object_store_url, schema, source) + let base_config = FileScanConfigBuilder::new(object_store_url, source) .with_file(partitioned_file) .build(); DataSourceExec::new(Arc::new(base_config)) } +async fn get_filter_results( + state: &SessionState, + filter: Expr, + pushdown_filters: bool, +) -> Vec { + let parquet_exec = get_parquet_exec(state, filter, pushdown_filters).await; + let task_ctx = state.task_ctx(); + let mut results = parquet_exec.execute(0, task_ctx.clone()).unwrap(); + let mut batches = Vec::new(); + while let Some(Ok(batch)) = results.next().await { + batches.push(batch); + } + batches +} + #[tokio::test] async fn page_index_filter_one_col() { let session_ctx = SessionContext::new(); let state = session_ctx.state(); - let task_ctx = state.task_ctx(); // 1.create filter month == 1; let filter = col("month").eq(lit(1_i32)); - let parquet_exec = get_parquet_exec(&state, filter).await; - - let mut results = parquet_exec.execute(0, task_ctx.clone()).unwrap(); - - let batch = results.next().await.unwrap().unwrap(); - + let batches = get_filter_results(&state, filter.clone(), false).await; // `month = 1` from the page index should create below RowSelection // vec.push(RowSelector::select(312)); // vec.push(RowSelector::skip(3330)); // vec.push(RowSelector::select(339)); // vec.push(RowSelector::skip(3319)); // total 651 row - assert_eq!(batch.num_rows(), 651); + assert_eq!(batches[0].num_rows(), 651); + + let batches = get_filter_results(&state, filter, true).await; + assert_eq!(batches[0].num_rows(), 620); // 2. create filter month == 1 or month == 2; let filter = col("month").eq(lit(1_i32)).or(col("month").eq(lit(2_i32))); - - let parquet_exec = get_parquet_exec(&state, filter).await; - - let mut results = parquet_exec.execute(0, task_ctx.clone()).unwrap(); - - let batch = results.next().await.unwrap().unwrap(); - + let batches = get_filter_results(&state, filter.clone(), false).await; // `month = 1` or `month = 2` from the page index should create below RowSelection // vec.push(RowSelector::select(312)); // vec.push(RowSelector::skip(900)); @@ -128,95 +136,78 @@ async fn page_index_filter_one_col() { // vec.push(RowSelector::skip(873)); // vec.push(RowSelector::select(318)); // vec.push(RowSelector::skip(2128)); - assert_eq!(batch.num_rows(), 1281); + assert_eq!(batches[0].num_rows(), 1281); + + let batches = get_filter_results(&state, filter, true).await; + assert_eq!(batches[0].num_rows(), 1180); // 3. create filter month == 1 and month == 12; let filter = col("month") .eq(lit(1_i32)) .and(col("month").eq(lit(12_i32))); + let batches = get_filter_results(&state, filter.clone(), false).await; + assert!(batches.is_empty()); - let parquet_exec = get_parquet_exec(&state, filter).await; - - let mut results = parquet_exec.execute(0, task_ctx.clone()).unwrap(); - - let batch = results.next().await; - - assert!(batch.is_none()); + let batches = get_filter_results(&state, filter, true).await; + assert!(batches.is_empty()); // 4.create filter 0 < month < 2 ; let filter = col("month").gt(lit(0_i32)).and(col("month").lt(lit(2_i32))); - - let parquet_exec = get_parquet_exec(&state, filter).await; - - let mut results = parquet_exec.execute(0, task_ctx.clone()).unwrap(); - - let batch = results.next().await.unwrap().unwrap(); - + let batches = get_filter_results(&state, filter.clone(), false).await; // should same with `month = 1` - assert_eq!(batch.num_rows(), 651); - - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + assert_eq!(batches[0].num_rows(), 651); + let batches = get_filter_results(&state, filter, true).await; + assert_eq!(batches[0].num_rows(), 620); // 5.create filter date_string_col == "01/01/09"`; // Note this test doesn't apply type coercion so the literal must match the actual view type let filter = col("date_string_col").eq(lit(ScalarValue::new_utf8view("01/01/09"))); - let parquet_exec = get_parquet_exec(&state, filter).await; - let mut results = parquet_exec.execute(0, task_ctx.clone()).unwrap(); - let batch = results.next().await.unwrap().unwrap(); + let batches = get_filter_results(&state, filter.clone(), false).await; + assert_eq!(batches[0].num_rows(), 14); // there should only two pages match the filter // min max // page-20 0 01/01/09 01/02/09 // page-21 0 01/01/09 01/01/09 // each 7 rows - assert_eq!(batch.num_rows(), 14); + assert_eq!(batches[0].num_rows(), 14); + let batches = get_filter_results(&state, filter, true).await; + assert_eq!(batches[0].num_rows(), 10); } #[tokio::test] async fn page_index_filter_multi_col() { let session_ctx = SessionContext::new(); let state = session_ctx.state(); - let task_ctx = session_ctx.task_ctx(); // create filter month == 1 and year = 2009; let filter = col("month").eq(lit(1_i32)).and(col("year").eq(lit(2009))); - - let parquet_exec = get_parquet_exec(&state, filter).await; - - let mut results = parquet_exec.execute(0, task_ctx.clone()).unwrap(); - - let batch = results.next().await.unwrap().unwrap(); - + let batches = get_filter_results(&state, filter.clone(), false).await; // `year = 2009` from the page index should create below RowSelection // vec.push(RowSelector::select(3663)); // vec.push(RowSelector::skip(3642)); // combine with `month = 1` total 333 row - assert_eq!(batch.num_rows(), 333); + assert_eq!(batches[0].num_rows(), 333); + let batches = get_filter_results(&state, filter, true).await; + assert_eq!(batches[0].num_rows(), 310); // create filter (year = 2009 or id = 1) and month = 1; // this should only use `month = 1` to evaluate the page index. let filter = col("month") .eq(lit(1_i32)) .and(col("year").eq(lit(2009)).or(col("id").eq(lit(1)))); - - let parquet_exec = get_parquet_exec(&state, filter).await; - - let mut results = parquet_exec.execute(0, task_ctx.clone()).unwrap(); - - let batch = results.next().await.unwrap().unwrap(); - assert_eq!(batch.num_rows(), 651); + let batches = get_filter_results(&state, filter.clone(), false).await; + assert_eq!(batches[0].num_rows(), 651); + let batches = get_filter_results(&state, filter, true).await; + assert_eq!(batches[0].num_rows(), 310); // create filter (year = 2009 or id = 1) // this filter use two columns will not push down let filter = col("year").eq(lit(2009)).or(col("id").eq(lit(1))); - - let parquet_exec = get_parquet_exec(&state, filter).await; - - let mut results = parquet_exec.execute(0, task_ctx.clone()).unwrap(); - - let batch = results.next().await.unwrap().unwrap(); - assert_eq!(batch.num_rows(), 7300); + let batches = get_filter_results(&state, filter.clone(), false).await; + assert_eq!(batches[0].num_rows(), 7300); + let batches = get_filter_results(&state, filter, true).await; + assert_eq!(batches[0].num_rows(), 3650); // create filter (year = 2009 and id = 1) or (year = 2010) // this filter use two columns will not push down @@ -226,13 +217,10 @@ async fn page_index_filter_multi_col() { .eq(lit(2009)) .and(col("id").eq(lit(1))) .or(col("year").eq(lit(2010))); - - let parquet_exec = get_parquet_exec(&state, filter).await; - - let mut results = parquet_exec.execute(0, task_ctx.clone()).unwrap(); - - let batch = results.next().await.unwrap().unwrap(); - assert_eq!(batch.num_rows(), 7300); + let batches = get_filter_results(&state, filter.clone(), false).await; + assert_eq!(batches[0].num_rows(), 7300); + let batches = get_filter_results(&state, filter, true).await; + assert_eq!(batches[0].num_rows(), 3651); } async fn test_prune( @@ -378,281 +366,367 @@ async fn prune_date64() { } macro_rules! int_tests { - ($bits:expr) => { - paste::item! { - #[tokio::test] - // null count min max - // page-0 0 -5 -1 - // page-1 0 -4 0 - // page-2 0 0 4 - // page-3 0 5 9 - async fn []() { - test_prune( - Scenario::Int, - &format!("SELECT * FROM t where i{} < 1", $bits), - Some(0), - Some(5), - 11, - 5, - ) - .await; - // result of sql "SELECT * FROM t where i < 1" is same as - // "SELECT * FROM t where -i > -1" - test_prune( - Scenario::Int, - &format!("SELECT * FROM t where -i{} > -1", $bits), - Some(0), - Some(5), - 11, - 5, - ) - .await; - } - - #[tokio::test] - async fn []() { - test_prune( - Scenario::Int, - &format!("SELECT * FROM t where i{} > 8", $bits), - Some(0), - Some(15), - 1, - 5, - ) - .await; - - test_prune( - Scenario::Int, - &format!("SELECT * FROM t where -i{} < -8", $bits), - Some(0), - Some(15), - 1, - 5, - ) - .await; - } - - #[tokio::test] - async fn []() { - test_prune( - Scenario::Int, - &format!("SELECT * FROM t where i{} = 1", $bits), - Some(0), - Some(15), - 1, - 5 - ) - .await; - } - #[tokio::test] - async fn []() { - test_prune( - Scenario::Int, - &format!("SELECT * FROM t where abs(i{}) = 1 and i{} = 1", $bits, $bits), - Some(0), - Some(15), - 1, - 5 - ) - .await; - } - - #[tokio::test] - async fn []() { - test_prune( - Scenario::Int, - &format!("SELECT * FROM t where abs(i{}) = 1", $bits), - Some(0), - Some(0), - 3, - 5 - ) - .await; - } - - #[tokio::test] - async fn []() { - test_prune( - Scenario::Int, - &format!("SELECT * FROM t where i{}+1 = 1", $bits), - Some(0), - Some(0), - 2, - 5 - ) - .await; - } - - #[tokio::test] - async fn []() { - test_prune( - Scenario::Int, - &format!("SELECT * FROM t where 1-i{} > 1", $bits), - Some(0), - Some(0), - 9, - 5 - ) - .await; - } - - #[tokio::test] - async fn []() { - // result of sql "SELECT * FROM t where in (1)" - test_prune( - Scenario::Int, - &format!("SELECT * FROM t where i{} in (1)", $bits), - Some(0), - Some(15), - 1, - 5 - ) - .await; - } - - #[tokio::test] - async fn []() { - // result of sql "SELECT * FROM t where not in (1)" prune nothing - test_prune( - Scenario::Int, - &format!("SELECT * FROM t where i{} not in (1)", $bits), - Some(0), - Some(0), - 19, - 5 - ) - .await; - } + ($bits:expr, $fn_lt:ident, $fn_gt:ident, $fn_eq:ident, $fn_scalar_fun_and_eq:ident, $fn_scalar_fun:ident, $fn_complex_expr:ident, $fn_complex_expr_subtract:ident, $fn_eq_in_list:ident, $fn_eq_in_list_negated:ident) => { + #[tokio::test] + // null count min max + // page-0 0 -5 -1 + // page-1 0 -4 0 + // page-2 0 0 4 + // page-3 0 5 9 + async fn $fn_lt() { + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where i{} < 1", $bits), + Some(0), + Some(5), + 11, + 5, + ) + .await; + // result of sql "SELECT * FROM t where i < 1" is same as + // "SELECT * FROM t where -i > -1" + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where -i{} > -1", $bits), + Some(0), + Some(5), + 11, + 5, + ) + .await; } - } + + #[tokio::test] + async fn $fn_gt() { + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where i{} > 8", $bits), + Some(0), + Some(15), + 1, + 5, + ) + .await; + + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where -i{} < -8", $bits), + Some(0), + Some(15), + 1, + 5, + ) + .await; + } + + #[tokio::test] + async fn $fn_eq() { + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where i{} = 1", $bits), + Some(0), + Some(15), + 1, + 5, + ) + .await; + } + #[tokio::test] + async fn $fn_scalar_fun_and_eq() { + test_prune( + Scenario::Int, + &format!( + "SELECT * FROM t where abs(i{}) = 1 and i{} = 1", + $bits, $bits + ), + Some(0), + Some(15), + 1, + 5, + ) + .await; + } + + #[tokio::test] + async fn $fn_scalar_fun() { + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where abs(i{}) = 1", $bits), + Some(0), + Some(0), + 3, + 5, + ) + .await; + } + + #[tokio::test] + async fn $fn_complex_expr() { + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where i{}+1 = 1", $bits), + Some(0), + Some(0), + 2, + 5, + ) + .await; + } + + #[tokio::test] + async fn $fn_complex_expr_subtract() { + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where 1-i{} > 1", $bits), + Some(0), + Some(0), + 9, + 5, + ) + .await; + } + + #[tokio::test] + async fn $fn_eq_in_list() { + // result of sql "SELECT * FROM t where in (1)" + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where i{} in (1)", $bits), + Some(0), + Some(15), + 1, + 5, + ) + .await; + } + + #[tokio::test] + async fn $fn_eq_in_list_negated() { + // result of sql "SELECT * FROM t where not in (1)" prune nothing + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where i{} not in (1)", $bits), + Some(0), + Some(0), + 19, + 5, + ) + .await; + } + }; } -int_tests!(8); -int_tests!(16); -int_tests!(32); -int_tests!(64); +int_tests!( + 8, + prune_int8_lt, + prune_int8_gt, + prune_int8_eq, + prune_int8_scalar_fun_and_eq, + prune_int8_scalar_fun, + prune_int8_complex_expr, + prune_int8_complex_expr_subtract, + prune_int8_eq_in_list, + prune_int8_eq_in_list_negated +); +int_tests!( + 16, + prune_int16_lt, + prune_int16_gt, + prune_int16_eq, + prune_int16_scalar_fun_and_eq, + prune_int16_scalar_fun, + prune_int16_complex_expr, + prune_int16_complex_expr_subtract, + prune_int16_eq_in_list, + prune_int16_eq_in_list_negated +); +int_tests!( + 32, + prune_int32_lt, + prune_int32_gt, + prune_int32_eq, + prune_int32_scalar_fun_and_eq, + prune_int32_scalar_fun, + prune_int32_complex_expr, + prune_int32_complex_expr_subtract, + prune_int32_eq_in_list, + prune_int32_eq_in_list_negated +); +int_tests!( + 64, + prune_int64_lt, + prune_int64_gt, + prune_int64_eq, + prune_int64_scalar_fun_and_eq, + prune_int64_scalar_fun, + prune_int64_complex_expr, + prune_int64_complex_expr_subtract, + prune_int64_eq_in_list, + prune_int64_eq_in_list_negated +); macro_rules! uint_tests { - ($bits:expr) => { - paste::item! { - #[tokio::test] - // null count min max - // page-0 0 0 4 - // page-1 0 1 5 - // page-2 0 5 9 - // page-3 0 250 254 - async fn []() { - test_prune( - Scenario::UInt, - &format!("SELECT * FROM t where u{} < 6", $bits), - Some(0), - Some(5), - 11, - 5 - ) - .await; - } - - #[tokio::test] - async fn []() { - test_prune( - Scenario::UInt, - &format!("SELECT * FROM t where u{} > 253", $bits), - Some(0), - Some(15), - 1, - 5 - ) - .await; - } - - #[tokio::test] - async fn []() { - test_prune( - Scenario::UInt, - &format!("SELECT * FROM t where u{} = 6", $bits), - Some(0), - Some(15), - 1, - 5 - ) - .await; - } - - #[tokio::test] - async fn []() { - test_prune( - Scenario::UInt, - &format!("SELECT * FROM t where power(u{}, 2) = 36 and u{} = 6", $bits, $bits), - Some(0), - Some(15), - 1, - 5 - ) - .await; - } - - #[tokio::test] - async fn []() { - test_prune( - Scenario::UInt, - &format!("SELECT * FROM t where power(u{}, 2) = 25", $bits), - Some(0), - Some(0), - 2, - 5 - ) - .await; - } - - #[tokio::test] - async fn []() { - test_prune( - Scenario::UInt, - &format!("SELECT * FROM t where u{}+1 = 6", $bits), - Some(0), - Some(0), - 2, - 5 - ) - .await; - } - - #[tokio::test] - async fn []() { - // result of sql "SELECT * FROM t where in (1)" - test_prune( - Scenario::UInt, - &format!("SELECT * FROM t where u{} in (6)", $bits), - Some(0), - Some(15), - 1, - 5 - ) - .await; - } - - #[tokio::test] - async fn []() { - // result of sql "SELECT * FROM t where not in (6)" prune nothing - test_prune( - Scenario::UInt, - &format!("SELECT * FROM t where u{} not in (6)", $bits), - Some(0), - Some(0), - 19, - 5 - ) - .await; - } + ($bits:expr, $fn_lt:ident, $fn_gt:ident, $fn_eq:ident, $fn_scalar_fun_and_eq:ident, $fn_scalar_fun:ident, $fn_complex_expr:ident, $fn_eq_in_list:ident, $fn_eq_in_list_negated:ident) => { + #[tokio::test] + // null count min max + // page-0 0 0 4 + // page-1 0 1 5 + // page-2 0 5 9 + // page-3 0 250 254 + async fn $fn_lt() { + test_prune( + Scenario::UInt, + &format!("SELECT * FROM t where u{} < 6", $bits), + Some(0), + Some(5), + 11, + 5, + ) + .await; } - } + + #[tokio::test] + async fn $fn_gt() { + test_prune( + Scenario::UInt, + &format!("SELECT * FROM t where u{} > 253", $bits), + Some(0), + Some(15), + 1, + 5, + ) + .await; + } + + #[tokio::test] + async fn $fn_eq() { + test_prune( + Scenario::UInt, + &format!("SELECT * FROM t where u{} = 6", $bits), + Some(0), + Some(15), + 1, + 5, + ) + .await; + } + + #[tokio::test] + async fn $fn_scalar_fun_and_eq() { + test_prune( + Scenario::UInt, + &format!( + "SELECT * FROM t where power(u{}, 2) = 36 and u{} = 6", + $bits, $bits + ), + Some(0), + Some(15), + 1, + 5, + ) + .await; + } + + #[tokio::test] + async fn $fn_scalar_fun() { + test_prune( + Scenario::UInt, + &format!("SELECT * FROM t where power(u{}, 2) = 25", $bits), + Some(0), + Some(0), + 2, + 5, + ) + .await; + } + + #[tokio::test] + async fn $fn_complex_expr() { + test_prune( + Scenario::UInt, + &format!("SELECT * FROM t where u{}+1 = 6", $bits), + Some(0), + Some(0), + 2, + 5, + ) + .await; + } + + #[tokio::test] + async fn $fn_eq_in_list() { + // result of sql "SELECT * FROM t where in (1)" + test_prune( + Scenario::UInt, + &format!("SELECT * FROM t where u{} in (6)", $bits), + Some(0), + Some(15), + 1, + 5, + ) + .await; + } + + #[tokio::test] + async fn $fn_eq_in_list_negated() { + // result of sql "SELECT * FROM t where not in (6)" prune nothing + test_prune( + Scenario::UInt, + &format!("SELECT * FROM t where u{} not in (6)", $bits), + Some(0), + Some(0), + 19, + 5, + ) + .await; + } + }; } -uint_tests!(8); -uint_tests!(16); -uint_tests!(32); -uint_tests!(64); +uint_tests!( + 8, + prune_uint8_lt, + prune_uint8_gt, + prune_uint8_eq, + prune_uint8_scalar_fun_and_eq, + prune_uint8_scalar_fun, + prune_uint8_complex_expr, + prune_uint8_eq_in_list, + prune_uint8_eq_in_list_negated +); +uint_tests!( + 16, + prune_uint16_lt, + prune_uint16_gt, + prune_uint16_eq, + prune_uint16_scalar_fun_and_eq, + prune_uint16_scalar_fun, + prune_uint16_complex_expr, + prune_uint16_eq_in_list, + prune_uint16_eq_in_list_negated +); +uint_tests!( + 32, + prune_uint32_lt, + prune_uint32_gt, + prune_uint32_eq, + prune_uint32_scalar_fun_and_eq, + prune_uint32_scalar_fun, + prune_uint32_complex_expr, + prune_uint32_eq_in_list, + prune_uint32_eq_in_list_negated +); +uint_tests!( + 64, + prune_uint64_lt, + prune_uint64_gt, + prune_uint64_eq, + prune_uint64_scalar_fun_and_eq, + prune_uint64_scalar_fun, + prune_uint64_complex_expr, + prune_uint64_eq_in_list, + prune_uint64_eq_in_list_negated +); #[tokio::test] // null count min max @@ -911,8 +985,8 @@ async fn without_pushdown_filter() { ) .unwrap(); - // Without filter will not read pageIndex. - assert!(bytes_scanned_with_filter > bytes_scanned_without_filter); + // Same amount of bytes are scanned when defaulting to cache parquet metadata + assert_eq!(bytes_scanned_with_filter, bytes_scanned_without_filter); } #[tokio::test] @@ -976,3 +1050,56 @@ fn cast_count_metric(metric: MetricValue) -> Option { _ => None, } } + +#[tokio::test] +async fn test_parquet_opener_without_page_index() { + // Defines a simple schema and batch + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + + // Create a temp file + let file = tempfile::Builder::new() + .suffix(".parquet") + .tempfile() + .unwrap(); + let path = file.path().to_str().unwrap().to_string(); + + // Write parquet WITHOUT page index + // The default WriterProperties does not write page index, but we set it explicitly + // to be robust against future changes in defaults as requested by reviewers. + let props = WriterProperties::builder() + .set_statistics_enabled(parquet::file::properties::EnabledStatistics::None) + .build(); + + let file_fs = std::fs::File::create(&path).unwrap(); + let mut writer = ArrowWriter::try_new(file_fs, batch.schema(), Some(props)).unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + + // Setup SessionContext with PageIndex enabled + // This triggers the ParquetOpener to try and load page index if available + let config = SessionConfig::new().with_parquet_page_index_pruning(true); + + let ctx = SessionContext::new_with_config(config); + + // Register the table + ctx.register_parquet("t", &path, Default::default()) + .await + .unwrap(); + + // Query the table + // If the bug exists, this might fail because Opener tries to load PageIndex forcefully + let df = ctx.sql("SELECT * FROM t").await.unwrap(); + let batches = df + .collect() + .await + .expect("Failed to read parquet file without page index"); + + // We expect this to succeed, but currently it might fail + assert_eq!(batches.len(), 1); + assert_eq!(batches[0].num_rows(), 3); +} diff --git a/datafusion/core/tests/parquet/row_group_pruning.rs b/datafusion/core/tests/parquet/row_group_pruning.rs index 5a85f47c015a9..3ec3541af977a 100644 --- a/datafusion/core/tests/parquet/row_group_pruning.rs +++ b/datafusion/core/tests/parquet/row_group_pruning.rs @@ -18,8 +18,12 @@ //! This file contains an end to end test of parquet pruning. It writes //! data into a parquet file and then verifies row groups are pruned as //! expected. +use std::sync::Arc; + +use arrow::array::{ArrayRef, Int32Array, RecordBatch}; +use arrow_schema::{DataType, Field, Schema}; use datafusion::prelude::SessionConfig; -use datafusion_common::ScalarValue; +use datafusion_common::{DataFusionError, ScalarValue}; use itertools::Itertools; use crate::parquet::Unit::RowGroup; @@ -30,10 +34,13 @@ struct RowGroupPruningTest { query: String, expected_errors: Option, expected_row_group_matched_by_statistics: Option, + expected_row_group_fully_matched_by_statistics: Option, expected_row_group_pruned_by_statistics: Option, + expected_files_pruned_by_statistics: Option, expected_row_group_matched_by_bloom_filter: Option, expected_row_group_pruned_by_bloom_filter: Option, - expected_results: usize, + expected_limit_pruned_row_groups: Option, + expected_rows: usize, } impl RowGroupPruningTest { // Start building the test configuration @@ -44,9 +51,12 @@ impl RowGroupPruningTest { expected_errors: None, expected_row_group_matched_by_statistics: None, expected_row_group_pruned_by_statistics: None, + expected_row_group_fully_matched_by_statistics: None, + expected_files_pruned_by_statistics: None, expected_row_group_matched_by_bloom_filter: None, expected_row_group_pruned_by_bloom_filter: None, - expected_results: 0, + expected_limit_pruned_row_groups: None, + expected_rows: 0, } } @@ -74,12 +84,26 @@ impl RowGroupPruningTest { self } + // Set the expected fully matched row groups by statistics + fn with_fully_matched_by_stats( + mut self, + fully_matched_by_stats: Option, + ) -> Self { + self.expected_row_group_fully_matched_by_statistics = fully_matched_by_stats; + self + } + // Set the expected pruned row groups by statistics fn with_pruned_by_stats(mut self, pruned_by_stats: Option) -> Self { self.expected_row_group_pruned_by_statistics = pruned_by_stats; self } + fn with_pruned_files(mut self, pruned_files: Option) -> Self { + self.expected_files_pruned_by_statistics = pruned_files; + self + } + // Set the expected matched row groups by bloom filter fn with_matched_by_bloom_filter(mut self, matched_by_bf: Option) -> Self { self.expected_row_group_matched_by_bloom_filter = matched_by_bf; @@ -92,9 +116,14 @@ impl RowGroupPruningTest { self } - // Set the expected rows for the test + fn with_limit_pruned_row_groups(mut self, pruned_by_limit: Option) -> Self { + self.expected_limit_pruned_row_groups = pruned_by_limit; + self + } + + /// Set the number of expected rows from the output of this test fn with_expected_rows(mut self, rows: usize) -> Self { - self.expected_results = rows; + self.expected_rows = rows; self } @@ -122,19 +151,86 @@ impl RowGroupPruningTest { "mismatched row_groups_pruned_statistics", ); assert_eq!( - output.row_groups_matched_bloom_filter(), + output.files_ranges_pruned_statistics(), + self.expected_files_pruned_by_statistics, + "mismatched files_ranges_pruned_statistics", + ); + let bloom_filter_metrics = output.row_groups_bloom_filter(); + assert_eq!( + bloom_filter_metrics.as_ref().map(|pm| pm.total_matched()), self.expected_row_group_matched_by_bloom_filter, "mismatched row_groups_matched_bloom_filter", ); assert_eq!( - output.row_groups_pruned_bloom_filter(), + bloom_filter_metrics.map(|pm| pm.total_pruned()), self.expected_row_group_pruned_by_bloom_filter, "mismatched row_groups_pruned_bloom_filter", ); + assert_eq!( output.result_rows, - self.expected_results, - "mismatched expected rows: {}", + self.expected_rows, + "Expected {} rows, got {}: {}", + output.result_rows, + self.expected_rows, + output.description(), + ); + } + + // Execute the test with the current configuration + async fn test_row_group_prune_with_custom_data( + self, + schema: Arc, + batches: Vec, + max_row_per_group: usize, + ) { + let output = ContextWithParquet::with_custom_data( + self.scenario, + RowGroup(max_row_per_group), + schema, + batches, + ) + .await + .query(&self.query) + .await; + + println!("{}", output.description()); + assert_eq!( + output.predicate_evaluation_errors(), + self.expected_errors, + "mismatched predicate_evaluation error" + ); + assert_eq!( + output.row_groups_matched_statistics(), + self.expected_row_group_matched_by_statistics, + "mismatched row_groups_matched_statistics", + ); + assert_eq!( + output.row_groups_fully_matched_statistics(), + self.expected_row_group_fully_matched_by_statistics, + "mismatched row_groups_fully_matched_statistics", + ); + assert_eq!( + output.row_groups_pruned_statistics(), + self.expected_row_group_pruned_by_statistics, + "mismatched row_groups_pruned_statistics", + ); + assert_eq!( + output.files_ranges_pruned_statistics(), + self.expected_files_pruned_by_statistics, + "mismatched files_ranges_pruned_statistics", + ); + assert_eq!( + output.limit_pruned_row_groups(), + self.expected_limit_pruned_row_groups, + "mismatched limit_pruned_row_groups", + ); + assert_eq!( + output.result_rows, + self.expected_rows, + "Expected {} rows, got {}: {}", + output.result_rows, + self.expected_rows, output.description(), ); } @@ -148,7 +244,8 @@ async fn prune_timestamps_nanos() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) - .with_matched_by_bloom_filter(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(3)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(10) .test_row_group_prune() @@ -165,7 +262,8 @@ async fn prune_timestamps_micros() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) - .with_matched_by_bloom_filter(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(3)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(10) .test_row_group_prune() @@ -182,7 +280,8 @@ async fn prune_timestamps_millis() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) - .with_matched_by_bloom_filter(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(3)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(10) .test_row_group_prune() @@ -199,7 +298,8 @@ async fn prune_timestamps_seconds() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) - .with_matched_by_bloom_filter(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(3)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(10) .test_row_group_prune() @@ -214,7 +314,8 @@ async fn prune_date32() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) - .with_matched_by_bloom_filter(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) .test_row_group_prune() @@ -243,8 +344,9 @@ async fn prune_date64() { println!("{}", output.description()); // This should prune out groups without error assert_eq!(output.predicate_evaluation_errors(), Some(0)); - assert_eq!(output.row_groups_matched(), Some(1)); - assert_eq!(output.row_groups_pruned(), Some(3)); + // 'dates' table has 4 row groups, and only the first one is matched by the predicate + assert_eq!(output.row_groups_matched_statistics(), Some(1)); + assert_eq!(output.row_groups_pruned_statistics(), Some(3)); assert_eq!(output.result_rows, 1, "{}", output.description()); } @@ -256,7 +358,8 @@ async fn prune_disabled() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) - .with_matched_by_bloom_filter(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(3)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(10) .test_row_group_prune() @@ -267,16 +370,21 @@ async fn prune_disabled() { let expected_rows = 10; let config = SessionConfig::new().with_parquet_pruning(false); - let output = - ContextWithParquet::with_config(Scenario::Timestamps, RowGroup(5), config) - .await - .query(query) - .await; + let output = ContextWithParquet::with_config( + Scenario::Timestamps, + RowGroup(5), + config, + None, + None, + ) + .await + .query(query) + .await; println!("{}", output.description()); // This should not prune any assert_eq!(output.predicate_evaluation_errors(), Some(0)); - assert_eq!(output.row_groups_matched(), Some(0)); + assert_eq!(output.row_groups_matched(), Some(4)); assert_eq!(output.row_groups_pruned(), Some(0)); assert_eq!( output.result_rows, @@ -291,303 +399,365 @@ async fn prune_disabled() { // https://github.com/apache/datafusion/issues/9779 bug so that tests pass // if and only if Bloom filters on Int8 and Int16 columns are still buggy. macro_rules! int_tests { - ($bits:expr) => { - paste::item! { - #[tokio::test] - async fn []() { - RowGroupPruningTest::new() - .with_scenario(Scenario::Int) - .with_query(&format!("SELECT * FROM t where i{} < 1", $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(3)) - .with_pruned_by_stats(Some(1)) - .with_matched_by_bloom_filter(Some(0)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(11) - .test_row_group_prune() - .await; - - // result of sql "SELECT * FROM t where i < 1" is same as - // "SELECT * FROM t where -i > -1" - RowGroupPruningTest::new() - .with_scenario(Scenario::Int) - .with_query(&format!("SELECT * FROM t where -i{} > -1", $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(3)) - .with_pruned_by_stats(Some(1)) - .with_matched_by_bloom_filter(Some(0)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(11) - .test_row_group_prune() - .await; - } - - #[tokio::test] - async fn []() { - RowGroupPruningTest::new() - .with_scenario(Scenario::Int) - .with_query(&format!("SELECT * FROM t where i{} = 1", $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(1)) - .with_pruned_by_stats(Some(3)) - .with_matched_by_bloom_filter(Some(1)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(1) - .test_row_group_prune() - .await; - } - #[tokio::test] - async fn []() { - RowGroupPruningTest::new() - .with_scenario(Scenario::Int) - .with_query(&format!("SELECT * FROM t where abs(i{}) = 1 and i{} = 1", $bits, $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(1)) - .with_pruned_by_stats(Some(3)) - .with_matched_by_bloom_filter(Some(1)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(1) - .test_row_group_prune() - .await; - } - - #[tokio::test] - async fn []() { - RowGroupPruningTest::new() - .with_scenario(Scenario::Int) - .with_query(&format!("SELECT * FROM t where abs(i{}) = 1", $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(0)) - .with_pruned_by_stats(Some(0)) - .with_matched_by_bloom_filter(Some(0)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(3) - .test_row_group_prune() - .await; - } - - #[tokio::test] - async fn []() { - RowGroupPruningTest::new() - .with_scenario(Scenario::Int) - .with_query(&format!("SELECT * FROM t where i{}+1 = 1", $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(0)) - .with_pruned_by_stats(Some(0)) - .with_matched_by_bloom_filter(Some(0)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(2) - .test_row_group_prune() - .await; - } - - #[tokio::test] - async fn []() { - RowGroupPruningTest::new() - .with_scenario(Scenario::Int) - .with_query(&format!("SELECT * FROM t where 1-i{} > 1", $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(0)) - .with_pruned_by_stats(Some(0)) - .with_matched_by_bloom_filter(Some(0)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(9) - .test_row_group_prune() - .await; - } - - #[tokio::test] - async fn []() { - // result of sql "SELECT * FROM t where in (1)" - RowGroupPruningTest::new() - .with_scenario(Scenario::Int) - .with_query(&format!("SELECT * FROM t where i{} in (1)", $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(1)) - .with_pruned_by_stats(Some(3)) - .with_matched_by_bloom_filter(Some(1)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(1) - .test_row_group_prune() - .await; - } - - #[tokio::test] - async fn []() { - // result of sql "SELECT * FROM t where in (1000)", prune all - // test whether statistics works - RowGroupPruningTest::new() - .with_scenario(Scenario::Int) - .with_query(&format!("SELECT * FROM t where i{} in (100)", $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(0)) - .with_pruned_by_stats(Some(4)) - .with_matched_by_bloom_filter(Some(0)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(0) - .test_row_group_prune() - .await; - } - - #[tokio::test] - async fn []() { - // result of sql "SELECT * FROM t where not in (1)" prune nothing - RowGroupPruningTest::new() - .with_scenario(Scenario::Int) - .with_query(&format!("SELECT * FROM t where i{} not in (1)", $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(4)) - .with_pruned_by_stats(Some(0)) - .with_matched_by_bloom_filter(Some(4)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(19) - .test_row_group_prune() - .await; - } + ($bits:expr, $fn_lt:ident, $fn_eq:ident, $fn_scalar_fun_and_eq:ident, $fn_scalar_fun:ident, $fn_complex_expr:ident, $fn_complex_expr_subtract:ident, $fn_eq_in_list:ident, $fn_eq_in_list_2:ident, $fn_eq_in_list_negated:ident) => { + #[tokio::test] + async fn $fn_lt() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where i{} < 1", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) + .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(3)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(11) + .test_row_group_prune() + .await; + + // result of sql "SELECT * FROM t where i < 1" is same as + // "SELECT * FROM t where -i > -1" + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where -i{} > -1", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) + .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(3)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(11) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn $fn_eq() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where i{} = 1", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(3)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) + .test_row_group_prune() + .await; + } + #[tokio::test] + async fn $fn_scalar_fun_and_eq() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!( + "SELECT * FROM t where abs(i{}) = 1 and i{} = 1", + $bits, $bits + )) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(3)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn $fn_scalar_fun() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where abs(i{}) = 1", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(4)) + .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(4)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(3) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn $fn_complex_expr() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where i{}+1 = 1", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(4)) + .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(4)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(2) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn $fn_complex_expr_subtract() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where 1-i{} > 1", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(4)) + .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(4)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(9) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn $fn_eq_in_list() { + // result of sql "SELECT * FROM t where in (1)" + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where i{} in (1)", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(3)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn $fn_eq_in_list_2() { + // result of sql "SELECT * FROM t where in (1000)", prune all + // test whether statistics works + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where i{} in (100)", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(0)) + .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(0) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn $fn_eq_in_list_negated() { + // result of sql "SELECT * FROM t where not in (1)" prune nothing + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where i{} not in (1)", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(4)) + .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(4)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(19) + .test_row_group_prune() + .await; } }; } // int8/int16 are incorrect: https://github.com/apache/datafusion/issues/9779 -int_tests!(32); -int_tests!(64); +int_tests!( + 32, + prune_int32_lt, + prune_int32_eq, + prune_int32_scalar_fun_and_eq, + prune_int32_scalar_fun, + prune_int32_complex_expr, + prune_int32_complex_expr_subtract, + prune_int32_eq_in_list, + prune_int32_eq_in_list_2, + prune_int32_eq_in_list_negated +); +int_tests!( + 64, + prune_int64_lt, + prune_int64_eq, + prune_int64_scalar_fun_and_eq, + prune_int64_scalar_fun, + prune_int64_complex_expr, + prune_int64_complex_expr_subtract, + prune_int64_eq_in_list, + prune_int64_eq_in_list_2, + prune_int64_eq_in_list_negated +); // $bits: number of bits of the integer to test (8, 16, 32, 64) // $correct_bloom_filters: if false, replicates the // https://github.com/apache/datafusion/issues/9779 bug so that tests pass // if and only if Bloom filters on UInt8 and UInt16 columns are still buggy. macro_rules! uint_tests { - ($bits:expr) => { - paste::item! { - #[tokio::test] - async fn []() { - RowGroupPruningTest::new() - .with_scenario(Scenario::UInt) - .with_query(&format!("SELECT * FROM t where u{} < 6", $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(3)) - .with_pruned_by_stats(Some(1)) - .with_matched_by_bloom_filter(Some(0)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(11) - .test_row_group_prune() - .await; - } - - #[tokio::test] - async fn []() { - RowGroupPruningTest::new() - .with_scenario(Scenario::UInt) - .with_query(&format!("SELECT * FROM t where u{} = 6", $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(1)) - .with_pruned_by_stats(Some(3)) - .with_matched_by_bloom_filter(Some(1)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(1) - .test_row_group_prune() - .await; - } - #[tokio::test] - async fn []() { - RowGroupPruningTest::new() - .with_scenario(Scenario::UInt) - .with_query(&format!("SELECT * FROM t where power(u{}, 2) = 36 and u{} = 6", $bits, $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(1)) - .with_pruned_by_stats(Some(3)) - .with_matched_by_bloom_filter(Some(1)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(1) - .test_row_group_prune() - .await; - } - - #[tokio::test] - async fn []() { - RowGroupPruningTest::new() - .with_scenario(Scenario::UInt) - .with_query(&format!("SELECT * FROM t where power(u{}, 2) = 25", $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(0)) - .with_pruned_by_stats(Some(0)) - .with_matched_by_bloom_filter(Some(0)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(2) - .test_row_group_prune() - .await; - } - - #[tokio::test] - async fn []() { - RowGroupPruningTest::new() - .with_scenario(Scenario::UInt) - .with_query(&format!("SELECT * FROM t where u{}+1 = 6", $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(0)) - .with_pruned_by_stats(Some(0)) - .with_matched_by_bloom_filter(Some(0)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(2) - .test_row_group_prune() - .await; - } - - #[tokio::test] - async fn []() { - // result of sql "SELECT * FROM t where in (1)" - RowGroupPruningTest::new() - .with_scenario(Scenario::UInt) - .with_query(&format!("SELECT * FROM t where u{} in (6)", $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(1)) - .with_pruned_by_stats(Some(3)) - .with_matched_by_bloom_filter(Some(1)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(1) - .test_row_group_prune() - .await; - } - - #[tokio::test] - async fn []() { - // result of sql "SELECT * FROM t where in (1000)", prune all - // test whether statistics works - RowGroupPruningTest::new() - .with_scenario(Scenario::UInt) - .with_query(&format!("SELECT * FROM t where u{} in (100)", $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(0)) - .with_pruned_by_stats(Some(4)) - .with_matched_by_bloom_filter(Some(0)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(0) - .test_row_group_prune() - .await; - } - - #[tokio::test] - async fn []() { - // result of sql "SELECT * FROM t where not in (1)" prune nothing - RowGroupPruningTest::new() - .with_scenario(Scenario::UInt) - .with_query(&format!("SELECT * FROM t where u{} not in (6)", $bits)) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(4)) - .with_pruned_by_stats(Some(0)) - .with_matched_by_bloom_filter(Some(4)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(19) - .test_row_group_prune() - .await; - } + ($bits:expr, $fn_lt:ident, $fn_eq:ident, $fn_scalar_fun_and_eq:ident, $fn_scalar_fun:ident, $fn_complex_expr:ident, $fn_eq_in_list:ident, $fn_eq_in_list_2:ident, $fn_eq_in_list_negated:ident) => { + #[tokio::test] + async fn $fn_lt() { + RowGroupPruningTest::new() + .with_scenario(Scenario::UInt) + .with_query(&format!("SELECT * FROM t where u{} < 6", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) + .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(3)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(11) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn $fn_eq() { + RowGroupPruningTest::new() + .with_scenario(Scenario::UInt) + .with_query(&format!("SELECT * FROM t where u{} = 6", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(3)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) + .test_row_group_prune() + .await; + } + #[tokio::test] + async fn $fn_scalar_fun_and_eq() { + RowGroupPruningTest::new() + .with_scenario(Scenario::UInt) + .with_query(&format!( + "SELECT * FROM t where power(u{}, 2) = 36 and u{} = 6", + $bits, $bits + )) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(3)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn $fn_scalar_fun() { + RowGroupPruningTest::new() + .with_scenario(Scenario::UInt) + .with_query(&format!("SELECT * FROM t where power(u{}, 2) = 25", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(4)) + .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(4)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(2) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn $fn_complex_expr() { + RowGroupPruningTest::new() + .with_scenario(Scenario::UInt) + .with_query(&format!("SELECT * FROM t where u{}+1 = 6", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(4)) + .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(4)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(2) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn $fn_eq_in_list() { + // result of sql "SELECT * FROM t where in (1)" + RowGroupPruningTest::new() + .with_scenario(Scenario::UInt) + .with_query(&format!("SELECT * FROM t where u{} in (6)", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(3)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn $fn_eq_in_list_2() { + // result of sql "SELECT * FROM t where in (1000)", prune all + // test whether statistics works + RowGroupPruningTest::new() + .with_scenario(Scenario::UInt) + .with_query(&format!("SELECT * FROM t where u{} in (100)", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(0)) + .with_pruned_by_stats(Some(4)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(0) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn $fn_eq_in_list_negated() { + // result of sql "SELECT * FROM t where not in (1)" prune nothing + RowGroupPruningTest::new() + .with_scenario(Scenario::UInt) + .with_query(&format!("SELECT * FROM t where u{} not in (6)", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(4)) + .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(4)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(19) + .test_row_group_prune() + .await; } }; } // uint8/uint16 are incorrect: https://github.com/apache/datafusion/issues/9779 -uint_tests!(32); -uint_tests!(64); +uint_tests!( + 32, + prune_uint32_lt, + prune_uint32_eq, + prune_uint32_scalar_fun_and_eq, + prune_uint32_scalar_fun, + prune_uint32_complex_expr, + prune_uint32_eq_in_list, + prune_uint32_eq_in_list_2, + prune_uint32_eq_in_list_negated +); +uint_tests!( + 64, + prune_uint64_lt, + prune_uint64_eq, + prune_uint64_scalar_fun_and_eq, + prune_uint64_scalar_fun, + prune_uint64_complex_expr, + prune_uint64_eq_in_list, + prune_uint64_eq_in_list_2, + prune_uint64_eq_in_list_negated +); #[tokio::test] async fn prune_int32_eq_large_in_list() { @@ -604,6 +774,7 @@ async fn prune_int32_eq_large_in_list() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(1)) .with_expected_rows(0) @@ -626,6 +797,7 @@ async fn prune_uint32_eq_large_in_list() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(1)) .with_expected_rows(0) @@ -641,7 +813,8 @@ async fn prune_f64_lt() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) - .with_matched_by_bloom_filter(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(3)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(11) .test_row_group_prune() @@ -652,7 +825,8 @@ async fn prune_f64_lt() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) - .with_matched_by_bloom_filter(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(3)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(11) .test_row_group_prune() @@ -669,7 +843,8 @@ async fn prune_f64_scalar_fun_and_gt() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(2)) - .with_matched_by_bloom_filter(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) .test_row_group_prune() @@ -683,9 +858,10 @@ async fn prune_f64_scalar_fun() { .with_scenario(Scenario::Float64) .with_query("SELECT * FROM t where abs(f-1) <= 0.000001") .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(0)) + .with_matched_by_stats(Some(4)) .with_pruned_by_stats(Some(0)) - .with_matched_by_bloom_filter(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(4)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) .test_row_group_prune() @@ -699,9 +875,10 @@ async fn prune_f64_complex_expr() { .with_scenario(Scenario::Float64) .with_query("SELECT * FROM t where f+1 > 1.1") .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(0)) + .with_matched_by_stats(Some(4)) .with_pruned_by_stats(Some(0)) - .with_matched_by_bloom_filter(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(4)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(9) .test_row_group_prune() @@ -715,9 +892,10 @@ async fn prune_f64_complex_expr_subtract() { .with_scenario(Scenario::Float64) .with_query("SELECT * FROM t where 1-f > 1") .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(0)) + .with_matched_by_stats(Some(4)) .with_pruned_by_stats(Some(0)) - .with_matched_by_bloom_filter(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(4)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(9) .test_row_group_prune() @@ -735,7 +913,8 @@ async fn prune_decimal_lt() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) - .with_matched_by_bloom_filter(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(6) .test_row_group_prune() @@ -746,7 +925,8 @@ async fn prune_decimal_lt() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) - .with_matched_by_bloom_filter(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(8) .test_row_group_prune() @@ -757,7 +937,8 @@ async fn prune_decimal_lt() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) - .with_matched_by_bloom_filter(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(6) .test_row_group_prune() @@ -768,7 +949,8 @@ async fn prune_decimal_lt() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) - .with_matched_by_bloom_filter(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(8) .test_row_group_prune() @@ -786,6 +968,7 @@ async fn prune_decimal_eq() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) @@ -797,6 +980,7 @@ async fn prune_decimal_eq() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) @@ -809,6 +993,7 @@ async fn prune_decimal_eq() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) @@ -820,6 +1005,7 @@ async fn prune_decimal_eq() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) @@ -839,7 +1025,8 @@ async fn prune_decimal_in_list() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) - .with_matched_by_bloom_filter(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(5) .test_row_group_prune() @@ -850,7 +1037,8 @@ async fn prune_decimal_in_list() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) - .with_matched_by_bloom_filter(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(6) .test_row_group_prune() @@ -861,7 +1049,8 @@ async fn prune_decimal_in_list() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) - .with_matched_by_bloom_filter(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(5) .test_row_group_prune() @@ -872,7 +1061,8 @@ async fn prune_decimal_in_list() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) - .with_matched_by_bloom_filter(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(6) .test_row_group_prune() @@ -885,6 +1075,7 @@ async fn prune_decimal_in_list() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(2)) .with_expected_rows(1) @@ -898,6 +1089,7 @@ async fn prune_decimal_in_list() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(2)) .with_expected_rows(1) @@ -911,6 +1103,7 @@ async fn prune_decimal_in_list() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(2)) .with_expected_rows(1) @@ -929,6 +1122,7 @@ async fn prune_string_eq_match() { // false positive on 'all backends' batch: 'backend five' < 'backend one' < 'backend three' .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(1)) .with_expected_rows(1) @@ -947,6 +1141,7 @@ async fn prune_string_eq_no_match() { // false positive on 'all backends' batch: 'backend five' < 'backend one' < 'backend three' .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(2)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(1)) .with_expected_rows(0) @@ -963,6 +1158,7 @@ async fn prune_string_eq_no_match() { // false positive on 'mixed' batch: 'backend one' < 'frontend nine' < 'frontend six' .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(2)) .with_expected_rows(0) @@ -980,6 +1176,7 @@ async fn prune_string_neq() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(3)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(14) @@ -998,7 +1195,8 @@ async fn prune_string_lt() { // matches 'all backends' only .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(2)) - .with_matched_by_bloom_filter(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(3) .test_row_group_prune() @@ -1012,7 +1210,8 @@ async fn prune_string_lt() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) - .with_matched_by_bloom_filter(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) // all backends from 'mixed' and 'all backends' .with_expected_rows(8) @@ -1031,6 +1230,7 @@ async fn prune_binary_eq_match() { // false positive on 'all backends' batch: 'backend five' < 'backend one' < 'backend three' .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(1)) .with_expected_rows(1) @@ -1049,6 +1249,7 @@ async fn prune_binary_eq_no_match() { // false positive on 'all backends' batch: 'backend five' < 'backend one' < 'backend three' .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(2)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(1)) .with_expected_rows(0) @@ -1065,6 +1266,7 @@ async fn prune_binary_eq_no_match() { // false positive on 'mixed' batch: 'backend one' < 'frontend nine' < 'frontend six' .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(2)) .with_expected_rows(0) @@ -1082,6 +1284,7 @@ async fn prune_binary_neq() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(3)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(14) @@ -1100,7 +1303,8 @@ async fn prune_binary_lt() { // matches 'all backends' only .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(2)) - .with_matched_by_bloom_filter(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(3) .test_row_group_prune() @@ -1114,7 +1318,8 @@ async fn prune_binary_lt() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) - .with_matched_by_bloom_filter(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) // all backends from 'mixed' and 'all backends' .with_expected_rows(8) @@ -1133,6 +1338,7 @@ async fn prune_fixedsizebinary_eq_match() { // false positive on 'all frontends' batch: 'fe1' < 'fe6' < 'fe7' .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(1)) .with_expected_rows(1) @@ -1148,6 +1354,7 @@ async fn prune_fixedsizebinary_eq_match() { // false positive on 'all frontends' batch: 'fe1' < 'fe6' < 'fe7' .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(1)) .with_expected_rows(1) @@ -1166,6 +1373,7 @@ async fn prune_fixedsizebinary_eq_no_match() { // false positive on 'mixed' batch: 'be1' < 'be9' < 'fe4' .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(2)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(1)) .with_expected_rows(0) @@ -1183,6 +1391,7 @@ async fn prune_fixedsizebinary_neq() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(3)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(14) @@ -1201,7 +1410,8 @@ async fn prune_fixedsizebinary_lt() { // matches 'all backends' only .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(2)) - .with_matched_by_bloom_filter(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) .test_row_group_prune() @@ -1215,7 +1425,8 @@ async fn prune_fixedsizebinary_lt() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) - .with_matched_by_bloom_filter(Some(0)) + .with_pruned_files(Some(0)) + .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) // all backends from 'mixed' and 'all backends' .with_expected_rows(8) @@ -1235,6 +1446,7 @@ async fn prune_periods_in_column_names() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(7) @@ -1246,6 +1458,7 @@ async fn prune_periods_in_column_names() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(2)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(5) @@ -1257,6 +1470,7 @@ async fn prune_periods_in_column_names() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(2)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) @@ -1277,9 +1491,10 @@ async fn test_row_group_with_null_values() { .with_query("SELECT * FROM t WHERE \"i8\" <= 5") .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_pruned_by_stats(Some(2)) .with_expected_rows(5) - .with_matched_by_bloom_filter(Some(0)) + .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .test_row_group_prune() .await; @@ -1290,9 +1505,10 @@ async fn test_row_group_with_null_values() { .with_query("SELECT * FROM t WHERE \"i8\" is Null") .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) + .with_pruned_files(Some(0)) .with_pruned_by_stats(Some(1)) .with_expected_rows(10) - .with_matched_by_bloom_filter(Some(0)) + .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .test_row_group_prune() .await; @@ -1303,9 +1519,10 @@ async fn test_row_group_with_null_values() { .with_query("SELECT * FROM t WHERE \"i16\" is Not Null") .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_pruned_by_stats(Some(2)) .with_expected_rows(5) - .with_matched_by_bloom_filter(Some(0)) + .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .test_row_group_prune() .await; @@ -1316,7 +1533,8 @@ async fn test_row_group_with_null_values() { .with_query("SELECT * FROM t WHERE \"i32\" > 7") .with_expected_errors(Some(0)) .with_matched_by_stats(Some(0)) - .with_pruned_by_stats(Some(3)) + .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(1)) .with_expected_rows(0) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) @@ -1332,6 +1550,7 @@ async fn test_bloom_filter_utf8_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(1) .with_pruned_by_bloom_filter(Some(0)) .with_matched_by_bloom_filter(Some(1)) @@ -1344,6 +1563,7 @@ async fn test_bloom_filter_utf8_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(0) .with_pruned_by_bloom_filter(Some(1)) .with_matched_by_bloom_filter(Some(0)) @@ -1356,6 +1576,7 @@ async fn test_bloom_filter_utf8_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(1) .with_pruned_by_bloom_filter(Some(0)) .with_matched_by_bloom_filter(Some(1)) @@ -1368,6 +1589,7 @@ async fn test_bloom_filter_utf8_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(0) .with_pruned_by_bloom_filter(Some(1)) .with_matched_by_bloom_filter(Some(0)) @@ -1383,6 +1605,7 @@ async fn test_bloom_filter_integer_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(1) .with_pruned_by_bloom_filter(Some(0)) .with_matched_by_bloom_filter(Some(1)) @@ -1395,6 +1618,7 @@ async fn test_bloom_filter_integer_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(0) .with_pruned_by_bloom_filter(Some(1)) .with_matched_by_bloom_filter(Some(0)) @@ -1407,6 +1631,7 @@ async fn test_bloom_filter_integer_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(1) .with_pruned_by_bloom_filter(Some(0)) .with_matched_by_bloom_filter(Some(1)) @@ -1419,6 +1644,7 @@ async fn test_bloom_filter_integer_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(0) .with_pruned_by_bloom_filter(Some(1)) .with_matched_by_bloom_filter(Some(0)) @@ -1434,6 +1660,7 @@ async fn test_bloom_filter_unsigned_integer_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(1) .with_pruned_by_bloom_filter(Some(0)) .with_matched_by_bloom_filter(Some(1)) @@ -1446,6 +1673,7 @@ async fn test_bloom_filter_unsigned_integer_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(0) .with_pruned_by_bloom_filter(Some(1)) .with_matched_by_bloom_filter(Some(0)) @@ -1461,6 +1689,7 @@ async fn test_bloom_filter_binary_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(1) .with_pruned_by_bloom_filter(Some(0)) .with_matched_by_bloom_filter(Some(1)) @@ -1473,6 +1702,7 @@ async fn test_bloom_filter_binary_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(0) .with_pruned_by_bloom_filter(Some(1)) .with_matched_by_bloom_filter(Some(0)) @@ -1485,6 +1715,7 @@ async fn test_bloom_filter_binary_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(1) .with_pruned_by_bloom_filter(Some(0)) .with_matched_by_bloom_filter(Some(1)) @@ -1499,6 +1730,7 @@ async fn test_bloom_filter_binary_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(0) .with_pruned_by_bloom_filter(Some(1)) .with_matched_by_bloom_filter(Some(0)) @@ -1514,6 +1746,7 @@ async fn test_bloom_filter_decimal_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(1) .with_pruned_by_bloom_filter(Some(0)) .with_matched_by_bloom_filter(Some(1)) @@ -1526,9 +1759,247 @@ async fn test_bloom_filter_decimal_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(0) .with_pruned_by_bloom_filter(Some(1)) .with_matched_by_bloom_filter(Some(0)) .test_row_group_prune() .await; } + +// Helper function to create a batch with a single Int32 column. +fn make_i32_batch( + name: &str, + values: Vec, +) -> datafusion_common::error::Result { + let schema = Arc::new(Schema::new(vec![Field::new(name, DataType::Int32, false)])); + let array: ArrayRef = Arc::new(Int32Array::from(values)); + RecordBatch::try_new(schema, vec![array]).map_err(DataFusionError::from) +} + +// Helper function to create a batch with two Int32 columns +fn make_two_col_i32_batch( + name_a: &str, + name_b: &str, + values_a: Vec, + values_b: Vec, +) -> datafusion_common::error::Result { + let schema = Arc::new(Schema::new(vec![ + Field::new(name_a, DataType::Int32, false), + Field::new(name_b, DataType::Int32, false), + ])); + let array_a: ArrayRef = Arc::new(Int32Array::from(values_a)); + let array_b: ArrayRef = Arc::new(Int32Array::from(values_b)); + RecordBatch::try_new(schema, vec![array_a, array_b]).map_err(DataFusionError::from) +} + +#[tokio::test] +async fn test_limit_pruning_basic() -> datafusion_common::error::Result<()> { + // Scenario: Simple integer column, multiple row groups + // Query: SELECT c1 FROM t WHERE c1 = 0 LIMIT 2 + // We expect 2 rows in total. + + // Row Group 0: c1 = [0, -2] -> Partially matched, 1 row + // Row Group 1: c1 = [1, 2] -> Fully matched, 2 rows + // Row Group 2: c1 = [3, 4] -> Fully matched, 2 rows + // Row Group 3: c1 = [5, 6] -> Fully matched, 2 rows + // Row Group 4: c1 = [-1, -2] -> Not matched + + // If limit = 2, and RG1 is fully matched and has 2 rows, we should + // only scan RG1 and prune other row groups + // RG4 is pruned by statistics. RG2 and RG3 are pruned by limit. + // So 2 row groups are effectively pruned due to limit pruning. + + let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); + let query = "SELECT c1 FROM t WHERE c1 >= 0 LIMIT 2"; + + let batches = vec![ + make_i32_batch("c1", vec![0, -2])?, + make_i32_batch("c1", vec![0, 0])?, + make_i32_batch("c1", vec![0, 0])?, + make_i32_batch("c1", vec![0, 0])?, + make_i32_batch("c1", vec![-1, -2])?, + ]; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) // Assuming Scenario::Int can handle this data + .with_query(query) + .with_expected_errors(Some(0)) + .with_expected_rows(2) + .with_pruned_files(Some(0)) + .with_matched_by_stats(Some(4)) + .with_fully_matched_by_stats(Some(3)) + .with_pruned_by_stats(Some(1)) + .with_limit_pruned_row_groups(Some(3)) + .test_row_group_prune_with_custom_data(schema, batches, 2) + .await; + + Ok(()) +} + +#[tokio::test] +async fn test_limit_pruning_complex_filter() -> datafusion_common::error::Result<()> { + // Test Case 1: Complex filter with two columns (a = 1 AND b > 1 AND b < 4) + // Row Group 0: a=[1,1,1], b=[0,2,3] -> Partially matched, 2 rows match (b=2,3) + // Row Group 1: a=[1,1,1], b=[2,2,2] -> Fully matched, 3 rows + // Row Group 2: a=[1,1,1], b=[2,3,3] -> Fully matched, 3 rows + // Row Group 3: a=[1,1,1], b=[2,2,3] -> Fully matched, 3 rows + // Row Group 4: a=[2,2,2], b=[2,2,2] -> Not matched (a != 1) + // Row Group 5: a=[1,1,1], b=[5,6,7] -> Not matched (b >= 4) + + // With LIMIT 5, we need RG1 (3 rows) + RG2 (2 rows from 3) = 5 rows + // RG4 and RG5 should be pruned by statistics + // RG3 should be pruned by limit + // RG0 is partially matched, so it depends on the order + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + let query = "SELECT a, b FROM t WHERE a = 1 AND b > 1 AND b < 4 LIMIT 5"; + + let batches = vec![ + make_two_col_i32_batch("a", "b", vec![1, 1, 1], vec![0, 2, 3])?, + make_two_col_i32_batch("a", "b", vec![1, 1, 1], vec![2, 2, 2])?, + make_two_col_i32_batch("a", "b", vec![1, 1, 1], vec![2, 3, 3])?, + make_two_col_i32_batch("a", "b", vec![1, 1, 1], vec![2, 2, 3])?, + make_two_col_i32_batch("a", "b", vec![2, 2, 2], vec![2, 2, 2])?, + make_two_col_i32_batch("a", "b", vec![1, 1, 1], vec![5, 6, 7])?, + ]; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(query) + .with_expected_errors(Some(0)) + .with_expected_rows(5) + .with_pruned_files(Some(0)) + .with_matched_by_stats(Some(4)) // RG0,1,2,3 are matched + .with_fully_matched_by_stats(Some(3)) + .with_pruned_by_stats(Some(2)) // RG4,5 are pruned + .with_limit_pruned_row_groups(Some(2)) // RG0, RG3 is pruned by limit + .test_row_group_prune_with_custom_data(schema, batches, 3) + .await; + + Ok(()) +} + +#[tokio::test] +async fn test_limit_pruning_multiple_fully_matched() +-> datafusion_common::error::Result<()> { + // Test Case 2: Limit requires multiple fully matched row groups + // Row Group 0: a=[5,5,5,5] -> Fully matched, 4 rows + // Row Group 1: a=[5,5,5,5] -> Fully matched, 4 rows + // Row Group 2: a=[5,5,5,5] -> Fully matched, 4 rows + // Row Group 3: a=[5,5,5,5] -> Fully matched, 4 rows + // Row Group 4: a=[1,2,3,4] -> Not matched + + // With LIMIT 8, we need RG0 (4 rows) + RG1 (4 rows) 8 rows + // RG2,3 should be pruned by limit + // RG4 should be pruned by statistics + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let query = "SELECT a FROM t WHERE a = 5 LIMIT 8"; + + let batches = vec![ + make_i32_batch("a", vec![5, 5, 5, 5])?, + make_i32_batch("a", vec![5, 5, 5, 5])?, + make_i32_batch("a", vec![5, 5, 5, 5])?, + make_i32_batch("a", vec![5, 5, 5, 5])?, + make_i32_batch("a", vec![1, 2, 3, 4])?, + ]; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(query) + .with_expected_errors(Some(0)) + .with_expected_rows(8) + .with_pruned_files(Some(0)) + .with_matched_by_stats(Some(4)) // RG0,1,2,3 matched + .with_fully_matched_by_stats(Some(4)) + .with_pruned_by_stats(Some(1)) // RG4 pruned + .with_limit_pruned_row_groups(Some(2)) // RG2,3 pruned by limit + .test_row_group_prune_with_custom_data(schema, batches, 4) + .await; + + Ok(()) +} + +#[tokio::test] +async fn test_limit_pruning_no_fully_matched() -> datafusion_common::error::Result<()> { + // Test Case 3: No fully matched row groups - all are partially matched + // Row Group 0: a=[1,2,3] -> Partially matched, 1 row (a=2) + // Row Group 1: a=[2,3,4] -> Partially matched, 1 row (a=2) + // Row Group 2: a=[2,5,6] -> Partially matched, 1 row (a=2) + // Row Group 3: a=[2,7,8] -> Partially matched, 1 row (a=2) + // Row Group 4: a=[9,10,11] -> Not matched + + // With LIMIT 3, we need to scan RG0,1,2 to get 3 matching rows + // Cannot prune much by limit since all matching RGs are partial + // RG4 should be pruned by statistics + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let query = "SELECT a FROM t WHERE a = 2 LIMIT 3"; + + let batches = vec![ + make_i32_batch("a", vec![1, 2, 3])?, + make_i32_batch("a", vec![2, 3, 4])?, + make_i32_batch("a", vec![2, 5, 6])?, + make_i32_batch("a", vec![2, 7, 8])?, + make_i32_batch("a", vec![9, 10, 11])?, + ]; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(query) + .with_expected_errors(Some(0)) + .with_expected_rows(3) + .with_pruned_files(Some(0)) + .with_matched_by_stats(Some(4)) // RG0,1,2,3 matched + .with_fully_matched_by_stats(Some(0)) + .with_pruned_by_stats(Some(1)) // RG4 pruned + .with_limit_pruned_row_groups(Some(0)) // RG3 pruned by limit + .test_row_group_prune_with_custom_data(schema, batches, 3) + .await; + + Ok(()) +} + +#[tokio::test] +async fn test_limit_pruning_exceeds_fully_matched() -> datafusion_common::error::Result<()> +{ + // Test Case 4: Limit exceeds all fully matched rows, need partially matched + // Row Group 0: a=[10,11,12,12] -> Partially matched, 1 row (a=10) + // Row Group 1: a=[10,10,10,10] -> Fully matched, 4 rows + // Row Group 2: a=[10,10,10,10] -> Fully matched, 4 rows + // Row Group 3: a=[10,13,14,11] -> Partially matched, 1 row (a=10) + // Row Group 4: a=[20,21,22,22] -> Not matched + + // With LIMIT 10, we need RG1 (4) + RG2 (4) = 8 from fully matched + // Still need 2 more, so we need to scan partially matched RG0 and RG3 + // All matching row groups should be scanned, only RG4 pruned by statistics + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let query = "SELECT a FROM t WHERE a = 10 LIMIT 10"; + + let batches = vec![ + make_i32_batch("a", vec![10, 11, 12, 12])?, + make_i32_batch("a", vec![10, 10, 10, 10])?, + make_i32_batch("a", vec![10, 10, 10, 10])?, + make_i32_batch("a", vec![10, 13, 14, 11])?, + make_i32_batch("a", vec![20, 21, 22, 22])?, + ]; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(query) + .with_expected_errors(Some(0)) + .with_expected_rows(10) // Total: 1 + 4 + 4 + 1 = 10 + .with_pruned_files(Some(0)) + .with_matched_by_stats(Some(4)) // RG0,1,2,3 matched + .with_fully_matched_by_stats(Some(2)) + .with_pruned_by_stats(Some(1)) // RG4 pruned + .with_limit_pruned_row_groups(Some(0)) // No limit pruning since we need all RGs + .test_row_group_prune_with_custom_data(schema, batches, 4) + .await; + Ok(()) +} diff --git a/datafusion/core/tests/parquet/schema_coercion.rs b/datafusion/core/tests/parquet/schema_coercion.rs index 59cbf4b0872ea..6f7e2e328d0c3 100644 --- a/datafusion/core/tests/parquet/schema_coercion.rs +++ b/datafusion/core/tests/parquet/schema_coercion.rs @@ -18,16 +18,16 @@ use std::sync::Arc; use arrow::array::{ - types::Int32Type, ArrayRef, DictionaryArray, Float32Array, Int64Array, RecordBatch, - StringArray, + ArrayRef, DictionaryArray, Float32Array, Int64Array, RecordBatch, StringArray, + types::Int32Type, }; use arrow::datatypes::{DataType, Field, Schema}; use datafusion::datasource::physical_plan::ParquetSource; use datafusion::physical_plan::collect; use datafusion::prelude::SessionContext; use datafusion::test::object_store::local_unpartitioned_file; -use datafusion_common::test_util::batches_to_sort_string; use datafusion_common::Result; +use datafusion_common::test_util::batches_to_sort_string; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; @@ -62,14 +62,10 @@ async fn multi_parquet_coercion() { Field::new("c2", DataType::Int32, true), Field::new("c3", DataType::Float64, true), ])); - let source = Arc::new(ParquetSource::default()); - let conf = FileScanConfigBuilder::new( - ObjectStoreUrl::local_filesystem(), - file_schema, - source, - ) - .with_file_group(file_group) - .build(); + let source = Arc::new(ParquetSource::new(file_schema.clone())); + let conf = FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), source) + .with_file_group(file_group) + .build(); let parquet_exec = DataSourceExec::from_data_source(conf); @@ -122,11 +118,11 @@ async fn multi_parquet_coercion_projection() { ])); let config = FileScanConfigBuilder::new( ObjectStoreUrl::local_filesystem(), - file_schema, - Arc::new(ParquetSource::default()), + Arc::new(ParquetSource::new(file_schema)), ) .with_file_group(file_group) - .with_projection(Some(vec![1, 0, 2])) + .with_projection_indices(Some(vec![1, 0, 2])) + .unwrap() .build(); let parquet_exec = DataSourceExec::from_data_source(config); diff --git a/datafusion/core/tests/parquet/utils.rs b/datafusion/core/tests/parquet/utils.rs index 24b6cadc148f8..e5e0026ec1f16 100644 --- a/datafusion/core/tests/parquet/utils.rs +++ b/datafusion/core/tests/parquet/utils.rs @@ -20,7 +20,7 @@ use datafusion::datasource::physical_plan::ParquetSource; use datafusion::datasource::source::DataSourceExec; use datafusion_physical_plan::metrics::MetricsSet; -use datafusion_physical_plan::{accept, ExecutionPlan, ExecutionPlanVisitor}; +use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanVisitor, accept}; /// Find the metrics from the first DataSourceExec encountered in the plan #[derive(Debug)] @@ -47,13 +47,12 @@ impl MetricsFinder { impl ExecutionPlanVisitor for MetricsFinder { type Error = std::convert::Infallible; fn pre_visit(&mut self, plan: &dyn ExecutionPlan) -> Result { - if let Some(data_source_exec) = plan.as_any().downcast_ref::() { - if data_source_exec + if let Some(data_source_exec) = plan.as_any().downcast_ref::() + && data_source_exec .downcast_to_file_source::() .is_some() - { - self.metrics = data_source_exec.metrics(); - } + { + self.metrics = data_source_exec.metrics(); } // stop searching once we have found the metrics Ok(self.metrics.is_none()) diff --git a/datafusion/core/tests/parquet_config.rs b/datafusion/core/tests/parquet_integration.rs similarity index 100% rename from datafusion/core/tests/parquet_config.rs rename to datafusion/core/tests/parquet_integration.rs diff --git a/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs b/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs index a79d743cb253d..850f9d187780b 100644 --- a/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs @@ -20,26 +20,38 @@ use std::sync::Arc; use crate::physical_optimizer::test_utils::TestAggregate; use arrow::array::Int32Array; +use arrow::array::{Int64Array, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; +use datafusion::datasource::listing::PartitionedFile; +use datafusion::datasource::memory::MemTable; use datafusion::datasource::memory::MemorySourceConfig; +use datafusion::datasource::physical_plan::ParquetSource; use datafusion::datasource::source::DataSourceExec; +use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion_common::assert_batches_eq; use datafusion_common::cast::as_int64_array; use datafusion_common::config::ConfigOptions; -use datafusion_common::Result; +use datafusion_common::stats::Precision; +use datafusion_common::{ColumnStatistics, Result, Statistics}; +use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_execution::TaskContext; +use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_expr::Operator; +use datafusion_functions_aggregate::count::count_udaf; +use datafusion_physical_expr::aggregate::AggregateExprBuilder; use datafusion_physical_expr::expressions::{self, cast}; -use datafusion_physical_optimizer::aggregate_statistics::AggregateStatistics; use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_optimizer::aggregate_statistics::AggregateStatistics; +use datafusion_physical_plan::ExecutionPlan; use datafusion_physical_plan::aggregates::AggregateExec; use datafusion_physical_plan::aggregates::AggregateMode; use datafusion_physical_plan::aggregates::PhysicalGroupBy; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::common; +use datafusion_physical_plan::displayable; use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::projection::ProjectionExec; -use datafusion_physical_plan::ExecutionPlan; /// Mock data using a MemorySourceConfig which has an exact count statistic fn mock_data() -> Result> { @@ -316,3 +328,228 @@ async fn test_count_with_nulls_inexact_stat() -> Result<()> { Ok(()) } + +/// Tests that TopK aggregation correctly handles UTF-8 (string) types in both grouping keys and aggregate values. +/// +/// The TopK optimization is designed to efficiently handle `GROUP BY ... ORDER BY aggregate LIMIT n` queries +/// by maintaining only the top K groups during aggregation. However, not all type combinations are supported. +/// +/// This test verifies two scenarios: +/// 1. **Supported case**: UTF-8 grouping key with numeric aggregate (max/min) - should use TopK optimization +/// 2. **Unsupported case**: UTF-8 grouping key with UTF-8 aggregate value - must gracefully fall back to +/// standard aggregation without panicking +/// +/// The fallback behavior is critical because attempting to use TopK with unsupported types could cause +/// runtime panics. This test ensures the optimizer correctly detects incompatible types and chooses +/// the appropriate execution path. +#[tokio::test] +async fn utf8_grouping_min_max_limit_fallbacks() -> Result<()> { + let mut config = SessionConfig::new(); + config.options_mut().optimizer.enable_topk_aggregation = true; + let ctx = SessionContext::new_with_config(config); + + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![ + Field::new("g", DataType::Utf8, false), + Field::new("val_str", DataType::Utf8, false), + Field::new("val_num", DataType::Int64, false), + ])), + vec![ + Arc::new(StringArray::from(vec!["a", "b", "a"])), + Arc::new(StringArray::from(vec!["alpha", "bravo", "charlie"])), + Arc::new(Int64Array::from(vec![1, 2, 3])), + ], + )?; + let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?; + ctx.register_table("t", Arc::new(table))?; + + // Supported path: numeric min/max with UTF-8 grouping should still use TopK aggregation + // and return correct results. + let supported_df = ctx + .sql("SELECT g, max(val_num) AS m FROM t GROUP BY g ORDER BY m DESC LIMIT 1") + .await?; + let supported_batches = supported_df.collect().await?; + assert_batches_eq!( + &[ + "+---+---+", + "| g | m |", + "+---+---+", + "| a | 3 |", + "+---+---+" + ], + &supported_batches + ); + + // Unsupported TopK value type: string min/max should fall back without panicking. + let unsupported_df = ctx + .sql("SELECT g, max(val_str) AS s FROM t GROUP BY g ORDER BY s DESC LIMIT 1") + .await?; + let unsupported_plan = unsupported_df.clone().create_physical_plan().await?; + let unsupported_batches = unsupported_df.collect().await?; + + // Ensure the plan avoided the TopK-specific stream implementation. + let plan_display = displayable(unsupported_plan.as_ref()) + .indent(true) + .to_string(); + assert!( + !plan_display.contains("GroupedTopKAggregateStream"), + "Unsupported UTF-8 aggregate value should not use TopK: {plan_display}" + ); + + assert_batches_eq!( + &[ + "+---+---------+", + "| g | s |", + "+---+---------+", + "| a | charlie |", + "+---+---------+" + ], + &unsupported_batches + ); + + Ok(()) +} + +#[tokio::test] +async fn test_count_distinct_optimization() -> Result<()> { + struct TestCase { + name: &'static str, + distinct_count: Precision, + use_column_expr: bool, + expect_optimized: bool, + expected_value: Option, + } + + let cases = vec![ + TestCase { + name: "exact statistics", + distinct_count: Precision::Exact(42), + use_column_expr: true, + expect_optimized: true, + expected_value: Some(42), + }, + TestCase { + name: "absent statistics", + distinct_count: Precision::Absent, + use_column_expr: true, + expect_optimized: false, + expected_value: None, + }, + TestCase { + name: "inexact statistics", + distinct_count: Precision::Inexact(42), + use_column_expr: true, + expect_optimized: false, + expected_value: None, + }, + TestCase { + name: "non-column expression with exact statistics", + distinct_count: Precision::Exact(42), + use_column_expr: false, + expect_optimized: false, + expected_value: None, + }, + ]; + + for case in cases { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + let statistics = Statistics { + num_rows: Precision::Exact(100), + total_byte_size: Precision::Absent, + column_statistics: vec![ + ColumnStatistics { + distinct_count: case.distinct_count, + null_count: Precision::Exact(10), + ..Default::default() + }, + ColumnStatistics::default(), + ], + }; + + let config = FileScanConfigBuilder::new( + ObjectStoreUrl::parse("test:///").unwrap(), + Arc::new(ParquetSource::new(Arc::clone(&schema))), + ) + .with_file(PartitionedFile::new("x".to_string(), 100)) + .with_statistics(statistics) + .build(); + + let source: Arc = DataSourceExec::from_data_source(config); + let schema = source.schema(); + + let (agg_args, alias): (Vec>, _) = + if case.use_column_expr { + (vec![expressions::col("a", &schema)?], "COUNT(DISTINCT a)") + } else { + ( + vec![expressions::binary( + expressions::col("a", &schema)?, + Operator::Plus, + expressions::col("b", &schema)?, + &schema, + )?], + "COUNT(DISTINCT a + b)", + ) + }; + + let count_distinct_expr = AggregateExprBuilder::new(count_udaf(), agg_args) + .schema(Arc::clone(&schema)) + .alias(alias) + .distinct() + .build()?; + + let partial_agg = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::default(), + vec![Arc::new(count_distinct_expr.clone())], + vec![None], + source, + Arc::clone(&schema), + )?; + + let final_agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::default(), + vec![Arc::new(count_distinct_expr)], + vec![None], + Arc::new(partial_agg), + Arc::clone(&schema), + )?; + + let conf = ConfigOptions::new(); + let optimized = + AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?; + + if case.expect_optimized { + assert!( + optimized.as_any().is::(), + "'{}': expected ProjectionExec", + case.name + ); + + if let Some(expected_val) = case.expected_value { + let task_ctx = Arc::new(TaskContext::default()); + let result = common::collect(optimized.execute(0, task_ctx)?).await?; + assert_eq!(result.len(), 1, "'{}': expected 1 batch", case.name); + assert_eq!( + as_int64_array(result[0].column(0)).unwrap().values(), + &[expected_val], + "'{}': unexpected value", + case.name + ); + } + } else { + assert!( + optimized.as_any().is::(), + "'{}': expected AggregateExec (not optimized)", + case.name + ); + } + } + + Ok(()) +} diff --git a/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs index 568be0d18f245..9e63c341c92d9 100644 --- a/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs @@ -20,45 +20,40 @@ //! Note these tests are not in the same module as the optimizer pass because //! they rely on `DataSourceExec` which is in the core crate. +use insta::assert_snapshot; use std::sync::Arc; -use crate::physical_optimizer::test_utils::{parquet_exec, trim_plan_display}; +use crate::physical_optimizer::test_utils::parquet_exec; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::config::ConfigOptions; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::sum::sum_udaf; +use datafusion_physical_expr::Partitioning; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion_physical_expr::expressions::{col, lit}; -use datafusion_physical_expr::Partitioning; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_optimizer::combine_partial_final_agg::CombinePartialFinalAggregate; use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_optimizer::combine_partial_final_agg::CombinePartialFinalAggregate; +use datafusion_physical_plan::ExecutionPlan; use datafusion_physical_plan::aggregates::{ - AggregateExec, AggregateMode, PhysicalGroupBy, + AggregateExec, AggregateMode, LimitOptions, PhysicalGroupBy, }; use datafusion_physical_plan::displayable; use datafusion_physical_plan::repartition::RepartitionExec; -use datafusion_physical_plan::ExecutionPlan; /// Runs the CombinePartialFinalAggregate optimizer and asserts the plan against the expected macro_rules! assert_optimized { - ($EXPECTED_LINES: expr, $PLAN: expr) => { - let expected_lines: Vec<&str> = $EXPECTED_LINES.iter().map(|s| *s).collect(); - + ($PLAN: expr, @ $EXPECTED_LINES: literal $(,)?) => { // run optimizer let optimizer = CombinePartialFinalAggregate {}; let config = ConfigOptions::new(); let optimized = optimizer.optimize($PLAN, &config)?; // Now format correctly let plan = displayable(optimized.as_ref()).indent(true).to_string(); - let actual_lines = trim_plan_display(&plan); + let actual_lines = plan.trim(); - assert_eq!( - &expected_lines, &actual_lines, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected_lines, actual_lines - ); + assert_snapshot!(actual_lines, @ $EXPECTED_LINES); }; } @@ -136,7 +131,7 @@ fn aggregations_not_combined() -> datafusion_common::Result<()> { let plan = final_aggregate_exec( repartition_exec(partial_aggregate_exec( - parquet_exec(&schema), + parquet_exec(schema.clone()), PhysicalGroupBy::default(), aggr_expr.clone(), )), @@ -144,20 +139,22 @@ fn aggregations_not_combined() -> datafusion_common::Result<()> { aggr_expr, ); // should not combine the Partial/Final AggregateExecs - let expected = &[ - "AggregateExec: mode=Final, gby=[], aggr=[COUNT(1)]", - "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "AggregateExec: mode=Partial, gby=[], aggr=[COUNT(1)]", - "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c], file_type=parquet", - ]; - assert_optimized!(expected, plan); + assert_optimized!( + plan, + @ r" + AggregateExec: mode=Final, gby=[], aggr=[COUNT(1)] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + AggregateExec: mode=Partial, gby=[], aggr=[COUNT(1)] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c], file_type=parquet + " + ); let aggr_expr1 = vec![count_expr(lit(1i8), "COUNT(1)", &schema)]; let aggr_expr2 = vec![count_expr(lit(1i8), "COUNT(2)", &schema)]; let plan = final_aggregate_exec( partial_aggregate_exec( - parquet_exec(&schema), + parquet_exec(schema), PhysicalGroupBy::default(), aggr_expr1, ), @@ -165,13 +162,14 @@ fn aggregations_not_combined() -> datafusion_common::Result<()> { aggr_expr2, ); // should not combine the Partial/Final AggregateExecs - let expected = &[ - "AggregateExec: mode=Final, gby=[], aggr=[COUNT(2)]", - "AggregateExec: mode=Partial, gby=[], aggr=[COUNT(1)]", - "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c], file_type=parquet", - ]; - - assert_optimized!(expected, plan); + assert_optimized!( + plan, + @ r" + AggregateExec: mode=Final, gby=[], aggr=[COUNT(2)] + AggregateExec: mode=Partial, gby=[], aggr=[COUNT(1)] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c], file_type=parquet + " + ); Ok(()) } @@ -183,7 +181,7 @@ fn aggregations_combined() -> datafusion_common::Result<()> { let plan = final_aggregate_exec( partial_aggregate_exec( - parquet_exec(&schema), + parquet_exec(schema), PhysicalGroupBy::default(), aggr_expr.clone(), ), @@ -191,12 +189,13 @@ fn aggregations_combined() -> datafusion_common::Result<()> { aggr_expr, ); // should combine the Partial/Final AggregateExecs to the Single AggregateExec - let expected = &[ - "AggregateExec: mode=Single, gby=[], aggr=[COUNT(1)]", - "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c], file_type=parquet", - ]; - - assert_optimized!(expected, plan); + assert_optimized!( + plan, + @ r" + AggregateExec: mode=Single, gby=[], aggr=[COUNT(1)] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c], file_type=parquet + " + ); Ok(()) } @@ -215,11 +214,8 @@ fn aggregations_with_group_combined() -> datafusion_common::Result<()> { vec![(col("c", &schema)?, "c".to_string())]; let partial_group_by = PhysicalGroupBy::new_single(groups); - let partial_agg = partial_aggregate_exec( - parquet_exec(&schema), - partial_group_by, - aggr_expr.clone(), - ); + let partial_agg = + partial_aggregate_exec(parquet_exec(schema), partial_group_by, aggr_expr.clone()); let groups: Vec<(Arc, String)> = vec![(col("c", &partial_agg.schema())?, "c".to_string())]; @@ -227,12 +223,13 @@ fn aggregations_with_group_combined() -> datafusion_common::Result<()> { let plan = final_aggregate_exec(partial_agg, final_group_by, aggr_expr); // should combine the Partial/Final AggregateExecs to the Single AggregateExec - let expected = &[ - "AggregateExec: mode=Single, gby=[c@2 as c], aggr=[Sum(b)]", - "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c], file_type=parquet", - ]; - - assert_optimized!(expected, plan); + assert_optimized!( + plan, + @ r" + AggregateExec: mode=Single, gby=[c@2 as c], aggr=[Sum(b)] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c], file_type=parquet + " + ); Ok(()) } @@ -245,11 +242,8 @@ fn aggregations_with_limit_combined() -> datafusion_common::Result<()> { vec![(col("c", &schema)?, "c".to_string())]; let partial_group_by = PhysicalGroupBy::new_single(groups); - let partial_agg = partial_aggregate_exec( - parquet_exec(&schema), - partial_group_by, - aggr_expr.clone(), - ); + let partial_agg = + partial_aggregate_exec(parquet_exec(schema), partial_group_by, aggr_expr.clone()); let groups: Vec<(Arc, String)> = vec![(col("c", &partial_agg.schema())?, "c".to_string())]; @@ -266,16 +260,17 @@ fn aggregations_with_limit_combined() -> datafusion_common::Result<()> { schema, ) .unwrap() - .with_limit(Some(5)), + .with_limit_options(Some(LimitOptions::new(5))), ); let plan: Arc = final_agg; // should combine the Partial/Final AggregateExecs to a Single AggregateExec // with the final limit preserved - let expected = &[ - "AggregateExec: mode=Single, gby=[c@2 as c], aggr=[], lim=[5]", - "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c], file_type=parquet", - ]; - - assert_optimized!(expected, plan); + assert_optimized!( + plan, + @ r" + AggregateExec: mode=Single, gby=[c@2 as c], aggr=[], lim=[5] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c], file_type=parquet + " + ); Ok(()) } diff --git a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs index 4034800c30cba..993798ff7539f 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs @@ -20,59 +20,100 @@ use std::ops::Deref; use std::sync::Arc; use crate::physical_optimizer::test_utils::{ - check_integrity, coalesce_partitions_exec, repartition_exec, schema, - sort_merge_join_exec, sort_preserving_merge_exec, -}; -use crate::physical_optimizer::test_utils::{ - parquet_exec_with_sort, parquet_exec_with_stats, + check_integrity, coalesce_partitions_exec, parquet_exec_with_sort, + parquet_exec_with_stats, repartition_exec, schema, sort_exec, + sort_exec_with_preserve_partitioning, sort_merge_join_exec, + sort_preserving_merge_exec, union_exec, }; -use arrow::array::{RecordBatch, UInt64Array, UInt8Array}; +use arrow::array::{RecordBatch, UInt8Array, UInt64Array}; use arrow::compute::SortOptions; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use datafusion::config::ConfigOptions; +use datafusion::datasource::MemTable; use datafusion::datasource::file_format::file_compression_type::FileCompressionType; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{CsvSource, ParquetSource}; use datafusion::datasource::source::DataSourceExec; -use datafusion::datasource::MemTable; use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_common::error::Result; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::ScalarValue; +use datafusion_common::config::CsvOptions; +use datafusion_common::error::Result; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_expr::{JoinType, Operator}; -use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal}; -use datafusion_physical_expr::PhysicalExpr; -use datafusion_physical_expr::{ - expressions::binary, expressions::lit, LexOrdering, PhysicalSortExpr, +use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal, binary, lit}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::{ + LexOrdering, OrderingRequirements, PhysicalSortExpr, }; -use datafusion_physical_expr_common::sort_expr::LexRequirement; +use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_optimizer::enforce_distribution::*; use datafusion_physical_optimizer::enforce_sorting::EnforceSorting; use datafusion_physical_optimizer::output_requirements::OutputRequirements; -use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; -use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; + use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::execution_plan::ExecutionPlan; use datafusion_physical_plan::expressions::col; use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::joins::utils::JoinOn; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; -use datafusion_physical_plan::projection::ProjectionExec; -use datafusion_physical_plan::sorts::sort::SortExec; +use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr}; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::union::UnionExec; -use datafusion_physical_plan::ExecutionPlanProperties; -use datafusion_physical_plan::PlanProperties; use datafusion_physical_plan::{ - get_plan_string, DisplayAs, DisplayFormatType, Statistics, + DisplayAs, DisplayFormatType, ExecutionPlanProperties, PlanProperties, displayable, }; +use insta::Settings; + +/// Helper function to replace only the first occurrence of a regex pattern in a plan +/// Returns (captured_group_1, modified_string) +fn hide_first( + plan: &dyn ExecutionPlan, + regex: &str, + replacement: &str, +) -> (String, String) { + let plan_str = displayable(plan).indent(true).to_string(); + let pattern = regex::Regex::new(regex).unwrap(); + + if let Some(captures) = pattern.captures(&plan_str) { + let full_match = captures.get(0).unwrap(); + let captured_value = captures + .get(1) + .map(|m| m.as_str().to_string()) + .unwrap_or_default(); + let pos = full_match.start(); + let end_pos = full_match.end(); + let mut result = String::with_capacity(plan_str.len()); + result.push_str(&plan_str[..pos]); + result.push_str(replacement); + result.push_str(&plan_str[end_pos..]); + (captured_value, result) + } else { + (String::new(), plan_str) + } +} + +macro_rules! assert_plan { + ($plan: expr, @ $expected:literal) => { + insta::assert_snapshot!( + displayable($plan.as_ref()).indent(true).to_string(), + @ $expected + ) + }; + ($plan: expr, $another_plan: expr) => { + let plan1 = displayable($plan.as_ref()).indent(true).to_string(); + let plan2 = displayable($another_plan.as_ref()).indent(true).to_string(); + assert_eq!(plan1, plan2); + } +} /// Models operators like BoundedWindowExec that require an input /// ordering but is easy to construct @@ -80,7 +121,7 @@ use datafusion_physical_plan::{ struct SortRequiredExec { input: Arc, expr: LexOrdering, - cache: PlanProperties, + cache: Arc, } impl SortRequiredExec { @@ -92,7 +133,7 @@ impl SortRequiredExec { Self { input, expr: requirement, - cache, + cache: Arc::new(cache), } } @@ -134,7 +175,7 @@ impl ExecutionPlan for SortRequiredExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -147,12 +188,8 @@ impl ExecutionPlan for SortRequiredExec { } // model that it requires the output ordering of its input - fn required_input_ordering(&self) -> Vec> { - if self.expr.is_empty() { - vec![None] - } else { - vec![Some(LexRequirement::from(self.expr.clone()))] - } + fn required_input_ordering(&self) -> Vec> { + vec![Some(OrderingRequirements::from(self.expr.clone()))] } fn with_new_children( @@ -167,6 +204,20 @@ impl ExecutionPlan for SortRequiredExec { ))) } + fn apply_expressions( + &self, + f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + // Visit expressions in the output ordering from equivalence properties + let mut tnr = TreeNodeRecursion::Continue; + if let Some(ordering) = self.cache.output_ordering() { + for sort_expr in ordering { + tnr = tnr.visit_sibling(|| f(sort_expr.expr.as_ref()))?; + } + } + Ok(tnr) + } + fn execute( &self, _partition: usize, @@ -174,14 +225,10 @@ impl ExecutionPlan for SortRequiredExec { ) -> Result { unreachable!(); } - - fn statistics(&self) -> Result { - self.input.partition_statistics(None) - } } fn parquet_exec() -> Arc { - parquet_exec_with_sort(vec![]) + parquet_exec_with_sort(schema(), vec![]) } fn parquet_exec_multiple() -> Arc { @@ -194,8 +241,7 @@ fn parquet_exec_multiple_sorted( ) -> Arc { let config = FileScanConfigBuilder::new( ObjectStoreUrl::parse("test:///").unwrap(), - schema(), - Arc::new(ParquetSource::default()), + Arc::new(ParquetSource::new(schema())), ) .with_file_groups(vec![ FileGroup::new(vec![PartitionedFile::new("x".to_string(), 100)]), @@ -212,14 +258,19 @@ fn csv_exec() -> Arc { } fn csv_exec_with_sort(output_ordering: Vec) -> Arc { - let config = FileScanConfigBuilder::new( - ObjectStoreUrl::parse("test:///").unwrap(), - schema(), - Arc::new(CsvSource::new(false, b',', b'"')), - ) - .with_file(PartitionedFile::new("x".to_string(), 100)) - .with_output_ordering(output_ordering) - .build(); + let config = + FileScanConfigBuilder::new(ObjectStoreUrl::parse("test:///").unwrap(), { + let options = CsvOptions { + has_header: Some(false), + delimiter: b',', + quote: b'"', + ..Default::default() + }; + Arc::new(CsvSource::new(schema()).with_csv_options(options)) + }) + .with_file(PartitionedFile::new("x".to_string(), 100)) + .with_output_ordering(output_ordering) + .build(); DataSourceExec::from_data_source(config) } @@ -230,17 +281,22 @@ fn csv_exec_multiple() -> Arc { // Created a sorted parquet exec with multiple files fn csv_exec_multiple_sorted(output_ordering: Vec) -> Arc { - let config = FileScanConfigBuilder::new( - ObjectStoreUrl::parse("test:///").unwrap(), - schema(), - Arc::new(CsvSource::new(false, b',', b'"')), - ) - .with_file_groups(vec![ - FileGroup::new(vec![PartitionedFile::new("x".to_string(), 100)]), - FileGroup::new(vec![PartitionedFile::new("y".to_string(), 100)]), - ]) - .with_output_ordering(output_ordering) - .build(); + let config = + FileScanConfigBuilder::new(ObjectStoreUrl::parse("test:///").unwrap(), { + let options = CsvOptions { + has_header: Some(false), + delimiter: b',', + quote: b'"', + ..Default::default() + }; + Arc::new(CsvSource::new(schema()).with_csv_options(options)) + }) + .with_file_groups(vec![ + FileGroup::new(vec![PartitionedFile::new("x".to_string(), 100)]), + FileGroup::new(vec![PartitionedFile::new("y".to_string(), 100)]), + ]) + .with_output_ordering(output_ordering) + .build(); DataSourceExec::from_data_source(config) } @@ -251,7 +307,10 @@ fn projection_exec_with_alias( ) -> Arc { let mut exprs = vec![]; for (column, alias) in alias_pairs.iter() { - exprs.push((col(column, &input.schema()).unwrap(), alias.to_string())); + exprs.push(ProjectionExpr { + expr: col(column, &input.schema()).unwrap(), + alias: alias.to_string(), + }); } Arc::new(ProjectionExec::try_new(exprs, input).unwrap()) } @@ -327,16 +386,6 @@ fn filter_exec(input: Arc) -> Arc { Arc::new(FilterExec::try_new(predicate, input).unwrap()) } -fn sort_exec( - sort_exprs: LexOrdering, - input: Arc, - preserve_partitioning: bool, -) -> Arc { - let new_sort = SortExec::new(sort_exprs, input) - .with_preserve_partitioning(preserve_partitioning); - Arc::new(new_sort) -} - fn limit_exec(input: Arc) -> Arc { Arc::new(GlobalLimitExec::new( Arc::new(LocalLimitExec::new(input, 100)), @@ -345,10 +394,6 @@ fn limit_exec(input: Arc) -> Arc { )) } -fn union_exec(input: Vec>) -> Arc { - Arc::new(UnionExec::new(input)) -} - fn sort_required_exec_with_req( input: Arc, sort_exprs: LexOrdering, @@ -371,22 +416,6 @@ fn ensure_distribution_helper( ensure_distribution(distribution_context, &config).map(|item| item.data.plan) } -/// Test whether plan matches with expected plan -macro_rules! plans_matches_expected { - ($EXPECTED_LINES: expr, $PLAN: expr) => { - let physical_plan = $PLAN; - let actual = get_plan_string(&physical_plan); - - let expected_plan_lines: Vec<&str> = $EXPECTED_LINES - .iter().map(|s| *s).collect(); - - assert_eq!( - expected_plan_lines, actual, - "\n**Original Plan Mismatch\n\nexpected:\n\n{expected_plan_lines:#?}\nactual:\n\n{actual:#?}\n\n" - ); - } -} - fn test_suite_default_config_options() -> ConfigOptions { let mut config = ConfigOptions::new(); @@ -463,15 +492,12 @@ impl TestConfig { /// Perform a series of runs using the current [`TestConfig`], /// assert the expected plan result, - /// and return the result plan (for potentional subsequent runs). - fn run( + /// and return the result plan (for potential subsequent runs). + fn try_to_plan( &self, - expected_lines: &[&str], plan: Arc, optimizers_to_run: &[Run], ) -> Result> { - let expected_lines: Vec<&str> = expected_lines.to_vec(); - // Add the ancillary output requirements operator at the start: let optimizer = OutputRequirements::new_add_mode(); let mut optimized = optimizer.optimize(plan.clone(), &self.config)?; @@ -526,30 +552,16 @@ impl TestConfig { let optimizer = OutputRequirements::new_remove_mode(); let optimized = optimizer.optimize(optimized, &self.config)?; - // Now format correctly - let actual_lines = get_plan_string(&optimized); - - assert_eq!( - &expected_lines, &actual_lines, - "\n\nexpected:\n\n{expected_lines:#?}\nactual:\n\n{actual_lines:#?}\n\n" - ); - Ok(optimized) } -} - -macro_rules! assert_plan_txt { - ($EXPECTED_LINES: expr, $PLAN: expr) => { - let expected_lines: Vec<&str> = $EXPECTED_LINES.iter().map(|s| *s).collect(); - // Now format correctly - let actual_lines = get_plan_string(&$PLAN); - assert_eq!( - &expected_lines, &actual_lines, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected_lines, actual_lines - ); - }; + fn to_plan( + &self, + plan: Arc, + optimizers_to_run: &[Run], + ) -> Arc { + self.try_to_plan(plan, optimizers_to_run).unwrap() + } } #[test] @@ -575,6 +587,8 @@ fn multi_hash_joins() -> Result<()> { JoinType::RightAnti, ]; + let settings = Settings::clone_current(); + // Join on (a == b1) let join_on = vec![( Arc::new(Column::new_with_schema("a", &schema()).unwrap()) as _, @@ -583,11 +597,17 @@ fn multi_hash_joins() -> Result<()> { for join_type in join_types { let join = hash_join_exec(left.clone(), right.clone(), &join_on, &join_type); - let join_plan = |shift| -> String { - format!("{}HashJoinExec: mode=Partitioned, join_type={join_type}, on=[(a@0, b1@1)]", " ".repeat(shift)) - }; - let join_plan_indent2 = join_plan(2); - let join_plan_indent4 = join_plan(4); + + let mut settings = settings.clone(); + settings.add_filter( + // join_type={} replace with join_type=... to avoid snapshot name issue + format!("join_type={join_type}").as_str(), + "join_type=...", + ); + + insta::allow_duplicates! { + settings.bind( || { + match join_type { JoinType::Inner @@ -608,57 +628,60 @@ fn multi_hash_joins() -> Result<()> { &top_join_on, &join_type, ); - let top_join_plan = - format!("HashJoinExec: mode=Partitioned, join_type={join_type}, on=[(a@0, c@2)]"); - let expected = match join_type { + let test_config = TestConfig::default(); + let plan_distrib = test_config.to_plan(top_join.clone(), &DISTRIB_DISTRIB_SORT); + + match join_type { // Should include 3 RepartitionExecs - JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => vec![ - top_join_plan.as_str(), - &join_plan_indent2, - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ], + JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { + + assert_plan!(plan_distrib, @r" + HashJoinExec: mode=Partitioned, join_type=..., on=[(a@0, c@2)] + HashJoinExec: mode=Partitioned, join_type=..., on=[(a@0, b1@1)] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + }, // Should include 4 RepartitionExecs - _ => vec![ - top_join_plan.as_str(), - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", - &join_plan_indent4, - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ], + _ => { + assert_plan!(plan_distrib, @r" + HashJoinExec: mode=Partitioned, join_type=..., on=[(a@0, c@2)] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10 + HashJoinExec: mode=Partitioned, join_type=..., on=[(a@0, b1@1)] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + }, }; - let test_config = TestConfig::default(); - test_config.run(&expected, top_join.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(&expected, top_join, &SORT_DISTRIB_DISTRIB)?; + + let plan_sort = test_config.to_plan(top_join, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); } - JoinType::RightSemi | JoinType::RightAnti => {} + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => {} } + + match join_type { JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full | JoinType::RightSemi - | JoinType::RightAnti => { + | JoinType::RightAnti + | JoinType::RightMark => { // This time we use (b1 == c) for top join // Join on (b1 == c) let top_join_on = vec![( @@ -668,55 +691,58 @@ fn multi_hash_joins() -> Result<()> { let top_join = hash_join_exec(join, parquet_exec(), &top_join_on, &join_type); - let top_join_plan = match join_type { - JoinType::RightSemi | JoinType::RightAnti => - format!("HashJoinExec: mode=Partitioned, join_type={join_type}, on=[(b1@1, c@2)]"), - _ => - format!("HashJoinExec: mode=Partitioned, join_type={join_type}, on=[(b1@6, c@2)]"), - }; - let expected = match join_type { + let test_config = TestConfig::default(); + let plan_distrib = test_config.to_plan(top_join.clone(), &DISTRIB_DISTRIB_SORT); + + match join_type { // Should include 3 RepartitionExecs - JoinType::Inner | JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => - vec![ - top_join_plan.as_str(), - &join_plan_indent2, - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ], + JoinType::Inner | JoinType::Right => { + assert_plan!(parquet_exec(), @"DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet"); + }, + // Should include 3 RepartitionExecs but have a different "on" + JoinType::RightSemi | JoinType::RightAnti => { + assert_plan!(plan_distrib, @r" + HashJoinExec: mode=Partitioned, join_type=..., on=[(b1@1, c@2)] + HashJoinExec: mode=Partitioned, join_type=..., on=[(a@0, b1@1)] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + + } + // Should include 4 RepartitionExecs - _ => - vec![ - top_join_plan.as_str(), - " RepartitionExec: partitioning=Hash([b1@6], 10), input_partitions=10", - &join_plan_indent4, - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ], + _ => { + assert_plan!(plan_distrib, @r" + HashJoinExec: mode=Partitioned, join_type=..., on=[(b1@6, c@2)] + RepartitionExec: partitioning=Hash([b1@6], 10), input_partitions=10 + HashJoinExec: mode=Partitioned, join_type=..., on=[(a@0, b1@1)] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + + }, }; - let test_config = TestConfig::default(); - test_config.run(&expected, top_join.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(&expected, top_join, &SORT_DISTRIB_DISTRIB)?; + + let plan_sort = test_config.to_plan(top_join, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); } JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => {} } + + }); + } } Ok(()) @@ -755,23 +781,24 @@ fn multi_joins_after_alias() -> Result<()> { ); // Output partition need to respect the Alias and should not introduce additional RepartitionExec - let expected = &[ - "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a1@0, c@2)]", - " ProjectionExec: expr=[a@0 as a1, a@0 as a2]", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, b@1)]", - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([b@1], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; let test_config = TestConfig::default(); - test_config.run(expected, top_join.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected, top_join, &SORT_DISTRIB_DISTRIB)?; + let plan_distrib = test_config.to_plan(top_join.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!( + plan_distrib, + @r" + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a1@0, c@2)] + ProjectionExec: expr=[a@0 as a1, a@0 as a2] + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, b@1)] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b@1], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + " + ); + let plan_sort = test_config.to_plan(top_join, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); // Join on (a2 == c) let top_join_on = vec![( @@ -782,23 +809,24 @@ fn multi_joins_after_alias() -> Result<()> { let top_join = hash_join_exec(projection, right, &top_join_on, &JoinType::Inner); // Output partition need to respect the Alias and should not introduce additional RepartitionExec - let expected = &[ - "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a2@1, c@2)]", - " ProjectionExec: expr=[a@0 as a1, a@0 as a2]", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, b@1)]", - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([b@1], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; let test_config = TestConfig::default(); - test_config.run(expected, top_join.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected, top_join, &SORT_DISTRIB_DISTRIB)?; + let plan_distrib = test_config.to_plan(top_join.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!( + plan_distrib, + @r" + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a2@1, c@2)] + ProjectionExec: expr=[a@0 as a1, a@0 as a2] + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, b@1)] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b@1], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + " + ); + let plan_sort = test_config.to_plan(top_join, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); Ok(()) } @@ -834,26 +862,26 @@ fn multi_joins_after_multi_alias() -> Result<()> { // The Column 'a' has different meaning now after the two Projections // The original Output partition can not satisfy the Join requirements and need to add an additional RepartitionExec - let expected = &[ - "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, c@2)]", - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", - " ProjectionExec: expr=[c1@0 as a]", - " ProjectionExec: expr=[c@2 as c1]", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, b@1)]", - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([b@1], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; - let test_config = TestConfig::default(); - test_config.run(expected, top_join.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected, top_join, &SORT_DISTRIB_DISTRIB)?; + let plan_distrib = test_config.to_plan(top_join.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!( + plan_distrib, + @r" + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, c@2)] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10 + ProjectionExec: expr=[c1@0 as a] + ProjectionExec: expr=[c@2 as c1] + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, b@1)] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b@1], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + " + ); + let plan_sort = test_config.to_plan(top_join, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); Ok(()) } @@ -879,22 +907,26 @@ fn join_after_agg_alias() -> Result<()> { let join = hash_join_exec(left, right.clone(), &join_on, &JoinType::Inner); // Only two RepartitionExecs added - let expected = &[ - "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a1@0, a2@0)]", - " AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[]", - " RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10", - " AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " AggregateExec: mode=FinalPartitioned, gby=[a2@0 as a2], aggr=[]", - " RepartitionExec: partitioning=Hash([a2@0], 10), input_partitions=10", - " AggregateExec: mode=Partial, gby=[a@0 as a2], aggr=[]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; let test_config = TestConfig::default(); - test_config.run(expected, join.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected, join, &SORT_DISTRIB_DISTRIB)?; + let plan_distrib = test_config.to_plan(join.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!( + plan_distrib, + @r" + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a1@0, a2@0)] + AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[] + RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10 + AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + AggregateExec: mode=FinalPartitioned, gby=[a2@0 as a2], aggr=[] + RepartitionExec: partitioning=Hash([a2@0], 10), input_partitions=10 + AggregateExec: mode=Partial, gby=[a@0 as a2], aggr=[] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + " + ); + let plan_sort = test_config.to_plan(join, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); Ok(()) } @@ -932,23 +964,27 @@ fn hash_join_key_ordering() -> Result<()> { let join = hash_join_exec(left, right.clone(), &join_on, &JoinType::Inner); // Only two RepartitionExecs added - let expected = &[ - "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(b1@1, b@0), (a1@0, a@1)]", - " ProjectionExec: expr=[a1@1 as a1, b1@0 as b1]", - " AggregateExec: mode=FinalPartitioned, gby=[b1@0 as b1, a1@1 as a1], aggr=[]", - " RepartitionExec: partitioning=Hash([b1@0, a1@1], 10), input_partitions=10", - " AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " AggregateExec: mode=FinalPartitioned, gby=[b@0 as b, a@1 as a], aggr=[]", - " RepartitionExec: partitioning=Hash([b@0, a@1], 10), input_partitions=10", - " AggregateExec: mode=Partial, gby=[b@1 as b, a@0 as a], aggr=[]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; let test_config = TestConfig::default(); - test_config.run(expected, join.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected, join, &SORT_DISTRIB_DISTRIB)?; + let plan_distrib = test_config.to_plan(join.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!( + plan_distrib, + @r" + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(b1@1, b@0), (a1@0, a@1)] + ProjectionExec: expr=[a1@1 as a1, b1@0 as b1] + AggregateExec: mode=FinalPartitioned, gby=[b1@0 as b1, a1@1 as a1], aggr=[] + RepartitionExec: partitioning=Hash([b1@0, a1@1], 10), input_partitions=10 + AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + AggregateExec: mode=FinalPartitioned, gby=[b@0 as b, a@1 as a], aggr=[] + RepartitionExec: partitioning=Hash([b@0, a@1], 10), input_partitions=10 + AggregateExec: mode=Partial, gby=[b@1 as b, a@0 as a], aggr=[] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + " + ); + let plan_sort = test_config.to_plan(join, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); Ok(()) } @@ -1052,30 +1088,31 @@ fn multi_hash_join_key_ordering() -> Result<()> { Arc::new(FilterExec::try_new(predicate, top_join)?); // The bottom joins' join key ordering is adjusted based on the top join. And the top join should not introduce additional RepartitionExec - let expected = &[ - "FilterExec: c@6 > 1", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(B@2, b1@6), (C@3, c@2), (AA@1, a1@5)]", - " ProjectionExec: expr=[a@0 as A, a@0 as AA, b@1 as B, c@2 as C]", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(b@1, b1@1), (c@2, c1@2), (a@0, a1@0)]", - " RepartitionExec: partitioning=Hash([b@1, c@2, a@0], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([b1@1, c1@2, a1@0], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(b@1, b1@1), (c@2, c1@2), (a@0, a1@0)]", - " RepartitionExec: partitioning=Hash([b@1, c@2, a@0], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([b1@1, c1@2, a1@0], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; let test_config = TestConfig::default(); - test_config.run(expected, filter_top_join.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected, filter_top_join, &SORT_DISTRIB_DISTRIB)?; + let plan_distrib = + test_config.to_plan(filter_top_join.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!( + plan_distrib, + @r" + FilterExec: c@6 > 1 + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(B@2, b1@6), (C@3, c@2), (AA@1, a1@5)] + ProjectionExec: expr=[a@0 as A, a@0 as AA, b@1 as B, c@2 as C] + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(b@1, b1@1), (c@2, c1@2), (a@0, a1@0)] + RepartitionExec: partitioning=Hash([b@1, c@2, a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b1@1, c1@2, a1@0], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(b@1, b1@1), (c@2, c1@2), (a@0, a1@0)] + RepartitionExec: partitioning=Hash([b@1, c@2, a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b1@1, c1@2, a1@0], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + " + ); + let plan_sort = test_config.to_plan(filter_top_join, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); Ok(()) } @@ -1186,34 +1223,30 @@ fn reorder_join_keys_to_left_input() -> Result<()> { &top_join_on, &join_type, ); - let top_join_plan = - format!("HashJoinExec: mode=Partitioned, join_type={:?}, on=[(AA@1, a1@5), (B@2, b1@6), (C@3, c@2)]", &join_type); - let reordered = reorder_join_keys_to_inputs(top_join)?; + let reordered = reorder_join_keys_to_inputs(top_join).unwrap(); // The top joins' join key ordering is adjusted based on the children inputs. - let expected = &[ - top_join_plan.as_str(), - " ProjectionExec: expr=[a@0 as A, a@0 as AA, b@1 as B, c@2 as C]", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a1@0), (b@1, b1@1), (c@2, c1@2)]", - " RepartitionExec: partitioning=Hash([a@0, b@1, c@2], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([a1@0, b1@1, c1@2], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@2, c1@2), (b@1, b1@1), (a@0, a1@0)]", - " RepartitionExec: partitioning=Hash([c@2, b@1, a@0], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([c1@2, b1@1, a1@0], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; - - assert_plan_txt!(expected, reordered); + let (captured_join_type, modified_plan) = + hide_first(reordered.as_ref(), r"join_type=(\w+)", "join_type=..."); + assert_eq!(captured_join_type, join_type.to_string()); + + insta::allow_duplicates! {insta::assert_snapshot!(modified_plan, @r" + HashJoinExec: mode=Partitioned, join_type=..., on=[(AA@1, a1@5), (B@2, b1@6), (C@3, c@2)] + ProjectionExec: expr=[a@0 as A, a@0 as AA, b@1 as B, c@2 as C] + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a1@0), (b@1, b1@1), (c@2, c1@2)] + RepartitionExec: partitioning=Hash([a@0, b@1, c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([a1@0, b1@1, c1@2], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@2, c1@2), (b@1, b1@1), (a@0, a1@0)] + RepartitionExec: partitioning=Hash([c@2, b@1, a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c1@2, b1@1, a1@0], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + ");} } Ok(()) @@ -1320,34 +1353,28 @@ fn reorder_join_keys_to_right_input() -> Result<()> { &top_join_on, &join_type, ); - let top_join_plan = - format!("HashJoinExec: mode=Partitioned, join_type={:?}, on=[(C@3, c@2), (B@2, b1@6), (AA@1, a1@5)]", &join_type); - let reordered = reorder_join_keys_to_inputs(top_join)?; + let reordered = reorder_join_keys_to_inputs(top_join).unwrap(); // The top joins' join key ordering is adjusted based on the children inputs. - let expected = &[ - top_join_plan.as_str(), - " ProjectionExec: expr=[a@0 as A, a@0 as AA, b@1 as B, c@2 as C]", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a1@0), (b@1, b1@1)]", - " RepartitionExec: partitioning=Hash([a@0, b@1], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([a1@0, b1@1], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@2, c1@2), (b@1, b1@1), (a@0, a1@0)]", - " RepartitionExec: partitioning=Hash([c@2, b@1, a@0], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([c1@2, b1@1, a1@0], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; - - assert_plan_txt!(expected, reordered); + let (_, plan_str) = + hide_first(reordered.as_ref(), r"join_type=(\w+)", "join_type=..."); + insta::allow_duplicates! {insta::assert_snapshot!(plan_str, @r" + HashJoinExec: mode=Partitioned, join_type=..., on=[(C@3, c@2), (B@2, b1@6), (AA@1, a1@5)] + ProjectionExec: expr=[a@0 as A, a@0 as AA, b@1 as B, c@2 as C] + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a1@0), (b@1, b1@1)] + RepartitionExec: partitioning=Hash([a@0, b@1], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([a1@0, b1@1], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@2, c1@2), (b@1, b1@1), (a@0, a1@0)] + RepartitionExec: partitioning=Hash([c@2, b@1, a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c1@2, b1@1, a1@0], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + ");} } Ok(()) @@ -1387,15 +1414,6 @@ fn multi_smj_joins() -> Result<()> { for join_type in join_types { let join = sort_merge_join_exec(left.clone(), right.clone(), &join_on, &join_type); - let join_plan = |shift| -> String { - format!( - "{}SortMergeJoin: join_type={join_type}, on=[(a@0, b1@1)]", - " ".repeat(shift) - ) - }; - let join_plan_indent2 = join_plan(2); - let join_plan_indent6 = join_plan(6); - let join_plan_indent10 = join_plan(10); // Top join on (a == c) let top_join_on = vec![( @@ -1404,235 +1422,220 @@ fn multi_smj_joins() -> Result<()> { )]; let top_join = sort_merge_join_exec(join.clone(), parquet_exec(), &top_join_on, &join_type); - let top_join_plan = - format!("SortMergeJoin: join_type={join_type}, on=[(a@0, c@2)]"); - - let expected = match join_type { - // Should include 6 RepartitionExecs (3 hash, 3 round-robin), 3 SortExecs - JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => - vec![ - top_join_plan.as_str(), - &join_plan_indent2, - " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " SortExec: expr=[b1@1 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " SortExec: expr=[c@2 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ], - // Should include 7 RepartitionExecs (4 hash, 3 round-robin), 4 SortExecs - // Since ordering of the left child is not preserved after SortMergeJoin - // when mode is Right, RightSemi, RightAnti, Full - // - We need to add one additional SortExec after SortMergeJoin in contrast the test cases - // when mode is Inner, Left, LeftSemi, LeftAnti - // Similarly, since partitioning of the left side is not preserved - // when mode is Right, RightSemi, RightAnti, Full - // - We need to add one additional Hash Repartition after SortMergeJoin in contrast the test - // cases when mode is Inner, Left, LeftSemi, LeftAnti - _ => vec![ - top_join_plan.as_str(), - // Below 2 operators are differences introduced, when join mode is changed - " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", - &join_plan_indent6, - " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " SortExec: expr=[b1@1 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " SortExec: expr=[c@2 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ], - }; - // TODO(wiedld): show different test result if enforce sorting first. - test_config.run(&expected, top_join.clone(), &DISTRIB_DISTRIB_SORT)?; - - let expected_first_sort_enforcement = match join_type { - // Should include 6 RepartitionExecs (3 hash, 3 round-robin), 3 SortExecs - JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => - vec![ - top_join_plan.as_str(), - &join_plan_indent2, - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, preserve_order=true, sort_exprs=b1@1 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[b1@1 ASC], preserve_partitioning=[false]", - " ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=c@2 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ], - // Should include 8 RepartitionExecs (4 hash, 8 round-robin), 4 SortExecs - // Since ordering of the left child is not preserved after SortMergeJoin - // when mode is Right, RightSemi, RightAnti, Full - // - We need to add one additional SortExec after SortMergeJoin in contrast the test cases - // when mode is Inner, Left, LeftSemi, LeftAnti - // Similarly, since partitioning of the left side is not preserved - // when mode is Right, RightSemi, RightAnti, Full - // - We need to add one additional Hash Repartition and Roundrobin repartition after - // SortMergeJoin in contrast the test cases when mode is Inner, Left, LeftSemi, LeftAnti - _ => vec![ - top_join_plan.as_str(), - // Below 4 operators are differences introduced, when join mode is changed - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - &join_plan_indent10, - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, preserve_order=true, sort_exprs=b1@1 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[b1@1 ASC], preserve_partitioning=[false]", - " ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=c@2 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ], - }; - // TODO(wiedld): show different test result if enforce distribution first. - test_config.run( - &expected_first_sort_enforcement, - top_join, - &SORT_DISTRIB_DISTRIB, - )?; - match join_type { - JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { - // This time we use (b1 == c) for top join - // Join on (b1 == c) - let top_join_on = vec![( - Arc::new(Column::new_with_schema("b1", &join.schema()).unwrap()) as _, - Arc::new(Column::new_with_schema("c", &schema()).unwrap()) as _, - )]; - let top_join = - sort_merge_join_exec(join, parquet_exec(), &top_join_on, &join_type); - let top_join_plan = - format!("SortMergeJoin: join_type={join_type}, on=[(b1@6, c@2)]"); - - let expected = match join_type { - // Should include 6 RepartitionExecs(3 hash, 3 round-robin) and 3 SortExecs - JoinType::Inner | JoinType::Right => vec![ - top_join_plan.as_str(), - &join_plan_indent2, - " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " SortExec: expr=[b1@1 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " SortExec: expr=[c@2 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ], - // Should include 7 RepartitionExecs (4 hash, 3 round-robin) and 4 SortExecs - JoinType::Left | JoinType::Full => vec![ - top_join_plan.as_str(), - " SortExec: expr=[b1@6 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([b1@6], 10), input_partitions=10", - &join_plan_indent6, - " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " SortExec: expr=[b1@1 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " SortExec: expr=[c@2 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ], - // this match arm cannot be reached - _ => unreachable!() - }; - // TODO(wiedld): show different test result if enforce sorting first. - test_config.run(&expected, top_join.clone(), &DISTRIB_DISTRIB_SORT)?; - - let expected_first_sort_enforcement = match join_type { - // Should include 6 RepartitionExecs (3 of them preserves order) and 3 SortExecs - JoinType::Inner | JoinType::Right => vec![ - top_join_plan.as_str(), - &join_plan_indent2, - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, preserve_order=true, sort_exprs=b1@1 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[b1@1 ASC], preserve_partitioning=[false]", - " ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=c@2 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ], - // Should include 8 RepartitionExecs (4 of them preserves order) and 4 SortExecs - JoinType::Left | JoinType::Full => vec![ - top_join_plan.as_str(), - " RepartitionExec: partitioning=Hash([b1@6], 10), input_partitions=10, preserve_order=true, sort_exprs=b1@6 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[b1@6 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - &join_plan_indent10, - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, preserve_order=true, sort_exprs=b1@1 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[b1@1 ASC], preserve_partitioning=[false]", - " ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=c@2 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ], - // this match arm cannot be reached - _ => unreachable!() - }; + let mut settings = Settings::clone_current(); + settings.add_filter(&format!("join_type={join_type}"), "join_type=..."); + + #[rustfmt::skip] + insta::allow_duplicates! { + settings.bind(|| { + let plan_distrib = test_config.to_plan(top_join.clone(), &DISTRIB_DISTRIB_SORT); + + match join_type { + // Should include 6 RepartitionExecs (3 hash, 3 round-robin), 3 SortExecs + JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => { + assert_plan!(plan_distrib, @r" + SortMergeJoinExec: join_type=..., on=[(a@0, c@2)] + SortMergeJoinExec: join_type=..., on=[(a@0, b1@1)] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + SortExec: expr=[b1@1 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + SortExec: expr=[c@2 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + } + // Should include 7 RepartitionExecs (4 hash, 3 round-robin), 4 SortExecs + // Since ordering of the left child is not preserved after SortMergeJoinExec + // when mode is Right, RightSemi, RightAnti, Full + // - We need to add one additional SortExec after SortMergeJoinExec in contrast the test cases + // when mode is Inner, Left, LeftSemi, LeftAnti + // Similarly, since partitioning of the left side is not preserved + // when mode is Right, RightSemi, RightAnti, Full + // - We need to add one additional Hash Repartition after SortMergeJoinExec in contrast the test + // cases when mode is Inner, Left, LeftSemi, LeftAnti + _ => { + assert_plan!(plan_distrib, @r" + SortMergeJoinExec: join_type=..., on=[(a@0, c@2)] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10 + SortMergeJoinExec: join_type=..., on=[(a@0, b1@1)] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + SortExec: expr=[b1@1 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + SortExec: expr=[c@2 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + } + } - // TODO(wiedld): show different test result if enforce distribution first. - test_config.run( - &expected_first_sort_enforcement, - top_join, - &SORT_DISTRIB_DISTRIB, - )?; - } - _ => {} + let plan_sort = test_config.to_plan(top_join.clone(), &SORT_DISTRIB_DISTRIB); + + match join_type { + // Should include 6 RepartitionExecs (3 hash, 3 round-robin), 3 SortExecs + JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => { + // TODO(wiedld): show different test result if enforce distribution first. + assert_plan!(plan_sort, @r" + SortMergeJoinExec: join_type=..., on=[(a@0, c@2)] + SortMergeJoinExec: join_type=..., on=[(a@0, b1@1)] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[b1@1 ASC], preserve_partitioning=[false] + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + } + // Should include 8 RepartitionExecs (4 hash, 8 round-robin), 4 SortExecs + // Since ordering of the left child is not preserved after SortMergeJoinExec + // when mode is Right, RightSemi, RightAnti, Full + // - We need to add one additional SortExec after SortMergeJoinExec in contrast the test cases + // when mode is Inner, Left, LeftSemi, LeftAnti + // Similarly, since partitioning of the left side is not preserved + // when mode is Right, RightSemi, RightAnti, Full + // - We need to add one additional Hash Repartition and Roundrobin repartition after + // SortMergeJoinExec in contrast the test cases when mode is Inner, Left, LeftSemi, LeftAnti + _ => { + // TODO(wiedld): show different test result if enforce distribution first. + assert_plan!(plan_sort, @r" + SortMergeJoinExec: join_type=..., on=[(a@0, c@2)] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + SortMergeJoinExec: join_type=..., on=[(a@0, b1@1)] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[b1@1 ASC], preserve_partitioning=[false] + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + } + } + + match join_type { + JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { + // This time we use (b1 == c) for top join + // Join on (b1 == c) + let top_join_on = vec![( + Arc::new(Column::new_with_schema("b1", &join.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("c", &schema()).unwrap()) as _, + )]; + let top_join = sort_merge_join_exec(join, parquet_exec(), &top_join_on, &join_type); + + let plan_distrib = test_config.to_plan(top_join.clone(), &DISTRIB_DISTRIB_SORT); + + match join_type { + // Should include 6 RepartitionExecs(3 hash, 3 round-robin) and 3 SortExecs + JoinType::Inner | JoinType::Right => { + // TODO(wiedld): show different test result if enforce sorting first. + assert_plan!(plan_distrib, @r" + SortMergeJoinExec: join_type=..., on=[(b1@6, c@2)] + SortMergeJoinExec: join_type=..., on=[(a@0, b1@1)] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + SortExec: expr=[b1@1 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + SortExec: expr=[c@2 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + } + // Should include 7 RepartitionExecs (4 hash, 3 round-robin) and 4 SortExecs + JoinType::Left | JoinType::Full => { + // TODO(wiedld): show different test result if enforce sorting first. + assert_plan!(plan_distrib, @r" + SortMergeJoinExec: join_type=..., on=[(b1@6, c@2)] + SortExec: expr=[b1@6 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([b1@6], 10), input_partitions=10 + SortMergeJoinExec: join_type=..., on=[(a@0, b1@1)] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + SortExec: expr=[b1@1 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1 + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + SortExec: expr=[c@2 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + } + // this match arm cannot be reached + _ => unreachable!() + } + + let plan_sort = test_config.to_plan(top_join, &SORT_DISTRIB_DISTRIB); + + match join_type { + // Should include 6 RepartitionExecs (3 of them preserves order) and 3 SortExecs + JoinType::Inner | JoinType::Right => { + // TODO(wiedld): show different test result if enforce distribution first. + assert_plan!(plan_sort, @r" + SortMergeJoinExec: join_type=..., on=[(b1@6, c@2)] + SortMergeJoinExec: join_type=..., on=[(a@0, b1@1)] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[b1@1 ASC], preserve_partitioning=[false] + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + } + // Should include 8 RepartitionExecs (4 of them preserves order) and 4 SortExecs + JoinType::Left | JoinType::Full => { + // TODO(wiedld): show different test result if enforce distribution first. + assert_plan!(plan_sort, @r" + SortMergeJoinExec: join_type=..., on=[(b1@6, c@2)] + RepartitionExec: partitioning=Hash([b1@6], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[b1@6 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + SortMergeJoinExec: join_type=..., on=[(a@0, b1@1)] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[b1@1 ASC], preserve_partitioning=[false] + ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + } + // this match arm cannot be reached + _ => unreachable!() + } + } + _ => {} + } + }); } } - Ok(()) } @@ -1688,52 +1691,50 @@ fn smj_join_key_ordering() -> Result<()> { // Test: run EnforceDistribution, then EnforceSort. // Only two RepartitionExecs added - let expected = &[ - "SortMergeJoin: join_type=Inner, on=[(b3@1, b2@1), (a3@0, a2@0)]", - " SortExec: expr=[b3@1 ASC, a3@0 ASC], preserve_partitioning=[true]", - " ProjectionExec: expr=[a1@0 as a3, b1@1 as b3]", - " ProjectionExec: expr=[a1@1 as a1, b1@0 as b1]", - " AggregateExec: mode=FinalPartitioned, gby=[b1@0 as b1, a1@1 as a1], aggr=[]", - " RepartitionExec: partitioning=Hash([b1@0, a1@1], 10), input_partitions=10", - " AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " SortExec: expr=[b2@1 ASC, a2@0 ASC], preserve_partitioning=[true]", - " ProjectionExec: expr=[a@1 as a2, b@0 as b2]", - " AggregateExec: mode=FinalPartitioned, gby=[b@0 as b, a@1 as a], aggr=[]", - " RepartitionExec: partitioning=Hash([b@0, a@1], 10), input_partitions=10", - " AggregateExec: mode=Partial, gby=[b@1 as b, a@0 as a], aggr=[]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; - test_config.run(expected, join.clone(), &DISTRIB_DISTRIB_SORT)?; + let plan_distrib = test_config.to_plan(join.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_distrib, @r" + SortMergeJoinExec: join_type=Inner, on=[(b3@1, b2@1), (a3@0, a2@0)] + SortExec: expr=[b3@1 ASC, a3@0 ASC], preserve_partitioning=[true] + ProjectionExec: expr=[a1@0 as a3, b1@1 as b3] + ProjectionExec: expr=[a1@1 as a1, b1@0 as b1] + AggregateExec: mode=FinalPartitioned, gby=[b1@0 as b1, a1@1 as a1], aggr=[] + RepartitionExec: partitioning=Hash([b1@0, a1@1], 10), input_partitions=10 + AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + SortExec: expr=[b2@1 ASC, a2@0 ASC], preserve_partitioning=[true] + ProjectionExec: expr=[a@1 as a2, b@0 as b2] + AggregateExec: mode=FinalPartitioned, gby=[b@0 as b, a@1 as a], aggr=[] + RepartitionExec: partitioning=Hash([b@0, a@1], 10), input_partitions=10 + AggregateExec: mode=Partial, gby=[b@1 as b, a@0 as a], aggr=[] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); // Test: result IS DIFFERENT, if EnforceSorting is run first: - let expected_first_sort_enforcement = &[ - "SortMergeJoin: join_type=Inner, on=[(b3@1, b2@1), (a3@0, a2@0)]", - " RepartitionExec: partitioning=Hash([b3@1, a3@0], 10), input_partitions=10, preserve_order=true, sort_exprs=b3@1 ASC, a3@0 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[b3@1 ASC, a3@0 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " ProjectionExec: expr=[a1@0 as a3, b1@1 as b3]", - " ProjectionExec: expr=[a1@1 as a1, b1@0 as b1]", - " AggregateExec: mode=FinalPartitioned, gby=[b1@0 as b1, a1@1 as a1], aggr=[]", - " RepartitionExec: partitioning=Hash([b1@0, a1@1], 10), input_partitions=10", - " AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " RepartitionExec: partitioning=Hash([b2@1, a2@0], 10), input_partitions=10, preserve_order=true, sort_exprs=b2@1 ASC, a2@0 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[b2@1 ASC, a2@0 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " ProjectionExec: expr=[a@1 as a2, b@0 as b2]", - " AggregateExec: mode=FinalPartitioned, gby=[b@0 as b, a@1 as a], aggr=[]", - " RepartitionExec: partitioning=Hash([b@0, a@1], 10), input_partitions=10", - " AggregateExec: mode=Partial, gby=[b@1 as b, a@0 as a], aggr=[]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; - test_config.run(expected_first_sort_enforcement, join, &SORT_DISTRIB_DISTRIB)?; + let plan_sort = test_config.to_plan(join, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_sort, @r" + SortMergeJoinExec: join_type=Inner, on=[(b3@1, b2@1), (a3@0, a2@0)] + RepartitionExec: partitioning=Hash([b3@1, a3@0], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[b3@1 ASC, a3@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + ProjectionExec: expr=[a1@0 as a3, b1@1 as b3] + ProjectionExec: expr=[a1@1 as a1, b1@0 as b1] + AggregateExec: mode=FinalPartitioned, gby=[b1@0 as b1, a1@1 as a1], aggr=[] + RepartitionExec: partitioning=Hash([b1@0, a1@1], 10), input_partitions=10 + AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + RepartitionExec: partitioning=Hash([b2@1, a2@0], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[b2@1 ASC, a2@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + ProjectionExec: expr=[a@1 as a2, b@0 as b2] + AggregateExec: mode=FinalPartitioned, gby=[b@0 as b, a@1 as a], aggr=[] + RepartitionExec: partitioning=Hash([b@0, a@1], 10), input_partitions=10 + AggregateExec: mode=Partial, gby=[b@1 as b, a@0 as a], aggr=[] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); Ok(()) } @@ -1742,17 +1743,15 @@ fn smj_join_key_ordering() -> Result<()> { fn merge_does_not_need_sort() -> Result<()> { // see https://github.com/apache/datafusion/issues/4331 let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); // Scan some sorted parquet files let exec = parquet_exec_multiple_sorted(vec![sort_key.clone()]); - // CoalesceBatchesExec to mimic behavior after a filter - let exec = Arc::new(CoalesceBatchesExec::new(exec, 4096)); - // Merge from multiple parquet files and keep the data sorted let exec: Arc = Arc::new(SortPreservingMergeExec::new(sort_key, exec)); @@ -1761,13 +1760,13 @@ fn merge_does_not_need_sort() -> Result<()> { // // The optimizer should not add an additional SortExec as the // data is already sorted - let expected = &[ - "SortPreservingMergeExec: [a@0 ASC]", - " CoalesceBatchesExec: target_batch_size=4096", - " DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet", - ]; let test_config = TestConfig::default(); - test_config.run(expected, exec.clone(), &DISTRIB_DISTRIB_SORT)?; + let plan_distrib = test_config.to_plan(exec.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_distrib, + @r" + SortPreservingMergeExec: [a@0 ASC] + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + "); // Test: result IS DIFFERENT, if EnforceSorting is run first: // @@ -1775,13 +1774,13 @@ fn merge_does_not_need_sort() -> Result<()> { // (according to flag: PREFER_EXISTING_SORT) // hence in this case ordering lost during CoalescePartitionsExec and re-introduced with // SortExec at the top. - let expected_first_sort_enforcement = &[ - "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " CoalesceBatchesExec: target_batch_size=4096", - " DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet", - ]; - test_config.run(expected_first_sort_enforcement, exec, &SORT_DISTRIB_DISTRIB)?; + let plan_sort = test_config.to_plan(exec, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_sort, + @r" + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + "); Ok(()) } @@ -1800,32 +1799,33 @@ fn union_to_interleave() -> Result<()> { ); // Union - let plan = Arc::new(UnionExec::new(vec![left, right])); + let plan = UnionExec::try_new(vec![left, right])?; // final agg let plan = aggregate_exec_with_alias(plan, vec![("a1".to_string(), "a2".to_string())]); // Only two RepartitionExecs added, no final RepartitionExec required - let expected = &[ - "AggregateExec: mode=FinalPartitioned, gby=[a2@0 as a2], aggr=[]", - " AggregateExec: mode=Partial, gby=[a1@0 as a2], aggr=[]", - " InterleaveExec", - " AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[]", - " RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10", - " AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[]", - " RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10", - " AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; - let test_config = TestConfig::default(); - test_config.run(expected, plan.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected, plan, &SORT_DISTRIB_DISTRIB)?; + let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_distrib, + @r" + AggregateExec: mode=FinalPartitioned, gby=[a2@0 as a2], aggr=[] + AggregateExec: mode=Partial, gby=[a1@0 as a2], aggr=[] + InterleaveExec + AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[] + RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10 + AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[] + RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10 + AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); Ok(()) } @@ -1844,35 +1844,36 @@ fn union_not_to_interleave() -> Result<()> { ); // Union - let plan = Arc::new(UnionExec::new(vec![left, right])); + let plan = UnionExec::try_new(vec![left, right])?; // final agg let plan = aggregate_exec_with_alias(plan, vec![("a1".to_string(), "a2".to_string())]); // Only two RepartitionExecs added, no final RepartitionExec required - let expected = &[ - "AggregateExec: mode=FinalPartitioned, gby=[a2@0 as a2], aggr=[]", - " RepartitionExec: partitioning=Hash([a2@0], 10), input_partitions=20", - " AggregateExec: mode=Partial, gby=[a1@0 as a2], aggr=[]", - " UnionExec", - " AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[]", - " RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10", - " AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[]", - " RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10", - " AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; - // TestConfig: Prefer existing union. let test_config = TestConfig::default().with_prefer_existing_union(); - test_config.run(expected, plan.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected, plan, &SORT_DISTRIB_DISTRIB)?; + let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_distrib, + @r" + AggregateExec: mode=FinalPartitioned, gby=[a2@0 as a2], aggr=[] + RepartitionExec: partitioning=Hash([a2@0], 10), input_partitions=20 + AggregateExec: mode=Partial, gby=[a1@0 as a2], aggr=[] + UnionExec + AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[] + RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10 + AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[] + RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10 + AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); Ok(()) } @@ -1882,17 +1883,18 @@ fn added_repartition_to_single_partition() -> Result<()> { let alias = vec![("a".to_string(), "a".to_string())]; let plan = aggregate_exec_with_alias(parquet_exec(), alias); - let expected = [ - "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", - " AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; - let test_config = TestConfig::default(); - test_config.run(&expected, plan.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(&expected, plan, &SORT_DISTRIB_DISTRIB)?; + let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_distrib, + @r" + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10 + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); Ok(()) } @@ -1902,18 +1904,19 @@ fn repartition_deepest_node() -> Result<()> { let alias = vec![("a".to_string(), "a".to_string())]; let plan = aggregate_exec_with_alias(filter_exec(parquet_exec()), alias); - let expected = &[ - "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", - " AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", - " FilterExec: c@2 = 0", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; - let test_config = TestConfig::default(); - test_config.run(expected, plan.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected, plan, &SORT_DISTRIB_DISTRIB)?; + let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_distrib, + @r" + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10 + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); Ok(()) } @@ -1922,19 +1925,20 @@ fn repartition_deepest_node() -> Result<()> { fn repartition_unsorted_limit() -> Result<()> { let plan = limit_exec(filter_exec(parquet_exec())); - let expected = &[ - "GlobalLimitExec: skip=0, fetch=100", - " CoalescePartitionsExec", - " LocalLimitExec: fetch=100", - " FilterExec: c@2 = 0", - // nothing sorts the data, so the local limit doesn't require sorted data either - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; - let test_config = TestConfig::default(); - test_config.run(expected, plan.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected, plan, &SORT_DISTRIB_DISTRIB)?; + let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_distrib, + @r" + GlobalLimitExec: skip=0, fetch=100 + CoalescePartitionsExec + LocalLimitExec: fetch=100 + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + // nothing sorts the data, so the local limit doesn't require sorted data either + let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); Ok(()) } @@ -1942,23 +1946,25 @@ fn repartition_unsorted_limit() -> Result<()> { #[test] fn repartition_sorted_limit() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); - let plan = limit_exec(sort_exec(sort_key, parquet_exec(), false)); - - let expected = &[ - "GlobalLimitExec: skip=0, fetch=100", - " LocalLimitExec: fetch=100", - // data is sorted so can't repartition here - " SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; + }] + .into(); + let plan = limit_exec(sort_exec(sort_key, parquet_exec())); let test_config = TestConfig::default(); - test_config.run(expected, plan.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected, plan, &SORT_DISTRIB_DISTRIB)?; + let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_distrib, + @r" + GlobalLimitExec: skip=0, fetch=100 + LocalLimitExec: fetch=100 + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + // data is sorted so can't repartition here + let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); Ok(()) } @@ -1966,28 +1972,30 @@ fn repartition_sorted_limit() -> Result<()> { #[test] fn repartition_sorted_limit_with_filter() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let plan = sort_required_exec_with_req( - filter_exec(sort_exec(sort_key.clone(), parquet_exec(), false)), + filter_exec(sort_exec(sort_key.clone(), parquet_exec())), sort_key, ); - let expected = &[ - "SortRequiredExec: [c@2 ASC]", - " FilterExec: c@2 = 0", - // We can use repartition here, ordering requirement by SortRequiredExec - // is still satisfied. - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; - let test_config = TestConfig::default(); - test_config.run(expected, plan.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected, plan, &SORT_DISTRIB_DISTRIB)?; + let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_distrib, + @r" + SortRequiredExec: [c@2 ASC] + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + // We can use repartition here, ordering requirement by SortRequiredExec + // is still satisfied. + let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); Ok(()) } @@ -2000,26 +2008,28 @@ fn repartition_ignores_limit() -> Result<()> { alias, ); - let expected = &[ - "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", - " AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " GlobalLimitExec: skip=0, fetch=100", - " CoalescePartitionsExec", - " LocalLimitExec: fetch=100", - " FilterExec: c@2 = 0", - // repartition should happen prior to the filter to maximize parallelism - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " GlobalLimitExec: skip=0, fetch=100", - " LocalLimitExec: fetch=100", - // Expect no repartition to happen for local limit - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; - let test_config = TestConfig::default(); - test_config.run(expected, plan.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected, plan, &SORT_DISTRIB_DISTRIB)?; + let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_distrib, + @r" + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10 + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + GlobalLimitExec: skip=0, fetch=100 + CoalescePartitionsExec + LocalLimitExec: fetch=100 + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + GlobalLimitExec: skip=0, fetch=100 + LocalLimitExec: fetch=100 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + // repartition should happen prior to the filter to maximize parallelism + // Expect no repartition to happen for local limit (DataSourceExec) + + let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); Ok(()) } @@ -2028,19 +2038,20 @@ fn repartition_ignores_limit() -> Result<()> { fn repartition_ignores_union() -> Result<()> { let plan = union_exec(vec![parquet_exec(); 5]); - let expected = &[ - "UnionExec", - // Expect no repartition of DataSourceExec - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; - let test_config = TestConfig::default(); - test_config.run(expected, plan.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected, plan, &SORT_DISTRIB_DISTRIB)?; + let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_distrib, + @r" + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + // Expect no repartition of DataSourceExec + let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); Ok(()) } @@ -2049,21 +2060,22 @@ fn repartition_ignores_union() -> Result<()> { fn repartition_through_sort_preserving_merge() -> Result<()> { // sort preserving merge with non-sorted input let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let plan = sort_preserving_merge_exec(sort_key, parquet_exec()); - // need resort as the data was not sorted correctly - let expected = &[ - "SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; - let test_config = TestConfig::default(); - test_config.run(expected, plan.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected, plan, &SORT_DISTRIB_DISTRIB)?; + let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_distrib, + @r" + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); Ok(()) } @@ -2072,33 +2084,35 @@ fn repartition_through_sort_preserving_merge() -> Result<()> { fn repartition_ignores_sort_preserving_merge() -> Result<()> { // sort preserving merge already sorted input, let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let plan = sort_preserving_merge_exec( sort_key.clone(), parquet_exec_multiple_sorted(vec![sort_key]), ); + let test_config = TestConfig::default(); + let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); // Test: run EnforceDistribution, then EnforceSort - // + assert_plan!(plan_distrib, + @r" + SortPreservingMergeExec: [c@2 ASC] + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); // should not sort (as the data was already sorted) // should not repartition, since increased parallelism is not beneficial for SortPReservingMerge - let expected = &[ - "SortPreservingMergeExec: [c@2 ASC]", - " DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet", - ]; - let test_config = TestConfig::default(); - test_config.run(expected, plan.clone(), &DISTRIB_DISTRIB_SORT)?; // Test: result IS DIFFERENT, if EnforceSorting is run first: - let expected_first_sort_enforcement = &[ - "SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet", - ]; - test_config.run(expected_first_sort_enforcement, plan, &SORT_DISTRIB_DISTRIB)?; + let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_sort, + @r" + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); Ok(()) } @@ -2107,34 +2121,40 @@ fn repartition_ignores_sort_preserving_merge() -> Result<()> { fn repartition_ignores_sort_preserving_merge_with_union() -> Result<()> { // 2 sorted parquet files unioned (partitions are concatenated, sort is preserved) let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); - let input = union_exec(vec![parquet_exec_with_sort(vec![sort_key.clone()]); 2]); + }] + .into(); + let input = union_exec(vec![ + parquet_exec_with_sort(schema, vec![sort_key.clone()]); + 2 + ]); let plan = sort_preserving_merge_exec(sort_key, input); + let test_config = TestConfig::default(); + let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); // Test: run EnforceDistribution, then EnforceSort. + assert_plan!(plan_distrib, + @r" + SortPreservingMergeExec: [c@2 ASC] + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); // // should not repartition / sort (as the data was already sorted) - let expected = &[ - "SortPreservingMergeExec: [c@2 ASC]", - " UnionExec", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet", - ]; - let test_config = TestConfig::default(); - test_config.run(expected, plan.clone(), &DISTRIB_DISTRIB_SORT)?; // test: result IS DIFFERENT, if EnforceSorting is run first: - let expected_first_sort_enforcement = &[ - "SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " UnionExec", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet", - ]; - test_config.run(expected_first_sort_enforcement, plan, &SORT_DISTRIB_DISTRIB)?; + let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_sort, + @r" + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); Ok(()) } @@ -2145,28 +2165,30 @@ fn repartition_does_not_destroy_sort() -> Result<()> { // SortRequired // Parquet(sorted) let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("d", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("d", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let plan = sort_required_exec_with_req( - filter_exec(parquet_exec_with_sort(vec![sort_key.clone()])), + filter_exec(parquet_exec_with_sort(schema, vec![sort_key.clone()])), sort_key, ); // TestConfig: Prefer existing sort. let test_config = TestConfig::default().with_prefer_existing_sort(); + let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_distrib, + @r" + SortRequiredExec: [d@3 ASC] + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[d@3 ASC], file_type=parquet + "); // during repartitioning ordering is preserved - let expected = &[ - "SortRequiredExec: [d@3 ASC]", - " FilterExec: c@2 = 0", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[d@3 ASC], file_type=parquet", - ]; - - test_config.run(expected, plan.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected, plan, &SORT_DISTRIB_DISTRIB)?; + let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); Ok(()) } @@ -2183,33 +2205,37 @@ fn repartition_does_not_destroy_sort_more_complex() -> Result<()> { // Parquet(unsorted) let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input1 = sort_required_exec_with_req( - parquet_exec_with_sort(vec![sort_key.clone()]), + parquet_exec_with_sort(schema, vec![sort_key.clone()]), sort_key, ); let input2 = filter_exec(parquet_exec()); let plan = union_exec(vec![input1, input2]); + let test_config = TestConfig::default(); + let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_distrib, + @r" + UnionExec + SortRequiredExec: [c@2 ASC] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + // union input 1: no repartitioning + // union input 2: should repartition + // // should not repartition below the SortRequired as that // branch doesn't benefit from increased parallelism - let expected = &[ - "UnionExec", - // union input 1: no repartitioning - " SortRequiredExec: [c@2 ASC]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet", - // union input 2: should repartition - " FilterExec: c@2 = 0", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; - let test_config = TestConfig::default(); - test_config.run(expected, plan.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected, plan, &SORT_DISTRIB_DISTRIB)?; + let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); Ok(()) } @@ -2217,44 +2243,45 @@ fn repartition_does_not_destroy_sort_more_complex() -> Result<()> { #[test] fn repartition_transitively_with_projection() -> Result<()> { let schema = schema(); - let proj_exprs = vec![( - Arc::new(BinaryExpr::new( - col("a", &schema).unwrap(), + let proj_exprs = vec![ProjectionExpr { + expr: Arc::new(BinaryExpr::new( + col("a", &schema)?, Operator::Plus, - col("b", &schema).unwrap(), - )) as Arc, - "sum".to_string(), - )]; + col("b", &schema)?, + )) as _, + alias: "sum".to_string(), + }]; // non sorted input let proj = Arc::new(ProjectionExec::try_new(proj_exprs, parquet_exec())?); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("sum", &proj.schema()).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("sum", &proj.schema())?, options: SortOptions::default(), - }]); + }] + .into(); let plan = sort_preserving_merge_exec(sort_key, proj); - // Test: run EnforceDistribution, then EnforceSort. - let expected = &[ - "SortPreservingMergeExec: [sum@0 ASC]", - " SortExec: expr=[sum@0 ASC], preserve_partitioning=[true]", - // Since this projection is not trivial, increasing parallelism is beneficial - " ProjectionExec: expr=[a@0 + b@1 as sum]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; let test_config = TestConfig::default(); - test_config.run(expected, plan.clone(), &DISTRIB_DISTRIB_SORT)?; + let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_distrib, + @r" + SortPreservingMergeExec: [sum@0 ASC] + SortExec: expr=[sum@0 ASC], preserve_partitioning=[true] + ProjectionExec: expr=[a@0 + b@1 as sum] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); // Test: result IS DIFFERENT, if EnforceSorting is run first: - let expected_first_sort_enforcement = &[ - "SortExec: expr=[sum@0 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - // Since this projection is not trivial, increasing parallelism is beneficial - " ProjectionExec: expr=[a@0 + b@1 as sum]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; - test_config.run(expected_first_sort_enforcement, plan, &SORT_DISTRIB_DISTRIB)?; + let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_sort, + @r" + SortExec: expr=[sum@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + ProjectionExec: expr=[a@0 + b@1 as sum] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + // Since this projection is not trivial, increasing parallelism is beneficial Ok(()) } @@ -2262,10 +2289,11 @@ fn repartition_transitively_with_projection() -> Result<()> { #[test] fn repartition_ignores_transitively_with_projection() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let alias = vec![ ("a".to_string(), "a".to_string()), ("b".to_string(), "b".to_string()), @@ -2280,16 +2308,18 @@ fn repartition_ignores_transitively_with_projection() -> Result<()> { sort_key, ); - let expected = &[ - "SortRequiredExec: [c@2 ASC]", - // Since this projection is trivial, increasing parallelism is not beneficial - " ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c]", - " DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet", - ]; - let test_config = TestConfig::default(); - test_config.run(expected, plan.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected, plan, &SORT_DISTRIB_DISTRIB)?; + let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_distrib, + @r" + SortRequiredExec: [c@2 ASC] + ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c] + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); + // Since this projection is trivial, increasing parallelism is not beneficial + + let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); Ok(()) } @@ -2297,10 +2327,11 @@ fn repartition_ignores_transitively_with_projection() -> Result<()> { #[test] fn repartition_transitively_past_sort_with_projection() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let alias = vec![ ("a".to_string(), "a".to_string()), ("b".to_string(), "b".to_string()), @@ -2308,23 +2339,23 @@ fn repartition_transitively_past_sort_with_projection() -> Result<()> { ]; let plan = sort_preserving_merge_exec( sort_key.clone(), - sort_exec( + sort_exec_with_preserve_partitioning( sort_key, projection_exec_with_alias(parquet_exec(), alias), - true, ), ); - let expected = &[ - "SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", - // Since this projection is trivial, increasing parallelism is not beneficial - " ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; - let test_config = TestConfig::default(); - test_config.run(expected, plan.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected, plan, &SORT_DISTRIB_DISTRIB)?; + let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_distrib, + @r" + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + // Since this projection is trivial, increasing parallelism is not beneficial + let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); Ok(()) } @@ -2332,34 +2363,37 @@ fn repartition_transitively_past_sort_with_projection() -> Result<()> { #[test] fn repartition_transitively_past_sort_with_filter() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); - let plan = sort_exec(sort_key, filter_exec(parquet_exec()), false); + }] + .into(); + let plan = sort_exec(sort_key, filter_exec(parquet_exec())); - // Test: run EnforceDistribution, then EnforceSort. - let expected = &[ - "SortPreservingMergeExec: [a@0 ASC]", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", - // Expect repartition on the input to the sort (as it can benefit from additional parallelism) - " FilterExec: c@2 = 0", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; let test_config = TestConfig::default(); - test_config.run(expected, plan.clone(), &DISTRIB_DISTRIB_SORT)?; + let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_distrib, + @r" + SortPreservingMergeExec: [a@0 ASC] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + + // Expect repartition on the input to the sort (as it can benefit from additional parallelism) // Test: result IS DIFFERENT, if EnforceSorting is run first: - let expected_first_sort_enforcement = &[ - "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " FilterExec: c@2 = 0", - // Expect repartition on the input of the filter (as it can benefit from additional parallelism) - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; - test_config.run(expected_first_sort_enforcement, plan, &SORT_DISTRIB_DISTRIB)?; + let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_sort, + @r" + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + // Expect repartition on the input of the filter (as it can benefit from additional parallelism) Ok(()) } @@ -2368,10 +2402,11 @@ fn repartition_transitively_past_sort_with_filter() -> Result<()> { #[cfg(feature = "parquet")] fn repartition_transitively_past_sort_with_projection_and_filter() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let plan = sort_exec( sort_key, projection_exec_with_alias( @@ -2382,33 +2417,34 @@ fn repartition_transitively_past_sort_with_projection_and_filter() -> Result<()> ("c".to_string(), "c".to_string()), ], ), - false, ); - // Test: run EnforceDistribution, then EnforceSort. - let expected = &[ - "SortPreservingMergeExec: [a@0 ASC]", - // Expect repartition on the input to the sort (as it can benefit from additional parallelism) - " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", - " ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c]", - " FilterExec: c@2 = 0", - // repartition is lowest down - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; let test_config = TestConfig::default(); - test_config.run(expected, plan.clone(), &DISTRIB_DISTRIB_SORT)?; + let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_distrib, + @r" + SortPreservingMergeExec: [a@0 ASC] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c] + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + + // Expect repartition on the input to the sort (as it can benefit from additional parallelism) + // repartition is lowest down // Test: result IS DIFFERENT, if EnforceSorting is run first: - let expected_first_sort_enforcement = &[ - "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c]", - " FilterExec: c@2 = 0", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; - test_config.run(expected_first_sort_enforcement, plan, &SORT_DISTRIB_DISTRIB)?; + let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_sort, + @r" + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c] + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); Ok(()) } @@ -2424,28 +2460,29 @@ fn parallelization_single_partition() -> Result<()> { .with_query_execution_partitions(2); // Test: with parquet - let expected_parquet = [ - "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", - " RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2", - " AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", - " DataSourceExec: file_groups={2 groups: [[x:0..50], [x:50..100]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; - test_config.run( - &expected_parquet, - plan_parquet.clone(), - &DISTRIB_DISTRIB_SORT, - )?; - test_config.run(&expected_parquet, plan_parquet, &SORT_DISTRIB_DISTRIB)?; + let plan_parquet_distrib = + test_config.to_plan(plan_parquet.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_parquet_distrib, + @r" + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + DataSourceExec: file_groups={2 groups: [[x:0..50], [x:50..100]]}, projection=[a, b, c, d, e], file_type=parquet + "); + let plan_parquet_sort = test_config.to_plan(plan_parquet, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_parquet_distrib, plan_parquet_sort); // Test: with csv - let expected_csv = [ - "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", - " RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2", - " AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", - " DataSourceExec: file_groups={2 groups: [[x:0..50], [x:50..100]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - ]; - test_config.run(&expected_csv, plan_csv.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(&expected_csv, plan_csv, &SORT_DISTRIB_DISTRIB)?; + let plan_csv_distrib = test_config.to_plan(plan_csv.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_csv_distrib, + @r" + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + DataSourceExec: file_groups={2 groups: [[x:0..50], [x:50..100]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + "); + let plan_csv_sort = test_config.to_plan(plan_csv, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_csv_distrib, plan_csv_sort); Ok(()) } @@ -2453,10 +2490,11 @@ fn parallelization_single_partition() -> Result<()> { #[test] fn parallelization_multiple_files() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let plan = filter_exec(parquet_exec_multiple_sorted(vec![sort_key.clone()])); let plan = sort_required_exec_with_req(plan, sort_key); @@ -2468,40 +2506,31 @@ fn parallelization_multiple_files() -> Result<()> { // The groups must have only contiguous ranges of rows from the same file // if any group has rows from multiple files, the data is no longer sorted destroyed // https://github.com/apache/datafusion/issues/8451 - let expected_with_3_target_partitions = [ - "SortRequiredExec: [a@0 ASC]", - " FilterExec: c@2 = 0", - " DataSourceExec: file_groups={3 groups: [[x:0..50], [y:0..100], [x:50..100]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet", - ]; let test_config_concurrency_3 = test_config.clone().with_query_execution_partitions(3); - test_config_concurrency_3.run( - &expected_with_3_target_partitions, - plan.clone(), - &DISTRIB_DISTRIB_SORT, - )?; - test_config_concurrency_3.run( - &expected_with_3_target_partitions, - plan.clone(), - &SORT_DISTRIB_DISTRIB, - )?; + let plan_3_distrib = + test_config_concurrency_3.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_3_distrib, + @r" + SortRequiredExec: [a@0 ASC] + FilterExec: c@2 = 0 + DataSourceExec: file_groups={3 groups: [[x:0..50], [y:0..100], [x:50..100]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + "); + let plan_3_sort = + test_config_concurrency_3.to_plan(plan.clone(), &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_3_distrib, plan_3_sort); - let expected_with_8_target_partitions = [ - "SortRequiredExec: [a@0 ASC]", - " FilterExec: c@2 = 0", - " DataSourceExec: file_groups={8 groups: [[x:0..25], [y:0..25], [x:25..50], [y:25..50], [x:50..75], [y:50..75], [x:75..100], [y:75..100]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet", - ]; let test_config_concurrency_8 = test_config.with_query_execution_partitions(8); - test_config_concurrency_8.run( - &expected_with_8_target_partitions, - plan.clone(), - &DISTRIB_DISTRIB_SORT, - )?; - test_config_concurrency_8.run( - &expected_with_8_target_partitions, - plan, - &SORT_DISTRIB_DISTRIB, - )?; + let plan_8_distrib = + test_config_concurrency_8.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_8_distrib, + @r" + SortRequiredExec: [a@0 ASC] + FilterExec: c@2 = 0 + DataSourceExec: file_groups={8 groups: [[x:0..25], [y:0..25], [x:25..50], [y:25..50], [x:50..75], [y:50..75], [x:75..100], [y:75..100]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + "); + let plan_8_sort = test_config_concurrency_8.to_plan(plan, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_8_distrib, plan_8_sort); Ok(()) } @@ -2518,46 +2547,55 @@ fn parallelization_compressed_csv() -> Result<()> { FileCompressionType::UNCOMPRESSED, ]; - let expected_not_partitioned = [ - "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", - " RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2", - " AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - ]; - - let expected_partitioned = [ - "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", - " RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2", - " AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", - " DataSourceExec: file_groups={2 groups: [[x:0..50], [x:50..100]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - ]; + #[rustfmt::skip] + insta::allow_duplicates! { + for compression_type in compression_types { + let plan = aggregate_exec_with_alias( + DataSourceExec::from_data_source( + FileScanConfigBuilder::new(ObjectStoreUrl::parse("test:///").unwrap(), { + let options = CsvOptions { + has_header: Some(false), + delimiter: b',', + quote: b'"', + ..Default::default() + }; + Arc::new(CsvSource::new(schema()).with_csv_options(options)) + }) + .with_file(PartitionedFile::new("x".to_string(), 100)) + .with_file_compression_type(compression_type) + .build(), + ), + vec![("a".to_string(), "a".to_string())], + ); + let test_config = TestConfig::default() + .with_query_execution_partitions(2) + .with_prefer_repartition_file_scans(10); + + let plan_distrib = test_config.to_plan(plan.clone(), &DISTRIB_DISTRIB_SORT); + if compression_type.is_compressed() { + // Compressed files cannot be partitioned + assert_plan!(plan_distrib, + @r" + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + "); + } else { + // Uncompressed files can be partitioned + assert_plan!(plan_distrib, + @r" + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + DataSourceExec: file_groups={2 groups: [[x:0..50], [x:50..100]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + "); + } - for compression_type in compression_types { - let expected = if compression_type.is_compressed() { - &expected_not_partitioned[..] - } else { - &expected_partitioned[..] - }; - - let plan = aggregate_exec_with_alias( - DataSourceExec::from_data_source( - FileScanConfigBuilder::new( - ObjectStoreUrl::parse("test:///").unwrap(), - schema(), - Arc::new(CsvSource::new(false, b',', b'"')), - ) - .with_file(PartitionedFile::new("x".to_string(), 100)) - .with_file_compression_type(compression_type) - .build(), - ), - vec![("a".to_string(), "a".to_string())], - ); - let test_config = TestConfig::default() - .with_query_execution_partitions(2) - .with_prefer_repartition_file_scans(10); - test_config.run(expected, plan.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected, plan, &SORT_DISTRIB_DISTRIB)?; + let plan_sort = test_config.to_plan(plan, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); + } } Ok(()) } @@ -2573,30 +2611,30 @@ fn parallelization_two_partitions() -> Result<()> { .with_prefer_repartition_file_scans(10); // Test: with parquet - let expected_parquet = [ - "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", - " RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2", - " AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", - // Plan already has two partitions - " DataSourceExec: file_groups={2 groups: [[x:0..100], [y:0..100]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; - test_config.run( - &expected_parquet, - plan_parquet.clone(), - &DISTRIB_DISTRIB_SORT, - )?; - test_config.run(&expected_parquet, plan_parquet, &SORT_DISTRIB_DISTRIB)?; + let plan_parquet_distrib = + test_config.to_plan(plan_parquet.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_parquet_distrib, + @r" + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + DataSourceExec: file_groups={2 groups: [[x:0..100], [y:0..100]]}, projection=[a, b, c, d, e], file_type=parquet + "); + // Plan already has two partitions + let plan_parquet_sort = test_config.to_plan(plan_parquet, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_parquet_distrib, plan_parquet_sort); // Test: with csv - let expected_csv = [ - "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", - " RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2", - " AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", - // Plan already has two partitions - " DataSourceExec: file_groups={2 groups: [[x:0..100], [y:0..100]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - ]; - test_config.run(&expected_csv, plan_csv.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(&expected_csv, plan_csv, &SORT_DISTRIB_DISTRIB)?; + let plan_csv_distrib = test_config.to_plan(plan_csv.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_csv_distrib, @r" + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + DataSourceExec: file_groups={2 groups: [[x:0..100], [y:0..100]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + "); + // Plan already has two partitions + let plan_csv_sort = test_config.to_plan(plan_csv, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_csv_distrib, plan_csv_sort); Ok(()) } @@ -2612,30 +2650,32 @@ fn parallelization_two_partitions_into_four() -> Result<()> { .with_prefer_repartition_file_scans(10); // Test: with parquet - let expected_parquet = [ - "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", - " RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4", - " AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", - // Multiple source files splitted across partitions - " DataSourceExec: file_groups={4 groups: [[x:0..50], [x:50..100], [y:0..50], [y:50..100]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; - test_config.run( - &expected_parquet, - plan_parquet.clone(), - &DISTRIB_DISTRIB_SORT, - )?; - test_config.run(&expected_parquet, plan_parquet, &SORT_DISTRIB_DISTRIB)?; + let plan_parquet_distrib = + test_config.to_plan(plan_parquet.clone(), &DISTRIB_DISTRIB_SORT); + // Multiple source files split across partitions + assert_plan!(plan_parquet_distrib, + @r" + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + DataSourceExec: file_groups={4 groups: [[x:0..50], [x:50..100], [y:0..50], [y:50..100]]}, projection=[a, b, c, d, e], file_type=parquet + "); + // Multiple source files split across partitions + let plan_parquet_sort = test_config.to_plan(plan_parquet, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_parquet_distrib, plan_parquet_sort); // Test: with csv - let expected_csv = [ - "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", - " RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4", - " AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", - // Multiple source files splitted across partitions - " DataSourceExec: file_groups={4 groups: [[x:0..50], [x:50..100], [y:0..50], [y:50..100]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - ]; - test_config.run(&expected_csv, plan_csv.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(&expected_csv, plan_csv, &SORT_DISTRIB_DISTRIB)?; + let plan_csv_distrib = test_config.to_plan(plan_csv.clone(), &DISTRIB_DISTRIB_SORT); + // Multiple source files split across partitions + assert_plan!(plan_csv_distrib, @r" + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + DataSourceExec: file_groups={4 groups: [[x:0..50], [x:50..100], [y:0..50], [y:50..100]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + "); + // Multiple source files split across partitions + let plan_csv_sort = test_config.to_plan(plan_csv, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_csv_distrib, plan_csv_sort); Ok(()) } @@ -2643,42 +2683,43 @@ fn parallelization_two_partitions_into_four() -> Result<()> { #[test] fn parallelization_sorted_limit() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); - let plan_parquet = limit_exec(sort_exec(sort_key.clone(), parquet_exec(), false)); - let plan_csv = limit_exec(sort_exec(sort_key, csv_exec(), false)); + }] + .into(); + let plan_parquet = limit_exec(sort_exec(sort_key.clone(), parquet_exec())); + let plan_csv = limit_exec(sort_exec(sort_key, csv_exec())); let test_config = TestConfig::default(); // Test: with parquet - let expected_parquet = &[ - "GlobalLimitExec: skip=0, fetch=100", - " LocalLimitExec: fetch=100", - // data is sorted so can't repartition here - " SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", - // Doesn't parallelize for SortExec without preserve_partitioning - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; - test_config.run( - expected_parquet, - plan_parquet.clone(), - &DISTRIB_DISTRIB_SORT, - )?; - test_config.run(expected_parquet, plan_parquet, &SORT_DISTRIB_DISTRIB)?; + let plan_parquet_distrib = + test_config.to_plan(plan_parquet.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_parquet_distrib, @r" + GlobalLimitExec: skip=0, fetch=100 + LocalLimitExec: fetch=100 + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + // data is sorted so can't repartition here + // Doesn't parallelize for SortExec without preserve_partitioning + let plan_parquet_sort = test_config.to_plan(plan_parquet, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_parquet_distrib, plan_parquet_sort); // Test: with csv - let expected_csv = &[ - "GlobalLimitExec: skip=0, fetch=100", - " LocalLimitExec: fetch=100", - // data is sorted so can't repartition here - " SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", - // Doesn't parallelize for SortExec without preserve_partitioning - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - ]; - test_config.run(expected_csv, plan_csv.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected_csv, plan_csv, &SORT_DISTRIB_DISTRIB)?; + let plan_csv_distrib = test_config.to_plan(plan_csv.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_csv_distrib, + @r" + GlobalLimitExec: skip=0, fetch=100 + LocalLimitExec: fetch=100 + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + "); + // data is sorted so can't repartition here + // Doesn't parallelize for SortExec without preserve_partitioning + let plan_csv_sort = test_config.to_plan(plan_csv, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_csv_distrib, plan_csv_sort); Ok(()) } @@ -2686,54 +2727,53 @@ fn parallelization_sorted_limit() -> Result<()> { #[test] fn parallelization_limit_with_filter() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); - let plan_parquet = limit_exec(filter_exec(sort_exec( - sort_key.clone(), - parquet_exec(), - false, - ))); - let plan_csv = limit_exec(filter_exec(sort_exec(sort_key, csv_exec(), false))); + }] + .into(); + let plan_parquet = + limit_exec(filter_exec(sort_exec(sort_key.clone(), parquet_exec()))); + let plan_csv = limit_exec(filter_exec(sort_exec(sort_key, csv_exec()))); let test_config = TestConfig::default(); // Test: with parquet - let expected_parquet = &[ - "GlobalLimitExec: skip=0, fetch=100", - " CoalescePartitionsExec", - " LocalLimitExec: fetch=100", - " FilterExec: c@2 = 0", - // even though data is sorted, we can use repartition here. Since - // ordering is not used in subsequent stages anyway. - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", - // SortExec doesn't benefit from input partitioning - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; - test_config.run( - expected_parquet, - plan_parquet.clone(), - &DISTRIB_DISTRIB_SORT, - )?; - test_config.run(expected_parquet, plan_parquet, &SORT_DISTRIB_DISTRIB)?; + let plan_parquet_distrib = + test_config.to_plan(plan_parquet.clone(), &DISTRIB_DISTRIB_SORT); + // even though data is sorted, we can use repartition here. Since + // ordering is not used in subsequent stages anyway. + // SortExec doesn't benefit from input partitioning + assert_plan!(plan_parquet_distrib, + @r" + GlobalLimitExec: skip=0, fetch=100 + CoalescePartitionsExec + LocalLimitExec: fetch=100 + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + let plan_parquet_sort = test_config.to_plan(plan_parquet, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_parquet_distrib, plan_parquet_sort); // Test: with csv - let expected_csv = &[ - "GlobalLimitExec: skip=0, fetch=100", - " CoalescePartitionsExec", - " LocalLimitExec: fetch=100", - " FilterExec: c@2 = 0", - // even though data is sorted, we can use repartition here. Since - // ordering is not used in subsequent stages anyway. - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", - // SortExec doesn't benefit from input partitioning - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - ]; - test_config.run(expected_csv, plan_csv.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected_csv, plan_csv, &SORT_DISTRIB_DISTRIB)?; + let plan_csv_distrib = test_config.to_plan(plan_csv.clone(), &DISTRIB_DISTRIB_SORT); + // even though data is sorted, we can use repartition here. Since + // ordering is not used in subsequent stages anyway. + // SortExec doesn't benefit from input partitioning + assert_plan!(plan_csv_distrib, + @r" + GlobalLimitExec: skip=0, fetch=100 + CoalescePartitionsExec + LocalLimitExec: fetch=100 + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + "); + let plan_csv_sort = test_config.to_plan(plan_csv, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_csv_distrib, plan_csv_sort); Ok(()) } @@ -2751,48 +2791,49 @@ fn parallelization_ignores_limit() -> Result<()> { let test_config = TestConfig::default(); // Test: with parquet - let expected_parquet = &[ - "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", - " AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " GlobalLimitExec: skip=0, fetch=100", - " CoalescePartitionsExec", - " LocalLimitExec: fetch=100", - " FilterExec: c@2 = 0", - // repartition should happen prior to the filter to maximize parallelism - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " GlobalLimitExec: skip=0, fetch=100", - // Limit doesn't benefit from input partitioning - no parallelism - " LocalLimitExec: fetch=100", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; - test_config.run( - expected_parquet, - plan_parquet.clone(), - &DISTRIB_DISTRIB_SORT, - )?; - test_config.run(expected_parquet, plan_parquet, &SORT_DISTRIB_DISTRIB)?; + let plan_parquet_distrib = + test_config.to_plan(plan_parquet.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_parquet_distrib, + @r" + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10 + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + GlobalLimitExec: skip=0, fetch=100 + CoalescePartitionsExec + LocalLimitExec: fetch=100 + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + GlobalLimitExec: skip=0, fetch=100 + LocalLimitExec: fetch=100 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + // repartition should happen prior to the filter to maximize parallelism + // Limit doesn't benefit from input partitioning - no parallelism + let plan_parquet_sort = test_config.to_plan(plan_parquet, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_parquet_distrib, plan_parquet_sort); // Test: with csv - let expected_csv = &[ - "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", - " AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " GlobalLimitExec: skip=0, fetch=100", - " CoalescePartitionsExec", - " LocalLimitExec: fetch=100", - " FilterExec: c@2 = 0", - // repartition should happen prior to the filter to maximize parallelism - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " GlobalLimitExec: skip=0, fetch=100", - // Limit doesn't benefit from input partitioning - no parallelism - " LocalLimitExec: fetch=100", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - ]; - test_config.run(expected_csv, plan_csv.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected_csv, plan_csv, &SORT_DISTRIB_DISTRIB)?; + let plan_csv_distrib = test_config.to_plan(plan_csv.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_csv_distrib, + @r" + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10 + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + GlobalLimitExec: skip=0, fetch=100 + CoalescePartitionsExec + LocalLimitExec: fetch=100 + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + GlobalLimitExec: skip=0, fetch=100 + LocalLimitExec: fetch=100 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + "); + // repartition should happen prior to the filter to maximize parallelism + // Limit doesn't benefit from input partitioning - no parallelism + let plan_csv_sort = test_config.to_plan(plan_csv, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_csv_distrib, plan_csv_sort); Ok(()) } @@ -2805,34 +2846,35 @@ fn parallelization_union_inputs() -> Result<()> { let test_config = TestConfig::default(); // Test: with parquet - let expected_parquet = &[ - "UnionExec", - // Union doesn't benefit from input partitioning - no parallelism - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; - test_config.run( - expected_parquet, - plan_parquet.clone(), - &DISTRIB_DISTRIB_SORT, - )?; - test_config.run(expected_parquet, plan_parquet, &SORT_DISTRIB_DISTRIB)?; + let plan_parquet_distrib = + test_config.to_plan(plan_parquet.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_parquet_distrib, + @r" + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + // Union doesn't benefit from input partitioning - no parallelism + let plan_parquet_sort = test_config.to_plan(plan_parquet, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_parquet_distrib, plan_parquet_sort); // Test: with csv - let expected_csv = &[ - "UnionExec", - // Union doesn't benefit from input partitioning - no parallelism - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - ]; - test_config.run(expected_csv, plan_csv.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected_csv, plan_csv, &SORT_DISTRIB_DISTRIB)?; + let plan_csv_distrib = test_config.to_plan(plan_csv.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_csv_distrib, + @r" + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + "); + // Union doesn't benefit from input partitioning - no parallelism + let plan_csv_sort = test_config.to_plan(plan_csv, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_csv_distrib, plan_csv_sort); Ok(()) } @@ -2840,14 +2882,15 @@ fn parallelization_union_inputs() -> Result<()> { #[test] fn parallelization_prior_to_sort_preserving_merge() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); // sort preserving merge already sorted input, let plan_parquet = sort_preserving_merge_exec( sort_key.clone(), - parquet_exec_with_sort(vec![sort_key.clone()]), + parquet_exec_with_sort(schema, vec![sort_key.clone()]), ); let plan_csv = sort_preserving_merge_exec(sort_key.clone(), csv_exec_with_sort(vec![sort_key])); @@ -2858,22 +2901,21 @@ fn parallelization_prior_to_sort_preserving_merge() -> Result<()> { // parallelization is not beneficial for SortPreservingMerge // Test: with parquet - let expected_parquet = &[ - "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet", - ]; - test_config.run( - expected_parquet, - plan_parquet.clone(), - &DISTRIB_DISTRIB_SORT, - )?; - test_config.run(expected_parquet, plan_parquet, &SORT_DISTRIB_DISTRIB)?; + let plan_parquet_distrib = + test_config.to_plan(plan_parquet.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_parquet_distrib, + @"DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet" + ); + let plan_parquet_sort = test_config.to_plan(plan_parquet, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_parquet_distrib, plan_parquet_sort); // Test: with csv - let expected_csv = &[ - "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=csv, has_header=false", - ]; - test_config.run(expected_csv, plan_csv.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected_csv, plan_csv, &SORT_DISTRIB_DISTRIB)?; + let plan_csv_distrib = test_config.to_plan(plan_csv.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_csv_distrib, + @"DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=csv, has_header=false" + ); + let plan_csv_sort = test_config.to_plan(plan_csv, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_csv_distrib, plan_csv_sort); Ok(()) } @@ -2881,13 +2923,17 @@ fn parallelization_prior_to_sort_preserving_merge() -> Result<()> { #[test] fn parallelization_sort_preserving_merge_with_union() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); // 2 sorted parquet files unioned (partitions are concatenated, sort is preserved) let input_parquet = - union_exec(vec![parquet_exec_with_sort(vec![sort_key.clone()]); 2]); + union_exec(vec![ + parquet_exec_with_sort(schema, vec![sort_key.clone()]); + 2 + ]); let input_csv = union_exec(vec![csv_exec_with_sort(vec![sort_key.clone()]); 2]); let plan_parquet = sort_preserving_merge_exec(sort_key.clone(), input_parquet); let plan_csv = sort_preserving_merge_exec(sort_key, input_csv); @@ -2899,54 +2945,47 @@ fn parallelization_sort_preserving_merge_with_union() -> Result<()> { // should not sort (as the data was already sorted) // Test: with parquet - let expected_parquet = &[ - "SortPreservingMergeExec: [c@2 ASC]", - " UnionExec", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet", - ]; - test_config.run( - expected_parquet, - plan_parquet.clone(), - &DISTRIB_DISTRIB_SORT, - )?; - let expected_parquet_first_sort_enforcement = &[ - // no SPM - "SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", - // has coalesce - " CoalescePartitionsExec", - " UnionExec", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet", - ]; - test_config.run( - expected_parquet_first_sort_enforcement, - plan_parquet, - &SORT_DISTRIB_DISTRIB, - )?; + let plan_parquet_distrib = + test_config.to_plan(plan_parquet.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_parquet_distrib, + @r" + SortPreservingMergeExec: [c@2 ASC] + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); + let plan_parquet_sort = test_config.to_plan(plan_parquet, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_parquet_sort, + @r" + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); + // no SPM + // has coalesce // Test: with csv - let expected_csv = &[ - "SortPreservingMergeExec: [c@2 ASC]", - " UnionExec", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=csv, has_header=false", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=csv, has_header=false", - ]; - test_config.run(expected_csv, plan_csv.clone(), &DISTRIB_DISTRIB_SORT)?; - let expected_csv_first_sort_enforcement = &[ - // no SPM - "SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", - // has coalesce - " CoalescePartitionsExec", - " UnionExec", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=csv, has_header=false", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=csv, has_header=false", - ]; - test_config.run( - expected_csv_first_sort_enforcement, - plan_csv.clone(), - &SORT_DISTRIB_DISTRIB, - )?; + let plan_csv_distrib = test_config.to_plan(plan_csv.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_csv_distrib, + @r" + SortPreservingMergeExec: [c@2 ASC] + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=csv, has_header=false + "); + let plan_csv_sort = test_config.to_plan(plan_csv.clone(), &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_csv_sort, + @r" + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=csv, has_header=false + "); + // no SPM + // has coalesce Ok(()) } @@ -2954,14 +2993,15 @@ fn parallelization_sort_preserving_merge_with_union() -> Result<()> { #[test] fn parallelization_does_not_benefit() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); // SortRequired // Parquet(sorted) let plan_parquet = sort_required_exec_with_req( - parquet_exec_with_sort(vec![sort_key.clone()]), + parquet_exec_with_sort(schema, vec![sort_key.clone()]), sort_key.clone(), ); let plan_csv = @@ -2973,24 +3013,25 @@ fn parallelization_does_not_benefit() -> Result<()> { // no parallelization, because SortRequiredExec doesn't benefit from increased parallelism // Test: with parquet - let expected_parquet = &[ - "SortRequiredExec: [c@2 ASC]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet", - ]; - test_config.run( - expected_parquet, - plan_parquet.clone(), - &DISTRIB_DISTRIB_SORT, - )?; - test_config.run(expected_parquet, plan_parquet, &SORT_DISTRIB_DISTRIB)?; + let plan_parquet_distrib = + test_config.to_plan(plan_parquet.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_parquet_distrib, + @r" + SortRequiredExec: [c@2 ASC] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); + let plan_parquet_sort = test_config.to_plan(plan_parquet, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_parquet_distrib, plan_parquet_sort); // Test: with csv - let expected_csv = &[ - "SortRequiredExec: [c@2 ASC]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=csv, has_header=false", - ]; - test_config.run(expected_csv, plan_csv.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected_csv, plan_csv, &SORT_DISTRIB_DISTRIB)?; + let plan_csv_distrib = test_config.to_plan(plan_csv.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_csv_distrib, + @r" + SortRequiredExec: [c@2 ASC] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=csv, has_header=false + "); + let plan_csv_sort = test_config.to_plan(plan_csv, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_csv_distrib, plan_csv_sort); Ok(()) } @@ -2999,44 +3040,48 @@ fn parallelization_does_not_benefit() -> Result<()> { fn parallelization_ignores_transitively_with_projection_parquet() -> Result<()> { // sorted input let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); //Projection(a as a2, b as b2) let alias_pairs: Vec<(String, String)> = vec![ ("a".to_string(), "a2".to_string()), ("c".to_string(), "c2".to_string()), ]; - let proj_parquet = - projection_exec_with_alias(parquet_exec_with_sort(vec![sort_key]), alias_pairs); - let sort_key_after_projection = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c2", &proj_parquet.schema()).unwrap(), + let proj_parquet = projection_exec_with_alias( + parquet_exec_with_sort(schema, vec![sort_key]), + alias_pairs, + ); + let sort_key_after_projection = [PhysicalSortExpr { + expr: col("c2", &proj_parquet.schema())?, options: SortOptions::default(), - }]); + }] + .into(); let plan_parquet = sort_preserving_merge_exec(sort_key_after_projection, proj_parquet); - let expected = &[ - "SortPreservingMergeExec: [c2@1 ASC]", - " ProjectionExec: expr=[a@0 as a2, c@2 as c2]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet", - ]; - plans_matches_expected!(expected, &plan_parquet); + assert_plan!(plan_parquet, + @r" + SortPreservingMergeExec: [c2@1 ASC] + ProjectionExec: expr=[a@0 as a2, c@2 as c2] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); + + let test_config = TestConfig::default(); + let plan_parquet_distrib = + test_config.to_plan(plan_parquet.clone(), &DISTRIB_DISTRIB_SORT); // Expected Outcome: // data should not be repartitioned / resorted - let expected_parquet = &[ - "ProjectionExec: expr=[a@0 as a2, c@2 as c2]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet", - ]; - let test_config = TestConfig::default(); - test_config.run( - expected_parquet, - plan_parquet.clone(), - &DISTRIB_DISTRIB_SORT, - )?; - test_config.run(expected_parquet, plan_parquet, &SORT_DISTRIB_DISTRIB)?; + assert_plan!(plan_parquet_distrib, + @r" + ProjectionExec: expr=[a@0 as a2, c@2 as c2] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); + let plan_parquet_sort = test_config.to_plan(plan_parquet, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_parquet_distrib, plan_parquet_sort); Ok(()) } @@ -3045,10 +3090,11 @@ fn parallelization_ignores_transitively_with_projection_parquet() -> Result<()> fn parallelization_ignores_transitively_with_projection_csv() -> Result<()> { // sorted input let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); //Projection(a as a2, b as b2) let alias_pairs: Vec<(String, String)> = vec![ @@ -3058,27 +3104,30 @@ fn parallelization_ignores_transitively_with_projection_csv() -> Result<()> { let proj_csv = projection_exec_with_alias(csv_exec_with_sort(vec![sort_key]), alias_pairs); - let sort_key_after_projection = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c2", &proj_csv.schema()).unwrap(), + let sort_key_after_projection = [PhysicalSortExpr { + expr: col("c2", &proj_csv.schema())?, options: SortOptions::default(), - }]); + }] + .into(); let plan_csv = sort_preserving_merge_exec(sort_key_after_projection, proj_csv); - let expected = &[ - "SortPreservingMergeExec: [c2@1 ASC]", - " ProjectionExec: expr=[a@0 as a2, c@2 as c2]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=csv, has_header=false", - ]; - plans_matches_expected!(expected, &plan_csv); + assert_plan!(plan_csv, + @r" + SortPreservingMergeExec: [c2@1 ASC] + ProjectionExec: expr=[a@0 as a2, c@2 as c2] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=csv, has_header=false + "); + let test_config = TestConfig::default(); + let plan_distrib = test_config.to_plan(plan_csv.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_distrib, + @r" + ProjectionExec: expr=[a@0 as a2, c@2 as c2] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=csv, has_header=false + "); // Expected Outcome: // data should not be repartitioned / resorted - let expected_csv = &[ - "ProjectionExec: expr=[a@0 as a2, c@2 as c2]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=csv, has_header=false", - ]; - let test_config = TestConfig::default(); - test_config.run(expected_csv, plan_csv.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected_csv, plan_csv, &SORT_DISTRIB_DISTRIB)?; + let plan_sort = test_config.to_plan(plan_csv, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); Ok(()) } @@ -3088,24 +3137,25 @@ fn remove_redundant_roundrobins() -> Result<()> { let input = parquet_exec(); let repartition = repartition_exec(repartition_exec(input)); let physical_plan = repartition_exec(filter_exec(repartition)); - let expected = &[ - "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10", - " FilterExec: c@2 = 0", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; - plans_matches_expected!(expected, &physical_plan); - - let expected = &[ - "FilterExec: c@2 = 0", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; + assert_plan!(physical_plan, + @r" + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10 + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); let test_config = TestConfig::default(); - test_config.run(expected, physical_plan.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected, physical_plan, &SORT_DISTRIB_DISTRIB)?; + let plan_distrib = test_config.to_plan(physical_plan.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_distrib, + @r" + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); Ok(()) } @@ -3114,28 +3164,30 @@ fn remove_redundant_roundrobins() -> Result<()> { #[test] fn remove_unnecessary_spm_after_filter() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); let physical_plan = sort_preserving_merge_exec(sort_key, filter_exec(input)); // TestConfig: Prefer existing sort. let test_config = TestConfig::default().with_prefer_existing_sort(); + let plan_distrib = test_config.to_plan(physical_plan.clone(), &DISTRIB_DISTRIB_SORT); // Expected Outcome: // Original plan expects its output to be ordered by c@2 ASC. // This is still satisfied since, after filter that column is constant. - let expected = &[ - "CoalescePartitionsExec", - " FilterExec: c@2 = 0", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true, sort_exprs=c@2 ASC", - " DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet", - ]; - - test_config.run(expected, physical_plan.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected, physical_plan, &SORT_DISTRIB_DISTRIB)?; + assert_plan!(plan_distrib, + @r" + CoalescePartitionsExec + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true, sort_exprs=c@2 ASC + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); + let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); Ok(()) } @@ -3144,24 +3196,27 @@ fn remove_unnecessary_spm_after_filter() -> Result<()> { #[test] fn preserve_ordering_through_repartition() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("d", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("d", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); let physical_plan = sort_preserving_merge_exec(sort_key, filter_exec(input)); // TestConfig: Prefer existing sort. let test_config = TestConfig::default().with_prefer_existing_sort(); - let expected = &[ - "SortPreservingMergeExec: [d@3 ASC]", - " FilterExec: c@2 = 0", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true, sort_exprs=d@3 ASC", - " DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[d@3 ASC], file_type=parquet", - ]; - test_config.run(expected, physical_plan.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected, physical_plan, &SORT_DISTRIB_DISTRIB)?; + let plan_distrib = test_config.to_plan(physical_plan.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_distrib, + @r" + SortPreservingMergeExec: [d@3 ASC] + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true, sort_exprs=d@3 ASC + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[d@3 ASC], file_type=parquet + "); + let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); Ok(()) } @@ -3169,38 +3224,37 @@ fn preserve_ordering_through_repartition() -> Result<()> { #[test] fn do_not_preserve_ordering_through_repartition() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); let physical_plan = sort_preserving_merge_exec(sort_key, filter_exec(input)); let test_config = TestConfig::default(); + let plan_distrib = test_config.to_plan(physical_plan.clone(), &DISTRIB_DISTRIB_SORT); // Test: run EnforceDistribution, then EnforceSort. - let expected = &[ - "SortPreservingMergeExec: [a@0 ASC]", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", - " FilterExec: c@2 = 0", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", - " DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet", - ]; - test_config.run(expected, physical_plan.clone(), &DISTRIB_DISTRIB_SORT)?; + assert_plan!(plan_distrib, + @r" + SortPreservingMergeExec: [a@0 ASC] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2 + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + "); // Test: result IS DIFFERENT, if EnforceSorting is run first: - let expected_first_sort_enforcement = &[ - "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " FilterExec: c@2 = 0", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", - " DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet", - ]; - test_config.run( - expected_first_sort_enforcement, - physical_plan, - &SORT_DISTRIB_DISTRIB, - )?; + let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_sort, + @r" + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2 + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + "); Ok(()) } @@ -3208,24 +3262,26 @@ fn do_not_preserve_ordering_through_repartition() -> Result<()> { #[test] fn no_need_for_sort_after_filter() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); let physical_plan = sort_preserving_merge_exec(sort_key, filter_exec(input)); - let expected = &[ - // After CoalescePartitionsExec c is still constant. Hence c@2 ASC ordering is already satisfied. - "CoalescePartitionsExec", - // Since after this stage c is constant. c@2 ASC ordering is already satisfied. - " FilterExec: c@2 = 0", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", - " DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet", - ]; let test_config = TestConfig::default(); - test_config.run(expected, physical_plan.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected, physical_plan, &SORT_DISTRIB_DISTRIB)?; + let plan_distrib = test_config.to_plan(physical_plan.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_distrib, @r" + CoalescePartitionsExec + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2 + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); + let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); + // After CoalescePartitionsExec c is still constant. Hence c@2 ASC ordering is already satisfied. + // Since after this stage c is constant. c@2 ASC ordering is already satisfied. Ok(()) } @@ -3233,44 +3289,44 @@ fn no_need_for_sort_after_filter() -> Result<()> { #[test] fn do_not_preserve_ordering_through_repartition2() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input = parquet_exec_multiple_sorted(vec![sort_key]); - let sort_req = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema).unwrap(), + let sort_req = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let physical_plan = sort_preserving_merge_exec(sort_req, filter_exec(input)); let test_config = TestConfig::default(); + let plan_distrib = test_config.to_plan(physical_plan.clone(), &DISTRIB_DISTRIB_SORT); // Test: run EnforceDistribution, then EnforceSort. - let expected = &[ - "SortPreservingMergeExec: [a@0 ASC]", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", - " FilterExec: c@2 = 0", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", - " DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet", - ]; - test_config.run(expected, physical_plan.clone(), &DISTRIB_DISTRIB_SORT)?; + assert_plan!(plan_distrib, + @r" + SortPreservingMergeExec: [a@0 ASC] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2 + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); // Test: result IS DIFFERENT, if EnforceSorting is run first: - let expected_first_sort_enforcement = &[ - "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", - " FilterExec: c@2 = 0", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", - " DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet", - ]; - test_config.run( - expected_first_sort_enforcement, - physical_plan, - &SORT_DISTRIB_DISTRIB, - )?; + let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_sort, + @r" + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2 + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); Ok(()) } @@ -3278,21 +3334,24 @@ fn do_not_preserve_ordering_through_repartition2() -> Result<()> { #[test] fn do_not_preserve_ordering_through_repartition3() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input = parquet_exec_multiple_sorted(vec![sort_key]); let physical_plan = filter_exec(input); - let expected = &[ - "FilterExec: c@2 = 0", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", - " DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet", - ]; let test_config = TestConfig::default(); - test_config.run(expected, physical_plan.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected, physical_plan, &SORT_DISTRIB_DISTRIB)?; + let plan_distrib = test_config.to_plan(physical_plan.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_distrib, + @r" + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2 + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); + let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); Ok(()) } @@ -3300,36 +3359,34 @@ fn do_not_preserve_ordering_through_repartition3() -> Result<()> { #[test] fn do_not_put_sort_when_input_is_invalid() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input = parquet_exec(); let physical_plan = sort_required_exec_with_req(filter_exec(input), sort_key); - let expected = &[ - // Ordering requirement of sort required exec is NOT satisfied - // by existing ordering at the source. - "SortRequiredExec: [a@0 ASC]", - " FilterExec: c@2 = 0", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; - assert_plan_txt!(expected, physical_plan); - - let expected = &[ - "SortRequiredExec: [a@0 ASC]", - // Since at the start of the rule ordering requirement is not satisfied - // EnforceDistribution rule doesn't satisfy this requirement either. - " FilterExec: c@2 = 0", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; + // Ordering requirement of sort required exec is NOT satisfied + // by existing ordering at the source. + assert_plan!(physical_plan, @r" + SortRequiredExec: [a@0 ASC] + FilterExec: c@2 = 0 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); let mut config = ConfigOptions::new(); config.execution.target_partitions = 10; config.optimizer.enable_round_robin_repartition = true; config.optimizer.prefer_existing_sort = false; let dist_plan = EnforceDistribution::new().optimize(physical_plan, &config)?; - assert_plan_txt!(expected, dist_plan); + // Since at the start of the rule ordering requirement is not satisfied + // EnforceDistribution rule doesn't satisfy this requirement either. + assert_plan!(dist_plan, @r" + SortRequiredExec: [a@0 ASC] + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); Ok(()) } @@ -3337,36 +3394,34 @@ fn do_not_put_sort_when_input_is_invalid() -> Result<()> { #[test] fn put_sort_when_input_is_valid() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); let physical_plan = sort_required_exec_with_req(filter_exec(input), sort_key); - let expected = &[ - // Ordering requirement of sort required exec is satisfied - // by existing ordering at the source. - "SortRequiredExec: [a@0 ASC]", - " FilterExec: c@2 = 0", - " DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet", - ]; - assert_plan_txt!(expected, physical_plan); - - let expected = &[ - // Since at the start of the rule ordering requirement is satisfied - // EnforceDistribution rule satisfy this requirement also. - "SortRequiredExec: [a@0 ASC]", - " FilterExec: c@2 = 0", - " DataSourceExec: file_groups={10 groups: [[x:0..20], [y:0..20], [x:20..40], [y:20..40], [x:40..60], [y:40..60], [x:60..80], [y:60..80], [x:80..100], [y:80..100]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet", - ]; + // Ordering requirement of sort required exec is satisfied + // by existing ordering at the source. + assert_plan!(physical_plan, @r" + SortRequiredExec: [a@0 ASC] + FilterExec: c@2 = 0 + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + "); let mut config = ConfigOptions::new(); config.execution.target_partitions = 10; config.optimizer.enable_round_robin_repartition = true; config.optimizer.prefer_existing_sort = false; let dist_plan = EnforceDistribution::new().optimize(physical_plan, &config)?; - assert_plan_txt!(expected, dist_plan); + // Since at the start of the rule ordering requirement is satisfied + // EnforceDistribution rule satisfy this requirement also. + assert_plan!(dist_plan, @r" + SortRequiredExec: [a@0 ASC] + FilterExec: c@2 = 0 + DataSourceExec: file_groups={10 groups: [[x:0..20], [y:0..20], [x:20..40], [y:20..40], [x:40..60], [y:40..60], [x:60..80], [y:60..80], [x:80..100], [y:80..100]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + "); Ok(()) } @@ -3374,25 +3429,28 @@ fn put_sort_when_input_is_valid() -> Result<()> { #[test] fn do_not_add_unnecessary_hash() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let alias = vec![("a".to_string(), "a".to_string())]; - let input = parquet_exec_with_sort(vec![sort_key]); + let input = parquet_exec_with_sort(schema, vec![sort_key]); let physical_plan = aggregate_exec_with_alias(input, alias); // TestConfig: // Make sure target partition number is 1. In this case hash repartition is unnecessary. let test_config = TestConfig::default().with_query_execution_partitions(1); - let expected = &[ - "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", - " AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet", - ]; - test_config.run(expected, physical_plan.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected, physical_plan, &SORT_DISTRIB_DISTRIB)?; + let plan_distrib = test_config.to_plan(physical_plan.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_distrib, + @r" + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); + let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); Ok(()) } @@ -3400,10 +3458,11 @@ fn do_not_add_unnecessary_hash() -> Result<()> { #[test] fn do_not_add_unnecessary_hash2() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let alias = vec![("a".to_string(), "a".to_string())]; let input = parquet_exec_multiple_sorted(vec![sort_key]); let aggregate = aggregate_exec_with_alias(input, alias.clone()); @@ -3413,19 +3472,21 @@ fn do_not_add_unnecessary_hash2() -> Result<()> { // Make sure target partition number is larger than 2 (e.g partition number at the source). let test_config = TestConfig::default().with_query_execution_partitions(4); - let expected = &[ - "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", - // Since hash requirements of this operator is satisfied. There shouldn't be - // a hash repartition here - " AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", - " AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", - " RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4", - " AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", - " RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=2", - " DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet", - ]; - test_config.run(expected, physical_plan.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected, physical_plan, &SORT_DISTRIB_DISTRIB)?; + let plan_distrib = test_config.to_plan(physical_plan.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_distrib, + @r" + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[] + RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=2 + DataSourceExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC], file_type=parquet + "); + // Since hash requirements of this operator is satisfied. There shouldn't be + // a hash repartition here + let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); Ok(()) } @@ -3433,19 +3494,19 @@ fn do_not_add_unnecessary_hash2() -> Result<()> { #[test] fn optimize_away_unnecessary_repartition() -> Result<()> { let physical_plan = coalesce_partitions_exec(repartition_exec(parquet_exec())); - let expected = &[ - "CoalescePartitionsExec", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; - plans_matches_expected!(expected, physical_plan.clone()); - - let expected = - &["DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet"]; + assert_plan!(physical_plan, + @r" + CoalescePartitionsExec + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); let test_config = TestConfig::default(); - test_config.run(expected, physical_plan.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected, physical_plan, &SORT_DISTRIB_DISTRIB)?; + let plan_distrib = test_config.to_plan(physical_plan.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_distrib, + @"DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet"); + let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); Ok(()) } @@ -3455,25 +3516,27 @@ fn optimize_away_unnecessary_repartition2() -> Result<()> { let physical_plan = filter_exec(repartition_exec(coalesce_partitions_exec( filter_exec(repartition_exec(parquet_exec())), ))); - let expected = &[ - "FilterExec: c@2 = 0", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " CoalescePartitionsExec", - " FilterExec: c@2 = 0", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; - plans_matches_expected!(expected, physical_plan.clone()); + assert_plan!(physical_plan, + @r" + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + CoalescePartitionsExec + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); - let expected = &[ - "FilterExec: c@2 = 0", - " FilterExec: c@2 = 0", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; let test_config = TestConfig::default(); - test_config.run(expected, physical_plan.clone(), &DISTRIB_DISTRIB_SORT)?; - test_config.run(expected, physical_plan, &SORT_DISTRIB_DISTRIB)?; + let plan_distrib = test_config.to_plan(physical_plan.clone(), &DISTRIB_DISTRIB_SORT); + assert_plan!(plan_distrib, + @r" + FilterExec: c@2 = 0 + FilterExec: c@2 = 0 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); + let plan_sort = test_config.to_plan(physical_plan, &SORT_DISTRIB_DISTRIB); + assert_plan!(plan_distrib, plan_sort); Ok(()) } @@ -3489,34 +3552,35 @@ async fn test_distribute_sort_parquet() -> Result<()> { ); let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), - options: SortOptions::default(), - }]); - let physical_plan = sort_exec(sort_key, parquet_exec_with_stats(10000 * 8192), false); + let sort_key = [PhysicalSortExpr::new_default(col("c", &schema)?)].into(); + let physical_plan = sort_exec(sort_key, parquet_exec_with_stats(10000 * 8192)); // prior to optimization, this is the starting plan - let starting = &[ - "SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; - plans_matches_expected!(starting, physical_plan.clone()); + assert_plan!(physical_plan, + @r" + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + "); // what the enforce distribution run does. - let expected = &[ - "SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " DataSourceExec: file_groups={10 groups: [[x:0..8192000], [x:8192000..16384000], [x:16384000..24576000], [x:24576000..32768000], [x:32768000..40960000], [x:40960000..49152000], [x:49152000..57344000], [x:57344000..65536000], [x:65536000..73728000], [x:73728000..81920000]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; - test_config.run(expected, physical_plan.clone(), &[Run::Distribution])?; + let plan_distribution = + test_config.to_plan(physical_plan.clone(), &[Run::Distribution]); + assert_plan!(plan_distribution, + @r" + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + DataSourceExec: file_groups={10 groups: [[x:0..8192000], [x:8192000..16384000], [x:16384000..24576000], [x:24576000..32768000], [x:32768000..40960000], [x:40960000..49152000], [x:49152000..57344000], [x:57344000..65536000], [x:65536000..73728000], [x:73728000..81920000]]}, projection=[a, b, c, d, e], file_type=parquet + "); // what the sort parallelization (in enforce sorting), does after the enforce distribution changes - let expected = &[ - "SortPreservingMergeExec: [c@2 ASC]", - " SortExec: expr=[c@2 ASC], preserve_partitioning=[true]", - " DataSourceExec: file_groups={10 groups: [[x:0..8192000], [x:8192000..16384000], [x:16384000..24576000], [x:24576000..32768000], [x:32768000..40960000], [x:40960000..49152000], [x:49152000..57344000], [x:57344000..65536000], [x:65536000..73728000], [x:73728000..81920000]]}, projection=[a, b, c, d, e], file_type=parquet", - ]; - test_config.run(expected, physical_plan, &[Run::Distribution, Run::Sorting])?; + let plan_both = + test_config.to_plan(physical_plan, &[Run::Distribution, Run::Sorting]); + assert_plan!(plan_both, + @r" + SortPreservingMergeExec: [c@2 ASC] + SortExec: expr=[c@2 ASC], preserve_partitioning=[true] + DataSourceExec: file_groups={10 groups: [[x:0..8192000], [x:8192000..16384000], [x:16384000..24576000], [x:24576000..32768000], [x:32768000..40960000], [x:40960000..49152000], [x:49152000..57344000], [x:57344000..65536000], [x:65536000..73728000], [x:73728000..81920000]]}, projection=[a, b, c, d, e], file_type=parquet + "); Ok(()) } @@ -3541,12 +3605,12 @@ async fn test_distribute_sort_memtable() -> Result<()> { let physical_plan = dataframe.create_physical_plan().await?; // this is the final, optimized plan - let expected = &[ - "SortPreservingMergeExec: [id@0 ASC NULLS LAST]", - " SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[true]", - " DataSourceExec: partitions=3, partition_sizes=[34, 33, 33]", - ]; - plans_matches_expected!(expected, physical_plan); + assert_plan!(physical_plan, + @r" + SortPreservingMergeExec: [id@0 ASC NULLS LAST] + SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[true] + DataSourceExec: partitions=3, partition_sizes=[34, 33, 33] + "); Ok(()) } @@ -3583,16 +3647,12 @@ fn test_replace_order_preserving_variants_with_fetch() -> Result<()> { // Create a base plan let parquet_exec = parquet_exec(); - let sort_expr = PhysicalSortExpr { - expr: Arc::new(Column::new("id", 0)), - options: SortOptions::default(), - }; - - let ordering = LexOrdering::new(vec![sort_expr]); + let sort_expr = PhysicalSortExpr::new_default(Arc::new(Column::new("id", 0))); // Create a SortPreservingMergeExec with fetch=5 let spm_exec = Arc::new( - SortPreservingMergeExec::new(ordering, parquet_exec.clone()).with_fetch(Some(5)), + SortPreservingMergeExec::new([sort_expr].into(), parquet_exec.clone()) + .with_fetch(Some(5)), ); // Create distribution context diff --git a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs index f7668c8aab11f..6349ff1cd109f 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs @@ -17,130 +17,119 @@ use std::sync::Arc; +use crate::memory_limit::DummyStreamPartition; use crate::physical_optimizer::test_utils::{ - aggregate_exec, bounded_window_exec, check_integrity, coalesce_batches_exec, - coalesce_partitions_exec, create_test_schema, create_test_schema2, - create_test_schema3, filter_exec, global_limit_exec, hash_join_exec, limit_exec, - local_limit_exec, memory_exec, parquet_exec, repartition_exec, sort_exec, + RequirementsTestExec, aggregate_exec, bounded_window_exec, + bounded_window_exec_with_partition, check_integrity, coalesce_partitions_exec, + create_test_schema, create_test_schema2, create_test_schema3, filter_exec, + global_limit_exec, hash_join_exec, local_limit_exec, memory_exec, parquet_exec, + parquet_exec_with_sort, projection_exec, repartition_exec, sort_exec, sort_exec_with_fetch, sort_expr, sort_expr_options, sort_merge_join_exec, sort_preserving_merge_exec, sort_preserving_merge_exec_with_fetch, - spr_repartition_exec, stream_exec_ordered, union_exec, RequirementsTestExec, + spr_repartition_exec, stream_exec_ordered, union_exec, }; -use arrow::compute::SortOptions; +use arrow::compute::{SortOptions}; use arrow::datatypes::{DataType, SchemaRef}; -use datafusion_common::config::ConfigOptions; +use datafusion_common::config::{ConfigOptions, CsvOptions}; use datafusion_common::tree_node::{TreeNode, TransformedResult}; -use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::{JoinType, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition}; +use datafusion_common::{create_array, Result, TableReference}; +use datafusion_datasource::file_scan_config::FileScanConfigBuilder; +use datafusion_datasource::source::DataSourceExec; +use datafusion_expr_common::operator::Operator; +use datafusion_expr::{JoinType, SortExpr}; use datafusion_execution::object_store::ObjectStoreUrl; -use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; -use datafusion_physical_expr::expressions::{col, Column, NotExpr}; -use datafusion_physical_expr::Partitioning; -use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion_physical_expr_common::sort_expr::{ + LexOrdering, PhysicalSortExpr, PhysicalSortRequirement, OrderingRequirements +}; +use datafusion_physical_expr::{Distribution, Partitioning}; +use datafusion_physical_expr::expressions::{col, BinaryExpr, Column, NotExpr}; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::sorts::sort::SortExec; -use datafusion_physical_plan::windows::{create_window_expr, BoundedWindowAggExec, WindowAggExec}; -use datafusion_physical_plan::{displayable, get_plan_string, ExecutionPlan, InputOrderMode}; -use datafusion::datasource::physical_plan::{CsvSource, ParquetSource}; +use datafusion_physical_plan::{displayable, get_plan_string, ExecutionPlan}; +use datafusion::datasource::physical_plan::CsvSource; use datafusion::datasource::listing::PartitionedFile; use datafusion_physical_optimizer::enforce_sorting::{EnforceSorting, PlanWithCorrespondingCoalescePartitions, PlanWithCorrespondingSort, parallelize_sorts, ensure_sorting}; use datafusion_physical_optimizer::enforce_sorting::replace_with_order_preserving_variants::{replace_with_order_preserving_variants, OrderPreservationContext}; use datafusion_physical_optimizer::enforce_sorting::sort_pushdown::{SortPushDown, assign_initial_requirements, pushdown_sorts}; use datafusion_physical_optimizer::enforce_distribution::EnforceDistribution; +use datafusion_physical_optimizer::output_requirements::OutputRequirementExec; use datafusion_physical_optimizer::PhysicalOptimizerRule; -use datafusion_functions_aggregate::average::avg_udaf; -use datafusion_functions_aggregate::count::count_udaf; -use datafusion_functions_aggregate::min_max::{max_udaf, min_udaf}; - -use datafusion_datasource::file_scan_config::FileScanConfigBuilder; -use datafusion_datasource::source::DataSourceExec; -use rstest::rstest; - -/// Create a csv exec for tests -fn csv_exec_ordered( - schema: &SchemaRef, - sort_exprs: impl IntoIterator, -) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - - let config = FileScanConfigBuilder::new( - ObjectStoreUrl::parse("test:///").unwrap(), - schema.clone(), - Arc::new(CsvSource::new(true, 0, b'"')), - ) - .with_file(PartitionedFile::new("file_path".to_string(), 100)) - .with_output_ordering(vec![sort_exprs]) - .build(); - - DataSourceExec::from_data_source(config) -} - -/// Created a sorted parquet exec -pub fn parquet_exec_sorted( - schema: &SchemaRef, - sort_exprs: impl IntoIterator, -) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - - let source = Arc::new(ParquetSource::default()); - let config = FileScanConfigBuilder::new( - ObjectStoreUrl::parse("test:///").unwrap(), - schema.clone(), - source, - ) - .with_file(PartitionedFile::new("x".to_string(), 100)) - .with_output_ordering(vec![sort_exprs]) - .build(); +use datafusion::prelude::*; +use arrow::array::{record_batch, ArrayRef, Int32Array, RecordBatch}; +use arrow::datatypes::{Field}; +use arrow_schema::Schema; +use datafusion_execution::TaskContext; +use datafusion_catalog::streaming::StreamingTable; - DataSourceExec::from_data_source(config) -} +use futures::StreamExt; +use insta::{Settings, assert_snapshot}; /// Create a sorted Csv exec fn csv_exec_sorted( schema: &SchemaRef, sort_exprs: impl IntoIterator, ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - - let config = FileScanConfigBuilder::new( + let options = CsvOptions { + has_header: Some(false), + delimiter: 0, + quote: 0, + ..Default::default() + }; + let mut builder = FileScanConfigBuilder::new( ObjectStoreUrl::parse("test:///").unwrap(), - schema.clone(), - Arc::new(CsvSource::new(false, 0, 0)), + Arc::new(CsvSource::new(schema.clone()).with_csv_options(options)), ) - .with_file(PartitionedFile::new("x".to_string(), 100)) - .with_output_ordering(vec![sort_exprs]) - .build(); + .with_file(PartitionedFile::new("x".to_string(), 100)); + if let Some(ordering) = LexOrdering::new(sort_exprs) { + builder = builder.with_output_ordering(vec![ordering]); + } + let config = builder.build(); DataSourceExec::from_data_source(config) } /// Runs the sort enforcement optimizer and asserts the plan /// against the original and expected plans -/// -/// `$EXPECTED_PLAN_LINES`: input plan -/// `$EXPECTED_OPTIMIZED_PLAN_LINES`: optimized plan -/// `$PLAN`: the plan to optimized -/// `REPARTITION_SORTS`: Flag to set `config.options.optimizer.repartition_sorts` option. -/// -macro_rules! assert_optimized { - ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $REPARTITION_SORTS: expr) => { +pub(crate) struct EnforceSortingTest { + plan: Arc, + repartition_sorts: bool, +} + +impl EnforceSortingTest { + pub(crate) fn new(plan: Arc) -> Self { + Self { + plan, + repartition_sorts: false, + } + } + + /// Set whether to repartition sorts + pub(crate) fn with_repartition_sorts(mut self, repartition_sorts: bool) -> Self { + self.repartition_sorts = repartition_sorts; + self + } + + /// Runs the enforce sorting test and returns a string with the input and + /// optimized plan as strings for snapshot comparison using insta + pub(crate) fn run(&self) -> String { let mut config = ConfigOptions::new(); - config.optimizer.repartition_sorts = $REPARTITION_SORTS; + config.optimizer.repartition_sorts = self.repartition_sorts; // This file has 4 rules that use tree node, apply these rules as in the // EnforceSorting::optimize implementation // After these operations tree nodes should be in a consistent state. // This code block makes sure that these rules doesn't violate tree node integrity. { - let plan_requirements = PlanWithCorrespondingSort::new_default($PLAN.clone()); + let plan_requirements = + PlanWithCorrespondingSort::new_default(Arc::clone(&self.plan)); let adjusted = plan_requirements .transform_up(ensure_sorting) .data() - .and_then(check_integrity)?; + .and_then(check_integrity) + .expect("check_integrity failed after ensure_sorting"); // TODO: End state payloads will be checked here. let new_plan = if config.optimizer.repartition_sorts { @@ -149,60 +138,60 @@ macro_rules! assert_optimized { let parallel = plan_with_coalesce_partitions .transform_up(parallelize_sorts) .data() - .and_then(check_integrity)?; + .and_then(check_integrity) + .expect("check_integrity failed after parallelize_sorts"); // TODO: End state payloads will be checked here. parallel.plan } else { adjusted.plan }; - let plan_with_pipeline_fixer = OrderPreservationContext::new_default(new_plan); + let plan_with_pipeline_fixer = + OrderPreservationContext::new_default(new_plan); let updated_plan = plan_with_pipeline_fixer .transform_up(|plan_with_pipeline_fixer| { replace_with_order_preserving_variants( plan_with_pipeline_fixer, false, true, - &config, + &config, ) }) .data() - .and_then(check_integrity)?; + .and_then(check_integrity) + .expect( + "check_integrity failed after replace_with_order_preserving_variants", + ); // TODO: End state payloads will be checked here. let mut sort_pushdown = SortPushDown::new_default(updated_plan.plan); assign_initial_requirements(&mut sort_pushdown); - check_integrity(pushdown_sorts(sort_pushdown)?)?; + check_integrity( + pushdown_sorts(sort_pushdown).expect("pushdown_sorts failed"), + ) + .expect("check_integrity failed after pushdown_sorts"); // TODO: End state payloads will be checked here. } - - let physical_plan = $PLAN; - let formatted = displayable(physical_plan.as_ref()).indent(true).to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - - let expected_plan_lines: Vec<&str> = $EXPECTED_PLAN_LINES - .iter().map(|s| *s).collect(); - - assert_eq!( - expected_plan_lines, actual, - "\n**Original Plan Mismatch\n\nexpected:\n\n{expected_plan_lines:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - let expected_optimized_lines: Vec<&str> = $EXPECTED_OPTIMIZED_PLAN_LINES - .iter().map(|s| *s).collect(); + let input_plan_string = displayable(self.plan.as_ref()).indent(true).to_string(); // Run the actual optimizer - let optimized_physical_plan = - EnforceSorting::new().optimize(physical_plan,&config)?; + let optimized_physical_plan = EnforceSorting::new() + .optimize(Arc::clone(&self.plan), &config) + .expect("enforce_sorting failed"); // Get string representation of the plan - let actual = get_plan_string(&optimized_physical_plan); - assert_eq!( - expected_optimized_lines, actual, - "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected_optimized_lines:#?}\nactual:\n\n{actual:#?}\n\n" - ); + let optimized_plan_string = displayable(optimized_physical_plan.as_ref()) + .indent(true) + .to_string(); - }; + if input_plan_string == optimized_plan_string { + format!("Input / Optimized Plan:\n{input_plan_string}",) + } else { + format!( + "Input Plan:\n{input_plan_string}\nOptimized Plan:\n{optimized_plan_string}", + ) + } + } } #[tokio::test] @@ -210,96 +199,97 @@ async fn test_remove_unnecessary_sort5() -> Result<()> { let left_schema = create_test_schema2()?; let right_schema = create_test_schema3()?; let left_input = memory_exec(&left_schema); - let parquet_sort_exprs = vec![sort_expr("a", &right_schema)]; - let right_input = parquet_exec_sorted(&right_schema, parquet_sort_exprs); - + let parquet_ordering = [sort_expr("a", &right_schema)].into(); + let right_input = + parquet_exec_with_sort(right_schema.clone(), vec![parquet_ordering]); let on = vec![( Arc::new(Column::new_with_schema("col_a", &left_schema)?) as _, Arc::new(Column::new_with_schema("c", &right_schema)?) as _, )]; let join = hash_join_exec(left_input, right_input, on, None, &JoinType::Inner)?; - let physical_plan = sort_exec(vec![sort_expr("a", &join.schema())], join); - - let expected_input = ["SortExec: expr=[a@2 ASC], preserve_partitioning=[false]", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col_a@0, c@2)]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet"]; - - let expected_optimized = ["HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col_a@0, c@2)]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); - + let physical_plan = sort_exec([sort_expr("a", &join.schema())].into(), join); + + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortExec: expr=[a@2 ASC], preserve_partitioning=[false] + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col_a@0, c@2)] + DataSourceExec: partitions=1, partition_sizes=[0] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + + Optimized Plan: + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col_a@0, c@2)] + DataSourceExec: partitions=1, partition_sizes=[0] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + "); Ok(()) } #[tokio::test] async fn test_do_not_remove_sort_with_limit() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs = vec![ + let source1 = parquet_exec(schema.clone()); + let ordering: LexOrdering = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let sort = sort_exec(sort_exprs.clone(), source1); - let limit = limit_exec(sort); - - let parquet_sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let source2 = parquet_exec_sorted(&schema, parquet_sort_exprs); - + ] + .into(); + let sort = sort_exec(ordering.clone(), source1); + let limit = local_limit_exec(sort, 100); + let parquet_ordering = [sort_expr("nullable_col", &schema)].into(); + let source2 = parquet_exec_with_sort(schema, vec![parquet_ordering]); let union = union_exec(vec![source2, limit]); let repartition = repartition_exec(union); - let physical_plan = sort_preserving_merge_exec(sort_exprs, repartition); - - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", - " UnionExec", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", - " GlobalLimitExec: skip=0, fetch=100", - " LocalLimitExec: fetch=100", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - + let physical_plan = sort_preserving_merge_exec(ordering, repartition); + + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2 + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + LocalLimitExec: fetch=100 + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + + Optimized Plan: + SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2 + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + LocalLimitExec: fetch=100 + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "); // We should keep the bottom `SortExec`. - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", - " UnionExec", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", - " GlobalLimitExec: skip=0, fetch=100", - " LocalLimitExec: fetch=100", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); - Ok(()) } #[tokio::test] async fn test_union_inputs_sorted() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let sort = sort_exec(sort_exprs.clone(), source1); - - let source2 = parquet_exec_sorted(&schema, sort_exprs.clone()); - + let source1 = parquet_exec(schema.clone()); + let ordering: LexOrdering = [sort_expr("nullable_col", &schema)].into(); + let sort = sort_exec(ordering.clone(), source1); + let source2 = parquet_exec_with_sort(schema, vec![ordering.clone()]); let union = union_exec(vec![source2, sort]); - let physical_plan = sort_preserving_merge_exec(sort_exprs, union); + let physical_plan = sort_preserving_merge_exec(ordering, union); // one input to the union is already sorted, one is not. - let expected_input = vec![ - "SortPreservingMergeExec: [nullable_col@0 ASC]", - " UnionExec", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - ]; + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + + assert_snapshot!(test.run(), @r" + Input / Optimized Plan: + SortPreservingMergeExec: [nullable_col@0 ASC] + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "); // should not add a sort at the output of the union, input plan should not be changed - let expected_optimized = expected_input.clone(); - assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -307,31 +297,30 @@ async fn test_union_inputs_sorted() -> Result<()> { #[tokio::test] async fn test_union_inputs_different_sorted() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let sort = sort_exec(sort_exprs.clone(), source1); - - let parquet_sort_exprs = vec![ + let source1 = parquet_exec(schema.clone()); + let ordering: LexOrdering = [sort_expr("nullable_col", &schema)].into(); + let sort = sort_exec(ordering.clone(), source1); + let parquet_ordering = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let source2 = parquet_exec_sorted(&schema, parquet_sort_exprs); - + ] + .into(); + let source2 = parquet_exec_with_sort(schema, vec![parquet_ordering]); let union = union_exec(vec![source2, sort]); - let physical_plan = sort_preserving_merge_exec(sort_exprs, union); + let physical_plan = sort_preserving_merge_exec(ordering, union); // one input to the union is already sorted, one is not. - let expected_input = vec![ - "SortPreservingMergeExec: [nullable_col@0 ASC]", - " UnionExec", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC, non_nullable_col@1 ASC], file_type=parquet", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - ]; + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + + assert_snapshot!(test.run(), @r" + Input / Optimized Plan: + SortPreservingMergeExec: [nullable_col@0 ASC] + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC, non_nullable_col@1 ASC], file_type=parquet + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "); // should not add a sort at the output of the union, input plan should not be changed - let expected_optimized = expected_input.clone(); - assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -339,120 +328,216 @@ async fn test_union_inputs_different_sorted() -> Result<()> { #[tokio::test] async fn test_union_inputs_different_sorted2() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs = vec![ + let source1 = parquet_exec(schema.clone()); + let sort_exprs: LexOrdering = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; + ] + .into(); let sort = sort_exec(sort_exprs.clone(), source1); - - let parquet_sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let source2 = parquet_exec_sorted(&schema, parquet_sort_exprs); - + let parquet_ordering = [sort_expr("nullable_col", &schema)].into(); + let source2 = parquet_exec_with_sort(schema, vec![parquet_ordering]); let union = union_exec(vec![source2, sort]); let physical_plan = sort_preserving_merge_exec(sort_exprs, union); // Input is an invalid plan. In this case rule should add required sorting in appropriate places. // First DataSourceExec has output ordering(nullable_col@0 ASC). However, it doesn't satisfy the // required ordering of SortPreservingMergeExec. - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", - " UnionExec", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", - " UnionExec", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + + Optimized Plan: + SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] + UnionExec + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "); Ok(()) } #[tokio::test] -async fn test_union_inputs_different_sorted3() -> Result<()> { +// Test with `repartition_sorts` enabled to preserve pre-sorted partitions and avoid resorting +async fn union_with_mix_of_presorted_and_explicitly_resorted_inputs_with_repartition_sorts_true() +-> Result<()> { + assert_snapshot!( + union_with_mix_of_presorted_and_explicitly_resorted_inputs_impl(true).await?, + @r" + Input Plan: + OutputRequirementExec: order_by=[(nullable_col@0, asc)], dist_by=SinglePartition + CoalescePartitionsExec + UnionExec + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + + Optimized Plan: + OutputRequirementExec: order_by=[(nullable_col@0, asc)], dist_by=SinglePartition + SortPreservingMergeExec: [nullable_col@0 ASC] + UnionExec + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + "); + Ok(()) +} + +#[tokio::test] +// Test with `repartition_sorts` disabled, causing a full resort of the data +async fn union_with_mix_of_presorted_and_explicitly_resorted_inputs_with_repartition_sorts_false() +-> Result<()> { + assert_snapshot!( + union_with_mix_of_presorted_and_explicitly_resorted_inputs_impl(false).await?, + @r" + Input Plan: + OutputRequirementExec: order_by=[(nullable_col@0, asc)], dist_by=SinglePartition + CoalescePartitionsExec + UnionExec + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + + Optimized Plan: + OutputRequirementExec: order_by=[(nullable_col@0, asc)], dist_by=SinglePartition + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + "); + Ok(()) +} + +async fn union_with_mix_of_presorted_and_explicitly_resorted_inputs_impl( + repartition_sorts: bool, +) -> Result { let schema = create_test_schema()?; - let source1 = parquet_exec(&schema); - let sort_exprs1 = vec![ - sort_expr("nullable_col", &schema), - sort_expr("non_nullable_col", &schema), - ]; - let sort1 = sort_exec(sort_exprs1, source1.clone()); - let sort_exprs2 = vec![sort_expr("nullable_col", &schema)]; - let sort2 = sort_exec(sort_exprs2, source1); + // Source 1, will be sorted explicitly (on `nullable_col`) + let source1 = parquet_exec(schema.clone()); + let ordering1 = [sort_expr("nullable_col", &schema)].into(); + let sort1 = sort_exec(ordering1, source1.clone()); + + // Source 2, pre-sorted (on `nullable_col`) + let parquet_ordering: LexOrdering = [sort_expr("nullable_col", &schema)].into(); + let source2 = parquet_exec_with_sort(schema.clone(), vec![parquet_ordering.clone()]); + + let union = union_exec(vec![sort1, source2]); + + let coalesced = coalesce_partitions_exec(union); + + // Required sorted / single partitioned output + let requirement = [PhysicalSortRequirement::new( + col("nullable_col", &schema)?, + Some(SortOptions::new(false, true)), + )] + .into(); + let physical_plan = Arc::new(OutputRequirementExec::new( + coalesced, + Some(OrderingRequirements::new(requirement)), + Distribution::SinglePartition, + None, + )); - let parquet_sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let source2 = parquet_exec_sorted(&schema, parquet_sort_exprs.clone()); + let test = + EnforceSortingTest::new(physical_plan).with_repartition_sorts(repartition_sorts); + Ok(test.run()) +} +#[tokio::test] +async fn test_union_inputs_different_sorted3() -> Result<()> { + let schema = create_test_schema()?; + let source1 = parquet_exec(schema.clone()); + let ordering1 = [ + sort_expr("nullable_col", &schema), + sort_expr("non_nullable_col", &schema), + ] + .into(); + let sort1 = sort_exec(ordering1, source1.clone()); + let ordering2 = [sort_expr("nullable_col", &schema)].into(); + let sort2 = sort_exec(ordering2, source1); + let parquet_ordering: LexOrdering = [sort_expr("nullable_col", &schema)].into(); + let source2 = parquet_exec_with_sort(schema, vec![parquet_ordering.clone()]); let union = union_exec(vec![sort1, source2, sort2]); - let physical_plan = sort_preserving_merge_exec(parquet_sort_exprs, union); + let physical_plan = sort_preserving_merge_exec(parquet_ordering, union); // First input to the union is not Sorted (SortExec is finer than required ordering by the SortPreservingMergeExec above). // Second input to the union is already Sorted (matches with the required ordering by the SortPreservingMergeExec above). // Third input to the union is not Sorted (SortExec is matches required ordering by the SortPreservingMergeExec above). - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC]", - " UnionExec", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [nullable_col@0 ASC] + UnionExec + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + + Optimized Plan: + SortPreservingMergeExec: [nullable_col@0 ASC] + UnionExec + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "); // should adjust sorting in the first input of the union such that it is not unnecessarily fine - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC]", - " UnionExec", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); - Ok(()) } #[tokio::test] async fn test_union_inputs_different_sorted4() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs1 = vec![ + let source1 = parquet_exec(schema.clone()); + let ordering1 = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let sort_exprs2 = vec![sort_expr("nullable_col", &schema)]; - let sort1 = sort_exec(sort_exprs2.clone(), source1.clone()); - let sort2 = sort_exec(sort_exprs2.clone(), source1); - - let source2 = parquet_exec_sorted(&schema, sort_exprs2); - + ] + .into(); + let ordering2: LexOrdering = [sort_expr("nullable_col", &schema)].into(); + let sort1 = sort_exec(ordering2.clone(), source1.clone()); + let sort2 = sort_exec(ordering2.clone(), source1); + let source2 = parquet_exec_with_sort(schema, vec![ordering2]); let union = union_exec(vec![sort1, source2, sort2]); - let physical_plan = sort_preserving_merge_exec(sort_exprs1, union); + let physical_plan = sort_preserving_merge_exec(ordering1, union); // Ordering requirement of the `SortPreservingMergeExec` is not met. // Should modify the plan to ensure that all three inputs to the // `UnionExec` satisfy the ordering, OR add a single sort after // the `UnionExec` (both of which are equally good for this example). - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", - " UnionExec", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", - " UnionExec", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] + UnionExec + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + + Optimized Plan: + SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] + UnionExec + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "); Ok(()) } @@ -460,13 +545,13 @@ async fn test_union_inputs_different_sorted4() -> Result<()> { #[tokio::test] async fn test_union_inputs_different_sorted5() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs1 = vec![ + let source1 = parquet_exec(schema.clone()); + let ordering1 = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let sort_exprs2 = vec![ + ] + .into(); + let ordering2 = [ sort_expr("nullable_col", &schema), sort_expr_options( "non_nullable_col", @@ -476,30 +561,35 @@ async fn test_union_inputs_different_sorted5() -> Result<()> { nulls_first: false, }, ), - ]; - let sort_exprs3 = vec![sort_expr("nullable_col", &schema)]; - let sort1 = sort_exec(sort_exprs1, source1.clone()); - let sort2 = sort_exec(sort_exprs2, source1); - + ] + .into(); + let ordering3 = [sort_expr("nullable_col", &schema)].into(); + let sort1 = sort_exec(ordering1, source1.clone()); + let sort2 = sort_exec(ordering2, source1); let union = union_exec(vec![sort1, sort2]); - let physical_plan = sort_preserving_merge_exec(sort_exprs3, union); + let physical_plan = sort_preserving_merge_exec(ordering3, union); // The `UnionExec` doesn't preserve any of the inputs ordering in the // example below. However, we should be able to change the unnecessarily // fine `SortExec`s below with required `SortExec`s that are absolutely necessary. - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC]", - " UnionExec", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC]", - " UnionExec", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [nullable_col@0 ASC] + UnionExec + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + + Optimized Plan: + SortPreservingMergeExec: [nullable_col@0 ASC] + UnionExec + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "); Ok(()) } @@ -507,22 +597,20 @@ async fn test_union_inputs_different_sorted5() -> Result<()> { #[tokio::test] async fn test_union_inputs_different_sorted6() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs1 = vec![sort_expr("nullable_col", &schema)]; - let sort1 = sort_exec(sort_exprs1, source1.clone()); - let sort_exprs2 = vec![ + let source1 = parquet_exec(schema.clone()); + let ordering1 = [sort_expr("nullable_col", &schema)].into(); + let sort1 = sort_exec(ordering1, source1.clone()); + let ordering2 = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; + ] + .into(); let repartition = repartition_exec(source1); - let spm = sort_preserving_merge_exec(sort_exprs2, repartition); - - let parquet_sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let source2 = parquet_exec_sorted(&schema, parquet_sort_exprs.clone()); - + let spm = sort_preserving_merge_exec(ordering2, repartition); + let parquet_ordering: LexOrdering = [sort_expr("nullable_col", &schema)].into(); + let source2 = parquet_exec_with_sort(schema, vec![parquet_ordering.clone()]); let union = union_exec(vec![sort1, source2, spm]); - let physical_plan = sort_preserving_merge_exec(parquet_sort_exprs, union); + let physical_plan = sort_preserving_merge_exec(parquet_ordering, union); // The plan is not valid as it is -- the input ordering requirement // of the `SortPreservingMergeExec` under the third child of the @@ -530,25 +618,30 @@ async fn test_union_inputs_different_sorted6() -> Result<()> { // At the same time, this ordering requirement is unnecessarily fine. // The final plan should be valid AND the ordering of the third child // shouldn't be finer than necessary. - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC]", - " UnionExec", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", - " SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [nullable_col@0 ASC] + UnionExec + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + + Optimized Plan: + SortPreservingMergeExec: [nullable_col@0 ASC] + UnionExec + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "); // Should adjust the requirement in the third input of the union so // that it is not unnecessarily fine. - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC]", - " UnionExec", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -556,34 +649,30 @@ async fn test_union_inputs_different_sorted6() -> Result<()> { #[tokio::test] async fn test_union_inputs_different_sorted7() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs1 = vec![ + let source1 = parquet_exec(schema.clone()); + let ordering1: LexOrdering = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let sort_exprs3 = vec![sort_expr("nullable_col", &schema)]; - let sort1 = sort_exec(sort_exprs1.clone(), source1.clone()); - let sort2 = sort_exec(sort_exprs1, source1); - + ] + .into(); + let sort1 = sort_exec(ordering1.clone(), source1.clone()); + let sort2 = sort_exec(ordering1, source1); let union = union_exec(vec![sort1, sort2]); - let physical_plan = sort_preserving_merge_exec(sort_exprs3, union); + let ordering2 = [sort_expr("nullable_col", &schema)].into(); + let physical_plan = sort_preserving_merge_exec(ordering2, union); // Union has unnecessarily fine ordering below it. We should be able to replace them with absolutely necessary ordering. - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC]", - " UnionExec", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - // Union preserves the inputs ordering and we should not change any of the SortExecs under UnionExec - let expected_output = ["SortPreservingMergeExec: [nullable_col@0 ASC]", - " UnionExec", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - assert_optimized!(expected_input, expected_output, physical_plan, true); + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input / Optimized Plan: + SortPreservingMergeExec: [nullable_col@0 ASC] + UnionExec + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "); + // Union preserves the inputs ordering, and we should not change any of the SortExecs under UnionExec Ok(()) } @@ -591,13 +680,13 @@ async fn test_union_inputs_different_sorted7() -> Result<()> { #[tokio::test] async fn test_union_inputs_different_sorted8() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs1 = vec![ + let source1 = parquet_exec(schema.clone()); + let ordering1 = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let sort_exprs2 = vec![ + ] + .into(); + let ordering2 = [ sort_expr_options( "nullable_col", &schema, @@ -614,75 +703,484 @@ async fn test_union_inputs_different_sorted8() -> Result<()> { nulls_first: false, }, ), - ]; - let sort1 = sort_exec(sort_exprs1, source1.clone()); - let sort2 = sort_exec(sort_exprs2, source1); - + ] + .into(); + let sort1 = sort_exec(ordering1, source1.clone()); + let sort2 = sort_exec(ordering2, source1); let physical_plan = union_exec(vec![sort1, sort2]); // The `UnionExec` doesn't preserve any of the inputs ordering in the // example below. - let expected_input = ["UnionExec", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " SortExec: expr=[nullable_col@0 DESC NULLS LAST, non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + UnionExec + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + SortExec: expr=[nullable_col@0 DESC NULLS LAST, non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + + Optimized Plan: + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "); // Since `UnionExec` doesn't preserve ordering in the plan above. // We shouldn't keep SortExecs in the plan. - let expected_optimized = ["UnionExec", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } #[tokio::test] -async fn test_window_multi_path_sort() -> Result<()> { +async fn test_soft_hard_requirements_remove_soft_requirement() -> Result<()> { + let schema = create_test_schema()?; + let source = parquet_exec(schema.clone()); + let sort_exprs = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let sort = sort_exec(sort_exprs, source); + let partition_bys = &[col("nullable_col", &schema)?]; + let physical_plan = + bounded_window_exec_with_partition("nullable_col", vec![], partition_bys, sort); + + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r#" + Input Plan: + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + + Optimized Plan: + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "#); + // TODO When sort pushdown respects to the alternatives, and removes soft SortExecs this should be changed + // let expected_optimized = [ + // "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + // ]; + Ok(()) +} + +#[tokio::test] +async fn test_soft_hard_requirements_remove_soft_requirement_without_pushdowns() +-> Result<()> { let schema = create_test_schema()?; + let source = parquet_exec(schema.clone()); + let ordering = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let sort = sort_exec(ordering, source.clone()); + let proj_exprs = vec![( + Arc::new(BinaryExpr::new( + col("nullable_col", &schema)?, + Operator::Plus, + col("non_nullable_col", &schema)?, + )) as _, + "count".to_string(), + )]; + let partition_bys = &[col("nullable_col", &schema)?]; + let bounded_window = + bounded_window_exec_with_partition("nullable_col", vec![], partition_bys, sort); + let physical_plan = projection_exec(proj_exprs, bounded_window)?; + + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r#" + Input Plan: + ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as count] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + + Optimized Plan: + ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as count] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "#); + // TODO When sort pushdown respects to the alternatives, and removes soft SortExecs this should be changed + // let expected_optimized = [ + // "ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as count]", + // " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + // ]; + + let ordering = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let sort = sort_exec(ordering, source); + let proj_exprs = vec![( + Arc::new(BinaryExpr::new( + col("nullable_col", &schema)?, + Operator::Plus, + col("non_nullable_col", &schema)?, + )) as _, + "nullable_col".to_string(), + )]; + let partition_bys = &[col("nullable_col", &schema)?]; + let projection = projection_exec(proj_exprs, sort)?; + let physical_plan = bounded_window_exec_with_partition( + "nullable_col", + vec![], + partition_bys, + projection, + ); - let sort_exprs1 = vec![ - sort_expr("nullable_col", &schema), - sort_expr("non_nullable_col", &schema), - ]; - let sort_exprs2 = vec![sort_expr("nullable_col", &schema)]; - // reverse sorting of sort_exprs2 - let sort_exprs3 = vec![sort_expr_options( + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r#" + Input Plan: + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col] + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + + Optimized Plan: + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col] + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "#); + // TODO When sort pushdown respects to the alternatives, and removes soft SortExecs this should be changed + // let expected_optimized = [ + // "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col]", + // " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + // ]; + Ok(()) +} + +#[tokio::test] +async fn test_soft_hard_requirements_multiple_soft_requirements() -> Result<()> { + let schema = create_test_schema()?; + let source = parquet_exec(schema.clone()); + let ordering = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let sort = sort_exec(ordering, source.clone()); + let proj_exprs = vec![( + Arc::new(BinaryExpr::new( + col("nullable_col", &schema)?, + Operator::Plus, + col("non_nullable_col", &schema)?, + )) as _, + "nullable_col".to_string(), + )]; + let partition_bys = &[col("nullable_col", &schema)?]; + let projection = projection_exec(proj_exprs, sort)?; + let bounded_window = bounded_window_exec_with_partition( + "nullable_col", + vec![], + partition_bys, + projection, + ); + let physical_plan = bounded_window_exec_with_partition( + "count", + vec![], + partition_bys, + bounded_window, + ); + + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r#" + Input Plan: + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col] + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + + Optimized Plan: + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col] + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "#); + // TODO When sort pushdown respects to the alternatives, and removes soft SortExecs this should be changed + // let expected_optimized = [ + // "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col]", + // " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + // ]; + + let ordering = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let sort = sort_exec(ordering, source); + let proj_exprs = vec![( + Arc::new(BinaryExpr::new( + col("nullable_col", &schema)?, + Operator::Plus, + col("non_nullable_col", &schema)?, + )) as _, + "nullable_col".to_string(), + )]; + let partition_bys = &[col("nullable_col", &schema)?]; + let projection = projection_exec(proj_exprs, sort)?; + let bounded_window = bounded_window_exec_with_partition( + "nullable_col", + vec![], + partition_bys, + projection, + ); + + let ordering2: LexOrdering = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let sort2 = sort_exec(ordering2.clone(), bounded_window); + let sort3 = sort_exec(ordering2, sort2); + let physical_plan = + bounded_window_exec_with_partition("count", vec![], partition_bys, sort3); + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r#" + Input Plan: + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col] + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + + Optimized Plan: + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col] + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "#); + // TODO When sort pushdown respects to the alternatives, and removes soft SortExecs this should be changed + // let expected_optimized = [ + // "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col]", + // " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + // ]; + Ok(()) +} + +#[tokio::test] +async fn test_soft_hard_requirements_multiple_sorts() -> Result<()> { + let schema = create_test_schema()?; + let source = parquet_exec(schema.clone()); + let ordering = [sort_expr_options( "nullable_col", &schema, SortOptions { descending: true, nulls_first: false, }, + )] + .into(); + let sort = sort_exec(ordering, source); + let proj_exprs = vec![( + Arc::new(BinaryExpr::new( + col("nullable_col", &schema)?, + Operator::Plus, + col("non_nullable_col", &schema)?, + )) as _, + "nullable_col".to_string(), )]; - let source1 = parquet_exec_sorted(&schema, sort_exprs1); - let source2 = parquet_exec_sorted(&schema, sort_exprs2); - let sort1 = sort_exec(sort_exprs3.clone(), source1); - let sort2 = sort_exec(sort_exprs3.clone(), source2); + let partition_bys = &[col("nullable_col", &schema)?]; + let projection = projection_exec(proj_exprs, sort)?; + let bounded_window = bounded_window_exec_with_partition( + "nullable_col", + vec![], + partition_bys, + projection, + ); + let ordering2: LexOrdering = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let sort2 = sort_exec(ordering2.clone(), bounded_window); + let physical_plan = sort_exec(ordering2, sort2); + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r#" + Input Plan: + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col] + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + + Optimized Plan: + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col] + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "#); + // TODO When sort pushdown respects to the alternatives, and removes soft SortExecs this should be changed + // let expected_optimized = [ + // "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col]", + // " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + // ]; + Ok(()) +} + +#[tokio::test] +async fn test_soft_hard_requirements_with_multiple_soft_requirements_and_output_requirement() +-> Result<()> { + let schema = create_test_schema()?; + let source = parquet_exec(schema.clone()); + let ordering = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let sort = sort_exec(ordering, source); + let partition_bys1 = &[col("nullable_col", &schema)?]; + let bounded_window = + bounded_window_exec_with_partition("nullable_col", vec![], partition_bys1, sort); + let partition_bys2 = &[col("non_nullable_col", &schema)?]; + let bounded_window2 = bounded_window_exec_with_partition( + "non_nullable_col", + vec![], + partition_bys2, + bounded_window, + ); + let requirement = [PhysicalSortRequirement::new( + col("non_nullable_col", &schema)?, + Some(SortOptions::new(false, true)), + )] + .into(); + let physical_plan = Arc::new(OutputRequirementExec::new( + bounded_window2, + Some(OrderingRequirements::new(requirement)), + Distribution::SinglePartition, + None, + )); + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r#" + Input Plan: + OutputRequirementExec: order_by=[(non_nullable_col@1, asc)], dist_by=SinglePartition + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + + Optimized Plan: + OutputRequirementExec: order_by=[(non_nullable_col@1, asc)], dist_by=SinglePartition + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: expr=[non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "#); + // TODO When sort pushdown respects to the alternatives, and removes soft SortExecs this should be changed + // let expected_optimized = [ + // "OutputRequirementExec", + // " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + // " SortExec: expr=[non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false]", + // " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + // ]; + Ok(()) +} +#[tokio::test] +async fn test_window_multi_path_sort() -> Result<()> { + let schema = create_test_schema()?; + let ordering1 = [ + sort_expr("nullable_col", &schema), + sort_expr("non_nullable_col", &schema), + ] + .into(); + let ordering2 = [sort_expr("nullable_col", &schema)].into(); + // Reverse of the above + let ordering3: LexOrdering = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let source1 = parquet_exec_with_sort(schema.clone(), vec![ordering1]); + let source2 = parquet_exec_with_sort(schema, vec![ordering2]); + let sort1 = sort_exec(ordering3.clone(), source1); + let sort2 = sort_exec(ordering3.clone(), source2); let union = union_exec(vec![sort1, sort2]); - let spm = sort_preserving_merge_exec(sort_exprs3.clone(), union); - let physical_plan = bounded_window_exec("nullable_col", sort_exprs3, spm); + let spm = sort_preserving_merge_exec(ordering3.clone(), union); + let physical_plan = bounded_window_exec("nullable_col", ordering3, spm); // The `WindowAggExec` gets its sorting from multiple children jointly. // During the removal of `SortExec`s, it should be able to remove the // corresponding SortExecs together. Also, the inputs of these `SortExec`s // are not necessarily the same to be able to remove them. - let expected_input = [ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " SortPreservingMergeExec: [nullable_col@0 DESC NULLS LAST]", - " UnionExec", - " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC, non_nullable_col@1 ASC], file_type=parquet", - " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet"]; - let expected_optimized = [ - "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " SortPreservingMergeExec: [nullable_col@0 ASC]", - " UnionExec", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC, non_nullable_col@1 ASC], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r#" + Input Plan: + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortPreservingMergeExec: [nullable_col@0 DESC NULLS LAST] + UnionExec + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC, non_nullable_col@1 ASC], file_type=parquet + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + + Optimized Plan: + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64 }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + SortPreservingMergeExec: [nullable_col@0 ASC] + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC, non_nullable_col@1 ASC], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + "#); Ok(()) } @@ -690,36 +1188,40 @@ async fn test_window_multi_path_sort() -> Result<()> { #[tokio::test] async fn test_window_multi_path_sort2() -> Result<()> { let schema = create_test_schema()?; - - let sort_exprs1 = LexOrdering::new(vec![ + let ordering1: LexOrdering = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]); - let sort_exprs2 = vec![sort_expr("nullable_col", &schema)]; - let source1 = parquet_exec_sorted(&schema, sort_exprs2.clone()); - let source2 = parquet_exec_sorted(&schema, sort_exprs2.clone()); - let sort1 = sort_exec(sort_exprs1.clone(), source1); - let sort2 = sort_exec(sort_exprs1.clone(), source2); - + ] + .into(); + let ordering2: LexOrdering = [sort_expr("nullable_col", &schema)].into(); + let source1 = parquet_exec_with_sort(schema.clone(), vec![ordering2.clone()]); + let source2 = parquet_exec_with_sort(schema, vec![ordering2.clone()]); + let sort1 = sort_exec(ordering1.clone(), source1); + let sort2 = sort_exec(ordering1.clone(), source2); let union = union_exec(vec![sort1, sort2]); - let spm = Arc::new(SortPreservingMergeExec::new(sort_exprs1, union)) as _; - let physical_plan = bounded_window_exec("nullable_col", sort_exprs2, spm); + let spm = Arc::new(SortPreservingMergeExec::new(ordering1, union)) as _; + let physical_plan = bounded_window_exec("nullable_col", ordering2, spm); // The `WindowAggExec` can get its required sorting from the leaf nodes directly. // The unnecessary SortExecs should be removed - let expected_input = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", - " UnionExec", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet"]; - let expected_optimized = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " SortPreservingMergeExec: [nullable_col@0 ASC]", - " UnionExec", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r#" + Input Plan: + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] + UnionExec + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + + Optimized Plan: + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortPreservingMergeExec: [nullable_col@0 ASC] + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + "#); Ok(()) } @@ -727,13 +1229,13 @@ async fn test_window_multi_path_sort2() -> Result<()> { #[tokio::test] async fn test_union_inputs_different_sorted_with_limit() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs1 = vec![ + let source1 = parquet_exec(schema.clone()); + let ordering1 = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let sort_exprs2 = vec![ + ] + .into(); + let ordering2 = [ sort_expr("nullable_col", &schema), sort_expr_options( "non_nullable_col", @@ -743,35 +1245,39 @@ async fn test_union_inputs_different_sorted_with_limit() -> Result<()> { nulls_first: false, }, ), - ]; - let sort_exprs3 = vec![sort_expr("nullable_col", &schema)]; - let sort1 = sort_exec(sort_exprs1, source1.clone()); - - let sort2 = sort_exec(sort_exprs2, source1); - let limit = local_limit_exec(sort2); - let limit = global_limit_exec(limit); - + ] + .into(); + let sort1 = sort_exec(ordering1, source1.clone()); + let sort2 = sort_exec(ordering2, source1); + let limit = local_limit_exec(sort2, 100); + let limit = global_limit_exec(limit, 0, Some(100)); let union = union_exec(vec![sort1, limit]); - let physical_plan = sort_preserving_merge_exec(sort_exprs3, union); + let ordering3 = [sort_expr("nullable_col", &schema)].into(); + let physical_plan = sort_preserving_merge_exec(ordering3, union); + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); // Should not change the unnecessarily fine `SortExec`s because there is `LimitExec` - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC]", - " UnionExec", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " GlobalLimitExec: skip=0, fetch=100", - " LocalLimitExec: fetch=100", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC]", - " UnionExec", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " GlobalLimitExec: skip=0, fetch=100", - " LocalLimitExec: fetch=100", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [nullable_col@0 ASC] + UnionExec + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + GlobalLimitExec: skip=0, fetch=100 + LocalLimitExec: fetch=100 + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + + Optimized Plan: + SortPreservingMergeExec: [nullable_col@0 ASC] + UnionExec + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + GlobalLimitExec: skip=0, fetch=100 + LocalLimitExec: fetch=100 + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "); Ok(()) } @@ -781,15 +1287,17 @@ async fn test_sort_merge_join_order_by_left() -> Result<()> { let left_schema = create_test_schema()?; let right_schema = create_test_schema2()?; - let left = parquet_exec(&left_schema); - let right = parquet_exec(&right_schema); + let left = parquet_exec(left_schema); + let right = parquet_exec(right_schema); // Join on (nullable_col == col_a) let join_on = vec![( - Arc::new(Column::new_with_schema("nullable_col", &left.schema()).unwrap()) as _, - Arc::new(Column::new_with_schema("col_a", &right.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("nullable_col", &left.schema())?) as _, + Arc::new(Column::new_with_schema("col_a", &right.schema())?) as _, )]; + let settings = Settings::clone_current(); + let join_types = vec![ JoinType::Inner, JoinType::Left, @@ -801,49 +1309,69 @@ async fn test_sort_merge_join_order_by_left() -> Result<()> { for join_type in join_types { let join = sort_merge_join_exec(left.clone(), right.clone(), &join_on, &join_type); - let sort_exprs = vec![ + let ordering = [ sort_expr("nullable_col", &join.schema()), sort_expr("non_nullable_col", &join.schema()), - ]; - let physical_plan = sort_preserving_merge_exec(sort_exprs.clone(), join); + ] + .into(); + let physical_plan = sort_preserving_merge_exec(ordering, join); - let join_plan = format!( - "SortMergeJoin: join_type={join_type}, on=[(nullable_col@0, col_a@0)]" - ); - let join_plan2 = format!( - " SortMergeJoin: join_type={join_type}, on=[(nullable_col@0, col_a@0)]" + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + + let mut settings = settings.clone(); + + settings.add_filter( + // join_type={} replace with join_type=... to avoid snapshot name issue + format!("join_type={join_type}").as_str(), + "join_type=...", ); - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", - join_plan2.as_str(), - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet"]; - let expected_optimized = match join_type { + + insta::allow_duplicates! { + settings.bind( || { + + + match join_type { JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => { // can push down the sort requirements and save 1 SortExec - vec![ - join_plan.as_str(), - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet", - ] + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] + SortMergeJoinExec: join_type=..., on=[(nullable_col@0, col_a@0)] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet + + Optimized Plan: + SortMergeJoinExec: join_type=..., on=[(nullable_col@0, col_a@0)] + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet + "); } _ => { // can not push down the sort requirements - vec![ - "SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - join_plan2.as_str(), - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet", - ] + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] + SortMergeJoinExec: join_type=..., on=[(nullable_col@0, col_a@0)] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet + + Optimized Plan: + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + SortMergeJoinExec: join_type=..., on=[(nullable_col@0, col_a@0)] + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet + "); } }; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + }) + } } Ok(()) } @@ -853,15 +1381,17 @@ async fn test_sort_merge_join_order_by_right() -> Result<()> { let left_schema = create_test_schema()?; let right_schema = create_test_schema2()?; - let left = parquet_exec(&left_schema); - let right = parquet_exec(&right_schema); + let left = parquet_exec(left_schema); + let right = parquet_exec(right_schema); // Join on (nullable_col == col_a) let join_on = vec![( - Arc::new(Column::new_with_schema("nullable_col", &left.schema()).unwrap()) as _, - Arc::new(Column::new_with_schema("col_a", &right.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("nullable_col", &left.schema())?) as _, + Arc::new(Column::new_with_schema("col_a", &right.schema())?) as _, )]; + let settings = Settings::clone_current(); + let join_types = vec![ JoinType::Inner, JoinType::Left, @@ -872,50 +1402,83 @@ async fn test_sort_merge_join_order_by_right() -> Result<()> { for join_type in join_types { let join = sort_merge_join_exec(left.clone(), right.clone(), &join_on, &join_type); - let sort_exprs = vec![ + let ordering = [ sort_expr("col_a", &join.schema()), sort_expr("col_b", &join.schema()), - ]; - let physical_plan = sort_preserving_merge_exec(sort_exprs, join); + ] + .into(); + let physical_plan = sort_preserving_merge_exec(ordering, join); - let join_plan = format!( - "SortMergeJoin: join_type={join_type}, on=[(nullable_col@0, col_a@0)]" - ); - let spm_plan = match join_type { - JoinType::RightAnti => "SortPreservingMergeExec: [col_a@0 ASC, col_b@1 ASC]", - _ => "SortPreservingMergeExec: [col_a@2 ASC, col_b@3 ASC]", - }; - let join_plan2 = format!( - " SortMergeJoin: join_type={join_type}, on=[(nullable_col@0, col_a@0)]" + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + + let mut settings = settings.clone(); + + settings.add_filter( + // join_type={} replace with join_type=... to avoid snapshot name issue + format!("join_type={join_type}").as_str(), + "join_type=...", ); - let expected_input = [spm_plan, - join_plan2.as_str(), - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet"]; - let expected_optimized = match join_type { - JoinType::Inner | JoinType::Right | JoinType::RightAnti => { + + insta::allow_duplicates! { + settings.bind( || { + + + match join_type { + JoinType::Inner | JoinType::Right => { + // can push down the sort requirements and save 1 SortExec + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [col_a@2 ASC, col_b@3 ASC] + SortMergeJoinExec: join_type=..., on=[(nullable_col@0, col_a@0)] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet + + Optimized Plan: + SortMergeJoinExec: join_type=..., on=[(nullable_col@0, col_a@0)] + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + SortExec: expr=[col_a@0 ASC, col_b@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet + "); + } + JoinType::RightAnti => { // can push down the sort requirements and save 1 SortExec - vec![ - join_plan.as_str(), - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " SortExec: expr=[col_a@0 ASC, col_b@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet", - ] + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [col_a@0 ASC, col_b@1 ASC] + SortMergeJoinExec: join_type=..., on=[(nullable_col@0, col_a@0)] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet + + Optimized Plan: + SortMergeJoinExec: join_type=..., on=[(nullable_col@0, col_a@0)] + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + SortExec: expr=[col_a@0 ASC, col_b@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet + "); } _ => { // can not push down the sort requirements for Left and Full join. - vec![ - "SortExec: expr=[col_a@2 ASC, col_b@3 ASC], preserve_partitioning=[false]", - join_plan2.as_str(), - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet", - ] + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [col_a@2 ASC, col_b@3 ASC] + SortMergeJoinExec: join_type=..., on=[(nullable_col@0, col_a@0)] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet + + Optimized Plan: + SortExec: expr=[col_a@2 ASC, col_b@3 ASC], preserve_partitioning=[false] + SortMergeJoinExec: join_type=..., on=[(nullable_col@0, col_a@0)] + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet + "); } }; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + }) + } } Ok(()) } @@ -925,59 +1488,69 @@ async fn test_sort_merge_join_complex_order_by() -> Result<()> { let left_schema = create_test_schema()?; let right_schema = create_test_schema2()?; - let left = parquet_exec(&left_schema); - let right = parquet_exec(&right_schema); + let left = parquet_exec(left_schema); + let right = parquet_exec(right_schema); // Join on (nullable_col == col_a) let join_on = vec![( - Arc::new(Column::new_with_schema("nullable_col", &left.schema()).unwrap()) as _, - Arc::new(Column::new_with_schema("col_a", &right.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("nullable_col", &left.schema())?) as _, + Arc::new(Column::new_with_schema("col_a", &right.schema())?) as _, )]; let join = sort_merge_join_exec(left, right, &join_on, &JoinType::Inner); // order by (col_b, col_a) - let sort_exprs1 = vec![ + let ordering = [ sort_expr("col_b", &join.schema()), sort_expr("col_a", &join.schema()), - ]; - let physical_plan = sort_preserving_merge_exec(sort_exprs1, join.clone()); - - let expected_input = ["SortPreservingMergeExec: [col_b@3 ASC, col_a@2 ASC]", - " SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet"]; - + ] + .into(); + let physical_plan = sort_preserving_merge_exec(ordering, join.clone()); + + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [col_b@3 ASC, col_a@2 ASC] + SortMergeJoinExec: join_type=Inner, on=[(nullable_col@0, col_a@0)] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet + + Optimized Plan: + SortExec: expr=[col_b@3 ASC, nullable_col@0 ASC], preserve_partitioning=[false] + SortMergeJoinExec: join_type=Inner, on=[(nullable_col@0, col_a@0)] + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet + "); // can not push down the sort requirements, need to add SortExec - let expected_optimized = ["SortExec: expr=[col_b@3 ASC, col_a@2 ASC], preserve_partitioning=[false]", - " SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)]", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); // order by (nullable_col, col_b, col_a) - let sort_exprs2 = vec![ + let ordering2 = [ sort_expr("nullable_col", &join.schema()), sort_expr("col_b", &join.schema()), sort_expr("col_a", &join.schema()), - ]; - let physical_plan = sort_preserving_merge_exec(sort_exprs2, join); - - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC, col_b@3 ASC, col_a@2 ASC]", - " SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet"]; - - // can not push down the sort requirements, need to add SortExec - let expected_optimized = ["SortExec: expr=[nullable_col@0 ASC, col_b@3 ASC, col_a@2 ASC], preserve_partitioning=[false]", - " SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)]", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + ] + .into(); + let physical_plan = sort_preserving_merge_exec(ordering2, join); + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [nullable_col@0 ASC, col_b@3 ASC, col_a@2 ASC] + SortMergeJoinExec: join_type=Inner, on=[(nullable_col@0, col_a@0)] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet + + Optimized Plan: + SortMergeJoinExec: join_type=Inner, on=[(nullable_col@0, col_a@0)] + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + SortExec: expr=[col_a@0 ASC, col_b@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet + "); + // Can push down the sort requirements since col_a = nullable_col Ok(()) } @@ -985,152 +1558,136 @@ async fn test_sort_merge_join_complex_order_by() -> Result<()> { #[tokio::test] async fn test_multilayer_coalesce_partitions() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); + let source1 = parquet_exec(schema.clone()); let repartition = repartition_exec(source1); - let coalesce = Arc::new(CoalescePartitionsExec::new(repartition)) as _; + let coalesce = coalesce_partitions_exec(repartition) as _; // Add dummy layer propagating Sort above, to test whether sort can be removed from multi layer before let filter = filter_exec( - Arc::new(NotExpr::new( - col("non_nullable_col", schema.as_ref()).unwrap(), - )), + Arc::new(NotExpr::new(col("non_nullable_col", schema.as_ref())?)), coalesce, ); - let sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let physical_plan = sort_exec(sort_exprs, filter); + let ordering = [sort_expr("nullable_col", &schema)].into(); + let physical_plan = sort_exec(ordering, filter); // CoalescePartitionsExec and SortExec are not directly consecutive. In this case // we should be able to parallelize Sorting also (given that executors in between don't require) // single partition. - let expected_input = ["SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " FilterExec: NOT non_nullable_col@1", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC]", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[true]", - " FilterExec: NOT non_nullable_col@1", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + FilterExec: NOT non_nullable_col@1 + CoalescePartitionsExec + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + + Optimized Plan: + SortPreservingMergeExec: [nullable_col@0 ASC] + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[true] + FilterExec: NOT non_nullable_col@1 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "); Ok(()) } -#[tokio::test] -async fn test_with_lost_ordering_bounded() -> Result<()> { +fn create_lost_ordering_plan(source_unbounded: bool) -> Result> { let schema = create_test_schema3()?; - let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs); - let repartition_rr = repartition_exec(source); - let repartition_hash = Arc::new(RepartitionExec::try_new( - repartition_rr, - Partitioning::Hash(vec![col("c", &schema).unwrap()], 10), - )?) as _; - let coalesce_partitions = coalesce_partitions_exec(repartition_hash); - let physical_plan = sort_exec(vec![sort_expr("a", &schema)], coalesce_partitions); - - let expected_input = ["SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=false"]; - let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC]", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=false"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); - - Ok(()) -} - -#[rstest] -#[tokio::test] -async fn test_with_lost_ordering_unbounded_bounded( - #[values(false, true)] source_unbounded: bool, -) -> Result<()> { - let schema = create_test_schema3()?; - let sort_exprs = vec![sort_expr("a", &schema)]; + let sort_exprs = [sort_expr("a", &schema)]; // create either bounded or unbounded source let source = if source_unbounded { - stream_exec_ordered(&schema, sort_exprs) + stream_exec_ordered(&schema, sort_exprs.clone().into()) } else { - csv_exec_ordered(&schema, sort_exprs) + csv_exec_sorted(&schema, sort_exprs.clone()) }; let repartition_rr = repartition_exec(source); let repartition_hash = Arc::new(RepartitionExec::try_new( repartition_rr, - Partitioning::Hash(vec![col("c", &schema).unwrap()], 10), + Partitioning::Hash(vec![col("c", &schema)?], 10), )?) as _; let coalesce_partitions = coalesce_partitions_exec(repartition_hash); - let physical_plan = sort_exec(vec![sort_expr("a", &schema)], coalesce_partitions); - - // Expected inputs unbounded and bounded - let expected_input_unbounded = vec![ - "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC]", - ]; - let expected_input_bounded = vec![ - "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[file_path]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=true", - ]; + let physical_plan = sort_exec(sort_exprs.into(), coalesce_partitions); + Ok(physical_plan) +} - // Expected unbounded result (same for with and without flag) - let expected_optimized_unbounded = vec![ - "SortPreservingMergeExec: [a@0 ASC]", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC]", - ]; +#[tokio::test] +async fn test_with_lost_ordering_unbounded() -> Result<()> { + let physical_plan = create_lost_ordering_plan(true)?; + + let test_no_repartition_sorts = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(false); + + assert_snapshot!(test_no_repartition_sorts.run(), @r" + Input Plan: + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC] + + Optimized Plan: + SortPreservingMergeExec: [a@0 ASC] + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC] + "); + + let test_with_repartition_sorts = + EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + + assert_snapshot!(test_with_repartition_sorts.run(), @r" + Input Plan: + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC] + + Optimized Plan: + SortPreservingMergeExec: [a@0 ASC] + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC] + "); - // Expected bounded results with and without flag - let expected_optimized_bounded = vec![ - "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[file_path]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=true", - ]; - let expected_optimized_bounded_parallelize_sort = vec![ - "SortPreservingMergeExec: [a@0 ASC]", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[file_path]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=true", - ]; - let (expected_input, expected_optimized, expected_optimized_sort_parallelize) = - if source_unbounded { - ( - expected_input_unbounded, - expected_optimized_unbounded.clone(), - expected_optimized_unbounded, - ) - } else { - ( - expected_input_bounded, - expected_optimized_bounded, - expected_optimized_bounded_parallelize_sort, - ) - }; - assert_optimized!( - expected_input, - expected_optimized, - physical_plan.clone(), - false - ); - assert_optimized!( - expected_input, - expected_optimized_sort_parallelize, - physical_plan, - true - ); + Ok(()) +} + +#[tokio::test] +async fn test_with_lost_ordering_bounded() -> Result<()> { + let physical_plan = create_lost_ordering_plan(false)?; + + let test_no_repartition_sorts = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(false); + + assert_snapshot!(test_no_repartition_sorts.run(), @r" + Input / Optimized Plan: + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=false + "); + + let test_with_repartition_sorts = + EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + + assert_snapshot!(test_with_repartition_sorts.run(), @r" + Input Plan: + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=false + + Optimized Plan: + SortPreservingMergeExec: [a@0 ASC] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=false + "); Ok(()) } @@ -1138,21 +1695,21 @@ async fn test_with_lost_ordering_unbounded_bounded( #[tokio::test] async fn test_do_not_pushdown_through_spm() -> Result<()> { let schema = create_test_schema3()?; - let sort_exprs = vec![sort_expr("a", &schema), sort_expr("b", &schema)]; + let sort_exprs = [sort_expr("a", &schema), sort_expr("b", &schema)]; let source = csv_exec_sorted(&schema, sort_exprs.clone()); let repartition_rr = repartition_exec(source); - let spm = sort_preserving_merge_exec(sort_exprs, repartition_rr); - let physical_plan = sort_exec(vec![sort_expr("b", &schema)], spm); - - let expected_input = ["SortExec: expr=[b@1 ASC], preserve_partitioning=[false]", - " SortPreservingMergeExec: [a@0 ASC, b@1 ASC]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false",]; - let expected_optimized = ["SortExec: expr=[b@1 ASC], preserve_partitioning=[false]", - " SortPreservingMergeExec: [a@0 ASC, b@1 ASC]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false",]; - assert_optimized!(expected_input, expected_optimized, physical_plan, false); + let spm = sort_preserving_merge_exec(sort_exprs.into(), repartition_rr); + let physical_plan = sort_exec([sort_expr("b", &schema)].into(), spm); + + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input / Optimized Plan: + SortExec: expr=[b@1 ASC], preserve_partitioning=[false] + SortPreservingMergeExec: [a@0 ASC, b@1 ASC] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false + "); Ok(()) } @@ -1160,192 +1717,115 @@ async fn test_do_not_pushdown_through_spm() -> Result<()> { #[tokio::test] async fn test_pushdown_through_spm() -> Result<()> { let schema = create_test_schema3()?; - let sort_exprs = vec![sort_expr("a", &schema), sort_expr("b", &schema)]; + let sort_exprs = [sort_expr("a", &schema), sort_expr("b", &schema)]; let source = csv_exec_sorted(&schema, sort_exprs.clone()); let repartition_rr = repartition_exec(source); - let spm = sort_preserving_merge_exec(sort_exprs, repartition_rr); + let spm = sort_preserving_merge_exec(sort_exprs.into(), repartition_rr); let physical_plan = sort_exec( - vec![ + [ sort_expr("a", &schema), sort_expr("b", &schema), sort_expr("c", &schema), - ], + ] + .into(), spm, ); - - let expected_input = ["SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[false]", - " SortPreservingMergeExec: [a@0 ASC, b@1 ASC]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false",]; - let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC, b@1 ASC]", - " SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false",]; - assert_optimized!(expected_input, expected_optimized, physical_plan, false); - + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[false] + SortPreservingMergeExec: [a@0 ASC, b@1 ASC] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false + + Optimized Plan: + SortPreservingMergeExec: [a@0 ASC, b@1 ASC] + SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false + "); Ok(()) } #[tokio::test] async fn test_window_multi_layer_requirement() -> Result<()> { let schema = create_test_schema3()?; - let sort_exprs = vec![sort_expr("a", &schema), sort_expr("b", &schema)]; + let sort_exprs = [sort_expr("a", &schema), sort_expr("b", &schema)]; let source = csv_exec_sorted(&schema, vec![]); - let sort = sort_exec(sort_exprs.clone(), source); + let sort = sort_exec(sort_exprs.clone().into(), source); let repartition = repartition_exec(sort); let repartition = spr_repartition_exec(repartition); - let spm = sort_preserving_merge_exec(sort_exprs.clone(), repartition); - + let spm = sort_preserving_merge_exec(sort_exprs.clone().into(), repartition); let physical_plan = bounded_window_exec("a", sort_exprs, spm); - let expected_input = [ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " SortPreservingMergeExec: [a@0 ASC, b@1 ASC]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC, b@1 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - ]; - let expected_optimized = [ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, false); + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r#" + Input Plan: + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortPreservingMergeExec: [a@0 ASC, b@1 ASC] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC, b@1 ASC + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + + Optimized Plan: + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortPreservingMergeExec: [a@0 ASC, b@1 ASC] + SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + "#); Ok(()) } #[tokio::test] async fn test_not_replaced_with_partial_sort_for_bounded_input() -> Result<()> { let schema = create_test_schema3()?; - let input_sort_exprs = vec![sort_expr("b", &schema), sort_expr("c", &schema)]; - let parquet_input = parquet_exec_sorted(&schema, input_sort_exprs); - + let parquet_ordering = [sort_expr("b", &schema), sort_expr("c", &schema)].into(); + let parquet_input = parquet_exec_with_sort(schema.clone(), vec![parquet_ordering]); let physical_plan = sort_exec( - vec![ + [ sort_expr("a", &schema), sort_expr("b", &schema), sort_expr("c", &schema), - ], + ] + .into(), parquet_input, ); - let expected_input = [ - "SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[b@1 ASC, c@2 ASC], file_type=parquet" - ]; - let expected_no_change = expected_input; - assert_optimized!(expected_input, expected_no_change, physical_plan, false); - Ok(()) -} - -/// Runs the sort enforcement optimizer and asserts the plan -/// against the original and expected plans -/// -/// `$EXPECTED_PLAN_LINES`: input plan -/// `$EXPECTED_OPTIMIZED_PLAN_LINES`: optimized plan -/// `$PLAN`: the plan to optimized -/// `REPARTITION_SORTS`: Flag to set `config.options.optimizer.repartition_sorts` option. -/// `$CASE_NUMBER` (optional): The test case number to print on failure. -macro_rules! assert_optimized { - ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $REPARTITION_SORTS: expr $(, $CASE_NUMBER: expr)?) => { - let mut config = ConfigOptions::new(); - config.optimizer.repartition_sorts = $REPARTITION_SORTS; - - // This file has 4 rules that use tree node, apply these rules as in the - // EnforceSorting::optimize implementation - // After these operations tree nodes should be in a consistent state. - // This code block makes sure that these rules doesn't violate tree node integrity. - { - let plan_requirements = PlanWithCorrespondingSort::new_default($PLAN.clone()); - let adjusted = plan_requirements - .transform_up(ensure_sorting) - .data() - .and_then(check_integrity)?; - // TODO: End state payloads will be checked here. - - let new_plan = if config.optimizer.repartition_sorts { - let plan_with_coalesce_partitions = - PlanWithCorrespondingCoalescePartitions::new_default(adjusted.plan); - let parallel = plan_with_coalesce_partitions - .transform_up(parallelize_sorts) - .data() - .and_then(check_integrity)?; - // TODO: End state payloads will be checked here. - parallel.plan - } else { - adjusted.plan - }; - - let plan_with_pipeline_fixer = OrderPreservationContext::new_default(new_plan); - let updated_plan = plan_with_pipeline_fixer - .transform_up(|plan_with_pipeline_fixer| { - replace_with_order_preserving_variants( - plan_with_pipeline_fixer, - false, - true, - &config, - ) - }) - .data() - .and_then(check_integrity)?; - // TODO: End state payloads will be checked here. - - let mut sort_pushdown = SortPushDown::new_default(updated_plan.plan); - assign_initial_requirements(&mut sort_pushdown); - check_integrity(pushdown_sorts(sort_pushdown)?)?; - // TODO: End state payloads will be checked here. - } - - let physical_plan = $PLAN; - let formatted = displayable(physical_plan.as_ref()).indent(true).to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - - let expected_plan_lines: Vec<&str> = $EXPECTED_PLAN_LINES - .iter().map(|s| *s).collect(); - - if expected_plan_lines != actual { - $(println!("\n**Original Plan Mismatch in case {}**", $CASE_NUMBER);)? - println!("\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", expected_plan_lines, actual); - assert_eq!(expected_plan_lines, actual); - } - - let expected_optimized_lines: Vec<&str> = $EXPECTED_OPTIMIZED_PLAN_LINES - .iter().map(|s| *s).collect(); + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(false); - // Run the actual optimizer - let optimized_physical_plan = - EnforceSorting::new().optimize(physical_plan, &config)?; + assert_snapshot!(test.run(), @r" + Input / Optimized Plan: + SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[b@1 ASC, c@2 ASC], file_type=parquet + "); - // Get string representation of the plan - let actual = get_plan_string(&optimized_physical_plan); - if expected_optimized_lines != actual { - $(println!("\n**Optimized Plan Mismatch in case {}**", $CASE_NUMBER);)? - println!("\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", expected_optimized_lines, actual); - assert_eq!(expected_optimized_lines, actual); - } - }; + Ok(()) } #[tokio::test] async fn test_remove_unnecessary_sort() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let input = sort_exec(vec![sort_expr("non_nullable_col", &schema)], source); - let physical_plan = sort_exec(vec![sort_expr("nullable_col", &schema)], input); - - let expected_input = [ - "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - let expected_optimized = [ - "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let input = sort_exec([sort_expr("non_nullable_col", &schema)].into(), source); + let physical_plan = sort_exec([sort_expr("nullable_col", &schema)].into(), input); + + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + + Optimized Plan: + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + "); Ok(()) } @@ -1354,58 +1834,52 @@ async fn test_remove_unnecessary_sort() -> Result<()> { async fn test_remove_unnecessary_sort_window_multilayer() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - - let sort_exprs = vec![sort_expr_options( + let ordering: LexOrdering = [sort_expr_options( "non_nullable_col", &source.schema(), SortOptions { descending: true, nulls_first: true, }, - )]; - let sort = sort_exec(sort_exprs.clone(), source); - // Add dummy layer propagating Sort above, to test whether sort can be removed from multi layer before - let coalesce_batches = coalesce_batches_exec(sort); - - let window_agg = - bounded_window_exec("non_nullable_col", sort_exprs, coalesce_batches); - - let sort_exprs = vec![sort_expr_options( + )] + .into(); + let sort = sort_exec(ordering.clone(), source); + let window_agg = bounded_window_exec("non_nullable_col", ordering, sort); + let ordering2: LexOrdering = [sort_expr_options( "non_nullable_col", &window_agg.schema(), SortOptions { descending: false, nulls_first: false, }, - )]; - - let sort = sort_exec(sort_exprs.clone(), window_agg); - + )] + .into(); + let sort = sort_exec(ordering2.clone(), window_agg); // Add dummy layer propagating Sort above, to test whether sort can be removed from multi layer before let filter = filter_exec( - Arc::new(NotExpr::new( - col("non_nullable_col", schema.as_ref()).unwrap(), - )), + Arc::new(NotExpr::new(col("non_nullable_col", schema.as_ref())?)), sort, ); - - let physical_plan = bounded_window_exec("non_nullable_col", sort_exprs, filter); - - let expected_input = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " FilterExec: NOT non_nullable_col@1", - " SortExec: expr=[non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " CoalesceBatchesExec: target_batch_size=128", - " SortExec: expr=[non_nullable_col@1 DESC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]"]; - - let expected_optimized = ["WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " FilterExec: NOT non_nullable_col@1", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " CoalesceBatchesExec: target_batch_size=128", - " SortExec: expr=[non_nullable_col@1 DESC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let physical_plan = bounded_window_exec("non_nullable_col", ordering2, filter); + + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r#" + Input Plan: + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + FilterExec: NOT non_nullable_col@1 + SortExec: expr=[non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: expr=[non_nullable_col@1 DESC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + + Optimized Plan: + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64 }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + FilterExec: NOT non_nullable_col@1 + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: expr=[non_nullable_col@1 DESC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + "#); Ok(()) } @@ -1414,20 +1888,20 @@ async fn test_remove_unnecessary_sort_window_multilayer() -> Result<()> { async fn test_add_required_sort() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); + let ordering = [sort_expr("nullable_col", &schema)].into(); + let physical_plan = sort_preserving_merge_exec(ordering, source); - let sort_exprs = vec![sort_expr("nullable_col", &schema)]; - - let physical_plan = sort_preserving_merge_exec(sort_exprs, source); + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [nullable_col@0 ASC] + DataSourceExec: partitions=1, partition_sizes=[0] - let expected_input = [ - "SortPreservingMergeExec: [nullable_col@0 ASC]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - let expected_optimized = [ - "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + Optimized Plan: + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + "); Ok(()) } @@ -1436,25 +1910,26 @@ async fn test_add_required_sort() -> Result<()> { async fn test_remove_unnecessary_sort1() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let sort = sort_exec(sort_exprs.clone(), source); - let spm = sort_preserving_merge_exec(sort_exprs, sort); - - let sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let sort = sort_exec(sort_exprs.clone(), spm); - let physical_plan = sort_preserving_merge_exec(sort_exprs, sort); - let expected_input = [ - "SortPreservingMergeExec: [nullable_col@0 ASC]", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " SortPreservingMergeExec: [nullable_col@0 ASC]", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - let expected_optimized = [ - "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let ordering: LexOrdering = [sort_expr("nullable_col", &schema)].into(); + let sort = sort_exec(ordering.clone(), source); + let spm = sort_preserving_merge_exec(ordering.clone(), sort); + let sort = sort_exec(ordering.clone(), spm); + let physical_plan = sort_preserving_merge_exec(ordering, sort); + + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [nullable_col@0 ASC] + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + SortPreservingMergeExec: [nullable_col@0 ASC] + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + + Optimized Plan: + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + "); Ok(()) } @@ -1463,38 +1938,38 @@ async fn test_remove_unnecessary_sort1() -> Result<()> { async fn test_remove_unnecessary_sort2() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let sort_exprs = vec![sort_expr("non_nullable_col", &schema)]; - let sort = sort_exec(sort_exprs.clone(), source); - let spm = sort_preserving_merge_exec(sort_exprs, sort); - - let sort_exprs = vec![ + let ordering: LexOrdering = [sort_expr("non_nullable_col", &schema)].into(); + let sort = sort_exec(ordering.clone(), source); + let spm = sort_preserving_merge_exec(ordering, sort); + let ordering2: LexOrdering = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let sort2 = sort_exec(sort_exprs.clone(), spm); - let spm2 = sort_preserving_merge_exec(sort_exprs, sort2); - - let sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let sort3 = sort_exec(sort_exprs, spm2); + ] + .into(); + let sort2 = sort_exec(ordering2.clone(), spm); + let spm2 = sort_preserving_merge_exec(ordering2, sort2); + let ordering3 = [sort_expr("nullable_col", &schema)].into(); + let sort3 = sort_exec(ordering3, spm2); let physical_plan = repartition_exec(repartition_exec(sort3)); - let expected_input = [ - "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " SortPreservingMergeExec: [non_nullable_col@1 ASC]", - " SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - - let expected_optimized = [ - "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + SortPreservingMergeExec: [non_nullable_col@1 ASC] + SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + + Optimized Plan: + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[0] + "); Ok(()) } @@ -1503,43 +1978,43 @@ async fn test_remove_unnecessary_sort2() -> Result<()> { async fn test_remove_unnecessary_sort3() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let sort_exprs = vec![sort_expr("non_nullable_col", &schema)]; - let sort = sort_exec(sort_exprs.clone(), source); - let spm = sort_preserving_merge_exec(sort_exprs, sort); - - let sort_exprs = LexOrdering::new(vec![ + let ordering: LexOrdering = [sort_expr("non_nullable_col", &schema)].into(); + let sort = sort_exec(ordering.clone(), source); + let spm = sort_preserving_merge_exec(ordering, sort); + let repartition_exec = repartition_exec(spm); + let ordering2: LexOrdering = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]); - let repartition_exec = repartition_exec(spm); + ] + .into(); let sort2 = Arc::new( - SortExec::new(sort_exprs.clone(), repartition_exec) + SortExec::new(ordering2.clone(), repartition_exec) .with_preserve_partitioning(true), ) as _; - let spm2 = sort_preserving_merge_exec(sort_exprs, sort2); - + let spm2 = sort_preserving_merge_exec(ordering2, sort2); let physical_plan = aggregate_exec(spm2); // When removing a `SortPreservingMergeExec`, make sure that partitioning // requirements are not violated. In some cases, we may need to replace // it with a `CoalescePartitionsExec` instead of directly removing it. - let expected_input = [ - "AggregateExec: mode=Final, gby=[], aggr=[]", - " SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortPreservingMergeExec: [non_nullable_col@1 ASC]", - " SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - - let expected_optimized = [ - "AggregateExec: mode=Final, gby=[], aggr=[]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + AggregateExec: mode=Final, gby=[], aggr=[] + SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + SortPreservingMergeExec: [non_nullable_col@1 ASC] + SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + + Optimized Plan: + AggregateExec: mode=Final, gby=[], aggr=[] + CoalescePartitionsExec + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[0] + "); Ok(()) } @@ -1548,52 +2023,51 @@ async fn test_remove_unnecessary_sort3() -> Result<()> { async fn test_remove_unnecessary_sort4() -> Result<()> { let schema = create_test_schema()?; let source1 = repartition_exec(memory_exec(&schema)); - let source2 = repartition_exec(memory_exec(&schema)); let union = union_exec(vec![source1, source2]); - - let sort_exprs = LexOrdering::new(vec![sort_expr("non_nullable_col", &schema)]); - // let sort = sort_exec(sort_exprs.clone(), union); - let sort = Arc::new( - SortExec::new(sort_exprs.clone(), union).with_preserve_partitioning(true), - ) as _; - let spm = sort_preserving_merge_exec(sort_exprs, sort); - + let ordering: LexOrdering = [sort_expr("non_nullable_col", &schema)].into(); + let sort = + Arc::new(SortExec::new(ordering.clone(), union).with_preserve_partitioning(true)) + as _; + let spm = sort_preserving_merge_exec(ordering, sort); let filter = filter_exec( - Arc::new(NotExpr::new( - col("non_nullable_col", schema.as_ref()).unwrap(), - )), + Arc::new(NotExpr::new(col("non_nullable_col", schema.as_ref())?)), spm, ); - - let sort_exprs = vec![ + let ordering2 = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let physical_plan = sort_exec(sort_exprs, filter); + ] + .into(); + let physical_plan = sort_exec(ordering2, filter); // When removing a `SortPreservingMergeExec`, make sure that partitioning // requirements are not violated. In some cases, we may need to replace // it with a `CoalescePartitionsExec` instead of directly removing it. - let expected_input = ["SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " FilterExec: NOT non_nullable_col@1", - " SortPreservingMergeExec: [non_nullable_col@1 ASC]", - " SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[true]", - " UnionExec", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[0]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[0]"]; - - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[true]", - " FilterExec: NOT non_nullable_col@1", - " UnionExec", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[0]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[0]"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + FilterExec: NOT non_nullable_col@1 + SortPreservingMergeExec: [non_nullable_col@1 ASC] + SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[true] + UnionExec + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[0] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[0] + + Optimized Plan: + SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[true] + FilterExec: NOT non_nullable_col@1 + UnionExec + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[0] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[0] + "); Ok(()) } @@ -1602,31 +2076,31 @@ async fn test_remove_unnecessary_sort4() -> Result<()> { async fn test_remove_unnecessary_sort6() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let input = Arc::new( - SortExec::new( - LexOrdering::new(vec![sort_expr("non_nullable_col", &schema)]), - source, - ) - .with_fetch(Some(2)), + let input = sort_exec_with_fetch( + [sort_expr("non_nullable_col", &schema)].into(), + Some(2), + source, ); let physical_plan = sort_exec( - vec![ + [ sort_expr("non_nullable_col", &schema), sort_expr("nullable_col", &schema), - ], + ] + .into(), input, ); - - let expected_input = [ - "SortExec: expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false]", - " SortExec: TopK(fetch=2), expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - let expected_optimized = [ - "SortExec: TopK(fetch=2), expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortExec: expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false] + SortExec: TopK(fetch=2), expr=[non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + + Optimized Plan: + SortExec: TopK(fetch=2), expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + "); Ok(()) } @@ -1635,33 +2109,33 @@ async fn test_remove_unnecessary_sort6() -> Result<()> { async fn test_remove_unnecessary_sort7() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let input = Arc::new(SortExec::new( - LexOrdering::new(vec![ + let input = sort_exec( + [ sort_expr("non_nullable_col", &schema), sort_expr("nullable_col", &schema), - ]), + ] + .into(), source, - )); + ); + let physical_plan = sort_exec_with_fetch( + [sort_expr("non_nullable_col", &schema)].into(), + Some(2), + input, + ); - let physical_plan = Arc::new( - SortExec::new( - LexOrdering::new(vec![sort_expr("non_nullable_col", &schema)]), - input, - ) - .with_fetch(Some(2)), - ) as Arc; + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortExec: TopK(fetch=2), expr=[non_nullable_col@1 ASC], preserve_partitioning=[false], sort_prefix=[non_nullable_col@1 ASC] + SortExec: expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] - let expected_input = [ - "SortExec: TopK(fetch=2), expr=[non_nullable_col@1 ASC], preserve_partitioning=[false], sort_prefix=[non_nullable_col@1 ASC]", - " SortExec: expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - let expected_optimized = [ - "GlobalLimitExec: skip=0, fetch=2", - " SortExec: expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + Optimized Plan: + GlobalLimitExec: skip=0, fetch=2 + SortExec: expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + "); Ok(()) } @@ -1670,31 +2144,31 @@ async fn test_remove_unnecessary_sort7() -> Result<()> { async fn test_remove_unnecessary_sort8() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let input = Arc::new(SortExec::new( - LexOrdering::new(vec![sort_expr("non_nullable_col", &schema)]), - source, - )); + let input = sort_exec([sort_expr("non_nullable_col", &schema)].into(), source); let limit = Arc::new(LocalLimitExec::new(input, 2)); let physical_plan = sort_exec( - vec![ + [ sort_expr("non_nullable_col", &schema), sort_expr("nullable_col", &schema), - ], + ] + .into(), limit, ); - let expected_input = [ - "SortExec: expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false]", - " LocalLimitExec: fetch=2", - " SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - let expected_optimized = [ - "LocalLimitExec: fetch=2", - " SortExec: TopK(fetch=2), expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortExec: expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false] + LocalLimitExec: fetch=2 + SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + + Optimized Plan: + LocalLimitExec: fetch=2 + SortExec: TopK(fetch=2), expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + "); Ok(()) } @@ -1703,27 +2177,19 @@ async fn test_remove_unnecessary_sort8() -> Result<()> { async fn test_do_not_pushdown_through_limit() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - // let input = sort_exec(vec![sort_expr("non_nullable_col", &schema)], source); - let input = Arc::new(SortExec::new( - LexOrdering::new(vec![sort_expr("non_nullable_col", &schema)]), - source, - )); + let input = sort_exec([sort_expr("non_nullable_col", &schema)].into(), source); let limit = Arc::new(GlobalLimitExec::new(input, 0, Some(5))) as _; - let physical_plan = sort_exec(vec![sort_expr("nullable_col", &schema)], limit); - - let expected_input = [ - "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " GlobalLimitExec: skip=0, fetch=5", - " SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - let expected_optimized = [ - "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " GlobalLimitExec: skip=0, fetch=5", - " SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let physical_plan = sort_exec([sort_expr("nullable_col", &schema)].into(), limit); + + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input / Optimized Plan: + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + GlobalLimitExec: skip=0, fetch=5 + SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + "); Ok(()) } @@ -1732,24 +2198,25 @@ async fn test_do_not_pushdown_through_limit() -> Result<()> { async fn test_remove_unnecessary_spm1() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let input = - sort_preserving_merge_exec(vec![sort_expr("non_nullable_col", &schema)], source); - let input2 = - sort_preserving_merge_exec(vec![sort_expr("non_nullable_col", &schema)], input); + let ordering: LexOrdering = [sort_expr("non_nullable_col", &schema)].into(); + let input = sort_preserving_merge_exec(ordering.clone(), source); + let input2 = sort_preserving_merge_exec(ordering, input); let physical_plan = - sort_preserving_merge_exec(vec![sort_expr("nullable_col", &schema)], input2); - - let expected_input = [ - "SortPreservingMergeExec: [nullable_col@0 ASC]", - " SortPreservingMergeExec: [non_nullable_col@1 ASC]", - " SortPreservingMergeExec: [non_nullable_col@1 ASC]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - let expected_optimized = [ - "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + sort_preserving_merge_exec([sort_expr("nullable_col", &schema)].into(), input2); + + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [nullable_col@0 ASC] + SortPreservingMergeExec: [non_nullable_col@1 ASC] + SortPreservingMergeExec: [non_nullable_col@1 ASC] + DataSourceExec: partitions=1, partition_sizes=[0] + + Optimized Plan: + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + "); Ok(()) } @@ -1759,21 +2226,22 @@ async fn test_remove_unnecessary_spm2() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); let input = sort_preserving_merge_exec_with_fetch( - vec![sort_expr("non_nullable_col", &schema)], + [sort_expr("non_nullable_col", &schema)].into(), source, 100, ); - let expected_input = [ - "SortPreservingMergeExec: [non_nullable_col@1 ASC], fetch=100", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - let expected_optimized = [ - "LocalLimitExec: fetch=100", - " SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - assert_optimized!(expected_input, expected_optimized, input, true); + let test = EnforceSortingTest::new(input.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [non_nullable_col@1 ASC], fetch=100 + DataSourceExec: partitions=1, partition_sizes=[0] + + Optimized Plan: + LocalLimitExec: fetch=100 + SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + "); Ok(()) } @@ -1782,22 +2250,25 @@ async fn test_remove_unnecessary_spm2() -> Result<()> { async fn test_change_wrong_sorting() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let sort_exprs = vec![ + let sort_exprs = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), ]; - let sort = sort_exec(vec![sort_exprs[0].clone()], source); - let physical_plan = sort_preserving_merge_exec(sort_exprs, sort); - let expected_input = [ - "SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - let expected_optimized = [ - "SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let sort = sort_exec([sort_exprs[0].clone()].into(), source); + let physical_plan = sort_preserving_merge_exec(sort_exprs.into(), sort); + + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + + Optimized Plan: + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + "); Ok(()) } @@ -1806,25 +2277,26 @@ async fn test_change_wrong_sorting() -> Result<()> { async fn test_change_wrong_sorting2() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let sort_exprs = vec![ + let sort_exprs = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), ]; - let spm1 = sort_preserving_merge_exec(sort_exprs.clone(), source); - let sort2 = sort_exec(vec![sort_exprs[0].clone()], spm1); - let physical_plan = sort_preserving_merge_exec(vec![sort_exprs[1].clone()], sort2); - - let expected_input = [ - "SortPreservingMergeExec: [non_nullable_col@1 ASC]", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - let expected_optimized = [ - "SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let spm1 = sort_preserving_merge_exec(sort_exprs.clone().into(), source); + let sort2 = sort_exec([sort_exprs[0].clone()].into(), spm1); + let physical_plan = sort_preserving_merge_exec([sort_exprs[1].clone()].into(), sort2); + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [non_nullable_col@1 ASC] + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] + DataSourceExec: partitions=1, partition_sizes=[0] + + Optimized Plan: + SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + "); Ok(()) } @@ -1833,32 +2305,34 @@ async fn test_change_wrong_sorting2() -> Result<()> { async fn test_multiple_sort_window_exec() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - - let sort_exprs1 = vec![sort_expr("nullable_col", &schema)]; - let sort_exprs2 = vec![ + let ordering1 = [sort_expr("nullable_col", &schema)]; + let sort1 = sort_exec(ordering1.clone().into(), source); + let window_agg1 = bounded_window_exec("non_nullable_col", ordering1.clone(), sort1); + let ordering2 = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), ]; - - let sort1 = sort_exec(sort_exprs1.clone(), source); - let window_agg1 = bounded_window_exec("non_nullable_col", sort_exprs1.clone(), sort1); - let window_agg2 = bounded_window_exec("non_nullable_col", sort_exprs2, window_agg1); - // let filter_exec = sort_exec; - let physical_plan = bounded_window_exec("non_nullable_col", sort_exprs1, window_agg2); - - let expected_input = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]"]; - - let expected_optimized = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let window_agg2 = bounded_window_exec("non_nullable_col", ordering2, window_agg1); + let physical_plan = bounded_window_exec("non_nullable_col", ordering1, window_agg2); + + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r#" + Input Plan: + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + + Optimized Plan: + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + "#); Ok(()) } @@ -1871,47 +2345,38 @@ async fn test_multiple_sort_window_exec() -> Result<()> { // EnforceDistribution may invalidate ordering invariant. async fn test_commutativity() -> Result<()> { let schema = create_test_schema()?; - let config = ConfigOptions::new(); - let memory_exec = memory_exec(&schema); - let sort_exprs = LexOrdering::new(vec![sort_expr("nullable_col", &schema)]); + let sort_exprs = [sort_expr("nullable_col", &schema)]; let window = bounded_window_exec("nullable_col", sort_exprs.clone(), memory_exec); let repartition = repartition_exec(window); + let orig_plan = sort_exec(sort_exprs.into(), repartition); - let orig_plan = - Arc::new(SortExec::new(sort_exprs, repartition)) as Arc; - let actual = get_plan_string(&orig_plan); - let expected_input = vec![ - "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - assert_eq!( - expected_input, actual, - "\n**Original Plan Mismatch\n\nexpected:\n\n{expected_input:#?}\nactual:\n\n{actual:#?}\n\n" - ); + assert_snapshot!(displayable(orig_plan.as_ref()).indent(true), @r#" + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: partitions=1, partition_sizes=[0] + "#); - let mut plan = orig_plan.clone(); + let config = ConfigOptions::new(); let rules = vec![ Arc::new(EnforceDistribution::new()) as Arc, Arc::new(EnforceSorting::new()) as Arc, ]; + let mut first_plan = orig_plan.clone(); for rule in rules { - plan = rule.optimize(plan, &config)?; + first_plan = rule.optimize(first_plan, &config)?; } - let first_plan = plan.clone(); - let mut plan = orig_plan.clone(); let rules = vec![ Arc::new(EnforceSorting::new()) as Arc, Arc::new(EnforceDistribution::new()) as Arc, Arc::new(EnforceSorting::new()) as Arc, ]; + let mut second_plan = orig_plan.clone(); for rule in rules { - plan = rule.optimize(plan, &config)?; + second_plan = rule.optimize(second_plan, &config)?; } - let second_plan = plan.clone(); assert_eq!(get_plan_string(&first_plan), get_plan_string(&second_plan)); Ok(()) @@ -1922,35 +2387,37 @@ async fn test_coalesce_propagate() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); let repartition = repartition_exec(source); - let coalesce_partitions = Arc::new(CoalescePartitionsExec::new(repartition)); + let coalesce_partitions = coalesce_partitions_exec(repartition); let repartition = repartition_exec(coalesce_partitions); - let sort_exprs = LexOrdering::new(vec![sort_expr("nullable_col", &schema)]); + let ordering: LexOrdering = [sort_expr("nullable_col", &schema)].into(); // Add local sort let sort = Arc::new( - SortExec::new(sort_exprs.clone(), repartition).with_preserve_partitioning(true), + SortExec::new(ordering.clone(), repartition).with_preserve_partitioning(true), ) as _; - let spm = sort_preserving_merge_exec(sort_exprs.clone(), sort); - let sort = sort_exec(sort_exprs, spm); + let spm = sort_preserving_merge_exec(ordering.clone(), sort); + let sort = sort_exec(ordering, spm); let physical_plan = sort.clone(); // Sort Parallelize rule should end Coalesce + Sort linkage when Sort is Global Sort // Also input plan is not valid as it is. We need to add SortExec before SortPreservingMergeExec. - let expected_input = [ - "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " SortPreservingMergeExec: [nullable_col@0 ASC]", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - let expected_optimized = [ - "SortPreservingMergeExec: [nullable_col@0 ASC]", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + SortPreservingMergeExec: [nullable_col@0 ASC] + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + CoalescePartitionsExec + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[0] + + Optimized Plan: + SortPreservingMergeExec: [nullable_col@0 ASC] + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[0] + "); Ok(()) } @@ -1958,1425 +2425,167 @@ async fn test_coalesce_propagate() -> Result<()> { #[tokio::test] async fn test_replace_with_partial_sort2() -> Result<()> { let schema = create_test_schema3()?; - let input_sort_exprs = vec![sort_expr("a", &schema), sort_expr("c", &schema)]; - let unbounded_input = stream_exec_ordered(&schema, input_sort_exprs); - + let input_ordering = [sort_expr("a", &schema), sort_expr("c", &schema)].into(); + let unbounded_input = stream_exec_ordered(&schema, input_ordering); let physical_plan = sort_exec( - vec![ + [ sort_expr("a", &schema), sort_expr("c", &schema), sort_expr("d", &schema), - ], + ] + .into(), unbounded_input, ); + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortExec: expr=[a@0 ASC, c@2 ASC, d@3 ASC], preserve_partitioning=[false] + StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC, c@2 ASC] + + Optimized Plan: + PartialSortExec: expr=[a@0 ASC, c@2 ASC, d@3 ASC], common_prefix_length=[2] + StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC, c@2 ASC] + "); - let expected_input = [ - "SortExec: expr=[a@0 ASC, c@2 ASC, d@3 ASC], preserve_partitioning=[false]", - " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC, c@2 ASC]" - ]; - // let optimized - let expected_optimized = [ - "PartialSortExec: expr=[a@0 ASC, c@2 ASC, d@3 ASC], common_prefix_length=[2]", - " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC, c@2 ASC]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } #[tokio::test] async fn test_push_with_required_input_ordering_prohibited() -> Result<()> { - // SortExec: expr=[b] <-- can't push this down - // RequiredInputOrder expr=[a] <-- this requires input sorted by a, and preserves the input order - // SortExec: expr=[a] - // DataSourceExec let schema = create_test_schema3()?; - let sort_exprs_a = LexOrdering::new(vec![sort_expr("a", &schema)]); - let sort_exprs_b = LexOrdering::new(vec![sort_expr("b", &schema)]); + let ordering_a: LexOrdering = [sort_expr("a", &schema)].into(); + let ordering_b: LexOrdering = [sort_expr("b", &schema)].into(); let plan = memory_exec(&schema); - let plan = sort_exec(sort_exprs_a.clone(), plan); + let plan = sort_exec(ordering_a.clone(), plan); let plan = RequirementsTestExec::new(plan) - .with_required_input_ordering(sort_exprs_a) + .with_required_input_ordering(Some(ordering_a)) .with_maintains_input_order(true) .into_arc(); - let plan = sort_exec(sort_exprs_b, plan); - - let expected_input = [ - "SortExec: expr=[b@1 ASC], preserve_partitioning=[false]", - " RequiredInputOrderingExec", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; + let plan = sort_exec(ordering_b, plan); + let test = EnforceSortingTest::new(plan.clone()).with_repartition_sorts(true); // should not be able to push shorts - let expected_no_change = expected_input; - assert_optimized!(expected_input, expected_no_change, plan, true); + + assert_snapshot!(test.run(), @r" + Input / Optimized Plan: + SortExec: expr=[b@1 ASC], preserve_partitioning=[false] + RequiredInputOrderingExec + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + "); Ok(()) } // test when the required input ordering is satisfied so could push through #[tokio::test] async fn test_push_with_required_input_ordering_allowed() -> Result<()> { - // SortExec: expr=[a,b] <-- can push this down (as it is compatible with the required input ordering) - // RequiredInputOrder expr=[a] <-- this requires input sorted by a, and preserves the input order - // SortExec: expr=[a] - // DataSourceExec let schema = create_test_schema3()?; - let sort_exprs_a = LexOrdering::new(vec![sort_expr("a", &schema)]); - let sort_exprs_ab = - LexOrdering::new(vec![sort_expr("a", &schema), sort_expr("b", &schema)]); + let ordering_a: LexOrdering = [sort_expr("a", &schema)].into(); + let ordering_ab = [sort_expr("a", &schema), sort_expr("b", &schema)].into(); let plan = memory_exec(&schema); - let plan = sort_exec(sort_exprs_a.clone(), plan); + let plan = sort_exec(ordering_a.clone(), plan); let plan = RequirementsTestExec::new(plan) - .with_required_input_ordering(sort_exprs_a) + .with_required_input_ordering(Some(ordering_a)) .with_maintains_input_order(true) .into_arc(); - let plan = sort_exec(sort_exprs_ab, plan); + let plan = sort_exec(ordering_ab, plan); + /* let expected_input = [ - "SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[false]", - " RequiredInputOrderingExec", + "SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[false]", // <-- can push this down (as it is compatible with the required input ordering) + " RequiredInputOrderingExec", // <-- this requires input sorted by a, and preserves the input order " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", " DataSourceExec: partitions=1, partition_sizes=[0]", ]; - // should able to push shorts - let expected = [ - "RequiredInputOrderingExec", - " SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - assert_optimized!(expected_input, expected, plan, true); + */ + let test = EnforceSortingTest::new(plan.clone()).with_repartition_sorts(true); + + assert_snapshot!(test.run(), @r" + Input Plan: + SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[false] + RequiredInputOrderingExec + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + + Optimized Plan: + RequiredInputOrderingExec + SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + "); + // Should be able to push down Ok(()) } #[tokio::test] async fn test_replace_with_partial_sort() -> Result<()> { let schema = create_test_schema3()?; - let input_sort_exprs = vec![sort_expr("a", &schema)]; - let unbounded_input = stream_exec_ordered(&schema, input_sort_exprs); - + let input_ordering = [sort_expr("a", &schema)].into(); + let unbounded_input = stream_exec_ordered(&schema, input_ordering); let physical_plan = sort_exec( - vec![sort_expr("a", &schema), sort_expr("c", &schema)], + [sort_expr("a", &schema), sort_expr("c", &schema)].into(), unbounded_input, ); - let expected_input = [ - "SortExec: expr=[a@0 ASC, c@2 ASC], preserve_partitioning=[false]", - " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC]" - ]; - let expected_optimized = [ - "PartialSortExec: expr=[a@0 ASC, c@2 ASC], common_prefix_length=[1]", - " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortExec: expr=[a@0 ASC, c@2 ASC], preserve_partitioning=[false] + StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC] + + Optimized Plan: + PartialSortExec: expr=[a@0 ASC, c@2 ASC], common_prefix_length=[1] + StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC] + "); Ok(()) } #[tokio::test] async fn test_not_replaced_with_partial_sort_for_unbounded_input() -> Result<()> { let schema = create_test_schema3()?; - let input_sort_exprs = vec![sort_expr("b", &schema), sort_expr("c", &schema)]; - let unbounded_input = stream_exec_ordered(&schema, input_sort_exprs); - + let input_ordering = [sort_expr("b", &schema), sort_expr("c", &schema)].into(); + let unbounded_input = stream_exec_ordered(&schema, input_ordering); let physical_plan = sort_exec( - vec![ + [ sort_expr("a", &schema), sort_expr("b", &schema), sort_expr("c", &schema), - ], + ] + .into(), unbounded_input, ); - let expected_input = [ - "SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[false]", - " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC]" - ]; - let expected_no_change = expected_input; - assert_optimized!(expected_input, expected_no_change, physical_plan, true); - Ok(()) -} - -#[tokio::test] -async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { - let input_schema = create_test_schema()?; - let sort_exprs = vec![sort_expr_options( - "nullable_col", - &input_schema, - SortOptions { - descending: false, - nulls_first: false, - }, - )]; - let source = parquet_exec_sorted(&input_schema, sort_exprs); - - // Function definition - Alias of the resulting column - Arguments of the function - #[derive(Clone)] - struct WindowFuncParam(WindowFunctionDefinition, String, Vec>); - let function_arg_ordered = vec![col("nullable_col", &input_schema)?]; - let function_arg_unordered = vec![col("non_nullable_col", &input_schema)?]; - let fn_count_on_ordered = WindowFuncParam( - WindowFunctionDefinition::AggregateUDF(count_udaf()), - "count".to_string(), - function_arg_ordered.clone(), - ); - let fn_max_on_ordered = WindowFuncParam( - WindowFunctionDefinition::AggregateUDF(max_udaf()), - "max".to_string(), - function_arg_ordered.clone(), - ); - let fn_min_on_ordered = WindowFuncParam( - WindowFunctionDefinition::AggregateUDF(min_udaf()), - "min".to_string(), - function_arg_ordered.clone(), - ); - let fn_avg_on_ordered = WindowFuncParam( - WindowFunctionDefinition::AggregateUDF(avg_udaf()), - "avg".to_string(), - function_arg_ordered, - ); - let fn_count_on_unordered = WindowFuncParam( - WindowFunctionDefinition::AggregateUDF(count_udaf()), - "count".to_string(), - function_arg_unordered.clone(), - ); - let fn_max_on_unordered = WindowFuncParam( - WindowFunctionDefinition::AggregateUDF(max_udaf()), - "max".to_string(), - function_arg_unordered.clone(), - ); - let fn_min_on_unordered = WindowFuncParam( - WindowFunctionDefinition::AggregateUDF(min_udaf()), - "min".to_string(), - function_arg_unordered.clone(), - ); - let fn_avg_on_unordered = WindowFuncParam( - WindowFunctionDefinition::AggregateUDF(avg_udaf()), - "avg".to_string(), - function_arg_unordered, - ); - struct TestCase<'a> { - // Whether window expression has a partition_by expression or not. - // If it does, it will be on the ordered column -- `nullable_col`. - partition_by: bool, - // Whether the frame is unbounded in both directions, or unbounded in - // only one direction (when set-monotonicity has a meaning), or it is - // a sliding window. - window_frame: Arc, - // Function definition - Alias of the resulting column - Arguments of the function - func: WindowFuncParam, - // Global sort requirement at the root and its direction, - // which is required to be removed or preserved -- (asc, nulls_first) - required_sort_columns: Vec<(&'a str, bool, bool)>, - initial_plan: Vec<&'a str>, - expected_plan: Vec<&'a str>, - } - let test_cases = vec![ - // ============================================REGION STARTS============================================ - // WindowAggExec + Plain(unbounded preceding, unbounded following) + no partition_by + on ordered column - // Case 0: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(None)), - func: fn_count_on_ordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("count", true, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 1: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(None)), - func: fn_max_on_ordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("max", false, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 2: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(None)), - func: fn_min_on_ordered.clone(), - required_sort_columns: vec![("min", false, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 3: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(None)), - func: fn_avg_on_ordered.clone(), - required_sort_columns: vec![("avg", true, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[avg@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // WindowAggExec + Plain(unbounded preceding, unbounded following) + no partition_by + on unordered column - // Case 4: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(None)), - func: fn_count_on_unordered.clone(), - required_sort_columns: vec![("non_nullable_col", true, false), ("count", true, false)], - initial_plan: vec![ - "SortExec: expr=[non_nullable_col@1 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 5: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(None)), - func: fn_max_on_unordered.clone(), - required_sort_columns: vec![("non_nullable_col", false, false), ("max", false, false)], - initial_plan: vec![ - "SortExec: expr=[non_nullable_col@1 DESC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 6: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(None)), - func: fn_min_on_unordered.clone(), - required_sort_columns: vec![("min", true, false), ("non_nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[min@2 ASC NULLS LAST, non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 7: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(None)), - func: fn_avg_on_unordered.clone(), - required_sort_columns: vec![("avg", false, false), ("nullable_col", false, false)], - initial_plan: vec![ - "SortExec: expr=[avg@2 DESC NULLS LAST, nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // WindowAggExec + Plain(unbounded preceding, unbounded following) + partition_by + on ordered column - // Case 8: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(None)), - func: fn_count_on_ordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("count", true, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 9: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(None)), - func: fn_max_on_ordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("max", false, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 10: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(None)), - func: fn_min_on_ordered.clone(), - required_sort_columns: vec![("min", false, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 11: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(None)), - func: fn_avg_on_ordered.clone(), - required_sort_columns: vec![("avg", true, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[avg@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[avg@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // WindowAggExec + Plain(unbounded preceding, unbounded following) + partition_by + on unordered column - // Case 12: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(None)), - func: fn_count_on_unordered.clone(), - required_sort_columns: vec![("non_nullable_col", true, false), ("count", true, false)], - initial_plan: vec![ - "SortExec: expr=[non_nullable_col@1 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[non_nullable_col@1 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 13: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(None)), - func: fn_max_on_unordered.clone(), - required_sort_columns: vec![("non_nullable_col", true, false), ("max", false, false)], - initial_plan: vec![ - "SortExec: expr=[non_nullable_col@1 ASC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[non_nullable_col@1 ASC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 14: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(None)), - func: fn_min_on_unordered.clone(), - required_sort_columns: vec![("min", false, false), ("non_nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[min@2 DESC NULLS LAST, non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[min@2 DESC NULLS LAST, non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 15: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(None)), - func: fn_avg_on_unordered.clone(), - required_sort_columns: vec![("avg", true, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[avg@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[avg@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // WindowAggExec + Sliding(current row, unbounded following) + no partition_by + on ordered column - // Case 16: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: fn_count_on_ordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("count", false, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 17: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: fn_max_on_ordered.clone(), - required_sort_columns: vec![("max", false, true), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[max@2 DESC, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 18: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: fn_min_on_ordered.clone(), - required_sort_columns: vec![("min", true, true), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[min@2 ASC, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 19: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: fn_avg_on_ordered.clone(), - required_sort_columns: vec![("avg", false, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[avg@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[avg@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // WindowAggExec + Sliding(current row, unbounded following) + no partition_by + on unordered column - // Case 20: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: fn_count_on_unordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("count", true, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 21: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: fn_max_on_unordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("max", false, true)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC], preserve_partitioning=[false]", - " WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 22: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: fn_min_on_unordered.clone(), - required_sort_columns: vec![("min", true, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[min@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[min@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 23: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: fn_avg_on_unordered.clone(), - required_sort_columns: vec![("avg", false, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[avg@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[avg@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // WindowAggExec + Sliding(current row, unbounded following) + partition_by + on ordered column - // Case 24: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: fn_count_on_ordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("count", false, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 25: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: fn_max_on_ordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("max", true, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 26: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: fn_min_on_ordered.clone(), - required_sort_columns: vec![("min", false, false)], - initial_plan: vec![ - "SortExec: expr=[min@2 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[min@2 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 27: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: fn_avg_on_ordered.clone(), - required_sort_columns: vec![("avg", false, false)], - initial_plan: vec![ - "SortExec: expr=[avg@2 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[avg@2 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // WindowAggExec + Sliding(current row, unbounded following) + partition_by + on unordered column - // Case 28: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: fn_count_on_unordered.clone(), - required_sort_columns: vec![("count", false, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[count@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[count@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet" - ], - }, - // Case 29: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: fn_max_on_unordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("max", false, true)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC], preserve_partitioning=[false]", - " WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 30: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: fn_min_on_unordered.clone(), - required_sort_columns: vec![("min", false, false)], - initial_plan: vec![ - "SortExec: expr=[min@2 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[min@2 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 31: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: fn_avg_on_unordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("avg", true, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, avg@2 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet" - ], - expected_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, avg@2 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet" - ], - }, - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // BoundedWindowAggExec + Plain(unbounded preceding, unbounded following) + no partition_by + on ordered column - // Case 32: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: fn_count_on_ordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("count", true, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 33: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: fn_max_on_ordered.clone(), - required_sort_columns: vec![("max", false, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[max@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[max@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet" - ], - }, - // Case 34: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: fn_min_on_ordered.clone(), - required_sort_columns: vec![("min", false, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet" - ], - expected_plan: vec![ - "BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 35: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: fn_avg_on_ordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("avg", true, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, avg@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, avg@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // BoundedWindowAggExec + Plain(unbounded preceding, unbounded following) + no partition_by + on unordered column - // Case 36: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: fn_count_on_unordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("count", true, true)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 37: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: fn_max_on_unordered.clone(), - required_sort_columns: vec![("max", true, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[max@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 38: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: fn_min_on_unordered.clone(), - required_sort_columns: vec![("min", false, true), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[min@2 DESC, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[min@2 DESC, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 39: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: fn_avg_on_unordered.clone(), - required_sort_columns: vec![("avg", true, false)], - initial_plan: vec![ - "SortExec: expr=[avg@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[avg@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // BoundedWindowAggExec + Plain(unbounded preceding, unbounded following) + partition_by + on ordered column - // Case 40: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: fn_count_on_ordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("count", true, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 41: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: fn_max_on_ordered.clone(), - required_sort_columns: vec![("max", true, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[max@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet" - ], - expected_plan: vec![ - "SortExec: expr=[max@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet" - ], - }, - // Case 42: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: fn_min_on_ordered.clone(), - required_sort_columns: vec![("min", false, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 43: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: fn_avg_on_ordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("avg", true, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, avg@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, avg@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // BoundedWindowAggExec + Plain(unbounded preceding, unbounded following) + partition_by + on unordered column - // Case 44: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: fn_count_on_unordered.clone(), - required_sort_columns: vec![ ("count", true, true)], - initial_plan: vec![ - "SortExec: expr=[count@2 ASC], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[count@2 ASC], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], - }, - // Case 45: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: fn_max_on_unordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("max", false, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 46: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: fn_min_on_unordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("min", false, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, min@2 DESC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 47: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: fn_avg_on_unordered.clone(), - required_sort_columns: vec![("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // BoundedWindowAggExec + Sliding(bounded preceding, bounded following) + no partition_by + on ordered column - // Case 48: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: fn_count_on_ordered.clone(), - required_sort_columns: vec![("count", true, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[count@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 49: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::Following(ScalarValue::new_one(&DataType::UInt32)?))), - func: fn_max_on_ordered.clone(), - required_sort_columns: vec![("max", true, false)], - initial_plan: vec![ - "SortExec: expr=[max@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: Following(UInt32(1)), is_causal: false }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[max@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: Following(UInt32(1)), is_causal: false }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 50: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: fn_min_on_ordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("min", false, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, min@2 DESC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 51: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: fn_avg_on_ordered.clone(), - required_sort_columns: vec![("avg", true, false)], - initial_plan: vec![ - "SortExec: expr=[avg@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[avg@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // BoundedWindowAggExec + Sliding(bounded preceding, bounded following) + no partition_by + on unordered column - // Case 52: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::Following(ScalarValue::new_one(&DataType::UInt32)?))), - func: fn_count_on_unordered.clone(), - required_sort_columns: vec![("count", true, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[count@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: Following(UInt32(1)), is_causal: false }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[count@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: Following(UInt32(1)), is_causal: false }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet" - ], - }, - // Case 53: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: fn_max_on_unordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("max", true, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 54: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: fn_min_on_unordered.clone(), - required_sort_columns: vec![("min", true, false)], - initial_plan: vec![ - "SortExec: expr=[min@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[min@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 55: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::Following(ScalarValue::new_one(&DataType::UInt32)?))), - func: fn_avg_on_unordered.clone(), - required_sort_columns: vec![("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: Following(UInt32(1)), is_causal: false }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: Following(UInt32(1)), is_causal: false }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // BoundedWindowAggExec + Sliding(bounded preceding, bounded following) + partition_by + on ordered column - // Case 56: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: fn_count_on_ordered.clone(), - required_sort_columns: vec![("count", true, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[count@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 57: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::Following(ScalarValue::new_one(&DataType::UInt32)?))), - func: fn_max_on_ordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("max", true, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: Following(UInt32(1)), is_causal: false }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: Following(UInt32(1)), is_causal: false }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 58: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: fn_min_on_ordered.clone(), - required_sort_columns: vec![("min", false, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 59: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: fn_avg_on_ordered.clone(), - required_sort_columns: vec![("avg", true, false)], - initial_plan: vec![ - "SortExec: expr=[avg@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[avg@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // BoundedWindowAggExec + Sliding(bounded preceding, bounded following) + partition_by + on unordered column - // Case 60: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: fn_count_on_unordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("count", true, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 61: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: fn_max_on_unordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("max", true, true)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 ASC], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 ASC], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 62: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: fn_min_on_unordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("min", false, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, min@2 DESC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, min@2 DESC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 63: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: fn_avg_on_unordered.clone(), - required_sort_columns: vec![("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // =============================================REGION ENDS============================================= - ]; - - for (case_idx, case) in test_cases.into_iter().enumerate() { - let partition_by = if case.partition_by { - vec![col("nullable_col", &input_schema)?] - } else { - vec![] - }; - let window_expr = create_window_expr( - &case.func.0, - case.func.1, - &case.func.2, - &partition_by, - &LexOrdering::default(), - case.window_frame, - input_schema.as_ref(), - false, - )?; - let window_exec = if window_expr.uses_bounded_memory() { - Arc::new(BoundedWindowAggExec::try_new( - vec![window_expr], - Arc::clone(&source), - InputOrderMode::Sorted, - case.partition_by, - )?) as Arc - } else { - Arc::new(WindowAggExec::try_new( - vec![window_expr], - Arc::clone(&source), - case.partition_by, - )?) as _ - }; - let output_schema = window_exec.schema(); - let sort_expr = case - .required_sort_columns - .iter() - .map(|(col_name, asc, nf)| { - sort_expr_options( - col_name, - &output_schema, - SortOptions { - descending: !asc, - nulls_first: *nf, - }, - ) - }) - .collect::>(); - let physical_plan = sort_exec(sort_expr, window_exec); - - assert_optimized!( - case.initial_plan, - case.expected_plan, - physical_plan, - true, - case_idx - ); - } - + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input / Optimized Plan: + SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[false] + StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC] + "); Ok(()) } #[test] fn test_removes_unused_orthogonal_sort() -> Result<()> { let schema = create_test_schema3()?; - let input_sort_exprs = vec![sort_expr("b", &schema), sort_expr("c", &schema)]; - let unbounded_input = stream_exec_ordered(&schema, input_sort_exprs.clone()); - - let orthogonal_sort = sort_exec(vec![sort_expr("a", &schema)], unbounded_input); - let output_sort = sort_exec(input_sort_exprs, orthogonal_sort); // same sort as data source + let input_ordering: LexOrdering = + [sort_expr("b", &schema), sort_expr("c", &schema)].into(); + let unbounded_input = stream_exec_ordered(&schema, input_ordering.clone()); + let orthogonal_sort = sort_exec([sort_expr("a", &schema)].into(), unbounded_input); + let output_sort = sort_exec(input_ordering, orthogonal_sort); // same sort as data source // Test scenario/input has an orthogonal sort: - let expected_input = [ - "SortExec: expr=[b@1 ASC, c@2 ASC], preserve_partitioning=[false]", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC]" - ]; - assert_eq!(get_plan_string(&output_sort), expected_input,); + let test = EnforceSortingTest::new(output_sort).with_repartition_sorts(true); + + assert_snapshot!(test.run(), @r" + Input Plan: + SortExec: expr=[b@1 ASC, c@2 ASC], preserve_partitioning=[false] + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC] + Optimized Plan: + StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC] + "); // Test: should remove orthogonal sort, and the uppermost (unneeded) sort: - let expected_optimized = [ - "StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC]" - ]; - assert_optimized!(expected_input, expected_optimized, output_sort, true); Ok(()) } @@ -3384,24 +2593,23 @@ fn test_removes_unused_orthogonal_sort() -> Result<()> { #[test] fn test_keeps_used_orthogonal_sort() -> Result<()> { let schema = create_test_schema3()?; - let input_sort_exprs = vec![sort_expr("b", &schema), sort_expr("c", &schema)]; - let unbounded_input = stream_exec_ordered(&schema, input_sort_exprs.clone()); - + let input_ordering: LexOrdering = + [sort_expr("b", &schema), sort_expr("c", &schema)].into(); + let unbounded_input = stream_exec_ordered(&schema, input_ordering.clone()); let orthogonal_sort = - sort_exec_with_fetch(vec![sort_expr("a", &schema)], Some(3), unbounded_input); // has fetch, so this orthogonal sort changes the output - let output_sort = sort_exec(input_sort_exprs, orthogonal_sort); + sort_exec_with_fetch([sort_expr("a", &schema)].into(), Some(3), unbounded_input); // has fetch, so this orthogonal sort changes the output + let output_sort = sort_exec(input_ordering, orthogonal_sort); // Test scenario/input has an orthogonal sort: - let expected_input = [ - "SortExec: expr=[b@1 ASC, c@2 ASC], preserve_partitioning=[false]", - " SortExec: TopK(fetch=3), expr=[a@0 ASC], preserve_partitioning=[false]", - " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC]" - ]; - assert_eq!(get_plan_string(&output_sort), expected_input,); + let test = EnforceSortingTest::new(output_sort).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input / Optimized Plan: + SortExec: expr=[b@1 ASC, c@2 ASC], preserve_partitioning=[false] + SortExec: TopK(fetch=3), expr=[a@0 ASC], preserve_partitioning=[false] + StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC] + "); // Test: should keep the orthogonal sort, since it modifies the output: - let expected_optimized = expected_input; - assert_optimized!(expected_input, expected_optimized, output_sort, true); Ok(()) } @@ -3409,35 +2617,36 @@ fn test_keeps_used_orthogonal_sort() -> Result<()> { #[test] fn test_handles_multiple_orthogonal_sorts() -> Result<()> { let schema = create_test_schema3()?; - let input_sort_exprs = vec![sort_expr("b", &schema), sort_expr("c", &schema)]; - let unbounded_input = stream_exec_ordered(&schema, input_sort_exprs.clone()); - - let orthogonal_sort_0 = sort_exec(vec![sort_expr("c", &schema)], unbounded_input); // has no fetch, so can be removed + let input_ordering: LexOrdering = + [sort_expr("b", &schema), sort_expr("c", &schema)].into(); + let unbounded_input = stream_exec_ordered(&schema, input_ordering.clone()); + let ordering0: LexOrdering = [sort_expr("c", &schema)].into(); + let orthogonal_sort_0 = sort_exec(ordering0.clone(), unbounded_input); // has no fetch, so can be removed + let ordering1: LexOrdering = [sort_expr("a", &schema)].into(); let orthogonal_sort_1 = - sort_exec_with_fetch(vec![sort_expr("a", &schema)], Some(3), orthogonal_sort_0); // has fetch, so this orthogonal sort changes the output - let orthogonal_sort_2 = sort_exec(vec![sort_expr("c", &schema)], orthogonal_sort_1); // has no fetch, so can be removed - let orthogonal_sort_3 = sort_exec(vec![sort_expr("a", &schema)], orthogonal_sort_2); // has no fetch, so can be removed - let output_sort = sort_exec(input_sort_exprs, orthogonal_sort_3); // final sort + sort_exec_with_fetch(ordering1.clone(), Some(3), orthogonal_sort_0); // has fetch, so this orthogonal sort changes the output + let orthogonal_sort_2 = sort_exec(ordering0, orthogonal_sort_1); // has no fetch, so can be removed + let orthogonal_sort_3 = sort_exec(ordering1, orthogonal_sort_2); // has no fetch, so can be removed + let output_sort = sort_exec(input_ordering, orthogonal_sort_3); // final sort // Test scenario/input has an orthogonal sort: - let expected_input = [ - "SortExec: expr=[b@1 ASC, c@2 ASC], preserve_partitioning=[false]", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", - " SortExec: TopK(fetch=3), expr=[a@0 ASC], preserve_partitioning=[false]", - " SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", - " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC]", - ]; - assert_eq!(get_plan_string(&output_sort), expected_input,); + let test = EnforceSortingTest::new(output_sort.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortExec: expr=[b@1 ASC, c@2 ASC], preserve_partitioning=[false] + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + SortExec: TopK(fetch=3), expr=[a@0 ASC], preserve_partitioning=[false] + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC] + + Optimized Plan: + SortExec: expr=[b@1 ASC, c@2 ASC], preserve_partitioning=[false] + SortExec: TopK(fetch=3), expr=[a@0 ASC], preserve_partitioning=[false] + StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC] + "); // Test: should keep only the needed orthogonal sort, and remove the unneeded ones: - let expected_optimized = [ - "SortExec: expr=[b@1 ASC, c@2 ASC], preserve_partitioning=[false]", - " SortExec: TopK(fetch=3), expr=[a@0 ASC], preserve_partitioning=[false]", - " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC]", - ]; - assert_optimized!(expected_input, expected_optimized, output_sort, true); - Ok(()) } @@ -3445,13 +2654,14 @@ fn test_handles_multiple_orthogonal_sorts() -> Result<()> { fn test_parallelize_sort_preserves_fetch() -> Result<()> { // Create a schema let schema = create_test_schema3()?; - let parquet_exec = parquet_exec(&schema); - let coalesced = Arc::new(CoalescePartitionsExec::new(parquet_exec.clone())); - let top_coalesced = - Arc::new(CoalescePartitionsExec::new(coalesced.clone()).with_fetch(Some(10))); + let parquet_exec = parquet_exec(schema); + let coalesced = coalesce_partitions_exec(parquet_exec.clone()); + let top_coalesced = coalesce_partitions_exec(coalesced.clone()) + .with_fetch(Some(10)) + .unwrap(); let requirements = PlanWithCorrespondingCoalescePartitions::new( - top_coalesced.clone(), + top_coalesced, true, vec![PlanWithCorrespondingCoalescePartitions::new( coalesced, @@ -3474,3 +2684,168 @@ fn test_parallelize_sort_preserves_fetch() -> Result<()> { ); Ok(()) } + +#[tokio::test] +async fn test_partial_sort_with_homogeneous_batches() -> Result<()> { + // Create schema for the table + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ])); + + // Create homogeneous batches - each batch has the same values for columns a and b + let batch1 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 1, 1])), + Arc::new(Int32Array::from(vec![1, 1, 1])), + Arc::new(Int32Array::from(vec![3, 2, 1])), + ], + )?; + let batch2 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![2, 2, 2])), + Arc::new(Int32Array::from(vec![2, 2, 2])), + Arc::new(Int32Array::from(vec![4, 6, 5])), + ], + )?; + let batch3 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![3, 3, 3])), + Arc::new(Int32Array::from(vec![3, 3, 3])), + Arc::new(Int32Array::from(vec![9, 7, 8])), + ], + )?; + + // Create session with batch size of 3 to match our homogeneous batch pattern + let session_config = SessionConfig::new() + .with_batch_size(3) + .with_target_partitions(1); + let ctx = SessionContext::new_with_config(session_config); + + let sort_order = vec![ + SortExpr::new( + Expr::Column(datafusion_common::Column::new( + Option::::None, + "a", + )), + true, + false, + ), + SortExpr::new( + Expr::Column(datafusion_common::Column::new( + Option::::None, + "b", + )), + true, + false, + ), + ]; + let batches = Arc::new(DummyStreamPartition { + schema: schema.clone(), + batches: vec![batch1, batch2, batch3], + }) as _; + let provider = StreamingTable::try_new(schema.clone(), vec![batches])? + .with_sort_order(sort_order) + .with_infinite_table(true); + ctx.register_table("test_table", Arc::new(provider))?; + + let sql = "SELECT * FROM test_table ORDER BY a ASC, c ASC"; + let df = ctx.sql(sql).await?; + + let physical_plan = df.create_physical_plan().await?; + + // Verify that PartialSortExec is used + let plan_str = displayable(physical_plan.as_ref()).indent(true).to_string(); + assert!( + plan_str.contains("PartialSortExec"), + "Expected PartialSortExec in plan:\n{plan_str}", + ); + + let task_ctx = Arc::new(TaskContext::default()); + let mut stream = physical_plan.execute(0, task_ctx.clone())?; + + let mut collected_batches = Vec::new(); + while let Some(batch) = stream.next().await { + let batch = batch?; + if batch.num_rows() > 0 { + collected_batches.push(batch); + } + } + + // Assert we got 3 separate batches (not concatenated into fewer) + assert_eq!( + collected_batches.len(), + 3, + "Expected 3 separate batches, got {}", + collected_batches.len() + ); + + // Verify each batch has been sorted within itself + let expected_values = [vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]]; + + for (i, batch) in collected_batches.iter().enumerate() { + let c_array = batch + .column(2) + .as_any() + .downcast_ref::() + .unwrap(); + let actual = c_array.values().iter().copied().collect::>(); + assert_eq!(actual, expected_values[i], "Batch {i} not sorted correctly",); + } + + assert_eq!( + task_ctx.runtime_env().memory_pool.reserved(), + 0, + "Memory should be released after execution" + ); + + Ok(()) +} + +#[tokio::test] +async fn test_sort_with_streaming_table() -> Result<()> { + let batch = record_batch!(("a", Int32, [1, 2, 3]), ("b", Int32, [1, 2, 3]))?; + + let ctx = SessionContext::new(); + + let sort_order = vec![ + SortExpr::new( + Expr::Column(datafusion_common::Column::new( + Option::::None, + "a", + )), + true, + false, + ), + SortExpr::new( + Expr::Column(datafusion_common::Column::new( + Option::::None, + "b", + )), + true, + false, + ), + ]; + let schema = batch.schema(); + let batches = Arc::new(DummyStreamPartition { + schema: schema.clone(), + batches: vec![batch], + }) as _; + let provider = StreamingTable::try_new(schema.clone(), vec![batches])? + .with_sort_order(sort_order); + ctx.register_table("test_table", Arc::new(provider))?; + + let sql = "SELECT a FROM test_table GROUP BY a ORDER BY a"; + let results = ctx.sql(sql).await?.collect().await?; + + assert_eq!(results.len(), 1); + assert_eq!(results[0].num_columns(), 1); + let expected = create_array!(Int32, vec![1, 2, 3]) as ArrayRef; + assert_eq!(results[0].column(0), &expected); + + Ok(()) +} diff --git a/datafusion/core/tests/physical_optimizer/enforce_sorting_monotonicity.rs b/datafusion/core/tests/physical_optimizer/enforce_sorting_monotonicity.rs new file mode 100644 index 0000000000000..de7611ff211a5 --- /dev/null +++ b/datafusion/core/tests/physical_optimizer/enforce_sorting_monotonicity.rs @@ -0,0 +1,1715 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::physical_optimizer::test_utils::{ + create_test_schema, parquet_exec_with_sort, sort_exec, sort_expr_options, +}; +use arrow::datatypes::DataType; +use arrow_schema::SortOptions; +use datafusion::common::ScalarValue; +use datafusion::logical_expr::WindowFrameBound; +use datafusion::logical_expr::WindowFrameUnits; +use datafusion_expr::{WindowFrame, WindowFunctionDefinition}; +use datafusion_functions_aggregate::average::avg_udaf; +use datafusion_functions_aggregate::count::count_udaf; +use datafusion_functions_aggregate::min_max::{max_udaf, min_udaf}; +use datafusion_physical_expr::expressions::col; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_plan::windows::{ + BoundedWindowAggExec, WindowAggExec, create_window_expr, +}; +use datafusion_physical_plan::{ExecutionPlan, InputOrderMode}; +use insta::assert_snapshot; +use std::sync::{Arc, LazyLock}; + +// Function definition - Alias of the resulting column - Arguments of the function +#[derive(Clone)] +struct WindowFuncParam( + WindowFunctionDefinition, + &'static str, + Vec>, +); + +fn function_arg_ordered() -> Vec> { + let input_schema = create_test_schema().unwrap(); + vec![col("nullable_col", &input_schema).unwrap()] +} +fn function_arg_unordered() -> Vec> { + let input_schema = create_test_schema().unwrap(); + vec![col("non_nullable_col", &input_schema).unwrap()] +} + +fn fn_count_on_ordered() -> WindowFuncParam { + WindowFuncParam( + WindowFunctionDefinition::AggregateUDF(count_udaf()), + "count", + function_arg_ordered(), + ) +} + +fn fn_max_on_ordered() -> WindowFuncParam { + WindowFuncParam( + WindowFunctionDefinition::AggregateUDF(max_udaf()), + "max", + function_arg_ordered(), + ) +} + +fn fn_min_on_ordered() -> WindowFuncParam { + WindowFuncParam( + WindowFunctionDefinition::AggregateUDF(min_udaf()), + "min", + function_arg_ordered(), + ) +} + +fn fn_avg_on_ordered() -> WindowFuncParam { + WindowFuncParam( + WindowFunctionDefinition::AggregateUDF(avg_udaf()), + "avg", + function_arg_ordered(), + ) +} + +fn fn_count_on_unordered() -> WindowFuncParam { + WindowFuncParam( + WindowFunctionDefinition::AggregateUDF(count_udaf()), + "count", + function_arg_unordered(), + ) +} + +fn fn_max_on_unordered() -> WindowFuncParam { + WindowFuncParam( + WindowFunctionDefinition::AggregateUDF(max_udaf()), + "max", + function_arg_unordered(), + ) +} +fn fn_min_on_unordered() -> WindowFuncParam { + WindowFuncParam( + WindowFunctionDefinition::AggregateUDF(min_udaf()), + "min", + function_arg_unordered(), + ) +} + +fn fn_avg_on_unordered() -> WindowFuncParam { + WindowFuncParam( + WindowFunctionDefinition::AggregateUDF(avg_udaf()), + "avg", + function_arg_unordered(), + ) +} + +struct TestWindowCase { + partition_by: bool, + window_frame: Arc, + func: WindowFuncParam, + required_sort: Vec<(&'static str, bool, bool)>, // (column name, ascending, nulls_first) +} +impl TestWindowCase { + fn source() -> Arc { + static SOURCE: LazyLock> = LazyLock::new(|| { + let input_schema = create_test_schema().unwrap(); + let ordering = [sort_expr_options( + "nullable_col", + &input_schema, + SortOptions { + descending: false, + nulls_first: false, + }, + )] + .into(); + parquet_exec_with_sort(input_schema.clone(), vec![ordering]) + }); + Arc::clone(&SOURCE) + } + + // runs the window test case and returns the string representation of the plan + fn run(self) -> String { + let input_schema = create_test_schema().unwrap(); + let source = Self::source(); + + let Self { + partition_by, + window_frame, + func: WindowFuncParam(func_def, func_name, func_args), + required_sort, + } = self; + let partition_by_exprs = if partition_by { + vec![col("nullable_col", &input_schema).unwrap()] + } else { + vec![] + }; + + let window_expr = create_window_expr( + &func_def, + func_name.to_string(), + &func_args, + &partition_by_exprs, + &[], + window_frame, + Arc::clone(&input_schema), + false, + false, + None, + ) + .unwrap(); + + let window_exec = if window_expr.uses_bounded_memory() { + Arc::new( + BoundedWindowAggExec::try_new( + vec![window_expr], + Arc::clone(&source), + InputOrderMode::Sorted, + partition_by, + ) + .unwrap(), + ) as Arc + } else { + Arc::new( + WindowAggExec::try_new( + vec![window_expr], + Arc::clone(&source), + partition_by, + ) + .unwrap(), + ) as Arc + }; + + let output_schema = window_exec.schema(); + let sort_expr = required_sort.into_iter().map(|(col, asc, nulls_first)| { + sort_expr_options( + col, + &output_schema, + SortOptions { + descending: !asc, + nulls_first, + }, + ) + }); + let ordering = LexOrdering::new(sort_expr).unwrap(); + let physical_plan = sort_exec(ordering, window_exec); + + crate::physical_optimizer::enforce_sorting::EnforceSortingTest::new(physical_plan) + .with_repartition_sorts(true) + .run() + } +} +#[test] +fn test_window_partial_constant_and_set_monotonicity_0() { + // ============================================REGION STARTS============================================ + // WindowAggExec + Plain(unbounded preceding, unbounded following) + no partition_by + on ordered column + // Case 0: + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(None)), + func: fn_count_on_ordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("count", true, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64 }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64 }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +#[test] +fn test_window_partial_constant_and_set_monotonicity_1() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(None)), + func: fn_max_on_ordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("max", false, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +#[test] +fn test_window_partial_constant_and_set_monotonicity_2() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(None)), + func: fn_min_on_ordered(), + required_sort: vec![ + ("min", false, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +#[test] +fn test_window_partial_constant_and_set_monotonicity_3() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(None)), + func: fn_avg_on_ordered(), + required_sort: vec![ + ("avg", true, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[avg@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +#[test] +fn test_window_partial_constant_and_set_monotonicity_4() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(None)), + func: fn_count_on_unordered(), + required_sort: vec![ + ("non_nullable_col", true, false), + ("count", true, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[non_nullable_col@1 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64 }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + SortExec: expr=[non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64 }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +#[test] +fn test_window_partial_constant_and_set_monotonicity_5() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(None)), + func: fn_max_on_unordered(), + required_sort: vec![ + ("non_nullable_col", false, false), + ("max", false, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[non_nullable_col@1 DESC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + SortExec: expr=[non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +#[test] +fn test_window_partial_constant_and_set_monotonicity_6() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(None)), + func: fn_min_on_unordered(), + required_sort: vec![ + ("min", true, false), + ("non_nullable_col", true, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[min@2 ASC NULLS LAST, non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + SortExec: expr=[non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +#[test] +fn test_window_partial_constant_and_set_monotonicity_7() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(None)), + func: fn_avg_on_unordered(), + required_sort: vec![ + ("avg", false, false), + ("nullable_col", false, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[avg@2 DESC NULLS LAST, nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// =============================================REGION ENDS============================================= +// = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = +// ============================================REGION STARTS============================================ + +#[test] +fn test_window_partial_constant_and_set_monotonicity_8() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(None)), + func: fn_count_on_ordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("count", true, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64 }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64 }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +#[test] +fn test_window_partial_constant_and_set_monotonicity_9() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(None)), + func: fn_max_on_ordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("max", false, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +#[test] +fn test_window_partial_constant_and_set_monotonicity_10() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(None)), + func: fn_min_on_ordered(), + required_sort: vec![ + ("min", false, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +#[test] +fn test_window_partial_constant_and_set_monotonicity_11() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(None)), + func: fn_avg_on_ordered(), + required_sort: vec![ + ("avg", true, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[avg@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// =============================================REGION ENDS============================================= +// = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = +// ============================================REGION STARTS============================================ +// WindowAggExec + Plain(unbounded preceding, unbounded following) + partition_by + on unordered column +// Case 12: +#[test] +fn test_window_partial_constant_and_set_monotonicity_12() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(None)), + func: fn_count_on_unordered(), + required_sort: vec![ + ("non_nullable_col", true, false), + ("count", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[non_nullable_col@1 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64 }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 13: +#[test] +fn test_window_partial_constant_and_set_monotonicity_13() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(None)), + func: fn_max_on_unordered(), + required_sort: vec![ + ("non_nullable_col", true, false), + ("max", false, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[non_nullable_col@1 ASC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 14: +#[test] +fn test_window_partial_constant_and_set_monotonicity_14() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(None)), + func: fn_min_on_unordered(), + required_sort: vec![ + ("min", false, false), + ("non_nullable_col", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[min@2 DESC NULLS LAST, non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 15: +#[test] +fn test_window_partial_constant_and_set_monotonicity_15() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(None)), + func: fn_avg_on_unordered(), + required_sort: vec![ + ("avg", true, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[avg@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// =============================================REGION ENDS============================================= +// = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = +// ============================================REGION STARTS============================================ +// WindowAggExec + Sliding(current row, unbounded following) + no partition_by + on ordered column +// Case 16: +#[test] +fn test_window_partial_constant_and_set_monotonicity_16() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: fn_count_on_ordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("count", false, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64 }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64 }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 17: +#[test] +fn test_window_partial_constant_and_set_monotonicity_17() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: fn_max_on_ordered(), + required_sort: vec![ + ("max", false, true), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[max@2 DESC, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 18: +#[test] +fn test_window_partial_constant_and_set_monotonicity_18() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: fn_min_on_ordered(), + required_sort: vec![ + ("min", true, true), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[min@2 ASC, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 19: +#[test] +fn test_window_partial_constant_and_set_monotonicity_19() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: fn_avg_on_ordered(), + required_sort: vec![ + ("avg", false, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[avg@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// =============================================REGION ENDS============================================= +// = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = +// ============================================REGION STARTS============================================ +// WindowAggExec + Sliding(current row, unbounded following) + no partition_by + on unordered column +// Case 20: +#[test] +fn test_window_partial_constant_and_set_monotonicity_20() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: fn_count_on_unordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("count", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64 }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 21: +#[test] +fn test_window_partial_constant_and_set_monotonicity_21() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: fn_max_on_unordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("max", false, true), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC], preserve_partitioning=[false] + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 22: +#[test] +fn test_window_partial_constant_and_set_monotonicity_22() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: fn_min_on_unordered(), + required_sort: vec![ + ("min", true, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[min@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 23: +#[test] +fn test_window_partial_constant_and_set_monotonicity_23() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: fn_avg_on_unordered(), + required_sort: vec![ + ("avg", false, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[avg@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// =============================================REGION ENDS============================================= +// = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = +// ============================================REGION STARTS============================================ +// WindowAggExec + Sliding(current row, unbounded following) + partition_by + on ordered column +// Case 24: +#[test] +fn test_window_partial_constant_and_set_monotonicity_24() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: fn_count_on_ordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("count", false, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64 }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64 }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 25: +#[test] +fn test_window_partial_constant_and_set_monotonicity_25() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: fn_max_on_ordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("max", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 26: +#[test] +fn test_window_partial_constant_and_set_monotonicity_26() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: fn_min_on_ordered(), + required_sort: vec![ + ("min", false, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[min@2 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "#); +} + +// Case 27: +#[test] +fn test_window_partial_constant_and_set_monotonicity_27() { + assert_snapshot!( + TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: fn_avg_on_ordered(), + required_sort: vec![ + ("avg", false, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[avg@2 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "#); +} + +// =============================================REGION ENDS============================================= +// = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = +// ============================================REGION STARTS============================================ +// WindowAggExec + Sliding(current row, unbounded following) + partition_by + on unordered column + +// Case 28: +#[test] +fn test_window_partial_constant_and_set_monotonicity_28() { + assert_snapshot!( + TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: fn_count_on_unordered(), + required_sort: vec![ + ("count", false, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[count@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64 }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 29: +#[test] +fn test_window_partial_constant_and_set_monotonicity_29() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: fn_max_on_unordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("max", false, true), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC], preserve_partitioning=[false] + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "#) +} + +// Case 30: +#[test] +fn test_window_partial_constant_and_set_monotonicity_30() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: fn_min_on_unordered(), + required_sort: vec![ + ("min", false, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[min@2 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "#); +} + +// Case 31: +#[test] +fn test_window_partial_constant_and_set_monotonicity_31() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: fn_avg_on_unordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("avg", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, avg@2 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// =============================================REGION ENDS============================================= +// = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = +// ============================================REGION STARTS============================================ +// BoundedWindowAggExec + Plain(unbounded preceding, unbounded following) + no partition_by + on ordered column + +// Case 32: +#[test] +fn test_window_partial_constant_and_set_monotonicity_32() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: fn_count_on_ordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("count", true, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 33: +#[test] +fn test_window_partial_constant_and_set_monotonicity_33() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: fn_max_on_ordered(), + required_sort: vec![ + ("max", false, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[max@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[max: Field { "max": nullable Int32 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 34: +#[test] +fn test_window_partial_constant_and_set_monotonicity_34() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: fn_min_on_ordered(), + required_sort: vec![ + ("min", false, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[min: Field { "min": nullable Int32 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + BoundedWindowAggExec: wdw=[min: Field { "min": nullable Int32 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} +// Case 35: +#[test] +fn test_window_partial_constant_and_set_monotonicity_35() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: fn_avg_on_ordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("avg", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, avg@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[avg: Field { "avg": nullable Float64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// =============================================REGION ENDS============================================= +// = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = +// ============================================REGION STARTS============================================ +// BoundedWindowAggExec + Plain(unbounded preceding, unbounded following) + no partition_by + on unordered column + +// Case 36: +#[test] +fn test_window_partial_constant_and_set_monotonicity_36() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: fn_count_on_unordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("count", true, true), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 37: +#[test] +fn test_window_partial_constant_and_set_monotonicity_37() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: fn_max_on_unordered(), + required_sort: vec![ + ("max", true, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[max@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[max: Field { "max": nullable Int32 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + BoundedWindowAggExec: wdw=[max: Field { "max": nullable Int32 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 38: +#[test] +fn test_window_partial_constant_and_set_monotonicity_38() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: fn_min_on_unordered(), + required_sort: vec![ + ("min", false, true), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[min@2 DESC, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[min: Field { "min": nullable Int32 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 39: +#[test] +fn test_window_partial_constant_and_set_monotonicity_39() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: fn_avg_on_unordered(), + required_sort: vec![ + ("avg", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[avg@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[avg: Field { "avg": nullable Float64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// =============================================REGION ENDS============================================= +// = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = +// ============================================REGION STARTS============================================ +// BoundedWindowAggExec + Plain(unbounded preceding, unbounded following) + partition_by + on ordered column + +// Case 40: +#[test] +fn test_window_partial_constant_and_set_monotonicity_40() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: fn_count_on_ordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("count", true, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 41: +#[test] +fn test_window_partial_constant_and_set_monotonicity_41() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: fn_max_on_ordered(), + required_sort: vec![ + ("max", true, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[max@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[max: Field { "max": nullable Int32 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 42: +#[test] +fn test_window_partial_constant_and_set_monotonicity_42() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: fn_min_on_ordered(), + required_sort: vec![ + ("min", false, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[min: Field { "min": nullable Int32 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 43: +#[test] +fn test_window_partial_constant_and_set_monotonicity_43() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: fn_avg_on_ordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("avg", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, avg@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[avg: Field { "avg": nullable Float64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// =============================================REGION ENDS============================================= +// = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = +// ============================================REGION STARTS============================================ +// BoundedWindowAggExec + Plain(unbounded preceding, unbounded following) + partition_by + on unordered column + +// Case 44: +#[test] +fn test_window_partial_constant_and_set_monotonicity_44() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: fn_count_on_unordered(), + required_sort: vec![ + ("count", true, true), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[count@2 ASC], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 45: +#[test] +fn test_window_partial_constant_and_set_monotonicity_45() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: fn_max_on_unordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("max", false, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[max: Field { "max": nullable Int32 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 46: +#[test] +fn test_window_partial_constant_and_set_monotonicity_46() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: fn_min_on_unordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("min", false, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, min@2 DESC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[min: Field { "min": nullable Int32 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + BoundedWindowAggExec: wdw=[min: Field { "min": nullable Int32 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 47: +#[test] +fn test_window_partial_constant_and_set_monotonicity_47() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: fn_avg_on_unordered(), + required_sort: vec![ + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[avg: Field { "avg": nullable Float64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + BoundedWindowAggExec: wdw=[avg: Field { "avg": nullable Float64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// =============================================REGION ENDS============================================= +// = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = +// ============================================REGION STARTS============================================ +// BoundedWindowAggExec + Sliding(bounded preceding, bounded following) + no partition_by + on ordered column + +// Case 48: +#[test] +fn test_window_partial_constant_and_set_monotonicity_48() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32).unwrap()), WindowFrameBound::CurrentRow)), + func: fn_count_on_ordered(), + required_sort: vec![ + ("count", true, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[count@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 49: +#[test] +fn test_window_partial_constant_and_set_monotonicity_49() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32).unwrap()), WindowFrameBound::Following(ScalarValue::new_one(&DataType::UInt32).unwrap()))), + func: fn_max_on_ordered(), + required_sort: vec![ + ("max", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[max@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[max: Field { "max": nullable Int32 }, frame: ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 50: +#[test] +fn test_window_partial_constant_and_set_monotonicity_50() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32).unwrap()), WindowFrameBound::CurrentRow)), + func: fn_min_on_ordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("min", false, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, min@2 DESC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[min: Field { "min": nullable Int32 }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + BoundedWindowAggExec: wdw=[min: Field { "min": nullable Int32 }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 51: +#[test] +fn test_window_partial_constant_and_set_monotonicity_51() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32).unwrap()), WindowFrameBound::CurrentRow)), + func: fn_avg_on_ordered(), + required_sort: vec![ + ("avg", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[avg@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[avg: Field { "avg": nullable Float64 }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// =============================================REGION ENDS============================================= +// = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = +// ============================================REGION STARTS============================================ +// BoundedWindowAggExec + Sliding(bounded preceding, bounded following) + no partition_by + on unordered column + +// Case 52: +#[test] +fn test_window_partial_constant_and_set_monotonicity_52() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32).unwrap()), WindowFrameBound::Following(ScalarValue::new_one(&DataType::UInt32).unwrap()))), + func: fn_count_on_unordered(), + required_sort: vec![ + ("count", true, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[count@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 53: +#[test] +fn test_window_partial_constant_and_set_monotonicity_53() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32).unwrap()), WindowFrameBound::CurrentRow)), + func: fn_max_on_unordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("max", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[max: Field { "max": nullable Int32 }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 54: +#[test] +fn test_window_partial_constant_and_set_monotonicity_54() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32).unwrap()), WindowFrameBound::CurrentRow)), + func: fn_min_on_unordered(), + required_sort: vec![ + ("min", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[min@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[min: Field { "min": nullable Int32 }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 55: +#[test] +fn test_window_partial_constant_and_set_monotonicity_55() { + assert_snapshot!(TestWindowCase { + partition_by: false, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32).unwrap()), WindowFrameBound::Following(ScalarValue::new_one(&DataType::UInt32).unwrap()))), + func: fn_avg_on_unordered(), + required_sort: vec![ + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[avg: Field { "avg": nullable Float64 }, frame: ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + BoundedWindowAggExec: wdw=[avg: Field { "avg": nullable Float64 }, frame: ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// =============================================REGION ENDS============================================= +// = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = +// ============================================REGION STARTS============================================ +// BoundedWindowAggExec + Sliding(bounded preceding, bounded following) + partition_by + on ordered column + +// Case 56: +#[test] +fn test_window_partial_constant_and_set_monotonicity_56() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32).unwrap()), WindowFrameBound::CurrentRow)), + func: fn_count_on_ordered(), + required_sort: vec![ + ("count", true, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[count@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 57: +#[test] +fn test_window_partial_constant_and_set_monotonicity_57() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32).unwrap()), WindowFrameBound::Following(ScalarValue::new_one(&DataType::UInt32).unwrap()))), + func: fn_max_on_ordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("max", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[max: Field { "max": nullable Int32 }, frame: ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 58: +#[test] +fn test_window_partial_constant_and_set_monotonicity_58() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32).unwrap()), WindowFrameBound::CurrentRow)), + func: fn_min_on_ordered(), + required_sort: vec![ + ("min", false, false), + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[min: Field { "min": nullable Int32 }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 59: +#[test] +fn test_window_partial_constant_and_set_monotonicity_59() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32).unwrap()), WindowFrameBound::CurrentRow)), + func: fn_avg_on_ordered(), + required_sort: vec![ + ("avg", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[avg@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[avg: Field { "avg": nullable Float64 }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// =============================================REGION ENDS============================================= +// = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = +// ============================================REGION STARTS============================================ +// BoundedWindowAggExec + Sliding(bounded preceding, bounded following) + partition_by + on unordered column + +// Case 60: +#[test] +fn test_window_partial_constant_and_set_monotonicity_60() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32).unwrap()), WindowFrameBound::CurrentRow)), + func: fn_count_on_unordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("count", true, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 61: +#[test] +fn test_window_partial_constant_and_set_monotonicity_61() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32).unwrap()), WindowFrameBound::CurrentRow)), + func: fn_max_on_unordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("max", true, true), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 ASC], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[max: Field { "max": nullable Int32 }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 62: +#[test] +fn test_window_partial_constant_and_set_monotonicity_62() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32).unwrap()), WindowFrameBound::CurrentRow)), + func: fn_min_on_unordered(), + required_sort: vec![ + ("nullable_col", true, false), + ("min", false, false), + ], + }.run(), + @ r#" + Input / Optimized Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, min@2 DESC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[min: Field { "min": nullable Int32 }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} + +// Case 63: +#[test] +fn test_window_partial_constant_and_set_monotonicity_63() { + assert_snapshot!(TestWindowCase { + partition_by: true, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32).unwrap()), WindowFrameBound::CurrentRow)), + func: fn_avg_on_unordered(), + required_sort: vec![ + ("nullable_col", true, false), + ], + }.run(), + @ r#" + Input Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[avg: Field { "avg": nullable Float64 }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + BoundedWindowAggExec: wdw=[avg: Field { "avg": nullable Float64 }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + ); +} +// =============================================REGION ENDS============================================= diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown.rs b/datafusion/core/tests/physical_optimizer/filter_pushdown.rs new file mode 100644 index 0000000000000..8f430f7753ef6 --- /dev/null +++ b/datafusion/core/tests/physical_optimizer/filter_pushdown.rs @@ -0,0 +1,4463 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::{Arc, LazyLock}; + +use arrow::{ + array::{Float64Array, Int32Array, RecordBatch, StringArray, record_batch}, + datatypes::{DataType, Field, Schema, SchemaRef}, + util::pretty::pretty_format_batches, +}; +use arrow_schema::SortOptions; +use datafusion::{ + assert_batches_eq, + logical_expr::Operator, + physical_plan::{ + PhysicalExpr, + expressions::{BinaryExpr, Column, Literal}, + }, + prelude::{ParquetReadOptions, SessionConfig, SessionContext}, + scalar::ScalarValue, +}; +use datafusion_catalog::memory::DataSourceExec; +use datafusion_common::config::ConfigOptions; +use datafusion_datasource::{ + PartitionedFile, file_groups::FileGroup, file_scan_config::FileScanConfigBuilder, +}; +use datafusion_execution::object_store::ObjectStoreUrl; +use datafusion_expr::ScalarUDF; +use datafusion_functions::math::random::RandomFunc; +use datafusion_functions_aggregate::{ + count::count_udaf, + min_max::{max_udaf, min_udaf}, +}; +use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr, expressions::col}; +use datafusion_physical_expr::{ + Partitioning, ScalarFunctionExpr, + aggregate::{AggregateExprBuilder, AggregateFunctionExpr}, +}; +use datafusion_physical_optimizer::{ + PhysicalOptimizerRule, filter_pushdown::FilterPushdown, +}; +use datafusion_physical_plan::{ + ExecutionPlan, + aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}, + coalesce_partitions::CoalescePartitionsExec, + collect, + filter::{FilterExec, FilterExecBuilder}, + projection::ProjectionExec, + repartition::RepartitionExec, + sorts::sort::SortExec, +}; + +use super::pushdown_utils::{ + OptimizationTest, TestNode, TestScanBuilder, TestSource, format_plan_for_test, +}; +use datafusion_physical_plan::union::UnionExec; +use futures::StreamExt; +use object_store::{ObjectStore, memory::InMemory}; +use regex::Regex; + +#[test] +fn test_pushdown_into_scan() { + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, scan).unwrap()); + + // expect the predicate to be pushed down into the DataSource + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + " + ); +} + +#[test] +fn test_pushdown_volatile_functions_not_allowed() { + // Test that we do not push down filters with volatile functions + // Use random() as an example of a volatile function + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let cfg = Arc::new(ConfigOptions::default()); + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new_with_schema("a", &schema()).unwrap()), + Operator::Eq, + Arc::new( + ScalarFunctionExpr::try_new( + Arc::new(ScalarUDF::from(RandomFunc::new())), + vec![], + &schema(), + cfg, + ) + .unwrap(), + ), + )) as Arc; + let plan = Arc::new(FilterExec::try_new(predicate, scan).unwrap()); + // expect the filter to not be pushed down + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = random() + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - FilterExec: a@0 = random() + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + ", + ); +} + +/// Show that we can use config options to determine how to do pushdown. +#[test] +fn test_pushdown_into_scan_with_config_options() { + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, scan).unwrap()) as _; + + let mut cfg = ConfigOptions::default(); + insta::assert_snapshot!( + OptimizationTest::new( + Arc::clone(&plan), + FilterPushdown::new(), + false + ), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + " + ); + + cfg.execution.parquet.pushdown_filters = true; + insta::assert_snapshot!( + OptimizationTest::new( + plan, + FilterPushdown::new(), + true + ), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + " + ); +} + +#[tokio::test] +async fn test_dynamic_filter_pushdown_through_hash_join_with_topk() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + // Create build side with limited values + let build_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8View, ["ba", "bb"]), + ("c", Float64, [1.0, 2.0]) + ) + .unwrap(), + ]; + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8View, false), + Field::new("c", DataType::Float64, false), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .with_batches(build_batches) + .build(); + + // Create probe side with more values + let probe_batches = vec![ + record_batch!( + ("d", Utf8, ["aa", "ab", "ac", "ad"]), + ("e", Utf8View, ["ba", "bb", "bc", "bd"]), + ("f", Float64, [1.0, 2.0, 3.0, 4.0]) + ) + .unwrap(), + ]; + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("d", DataType::Utf8, false), + Field::new("e", DataType::Utf8View, false), + Field::new("f", DataType::Float64, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .with_batches(probe_batches) + .build(); + + // Create HashJoinExec + let on = vec![( + col("a", &build_side_schema).unwrap(), + col("d", &probe_side_schema).unwrap(), + )]; + let join = Arc::new( + HashJoinExec::try_new( + build_scan, + probe_scan, + on, + None, + &JoinType::Inner, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ); + + let join_schema = join.schema(); + + // Finally let's add a SortExec on the outside to test pushdown of dynamic filters + let sort_expr = + PhysicalSortExpr::new(col("e", &join_schema).unwrap(), SortOptions::default()); + let plan = Arc::new( + SortExec::new(LexOrdering::new(vec![sort_expr]).unwrap(), join) + .with_fetch(Some(2)), + ) as Arc; + + let mut config = ConfigOptions::default(); + config.optimizer.enable_dynamic_filter_pushdown = true; + config.execution.parquet.pushdown_filters = true; + + // Apply the FilterPushdown optimizer rule + let plan = FilterPushdown::new_post_optimization() + .optimize(Arc::clone(&plan), &config) + .unwrap(); + + // Test that filters are pushed down correctly to each side of the join + insta::assert_snapshot!( + format_plan_for_test(&plan), + @r" + - SortExec: TopK(fetch=2), expr=[e@4 ASC], preserve_partitioning=[false] + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, d@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] AND DynamicFilter [ empty ] + " + ); + + // Put some data through the plan to check that the filter is updated to reflect the TopK state + let session_ctx = SessionContext::new_with_config(SessionConfig::new()); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + let mut stream = plan.execute(0, Arc::clone(&task_ctx)).unwrap(); + // Iterate one batch + stream.next().await.unwrap().unwrap(); + + // Test that filters are pushed down correctly to each side of the join + // NOTE: We dropped the CASE expression here because we now optimize that away if there's only 1 partition + insta::assert_snapshot!( + format_plan_for_test(&plan), + @r" + - SortExec: TopK(fetch=2), expr=[e@4 ASC], preserve_partitioning=[false], filter=[e@4 IS NULL OR e@4 < bb] + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, d@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ d@0 >= aa AND d@0 <= ab AND d@0 IN (SET) ([aa, ab]) ] AND DynamicFilter [ e@1 IS NULL OR e@1 < bb ] + " + ); +} + +// Test both static and dynamic filter pushdown in HashJoinExec. +// Note that static filter pushdown is rare: it should have already happened in the logical optimizer phase. +// However users may manually construct plans that could result in a FilterExec -> HashJoinExec -> Scan setup. +// Dynamic filters arise in cases such as nested inner joins or TopK -> HashJoinExec -> Scan setups. +#[tokio::test] +async fn test_static_filter_pushdown_through_hash_join() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + // Create build side with limited values + let build_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8View, ["ba", "bb"]), + ("c", Float64, [1.0, 2.0]) + ) + .unwrap(), + ]; + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8View, false), + Field::new("c", DataType::Float64, false), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .with_batches(build_batches) + .build(); + + // Create probe side with more values + let probe_batches = vec![ + record_batch!( + ("d", Utf8, ["aa", "ab", "ac", "ad"]), + ("e", Utf8View, ["ba", "bb", "bc", "bd"]), + ("f", Float64, [1.0, 2.0, 3.0, 4.0]) + ) + .unwrap(), + ]; + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("d", DataType::Utf8, false), + Field::new("e", DataType::Utf8View, false), + Field::new("f", DataType::Float64, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .with_batches(probe_batches) + .build(); + + // Create HashJoinExec + let on = vec![( + col("a", &build_side_schema).unwrap(), + col("d", &probe_side_schema).unwrap(), + )]; + let join = Arc::new( + HashJoinExec::try_new( + build_scan, + probe_scan, + on, + None, + &JoinType::Inner, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ); + + // Create filters that can be pushed down to different sides + // We need to create filters in the context of the join output schema + let join_schema = join.schema(); + + // Filter on build side column: a = 'aa' + let left_filter = col_lit_predicate("a", "aa", &join_schema); + // Filter on probe side column: e = 'ba' + let right_filter = col_lit_predicate("e", "ba", &join_schema); + // Filter that references both sides: a = d (should not be pushed down) + let cross_filter = Arc::new(BinaryExpr::new( + col("a", &join_schema).unwrap(), + Operator::Eq, + col("d", &join_schema).unwrap(), + )) as Arc; + + let filter = + Arc::new(FilterExec::try_new(left_filter, Arc::clone(&join) as _).unwrap()); + let filter = Arc::new(FilterExec::try_new(right_filter, filter).unwrap()); + let plan = Arc::new(FilterExec::try_new(cross_filter, filter).unwrap()) + as Arc; + + // Test that filters are pushed down correctly to each side of the join + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = d@3 + - FilterExec: e@4 = ba + - FilterExec: a@0 = aa + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, d@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true + output: + Ok: + - FilterExec: a@0 = d@3 + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, d@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = aa + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true, predicate=e@1 = ba + " + ); + + // Test left join: filter on preserved (build) side is pushed down, + // filter on non-preserved (probe) side is NOT pushed down. + let join = Arc::new( + HashJoinExec::try_new( + TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .build(), + TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .build(), + vec![( + col("a", &build_side_schema).unwrap(), + col("d", &probe_side_schema).unwrap(), + )], + None, + &JoinType::Left, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ); + + let join_schema = join.schema(); + // Filter on build side column (preserved): should be pushed down + let left_filter = col_lit_predicate("a", "aa", &join_schema); + // Filter on probe side column (not preserved): should NOT be pushed down + let right_filter = col_lit_predicate("e", "ba", &join_schema); + let filter = + Arc::new(FilterExec::try_new(left_filter, Arc::clone(&join) as _).unwrap()); + let plan = Arc::new(FilterExec::try_new(right_filter, filter).unwrap()) + as Arc; + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: e@4 = ba + - FilterExec: a@0 = aa + - HashJoinExec: mode=Partitioned, join_type=Left, on=[(a@0, d@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true + output: + Ok: + - FilterExec: e@4 = ba + - HashJoinExec: mode=Partitioned, join_type=Left, on=[(a@0, d@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = aa + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true + " + ); +} + +#[test] +fn test_filter_collapse() { + // filter should be pushed down into the parquet scan with two filters + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let predicate1 = col_lit_predicate("a", "foo", &schema()); + let filter1 = Arc::new(FilterExec::try_new(predicate1, scan).unwrap()); + let predicate2 = col_lit_predicate("b", "bar", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate2, filter1).unwrap()); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: b@1 = bar + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo AND b@1 = bar + " + ); +} + +#[test] +fn test_filter_with_projection() { + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let projection = vec![1, 0]; + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new( + FilterExecBuilder::new(predicate, Arc::clone(&scan)) + .apply_projection(Some(projection)) + .unwrap() + .build() + .unwrap(), + ); + + // expect the predicate to be pushed down into the DataSource but the FilterExec to be converted to ProjectionExec + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo, projection=[b@1, a@0] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - ProjectionExec: expr=[b@1 as b, a@0 as a] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + ", + ); + + // add a test where the filter is on a column that isn't included in the output + let projection = vec![1]; + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new( + FilterExecBuilder::new(predicate, scan) + .apply_projection(Some(projection)) + .unwrap() + .build() + .unwrap(), + ); + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(),true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo, projection=[b@1] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - ProjectionExec: expr=[b@1 as b] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + " + ); +} + +#[test] +fn test_push_down_through_transparent_nodes() { + // expect the predicate to be pushed down into the DataSource + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let predicate = col_lit_predicate("a", "foo", &schema()); + let filter = Arc::new(FilterExec::try_new(predicate, scan).unwrap()); + let repartition = Arc::new( + RepartitionExec::try_new(filter, Partitioning::RoundRobinBatch(1)).unwrap(), + ); + let predicate = col_lit_predicate("b", "bar", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, repartition).unwrap()); + + // expect the predicate to be pushed down into the DataSource + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(),true), + @r" + OptimizationTest: + input: + - FilterExec: b@1 = bar + - RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=1 + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo AND b@1 = bar + " + ); +} + +#[test] +fn test_pushdown_through_aggregates_on_grouping_columns() { + // Test that filters on grouping columns can be pushed through AggregateExec. + // This test has two filters: + // 1. An inner filter (a@0 = foo) below the aggregate - gets pushed to DataSource + // 2. An outer filter (b@1 = bar) above the aggregate - also gets pushed through because 'b' is a grouping column + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + + let filter = Arc::new( + FilterExecBuilder::new(col_lit_predicate("a", "foo", &schema()), scan) + .with_batch_size(10) + .build() + .unwrap(), + ); + + let aggregate_expr = vec![ + AggregateExprBuilder::new(count_udaf(), vec![col("a", &schema()).unwrap()]) + .schema(schema()) + .alias("cnt") + .build() + .map(Arc::new) + .unwrap(), + ]; + let group_by = PhysicalGroupBy::new_single(vec![ + (col("a", &schema()).unwrap(), "a".to_string()), + (col("b", &schema()).unwrap(), "b".to_string()), + ]); + let aggregate = Arc::new( + AggregateExec::try_new( + AggregateMode::Final, + group_by, + aggregate_expr.clone(), + vec![None], + filter, + schema(), + ) + .unwrap(), + ); + + let predicate = col_lit_predicate("b", "bar", &schema()); + let plan = Arc::new( + FilterExecBuilder::new(predicate, aggregate) + .with_batch_size(100) + .build() + .unwrap(), + ); + + // Both filters should be pushed down to the DataSource since both reference grouping columns + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: b@1 = bar + - AggregateExec: mode=Final, gby=[a@0 as a, b@1 as b], aggr=[cnt], ordering_mode=PartiallySorted([0]) + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - AggregateExec: mode=Final, gby=[a@0 as a, b@1 as b], aggr=[cnt], ordering_mode=Sorted + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo AND b@1 = bar + " + ); +} + +/// Test various combinations of handling of child pushdown results +/// in an ExecutionPlan in combination with support/not support in a DataSource. +#[test] +fn test_node_handles_child_pushdown_result() { + // If we set `with_support(true)` + `inject_filter = true` then the filter is pushed down to the DataSource + // and no FilterExec is created. + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(TestNode::new(true, Arc::clone(&scan), predicate)); + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - TestInsertExec { inject_filter: true } + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - TestInsertExec { inject_filter: true } + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + ", + ); + + // If we set `with_support(false)` + `inject_filter = true` then the filter is not pushed down to the DataSource + // and a FilterExec is created. + let scan = TestScanBuilder::new(schema()).with_support(false).build(); + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(TestNode::new(true, Arc::clone(&scan), predicate)); + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - TestInsertExec { inject_filter: true } + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + output: + Ok: + - TestInsertExec { inject_filter: false } + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + ", + ); + + // If we set `with_support(false)` + `inject_filter = false` then the filter is not pushed down to the DataSource + // and no FilterExec is created. + let scan = TestScanBuilder::new(schema()).with_support(false).build(); + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(TestNode::new(false, Arc::clone(&scan), predicate)); + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - TestInsertExec { inject_filter: false } + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + output: + Ok: + - TestInsertExec { inject_filter: false } + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + ", + ); +} + +#[tokio::test] +async fn test_topk_dynamic_filter_pushdown() { + let batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8, ["bd", "bc"]), + ("c", Float64, [1.0, 2.0]) + ) + .unwrap(), + record_batch!( + ("a", Utf8, ["ac", "ad"]), + ("b", Utf8, ["bb", "ba"]), + ("c", Float64, [2.0, 1.0]) + ) + .unwrap(), + ]; + let scan = TestScanBuilder::new(schema()) + .with_support(true) + .with_batches(batches) + .build(); + let plan = Arc::new( + SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr::new( + col("b", &schema()).unwrap(), + SortOptions::new(true, false), // descending, nulls_first + )]) + .unwrap(), + Arc::clone(&scan), + ) + .with_fetch(Some(1)), + ) as Arc; + + // expect the predicate to be pushed down into the DataSource + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new_post_optimization(), true), + @r" + OptimizationTest: + input: + - SortExec: TopK(fetch=1), expr=[b@1 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - SortExec: TopK(fetch=1), expr=[b@1 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] + " + ); + + // Actually apply the optimization to the plan and put some data through it to check that the filter is updated to reflect the TopK state + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + let plan = FilterPushdown::new_post_optimization() + .optimize(plan, &config) + .unwrap(); + let config = SessionConfig::new().with_batch_size(2); + let session_ctx = SessionContext::new_with_config(config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + let mut stream = plan.execute(0, Arc::clone(&task_ctx)).unwrap(); + // Iterate one batch + stream.next().await.unwrap().unwrap(); + // Now check what our filter looks like + insta::assert_snapshot!( + format!("{}", format_plan_for_test(&plan)), + @r" + - SortExec: TopK(fetch=1), expr=[b@1 DESC NULLS LAST], preserve_partitioning=[false], filter=[b@1 > bd] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ b@1 > bd ] + " + ); +} + +#[tokio::test] +async fn test_topk_dynamic_filter_pushdown_multi_column_sort() { + let batches = vec![ + // We are going to do ORDER BY b ASC NULLS LAST, a DESC + // And we put the values in such a way that the first batch will fill the TopK + // and we skip the second batch. + record_batch!( + ("a", Utf8, ["ac", "ad"]), + ("b", Utf8, ["bb", "ba"]), + ("c", Float64, [2.0, 1.0]) + ) + .unwrap(), + record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8, ["bc", "bd"]), + ("c", Float64, [1.0, 2.0]) + ) + .unwrap(), + ]; + let scan = TestScanBuilder::new(schema()) + .with_support(true) + .with_batches(batches) + .build(); + let plan = Arc::new( + SortExec::new( + LexOrdering::new(vec![ + PhysicalSortExpr::new( + col("b", &schema()).unwrap(), + SortOptions::default().asc().nulls_last(), + ), + PhysicalSortExpr::new( + col("a", &schema()).unwrap(), + SortOptions::default().desc().nulls_first(), + ), + ]) + .unwrap(), + Arc::clone(&scan), + ) + .with_fetch(Some(2)), + ) as Arc; + + // expect the predicate to be pushed down into the DataSource + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new_post_optimization(), true), + @r" + OptimizationTest: + input: + - SortExec: TopK(fetch=2), expr=[b@1 ASC NULLS LAST, a@0 DESC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - SortExec: TopK(fetch=2), expr=[b@1 ASC NULLS LAST, a@0 DESC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] + " + ); + + // Actually apply the optimization to the plan and put some data through it to check that the filter is updated to reflect the TopK state + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + let plan = FilterPushdown::new_post_optimization() + .optimize(plan, &config) + .unwrap(); + let config = SessionConfig::new().with_batch_size(2); + let session_ctx = SessionContext::new_with_config(config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + let mut stream = plan.execute(0, Arc::clone(&task_ctx)).unwrap(); + // Iterate one batch + let res = stream.next().await.unwrap().unwrap(); + #[rustfmt::skip] + let expected = [ + "+----+----+-----+", + "| a | b | c |", + "+----+----+-----+", + "| ad | ba | 1.0 |", + "| ac | bb | 2.0 |", + "+----+----+-----+", + ]; + assert_batches_eq!(expected, &[res]); + // Now check what our filter looks like + insta::assert_snapshot!( + format!("{}", format_plan_for_test(&plan)), + @r" + - SortExec: TopK(fetch=2), expr=[b@1 ASC NULLS LAST, a@0 DESC], preserve_partitioning=[false], filter=[b@1 < bb OR b@1 = bb AND (a@0 IS NULL OR a@0 > ac)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ b@1 < bb OR b@1 = bb AND (a@0 IS NULL OR a@0 > ac) ] + " + ); + // There should be no more batches + assert!(stream.next().await.is_none()); +} + +#[tokio::test] +async fn test_topk_filter_passes_through_coalesce_partitions() { + // Create multiple batches for different partitions + let batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8, ["bd", "bc"]), + ("c", Float64, [1.0, 2.0]) + ) + .unwrap(), + record_batch!( + ("a", Utf8, ["ac", "ad"]), + ("b", Utf8, ["bb", "ba"]), + ("c", Float64, [2.0, 1.0]) + ) + .unwrap(), + ]; + + // Create a source that supports all batches + let source = Arc::new(TestSource::new(schema(), true, batches)); + + let base_config = + FileScanConfigBuilder::new(ObjectStoreUrl::parse("test://").unwrap(), source) + .with_file_groups(vec![ + // Partition 0 + FileGroup::new(vec![PartitionedFile::new("test1.parquet", 123)]), + // Partition 1 + FileGroup::new(vec![PartitionedFile::new("test2.parquet", 123)]), + ]) + .build(); + + let scan = DataSourceExec::from_data_source(base_config); + + // Add CoalescePartitionsExec to merge the two partitions + let coalesce = Arc::new(CoalescePartitionsExec::new(scan)) as Arc; + + // Add SortExec with TopK + let plan = Arc::new( + SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr::new( + col("b", &schema()).unwrap(), + SortOptions::new(true, false), + )]) + .unwrap(), + coalesce, + ) + .with_fetch(Some(1)), + ) as Arc; + + // Test optimization - the filter SHOULD pass through CoalescePartitionsExec + // if it properly implements from_children (not all_unsupported) + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new_post_optimization(), true), + @r" + OptimizationTest: + input: + - SortExec: TopK(fetch=1), expr=[b@1 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - DataSourceExec: file_groups={2 groups: [[test1.parquet], [test2.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - SortExec: TopK(fetch=1), expr=[b@1 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - DataSourceExec: file_groups={2 groups: [[test1.parquet], [test2.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] + " + ); +} + +#[tokio::test] +async fn test_hashjoin_dynamic_filter_pushdown() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + // Create build side with limited values + let build_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8, ["ba", "bb"]), + ("c", Float64, [1.0, 2.0]) // Extra column not used in join + ) + .unwrap(), + ]; + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .with_batches(build_batches) + .build(); + + // Create probe side with more values + let probe_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab", "ac", "ad"]), + ("b", Utf8, ["ba", "bb", "bc", "bd"]), + ("e", Float64, [1.0, 2.0, 3.0, 4.0]) // Extra column not used in join + ) + .unwrap(), + ]; + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("e", DataType::Float64, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .with_batches(probe_batches) + .build(); + + // Create HashJoinExec with dynamic filter + let on = vec![ + ( + col("a", &build_side_schema).unwrap(), + col("a", &probe_side_schema).unwrap(), + ), + ( + col("b", &build_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ), + ]; + let plan = Arc::new( + HashJoinExec::try_new( + build_scan, + probe_scan, + on, + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ) as Arc; + + // expect the predicate to be pushed down into the probe side DataSource + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new_post_optimization(), true), + @r" + OptimizationTest: + input: + - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true + output: + Ok: + - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] + ", + ); + + // Actually apply the optimization to the plan and execute to see the filter in action + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + config.optimizer.enable_dynamic_filter_pushdown = true; + let plan = FilterPushdown::new_post_optimization() + .optimize(plan, &config) + .unwrap(); + + // Test for https://github.com/apache/datafusion/pull/17371: dynamic filter linking survives `with_new_children` + let children = plan.children().into_iter().map(Arc::clone).collect(); + let plan = plan.with_new_children(children).unwrap(); + + let config = SessionConfig::new().with_batch_size(10); + let session_ctx = SessionContext::new_with_config(config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + let mut stream = plan.execute(0, Arc::clone(&task_ctx)).unwrap(); + // Iterate one batch + stream.next().await.unwrap().unwrap(); + + // Now check what our filter looks like + insta::assert_snapshot!( + format!("{}", format_plan_for_test(&plan)), + @r" + - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= ba AND b@1 <= bb AND struct(a@0, b@1) IN (SET) ([{c0:aa,c1:ba}, {c0:ab,c1:bb}]) ] + " + ); +} + +#[tokio::test] +async fn test_hashjoin_dynamic_filter_pushdown_partitioned() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + // Rough sketch of the MRE we're trying to recreate: + // COPY (select i as k from generate_series(1, 10000000) as t(i)) + // TO 'test_files/scratch/push_down_filter/t1.parquet' + // STORED AS PARQUET; + // COPY (select i as k, i as v from generate_series(1, 10000000) as t(i)) + // TO 'test_files/scratch/push_down_filter/t2.parquet' + // STORED AS PARQUET; + // create external table t1 stored as parquet location 'test_files/scratch/push_down_filter/t1.parquet'; + // create external table t2 stored as parquet location 'test_files/scratch/push_down_filter/t2.parquet'; + // explain + // select * + // from t1 + // join t2 on t1.k = t2.k; + // +---------------+------------------------------------------------------------+ + // | plan_type | plan | + // +---------------+------------------------------------------------------------+ + // | physical_plan | ┌───────────────────────────┐ | + // | | │ HashJoinExec │ | + // | | │ -------------------- ├──────────────┐ | + // | | │ on: (k = k) │ │ | + // | | └─────────────┬─────────────┘ │ | + // | | ┌─────────────┴─────────────┐┌─────────────┴─────────────┐ | + // | | │ RepartitionExec ││ RepartitionExec │ | + // | | │ -------------------- ││ -------------------- │ | + // | | │ partition_count(in->out): ││ partition_count(in->out): │ | + // | | │ 12 -> 12 ││ 12 -> 12 │ | + // | | │ ││ │ | + // | | │ partitioning_scheme: ││ partitioning_scheme: │ | + // | | │ Hash([k@0], 12) ││ Hash([k@0], 12) │ | + // | | └─────────────┬─────────────┘└─────────────┬─────────────┘ | + // | | ┌─────────────┴─────────────┐┌─────────────┴─────────────┐ | + // | | │ DataSourceExec ││ DataSourceExec │ | + // | | │ -------------------- ││ -------------------- │ | + // | | │ files: 12 ││ files: 12 │ | + // | | │ format: parquet ││ format: parquet │ | + // | | │ ││ predicate: true │ | + // | | └───────────────────────────┘└───────────────────────────┘ | + // | | | + // +---------------+------------------------------------------------------------+ + + // Create build side with limited values + let build_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8, ["ba", "bb"]), + ("c", Float64, [1.0, 2.0]) // Extra column not used in join + ) + .unwrap(), + ]; + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .with_batches(build_batches) + .build(); + + // Create probe side with more values + let probe_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab", "ac", "ad"]), + ("b", Utf8, ["ba", "bb", "bc", "bd"]), + ("e", Float64, [1.0, 2.0, 3.0, 4.0]) // Extra column not used in join + ) + .unwrap(), + ]; + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("e", DataType::Float64, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .with_batches(probe_batches) + .build(); + + // Create RepartitionExec nodes for both sides with hash partitioning on join keys + let partition_count = 12; + + // Build side: DataSource -> RepartitionExec (Hash) + let build_hash_exprs = vec![ + col("a", &build_side_schema).unwrap(), + col("b", &build_side_schema).unwrap(), + ]; + let build_repartition = Arc::new( + RepartitionExec::try_new( + build_scan, + Partitioning::Hash(build_hash_exprs, partition_count), + ) + .unwrap(), + ); + + // Probe side: DataSource -> RepartitionExec (Hash) + let probe_hash_exprs = vec![ + col("a", &probe_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ]; + let probe_repartition = Arc::new( + RepartitionExec::try_new( + Arc::clone(&probe_scan), + Partitioning::Hash(probe_hash_exprs, partition_count), + ) + .unwrap(), + ); + + // Create HashJoinExec with partitioned inputs + let on = vec![ + ( + col("a", &build_side_schema).unwrap(), + col("a", &probe_side_schema).unwrap(), + ), + ( + col("b", &build_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ), + ]; + let hash_join = Arc::new( + HashJoinExec::try_new( + build_repartition, + probe_repartition, + on, + None, + &JoinType::Inner, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ); + + // Top-level CoalescePartitionsExec + let cp = Arc::new(CoalescePartitionsExec::new(hash_join)) as Arc; + // Add a sort for deterministic output + let plan = Arc::new(SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr::new( + col("a", &probe_side_schema).unwrap(), + SortOptions::new(true, false), // descending, nulls_first + )]) + .unwrap(), + cp, + )) as Arc; + + // expect the predicate to be pushed down into the probe side DataSource + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new_post_optimization(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] + " + ); + + // Actually apply the optimization to the plan and execute to see the filter in action + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + config.optimizer.enable_dynamic_filter_pushdown = true; + let plan = FilterPushdown::new_post_optimization() + .optimize(plan, &config) + .unwrap(); + let config = SessionConfig::new().with_batch_size(10); + let session_ctx = SessionContext::new_with_config(config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + let batches = collect(Arc::clone(&plan), Arc::clone(&task_ctx)) + .await + .unwrap(); + + // Now check what our filter looks like + #[cfg(not(feature = "force_hash_collisions"))] + insta::assert_snapshot!( + format!("{}", format_plan_for_test(&plan)), + @r" + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ CASE hash_repartition % 12 WHEN 5 THEN a@0 >= ab AND a@0 <= ab AND b@1 >= bb AND b@1 <= bb AND struct(a@0, b@1) IN (SET) ([{c0:ab,c1:bb}]) WHEN 8 THEN a@0 >= aa AND a@0 <= aa AND b@1 >= ba AND b@1 <= ba AND struct(a@0, b@1) IN (SET) ([{c0:aa,c1:ba}]) ELSE false END ] + " + ); + + // When hash collisions force all data into a single partition, we optimize away the CASE expression. + // This avoids calling create_hashes() for every row on the probe side, since hash % 1 == 0 always, + // meaning the WHEN 0 branch would always match. This optimization is also important for primary key + // joins or any scenario where all build-side data naturally lands in one partition. + #[cfg(feature = "force_hash_collisions")] + insta::assert_snapshot!( + format!("{}", format_plan_for_test(&plan)), + @r" + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= ba AND b@1 <= bb AND struct(a@0, b@1) IN (SET) ([{c0:aa,c1:ba}, {c0:ab,c1:bb}]) ] + " + ); + + let result = format!("{}", pretty_format_batches(&batches).unwrap()); + + let probe_scan_metrics = probe_scan.metrics().unwrap(); + + // The probe side had 4 rows, but after applying the dynamic filter only 2 rows should remain. + // The number of output rows from the probe side scan should stay consistent across executions. + // Issue: https://github.com/apache/datafusion/issues/17451 + assert_eq!(probe_scan_metrics.output_rows().unwrap(), 2); + + insta::assert_snapshot!( + result, + @r" + +----+----+-----+----+----+-----+ + | a | b | c | a | b | e | + +----+----+-----+----+----+-----+ + | ab | bb | 2.0 | ab | bb | 2.0 | + | aa | ba | 1.0 | aa | ba | 1.0 | + +----+----+-----+----+----+-----+ + ", + ); +} + +#[tokio::test] +async fn test_hashjoin_dynamic_filter_pushdown_collect_left() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + let build_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8, ["ba", "bb"]), + ("c", Float64, [1.0, 2.0]) // Extra column not used in join + ) + .unwrap(), + ]; + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .with_batches(build_batches) + .build(); + + // Create probe side with more values + let probe_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab", "ac", "ad"]), + ("b", Utf8, ["ba", "bb", "bc", "bd"]), + ("e", Float64, [1.0, 2.0, 3.0, 4.0]) // Extra column not used in join + ) + .unwrap(), + ]; + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("e", DataType::Float64, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .with_batches(probe_batches) + .build(); + + // Create RepartitionExec nodes for both sides with hash partitioning on join keys + let partition_count = 12; + + // Probe side: DataSource -> RepartitionExec(Hash) + let probe_hash_exprs = vec![ + col("a", &probe_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ]; + let probe_repartition = Arc::new( + RepartitionExec::try_new( + Arc::clone(&probe_scan), + Partitioning::Hash(probe_hash_exprs, partition_count), // create multi partitions on probSide + ) + .unwrap(), + ); + + let on = vec![ + ( + col("a", &build_side_schema).unwrap(), + col("a", &probe_side_schema).unwrap(), + ), + ( + col("b", &build_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ), + ]; + let hash_join = Arc::new( + HashJoinExec::try_new( + build_scan, + probe_repartition, + on, + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ); + + // Top-level CoalescePartitionsExec + let cp = Arc::new(CoalescePartitionsExec::new(hash_join)) as Arc; + // Add a sort for deterministic output + let plan = Arc::new(SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr::new( + col("a", &probe_side_schema).unwrap(), + SortOptions::new(true, false), // descending, nulls_first + )]) + .unwrap(), + cp, + )) as Arc; + + // expect the predicate to be pushed down into the probe side DataSource + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new_post_optimization(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] + " + ); + + // Actually apply the optimization to the plan and execute to see the filter in action + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + config.optimizer.enable_dynamic_filter_pushdown = true; + let plan = FilterPushdown::new_post_optimization() + .optimize(plan, &config) + .unwrap(); + let config = SessionConfig::new().with_batch_size(10); + let session_ctx = SessionContext::new_with_config(config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + let batches = collect(Arc::clone(&plan), Arc::clone(&task_ctx)) + .await + .unwrap(); + + // Now check what our filter looks like + insta::assert_snapshot!( + format!("{}", format_plan_for_test(&plan)), + @r" + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= ba AND b@1 <= bb AND struct(a@0, b@1) IN (SET) ([{c0:aa,c1:ba}, {c0:ab,c1:bb}]) ] + " + ); + + let result = format!("{}", pretty_format_batches(&batches).unwrap()); + + let probe_scan_metrics = probe_scan.metrics().unwrap(); + + // The probe side had 4 rows, but after applying the dynamic filter only 2 rows should remain. + // The number of output rows from the probe side scan should stay consistent across executions. + // Issue: https://github.com/apache/datafusion/issues/17451 + assert_eq!(probe_scan_metrics.output_rows().unwrap(), 2); + + insta::assert_snapshot!( + result, + @r" + +----+----+-----+----+----+-----+ + | a | b | c | a | b | e | + +----+----+-----+----+----+-----+ + | ab | bb | 2.0 | ab | bb | 2.0 | + | aa | ba | 1.0 | aa | ba | 1.0 | + +----+----+-----+----+----+-----+ + ", + ); +} + +#[tokio::test] +async fn test_nested_hashjoin_dynamic_filter_pushdown() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + // Create test data for three tables: t1, t2, t3 + // t1: small table with limited values (will be build side of outer join) + let t1_batches = vec![ + record_batch!(("a", Utf8, ["aa", "ab"]), ("x", Float64, [1.0, 2.0])).unwrap(), + ]; + let t1_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("x", DataType::Float64, false), + ])); + let t1_scan = TestScanBuilder::new(Arc::clone(&t1_schema)) + .with_support(true) + .with_batches(t1_batches) + .build(); + + // t2: larger table (will be probe side of inner join, build side of outer join) + let t2_batches = vec![ + record_batch!( + ("b", Utf8, ["aa", "ab", "ac", "ad", "ae"]), + ("c", Utf8, ["ca", "cb", "cc", "cd", "ce"]), + ("y", Float64, [1.0, 2.0, 3.0, 4.0, 5.0]) + ) + .unwrap(), + ]; + let t2_schema = Arc::new(Schema::new(vec![ + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Utf8, false), + Field::new("y", DataType::Float64, false), + ])); + let t2_scan = TestScanBuilder::new(Arc::clone(&t2_schema)) + .with_support(true) + .with_batches(t2_batches) + .build(); + + // t3: largest table (will be probe side of inner join) + let t3_batches = vec![ + record_batch!( + ("d", Utf8, ["ca", "cb", "cc", "cd", "ce", "cf", "cg", "ch"]), + ("z", Float64, [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + ) + .unwrap(), + ]; + let t3_schema = Arc::new(Schema::new(vec![ + Field::new("d", DataType::Utf8, false), + Field::new("z", DataType::Float64, false), + ])); + let t3_scan = TestScanBuilder::new(Arc::clone(&t3_schema)) + .with_support(true) + .with_batches(t3_batches) + .build(); + + // Create nested join structure: + // Join (t1.a = t2.b) + // / \ + // t1 Join(t2.c = t3.d) + // / \ + // t2 t3 + + // First create inner join: t2.c = t3.d + let inner_join_on = + vec![(col("c", &t2_schema).unwrap(), col("d", &t3_schema).unwrap())]; + let inner_join = Arc::new( + HashJoinExec::try_new( + t2_scan, + t3_scan, + inner_join_on, + None, + &JoinType::Inner, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ); + + // Then create outer join: t1.a = t2.b (from inner join result) + let outer_join_on = vec![( + col("a", &t1_schema).unwrap(), + col("b", &inner_join.schema()).unwrap(), + )]; + let outer_join = Arc::new( + HashJoinExec::try_new( + t1_scan, + inner_join as Arc, + outer_join_on, + None, + &JoinType::Inner, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ) as Arc; + + // Test that dynamic filters are pushed down correctly through nested joins + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&outer_join), FilterPushdown::new_post_optimization(), true), + @r" + OptimizationTest: + input: + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, b@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, x], file_type=test, pushdown_supported=true + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, d@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[b, c, y], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, z], file_type=test, pushdown_supported=true + output: + Ok: + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, b@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, x], file_type=test, pushdown_supported=true + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, d@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[b, c, y], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, z], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] + ", + ); + + // Execute the plan to verify the dynamic filters are properly updated + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + config.optimizer.enable_dynamic_filter_pushdown = true; + let plan = FilterPushdown::new_post_optimization() + .optimize(outer_join, &config) + .unwrap(); + let config = SessionConfig::new().with_batch_size(10); + let session_ctx = SessionContext::new_with_config(config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + let mut stream = plan.execute(0, Arc::clone(&task_ctx)).unwrap(); + // Execute to populate the dynamic filters + stream.next().await.unwrap().unwrap(); + + // Verify that both the inner and outer join have updated dynamic filters + insta::assert_snapshot!( + format!("{}", format_plan_for_test(&plan)), + @r" + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, b@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, x], file_type=test, pushdown_supported=true + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, d@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[b, c, y], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ b@0 >= aa AND b@0 <= ab AND b@0 IN (SET) ([aa, ab]) ] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, z], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ d@0 >= ca AND d@0 <= cb AND d@0 IN (SET) ([ca, cb]) ] + " + ); +} + +#[tokio::test] +async fn test_hashjoin_parent_filter_pushdown() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + // Create build side with limited values + let build_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8, ["ba", "bb"]), + ("c", Float64, [1.0, 2.0]) + ) + .unwrap(), + ]; + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .with_batches(build_batches) + .build(); + + // Create probe side with more values + let probe_batches = vec![ + record_batch!( + ("d", Utf8, ["aa", "ab", "ac", "ad"]), + ("e", Utf8, ["ba", "bb", "bc", "bd"]), + ("f", Float64, [1.0, 2.0, 3.0, 4.0]) + ) + .unwrap(), + ]; + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("d", DataType::Utf8, false), + Field::new("e", DataType::Utf8, false), + Field::new("f", DataType::Float64, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .with_batches(probe_batches) + .build(); + + // Create HashJoinExec + let on = vec![( + col("a", &build_side_schema).unwrap(), + col("d", &probe_side_schema).unwrap(), + )]; + let join = Arc::new( + HashJoinExec::try_new( + build_scan, + probe_scan, + on, + None, + &JoinType::Inner, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ); + + // Create filters that can be pushed down to different sides + // We need to create filters in the context of the join output schema + let join_schema = join.schema(); + + // Filter on build side column: a = 'aa' + let left_filter = col_lit_predicate("a", "aa", &join_schema); + // Filter on probe side column: e = 'ba' + let right_filter = col_lit_predicate("e", "ba", &join_schema); + // Filter that references both sides: a = d (should not be pushed down) + let cross_filter = Arc::new(BinaryExpr::new( + col("a", &join_schema).unwrap(), + Operator::Eq, + col("d", &join_schema).unwrap(), + )) as Arc; + + let filter = + Arc::new(FilterExec::try_new(left_filter, Arc::clone(&join) as _).unwrap()); + let filter = Arc::new(FilterExec::try_new(right_filter, filter).unwrap()); + let plan = Arc::new(FilterExec::try_new(cross_filter, filter).unwrap()) + as Arc; + + // Test that filters are pushed down correctly to each side of the join + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = d@3 + - FilterExec: e@4 = ba + - FilterExec: a@0 = aa + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, d@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true + output: + Ok: + - FilterExec: a@0 = d@3 + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, d@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = aa + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true, predicate=e@1 = ba + " + ); +} + +#[test] +fn test_hashjoin_parent_filter_pushdown_same_column_names() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("build_val", DataType::Utf8, false), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .build(); + + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("probe_val", DataType::Utf8, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .build(); + + let on = vec![( + col("id", &build_side_schema).unwrap(), + col("id", &probe_side_schema).unwrap(), + )]; + let join = Arc::new( + HashJoinExec::try_new( + build_scan, + probe_scan, + on, + None, + &JoinType::Inner, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ); + + let join_schema = join.schema(); + + let build_id_filter = col_lit_predicate("id", "aa", &join_schema); + let probe_val_filter = col_lit_predicate("probe_val", "x", &join_schema); + + let filter = + Arc::new(FilterExec::try_new(build_id_filter, Arc::clone(&join) as _).unwrap()); + let plan = Arc::new(FilterExec::try_new(probe_val_filter, filter).unwrap()) + as Arc; + + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: probe_val@3 = x + - FilterExec: id@0 = aa + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(id@0, id@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id, build_val], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id, probe_val], file_type=test, pushdown_supported=true + output: + Ok: + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(id@0, id@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id, build_val], file_type=test, pushdown_supported=true, predicate=id@0 = aa + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id, probe_val], file_type=test, pushdown_supported=true, predicate=probe_val@1 = x + " + ); +} + +#[test] +fn test_hashjoin_parent_filter_pushdown_mark_join() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + let left_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("val", DataType::Utf8, false), + ])); + let left_scan = TestScanBuilder::new(Arc::clone(&left_schema)) + .with_support(true) + .build(); + + let right_schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::Utf8, false)])); + let right_scan = TestScanBuilder::new(Arc::clone(&right_schema)) + .with_support(true) + .build(); + + let on = vec![( + col("id", &left_schema).unwrap(), + col("id", &right_schema).unwrap(), + )]; + let join = Arc::new( + HashJoinExec::try_new( + left_scan, + right_scan, + on, + None, + &JoinType::LeftMark, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ); + + let join_schema = join.schema(); + + let left_filter = col_lit_predicate("val", "x", &join_schema); + let mark_filter = col_lit_predicate("mark", true, &join_schema); + + let filter = + Arc::new(FilterExec::try_new(left_filter, Arc::clone(&join) as _).unwrap()); + let plan = Arc::new(FilterExec::try_new(mark_filter, filter).unwrap()) + as Arc; + + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: mark@2 = true + - FilterExec: val@1 = x + - HashJoinExec: mode=Partitioned, join_type=LeftMark, on=[(id@0, id@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id, val], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id], file_type=test, pushdown_supported=true + output: + Ok: + - FilterExec: mark@2 = true + - HashJoinExec: mode=Partitioned, join_type=LeftMark, on=[(id@0, id@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id, val], file_type=test, pushdown_supported=true, predicate=val@1 = x + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id], file_type=test, pushdown_supported=true + " + ); +} + +/// Test that filters on join key columns are pushed to both sides of semi/anti joins. +/// For LeftSemi/LeftAnti, the output only contains left columns, but filters on +/// join key columns can also be pushed to the right (non-preserved) side because +/// the equijoin condition guarantees the key values match. +#[test] +fn test_hashjoin_parent_filter_pushdown_semi_anti_join() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + let left_schema = Arc::new(Schema::new(vec![ + Field::new("k", DataType::Utf8, false), + Field::new("v", DataType::Utf8, false), + ])); + let left_scan = TestScanBuilder::new(Arc::clone(&left_schema)) + .with_support(true) + .build(); + + let right_schema = Arc::new(Schema::new(vec![ + Field::new("k", DataType::Utf8, false), + Field::new("w", DataType::Utf8, false), + ])); + let right_scan = TestScanBuilder::new(Arc::clone(&right_schema)) + .with_support(true) + .build(); + + let on = vec![( + col("k", &left_schema).unwrap(), + col("k", &right_schema).unwrap(), + )]; + + let join = Arc::new( + HashJoinExec::try_new( + left_scan, + right_scan, + on, + None, + &JoinType::LeftSemi, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ); + + let join_schema = join.schema(); + // Filter on join key column: k = 'x' — should be pushed to BOTH sides + let key_filter = col_lit_predicate("k", "x", &join_schema); + // Filter on non-key column: v = 'y' — should only be pushed to the left side + let val_filter = col_lit_predicate("v", "y", &join_schema); + + let filter = + Arc::new(FilterExec::try_new(key_filter, Arc::clone(&join) as _).unwrap()); + let plan = Arc::new(FilterExec::try_new(val_filter, filter).unwrap()) + as Arc; + + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: v@1 = y + - FilterExec: k@0 = x + - HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(k@0, k@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[k, v], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[k, w], file_type=test, pushdown_supported=true + output: + Ok: + - HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(k@0, k@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[k, v], file_type=test, pushdown_supported=true, predicate=k@0 = x AND v@1 = y + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[k, w], file_type=test, pushdown_supported=true, predicate=k@0 = x + " + ); +} + +/// Integration test for dynamic filter pushdown with TopK. +/// We use an integration test because there are complex interactions in the optimizer rules +/// that the unit tests applying a single optimizer rule do not cover. +#[tokio::test] +async fn test_topk_dynamic_filter_pushdown_integration() { + let store = Arc::new(InMemory::new()) as Arc; + let mut cfg = SessionConfig::new(); + cfg.options_mut().execution.parquet.pushdown_filters = true; + cfg.options_mut().execution.parquet.max_row_group_size = 128; + let ctx = SessionContext::new_with_config(cfg); + ctx.register_object_store( + ObjectStoreUrl::parse("memory://").unwrap().as_ref(), + Arc::clone(&store), + ); + ctx.sql( + r" +COPY ( + SELECT 1372708800 + value AS t + FROM generate_series(0, 99999) + ORDER BY t + ) TO 'memory:///1.parquet' +STORED AS PARQUET; + ", + ) + .await + .unwrap() + .collect() + .await + .unwrap(); + + // Register the file with the context + ctx.register_parquet( + "topk_pushdown", + "memory:///1.parquet", + ParquetReadOptions::default(), + ) + .await + .unwrap(); + + // Create a TopK query that will use dynamic filter pushdown + // Note that we use t * t as the order by expression to avoid + // the order pushdown optimizer from optimizing away the TopK. + let df = ctx + .sql(r"EXPLAIN ANALYZE SELECT t FROM topk_pushdown ORDER BY t * t LIMIT 10;") + .await + .unwrap(); + let batches = df.collect().await.unwrap(); + let explain = format!("{}", pretty_format_batches(&batches).unwrap()); + + assert!(explain.contains("output_rows=128")); // Read 1 row group + assert!(explain.contains("t@0 < 1884329474306198481")); // Dynamic filter was applied + assert!( + explain.contains("pushdown_rows_matched=128, pushdown_rows_pruned=99.87 K"), + "{explain}" + ); + // Pushdown pruned most rows +} + +#[test] +fn test_filter_pushdown_through_union() { + let scan1 = TestScanBuilder::new(schema()).with_support(true).build(); + let scan2 = TestScanBuilder::new(schema()).with_support(true).build(); + + let union = UnionExec::try_new(vec![scan1, scan2]).unwrap(); + + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, union).unwrap()); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - UnionExec + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - UnionExec + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + " + ); +} + +#[test] +fn test_filter_pushdown_through_union_mixed_support() { + // Test case where one child supports filter pushdown and one doesn't + let scan1 = TestScanBuilder::new(schema()).with_support(true).build(); + let scan2 = TestScanBuilder::new(schema()).with_support(false).build(); + + let union = UnionExec::try_new(vec![scan1, scan2]).unwrap(); + + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, union).unwrap()); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - UnionExec + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + output: + Ok: + - UnionExec + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + " + ); +} + +#[test] +fn test_filter_pushdown_through_union_does_not_support() { + // Test case where one child supports filter pushdown and one doesn't + let scan1 = TestScanBuilder::new(schema()).with_support(false).build(); + let scan2 = TestScanBuilder::new(schema()).with_support(false).build(); + + let union = UnionExec::try_new(vec![scan1, scan2]).unwrap(); + + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, union).unwrap()); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - UnionExec + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + output: + Ok: + - UnionExec + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + " + ); +} + +/// Schema: +/// a: String +/// b: String +/// c: f64 +static TEST_SCHEMA: LazyLock = LazyLock::new(|| { + let fields = vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + ]; + Arc::new(Schema::new(fields)) +}); + +fn schema() -> SchemaRef { + Arc::clone(&TEST_SCHEMA) +} + +struct ProjectionDynFilterTestCase { + schema: SchemaRef, + batches: Vec, + projection: Vec<(Arc, String)>, + sort_expr: PhysicalSortExpr, + expected_plans: Vec, +} + +async fn run_projection_dyn_filter_case(case: ProjectionDynFilterTestCase) { + let ProjectionDynFilterTestCase { + schema, + batches, + projection, + sort_expr, + expected_plans, + } = case; + + let scan = TestScanBuilder::new(Arc::clone(&schema)) + .with_support(true) + .with_batches(batches) + .build(); + + let projection_exec = Arc::new(ProjectionExec::try_new(projection, scan).unwrap()); + + let sort = Arc::new( + SortExec::new(LexOrdering::new(vec![sort_expr]).unwrap(), projection_exec) + .with_fetch(Some(2)), + ) as Arc; + + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + config.optimizer.enable_dynamic_filter_pushdown = true; + + let optimized_plan = FilterPushdown::new_post_optimization() + .optimize(Arc::clone(&sort), &config) + .unwrap(); + + pretty_assertions::assert_eq!( + format_plan_for_test(&optimized_plan).trim(), + expected_plans[0].trim() + ); + + let config = SessionConfig::new().with_batch_size(2); + let session_ctx = SessionContext::new_with_config(config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + let mut stream = optimized_plan.execute(0, Arc::clone(&task_ctx)).unwrap(); + for (idx, expected_plan) in expected_plans.iter().enumerate().skip(1) { + stream.next().await.unwrap().unwrap(); + let formatted_plan = format_plan_for_test(&optimized_plan); + pretty_assertions::assert_eq!( + formatted_plan.trim(), + expected_plan.trim(), + "Mismatch at iteration {}", + idx + ); + } +} + +#[tokio::test] +async fn test_topk_with_projection_transformation_on_dyn_filter() { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + ])); + let simple_abc = vec![ + record_batch!( + ("a", Int32, [1, 2, 3]), + ("b", Utf8, ["x", "y", "z"]), + ("c", Float64, [1.0, 2.0, 3.0]) + ) + .unwrap(), + ]; + + // Case 1: Reordering [b, a] + run_projection_dyn_filter_case(ProjectionDynFilterTestCase { + schema: Arc::clone(&schema), + batches: simple_abc.clone(), + projection: vec![ + (col("b", &schema).unwrap(), "b".to_string()), + (col("a", &schema).unwrap(), "a".to_string()), + ], + sort_expr: PhysicalSortExpr::new( + Arc::new(Column::new("a", 1)), + SortOptions::default(), + ), + expected_plans: vec![ +r#" - SortExec: TopK(fetch=2), expr=[a@1 ASC], preserve_partitioning=[false] + - ProjectionExec: expr=[b@1 as b, a@0 as a] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ]"#.to_string(), +r#" - SortExec: TopK(fetch=2), expr=[a@1 ASC], preserve_partitioning=[false], filter=[a@1 IS NULL OR a@1 < 2] + - ProjectionExec: expr=[b@1 as b, a@0 as a] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 IS NULL OR a@0 < 2 ]"#.to_string()] + }) + .await; + + // Case 2: Pruning [a] + run_projection_dyn_filter_case(ProjectionDynFilterTestCase { + schema: Arc::clone(&schema), + batches: simple_abc.clone(), + projection: vec![(col("a", &schema).unwrap(), "a".to_string())], + sort_expr: PhysicalSortExpr::new( + Arc::new(Column::new("a", 0)), + SortOptions::default(), + ), + expected_plans: vec![ + r#" - SortExec: TopK(fetch=2), expr=[a@0 ASC], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as a] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ]"#.to_string(), + r#" - SortExec: TopK(fetch=2), expr=[a@0 ASC], preserve_partitioning=[false], filter=[a@0 IS NULL OR a@0 < 2] + - ProjectionExec: expr=[a@0 as a] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 IS NULL OR a@0 < 2 ]"#.to_string(), + ], + }) + .await; + + // Case 3: Identity [a, b] + run_projection_dyn_filter_case(ProjectionDynFilterTestCase { + schema: Arc::clone(&schema), + batches: simple_abc.clone(), + projection: vec![ + (col("a", &schema).unwrap(), "a".to_string()), + (col("b", &schema).unwrap(), "b".to_string()), + ], + sort_expr: PhysicalSortExpr::new( + Arc::new(Column::new("a", 0)), + SortOptions::default(), + ), + expected_plans: vec![ + r#" - SortExec: TopK(fetch=2), expr=[a@0 ASC], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as a, b@1 as b] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ]"#.to_string(), + r#" - SortExec: TopK(fetch=2), expr=[a@0 ASC], preserve_partitioning=[false], filter=[a@0 IS NULL OR a@0 < 2] + - ProjectionExec: expr=[a@0 as a, b@1 as b] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 IS NULL OR a@0 < 2 ]"#.to_string(), + ], + }) + .await; + + // Case 4: Expressions [a + 1, b] + run_projection_dyn_filter_case(ProjectionDynFilterTestCase { + schema: Arc::clone(&schema), + batches: simple_abc.clone(), + projection: vec![ + ( + Arc::new(BinaryExpr::new( + col("a", &schema).unwrap(), + Operator::Plus, + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + )), + "a_plus_1".to_string(), + ), + (col("b", &schema).unwrap(), "b".to_string()), + ], + sort_expr: PhysicalSortExpr::new( + Arc::new(Column::new("a_plus_1", 0)), + SortOptions::default(), + ), + expected_plans: vec![ + r#" - SortExec: TopK(fetch=2), expr=[a_plus_1@0 ASC], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 + 1 as a_plus_1, b@1 as b] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ]"#.to_string(), + r#" - SortExec: TopK(fetch=2), expr=[a_plus_1@0 ASC], preserve_partitioning=[false], filter=[a_plus_1@0 IS NULL OR a_plus_1@0 < 3] + - ProjectionExec: expr=[a@0 + 1 as a_plus_1, b@1 as b] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 + 1 IS NULL OR a@0 + 1 < 3 ]"#.to_string(), + ], + }) + .await; + + // Case 5: [a as b, b as a] (swapped columns) + run_projection_dyn_filter_case(ProjectionDynFilterTestCase { + schema: Arc::clone(&schema), + batches: simple_abc.clone(), + projection: vec![ + (col("a", &schema).unwrap(), "b".to_string()), + (col("b", &schema).unwrap(), "a".to_string()), + ], + sort_expr: PhysicalSortExpr::new( + Arc::new(Column::new("b", 0)), + SortOptions::default(), + ), + expected_plans: vec![ + r#" - SortExec: TopK(fetch=2), expr=[b@0 ASC], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as b, b@1 as a] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ]"#.to_string(), + r#" - SortExec: TopK(fetch=2), expr=[b@0 ASC], preserve_partitioning=[false], filter=[b@0 IS NULL OR b@0 < 2] + - ProjectionExec: expr=[a@0 as b, b@1 as a] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 IS NULL OR a@0 < 2 ]"#.to_string(), + ], + }) + .await; + + // Case 6: Confusing expr [a + 1 as a, b] + run_projection_dyn_filter_case(ProjectionDynFilterTestCase { + schema: Arc::clone(&schema), + batches: simple_abc.clone(), + projection: vec![ + ( + Arc::new(BinaryExpr::new( + col("a", &schema).unwrap(), + Operator::Plus, + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + )), + "a".to_string(), + ), + (col("b", &schema).unwrap(), "b".to_string()), + ], + sort_expr: PhysicalSortExpr::new( + Arc::new(Column::new("a", 0)), + SortOptions::default(), + ), + expected_plans: vec![ + r#" - SortExec: TopK(fetch=2), expr=[a@0 ASC], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 + 1 as a, b@1 as b] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ]"#.to_string(), + r#" - SortExec: TopK(fetch=2), expr=[a@0 ASC], preserve_partitioning=[false], filter=[a@0 IS NULL OR a@0 < 3] + - ProjectionExec: expr=[a@0 + 1 as a, b@1 as b] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 + 1 IS NULL OR a@0 + 1 < 3 ]"#.to_string(), + ], + }) + .await; +} + +/// Returns a predicate that is a binary expression col = lit +fn col_lit_predicate( + column_name: &str, + scalar_value: impl Into, + schema: &Schema, +) -> Arc { + let scalar_value = scalar_value.into(); + Arc::new(BinaryExpr::new( + Arc::new(Column::new_with_schema(column_name, schema).unwrap()), + Operator::Eq, + Arc::new(Literal::new(scalar_value)), + )) +} + +// ==== Aggregate Dynamic Filter tests ==== + +// ---- Test Utilities ---- +struct AggregateDynFilterCase<'a> { + schema: SchemaRef, + batches: Vec, + aggr_exprs: Vec, + expected_before: Option<&'a str>, + expected_after: Option<&'a str>, + scan_support: bool, +} + +async fn run_aggregate_dyn_filter_case(case: AggregateDynFilterCase<'_>) { + let AggregateDynFilterCase { + schema, + batches, + aggr_exprs, + expected_before, + expected_after, + scan_support, + } = case; + + let scan = TestScanBuilder::new(Arc::clone(&schema)) + .with_support(scan_support) + .with_batches(batches) + .build(); + + let aggr_exprs: Vec<_> = aggr_exprs + .into_iter() + .map(|expr| Arc::new(expr) as Arc) + .collect(); + let aggr_len = aggr_exprs.len(); + + let plan: Arc = Arc::new( + AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(vec![]), + aggr_exprs, + vec![None; aggr_len], + scan, + Arc::clone(&schema), + ) + .unwrap(), + ); + + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + config.optimizer.enable_dynamic_filter_pushdown = true; + + let optimized = FilterPushdown::new_post_optimization() + .optimize(plan, &config) + .unwrap(); + + let before = format_plan_for_test(&optimized); + if let Some(expected) = expected_before { + assert!( + before.contains(expected), + "expected `{expected}` before execution, got: {before}" + ); + } else { + assert!( + !before.contains("DynamicFilter ["), + "dynamic filter unexpectedly present before execution: {before}" + ); + } + + let session_ctx = SessionContext::new(); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let task_ctx = session_ctx.state().task_ctx(); + let mut stream = optimized.execute(0, Arc::clone(&task_ctx)).unwrap(); + let _ = stream.next().await.transpose().unwrap(); + + let after = format_plan_for_test(&optimized); + if let Some(expected) = expected_after { + assert!( + after.contains(expected), + "expected `{expected}` after execution, got: {after}" + ); + } else { + assert!( + !after.contains("DynamicFilter ["), + "dynamic filter unexpectedly present after execution: {after}" + ); + } +} + +// ---- Test Cases ---- +// Cases covered below: +// 1. `min(a)` and `max(a)` baseline. +// 2. Unsupported expression input (`min(a+1)`). +// 3. Multiple supported columns (same column vs different columns). +// 4. Mixed supported + unsupported aggregates. +// 5. Entirely NULL input to surface current bound behavior. +// 6. End-to-end tests on parquet files + +/// `MIN(a)`: able to pushdown dynamic filter +#[tokio::test] +async fn test_aggregate_dynamic_filter_min_simple() { + // Single min(a) showcases the base case. + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + let batches = vec![record_batch!(("a", Int32, [5, 1, 3, 8])).unwrap()]; + + let min_expr = + AggregateExprBuilder::new(min_udaf(), vec![col("a", &schema).unwrap()]) + .schema(Arc::clone(&schema)) + .alias("min_a") + .build() + .unwrap(); + + run_aggregate_dyn_filter_case(AggregateDynFilterCase { + schema, + batches, + aggr_exprs: vec![min_expr], + expected_before: Some("DynamicFilter [ empty ]"), + expected_after: Some("DynamicFilter [ a@0 < 1 ]"), + scan_support: true, + }) + .await; +} + +/// `MAX(a)`: able to pushdown dynamic filter +#[tokio::test] +async fn test_aggregate_dynamic_filter_max_simple() { + // Single max(a) mirrors the base case on the upper bound. + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + let batches = vec![record_batch!(("a", Int32, [5, 1, 3, 8])).unwrap()]; + + let max_expr = + AggregateExprBuilder::new(max_udaf(), vec![col("a", &schema).unwrap()]) + .schema(Arc::clone(&schema)) + .alias("max_a") + .build() + .unwrap(); + + run_aggregate_dyn_filter_case(AggregateDynFilterCase { + schema, + batches, + aggr_exprs: vec![max_expr], + expected_before: Some("DynamicFilter [ empty ]"), + expected_after: Some("DynamicFilter [ a@0 > 8 ]"), + scan_support: true, + }) + .await; +} + +/// `MIN(a+1)`: Can't pushdown dynamic filter +#[tokio::test] +async fn test_aggregate_dynamic_filter_min_expression_not_supported() { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + let batches = vec![record_batch!(("a", Int32, [5, 1, 3, 8])).unwrap()]; + + let expr: Arc = Arc::new(BinaryExpr::new( + col("a", &schema).unwrap(), + Operator::Plus, + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + )); + let min_expr = AggregateExprBuilder::new(min_udaf(), vec![expr]) + .schema(Arc::clone(&schema)) + .alias("min_a_plus_one") + .build() + .unwrap(); + + run_aggregate_dyn_filter_case(AggregateDynFilterCase { + schema, + batches, + aggr_exprs: vec![min_expr], + expected_before: None, + expected_after: None, + scan_support: true, + }) + .await; +} + +/// `MIN(a), MAX(a)`: Pushdown dynamic filter like `(a<1) or (a>8)` +#[tokio::test] +async fn test_aggregate_dynamic_filter_min_max_same_column() { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + let batches = vec![record_batch!(("a", Int32, [5, 1, 3, 8])).unwrap()]; + + let min_expr = + AggregateExprBuilder::new(min_udaf(), vec![col("a", &schema).unwrap()]) + .schema(Arc::clone(&schema)) + .alias("min_a") + .build() + .unwrap(); + let max_expr = + AggregateExprBuilder::new(max_udaf(), vec![col("a", &schema).unwrap()]) + .schema(Arc::clone(&schema)) + .alias("max_a") + .build() + .unwrap(); + + run_aggregate_dyn_filter_case(AggregateDynFilterCase { + schema, + batches, + aggr_exprs: vec![min_expr, max_expr], + expected_before: Some("DynamicFilter [ empty ]"), + expected_after: Some("DynamicFilter [ a@0 < 1 OR a@0 > 8 ]"), + scan_support: true, + }) + .await; +} + +/// `MIN(a), MAX(b)`: Pushdown dynamic filter like `(a<1) or (b>9)` +#[tokio::test] +async fn test_aggregate_dynamic_filter_min_max_different_columns() { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + let batches = vec![ + record_batch!(("a", Int32, [5, 1, 3, 8]), ("b", Int32, [7, 2, 4, 9])).unwrap(), + ]; + + let min_expr = + AggregateExprBuilder::new(min_udaf(), vec![col("a", &schema).unwrap()]) + .schema(Arc::clone(&schema)) + .alias("min_a") + .build() + .unwrap(); + let max_expr = + AggregateExprBuilder::new(max_udaf(), vec![col("b", &schema).unwrap()]) + .schema(Arc::clone(&schema)) + .alias("max_b") + .build() + .unwrap(); + + run_aggregate_dyn_filter_case(AggregateDynFilterCase { + schema, + batches, + aggr_exprs: vec![min_expr, max_expr], + expected_before: Some("DynamicFilter [ empty ]"), + expected_after: Some("DynamicFilter [ a@0 < 1 OR b@1 > 9 ]"), + scan_support: true, + }) + .await; +} + +/// Mix of supported/unsupported aggregates retains only the valid ones. +/// `MIN(a), MAX(a), MAX(b), MIN(c+1)`: Pushdown dynamic filter like `(a<1) or (a>8) OR (b>12)` +#[tokio::test] +async fn test_aggregate_dynamic_filter_multiple_mixed_expressions() { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + ])); + let batches = vec![ + record_batch!( + ("a", Int32, [5, 1, 3, 8]), + ("b", Int32, [10, 4, 6, 12]), + ("c", Int32, [100, 70, 90, 110]) + ) + .unwrap(), + ]; + + let min_a = AggregateExprBuilder::new(min_udaf(), vec![col("a", &schema).unwrap()]) + .schema(Arc::clone(&schema)) + .alias("min_a") + .build() + .unwrap(); + let max_a = AggregateExprBuilder::new(max_udaf(), vec![col("a", &schema).unwrap()]) + .schema(Arc::clone(&schema)) + .alias("max_a") + .build() + .unwrap(); + let max_b = AggregateExprBuilder::new(max_udaf(), vec![col("b", &schema).unwrap()]) + .schema(Arc::clone(&schema)) + .alias("max_b") + .build() + .unwrap(); + let expr_c: Arc = Arc::new(BinaryExpr::new( + col("c", &schema).unwrap(), + Operator::Plus, + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + )); + let min_c_expr = AggregateExprBuilder::new(min_udaf(), vec![expr_c]) + .schema(Arc::clone(&schema)) + .alias("min_c_plus_one") + .build() + .unwrap(); + + run_aggregate_dyn_filter_case(AggregateDynFilterCase { + schema, + batches, + aggr_exprs: vec![min_a, max_a, max_b, min_c_expr], + expected_before: Some("DynamicFilter [ empty ]"), + expected_after: Some("DynamicFilter [ a@0 < 1 OR a@0 > 8 OR b@1 > 12 ]"), + scan_support: true, + }) + .await; +} + +/// Don't tighten the dynamic filter if all inputs are null +#[tokio::test] +async fn test_aggregate_dynamic_filter_min_all_nulls() { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + let batches = vec![record_batch!(("a", Int32, [None, None, None, None])).unwrap()]; + + let min_expr = + AggregateExprBuilder::new(min_udaf(), vec![col("a", &schema).unwrap()]) + .schema(Arc::clone(&schema)) + .alias("min_a") + .build() + .unwrap(); + + run_aggregate_dyn_filter_case(AggregateDynFilterCase { + schema, + batches, + aggr_exprs: vec![min_expr], + expected_before: Some("DynamicFilter [ empty ]"), + // After reading the input it hasn't a meaningful bound to update, so the + // predicate `true` means don't filter out anything + expected_after: Some("DynamicFilter [ true ]"), + scan_support: true, + }) + .await; +} + +/// Test aggregate dynamic filter is working when reading parquet files +/// +/// Runs 'select max(id) from test_table where id > 1', and ensure some file ranges +/// pruned by the dynamic filter. +#[tokio::test] +async fn test_aggregate_dynamic_filter_parquet_e2e() { + let config = SessionConfig::new() + .with_collect_statistics(true) + .with_target_partitions(2) + .set_bool("datafusion.optimizer.enable_dynamic_filter_pushdown", true) + .set_bool("datafusion.execution.parquet.pushdown_filters", true); + let ctx = SessionContext::new_with_config(config); + + let data_path = format!( + "{}/tests/data/test_statistics_per_partition/", + env!("CARGO_MANIFEST_DIR") + ); + + ctx.register_parquet("test_table", &data_path, ParquetReadOptions::default()) + .await + .unwrap(); + + // partition 1: + // files: ..03-01(id=4), ..03-02(id=3) + // partition 1: + // files: ..03-03(id=2), ..03-04(id=1) + // + // In partition 1, after reading the first file, the dynamic filter will be update + // to "id > 4", so the `..03-02` file must be able to get pruned out + let df = ctx + .sql("explain analyze select max(id) from test_table where id > 1") + .await + .unwrap(); + + let result = df.collect().await.unwrap(); + + let formatted = pretty_format_batches(&result).unwrap(); + let explain_analyze = format!("{formatted}"); + + // Capture "2" from "files_ranges_pruned_statistics=4 total → 2 matched" + let re = Regex::new( + r"files_ranges_pruned_statistics\s*=\s*(\d+)\s*total\s*[→>\-]\s*(\d+)\s*matched", + ) + .unwrap(); + + if let Some(caps) = re.captures(&explain_analyze) { + let matched_num: i32 = caps[2].parse().unwrap(); + assert!( + matched_num < 4, + "Total 4 files, if some pruned, the matched count is < 4" + ); + } else { + unreachable!("metrics should exist") + } +} + +/// Non-partial (Single) aggregates should skip dynamic filter initialization. +#[test] +fn test_aggregate_dynamic_filter_not_created_for_single_mode() { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + let batches = vec![record_batch!(("a", Int32, [5, 1, 3, 8])).unwrap()]; + + let scan = TestScanBuilder::new(Arc::clone(&schema)) + .with_support(true) + .with_batches(batches) + .build(); + + let min_expr = + AggregateExprBuilder::new(min_udaf(), vec![col("a", &schema).unwrap()]) + .schema(Arc::clone(&schema)) + .alias("min_a") + .build() + .unwrap(); + + let plan: Arc = Arc::new( + AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new_single(vec![]), + vec![min_expr.into()], + vec![None], + scan, + Arc::clone(&schema), + ) + .unwrap(), + ); + + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + config.optimizer.enable_dynamic_filter_pushdown = true; + + let optimized = FilterPushdown::new_post_optimization() + .optimize(plan, &config) + .unwrap(); + + let formatted = format_plan_for_test(&optimized); + assert!( + !formatted.contains("DynamicFilter ["), + "dynamic filter should not be created for AggregateMode::Single: {formatted}" + ); +} + +#[tokio::test] +async fn test_aggregate_filter_pushdown() { + // Test that filters can pass through AggregateExec even with aggregate functions + // when the filter references grouping columns + // Simulates: SELECT a, COUNT(b) FROM table WHERE a = 'x' GROUP BY a + + let batches = vec![ + record_batch!(("a", Utf8, ["x", "y"]), ("b", Utf8, ["foo", "bar"])).unwrap(), + ]; + + let scan = TestScanBuilder::new(schema()) + .with_support(true) + .with_batches(batches) + .build(); + + // Create an aggregate: GROUP BY a with COUNT(b) + let group_by = PhysicalGroupBy::new_single(vec![( + col("a", &schema()).unwrap(), + "a".to_string(), + )]); + + // Add COUNT aggregate + let count_expr = + AggregateExprBuilder::new(count_udaf(), vec![col("b", &schema()).unwrap()]) + .schema(schema()) + .alias("count") + .build() + .unwrap(); + + let aggregate = Arc::new( + AggregateExec::try_new( + AggregateMode::Partial, + group_by, + vec![count_expr.into()], // Has aggregate function + vec![None], // No filter on the aggregate function + Arc::clone(&scan), + schema(), + ) + .unwrap(), + ); + + // Add a filter on the grouping column 'a' + let predicate = col_lit_predicate("a", "x", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, aggregate).unwrap()) + as Arc; + + // Even with aggregate functions, filter on grouping column should be pushed through + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = x + - AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count], ordering_mode=Sorted + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = x + " + ); +} + +#[tokio::test] +async fn test_no_pushdown_filter_on_aggregate_result() { + // Test that filters on aggregate results (not grouping columns) are NOT pushed through + // SELECT a, COUNT(b) as cnt FROM table GROUP BY a HAVING cnt > 5 + // The filter on 'cnt' cannot be pushed down because it's an aggregate result + + let batches = vec![ + record_batch!(("a", Utf8, ["x", "y"]), ("b", Utf8, ["foo", "bar"])).unwrap(), + ]; + + let scan = TestScanBuilder::new(schema()) + .with_support(true) + .with_batches(batches) + .build(); + + // Create an aggregate: GROUP BY a with COUNT(b) + let group_by = PhysicalGroupBy::new_single(vec![( + col("a", &schema()).unwrap(), + "a".to_string(), + )]); + + // Add COUNT aggregate + let count_expr = + AggregateExprBuilder::new(count_udaf(), vec![col("b", &schema()).unwrap()]) + .schema(schema()) + .alias("count") + .build() + .unwrap(); + + let aggregate = Arc::new( + AggregateExec::try_new( + AggregateMode::Partial, + group_by, + vec![count_expr.into()], + vec![None], + Arc::clone(&scan), + schema(), + ) + .unwrap(), + ); + + // Add a filter on the aggregate output column + // This simulates filtering on COUNT result, which should NOT be pushed through + let agg_schema = aggregate.schema(); + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new_with_schema("count[count]", &agg_schema).unwrap()), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int64(Some(5)))), + )); + let plan = Arc::new(FilterExec::try_new(predicate, aggregate).unwrap()) + as Arc; + + // The filter should NOT be pushed through the aggregate since it's on an aggregate result + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: count[count]@1 > 5 + - AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - FilterExec: count[count]@1 > 5 + - AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + " + ); +} + +#[test] +fn test_pushdown_filter_on_non_first_grouping_column() { + // Test that filters on non-first grouping columns are still pushed down + // SELECT a, b, count(*) as cnt FROM table GROUP BY a, b HAVING b = 'bar' + // The filter is on 'b' (second grouping column), should push down + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + + let aggregate_expr = vec![ + AggregateExprBuilder::new(count_udaf(), vec![col("c", &schema()).unwrap()]) + .schema(schema()) + .alias("cnt") + .build() + .map(Arc::new) + .unwrap(), + ]; + + let group_by = PhysicalGroupBy::new_single(vec![ + (col("a", &schema()).unwrap(), "a".to_string()), + (col("b", &schema()).unwrap(), "b".to_string()), + ]); + + let aggregate = Arc::new( + AggregateExec::try_new( + AggregateMode::Final, + group_by, + aggregate_expr.clone(), + vec![None], + scan, + schema(), + ) + .unwrap(), + ); + + let predicate = col_lit_predicate("b", "bar", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, aggregate).unwrap()); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: b@1 = bar + - AggregateExec: mode=Final, gby=[a@0 as a, b@1 as b], aggr=[cnt] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - AggregateExec: mode=Final, gby=[a@0 as a, b@1 as b], aggr=[cnt], ordering_mode=PartiallySorted([1]) + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=b@1 = bar + " + ); +} + +#[test] +fn test_no_pushdown_grouping_sets_filter_on_missing_column() { + // Test that filters on columns missing from some grouping sets are NOT pushed through + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + + let aggregate_expr = vec![ + AggregateExprBuilder::new(count_udaf(), vec![col("c", &schema()).unwrap()]) + .schema(schema()) + .alias("cnt") + .build() + .map(Arc::new) + .unwrap(), + ]; + + // Create GROUPING SETS with (a, b) and (b) + let group_by = PhysicalGroupBy::new( + vec![ + (col("a", &schema()).unwrap(), "a".to_string()), + (col("b", &schema()).unwrap(), "b".to_string()), + ], + vec![ + ( + Arc::new(Literal::new(ScalarValue::Utf8(None))), + "a".to_string(), + ), + ( + Arc::new(Literal::new(ScalarValue::Utf8(None))), + "b".to_string(), + ), + ], + vec![ + vec![false, false], // (a, b) - both present + vec![true, false], // (b) - a is NULL, b present + ], + true, + ); + + let aggregate = Arc::new( + AggregateExec::try_new( + AggregateMode::Final, + group_by, + aggregate_expr.clone(), + vec![None], + scan, + schema(), + ) + .unwrap(), + ); + + // Filter on column 'a' which is missing in the second grouping set, should not be pushed down + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, aggregate).unwrap()); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - AggregateExec: mode=Final, gby=[(a@0 as a, b@1 as b), (NULL as a, b@1 as b)], aggr=[cnt] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - FilterExec: a@0 = foo + - AggregateExec: mode=Final, gby=[(a@0 as a, b@1 as b), (NULL as a, b@1 as b)], aggr=[cnt] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + " + ); +} + +#[test] +fn test_pushdown_grouping_sets_filter_on_common_column() { + // Test that filters on columns present in ALL grouping sets ARE pushed through + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + + let aggregate_expr = vec![ + AggregateExprBuilder::new(count_udaf(), vec![col("c", &schema()).unwrap()]) + .schema(schema()) + .alias("cnt") + .build() + .map(Arc::new) + .unwrap(), + ]; + + // Create GROUPING SETS with (a, b) and (b) + let group_by = PhysicalGroupBy::new( + vec![ + (col("a", &schema()).unwrap(), "a".to_string()), + (col("b", &schema()).unwrap(), "b".to_string()), + ], + vec![ + ( + Arc::new(Literal::new(ScalarValue::Utf8(None))), + "a".to_string(), + ), + ( + Arc::new(Literal::new(ScalarValue::Utf8(None))), + "b".to_string(), + ), + ], + vec![ + vec![false, false], // (a, b) - both present + vec![true, false], // (b) - a is NULL, b present + ], + true, + ); + + let aggregate = Arc::new( + AggregateExec::try_new( + AggregateMode::Final, + group_by, + aggregate_expr.clone(), + vec![None], + scan, + schema(), + ) + .unwrap(), + ); + + // Filter on column 'b' which is present in all grouping sets will be pushed down + let predicate = col_lit_predicate("b", "bar", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, aggregate).unwrap()); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: b@1 = bar + - AggregateExec: mode=Final, gby=[(a@0 as a, b@1 as b), (NULL as a, b@1 as b)], aggr=[cnt] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - AggregateExec: mode=Final, gby=[(a@0 as a, b@1 as b), (NULL as a, b@1 as b)], aggr=[cnt], ordering_mode=PartiallySorted([1]) + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=b@1 = bar + " + ); +} + +#[test] +fn test_pushdown_with_empty_group_by() { + // Test that filters can be pushed down when GROUP BY is empty (no grouping columns) + // SELECT count(*) as cnt FROM table WHERE a = 'foo' + // There are no grouping columns, so the filter should still push down + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + + let aggregate_expr = vec![ + AggregateExprBuilder::new(count_udaf(), vec![col("c", &schema()).unwrap()]) + .schema(schema()) + .alias("cnt") + .build() + .map(Arc::new) + .unwrap(), + ]; + + // Empty GROUP BY - no grouping columns + let group_by = PhysicalGroupBy::new_single(vec![]); + + let aggregate = Arc::new( + AggregateExec::try_new( + AggregateMode::Final, + group_by, + aggregate_expr.clone(), + vec![None], + scan, + schema(), + ) + .unwrap(), + ); + + // Filter on 'a' + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, aggregate).unwrap()); + + // The filter should be pushed down even with empty GROUP BY + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - AggregateExec: mode=Final, gby=[], aggr=[cnt] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - AggregateExec: mode=Final, gby=[], aggr=[cnt] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + " + ); +} + +#[test] +fn test_pushdown_with_computed_grouping_key() { + // Test filter pushdown with computed grouping expression + // SELECT (c + 1.0) as c_plus_1, count(*) FROM table WHERE c > 5.0 GROUP BY (c + 1.0) + + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + + let predicate = Arc::new(BinaryExpr::new( + col("c", &schema()).unwrap(), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Float64(Some(5.0)))), + )) as Arc; + let filter = Arc::new(FilterExec::try_new(predicate, scan).unwrap()); + + let aggregate_expr = vec![ + AggregateExprBuilder::new(count_udaf(), vec![col("a", &schema()).unwrap()]) + .schema(schema()) + .alias("cnt") + .build() + .map(Arc::new) + .unwrap(), + ]; + + let c_plus_one = Arc::new(BinaryExpr::new( + col("c", &schema()).unwrap(), + Operator::Plus, + Arc::new(Literal::new(ScalarValue::Float64(Some(1.0)))), + )) as Arc; + + let group_by = + PhysicalGroupBy::new_single(vec![(c_plus_one, "c_plus_1".to_string())]); + + let plan = Arc::new( + AggregateExec::try_new( + AggregateMode::Final, + group_by, + aggregate_expr.clone(), + vec![None], + filter, + schema(), + ) + .unwrap(), + ); + + // The filter should be pushed down because 'c' is extracted from the grouping expression (c + 1.0) + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - AggregateExec: mode=Final, gby=[c@2 + 1 as c_plus_1], aggr=[cnt] + - FilterExec: c@2 > 5 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - AggregateExec: mode=Final, gby=[c@2 + 1 as c_plus_1], aggr=[cnt] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=c@2 > 5 + " + ); +} + +#[tokio::test] +async fn test_hashjoin_dynamic_filter_all_partitions_empty() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + // Test scenario where all build-side partitions are empty + // This validates the code path that sets the filter to `false` when no rows can match + + // Create empty build side + let build_batches = vec![]; + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .with_batches(build_batches) + .build(); + + // Create probe side with some data + let probe_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab", "ac"]), + ("b", Utf8, ["ba", "bb", "bc"]) + ) + .unwrap(), + ]; + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .with_batches(probe_batches) + .build(); + + // Create RepartitionExec nodes for both sides + let partition_count = 4; + + let build_hash_exprs = vec![ + col("a", &build_side_schema).unwrap(), + col("b", &build_side_schema).unwrap(), + ]; + let build_repartition = Arc::new( + RepartitionExec::try_new( + build_scan, + Partitioning::Hash(build_hash_exprs, partition_count), + ) + .unwrap(), + ); + + let probe_hash_exprs = vec![ + col("a", &probe_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ]; + let probe_repartition = Arc::new( + RepartitionExec::try_new( + Arc::clone(&probe_scan), + Partitioning::Hash(probe_hash_exprs, partition_count), + ) + .unwrap(), + ); + + // Create HashJoinExec + let on = vec![ + ( + col("a", &build_side_schema).unwrap(), + col("a", &probe_side_schema).unwrap(), + ), + ( + col("b", &build_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ), + ]; + let plan = Arc::new( + HashJoinExec::try_new( + build_repartition, + probe_repartition, + on, + None, + &JoinType::Inner, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ); + + // Apply the filter pushdown optimizer + let mut config = SessionConfig::new(); + config.options_mut().execution.parquet.pushdown_filters = true; + let optimizer = FilterPushdown::new_post_optimization(); + let plan = optimizer.optimize(plan, config.options()).unwrap(); + + insta::assert_snapshot!( + format_plan_for_test(&plan), + @r" + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - RepartitionExec: partitioning=Hash([a@0, b@1], 4), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b], file_type=test, pushdown_supported=true + - RepartitionExec: partitioning=Hash([a@0, b@1], 4), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] + " + ); + + // Put some data through the plan to check that the filter is updated to reflect the TopK state + let session_ctx = SessionContext::new_with_config(config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + // Execute all partitions (required for partitioned hash join coordination) + let _batches = collect(Arc::clone(&plan), Arc::clone(&task_ctx)) + .await + .unwrap(); + + // Test that filters are pushed down correctly to each side of the join + insta::assert_snapshot!( + format_plan_for_test(&plan), + @r" + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - RepartitionExec: partitioning=Hash([a@0, b@1], 4), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b], file_type=test, pushdown_supported=true + - RepartitionExec: partitioning=Hash([a@0, b@1], 4), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ false ] + " + ); +} + +#[tokio::test] +async fn test_hashjoin_dynamic_filter_with_nulls() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + // Test scenario where build side has NULL values in join keys + // This validates NULL handling in bounds computation and filter generation + + // Create build side with NULL values + let build_batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, true), // nullable + Field::new("b", DataType::Int32, true), // nullable + ])), + vec![ + Arc::new(StringArray::from(vec![Some("aa"), None, Some("ab")])), + Arc::new(Int32Array::from(vec![Some(1), Some(2), None])), + ], + ) + .unwrap(); + let build_batches = vec![build_batch]; + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, true), + Field::new("b", DataType::Int32, true), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .with_batches(build_batches) + .build(); + + // Create probe side with nullable fields + let probe_batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Float64, false), + ])), + vec![ + Arc::new(StringArray::from(vec![ + Some("aa"), + Some("ab"), + Some("ac"), + None, + ])), + Arc::new(Int32Array::from(vec![Some(1), Some(3), Some(4), Some(5)])), + Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])), + ], + ) + .unwrap(); + let probe_batches = vec![probe_batch]; + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Float64, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .with_batches(probe_batches) + .build(); + + // Create HashJoinExec in CollectLeft mode (simpler for this test) + let on = vec![ + ( + col("a", &build_side_schema).unwrap(), + col("a", &probe_side_schema).unwrap(), + ), + ( + col("b", &build_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ), + ]; + let plan = Arc::new( + HashJoinExec::try_new( + build_scan, + Arc::clone(&probe_scan), + on, + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ); + + // Apply the filter pushdown optimizer + let mut config = SessionConfig::new(); + config.options_mut().execution.parquet.pushdown_filters = true; + let optimizer = FilterPushdown::new_post_optimization(); + let plan = optimizer.optimize(plan, config.options()).unwrap(); + + insta::assert_snapshot!( + format_plan_for_test(&plan), + @r" + - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] + " + ); + + // Put some data through the plan to check that the filter is updated to reflect the TopK state + let session_ctx = SessionContext::new_with_config(config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + // Execute all partitions (required for partitioned hash join coordination) + let batches = collect(Arc::clone(&plan), Arc::clone(&task_ctx)) + .await + .unwrap(); + + // Test that filters are pushed down correctly to each side of the join + insta::assert_snapshot!( + format_plan_for_test(&plan), + @r" + - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= 1 AND b@1 <= 2 AND struct(a@0, b@1) IN (SET) ([{c0:aa,c1:1}, {c0:,c1:2}, {c0:ab,c1:}]) ] + " + ); + + #[rustfmt::skip] + let expected = [ + "+----+---+----+---+-----+", + "| a | b | a | b | c |", + "+----+---+----+---+-----+", + "| aa | 1 | aa | 1 | 1.0 |", + "+----+---+----+---+-----+", + ]; + assert_batches_eq!(&expected, &batches); +} + +/// Test that when hash_join_inlist_pushdown_max_size is set to a very small value, +/// the HashTable strategy is used instead of InList strategy, even with small build sides. +/// This test is identical to test_hashjoin_dynamic_filter_pushdown_partitioned except +/// for the config setting that forces the HashTable strategy. +#[tokio::test] +async fn test_hashjoin_hash_table_pushdown_partitioned() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + // Create build side with limited values + let build_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8, ["ba", "bb"]), + ("c", Float64, [1.0, 2.0]) // Extra column not used in join + ) + .unwrap(), + ]; + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .with_batches(build_batches) + .build(); + + // Create probe side with more values + let probe_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab", "ac", "ad"]), + ("b", Utf8, ["ba", "bb", "bc", "bd"]), + ("e", Float64, [1.0, 2.0, 3.0, 4.0]) // Extra column not used in join + ) + .unwrap(), + ]; + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("e", DataType::Float64, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .with_batches(probe_batches) + .build(); + + // Create RepartitionExec nodes for both sides with hash partitioning on join keys + let partition_count = 12; + + // Build side: DataSource -> RepartitionExec (Hash) + let build_hash_exprs = vec![ + col("a", &build_side_schema).unwrap(), + col("b", &build_side_schema).unwrap(), + ]; + let build_repartition = Arc::new( + RepartitionExec::try_new( + build_scan, + Partitioning::Hash(build_hash_exprs, partition_count), + ) + .unwrap(), + ); + + // Probe side: DataSource -> RepartitionExec (Hash) + let probe_hash_exprs = vec![ + col("a", &probe_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ]; + let probe_repartition = Arc::new( + RepartitionExec::try_new( + Arc::clone(&probe_scan), + Partitioning::Hash(probe_hash_exprs, partition_count), + ) + .unwrap(), + ); + + // Create HashJoinExec with partitioned inputs + let on = vec![ + ( + col("a", &build_side_schema).unwrap(), + col("a", &probe_side_schema).unwrap(), + ), + ( + col("b", &build_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ), + ]; + let hash_join = Arc::new( + HashJoinExec::try_new( + build_repartition, + probe_repartition, + on, + None, + &JoinType::Inner, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ); + + // Top-level CoalescePartitionsExec + let cp = Arc::new(CoalescePartitionsExec::new(hash_join)) as Arc; + // Add a sort for deterministic output + let plan = Arc::new(SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr::new( + col("a", &probe_side_schema).unwrap(), + SortOptions::new(true, false), // descending, nulls_first + )]) + .unwrap(), + cp, + )) as Arc; + + // Apply the optimization with config setting that forces HashTable strategy + let session_config = SessionConfig::default() + .with_batch_size(10) + .set_usize("datafusion.optimizer.hash_join_inlist_pushdown_max_size", 1) + .set_bool("datafusion.execution.parquet.pushdown_filters", true) + .set_bool("datafusion.optimizer.enable_dynamic_filter_pushdown", true); + let plan = FilterPushdown::new_post_optimization() + .optimize(plan, session_config.options()) + .unwrap(); + let session_ctx = SessionContext::new_with_config(session_config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + let batches = collect(Arc::clone(&plan), Arc::clone(&task_ctx)) + .await + .unwrap(); + + // Verify that hash_lookup is used instead of IN (SET) + let plan_str = format_plan_for_test(&plan).to_string(); + assert!( + plan_str.contains("hash_lookup"), + "Expected hash_lookup in plan but got: {plan_str}" + ); + assert!( + !plan_str.contains("IN (SET)"), + "Expected no IN (SET) in plan but got: {plan_str}" + ); + + let result = format!("{}", pretty_format_batches(&batches).unwrap()); + + let probe_scan_metrics = probe_scan.metrics().unwrap(); + + // The probe side had 4 rows, but after applying the dynamic filter only 2 rows should remain. + assert_eq!(probe_scan_metrics.output_rows().unwrap(), 2); + + // Results should be identical to the InList version + insta::assert_snapshot!( + result, + @r" + +----+----+-----+----+----+-----+ + | a | b | c | a | b | e | + +----+----+-----+----+----+-----+ + | ab | bb | 2.0 | ab | bb | 2.0 | + | aa | ba | 1.0 | aa | ba | 1.0 | + +----+----+-----+----+----+-----+ + ", + ); +} + +/// Test that when hash_join_inlist_pushdown_max_size is set to a very small value, +/// the HashTable strategy is used instead of InList strategy in CollectLeft mode. +/// This test is identical to test_hashjoin_dynamic_filter_pushdown_collect_left except +/// for the config setting that forces the HashTable strategy. +#[tokio::test] +async fn test_hashjoin_hash_table_pushdown_collect_left() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + let build_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8, ["ba", "bb"]), + ("c", Float64, [1.0, 2.0]) // Extra column not used in join + ) + .unwrap(), + ]; + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .with_batches(build_batches) + .build(); + + // Create probe side with more values + let probe_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab", "ac", "ad"]), + ("b", Utf8, ["ba", "bb", "bc", "bd"]), + ("e", Float64, [1.0, 2.0, 3.0, 4.0]) // Extra column not used in join + ) + .unwrap(), + ]; + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("e", DataType::Float64, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .with_batches(probe_batches) + .build(); + + // Create RepartitionExec nodes for both sides with hash partitioning on join keys + let partition_count = 12; + + // Probe side: DataSource -> RepartitionExec(Hash) + let probe_hash_exprs = vec![ + col("a", &probe_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ]; + let probe_repartition = Arc::new( + RepartitionExec::try_new( + Arc::clone(&probe_scan), + Partitioning::Hash(probe_hash_exprs, partition_count), // create multi partitions on probSide + ) + .unwrap(), + ); + + let on = vec![ + ( + col("a", &build_side_schema).unwrap(), + col("a", &probe_side_schema).unwrap(), + ), + ( + col("b", &build_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ), + ]; + let hash_join = Arc::new( + HashJoinExec::try_new( + build_scan, + probe_repartition, + on, + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ); + + // Top-level CoalescePartitionsExec + let cp = Arc::new(CoalescePartitionsExec::new(hash_join)) as Arc; + // Add a sort for deterministic output + let plan = Arc::new(SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr::new( + col("a", &probe_side_schema).unwrap(), + SortOptions::new(true, false), // descending, nulls_first + )]) + .unwrap(), + cp, + )) as Arc; + + // Apply the optimization with config setting that forces HashTable strategy + let session_config = SessionConfig::default() + .with_batch_size(10) + .set_usize("datafusion.optimizer.hash_join_inlist_pushdown_max_size", 1) + .set_bool("datafusion.execution.parquet.pushdown_filters", true) + .set_bool("datafusion.optimizer.enable_dynamic_filter_pushdown", true); + let plan = FilterPushdown::new_post_optimization() + .optimize(plan, session_config.options()) + .unwrap(); + let session_ctx = SessionContext::new_with_config(session_config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + let batches = collect(Arc::clone(&plan), Arc::clone(&task_ctx)) + .await + .unwrap(); + + // Verify that hash_lookup is used instead of IN (SET) + let plan_str = format_plan_for_test(&plan).to_string(); + assert!( + plan_str.contains("hash_lookup"), + "Expected hash_lookup in plan but got: {plan_str}" + ); + assert!( + !plan_str.contains("IN (SET)"), + "Expected no IN (SET) in plan but got: {plan_str}" + ); + + let result = format!("{}", pretty_format_batches(&batches).unwrap()); + + let probe_scan_metrics = probe_scan.metrics().unwrap(); + + // The probe side had 4 rows, but after applying the dynamic filter only 2 rows should remain. + assert_eq!(probe_scan_metrics.output_rows().unwrap(), 2); + + // Results should be identical to the InList version + insta::assert_snapshot!( + result, + @r" + +----+----+-----+----+----+-----+ + | a | b | c | a | b | e | + +----+----+-----+----+----+-----+ + | ab | bb | 2.0 | ab | bb | 2.0 | + | aa | ba | 1.0 | aa | ba | 1.0 | + +----+----+-----+----+----+-----+ + ", + ); +} + +/// Test HashTable strategy with integer multi-column join keys. +/// Verifies that hash_lookup works correctly with integer data types. +#[tokio::test] +async fn test_hashjoin_hash_table_pushdown_integer_keys() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + // Create build side with integer keys + let build_batches = vec![ + record_batch!( + ("id1", Int32, [1, 2]), + ("id2", Int32, [10, 20]), + ("value", Float64, [100.0, 200.0]) + ) + .unwrap(), + ]; + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("id1", DataType::Int32, false), + Field::new("id2", DataType::Int32, false), + Field::new("value", DataType::Float64, false), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .with_batches(build_batches) + .build(); + + // Create probe side with more integer rows + let probe_batches = vec![ + record_batch!( + ("id1", Int32, [1, 2, 3, 4]), + ("id2", Int32, [10, 20, 30, 40]), + ("data", Utf8, ["a", "b", "c", "d"]) + ) + .unwrap(), + ]; + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("id1", DataType::Int32, false), + Field::new("id2", DataType::Int32, false), + Field::new("data", DataType::Utf8, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .with_batches(probe_batches) + .build(); + + // Create join on multiple integer columns + let on = vec![ + ( + col("id1", &build_side_schema).unwrap(), + col("id1", &probe_side_schema).unwrap(), + ), + ( + col("id2", &build_side_schema).unwrap(), + col("id2", &probe_side_schema).unwrap(), + ), + ]; + let plan = Arc::new( + HashJoinExec::try_new( + build_scan, + Arc::clone(&probe_scan), + on, + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ); + + // Apply optimization with forced HashTable strategy + let session_config = SessionConfig::default() + .with_batch_size(10) + .set_usize("datafusion.optimizer.hash_join_inlist_pushdown_max_size", 1) + .set_bool("datafusion.execution.parquet.pushdown_filters", true) + .set_bool("datafusion.optimizer.enable_dynamic_filter_pushdown", true); + let plan = FilterPushdown::new_post_optimization() + .optimize(plan, session_config.options()) + .unwrap(); + let session_ctx = SessionContext::new_with_config(session_config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + let batches = collect(Arc::clone(&plan), Arc::clone(&task_ctx)) + .await + .unwrap(); + + // Verify hash_lookup is used + let plan_str = format_plan_for_test(&plan).to_string(); + assert!( + plan_str.contains("hash_lookup"), + "Expected hash_lookup in plan but got: {plan_str}" + ); + assert!( + !plan_str.contains("IN (SET)"), + "Expected no IN (SET) in plan but got: {plan_str}" + ); + + let result = format!("{}", pretty_format_batches(&batches).unwrap()); + + let probe_scan_metrics = probe_scan.metrics().unwrap(); + // Only 2 rows from probe side match the build side + assert_eq!(probe_scan_metrics.output_rows().unwrap(), 2); + + insta::assert_snapshot!( + result, + @r" + +-----+-----+-------+-----+-----+------+ + | id1 | id2 | value | id1 | id2 | data | + +-----+-----+-------+-----+-----+------+ + | 1 | 10 | 100.0 | 1 | 10 | a | + | 2 | 20 | 200.0 | 2 | 20 | b | + +-----+-----+-------+-----+-----+------+ + ", + ); +} + +#[tokio::test] +async fn test_hashjoin_dynamic_filter_pushdown_is_used() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + // Test both cases: probe side with and without filter pushdown support + for (probe_supports_pushdown, expected_is_used) in [(false, false), (true, true)] { + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .with_batches(vec![ + record_batch!(("a", Utf8, ["aa", "ab"]), ("b", Utf8, ["ba", "bb"])) + .unwrap(), + ]) + .build(); + + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(probe_supports_pushdown) + .with_batches(vec![ + record_batch!( + ("a", Utf8, ["aa", "ab", "ac", "ad"]), + ("b", Utf8, ["ba", "bb", "bc", "bd"]) + ) + .unwrap(), + ]) + .build(); + + let on = vec![ + ( + col("a", &build_side_schema).unwrap(), + col("a", &probe_side_schema).unwrap(), + ), + ( + col("b", &build_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ), + ]; + let plan = Arc::new( + HashJoinExec::try_new( + build_scan, + probe_scan, + on, + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ) as Arc; + + // Apply filter pushdown optimization + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + config.optimizer.enable_dynamic_filter_pushdown = true; + let plan = FilterPushdown::new_post_optimization() + .optimize(plan, &config) + .unwrap(); + + // Get the HashJoinExec to check the dynamic filter + let hash_join = plan + .as_any() + .downcast_ref::() + .expect("Plan should be HashJoinExec"); + + // Verify that a dynamic filter was created + let dynamic_filter = hash_join + .dynamic_filter_for_test() + .expect("Dynamic filter should be created"); + + // Verify that is_used() returns the expected value based on probe side support. + // When probe_supports_pushdown=false: no consumer holds a reference (is_used=false) + // When probe_supports_pushdown=true: probe side holds a reference (is_used=true) + assert_eq!( + dynamic_filter.is_used(), + expected_is_used, + "is_used() should return {expected_is_used} when probe side support is {probe_supports_pushdown}" + ); + } +} + +/// Regression test for https://github.com/apache/datafusion/issues/20109 +#[tokio::test] +async fn test_filter_with_projection_pushdown() { + use arrow::array::{Int64Array, RecordBatch, StringArray}; + use datafusion_physical_plan::collect; + use datafusion_physical_plan::filter::FilterExecBuilder; + + // Create schema: [time, event, size] + let schema = Arc::new(Schema::new(vec![ + Field::new("time", DataType::Int64, false), + Field::new("event", DataType::Utf8, false), + Field::new("size", DataType::Int64, false), + ])); + + // Create sample data + let timestamps = vec![100i64, 200, 300, 400, 500]; + let events = vec!["Ingestion", "Ingestion", "Query", "Ingestion", "Query"]; + let sizes = vec![10i64, 20, 30, 40, 50]; + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int64Array::from(timestamps)), + Arc::new(StringArray::from(events)), + Arc::new(Int64Array::from(sizes)), + ], + ) + .unwrap(); + + // Create data source + let memory_exec = datafusion_datasource::memory::MemorySourceConfig::try_new_exec( + &[vec![batch]], + schema.clone(), + None, + ) + .unwrap(); + + // First FilterExec: time < 350 with projection=[event@1, size@2] + let time_col = col("time", &memory_exec.schema()).unwrap(); + let time_filter = Arc::new(BinaryExpr::new( + time_col, + Operator::Lt, + Arc::new(Literal::new(ScalarValue::Int64(Some(350)))), + )); + let filter1 = Arc::new( + FilterExecBuilder::new(time_filter, memory_exec) + .apply_projection(Some(vec![1, 2])) + .unwrap() + .build() + .unwrap(), + ); + + // Second FilterExec: event = 'Ingestion' with projection=[size@1] + let event_col = col("event", &filter1.schema()).unwrap(); + let event_filter = Arc::new(BinaryExpr::new( + event_col, + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Utf8(Some( + "Ingestion".to_string(), + )))), + )); + let filter2 = Arc::new( + FilterExecBuilder::new(event_filter, filter1) + .apply_projection(Some(vec![1])) + .unwrap() + .build() + .unwrap(), + ); + + // Apply filter pushdown optimization + let config = ConfigOptions::default(); + let optimized_plan = FilterPushdown::new() + .optimize(Arc::clone(&filter2) as Arc, &config) + .unwrap(); + + // Execute the optimized plan - this should not error + let ctx = SessionContext::new(); + let result = collect(optimized_plan, ctx.task_ctx()).await.unwrap(); + + // Verify results: should return rows where time < 350 AND event = 'Ingestion' + // That's rows with time=100,200 (both have event='Ingestion'), so sizes 10,20 + let expected = [ + "+------+", "| size |", "+------+", "| 10 |", "| 20 |", "+------+", + ]; + assert_batches_eq!(expected, &result); +} + +/// Test that ExecutionPlan::apply_expressions() can discover dynamic filters across the plan tree +#[tokio::test] +async fn test_discover_dynamic_filters_via_expressions_api() { + use datafusion_common::JoinType; + use datafusion_common::tree_node::TreeNodeRecursion; + use datafusion_physical_expr::expressions::DynamicFilterPhysicalExpr; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + fn count_dynamic_filters(plan: &Arc) -> usize { + let mut count = 0; + + // Check expressions from this node using apply_expressions + let _ = plan.apply_expressions(&mut |expr| { + if let Some(_df) = expr.as_any().downcast_ref::() { + count += 1; + } + Ok(TreeNodeRecursion::Continue) + }); + + // Recursively visit children + for child in plan.children() { + count += count_dynamic_filters(child); + } + + count + } + + // Create build side (left) + let build_batches = + vec![record_batch!(("a", Utf8, ["foo", "bar"]), ("b", Int32, [1, 2])).unwrap()]; + let build_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Int32, false), + ])); + let build_scan = TestScanBuilder::new(build_schema.clone()) + .with_support(true) + .with_batches(build_batches) + .build(); + + // Create probe side (right) + let probe_batches = vec![ + record_batch!( + ("a", Utf8, ["foo", "bar", "baz", "qux"]), + ("c", Float64, [1.0, 2.0, 3.0, 4.0]) + ) + .unwrap(), + ]; + let probe_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + ])); + let probe_scan = TestScanBuilder::new(probe_schema.clone()) + .with_support(true) + .with_batches(probe_batches) + .build(); + + // Create HashJoinExec + let plan = Arc::new( + HashJoinExec::try_new( + build_scan, + probe_scan, + vec![( + col("a", &build_schema).unwrap(), + col("a", &probe_schema).unwrap(), + )], + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ) as Arc; + + // Before optimization: no dynamic filters + let count_before = count_dynamic_filters(&plan); + assert_eq!( + count_before, 0, + "Before optimization, should have no dynamic filters" + ); + + // Apply filter pushdown optimization (this creates dynamic filters) + let mut config = ConfigOptions::default(); + config.optimizer.enable_dynamic_filter_pushdown = true; + config.execution.parquet.pushdown_filters = true; + let optimized_plan = FilterPushdown::new_post_optimization() + .optimize(plan, &config) + .unwrap(); + + // After optimization: should discover dynamic filters + // We expect 2 dynamic filters: + // 1. In the HashJoinExec (producer) + // 2. In the DataSourceExec (consumer, pushed down to the probe side) + let count_after = count_dynamic_filters(&optimized_plan); + assert_eq!( + count_after, 2, + "After optimization, should discover exactly 2 dynamic filters (1 in HashJoinExec, 1 in DataSourceExec), found {count_after}" + ); +} + +#[tokio::test] +async fn test_hashjoin_dynamic_filter_pushdown_left_join() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + // Create build side with limited values + let build_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8, ["ba", "bb"]), + ("c", Float64, [1.0, 2.0]) + ) + .unwrap(), + ]; + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .with_batches(build_batches) + .build(); + + // Create probe side with more values (some won't match) + let probe_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab", "ac", "ad"]), + ("b", Utf8, ["ba", "bb", "bc", "bd"]), + ("e", Float64, [1.0, 2.0, 3.0, 4.0]) + ) + .unwrap(), + ]; + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("e", DataType::Float64, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .with_batches(probe_batches) + .build(); + + // Create HashJoinExec with Left join and CollectLeft mode + let on = vec![ + ( + col("a", &build_side_schema).unwrap(), + col("a", &probe_side_schema).unwrap(), + ), + ( + col("b", &build_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ), + ]; + let plan = Arc::new( + HashJoinExec::try_new( + build_scan, + Arc::clone(&probe_scan), + on, + None, + &JoinType::Left, + None, + PartitionMode::CollectLeft, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ) as Arc; + + // Expect the dynamic filter predicate to be pushed down into the probe side DataSource + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new_post_optimization(), true), + @r" + OptimizationTest: + input: + - HashJoinExec: mode=CollectLeft, join_type=Left, on=[(a@0, a@0), (b@1, b@1)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true + output: + Ok: + - HashJoinExec: mode=CollectLeft, join_type=Left, on=[(a@0, a@0), (b@1, b@1)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] + ", + ); + + // Actually apply the optimization and execute the plan + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + config.optimizer.enable_dynamic_filter_pushdown = true; + let plan = FilterPushdown::new_post_optimization() + .optimize(plan, &config) + .unwrap(); + + // Test that dynamic filter linking survives with_new_children + let children = plan.children().into_iter().map(Arc::clone).collect(); + let plan = plan.with_new_children(children).unwrap(); + + let config = SessionConfig::new().with_batch_size(10); + let session_ctx = SessionContext::new_with_config(config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + let batches = collect(Arc::clone(&plan), Arc::clone(&task_ctx)) + .await + .unwrap(); + + // After execution, verify the dynamic filter was populated with bounds and IN-list + insta::assert_snapshot!( + format!("{}", format_plan_for_test(&plan)), + @r" + - HashJoinExec: mode=CollectLeft, join_type=Left, on=[(a@0, a@0), (b@1, b@1)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= ba AND b@1 <= bb AND struct(a@0, b@1) IN (SET) ([{c0:aa,c1:ba}, {c0:ab,c1:bb}]) ] + " + ); + + // Verify result correctness: left join preserves all build (left) rows. + // All build rows match probe rows here, so we get 2 matched rows. + // The dynamic filter pruned unmatched probe rows (ac, ad) at scan time, + // which is safe because those probe rows would never match any build row. + let result = format!("{}", pretty_format_batches(&batches).unwrap()); + insta::assert_snapshot!( + result, + @r" + +----+----+-----+----+----+-----+ + | a | b | c | a | b | e | + +----+----+-----+----+----+-----+ + | aa | ba | 1.0 | aa | ba | 1.0 | + | ab | bb | 2.0 | ab | bb | 2.0 | + +----+----+-----+----+----+-----+ + " + ); +} + +#[tokio::test] +async fn test_hashjoin_dynamic_filter_pushdown_left_semi_join() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + // Create build side with limited values + let build_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8, ["ba", "bb"]), + ("c", Float64, [1.0, 2.0]) + ) + .unwrap(), + ]; + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .with_batches(build_batches) + .build(); + + // Create probe side with more values (some won't match) + let probe_batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab", "ac", "ad"]), + ("b", Utf8, ["ba", "bb", "bc", "bd"]), + ("e", Float64, [1.0, 2.0, 3.0, 4.0]) + ) + .unwrap(), + ]; + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("e", DataType::Float64, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .with_batches(probe_batches) + .build(); + + // Create HashJoinExec with LeftSemi join and CollectLeft mode + let on = vec![ + ( + col("a", &build_side_schema).unwrap(), + col("a", &probe_side_schema).unwrap(), + ), + ( + col("b", &build_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ), + ]; + let plan = Arc::new( + HashJoinExec::try_new( + build_scan, + Arc::clone(&probe_scan), + on, + None, + &JoinType::LeftSemi, + None, + PartitionMode::CollectLeft, + datafusion_common::NullEquality::NullEqualsNothing, + false, + ) + .unwrap(), + ) as Arc; + + // Expect the dynamic filter predicate to be pushed down into the probe side DataSource + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new_post_optimization(), true), + @r" + OptimizationTest: + input: + - HashJoinExec: mode=CollectLeft, join_type=LeftSemi, on=[(a@0, a@0), (b@1, b@1)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true + output: + Ok: + - HashJoinExec: mode=CollectLeft, join_type=LeftSemi, on=[(a@0, a@0), (b@1, b@1)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] + ", + ); + + // Actually apply the optimization and execute the plan + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + config.optimizer.enable_dynamic_filter_pushdown = true; + let plan = FilterPushdown::new_post_optimization() + .optimize(plan, &config) + .unwrap(); + + // Test that dynamic filter linking survives with_new_children + let children = plan.children().into_iter().map(Arc::clone).collect(); + let plan = plan.with_new_children(children).unwrap(); + + let config = SessionConfig::new().with_batch_size(10); + let session_ctx = SessionContext::new_with_config(config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + let batches = collect(Arc::clone(&plan), Arc::clone(&task_ctx)) + .await + .unwrap(); + + // After execution, verify the dynamic filter was populated with bounds and IN-list + insta::assert_snapshot!( + format!("{}", format_plan_for_test(&plan)), + @r" + - HashJoinExec: mode=CollectLeft, join_type=LeftSemi, on=[(a@0, a@0), (b@1, b@1)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= ba AND b@1 <= bb AND struct(a@0, b@1) IN (SET) ([{c0:aa,c1:ba}, {c0:ab,c1:bb}]) ] + " + ); + + // Verify result correctness: left semi join returns only build (left) rows + // that have at least one matching probe row. Output schema is build-side columns only. + let result = format!("{}", pretty_format_batches(&batches).unwrap()); + insta::assert_snapshot!( + result, + @r" + +----+----+-----+ + | a | b | c | + +----+----+-----+ + | aa | ba | 1.0 | + | ab | bb | 2.0 | + +----+----+-----+ + " + ); +} diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs b/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs deleted file mode 100644 index a28933d97bcd1..0000000000000 --- a/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs +++ /dev/null @@ -1,378 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::sync::{Arc, LazyLock}; - -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use datafusion::{ - logical_expr::Operator, - physical_plan::{ - expressions::{BinaryExpr, Column, Literal}, - PhysicalExpr, - }, - scalar::ScalarValue, -}; -use datafusion_common::config::ConfigOptions; -use datafusion_functions_aggregate::count::count_udaf; -use datafusion_physical_expr::expressions::col; -use datafusion_physical_expr::{aggregate::AggregateExprBuilder, Partitioning}; -use datafusion_physical_optimizer::filter_pushdown::FilterPushdown; -use datafusion_physical_plan::{ - aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}, - coalesce_batches::CoalesceBatchesExec, - filter::FilterExec, - repartition::RepartitionExec, -}; - -use util::{OptimizationTest, TestNode, TestScanBuilder}; - -mod util; - -#[test] -fn test_pushdown_into_scan() { - let scan = TestScanBuilder::new(schema()).with_support(true).build(); - let predicate = col_lit_predicate("a", "foo", &schema()); - let plan = Arc::new(FilterExec::try_new(predicate, scan).unwrap()); - - // expect the predicate to be pushed down into the DataSource - insta::assert_snapshot!( - OptimizationTest::new(plan, FilterPushdown{}, true), - @r" - OptimizationTest: - input: - - FilterExec: a@0 = foo - - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - output: - Ok: - - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo - " - ); -} - -/// Show that we can use config options to determine how to do pushdown. -#[test] -fn test_pushdown_into_scan_with_config_options() { - let scan = TestScanBuilder::new(schema()).with_support(true).build(); - let predicate = col_lit_predicate("a", "foo", &schema()); - let plan = Arc::new(FilterExec::try_new(predicate, scan).unwrap()) as _; - - let mut cfg = ConfigOptions::default(); - insta::assert_snapshot!( - OptimizationTest::new( - Arc::clone(&plan), - FilterPushdown {}, - false - ), - @r" - OptimizationTest: - input: - - FilterExec: a@0 = foo - - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - output: - Ok: - - FilterExec: a@0 = foo - - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - " - ); - - cfg.execution.parquet.pushdown_filters = true; - insta::assert_snapshot!( - OptimizationTest::new( - plan, - FilterPushdown {}, - true - ), - @r" - OptimizationTest: - input: - - FilterExec: a@0 = foo - - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - output: - Ok: - - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo - " - ); -} - -#[test] -fn test_filter_collapse() { - // filter should be pushed down into the parquet scan with two filters - let scan = TestScanBuilder::new(schema()).with_support(true).build(); - let predicate1 = col_lit_predicate("a", "foo", &schema()); - let filter1 = Arc::new(FilterExec::try_new(predicate1, scan).unwrap()); - let predicate2 = col_lit_predicate("b", "bar", &schema()); - let plan = Arc::new(FilterExec::try_new(predicate2, filter1).unwrap()); - - insta::assert_snapshot!( - OptimizationTest::new(plan, FilterPushdown{}, true), - @r" - OptimizationTest: - input: - - FilterExec: b@1 = bar - - FilterExec: a@0 = foo - - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - output: - Ok: - - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo AND b@1 = bar - " - ); -} - -#[test] -fn test_filter_with_projection() { - let scan = TestScanBuilder::new(schema()).with_support(true).build(); - let projection = vec![1, 0]; - let predicate = col_lit_predicate("a", "foo", &schema()); - let plan = Arc::new( - FilterExec::try_new(predicate, Arc::clone(&scan)) - .unwrap() - .with_projection(Some(projection)) - .unwrap(), - ); - - // expect the predicate to be pushed down into the DataSource but the FilterExec to be converted to ProjectionExec - insta::assert_snapshot!( - OptimizationTest::new(plan, FilterPushdown{}, true), - @r" - OptimizationTest: - input: - - FilterExec: a@0 = foo, projection=[b@1, a@0] - - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - output: - Ok: - - ProjectionExec: expr=[b@1 as b, a@0 as a] - - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo - ", - ); - - // add a test where the filter is on a column that isn't included in the output - let projection = vec![1]; - let predicate = col_lit_predicate("a", "foo", &schema()); - let plan = Arc::new( - FilterExec::try_new(predicate, scan) - .unwrap() - .with_projection(Some(projection)) - .unwrap(), - ); - insta::assert_snapshot!( - OptimizationTest::new(plan, FilterPushdown{},true), - @r" - OptimizationTest: - input: - - FilterExec: a@0 = foo, projection=[b@1] - - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - output: - Ok: - - ProjectionExec: expr=[b@1 as b] - - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo - " - ); -} - -#[test] -fn test_push_down_through_transparent_nodes() { - // expect the predicate to be pushed down into the DataSource - let scan = TestScanBuilder::new(schema()).with_support(true).build(); - let coalesce = Arc::new(CoalesceBatchesExec::new(scan, 1)); - let predicate = col_lit_predicate("a", "foo", &schema()); - let filter = Arc::new(FilterExec::try_new(predicate, coalesce).unwrap()); - let repartition = Arc::new( - RepartitionExec::try_new(filter, Partitioning::RoundRobinBatch(1)).unwrap(), - ); - let predicate = col_lit_predicate("b", "bar", &schema()); - let plan = Arc::new(FilterExec::try_new(predicate, repartition).unwrap()); - - // expect the predicate to be pushed down into the DataSource - insta::assert_snapshot!( - OptimizationTest::new(plan, FilterPushdown{},true), - @r" - OptimizationTest: - input: - - FilterExec: b@1 = bar - - RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=1 - - FilterExec: a@0 = foo - - CoalesceBatchesExec: target_batch_size=1 - - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - output: - Ok: - - RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=1 - - CoalesceBatchesExec: target_batch_size=1 - - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo AND b@1 = bar - " - ); -} - -#[test] -fn test_no_pushdown_through_aggregates() { - // There are 2 important points here: - // 1. The outer filter **is not** pushed down at all because we haven't implemented pushdown support - // yet for AggregateExec. - // 2. The inner filter **is** pushed down into the DataSource. - let scan = TestScanBuilder::new(schema()).with_support(true).build(); - - let coalesce = Arc::new(CoalesceBatchesExec::new(scan, 10)); - - let filter = Arc::new( - FilterExec::try_new(col_lit_predicate("a", "foo", &schema()), coalesce).unwrap(), - ); - - let aggregate_expr = - vec![ - AggregateExprBuilder::new(count_udaf(), vec![col("a", &schema()).unwrap()]) - .schema(schema()) - .alias("cnt") - .build() - .map(Arc::new) - .unwrap(), - ]; - let group_by = PhysicalGroupBy::new_single(vec![ - (col("a", &schema()).unwrap(), "a".to_string()), - (col("b", &schema()).unwrap(), "b".to_string()), - ]); - let aggregate = Arc::new( - AggregateExec::try_new( - AggregateMode::Final, - group_by, - aggregate_expr.clone(), - vec![None], - filter, - schema(), - ) - .unwrap(), - ); - - let coalesce = Arc::new(CoalesceBatchesExec::new(aggregate, 100)); - - let predicate = col_lit_predicate("b", "bar", &schema()); - let plan = Arc::new(FilterExec::try_new(predicate, coalesce).unwrap()); - - // expect the predicate to be pushed down into the DataSource - insta::assert_snapshot!( - OptimizationTest::new(plan, FilterPushdown{}, true), - @r" - OptimizationTest: - input: - - FilterExec: b@1 = bar - - CoalesceBatchesExec: target_batch_size=100 - - AggregateExec: mode=Final, gby=[a@0 as a, b@1 as b], aggr=[cnt], ordering_mode=PartiallySorted([0]) - - FilterExec: a@0 = foo - - CoalesceBatchesExec: target_batch_size=10 - - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - output: - Ok: - - FilterExec: b@1 = bar - - CoalesceBatchesExec: target_batch_size=100 - - AggregateExec: mode=Final, gby=[a@0 as a, b@1 as b], aggr=[cnt] - - CoalesceBatchesExec: target_batch_size=10 - - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo - " - ); -} - -/// Test various combinations of handling of child pushdown results -/// in an ExectionPlan in combination with support/not support in a DataSource. -#[test] -fn test_node_handles_child_pushdown_result() { - // If we set `with_support(true)` + `inject_filter = true` then the filter is pushed down to the DataSource - // and no FilterExec is created. - let scan = TestScanBuilder::new(schema()).with_support(true).build(); - let predicate = col_lit_predicate("a", "foo", &schema()); - let plan = Arc::new(TestNode::new(true, Arc::clone(&scan), predicate)); - insta::assert_snapshot!( - OptimizationTest::new(plan, FilterPushdown{}, true), - @r" - OptimizationTest: - input: - - TestInsertExec { inject_filter: true } - - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - output: - Ok: - - TestInsertExec { inject_filter: true } - - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo - ", - ); - - // If we set `with_support(false)` + `inject_filter = true` then the filter is not pushed down to the DataSource - // and a FilterExec is created. - let scan = TestScanBuilder::new(schema()).with_support(false).build(); - let predicate = col_lit_predicate("a", "foo", &schema()); - let plan = Arc::new(TestNode::new(true, Arc::clone(&scan), predicate)); - insta::assert_snapshot!( - OptimizationTest::new(plan, FilterPushdown{}, true), - @r" - OptimizationTest: - input: - - TestInsertExec { inject_filter: true } - - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false - output: - Ok: - - TestInsertExec { inject_filter: false } - - FilterExec: a@0 = foo - - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false - ", - ); - - // If we set `with_support(false)` + `inject_filter = false` then the filter is not pushed down to the DataSource - // and no FilterExec is created. - let scan = TestScanBuilder::new(schema()).with_support(false).build(); - let predicate = col_lit_predicate("a", "foo", &schema()); - let plan = Arc::new(TestNode::new(false, Arc::clone(&scan), predicate)); - insta::assert_snapshot!( - OptimizationTest::new(plan, FilterPushdown{}, true), - @r" - OptimizationTest: - input: - - TestInsertExec { inject_filter: false } - - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false - output: - Ok: - - TestInsertExec { inject_filter: false } - - DataSourceExec: file_groups={1 group: [[test.paqruet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false - ", - ); -} - -/// Schema: -/// a: String -/// b: String -/// c: f64 -static TEST_SCHEMA: LazyLock = LazyLock::new(|| { - let fields = vec![ - Field::new("a", DataType::Utf8, false), - Field::new("b", DataType::Utf8, false), - Field::new("c", DataType::Float64, false), - ]; - Arc::new(Schema::new(fields)) -}); - -fn schema() -> SchemaRef { - Arc::clone(&TEST_SCHEMA) -} - -/// Returns a predicate that is a binary expression col = lit -fn col_lit_predicate( - column_name: &str, - scalar_value: impl Into, - schema: &Schema, -) -> Arc { - let scalar_value = scalar_value.into(); - Arc::new(BinaryExpr::new( - Arc::new(Column::new_with_schema(column_name, schema).unwrap()), - Operator::Eq, - Arc::new(Literal::new(scalar_value)), - )) -} diff --git a/datafusion/core/tests/physical_optimizer/join_selection.rs b/datafusion/core/tests/physical_optimizer/join_selection.rs index d8c0c142f7fb6..1c94a7bd1e91c 100644 --- a/datafusion/core/tests/physical_optimizer/join_selection.rs +++ b/datafusion/core/tests/physical_optimizer/join_selection.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use insta::assert_snapshot; use std::sync::Arc; use std::{ any::Any, @@ -25,29 +26,28 @@ use std::{ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::config::ConfigOptions; -use datafusion_common::JoinSide; -use datafusion_common::{stats::Precision, ColumnStatistics, JoinType, ScalarValue}; +use datafusion_common::tree_node::TreeNodeRecursion; +use datafusion_common::{ColumnStatistics, JoinType, ScalarValue, stats::Precision}; +use datafusion_common::{JoinSide, NullEquality}; use datafusion_common::{Result, Statistics}; use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext}; use datafusion_expr::Operator; +use datafusion_physical_expr::PhysicalExprRef; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::expressions::{BinaryExpr, Column, NegativeExpr}; use datafusion_physical_expr::intervals::utils::check_support; -use datafusion_physical_expr::PhysicalExprRef; use datafusion_physical_expr::{EquivalenceProperties, Partitioning, PhysicalExpr}; -use datafusion_physical_optimizer::join_selection::{ - hash_join_swap_subrule, JoinSelection, -}; use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_optimizer::join_selection::JoinSelection; +use datafusion_physical_plan::ExecutionPlanProperties; use datafusion_physical_plan::displayable; use datafusion_physical_plan::joins::utils::ColumnIndex; use datafusion_physical_plan::joins::utils::JoinFilter; use datafusion_physical_plan::joins::{HashJoinExec, NestedLoopJoinExec, PartitionMode}; use datafusion_physical_plan::projection::ProjectionExec; -use datafusion_physical_plan::ExecutionPlanProperties; use datafusion_physical_plan::{ - execution_plan::{Boundedness, EmissionType}, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, + execution_plan::{Boundedness, EmissionType}, }; use futures::Stream; @@ -222,6 +222,7 @@ async fn test_join_with_swap() { &JoinType::Left, None, PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, false, ) .unwrap(), @@ -237,12 +238,12 @@ async fn test_join_with_swap() { .expect("A proj is required to swap columns back to their original order"); assert_eq!(swapping_projection.expr().len(), 2); - let (col, name) = &swapping_projection.expr()[0]; - assert_eq!(name, "big_col"); - assert_col_expr(col, "big_col", 1); - let (col, name) = &swapping_projection.expr()[1]; - assert_eq!(name, "small_col"); - assert_col_expr(col, "small_col", 0); + let proj_expr = &swapping_projection.expr()[0]; + assert_eq!(proj_expr.alias, "big_col"); + assert_col_expr(&proj_expr.expr, "big_col", 1); + let proj_expr = &swapping_projection.expr()[1]; + assert_eq!(proj_expr.alias, "small_col"); + assert_col_expr(&proj_expr.expr, "small_col", 0); let swapped_join = swapping_projection .input() @@ -284,6 +285,7 @@ async fn test_left_join_no_swap() { &JoinType::Left, None, PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, false, ) .unwrap(), @@ -333,6 +335,7 @@ async fn test_join_with_swap_semi() { &join_type, None, PartitionMode::Partitioned, + NullEquality::NullEqualsNothing, false, ) .unwrap(); @@ -371,10 +374,65 @@ async fn test_join_with_swap_semi() { } } +#[tokio::test] +async fn test_join_with_swap_mark() { + let join_types = [JoinType::LeftMark, JoinType::RightMark]; + for join_type in join_types { + let (big, small) = create_big_and_small(); + + let join = HashJoinExec::try_new( + Arc::clone(&big), + Arc::clone(&small), + vec![( + Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()), + Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()), + )], + None, + &join_type, + None, + PartitionMode::Partitioned, + NullEquality::NullEqualsNothing, + false, + ) + .unwrap(); + + let original_schema = join.schema(); + + let optimized_join = JoinSelection::new() + .optimize(Arc::new(join), &ConfigOptions::new()) + .unwrap(); + + let swapped_join = optimized_join + .as_any() + .downcast_ref::() + .expect( + "A proj is not required to swap columns back to their original order", + ); + + assert_eq!(swapped_join.schema().fields().len(), 2); + assert_eq!( + swapped_join + .left() + .partition_statistics(None) + .unwrap() + .total_byte_size, + Precision::Inexact(8192) + ); + assert_eq!( + swapped_join + .right() + .partition_statistics(None) + .unwrap() + .total_byte_size, + Precision::Inexact(2097152) + ); + assert_eq!(original_schema, swapped_join.schema()); + } +} + /// Compare the input plan with the plan after running the probe order optimizer. macro_rules! assert_optimized { - ($EXPECTED_LINES: expr, $PLAN: expr) => { - let expected_lines = $EXPECTED_LINES.iter().map(|s| *s).collect::>(); + ($PLAN: expr, @$EXPECTED_LINES: literal $(,)?) => { let plan = Arc::new($PLAN); let optimized = JoinSelection::new() @@ -382,12 +440,11 @@ macro_rules! assert_optimized { .unwrap(); let plan_string = displayable(optimized.as_ref()).indent(true).to_string(); - let actual_lines = plan_string.split("\n").collect::>(); + let actual = plan_string.trim(); - assert_eq!( - &expected_lines, &actual_lines, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected_lines, actual_lines + assert_snapshot!( + actual, + @$EXPECTED_LINES ); }; } @@ -408,6 +465,7 @@ async fn test_nested_join_swap() { &JoinType::Inner, None, PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, false, ) .unwrap(); @@ -425,6 +483,7 @@ async fn test_nested_join_swap() { &JoinType::Left, None, PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, false, ) .unwrap(); @@ -436,17 +495,18 @@ async fn test_nested_join_swap() { // The first hash join's left is 'small' table (with 1000 rows), and the second hash join's // left is the F(small IJ big) which has an estimated cardinality of 2000 rows (vs medium which // has an exact cardinality of 10_000 rows). - let expected = [ - "ProjectionExec: expr=[medium_col@2 as medium_col, big_col@0 as big_col, small_col@1 as small_col]", - " HashJoinExec: mode=CollectLeft, join_type=Right, on=[(small_col@1, medium_col@0)]", - " ProjectionExec: expr=[big_col@1 as big_col, small_col@0 as small_col]", - " HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(small_col@0, big_col@0)]", - " StatisticsExec: col_count=1, row_count=Inexact(1000)", - " StatisticsExec: col_count=1, row_count=Inexact(100000)", - " StatisticsExec: col_count=1, row_count=Inexact(10000)", - "", - ]; - assert_optimized!(expected, join); + assert_optimized!( + join, + @r" + ProjectionExec: expr=[medium_col@2 as medium_col, big_col@0 as big_col, small_col@1 as small_col] + HashJoinExec: mode=CollectLeft, join_type=Right, on=[(small_col@1, medium_col@0)] + ProjectionExec: expr=[big_col@1 as big_col, small_col@0 as small_col] + HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(small_col@0, big_col@0)] + StatisticsExec: col_count=1, row_count=Inexact(1000) + StatisticsExec: col_count=1, row_count=Inexact(100000) + StatisticsExec: col_count=1, row_count=Inexact(10000) + " + ); } #[tokio::test] @@ -464,6 +524,7 @@ async fn test_join_no_swap() { &JoinType::Inner, None, PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, false, ) .unwrap(), @@ -528,12 +589,12 @@ async fn test_nl_join_with_swap(join_type: JoinType) { .expect("A proj is required to swap columns back to their original order"); assert_eq!(swapping_projection.expr().len(), 2); - let (col, name) = &swapping_projection.expr()[0]; - assert_eq!(name, "big_col"); - assert_col_expr(col, "big_col", 1); - let (col, name) = &swapping_projection.expr()[1]; - assert_eq!(name, "small_col"); - assert_col_expr(col, "small_col", 0); + let proj_expr = &swapping_projection.expr()[0]; + assert_eq!(proj_expr.alias, "big_col"); + assert_col_expr(&proj_expr.expr, "big_col", 1); + let proj_expr = &swapping_projection.expr()[1]; + assert_eq!(proj_expr.alias, "small_col"); + assert_col_expr(&proj_expr.expr, "small_col", 0); let swapped_join = swapping_projection .input() @@ -578,7 +639,8 @@ async fn test_nl_join_with_swap(join_type: JoinType) { case::left_semi(JoinType::LeftSemi), case::left_anti(JoinType::LeftAnti), case::right_semi(JoinType::RightSemi), - case::right_anti(JoinType::RightAnti) + case::right_anti(JoinType::RightAnti), + case::right_mark(JoinType::RightMark) )] #[tokio::test] async fn test_nl_join_with_swap_no_proj(join_type: JoinType) { @@ -690,6 +752,7 @@ async fn test_hash_join_swap_on_joins_with_projections( &join_type, Some(projection), PartitionMode::Partitioned, + NullEquality::NullEqualsNothing, false, )?); @@ -700,7 +763,7 @@ async fn test_hash_join_swap_on_joins_with_projections( "ProjectionExec won't be added above if HashJoinExec contains embedded projection", ); - assert_eq!(swapped_join.projection, Some(vec![0_usize])); + assert_eq!(swapped_join.projection.as_deref().unwrap(), &[0_usize]); assert_eq!(swapped.schema().fields.len(), 1); assert_eq!(swapped.schema().fields[0].name(), "small_col"); Ok(()) @@ -851,6 +914,7 @@ fn check_join_partition_mode( &JoinType::Inner, None, PartitionMode::Auto, + NullEquality::NullEqualsNothing, false, ) .unwrap(), @@ -895,10 +959,10 @@ impl Stream for UnboundedStream { mut self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll> { - if let Some(val) = self.batch_produce { - if val <= self.count { - return Poll::Ready(None); - } + if let Some(val) = self.batch_produce + && val <= self.count + { + return Poll::Ready(None); } self.count += 1; Poll::Ready(Some(Ok(self.batch.clone()))) @@ -916,7 +980,7 @@ impl RecordBatchStream for UnboundedStream { pub struct UnboundedExec { batch_produce: Option, batch: RecordBatch, - cache: PlanProperties, + cache: Arc, } impl UnboundedExec { @@ -932,7 +996,7 @@ impl UnboundedExec { Self { batch_produce, batch, - cache, + cache: Arc::new(cache), } } @@ -989,7 +1053,7 @@ impl ExecutionPlan for UnboundedExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -1015,6 +1079,20 @@ impl ExecutionPlan for UnboundedExec { batch: self.batch.clone(), })) } + + fn apply_expressions( + &self, + f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + // Visit expressions in the output ordering from equivalence properties + let mut tnr = TreeNodeRecursion::Continue; + if let Some(ordering) = self.cache.output_ordering() { + for sort_expr in ordering { + tnr = tnr.visit_sibling(|| f(sort_expr.expr.as_ref()))?; + } + } + Ok(tnr) + } } #[derive(Eq, PartialEq, Debug)] @@ -1028,20 +1106,21 @@ pub enum SourceType { pub struct StatisticsExec { stats: Statistics, schema: Arc, - cache: PlanProperties, + cache: Arc, } impl StatisticsExec { pub fn new(stats: Statistics, schema: Schema) -> Self { assert_eq!( - stats.column_statistics.len(), schema.fields().len(), - "if defined, the column statistics vector length should be the number of fields" - ); + stats.column_statistics.len(), + schema.fields().len(), + "if defined, the column statistics vector length should be the number of fields" + ); let cache = Self::compute_properties(Arc::new(schema.clone())); Self { stats, schema: Arc::new(schema), - cache, + cache: Arc::new(cache), } } @@ -1089,7 +1168,7 @@ impl ExecutionPlan for StatisticsExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -1112,16 +1191,26 @@ impl ExecutionPlan for StatisticsExec { unimplemented!("This plan only serves for testing statistics") } - fn statistics(&self) -> Result { - Ok(self.stats.clone()) - } - - fn partition_statistics(&self, partition: Option) -> Result { - Ok(if partition.is_some() { + fn partition_statistics(&self, partition: Option) -> Result> { + Ok(Arc::new(if partition.is_some() { Statistics::new_unknown(&self.schema) } else { self.stats.clone() - }) + })) + } + + fn apply_expressions( + &self, + f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + // Visit expressions in the output ordering from equivalence properties + let mut tnr = TreeNodeRecursion::Continue; + if let Some(ordering) = self.cache.output_ordering() { + for sort_expr in ordering { + tnr = tnr.visit_sibling(|| f(sort_expr.expr.as_ref()))?; + } + } + Ok(tnr) } } @@ -1498,10 +1587,12 @@ async fn test_join_with_maybe_swap_unbounded_case(t: TestCase) -> Result<()> { &t.initial_join_type, None, t.initial_mode, + NullEquality::NullEqualsNothing, false, )?) as _; - let optimized_join_plan = hash_join_swap_subrule(join, &ConfigOptions::new())?; + let optimized_join_plan = + JoinSelection::new().optimize(Arc::clone(&join), &ConfigOptions::new())?; // If swap did happen let projection_added = optimized_join_plan.as_any().is::(); diff --git a/datafusion/core/tests/physical_optimizer/limit_pushdown.rs b/datafusion/core/tests/physical_optimizer/limit_pushdown.rs index dd2c1960a6580..b8c4d6d6f0d7a 100644 --- a/datafusion/core/tests/physical_optimizer/limit_pushdown.rs +++ b/datafusion/core/tests/physical_optimizer/limit_pushdown.rs @@ -17,28 +17,28 @@ use std::sync::Arc; +use crate::physical_optimizer::test_utils::{ + coalesce_partitions_exec, global_limit_exec, hash_join_exec, local_limit_exec, + sort_exec, sort_preserving_merge_exec, stream_exec, +}; + use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::config::ConfigOptions; use datafusion_common::error::Result; -use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_expr::Operator; -use datafusion_physical_expr::expressions::BinaryExpr; -use datafusion_physical_expr::expressions::{col, lit}; -use datafusion_physical_expr::{Partitioning, PhysicalSortExpr}; -use datafusion_physical_optimizer::limit_pushdown::LimitPushdown; +use datafusion_expr::{JoinType, Operator}; +use datafusion_physical_expr::Partitioning; +use datafusion_physical_expr::expressions::{BinaryExpr, col, lit}; +use datafusion_physical_expr_common::physical_expr::PhysicalExprRef; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use datafusion_physical_optimizer::PhysicalOptimizerRule; -use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; -use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion_physical_optimizer::limit_pushdown::LimitPushdown; use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::filter::FilterExec; -use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use datafusion_physical_plan::joins::NestedLoopJoinExec; use datafusion_physical_plan::projection::ProjectionExec; use datafusion_physical_plan::repartition::RepartitionExec; -use datafusion_physical_plan::sorts::sort::SortExec; -use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; -use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; -use datafusion_physical_plan::{get_plan_string, ExecutionPlan, ExecutionPlanProperties}; +use datafusion_physical_plan::{ExecutionPlan, get_plan_string}; fn create_schema() -> SchemaRef { Arc::new(Schema::new(vec![ @@ -48,48 +48,6 @@ fn create_schema() -> SchemaRef { ])) } -fn streaming_table_exec(schema: SchemaRef) -> Result> { - Ok(Arc::new(StreamingTableExec::try_new( - Arc::clone(&schema), - vec![Arc::new(DummyStreamPartition { schema }) as _], - None, - None, - true, - None, - )?)) -} - -fn global_limit_exec( - input: Arc, - skip: usize, - fetch: Option, -) -> Arc { - Arc::new(GlobalLimitExec::new(input, skip, fetch)) -} - -fn local_limit_exec( - input: Arc, - fetch: usize, -) -> Arc { - Arc::new(LocalLimitExec::new(input, fetch)) -} - -fn sort_exec( - sort_exprs: impl IntoIterator, - input: Arc, -) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new(SortExec::new(sort_exprs, input)) -} - -fn sort_preserving_merge_exec( - sort_exprs: impl IntoIterator, - input: Arc, -) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new(SortPreservingMergeExec::new(sort_exprs, input)) -} - fn projection_exec( schema: SchemaRef, input: Arc, @@ -118,16 +76,6 @@ fn filter_exec( )?)) } -fn coalesce_batches_exec(input: Arc) -> Arc { - Arc::new(CoalesceBatchesExec::new(input, 8192)) -} - -fn coalesce_partitions_exec( - local_limit: Arc, -) -> Arc { - Arc::new(CoalescePartitionsExec::new(local_limit)) -} - fn repartition_exec( streaming_table: Arc, ) -> Result> { @@ -141,168 +89,272 @@ fn empty_exec(schema: SchemaRef) -> Arc { Arc::new(EmptyExec::new(schema)) } -#[derive(Debug)] -struct DummyStreamPartition { - schema: SchemaRef, +fn nested_loop_join_exec( + left: Arc, + right: Arc, + join_type: JoinType, +) -> Result> { + Ok(Arc::new(NestedLoopJoinExec::try_new( + left, right, None, &join_type, None, + )?)) } -impl PartitionStream for DummyStreamPartition { - fn schema(&self) -> &SchemaRef { - &self.schema - } - fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream { - unreachable!() - } + +fn format_plan(plan: &Arc) -> String { + get_plan_string(plan).join("\n") } #[test] fn transforms_streaming_table_exec_into_fetching_version_when_skip_is_zero() -> Result<()> { let schema = create_schema(); - let streaming_table = streaming_table_exec(schema)?; + let streaming_table = stream_exec(&schema); let global_limit = global_limit_exec(streaming_table, 0, Some(5)); - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=0, fetch=5", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - assert_eq!(initial, expected_initial); + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=0, fetch=5 + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true + " + ); let after_optimize = LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; - let expected = [ - "StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true, fetch=5" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @"StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true, fetch=5" + ); Ok(()) } #[test] -fn transforms_streaming_table_exec_into_fetching_version_and_keeps_the_global_limit_when_skip_is_nonzero( -) -> Result<()> { +fn transforms_streaming_table_exec_into_fetching_version_and_keeps_the_global_limit_when_skip_is_nonzero() +-> Result<()> { let schema = create_schema(); - let streaming_table = streaming_table_exec(schema)?; + let streaming_table = stream_exec(&schema); let global_limit = global_limit_exec(streaming_table, 2, Some(5)); - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=2, fetch=5", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - assert_eq!(initial, expected_initial); + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=2, fetch=5 + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true + " + ); let after_optimize = LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; - let expected = [ - "GlobalLimitExec: skip=2, fetch=5", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true, fetch=7" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + GlobalLimitExec: skip=2, fetch=5 + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true, fetch=7 + " + ); Ok(()) } +fn join_on_columns( + left_col: &str, + right_col: &str, +) -> Vec<(PhysicalExprRef, PhysicalExprRef)> { + vec![( + Arc::new(datafusion_physical_expr::expressions::Column::new( + left_col, 0, + )) as _, + Arc::new(datafusion_physical_expr::expressions::Column::new( + right_col, 0, + )) as _, + )] +} + #[test] -fn transforms_coalesce_batches_exec_into_fetching_version_and_removes_local_limit( -) -> Result<()> { +fn absorbs_limit_into_hash_join_inner() -> Result<()> { + // HashJoinExec with Inner join should absorb limit via with_fetch let schema = create_schema(); - let streaming_table = streaming_table_exec(Arc::clone(&schema))?; - let repartition = repartition_exec(streaming_table)?; - let filter = filter_exec(schema, repartition)?; - let coalesce_batches = coalesce_batches_exec(filter); - let local_limit = local_limit_exec(coalesce_batches, 5); - let coalesce_partitions = coalesce_partitions_exec(local_limit); - let global_limit = global_limit_exec(coalesce_partitions, 0, Some(5)); - - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=0, fetch=5", - " CoalescePartitionsExec", - " LocalLimitExec: fetch=5", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c3@2 > 0", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - assert_eq!(initial, expected_initial); + let left = empty_exec(Arc::clone(&schema)); + let right = empty_exec(Arc::clone(&schema)); + let on = join_on_columns("c1", "c1"); + let hash_join = hash_join_exec(left, right, on, None, &JoinType::Inner)?; + let global_limit = global_limit_exec(hash_join, 0, Some(5)); + + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=0, fetch=5 + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c1@0, c1@0)] + EmptyExec + EmptyExec + " + ); let after_optimize = LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; - - let expected = [ - "CoalescePartitionsExec: fetch=5", - " CoalesceBatchesExec: target_batch_size=8192, fetch=5", - " FilterExec: c3@2 > 0", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + // The limit should be absorbed by the hash join (not pushed to children) + insta::assert_snapshot!( + optimized, + @r" + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c1@0, c1@0)], fetch=5 + EmptyExec + EmptyExec + " + ); Ok(()) } #[test] -fn pushes_global_limit_exec_through_projection_exec() -> Result<()> { +fn absorbs_limit_into_hash_join_right() -> Result<()> { + // HashJoinExec with Right join should absorb limit via with_fetch let schema = create_schema(); - let streaming_table = streaming_table_exec(Arc::clone(&schema))?; - let filter = filter_exec(Arc::clone(&schema), streaming_table)?; - let projection = projection_exec(schema, filter)?; - let global_limit = global_limit_exec(projection, 0, Some(5)); + let left = empty_exec(Arc::clone(&schema)); + let right = empty_exec(Arc::clone(&schema)); + let on = join_on_columns("c1", "c1"); + let hash_join = hash_join_exec(left, right, on, None, &JoinType::Right)?; + let global_limit = global_limit_exec(hash_join, 0, Some(10)); + + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=0, fetch=10 + HashJoinExec: mode=Partitioned, join_type=Right, on=[(c1@0, c1@0)] + EmptyExec + EmptyExec + " + ); + + let after_optimize = + LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; + let optimized = format_plan(&after_optimize); + // The limit should be absorbed by the hash join + insta::assert_snapshot!( + optimized, + @r" + HashJoinExec: mode=Partitioned, join_type=Right, on=[(c1@0, c1@0)], fetch=10 + EmptyExec + EmptyExec + " + ); + + Ok(()) +} - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=0, fetch=5", - " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", - " FilterExec: c3@2 > 0", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - assert_eq!(initial, expected_initial); +#[test] +fn absorbs_limit_into_hash_join_left() -> Result<()> { + // during probing, then unmatched rows at the end, stopping when limit is reached + let schema = create_schema(); + let left = empty_exec(Arc::clone(&schema)); + let right = empty_exec(Arc::clone(&schema)); + let on = join_on_columns("c1", "c1"); + let hash_join = hash_join_exec(left, right, on, None, &JoinType::Left)?; + let global_limit = global_limit_exec(hash_join, 0, Some(5)); + + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=0, fetch=5 + HashJoinExec: mode=Partitioned, join_type=Left, on=[(c1@0, c1@0)] + EmptyExec + EmptyExec + " + ); let after_optimize = LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; + let optimized = format_plan(&after_optimize); + // Left join now absorbs the limit + insta::assert_snapshot!( + optimized, + @r" + HashJoinExec: mode=Partitioned, join_type=Left, on=[(c1@0, c1@0)], fetch=5 + EmptyExec + EmptyExec + " + ); - let expected = [ - "ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", - " GlobalLimitExec: skip=0, fetch=5", - " FilterExec: c3@2 > 0", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + Ok(()) +} + +#[test] +fn absorbs_limit_with_skip_into_hash_join() -> Result<()> { + let schema = create_schema(); + let left = empty_exec(Arc::clone(&schema)); + let right = empty_exec(Arc::clone(&schema)); + let on = join_on_columns("c1", "c1"); + let hash_join = hash_join_exec(left, right, on, None, &JoinType::Inner)?; + let global_limit = global_limit_exec(hash_join, 3, Some(5)); + + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=3, fetch=5 + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c1@0, c1@0)] + EmptyExec + EmptyExec + " + ); + + let after_optimize = + LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; + let optimized = format_plan(&after_optimize); + // With skip, GlobalLimit is kept but fetch (skip + limit = 8) is absorbed by the join + insta::assert_snapshot!( + optimized, + @r" + GlobalLimitExec: skip=3, fetch=5 + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c1@0, c1@0)], fetch=8 + EmptyExec + EmptyExec + " + ); Ok(()) } #[test] -fn pushes_global_limit_exec_through_projection_exec_and_transforms_coalesce_batches_exec_into_fetching_version( -) -> Result<()> { +fn pushes_global_limit_exec_through_projection_exec() -> Result<()> { let schema = create_schema(); - let streaming_table = streaming_table_exec(Arc::clone(&schema)).unwrap(); - let coalesce_batches = coalesce_batches_exec(streaming_table); - let projection = projection_exec(schema, coalesce_batches)?; + let streaming_table = stream_exec(&schema); + let filter = filter_exec(Arc::clone(&schema), streaming_table)?; + let projection = projection_exec(schema, filter)?; let global_limit = global_limit_exec(projection, 0, Some(5)); - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=0, fetch=5", - " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", - " CoalesceBatchesExec: target_batch_size=8192", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - - assert_eq!(initial, expected_initial); + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=0, fetch=5 + ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3] + FilterExec: c3@2 > 0 + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true + " + ); let after_optimize = LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; - let expected = [ - "ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", - " CoalesceBatchesExec: target_batch_size=8192, fetch=5", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3] + FilterExec: c3@2 > 0, fetch=5 + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true + " + ); Ok(()) } @@ -310,45 +362,45 @@ fn pushes_global_limit_exec_through_projection_exec_and_transforms_coalesce_batc #[test] fn pushes_global_limit_into_multiple_fetch_plans() -> Result<()> { let schema = create_schema(); - let streaming_table = streaming_table_exec(Arc::clone(&schema)).unwrap(); - let coalesce_batches = coalesce_batches_exec(streaming_table); - let projection = projection_exec(Arc::clone(&schema), coalesce_batches)?; + let streaming_table = stream_exec(&schema); + let projection = projection_exec(Arc::clone(&schema), streaming_table)?; let repartition = repartition_exec(projection)?; - let sort = sort_exec( - vec![PhysicalSortExpr { - expr: col("c1", &schema)?, - options: SortOptions::default(), - }], - repartition, - ); - let spm = sort_preserving_merge_exec(sort.output_ordering().unwrap().to_vec(), sort); + let ordering: LexOrdering = [PhysicalSortExpr { + expr: col("c1", &schema)?, + options: SortOptions::default(), + }] + .into(); + let sort = sort_exec(ordering.clone(), repartition); + let spm = sort_preserving_merge_exec(ordering, sort); let global_limit = global_limit_exec(spm, 0, Some(5)); - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=0, fetch=5", - " SortPreservingMergeExec: [c1@0 ASC]", - " SortExec: expr=[c1@0 ASC], preserve_partitioning=[false]", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", - " CoalesceBatchesExec: target_batch_size=8192", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - - assert_eq!(initial, expected_initial); + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=0, fetch=5 + SortPreservingMergeExec: [c1@0 ASC] + SortExec: expr=[c1@0 ASC], preserve_partitioning=[false] + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3] + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true + " + ); let after_optimize = LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; - let expected = [ - "SortPreservingMergeExec: [c1@0 ASC], fetch=5", - " SortExec: TopK(fetch=5), expr=[c1@0 ASC], preserve_partitioning=[false]", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", - " CoalesceBatchesExec: target_batch_size=8192", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + SortPreservingMergeExec: [c1@0 ASC], fetch=5 + SortExec: TopK(fetch=5), expr=[c1@0 ASC], preserve_partitioning=[false] + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3] + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true + " + ); Ok(()) } @@ -357,32 +409,37 @@ fn pushes_global_limit_into_multiple_fetch_plans() -> Result<()> { fn keeps_pushed_local_limit_exec_when_there_are_multiple_input_partitions() -> Result<()> { let schema = create_schema(); - let streaming_table = streaming_table_exec(Arc::clone(&schema))?; + let streaming_table = stream_exec(&schema); let repartition = repartition_exec(streaming_table)?; let filter = filter_exec(schema, repartition)?; let coalesce_partitions = coalesce_partitions_exec(filter); let global_limit = global_limit_exec(coalesce_partitions, 0, Some(5)); - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=0, fetch=5", - " CoalescePartitionsExec", - " FilterExec: c3@2 > 0", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - assert_eq!(initial, expected_initial); + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=0, fetch=5 + CoalescePartitionsExec + FilterExec: c3@2 > 0 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true + " + ); let after_optimize = LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; - let expected = [ - "CoalescePartitionsExec: fetch=5", - " FilterExec: c3@2 > 0", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + CoalescePartitionsExec: fetch=5 + FilterExec: c3@2 > 0, fetch=5 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true + " + ); Ok(()) } @@ -394,20 +451,27 @@ fn merges_local_limit_with_local_limit() -> Result<()> { let child_local_limit = local_limit_exec(empty_exec, 10); let parent_local_limit = local_limit_exec(child_local_limit, 20); - let initial = get_plan_string(&parent_local_limit); - let expected_initial = [ - "LocalLimitExec: fetch=20", - " LocalLimitExec: fetch=10", - " EmptyExec", - ]; - - assert_eq!(initial, expected_initial); + let initial = format_plan(&parent_local_limit); + insta::assert_snapshot!( + initial, + @r" + LocalLimitExec: fetch=20 + LocalLimitExec: fetch=10 + EmptyExec + " + ); let after_optimize = LimitPushdown::new().optimize(parent_local_limit, &ConfigOptions::new())?; - let expected = ["GlobalLimitExec: skip=0, fetch=10", " EmptyExec"]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + GlobalLimitExec: skip=0, fetch=10 + EmptyExec + " + ); Ok(()) } @@ -419,20 +483,27 @@ fn merges_global_limit_with_global_limit() -> Result<()> { let child_global_limit = global_limit_exec(empty_exec, 10, Some(30)); let parent_global_limit = global_limit_exec(child_global_limit, 10, Some(20)); - let initial = get_plan_string(&parent_global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=10, fetch=20", - " GlobalLimitExec: skip=10, fetch=30", - " EmptyExec", - ]; - - assert_eq!(initial, expected_initial); + let initial = format_plan(&parent_global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=10, fetch=20 + GlobalLimitExec: skip=10, fetch=30 + EmptyExec + " + ); let after_optimize = LimitPushdown::new().optimize(parent_global_limit, &ConfigOptions::new())?; - let expected = ["GlobalLimitExec: skip=20, fetch=20", " EmptyExec"]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + GlobalLimitExec: skip=20, fetch=20 + EmptyExec + " + ); Ok(()) } @@ -444,20 +515,27 @@ fn merges_global_limit_with_local_limit() -> Result<()> { let local_limit = local_limit_exec(empty_exec, 40); let global_limit = global_limit_exec(local_limit, 20, Some(30)); - let initial = get_plan_string(&global_limit); - let expected_initial = [ - "GlobalLimitExec: skip=20, fetch=30", - " LocalLimitExec: fetch=40", - " EmptyExec", - ]; - - assert_eq!(initial, expected_initial); + let initial = format_plan(&global_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=20, fetch=30 + LocalLimitExec: fetch=40 + EmptyExec + " + ); let after_optimize = LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; - let expected = ["GlobalLimitExec: skip=20, fetch=20", " EmptyExec"]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + GlobalLimitExec: skip=20, fetch=20 + EmptyExec + " + ); Ok(()) } @@ -469,20 +547,138 @@ fn merges_local_limit_with_global_limit() -> Result<()> { let global_limit = global_limit_exec(empty_exec, 20, Some(30)); let local_limit = local_limit_exec(global_limit, 20); - let initial = get_plan_string(&local_limit); - let expected_initial = [ - "LocalLimitExec: fetch=20", - " GlobalLimitExec: skip=20, fetch=30", - " EmptyExec", - ]; - - assert_eq!(initial, expected_initial); + let initial = format_plan(&local_limit); + insta::assert_snapshot!( + initial, + @r" + LocalLimitExec: fetch=20 + GlobalLimitExec: skip=20, fetch=30 + EmptyExec + " + ); let after_optimize = LimitPushdown::new().optimize(local_limit, &ConfigOptions::new())?; - let expected = ["GlobalLimitExec: skip=20, fetch=20", " EmptyExec"]; - assert_eq!(get_plan_string(&after_optimize), expected); + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + GlobalLimitExec: skip=20, fetch=20 + EmptyExec + " + ); + + Ok(()) +} + +#[test] +fn preserves_nested_global_limit() -> Result<()> { + // If there are multiple limits in an execution plan, they all need to be + // preserved in the optimized plan. + // + // Plan structure: + // GlobalLimitExec: skip=1, fetch=1 + // NestedLoopJoinExec (Left) + // EmptyExec (left side) + // GlobalLimitExec: skip=2, fetch=1 + // NestedLoopJoinExec (Right) + // EmptyExec (left side) + // EmptyExec (right side) + let schema = create_schema(); + + // Build inner join: NestedLoopJoin(Empty, Empty) + let inner_left = empty_exec(Arc::clone(&schema)); + let inner_right = empty_exec(Arc::clone(&schema)); + let inner_join = nested_loop_join_exec(inner_left, inner_right, JoinType::Right)?; + + // Add inner limit: GlobalLimitExec: skip=2, fetch=1 + let inner_limit = global_limit_exec(inner_join, 2, Some(1)); + + // Build outer join: NestedLoopJoin(Empty, GlobalLimit) + let outer_left = empty_exec(Arc::clone(&schema)); + let outer_join = nested_loop_join_exec(outer_left, inner_limit, JoinType::Left)?; + + // Add outer limit: GlobalLimitExec: skip=1, fetch=1 + let outer_limit = global_limit_exec(outer_join, 1, Some(1)); + + let initial = format_plan(&outer_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=1, fetch=1 + NestedLoopJoinExec: join_type=Left + EmptyExec + GlobalLimitExec: skip=2, fetch=1 + NestedLoopJoinExec: join_type=Right + EmptyExec + EmptyExec + " + ); + + let after_optimize = + LimitPushdown::new().optimize(outer_limit, &ConfigOptions::new())?; + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + GlobalLimitExec: skip=1, fetch=1 + NestedLoopJoinExec: join_type=Left + EmptyExec + GlobalLimitExec: skip=2, fetch=1 + NestedLoopJoinExec: join_type=Right + EmptyExec + EmptyExec + " + ); + + Ok(()) +} + +#[test] +fn preserves_skip_before_sort() -> Result<()> { + // If there's a limit with skip before a node that (1) supports fetch but + // (2) does not support limit pushdown, that limit should not be removed. + // + // Plan structure: + // GlobalLimitExec: skip=1, fetch=None + // SortExec: TopK(fetch=4) + // EmptyExec + let schema = create_schema(); + + let empty = empty_exec(Arc::clone(&schema)); + + let ordering = [PhysicalSortExpr { + expr: col("c1", &schema)?, + options: SortOptions::default(), + }]; + let sort = sort_exec(ordering.into(), empty) + .with_fetch(Some(4)) + .unwrap(); + + let outer_limit = global_limit_exec(sort, 1, None); + + let initial = format_plan(&outer_limit); + insta::assert_snapshot!( + initial, + @r" + GlobalLimitExec: skip=1, fetch=None + SortExec: TopK(fetch=4), expr=[c1@0 ASC], preserve_partitioning=[false] + EmptyExec + " + ); + + let after_optimize = + LimitPushdown::new().optimize(outer_limit, &ConfigOptions::new())?; + let optimized = format_plan(&after_optimize); + insta::assert_snapshot!( + optimized, + @r" + GlobalLimitExec: skip=1, fetch=3 + SortExec: TopK(fetch=4), expr=[c1@0 ASC], preserve_partitioning=[false] + EmptyExec + " + ); Ok(()) } diff --git a/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs index f9810eab8f594..c523b4a752a82 100644 --- a/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs +++ b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs @@ -17,11 +17,12 @@ //! Integration tests for [`LimitedDistinctAggregation`] physical optimizer rule +use insta::assert_snapshot; use std::sync::Arc; use crate::physical_optimizer::test_utils::{ - assert_plan_matches_expected, build_group_by, mock_data, parquet_exec_with_sort, - schema, TestAggregate, + TestAggregate, build_group_by, get_optimized_plan, mock_data, parquet_exec_with_sort, + schema, }; use arrow::datatypes::DataType; @@ -30,26 +31,21 @@ use datafusion::prelude::SessionContext; use datafusion_common::Result; use datafusion_execution::config::SessionConfig; use datafusion_expr::Operator; -use datafusion_physical_expr::expressions::cast; -use datafusion_physical_expr::{expressions, expressions::col, PhysicalSortExpr}; -use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_expr::expressions::{self, cast, col}; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use datafusion_physical_plan::{ + ExecutionPlan, aggregates::{AggregateExec, AggregateMode}, collect, limit::{GlobalLimitExec, LocalLimitExec}, - ExecutionPlan, }; -async fn assert_results_match_expected( - plan: Arc, - expected: &str, -) -> Result<()> { +async fn run_plan_and_format(plan: Arc) -> Result { let cfg = SessionConfig::new().with_target_partitions(1); let ctx = SessionContext::new_with_config(cfg); let batches = collect(plan, ctx.task_ctx()).await?; let actual = format!("{}", pretty_format_batches(&batches)?); - assert_eq!(actual, expected); - Ok(()) + Ok(actual) } #[tokio::test] @@ -78,27 +74,33 @@ async fn test_partial_final() -> Result<()> { Arc::new(final_agg), 4, // fetch ); - // expected to push the limit to the Partial and Final AggregateExecs - let expected = [ - "LocalLimitExec: fetch=4", - "AggregateExec: mode=Final, gby=[a@0 as a], aggr=[], lim=[4]", - "AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[], lim=[4]", - "DataSourceExec: partitions=1, partition_sizes=[1]", - ]; let plan: Arc = Arc::new(limit_exec); - assert_plan_matches_expected(&plan, &expected)?; - let expected = r#" -+---+ -| a | -+---+ -| 1 | -| 2 | -| | -| 4 | -+---+ -"# - .trim(); - assert_results_match_expected(plan, expected).await?; + let formatted = get_optimized_plan(&plan)?; + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r" + LocalLimitExec: fetch=4 + AggregateExec: mode=Final, gby=[a@0 as a], aggr=[], lim=[4] + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[], lim=[4] + DataSourceExec: partitions=1, partition_sizes=[1] + " + ); + let expected = run_plan_and_format(plan).await?; + assert_snapshot!( + expected, + @r" + +---+ + | a | + +---+ + | 1 | + | 2 | + | | + | 4 | + +---+ + " + ); + Ok(()) } @@ -121,25 +123,31 @@ async fn test_single_local() -> Result<()> { 4, // fetch ); // expected to push the limit to the AggregateExec - let expected = [ - "LocalLimitExec: fetch=4", - "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], lim=[4]", - "DataSourceExec: partitions=1, partition_sizes=[1]", - ]; let plan: Arc = Arc::new(limit_exec); - assert_plan_matches_expected(&plan, &expected)?; - let expected = r#" -+---+ -| a | -+---+ -| 1 | -| 2 | -| | -| 4 | -+---+ -"# - .trim(); - assert_results_match_expected(plan, expected).await?; + let formatted = get_optimized_plan(&plan)?; + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r" + LocalLimitExec: fetch=4 + AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], lim=[4] + DataSourceExec: partitions=1, partition_sizes=[1] + " + ); + let expected = run_plan_and_format(plan).await?; + assert_snapshot!( + expected, + @r" + +---+ + | a | + +---+ + | 1 | + | 2 | + | | + | 4 | + +---+ + " + ); Ok(()) } @@ -163,24 +171,30 @@ async fn test_single_global() -> Result<()> { Some(3), // fetch ); // expected to push the skip+fetch limit to the AggregateExec - let expected = [ - "GlobalLimitExec: skip=1, fetch=3", - "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], lim=[4]", - "DataSourceExec: partitions=1, partition_sizes=[1]", - ]; let plan: Arc = Arc::new(limit_exec); - assert_plan_matches_expected(&plan, &expected)?; - let expected = r#" -+---+ -| a | -+---+ -| 2 | -| | -| 4 | -+---+ -"# - .trim(); - assert_results_match_expected(plan, expected).await?; + let formatted = get_optimized_plan(&plan)?; + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r" + GlobalLimitExec: skip=1, fetch=3 + AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], lim=[4] + DataSourceExec: partitions=1, partition_sizes=[1] + " + ); + let expected = run_plan_and_format(plan).await?; + assert_snapshot!( + expected, + @r" + +---+ + | a | + +---+ + | 2 | + | | + | 4 | + +---+ + " + ); Ok(()) } @@ -211,37 +225,44 @@ async fn test_distinct_cols_different_than_group_by_cols() -> Result<()> { 4, // fetch ); // expected to push the limit to the outer AggregateExec only - let expected = [ - "LocalLimitExec: fetch=4", - "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], lim=[4]", - "AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[]", - "DataSourceExec: partitions=1, partition_sizes=[1]", - ]; let plan: Arc = Arc::new(limit_exec); - assert_plan_matches_expected(&plan, &expected)?; - let expected = r#" -+---+ -| a | -+---+ -| 1 | -| 2 | -| | -| 4 | -+---+ -"# - .trim(); - assert_results_match_expected(plan, expected).await?; + let formatted = get_optimized_plan(&plan)?; + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r" + LocalLimitExec: fetch=4 + AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], lim=[4] + AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[] + DataSourceExec: partitions=1, partition_sizes=[1] + " + ); + let expected = run_plan_and_format(plan).await?; + assert_snapshot!( + expected, + @r" + +---+ + | a | + +---+ + | 1 | + | 2 | + | | + | 4 | + +---+ + " + ); Ok(()) } #[test] fn test_has_order_by() -> Result<()> { - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema()).unwrap(), + let schema = schema(); + let sort_key = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); - let source = parquet_exec_with_sort(vec![sort_key]); - let schema = source.schema(); + }] + .into(); + let source = parquet_exec_with_sort(schema.clone(), vec![sort_key]); // `SELECT a FROM DataSourceExec WHERE a > 1 GROUP BY a LIMIT 10;`, Single AggregateExec // the `a > 1` filter is applied in the AggregateExec @@ -258,13 +279,17 @@ fn test_has_order_by() -> Result<()> { 10, // fetch ); // expected not to push the limit to the AggregateExec - let expected = [ - "LocalLimitExec: fetch=10", - "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], ordering_mode=Sorted", - "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet", - ]; let plan: Arc = Arc::new(limit_exec); - assert_plan_matches_expected(&plan, &expected)?; + let formatted = get_optimized_plan(&plan)?; + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r" + LocalLimitExec: fetch=10 + AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], ordering_mode=Sorted + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + " + ); Ok(()) } @@ -287,13 +312,17 @@ fn test_no_group_by() -> Result<()> { 10, // fetch ); // expected not to push the limit to the AggregateExec - let expected = [ - "LocalLimitExec: fetch=10", - "AggregateExec: mode=Single, gby=[], aggr=[]", - "DataSourceExec: partitions=1, partition_sizes=[1]", - ]; let plan: Arc = Arc::new(limit_exec); - assert_plan_matches_expected(&plan, &expected)?; + let formatted = get_optimized_plan(&plan)?; + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r" + LocalLimitExec: fetch=10 + AggregateExec: mode=Single, gby=[], aggr=[] + DataSourceExec: partitions=1, partition_sizes=[1] + " + ); Ok(()) } @@ -317,13 +346,17 @@ fn test_has_aggregate_expression() -> Result<()> { 10, // fetch ); // expected not to push the limit to the AggregateExec - let expected = [ - "LocalLimitExec: fetch=10", - "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[COUNT(*)]", - "DataSourceExec: partitions=1, partition_sizes=[1]", - ]; let plan: Arc = Arc::new(limit_exec); - assert_plan_matches_expected(&plan, &expected)?; + let formatted = get_optimized_plan(&plan)?; + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r" + LocalLimitExec: fetch=10 + AggregateExec: mode=Single, gby=[a@0 as a], aggr=[COUNT(*)] + DataSourceExec: partitions=1, partition_sizes=[1] + " + ); Ok(()) } @@ -355,12 +388,16 @@ fn test_has_filter() -> Result<()> { ); // expected not to push the limit to the AggregateExec // TODO(msirek): open an issue for `filter_expr` of `AggregateExec` not printing out - let expected = [ - "LocalLimitExec: fetch=10", - "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[COUNT(*)]", - "DataSourceExec: partitions=1, partition_sizes=[1]", - ]; let plan: Arc = Arc::new(limit_exec); - assert_plan_matches_expected(&plan, &expected)?; + let formatted = get_optimized_plan(&plan)?; + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r" + LocalLimitExec: fetch=10 + AggregateExec: mode=Single, gby=[a@0 as a], aggr=[COUNT(*)] + DataSourceExec: partitions=1, partition_sizes=[1] + " + ); Ok(()) } diff --git a/datafusion/core/tests/physical_optimizer/mod.rs b/datafusion/core/tests/physical_optimizer/mod.rs index 98e7b87ad2157..cf179cb727cf1 100644 --- a/datafusion/core/tests/physical_optimizer/mod.rs +++ b/datafusion/core/tests/physical_optimizer/mod.rs @@ -17,16 +17,25 @@ //! Physical Optimizer integration tests +#[expect(clippy::needless_pass_by_value)] mod aggregate_statistics; mod combine_partial_final_agg; +#[expect(clippy::needless_pass_by_value)] mod enforce_distribution; mod enforce_sorting; +mod enforce_sorting_monotonicity; mod filter_pushdown; mod join_selection; +#[expect(clippy::needless_pass_by_value)] mod limit_pushdown; mod limited_distinct_aggregation; mod partition_statistics; mod projection_pushdown; +mod pushdown_sort; mod replace_with_order_preserving_variants; mod sanity_checker; +#[expect(clippy::needless_pass_by_value)] mod test_utils; +mod window_optimize; + +mod pushdown_utils; diff --git a/datafusion/core/tests/physical_optimizer/partition_statistics.rs b/datafusion/core/tests/physical_optimizer/partition_statistics.rs index 62f04f2fe740e..42c1e84534b6d 100644 --- a/datafusion/core/tests/physical_optimizer/partition_statistics.rs +++ b/datafusion/core/tests/physical_optimizer/partition_statistics.rs @@ -17,40 +17,52 @@ #[cfg(test)] mod test { + use insta::assert_snapshot; + use std::sync::Arc; + use arrow::array::{Int32Array, RecordBatch}; use arrow_schema::{DataType, Field, Schema, SortOptions}; use datafusion::datasource::listing::ListingTable; use datafusion::prelude::SessionContext; use datafusion_catalog::TableProvider; - use datafusion_common::stats::Precision; use datafusion_common::Result; - use datafusion_common::{ColumnStatistics, ScalarValue, Statistics}; - use datafusion_execution::config::SessionConfig; + use datafusion_common::stats::Precision; + use datafusion_common::{ + ColumnStatistics, JoinType, NullEquality, ScalarValue, Statistics, + }; use datafusion_execution::TaskContext; + use datafusion_execution::config::SessionConfig; + use datafusion_expr::{WindowFrame, WindowFunctionDefinition}; use datafusion_expr_common::operator::Operator; use datafusion_functions_aggregate::count::count_udaf; + use datafusion_physical_expr::Partitioning; use datafusion_physical_expr::aggregate::AggregateExprBuilder; - use datafusion_physical_expr::expressions::{binary, col, lit, Column}; + use datafusion_physical_expr::expressions::{Column, binary, col, lit}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; - use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; + use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; - use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; + use datafusion_physical_plan::common::compute_record_batch_statistics; use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::filter::FilterExec; - use datafusion_physical_plan::joins::CrossJoinExec; + use datafusion_physical_plan::joins::{ + CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode, + }; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; - use datafusion_physical_plan::projection::ProjectionExec; + use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; + use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr}; + use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::sorts::sort::SortExec; - use datafusion_physical_plan::union::UnionExec; + use datafusion_physical_plan::union::{InterleaveExec, UnionExec}; + use datafusion_physical_plan::windows::{WindowAggExec, create_window_expr}; use datafusion_physical_plan::{ - execute_stream_partitioned, get_plan_string, ExecutionPlan, - ExecutionPlanProperties, + ExecutionPlan, ExecutionPlanProperties, execute_stream_partitioned, + get_plan_string, }; + use futures::TryStreamExt; - use std::sync::Arc; /// Creates a test table with statistics from the test data directory. /// @@ -60,8 +72,9 @@ mod test { /// - Each partition has an "id" column (INT) with the following values: /// - First partition: [3, 4] /// - Second partition: [1, 2] - /// - Each row is 110 bytes in size + /// - Each partition has 16 bytes total (Int32 id: 4 bytes × 2 rows + Date32 date: 4 bytes × 2 rows) /// + /// @param create_table_sql Optional parameter to set the create table SQL /// @param target_partition Optional parameter to set the target partitions /// @return ExecutionPlan representing the scan of the table with statistics async fn create_scan_exec_with_statistics( @@ -104,29 +117,53 @@ mod test { .unwrap() } + // Date32 values for test data (days since 1970-01-01): + // 2025-03-01 = 20148 + // 2025-03-02 = 20149 + // 2025-03-03 = 20150 + // 2025-03-04 = 20151 + const DATE_2025_03_01: i32 = 20148; + const DATE_2025_03_02: i32 = 20149; + const DATE_2025_03_03: i32 = 20150; + const DATE_2025_03_04: i32 = 20151; + /// Helper function to create expected statistics for a partition with Int32 column + /// + /// If `date_range` is provided, includes exact statistics for the partition date column. + /// Partition column statistics are exact because all rows in a partition share the same value. fn create_partition_statistics( num_rows: usize, total_byte_size: usize, min_value: i32, max_value: i32, - include_date_column: bool, + date_range: Option<(i32, i32)>, ) -> Statistics { + // Int32 is 4 bytes per row + let int32_byte_size = num_rows * 4; let mut column_stats = vec![ColumnStatistics { null_count: Precision::Exact(0), max_value: Precision::Exact(ScalarValue::Int32(Some(max_value))), min_value: Precision::Exact(ScalarValue::Int32(Some(min_value))), sum_value: Precision::Absent, distinct_count: Precision::Absent, + byte_size: Precision::Exact(int32_byte_size), }]; - if include_date_column { + if let Some((min_date, max_date)) = date_range { + // Partition column stats are computed from partition values: + // - null_count = 0 (partition values from paths are never null) + // - min/max are the merged partition values across files in the group + // - byte_size = num_rows * 4 (Date32 is 4 bytes per row) + // - distinct_count = Inexact(1) per partition file (single partition value per file), + // preserved via max() when merging stats across partitions + let date32_byte_size = num_rows * 4; column_stats.push(ColumnStatistics { - null_count: Precision::Absent, - max_value: Precision::Absent, - min_value: Precision::Absent, + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Date32(Some(max_date))), + min_value: Precision::Exact(ScalarValue::Date32(Some(min_date))), sum_value: Precision::Absent, - distinct_count: Precision::Absent, + distinct_count: Precision::Inexact(1), + byte_size: Precision::Exact(date32_byte_size), }); } @@ -206,14 +243,26 @@ mod test { let statistics = (0..scan.output_partitioning().partition_count()) .map(|idx| scan.partition_statistics(Some(idx))) .collect::>>()?; - let expected_statistic_partition_1 = - create_partition_statistics(2, 110, 3, 4, true); - let expected_statistic_partition_2 = - create_partition_statistics(2, 110, 1, 2, true); + // Partition 1: ids [3,4], dates [2025-03-01, 2025-03-02] + let expected_statistic_partition_1 = create_partition_statistics( + 2, + 16, + 3, + 4, + Some((DATE_2025_03_01, DATE_2025_03_02)), + ); + // Partition 2: ids [1,2], dates [2025-03-03, 2025-03-04] + let expected_statistic_partition_2 = create_partition_statistics( + 2, + 16, + 1, + 2, + Some((DATE_2025_03_03, DATE_2025_03_04)), + ); // Check the statistics of each partition assert_eq!(statistics.len(), 2); - assert_eq!(statistics[0], expected_statistic_partition_1); - assert_eq!(statistics[1], expected_statistic_partition_2); + assert_eq!(*statistics[0], expected_statistic_partition_1); + assert_eq!(*statistics[1], expected_statistic_partition_2); // Check the statistics_by_partition with real results let expected_stats = vec![ @@ -229,21 +278,24 @@ mod test { async fn test_statistics_by_partition_of_projection() -> Result<()> { let scan = create_scan_exec_with_statistics(None, Some(2)).await; // Add projection execution plan - let exprs: Vec<(Arc, String)> = - vec![(Arc::new(Column::new("id", 0)), "id".to_string())]; + let exprs = vec![ProjectionExpr { + expr: Arc::new(Column::new("id", 0)) as Arc, + alias: "id".to_string(), + }]; let projection: Arc = Arc::new(ProjectionExec::try_new(exprs, scan)?); let statistics = (0..projection.output_partitioning().partition_count()) .map(|idx| projection.partition_statistics(Some(idx))) .collect::>>()?; + // Projection only includes id column, not the date partition column let expected_statistic_partition_1 = - create_partition_statistics(2, 8, 3, 4, false); + create_partition_statistics(2, 8, 3, 4, None); let expected_statistic_partition_2 = - create_partition_statistics(2, 8, 1, 2, false); + create_partition_statistics(2, 8, 1, 2, None); // Check the statistics of each partition assert_eq!(statistics.len(), 2); - assert_eq!(statistics[0], expected_statistic_partition_1); - assert_eq!(statistics[1], expected_statistic_partition_2); + assert_eq!(*statistics[0], expected_statistic_partition_1); + assert_eq!(*statistics[1], expected_statistic_partition_2); // Check the statistics_by_partition with real results let expected_stats = vec![ @@ -258,24 +310,25 @@ mod test { async fn test_statistics_by_partition_of_sort() -> Result<()> { let scan_1 = create_scan_exec_with_statistics(None, Some(1)).await; // Add sort execution plan - let sort = SortExec::new( - LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::new(Column::new("id", 0)), - options: SortOptions { - descending: false, - nulls_first: false, - }, - }]), - scan_1, - ); - let sort_exec: Arc = Arc::new(sort.clone()); + let ordering = [PhysicalSortExpr::new( + Arc::new(Column::new("id", 0)), + SortOptions::new(false, false), + )]; + let sort = SortExec::new(ordering.clone().into(), scan_1); + let sort_exec: Arc = Arc::new(sort); let statistics = (0..sort_exec.output_partitioning().partition_count()) .map(|idx| sort_exec.partition_statistics(Some(idx))) .collect::>>()?; - let expected_statistic_partition = - create_partition_statistics(4, 220, 1, 4, true); + // All 4 files merged: ids [1-4], dates [2025-03-01, 2025-03-04] + let expected_statistic_partition = create_partition_statistics( + 4, + 32, + 1, + 4, + Some((DATE_2025_03_01, DATE_2025_03_04)), + ); assert_eq!(statistics.len(), 1); - assert_eq!(statistics[0], expected_statistic_partition); + assert_eq!(*statistics[0], expected_statistic_partition); // Check the statistics_by_partition with real results let expected_stats = vec![ExpectedStatistics::NonEmpty(1, 4, 4)]; validate_statistics_with_data(sort_exec.clone(), expected_stats, 0).await?; @@ -284,28 +337,30 @@ mod test { let scan_2 = create_scan_exec_with_statistics(None, Some(2)).await; // Add sort execution plan let sort_exec: Arc = Arc::new( - SortExec::new( - LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::new(Column::new("id", 0)), - options: SortOptions { - descending: false, - nulls_first: false, - }, - }]), - scan_2, - ) - .with_preserve_partitioning(true), + SortExec::new(ordering.into(), scan_2).with_preserve_partitioning(true), + ); + // Partition 1: ids [3,4], dates [2025-03-01, 2025-03-02] + let expected_statistic_partition_1 = create_partition_statistics( + 2, + 16, + 3, + 4, + Some((DATE_2025_03_01, DATE_2025_03_02)), + ); + // Partition 2: ids [1,2], dates [2025-03-03, 2025-03-04] + let expected_statistic_partition_2 = create_partition_statistics( + 2, + 16, + 1, + 2, + Some((DATE_2025_03_03, DATE_2025_03_04)), ); - let expected_statistic_partition_1 = - create_partition_statistics(2, 110, 3, 4, true); - let expected_statistic_partition_2 = - create_partition_statistics(2, 110, 1, 2, true); let statistics = (0..sort_exec.output_partitioning().partition_count()) .map(|idx| sort_exec.partition_statistics(Some(idx))) .collect::>>()?; assert_eq!(statistics.len(), 2); - assert_eq!(statistics[0], expected_statistic_partition_1); - assert_eq!(statistics[1], expected_statistic_partition_2); + assert_eq!(*statistics[0], expected_statistic_partition_1); + assert_eq!(*statistics[1], expected_statistic_partition_2); // Check the statistics_by_partition with real results let expected_stats = vec![ @@ -329,34 +384,61 @@ mod test { let filter: Arc = Arc::new(FilterExec::try_new(predicate, scan)?); let full_statistics = filter.partition_statistics(None)?; + // Filter preserves original total_rows and byte_size from input + // (4 total rows = 2 partitions * 2 rows each, byte_size = 4 * 4 = 16 bytes for int32) let expected_full_statistic = Statistics { num_rows: Precision::Inexact(0), total_byte_size: Precision::Inexact(0), column_statistics: vec![ ColumnStatistics { null_count: Precision::Exact(0), - max_value: Precision::Exact(ScalarValue::Null), - min_value: Precision::Exact(ScalarValue::Null), - sum_value: Precision::Exact(ScalarValue::Null), + max_value: Precision::Exact(ScalarValue::Int32(None)), + min_value: Precision::Exact(ScalarValue::Int32(None)), + sum_value: Precision::Exact(ScalarValue::Int32(None)), distinct_count: Precision::Exact(0), + byte_size: Precision::Exact(16), }, ColumnStatistics { null_count: Precision::Exact(0), - max_value: Precision::Exact(ScalarValue::Null), - min_value: Precision::Exact(ScalarValue::Null), - sum_value: Precision::Exact(ScalarValue::Null), + max_value: Precision::Exact(ScalarValue::Date32(None)), + min_value: Precision::Exact(ScalarValue::Date32(None)), + sum_value: Precision::Exact(ScalarValue::Date32(None)), distinct_count: Precision::Exact(0), + byte_size: Precision::Exact(16), // 4 rows * 4 bytes (Date32) }, ], }; - assert_eq!(full_statistics, expected_full_statistic); + assert_eq!(*full_statistics, expected_full_statistic); let statistics = (0..filter.output_partitioning().partition_count()) .map(|idx| filter.partition_statistics(Some(idx))) .collect::>>()?; assert_eq!(statistics.len(), 2); - assert_eq!(statistics[0], expected_full_statistic); - assert_eq!(statistics[1], expected_full_statistic); + // Per-partition stats: each partition has 2 rows, byte_size = 2 * 4 = 8 + let expected_partition_statistic = Statistics { + num_rows: Precision::Inexact(0), + total_byte_size: Precision::Inexact(0), + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(None)), + min_value: Precision::Exact(ScalarValue::Int32(None)), + sum_value: Precision::Exact(ScalarValue::Int32(None)), + distinct_count: Precision::Exact(0), + byte_size: Precision::Exact(8), + }, + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Date32(None)), + min_value: Precision::Exact(ScalarValue::Date32(None)), + sum_value: Precision::Exact(ScalarValue::Date32(None)), + distinct_count: Precision::Exact(0), + byte_size: Precision::Exact(8), // 2 rows * 4 bytes (Date32) + }, + ], + }; + assert_eq!(*statistics[0], expected_partition_statistic); + assert_eq!(*statistics[1], expected_partition_statistic); Ok(()) } @@ -364,24 +446,36 @@ mod test { async fn test_statistic_by_partition_of_union() -> Result<()> { let scan = create_scan_exec_with_statistics(None, Some(2)).await; let union_exec: Arc = - Arc::new(UnionExec::new(vec![scan.clone(), scan])); + UnionExec::try_new(vec![scan.clone(), scan])?; let statistics = (0..union_exec.output_partitioning().partition_count()) .map(|idx| union_exec.partition_statistics(Some(idx))) .collect::>>()?; // Check that we have 4 partitions (2 from each scan) assert_eq!(statistics.len(), 4); - let expected_statistic_partition_1 = - create_partition_statistics(2, 110, 3, 4, true); - let expected_statistic_partition_2 = - create_partition_statistics(2, 110, 1, 2, true); + // Partition 1: ids [3,4], dates [2025-03-01, 2025-03-02] + let expected_statistic_partition_1 = create_partition_statistics( + 2, + 16, + 3, + 4, + Some((DATE_2025_03_01, DATE_2025_03_02)), + ); + // Partition 2: ids [1,2], dates [2025-03-03, 2025-03-04] + let expected_statistic_partition_2 = create_partition_statistics( + 2, + 16, + 1, + 2, + Some((DATE_2025_03_03, DATE_2025_03_04)), + ); // Verify first partition (from first scan) - assert_eq!(statistics[0], expected_statistic_partition_1); + assert_eq!(*statistics[0], expected_statistic_partition_1); // Verify second partition (from first scan) - assert_eq!(statistics[1], expected_statistic_partition_2); + assert_eq!(*statistics[1], expected_statistic_partition_2); // Verify third partition (from second scan - same as first partition) - assert_eq!(statistics[2], expected_statistic_partition_1); + assert_eq!(*statistics[2], expected_statistic_partition_1); // Verify fourth partition (from second scan - same as second partition) - assert_eq!(statistics[3], expected_statistic_partition_2); + assert_eq!(*statistics[3], expected_statistic_partition_2); // Check the statistics_by_partition with real results let expected_stats = vec![ @@ -394,6 +488,64 @@ mod test { Ok(()) } + #[tokio::test] + async fn test_statistics_by_partition_of_interleave() -> Result<()> { + let scan1 = create_scan_exec_with_statistics(None, Some(1)).await; + let scan2 = create_scan_exec_with_statistics(None, Some(1)).await; + + // Create same hash partitioning on the 'id' column as InterleaveExec + // requires all children have a consistent hash partitioning + let hash_expr1 = vec![col("id", &scan1.schema())?]; + let repartition1 = Arc::new(RepartitionExec::try_new( + scan1, + Partitioning::Hash(hash_expr1, 2), + )?); + let hash_expr2 = vec![col("id", &scan2.schema())?]; + let repartition2 = Arc::new(RepartitionExec::try_new( + scan2, + Partitioning::Hash(hash_expr2, 2), + )?); + + let interleave: Arc = + Arc::new(InterleaveExec::try_new(vec![repartition1, repartition2])?); + + // Verify the result of partition statistics + let stats = (0..interleave.output_partitioning().partition_count()) + .map(|idx| interleave.partition_statistics(Some(idx))) + .collect::>>()?; + assert_eq!(stats.len(), 2); + + // Each partition gets half of combined input, total_rows per partition = 4 + let expected_stats = Statistics { + num_rows: Precision::Inexact(4), + total_byte_size: Precision::Inexact(32), + column_statistics: vec![ + ColumnStatistics::new_unknown(), + ColumnStatistics::new_unknown(), + ], + }; + assert_eq!(*stats[0], expected_stats); + assert_eq!(*stats[1], expected_stats); + + // Verify the execution results + let partitions = execute_stream_partitioned( + interleave.clone(), + Arc::new(TaskContext::default()), + )?; + assert_eq!(partitions.len(), 2); + + let mut partition_row_counts = Vec::new(); + for partition_stream in partitions.into_iter() { + let results: Vec = partition_stream.try_collect().await?; + let total_rows: usize = results.iter().map(|batch| batch.num_rows()).sum(); + partition_row_counts.push(total_rows); + } + assert_eq!(partition_row_counts.len(), 2); + assert_eq!(partition_row_counts.iter().sum::(), 8); + + Ok(()) + } + #[tokio::test] async fn test_statistic_by_partition_of_cross_join() -> Result<()> { let left_scan = create_scan_exec_with_statistics(None, Some(1)).await; @@ -409,30 +561,78 @@ mod test { .collect::>>()?; // Check that we have 2 partitions assert_eq!(statistics.len(), 2); - let mut expected_statistic_partition_1 = - create_partition_statistics(8, 48400, 1, 4, true); - expected_statistic_partition_1 - .column_statistics - .push(ColumnStatistics { - null_count: Precision::Exact(0), - max_value: Precision::Exact(ScalarValue::Int32(Some(4))), - min_value: Precision::Exact(ScalarValue::Int32(Some(3))), - sum_value: Precision::Absent, - distinct_count: Precision::Absent, - }); - let mut expected_statistic_partition_2 = - create_partition_statistics(8, 48400, 1, 4, true); - expected_statistic_partition_2 - .column_statistics - .push(ColumnStatistics { - null_count: Precision::Exact(0), - max_value: Precision::Exact(ScalarValue::Int32(Some(2))), - min_value: Precision::Exact(ScalarValue::Int32(Some(1))), - sum_value: Precision::Absent, - distinct_count: Precision::Absent, - }); - assert_eq!(statistics[0], expected_statistic_partition_1); - assert_eq!(statistics[1], expected_statistic_partition_2); + // Cross join output schema: [left.id, left.date, right.id] + // Cross join doesn't propagate Column's byte_size + let expected_statistic_partition_1 = Statistics { + num_rows: Precision::Exact(8), + total_byte_size: Precision::Exact(512), + column_statistics: vec![ + // column 0: left.id (Int32, file column from t1) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(4))), + min_value: Precision::Exact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Absent, + }, + // column 1: left.date (Date32, partition column from t1) + // Partition column statistics are exact because all rows in a partition share the same value. + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Date32(Some(20151))), + min_value: Precision::Exact(ScalarValue::Date32(Some(20148))), + sum_value: Precision::Absent, + distinct_count: Precision::Inexact(1), + byte_size: Precision::Absent, + }, + // column 2: right.id (Int32, file column from t2) - right partition 0: ids [3,4] + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(4))), + min_value: Precision::Exact(ScalarValue::Int32(Some(3))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Absent, + }, + ], + }; + let expected_statistic_partition_2 = Statistics { + num_rows: Precision::Exact(8), + total_byte_size: Precision::Exact(512), + column_statistics: vec![ + // column 0: left.id (Int32, file column from t1) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(4))), + min_value: Precision::Exact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Absent, + }, + // column 1: left.date (Date32, partition column from t1) + // Partition column statistics are exact because all rows in a partition share the same value. + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Date32(Some(20151))), + min_value: Precision::Exact(ScalarValue::Date32(Some(20148))), + sum_value: Precision::Absent, + distinct_count: Precision::Inexact(1), + byte_size: Precision::Absent, + }, + // column 2: right.id (Int32, file column from t2) - right partition 1: ids [1,2] + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(2))), + min_value: Precision::Exact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Absent, + }, + ], + }; + assert_eq!(*statistics[0], expected_statistic_partition_1); + assert_eq!(*statistics[1], expected_statistic_partition_2); // Check the statistics_by_partition with real results let expected_stats = vec![ @@ -444,28 +644,77 @@ mod test { } #[tokio::test] - async fn test_statistic_by_partition_of_coalesce_batches() -> Result<()> { - let scan = create_scan_exec_with_statistics(None, Some(2)).await; - dbg!(scan.partition_statistics(Some(0))?); - let coalesce_batches: Arc = - Arc::new(CoalesceBatchesExec::new(scan, 2)); - let expected_statistic_partition_1 = - create_partition_statistics(2, 110, 3, 4, true); - let expected_statistic_partition_2 = - create_partition_statistics(2, 110, 1, 2, true); - let statistics = (0..coalesce_batches.output_partitioning().partition_count()) - .map(|idx| coalesce_batches.partition_statistics(Some(idx))) + async fn test_statistic_by_partition_of_nested_loop_join() -> Result<()> { + use datafusion_expr::JoinType; + + let left_scan = create_scan_exec_with_statistics(None, Some(2)).await; + let left_scan_coalesced: Arc = + Arc::new(CoalescePartitionsExec::new(left_scan)); + + let right_scan = create_scan_exec_with_statistics(None, Some(2)).await; + + let nested_loop_join: Arc = + Arc::new(NestedLoopJoinExec::try_new( + left_scan_coalesced, + right_scan, + None, + &JoinType::RightSemi, + None, + )?); + + // Test partition_statistics(None) - returns overall statistics + // For RightSemi join, output columns come from right side only + let full_statistics = nested_loop_join.partition_statistics(None)?; + // With empty join columns, estimate_join_statistics returns Inexact row count + // based on the outer side (right side for RightSemi) + let mut expected_full_statistics = create_partition_statistics( + 4, + 32, + 1, + 4, + Some((DATE_2025_03_01, DATE_2025_03_04)), + ); + expected_full_statistics.num_rows = Precision::Inexact(4); + expected_full_statistics.total_byte_size = Precision::Absent; + assert_eq!(*full_statistics, expected_full_statistics); + + // Test partition_statistics(Some(idx)) - returns partition-specific statistics + // Partition 1: ids [3,4], dates [2025-03-01, 2025-03-02] + let mut expected_statistic_partition_1 = create_partition_statistics( + 2, + 16, + 3, + 4, + Some((DATE_2025_03_01, DATE_2025_03_02)), + ); + expected_statistic_partition_1.num_rows = Precision::Inexact(2); + expected_statistic_partition_1.total_byte_size = Precision::Absent; + + // Partition 2: ids [1,2], dates [2025-03-03, 2025-03-04] + let mut expected_statistic_partition_2 = create_partition_statistics( + 2, + 16, + 1, + 2, + Some((DATE_2025_03_03, DATE_2025_03_04)), + ); + expected_statistic_partition_2.num_rows = Precision::Inexact(2); + expected_statistic_partition_2.total_byte_size = Precision::Absent; + + let statistics = (0..nested_loop_join.output_partitioning().partition_count()) + .map(|idx| nested_loop_join.partition_statistics(Some(idx))) .collect::>>()?; assert_eq!(statistics.len(), 2); - assert_eq!(statistics[0], expected_statistic_partition_1); - assert_eq!(statistics[1], expected_statistic_partition_2); + assert_eq!(*statistics[0], expected_statistic_partition_1); + assert_eq!(*statistics[1], expected_statistic_partition_2); // Check the statistics_by_partition with real results let expected_stats = vec![ ExpectedStatistics::NonEmpty(3, 4, 2), ExpectedStatistics::NonEmpty(1, 2, 2), ]; - validate_statistics_with_data(coalesce_batches, expected_stats, 0).await?; + validate_statistics_with_data(nested_loop_join, expected_stats, 0).await?; + Ok(()) } @@ -474,13 +723,19 @@ mod test { let scan = create_scan_exec_with_statistics(None, Some(2)).await; let coalesce_partitions: Arc = Arc::new(CoalescePartitionsExec::new(scan)); - let expected_statistic_partition = - create_partition_statistics(4, 220, 1, 4, true); + // All files merged: ids [1-4], dates [2025-03-01, 2025-03-04] + let expected_statistic_partition = create_partition_statistics( + 4, + 32, + 1, + 4, + Some((DATE_2025_03_01, DATE_2025_03_04)), + ); let statistics = (0..coalesce_partitions.output_partitioning().partition_count()) .map(|idx| coalesce_partitions.partition_statistics(Some(idx))) .collect::>>()?; assert_eq!(statistics.len(), 1); - assert_eq!(statistics[0], expected_statistic_partition); + assert_eq!(*statistics[0], expected_statistic_partition); // Check the statistics_by_partition with real results let expected_stats = vec![ExpectedStatistics::NonEmpty(1, 4, 4)]; @@ -497,11 +752,20 @@ mod test { .map(|idx| local_limit.partition_statistics(Some(idx))) .collect::>>()?; assert_eq!(statistics.len(), 2); - let schema = scan.schema(); - let mut expected_statistic_partition = Statistics::new_unknown(&schema); - expected_statistic_partition.num_rows = Precision::Exact(1); - assert_eq!(statistics[0], expected_statistic_partition); - assert_eq!(statistics[1], expected_statistic_partition); + let mut expected_0 = Statistics::clone(&statistics[0]); + expected_0.column_statistics = expected_0 + .column_statistics + .into_iter() + .map(|c| c.to_inexact()) + .collect(); + let mut expected_1 = Statistics::clone(&statistics[1]); + expected_1.column_statistics = expected_1 + .column_statistics + .into_iter() + .map(|c| c.to_inexact()) + .collect(); + assert_eq!(*statistics[0], expected_0); + assert_eq!(*statistics[1], expected_1); Ok(()) } @@ -515,9 +779,15 @@ mod test { .map(|idx| global_limit.partition_statistics(Some(idx))) .collect::>>()?; assert_eq!(statistics.len(), 1); - let expected_statistic_partition = - create_partition_statistics(2, 110, 3, 4, true); - assert_eq!(statistics[0], expected_statistic_partition); + // GlobalLimit takes from first partition: ids [3,4], dates [2025-03-01, 2025-03-02] + let expected_statistic_partition = create_partition_statistics( + 2, + 16, + 3, + 4, + Some((DATE_2025_03_01, DATE_2025_03_02)), + ); + assert_eq!(*statistics[0], expected_statistic_partition); Ok(()) } @@ -541,34 +811,36 @@ mod test { ), ]); - let aggr_expr = vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1)]) - .schema(Arc::clone(&scan_schema)) - .alias(String::from("COUNT(c)")) - .build() - .map(Arc::new)?]; - - let aggregate_exec_partial = Arc::new(AggregateExec::try_new( - AggregateMode::Partial, - group_by.clone(), - aggr_expr.clone(), - vec![None], - Arc::clone(&scan), - scan_schema.clone(), - )?) as _; - - let mut plan_string = get_plan_string(&aggregate_exec_partial); - let _ = plan_string.swap_remove(1); - let expected_plan = vec![ - "AggregateExec: mode=Partial, gby=[id@0 as id, 1 + id@0 as expr], aggr=[COUNT(c)]", - //" DataSourceExec: file_groups={2 groups: [[.../datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-01/j5fUeSDQo22oPyPU.parquet, .../datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-02/j5fUeSDQo22oPyPU.parquet], [.../datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-03/j5fUeSDQo22oPyPU.parquet, .../datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-04/j5fUeSDQo22oPyPU.parquet]]}, projection=[id, date], file_type=parquet + let aggr_expr = vec![ + AggregateExprBuilder::new(count_udaf(), vec![lit(1)]) + .schema(Arc::clone(&scan_schema)) + .alias(String::from("COUNT(c)")) + .build() + .map(Arc::new)?, ]; - assert_eq!(plan_string, expected_plan); + + let aggregate_exec_partial: Arc = + Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + group_by.clone(), + aggr_expr.clone(), + vec![None], + Arc::clone(&scan), + scan_schema.clone(), + )?) as _; + + let plan_string = get_plan_string(&aggregate_exec_partial).swap_remove(0); + assert_snapshot!( + plan_string, + @"AggregateExec: mode=Partial, gby=[id@0 as id, 1 + id@0 as expr], aggr=[COUNT(c)]" + ); let p0_statistics = aggregate_exec_partial.partition_statistics(Some(0))?; + // Aggregate doesn't propagate num_rows and ColumnStatistics byte_size from input let expected_p0_statistics = Statistics { num_rows: Precision::Inexact(2), - total_byte_size: Precision::Absent, + total_byte_size: Precision::Inexact(16), column_statistics: vec![ ColumnStatistics { null_count: Precision::Absent, @@ -576,17 +848,18 @@ mod test { min_value: Precision::Exact(ScalarValue::Int32(Some(3))), sum_value: Precision::Absent, distinct_count: Precision::Absent, + byte_size: Precision::Absent, }, ColumnStatistics::new_unknown(), ColumnStatistics::new_unknown(), ], }; - assert_eq!(&p0_statistics, &expected_p0_statistics); + assert_eq!(*p0_statistics, expected_p0_statistics); let expected_p1_statistics = Statistics { num_rows: Precision::Inexact(2), - total_byte_size: Precision::Absent, + total_byte_size: Precision::Inexact(16), column_statistics: vec![ ColumnStatistics { null_count: Precision::Absent, @@ -594,6 +867,7 @@ mod test { min_value: Precision::Exact(ScalarValue::Int32(Some(1))), sum_value: Precision::Absent, distinct_count: Precision::Absent, + byte_size: Precision::Absent, }, ColumnStatistics::new_unknown(), ColumnStatistics::new_unknown(), @@ -601,7 +875,7 @@ mod test { }; let p1_statistics = aggregate_exec_partial.partition_statistics(Some(1))?; - assert_eq!(&p1_statistics, &expected_p1_statistics); + assert_eq!(*p1_statistics, expected_p1_statistics); validate_statistics_with_data( aggregate_exec_partial.clone(), @@ -623,10 +897,10 @@ mod test { )?); let p0_statistics = agg_final.partition_statistics(Some(0))?; - assert_eq!(&p0_statistics, &expected_p0_statistics); + assert_eq!(*p0_statistics, expected_p0_statistics); let p1_statistics = agg_final.partition_statistics(Some(1))?; - assert_eq!(&p1_statistics, &expected_p1_statistics); + assert_eq!(*p1_statistics, expected_p1_statistics); validate_statistics_with_data( agg_final.clone(), @@ -652,7 +926,10 @@ mod test { )?) as _; let agg_plan = get_plan_string(&agg_partial).remove(0); - assert_eq!("AggregateExec: mode=Partial, gby=[id@0 as id, 1 + id@0 as expr], aggr=[COUNT(c)]",agg_plan); + assert_snapshot!( + agg_plan, + @"AggregateExec: mode=Partial, gby=[id@0 as id, 1 + id@0 as expr], aggr=[COUNT(c)]" + ); let empty_stat = Statistics { num_rows: Precision::Exact(0), @@ -664,8 +941,8 @@ mod test { ], }; - assert_eq!(&empty_stat, &agg_partial.partition_statistics(Some(0))?); - assert_eq!(&empty_stat, &agg_partial.partition_statistics(Some(1))?); + assert_eq!(empty_stat, *agg_partial.partition_statistics(Some(0))?); + assert_eq!(empty_stat, *agg_partial.partition_statistics(Some(1))?); validate_statistics_with_data( agg_partial.clone(), vec![ExpectedStatistics::Empty, ExpectedStatistics::Empty], @@ -691,8 +968,8 @@ mod test { agg_partial.schema(), )?); - assert_eq!(&empty_stat, &agg_final.partition_statistics(Some(0))?); - assert_eq!(&empty_stat, &agg_final.partition_statistics(Some(1))?); + assert_eq!(empty_stat, *agg_final.partition_statistics(Some(0))?); + assert_eq!(empty_stat, *agg_final.partition_statistics(Some(1))?); validate_statistics_with_data( agg_final, @@ -728,7 +1005,7 @@ mod test { column_statistics: vec![ColumnStatistics::new_unknown()], }; - assert_eq!(&expect_stat, &agg_final.partition_statistics(Some(0))?); + assert_eq!(expect_stat, *agg_final.partition_statistics(Some(0))?); // Verify that the aggregate final result has exactly one partition with one row let mut partitions = execute_stream_partitioned( @@ -741,4 +1018,594 @@ mod test { Ok(()) } + + #[tokio::test] + async fn test_statistic_by_partition_of_placeholder_rows() -> Result<()> { + let schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + let plan = Arc::new(PlaceholderRowExec::new(schema).with_partitions(2)) + as Arc; + let schema = plan.schema(); + + let ctx = TaskContext::default(); + let partitions = execute_stream_partitioned(Arc::clone(&plan), Arc::new(ctx))?; + + let mut all_batches = vec![]; + for (i, partition_stream) in partitions.into_iter().enumerate() { + let batches: Vec = partition_stream.try_collect().await?; + let actual = plan.partition_statistics(Some(i))?; + let expected = compute_record_batch_statistics( + std::slice::from_ref(&batches), + &schema, + None, + ); + assert_eq!(*actual, expected); + all_batches.push(batches); + } + + let actual = plan.partition_statistics(None)?; + let expected = compute_record_batch_statistics(&all_batches, &schema, None); + assert_eq!(*actual, expected); + + Ok(()) + } + + #[tokio::test] + async fn test_statistic_by_partition_of_repartition() -> Result<()> { + let scan = create_scan_exec_with_statistics(None, Some(2)).await; + + let repartition = Arc::new(RepartitionExec::try_new( + scan.clone(), + Partitioning::RoundRobinBatch(3), + )?); + + let statistics = (0..repartition.partitioning().partition_count()) + .map(|idx| repartition.partition_statistics(Some(idx))) + .collect::>>()?; + assert_eq!(statistics.len(), 3); + + // Repartition preserves original total_rows from input (4 rows total) + let expected_stats = Statistics { + num_rows: Precision::Inexact(1), + total_byte_size: Precision::Inexact(10), + column_statistics: vec![ + ColumnStatistics::new_unknown(), + ColumnStatistics::new_unknown(), + ], + }; + + // All partitions should have the same statistics + for stat in statistics.iter() { + assert_eq!(**stat, expected_stats); + } + + // Verify that the result has exactly 3 partitions + let partitions = execute_stream_partitioned( + repartition.clone(), + Arc::new(TaskContext::default()), + )?; + assert_eq!(partitions.len(), 3); + + // Collect row counts from each partition + let mut partition_row_counts = Vec::new(); + for partition_stream in partitions.into_iter() { + let results: Vec = partition_stream.try_collect().await?; + let total_rows: usize = results.iter().map(|batch| batch.num_rows()).sum(); + partition_row_counts.push(total_rows); + } + assert_eq!(partition_row_counts.len(), 3); + assert_eq!(partition_row_counts[0], 1); + assert_eq!(partition_row_counts[1], 2); + assert_eq!(partition_row_counts[2], 1); + + Ok(()) + } + + #[tokio::test] + async fn test_statistic_by_partition_of_repartition_invalid_partition() -> Result<()> + { + let scan = create_scan_exec_with_statistics(None, Some(2)).await; + + let repartition = Arc::new(RepartitionExec::try_new( + scan.clone(), + Partitioning::RoundRobinBatch(2), + )?); + + let result = repartition.partition_statistics(Some(2)); + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!( + error + .to_string() + .contains("RepartitionExec invalid partition 2 (expected less than 2)") + ); + + let partitions = execute_stream_partitioned( + repartition.clone(), + Arc::new(TaskContext::default()), + )?; + assert_eq!(partitions.len(), 2); + + Ok(()) + } + + #[tokio::test] + async fn test_statistic_by_partition_of_repartition_zero_partitions() -> Result<()> { + let scan = create_scan_exec_with_statistics(None, Some(2)).await; + let scan_schema = scan.schema(); + + // Create a repartition with 0 partitions + let repartition = Arc::new(RepartitionExec::try_new( + Arc::new(EmptyExec::new(scan_schema.clone())), + Partitioning::RoundRobinBatch(0), + )?); + + let result = repartition.partition_statistics(Some(0))?; + assert_eq!(*result, Statistics::new_unknown(&scan_schema)); + + // Verify that the result has exactly 0 partitions + let partitions = execute_stream_partitioned( + repartition.clone(), + Arc::new(TaskContext::default()), + )?; + assert_eq!(partitions.len(), 0); + + Ok(()) + } + + #[tokio::test] + async fn test_statistic_by_partition_of_repartition_hash_partitioning() -> Result<()> + { + let scan = create_scan_exec_with_statistics(None, Some(1)).await; + + // Create hash partitioning on the 'id' column + let hash_expr = vec![col("id", &scan.schema())?]; + let repartition = Arc::new(RepartitionExec::try_new( + scan, + Partitioning::Hash(hash_expr, 2), + )?); + + // Verify the result of partition statistics of repartition + let stats = (0..repartition.partitioning().partition_count()) + .map(|idx| repartition.partition_statistics(Some(idx))) + .collect::>>()?; + assert_eq!(stats.len(), 2); + + // Repartition preserves original total_rows from input (4 rows total) + let expected_stats = Statistics { + num_rows: Precision::Inexact(2), + total_byte_size: Precision::Inexact(16), + column_statistics: vec![ + ColumnStatistics::new_unknown(), + ColumnStatistics::new_unknown(), + ], + }; + assert_eq!(*stats[0], expected_stats); + assert_eq!(*stats[1], expected_stats); + + // Verify the repartition execution results + let partitions = + execute_stream_partitioned(repartition, Arc::new(TaskContext::default()))?; + assert_eq!(partitions.len(), 2); + + let mut partition_row_counts = Vec::new(); + for partition_stream in partitions.into_iter() { + let results: Vec = partition_stream.try_collect().await?; + let total_rows: usize = results.iter().map(|batch| batch.num_rows()).sum(); + partition_row_counts.push(total_rows); + } + assert_eq!(partition_row_counts.len(), 2); + assert_eq!(partition_row_counts.iter().sum::(), 4); + + Ok(()) + } + + #[tokio::test] + async fn test_statistic_by_partition_of_window_agg() -> Result<()> { + let scan = create_scan_exec_with_statistics(None, Some(2)).await; + + let window_expr = create_window_expr( + &WindowFunctionDefinition::AggregateUDF(count_udaf()), + "count".to_owned(), + &[col("id", &scan.schema())?], + &[], // no partition by + &[PhysicalSortExpr::new( + col("id", &scan.schema())?, + SortOptions::default(), + )], + Arc::new(WindowFrame::new(Some(false))), + scan.schema(), + false, + false, + None, + )?; + + let window_agg: Arc = + Arc::new(WindowAggExec::try_new(vec![window_expr], scan, true)?); + + // Verify partition statistics are properly propagated (not unknown) + let statistics = (0..window_agg.output_partitioning().partition_count()) + .map(|idx| window_agg.partition_statistics(Some(idx))) + .collect::>>()?; + + assert_eq!(statistics.len(), 2); + + // Window functions preserve input row counts and column statistics + // but add unknown statistics for the new window column + let expected_statistic_partition_1 = Statistics { + num_rows: Precision::Exact(2), + total_byte_size: Precision::Absent, + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(4))), + min_value: Precision::Exact(ScalarValue::Int32(Some(3))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(8), + }, + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_02, + ))), + min_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_01, + ))), + sum_value: Precision::Absent, + distinct_count: Precision::Inexact(1), + byte_size: Precision::Exact(8), + }, + ColumnStatistics::new_unknown(), // window column + ], + }; + + let expected_statistic_partition_2 = Statistics { + num_rows: Precision::Exact(2), + total_byte_size: Precision::Absent, + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(2))), + min_value: Precision::Exact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(8), + }, + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_04, + ))), + min_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_03, + ))), + sum_value: Precision::Absent, + distinct_count: Precision::Inexact(1), + byte_size: Precision::Exact(8), + }, + ColumnStatistics::new_unknown(), // window column + ], + }; + + assert_eq!(*statistics[0], expected_statistic_partition_1); + assert_eq!(*statistics[1], expected_statistic_partition_2); + + // Verify the statistics match actual execution results + let expected_stats = vec![ + ExpectedStatistics::NonEmpty(3, 4, 2), + ExpectedStatistics::NonEmpty(1, 2, 2), + ]; + validate_statistics_with_data(window_agg, expected_stats, 0).await?; + + Ok(()) + } + + #[tokio::test] + async fn test_statistics_by_partition_of_empty_exec() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ])); + + // Try to test with single partition + let empty_single = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + let stats = empty_single.partition_statistics(Some(0))?; + assert_eq!(stats.num_rows, Precision::Exact(0)); + assert_eq!(stats.total_byte_size, Precision::Exact(0)); + assert_eq!(stats.column_statistics.len(), 2); + + for col_stat in &stats.column_statistics { + assert_eq!(col_stat.null_count, Precision::Exact(0)); + assert_eq!(col_stat.distinct_count, Precision::Exact(0)); + assert_eq!(col_stat.byte_size, Precision::Exact(0)); + assert_eq!(col_stat.min_value, Precision::::Absent); + assert_eq!(col_stat.max_value, Precision::::Absent); + assert_eq!(col_stat.sum_value, Precision::::Absent); + assert_eq!(col_stat.byte_size, Precision::Exact(0)); + } + + let overall_stats = empty_single.partition_statistics(None)?; + assert_eq!(stats, overall_stats); + + validate_statistics_with_data(empty_single, vec![ExpectedStatistics::Empty], 0) + .await?; + + // Test with multiple partitions + let empty_multi: Arc = + Arc::new(EmptyExec::new(Arc::clone(&schema)).with_partitions(3)); + + let statistics = (0..empty_multi.output_partitioning().partition_count()) + .map(|idx| empty_multi.partition_statistics(Some(idx))) + .collect::>>()?; + + assert_eq!(statistics.len(), 3); + + for stat in &statistics { + assert_eq!(stat.num_rows, Precision::Exact(0)); + assert_eq!(stat.total_byte_size, Precision::Exact(0)); + } + + validate_statistics_with_data( + empty_multi, + vec![ + ExpectedStatistics::Empty, + ExpectedStatistics::Empty, + ExpectedStatistics::Empty, + ], + 0, + ) + .await?; + + Ok(()) + } + + #[tokio::test] + async fn test_hash_join_partition_statistics() -> Result<()> { + // Create left table scan and coalesce to 1 partition for CollectLeft mode + let left_scan = create_scan_exec_with_statistics(None, Some(2)).await; + let left_scan_coalesced = Arc::new(CoalescePartitionsExec::new(left_scan.clone())) + as Arc; + + // Create right table scan with different table name + let right_create_table_sql = "CREATE EXTERNAL TABLE t2 (id INT NOT NULL, date DATE) \ + STORED AS PARQUET LOCATION './tests/data/test_statistics_per_partition'\ + PARTITIONED BY (date) \ + WITH ORDER (id ASC);"; + let right_scan = + create_scan_exec_with_statistics(Some(right_create_table_sql), Some(2)).await; + + // Create join condition: t1.id = t2.id + let on = vec![( + Arc::new(Column::new("id", 0)) as Arc, + Arc::new(Column::new("id", 0)) as Arc, + )]; + + // Test CollectLeft mode - left child must have 1 partition + let collect_left_join = Arc::new(HashJoinExec::try_new( + left_scan_coalesced, + Arc::clone(&right_scan), + on.clone(), + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + false, + )?) as Arc; + + // Test partition statistics for CollectLeft mode + let statistics = (0..collect_left_join.output_partitioning().partition_count()) + .map(|idx| collect_left_join.partition_statistics(Some(idx))) + .collect::>>()?; + + // Check that we have the expected number of partitions + assert_eq!(statistics.len(), 2); + + // For collect left mode, the min/max values are from the entire left table and the specific partition of the right table. + let expected_p0_statistics = Statistics { + num_rows: Precision::Inexact(2), + total_byte_size: Precision::Absent, + column_statistics: vec![ + // Left id column: all partitions (id 1..4) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(4))), + min_value: Precision::Exact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(16), + }, + // Left date column: all partitions (2025-03-01..2025-03-04) + // NDV is Inexact(1) because each Hive partition has exactly 1 distinct date value, + // and merging takes max as a conservative lower bound + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_04, + ))), + min_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_01, + ))), + sum_value: Precision::Absent, + distinct_count: Precision::Inexact(1), + byte_size: Precision::Exact(16), + }, + // Right id column: partition 0 only (id 3..4) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(4))), + min_value: Precision::Exact(ScalarValue::Int32(Some(3))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(8), + }, + // Right date column: partition 0 only (2025-03-01..2025-03-02) + // NDV is Inexact(1) from the single Hive partition's date value + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_02, + ))), + min_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_01, + ))), + sum_value: Precision::Absent, + distinct_count: Precision::Inexact(1), + byte_size: Precision::Exact(8), + }, + ], + }; + assert_eq!(*statistics[0], expected_p0_statistics); + + // Test Partitioned mode + let partitioned_join = Arc::new(HashJoinExec::try_new( + Arc::clone(&left_scan), + Arc::clone(&right_scan), + on.clone(), + None, + &JoinType::Inner, + None, + PartitionMode::Partitioned, + NullEquality::NullEqualsNothing, + false, + )?) as Arc; + + // Test partition statistics for Partitioned mode + let statistics = (0..partitioned_join.output_partitioning().partition_count()) + .map(|idx| partitioned_join.partition_statistics(Some(idx))) + .collect::>>()?; + + // Check that we have the expected number of partitions + assert_eq!(statistics.len(), 2); + + // For partitioned mode, the min/max values are from the specific partition for each side. + let expected_p0_statistics = Statistics { + num_rows: Precision::Inexact(2), + total_byte_size: Precision::Absent, + column_statistics: vec![ + // Left id column: partition 0 only (id 3..4) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(4))), + min_value: Precision::Exact(ScalarValue::Int32(Some(3))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(8), + }, + // Left date column: partition 0 only (2025-03-01..2025-03-02) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_02, + ))), + min_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_01, + ))), + sum_value: Precision::Absent, + distinct_count: Precision::Inexact(1), + byte_size: Precision::Exact(8), + }, + // Right id column: partition 0 only (id 3..4) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(4))), + min_value: Precision::Exact(ScalarValue::Int32(Some(3))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(8), + }, + // Right date column: partition 0 only (2025-03-01..2025-03-02) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_02, + ))), + min_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_01, + ))), + sum_value: Precision::Absent, + distinct_count: Precision::Inexact(1), + byte_size: Precision::Exact(8), + }, + ], + }; + assert_eq!(*statistics[0], expected_p0_statistics); + + // Test Auto mode - should fall back to getting all partition statistics + let auto_join = Arc::new(HashJoinExec::try_new( + Arc::clone(&left_scan), + Arc::clone(&right_scan), + on, + None, + &JoinType::Inner, + None, + PartitionMode::Auto, + NullEquality::NullEqualsNothing, + false, + )?) as Arc; + + // Test partition statistics for Auto mode + let statistics = (0..auto_join.output_partitioning().partition_count()) + .map(|idx| auto_join.partition_statistics(Some(idx))) + .collect::>>()?; + + // Check that we have the expected number of partitions + assert_eq!(statistics.len(), 2); + + // For auto mode, the min/max values are from the entire left and right tables. + let expected_p0_statistics = Statistics { + num_rows: Precision::Inexact(4), + total_byte_size: Precision::Absent, + column_statistics: vec![ + // Left id column: all partitions (id 1..4) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(4))), + min_value: Precision::Exact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(16), + }, + // Left date column: all partitions (2025-03-01..2025-03-04) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_04, + ))), + min_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_01, + ))), + sum_value: Precision::Absent, + distinct_count: Precision::Inexact(1), + byte_size: Precision::Exact(16), + }, + // Right id column: all partitions (id 1..4) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(4))), + min_value: Precision::Exact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(16), + }, + // Right date column: all partitions (2025-03-01..2025-03-04) + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_04, + ))), + min_value: Precision::Exact(ScalarValue::Date32(Some( + DATE_2025_03_01, + ))), + sum_value: Precision::Absent, + distinct_count: Precision::Inexact(1), + byte_size: Precision::Exact(16), + }, + ], + }; + assert_eq!(*statistics[0], expected_p0_statistics); + Ok(()) + } } diff --git a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs index 7c00d323a8e69..00e016ae02cad 100644 --- a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs @@ -24,47 +24,48 @@ use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::physical_plan::CsvSource; use datafusion::datasource::source::DataSourceExec; -use datafusion_common::config::ConfigOptions; -use datafusion_common::Result; -use datafusion_common::{JoinSide, JoinType, ScalarValue}; +use datafusion_common::config::{ConfigOptions, CsvOptions}; +use datafusion_common::{JoinSide, JoinType, NullEquality, Result, ScalarValue}; +use datafusion_datasource::TableSchema; +use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::{ Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; +use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_physical_expr::expressions::{ - binary, cast, col, BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr, + BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr, binary, cast, col, }; -use datafusion_physical_expr::ScalarFunctionExpr; -use datafusion_physical_expr::{ - Distribution, Partitioning, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, +use datafusion_physical_expr::{Distribution, Partitioning, ScalarFunctionExpr}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::{ + OrderingRequirements, PhysicalSortExpr, PhysicalSortRequirement, }; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_optimizer::output_requirements::OutputRequirementExec; use datafusion_physical_optimizer::projection_pushdown::ProjectionPushdown; -use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion_physical_plan::coop::CooperativeExec; use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::joins::utils::{ColumnIndex, JoinFilter}; use datafusion_physical_plan::joins::{ HashJoinExec, NestedLoopJoinExec, PartitionMode, StreamJoinPartitionMode, SymmetricHashJoinExec, }; -use datafusion_physical_plan::projection::{update_expr, ProjectionExec}; +use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr, update_expr}; use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; -use datafusion_physical_plan::streaming::PartitionStream; -use datafusion_physical_plan::streaming::StreamingTableExec; +use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; use datafusion_physical_plan::union::UnionExec; -use datafusion_physical_plan::{get_plan_string, ExecutionPlan}; +use datafusion_physical_plan::{ExecutionPlan, displayable}; -use datafusion_datasource::file_scan_config::FileScanConfigBuilder; -use datafusion_expr_common::columnar_value::ColumnarValue; +use insta::assert_snapshot; use itertools::Itertools; /// Mocked UDF -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] struct DummyUDF { signature: Signature, } @@ -129,6 +130,7 @@ fn test_update_matching_exprs() -> Result<()> { )), ], Field::new("f", DataType::Int32, true).into(), + Arc::new(ConfigOptions::default()), )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d", 2))), @@ -194,6 +196,7 @@ fn test_update_matching_exprs() -> Result<()> { )), ], Field::new("f", DataType::Int32, true).into(), + Arc::new(ConfigOptions::default()), )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d", 3))), @@ -223,10 +226,16 @@ fn test_update_matching_exprs() -> Result<()> { )?), ]; + let child_exprs: Vec = child + .iter() + .map(|(expr, alias)| ProjectionExpr::new(expr.clone(), alias.clone())) + .collect(); for (expr, expected_expr) in exprs.into_iter().zip(expected_exprs.into_iter()) { - assert!(update_expr(&expr, &child, true)? - .unwrap() - .eq(&expected_expr)); + assert!( + update_expr(&expr, &child_exprs, true)? + .unwrap() + .eq(&expected_expr) + ); } Ok(()) @@ -262,6 +271,7 @@ fn test_update_projected_exprs() -> Result<()> { )), ], Field::new("f", DataType::Int32, true).into(), + Arc::new(ConfigOptions::default()), )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d", 2))), @@ -327,6 +337,7 @@ fn test_update_projected_exprs() -> Result<()> { )), ], Field::new("f", DataType::Int32, true).into(), + Arc::new(ConfigOptions::default()), )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d_new", 3))), @@ -356,10 +367,16 @@ fn test_update_projected_exprs() -> Result<()> { )?), ]; + let proj_exprs: Vec = projected_exprs + .iter() + .map(|(expr, alias)| ProjectionExpr::new(expr.clone(), alias.clone())) + .collect(); for (expr, expected_expr) in exprs.into_iter().zip(expected_exprs.into_iter()) { - assert!(update_expr(&expr, &projected_exprs, false)? - .unwrap() - .eq(&expected_expr)); + assert!( + update_expr(&expr, &proj_exprs, false)? + .unwrap() + .eq(&expected_expr) + ); } Ok(()) @@ -373,14 +390,20 @@ fn create_simple_csv_exec() -> Arc { Field::new("d", DataType::Int32, true), Field::new("e", DataType::Int32, true), ])); - let config = FileScanConfigBuilder::new( - ObjectStoreUrl::parse("test:///").unwrap(), - schema, - Arc::new(CsvSource::new(false, 0, 0)), - ) - .with_file(PartitionedFile::new("x".to_string(), 100)) - .with_projection(Some(vec![0, 1, 2, 3, 4])) - .build(); + let config = + FileScanConfigBuilder::new(ObjectStoreUrl::parse("test:///").unwrap(), { + let options = CsvOptions { + has_header: Some(false), + delimiter: 0, + quote: 0, + ..Default::default() + }; + Arc::new(CsvSource::new(schema.clone()).with_csv_options(options)) + }) + .with_file(PartitionedFile::new("x", 100)) + .with_projection_indices(Some(vec![0, 1, 2, 3, 4])) + .unwrap() + .build(); DataSourceExec::from_data_source(config) } @@ -392,14 +415,20 @@ fn create_projecting_csv_exec() -> Arc { Field::new("c", DataType::Int32, true), Field::new("d", DataType::Int32, true), ])); - let config = FileScanConfigBuilder::new( - ObjectStoreUrl::parse("test:///").unwrap(), - schema, - Arc::new(CsvSource::new(false, 0, 0)), - ) - .with_file(PartitionedFile::new("x".to_string(), 100)) - .with_projection(Some(vec![3, 2, 1])) - .build(); + let config = + FileScanConfigBuilder::new(ObjectStoreUrl::parse("test:///").unwrap(), { + let options = CsvOptions { + has_header: Some(false), + delimiter: 0, + quote: 0, + ..Default::default() + }; + Arc::new(CsvSource::new(schema.clone()).with_csv_options(options)) + }) + .with_file(PartitionedFile::new("x", 100)) + .with_projection_indices(Some(vec![3, 2, 1])) + .unwrap() + .build(); DataSourceExec::from_data_source(config) } @@ -421,24 +450,34 @@ fn test_csv_after_projection() -> Result<()> { let csv = create_projecting_csv_exec(); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - (Arc::new(Column::new("b", 2)), "b".to_string()), - (Arc::new(Column::new("d", 0)), "d".to_string()), + ProjectionExpr::new(Arc::new(Column::new("b", 2)), "b"), + ProjectionExpr::new(Arc::new(Column::new("d", 0)), "d"), ], csv.clone(), )?); - let initial = get_plan_string(&projection); - let expected_initial = [ - "ProjectionExec: expr=[b@2 as b, d@0 as d]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[d, c, b], file_type=csv, has_header=false", - ]; - assert_eq!(initial, expected_initial); + let initial = displayable(projection.as_ref()).indent(true).to_string(); + let actual = initial.trim(); + + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[b@2 as b, d@0 as d] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[d, c, b], file_type=csv, has_header=false + " + ); let after_optimize = ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; - let expected = - ["DataSourceExec: file_groups={1 group: [[x]]}, projection=[b, d], file_type=csv, has_header=false"]; - assert_eq!(get_plan_string(&after_optimize), expected); + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + + assert_snapshot!( + actual, + @"DataSourceExec: file_groups={1 group: [[x]]}, projection=[b, d], file_type=csv, has_header=false" + ); Ok(()) } @@ -448,24 +487,36 @@ fn test_memory_after_projection() -> Result<()> { let memory = create_projecting_memory_exec(); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - (Arc::new(Column::new("d", 2)), "d".to_string()), - (Arc::new(Column::new("e", 3)), "e".to_string()), - (Arc::new(Column::new("a", 1)), "a".to_string()), + ProjectionExpr::new(Arc::new(Column::new("d", 2)), "d"), + ProjectionExpr::new(Arc::new(Column::new("e", 3)), "e"), + ProjectionExpr::new(Arc::new(Column::new("a", 1)), "a"), ], memory.clone(), )?); - let initial = get_plan_string(&projection); - let expected_initial = [ - "ProjectionExec: expr=[d@2 as d, e@3 as e, a@1 as a]", - " DataSourceExec: partitions=0, partition_sizes=[]", - ]; - assert_eq!(initial, expected_initial); + let initial = displayable(projection.as_ref()).indent(true).to_string(); + let actual = initial.trim(); + + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[d@2 as d, e@3 as e, a@1 as a] + DataSourceExec: partitions=0, partition_sizes=[] + " + ); let after_optimize = ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; - let expected = ["DataSourceExec: partitions=0, partition_sizes=[]"]; - assert_eq!(get_plan_string(&after_optimize), expected); + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + + assert_snapshot!( + actual, + @"DataSourceExec: partitions=0, partition_sizes=[]" + ); + assert_eq!( after_optimize .clone() @@ -519,7 +570,7 @@ fn test_streaming_table_after_projection() -> Result<()> { }) as _], Some(&vec![0_usize, 2, 4, 3]), vec![ - LexOrdering::new(vec![ + [ PhysicalSortExpr { expr: Arc::new(Column::new("e", 2)), options: SortOptions::default(), @@ -528,11 +579,13 @@ fn test_streaming_table_after_projection() -> Result<()> { expr: Arc::new(Column::new("a", 0)), options: SortOptions::default(), }, - ]), - LexOrdering::new(vec![PhysicalSortExpr { + ] + .into(), + [PhysicalSortExpr { expr: Arc::new(Column::new("d", 3)), options: SortOptions::default(), - }]), + }] + .into(), ] .into_iter(), true, @@ -540,9 +593,9 @@ fn test_streaming_table_after_projection() -> Result<()> { )?; let projection = Arc::new(ProjectionExec::try_new( vec![ - (Arc::new(Column::new("d", 3)), "d".to_string()), - (Arc::new(Column::new("e", 2)), "e".to_string()), - (Arc::new(Column::new("a", 0)), "a".to_string()), + ProjectionExpr::new(Arc::new(Column::new("d", 3)), "d"), + ProjectionExpr::new(Arc::new(Column::new("e", 2)), "e"), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a"), ], Arc::new(streaming_table) as _, )?) as _; @@ -579,7 +632,7 @@ fn test_streaming_table_after_projection() -> Result<()> { assert_eq!( result.projected_output_ordering().into_iter().collect_vec(), vec![ - LexOrdering::new(vec![ + [ PhysicalSortExpr { expr: Arc::new(Column::new("e", 1)), options: SortOptions::default(), @@ -588,11 +641,13 @@ fn test_streaming_table_after_projection() -> Result<()> { expr: Arc::new(Column::new("a", 2)), options: SortOptions::default(), }, - ]), - LexOrdering::new(vec![PhysicalSortExpr { + ] + .into(), + [PhysicalSortExpr { expr: Arc::new(Column::new("d", 0)), options: SortOptions::default(), - }]), + }] + .into(), ] ); assert!(result.is_infinite()); @@ -605,45 +660,55 @@ fn test_projection_after_projection() -> Result<()> { let csv = create_simple_csv_exec(); let child_projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - (Arc::new(Column::new("c", 2)), "c".to_string()), - (Arc::new(Column::new("e", 4)), "new_e".to_string()), - (Arc::new(Column::new("a", 0)), "a".to_string()), - (Arc::new(Column::new("b", 1)), "new_b".to_string()), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c"), + ProjectionExpr::new(Arc::new(Column::new("e", 4)), "new_e"), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "new_b"), ], csv.clone(), )?); let top_projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - (Arc::new(Column::new("new_b", 3)), "new_b".to_string()), - ( + ProjectionExpr::new(Arc::new(Column::new("new_b", 3)), "new_b"), + ProjectionExpr::new( Arc::new(BinaryExpr::new( Arc::new(Column::new("c", 0)), Operator::Plus, Arc::new(Column::new("new_e", 1)), )), - "binary".to_string(), + "binary", ), - (Arc::new(Column::new("new_b", 3)), "newest_b".to_string()), + ProjectionExpr::new(Arc::new(Column::new("new_b", 3)), "newest_b"), ], child_projection.clone(), )?); - let initial = get_plan_string(&top_projection); - let expected_initial = [ - "ProjectionExec: expr=[new_b@3 as new_b, c@0 + new_e@1 as binary, new_b@3 as newest_b]", - " ProjectionExec: expr=[c@2 as c, e@4 as new_e, a@0 as a, b@1 as new_b]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false" - ]; - assert_eq!(initial, expected_initial); + let initial = displayable(top_projection.as_ref()) + .indent(true) + .to_string(); + let actual = initial.trim(); + + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[new_b@3 as new_b, c@0 + new_e@1 as binary, new_b@3 as newest_b] + ProjectionExec: expr=[c@2 as c, e@4 as new_e, a@0 as a, b@1 as new_b] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); let after_optimize = ProjectionPushdown::new().optimize(top_projection, &ConfigOptions::new())?; - let expected = [ - "ProjectionExec: expr=[b@1 as new_b, c@2 + e@4 as binary, b@1 as newest_b]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + + assert_snapshot!( + actual, + @"DataSourceExec: file_groups={1 group: [[x]]}, projection=[b@1 as new_b, c@2 + e@4 as binary, b@1 as newest_b], file_type=csv, has_header=false" + ); Ok(()) } @@ -652,67 +717,84 @@ fn test_projection_after_projection() -> Result<()> { fn test_output_req_after_projection() -> Result<()> { let csv = create_simple_csv_exec(); let sort_req: Arc = Arc::new(OutputRequirementExec::new( - csv.clone(), - Some(LexRequirement::new(vec![ - PhysicalSortRequirement { - expr: Arc::new(Column::new("b", 1)), - options: Some(SortOptions::default()), - }, - PhysicalSortRequirement { - expr: Arc::new(BinaryExpr::new( - Arc::new(Column::new("c", 2)), - Operator::Plus, - Arc::new(Column::new("a", 0)), - )), - options: Some(SortOptions::default()), - }, - ])), + csv, + Some(OrderingRequirements::new( + [ + PhysicalSortRequirement::new( + Arc::new(Column::new("b", 1)), + Some(SortOptions::default()), + ), + PhysicalSortRequirement::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Plus, + Arc::new(Column::new("a", 0)), + )), + Some(SortOptions::default()), + ), + ] + .into(), + )), Distribution::HashPartitioned(vec![ Arc::new(Column::new("a", 0)), Arc::new(Column::new("b", 1)), ]), + None, )); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - (Arc::new(Column::new("c", 2)), "c".to_string()), - (Arc::new(Column::new("a", 0)), "new_a".to_string()), - (Arc::new(Column::new("b", 1)), "b".to_string()), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c"), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "new_a"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b"), ], sort_req.clone(), )?); - let initial = get_plan_string(&projection); - let expected_initial = [ - "ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", - " OutputRequirementExec", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false" - ]; - assert_eq!(initial, expected_initial); + let initial = displayable(projection.as_ref()).indent(true).to_string(); + let actual = initial.trim(); + + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b] + OutputRequirementExec: order_by=[(b@1, asc), (c@2 + a@0, asc)], dist_by=HashPartitioned[[a@0, b@1]]) + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); let after_optimize = ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; - let expected: [&str; 3] = [ - "OutputRequirementExec", - " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false" - ]; - - assert_eq!(get_plan_string(&after_optimize), expected); - let expected_reqs = LexRequirement::new(vec![ - PhysicalSortRequirement { - expr: Arc::new(Column::new("b", 2)), - options: Some(SortOptions::default()), - }, - PhysicalSortRequirement { - expr: Arc::new(BinaryExpr::new( - Arc::new(Column::new("c", 0)), - Operator::Plus, - Arc::new(Column::new("new_a", 1)), - )), - options: Some(SortOptions::default()), - }, - ]); + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + + assert_snapshot!( + actual, + @r" + OutputRequirementExec: order_by=[(b@2, asc), (c@0 + new_a@1, asc)], dist_by=HashPartitioned[[new_a@1, b@2]]) + DataSourceExec: file_groups={1 group: [[x]]}, projection=[c, a@0 as new_a, b], file_type=csv, has_header=false + " + ); + + let expected_reqs = OrderingRequirements::new( + [ + PhysicalSortRequirement::new( + Arc::new(Column::new("b", 2)), + Some(SortOptions::default()), + ), + PhysicalSortRequirement::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 0)), + Operator::Plus, + Arc::new(Column::new("new_a", 1)), + )), + Some(SortOptions::default()), + ), + ] + .into(), + ); assert_eq!( after_optimize .as_any() @@ -734,10 +816,11 @@ fn test_output_req_after_projection() -> Result<()> { .required_input_distribution()[0] .clone() { - assert!(vec - .iter() - .zip(expected_distribution) - .all(|(actual, expected)| actual.eq(&expected))); + assert!( + vec.iter() + .zip(expected_distribution) + .all(|(actual, expected)| actual.eq(&expected)) + ); } else { panic!("Expected HashPartitioned distribution!"); }; @@ -752,29 +835,39 @@ fn test_coalesce_partitions_after_projection() -> Result<()> { Arc::new(CoalescePartitionsExec::new(csv)); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - (Arc::new(Column::new("b", 1)), "b".to_string()), - (Arc::new(Column::new("a", 0)), "a_new".to_string()), - (Arc::new(Column::new("d", 3)), "d".to_string()), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b"), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a_new"), + ProjectionExpr::new(Arc::new(Column::new("d", 3)), "d"), ], coalesce_partitions, )?); - let initial = get_plan_string(&projection); - let expected_initial = [ - "ProjectionExec: expr=[b@1 as b, a@0 as a_new, d@3 as d]", - " CoalescePartitionsExec", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - ]; - assert_eq!(initial, expected_initial); + let initial = displayable(projection.as_ref()).indent(true).to_string(); + let actual = initial.trim(); + + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[b@1 as b, a@0 as a_new, d@3 as d] + CoalescePartitionsExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); let after_optimize = ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; - let expected = [ - "CoalescePartitionsExec", - " ProjectionExec: expr=[b@1 as b, a@0 as a_new, d@3 as d]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + + assert_snapshot!( + actual, + @r" + CoalescePartitionsExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[b, a@0 as a_new, d], file_type=csv, has_header=false + " + ); Ok(()) } @@ -795,33 +888,43 @@ fn test_filter_after_projection() -> Result<()> { Arc::new(Column::new("a", 0)), )), )); - let filter: Arc = Arc::new(FilterExec::try_new(predicate, csv)?); + let filter = Arc::new(FilterExec::try_new(predicate, csv)?); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - (Arc::new(Column::new("a", 0)), "a_new".to_string()), - (Arc::new(Column::new("b", 1)), "b".to_string()), - (Arc::new(Column::new("d", 3)), "d".to_string()), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a_new"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b"), + ProjectionExpr::new(Arc::new(Column::new("d", 3)), "d"), ], filter.clone(), - )?); + )?) as _; + + let initial = displayable(projection.as_ref()).indent(true).to_string(); + let actual = initial.trim(); - let initial = get_plan_string(&projection); - let expected_initial = [ - "ProjectionExec: expr=[a@0 as a_new, b@1 as b, d@3 as d]", - " FilterExec: b@1 - a@0 > d@3 - a@0", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - ]; - assert_eq!(initial, expected_initial); + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[a@0 as a_new, b@1 as b, d@3 as d] + FilterExec: b@1 - a@0 > d@3 - a@0 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); let after_optimize = ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; - let expected = [ - "FilterExec: b@1 - a_new@0 > d@2 - a_new@0", - " ProjectionExec: expr=[a@0 as a_new, b@1 as b, d@3 as d]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + + assert_snapshot!( + actual, + @r" + FilterExec: b@1 - a_new@0 > d@2 - a_new@0 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a@0 as a_new, b, d], file_type=csv, has_header=false + " + ); Ok(()) } @@ -875,41 +978,50 @@ fn test_join_after_projection() -> Result<()> { ])), )), &JoinType::Inner, - true, + NullEquality::NullEqualsNull, None, None, StreamJoinPartitionMode::SinglePartition, )?); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - (Arc::new(Column::new("c", 2)), "c_from_left".to_string()), - (Arc::new(Column::new("b", 1)), "b_from_left".to_string()), - (Arc::new(Column::new("a", 0)), "a_from_left".to_string()), - (Arc::new(Column::new("a", 5)), "a_from_right".to_string()), - (Arc::new(Column::new("c", 7)), "c_from_right".to_string()), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c_from_left"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b_from_left"), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a_from_left"), + ProjectionExpr::new(Arc::new(Column::new("a", 5)), "a_from_right"), + ProjectionExpr::new(Arc::new(Column::new("c", 7)), "c_from_right"), ], join, - )?); - let initial = get_plan_string(&projection); - let expected_initial = [ - "ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left, a@5 as a_from_right, c@7 as c_from_right]", - " SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false" - ]; - assert_eq!(initial, expected_initial); + )?) as _; + let initial = displayable(projection.as_ref()).indent(true).to_string(); + let actual = initial.trim(); + + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left, a@5 as a_from_right, c@7 as c_from_right] + SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); let after_optimize = ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; - let expected = [ - "SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(b_from_left@1, c_from_right@1)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2", - " ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - " ProjectionExec: expr=[a@0 as a_from_right, c@2 as c_from_right]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + + assert_snapshot!( + actual, + @r" + SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(b_from_left@1, c_from_right@1)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a@0 as a_from_right, c@2 as c_from_right], file_type=csv, has_header=false + " + ); let expected_filter_col_ind = vec![ ColumnIndex { @@ -945,7 +1057,7 @@ fn test_join_after_required_projection() -> Result<()> { let left_csv = create_simple_csv_exec(); let right_csv = create_simple_csv_exec(); - let join: Arc = Arc::new(SymmetricHashJoinExec::try_new( + let join = Arc::new(SymmetricHashJoinExec::try_new( left_csv, right_csv, vec![(Arc::new(Column::new("b", 1)), Arc::new(Column::new("c", 2)))], @@ -989,45 +1101,56 @@ fn test_join_after_required_projection() -> Result<()> { ])), )), &JoinType::Inner, - true, + NullEquality::NullEqualsNull, None, None, StreamJoinPartitionMode::SinglePartition, )?); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - (Arc::new(Column::new("a", 5)), "a".to_string()), - (Arc::new(Column::new("b", 6)), "b".to_string()), - (Arc::new(Column::new("c", 7)), "c".to_string()), - (Arc::new(Column::new("d", 8)), "d".to_string()), - (Arc::new(Column::new("e", 9)), "e".to_string()), - (Arc::new(Column::new("a", 0)), "a".to_string()), - (Arc::new(Column::new("b", 1)), "b".to_string()), - (Arc::new(Column::new("c", 2)), "c".to_string()), - (Arc::new(Column::new("d", 3)), "d".to_string()), - (Arc::new(Column::new("e", 4)), "e".to_string()), + ProjectionExpr::new(Arc::new(Column::new("a", 5)), "a"), + ProjectionExpr::new(Arc::new(Column::new("b", 6)), "b"), + ProjectionExpr::new(Arc::new(Column::new("c", 7)), "c"), + ProjectionExpr::new(Arc::new(Column::new("d", 8)), "d"), + ProjectionExpr::new(Arc::new(Column::new("e", 9)), "e"), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b"), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c"), + ProjectionExpr::new(Arc::new(Column::new("d", 3)), "d"), + ProjectionExpr::new(Arc::new(Column::new("e", 4)), "e"), ], join, - )?); - let initial = get_plan_string(&projection); - let expected_initial = [ - "ProjectionExec: expr=[a@5 as a, b@6 as b, c@7 as c, d@8 as d, e@9 as e, a@0 as a, b@1 as b, c@2 as c, d@3 as d, e@4 as e]", - " SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false" - ]; - assert_eq!(initial, expected_initial); + )?) as _; + let initial = displayable(projection.as_ref()).indent(true).to_string(); + let actual = initial.trim(); + + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[a@5 as a, b@6 as b, c@7 as c, d@8 as d, e@9 as e, a@0 as a, b@1 as b, c@2 as c, d@3 as d, e@4 as e] + SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); let after_optimize = ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; - let expected = [ - "ProjectionExec: expr=[a@5 as a, b@6 as b, c@7 as c, d@8 as d, e@9 as e, a@0 as a, b@1 as b, c@2 as c, d@3 as d, e@4 as e]", - " SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[a@5 as a, b@6 as b, c@7 as c, d@8 as d, e@9 as e, a@0 as a, b@1 as b, c@2 as c, d@3 as d, e@4 as e] + SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); Ok(()) } @@ -1061,7 +1184,7 @@ fn test_nested_loop_join_after_projection() -> Result<()> { Field::new("c", DataType::Int32, true), ]); - let join: Arc = Arc::new(NestedLoopJoinExec::try_new( + let join = Arc::new(NestedLoopJoinExec::try_new( left_csv, right_csv, Some(JoinFilter::new( @@ -1071,29 +1194,39 @@ fn test_nested_loop_join_after_projection() -> Result<()> { )), &JoinType::Inner, None, - )?); + )?) as _; let projection: Arc = Arc::new(ProjectionExec::try_new( - vec![(col_left_c, "c".to_string())], + vec![ProjectionExpr::new(col_left_c, "c")], Arc::clone(&join), - )?); - let initial = get_plan_string(&projection); - let expected_initial = [ - "ProjectionExec: expr=[c@2 as c]", - " NestedLoopJoinExec: join_type=Inner, filter=a@0 < b@1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - ]; - assert_eq!(initial, expected_initial); + )?) as _; + let initial = displayable(projection.as_ref()).indent(true).to_string(); + let actual = initial.trim(); + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[c@2 as c] + NestedLoopJoinExec: join_type=Inner, filter=a@0 < b@1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); - let after_optimize = + let after_optimize_string = ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; - let expected = [ - "NestedLoopJoinExec: join_type=Inner, filter=a@0 < b@1, projection=[c@2]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let after_optimize_string = displayable(after_optimize_string.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + assert_snapshot!( + actual, + @r" + NestedLoopJoinExec: join_type=Inner, filter=a@0 < b@1, projection=[c@2] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + + ); Ok(()) } @@ -1104,7 +1237,7 @@ fn test_hash_join_after_projection() -> Result<()> { let left_csv = create_simple_csv_exec(); let right_csv = create_simple_csv_exec(); - let join: Arc = Arc::new(HashJoinExec::try_new( + let join = Arc::new(HashJoinExec::try_new( left_csv, right_csv, vec![(Arc::new(Column::new("b", 1)), Arc::new(Column::new("c", 2)))], @@ -1150,46 +1283,74 @@ fn test_hash_join_after_projection() -> Result<()> { &JoinType::Inner, None, PartitionMode::Auto, - true, + NullEquality::NullEqualsNothing, + false, )?); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - (Arc::new(Column::new("c", 2)), "c_from_left".to_string()), - (Arc::new(Column::new("b", 1)), "b_from_left".to_string()), - (Arc::new(Column::new("a", 0)), "a_from_left".to_string()), - (Arc::new(Column::new("c", 7)), "c_from_right".to_string()), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c_from_left"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b_from_left"), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a_from_left"), + ProjectionExpr::new(Arc::new(Column::new("c", 7)), "c_from_right"), ], join.clone(), - )?); - let initial = get_plan_string(&projection); - let expected_initial = [ - "ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left, c@7 as c_from_right]", " HashJoinExec: mode=Auto, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false" - ]; - assert_eq!(initial, expected_initial); + )?) as _; + let initial = displayable(projection.as_ref()).indent(true).to_string(); + let actual = initial.trim(); + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left, c@7 as c_from_right] + HashJoinExec: mode=Auto, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); let after_optimize = ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); // HashJoinExec only returns result after projection. Because there are some alias columns in the projection, the ProjectionExec is not removed. - let expected = ["ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left, c@3 as c_from_right]", " HashJoinExec: mode=Auto, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2, projection=[a@0, b@1, c@2, c@7]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false"]; - assert_eq!(get_plan_string(&after_optimize), expected); + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left, c@3 as c_from_right] + HashJoinExec: mode=Auto, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2, projection=[a@0, b@1, c@2, c@7] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); - let projection: Arc = Arc::new(ProjectionExec::try_new( + let projection = Arc::new(ProjectionExec::try_new( vec![ - (Arc::new(Column::new("a", 0)), "a".to_string()), - (Arc::new(Column::new("b", 1)), "b".to_string()), - (Arc::new(Column::new("c", 2)), "c".to_string()), - (Arc::new(Column::new("c", 7)), "c".to_string()), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b"), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c"), + ProjectionExpr::new(Arc::new(Column::new("c", 7)), "c"), ], join.clone(), )?); let after_optimize = ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); // Comparing to the previous result, this projection don't have alias columns either change the order of output fields. So the ProjectionExec is removed. - let expected = ["HashJoinExec: mode=Auto, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2, projection=[a@0, b@1, c@2, c@7]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false"]; - assert_eq!(get_plan_string(&after_optimize), expected); + assert_snapshot!( + actual, + @r" + HashJoinExec: mode=Auto, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2, projection=[a@0, b@1, c@2, c@7] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); Ok(()) } @@ -1197,7 +1358,7 @@ fn test_hash_join_after_projection() -> Result<()> { #[test] fn test_repartition_after_projection() -> Result<()> { let csv = create_simple_csv_exec(); - let repartition: Arc = Arc::new(RepartitionExec::try_new( + let repartition = Arc::new(RepartitionExec::try_new( csv, Partitioning::Hash( vec![ @@ -1210,29 +1371,37 @@ fn test_repartition_after_projection() -> Result<()> { )?); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - (Arc::new(Column::new("b", 1)), "b_new".to_string()), - (Arc::new(Column::new("a", 0)), "a".to_string()), - (Arc::new(Column::new("d", 3)), "d_new".to_string()), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b_new"), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a"), + ProjectionExpr::new(Arc::new(Column::new("d", 3)), "d_new"), ], repartition, - )?); - let initial = get_plan_string(&projection); - let expected_initial = [ - "ProjectionExec: expr=[b@1 as b_new, a@0 as a, d@3 as d_new]", - " RepartitionExec: partitioning=Hash([a@0, b@1, d@3], 6), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - ]; - assert_eq!(initial, expected_initial); + )?) as _; + let initial = displayable(projection.as_ref()).indent(true).to_string(); + let actual = initial.trim(); + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[b@1 as b_new, a@0 as a, d@3 as d_new] + RepartitionExec: partitioning=Hash([a@0, b@1, d@3], 6), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); let after_optimize = ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; - let expected = [ - "RepartitionExec: partitioning=Hash([a@1, b_new@0, d_new@2], 6), input_partitions=1", - " ProjectionExec: expr=[b@1 as b_new, a@0 as a, d@3 as d_new]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + assert_snapshot!( + actual, + @r" + RepartitionExec: partitioning=Hash([a@1, b_new@0, d_new@2], 6), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[b@1 as b_new, a, d@3 as d_new], file_type=csv, has_header=false + " + ); assert_eq!( after_optimize @@ -1257,49 +1426,52 @@ fn test_repartition_after_projection() -> Result<()> { #[test] fn test_sort_after_projection() -> Result<()> { let csv = create_simple_csv_exec(); - let sort_req: Arc = Arc::new(SortExec::new( - LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(BinaryExpr::new( - Arc::new(Column::new("c", 2)), - Operator::Plus, - Arc::new(Column::new("a", 0)), - )), - options: SortOptions::default(), - }, - ]), - csv.clone(), - )); + let sort_exec = SortExec::new( + [ + PhysicalSortExpr::new_default(Arc::new(Column::new("b", 1))), + PhysicalSortExpr::new_default(Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Plus, + Arc::new(Column::new("a", 0)), + ))), + ] + .into(), + csv, + ); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - (Arc::new(Column::new("c", 2)), "c".to_string()), - (Arc::new(Column::new("a", 0)), "new_a".to_string()), - (Arc::new(Column::new("b", 1)), "b".to_string()), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c"), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "new_a"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b"), ], - sort_req.clone(), - )?); + Arc::new(sort_exec), + )?) as _; - let initial = get_plan_string(&projection); - let expected_initial = [ - "ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", - " SortExec: expr=[b@1 ASC, c@2 + a@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false" - ]; - assert_eq!(initial, expected_initial); + let initial = displayable(projection.as_ref()).indent(true).to_string(); + let actual = initial.trim(); + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b] + SortExec: expr=[b@1 ASC, c@2 + a@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); let after_optimize = ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; - let expected = [ - "SortExec: expr=[b@2 ASC, c@0 + new_a@1 ASC], preserve_partitioning=[false]", - " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + assert_snapshot!( + actual, + @r" + SortExec: expr=[b@2 ASC, c@0 + new_a@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[c, a@0 as new_a, b], file_type=csv, has_header=false + " + ); Ok(()) } @@ -1307,49 +1479,52 @@ fn test_sort_after_projection() -> Result<()> { #[test] fn test_sort_preserving_after_projection() -> Result<()> { let csv = create_simple_csv_exec(); - let sort_req: Arc = Arc::new(SortPreservingMergeExec::new( - LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(BinaryExpr::new( - Arc::new(Column::new("c", 2)), - Operator::Plus, - Arc::new(Column::new("a", 0)), - )), - options: SortOptions::default(), - }, - ]), - csv.clone(), - )); + let sort_exec = SortPreservingMergeExec::new( + [ + PhysicalSortExpr::new_default(Arc::new(Column::new("b", 1))), + PhysicalSortExpr::new_default(Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Plus, + Arc::new(Column::new("a", 0)), + ))), + ] + .into(), + csv, + ); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - (Arc::new(Column::new("c", 2)), "c".to_string()), - (Arc::new(Column::new("a", 0)), "new_a".to_string()), - (Arc::new(Column::new("b", 1)), "b".to_string()), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c"), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "new_a"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b"), ], - sort_req.clone(), - )?); + Arc::new(sort_exec), + )?) as _; - let initial = get_plan_string(&projection); - let expected_initial = [ - "ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", - " SortPreservingMergeExec: [b@1 ASC, c@2 + a@0 ASC]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false" - ]; - assert_eq!(initial, expected_initial); + let initial = displayable(projection.as_ref()).indent(true).to_string(); + let actual = initial.trim(); + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b] + SortPreservingMergeExec: [b@1 ASC, c@2 + a@0 ASC] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); let after_optimize = ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; - let expected = [ - "SortPreservingMergeExec: [b@2 ASC, c@0 + new_a@1 ASC]", - " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + assert_snapshot!( + actual, + @r" + SortPreservingMergeExec: [b@2 ASC, c@0 + new_a@1 ASC] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[c, a@0 as new_a, b], file_type=csv, has_header=false + " + ); Ok(()) } @@ -1357,40 +1532,45 @@ fn test_sort_preserving_after_projection() -> Result<()> { #[test] fn test_union_after_projection() -> Result<()> { let csv = create_simple_csv_exec(); - let union: Arc = - Arc::new(UnionExec::new(vec![csv.clone(), csv.clone(), csv])); + let union = UnionExec::try_new(vec![csv.clone(), csv.clone(), csv])?; let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - (Arc::new(Column::new("c", 2)), "c".to_string()), - (Arc::new(Column::new("a", 0)), "new_a".to_string()), - (Arc::new(Column::new("b", 1)), "b".to_string()), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c"), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "new_a"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b"), ], union.clone(), - )?); + )?) as _; - let initial = get_plan_string(&projection); - let expected_initial = [ - "ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", - " UnionExec", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false" - ]; - assert_eq!(initial, expected_initial); + let initial = displayable(projection.as_ref()).indent(true).to_string(); + let actual = initial.trim(); + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b] + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); let after_optimize = ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; - let expected = [ - "UnionExec", - " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + assert_snapshot!( + actual, + @r" + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[c, a@0 as new_a, b], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[c, a@0 as new_a, b], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[c, a@0 as new_a, b], file_type=csv, has_header=false + " + ); Ok(()) } @@ -1403,14 +1583,23 @@ fn partitioned_data_source() -> Arc { Field::new("string_col", DataType::Utf8, true), ])); + let options = CsvOptions { + has_header: Some(false), + delimiter: b',', + quote: b'"', + ..Default::default() + }; + let table_schema = TableSchema::new( + Arc::clone(&file_schema), + vec![Arc::new(Field::new("partition_col", DataType::Utf8, true))], + ); let config = FileScanConfigBuilder::new( ObjectStoreUrl::parse("test:///").unwrap(), - file_schema.clone(), - Arc::new(CsvSource::default()), + Arc::new(CsvSource::new(table_schema).with_csv_options(options)), ) - .with_file(PartitionedFile::new("x".to_string(), 100)) - .with_table_partition_cols(vec![Field::new("partition_col", DataType::Utf8, true)]) - .with_projection(Some(vec![0, 1, 2])) + .with_file(PartitionedFile::new("x", 100)) + .with_projection_indices(Some(vec![0, 1, 2])) + .unwrap() .build(); DataSourceExec::from_data_source(config) @@ -1421,20 +1610,17 @@ fn test_partition_col_projection_pushdown() -> Result<()> { let source = partitioned_data_source(); let partitioned_schema = source.schema(); - let projection = Arc::new(ProjectionExec::try_new( + let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - ( + ProjectionExpr::new( col("string_col", partitioned_schema.as_ref())?, - "string_col".to_string(), + "string_col", ), - ( + ProjectionExpr::new( col("partition_col", partitioned_schema.as_ref())?, - "partition_col".to_string(), - ), - ( - col("int_col", partitioned_schema.as_ref())?, - "int_col".to_string(), + "partition_col", ), + ProjectionExpr::new(col("int_col", partitioned_schema.as_ref())?, "int_col"), ], source, )?); @@ -1442,11 +1628,14 @@ fn test_partition_col_projection_pushdown() -> Result<()> { let after_optimize = ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; - let expected = [ - "ProjectionExec: expr=[string_col@1 as string_col, partition_col@2 as partition_col, int_col@0 as int_col]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[int_col, string_col, partition_col], file_type=csv, has_header=false" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + assert_snapshot!( + actual, + @"DataSourceExec: file_groups={1 group: [[x]]}, projection=[string_col, partition_col, int_col], file_type=csv, has_header=false" + ); Ok(()) } @@ -1456,25 +1645,22 @@ fn test_partition_col_projection_pushdown_expr() -> Result<()> { let source = partitioned_data_source(); let partitioned_schema = source.schema(); - let projection = Arc::new(ProjectionExec::try_new( + let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - ( + ProjectionExpr::new( col("string_col", partitioned_schema.as_ref())?, - "string_col".to_string(), + "string_col", ), - ( + ProjectionExpr::new( // CAST(partition_col, Utf8View) cast( col("partition_col", partitioned_schema.as_ref())?, partitioned_schema.as_ref(), DataType::Utf8View, )?, - "partition_col".to_string(), - ), - ( - col("int_col", partitioned_schema.as_ref())?, - "int_col".to_string(), + "partition_col", ), + ProjectionExpr::new(col("int_col", partitioned_schema.as_ref())?, "int_col"), ], source, )?); @@ -1482,11 +1668,102 @@ fn test_partition_col_projection_pushdown_expr() -> Result<()> { let after_optimize = ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; - let expected = [ - "ProjectionExec: expr=[string_col@1 as string_col, CAST(partition_col@2 AS Utf8View) as partition_col, int_col@0 as int_col]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[int_col, string_col, partition_col], file_type=csv, has_header=false" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + assert_snapshot!( + actual, + @"DataSourceExec: file_groups={1 group: [[x]]}, projection=[string_col, CAST(partition_col@2 AS Utf8View) as partition_col, int_col], file_type=csv, has_header=false" + ); + + Ok(()) +} + +#[test] +fn test_cooperative_exec_after_projection() -> Result<()> { + let csv = create_simple_csv_exec(); + let cooperative: Arc = Arc::new(CooperativeExec::new(csv)); + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![ + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a"), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b"), + ], + cooperative, + )?); + + let initial = displayable(projection.as_ref()).indent(true).to_string(); + let actual = initial.trim(); + + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[a@0 as a, b@1 as b] + CooperativeExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + + // Projection should be pushed down through CooperativeExec + assert_snapshot!( + actual, + @r" + CooperativeExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b], file_type=csv, has_header=false + " + ); + + Ok(()) +} + +#[test] +fn test_hash_join_empty_projection_embeds() -> Result<()> { + let left_csv = create_simple_csv_exec(); + let right_csv = create_simple_csv_exec(); + + let join = Arc::new(HashJoinExec::try_new( + left_csv, + right_csv, + vec![(Arc::new(Column::new("a", 0)), Arc::new(Column::new("a", 0)))], + None, + &JoinType::Right, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + false, + )?); + + // Empty projection: no columns needed from the join output + let projection: Arc = Arc::new(ProjectionExec::try_new( + vec![] as Vec, + join, + )?); + + let after_optimize = + ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + + // The empty projection should be embedded into the HashJoinExec, + // resulting in projection=[] on the join and no ProjectionExec wrapper. + assert_snapshot!( + actual, + @r" + HashJoinExec: mode=CollectLeft, join_type=Right, on=[(a@0, a@0)], projection=[] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); Ok(()) } diff --git a/datafusion/core/tests/physical_optimizer/pushdown_sort.rs b/datafusion/core/tests/physical_optimizer/pushdown_sort.rs new file mode 100644 index 0000000000000..d6fd4d8d00ae4 --- /dev/null +++ b/datafusion/core/tests/physical_optimizer/pushdown_sort.rs @@ -0,0 +1,998 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for sort pushdown optimizer rule (Phase 1) +//! +//! Phase 1 tests verify that: +//! 1. Reverse scan is enabled (reverse_row_groups=true) +//! 2. SortExec is kept (because ordering is inexact) +//! 3. output_ordering remains unchanged +//! 4. Early termination is enabled for TopK queries +//! 5. Prefix matching works correctly + +use datafusion_physical_expr::expressions; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_optimizer::pushdown_sort::PushdownSort; +use std::sync::Arc; + +use crate::physical_optimizer::test_utils::{ + OptimizationTest, coalesce_partitions_exec, parquet_exec, parquet_exec_with_sort, + projection_exec, projection_exec_with_alias, repartition_exec, schema, + simple_projection_exec, sort_exec, sort_exec_with_fetch, sort_expr, sort_expr_named, + test_scan_with_ordering, +}; + +#[test] +fn test_sort_pushdown_disabled() { + // When pushdown is disabled, plan should remain unchanged + let schema = schema(); + let source = parquet_exec(schema.clone()); + let sort_exprs = LexOrdering::new(vec![sort_expr("a", &schema)]).unwrap(); + let plan = sort_exec(sort_exprs, source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), false), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + " + ); +} + +#[test] +fn test_sort_pushdown_basic_phase1() { + // Phase 1: Reverse scan enabled, Sort kept, output_ordering unchanged + let schema = schema(); + + // Source has ASC NULLS LAST ordering (default) + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Request DESC NULLS LAST ordering (exact reverse) + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let plan = sort_exec(desc_ordering, source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_sort_with_limit_phase1() { + // Phase 1: Sort with fetch enables early termination but keeps Sort + let schema = schema(); + + // Source has ASC ordering + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Request DESC ordering with limit + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let plan = sort_exec_with_fetch(desc_ordering, Some(10), source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: TopK(fetch=10), expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: TopK(fetch=10), expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_sort_multiple_columns_phase1() { + // Phase 1: Sort on multiple columns - reverse multi-column ordering + let schema = schema(); + + // Source has [a DESC NULLS LAST, b ASC] ordering + let a = sort_expr("a", &schema); + let b = sort_expr("b", &schema); + let source_ordering = LexOrdering::new(vec![a.clone().reverse(), b.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Request [a ASC NULLS FIRST, b DESC] ordering (exact reverse) + let reverse_ordering = + LexOrdering::new(vec![a.clone().asc().nulls_first(), b.reverse()]).unwrap(); + let plan = sort_exec(reverse_ordering, source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 ASC, b@1 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 DESC NULLS LAST, b@1 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 ASC, b@1 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +// ============================================================================ +// PREFIX MATCHING TESTS +// ============================================================================ + +#[test] +fn test_prefix_match_single_column() { + // Test prefix matching: source has [a DESC, b ASC], query needs [a ASC] + // After reverse: [a ASC, b DESC] which satisfies [a ASC] prefix + let schema = schema(); + + // Source has [a DESC NULLS LAST, b ASC NULLS LAST] ordering + let a = sort_expr("a", &schema); + let b = sort_expr("b", &schema); + let source_ordering = LexOrdering::new(vec![a.clone().reverse(), b]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Request only [a ASC NULLS FIRST] - a prefix of the reversed ordering + let prefix_ordering = LexOrdering::new(vec![a.clone().asc().nulls_first()]).unwrap(); + let plan = sort_exec(prefix_ordering, source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 DESC NULLS LAST, b@1 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_prefix_match_with_limit() { + // Test prefix matching with LIMIT - important for TopK optimization + let schema = schema(); + + // Source has [a ASC, b DESC, c ASC] ordering + let a = sort_expr("a", &schema); + let b = sort_expr("b", &schema); + let c = sort_expr("c", &schema); + let source_ordering = + LexOrdering::new(vec![a.clone(), b.clone().reverse(), c]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Request [a DESC NULLS LAST, b ASC NULLS FIRST] with LIMIT 100 + // This is a prefix (2 columns) of the reversed 3-column ordering + let prefix_ordering = + LexOrdering::new(vec![a.reverse(), b.clone().asc().nulls_first()]).unwrap(); + let plan = sort_exec_with_fetch(prefix_ordering, Some(100), source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: TopK(fetch=100), expr=[a@0 DESC NULLS LAST, b@1 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 DESC NULLS LAST, c@2 ASC], file_type=parquet + output: + Ok: + - SortExec: TopK(fetch=100), expr=[a@0 DESC NULLS LAST, b@1 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_prefix_match_through_transparent_nodes() { + // Test prefix matching works through transparent nodes + let schema = schema(); + + // Source has [a DESC NULLS LAST, b ASC, c DESC] ordering + let a = sort_expr("a", &schema); + let b = sort_expr("b", &schema); + let c = sort_expr("c", &schema); + let source_ordering = + LexOrdering::new(vec![a.clone().reverse(), b, c.reverse()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + let repartition = repartition_exec(source); + + // Request only [a ASC NULLS FIRST] - prefix of reversed ordering + let prefix_ordering = LexOrdering::new(vec![a.clone().asc().nulls_first()]).unwrap(); + let plan = sort_exec(prefix_ordering, repartition); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 DESC NULLS LAST, b@1 ASC, c@2 DESC NULLS LAST], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_no_prefix_match_wrong_direction() { + // Test that prefix matching does NOT work if the direction is wrong + let schema = schema(); + + // Source has [a DESC, b ASC] ordering + let a = sort_expr("a", &schema); + let b = sort_expr("b", &schema); + let source_ordering = LexOrdering::new(vec![a.clone().reverse(), b]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Request [a DESC] - same direction as source, NOT a reverse prefix + let same_direction = LexOrdering::new(vec![a.clone().reverse()]).unwrap(); + let plan = sort_exec(same_direction, source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 DESC NULLS LAST, b@1 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 DESC NULLS LAST, b@1 ASC], file_type=parquet + " + ); +} + +#[test] +fn test_no_prefix_match_longer_than_source() { + // Test that prefix matching does NOT work if requested is longer than source + let schema = schema(); + + // Source has [a DESC] ordering (single column) + let a = sort_expr("a", &schema); + let b = sort_expr("b", &schema); + let source_ordering = LexOrdering::new(vec![a.clone().reverse()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Request [a ASC, b DESC] - longer than source, can't be a prefix + let longer_ordering = + LexOrdering::new(vec![a.clone().asc().nulls_first(), b.reverse()]).unwrap(); + let plan = sort_exec(longer_ordering, source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 ASC, b@1 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 DESC NULLS LAST], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 ASC, b@1 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 DESC NULLS LAST], file_type=parquet + " + ); +} + +// ============================================================================ +// ORIGINAL TESTS +// ============================================================================ + +#[test] +fn test_sort_through_repartition() { + // Sort should push through RepartitionExec + let schema = schema(); + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + let repartition = repartition_exec(source); + + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let plan = sort_exec(desc_ordering, repartition); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_nested_sorts() { + // Nested sort operations - only innermost can be optimized + let schema = schema(); + let a = sort_expr("a", &schema); + let b = sort_expr("b", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let inner_sort = sort_exec(desc_ordering, source); + + let sort_exprs2 = LexOrdering::new(vec![b]).unwrap(); + let plan = sort_exec(sort_exprs2, inner_sort); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[b@1 ASC], preserve_partitioning=[false] + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[b@1 ASC], preserve_partitioning=[false] + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_non_sort_plans_unchanged() { + // Plans without SortExec should pass through unchanged + let schema = schema(); + let plan = parquet_exec(schema.clone()); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + output: + Ok: + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + " + ); +} + +#[test] +fn test_optimizer_properties() { + // Test optimizer metadata + let optimizer = PushdownSort::new(); + + assert_eq!(optimizer.name(), "PushdownSort"); + assert!(optimizer.schema_check()); +} + +#[test] +fn test_sort_through_coalesce_partitions() { + // Sort should push through CoalescePartitionsExec + let schema = schema(); + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + let repartition = repartition_exec(source); + let coalesce_parts = coalesce_partitions_exec(repartition); + + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let plan = sort_exec(desc_ordering, coalesce_parts); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_complex_plan_with_multiple_operators() { + // Test a complex plan with multiple operators between sort and source + let schema = schema(); + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + let repartition = repartition_exec(source); + let coalesce_parts = coalesce_partitions_exec(repartition); + + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let plan = sort_exec(desc_ordering, coalesce_parts); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_multiple_sorts_different_columns() { + // Test nested sorts on different columns - only innermost can optimize + let schema = schema(); + let a = sort_expr("a", &schema); + let c = sort_expr("c", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // First sort by column 'a' DESC (reverse of source) + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let sort1 = sort_exec(desc_ordering, source); + + // Then sort by column 'c' (different column, can't optimize) + let sort_exprs2 = LexOrdering::new(vec![c]).unwrap(); + let plan = sort_exec(sort_exprs2, sort1); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_no_pushdown_for_unordered_source() { + // Verify pushdown does NOT happen for sources without ordering + let schema = schema(); + let source = parquet_exec(schema.clone()); // No output_ordering + let sort_exprs = LexOrdering::new(vec![sort_expr("a", &schema)]).unwrap(); + let plan = sort_exec(sort_exprs, source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet + " + ); +} + +#[test] +fn test_no_pushdown_for_non_reverse_sort() { + // Verify pushdown does NOT happen when sort doesn't reverse source ordering + let schema = schema(); + + // Source sorted by 'a' ASC + let a = sort_expr("a", &schema); + let b = sort_expr("b", &schema); + let source_ordering = LexOrdering::new(vec![a]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Request sort by 'b' (different column) + let sort_exprs = LexOrdering::new(vec![b]).unwrap(); + let plan = sort_exec(sort_exprs, source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[b@1 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[b@1 ASC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + " + ); +} + +#[test] +fn test_pushdown_through_blocking_node() { + // Test that pushdown works for inner sort even when outer sort is blocked + // Structure: Sort -> Aggregate (blocks pushdown) -> Sort -> Scan + // The outer sort can't push through aggregate, but the inner sort should still optimize + use datafusion_functions_aggregate::count::count_udaf; + use datafusion_physical_expr::aggregate::AggregateExprBuilder; + use datafusion_physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, + }; + use std::sync::Arc; + + let schema = schema(); + + // Bottom: DataSource with [a ASC NULLS LAST] ordering + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Inner Sort: [a DESC NULLS FIRST] - exact reverse, CAN push down to source + let inner_sort_ordering = LexOrdering::new(vec![a.clone().reverse()]).unwrap(); + let inner_sort = sort_exec(inner_sort_ordering, source); + + // Middle: Aggregate (blocks pushdown from outer sort) + // GROUP BY a, COUNT(b) + let group_by = PhysicalGroupBy::new_single(vec![( + Arc::new(expressions::Column::new("a", 0)) as _, + "a".to_string(), + )]); + + let count_expr = Arc::new( + AggregateExprBuilder::new( + count_udaf(), + vec![Arc::new(expressions::Column::new("b", 1)) as _], + ) + .schema(Arc::clone(&schema)) + .alias("COUNT(b)") + .build() + .unwrap(), + ); + + let aggregate = Arc::new( + AggregateExec::try_new( + AggregateMode::Final, + group_by, + vec![count_expr], + vec![None], + inner_sort, + Arc::clone(&schema), + ) + .unwrap(), + ); + + // Outer Sort: [a ASC] - this CANNOT push down through aggregate + let outer_sort_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let plan = sort_exec(outer_sort_ordering, aggregate); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - AggregateExec: mode=Final, gby=[a@0 as a], aggr=[COUNT(b)], ordering_mode=Sorted + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + - AggregateExec: mode=Final, gby=[a@0 as a], aggr=[COUNT(b)], ordering_mode=Sorted + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +// ============================================================================ +// PROJECTION TESTS +// ============================================================================ + +#[test] +fn test_sort_pushdown_through_simple_projection() { + // Sort pushes through projection with simple column references + let schema = schema(); + + // Source has [a ASC] ordering + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Projection: SELECT a, b (simple column references) + let projection = simple_projection_exec(source, vec![0, 1]); // columns a, b + + // Request [a DESC] - should push through projection to source + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let plan = sort_exec(desc_ordering, projection); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as a, b@1 as b] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as a, b@1 as b] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_sort_pushdown_through_projection_with_alias() { + // Sort pushes through projection with column aliases + let schema = schema(); + + // Source has [a ASC] ordering + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Projection: SELECT a AS id, b AS value + let projection = projection_exec_with_alias(source, vec![(0, "id"), (1, "value")]); + + // Request [id DESC] - should map to [a DESC] and push down + let id_expr = sort_expr_named("id", 0); + let desc_ordering = LexOrdering::new(vec![id_expr.reverse()]).unwrap(); + let plan = sort_exec(desc_ordering, projection); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[id@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as id, b@1 as value] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[id@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as id, b@1 as value] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_no_sort_pushdown_through_computed_projection() { + use datafusion_expr::Operator; + + // Sort should NOT push through projection with computed columns + let schema = schema(); + + // Source has [a ASC] ordering + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Projection: SELECT a+b as sum, c + let projection = projection_exec( + vec![ + ( + Arc::new(expressions::BinaryExpr::new( + Arc::new(expressions::Column::new("a", 0)), + Operator::Plus, + Arc::new(expressions::Column::new("b", 1)), + )) as Arc, + "sum".to_string(), + ), + ( + Arc::new(expressions::Column::new("c", 2)) as Arc, + "c".to_string(), + ), + ], + source, + ) + .unwrap(); + + // Request [sum DESC] - should NOT push down (sum is computed) + let sum_expr = sort_expr_named("sum", 0); + let desc_ordering = LexOrdering::new(vec![sum_expr.reverse()]).unwrap(); + let plan = sort_exec(desc_ordering, projection); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[sum@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 + b@1 as sum, c@2 as c] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[sum@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 + b@1 as sum, c@2 as c] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + " + ); +} + +#[test] +fn test_sort_pushdown_projection_reordered_columns() { + // Sort pushes through projection that reorders columns + let schema = schema(); + + // Source has [a ASC] ordering + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Projection: SELECT c, b, a (columns reordered) + let projection = simple_projection_exec(source, vec![2, 1, 0]); // c, b, a + + // Request [a DESC] where a is now at index 2 in projection output + let a_expr_at_2 = sort_expr_named("a", 2); + let desc_ordering = LexOrdering::new(vec![a_expr_at_2.reverse()]).unwrap(); + let plan = sort_exec(desc_ordering, projection); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@2 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[c@2 as c, b@1 as b, a@0 as a] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@2 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[c@2 as c, b@1 as b, a@0 as a] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_sort_pushdown_projection_with_limit() { + // Sort with LIMIT pushes through simple projection + let schema = schema(); + + // Source has [a ASC] ordering + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Projection: SELECT a, b + let projection = simple_projection_exec(source, vec![0, 1]); + + // Request [a DESC] with LIMIT 10 + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let plan = sort_exec_with_fetch(desc_ordering, Some(10), projection); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: TopK(fetch=10), expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as a, b@1 as b] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: TopK(fetch=10), expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as a, b@1 as b] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_sort_pushdown_through_projection() { + // Sort pushes through both projection and coalesce batches + let schema = schema(); + + // Source has [a ASC] ordering + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Projection: SELECT a, b + let projection = simple_projection_exec(source, vec![0, 1]); + + // Request [a DESC] + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let plan = sort_exec(desc_ordering, projection); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as a, b@1 as b] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as a, b@1 as b] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +#[test] +fn test_sort_pushdown_projection_subset_of_columns() { + // Sort pushes through projection that selects subset of columns + let schema = schema(); + + // Source has [a ASC, b ASC] ordering + let a = sort_expr("a", &schema); + let b = sort_expr("b", &schema); + let source_ordering = LexOrdering::new(vec![a.clone(), b.clone()]).unwrap(); + let source = parquet_exec_with_sort(schema.clone(), vec![source_ordering]); + + // Projection: SELECT a (subset of columns) + let projection = simple_projection_exec(source, vec![0]); + + // Request [a DESC] + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let plan = sort_exec(desc_ordering, projection); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as a] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=parquet + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - ProjectionExec: expr=[a@0 as a] + - DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet, reverse_row_groups=true + " + ); +} + +// ============================================================================ +// TESTSCAN DEMONSTRATION TESTS +// ============================================================================ +// These tests use TestScan to demonstrate how sort pushdown works more clearly +// than ParquetExec. TestScan can accept ANY ordering (not just reverse) and +// displays the requested ordering explicitly in the output. + +#[test] +fn test_sort_pushdown_with_test_scan_basic() { + // Demonstrates TestScan showing requested ordering clearly + let schema = schema(); + + // Source has [a ASC] ordering + let a = sort_expr("a", &schema); + let source_ordering = LexOrdering::new(vec![a.clone()]).unwrap(); + let source = test_scan_with_ordering(schema.clone(), source_ordering); + + // Request [a DESC] ordering + let desc_ordering = LexOrdering::new(vec![a.reverse()]).unwrap(); + let plan = sort_exec(desc_ordering, source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - TestScan: output_ordering=[a@0 ASC] + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - TestScan: output_ordering=[a@0 ASC], requested_ordering=[a@0 DESC NULLS LAST] + " + ); +} + +#[test] +fn test_sort_pushdown_with_test_scan_multi_column() { + // Demonstrates TestScan with multi-column ordering + let schema = schema(); + + // Source has [a ASC, b DESC] ordering + let a = sort_expr("a", &schema); + let b = sort_expr("b", &schema); + let source_ordering = LexOrdering::new(vec![a.clone(), b.clone().reverse()]).unwrap(); + let source = test_scan_with_ordering(schema.clone(), source_ordering); + + // Request [a DESC, b ASC] ordering (reverse of source) + let reverse_ordering = LexOrdering::new(vec![a.reverse(), b]).unwrap(); + let plan = sort_exec(reverse_ordering, source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST, b@1 ASC], preserve_partitioning=[false] + - TestScan: output_ordering=[a@0 ASC, b@1 DESC NULLS LAST] + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST, b@1 ASC], preserve_partitioning=[false] + - TestScan: output_ordering=[a@0 ASC, b@1 DESC NULLS LAST], requested_ordering=[a@0 DESC NULLS LAST, b@1 ASC] + " + ); +} + +#[test] +fn test_sort_pushdown_with_test_scan_arbitrary_ordering() { + // Demonstrates that TestScan can accept ANY ordering (not just reverse) + // This is different from ParquetExec which only supports reverse scans + let schema = schema(); + + // Source has [a ASC, b ASC] ordering + let a = sort_expr("a", &schema); + let b = sort_expr("b", &schema); + let source_ordering = LexOrdering::new(vec![a.clone(), b.clone()]).unwrap(); + let source = test_scan_with_ordering(schema.clone(), source_ordering); + + // Request [a ASC, b DESC] - NOT a simple reverse, but TestScan accepts it + let mixed_ordering = LexOrdering::new(vec![a, b.reverse()]).unwrap(); + let plan = sort_exec(mixed_ordering, source); + + insta::assert_snapshot!( + OptimizationTest::new(plan, PushdownSort::new(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 ASC, b@1 DESC NULLS LAST], preserve_partitioning=[false] + - TestScan: output_ordering=[a@0 ASC, b@1 ASC] + output: + Ok: + - SortExec: expr=[a@0 ASC, b@1 DESC NULLS LAST], preserve_partitioning=[false] + - TestScan: output_ordering=[a@0 ASC, b@1 ASC], requested_ordering=[a@0 ASC, b@1 DESC NULLS LAST] + " + ); +} diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs b/datafusion/core/tests/physical_optimizer/pushdown_utils.rs similarity index 65% rename from datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs rename to datafusion/core/tests/physical_optimizer/pushdown_utils.rs index dc4d77194c082..ce2cb04b64a5f 100644 --- a/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs +++ b/datafusion/core/tests/physical_optimizer/pushdown_utils.rs @@ -16,31 +16,30 @@ // under the License. use arrow::datatypes::SchemaRef; -use arrow::error::ArrowError; use arrow::{array::RecordBatch, compute::concat_batches}; use datafusion::{datasource::object_store::ObjectStoreUrl, physical_plan::PhysicalExpr}; -use datafusion_common::{config::ConfigOptions, internal_err, Result, Statistics}; +use datafusion_common::tree_node::TreeNodeRecursion; +use datafusion_common::{Result, config::ConfigOptions, internal_err}; use datafusion_datasource::{ - file::FileSource, file_meta::FileMeta, file_scan_config::FileScanConfig, + PartitionedFile, file::FileSource, file_scan_config::FileScanConfig, file_scan_config::FileScanConfigBuilder, file_stream::FileOpenFuture, - file_stream::FileOpener, impl_schema_adapter_methods, - schema_adapter::DefaultSchemaAdapterFactory, schema_adapter::SchemaAdapterFactory, - source::DataSourceExec, PartitionedFile, + file_stream::FileOpener, source::DataSourceExec, }; -use datafusion_physical_expr::conjunction; +use datafusion_physical_expr::projection::ProjectionExprs; use datafusion_physical_expr_common::physical_expr::fmt_sql; use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_plan::filter::batch_filter; +use datafusion_physical_plan::filter_pushdown::{FilterPushdownPhase, PushedDown}; use datafusion_physical_plan::{ - displayable, + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, displayable, filter::FilterExec, filter_pushdown::{ - ChildPushdownResult, FilterDescription, FilterPushdownPropagation, - PredicateSupport, PredicateSupports, + ChildFilterDescription, ChildPushdownResult, FilterDescription, + FilterPushdownPropagation, }, metrics::ExecutionPlanMetricsSet, - DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, }; -use futures::stream::BoxStream; +use futures::StreamExt; use futures::{FutureExt, Stream}; use object_store::ObjectStore; use std::{ @@ -53,13 +52,17 @@ use std::{ pub struct TestOpener { batches: Vec, batch_size: Option, - schema: Option, - projection: Option>, + projection: Option, + predicate: Option>, } impl FileOpener for TestOpener { - fn open(&self, _file_meta: FileMeta) -> Result { + fn open(&self, _partitioned_file: PartitionedFile) -> Result { let mut batches = self.batches.clone(); + if self.batches.is_empty() { + return Ok((async { Ok(TestStream::new(vec![]).boxed()) }).boxed()); + } + let schema = self.batches[0].schema(); if let Some(batch_size) = self.batch_size { let batch = concat_batches(&batches[0].schema(), &batches)?; let mut new_batches = Vec::new(); @@ -70,56 +73,55 @@ impl FileOpener for TestOpener { } batches = new_batches.into_iter().collect(); } - if let Some(schema) = &self.schema { - let factory = DefaultSchemaAdapterFactory::from_schema(Arc::clone(schema)); - let (mapper, projection) = factory.map_schema(&batches[0].schema()).unwrap(); - let mut new_batches = Vec::new(); - for batch in batches { - let batch = batch.project(&projection).unwrap(); - let batch = mapper.map_batch(batch).unwrap(); - new_batches.push(batch); - } - batches = new_batches; + + let mut new_batches = Vec::new(); + for batch in batches { + let batch = if let Some(predicate) = &self.predicate { + batch_filter(&batch, predicate)? + } else { + batch + }; + new_batches.push(batch); } + batches = new_batches; + if let Some(projection) = &self.projection { + let projector = projection.make_projector(&schema)?; batches = batches .into_iter() - .map(|batch| batch.project(projection).unwrap()) + .map(|batch| projector.project_batch(&batch).unwrap()) .collect(); } let stream = TestStream::new(batches); - Ok((async { - let stream: BoxStream<'static, Result> = - Box::pin(stream); - Ok(stream) - }) - .boxed()) + Ok((async { Ok(stream.boxed()) }).boxed()) } } /// A placeholder data source that accepts filter pushdown -#[derive(Clone, Default)] +#[derive(Clone)] pub struct TestSource { support: bool, predicate: Option>, - statistics: Option, batch_size: Option, batches: Vec, - schema: Option, metrics: ExecutionPlanMetricsSet, - projection: Option>, - schema_adapter_factory: Option>, + projection: Option, + table_schema: datafusion_datasource::TableSchema, } impl TestSource { - fn new(support: bool, batches: Vec) -> Self { + pub fn new(schema: SchemaRef, support: bool, batches: Vec) -> Self { + let table_schema = datafusion_datasource::TableSchema::new(schema, vec![]); Self { support, metrics: ExecutionPlanMetricsSet::new(), batches, - ..Default::default() + predicate: None, + batch_size: None, + projection: None, + table_schema, } } } @@ -130,13 +132,17 @@ impl FileSource for TestSource { _object_store: Arc, _base_config: &FileScanConfig, _partition: usize, - ) -> Arc { - Arc::new(TestOpener { + ) -> Result> { + Ok(Arc::new(TestOpener { batches: self.batches.clone(), batch_size: self.batch_size, - schema: self.schema.clone(), projection: self.projection.clone(), - }) + predicate: self.predicate.clone(), + })) + } + + fn filter(&self) -> Option> { + self.predicate.clone() } fn as_any(&self) -> &dyn Any { @@ -150,39 +156,10 @@ impl FileSource for TestSource { }) } - fn with_schema(&self, schema: SchemaRef) -> Arc { - Arc::new(TestSource { - schema: Some(schema), - ..self.clone() - }) - } - - fn with_projection(&self, config: &FileScanConfig) -> Arc { - Arc::new(TestSource { - projection: config.projection.clone(), - ..self.clone() - }) - } - - fn with_statistics(&self, statistics: Statistics) -> Arc { - Arc::new(TestSource { - statistics: Some(statistics), - ..self.clone() - }) - } - fn metrics(&self) -> &ExecutionPlanMetricsSet { &self.metrics } - fn statistics(&self) -> Result { - Ok(self - .statistics - .as_ref() - .expect("statistics not set") - .clone()) - } - fn file_type(&self) -> &str { "test" } @@ -220,19 +197,68 @@ impl FileSource for TestSource { filters.push(Arc::clone(internal)); } let new_node = Arc::new(TestSource { - predicate: Some(conjunction(filters.clone())), + predicate: datafusion_physical_expr::utils::conjunction_opt( + filters.clone(), + ), ..self.clone() }); - Ok(FilterPushdownPropagation { - filters: PredicateSupports::all_supported(filters), - updated_node: Some(new_node), - }) + Ok(FilterPushdownPropagation::with_parent_pushdown_result( + vec![PushedDown::Yes; filters.len()], + ) + .with_updated_node(new_node)) } else { - Ok(FilterPushdownPropagation::unsupported(filters)) + Ok(FilterPushdownPropagation::with_parent_pushdown_result( + vec![PushedDown::No; filters.len()], + )) } } - impl_schema_adapter_methods!(); + fn try_pushdown_projection( + &self, + projection: &ProjectionExprs, + ) -> Result>> { + if let Some(existing_projection) = &self.projection { + // Combine existing projection with new projection + let combined_projection = existing_projection.try_merge(projection)?; + Ok(Some(Arc::new(TestSource { + projection: Some(combined_projection), + table_schema: self.table_schema.clone(), + ..self.clone() + }))) + } else { + Ok(Some(Arc::new(TestSource { + projection: Some(projection.clone()), + ..self.clone() + }))) + } + } + + fn projection(&self) -> Option<&ProjectionExprs> { + self.projection.as_ref() + } + + fn table_schema(&self) -> &datafusion_datasource::TableSchema { + &self.table_schema + } + + fn apply_expressions( + &self, + f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + // Visit predicate (filter) expression if present + if let Some(predicate) = &self.predicate { + f(predicate.as_ref())?; + } + + // Visit projection expressions if present + if let Some(projection) = &self.projection { + for proj_expr in projection { + f(proj_expr.expr.as_ref())?; + } + } + + Ok(TreeNodeRecursion::Continue) + } } #[derive(Debug, Clone)] @@ -256,15 +282,21 @@ impl TestScanBuilder { self } + pub fn with_batches(mut self, batches: Vec) -> Self { + self.batches = batches; + self + } + pub fn build(self) -> Arc { - let source = Arc::new(TestSource::new(self.support, self.batches)); - let base_config = FileScanConfigBuilder::new( - ObjectStoreUrl::parse("test://").unwrap(), + let source = Arc::new(TestSource::new( Arc::clone(&self.schema), - source, - ) - .with_file(PartitionedFile::new("test.paqruet", 123)) - .build(); + self.support, + self.batches, + )); + let base_config = + FileScanConfigBuilder::new(ObjectStoreUrl::parse("test://").unwrap(), source) + .with_file(PartitionedFile::new("test.parquet", 123)) + .build(); DataSourceExec::from_data_source(base_config) } } @@ -303,11 +335,12 @@ impl TestStream { /// least one entry in data (for the schema) pub fn new(data: Vec) -> Self { // check that there is at least one entry in data and that all batches have the same schema - assert!(!data.is_empty(), "data must not be empty"); - assert!( - data.iter().all(|batch| batch.schema() == data[0].schema()), - "all batches must have the same schema" - ); + if let Some(first) = data.first() { + assert!( + data.iter().all(|batch| batch.schema() == first.schema()), + "all batches must have the same schema" + ); + } Self { data, ..Default::default() @@ -316,7 +349,7 @@ impl TestStream { } impl Stream for TestStream { - type Item = Result; + type Item = Result; fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { let next_batch = self.index.value(); @@ -345,6 +378,7 @@ pub struct OptimizationTest { } impl OptimizationTest { + #[expect(clippy::needless_pass_by_value)] pub fn new( input_plan: Arc, opt: O, @@ -411,6 +445,15 @@ fn format_lines(s: &str) -> Vec { s.trim().split('\n').map(|s| s.to_string()).collect() } +pub fn format_plan_for_test(plan: &Arc) -> String { + let mut out = String::new(); + for line in format_execution_plan(plan) { + out.push_str(&format!(" - {line}\n")); + } + out.push('\n'); + out +} + #[derive(Debug)] pub(crate) struct TestNode { inject_filter: bool, @@ -451,7 +494,7 @@ impl ExecutionPlan for TestNode { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { self.input.properties() } @@ -481,16 +524,21 @@ impl ExecutionPlan for TestNode { fn gather_filters_for_pushdown( &self, + _phase: FilterPushdownPhase, parent_filters: Vec>, _config: &ConfigOptions, ) -> Result { - Ok(FilterDescription::new_with_child_count(1) - .all_parent_filters_supported(parent_filters) - .with_self_filter(Arc::clone(&self.predicate))) + // Since TestNode marks all parent filters as supported and adds its own filter, + // we use from_child to create a description with all parent filters supported + let child = &self.input; + let child_desc = ChildFilterDescription::from_child(&parent_filters, child)? + .with_self_filter(Arc::clone(&self.predicate)); + Ok(FilterDescription::new().with_child(child_desc)) } fn handle_child_pushdown_result( &self, + _phase: FilterPushdownPhase, child_pushdown_result: ChildPushdownResult, _config: &ConfigOptions, ) -> Result>> { @@ -502,29 +550,41 @@ impl ExecutionPlan for TestNode { let self_pushdown_result = child_pushdown_result.self_filters[0].clone(); // And pushed down 1 filter assert_eq!(self_pushdown_result.len(), 1); - let self_pushdown_result = self_pushdown_result.into_inner(); + let self_pushdown_result: Vec<_> = self_pushdown_result.into_iter().collect(); + + let first_pushdown_result = self_pushdown_result[0].clone(); - match &self_pushdown_result[0] { - PredicateSupport::Unsupported(filter) => { + match &first_pushdown_result.discriminant { + PushedDown::No => { // We have a filter to push down - let new_child = - FilterExec::try_new(Arc::clone(filter), Arc::clone(&self.input))?; + let new_child = FilterExec::try_new( + Arc::clone(&first_pushdown_result.predicate), + Arc::clone(&self.input), + )?; let new_self = TestNode::new(false, Arc::new(new_child), self.predicate.clone()); let mut res = - FilterPushdownPropagation::transparent(child_pushdown_result); + FilterPushdownPropagation::if_all(child_pushdown_result); res.updated_node = Some(Arc::new(new_self) as Arc); Ok(res) } - PredicateSupport::Supported(_) => { - let res = - FilterPushdownPropagation::transparent(child_pushdown_result); + PushedDown::Yes => { + let res = FilterPushdownPropagation::if_all(child_pushdown_result); Ok(res) } } } else { - let res = FilterPushdownPropagation::transparent(child_pushdown_result); + let res = FilterPushdownPropagation::if_all(child_pushdown_result); Ok(res) } } + + fn apply_expressions( + &self, + f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + // Visit the predicate expression + f(self.predicate.as_ref())?; + Ok(TreeNodeRecursion::Continue) + } } diff --git a/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs index 71b9757604ecf..cdfed5011696e 100644 --- a/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs @@ -18,7 +18,9 @@ use std::sync::Arc; use crate::physical_optimizer::test_utils::{ - check_integrity, create_test_schema3, sort_preserving_merge_exec, + check_integrity, coalesce_partitions_exec, create_test_schema3, + parquet_exec_with_sort, sort_exec, sort_exec_with_preserve_partitioning, + sort_preserving_merge_exec, sort_preserving_merge_exec_with_fetch, stream_exec_ordered_with_projection, }; @@ -27,1101 +29,1044 @@ use arrow::array::{ArrayRef, Int32Array}; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use insta::{allow_duplicates, assert_snapshot}; +use datafusion_common::tree_node::{TransformedResult, TreeNode}; +use datafusion_common::{assert_contains, NullEquality, Result}; +use datafusion_common::config::ConfigOptions; +use datafusion_datasource::source::DataSourceExec; use datafusion_execution::TaskContext; -use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; -use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; -use datafusion_physical_plan::collect; +use datafusion_expr::{JoinType, Operator}; +use datafusion_physical_expr::expressions::{self, col, Column}; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_optimizer::enforce_sorting::replace_with_order_preserving_variants::{ + plan_with_order_breaking_variants, plan_with_order_preserving_variants, replace_with_order_preserving_variants, OrderPreservationContext +}; use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; use datafusion::datasource::memory::MemorySourceConfig; use datafusion_physical_plan::repartition::RepartitionExec; -use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::{ - displayable, get_plan_string, ExecutionPlan, Partitioning, + collect, displayable, ExecutionPlan, Partitioning, }; -use datafusion::datasource::source::DataSourceExec; -use datafusion_common::tree_node::{TransformedResult, TreeNode}; -use datafusion_common::{assert_contains, Result}; -use datafusion_expr::{JoinType, Operator}; -use datafusion_physical_expr::expressions::{self, col, Column}; -use datafusion_physical_expr::PhysicalSortExpr; -use datafusion_physical_optimizer::enforce_sorting::replace_with_order_preserving_variants::{plan_with_order_breaking_variants, plan_with_order_preserving_variants, replace_with_order_preserving_variants, OrderPreservationContext}; -use datafusion_common::config::ConfigOptions; -use crate::physical_optimizer::enforce_sorting::parquet_exec_sorted; -use datafusion_physical_expr_common::sort_expr::LexOrdering; -use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; +use object_store::ObjectStoreExt; use object_store::memory::InMemory; -use object_store::ObjectStore; use rstest::rstest; use url::Url; -/// Runs the `replace_with_order_preserving_variants` sub-rule and asserts -/// the plan against the original and expected plans. -/// -/// # Parameters -/// -/// * `$EXPECTED_PLAN_LINES`: Expected input plan. -/// * `EXPECTED_OPTIMIZED_PLAN_LINES`: Optimized plan when the flag -/// `prefer_existing_sort` is `false`. -/// * `EXPECTED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES`: Optimized plan when -/// the flag `prefer_existing_sort` is `true`. -/// * `$PLAN`: The plan to optimize. -macro_rules! assert_optimized_prefer_sort_on_off { - ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $EXPECTED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $PREFER_EXISTING_SORT: expr, $SOURCE_UNBOUNDED: expr) => { - if $PREFER_EXISTING_SORT { - assert_optimized!( - $EXPECTED_PLAN_LINES, - $EXPECTED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES, - $PLAN, - $PREFER_EXISTING_SORT, - $SOURCE_UNBOUNDED - ); - } else { - assert_optimized!( - $EXPECTED_PLAN_LINES, - $EXPECTED_OPTIMIZED_PLAN_LINES, - $PLAN, - $PREFER_EXISTING_SORT, - $SOURCE_UNBOUNDED - ); - } - }; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Boundedness { + Unbounded, + Bounded, } -/// Runs the `replace_with_order_preserving_variants` sub-rule and asserts -/// the plan against the original and expected plans for both bounded and -/// unbounded cases. -/// -/// # Parameters -/// -/// * `EXPECTED_UNBOUNDED_PLAN_LINES`: Expected input unbounded plan. -/// * `EXPECTED_BOUNDED_PLAN_LINES`: Expected input bounded plan. -/// * `EXPECTED_UNBOUNDED_OPTIMIZED_PLAN_LINES`: Optimized plan, which is -/// the same regardless of the value of the `prefer_existing_sort` flag. -/// * `EXPECTED_BOUNDED_OPTIMIZED_PLAN_LINES`: Optimized plan when the flag -/// `prefer_existing_sort` is `false` for bounded cases. -/// * `EXPECTED_BOUNDED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES`: Optimized plan -/// when the flag `prefer_existing_sort` is `true` for bounded cases. -/// * `$PLAN`: The plan to optimize. -/// * `$SOURCE_UNBOUNDED`: Whether the given plan contains an unbounded source. -macro_rules! assert_optimized_in_all_boundedness_situations { - ($EXPECTED_UNBOUNDED_PLAN_LINES: expr, $EXPECTED_BOUNDED_PLAN_LINES: expr, $EXPECTED_UNBOUNDED_OPTIMIZED_PLAN_LINES: expr, $EXPECTED_BOUNDED_OPTIMIZED_PLAN_LINES: expr, $EXPECTED_BOUNDED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $SOURCE_UNBOUNDED: expr, $PREFER_EXISTING_SORT: expr) => { - if $SOURCE_UNBOUNDED { - assert_optimized_prefer_sort_on_off!( - $EXPECTED_UNBOUNDED_PLAN_LINES, - $EXPECTED_UNBOUNDED_OPTIMIZED_PLAN_LINES, - $EXPECTED_UNBOUNDED_OPTIMIZED_PLAN_LINES, - $PLAN, - $PREFER_EXISTING_SORT, - $SOURCE_UNBOUNDED - ); - } else { - assert_optimized_prefer_sort_on_off!( - $EXPECTED_BOUNDED_PLAN_LINES, - $EXPECTED_BOUNDED_OPTIMIZED_PLAN_LINES, - $EXPECTED_BOUNDED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES, - $PLAN, - $PREFER_EXISTING_SORT, - $SOURCE_UNBOUNDED - ); - } - }; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum SortPreference { + PreserveOrder, + MaximizeParallelism, } -/// Runs the `replace_with_order_preserving_variants` sub-rule and asserts -/// the plan against the original and expected plans. -/// -/// # Parameters -/// -/// * `$EXPECTED_PLAN_LINES`: Expected input plan. -/// * `$EXPECTED_OPTIMIZED_PLAN_LINES`: Expected optimized plan. -/// * `$PLAN`: The plan to optimize. -/// * `$PREFER_EXISTING_SORT`: Value of the `prefer_existing_sort` flag. -#[macro_export] -macro_rules! assert_optimized { - ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $PREFER_EXISTING_SORT: expr, $SOURCE_UNBOUNDED: expr) => { - let physical_plan = $PLAN; - let formatted = displayable(physical_plan.as_ref()).indent(true).to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - - let expected_plan_lines: Vec<&str> = $EXPECTED_PLAN_LINES - .iter().map(|s| *s).collect(); - - assert_eq!( - expected_plan_lines, actual, - "\n**Original Plan Mismatch\n\nexpected:\n\n{expected_plan_lines:#?}\nactual:\n\n{actual:#?}\n\n" - ); +struct ReplaceTest { + plan: Arc, + boundedness: Boundedness, + sort_preference: SortPreference, +} - let expected_optimized_lines: Vec<&str> = $EXPECTED_OPTIMIZED_PLAN_LINES.iter().map(|s| *s).collect(); +impl ReplaceTest { + fn new(plan: Arc) -> Self { + Self { + plan, + boundedness: Boundedness::Bounded, + sort_preference: SortPreference::MaximizeParallelism, + } + } - // Run the rule top-down - let mut config = ConfigOptions::new(); - config.optimizer.prefer_existing_sort=$PREFER_EXISTING_SORT; - let plan_with_pipeline_fixer = OrderPreservationContext::new_default(physical_plan); - let parallel = plan_with_pipeline_fixer.transform_up(|plan_with_pipeline_fixer| replace_with_order_preserving_variants(plan_with_pipeline_fixer, false, false, &config)).data().and_then(check_integrity)?; - let optimized_physical_plan = parallel.plan; + fn with_boundedness(mut self, boundedness: Boundedness) -> Self { + self.boundedness = boundedness; + self + } - // Get string representation of the plan - let actual = get_plan_string(&optimized_physical_plan); - assert_eq!( - expected_optimized_lines, actual, - "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected_optimized_lines:#?}\nactual:\n\n{actual:#?}\n\n" + fn with_sort_preference(mut self, sort_preference: SortPreference) -> Self { + self.sort_preference = sort_preference; + self + } + + async fn execute_plan(&self) -> String { + let mut config = ConfigOptions::new(); + config.optimizer.prefer_existing_sort = + self.sort_preference == SortPreference::PreserveOrder; + + let plan_with_pipeline_fixer = OrderPreservationContext::new_default( + self.plan.clone().reset_state().unwrap(), + ); + + let parallel = plan_with_pipeline_fixer + .transform_up(|plan_with_pipeline_fixer| { + replace_with_order_preserving_variants( + plan_with_pipeline_fixer, + false, + false, + &config, + ) + }) + .data() + .and_then(check_integrity) + .unwrap(); + + let optimized_physical_plan = parallel.plan; + let optimized_plan_string = displayable(optimized_physical_plan.as_ref()) + .indent(true) + .to_string(); + + if self.boundedness == Boundedness::Bounded { + let ctx = SessionContext::new(); + let object_store = InMemory::new(); + object_store + .put( + &object_store::path::Path::from("file_path"), + bytes::Bytes::from("").into(), + ) + .await + .expect("could not create object store"); + ctx.register_object_store( + &Url::parse("test://").unwrap(), + Arc::new(object_store), ); + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let res = collect(optimized_physical_plan, task_ctx).await; + assert!( + res.is_ok(), + "Some errors occurred while executing the optimized physical plan: {:?}\nPlan: {}", + res.unwrap_err(), + optimized_plan_string + ); + } + + optimized_plan_string + } + + async fn run(&self) -> String { + let input_plan_string = displayable(self.plan.as_ref()).indent(true).to_string(); - if !$SOURCE_UNBOUNDED { - let ctx = SessionContext::new(); - let object_store = InMemory::new(); - object_store.put(&object_store::path::Path::from("file_path"), bytes::Bytes::from("").into()).await?; - ctx.register_object_store(&Url::parse("test://").unwrap(), Arc::new(object_store)); - let task_ctx = Arc::new(TaskContext::from(&ctx)); - let res = collect(optimized_physical_plan, task_ctx).await; - assert!( - res.is_ok(), - "Some errors occurred while executing the optimized physical plan: {:?}", res.unwrap_err() - ); - } - }; + let optimized = self.execute_plan().await; + + if input_plan_string == optimized { + format!("Input / Optimized:\n{input_plan_string}") + } else { + format!("Input:\n{input_plan_string}\nOptimized:\n{optimized}") + } } +} #[rstest] #[tokio::test] // Searches for a simple sort and a repartition just after it, the second repartition with 1 input partition should not be affected async fn test_replace_multiple_input_repartition_1( - #[values(false, true)] source_unbounded: bool, - #[values(false, true)] prefer_existing_sort: bool, + #[values(Boundedness::Unbounded, Boundedness::Bounded)] boundedness: Boundedness, + #[values(SortPreference::PreserveOrder, SortPreference::MaximizeParallelism)] + sort_pref: SortPreference, ) -> Result<()> { let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; - let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) - } else { - memory_exec_sorted(&schema, sort_exprs) + let sort_exprs: LexOrdering = [sort_expr("a", &schema)].into(); + let source = match boundedness { + Boundedness::Unbounded => { + stream_exec_ordered_with_projection(&schema, sort_exprs.clone()) + } + Boundedness::Bounded => memory_exec_sorted(&schema, sort_exprs.clone()), }; let repartition = repartition_exec_hash(repartition_exec_round_robin(source)); - let sort = sort_exec(vec![sort_expr("a", &schema)], repartition, true); - - let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - - // Expected inputs unbounded and bounded - let expected_input_unbounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - let expected_input_bounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - - // Expected unbounded result (same for with and without flag) - let expected_optimized_unbounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - - // Expected bounded results with and without flag - let expected_optimized_bounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - let expected_optimized_bounded_sort_preserve = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - assert_optimized_in_all_boundedness_situations!( - expected_input_unbounded, - expected_input_bounded, - expected_optimized_unbounded, - expected_optimized_bounded, - expected_optimized_bounded_sort_preserve, - physical_plan, - source_unbounded, - prefer_existing_sort - ); + let sort = sort_exec_with_preserve_partitioning(sort_exprs.clone(), repartition); + let physical_plan = sort_preserving_merge_exec(sort_exprs, sort); + + let run = ReplaceTest::new(physical_plan) + .with_boundedness(boundedness) + .with_sort_preference(sort_pref); + + let physical_plan = run.run().await; + + allow_duplicates! { + match (boundedness, sort_pref) { + (Boundedness::Bounded, SortPreference::MaximizeParallelism) => { + assert_snapshot!(physical_plan, @r" + Input / Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + }, + (Boundedness::Unbounded, _) => { + assert_snapshot!(physical_plan, @r" + Input: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + + Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + "); + }, + (Boundedness::Bounded, SortPreference::PreserveOrder) => { + assert_snapshot!(physical_plan, @r" + Input: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + + Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + } + } + } + Ok(()) } #[rstest] #[tokio::test] async fn test_with_inter_children_change_only( - #[values(false, true)] source_unbounded: bool, - #[values(false, true)] prefer_existing_sort: bool, + #[values(Boundedness::Unbounded, Boundedness::Bounded)] boundedness: Boundedness, + #[values(SortPreference::PreserveOrder, SortPreference::MaximizeParallelism)] + sort_pref: SortPreference, ) -> Result<()> { let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr_default("a", &schema)]; - let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) - } else { - memory_exec_sorted(&schema, sort_exprs) + let ordering: LexOrdering = [sort_expr_default("a", &schema)].into(); + let source = match boundedness { + Boundedness::Unbounded => { + stream_exec_ordered_with_projection(&schema, ordering.clone()) + } + Boundedness::Bounded => memory_exec_sorted(&schema, ordering.clone()), }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); - let sort = sort_exec( - vec![sort_expr_default("a", &coalesce_partitions.schema())], - coalesce_partitions, - false, - ); + let sort = sort_exec(ordering.clone(), coalesce_partitions); let repartition_rr2 = repartition_exec_round_robin(sort); let repartition_hash2 = repartition_exec_hash(repartition_rr2); let filter = filter_exec(repartition_hash2); - let sort2 = sort_exec(vec![sort_expr_default("a", &filter.schema())], filter, true); - - let physical_plan = - sort_preserving_merge_exec(vec![sort_expr_default("a", &sort2.schema())], sort2); - - // Expected inputs unbounded and bounded - let expected_input_unbounded = [ - "SortPreservingMergeExec: [a@0 ASC]", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC]", - ]; - let expected_input_bounded = [ - "SortPreservingMergeExec: [a@0 ASC]", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC", - ]; - - // Expected unbounded result (same for with and without flag) - let expected_optimized_unbounded = [ - "SortPreservingMergeExec: [a@0 ASC]", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " SortPreservingMergeExec: [a@0 ASC]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC]", - ]; - - // Expected bounded results with and without flag - let expected_optimized_bounded = [ - "SortPreservingMergeExec: [a@0 ASC]", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC", - ]; - let expected_optimized_bounded_sort_preserve = [ - "SortPreservingMergeExec: [a@0 ASC]", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " SortPreservingMergeExec: [a@0 ASC]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC", - ]; - assert_optimized_in_all_boundedness_situations!( - expected_input_unbounded, - expected_input_bounded, - expected_optimized_unbounded, - expected_optimized_bounded, - expected_optimized_bounded_sort_preserve, - physical_plan, - source_unbounded, - prefer_existing_sort - ); + let sort2 = sort_exec_with_preserve_partitioning(ordering.clone(), filter); + + let physical_plan = sort_preserving_merge_exec(ordering, sort2); + + let run = ReplaceTest::new(physical_plan) + .with_boundedness(boundedness) + .with_sort_preference(sort_pref); + + let physical_plan = run.run().await; + + allow_duplicates! { + match (boundedness, sort_pref) { + (Boundedness::Unbounded, _) => { + assert_snapshot!(physical_plan, @r" + Input: + SortPreservingMergeExec: [a@0 ASC] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC] + + Optimized: + SortPreservingMergeExec: [a@0 ASC] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + SortPreservingMergeExec: [a@0 ASC] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC] + "); + }, + (Boundedness::Bounded, SortPreference::MaximizeParallelism) => { + assert_snapshot!(physical_plan, @r" + Input / Optimized: + SortPreservingMergeExec: [a@0 ASC] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC + "); + }, + (Boundedness::Bounded, SortPreference::PreserveOrder) => { + assert_snapshot!(physical_plan, @r" + Input: + SortPreservingMergeExec: [a@0 ASC] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC + + Optimized: + SortPreservingMergeExec: [a@0 ASC] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + SortPreservingMergeExec: [a@0 ASC] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC + "); + } + } + } + Ok(()) } #[rstest] #[tokio::test] async fn test_replace_multiple_input_repartition_2( - #[values(false, true)] source_unbounded: bool, - #[values(false, true)] prefer_existing_sort: bool, + #[values(Boundedness::Unbounded, Boundedness::Bounded)] boundedness: Boundedness, + #[values(SortPreference::PreserveOrder, SortPreference::MaximizeParallelism)] + sort_pref: SortPreference, ) -> Result<()> { let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; - let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) - } else { - memory_exec_sorted(&schema, sort_exprs) + let ordering: LexOrdering = [sort_expr("a", &schema)].into(); + let source = match boundedness { + Boundedness::Unbounded => { + stream_exec_ordered_with_projection(&schema, ordering.clone()) + } + Boundedness::Bounded => memory_exec_sorted(&schema, ordering.clone()), }; let repartition_rr = repartition_exec_round_robin(source); let filter = filter_exec(repartition_rr); let repartition_hash = repartition_exec_hash(filter); - let sort = sort_exec(vec![sort_expr("a", &schema)], repartition_hash, true); - - let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - - // Expected inputs unbounded and bounded - let expected_input_unbounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - let expected_input_bounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - - // Expected unbounded result (same for with and without flag) - let expected_optimized_unbounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - - // Expected bounded results with and without flag - let expected_optimized_bounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - let expected_optimized_bounded_sort_preserve = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - assert_optimized_in_all_boundedness_situations!( - expected_input_unbounded, - expected_input_bounded, - expected_optimized_unbounded, - expected_optimized_bounded, - expected_optimized_bounded_sort_preserve, - physical_plan, - source_unbounded, - prefer_existing_sort - ); + let sort = sort_exec_with_preserve_partitioning(ordering.clone(), repartition_hash); + let physical_plan = sort_preserving_merge_exec(ordering, sort); + + let run = ReplaceTest::new(physical_plan) + .with_boundedness(boundedness) + .with_sort_preference(sort_pref); + + let physical_plan = run.run().await; + + allow_duplicates! { + match (boundedness, sort_pref) { + (Boundedness::Unbounded, _) => { + assert_snapshot!(physical_plan, @r" + Input: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + FilterExec: c@1 > 3 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + + Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + FilterExec: c@1 > 3 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + "); + }, + (Boundedness::Bounded, SortPreference::MaximizeParallelism) => { + assert_snapshot!(physical_plan, @r" + Input / Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + FilterExec: c@1 > 3 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + }, + (Boundedness::Bounded, SortPreference::PreserveOrder) => { + assert_snapshot!(physical_plan, @r" + Input: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + FilterExec: c@1 > 3 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + + Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + FilterExec: c@1 > 3 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + } + } + } + Ok(()) } #[rstest] #[tokio::test] async fn test_replace_multiple_input_repartition_with_extra_steps( - #[values(false, true)] source_unbounded: bool, - #[values(false, true)] prefer_existing_sort: bool, + #[values(Boundedness::Unbounded, Boundedness::Bounded)] boundedness: Boundedness, + #[values(SortPreference::PreserveOrder, SortPreference::MaximizeParallelism)] + sort_pref: SortPreference, ) -> Result<()> { let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; - let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) - } else { - memory_exec_sorted(&schema, sort_exprs) + let ordering: LexOrdering = [sort_expr("a", &schema)].into(); + let source = match boundedness { + Boundedness::Unbounded => { + stream_exec_ordered_with_projection(&schema, ordering.clone()) + } + Boundedness::Bounded => memory_exec_sorted(&schema, ordering.clone()), }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); - let coalesce_batches_exec: Arc = coalesce_batches_exec(filter); - let sort = sort_exec(vec![sort_expr("a", &schema)], coalesce_batches_exec, true); - - let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - - // Expected inputs unbounded and bounded - let expected_input_unbounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - let expected_input_bounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - - // Expected unbounded result (same for with and without flag) - let expected_optimized_unbounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - - // Expected bounded results with and without flag - let expected_optimized_bounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - let expected_optimized_bounded_sort_preserve = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - assert_optimized_in_all_boundedness_situations!( - expected_input_unbounded, - expected_input_bounded, - expected_optimized_unbounded, - expected_optimized_bounded, - expected_optimized_bounded_sort_preserve, - physical_plan, - source_unbounded, - prefer_existing_sort - ); + let sort = sort_exec_with_preserve_partitioning(ordering.clone(), filter); + let physical_plan = sort_preserving_merge_exec(ordering, sort); + + let run = ReplaceTest::new(physical_plan) + .with_boundedness(boundedness) + .with_sort_preference(sort_pref); + + let physical_plan = run.run().await; + + allow_duplicates! { + match (boundedness, sort_pref) { + (Boundedness::Unbounded, _) => { + assert_snapshot!(physical_plan, @r" + Input: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + + Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + "); + }, + (Boundedness::Bounded, SortPreference::MaximizeParallelism) => { + assert_snapshot!(physical_plan, @r" + Input / Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + }, + (Boundedness::Bounded, SortPreference::PreserveOrder) => { + assert_snapshot!(physical_plan, @r" + Input: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + + Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + } + } + } + Ok(()) } #[rstest] #[tokio::test] async fn test_replace_multiple_input_repartition_with_extra_steps_2( - #[values(false, true)] source_unbounded: bool, - #[values(false, true)] prefer_existing_sort: bool, + #[values(Boundedness::Unbounded, Boundedness::Bounded)] boundedness: Boundedness, + #[values(SortPreference::PreserveOrder, SortPreference::MaximizeParallelism)] + sort_pref: SortPreference, ) -> Result<()> { let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; - let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) - } else { - memory_exec_sorted(&schema, sort_exprs) + let ordering: LexOrdering = [sort_expr("a", &schema)].into(); + let source = match boundedness { + Boundedness::Unbounded => { + stream_exec_ordered_with_projection(&schema, ordering.clone()) + } + Boundedness::Bounded => memory_exec_sorted(&schema, ordering.clone()), }; let repartition_rr = repartition_exec_round_robin(source); - let coalesce_batches_exec_1 = coalesce_batches_exec(repartition_rr); - let repartition_hash = repartition_exec_hash(coalesce_batches_exec_1); + let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); - let coalesce_batches_exec_2 = coalesce_batches_exec(filter); - let sort = sort_exec(vec![sort_expr("a", &schema)], coalesce_batches_exec_2, true); - - let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - - // Expected inputs unbounded and bounded - let expected_input_unbounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " CoalesceBatchesExec: target_batch_size=8192", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - let expected_input_bounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " CoalesceBatchesExec: target_batch_size=8192", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - - // Expected unbounded result (same for with and without flag) - let expected_optimized_unbounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", - " CoalesceBatchesExec: target_batch_size=8192", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - - // Expected bounded results with and without flag - let expected_optimized_bounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " CoalesceBatchesExec: target_batch_size=8192", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - let expected_optimized_bounded_sort_preserve = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", - " CoalesceBatchesExec: target_batch_size=8192", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - assert_optimized_in_all_boundedness_situations!( - expected_input_unbounded, - expected_input_bounded, - expected_optimized_unbounded, - expected_optimized_bounded, - expected_optimized_bounded_sort_preserve, - physical_plan, - source_unbounded, - prefer_existing_sort - ); + let sort = sort_exec_with_preserve_partitioning(ordering.clone(), filter); + let physical_plan = sort_preserving_merge_exec(ordering, sort); + + let run = ReplaceTest::new(physical_plan) + .with_boundedness(boundedness) + .with_sort_preference(sort_pref); + + let physical_plan = run.run().await; + + allow_duplicates! { + match (boundedness, sort_pref) { + (Boundedness::Unbounded, _) => { + assert_snapshot!(physical_plan, @r" + Input: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + + Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + "); + }, + (Boundedness::Bounded, SortPreference::MaximizeParallelism) => { + assert_snapshot!(physical_plan, @r" + Input / Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + }, + (Boundedness::Bounded, SortPreference::PreserveOrder) => { + assert_snapshot!(physical_plan, @r" + Input: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + + Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + } + } + } + Ok(()) } #[rstest] #[tokio::test] async fn test_not_replacing_when_no_need_to_preserve_sorting( - #[values(false, true)] source_unbounded: bool, - #[values(false, true)] prefer_existing_sort: bool, + #[values(Boundedness::Unbounded, Boundedness::Bounded)] boundedness: Boundedness, + #[values(SortPreference::PreserveOrder, SortPreference::MaximizeParallelism)] + sort_pref: SortPreference, ) -> Result<()> { let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; - let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) - } else { - memory_exec_sorted(&schema, sort_exprs) + let ordering: LexOrdering = [sort_expr("a", &schema)].into(); + let source = match boundedness { + Boundedness::Unbounded => stream_exec_ordered_with_projection(&schema, ordering), + Boundedness::Bounded => memory_exec_sorted(&schema, ordering), }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); - let coalesce_batches_exec: Arc = coalesce_batches_exec(filter); - - let physical_plan: Arc = - coalesce_partitions_exec(coalesce_batches_exec); - - // Expected inputs unbounded and bounded - let expected_input_unbounded = [ - "CoalescePartitionsExec", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - let expected_input_bounded = [ - "CoalescePartitionsExec", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - - // Expected unbounded result (same for with and without flag) - let expected_optimized_unbounded = [ - "CoalescePartitionsExec", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - - // Expected bounded results same with and without flag, because there is no executor with ordering requirement - let expected_optimized_bounded = [ - "CoalescePartitionsExec", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - let expected_optimized_bounded_sort_preserve = expected_optimized_bounded; - - assert_optimized_in_all_boundedness_situations!( - expected_input_unbounded, - expected_input_bounded, - expected_optimized_unbounded, - expected_optimized_bounded, - expected_optimized_bounded_sort_preserve, - physical_plan, - source_unbounded, - prefer_existing_sort - ); + let physical_plan = coalesce_partitions_exec(filter); + + let run = ReplaceTest::new(physical_plan) + .with_boundedness(boundedness) + .with_sort_preference(sort_pref); + + let physical_plan = run.run().await; + + allow_duplicates! { + match (boundedness, sort_pref) { + (Boundedness::Unbounded, _) => { + assert_snapshot!(physical_plan, @r" + Input / Optimized: + CoalescePartitionsExec + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + "); + }, + (Boundedness::Bounded, SortPreference::MaximizeParallelism) => { + assert_snapshot!(physical_plan, @r" + Input / Optimized: + CoalescePartitionsExec + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + // Expected bounded results same with and without flag, because there is no executor with ordering requirement + }, + (Boundedness::Bounded, SortPreference::PreserveOrder) => { + assert_snapshot!(physical_plan, @r" + Input / Optimized: + CoalescePartitionsExec + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + } + } + } + Ok(()) } #[rstest] #[tokio::test] -async fn test_with_multiple_replacable_repartitions( - #[values(false, true)] source_unbounded: bool, - #[values(false, true)] prefer_existing_sort: bool, +async fn test_with_multiple_replaceable_repartitions( + #[values(Boundedness::Unbounded, Boundedness::Bounded)] boundedness: Boundedness, + #[values(SortPreference::PreserveOrder, SortPreference::MaximizeParallelism)] + sort_pref: SortPreference, ) -> Result<()> { let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; - let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) - } else { - memory_exec_sorted(&schema, sort_exprs) + let ordering: LexOrdering = [sort_expr("a", &schema)].into(); + let source = match boundedness { + Boundedness::Unbounded => { + stream_exec_ordered_with_projection(&schema, ordering.clone()) + } + Boundedness::Bounded => memory_exec_sorted(&schema, ordering.clone()), }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); - let coalesce_batches = coalesce_batches_exec(filter); - let repartition_hash_2 = repartition_exec_hash(coalesce_batches); - let sort = sort_exec(vec![sort_expr("a", &schema)], repartition_hash_2, true); - - let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - - // Expected inputs unbounded and bounded - let expected_input_unbounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - let expected_input_bounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - - // Expected unbounded result (same for with and without flag) - let expected_optimized_unbounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - - // Expected bounded results with and without flag - let expected_optimized_bounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - let expected_optimized_bounded_sort_preserve = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - assert_optimized_in_all_boundedness_situations!( - expected_input_unbounded, - expected_input_bounded, - expected_optimized_unbounded, - expected_optimized_bounded, - expected_optimized_bounded_sort_preserve, - physical_plan, - source_unbounded, - prefer_existing_sort - ); + let repartition_hash_2 = repartition_exec_hash(filter); + let sort = sort_exec_with_preserve_partitioning(ordering.clone(), repartition_hash_2); + let physical_plan = sort_preserving_merge_exec(ordering, sort); + + let run = ReplaceTest::new(physical_plan) + .with_boundedness(boundedness) + .with_sort_preference(sort_pref); + + let physical_plan = run.run().await; + + allow_duplicates! { + match (boundedness, sort_pref) { + (Boundedness::Unbounded, _) => { + assert_snapshot!(physical_plan, @r" + Input: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + + Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + "); + }, + (Boundedness::Bounded, SortPreference::MaximizeParallelism) => { + assert_snapshot!(physical_plan, @r" + Input / Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + }, + (Boundedness::Bounded, SortPreference::PreserveOrder) => { + assert_snapshot!(physical_plan, @r" + Input: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + + Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + } + } + } + Ok(()) } #[rstest] #[tokio::test] async fn test_not_replace_with_different_orderings( - #[values(false, true)] source_unbounded: bool, - #[values(false, true)] prefer_existing_sort: bool, + #[values(Boundedness::Unbounded, Boundedness::Bounded)] boundedness: Boundedness, + #[values(SortPreference::PreserveOrder, SortPreference::MaximizeParallelism)] + sort_pref: SortPreference, ) -> Result<()> { + use datafusion_physical_expr::LexOrdering; + let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; - let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) - } else { - memory_exec_sorted(&schema, sort_exprs) + let ordering_a = [sort_expr("a", &schema)].into(); + let source = match boundedness { + Boundedness::Unbounded => { + stream_exec_ordered_with_projection(&schema, ordering_a) + } + Boundedness::Bounded => memory_exec_sorted(&schema, ordering_a), }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); - let sort = sort_exec( - vec![sort_expr_default("c", &repartition_hash.schema())], - repartition_hash, - true, - ); + let ordering_c: LexOrdering = + [sort_expr_default("c", &repartition_hash.schema())].into(); + let sort = sort_exec_with_preserve_partitioning(ordering_c.clone(), repartition_hash); + let physical_plan = sort_preserving_merge_exec(ordering_c, sort); + + let run = ReplaceTest::new(physical_plan) + .with_boundedness(boundedness) + .with_sort_preference(sort_pref); + + let physical_plan = run.run().await; + + allow_duplicates! { + match (boundedness, sort_pref) { + (Boundedness::Unbounded, _) => { + assert_snapshot!(physical_plan, @r" + Input / Optimized: + SortPreservingMergeExec: [c@1 ASC] + SortExec: expr=[c@1 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + "); + }, + (Boundedness::Bounded, SortPreference::MaximizeParallelism) => { + assert_snapshot!(physical_plan, @r" + Input / Optimized: + SortPreservingMergeExec: [c@1 ASC] + SortExec: expr=[c@1 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + // Expected bounded results same with and without flag, because ordering requirement of the executor is + // different from the existing ordering. + }, + (Boundedness::Bounded, SortPreference::PreserveOrder) => { + assert_snapshot!(physical_plan, @r" + Input / Optimized: + SortPreservingMergeExec: [c@1 ASC] + SortExec: expr=[c@1 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + } + } + } - let physical_plan = - sort_preserving_merge_exec(vec![sort_expr_default("c", &sort.schema())], sort); - - // Expected inputs unbounded and bounded - let expected_input_unbounded = [ - "SortPreservingMergeExec: [c@1 ASC]", - " SortExec: expr=[c@1 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - let expected_input_bounded = [ - "SortPreservingMergeExec: [c@1 ASC]", - " SortExec: expr=[c@1 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - - // Expected unbounded result (same for with and without flag) - let expected_optimized_unbounded = [ - "SortPreservingMergeExec: [c@1 ASC]", - " SortExec: expr=[c@1 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - - // Expected bounded results same with and without flag, because ordering requirement of the executor is different than the existing ordering. - let expected_optimized_bounded = [ - "SortPreservingMergeExec: [c@1 ASC]", - " SortExec: expr=[c@1 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - let expected_optimized_bounded_sort_preserve = expected_optimized_bounded; - - assert_optimized_in_all_boundedness_situations!( - expected_input_unbounded, - expected_input_bounded, - expected_optimized_unbounded, - expected_optimized_bounded, - expected_optimized_bounded_sort_preserve, - physical_plan, - source_unbounded, - prefer_existing_sort - ); Ok(()) } #[rstest] #[tokio::test] async fn test_with_lost_ordering( - #[values(false, true)] source_unbounded: bool, - #[values(false, true)] prefer_existing_sort: bool, + #[values(Boundedness::Unbounded, Boundedness::Bounded)] boundedness: Boundedness, + #[values(SortPreference::PreserveOrder, SortPreference::MaximizeParallelism)] + sort_pref: SortPreference, ) -> Result<()> { let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; - let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) - } else { - memory_exec_sorted(&schema, sort_exprs) + let ordering: LexOrdering = [sort_expr("a", &schema)].into(); + let source = match boundedness { + Boundedness::Unbounded => { + stream_exec_ordered_with_projection(&schema, ordering.clone()) + } + Boundedness::Bounded => memory_exec_sorted(&schema, ordering.clone()), }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); - let physical_plan = - sort_exec(vec![sort_expr("a", &schema)], coalesce_partitions, false); - - // Expected inputs unbounded and bounded - let expected_input_unbounded = [ - "SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - let expected_input_bounded = [ - "SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - - // Expected unbounded result (same for with and without flag) - let expected_optimized_unbounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - - // Expected bounded results with and without flag - let expected_optimized_bounded = [ - "SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - let expected_optimized_bounded_sort_preserve = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - assert_optimized_in_all_boundedness_situations!( - expected_input_unbounded, - expected_input_bounded, - expected_optimized_unbounded, - expected_optimized_bounded, - expected_optimized_bounded_sort_preserve, - physical_plan, - source_unbounded, - prefer_existing_sort - ); + let physical_plan = sort_exec(ordering, coalesce_partitions); + + let run = ReplaceTest::new(physical_plan) + .with_boundedness(boundedness) + .with_sort_preference(sort_pref); + + let physical_plan = run.run().await; + + allow_duplicates! { + match (boundedness, sort_pref) { + (Boundedness::Unbounded, _) => { + assert_snapshot!(physical_plan, @r" + Input: + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false] + CoalescePartitionsExec + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + + Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + "); + }, + (Boundedness::Bounded, SortPreference::MaximizeParallelism) => { + assert_snapshot!(physical_plan, @r" + Input / Optimized: + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false] + CoalescePartitionsExec + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + }, + (Boundedness::Bounded, SortPreference::PreserveOrder) => { + assert_snapshot!(physical_plan, @r" + Input: + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false] + CoalescePartitionsExec + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + + Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + } + } + } + Ok(()) } #[rstest] #[tokio::test] async fn test_with_lost_and_kept_ordering( - #[values(false, true)] source_unbounded: bool, - #[values(false, true)] prefer_existing_sort: bool, + #[values(Boundedness::Unbounded, Boundedness::Bounded)] boundedness: Boundedness, + #[values(SortPreference::PreserveOrder, SortPreference::MaximizeParallelism)] + sort_pref: SortPreference, ) -> Result<()> { + use datafusion_physical_expr::LexOrdering; + let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; - let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) - } else { - memory_exec_sorted(&schema, sort_exprs) + let ordering_a = [sort_expr("a", &schema)].into(); + let source = match boundedness { + Boundedness::Unbounded => { + stream_exec_ordered_with_projection(&schema, ordering_a) + } + Boundedness::Bounded => memory_exec_sorted(&schema, ordering_a), }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); - let sort = sort_exec( - vec![sort_expr_default("c", &coalesce_partitions.schema())], - coalesce_partitions, - false, - ); + let ordering_c: LexOrdering = + [sort_expr_default("c", &coalesce_partitions.schema())].into(); + let sort = sort_exec(ordering_c.clone(), coalesce_partitions); let repartition_rr2 = repartition_exec_round_robin(sort); let repartition_hash2 = repartition_exec_hash(repartition_rr2); let filter = filter_exec(repartition_hash2); - let sort2 = sort_exec(vec![sort_expr_default("c", &filter.schema())], filter, true); - - let physical_plan = - sort_preserving_merge_exec(vec![sort_expr_default("c", &sort2.schema())], sort2); - - // Expected inputs unbounded and bounded - let expected_input_unbounded = [ - "SortPreservingMergeExec: [c@1 ASC]", - " SortExec: expr=[c@1 ASC], preserve_partitioning=[true]", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " SortExec: expr=[c@1 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - let expected_input_bounded = [ - "SortPreservingMergeExec: [c@1 ASC]", - " SortExec: expr=[c@1 ASC], preserve_partitioning=[true]", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " SortExec: expr=[c@1 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - - // Expected unbounded result (same for with and without flag) - let expected_optimized_unbounded = [ - "SortPreservingMergeExec: [c@1 ASC]", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=c@1 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " SortExec: expr=[c@1 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - - // Expected bounded results with and without flag - let expected_optimized_bounded = [ - "SortPreservingMergeExec: [c@1 ASC]", - " SortExec: expr=[c@1 ASC], preserve_partitioning=[true]", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " SortExec: expr=[c@1 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - let expected_optimized_bounded_sort_preserve = [ - "SortPreservingMergeExec: [c@1 ASC]", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=c@1 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " SortExec: expr=[c@1 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - assert_optimized_in_all_boundedness_situations!( - expected_input_unbounded, - expected_input_bounded, - expected_optimized_unbounded, - expected_optimized_bounded, - expected_optimized_bounded_sort_preserve, - physical_plan, - source_unbounded, - prefer_existing_sort - ); + let sort2 = sort_exec_with_preserve_partitioning(ordering_c.clone(), filter); + let physical_plan = sort_preserving_merge_exec(ordering_c, sort2); + + let run = ReplaceTest::new(physical_plan) + .with_boundedness(boundedness) + .with_sort_preference(sort_pref); + + let physical_plan = run.run().await; + + allow_duplicates! { + match (boundedness, sort_pref) { + (Boundedness::Unbounded, _) => { + assert_snapshot!(physical_plan, @r" + Input: + SortPreservingMergeExec: [c@1 ASC] + SortExec: expr=[c@1 ASC], preserve_partitioning=[true] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + SortExec: expr=[c@1 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + + Optimized: + SortPreservingMergeExec: [c@1 ASC] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=c@1 ASC + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + SortExec: expr=[c@1 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + "); + }, + (Boundedness::Bounded, SortPreference::MaximizeParallelism) => { + assert_snapshot!(physical_plan, @r" + Input / Optimized: + SortPreservingMergeExec: [c@1 ASC] + SortExec: expr=[c@1 ASC], preserve_partitioning=[true] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + SortExec: expr=[c@1 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + }, + (Boundedness::Bounded, SortPreference::PreserveOrder) => { + assert_snapshot!(physical_plan, @r" + Input: + SortPreservingMergeExec: [c@1 ASC] + SortExec: expr=[c@1 ASC], preserve_partitioning=[true] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + SortExec: expr=[c@1 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + + Optimized: + SortPreservingMergeExec: [c@1 ASC] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=c@1 ASC + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + SortExec: expr=[c@1 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + } + } + } + Ok(()) } #[rstest] #[tokio::test] async fn test_with_multiple_child_trees( - #[values(false, true)] source_unbounded: bool, - #[values(false, true)] prefer_existing_sort: bool, + #[values(Boundedness::Unbounded, Boundedness::Bounded)] boundedness: Boundedness, + #[values(SortPreference::PreserveOrder, SortPreference::MaximizeParallelism)] + sort_pref: SortPreference, ) -> Result<()> { let schema = create_test_schema()?; - let left_sort_exprs = vec![sort_expr("a", &schema)]; - let left_source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, left_sort_exprs) - } else { - memory_exec_sorted(&schema, left_sort_exprs) + let left_ordering = [sort_expr("a", &schema)].into(); + let left_source = match boundedness { + Boundedness::Unbounded => { + stream_exec_ordered_with_projection(&schema, left_ordering) + } + Boundedness::Bounded => memory_exec_sorted(&schema, left_ordering), }; let left_repartition_rr = repartition_exec_round_robin(left_source); let left_repartition_hash = repartition_exec_hash(left_repartition_rr); - let left_coalesce_partitions = - Arc::new(CoalesceBatchesExec::new(left_repartition_hash, 4096)); - - let right_sort_exprs = vec![sort_expr("a", &schema)]; - let right_source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, right_sort_exprs) - } else { - memory_exec_sorted(&schema, right_sort_exprs) + + let right_ordering = [sort_expr("a", &schema)].into(); + let right_source = match boundedness { + Boundedness::Unbounded => { + stream_exec_ordered_with_projection(&schema, right_ordering) + } + Boundedness::Bounded => memory_exec_sorted(&schema, right_ordering), }; let right_repartition_rr = repartition_exec_round_robin(right_source); let right_repartition_hash = repartition_exec_hash(right_repartition_rr); - let right_coalesce_partitions = - Arc::new(CoalesceBatchesExec::new(right_repartition_hash, 4096)); - - let hash_join_exec = - hash_join_exec(left_coalesce_partitions, right_coalesce_partitions); - let sort = sort_exec( - vec![sort_expr_default("a", &hash_join_exec.schema())], - hash_join_exec, - true, - ); - let physical_plan = - sort_preserving_merge_exec(vec![sort_expr_default("a", &sort.schema())], sort); - - // Expected inputs unbounded and bounded - let expected_input_unbounded = [ - "SortPreservingMergeExec: [a@0 ASC]", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - let expected_input_bounded = [ - "SortPreservingMergeExec: [a@0 ASC]", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - - // Expected unbounded result (same for with and without flag) - let expected_optimized_unbounded = [ - "SortPreservingMergeExec: [a@0 ASC]", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - - // Expected bounded results same with and without flag, because ordering get lost during intermediate executor anyway. Hence no need to preserve - // existing ordering. - let expected_optimized_bounded = [ - "SortPreservingMergeExec: [a@0 ASC]", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - let expected_optimized_bounded_sort_preserve = expected_optimized_bounded; - - assert_optimized_in_all_boundedness_situations!( - expected_input_unbounded, - expected_input_bounded, - expected_optimized_unbounded, - expected_optimized_bounded, - expected_optimized_bounded_sort_preserve, - physical_plan, - source_unbounded, - prefer_existing_sort - ); + let hash_join_exec = hash_join_exec(left_repartition_hash, right_repartition_hash); + let ordering: LexOrdering = [sort_expr_default("a", &hash_join_exec.schema())].into(); + let sort = sort_exec_with_preserve_partitioning(ordering.clone(), hash_join_exec); + let physical_plan = sort_preserving_merge_exec(ordering, sort); + + let run = ReplaceTest::new(physical_plan) + .with_boundedness(boundedness) + .with_sort_preference(sort_pref); + + let physical_plan = run.run().await; + + allow_duplicates! { + match (boundedness, sort_pref) { + (Boundedness::Unbounded, _) => { + assert_snapshot!(physical_plan, @r" + Input / Optimized: + SortPreservingMergeExec: [a@0 ASC] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + "); + }, + (Boundedness::Bounded, _) => { + assert_snapshot!(physical_plan, @r" + Input / Optimized: + SortPreservingMergeExec: [a@0 ASC] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + // Expected bounded results same with and without flag, because ordering get lost during intermediate executor anyway. + // Hence, no need to preserve existing ordering. + } + } + } + Ok(()) } @@ -1149,18 +1094,6 @@ fn sort_expr_options( } } -fn sort_exec( - sort_exprs: impl IntoIterator, - input: Arc, - preserve_partitioning: bool, -) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new( - SortExec::new(sort_exprs, input) - .with_preserve_partitioning(preserve_partitioning), - ) -} - fn repartition_exec_round_robin(input: Arc) -> Arc { Arc::new(RepartitionExec::try_new(input, Partitioning::RoundRobinBatch(8)).unwrap()) } @@ -1188,14 +1121,6 @@ fn filter_exec(input: Arc) -> Arc { Arc::new(FilterExec::try_new(predicate, input).unwrap()) } -fn coalesce_batches_exec(input: Arc) -> Arc { - Arc::new(CoalesceBatchesExec::new(input, 8192)) -} - -fn coalesce_partitions_exec(input: Arc) -> Arc { - Arc::new(CoalescePartitionsExec::new(input)) -} - fn hash_join_exec( left: Arc, right: Arc, @@ -1213,6 +1138,7 @@ fn hash_join_exec( &JoinType::Inner, None, PartitionMode::Partitioned, + NullEquality::NullEqualsNothing, false, ) .unwrap(), @@ -1233,7 +1159,7 @@ fn create_test_schema() -> Result { // projection parameter is given static due to testing needs fn memory_exec_sorted( schema: &SchemaRef, - sort_exprs: impl IntoIterator, + ordering: LexOrdering, ) -> Arc { pub fn make_partition(schema: &SchemaRef, sz: i32) -> RecordBatch { let values = (0..sz).collect::>(); @@ -1249,7 +1175,6 @@ fn memory_exec_sorted( let rows = 5; let partitions = 1; - let sort_exprs = sort_exprs.into_iter().collect(); Arc::new({ let data: Vec> = (0..partitions) .map(|_| vec![make_partition(schema, rows)]) @@ -1258,7 +1183,7 @@ fn memory_exec_sorted( DataSourceExec::new(Arc::new( MemorySourceConfig::try_new(&data, schema.clone(), Some(projection)) .unwrap() - .try_with_sort_information(vec![sort_exprs]) + .try_with_sort_information(vec![ordering]) .unwrap(), )) }) @@ -1268,12 +1193,11 @@ fn memory_exec_sorted( fn test_plan_with_order_preserving_variants_preserves_fetch() -> Result<()> { // Create a schema let schema = create_test_schema3()?; - let parquet_sort_exprs = vec![crate::physical_optimizer::test_utils::sort_expr( - "a", &schema, - )]; - let parquet_exec = parquet_exec_sorted(&schema, parquet_sort_exprs); - let coalesced = - Arc::new(CoalescePartitionsExec::new(parquet_exec.clone()).with_fetch(Some(10))); + let parquet_sort_exprs = vec![[sort_expr("a", &schema)].into()]; + let parquet_exec = parquet_exec_with_sort(schema, parquet_sort_exprs); + let coalesced = coalesce_partitions_exec(parquet_exec.clone()) + .with_fetch(Some(10)) + .unwrap(); // Test sort's fetch is greater than coalesce fetch, return error because it's not reasonable let requirements = OrderPreservationContext::new( @@ -1286,7 +1210,10 @@ fn test_plan_with_order_preserving_variants_preserves_fetch() -> Result<()> { )], ); let res = plan_with_order_preserving_variants(requirements, false, true, Some(15)); - assert_contains!(res.unwrap_err().to_string(), "CoalescePartitionsExec fetch [10] should be greater than or equal to SortExec fetch [15]"); + assert_contains!( + res.unwrap_err().to_string(), + "CoalescePartitionsExec fetch [10] should be greater than or equal to SortExec fetch [15]" + ); // Test sort is without fetch, expected to get the fetch value from the coalesced let requirements = OrderPreservationContext::new( @@ -1315,17 +1242,15 @@ fn test_plan_with_order_preserving_variants_preserves_fetch() -> Result<()> { #[test] fn test_plan_with_order_breaking_variants_preserves_fetch() -> Result<()> { let schema = create_test_schema3()?; - let parquet_sort_exprs = vec![crate::physical_optimizer::test_utils::sort_expr( - "a", &schema, - )]; - let parquet_exec = parquet_exec_sorted(&schema, parquet_sort_exprs.clone()); - let spm = SortPreservingMergeExec::new( - LexOrdering::new(parquet_sort_exprs), + let parquet_sort_exprs: LexOrdering = [sort_expr("a", &schema)].into(); + let parquet_exec = parquet_exec_with_sort(schema, vec![parquet_sort_exprs.clone()]); + let spm = sort_preserving_merge_exec_with_fetch( + parquet_sort_exprs, parquet_exec.clone(), - ) - .with_fetch(Some(10)); + 10, + ); let requirements = OrderPreservationContext::new( - Arc::new(spm), + spm, true, vec![OrderPreservationContext::new( parquet_exec.clone(), diff --git a/datafusion/core/tests/physical_optimizer/sanity_checker.rs b/datafusion/core/tests/physical_optimizer/sanity_checker.rs index a73d084a081f3..217570846d56e 100644 --- a/datafusion/core/tests/physical_optimizer/sanity_checker.rs +++ b/datafusion/core/tests/physical_optimizer/sanity_checker.rs @@ -15,11 +15,13 @@ // specific language governing permissions and limitations // under the License. +use insta::assert_snapshot; use std::sync::Arc; use crate::physical_optimizer::test_utils::{ bounded_window_exec, global_limit_exec, local_limit_exec, memory_exec, - repartition_exec, sort_exec, sort_expr_options, sort_merge_join_exec, + projection_exec, repartition_exec, sort_exec, sort_expr, sort_expr_options, + sort_merge_join_exec, sort_preserving_merge_exec, union_exec, }; use arrow::compute::SortOptions; @@ -27,13 +29,14 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; use datafusion::prelude::{CsvReadOptions, SessionContext}; use datafusion_common::config::ConfigOptions; -use datafusion_common::{JoinType, Result}; -use datafusion_physical_expr::expressions::col; +use datafusion_common::{JoinType, Result, ScalarValue}; use datafusion_physical_expr::Partitioning; -use datafusion_physical_optimizer::sanity_checker::SanityCheckPlan; +use datafusion_physical_expr::expressions::{Literal, col}; +use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_optimizer::sanity_checker::SanityCheckPlan; use datafusion_physical_plan::repartition::RepartitionExec; -use datafusion_physical_plan::{displayable, ExecutionPlan}; +use datafusion_physical_plan::{ExecutionPlan, displayable}; use async_trait::async_trait; @@ -397,34 +400,32 @@ fn assert_sanity_check(plan: &Arc, is_sane: bool) { ); } -/// Check if the plan we created is as expected by comparing the plan -/// formatted as a string. -fn assert_plan(plan: &dyn ExecutionPlan, expected_lines: Vec<&str>) { - let plan_str = displayable(plan).indent(true).to_string(); - let actual_lines: Vec<&str> = plan_str.trim().lines().collect(); - assert_eq!(actual_lines, expected_lines); -} - #[tokio::test] /// Tests that plan is valid when the sort requirements are satisfied. async fn test_bounded_window_agg_sort_requirement() -> Result<()> { let schema = create_test_schema(); let source = memory_exec(&schema); - let sort_exprs = vec![sort_expr_options( + let ordering: LexOrdering = [sort_expr_options( "c9", &source.schema(), SortOptions { descending: false, nulls_first: false, }, - )]; - let sort = sort_exec(sort_exprs.clone(), source); - let bw = bounded_window_exec("c9", sort_exprs, sort); - assert_plan(bw.as_ref(), vec![ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " SortExec: expr=[c9@0 ASC NULLS LAST], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]" - ]); + )] + .into(); + let sort = sort_exec(ordering.clone(), source); + let bw = bounded_window_exec("c9", ordering, sort); + let plan_str = displayable(bw.as_ref()).indent(true).to_string(); + let actual = plan_str.trim(); + assert_snapshot!( + actual, + @r#" + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: expr=[c9@0 ASC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + "# + ); assert_sanity_check(&bw, true); Ok(()) } @@ -443,10 +444,15 @@ async fn test_bounded_window_agg_no_sort_requirement() -> Result<()> { }, )]; let bw = bounded_window_exec("c9", sort_exprs, source); - assert_plan(bw.as_ref(), vec![ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " DataSourceExec: partitions=1, partition_sizes=[0]" - ]); + let plan_str = displayable(bw.as_ref()).indent(true).to_string(); + let actual = plan_str.trim(); + assert_snapshot!( + actual, + @r#" + BoundedWindowAggExec: wdw=[count: Field { "count": Int64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: partitions=1, partition_sizes=[0] + "# + ); // Order requirement of the `BoundedWindowAggExec` is not satisfied. We expect to receive error during sanity check. assert_sanity_check(&bw, false); Ok(()) @@ -458,14 +464,16 @@ async fn test_bounded_window_agg_no_sort_requirement() -> Result<()> { async fn test_global_limit_single_partition() -> Result<()> { let schema = create_test_schema(); let source = memory_exec(&schema); - let limit = global_limit_exec(source); - - assert_plan( - limit.as_ref(), - vec![ - "GlobalLimitExec: skip=0, fetch=100", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ], + let limit = global_limit_exec(source, 0, Some(100)); + + let plan_str = displayable(limit.as_ref()).indent(true).to_string(); + let actual = plan_str.trim(); + assert_snapshot!( + actual, + @r" + GlobalLimitExec: skip=0, fetch=100 + DataSourceExec: partitions=1, partition_sizes=[0] + " ); assert_sanity_check(&limit, true); Ok(()) @@ -477,15 +485,17 @@ async fn test_global_limit_single_partition() -> Result<()> { async fn test_global_limit_multi_partition() -> Result<()> { let schema = create_test_schema(); let source = memory_exec(&schema); - let limit = global_limit_exec(repartition_exec(source)); - - assert_plan( - limit.as_ref(), - vec![ - "GlobalLimitExec: skip=0, fetch=100", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ], + let limit = global_limit_exec(repartition_exec(source), 0, Some(100)); + + let plan_str = displayable(limit.as_ref()).indent(true).to_string(); + let actual = plan_str.trim(); + assert_snapshot!( + actual, + @r" + GlobalLimitExec: skip=0, fetch=100 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[0] + " ); // Distribution requirement of the `GlobalLimitExec` is not satisfied. We expect to receive error during sanity check. assert_sanity_check(&limit, false); @@ -497,14 +507,16 @@ async fn test_global_limit_multi_partition() -> Result<()> { async fn test_local_limit() -> Result<()> { let schema = create_test_schema(); let source = memory_exec(&schema); - let limit = local_limit_exec(source); - - assert_plan( - limit.as_ref(), - vec![ - "LocalLimitExec: fetch=100", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ], + let limit = local_limit_exec(source, 100); + + let plan_str = displayable(limit.as_ref()).indent(true).to_string(); + let actual = plan_str.trim(); + assert_snapshot!( + actual, + @r" + LocalLimitExec: fetch=100 + DataSourceExec: partitions=1, partition_sizes=[0] + " ); assert_sanity_check(&limit, true); Ok(()) @@ -518,12 +530,12 @@ async fn test_sort_merge_join_satisfied() -> Result<()> { let source1 = memory_exec(&schema1); let source2 = memory_exec(&schema2); let sort_opts = SortOptions::default(); - let sort_exprs1 = vec![sort_expr_options("c9", &source1.schema(), sort_opts)]; - let sort_exprs2 = vec![sort_expr_options("a", &source2.schema(), sort_opts)]; - let left = sort_exec(sort_exprs1, source1); - let right = sort_exec(sort_exprs2, source2); - let left_jcol = col("c9", &left.schema()).unwrap(); - let right_jcol = col("a", &right.schema()).unwrap(); + let ordering1 = [sort_expr_options("c9", &source1.schema(), sort_opts)].into(); + let ordering2 = [sort_expr_options("a", &source2.schema(), sort_opts)].into(); + let left = sort_exec(ordering1, source1); + let right = sort_exec(ordering2, source2); + let left_jcol = col("c9", &left.schema())?; + let right_jcol = col("a", &right.schema())?; let left = Arc::new(RepartitionExec::try_new( left, Partitioning::Hash(vec![left_jcol.clone()], 10), @@ -538,17 +550,19 @@ async fn test_sort_merge_join_satisfied() -> Result<()> { let join_ty = JoinType::Inner; let smj = sort_merge_join_exec(left, right, &join_on, &join_ty); - assert_plan( - smj.as_ref(), - vec![ - "SortMergeJoin: join_type=Inner, on=[(c9@0, a@0)]", - " RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1", - " SortExec: expr=[c9@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ], + let plan_str = displayable(smj.as_ref()).indent(true).to_string(); + let actual = plan_str.trim(); + assert_snapshot!( + actual, + @r" + SortMergeJoinExec: join_type=Inner, on=[(c9@0, a@0)] + RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[c9@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + " ); assert_sanity_check(&smj, true); Ok(()) @@ -562,15 +576,16 @@ async fn test_sort_merge_join_order_missing() -> Result<()> { let schema2 = create_test_schema2(); let source1 = memory_exec(&schema1); let right = memory_exec(&schema2); - let sort_exprs1 = vec![sort_expr_options( + let ordering1 = [sort_expr_options( "c9", &source1.schema(), SortOptions::default(), - )]; - let left = sort_exec(sort_exprs1, source1); + )] + .into(); + let left = sort_exec(ordering1, source1); // Missing sort of the right child here.. - let left_jcol = col("c9", &left.schema()).unwrap(); - let right_jcol = col("a", &right.schema()).unwrap(); + let left_jcol = col("c9", &left.schema())?; + let right_jcol = col("a", &right.schema())?; let left = Arc::new(RepartitionExec::try_new( left, Partitioning::Hash(vec![left_jcol.clone()], 10), @@ -585,16 +600,18 @@ async fn test_sort_merge_join_order_missing() -> Result<()> { let join_ty = JoinType::Inner; let smj = sort_merge_join_exec(left, right, &join_on, &join_ty); - assert_plan( - smj.as_ref(), - vec![ - "SortMergeJoin: join_type=Inner, on=[(c9@0, a@0)]", - " RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1", - " SortExec: expr=[c9@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ], + let plan_str = displayable(smj.as_ref()).indent(true).to_string(); + let actual = plan_str.trim(); + assert_snapshot!( + actual, + @r" + SortMergeJoinExec: join_type=Inner, on=[(c9@0, a@0)] + RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[c9@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[0] + " ); // Order requirement for the `SortMergeJoin` is not satisfied for right child. We expect to receive error during sanity check. assert_sanity_check(&smj, false); @@ -610,16 +627,16 @@ async fn test_sort_merge_join_dist_missing() -> Result<()> { let source1 = memory_exec(&schema1); let source2 = memory_exec(&schema2); let sort_opts = SortOptions::default(); - let sort_exprs1 = vec![sort_expr_options("c9", &source1.schema(), sort_opts)]; - let sort_exprs2 = vec![sort_expr_options("a", &source2.schema(), sort_opts)]; - let left = sort_exec(sort_exprs1, source1); - let right = sort_exec(sort_exprs2, source2); + let ordering1 = [sort_expr_options("c9", &source1.schema(), sort_opts)].into(); + let ordering2 = [sort_expr_options("a", &source2.schema(), sort_opts)].into(); + let left = sort_exec(ordering1, source1); + let right = sort_exec(ordering2, source2); let right = Arc::new(RepartitionExec::try_new( right, Partitioning::RoundRobinBatch(10), )?); - let left_jcol = col("c9", &left.schema()).unwrap(); - let right_jcol = col("a", &right.schema()).unwrap(); + let left_jcol = col("c9", &left.schema())?; + let right_jcol = col("a", &right.schema())?; let left = Arc::new(RepartitionExec::try_new( left, Partitioning::Hash(vec![left_jcol.clone()], 10), @@ -631,19 +648,95 @@ async fn test_sort_merge_join_dist_missing() -> Result<()> { let join_ty = JoinType::Inner; let smj = sort_merge_join_exec(left, right, &join_on, &join_ty); - assert_plan( - smj.as_ref(), - vec![ - "SortMergeJoin: join_type=Inner, on=[(c9@0, a@0)]", - " RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1", - " SortExec: expr=[c9@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ], + let plan_str = displayable(smj.as_ref()).indent(true).to_string(); + let actual = plan_str.trim(); + assert_snapshot!( + actual, + @r" + SortMergeJoinExec: join_type=Inner, on=[(c9@0, a@0)] + RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[c9@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1, maintains_sort_order=true + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + " ); // Distribution requirement for the `SortMergeJoin` is not satisfied for right child (has round-robin partitioning). We expect to receive error during sanity check. assert_sanity_check(&smj, false); Ok(()) } + +/// A particular edge case. +/// +/// See . +#[tokio::test] +async fn test_union_with_sorts_and_constants() -> Result<()> { + let schema_in = create_test_schema2(); + + let proj_exprs_1 = vec![ + ( + Arc::new(Literal::new(ScalarValue::Utf8(Some("foo".to_owned())))) as _, + "const_1".to_owned(), + ), + ( + Arc::new(Literal::new(ScalarValue::Utf8(Some("foo".to_owned())))) as _, + "const_2".to_owned(), + ), + (col("a", &schema_in).unwrap(), "a".to_owned()), + ]; + let proj_exprs_2 = vec![ + ( + Arc::new(Literal::new(ScalarValue::Utf8(Some("foo".to_owned())))) as _, + "const_1".to_owned(), + ), + ( + Arc::new(Literal::new(ScalarValue::Utf8(Some("bar".to_owned())))) as _, + "const_2".to_owned(), + ), + (col("a", &schema_in).unwrap(), "a".to_owned()), + ]; + + let source_1 = memory_exec(&schema_in); + let source_1 = projection_exec(proj_exprs_1.clone(), source_1).unwrap(); + let schema_sources = source_1.schema(); + let ordering_sources: LexOrdering = + [sort_expr("a", &schema_sources).nulls_last()].into(); + let source_1 = sort_exec(ordering_sources.clone(), source_1); + + let source_2 = memory_exec(&schema_in); + let source_2 = projection_exec(proj_exprs_2, source_2).unwrap(); + let source_2 = sort_exec(ordering_sources.clone(), source_2); + + let plan = union_exec(vec![source_1, source_2]); + + let schema_out = plan.schema(); + let ordering_out: LexOrdering = [ + sort_expr("const_1", &schema_out).nulls_last(), + sort_expr("const_2", &schema_out).nulls_last(), + sort_expr("a", &schema_out).nulls_last(), + ] + .into(); + + let plan = sort_preserving_merge_exec(ordering_out, plan); + + let plan_str = displayable(plan.as_ref()).indent(true).to_string(); + let plan_str = plan_str.trim(); + assert_snapshot!( + plan_str, + @r" + SortPreservingMergeExec: [const_1@0 ASC NULLS LAST, const_2@1 ASC NULLS LAST, a@2 ASC NULLS LAST] + UnionExec + SortExec: expr=[a@2 ASC NULLS LAST], preserve_partitioning=[false] + ProjectionExec: expr=[foo as const_1, foo as const_2, a@0 as a] + DataSourceExec: partitions=1, partition_sizes=[0] + SortExec: expr=[a@2 ASC NULLS LAST], preserve_partitioning=[false] + ProjectionExec: expr=[foo as const_1, bar as const_2, a@0 as a] + DataSourceExec: partitions=1, partition_sizes=[0] + " + ); + + assert_sanity_check(&plan, true); + + Ok(()) +} diff --git a/datafusion/core/tests/physical_optimizer/test_utils.rs b/datafusion/core/tests/physical_optimizer/test_utils.rs index 955486a310309..8d9e7b68b8c96 100644 --- a/datafusion/core/tests/physical_optimizer/test_utils.rs +++ b/datafusion/core/tests/physical_optimizer/test_utils.rs @@ -18,8 +18,8 @@ //! Test utilities for physical optimizer tests use std::any::Any; -use std::fmt::Formatter; -use std::sync::Arc; +use std::fmt::{Display, Formatter}; +use std::sync::{Arc, LazyLock}; use arrow::array::Int32Array; use arrow::compute::SortOptions; @@ -31,49 +31,54 @@ use datafusion::datasource::physical_plan::ParquetSource; use datafusion::datasource::source::DataSourceExec; use datafusion_common::config::ConfigOptions; use datafusion_common::stats::Precision; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; use datafusion_common::utils::expr::COUNT_STAR_EXPANSION; -use datafusion_common::{ColumnStatistics, JoinType, Result, Statistics}; +use datafusion_common::{ + ColumnStatistics, JoinType, NullEquality, Result, Statistics, internal_err, +}; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::{WindowFrame, WindowFunctionDefinition}; use datafusion_functions_aggregate::count::count_udaf; +use datafusion_physical_expr::EquivalenceProperties; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; -use datafusion_physical_expr::expressions::col; -use datafusion_physical_expr::{expressions, PhysicalExpr}; +use datafusion_physical_expr::expressions::{self, col}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{ - LexOrdering, LexRequirement, PhysicalSortExpr, + LexOrdering, OrderingRequirements, PhysicalSortExpr, }; -use datafusion_physical_optimizer::limited_distinct_aggregation::LimitedDistinctAggregation; use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_optimizer::limited_distinct_aggregation::LimitedDistinctAggregation; use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; -use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::joins::utils::{JoinFilter, JoinOn}; use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode, SortMergeJoinExec}; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr}; use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; use datafusion_physical_plan::tree_node::PlanContext; use datafusion_physical_plan::union::UnionExec; -use datafusion_physical_plan::windows::{create_window_expr, BoundedWindowAggExec}; +use datafusion_physical_plan::windows::{BoundedWindowAggExec, create_window_expr}; use datafusion_physical_plan::{ - displayable, DisplayAs, DisplayFormatType, ExecutionPlan, InputOrderMode, - Partitioning, PlanProperties, + DisplayAs, DisplayFormatType, ExecutionPlan, InputOrderMode, Partitioning, + PlanProperties, SortOrderPushdownResult, displayable, }; /// Create a non sorted parquet exec -pub fn parquet_exec(schema: &SchemaRef) -> Arc { +pub fn parquet_exec(schema: SchemaRef) -> Arc { let config = FileScanConfigBuilder::new( ObjectStoreUrl::parse("test:///").unwrap(), - schema.clone(), - Arc::new(ParquetSource::default()), + Arc::new(ParquetSource::new(schema)), ) .with_file(PartitionedFile::new("x".to_string(), 100)) .build(); @@ -83,12 +88,12 @@ pub fn parquet_exec(schema: &SchemaRef) -> Arc { /// Create a single parquet file that is sorted pub(crate) fn parquet_exec_with_sort( + schema: SchemaRef, output_ordering: Vec, ) -> Arc { let config = FileScanConfigBuilder::new( ObjectStoreUrl::parse("test:///").unwrap(), - schema(), - Arc::new(ParquetSource::default()), + Arc::new(ParquetSource::new(schema)), ) .with_file(PartitionedFile::new("x".to_string(), 100)) .with_output_ordering(output_ordering) @@ -104,6 +109,7 @@ fn int64_stats() -> ColumnStatistics { max_value: Precision::Exact(1_000_000.into()), min_value: Precision::Exact(0.into()), distinct_count: Precision::Absent, + byte_size: Precision::Absent, } } @@ -125,52 +131,60 @@ pub(crate) fn parquet_exec_with_stats(file_size: u64) -> Arc { let config = FileScanConfigBuilder::new( ObjectStoreUrl::parse("test:///").unwrap(), - schema(), - Arc::new(ParquetSource::new(Default::default())), + Arc::new(ParquetSource::new(schema())), ) .with_file(PartitionedFile::new("x".to_string(), file_size)) .with_statistics(statistics) .build(); - assert_eq!( - config.file_source.statistics().unwrap().num_rows, - Precision::Inexact(10000) - ); + assert_eq!(config.statistics().num_rows, Precision::Inexact(10000)); DataSourceExec::from_data_source(config) } pub fn schema() -> SchemaRef { - Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int64, true), - Field::new("b", DataType::Int64, true), - Field::new("c", DataType::Int64, true), - Field::new("d", DataType::Int32, true), - Field::new("e", DataType::Boolean, true), - ])) + static SCHEMA: LazyLock = LazyLock::new(|| { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + Field::new("c", DataType::Int64, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Boolean, true), + ])) + }); + Arc::clone(&SCHEMA) } pub fn create_test_schema() -> Result { - let nullable_column = Field::new("nullable_col", DataType::Int32, true); - let non_nullable_column = Field::new("non_nullable_col", DataType::Int32, false); - let schema = Arc::new(Schema::new(vec![nullable_column, non_nullable_column])); + static SCHEMA: LazyLock = LazyLock::new(|| { + let nullable_column = Field::new("nullable_col", DataType::Int32, true); + let non_nullable_column = Field::new("non_nullable_col", DataType::Int32, false); + Arc::new(Schema::new(vec![nullable_column, non_nullable_column])) + }); + let schema = Arc::clone(&SCHEMA); Ok(schema) } pub fn create_test_schema2() -> Result { - let col_a = Field::new("col_a", DataType::Int32, true); - let col_b = Field::new("col_b", DataType::Int32, true); - let schema = Arc::new(Schema::new(vec![col_a, col_b])); + static SCHEMA: LazyLock = LazyLock::new(|| { + let col_a = Field::new("col_a", DataType::Int32, true); + let col_b = Field::new("col_b", DataType::Int32, true); + Arc::new(Schema::new(vec![col_a, col_b])) + }); + let schema = Arc::clone(&SCHEMA); Ok(schema) } // Generate a schema which consists of 5 columns (a, b, c, d, e) pub fn create_test_schema3() -> Result { - let a = Field::new("a", DataType::Int32, true); - let b = Field::new("b", DataType::Int32, false); - let c = Field::new("c", DataType::Int32, true); - let d = Field::new("d", DataType::Int32, false); - let e = Field::new("e", DataType::Int32, false); - let schema = Arc::new(Schema::new(vec![a, b, c, d, e])); + static SCHEMA: LazyLock = LazyLock::new(|| { + let a = Field::new("a", DataType::Int32, true); + let b = Field::new("b", DataType::Int32, false); + let c = Field::new("c", DataType::Int32, true); + let d = Field::new("d", DataType::Int32, false); + let e = Field::new("e", DataType::Int32, false); + Arc::new(Schema::new(vec![a, b, c, d, e])) + }); + let schema = Arc::clone(&SCHEMA); Ok(schema) } @@ -188,7 +202,7 @@ pub fn sort_merge_join_exec( None, *join_type, vec![SortOptions::default(); join_on.len()], - false, + NullEquality::NullEqualsNothing, ) .unwrap(), ) @@ -234,7 +248,8 @@ pub fn hash_join_exec( join_type, None, PartitionMode::Partitioned, - true, + NullEquality::NullEqualsNothing, + false, )?)) } @@ -243,17 +258,28 @@ pub fn bounded_window_exec( sort_exprs: impl IntoIterator, input: Arc, ) -> Arc { - let sort_exprs: LexOrdering = sort_exprs.into_iter().collect(); + bounded_window_exec_with_partition(col_name, sort_exprs, &[], input) +} + +pub fn bounded_window_exec_with_partition( + col_name: &str, + sort_exprs: impl IntoIterator, + partition_by: &[Arc], + input: Arc, +) -> Arc { + let sort_exprs = sort_exprs.into_iter().collect::>(); let schema = input.schema(); let window_expr = create_window_expr( &WindowFunctionDefinition::AggregateUDF(count_udaf()), "count".to_owned(), &[col(col_name, &schema).unwrap()], - &[], - sort_exprs.as_ref(), + partition_by, + &sort_exprs, Arc::new(WindowFrame::new(Some(false))), - schema.as_ref(), + schema, + false, false, + None, ) .unwrap(); @@ -276,36 +302,37 @@ pub fn filter_exec( } pub fn sort_preserving_merge_exec( - sort_exprs: impl IntoIterator, + ordering: LexOrdering, input: Arc, ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new(SortPreservingMergeExec::new(sort_exprs, input)) + Arc::new(SortPreservingMergeExec::new(ordering, input)) } pub fn sort_preserving_merge_exec_with_fetch( - sort_exprs: impl IntoIterator, + ordering: LexOrdering, input: Arc, fetch: usize, ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new(SortPreservingMergeExec::new(sort_exprs, input).with_fetch(Some(fetch))) + Arc::new(SortPreservingMergeExec::new(ordering, input).with_fetch(Some(fetch))) } pub fn union_exec(input: Vec>) -> Arc { - Arc::new(UnionExec::new(input)) -} - -pub fn limit_exec(input: Arc) -> Arc { - global_limit_exec(local_limit_exec(input)) + UnionExec::try_new(input).unwrap() } -pub fn local_limit_exec(input: Arc) -> Arc { - Arc::new(LocalLimitExec::new(input, 100)) +pub fn local_limit_exec( + input: Arc, + fetch: usize, +) -> Arc { + Arc::new(LocalLimitExec::new(input, fetch)) } -pub fn global_limit_exec(input: Arc) -> Arc { - Arc::new(GlobalLimitExec::new(input, 0, Some(100))) +pub fn global_limit_exec( + input: Arc, + skip: usize, + fetch: Option, +) -> Arc { + Arc::new(GlobalLimitExec::new(input, skip, fetch)) } pub fn repartition_exec(input: Arc) -> Arc { @@ -335,30 +362,43 @@ pub fn aggregate_exec(input: Arc) -> Arc { ) } -pub fn coalesce_batches_exec(input: Arc) -> Arc { - Arc::new(CoalesceBatchesExec::new(input, 128)) +pub fn sort_exec( + ordering: LexOrdering, + input: Arc, +) -> Arc { + sort_exec_with_fetch(ordering, None, input) } -pub fn sort_exec( - sort_exprs: impl IntoIterator, +pub fn sort_exec_with_preserve_partitioning( + ordering: LexOrdering, input: Arc, ) -> Arc { - sort_exec_with_fetch(sort_exprs, None, input) + Arc::new(SortExec::new(ordering, input).with_preserve_partitioning(true)) } pub fn sort_exec_with_fetch( - sort_exprs: impl IntoIterator, + ordering: LexOrdering, fetch: Option, input: Arc, ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new(SortExec::new(sort_exprs, input).with_fetch(fetch)) + Arc::new(SortExec::new(ordering, input).with_fetch(fetch)) +} + +pub fn projection_exec( + expr: Vec<(Arc, String)>, + input: Arc, +) -> Result> { + let proj_exprs: Vec = expr + .into_iter() + .map(|(expr, alias)| ProjectionExpr { expr, alias }) + .collect(); + Ok(Arc::new(ProjectionExec::try_new(proj_exprs, input)?)) } /// A test [`ExecutionPlan`] whose requirements can be configured. #[derive(Debug)] pub struct RequirementsTestExec { - required_input_ordering: LexOrdering, + required_input_ordering: Option, maintains_input_order: bool, input: Arc, } @@ -366,7 +406,7 @@ pub struct RequirementsTestExec { impl RequirementsTestExec { pub fn new(input: Arc) -> Self { Self { - required_input_ordering: LexOrdering::default(), + required_input_ordering: None, maintains_input_order: true, input, } @@ -375,7 +415,7 @@ impl RequirementsTestExec { /// sets the required input ordering pub fn with_required_input_ordering( mut self, - required_input_ordering: LexOrdering, + required_input_ordering: Option, ) -> Self { self.required_input_ordering = required_input_ordering; self @@ -416,13 +456,16 @@ impl ExecutionPlan for RequirementsTestExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { self.input.properties() } - fn required_input_ordering(&self) -> Vec> { - let requirement = LexRequirement::from(self.required_input_ordering.clone()); - vec![Some(requirement)] + fn required_input_ordering(&self) -> Vec> { + vec![ + self.required_input_ordering + .as_ref() + .map(|ordering| OrderingRequirements::from(ordering.clone())), + ] } fn maintains_input_order(&self) -> Vec { @@ -451,6 +494,20 @@ impl ExecutionPlan for RequirementsTestExec { ) -> Result { unimplemented!("Test exec does not support execution") } + + fn apply_expressions( + &self, + f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + // Visit expressions in required_input_ordering if present + let mut tnr = TreeNodeRecursion::Continue; + if let Some(ordering) = &self.required_input_ordering { + for sort_expr in ordering { + tnr = tnr.visit_sibling(|| f(sort_expr.expr.as_ref()))?; + } + } + Ok(tnr) + } } /// A [`PlanContext`] object is susceptible to being left in an inconsistent state after @@ -479,13 +536,6 @@ pub fn check_integrity(context: PlanContext) -> Result Vec<&str> { - plan.split('\n') - .map(|s| s.trim()) - .filter(|s| !s.is_empty()) - .collect() -} - // construct a stream partition for test purposes #[derive(Debug)] pub struct TestStreamPartition { @@ -501,13 +551,28 @@ impl PartitionStream for TestStreamPartition { } } -/// Create an unbounded stream exec +/// Create an unbounded stream table without data ordering. +pub fn stream_exec(schema: &SchemaRef) -> Arc { + Arc::new( + StreamingTableExec::try_new( + Arc::clone(schema), + vec![Arc::new(TestStreamPartition { + schema: Arc::clone(schema), + }) as _], + None, + vec![], + true, + None, + ) + .unwrap(), + ) +} + +/// Create an unbounded stream table with data ordering. pub fn stream_exec_ordered( schema: &SchemaRef, - sort_exprs: impl IntoIterator, + ordering: LexOrdering, ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new( StreamingTableExec::try_new( Arc::clone(schema), @@ -515,7 +580,7 @@ pub fn stream_exec_ordered( schema: Arc::clone(schema), }) as _], None, - vec![sort_exprs], + vec![ordering], true, None, ) @@ -523,12 +588,11 @@ pub fn stream_exec_ordered( ) } -// Creates a stream exec source for the test purposes +/// Create an unbounded stream table with data ordering and built-in projection. pub fn stream_exec_ordered_with_projection( schema: &SchemaRef, - sort_exprs: impl IntoIterator, + ordering: LexOrdering, ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); let projection: Vec = vec![0, 2, 3]; Arc::new( @@ -538,7 +602,7 @@ pub fn stream_exec_ordered_with_projection( schema: Arc::clone(schema), }) as _], Some(&projection), - vec![sort_exprs], + vec![ordering], true, None, ) @@ -585,25 +649,15 @@ pub fn build_group_by(input_schema: &SchemaRef, columns: Vec) -> Physica PhysicalGroupBy::new_single(group_by_expr.clone()) } -pub fn assert_plan_matches_expected( - plan: &Arc, - expected: &[&str], -) -> Result<()> { - let expected_lines: Vec<&str> = expected.to_vec(); +pub fn get_optimized_plan(plan: &Arc) -> Result { let config = ConfigOptions::new(); let optimized = LimitedDistinctAggregation::new().optimize(Arc::clone(plan), &config)?; let optimized_result = displayable(optimized.as_ref()).indent(true).to_string(); - let actual_lines = trim_plan_display(&optimized_result); - assert_eq!( - &expected_lines, &actual_lines, - "\n\nexpected:\n\n{expected_lines:#?}\nactual:\n\n{actual_lines:#?}\n\n" - ); - - Ok(()) + Ok(optimized_result) } /// Describe the type of aggregate being tested @@ -659,3 +713,300 @@ impl TestAggregate { } } } + +/// A harness for testing physical optimizers. +#[derive(Debug)] +pub struct OptimizationTest { + input: Vec, + output: Result, String>, +} + +impl OptimizationTest { + pub fn new( + input_plan: Arc, + opt: O, + enable_sort_pushdown: bool, + ) -> Self + where + O: PhysicalOptimizerRule, + { + let input = format_execution_plan(&input_plan); + let input_schema = input_plan.schema(); + + let mut config = ConfigOptions::new(); + config.optimizer.enable_sort_pushdown = enable_sort_pushdown; + let output_result = opt.optimize(input_plan, &config); + let output = output_result + .and_then(|plan| { + if opt.schema_check() && (plan.schema() != input_schema) { + internal_err!( + "Schema mismatch:\n\nBefore:\n{:?}\n\nAfter:\n{:?}", + input_schema, + plan.schema() + ) + } else { + Ok(plan) + } + }) + .map(|plan| format_execution_plan(&plan)) + .map_err(|e| e.to_string()); + + Self { input, output } + } +} + +impl Display for OptimizationTest { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + writeln!(f, "OptimizationTest:")?; + writeln!(f, " input:")?; + for line in &self.input { + writeln!(f, " - {line}")?; + } + writeln!(f, " output:")?; + match &self.output { + Ok(output) => { + writeln!(f, " Ok:")?; + for line in output { + writeln!(f, " - {line}")?; + } + } + Err(err) => { + writeln!(f, " Err: {err}")?; + } + } + Ok(()) + } +} + +pub fn format_execution_plan(plan: &Arc) -> Vec { + format_lines(&displayable(plan.as_ref()).indent(false).to_string()) +} + +fn format_lines(s: &str) -> Vec { + s.trim().split('\n').map(|s| s.to_string()).collect() +} + +/// Create a simple ProjectionExec with column indices (simplified version) +pub fn simple_projection_exec( + input: Arc, + columns: Vec, +) -> Arc { + let schema = input.schema(); + let exprs: Vec<(Arc, String)> = columns + .iter() + .map(|&i| { + let field = schema.field(i); + ( + Arc::new(expressions::Column::new(field.name(), i)) + as Arc, + field.name().to_string(), + ) + }) + .collect(); + + projection_exec(exprs, input).unwrap() +} + +/// Create a ProjectionExec with column aliases +pub fn projection_exec_with_alias( + input: Arc, + columns: Vec<(usize, &str)>, +) -> Arc { + let schema = input.schema(); + let exprs: Vec<(Arc, String)> = columns + .iter() + .map(|&(i, alias)| { + ( + Arc::new(expressions::Column::new(schema.field(i).name(), i)) + as Arc, + alias.to_string(), + ) + }) + .collect(); + + projection_exec(exprs, input).unwrap() +} + +/// Create a sort expression with custom name and index +pub fn sort_expr_named(name: &str, index: usize) -> PhysicalSortExpr { + PhysicalSortExpr { + expr: Arc::new(expressions::Column::new(name, index)), + options: SortOptions::default(), + } +} + +/// A test data source that can display any requested ordering +/// This is useful for testing sort pushdown behavior +#[derive(Debug, Clone)] +pub struct TestScan { + schema: SchemaRef, + output_ordering: Vec, + plan_properties: Arc, + // Store the requested ordering for display + requested_ordering: Option, +} + +impl TestScan { + /// Create a new TestScan with the given schema and output ordering + pub fn new(schema: SchemaRef, output_ordering: Vec) -> Self { + let eq_properties = if !output_ordering.is_empty() { + // Convert Vec to the format expected by new_with_orderings + // We need to extract the inner Vec from each LexOrdering + let orderings: Vec> = output_ordering + .iter() + .map(|lex_ordering| { + // LexOrdering implements IntoIterator, so we can collect it + lex_ordering.iter().cloned().collect() + }) + .collect(); + + EquivalenceProperties::new_with_orderings(Arc::clone(&schema), orderings) + } else { + EquivalenceProperties::new(Arc::clone(&schema)) + }; + + let plan_properties = PlanProperties::new( + eq_properties, + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + ); + + Self { + schema, + output_ordering, + plan_properties: Arc::new(plan_properties), + requested_ordering: None, + } + } + + /// Create a TestScan with a single output ordering + pub fn with_ordering(schema: SchemaRef, ordering: LexOrdering) -> Self { + Self::new(schema, vec![ordering]) + } +} + +impl DisplayAs for TestScan { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "TestScan")?; + if !self.output_ordering.is_empty() { + write!(f, ": output_ordering=[")?; + // Format the ordering in a readable way + for (i, sort_expr) in self.output_ordering[0].iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{sort_expr}")?; + } + write!(f, "]")?; + } + // This is the key part - show what ordering was requested + if let Some(ref req) = self.requested_ordering { + write!(f, ", requested_ordering=[")?; + for (i, sort_expr) in req.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{sort_expr}")?; + } + write!(f, "]")?; + } + Ok(()) + } + DisplayFormatType::TreeRender => { + write!(f, "TestScan") + } + } + } +} + +impl ExecutionPlan for TestScan { + fn name(&self) -> &str { + "TestScan" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &Arc { + &self.plan_properties + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + if children.is_empty() { + Ok(self) + } else { + internal_err!("TestScan should have no children") + } + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + internal_err!("TestScan is for testing optimizer only, not for execution") + } + + fn partition_statistics(&self, _partition: Option) -> Result> { + Ok(Arc::new(Statistics::new_unknown(&self.schema))) + } + + // This is the key method - implement sort pushdown + fn try_pushdown_sort( + &self, + order: &[PhysicalSortExpr], + ) -> Result>> { + // For testing purposes, accept ANY ordering request + // and create a new TestScan that shows what was requested + let requested_ordering = LexOrdering::new(order.to_vec()); + + let mut new_scan = self.clone(); + new_scan.requested_ordering = requested_ordering; + + // Always return Inexact to keep the Sort node (like Phase 1 behavior) + Ok(SortOrderPushdownResult::Inexact { + inner: Arc::new(new_scan), + }) + } + + fn apply_expressions( + &self, + f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + // Visit expressions in output_ordering + let mut tnr = TreeNodeRecursion::Continue; + for ordering in &self.output_ordering { + for sort_expr in ordering { + tnr = tnr.visit_sibling(|| f(sort_expr.expr.as_ref()))?; + } + } + + // Visit expressions in requested_ordering if present + if let Some(ordering) = &self.requested_ordering { + for sort_expr in ordering { + tnr = tnr.visit_sibling(|| f(sort_expr.expr.as_ref()))?; + } + } + + Ok(tnr) + } +} + +/// Helper function to create a TestScan with ordering +pub fn test_scan_with_ordering( + schema: SchemaRef, + ordering: LexOrdering, +) -> Arc { + Arc::new(TestScan::with_ordering(schema, ordering)) +} diff --git a/datafusion/core/tests/physical_optimizer/window_optimize.rs b/datafusion/core/tests/physical_optimizer/window_optimize.rs new file mode 100644 index 0000000000000..796f6b6259716 --- /dev/null +++ b/datafusion/core/tests/physical_optimizer/window_optimize.rs @@ -0,0 +1,95 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[cfg(test)] +mod test { + use arrow::array::{Int32Array, RecordBatch}; + use arrow_schema::{DataType, Field, Schema}; + use datafusion_common::Result; + use datafusion_datasource::memory::MemorySourceConfig; + use datafusion_datasource::source::DataSourceExec; + use datafusion_execution::TaskContext; + use datafusion_expr::WindowFrame; + use datafusion_functions_aggregate::count::count_udaf; + use datafusion_physical_expr::aggregate::AggregateExprBuilder; + use datafusion_physical_expr::expressions::{Column, col}; + use datafusion_physical_expr::window::PlainAggregateWindowExpr; + use datafusion_physical_plan::windows::BoundedWindowAggExec; + use datafusion_physical_plan::{ExecutionPlan, InputOrderMode, common}; + use std::sync::Arc; + + /// Test case for + #[tokio::test] + async fn test_window_constant_aggregate() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + let c = Arc::new(Column::new("b", 1)); + let cnt = AggregateExprBuilder::new(count_udaf(), vec![c]) + .schema(schema.clone()) + .alias("t") + .build()?; + let partition = [col("a", &schema)?]; + let frame = WindowFrame::new(None); + let plain = PlainAggregateWindowExpr::new( + Arc::new(cnt), + &partition, + &[], + Arc::new(frame), + None, + ); + + let bounded_agg_exec = BoundedWindowAggExec::try_new( + vec![Arc::new(plain)], + source, + InputOrderMode::Linear, + true, + )?; + let task_ctx = Arc::new(TaskContext::default()); + common::collect(bounded_agg_exec.execute(0, task_ctx)?).await?; + + Ok(()) + } + + pub fn mock_data() -> Result> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![ + Some(1), + Some(1), + Some(3), + Some(2), + Some(1), + ])), + Arc::new(Int32Array::from(vec![ + Some(1), + Some(6), + Some(2), + Some(8), + Some(9), + ])), + ], + )?; + + MemorySourceConfig::try_new_exec(&[vec![batch]], Arc::clone(&schema), None) + } +} diff --git a/datafusion/core/tests/set_comparison.rs b/datafusion/core/tests/set_comparison.rs new file mode 100644 index 0000000000000..464d6c937b328 --- /dev/null +++ b/datafusion/core/tests/set_comparison.rs @@ -0,0 +1,193 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{Int32Array, StringArray}; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow::record_batch::RecordBatch; +use datafusion::prelude::SessionContext; +use datafusion_common::{Result, assert_batches_eq, assert_contains}; + +fn build_table(values: &[i32]) -> Result { + let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, true)])); + let array = + Arc::new(Int32Array::from(values.to_vec())) as Arc; + RecordBatch::try_new(schema, vec![array]).map_err(Into::into) +} + +#[tokio::test] +async fn set_comparison_any() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[1, 6, 10])?)?; + // Include a NULL in the subquery input to ensure we propagate UNKNOWN correctly. + ctx.register_batch("s", { + let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, true)])); + let array = Arc::new(Int32Array::from(vec![Some(5), None])) + as Arc; + RecordBatch::try_new(schema, vec![array])? + })?; + + let df = ctx + .sql("select v from t where v > any(select v from s)") + .await?; + let results = df.collect().await?; + + assert_batches_eq!( + &["+----+", "| v |", "+----+", "| 6 |", "| 10 |", "+----+",], + &results + ); + Ok(()) +} + +#[tokio::test] +async fn set_comparison_any_aggregate_subquery() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[1, 7])?)?; + ctx.register_batch("s", build_table(&[1, 2, 3])?)?; + + let df = ctx + .sql( + "select v from t where v > any(select sum(v) from s group by v % 2) order by v", + ) + .await?; + let results = df.collect().await?; + + assert_batches_eq!(&["+---+", "| v |", "+---+", "| 7 |", "+---+",], &results); + Ok(()) +} + +#[tokio::test] +async fn set_comparison_all_empty() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[1, 6, 10])?)?; + ctx.register_batch( + "e", + RecordBatch::new_empty(Arc::new(Schema::new(vec![Field::new( + "v", + DataType::Int32, + true, + )]))), + )?; + + let df = ctx + .sql("select v from t where v < all(select v from e)") + .await?; + let results = df.collect().await?; + + assert_batches_eq!( + &[ + "+----+", "| v |", "+----+", "| 1 |", "| 6 |", "| 10 |", "+----+", + ], + &results + ); + Ok(()) +} + +#[tokio::test] +async fn set_comparison_type_mismatch() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[1])?)?; + ctx.register_batch("strings", { + let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)])); + let array = Arc::new(StringArray::from(vec![Some("a"), Some("b")])) + as Arc; + RecordBatch::try_new(schema, vec![array])? + })?; + + let df = ctx + .sql("select v from t where v > any(select s from strings)") + .await?; + let err = df.collect().await.unwrap_err(); + assert_contains!( + err.to_string(), + "expr type Int32 can't cast to Utf8 in SetComparison" + ); + Ok(()) +} + +#[tokio::test] +async fn set_comparison_multiple_operators() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[1, 2, 3, 4])?)?; + ctx.register_batch("s", build_table(&[2, 3])?)?; + + let df = ctx + .sql("select v from t where v = any(select v from s) order by v") + .await?; + let results = df.collect().await?; + assert_batches_eq!( + &["+---+", "| v |", "+---+", "| 2 |", "| 3 |", "+---+",], + &results + ); + + let df = ctx + .sql("select v from t where v != all(select v from s) order by v") + .await?; + let results = df.collect().await?; + assert_batches_eq!( + &["+---+", "| v |", "+---+", "| 1 |", "| 4 |", "+---+",], + &results + ); + + let df = ctx + .sql("select v from t where v >= all(select v from s) order by v") + .await?; + let results = df.collect().await?; + assert_batches_eq!( + &["+---+", "| v |", "+---+", "| 3 |", "| 4 |", "+---+",], + &results + ); + + let df = ctx + .sql("select v from t where v <= any(select v from s) order by v") + .await?; + let results = df.collect().await?; + assert_batches_eq!( + &[ + "+---+", "| v |", "+---+", "| 1 |", "| 2 |", "| 3 |", "+---+", + ], + &results + ); + Ok(()) +} + +#[tokio::test] +async fn set_comparison_null_semantics_all() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[5])?)?; + ctx.register_batch("s", { + let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, true)])); + let array = Arc::new(Int32Array::from(vec![Some(1), None])) + as Arc; + RecordBatch::try_new(schema, vec![array])? + })?; + + let df = ctx + .sql("select v from t where v != all(select v from s)") + .await?; + let results = df.collect().await?; + let row_count: usize = results.iter().map(|batch| batch.num_rows()).sum(); + assert_eq!(0, row_count); + Ok(()) +} diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates/basic.rs similarity index 78% rename from datafusion/core/tests/sql/aggregates.rs rename to datafusion/core/tests/sql/aggregates/basic.rs index 52372e01d41ac..d1b376b735ab9 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates/basic.rs @@ -16,7 +16,10 @@ // under the License. use super::*; -use datafusion::scalar::ScalarValue; +use datafusion::common::test_util::batches_to_string; +use datafusion_catalog::MemTable; +use datafusion_common::ScalarValue; +use insta::assert_snapshot; #[tokio::test] async fn csv_query_array_agg_distinct() -> Result<()> { @@ -45,11 +48,11 @@ async fn csv_query_array_agg_distinct() -> Result<()> { let column = actual[0].column(0); assert_eq!(column.len(), 1); let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&column)?; - let mut scalars = scalar_vec[0].clone(); + let mut scalars = scalar_vec[0].as_ref().unwrap().clone(); // workaround lack of Ord of ScalarValue let cmp = |a: &ScalarValue, b: &ScalarValue| { - a.partial_cmp(b).expect("Can compare ScalarValues") + a.try_cmp(b).expect("Can compare ScalarValues") }; scalars.sort_by(cmp); assert_eq!( @@ -321,3 +324,120 @@ async fn test_accumulator_row_accumulator() -> Result<()> { Ok(()) } + +/// Test that COUNT(DISTINCT) correctly handles dictionary arrays with all null values. +/// Verifies behavior across both single and multiple partitions. +#[tokio::test] +async fn count_distinct_dictionary_all_null_values() -> Result<()> { + let n: usize = 5; + let num = Arc::new(Int32Array::from_iter(0..n as i32)) as ArrayRef; + + // Create dictionary where all indices point to a null value (index 0) + let dict_values = StringArray::from(vec![None, Some("abc")]); + let dict_indices = Int32Array::from(vec![0; n]); + let dict = DictionaryArray::new(dict_indices, Arc::new(dict_values)); + + let schema = Arc::new(Schema::new(vec![ + Field::new("num1", DataType::Int32, false), + Field::new("num2", DataType::Int32, false), + Field::new( + "dict", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + ), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![num.clone(), num.clone(), Arc::new(dict)], + )?; + + // Test with single partition + let ctx = + SessionContext::new_with_config(SessionConfig::new().with_target_partitions(1)); + let provider = MemTable::try_new(schema.clone(), vec![vec![batch.clone()]])?; + ctx.register_table("t", Arc::new(provider))?; + + let df = ctx + .sql("SELECT count(distinct dict) as cnt, count(num2) FROM t GROUP BY num1") + .await?; + let results = df.collect().await?; + + assert_snapshot!( + batches_to_string(&results), + @r" + +-----+---------------+ + | cnt | count(t.num2) | + +-----+---------------+ + | 0 | 1 | + | 0 | 1 | + | 0 | 1 | + | 0 | 1 | + | 0 | 1 | + +-----+---------------+ + " + ); + + // Test with multiple partitions + let ctx_multi = + SessionContext::new_with_config(SessionConfig::new().with_target_partitions(2)); + let provider_multi = MemTable::try_new(schema, vec![vec![batch]])?; + ctx_multi.register_table("t", Arc::new(provider_multi))?; + + let df_multi = ctx_multi + .sql("SELECT count(distinct dict) as cnt, count(num2) FROM t GROUP BY num1") + .await?; + let results_multi = df_multi.collect().await?; + + // Results should be identical across partition configurations + assert_eq!( + batches_to_string(&results), + batches_to_string(&results_multi) + ); + + Ok(()) +} + +/// Test COUNT(DISTINCT) with mixed null and non-null dictionary values +#[tokio::test] +async fn count_distinct_dictionary_mixed_values() -> Result<()> { + let n: usize = 6; + let num = Arc::new(Int32Array::from_iter(0..n as i32)) as ArrayRef; + + // Dictionary values array with nulls and non-nulls + let dict_values = StringArray::from(vec![None, Some("abc"), Some("def"), None]); + // Create indices that point to both null and non-null values + let dict_indices = Int32Array::from(vec![0, 1, 2, 0, 1, 3]); + let dict = DictionaryArray::new(dict_indices, Arc::new(dict_values)); + + let schema = Arc::new(Schema::new(vec![ + Field::new("num1", DataType::Int32, false), + Field::new( + "dict", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + ), + ])); + + let batch = RecordBatch::try_new(schema.clone(), vec![num, Arc::new(dict)])?; + let provider = MemTable::try_new(schema, vec![vec![batch]])?; + let ctx = SessionContext::new(); + ctx.register_table("t", Arc::new(provider))?; + + // COUNT(DISTINCT) should only count non-null values "abc" and "def" + let df = ctx.sql("SELECT count(distinct dict) FROM t").await?; + let results = df.collect().await?; + + assert_snapshot!( + batches_to_string(&results), + @r" + +------------------------+ + | count(DISTINCT t.dict) | + +------------------------+ + | 2 | + +------------------------+ + " + ); + + Ok(()) +} diff --git a/datafusion/core/tests/sql/aggregates/dict_nulls.rs b/datafusion/core/tests/sql/aggregates/dict_nulls.rs new file mode 100644 index 0000000000000..f9e15a71a20f8 --- /dev/null +++ b/datafusion/core/tests/sql/aggregates/dict_nulls.rs @@ -0,0 +1,454 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; +use datafusion::common::test_util::batches_to_string; +use insta::assert_snapshot; + +/// Comprehensive test for aggregate functions with null values and dictionary columns +/// Tests COUNT, SUM, MIN, and MEDIAN null handling in single comprehensive test +#[tokio::test] +async fn test_aggregates_null_handling_comprehensive() -> Result<()> { + let test_data_basic = TestData::new(); + let test_data_extended = TestData::new_extended(); + let test_data_min_max = TestData::new_for_min_max(); + let test_data_median = TestData::new_for_median(); + + // Test COUNT null exclusion with basic data + let sql_count = "SELECT dict_null_keys, COUNT(value) as cnt FROM t GROUP BY dict_null_keys ORDER BY dict_null_keys NULLS FIRST"; + let results_count = run_snapshot_test(&test_data_basic, sql_count).await?; + + assert_snapshot!( + batches_to_string(&results_count), + @r" + +----------------+-----+ + | dict_null_keys | cnt | + +----------------+-----+ + | | 0 | + | group_a | 2 | + | group_b | 1 | + +----------------+-----+ + " + ); + + // Test SUM null handling with extended data + let sql_sum = "SELECT dict_null_vals, SUM(value) as total FROM t GROUP BY dict_null_vals ORDER BY dict_null_vals NULLS FIRST"; + let results_sum = run_snapshot_test(&test_data_extended, sql_sum).await?; + + assert_snapshot!( + batches_to_string(&results_sum), + @r" + +----------------+-------+ + | dict_null_vals | total | + +----------------+-------+ + | | 4 | + | group_x | 4 | + | group_y | 2 | + | group_z | 5 | + +----------------+-------+ + " + ); + + // Test MIN null handling with min/max data + let sql_min = "SELECT dict_null_keys, MIN(value) as minimum FROM t GROUP BY dict_null_keys ORDER BY dict_null_keys NULLS FIRST"; + let results_min = run_snapshot_test(&test_data_min_max, sql_min).await?; + + assert_snapshot!( + batches_to_string(&results_min), + @r" + +----------------+---------+ + | dict_null_keys | minimum | + +----------------+---------+ + | | 2 | + | group_a | 3 | + | group_b | 1 | + | group_c | 7 | + +----------------+---------+ + " + ); + + // Test MEDIAN null handling with median data + let sql_median = "SELECT dict_null_vals, MEDIAN(value) as median_value FROM t GROUP BY dict_null_vals ORDER BY dict_null_vals NULLS FIRST"; + let results_median = run_snapshot_test(&test_data_median, sql_median).await?; + + assert_snapshot!( + batches_to_string(&results_median), + @r" + +----------------+--------------+ + | dict_null_vals | median_value | + +----------------+--------------+ + | | 3 | + | group_x | 1 | + | group_y | 5 | + | group_z | 7 | + +----------------+--------------+ + "); + + Ok(()) +} + +/// Test FIRST_VAL and LAST_VAL with null values and GROUP BY dict with null keys and null values - may return null if first/last value is null (single and multiple partitions) +#[tokio::test] +async fn test_first_last_val_null_handling() -> Result<()> { + let test_data = TestData::new_for_first_last(); + + // Test FIRST_VALUE and LAST_VALUE with window functions over groups + let sql = "SELECT dict_null_keys, value, FIRST_VALUE(value) OVER (PARTITION BY dict_null_keys ORDER BY value NULLS FIRST) as first_val, LAST_VALUE(value) OVER (PARTITION BY dict_null_keys ORDER BY value NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) as last_val FROM t ORDER BY dict_null_keys NULLS FIRST, value NULLS FIRST"; + + let results_single = run_snapshot_test(&test_data, sql).await?; + + assert_snapshot!(batches_to_string(&results_single), @r" + +----------------+-------+-----------+----------+ + | dict_null_keys | value | first_val | last_val | + +----------------+-------+-----------+----------+ + | | 1 | 1 | 3 | + | | 3 | 1 | 3 | + | group_a | | | | + | group_a | | | | + | group_b | 2 | 2 | 2 | + +----------------+-------+-----------+----------+ + "); + + Ok(()) +} + +/// Test FIRST_VALUE and LAST_VALUE with ORDER BY - comprehensive null handling +#[tokio::test] +async fn test_first_last_value_order_by_null_handling() -> Result<()> { + let ctx = SessionContext::new(); + + // Create test data with nulls mixed in + let dict_keys = create_test_dict( + &[Some("group_a"), Some("group_b"), Some("group_c")], + &[Some(0), Some(1), Some(2), Some(0), Some(1)], + ); + + let values = Int32Array::from(vec![None, Some(10), Some(20), Some(5), None]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("dict_group", string_dict_type(), true), + Field::new("value", DataType::Int32, true), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(dict_keys), Arc::new(values)], + )?; + + let table = MemTable::try_new(schema, vec![vec![batch]])?; + ctx.register_table("test_data", Arc::new(table))?; + + // Test all combinations of FIRST_VALUE and LAST_VALUE with null handling + let sql = "SELECT + dict_group, + value, + FIRST_VALUE(value IGNORE NULLS) OVER (ORDER BY value NULLS LAST) as first_ignore_nulls, + FIRST_VALUE(value RESPECT NULLS) OVER (ORDER BY value NULLS FIRST) as first_respect_nulls, + LAST_VALUE(value IGNORE NULLS) OVER (ORDER BY value NULLS LAST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) as last_ignore_nulls, + LAST_VALUE(value RESPECT NULLS) OVER (ORDER BY value NULLS LAST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) as last_respect_nulls + FROM test_data + ORDER BY value NULLS LAST"; + + let df = ctx.sql(sql).await?; + let results = df.collect().await?; + + assert_snapshot!( + batches_to_string(&results), + @r" + +------------+-------+--------------------+---------------------+-------------------+--------------------+ + | dict_group | value | first_ignore_nulls | first_respect_nulls | last_ignore_nulls | last_respect_nulls | + +------------+-------+--------------------+---------------------+-------------------+--------------------+ + | group_a | 5 | 5 | | 20 | | + | group_b | 10 | 5 | | 20 | | + | group_c | 20 | 5 | | 20 | | + | group_a | | 5 | | 20 | | + | group_b | | 5 | | 20 | | + +------------+-------+--------------------+---------------------+-------------------+--------------------+ + " + ); + + Ok(()) +} + +/// Test GROUP BY with dictionary columns containing null keys and values for FIRST_VALUE/LAST_VALUE +#[tokio::test] +async fn test_first_last_value_group_by_dict_nulls() -> Result<()> { + let ctx = SessionContext::new(); + + // Create dictionary with null keys + let dict_null_keys = create_test_dict( + &[Some("group_a"), Some("group_b")], + &[ + Some(0), // group_a + None, // null key + Some(1), // group_b + None, // null key + Some(0), // group_a + ], + ); + + // Create dictionary with null values + let dict_null_vals = create_test_dict( + &[Some("val_x"), None, Some("val_y")], + &[ + Some(0), // val_x + Some(1), // null value + Some(2), // val_y + Some(1), // null value + Some(0), // val_x + ], + ); + + // Create test values + let values = Int32Array::from(vec![Some(10), Some(20), Some(30), Some(40), Some(50)]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("dict_null_keys", string_dict_type(), true), + Field::new("dict_null_vals", string_dict_type(), true), + Field::new("value", DataType::Int32, true), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(dict_null_keys), + Arc::new(dict_null_vals), + Arc::new(values), + ], + )?; + + let table = MemTable::try_new(schema, vec![vec![batch]])?; + ctx.register_table("test_data", Arc::new(table))?; + + // Test GROUP BY with null keys + let sql = "SELECT + dict_null_keys, + FIRST_VALUE(value) as first_val, + LAST_VALUE(value) as last_val, + COUNT(*) as cnt + FROM test_data + GROUP BY dict_null_keys + ORDER BY dict_null_keys NULLS FIRST"; + + let df = ctx.sql(sql).await?; + let results = df.collect().await?; + + assert_snapshot!( + batches_to_string(&results), + @r" + +----------------+-----------+----------+-----+ + | dict_null_keys | first_val | last_val | cnt | + +----------------+-----------+----------+-----+ + | | 20 | 40 | 2 | + | group_a | 10 | 50 | 2 | + | group_b | 30 | 30 | 1 | + +----------------+-----------+----------+-----+ + " + ); + + // Test GROUP BY with null values in dictionary + let sql2 = "SELECT + dict_null_vals, + FIRST_VALUE(value) as first_val, + LAST_VALUE(value) as last_val, + COUNT(*) as cnt + FROM test_data + GROUP BY dict_null_vals + ORDER BY dict_null_vals NULLS FIRST"; + + let df2 = ctx.sql(sql2).await?; + let results2 = df2.collect().await?; + + assert_snapshot!( + batches_to_string(&results2), + @r" + +----------------+-----------+----------+-----+ + | dict_null_vals | first_val | last_val | cnt | + +----------------+-----------+----------+-----+ + | | 20 | 40 | 2 | + | val_x | 10 | 50 | 2 | + | val_y | 30 | 30 | 1 | + +----------------+-----------+----------+-----+ + " + ); + + Ok(()) +} + +/// Test MAX with dictionary columns containing null keys and values as specified in the SQL query +#[tokio::test] +async fn test_max_with_fuzz_table_dict_nulls() -> Result<()> { + let (ctx_single, ctx_multi) = setup_fuzz_test_contexts().await?; + + // Execute the SQL query with MAX aggregations + let sql = "SELECT + u8_low, + dictionary_utf8_low, + utf8_low, + max(utf8_low) as col1, + max(utf8) as col2 + FROM + fuzz_table + GROUP BY + u8_low, + dictionary_utf8_low, + utf8_low + ORDER BY u8_low, dictionary_utf8_low NULLS FIRST, utf8_low"; + + let results = test_query_consistency(&ctx_single, &ctx_multi, sql).await?; + + assert_snapshot!( + batches_to_string(&results), + @r" + +--------+---------------------+----------+-------+---------+ + | u8_low | dictionary_utf8_low | utf8_low | col1 | col2 | + +--------+---------------------+----------+-------+---------+ + | 1 | | str_b | str_b | value_2 | + | 1 | dict_a | str_a | str_a | value_5 | + | 2 | | str_c | str_c | value_7 | + | 2 | | str_d | str_d | value_4 | + | 2 | dict_b | str_c | str_c | value_3 | + | 3 | | str_e | str_e | | + | 3 | dict_c | str_f | str_f | value_6 | + +--------+---------------------+----------+-------+---------+ + "); + + Ok(()) +} + +/// Test MIN with fuzz table containing dictionary columns with null keys and values and timestamp data (single and multiple partitions) +#[tokio::test] +async fn test_min_timestamp_with_fuzz_table_dict_nulls() -> Result<()> { + let (ctx_single, ctx_multi) = setup_fuzz_timestamp_test_contexts().await?; + + // Execute the SQL query with MIN aggregation on timestamp + let sql = "SELECT + utf8_low, + u8_low, + dictionary_utf8_low, + min(timestamp_us) as col1 + FROM + fuzz_table + GROUP BY + utf8_low, + u8_low, + dictionary_utf8_low + ORDER BY utf8_low, u8_low, dictionary_utf8_low NULLS FIRST"; + + let results = test_query_consistency(&ctx_single, &ctx_multi, sql).await?; + + assert_snapshot!( + batches_to_string(&results), + @r" + +----------+--------+---------------------+-------------------------+ + | utf8_low | u8_low | dictionary_utf8_low | col1 | + +----------+--------+---------------------+-------------------------+ + | alpha | 10 | dict_x | 1970-01-01T00:00:01 | + | beta | 20 | | 1970-01-01T00:00:02 | + | delta | 20 | | 1970-01-01T00:00:03.500 | + | epsilon | 40 | | 1970-01-01T00:00:04 | + | gamma | 30 | dict_y | 1970-01-01T00:00:02.800 | + | zeta | 30 | dict_z | 1970-01-01T00:00:02.500 | + +----------+--------+---------------------+-------------------------+ + " + ); + + Ok(()) +} + +/// Test COUNT and COUNT DISTINCT with fuzz table containing dictionary columns with null keys and values (single and multiple partitions) +#[tokio::test] +async fn test_count_distinct_with_fuzz_table_dict_nulls() -> Result<()> { + let (ctx_single, ctx_multi) = setup_fuzz_count_test_contexts().await?; + + // Execute the SQL query with COUNT and COUNT DISTINCT aggregations + let sql = "SELECT + u8_low, + utf8_low, + dictionary_utf8_low, + count(duration_nanosecond) as col1, + count(DISTINCT large_binary) as col2 + FROM + fuzz_table + GROUP BY + u8_low, + utf8_low, + dictionary_utf8_low + ORDER BY u8_low, utf8_low, dictionary_utf8_low NULLS FIRST"; + + let results = test_query_consistency(&ctx_single, &ctx_multi, sql).await?; + + assert_snapshot!( + batches_to_string(&results), + @r" + +--------+----------+---------------------+------+------+ + | u8_low | utf8_low | dictionary_utf8_low | col1 | col2 | + +--------+----------+---------------------+------+------+ + | 5 | text_a | group_alpha | 3 | 1 | + | 10 | text_b | | 1 | 1 | + | 10 | text_d | | 2 | 0 | + | 15 | text_c | group_beta | 1 | 1 | + | 20 | text_e | | 0 | 1 | + | 25 | text_f | group_gamma | 1 | 1 | + +--------+----------+---------------------+------+------+ + " + ); + + Ok(()) +} + +/// Test MEDIAN and MEDIAN DISTINCT with fuzz table containing various numeric types and dictionary columns with null keys and values (single and multiple partitions) +#[tokio::test] +async fn test_median_distinct_with_fuzz_table_dict_nulls() -> Result<()> { + let (ctx_single, ctx_multi) = setup_fuzz_median_test_contexts().await?; + + // Execute the SQL query with MEDIAN and MEDIAN DISTINCT aggregations + let sql = "SELECT + u8_low, + dictionary_utf8_low, + median(DISTINCT u64) as col1, + median(DISTINCT u16) as col2, + median(u64) as col3, + median(decimal128) as col4, + median(DISTINCT u32) as col5 + FROM + fuzz_table + GROUP BY + u8_low, + dictionary_utf8_low + ORDER BY u8_low, dictionary_utf8_low NULLS FIRST"; + + let results = test_query_consistency(&ctx_single, &ctx_multi, sql).await?; + + assert_snapshot!( + batches_to_string(&results), + @r" + +--------+---------------------+------+------+------+--------+--------+ + | u8_low | dictionary_utf8_low | col1 | col2 | col3 | col4 | col5 | + +--------+---------------------+------+------+------+--------+--------+ + | 50 | | | 30 | | 987.65 | 400000 | + | 50 | group_three | 5000 | 50 | 5000 | 555.55 | 500000 | + | 75 | | 4000 | | 4000 | | 450000 | + | 100 | group_one | 1100 | 11 | 1000 | 123.45 | 110000 | + | 100 | group_two | 1500 | 15 | 1500 | 111.11 | 150000 | + | 200 | | 2500 | 22 | 2500 | 506.11 | 250000 | + +--------+---------------------+------+------+------+--------+--------+ + " + ); + + Ok(()) +} diff --git a/datafusion/core/tests/sql/aggregates/mod.rs b/datafusion/core/tests/sql/aggregates/mod.rs new file mode 100644 index 0000000000000..ede40d5c4ceca --- /dev/null +++ b/datafusion/core/tests/sql/aggregates/mod.rs @@ -0,0 +1,1026 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Aggregate function tests + +use super::*; +use arrow::{ + array::{ + Decimal128Array, DictionaryArray, DurationNanosecondArray, Int32Array, + LargeBinaryArray, StringArray, TimestampMicrosecondArray, UInt8Array, + UInt16Array, UInt32Array, UInt64Array, types::UInt32Type, + }, + datatypes::{DataType, Field, Schema, TimeUnit}, + record_batch::RecordBatch, +}; +use datafusion::{ + common::{Result, test_util::batches_to_string}, + execution::{config::SessionConfig, context::SessionContext}, +}; +use datafusion_catalog::MemTable; +use std::{cmp::min, sync::Arc}; +/// Helper function to create the commonly used UInt32 indexed UTF-8 dictionary data type +pub fn string_dict_type() -> DataType { + DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)) +} + +/// Helper functions for aggregate tests with dictionary columns and nulls +/// Creates a dictionary array with null values in the dictionary +pub fn create_test_dict( + values: &[Option<&str>], + indices: &[Option], +) -> DictionaryArray { + let dict_values = StringArray::from(values.to_vec()); + let dict_indices = UInt32Array::from(indices.to_vec()); + DictionaryArray::new(dict_indices, Arc::new(dict_values)) +} + +/// Creates test data with both dictionary columns and value column +pub struct TestData { + pub dict_null_keys: DictionaryArray, + pub dict_null_vals: DictionaryArray, + pub values: Int32Array, + pub schema: Arc, +} + +impl TestData { + pub fn new() -> Self { + // Create dictionary with null keys + let dict_null_keys = create_test_dict( + &[Some("group_a"), Some("group_b")], + &[ + Some(0), // group_a + None, // null key + Some(1), // group_b + None, // null key + Some(0), // group_a + ], + ); + + // Create dictionary with null values + let dict_null_vals = create_test_dict( + &[Some("group_x"), None, Some("group_y")], + &[ + Some(0), // group_x + Some(1), // null value + Some(2), // group_y + Some(1), // null value + Some(0), // group_x + ], + ); + + // Create test data with nulls + let values = Int32Array::from(vec![Some(1), None, Some(2), None, Some(3)]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("dict_null_keys", string_dict_type(), true), + Field::new("dict_null_vals", string_dict_type(), true), + Field::new("value", DataType::Int32, true), + ])); + + Self { + dict_null_keys, + dict_null_vals, + values, + schema, + } + } + + /// Creates extended test data for more comprehensive testing + pub fn new_extended() -> Self { + // Create dictionary with null values in the dictionary array + let dict_null_vals = create_test_dict( + &[Some("group_a"), None, Some("group_b")], + &[ + Some(0), // group_a + Some(1), // null value + Some(2), // group_b + Some(1), // null value + Some(0), // group_a + Some(1), // null value + Some(2), // group_b + Some(1), // null value + ], + ); + + // Create dictionary with null keys + let dict_null_keys = create_test_dict( + &[Some("group_x"), Some("group_y"), Some("group_z")], + &[ + Some(0), // group_x + None, // null key + Some(1), // group_y + None, // null key + Some(0), // group_x + None, // null key + Some(2), // group_z + None, // null key + ], + ); + + // Create test data with nulls + let values = Int32Array::from(vec![ + Some(1), + None, + Some(2), + None, + Some(3), + Some(4), + Some(5), + None, + ]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("dict_null_vals", string_dict_type(), true), + Field::new("dict_null_keys", string_dict_type(), true), + Field::new("value", DataType::Int32, true), + ])); + + Self { + dict_null_keys, + dict_null_vals, + values, + schema, + } + } + + /// Creates test data for MIN/MAX testing with varied values + pub fn new_for_min_max() -> Self { + let dict_null_keys = create_test_dict( + &[Some("group_a"), Some("group_b"), Some("group_c")], + &[ + Some(0), + Some(1), + Some(0), + Some(2), + None, + None, // group_a, group_b, group_a, group_c, null, null + ], + ); + + let dict_null_vals = create_test_dict( + &[Some("group_x"), None, Some("group_y")], + &[ + Some(0), + Some(1), + Some(0), + Some(2), + Some(1), + Some(1), // group_x, null, group_x, group_y, null, null + ], + ); + + let values = + Int32Array::from(vec![Some(5), Some(1), Some(3), Some(7), Some(2), None]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("dict_null_keys", string_dict_type(), true), + Field::new("dict_null_vals", string_dict_type(), true), + Field::new("value", DataType::Int32, true), + ])); + + Self { + dict_null_keys, + dict_null_vals, + values, + schema, + } + } + + /// Creates test data for MEDIAN testing with varied values + pub fn new_for_median() -> Self { + let dict_null_vals = create_test_dict( + &[Some("group_a"), None, Some("group_b")], + &[Some(0), Some(1), Some(2), Some(1), Some(0)], + ); + + let dict_null_keys = create_test_dict( + &[Some("group_x"), Some("group_y"), Some("group_z")], + &[Some(0), None, Some(1), None, Some(2)], + ); + + let values = Int32Array::from(vec![Some(1), None, Some(5), Some(3), Some(7)]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("dict_null_vals", string_dict_type(), true), + Field::new("dict_null_keys", string_dict_type(), true), + Field::new("value", DataType::Int32, true), + ])); + + Self { + dict_null_keys, + dict_null_vals, + values, + schema, + } + } + + /// Creates test data for FIRST_VALUE/LAST_VALUE testing + pub fn new_for_first_last() -> Self { + let dict_null_keys = create_test_dict( + &[Some("group_a"), Some("group_b")], + &[Some(0), None, Some(1), None, Some(0)], + ); + + let dict_null_vals = create_test_dict( + &[Some("group_x"), None, Some("group_y")], + &[Some(0), Some(1), Some(2), Some(1), Some(0)], + ); + + let values = Int32Array::from(vec![None, Some(1), Some(2), Some(3), None]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("dict_null_keys", string_dict_type(), true), + Field::new("dict_null_vals", string_dict_type(), true), + Field::new("value", DataType::Int32, true), + ])); + + Self { + dict_null_keys, + dict_null_vals, + values, + schema, + } + } +} + +/// Sets up test contexts for TestData with both single and multiple partitions +pub async fn setup_test_contexts( + test_data: &TestData, +) -> Result<(SessionContext, SessionContext)> { + // Single partition context + let ctx_single = create_context_with_partitions(test_data, 1).await?; + + // Multiple partition context + let ctx_multi = create_context_with_partitions(test_data, 3).await?; + + Ok((ctx_single, ctx_multi)) +} + +/// Creates a session context with the specified number of partitions and registers test data +pub async fn create_context_with_partitions( + test_data: &TestData, + num_partitions: usize, +) -> Result { + let ctx = SessionContext::new_with_config( + SessionConfig::new().with_target_partitions(num_partitions), + ); + + let batches = split_test_data_into_batches(test_data, num_partitions)?; + let provider = MemTable::try_new(test_data.schema.clone(), batches)?; + ctx.register_table("t", Arc::new(provider))?; + + Ok(ctx) +} + +/// Splits test data into multiple batches for partitioning +pub fn split_test_data_into_batches( + test_data: &TestData, + num_partitions: usize, +) -> Result>> { + debug_assert!(num_partitions > 0, "num_partitions must be greater than 0"); + let total_len = test_data.values.len(); + let chunk_size = total_len.div_ceil(num_partitions); // Ensure we cover all data + + let mut batches = Vec::new(); + let mut start = 0; + + while start < total_len { + let end = min(start + chunk_size, total_len); + let len = end - start; + + if len > 0 { + let batch = RecordBatch::try_new( + test_data.schema.clone(), + vec![ + Arc::new(test_data.dict_null_keys.slice(start, len)), + Arc::new(test_data.dict_null_vals.slice(start, len)), + Arc::new(test_data.values.slice(start, len)), + ], + )?; + batches.push(vec![batch]); + } + start = end; + } + + Ok(batches) +} + +/// Executes a query on both single and multi-partition contexts and verifies consistency +pub async fn test_query_consistency( + ctx_single: &SessionContext, + ctx_multi: &SessionContext, + sql: &str, +) -> Result> { + let df_single = ctx_single.sql(sql).await?; + let results_single = df_single.collect().await?; + + let df_multi = ctx_multi.sql(sql).await?; + let results_multi = df_multi.collect().await?; + + // Verify results are consistent between single and multiple partitions + assert_eq!( + batches_to_string(&results_single), + batches_to_string(&results_multi), + "Results should be identical between single and multiple partitions" + ); + + Ok(results_single) +} + +/// Helper function to run snapshot tests with consistent setup, execution, and assertion +/// This reduces the repetitive pattern of "setup data → SQL → assert_snapshot!" +pub async fn run_snapshot_test( + test_data: &TestData, + sql: &str, +) -> Result> { + let (ctx_single, ctx_multi) = setup_test_contexts(test_data).await?; + let results = test_query_consistency(&ctx_single, &ctx_multi, sql).await?; + Ok(results) +} + +/// Test data structure for fuzz table with dictionary columns containing nulls +pub struct FuzzTestData { + pub schema: Arc, + pub u8_low: UInt8Array, + pub dictionary_utf8_low: DictionaryArray, + pub utf8_low: StringArray, + pub utf8: StringArray, +} + +impl FuzzTestData { + pub fn new() -> Self { + // Create dictionary columns with null keys and values + let dictionary_utf8_low = create_test_dict( + &[Some("dict_a"), None, Some("dict_b"), Some("dict_c")], + &[ + Some(0), // dict_a + Some(1), // null value + Some(2), // dict_b + None, // null key + Some(0), // dict_a + Some(1), // null value + Some(3), // dict_c + None, // null key + ], + ); + + let u8_low = UInt8Array::from(vec![ + Some(1), + Some(1), + Some(2), + Some(2), + Some(1), + Some(3), + Some(3), + Some(2), + ]); + + let utf8_low = StringArray::from(vec![ + Some("str_a"), + Some("str_b"), + Some("str_c"), + Some("str_d"), + Some("str_a"), + Some("str_e"), + Some("str_f"), + Some("str_c"), + ]); + + let utf8 = StringArray::from(vec![ + Some("value_1"), + Some("value_2"), + Some("value_3"), + Some("value_4"), + Some("value_5"), + None, + Some("value_6"), + Some("value_7"), + ]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("u8_low", DataType::UInt8, true), + Field::new("dictionary_utf8_low", string_dict_type(), true), + Field::new("utf8_low", DataType::Utf8, true), + Field::new("utf8", DataType::Utf8, true), + ])); + + Self { + schema, + u8_low, + dictionary_utf8_low, + utf8_low, + utf8, + } + } +} + +/// Sets up test contexts for fuzz table with both single and multiple partitions +pub async fn setup_fuzz_test_contexts() -> Result<(SessionContext, SessionContext)> { + let test_data = FuzzTestData::new(); + + // Single partition context + let ctx_single = create_fuzz_context_with_partitions(&test_data, 1).await?; + + // Multiple partition context + let ctx_multi = create_fuzz_context_with_partitions(&test_data, 3).await?; + + Ok((ctx_single, ctx_multi)) +} + +/// Creates a session context with fuzz table partitioned into specified number of partitions +pub async fn create_fuzz_context_with_partitions( + test_data: &FuzzTestData, + num_partitions: usize, +) -> Result { + let ctx = SessionContext::new_with_config( + SessionConfig::new().with_target_partitions(num_partitions), + ); + + let batches = split_fuzz_data_into_batches(test_data, num_partitions)?; + let provider = MemTable::try_new(test_data.schema.clone(), batches)?; + ctx.register_table("fuzz_table", Arc::new(provider))?; + + Ok(ctx) +} + +/// Splits fuzz test data into multiple batches for partitioning +pub fn split_fuzz_data_into_batches( + test_data: &FuzzTestData, + num_partitions: usize, +) -> Result>> { + debug_assert!(num_partitions > 0, "num_partitions must be greater than 0"); + let total_len = test_data.u8_low.len(); + let chunk_size = total_len.div_ceil(num_partitions); + + let mut batches = Vec::new(); + let mut start = 0; + + while start < total_len { + let end = min(start + chunk_size, total_len); + let len = end - start; + + if len > 0 { + let batch = RecordBatch::try_new( + test_data.schema.clone(), + vec![ + Arc::new(test_data.u8_low.slice(start, len)), + Arc::new(test_data.dictionary_utf8_low.slice(start, len)), + Arc::new(test_data.utf8_low.slice(start, len)), + Arc::new(test_data.utf8.slice(start, len)), + ], + )?; + batches.push(vec![batch]); + } + start = end; + } + + Ok(batches) +} + +/// Test data structure for fuzz table with duration, large_binary and dictionary columns containing nulls +pub struct FuzzCountTestData { + pub schema: Arc, + pub u8_low: UInt8Array, + pub utf8_low: StringArray, + pub dictionary_utf8_low: DictionaryArray, + pub duration_nanosecond: DurationNanosecondArray, + pub large_binary: LargeBinaryArray, +} + +impl FuzzCountTestData { + pub fn new() -> Self { + // Create dictionary columns with null keys and values + let dictionary_utf8_low = create_test_dict( + &[ + Some("group_alpha"), + None, + Some("group_beta"), + Some("group_gamma"), + ], + &[ + Some(0), // group_alpha + Some(1), // null value + Some(2), // group_beta + None, // null key + Some(0), // group_alpha + Some(1), // null value + Some(3), // group_gamma + None, // null key + Some(2), // group_beta + Some(0), // group_alpha + ], + ); + + let u8_low = UInt8Array::from(vec![ + Some(5), + Some(10), + Some(15), + Some(10), + Some(5), + Some(20), + Some(25), + Some(10), + Some(15), + Some(5), + ]); + + let utf8_low = StringArray::from(vec![ + Some("text_a"), + Some("text_b"), + Some("text_c"), + Some("text_d"), + Some("text_a"), + Some("text_e"), + Some("text_f"), + Some("text_d"), + Some("text_c"), + Some("text_a"), + ]); + + // Create duration data with some nulls (nanoseconds) + let duration_nanosecond = DurationNanosecondArray::from(vec![ + Some(1000000000), // 1 second + Some(2000000000), // 2 seconds + None, // null duration + Some(3000000000), // 3 seconds + Some(1500000000), // 1.5 seconds + None, // null duration + Some(4000000000), // 4 seconds + Some(2500000000), // 2.5 seconds + Some(3500000000), // 3.5 seconds + Some(1200000000), // 1.2 seconds + ]); + + // Create large binary data with some nulls and duplicates + let large_binary = LargeBinaryArray::from(vec![ + Some(b"binary_data_1".as_slice()), + Some(b"binary_data_2".as_slice()), + Some(b"binary_data_3".as_slice()), + None, // null binary + Some(b"binary_data_1".as_slice()), // duplicate + Some(b"binary_data_4".as_slice()), + Some(b"binary_data_5".as_slice()), + None, // null binary + Some(b"binary_data_3".as_slice()), // duplicate + Some(b"binary_data_1".as_slice()), // duplicate + ]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("u8_low", DataType::UInt8, true), + Field::new("utf8_low", DataType::Utf8, true), + Field::new("dictionary_utf8_low", string_dict_type(), true), + Field::new( + "duration_nanosecond", + DataType::Duration(TimeUnit::Nanosecond), + true, + ), + Field::new("large_binary", DataType::LargeBinary, true), + ])); + + Self { + schema, + u8_low, + utf8_low, + dictionary_utf8_low, + duration_nanosecond, + large_binary, + } + } +} + +/// Sets up test contexts for fuzz table with duration/binary columns and both single and multiple partitions +pub async fn setup_fuzz_count_test_contexts() -> Result<(SessionContext, SessionContext)> +{ + let test_data = FuzzCountTestData::new(); + + // Single partition context + let ctx_single = create_fuzz_count_context_with_partitions(&test_data, 1).await?; + + // Multiple partition context + let ctx_multi = create_fuzz_count_context_with_partitions(&test_data, 3).await?; + + Ok((ctx_single, ctx_multi)) +} + +/// Creates a session context with fuzz count table partitioned into specified number of partitions +pub async fn create_fuzz_count_context_with_partitions( + test_data: &FuzzCountTestData, + num_partitions: usize, +) -> Result { + let ctx = SessionContext::new_with_config( + SessionConfig::new().with_target_partitions(num_partitions), + ); + + let batches = split_fuzz_count_data_into_batches(test_data, num_partitions)?; + let provider = MemTable::try_new(test_data.schema.clone(), batches)?; + ctx.register_table("fuzz_table", Arc::new(provider))?; + + Ok(ctx) +} + +/// Splits fuzz count test data into multiple batches for partitioning +pub fn split_fuzz_count_data_into_batches( + test_data: &FuzzCountTestData, + num_partitions: usize, +) -> Result>> { + debug_assert!(num_partitions > 0, "num_partitions must be greater than 0"); + let total_len = test_data.u8_low.len(); + let chunk_size = total_len.div_ceil(num_partitions); + + let mut batches = Vec::new(); + let mut start = 0; + + while start < total_len { + let end = min(start + chunk_size, total_len); + let len = end - start; + + if len > 0 { + let batch = RecordBatch::try_new( + test_data.schema.clone(), + vec![ + Arc::new(test_data.u8_low.slice(start, len)), + Arc::new(test_data.utf8_low.slice(start, len)), + Arc::new(test_data.dictionary_utf8_low.slice(start, len)), + Arc::new(test_data.duration_nanosecond.slice(start, len)), + Arc::new(test_data.large_binary.slice(start, len)), + ], + )?; + batches.push(vec![batch]); + } + start = end; + } + + Ok(batches) +} + +/// Test data structure for fuzz table with numeric types for median testing and dictionary columns containing nulls +pub struct FuzzMedianTestData { + pub schema: Arc, + pub u8_low: UInt8Array, + pub dictionary_utf8_low: DictionaryArray, + pub u64: UInt64Array, + pub u16: UInt16Array, + pub u32: UInt32Array, + pub decimal128: Decimal128Array, +} + +impl FuzzMedianTestData { + pub fn new() -> Self { + // Create dictionary columns with null keys and values + let dictionary_utf8_low = create_test_dict( + &[ + Some("group_one"), + None, + Some("group_two"), + Some("group_three"), + ], + &[ + Some(0), // group_one + Some(1), // null value + Some(2), // group_two + None, // null key + Some(0), // group_one + Some(1), // null value + Some(3), // group_three + None, // null key + Some(2), // group_two + Some(0), // group_one + Some(1), // null value + Some(3), // group_three + ], + ); + + let u8_low = UInt8Array::from(vec![ + Some(100), + Some(200), + Some(100), + Some(200), + Some(100), + Some(50), + Some(50), + Some(200), + Some(100), + Some(100), + Some(75), + Some(50), + ]); + + // Create u64 data with some nulls and duplicates for DISTINCT testing + let u64 = UInt64Array::from(vec![ + Some(1000), + Some(2000), + Some(1500), + Some(3000), + Some(1000), // duplicate + None, // null + Some(5000), + Some(2500), + Some(1500), // duplicate + Some(1200), + Some(4000), + Some(5000), // duplicate + ]); + + // Create u16 data with some nulls and duplicates + let u16 = UInt16Array::from(vec![ + Some(10), + Some(20), + Some(15), + None, // null + Some(10), // duplicate + Some(30), + Some(50), + Some(25), + Some(15), // duplicate + Some(12), + None, // null + Some(50), // duplicate + ]); + + // Create u32 data with some nulls and duplicates + let u32 = UInt32Array::from(vec![ + Some(100000), + Some(200000), + Some(150000), + Some(300000), + Some(100000), // duplicate + Some(400000), + Some(500000), + None, // null + Some(150000), // duplicate + Some(120000), + Some(450000), + None, // null + ]); + + // Create decimal128 data with precision 10, scale 2 + let decimal128 = Decimal128Array::from(vec![ + Some(12345), // 123.45 + Some(67890), // 678.90 + Some(11111), // 111.11 + None, // null + Some(12345), // 123.45 duplicate + Some(98765), // 987.65 + Some(55555), // 555.55 + Some(33333), // 333.33 + Some(11111), // 111.11 duplicate + Some(12500), // 125.00 + None, // null + Some(55555), // 555.55 duplicate + ]) + .with_precision_and_scale(10, 2) + .unwrap(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("u8_low", DataType::UInt8, true), + Field::new("dictionary_utf8_low", string_dict_type(), true), + Field::new("u64", DataType::UInt64, true), + Field::new("u16", DataType::UInt16, true), + Field::new("u32", DataType::UInt32, true), + Field::new("decimal128", DataType::Decimal128(10, 2), true), + ])); + + Self { + schema, + u8_low, + dictionary_utf8_low, + u64, + u16, + u32, + decimal128, + } + } +} + +/// Sets up test contexts for fuzz table with numeric types for median testing and both single and multiple partitions +pub async fn setup_fuzz_median_test_contexts() -> Result<(SessionContext, SessionContext)> +{ + let test_data = FuzzMedianTestData::new(); + + // Single partition context + let ctx_single = create_fuzz_median_context_with_partitions(&test_data, 1).await?; + + // Multiple partition context + let ctx_multi = create_fuzz_median_context_with_partitions(&test_data, 3).await?; + + Ok((ctx_single, ctx_multi)) +} + +/// Creates a session context with fuzz median table partitioned into specified number of partitions +pub async fn create_fuzz_median_context_with_partitions( + test_data: &FuzzMedianTestData, + num_partitions: usize, +) -> Result { + let ctx = SessionContext::new_with_config( + SessionConfig::new().with_target_partitions(num_partitions), + ); + + let batches = split_fuzz_median_data_into_batches(test_data, num_partitions)?; + let provider = MemTable::try_new(test_data.schema.clone(), batches)?; + ctx.register_table("fuzz_table", Arc::new(provider))?; + + Ok(ctx) +} + +/// Splits fuzz median test data into multiple batches for partitioning +pub fn split_fuzz_median_data_into_batches( + test_data: &FuzzMedianTestData, + num_partitions: usize, +) -> Result>> { + debug_assert!(num_partitions > 0, "num_partitions must be greater than 0"); + let total_len = test_data.u8_low.len(); + let chunk_size = total_len.div_ceil(num_partitions); + + let mut batches = Vec::new(); + let mut start = 0; + + while start < total_len { + let end = min(start + chunk_size, total_len); + let len = end - start; + + if len > 0 { + let batch = RecordBatch::try_new( + test_data.schema.clone(), + vec![ + Arc::new(test_data.u8_low.slice(start, len)), + Arc::new(test_data.dictionary_utf8_low.slice(start, len)), + Arc::new(test_data.u64.slice(start, len)), + Arc::new(test_data.u16.slice(start, len)), + Arc::new(test_data.u32.slice(start, len)), + Arc::new(test_data.decimal128.slice(start, len)), + ], + )?; + batches.push(vec![batch]); + } + start = end; + } + + Ok(batches) +} + +/// Test data structure for fuzz table with timestamp and dictionary columns containing nulls +pub struct FuzzTimestampTestData { + pub schema: Arc, + pub utf8_low: StringArray, + pub u8_low: UInt8Array, + pub dictionary_utf8_low: DictionaryArray, + pub timestamp_us: TimestampMicrosecondArray, +} + +impl FuzzTimestampTestData { + pub fn new() -> Self { + // Create dictionary columns with null keys and values + let dictionary_utf8_low = create_test_dict( + &[Some("dict_x"), None, Some("dict_y"), Some("dict_z")], + &[ + Some(0), // dict_x + Some(1), // null value + Some(2), // dict_y + None, // null key + Some(0), // dict_x + Some(1), // null value + Some(3), // dict_z + None, // null key + Some(2), // dict_y + ], + ); + + let utf8_low = StringArray::from(vec![ + Some("alpha"), + Some("beta"), + Some("gamma"), + Some("delta"), + Some("alpha"), + Some("epsilon"), + Some("zeta"), + Some("delta"), + Some("gamma"), + ]); + + let u8_low = UInt8Array::from(vec![ + Some(10), + Some(20), + Some(30), + Some(20), + Some(10), + Some(40), + Some(30), + Some(20), + Some(30), + ]); + + // Create timestamp data with some nulls + let timestamp_us = TimestampMicrosecondArray::from(vec![ + Some(1000000), // 1970-01-01 00:00:01 + Some(2000000), // 1970-01-01 00:00:02 + Some(3000000), // 1970-01-01 00:00:03 + None, // null timestamp + Some(1500000), // 1970-01-01 00:00:01.5 + Some(4000000), // 1970-01-01 00:00:04 + Some(2500000), // 1970-01-01 00:00:02.5 + Some(3500000), // 1970-01-01 00:00:03.5 + Some(2800000), // 1970-01-01 00:00:02.8 + ]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("utf8_low", DataType::Utf8, true), + Field::new("u8_low", DataType::UInt8, true), + Field::new("dictionary_utf8_low", string_dict_type(), true), + Field::new( + "timestamp_us", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + ])); + + Self { + schema, + utf8_low, + u8_low, + dictionary_utf8_low, + timestamp_us, + } + } +} + +/// Sets up test contexts for fuzz table with timestamps and both single and multiple partitions +pub async fn setup_fuzz_timestamp_test_contexts() +-> Result<(SessionContext, SessionContext)> { + let test_data = FuzzTimestampTestData::new(); + + // Single partition context + let ctx_single = create_fuzz_timestamp_context_with_partitions(&test_data, 1).await?; + + // Multiple partition context + let ctx_multi = create_fuzz_timestamp_context_with_partitions(&test_data, 3).await?; + + Ok((ctx_single, ctx_multi)) +} + +/// Creates a session context with fuzz timestamp table partitioned into specified number of partitions +pub async fn create_fuzz_timestamp_context_with_partitions( + test_data: &FuzzTimestampTestData, + num_partitions: usize, +) -> Result { + let ctx = SessionContext::new_with_config( + SessionConfig::new().with_target_partitions(num_partitions), + ); + + let batches = split_fuzz_timestamp_data_into_batches(test_data, num_partitions)?; + let provider = MemTable::try_new(test_data.schema.clone(), batches)?; + ctx.register_table("fuzz_table", Arc::new(provider))?; + + Ok(ctx) +} + +/// Splits fuzz timestamp test data into multiple batches for partitioning +pub fn split_fuzz_timestamp_data_into_batches( + test_data: &FuzzTimestampTestData, + num_partitions: usize, +) -> Result>> { + debug_assert!(num_partitions > 0, "num_partitions must be greater than 0"); + let total_len = test_data.utf8_low.len(); + let chunk_size = total_len.div_ceil(num_partitions); + + let mut batches = Vec::new(); + let mut start = 0; + + while start < total_len { + let end = min(start + chunk_size, total_len); + let len = end - start; + + if len > 0 { + let batch = RecordBatch::try_new( + test_data.schema.clone(), + vec![ + Arc::new(test_data.utf8_low.slice(start, len)), + Arc::new(test_data.u8_low.slice(start, len)), + Arc::new(test_data.dictionary_utf8_low.slice(start, len)), + Arc::new(test_data.timestamp_us.slice(start, len)), + ], + )?; + batches.push(vec![batch]); + } + start = end; + } + + Ok(batches) +} + +pub mod basic; +pub mod dict_nulls; diff --git a/datafusion/core/tests/sql/create_drop.rs b/datafusion/core/tests/sql/create_drop.rs index 83712053b9542..4a60a79ff5de3 100644 --- a/datafusion/core/tests/sql/create_drop.rs +++ b/datafusion/core/tests/sql/create_drop.rs @@ -61,8 +61,31 @@ async fn create_external_table_with_ddl() -> Result<()> { assert_eq!(3, table_schema.fields().len()); assert_eq!(&DataType::Int32, table_schema.field(0).data_type()); - assert_eq!(&DataType::Utf8, table_schema.field(1).data_type()); + assert_eq!(&DataType::Utf8View, table_schema.field(1).data_type()); assert_eq!(&DataType::Boolean, table_schema.field(2).data_type()); Ok(()) } + +#[tokio::test] +async fn create_drop_table() -> Result<()> { + let ctx = SessionContext::new(); + + let sql = "CREATE TABLE dt (a_id integer, a_str string, a_bool boolean);"; + ctx.sql(sql).await.unwrap(); + + let cat = ctx.catalog("datafusion").unwrap(); + let schema = cat.schema("public").unwrap(); + + let exists = schema.table_exist("dt"); + assert!(exists, "Table should have been created!"); + + // Drop the table + let sql = "DROP TABLE dt;"; + ctx.sql(sql).await.unwrap(); + + let exists = schema.table_exist("dt"); + assert!(!exists, "Table should have been dropped!"); + + Ok(()) +} diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index 70e94227cfad8..5f62f7204eff1 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -16,11 +16,14 @@ // under the License. use super::*; +use insta::assert_snapshot; use rstest::rstest; use datafusion::config::ConfigOptions; use datafusion::physical_plan::display::DisplayableExecutionPlan; use datafusion::physical_plan::metrics::Timestamp; +use datafusion_common::format::ExplainAnalyzeLevel; +use object_store::path::Path; #[tokio::test] async fn explain_analyze_baseline_metrics() { @@ -52,42 +55,84 @@ async fn explain_analyze_baseline_metrics() { let formatted = arrow::util::pretty::pretty_format_batches(&results) .unwrap() .to_string(); + println!("Query Output:\n\n{formatted}"); assert_metrics!( &formatted, "AggregateExec: mode=Partial, gby=[]", - "metrics=[output_rows=3, elapsed_compute=" + "metrics=[output_rows=3, elapsed_compute=", + "output_bytes=", + "output_batches=3" ); + assert_metrics!( &formatted, - "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1]", - "metrics=[output_rows=5, elapsed_compute=" + "AggregateExec: mode=Partial, gby=[c1@0 as c1]", + "reduction_factor=5.1% (5/99)" ); + + { + let expected_batch_count_after_repartition = + if cfg!(not(feature = "force_hash_collisions")) { + "output_batches=3" + } else { + "output_batches=1" + }; + + assert_metrics!( + &formatted, + "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1]", + "metrics=[output_rows=5, elapsed_compute=", + "output_bytes=", + expected_batch_count_after_repartition + ); + + assert_metrics!( + &formatted, + "RepartitionExec: partitioning=Hash([c1@0], 3), input_partitions=3", + "metrics=[output_rows=5, elapsed_compute=", + "output_bytes=", + expected_batch_count_after_repartition + ); + + assert_metrics!( + &formatted, + "ProjectionExec: expr=[]", + "metrics=[output_rows=5, elapsed_compute=", + "output_bytes=", + expected_batch_count_after_repartition + ); + } + assert_metrics!( &formatted, "FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434", - "metrics=[output_rows=99, elapsed_compute=" - ); - assert_metrics!( - &formatted, - "ProjectionExec: expr=[]", - "metrics=[output_rows=5, elapsed_compute=" + "metrics=[output_rows=99, elapsed_compute=", + "output_bytes=", + "output_batches=1" ); + assert_metrics!( &formatted, - "CoalesceBatchesExec: target_batch_size=4096", - "metrics=[output_rows=5, elapsed_compute" + "FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434", + "selectivity=99% (99/100)" ); + assert_metrics!( &formatted, "UnionExec", - "metrics=[output_rows=3, elapsed_compute=" + "metrics=[output_rows=3, elapsed_compute=", + "output_bytes=", + "output_batches=3" ); + assert_metrics!( &formatted, "WindowAggExec", - "metrics=[output_rows=1, elapsed_compute=" + "metrics=[output_rows=1, elapsed_compute=", + "output_bytes=", + "output_batches=1" ); fn expected_to_have_metrics(plan: &dyn ExecutionPlan) -> bool { @@ -99,7 +144,6 @@ async fn explain_analyze_baseline_metrics() { || plan.as_any().downcast_ref::().is_some() || plan.as_any().downcast_ref::().is_some() || plan.as_any().downcast_ref::().is_some() - || plan.as_any().downcast_ref::().is_some() || plan.as_any().downcast_ref::().is_some() || plan.as_any().downcast_ref::().is_some() || plan.as_any().downcast_ref::().is_some() @@ -155,6 +199,116 @@ async fn explain_analyze_baseline_metrics() { fn nanos_from_timestamp(ts: &Timestamp) -> i64 { ts.value().unwrap().timestamp_nanos_opt().unwrap() } + +// Test different detail level for config `datafusion.explain.analyze_level` + +async fn collect_plan_with_context( + sql_str: &str, + ctx: &SessionContext, + level: ExplainAnalyzeLevel, +) -> String { + { + let state = ctx.state_ref(); + let mut state = state.write(); + state.config_mut().options_mut().explain.analyze_level = level; + } + let dataframe = ctx.sql(sql_str).await.unwrap(); + let batches = dataframe.collect().await.unwrap(); + arrow::util::pretty::pretty_format_batches(&batches) + .unwrap() + .to_string() +} + +async fn collect_plan(sql_str: &str, level: ExplainAnalyzeLevel) -> String { + let ctx = SessionContext::new(); + collect_plan_with_context(sql_str, &ctx, level).await +} + +#[tokio::test] +async fn explain_analyze_level() { + let sql = "EXPLAIN ANALYZE \ + SELECT * \ + FROM generate_series(10) as t1(v1) \ + ORDER BY v1 DESC"; + + for (level, needle, should_contain) in [ + (ExplainAnalyzeLevel::Summary, "spill_count", false), + (ExplainAnalyzeLevel::Summary, "output_batches", false), + (ExplainAnalyzeLevel::Summary, "output_rows", true), + (ExplainAnalyzeLevel::Summary, "output_bytes", true), + (ExplainAnalyzeLevel::Dev, "spill_count", true), + (ExplainAnalyzeLevel::Dev, "output_rows", true), + (ExplainAnalyzeLevel::Dev, "output_bytes", true), + (ExplainAnalyzeLevel::Dev, "output_batches", true), + ] { + let plan = collect_plan(sql, level).await; + assert_eq!( + plan.contains(needle), + should_contain, + "plan for level {level:?} unexpected content: {plan}" + ); + } +} + +#[tokio::test] +async fn explain_analyze_level_datasource_parquet() { + let table_name = "tpch_lineitem_small"; + let parquet_path = "tests/data/tpch_lineitem_small.parquet"; + let sql = format!("EXPLAIN ANALYZE SELECT * FROM {table_name}"); + + // Register test parquet file into context + let ctx = SessionContext::new(); + ctx.register_parquet(table_name, parquet_path, ParquetReadOptions::default()) + .await + .expect("register parquet table for explain analyze test"); + + for (level, needle, should_contain) in [ + (ExplainAnalyzeLevel::Summary, "metadata_load_time", true), + (ExplainAnalyzeLevel::Summary, "page_index_eval_time", false), + (ExplainAnalyzeLevel::Dev, "metadata_load_time", true), + (ExplainAnalyzeLevel::Dev, "page_index_eval_time", true), + ] { + let plan = collect_plan_with_context(&sql, &ctx, level).await; + + assert_eq!( + plan.contains(needle), + should_contain, + "plan for level {level:?} unexpected content: {plan}" + ); + } +} + +#[tokio::test] +async fn explain_analyze_parquet_pruning_metrics() { + let table_name = "tpch_lineitem_small"; + let parquet_path = "tests/data/tpch_lineitem_small.parquet"; + let ctx = SessionContext::new(); + ctx.register_parquet(table_name, parquet_path, ParquetReadOptions::default()) + .await + .expect("register parquet table for explain analyze test"); + + // Test scenario: + // This table's l_orderkey has range [1, 7] + // So the following query can't prune the file: + // select * from tpch_lineitem_small where l_orderkey = 5; + // If change filter to `l_orderkey=10`, the whole file can be pruned using stat. + for (l_orderkey, expected_pruning_metrics) in + [(5, "1 total → 1 matched"), (10, "1 total → 0 matched")] + { + let sql = format!( + "explain analyze select * from {table_name} where l_orderkey = {l_orderkey};" + ); + + let plan = + collect_plan_with_context(&sql, &ctx, ExplainAnalyzeLevel::Summary).await; + + let expected_metrics = + format!("files_ranges_pruned_statistics={expected_pruning_metrics}"); + + assert_metrics!(&plan, "DataSourceExec", &expected_metrics); + } +} + #[tokio::test] async fn csv_explain_plans() { // This test verify the look of each plan in its full cycle plan creation @@ -174,69 +328,66 @@ async fn csv_explain_plans() { println!("SQL: {sql}"); // // Verify schema - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: aggregate_test_100.c1 [c1:Utf8View]", - " Filter: aggregate_test_100.c2 > Int64(10) [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View]", - " TableScan: aggregate_test_100 [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View]", - ]; let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r" + Explain [plan_type:Utf8, plan:Utf8] + Projection: aggregate_test_100.c1 [c1:Utf8View] + Filter: aggregate_test_100.c2 > Int64(10) [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View] + TableScan: aggregate_test_100 [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View] + " ); // // Verify the text format of the plan - let expected = vec![ - "Explain", - " Projection: aggregate_test_100.c1", - " Filter: aggregate_test_100.c2 > Int64(10)", - " TableScan: aggregate_test_100", - ]; let formatted = plan.display_indent().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r" + Explain + Projection: aggregate_test_100.c1 + Filter: aggregate_test_100.c2 > Int64(10) + TableScan: aggregate_test_100 + " ); // // verify the grahviz format of the plan - let expected = vec![ - "// Begin DataFusion GraphViz Plan,", - "// display it online here: https://dreampuf.github.io/GraphvizOnline", - "", - "digraph {", - " subgraph cluster_1", - " {", - " graph[label=\"LogicalPlan\"]", - " 2[shape=box label=\"Explain\"]", - " 3[shape=box label=\"Projection: aggregate_test_100.c1\"]", - " 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]", - " 4[shape=box label=\"Filter: aggregate_test_100.c2 > Int64(10)\"]", - " 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]", - " 5[shape=box label=\"TableScan: aggregate_test_100\"]", - " 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]", - " }", - " subgraph cluster_6", - " {", - " graph[label=\"Detailed LogicalPlan\"]", - " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", - " 8[shape=box label=\"Projection: aggregate_test_100.c1\\nSchema: [c1:Utf8View]\"]", - " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", - " 9[shape=box label=\"Filter: aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View]\"]", - " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", - " 10[shape=box label=\"TableScan: aggregate_test_100\\nSchema: [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View]\"]", - " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", - " }", - "}", - "// End DataFusion GraphViz Plan", - ]; let formatted = plan.display_graphviz().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r#" + // Begin DataFusion GraphViz Plan, + // display it online here: https://dreampuf.github.io/GraphvizOnline + + digraph { + subgraph cluster_1 + { + graph[label="LogicalPlan"] + 2[shape=box label="Explain"] + 3[shape=box label="Projection: aggregate_test_100.c1"] + 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back] + 4[shape=box label="Filter: aggregate_test_100.c2 > Int64(10)"] + 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back] + 5[shape=box label="TableScan: aggregate_test_100"] + 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back] + } + subgraph cluster_6 + { + graph[label="Detailed LogicalPlan"] + 7[shape=box label="Explain\nSchema: [plan_type:Utf8, plan:Utf8]"] + 8[shape=box label="Projection: aggregate_test_100.c1\nSchema: [c1:Utf8View]"] + 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back] + 9[shape=box label="Filter: aggregate_test_100.c2 > Int64(10)\nSchema: [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View]"] + 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back] + 10[shape=box label="TableScan: aggregate_test_100\nSchema: [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View]"] + 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back] + } + } + // End DataFusion GraphViz Plan + "# ); // Optimized logical plan @@ -248,69 +399,66 @@ async fn csv_explain_plans() { assert_eq!(logical_schema, optimized_logical_schema.as_ref()); // // Verify schema - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: aggregate_test_100.c1 [c1:Utf8View]", - " Filter: aggregate_test_100.c2 > Int8(10) [c1:Utf8View, c2:Int8]", - " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)] [c1:Utf8View, c2:Int8]", - ]; let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r" + Explain [plan_type:Utf8, plan:Utf8] + Projection: aggregate_test_100.c1 [c1:Utf8View] + Filter: aggregate_test_100.c2 > Int8(10) [c1:Utf8View, c2:Int8] + TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)] [c1:Utf8View, c2:Int8] + " ); // // Verify the text format of the plan - let expected = vec![ - "Explain", - " Projection: aggregate_test_100.c1", - " Filter: aggregate_test_100.c2 > Int8(10)", - " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)]", - ]; let formatted = plan.display_indent().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r" + Explain + Projection: aggregate_test_100.c1 + Filter: aggregate_test_100.c2 > Int8(10) + TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)] + " ); // // verify the grahviz format of the plan - let expected = vec![ - "// Begin DataFusion GraphViz Plan,", - "// display it online here: https://dreampuf.github.io/GraphvizOnline", - "", - "digraph {", - " subgraph cluster_1", - " {", - " graph[label=\"LogicalPlan\"]", - " 2[shape=box label=\"Explain\"]", - " 3[shape=box label=\"Projection: aggregate_test_100.c1\"]", - " 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]", - " 4[shape=box label=\"Filter: aggregate_test_100.c2 > Int8(10)\"]", - " 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]", - " 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)]\"]", - " 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]", - " }", - " subgraph cluster_6", - " {", - " graph[label=\"Detailed LogicalPlan\"]", - " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", - " 8[shape=box label=\"Projection: aggregate_test_100.c1\\nSchema: [c1:Utf8View]\"]", - " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", - " 9[shape=box label=\"Filter: aggregate_test_100.c2 > Int8(10)\\nSchema: [c1:Utf8View, c2:Int8]\"]", - " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", - " 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)]\\nSchema: [c1:Utf8View, c2:Int8]\"]", - " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", - " }", - "}", - "// End DataFusion GraphViz Plan", - ]; let formatted = plan.display_graphviz().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r#" + // Begin DataFusion GraphViz Plan, + // display it online here: https://dreampuf.github.io/GraphvizOnline + + digraph { + subgraph cluster_1 + { + graph[label="LogicalPlan"] + 2[shape=box label="Explain"] + 3[shape=box label="Projection: aggregate_test_100.c1"] + 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back] + 4[shape=box label="Filter: aggregate_test_100.c2 > Int8(10)"] + 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back] + 5[shape=box label="TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)]"] + 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back] + } + subgraph cluster_6 + { + graph[label="Detailed LogicalPlan"] + 7[shape=box label="Explain\nSchema: [plan_type:Utf8, plan:Utf8]"] + 8[shape=box label="Projection: aggregate_test_100.c1\nSchema: [c1:Utf8View]"] + 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back] + 9[shape=box label="Filter: aggregate_test_100.c2 > Int8(10)\nSchema: [c1:Utf8View, c2:Int8]"] + 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back] + 10[shape=box label="TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)]\nSchema: [c1:Utf8View, c2:Int8]"] + 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back] + } + } + // End DataFusion GraphViz Plan + "# ); // Physical plan @@ -396,69 +544,66 @@ async fn csv_explain_verbose_plans() { // // Verify schema - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: aggregate_test_100.c1 [c1:Utf8View]", - " Filter: aggregate_test_100.c2 > Int64(10) [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View]", - " TableScan: aggregate_test_100 [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View]", - ]; let formatted = dataframe.logical_plan().display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r" + Explain [plan_type:Utf8, plan:Utf8] + Projection: aggregate_test_100.c1 [c1:Utf8View] + Filter: aggregate_test_100.c2 > Int64(10) [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View] + TableScan: aggregate_test_100 [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View] + " ); // // Verify the text format of the plan - let expected = vec![ - "Explain", - " Projection: aggregate_test_100.c1", - " Filter: aggregate_test_100.c2 > Int64(10)", - " TableScan: aggregate_test_100", - ]; let formatted = dataframe.logical_plan().display_indent().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r" + Explain + Projection: aggregate_test_100.c1 + Filter: aggregate_test_100.c2 > Int64(10) + TableScan: aggregate_test_100 + " ); // // verify the grahviz format of the plan - let expected = vec![ - "// Begin DataFusion GraphViz Plan,", - "// display it online here: https://dreampuf.github.io/GraphvizOnline", - "", - "digraph {", - " subgraph cluster_1", - " {", - " graph[label=\"LogicalPlan\"]", - " 2[shape=box label=\"Explain\"]", - " 3[shape=box label=\"Projection: aggregate_test_100.c1\"]", - " 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]", - " 4[shape=box label=\"Filter: aggregate_test_100.c2 > Int64(10)\"]", - " 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]", - " 5[shape=box label=\"TableScan: aggregate_test_100\"]", - " 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]", - " }", - " subgraph cluster_6", - " {", - " graph[label=\"Detailed LogicalPlan\"]", - " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", - " 8[shape=box label=\"Projection: aggregate_test_100.c1\\nSchema: [c1:Utf8View]\"]", - " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", - " 9[shape=box label=\"Filter: aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View]\"]", - " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", - " 10[shape=box label=\"TableScan: aggregate_test_100\\nSchema: [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View]\"]", - " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", - " }", - "}", - "// End DataFusion GraphViz Plan", - ]; let formatted = dataframe.logical_plan().display_graphviz().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r#" + // Begin DataFusion GraphViz Plan, + // display it online here: https://dreampuf.github.io/GraphvizOnline + + digraph { + subgraph cluster_1 + { + graph[label="LogicalPlan"] + 2[shape=box label="Explain"] + 3[shape=box label="Projection: aggregate_test_100.c1"] + 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back] + 4[shape=box label="Filter: aggregate_test_100.c2 > Int64(10)"] + 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back] + 5[shape=box label="TableScan: aggregate_test_100"] + 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back] + } + subgraph cluster_6 + { + graph[label="Detailed LogicalPlan"] + 7[shape=box label="Explain\nSchema: [plan_type:Utf8, plan:Utf8]"] + 8[shape=box label="Projection: aggregate_test_100.c1\nSchema: [c1:Utf8View]"] + 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back] + 9[shape=box label="Filter: aggregate_test_100.c2 > Int64(10)\nSchema: [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View]"] + 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back] + 10[shape=box label="TableScan: aggregate_test_100\nSchema: [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View]"] + 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back] + } + } + // End DataFusion GraphViz Plan + "# ); // Optimized logical plan @@ -470,69 +615,66 @@ async fn csv_explain_verbose_plans() { assert_eq!(&logical_schema, optimized_logical_schema.as_ref()); // // Verify schema - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: aggregate_test_100.c1 [c1:Utf8View]", - " Filter: aggregate_test_100.c2 > Int8(10) [c1:Utf8View, c2:Int8]", - " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)] [c1:Utf8View, c2:Int8]", - ]; let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r" + Explain [plan_type:Utf8, plan:Utf8] + Projection: aggregate_test_100.c1 [c1:Utf8View] + Filter: aggregate_test_100.c2 > Int8(10) [c1:Utf8View, c2:Int8] + TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)] [c1:Utf8View, c2:Int8] + " ); // // Verify the text format of the plan - let expected = vec![ - "Explain", - " Projection: aggregate_test_100.c1", - " Filter: aggregate_test_100.c2 > Int8(10)", - " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)]", - ]; let formatted = plan.display_indent().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r" + Explain + Projection: aggregate_test_100.c1 + Filter: aggregate_test_100.c2 > Int8(10) + TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)] + " ); // // verify the grahviz format of the plan - let expected = vec![ - "// Begin DataFusion GraphViz Plan,", - "// display it online here: https://dreampuf.github.io/GraphvizOnline", - "", - "digraph {", - " subgraph cluster_1", - " {", - " graph[label=\"LogicalPlan\"]", - " 2[shape=box label=\"Explain\"]", - " 3[shape=box label=\"Projection: aggregate_test_100.c1\"]", - " 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]", - " 4[shape=box label=\"Filter: aggregate_test_100.c2 > Int8(10)\"]", - " 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]", - " 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)]\"]", - " 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]", - " }", - " subgraph cluster_6", - " {", - " graph[label=\"Detailed LogicalPlan\"]", - " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", - " 8[shape=box label=\"Projection: aggregate_test_100.c1\\nSchema: [c1:Utf8View]\"]", - " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", - " 9[shape=box label=\"Filter: aggregate_test_100.c2 > Int8(10)\\nSchema: [c1:Utf8View, c2:Int8]\"]", - " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", - " 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)]\\nSchema: [c1:Utf8View, c2:Int8]\"]", - " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", - " }", - "}", - "// End DataFusion GraphViz Plan", - ]; let formatted = plan.display_graphviz().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r#" + // Begin DataFusion GraphViz Plan, + // display it online here: https://dreampuf.github.io/GraphvizOnline + + digraph { + subgraph cluster_1 + { + graph[label="LogicalPlan"] + 2[shape=box label="Explain"] + 3[shape=box label="Projection: aggregate_test_100.c1"] + 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back] + 4[shape=box label="Filter: aggregate_test_100.c2 > Int8(10)"] + 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back] + 5[shape=box label="TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)]"] + 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back] + } + subgraph cluster_6 + { + graph[label="Detailed LogicalPlan"] + 7[shape=box label="Explain\nSchema: [plan_type:Utf8, plan:Utf8]"] + 8[shape=box label="Projection: aggregate_test_100.c1\nSchema: [c1:Utf8View]"] + 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back] + 9[shape=box label="Filter: aggregate_test_100.c2 > Int8(10)\nSchema: [c1:Utf8View, c2:Int8]"] + 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back] + 10[shape=box label="TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)]\nSchema: [c1:Utf8View, c2:Int8]"] + 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back] + } + } + // End DataFusion GraphViz Plan + "# ); // Physical plan @@ -602,19 +744,6 @@ async fn test_physical_plan_display_indent() { LIMIT 10"; let dataframe = ctx.sql(sql).await.unwrap(); let physical_plan = dataframe.create_physical_plan().await.unwrap(); - let expected = vec![ - "SortPreservingMergeExec: [the_min@2 DESC], fetch=10", - " SortExec: TopK(fetch=10), expr=[the_min@2 DESC], preserve_partitioning=[true]", - " ProjectionExec: expr=[c1@0 as c1, max(aggregate_test_100.c12)@1 as max(aggregate_test_100.c12), min(aggregate_test_100.c12)@2 as the_min]", - " AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[max(aggregate_test_100.c12), min(aggregate_test_100.c12)]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c1@0], 9000), input_partitions=9000", - " AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[max(aggregate_test_100.c12), min(aggregate_test_100.c12)]", - " CoalesceBatchesExec: target_batch_size=4096", - " FilterExec: c12@1 < 10", - " RepartitionExec: partitioning=RoundRobinBatch(9000), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1, c12], file_type=csv, has_header=true", - ]; let normalizer = ExplainNormalizer::new(); let actual = format!("{}", displayable(physical_plan.as_ref()).indent(true)) @@ -622,10 +751,22 @@ async fn test_physical_plan_display_indent() { .lines() // normalize paths .map(|s| normalizer.normalize(s)) - .collect::>(); - assert_eq!( - expected, actual, - "expected:\n{expected:#?}\nactual:\n\n{actual:#?}\n" + .collect::>() + .join("\n"); + + assert_snapshot!( + actual, + @r" + SortPreservingMergeExec: [the_min@2 DESC], fetch=10 + SortExec: TopK(fetch=10), expr=[the_min@2 DESC], preserve_partitioning=[true] + ProjectionExec: expr=[c1@0 as c1, max(aggregate_test_100.c12)@1 as max(aggregate_test_100.c12), min(aggregate_test_100.c12)@2 as the_min] + AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[max(aggregate_test_100.c12), min(aggregate_test_100.c12)] + RepartitionExec: partitioning=Hash([c1@0], 9000), input_partitions=9000 + AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[max(aggregate_test_100.c12), min(aggregate_test_100.c12)] + FilterExec: c12@1 < 10 + RepartitionExec: partitioning=RoundRobinBatch(9000), input_partitions=1 + DataSourceExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1, c12], file_type=csv, has_header=true + " ); } @@ -647,19 +788,6 @@ async fn test_physical_plan_display_indent_multi_children() { let dataframe = ctx.sql(sql).await.unwrap(); let physical_plan = dataframe.create_physical_plan().await.unwrap(); - let expected = vec![ - "CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c1@0, c2@0)], projection=[c1@0]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c1@0], 9000), input_partitions=9000", - " RepartitionExec: partitioning=RoundRobinBatch(9000), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1], file_type=csv, has_header=true", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c2@0], 9000), input_partitions=9000", - " RepartitionExec: partitioning=RoundRobinBatch(9000), input_partitions=1", - " ProjectionExec: expr=[c1@0 as c2]", - " DataSourceExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1], file_type=csv, has_header=true", - ]; let normalizer = ExplainNormalizer::new(); let actual = format!("{}", displayable(physical_plan.as_ref()).indent(true)) @@ -667,11 +795,18 @@ async fn test_physical_plan_display_indent_multi_children() { .lines() // normalize paths .map(|s| normalizer.normalize(s)) - .collect::>(); + .collect::>() + .join("\n"); - assert_eq!( - expected, actual, - "expected:\n{expected:#?}\nactual:\n\n{actual:#?}\n" + assert_snapshot!( + actual, + @r" + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c1@0, c2@0)], projection=[c1@0] + RepartitionExec: partitioning=Hash([c1@0], 9000), input_partitions=1 + DataSourceExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1], file_type=csv, has_header=true + RepartitionExec: partitioning=Hash([c2@0], 9000), input_partitions=1 + DataSourceExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1@0 as c2], file_type=csv, has_header=true + " ); } @@ -710,8 +845,7 @@ async fn csv_explain_analyze_order_by() { // Ensure that the ordering is not optimized away from the plan // https://github.com/apache/datafusion/issues/6379 - let needle = - "SortExec: expr=[c1@0 ASC NULLS LAST], preserve_partitioning=[false], metrics=[output_rows=100, elapsed_compute"; + let needle = "SortExec: expr=[c1@0 ASC NULLS LAST], preserve_partitioning=[false], metrics=[output_rows=100, elapsed_compute"; assert_contains!(&formatted, needle); } @@ -729,10 +863,153 @@ async fn parquet_explain_analyze() { // should contain aggregated stats assert_contains!(&formatted, "output_rows=8"); - assert_contains!(&formatted, "row_groups_matched_bloom_filter=0"); - assert_contains!(&formatted, "row_groups_pruned_bloom_filter=0"); - assert_contains!(&formatted, "row_groups_matched_statistics=1"); - assert_contains!(&formatted, "row_groups_pruned_statistics=0"); + assert_contains!( + &formatted, + "row_groups_pruned_bloom_filter=1 total \u{2192} 1 matched" + ); + assert_contains!( + &formatted, + "row_groups_pruned_statistics=1 total \u{2192} 1 matched" + ); + assert_contains!(&formatted, "scan_efficiency_ratio=14%"); + + // The order of metrics is expected to be the same as the actual pruning order + // (file-> row-group -> page) + let i_file = formatted.find("files_ranges_pruned_statistics").unwrap(); + let i_rowgroup_stat = formatted.find("row_groups_pruned_statistics").unwrap(); + let i_rowgroup_bloomfilter = + formatted.find("row_groups_pruned_bloom_filter").unwrap(); + let i_page_rows = formatted.find("page_index_rows_pruned").unwrap(); + let i_page_pages = formatted.find("page_index_pages_pruned").unwrap(); + + assert!( + (i_file < i_rowgroup_stat) + && (i_rowgroup_stat < i_rowgroup_bloomfilter) + && (i_rowgroup_bloomfilter < i_page_pages && i_page_pages < i_page_rows), + "The parquet pruning metrics should be displayed in an order of: file range -> row group statistics -> row group bloom filter -> page index." + ); +} + +// This test reproduces the behavior described in +// https://github.com/apache/datafusion/issues/16684 where projection +// pushdown with recursive CTEs could fail to remove unused columns +// (e.g. nested/recursive expansion causing full schema to be scanned). +// Keeping this test ensures we don't regress that behavior. +#[tokio::test] +#[cfg_attr(tarpaulin, ignore)] +async fn parquet_recursive_projection_pushdown() -> Result<()> { + use parquet::arrow::arrow_writer::ArrowWriter; + use parquet::file::properties::WriterProperties; + + let temp_dir = TempDir::new().unwrap(); + let parquet_path = temp_dir.path().join("hierarchy.parquet"); + + let ids = Int64Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + let parent_ids = Int64Array::from(vec![0, 1, 1, 2, 2, 3, 4, 5, 6, 7]); + let values = Int64Array::from(vec![10, 20, 30, 40, 50, 60, 70, 80, 90, 100]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("parent_id", DataType::Int64, true), + Field::new("value", DataType::Int64, false), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(ids), Arc::new(parent_ids), Arc::new(values)], + ) + .unwrap(); + + let file = File::create(&parquet_path).unwrap(); + let props = WriterProperties::builder().build(); + let mut writer = ArrowWriter::try_new(file, schema, Some(props)).unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + + let ctx = SessionContext::new(); + ctx.register_parquet( + "hierarchy", + parquet_path.to_str().unwrap(), + ParquetReadOptions::default(), + ) + .await?; + + let sql = r#" + WITH RECURSIVE number_series AS ( + SELECT id, 1 as level + FROM hierarchy + WHERE id = 1 + + UNION ALL + + SELECT ns.id + 1, ns.level + 1 + FROM number_series ns + WHERE ns.id < 10 + ) + SELECT * FROM number_series ORDER BY id + "#; + + let dataframe = ctx.sql(sql).await?; + let physical_plan = dataframe.create_physical_plan().await?; + + let normalizer = ExplainNormalizer::new(); + let mut actual = format!("{}", displayable(physical_plan.as_ref()).indent(true)) + .trim() + .lines() + .map(|line| normalizer.normalize(line)) + .collect::>() + .join("\n"); + + fn replace_path_variants(actual: &mut String, path: &str) { + let mut candidates = vec![path.to_string()]; + + let trimmed = path.trim_start_matches(std::path::MAIN_SEPARATOR); + if trimmed != path { + candidates.push(trimmed.to_string()); + } + + let forward_slash = path.replace('\\', "/"); + if forward_slash != path { + candidates.push(forward_slash.clone()); + + let trimmed_forward = forward_slash.trim_start_matches('/'); + if trimmed_forward != forward_slash { + candidates.push(trimmed_forward.to_string()); + } + } + + for candidate in candidates { + *actual = actual.replace(&candidate, "TMP_DIR"); + } + } + + let temp_dir_path = temp_dir.path(); + let fs_path = temp_dir_path.to_string_lossy().to_string(); + replace_path_variants(&mut actual, &fs_path); + + if let Ok(url_path) = Path::from_filesystem_path(temp_dir_path) { + replace_path_variants(&mut actual, url_path.as_ref()); + } + + assert_snapshot!( + actual, + @r" + SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] + RecursiveQueryExec: name=number_series, is_distinct=false + CoalescePartitionsExec + ProjectionExec: expr=[id@0 as id, 1 as level] + FilterExec: id@0 = 1 + RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES), input_partitions=1 + DataSourceExec: file_groups={1 group: [[TMP_DIR/hierarchy.parquet]]}, projection=[id], file_type=parquet, predicate=id@0 = 1, pruning_predicate=id_null_count@2 != row_count@3 AND id_min@0 <= 1 AND 1 <= id_max@1, required_guarantees=[id in (1)] + CoalescePartitionsExec + ProjectionExec: expr=[id@0 + 1 as ns.id + Int64(1), level@1 + 1 as ns.level + Int64(1)] + FilterExec: id@0 < 10 + RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES), input_partitions=1 + WorkTableExec: name=number_series + " + ); + + Ok(()) } #[tokio::test] @@ -748,9 +1025,7 @@ async fn parquet_explain_analyze_verbose() { .to_string(); // should contain the raw per file stats (with the label) - assert_contains!(&formatted, "row_groups_matched_bloom_filter{partition=0"); assert_contains!(&formatted, "row_groups_pruned_bloom_filter{partition=0"); - assert_contains!(&formatted, "row_groups_matched_statistics{partition=0"); assert_contains!(&formatted, "row_groups_pruned_statistics{partition=0"); } @@ -779,14 +1054,19 @@ async fn explain_logical_plan_only() { let sql = "EXPLAIN select count(*) from (values ('a', 1, 100), ('a', 2, 150)) as t (c1,c2,c3)"; let actual = execute(&ctx, sql).await; let actual = normalize_vec_for_explain(actual); + let actual = actual.into_iter().map(|r| r.join("\n")).collect::(); - let expected = vec![ - vec!["logical_plan", "Projection: count(Int64(1)) AS count(*)\ - \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]\ - \n SubqueryAlias: t\ - \n Projection:\ - \n Values: (Utf8(\"a\"), Int64(1), Int64(100)), (Utf8(\"a\"), Int64(2), Int64(150))"]]; - assert_eq!(expected, actual); + assert_snapshot!( + actual, + @r#" + logical_plan + Projection: count(Int64(1)) AS count(*) + Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] + SubqueryAlias: t + Projection: + Values: (Utf8("a"), Int64(1), Int64(100)), (Utf8("a"), Int64(2), Int64(150)) + "# + ); } #[tokio::test] @@ -797,14 +1077,16 @@ async fn explain_physical_plan_only() { let sql = "EXPLAIN select count(*) from (values ('a', 1, 100), ('a', 2, 150)) as t (c1,c2,c3)"; let actual = execute(&ctx, sql).await; let actual = normalize_vec_for_explain(actual); + let actual = actual.into_iter().map(|r| r.join("\n")).collect::(); - let expected = vec![vec![ - "physical_plan", - "ProjectionExec: expr=[2 as count(*)]\ - \n PlaceholderRowExec\ - \n", - ]]; - assert_eq!(expected, actual); + assert_snapshot!( + actual, + @r" + physical_plan + ProjectionExec: expr=[2 as count(*)] + PlaceholderRowExec + " + ); } #[tokio::test] @@ -827,3 +1109,54 @@ async fn csv_explain_analyze_with_statistics() { ", statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:)]]" ); } + +#[tokio::test] +async fn nested_loop_join_selectivity() { + for (join_type, expected_selectivity) in [ + ("INNER", "1% (1/100)"), + ("LEFT", "10% (10/100)"), + ("RIGHT", "10% (10/100)"), + // 1 match + 9 left + 9 right = 19 + ("FULL", "19% (19/100)"), + ] { + let ctx = SessionContext::new(); + let sql = format!( + "EXPLAIN ANALYZE SELECT * \ + FROM generate_series(1, 10) as t1(a) \ + {join_type} JOIN generate_series(1, 10) as t2(b) \ + ON (t1.a + t2.b) = 20" + ); + + let actual = execute_to_batches(&ctx, sql.as_str()).await; + let formatted = arrow::util::pretty::pretty_format_batches(&actual) + .unwrap() + .to_string(); + + assert_metrics!( + &formatted, + "NestedLoopJoinExec", + &format!("selectivity={expected_selectivity}") + ); + } +} + +#[tokio::test] +async fn explain_analyze_hash_join() { + let sql = "EXPLAIN ANALYZE \ + SELECT * \ + FROM generate_series(10) as t1(a) \ + JOIN generate_series(20) as t2(b) \ + ON t1.a=t2.b"; + + for (level, needle, should_contain) in [ + (ExplainAnalyzeLevel::Summary, "probe_hit_rate", true), + (ExplainAnalyzeLevel::Summary, "avg_fanout", true), + ] { + let plan = collect_plan(sql, level).await; + assert_eq!( + plan.contains(needle), + should_contain, + "plan for level {level:?} unexpected content: {plan}" + ); + } +} diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 77eec20eac006..7c0e89ee96418 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -15,8 +15,13 @@ // specific language governing permissions and limitations // under the License. +use insta::assert_snapshot; + +use datafusion::assert_batches_eq; +use datafusion::catalog::MemTable; use datafusion::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; use datafusion::test_util::register_unbounded_file_with_ordering; +use datafusion_sql::unparser::plan_to_sql; use super::*; @@ -33,14 +38,16 @@ async fn join_change_in_planner() -> Result<()> { Field::new("a2", DataType::UInt32, false), ])); // Specify the ordering: - let file_sort_order = vec![[col("a1")] - .into_iter() - .map(|e| { - let ascending = true; - let nulls_first = false; - e.sort(ascending, nulls_first) - }) - .collect::>()]; + let file_sort_order = vec![ + [col("a1")] + .into_iter() + .map(|e| { + let ascending = true; + let nulls_first = false; + e.sort(ascending, nulls_first) + }) + .collect::>(), + ]; register_unbounded_file_with_ordering( &ctx, schema.clone(), @@ -61,28 +68,17 @@ async fn join_change_in_planner() -> Result<()> { let dataframe = ctx.sql(sql).await?; let physical_plan = dataframe.create_physical_plan().await?; let formatted = displayable(physical_plan.as_ref()).indent(true).to_string(); - let expected = { - [ - "SymmetricHashJoinExec: mode=Partitioned, join_type=Full, on=[(a2@1, a2@1)], filter=CAST(a1@0 AS Int64) > CAST(a1@1 AS Int64) + 3 AND CAST(a1@0 AS Int64) < CAST(a1@1 AS Int64) + 10", - " CoalesceBatchesExec: target_batch_size=8192", - " RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a1@0 ASC NULLS LAST", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - // " DataSourceExec: file_groups={1 group: [[tempdir/left.csv]]}, projection=[a1, a2], file_type=csv, has_header=false", - " CoalesceBatchesExec: target_batch_size=8192", - " RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a1@0 ASC NULLS LAST", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - // " DataSourceExec: file_groups={1 group: [[tempdir/right.csv]]}, projection=[a1, a2], file_type=csv, has_header=false" - ] - }; - let mut actual: Vec<&str> = formatted.trim().lines().collect(); - // Remove CSV lines - actual.remove(4); - actual.remove(7); - - assert_eq!( - expected, - actual[..], - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + + assert_snapshot!( + actual, + @r" + SymmetricHashJoinExec: mode=Partitioned, join_type=Full, on=[(a2@1, a2@1)], filter=CAST(a1@0 AS Int64) > CAST(a1@1 AS Int64) + 3 AND CAST(a1@0 AS Int64) < CAST(a1@1 AS Int64) + 10 + RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a1, a2], infinite_source=true, output_ordering=[a1@0 ASC NULLS LAST] + RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a1, a2], infinite_source=true, output_ordering=[a1@0 ASC NULLS LAST] + " ); Ok(()) } @@ -101,14 +97,16 @@ async fn join_no_order_on_filter() -> Result<()> { Field::new("a3", DataType::UInt32, false), ])); // Specify the ordering: - let file_sort_order = vec![[col("a1")] - .into_iter() - .map(|e| { - let ascending = true; - let nulls_first = false; - e.sort(ascending, nulls_first) - }) - .collect::>()]; + let file_sort_order = vec![ + [col("a1")] + .into_iter() + .map(|e| { + let ascending = true; + let nulls_first = false; + e.sort(ascending, nulls_first) + }) + .collect::>(), + ]; register_unbounded_file_with_ordering( &ctx, schema.clone(), @@ -129,28 +127,17 @@ async fn join_no_order_on_filter() -> Result<()> { let dataframe = ctx.sql(sql).await?; let physical_plan = dataframe.create_physical_plan().await?; let formatted = displayable(physical_plan.as_ref()).indent(true).to_string(); - let expected = { - [ - "SymmetricHashJoinExec: mode=Partitioned, join_type=Full, on=[(a2@1, a2@1)], filter=CAST(a3@0 AS Int64) > CAST(a3@1 AS Int64) + 3 AND CAST(a3@0 AS Int64) < CAST(a3@1 AS Int64) + 10", - " CoalesceBatchesExec: target_batch_size=8192", - " RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - // " DataSourceExec: file_groups={1 group: [[tempdir/left.csv]]}, projection=[a1, a2], file_type=csv, has_header=false", - " CoalesceBatchesExec: target_batch_size=8192", - " RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - // " DataSourceExec: file_groups={1 group: [[tempdir/right.csv]]}, projection=[a1, a2], file_type=csv, has_header=false" - ] - }; - let mut actual: Vec<&str> = formatted.trim().lines().collect(); - // Remove CSV lines - actual.remove(4); - actual.remove(7); - - assert_eq!( - expected, - actual[..], - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + + assert_snapshot!( + actual, + @r" + SymmetricHashJoinExec: mode=Partitioned, join_type=Full, on=[(a2@1, a2@1)], filter=CAST(a3@0 AS Int64) > CAST(a3@1 AS Int64) + 3 AND CAST(a3@0 AS Int64) < CAST(a3@1 AS Int64) + 10 + RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a1, a2, a3], infinite_source=true, output_ordering=[a1@0 ASC NULLS LAST] + RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=1, maintains_sort_order=true + StreamingTableExec: partition_sizes=1, projection=[a1, a2, a3], infinite_source=true, output_ordering=[a1@0 ASC NULLS LAST] + " ); Ok(()) } @@ -179,28 +166,17 @@ async fn join_change_in_planner_without_sort() -> Result<()> { let dataframe = ctx.sql(sql).await?; let physical_plan = dataframe.create_physical_plan().await?; let formatted = displayable(physical_plan.as_ref()).indent(true).to_string(); - let expected = { - [ - "SymmetricHashJoinExec: mode=Partitioned, join_type=Full, on=[(a2@1, a2@1)], filter=CAST(a1@0 AS Int64) > CAST(a1@1 AS Int64) + 3 AND CAST(a1@0 AS Int64) < CAST(a1@1 AS Int64) + 10", - " CoalesceBatchesExec: target_batch_size=8192", - " RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - // " DataSourceExec: file_groups={1 group: [[tempdir/left.csv]]}, projection=[a1, a2], file_type=csv, has_header=false", - " CoalesceBatchesExec: target_batch_size=8192", - " RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - // " DataSourceExec: file_groups={1 group: [[tempdir/right.csv]]}, projection=[a1, a2], file_type=csv, has_header=false" - ] - }; - let mut actual: Vec<&str> = formatted.trim().lines().collect(); - // Remove CSV lines - actual.remove(4); - actual.remove(7); - - assert_eq!( - expected, - actual[..], - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + + assert_snapshot!( + actual, + @r" + SymmetricHashJoinExec: mode=Partitioned, join_type=Full, on=[(a2@1, a2@1)], filter=CAST(a1@0 AS Int64) > CAST(a1@1 AS Int64) + 3 AND CAST(a1@0 AS Int64) < CAST(a1@1 AS Int64) + 10 + RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a1, a2], infinite_source=true + RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a1, a2], infinite_source=true + " ); Ok(()) } @@ -230,8 +206,96 @@ async fn join_change_in_planner_without_sort_not_allowed() -> Result<()> { match df.create_physical_plan().await { Ok(_) => panic!("Expecting error."), Err(e) => { - assert_eq!(e.strip_backtrace(), "SanityCheckPlan\ncaused by\nError during planning: Join operation cannot operate on a non-prunable stream without enabling the 'allow_symmetric_joins_without_pruning' configuration flag") + assert_eq!( + e.strip_backtrace(), + "SanityCheckPlan\ncaused by\nError during planning: Join operation cannot operate on a non-prunable stream without enabling the 'allow_symmetric_joins_without_pruning' configuration flag" + ) } } Ok(()) } + +#[tokio::test] +async fn join_using_uppercase_column() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new( + "UPPER", + DataType::UInt32, + false, + )])); + let tmp_dir = TempDir::new()?; + let file_path = tmp_dir.path().join("uppercase-column.csv"); + let mut file = File::create(file_path.clone())?; + file.write_all("0".as_bytes())?; + drop(file); + + let ctx = SessionContext::new(); + ctx.register_csv( + "test", + file_path.to_str().unwrap(), + CsvReadOptions::new().schema(&schema).has_header(false), + ) + .await?; + + let dataframe = ctx + .sql( + r#" + SELECT test."UPPER" FROM "test" + INNER JOIN ( + SELECT test."UPPER" FROM "test" + ) AS selection USING ("UPPER") + ; + "#, + ) + .await?; + + assert_batches_eq!( + [ + "+-------+", + "| UPPER |", + "+-------+", + "| 0 |", + "+-------+", + ], + &dataframe.collect().await? + ); + + Ok(()) +} + +// Issue #17359: https://github.com/apache/datafusion/issues/17359 +#[tokio::test] +async fn unparse_cross_join() -> Result<()> { + let ctx = SessionContext::new(); + + let j1_schema = Arc::new(Schema::new(vec![ + Field::new("j1_id", DataType::Int32, true), + Field::new("j1_string", DataType::Utf8, true), + ])); + let j2_schema = Arc::new(Schema::new(vec![ + Field::new("j2_id", DataType::Int32, true), + Field::new("j2_string", DataType::Utf8, true), + ])); + + ctx.register_table("j1", Arc::new(MemTable::try_new(j1_schema, vec![vec![]])?))?; + ctx.register_table("j2", Arc::new(MemTable::try_new(j2_schema, vec![vec![]])?))?; + + let df = ctx + .sql( + r#" + select j1.j1_id, j2.j2_string + from j1, j2 + where j2.j2_id = 0 + "#, + ) + .await?; + + let unopt_sql = plan_to_sql(df.logical_plan())?; + assert_snapshot!(unopt_sql, @"SELECT j1.j1_id, j2.j2_string FROM j1 CROSS JOIN j2 WHERE (j2.j2_id = 0)"); + + let optimized_plan = df.into_optimized_plan()?; + + let opt_sql = plan_to_sql(&optimized_plan)?; + assert_snapshot!(opt_sql, @"SELECT j1.j1_id, j2.j2_string FROM j1 CROSS JOIN j2 WHERE (j2.j2_id = 0)"); + + Ok(()) +} diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 2a5597b9fb7ee..9a1dc5502ee60 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -24,36 +24,40 @@ use arrow::{ use datafusion::error::Result; use datafusion::logical_expr::{Aggregate, LogicalPlan, TableScan}; -use datafusion::physical_plan::collect; -use datafusion::physical_plan::metrics::MetricValue; use datafusion::physical_plan::ExecutionPlan; use datafusion::physical_plan::ExecutionPlanVisitor; +use datafusion::physical_plan::collect; +use datafusion::physical_plan::metrics::MetricValue; use datafusion::prelude::*; use datafusion::test_util; use datafusion::{execution::context::SessionContext, physical_plan::displayable}; use datafusion_common::test_util::batches_to_sort_string; use datafusion_common::utils::get_available_parallelism; use datafusion_common::{assert_contains, assert_not_contains}; -use insta::assert_snapshot; use object_store::path::Path; use std::fs::File; use std::io::Write; use std::path::PathBuf; use tempfile::TempDir; -/// A macro to assert that some particular line contains two substrings -/// -/// Usage: `assert_metrics!(actual, operator_name, metrics)` +/// A macro to assert that some particular line contains the given substrings /// +/// Usage: `assert_metrics!(actual, operator_name, metrics_1, metrics_2, ...)` macro_rules! assert_metrics { - ($ACTUAL: expr, $OPERATOR_NAME: expr, $METRICS: expr) => { + ($ACTUAL: expr, $OPERATOR_NAME: expr, $($METRICS: expr),+) => { let found = $ACTUAL .lines() - .any(|line| line.contains($OPERATOR_NAME) && line.contains($METRICS)); + .any(|line| line.contains($OPERATOR_NAME) $( && line.contains($METRICS))+); + + let mut metrics = String::new(); + $(metrics.push_str(format!(" '{}',", $METRICS).as_str());)+ + // remove the last `,` from the string + metrics.pop(); + assert!( found, - "Can not find a line with both '{}' and '{}' in\n\n{}", - $OPERATOR_NAME, $METRICS, $ACTUAL + "Cannot find a line with operator name '{}' and metrics containing values {} in :\n\n{}", + $OPERATOR_NAME, metrics, $ACTUAL ); }; } @@ -66,6 +70,7 @@ mod path_partition; mod runtime_config; pub mod select; mod sql_api; +mod unparser; async fn register_aggregate_csv_by_sql(ctx: &SessionContext) { let testdata = test_util::arrow_test_data(); @@ -331,8 +336,7 @@ async fn nyc() -> Result<()> { match &optimized_plan { LogicalPlan::Aggregate(Aggregate { input, .. }) => match input.as_ref() { LogicalPlan::TableScan(TableScan { - ref projected_schema, - .. + projected_schema, .. }) => { assert_eq!(2, projected_schema.fields().len()); assert_eq!(projected_schema.field(0).name(), "passenger_count"); diff --git a/datafusion/core/tests/sql/path_partition.rs b/datafusion/core/tests/sql/path_partition.rs index 5e9748d23d8cd..1afab529f019c 100644 --- a/datafusion/core/tests/sql/path_partition.rs +++ b/datafusion/core/tests/sql/path_partition.rs @@ -20,7 +20,6 @@ use std::collections::BTreeSet; use std::fs::File; use std::io::{Read, Seek, SeekFrom}; -use std::ops::Range; use std::sync::Arc; use arrow::datatypes::DataType; @@ -31,26 +30,28 @@ use datafusion::{ listing::{ListingOptions, ListingTable, ListingTableConfig}, }, error::Result, - physical_plan::ColumnStatistics, prelude::SessionContext, test_util::{self, arrow_test_data, parquet_test_data}, }; use datafusion_catalog::TableProvider; +use datafusion_common::ScalarValue; use datafusion_common::stats::Precision; use datafusion_common::test_util::batches_to_sort_string; -use datafusion_common::ScalarValue; use datafusion_execution::config::SessionConfig; use async_trait::async_trait; use bytes::Bytes; use chrono::{TimeZone, Utc}; +use futures::StreamExt; use futures::stream::{self, BoxStream}; use insta::assert_snapshot; use object_store::{ - path::Path, GetOptions, GetResult, GetResultPayload, ListResult, ObjectMeta, - ObjectStore, PutOptions, PutResult, + Attributes, CopyOptions, GetRange, MultipartUpload, PutMultipartOptions, PutPayload, +}; +use object_store::{ + GetOptions, GetResult, GetResultPayload, ListResult, ObjectMeta, ObjectStore, + PutOptions, PutResult, path::Path, }; -use object_store::{Attributes, MultipartUpload, PutMultipartOpts, PutPayload}; use url::Url; #[tokio::test] @@ -460,14 +461,26 @@ async fn parquet_statistics() -> Result<()> { let schema = physical_plan.schema(); assert_eq!(schema.fields().len(), 4); - let stat_cols = physical_plan.partition_statistics(None)?.column_statistics; + let stat_cols = physical_plan + .partition_statistics(None)? + .column_statistics + .clone(); assert_eq!(stat_cols.len(), 4); // stats for the first col are read from the parquet file assert_eq!(stat_cols[0].null_count, Precision::Exact(3)); - // TODO assert partition column (1,2,3) stats once implemented (#1186) - assert_eq!(stat_cols[1], ColumnStatistics::new_unknown(),); - assert_eq!(stat_cols[2], ColumnStatistics::new_unknown(),); - assert_eq!(stat_cols[3], ColumnStatistics::new_unknown(),); + // Partition column statistics (year=2021 for all 3 rows) + assert_eq!(stat_cols[1].null_count, Precision::Exact(0)); + assert_eq!( + stat_cols[1].min_value, + Precision::Exact(ScalarValue::Int32(Some(2021))) + ); + assert_eq!( + stat_cols[1].max_value, + Precision::Exact(ScalarValue::Int32(Some(2021))) + ); + // month and day are Utf8 partition columns with statistics + assert_eq!(stat_cols[2].null_count, Precision::Exact(0)); + assert_eq!(stat_cols[3].null_count, Precision::Exact(0)); //// WITH PROJECTION //// let dataframe = ctx.sql("SELECT mycol, day FROM t WHERE day='28'").await?; @@ -475,12 +488,23 @@ async fn parquet_statistics() -> Result<()> { let schema = physical_plan.schema(); assert_eq!(schema.fields().len(), 2); - let stat_cols = physical_plan.partition_statistics(None)?.column_statistics; + let stat_cols = physical_plan + .partition_statistics(None)? + .column_statistics + .clone(); assert_eq!(stat_cols.len(), 2); // stats for the first col are read from the parquet file assert_eq!(stat_cols[0].null_count, Precision::Exact(1)); - // TODO assert partition column stats once implemented (#1186) - assert_eq!(stat_cols[1], ColumnStatistics::new_unknown()); + // Partition column statistics for day='28' (1 row) + assert_eq!(stat_cols[1].null_count, Precision::Exact(0)); + assert_eq!( + stat_cols[1].min_value, + Precision::Exact(ScalarValue::Utf8(Some("28".to_string()))) + ); + assert_eq!( + stat_cols[1].max_value, + Precision::Exact(ScalarValue::Utf8(Some("28".to_string()))) + ); Ok(()) } @@ -604,7 +628,7 @@ async fn create_partitioned_alltypes_parquet_table( } #[derive(Debug)] -/// An object store implem that is mirrors a given file to multiple paths. +/// An object store implem that mirrors a given file to multiple paths. pub struct MirroringObjectStore { /// The `(path,size)` of the files that "exist" in the store files: Vec, @@ -645,7 +669,7 @@ impl ObjectStore for MirroringObjectStore { async fn put_multipart_opts( &self, _location: &Path, - _opts: PutMultipartOpts, + _opts: PutMultipartOptions, ) -> object_store::Result> { unimplemented!() } @@ -653,12 +677,13 @@ impl ObjectStore for MirroringObjectStore { async fn get_opts( &self, location: &Path, - _options: GetOptions, + options: GetOptions, ) -> object_store::Result { self.files.iter().find(|x| *x == location).unwrap(); let path = std::path::PathBuf::from(&self.mirrored_file); let file = File::open(&path).unwrap(); let metadata = file.metadata().unwrap(); + let meta = ObjectMeta { location: location.clone(), last_modified: metadata.modified().map(chrono::DateTime::from).unwrap(), @@ -667,37 +692,35 @@ impl ObjectStore for MirroringObjectStore { version: None, }; + let payload = if options.head { + // no content for head requests + GetResultPayload::Stream(stream::empty().boxed()) + } else if let Some(range) = options.range { + let GetRange::Bounded(range) = range else { + unimplemented!("Unbounded range not supported in MirroringObjectStore"); + }; + let mut file = File::open(path).unwrap(); + file.seek(SeekFrom::Start(range.start)).unwrap(); + + let to_read = range.end - range.start; + let to_read: usize = to_read.try_into().unwrap(); + let mut data = Vec::with_capacity(to_read); + let read = file.take(to_read as u64).read_to_end(&mut data).unwrap(); + assert_eq!(read, to_read); + let stream = stream::once(async move { Ok(Bytes::from(data)) }).boxed(); + GetResultPayload::Stream(stream) + } else { + GetResultPayload::File(file, path) + }; + Ok(GetResult { range: 0..meta.size, - payload: GetResultPayload::File(file, path), + payload, meta, attributes: Attributes::default(), }) } - async fn get_range( - &self, - location: &Path, - range: Range, - ) -> object_store::Result { - self.files.iter().find(|x| *x == location).unwrap(); - let path = std::path::PathBuf::from(&self.mirrored_file); - let mut file = File::open(path).unwrap(); - file.seek(SeekFrom::Start(range.start)).unwrap(); - - let to_read = range.end - range.start; - let to_read: usize = to_read.try_into().unwrap(); - let mut data = Vec::with_capacity(to_read); - let read = file.take(to_read as u64).read_to_end(&mut data).unwrap(); - assert_eq!(read, to_read); - - Ok(data.into()) - } - - async fn delete(&self, _location: &Path) -> object_store::Result<()> { - unimplemented!() - } - fn list( &self, prefix: Option<&Path>, @@ -712,6 +735,8 @@ impl ObjectStore for MirroringObjectStore { .map(|mut x| x.next().is_some()) .unwrap_or(false); + #[expect(clippy::result_large_err)] + // closure only ever returns Ok; Err type is never constructed filter.then(|| { Ok(ObjectMeta { location, @@ -767,14 +792,18 @@ impl ObjectStore for MirroringObjectStore { }) } - async fn copy(&self, _from: &Path, _to: &Path) -> object_store::Result<()> { + fn delete_stream( + &self, + _locations: BoxStream<'static, object_store::Result>, + ) -> BoxStream<'static, object_store::Result> { unimplemented!() } - async fn copy_if_not_exists( + async fn copy_opts( &self, _from: &Path, _to: &Path, + _options: CopyOptions, ) -> object_store::Result<()> { unimplemented!() } diff --git a/datafusion/core/tests/sql/runtime_config.rs b/datafusion/core/tests/sql/runtime_config.rs index 18e07bb61ed94..cf5237d725805 100644 --- a/datafusion/core/tests/sql/runtime_config.rs +++ b/datafusion/core/tests/sql/runtime_config.rs @@ -18,9 +18,14 @@ //! Tests for runtime configuration SQL interface use std::sync::Arc; +use std::time::Duration; use datafusion::execution::context::SessionContext; use datafusion::execution::context::TaskContext; +use datafusion::prelude::SessionConfig; +use datafusion_execution::cache::DefaultListFilesCache; +use datafusion_execution::cache::cache_manager::CacheManagerConfig; +use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_physical_plan::common::collect; #[tokio::test] @@ -140,7 +145,7 @@ async fn test_memory_limit_enforcement() { } #[tokio::test] -async fn test_invalid_memory_limit() { +async fn test_invalid_memory_limit_when_unit_is_invalid() { let ctx = SessionContext::new(); let result = ctx @@ -149,7 +154,194 @@ async fn test_invalid_memory_limit() { assert!(result.is_err()); let error_message = result.unwrap_err().to_string(); - assert!(error_message.contains("Unsupported unit 'X'")); + assert!( + error_message + .contains("Unsupported unit 'X' in 'datafusion.runtime.memory_limit'") + && error_message.contains("Unit must be one of: 'K', 'M', 'G'") + ); +} + +#[tokio::test] +async fn test_invalid_memory_limit_when_limit_is_not_numeric() { + let ctx = SessionContext::new(); + + let result = ctx + .sql("SET datafusion.runtime.memory_limit = 'invalid_memory_limit'") + .await; + + assert!(result.is_err()); + let error_message = result.unwrap_err().to_string(); + assert!(error_message.contains( + "Failed to parse number from 'datafusion.runtime.memory_limit', limit 'invalid_memory_limit'" + )); +} + +#[tokio::test] +async fn test_max_temp_directory_size_enforcement() { + let ctx = SessionContext::new(); + + ctx.sql("SET datafusion.runtime.memory_limit = '1M'") + .await + .unwrap() + .collect() + .await + .unwrap(); + + ctx.sql("SET datafusion.execution.sort_spill_reservation_bytes = 0") + .await + .unwrap() + .collect() + .await + .unwrap(); + + ctx.sql("SET datafusion.runtime.max_temp_directory_size = '0K'") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let query = "select * from generate_series(1,100000) as t1(v1) order by v1;"; + let result = ctx.sql(query).await.unwrap().collect().await; + + assert!( + result.is_err(), + "Should fail due to max temp directory size limit" + ); + + ctx.sql("SET datafusion.runtime.max_temp_directory_size = '1M'") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let result = ctx.sql(query).await.unwrap().collect().await; + + assert!( + result.is_ok(), + "Should not fail due to max temp directory size limit" + ); +} + +#[tokio::test] +async fn test_test_metadata_cache_limit() { + let ctx = SessionContext::new(); + + let update_limit = async |ctx: &SessionContext, limit: &str| { + ctx.sql( + format!("SET datafusion.runtime.metadata_cache_limit = '{limit}'").as_str(), + ) + .await + .unwrap() + .collect() + .await + .unwrap(); + }; + + let get_limit = |ctx: &SessionContext| -> usize { + ctx.task_ctx() + .runtime_env() + .cache_manager + .get_file_metadata_cache() + .cache_limit() + }; + + update_limit(&ctx, "100M").await; + assert_eq!(get_limit(&ctx), 100 * 1024 * 1024); + + update_limit(&ctx, "2G").await; + assert_eq!(get_limit(&ctx), 2 * 1024 * 1024 * 1024); + + update_limit(&ctx, "123K").await; + assert_eq!(get_limit(&ctx), 123 * 1024); +} + +#[tokio::test] +async fn test_list_files_cache_limit() { + let list_files_cache = Arc::new(DefaultListFilesCache::default()); + + let rt = RuntimeEnvBuilder::new() + .with_cache_manager( + CacheManagerConfig::default().with_list_files_cache(Some(list_files_cache)), + ) + .build_arc() + .unwrap(); + + let ctx = SessionContext::new_with_config_rt(SessionConfig::default(), rt); + + let update_limit = async |ctx: &SessionContext, limit: &str| { + ctx.sql( + format!("SET datafusion.runtime.list_files_cache_limit = '{limit}'").as_str(), + ) + .await + .unwrap() + .collect() + .await + .unwrap(); + }; + + let get_limit = |ctx: &SessionContext| -> usize { + ctx.task_ctx() + .runtime_env() + .cache_manager + .get_list_files_cache() + .unwrap() + .cache_limit() + }; + + update_limit(&ctx, "100M").await; + assert_eq!(get_limit(&ctx), 100 * 1024 * 1024); + + update_limit(&ctx, "2G").await; + assert_eq!(get_limit(&ctx), 2 * 1024 * 1024 * 1024); + + update_limit(&ctx, "123K").await; + assert_eq!(get_limit(&ctx), 123 * 1024); +} + +#[tokio::test] +async fn test_list_files_cache_ttl() { + let list_files_cache = Arc::new(DefaultListFilesCache::default()); + + let rt = RuntimeEnvBuilder::new() + .with_cache_manager( + CacheManagerConfig::default().with_list_files_cache(Some(list_files_cache)), + ) + .build_arc() + .unwrap(); + + let ctx = SessionContext::new_with_config_rt(SessionConfig::default(), rt); + + let update_limit = async |ctx: &SessionContext, limit: &str| { + ctx.sql( + format!("SET datafusion.runtime.list_files_cache_ttl = '{limit}'").as_str(), + ) + .await + .unwrap() + .collect() + .await + .unwrap(); + }; + + let get_limit = |ctx: &SessionContext| -> Duration { + ctx.task_ctx() + .runtime_env() + .cache_manager + .get_list_files_cache() + .unwrap() + .cache_ttl() + .unwrap() + }; + + update_limit(&ctx, "1m").await; + assert_eq!(get_limit(&ctx), Duration::from_secs(60)); + + update_limit(&ctx, "30s").await; + assert_eq!(get_limit(&ctx), Duration::from_secs(30)); + + update_limit(&ctx, "1m30s").await; + assert_eq!(get_limit(&ctx), Duration::from_secs(90)); } #[tokio::test] diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index f874dd7c08428..6126793145efd 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -15,8 +15,11 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashMap; + use super::*; -use datafusion_common::ScalarValue; +use datafusion_common::{ParamValues, ScalarValue, metadata::ScalarAndMetadata}; +use insta::assert_snapshot; #[tokio::test] async fn test_list_query_parameters() -> Result<()> { @@ -217,10 +220,12 @@ async fn test_parameter_invalid_types() -> Result<()> { .with_param_values(vec![ScalarValue::from(4_i32)])? .collect() .await; - assert_eq!( - results.unwrap_err().strip_backtrace(), - "type_coercion\ncaused by\nError during planning: Cannot infer common argument type for comparison operation List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) = Int32" -); + assert_snapshot!(results.unwrap_err().strip_backtrace(), + @r" + type_coercion + caused by + Error during planning: Cannot infer common argument type for comparison operation List(Int32) = Int32 + "); Ok(()) } @@ -314,6 +319,47 @@ async fn test_named_parameter_not_bound() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_query_parameters_with_metadata() -> Result<()> { + let ctx = SessionContext::new(); + + let df = ctx.sql("SELECT $1, $2").await.unwrap(); + + let metadata1 = HashMap::from([("some_key".to_string(), "some_value".to_string())]); + let metadata2 = + HashMap::from([("some_other_key".to_string(), "some_other_value".to_string())]); + + let df_with_params_replaced = df + .with_param_values(ParamValues::List(vec![ + ScalarAndMetadata::new( + ScalarValue::UInt32(Some(1)), + Some(metadata1.clone().into()), + ), + ScalarAndMetadata::new( + ScalarValue::Utf8(Some("two".to_string())), + Some(metadata2.clone().into()), + ), + ])) + .unwrap(); + + let schema = df_with_params_replaced.schema(); + assert_eq!(schema.field(0).data_type(), &DataType::UInt32); + assert_eq!(schema.field(0).metadata(), &metadata1); + assert_eq!(schema.field(1).data_type(), &DataType::Utf8); + assert_eq!(schema.field(1).metadata(), &metadata2); + + let batches = df_with_params_replaced.collect().await.unwrap(); + assert_snapshot!(batches_to_sort_string(&batches), @r" + +----+-----+ + | $1 | $2 | + +----+-----+ + | 1 | two | + +----+-----+ + "); + + Ok(()) +} + #[tokio::test] async fn test_version_function() { let expected_version = format!( @@ -343,3 +389,45 @@ async fn test_version_function() { assert_eq!(version.value(0), expected_version); } + +/// Regression test for https://github.com/apache/datafusion/issues/17513 +/// See https://github.com/apache/datafusion/pull/17520 +#[tokio::test] +async fn test_select_no_projection() -> Result<()> { + let tmp_dir = TempDir::new()?; + // `create_ctx_with_partition` creates 10 rows per partition and we chose 1 partition + let ctx = create_ctx_with_partition(&tmp_dir, 1).await?; + + let results = ctx.sql("SELECT FROM test").await?.collect().await?; + // We should get all of the rows, just without any columns + let total_rows: usize = results.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 10); + // Check that none of the batches have any columns + for batch in &results { + assert_eq!(batch.num_columns(), 0); + } + // Sanity check the output, should be just empty columns + assert_snapshot!(batches_to_sort_string(&results), @r" + ++ + ++ + ++ + "); + Ok(()) +} + +#[tokio::test] +async fn test_select_cast_date_literal_to_timestamp_overflow() -> Result<()> { + let ctx = SessionContext::new(); + let err = ctx + .sql("SELECT CAST(DATE '9999-12-31' AS TIMESTAMP)") + .await? + .collect() + .await + .unwrap_err(); + + assert_contains!( + err.to_string(), + "Cannot cast Date32 value 2932896 to Timestamp(ns): converted value exceeds the representable i64 range" + ); + Ok(()) +} diff --git a/datafusion/core/tests/sql/sql_api.rs b/datafusion/core/tests/sql/sql_api.rs index ec086bcc50c76..b87afd27ddea7 100644 --- a/datafusion/core/tests/sql/sql_api.rs +++ b/datafusion/core/tests/sql/sql_api.rs @@ -84,8 +84,8 @@ async fn dml_output_schema() { ctx.sql("CREATE TABLE test (x int)").await.unwrap(); let sql = "INSERT INTO test VALUES (1)"; let df = ctx.sql(sql).await.unwrap(); - let count_schema = Schema::new(vec![Field::new("count", DataType::UInt64, false)]); - assert_eq!(Schema::from(df.schema()), count_schema); + let count_schema = &Schema::new(vec![Field::new("count", DataType::UInt64, false)]); + assert_eq!(df.schema().as_arrow(), count_schema); } #[tokio::test] diff --git a/datafusion/core/tests/sql/unparser.rs b/datafusion/core/tests/sql/unparser.rs new file mode 100644 index 0000000000000..e9bad71843ff2 --- /dev/null +++ b/datafusion/core/tests/sql/unparser.rs @@ -0,0 +1,456 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! SQL Unparser Roundtrip Integration Tests +//! +//! This module tests the [`Unparser`] by running queries through a complete roundtrip: +//! the original SQL is parsed into a logical plan, unparsed back to SQL, then that +//! generated SQL is parsed and executed. The results are compared to verify semantic +//! equivalence. +//! +//! ## Test Strategy +//! +//! Uses real-world benchmark queries (TPC-H and Clickbench) to validate that: +//! 1. The unparser produces syntactically valid SQL +//! 2. The unparsed SQL is semantically equivalent (produces identical results) +//! +//! ## Query Suites +//! +//! - **TPC-H**: Standard decision-support benchmark with 22 complex analytical queries +//! - **Clickbench**: Web analytics benchmark with 43 queries against a denormalized schema +//! +//! [`Unparser`]: datafusion_sql::unparser::Unparser + +use std::fs::ReadDir; +use std::future::Future; + +use arrow::array::RecordBatch; +use datafusion::common::Result; +use datafusion::prelude::{ParquetReadOptions, SessionContext}; +use datafusion_common::Column; +use datafusion_expr::Expr; +use datafusion_sql::unparser::Unparser; +use datafusion_sql::unparser::dialect::DefaultDialect; +use itertools::Itertools; +use recursive::{set_minimum_stack_size, set_stack_allocation_size}; + +/// Paths to benchmark query files (supports running from repo root or different working directories). +const BENCHMARK_PATHS: &[&str] = &["../../benchmarks/", "./benchmarks/"]; + +/// Reads all `.sql` files from a directory and converts them to test queries. +/// +/// Skips files that: +/// - Are not regular files +/// - Don't have a `.sql` extension +/// - Contain multiple SQL statements (indicated by `;\n`) +/// +/// Multi-statement files are skipped because the unparser doesn't support +/// DML statements like `CREATE VIEW` that appear in multi-statement Clickbench queries. +fn iterate_queries(dir: ReadDir) -> Vec { + let mut queries = vec![]; + for entry in dir.flatten() { + let Ok(file_type) = entry.file_type() else { + continue; + }; + if !file_type.is_file() { + continue; + } + let path = entry.path(); + let Some(ext) = path.extension() else { + continue; + }; + if ext != "sql" { + continue; + } + let name = path.file_stem().unwrap().to_string_lossy().to_string(); + if let Ok(mut contents) = std::fs::read_to_string(entry.path()) { + // If the query contains ;\n it has DML statements like CREATE VIEW which the unparser doesn't support; skip it + contents = contents.trim().to_string(); + if contents.contains(";\n") { + println!("Skipping query with multiple statements: {name}"); + continue; + } + queries.push(TestQuery { + sql: contents, + name, + }); + } + } + queries +} + +/// A SQL query loaded from a benchmark file for roundtrip testing. +/// +/// Each query is identified by its filename (without extension) and contains +/// the full SQL text to be tested. +struct TestQuery { + /// The SQL query text to test. + sql: String, + /// The query identifier (typically the filename without .sql extension). + name: String, +} + +/// Collect SQL for Clickbench queries. +fn clickbench_queries() -> Vec { + let mut queries = vec![]; + for path in BENCHMARK_PATHS { + let dir = format!("{path}queries/clickbench/queries/"); + println!("Reading Clickbench queries from {dir}"); + if let Ok(dir) = std::fs::read_dir(dir) { + let read = iterate_queries(dir); + println!("Found {} Clickbench queries", read.len()); + queries.extend(read); + } + } + queries.sort_unstable_by_key(|q| { + q.name + .split('q') + .next_back() + .and_then(|num| num.parse::().ok()) + }); + queries +} + +/// Collect SQL for TPC-H queries. +fn tpch_queries() -> Vec { + let mut queries = vec![]; + for path in BENCHMARK_PATHS { + let dir = format!("{path}queries/"); + println!("Reading TPC-H queries from {dir}"); + if let Ok(dir) = std::fs::read_dir(dir) { + let read = iterate_queries(dir); + queries.extend(read); + } + } + println!("Total TPC-H queries found: {}", queries.len()); + queries.sort_unstable_by_key(|q| q.name.clone()); + queries +} + +/// Create a new SessionContext for testing that has all Clickbench tables registered. +async fn clickbench_test_context() -> Result { + let ctx = SessionContext::new(); + ctx.register_parquet( + "hits", + "tests/data/clickbench_hits_10.parquet", + ParquetReadOptions::default(), + ) + .await?; + // Sanity check we found the table by querying it's schema, it should not be empty + // Otherwise if the path is wrong the tests will all fail in confusing ways + let df = ctx.sql("SELECT * FROM hits LIMIT 1").await?; + assert!( + !df.schema().fields().is_empty(), + "Clickbench 'hits' table not registered correctly" + ); + Ok(ctx) +} + +/// Create a new SessionContext for testing that has all TPC-H tables registered. +async fn tpch_test_context() -> Result { + let ctx = SessionContext::new(); + let data_dir = "tests/data/"; + // All tables have the pattern "tpch__small.parquet" + for table in [ + "customer", "lineitem", "nation", "orders", "part", "partsupp", "region", + "supplier", + ] { + let path = format!("{data_dir}tpch_{table}_small.parquet"); + ctx.register_parquet(table, &path, ParquetReadOptions::default()) + .await?; + // Sanity check we found the table by querying it's schema, it should not be empty + // Otherwise if the path is wrong the tests will all fail in confusing ways + let df = ctx.sql(&format!("SELECT * FROM {table} LIMIT 1")).await?; + assert!( + !df.schema().fields().is_empty(), + "TPC-H '{table}' table not registered correctly" + ); + } + Ok(ctx) +} + +/// Sorts record batches by all columns for deterministic comparison. +/// +/// When comparing query results, we need a canonical ordering so that +/// semantically equivalent results compare as equal. This function sorts +/// by all columns in the schema to achieve that. +async fn sort_batches( + ctx: &SessionContext, + batches: Vec, +) -> Result> { + let mut df = ctx.read_batches(batches)?; + let schema = df.schema().as_arrow().clone(); + let sort_exprs = schema + .fields() + .iter() + // Use Column directly, col() causes the column names to be normalized to lowercase + .map(|f| { + Expr::Column(Column::new_unqualified(f.name().to_string())).sort(true, false) + }) + .collect_vec(); + if !sort_exprs.is_empty() { + df = df.sort(sort_exprs)?; + } + df.collect().await +} + +/// The outcome of running a single roundtrip test. +/// +/// A successful test produces [`TestCaseResult::Success`]. +/// All other variants capture different failure modes with enough context to diagnose the issue. +enum TestCaseResult { + /// The unparsed SQL produced identical results to the original. + Success, + + /// Both queries executed but produced different results. + /// + /// This indicates a semantic bug in the unparser where the generated SQL + /// has different meaning than the original. + ResultsMismatch { original: String, unparsed: String }, + + /// The unparser failed to convert the logical plan to SQL. + /// + /// This may indicate an unsupported SQL feature or a bug in the unparser. + UnparseError { original: String, error: String }, + + /// The original SQL failed to execute. + /// + /// This indicates a problem with the test setup (missing tables, + /// invalid test data) rather than an unparser issue. + ExecutionError { original: String, error: String }, + + /// The unparsed SQL failed to execute, even though the original succeeded. + /// + /// This indicates the unparser generated syntactically invalid SQL or SQL + /// that references non-existent columns/tables. + UnparsedExecutionError { + original: String, + unparsed: String, + error: String, + }, +} + +impl TestCaseResult { + /// Returns true if the test case represents a failure + /// (anything other than [`TestCaseResult::Success`]). + fn is_failure(&self) -> bool { + !matches!(self, TestCaseResult::Success) + } + + /// Formats a detailed error message for the test case into a string. + fn format_error(&self, name: &str) -> String { + match self { + TestCaseResult::Success => String::new(), + TestCaseResult::ResultsMismatch { original, unparsed } => { + format!( + "Results mismatch for {name}.\nOriginal SQL:\n{original}\n\nUnparsed SQL:\n{unparsed}" + ) + } + TestCaseResult::UnparseError { original, error } => { + format!("Unparse error for {name}: {error}\nOriginal SQL:\n{original}") + } + TestCaseResult::ExecutionError { original, error } => { + format!("Execution error for {name}: {error}\nOriginal SQL:\n{original}") + } + TestCaseResult::UnparsedExecutionError { + original, + unparsed, + error, + } => { + format!( + "Unparsed execution error for {name}: {error}\nOriginal SQL:\n{original}\n\nUnparsed SQL:\n{unparsed}" + ) + } + } + } +} + +/// Executes a roundtrip test for a single SQL query. +/// +/// This is the core test logic that: +/// 1. Parses the original SQL and creates a logical plan +/// 2. Unparses the logical plan back to SQL +/// 3. Executes both the original and unparsed queries +/// 4. Compares the results (sorting if the query has no ORDER BY) +/// +/// This always uses [`DefaultDialect`] for unparsing. +/// +/// # Arguments +/// +/// * `ctx` - Session context with tables registered +/// * `original` - The original SQL query to test +/// +/// # Returns +/// +/// A [`TestCaseResult`] indicating success or the specific failure mode. +async fn collect_results(ctx: &SessionContext, original: &str) -> TestCaseResult { + let unparser = Unparser::new(&DefaultDialect {}); + + // Parse and create logical plan from original SQL + let df = match ctx.sql(original).await { + Ok(df) => df, + Err(e) => { + return TestCaseResult::ExecutionError { + original: original.to_string(), + error: e.to_string(), + }; + } + }; + + // Unparse the logical plan back to SQL + let unparsed = match unparser.plan_to_sql(df.logical_plan()) { + Ok(sql) => format!("{sql:#}"), + Err(e) => { + return TestCaseResult::UnparseError { + original: original.to_string(), + error: e.to_string(), + }; + } + }; + + // Collect results from original query + let mut expected = match df.collect().await { + Ok(batches) => batches, + Err(e) => { + return TestCaseResult::ExecutionError { + original: original.to_string(), + error: e.to_string(), + }; + } + }; + + // Parse and execute the unparsed SQL + let actual_df = match ctx.sql(&unparsed).await { + Ok(df) => df, + Err(e) => { + return TestCaseResult::UnparsedExecutionError { + original: original.to_string(), + unparsed, + error: e.to_string(), + }; + } + }; + + // Collect results from unparsed query + let mut actual = match actual_df.collect().await { + Ok(batches) => batches, + Err(e) => { + return TestCaseResult::UnparsedExecutionError { + original: original.to_string(), + unparsed, + error: e.to_string(), + }; + } + }; + + // Always sort for deterministic comparison — even "sorted" results can have + // tied rows in different order between original and unparsed SQL. + { + expected = match sort_batches(ctx, expected).await { + Ok(batches) => batches, + Err(e) => { + return TestCaseResult::ExecutionError { + original: original.to_string(), + error: format!("Failed to sort expected results: {e}"), + }; + } + }; + actual = match sort_batches(ctx, actual).await { + Ok(batches) => batches, + Err(e) => { + return TestCaseResult::UnparsedExecutionError { + original: original.to_string(), + unparsed, + error: format!("Failed to sort actual results: {e}"), + }; + } + }; + } + + if expected != actual { + TestCaseResult::ResultsMismatch { + original: original.to_string(), + unparsed, + } + } else { + TestCaseResult::Success + } +} + +/// Runs roundtrip tests for a collection of queries and reports results. +/// +/// Iterates through all queries, running each through [`collect_results`]. +/// Prints colored status (green checkmark for success, red X for failure) +/// and panics at the end if any tests failed, with detailed error messages. +/// +/// # Type Parameters +/// +/// * `F` - Factory function that creates fresh session contexts +/// * `Fut` - Future type returned by the context factory +/// +/// # Panics +/// +/// Panics if any query fails the roundtrip test, displaying all failures. +async fn run_roundtrip_tests( + suite_name: &str, + queries: Vec, + create_context: F, +) where + F: Fn() -> Fut, + Fut: Future>, +{ + let mut errors: Vec = vec![]; + for sql in queries { + let ctx = match create_context().await { + Ok(ctx) => ctx, + Err(e) => { + println!("\x1b[31m✗\x1b[0m {} query: {}", suite_name, sql.name); + errors.push(format!("Failed to create context for {}: {}", sql.name, e)); + continue; + } + }; + let result = collect_results(&ctx, &sql.sql).await; + if result.is_failure() { + println!("\x1b[31m✗\x1b[0m {} query: {}", suite_name, sql.name); + errors.push(result.format_error(&sql.name)); + } else { + println!("\x1b[32m✓\x1b[0m {} query: {}", suite_name, sql.name); + } + } + if !errors.is_empty() { + panic!( + "{} {} test(s) failed:\n\n{}", + errors.len(), + suite_name, + errors.join("\n\n---\n\n") + ); + } +} + +#[tokio::test] +async fn test_clickbench_unparser_roundtrip() { + run_roundtrip_tests("Clickbench", clickbench_queries(), clickbench_test_context) + .await; +} + +#[tokio::test] +async fn test_tpch_unparser_roundtrip() { + // Grow stacker segments earlier to avoid deep unparser recursion overflow in q20. + set_minimum_stack_size(512 * 1024); + set_stack_allocation_size(8 * 1024 * 1024); + run_roundtrip_tests("TPC-H", tpch_queries(), tpch_test_context).await; +} diff --git a/datafusion/core/tests/test_adapter_updated.rs b/datafusion/core/tests/test_adapter_updated.rs deleted file mode 100644 index c85b9a3447361..0000000000000 --- a/datafusion/core/tests/test_adapter_updated.rs +++ /dev/null @@ -1,214 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use arrow::record_batch::RecordBatch; -use datafusion_common::{ColumnStatistics, DataFusionError, Result, Statistics}; -use datafusion_datasource::file::FileSource; -use datafusion_datasource::file_scan_config::FileScanConfig; -use datafusion_datasource::file_stream::FileOpener; -use datafusion_datasource::schema_adapter::{ - SchemaAdapter, SchemaAdapterFactory, SchemaMapper, -}; -use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; -use object_store::ObjectStore; -use std::any::Any; -use std::fmt::Debug; -use std::sync::Arc; - -/// A test source for testing schema adapters -#[derive(Debug, Clone)] -struct TestSource { - schema_adapter_factory: Option>, -} - -impl TestSource { - fn new() -> Self { - Self { - schema_adapter_factory: None, - } - } -} - -impl FileSource for TestSource { - fn file_type(&self) -> &str { - "test" - } - - fn as_any(&self) -> &dyn Any { - self - } - - fn create_file_opener( - &self, - _store: Arc, - _conf: &FileScanConfig, - _index: usize, - ) -> Arc { - unimplemented!("Not needed for this test") - } - - fn with_batch_size(&self, _batch_size: usize) -> Arc { - Arc::new(self.clone()) - } - - fn with_schema(&self, _schema: SchemaRef) -> Arc { - Arc::new(self.clone()) - } - - fn with_projection(&self, _projection: &FileScanConfig) -> Arc { - Arc::new(self.clone()) - } - - fn with_statistics(&self, _statistics: Statistics) -> Arc { - Arc::new(self.clone()) - } - - fn metrics(&self) -> &ExecutionPlanMetricsSet { - unimplemented!("Not needed for this test") - } - - fn statistics(&self) -> Result { - Ok(Statistics::default()) - } - - fn with_schema_adapter_factory( - &self, - schema_adapter_factory: Arc, - ) -> Arc { - Arc::new(Self { - schema_adapter_factory: Some(schema_adapter_factory), - }) - } - - fn schema_adapter_factory(&self) -> Option> { - self.schema_adapter_factory.clone() - } -} - -/// A test schema adapter factory -#[derive(Debug)] -struct TestSchemaAdapterFactory {} - -impl SchemaAdapterFactory for TestSchemaAdapterFactory { - fn create( - &self, - projected_table_schema: SchemaRef, - _table_schema: SchemaRef, - ) -> Box { - Box::new(TestSchemaAdapter { - table_schema: projected_table_schema, - }) - } -} - -/// A test schema adapter implementation -#[derive(Debug)] -struct TestSchemaAdapter { - table_schema: SchemaRef, -} - -impl SchemaAdapter for TestSchemaAdapter { - fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option { - let field = self.table_schema.field(index); - file_schema.fields.find(field.name()).map(|(i, _)| i) - } - - fn map_schema( - &self, - file_schema: &Schema, - ) -> Result<(Arc, Vec)> { - let mut projection = Vec::with_capacity(file_schema.fields().len()); - for (file_idx, file_field) in file_schema.fields().iter().enumerate() { - if self.table_schema.fields().find(file_field.name()).is_some() { - projection.push(file_idx); - } - } - - Ok((Arc::new(TestSchemaMapping {}), projection)) - } -} - -/// A test schema mapper implementation -#[derive(Debug)] -struct TestSchemaMapping {} - -impl SchemaMapper for TestSchemaMapping { - fn map_batch(&self, batch: RecordBatch) -> Result { - // For testing, just return the original batch - Ok(batch) - } - - fn map_column_statistics( - &self, - stats: &[ColumnStatistics], - ) -> Result> { - // For testing, just return the input statistics - Ok(stats.to_vec()) - } -} - -#[test] -fn test_schema_adapter() { - // This test verifies the functionality of the SchemaAdapter and SchemaAdapterFactory - // components used in DataFusion's file sources. - // - // The test specifically checks: - // 1. Creating and attaching a schema adapter factory to a file source - // 2. Creating a schema adapter using the factory - // 3. The schema adapter's ability to map column indices between a table schema and a file schema - // 4. The schema adapter's ability to create a projection that selects only the columns - // from the file schema that are present in the table schema - // - // Schema adapters are used when the schema of data in files doesn't exactly match - // the schema expected by the query engine, allowing for field mapping and data transformation. - - // Create a test schema - let table_schema = Arc::new(Schema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("name", DataType::Utf8, true), - ])); - - // Create a file schema - let file_schema = Schema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("name", DataType::Utf8, true), - Field::new("extra", DataType::Int64, true), - ]); - - // Create a TestSource - let source = TestSource::new(); - assert!(source.schema_adapter_factory().is_none()); - - // Add a schema adapter factory - let factory = Arc::new(TestSchemaAdapterFactory {}); - let source_with_adapter = source.with_schema_adapter_factory(factory); - assert!(source_with_adapter.schema_adapter_factory().is_some()); - - // Create a schema adapter - let adapter_factory = source_with_adapter.schema_adapter_factory().unwrap(); - let adapter = - adapter_factory.create(Arc::clone(&table_schema), Arc::clone(&table_schema)); - - // Test mapping column index - assert_eq!(adapter.map_column_index(0, &file_schema), Some(0)); - assert_eq!(adapter.map_column_index(1, &file_schema), Some(1)); - - // Test creating schema mapper - let (_mapper, projection) = adapter.map_schema(&file_schema).unwrap(); - assert_eq!(projection, vec![0, 1]); -} diff --git a/datafusion/core/tests/tpc-ds/30.sql b/datafusion/core/tests/tpc-ds/30.sql index 78f34b807e5b5..80624f49006a9 100644 --- a/datafusion/core/tests/tpc-ds/30.sql +++ b/datafusion/core/tests/tpc-ds/30.sql @@ -14,7 +14,7 @@ with customer_total_return as ,ca_state) select c_customer_id,c_salutation,c_first_name,c_last_name,c_preferred_cust_flag ,c_birth_day,c_birth_month,c_birth_year,c_birth_country,c_login,c_email_address - ,c_last_review_date_sk,ctr_total_return + ,c_last_review_date,ctr_total_return from customer_total_return ctr1 ,customer_address ,customer @@ -26,7 +26,7 @@ with customer_total_return as and ctr1.ctr_customer_sk = c_customer_sk order by c_customer_id,c_salutation,c_first_name,c_last_name,c_preferred_cust_flag ,c_birth_day,c_birth_month,c_birth_year,c_birth_country,c_login,c_email_address - ,c_last_review_date_sk,ctr_total_return + ,c_last_review_date,ctr_total_return limit 100; diff --git a/datafusion/core/tests/tpcds_planning.rs b/datafusion/core/tests/tpcds_planning.rs index 252d76d0f9d92..3ad74962bc2c0 100644 --- a/datafusion/core/tests/tpcds_planning.rs +++ b/datafusion/core/tests/tpcds_planning.rs @@ -1052,9 +1052,12 @@ async fn regression_test(query_no: u8, create_physical: bool) -> Result<()> { for sql in &sql { let df = ctx.sql(sql).await?; let (state, plan) = df.into_parts(); - let plan = state.optimize(&plan)?; if create_physical { let _ = state.create_physical_plan(&plan).await?; + } else { + // Run the logical optimizer even if we are not creating the physical plan + // to ensure it will properly succeed + let _ = state.optimize(&plan)?; } } diff --git a/datafusion/core/tests/tracing/asserting_tracer.rs b/datafusion/core/tests/tracing/asserting_tracer.rs index 292e066e5f121..700f9f3308466 100644 --- a/datafusion/core/tests/tracing/asserting_tracer.rs +++ b/datafusion/core/tests/tracing/asserting_tracer.rs @@ -21,7 +21,7 @@ use std::ops::Deref; use std::sync::{Arc, LazyLock}; use datafusion_common::{HashMap, HashSet}; -use datafusion_common_runtime::{set_join_set_tracer, JoinSetTracer}; +use datafusion_common_runtime::{JoinSetTracer, set_join_set_tracer}; use futures::future::BoxFuture; use tokio::sync::{Mutex, MutexGuard}; diff --git a/datafusion/core/tests/tracing/mod.rs b/datafusion/core/tests/tracing/mod.rs index df8a28c021d1c..0b66a49eea9f4 100644 --- a/datafusion/core/tests/tracing/mod.rs +++ b/datafusion/core/tests/tracing/mod.rs @@ -76,7 +76,13 @@ async fn run_query() { let ctx = SessionContext::new(); // Get the test data directory - let test_data = parquet_test_data(); + let test_data = if cfg!(target_os = "windows") { + // Prefix Windows paths with "/", since they start with :/ but the URI should be + // test:///C:/... (https://datatracker.ietf.org/doc/html/rfc8089#appendix-E.2) + format!("/{}", parquet_test_data()) + } else { + parquet_test_data() + }; // Define a Parquet file format with pruning enabled let file_format = ParquetFormat::default().with_enable_pruning(true); diff --git a/datafusion/core/tests/tracing/traceable_object_store.rs b/datafusion/core/tests/tracing/traceable_object_store.rs index dfcafc3a63da1..71a61dbf8772a 100644 --- a/datafusion/core/tests/tracing/traceable_object_store.rs +++ b/datafusion/core/tests/tracing/traceable_object_store.rs @@ -18,10 +18,11 @@ //! Object store implementation used for testing use crate::tracing::asserting_tracer::assert_traceability; +use futures::StreamExt; use futures::stream::BoxStream; use object_store::{ - path::Path, GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, - ObjectStore, PutMultipartOpts, PutOptions, PutPayload, PutResult, + CopyOptions, GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, + ObjectStore, PutMultipartOptions, PutOptions, PutPayload, PutResult, path::Path, }; use std::fmt::{Debug, Display, Formatter}; use std::sync::Arc; @@ -68,7 +69,7 @@ impl ObjectStore for TraceableObjectStore { async fn put_multipart_opts( &self, location: &Path, - opts: PutMultipartOpts, + opts: PutMultipartOptions, ) -> object_store::Result> { assert_traceability().await; self.inner.put_multipart_opts(location, opts).await @@ -83,14 +84,17 @@ impl ObjectStore for TraceableObjectStore { self.inner.get_opts(location, options).await } - async fn head(&self, location: &Path) -> object_store::Result { - assert_traceability().await; - self.inner.head(location).await - } - - async fn delete(&self, location: &Path) -> object_store::Result<()> { - assert_traceability().await; - self.inner.delete(location).await + fn delete_stream( + &self, + locations: BoxStream<'static, object_store::Result>, + ) -> BoxStream<'static, object_store::Result> { + self.inner + .delete_stream(locations) + .then(|res| async { + futures::executor::block_on(assert_traceability()); + res + }) + .boxed() } fn list( @@ -109,17 +113,13 @@ impl ObjectStore for TraceableObjectStore { self.inner.list_with_delimiter(prefix).await } - async fn copy(&self, from: &Path, to: &Path) -> object_store::Result<()> { - assert_traceability().await; - self.inner.copy(from, to).await - } - - async fn copy_if_not_exists( + async fn copy_opts( &self, from: &Path, to: &Path, + options: CopyOptions, ) -> object_store::Result<()> { assert_traceability().await; - self.inner.copy_if_not_exists(from, to).await + self.inner.copy_opts(from, to, options).await } } diff --git a/datafusion/core/tests/user_defined/expr_planner.rs b/datafusion/core/tests/user_defined/expr_planner.rs index 1fc6d14c5b229..c5e5af731359f 100644 --- a/datafusion/core/tests/user_defined/expr_planner.rs +++ b/datafusion/core/tests/user_defined/expr_planner.rs @@ -26,9 +26,9 @@ use datafusion::logical_expr::Operator; use datafusion::prelude::*; use datafusion::sql::sqlparser::ast::BinaryOperator; use datafusion_common::ScalarValue; +use datafusion_expr::BinaryExpr; use datafusion_expr::expr::Alias; use datafusion_expr::planner::{ExprPlanner, PlannerResult, RawBinaryExpr}; -use datafusion_expr::BinaryExpr; #[derive(Debug)] struct MyCustomPlanner; @@ -56,7 +56,7 @@ impl ExprPlanner for MyCustomPlanner { } BinaryOperator::Question => { Ok(PlannerResult::Planned(Expr::Alias(Alias::new( - Expr::Literal(ScalarValue::Boolean(Some(true))), + Expr::Literal(ScalarValue::Boolean(Some(true)), None), None::<&str>, format!("{} ? {}", expr.left, expr.right), )))) @@ -77,25 +77,25 @@ async fn plan_and_collect(sql: &str) -> Result> { #[tokio::test] async fn test_custom_operators_arrow() { let actual = plan_and_collect("select 'foo'->'bar';").await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r#" +----------------------------+ | Utf8("foo") || Utf8("bar") | +----------------------------+ | foobar | +----------------------------+ - "###); + "#); } #[tokio::test] async fn test_custom_operators_long_arrow() { let actual = plan_and_collect("select 1->>2;").await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +---------------------+ | Int64(1) + Int64(2) | +---------------------+ | 3 | +---------------------+ - "###); + "); } #[tokio::test] @@ -103,13 +103,13 @@ async fn test_question_select() { let actual = plan_and_collect("select a ? 2 from (select 1 as a);") .await .unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +--------------+ | a ? Int64(2) | +--------------+ | true | +--------------+ - "###); + "); } #[tokio::test] @@ -117,11 +117,11 @@ async fn test_question_filter() { let actual = plan_and_collect("select a from (select 1 as a) where a ? 2;") .await .unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +---+ | a | +---+ | 1 | +---+ - "###); + "); } diff --git a/datafusion/core/tests/user_defined/insert_operation.rs b/datafusion/core/tests/user_defined/insert_operation.rs index 12f700ce572ba..2a2aed82f0af3 100644 --- a/datafusion/core/tests/user_defined/insert_operation.rs +++ b/datafusion/core/tests/user_defined/insert_operation.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::{any::Any, sync::Arc}; +use std::{any::Any, str::FromStr, sync::Arc}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use async_trait::async_trait; @@ -24,11 +24,14 @@ use datafusion::{ prelude::{SessionConfig, SessionContext}, }; use datafusion_catalog::{Session, TableProvider}; -use datafusion_expr::{dml::InsertOp, Expr, TableType}; +use datafusion_common::config::Dialect; +use datafusion_common::tree_node::TreeNodeRecursion; +use datafusion_expr::{Expr, TableType, dml::InsertOp}; use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; +use datafusion_physical_plan::execution_plan::SchedulingType; use datafusion_physical_plan::{ - execution_plan::{Boundedness, EmissionType}, DisplayAs, ExecutionPlan, PlanProperties, + execution_plan::{Boundedness, EmissionType}, }; #[tokio::test] @@ -62,7 +65,7 @@ async fn assert_insert_op(ctx: &SessionContext, sql: &str, insert_op: InsertOp) fn session_ctx_with_dialect(dialect: impl Into) -> SessionContext { let mut config = SessionConfig::new(); let options = config.options_mut(); - options.sql_parser.dialect = dialect.into(); + options.sql_parser.dialect = Dialect::from_str(&dialect.into()).unwrap(); SessionContext::new_with_config(config) } @@ -120,18 +123,21 @@ impl TableProvider for TestInsertTableProvider { #[derive(Debug)] struct TestInsertExec { op: InsertOp, - plan_properties: PlanProperties, + plan_properties: Arc, } impl TestInsertExec { fn new(op: InsertOp) -> Self { Self { op, - plan_properties: PlanProperties::new( - EquivalenceProperties::new(make_count_schema()), - Partitioning::UnknownPartitioning(1), - EmissionType::Incremental, - Boundedness::Bounded, + plan_properties: Arc::new( + PlanProperties::new( + EquivalenceProperties::new(make_count_schema()), + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + ) + .with_scheduling_type(SchedulingType::Cooperative), ), } } @@ -156,7 +162,7 @@ impl ExecutionPlan for TestInsertExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.plan_properties } @@ -179,6 +185,22 @@ impl ExecutionPlan for TestInsertExec { ) -> Result { unimplemented!("TestInsertExec is a stub for testing.") } + + fn apply_expressions( + &self, + f: &mut dyn FnMut( + &dyn datafusion_physical_plan::PhysicalExpr, + ) -> Result, + ) -> Result { + // Visit expressions in the output ordering from equivalence properties + let mut tnr = TreeNodeRecursion::Continue; + if let Some(ordering) = self.plan_properties.output_ordering() { + for sort_expr in ordering { + tnr = tnr.visit_sibling(|| f(sort_expr.expr.as_ref()))?; + } + } + Ok(tnr) + } } fn make_count_schema() -> SchemaRef { diff --git a/datafusion/core/tests/user_defined/mod.rs b/datafusion/core/tests/user_defined/mod.rs index 5d84cdb692830..bc9949f5d681c 100644 --- a/datafusion/core/tests/user_defined/mod.rs +++ b/datafusion/core/tests/user_defined/mod.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +/// Tests for user defined Async Scalar functions +mod user_defined_async_scalar_functions; + /// Tests for user defined Scalar functions mod user_defined_scalar_functions; @@ -33,5 +36,8 @@ mod user_defined_table_functions; /// Tests for Expression Planner mod expr_planner; +/// Tests for Relation Planner extensions +mod relation_planner; + /// Tests for insert operations mod insert_operation; diff --git a/datafusion/core/tests/user_defined/relation_planner.rs b/datafusion/core/tests/user_defined/relation_planner.rs new file mode 100644 index 0000000000000..54af53ad858d4 --- /dev/null +++ b/datafusion/core/tests/user_defined/relation_planner.rs @@ -0,0 +1,531 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for the RelationPlanner extension point + +use std::sync::Arc; + +use arrow::array::{Int64Array, RecordBatch, StringArray}; +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion::catalog::memory::MemTable; +use datafusion::common::test_util::batches_to_string; +use datafusion::prelude::*; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::Expr; +use datafusion_expr::logical_plan::builder::LogicalPlanBuilder; +use datafusion_expr::planner::{ + PlannedRelation, RelationPlanner, RelationPlannerContext, RelationPlanning, +}; +use datafusion_sql::sqlparser::ast::TableFactor; +use insta::assert_snapshot; + +// ============================================================================ +// Test Planners - Example Implementations +// ============================================================================ + +// The planners in this section are deliberately minimal, static examples used +// only for tests. In real applications a `RelationPlanner` would typically +// construct richer logical plans tailored to external systems or custom +// semantics rather than hard-coded in-memory tables. +// +// For more realistic examples, see `datafusion-examples/examples/relation_planner/`: +// - `table_sample.rs`: Full TABLESAMPLE implementation (parsing → execution) +// - `pivot_unpivot.rs`: PIVOT/UNPIVOT via SQL rewriting +// - `match_recognize.rs`: MATCH_RECOGNIZE logical planning + +/// Helper to build simple static values-backed virtual tables used by the +/// example planners below. +fn plan_static_values_table( + relation: TableFactor, + table_name: &str, + column_name: &str, + values: Vec, +) -> Result { + match relation { + TableFactor::Table { name, alias, .. } + if name.to_string().eq_ignore_ascii_case(table_name) => + { + let rows = values + .into_iter() + .map(|v| vec![Expr::Literal(v, None)]) + .collect::>(); + + let plan = LogicalPlanBuilder::values(rows)? + .project(vec![col("column1").alias(column_name)])? + .build()?; + + Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))) + } + other => Ok(RelationPlanning::Original(Box::new(other))), + } +} + +/// Example planner that provides a virtual `numbers` table with values +/// 1, 2, 3. +#[derive(Debug)] +struct NumbersPlanner; + +impl RelationPlanner for NumbersPlanner { + fn plan_relation( + &self, + relation: TableFactor, + _context: &mut dyn RelationPlannerContext, + ) -> Result { + plan_static_values_table( + relation, + "numbers", + "number", + vec![ + ScalarValue::Int64(Some(1)), + ScalarValue::Int64(Some(2)), + ScalarValue::Int64(Some(3)), + ], + ) + } +} + +/// Example planner that provides a virtual `colors` table with three string +/// values: `red`, `green`, `blue`. +#[derive(Debug)] +struct ColorsPlanner; + +impl RelationPlanner for ColorsPlanner { + fn plan_relation( + &self, + relation: TableFactor, + _context: &mut dyn RelationPlannerContext, + ) -> Result { + plan_static_values_table( + relation, + "colors", + "color", + vec![ + ScalarValue::Utf8(Some("red".into())), + ScalarValue::Utf8(Some("green".into())), + ScalarValue::Utf8(Some("blue".into())), + ], + ) + } +} + +/// Alternative implementation of `numbers` (returns 100, 200) used to +/// demonstrate planner precedence (last registered planner wins). +#[derive(Debug)] +struct AlternativeNumbersPlanner; + +impl RelationPlanner for AlternativeNumbersPlanner { + fn plan_relation( + &self, + relation: TableFactor, + _context: &mut dyn RelationPlannerContext, + ) -> Result { + plan_static_values_table( + relation, + "numbers", + "number", + vec![ScalarValue::Int64(Some(100)), ScalarValue::Int64(Some(200))], + ) + } +} + +/// Example planner that intercepts nested joins and samples both sides (limit 2) +/// before joining, demonstrating recursive planning with `context.plan()`. +#[derive(Debug)] +struct SamplingJoinPlanner; + +impl RelationPlanner for SamplingJoinPlanner { + fn plan_relation( + &self, + relation: TableFactor, + context: &mut dyn RelationPlannerContext, + ) -> Result { + match relation { + TableFactor::NestedJoin { + table_with_joins, + alias, + .. + } if table_with_joins.joins.len() == 1 => { + // Use context.plan() to recursively plan both sides + // This ensures other planners (like NumbersPlanner) can handle them + let left = context.plan(table_with_joins.relation.clone())?; + let right = context.plan(table_with_joins.joins[0].relation.clone())?; + + // Sample each table to 2 rows + let left_sampled = + LogicalPlanBuilder::from(left).limit(0, Some(2))?.build()?; + + let right_sampled = + LogicalPlanBuilder::from(right).limit(0, Some(2))?.build()?; + + // Cross join: 2 rows × 2 rows = 4 rows (instead of 3×3=9 without sampling) + let plan = LogicalPlanBuilder::from(left_sampled) + .cross_join(right_sampled)? + .build()?; + + Ok(RelationPlanning::Planned(Box::new(PlannedRelation::new( + plan, alias, + )))) + } + other => Ok(RelationPlanning::Original(Box::new(other))), + } + } +} + +/// Example planner that never handles any relation and always delegates by +/// returning `RelationPlanning::Original`. +#[derive(Debug)] +struct PassThroughPlanner; + +impl RelationPlanner for PassThroughPlanner { + fn plan_relation( + &self, + relation: TableFactor, + _context: &mut dyn RelationPlannerContext, + ) -> Result { + // Never handles anything - always delegates + Ok(RelationPlanning::Original(Box::new(relation))) + } +} + +/// Example planner that shows how planners can block specific constructs and +/// surface custom error messages by rejecting `UNNEST` relations (here framed +/// as a mock premium feature check). +#[derive(Debug)] +struct PremiumFeaturePlanner; + +impl RelationPlanner for PremiumFeaturePlanner { + fn plan_relation( + &self, + relation: TableFactor, + _context: &mut dyn RelationPlannerContext, + ) -> Result { + match relation { + TableFactor::UNNEST { .. } => Err(datafusion_common::DataFusionError::Plan( + "UNNEST is a premium feature! Please upgrade to DataFusion Pro™ \ + to unlock advanced array operations." + .to_string(), + )), + other => Ok(RelationPlanning::Original(Box::new(other))), + } + } +} + +// ============================================================================ +// Test Helpers - SQL Execution +// ============================================================================ + +/// Execute SQL and return results with better error messages. +async fn execute_sql(ctx: &SessionContext, sql: &str) -> Result> { + let df = ctx.sql(sql).await?; + df.collect().await +} + +/// Execute SQL and convert to string format for snapshot comparison. +async fn execute_sql_to_string(ctx: &SessionContext, sql: &str) -> String { + let batches = execute_sql(ctx, sql) + .await + .expect("SQL execution should succeed"); + batches_to_string(&batches) +} + +// ============================================================================ +// Test Helpers - Context Builders +// ============================================================================ + +/// Create a SessionContext with a catalog table containing Int64 and Utf8 columns. +/// +/// Creates a table with the specified name and sample data for fallback/integration tests. +fn create_context_with_catalog_table( + table_name: &str, + id_values: Vec, + name_values: Vec<&str>, +) -> SessionContext { + let ctx = SessionContext::new(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, false), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int64Array::from(id_values)), + Arc::new(StringArray::from(name_values)), + ], + ) + .unwrap(); + + let table = MemTable::try_new(schema, vec![vec![batch]]).unwrap(); + ctx.register_table(table_name, Arc::new(table)).unwrap(); + + ctx +} + +/// Create a SessionContext with a simple single-column Int64 table. +/// +/// Useful for basic tests that need a real catalog table. +fn create_context_with_simple_table( + table_name: &str, + values: Vec, +) -> SessionContext { + let ctx = SessionContext::new(); + + let schema = Arc::new(Schema::new(vec![Field::new( + "value", + DataType::Int64, + true, + )])); + + let batch = + RecordBatch::try_new(schema.clone(), vec![Arc::new(Int64Array::from(values))]) + .unwrap(); + + let table = MemTable::try_new(schema, vec![vec![batch]]).unwrap(); + ctx.register_table(table_name, Arc::new(table)).unwrap(); + + ctx +} + +// ============================================================================ +// TESTS: Ordered from Basic to Complex +// ============================================================================ + +/// Comprehensive test suite for RelationPlanner extension point. +/// Tests are ordered from simplest smoke test to most complex scenarios. +#[cfg(test)] +mod tests { + use super::*; + + /// Small extension trait to make test setup read fluently. + trait TestSessionExt { + fn with_planner(self, planner: P) -> Self; + } + + impl TestSessionExt for SessionContext { + fn with_planner(self, planner: P) -> Self { + self.register_relation_planner(Arc::new(planner)).unwrap(); + self + } + } + + /// Session context with only the `NumbersPlanner` registered. + fn ctx_with_numbers() -> SessionContext { + SessionContext::new().with_planner(NumbersPlanner) + } + + /// Session context with virtual tables (`numbers`, `colors`) and the + /// `SamplingJoinPlanner` registered for nested joins. + fn ctx_with_virtual_tables_and_sampling() -> SessionContext { + SessionContext::new() + .with_planner(NumbersPlanner) + .with_planner(ColorsPlanner) + .with_planner(SamplingJoinPlanner) + } + + // Basic smoke test: virtual table can be queried like a regular table. + #[tokio::test] + async fn virtual_table_basic_select() { + let ctx = ctx_with_numbers(); + + let result = execute_sql_to_string(&ctx, "SELECT * FROM numbers").await; + + assert_snapshot!(result, @r" + +--------+ + | number | + +--------+ + | 1 | + | 2 | + | 3 | + +--------+ + "); + } + + // Virtual table supports standard SQL operations (projection, filter, aggregation). + #[tokio::test] + async fn virtual_table_filters_and_aggregation() { + let ctx = ctx_with_numbers(); + + let filtered = execute_sql_to_string( + &ctx, + "SELECT number * 10 AS scaled FROM numbers WHERE number > 1", + ) + .await; + + assert_snapshot!(filtered, @r" + +--------+ + | scaled | + +--------+ + | 20 | + | 30 | + +--------+ + "); + + let aggregated = execute_sql_to_string( + &ctx, + "SELECT COUNT(*) as count, SUM(number) as total, AVG(number) as average \ + FROM numbers", + ) + .await; + + assert_snapshot!(aggregated, @r" + +-------+-------+---------+ + | count | total | average | + +-------+-------+---------+ + | 3 | 6 | 2.0 | + +-------+-------+---------+ + "); + } + + // Multiple planners can coexist and each handles its own virtual table. + #[tokio::test] + async fn multiple_planners_virtual_tables() { + let ctx = SessionContext::new() + .with_planner(NumbersPlanner) + .with_planner(ColorsPlanner); + + let result1 = execute_sql_to_string(&ctx, "SELECT * FROM numbers").await; + assert_snapshot!(result1, @r" + +--------+ + | number | + +--------+ + | 1 | + | 2 | + | 3 | + +--------+ + "); + + let result2 = execute_sql_to_string(&ctx, "SELECT * FROM colors").await; + assert_snapshot!(result2, @r" + +-------+ + | color | + +-------+ + | red | + | green | + | blue | + +-------+ + "); + } + + // Last registered planner for the same table name takes precedence (LIFO). + #[tokio::test] + async fn lifo_precedence_last_planner_wins() { + let ctx = SessionContext::new() + .with_planner(AlternativeNumbersPlanner) + .with_planner(NumbersPlanner); + + let result = execute_sql_to_string(&ctx, "SELECT * FROM numbers").await; + + // CustomValuesPlanner registered last, should win (returns 1,2,3 not 100,200) + assert_snapshot!(result, @r" + +--------+ + | number | + +--------+ + | 1 | + | 2 | + | 3 | + +--------+ + "); + } + + // Pass-through planner delegates to the catalog without changing behavior. + #[tokio::test] + async fn delegation_pass_through_to_catalog() { + let ctx = create_context_with_simple_table("real_table", vec![42]) + .with_planner(PassThroughPlanner); + + let result = execute_sql_to_string(&ctx, "SELECT * FROM real_table").await; + + assert_snapshot!(result, @r" + +-------+ + | value | + +-------+ + | 42 | + +-------+ + "); + } + + // Catalog is used when no planner claims the relation. + #[tokio::test] + async fn catalog_fallback_when_no_planner() { + let ctx = + create_context_with_catalog_table("users", vec![1, 2], vec!["Alice", "Bob"]) + .with_planner(NumbersPlanner); + + let result = execute_sql_to_string(&ctx, "SELECT * FROM users ORDER BY id").await; + + assert_snapshot!(result, @r" + +----+-------+ + | id | name | + +----+-------+ + | 1 | Alice | + | 2 | Bob | + +----+-------+ + "); + } + + // Planners can block specific constructs and surface custom error messages. + #[tokio::test] + async fn error_handling_premium_feature_blocking() { + // Verify UNNEST works without planner + let ctx_without_planner = SessionContext::new(); + let result = + execute_sql(&ctx_without_planner, "SELECT * FROM UNNEST(ARRAY[1, 2, 3])") + .await + .expect("UNNEST should work by default"); + assert_eq!(result.len(), 1); + + // Same query with blocking planner registered + let ctx = SessionContext::new().with_planner(PremiumFeaturePlanner); + + // Verify UNNEST is now rejected + let error = execute_sql(&ctx, "SELECT * FROM UNNEST(ARRAY[1, 2, 3])") + .await + .expect_err("UNNEST should be rejected"); + + let error_msg = error.to_string(); + assert!( + error_msg.contains("premium feature") && error_msg.contains("DataFusion Pro"), + "Expected custom rejection message, got: {error_msg}" + ); + } + + // SamplingJoinPlanner recursively calls `context.plan()` on both sides of a + // nested join before sampling, exercising recursive relation planning. + #[tokio::test] + async fn recursive_planning_sampling_join() { + let ctx = ctx_with_virtual_tables_and_sampling(); + + let result = + execute_sql_to_string(&ctx, "SELECT * FROM (numbers JOIN colors ON true)") + .await; + + // SamplingJoinPlanner limits each side to 2 rows: 2×2=4 (not 3×3=9) + assert_snapshot!(result, @r" + +--------+-------+ + | number | color | + +--------+-------+ + | 1 | red | + | 1 | green | + | 2 | red | + | 2 | green | + +--------+-------+ + "); + } +} diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index ae517795ab955..e7bd2241398ad 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -20,16 +20,16 @@ use std::any::Any; use std::collections::HashMap; -use std::hash::{DefaultHasher, Hash, Hasher}; +use std::hash::{Hash, Hasher}; use std::mem::{size_of, size_of_val}; use std::sync::{ - atomic::{AtomicBool, Ordering}, Arc, + atomic::{AtomicBool, Ordering}, }; use arrow::array::{ - record_batch, types::UInt64Type, Array, AsArray, Int32Array, PrimitiveArray, - StringArray, StructArray, UInt64Array, + Array, AsArray, Int32Array, PrimitiveArray, StringArray, StructArray, UInt64Array, + record_batch, types::UInt64Type, }; use arrow::datatypes::{Fields, Schema}; use arrow_schema::FieldRef; @@ -53,10 +53,11 @@ use datafusion::{ }; use datafusion_common::{assert_contains, exec_datafusion_err}; use datafusion_common::{cast::as_primitive_array, exec_err}; + use datafusion_expr::expr::WindowFunction; use datafusion_expr::{ - col, create_udaf, function::AccumulatorArgs, AggregateUDFImpl, Expr, - GroupsAccumulator, LogicalPlanBuilder, SimpleAggregateUDF, WindowFunctionDefinition, + AggregateUDFImpl, Expr, GroupsAccumulator, LogicalPlanBuilder, SimpleAggregateUDF, + WindowFunctionDefinition, col, create_udaf, function::AccumulatorArgs, }; use datafusion_functions_aggregate::average::AvgAccumulator; @@ -68,7 +69,7 @@ async fn test_setup() { let actual = execute(&ctx, sql).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +-------+----------------------------+ | value | time | +-------+----------------------------+ @@ -78,7 +79,7 @@ async fn test_setup() { | 5.0 | 1970-01-01T00:00:00.000005 | | 5.0 | 1970-01-01T00:00:00.000005 | +-------+----------------------------+ - "###); + "); } /// Basic user defined aggregate @@ -90,13 +91,13 @@ async fn test_udaf() { let actual = execute(&ctx, sql).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +----------------------------+ | time_sum(t.time) | +----------------------------+ | 1970-01-01T00:00:00.000019 | +----------------------------+ - "###); + "); // normal aggregates call update_batch assert!(test_state.update_batch()); @@ -111,7 +112,7 @@ async fn test_udaf_as_window() { let actual = execute(&ctx, sql).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +----------------------------+ | time_sum | +----------------------------+ @@ -121,7 +122,7 @@ async fn test_udaf_as_window() { | 1970-01-01T00:00:00.000019 | | 1970-01-01T00:00:00.000019 | +----------------------------+ - "###); + "); // aggregate over the entire window function call update_batch assert!(test_state.update_batch()); @@ -136,7 +137,7 @@ async fn test_udaf_as_window_with_frame() { let actual = execute(&ctx, sql).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +----------------------------+ | time_sum | +----------------------------+ @@ -146,7 +147,7 @@ async fn test_udaf_as_window_with_frame() { | 1970-01-01T00:00:00.000014 | | 1970-01-01T00:00:00.000010 | +----------------------------+ - "###); + "); // user defined aggregates with window frame should be calling retract batch assert!(test_state.update_batch()); @@ -163,7 +164,10 @@ async fn test_udaf_as_window_with_frame_without_retract_batch() { let sql = "SELECT time_sum(time) OVER(ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as time_sum from t"; // Note if this query ever does start working let err = execute(&ctx, sql).await.unwrap_err(); - assert_contains!(err.to_string(), "This feature is not implemented: Aggregate can not be used as a sliding accumulator because `retract_batch` is not implemented: time_sum(t.time) ORDER BY [t.time ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING"); + assert_contains!( + err.to_string(), + "This feature is not implemented: Aggregate can not be used as a sliding accumulator because `retract_batch` is not implemented: time_sum(t.time) ORDER BY [t.time ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING" + ); } /// Basic query for with a udaf returning a structure @@ -174,13 +178,13 @@ async fn test_udaf_returning_struct() { let actual = execute(&ctx, sql).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +------------------------------------------------+ | first(t.value,t.time) | +------------------------------------------------+ | {value: 2.0, time: 1970-01-01T00:00:00.000002} | +------------------------------------------------+ - "###); + "); } /// Demonstrate extracting the fields from a structure using a subquery @@ -191,13 +195,13 @@ async fn test_udaf_returning_struct_subquery() { let actual = execute(&ctx, sql).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +-----------------+----------------------------+ | sq.first[value] | sq.first[time] | +-----------------+----------------------------+ | 2.0 | 1970-01-01T00:00:00.000002 | +-----------------+----------------------------+ - "###); + "); } #[tokio::test] @@ -211,13 +215,13 @@ async fn test_udaf_shadows_builtin_fn() { // compute with builtin `sum` aggregator let actual = execute(&ctx, sql).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r#" +---------------------------------------+ | sum(arrow_cast(t.time,Utf8("Int64"))) | +---------------------------------------+ | 19000 | +---------------------------------------+ - "###); + "#); // Register `TimeSum` with name `sum`. This will shadow the builtin one TimeSum::register(&mut ctx, test_state.clone(), "sum"); @@ -225,13 +229,13 @@ async fn test_udaf_shadows_builtin_fn() { let actual = execute(&ctx, sql).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +----------------------------+ | sum(t.time) | +----------------------------+ | 1970-01-01T00:00:00.000019 | +----------------------------+ - "###); + "); } async fn execute(ctx: &SessionContext, sql: &str) -> Result> { @@ -271,13 +275,13 @@ async fn simple_udaf() -> Result<()> { let result = ctx.sql("SELECT MY_AVG(a) FROM t").await?.collect().await?; - insta::assert_snapshot!(batches_to_string(&result), @r###" + insta::assert_snapshot!(batches_to_string(&result), @r" +-------------+ | my_avg(t.a) | +-------------+ | 3.0 | +-------------+ - "###); + "); Ok(()) } @@ -297,10 +301,12 @@ async fn deregister_udaf() -> Result<()> { ctx.register_udaf(my_avg); assert!(ctx.state().aggregate_functions().contains_key("my_avg")); + assert!(datafusion_execution::FunctionRegistry::udafs(&ctx).contains("my_avg")); ctx.deregister_udaf("my_avg"); assert!(!ctx.state().aggregate_functions().contains_key("my_avg")); + assert!(!datafusion_execution::FunctionRegistry::udafs(&ctx).contains("my_avg")); Ok(()) } @@ -326,9 +332,10 @@ async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> { // doesn't work as it was registered as non lowercase let err = ctx.sql("SELECT MY_AVG(i) FROM t").await.unwrap_err(); - assert!(err - .to_string() - .contains("Error during planning: Invalid function \'my_avg\'")); + assert!( + err.to_string() + .contains("Error during planning: Invalid function \'my_avg\'") + ); // Can call it if you put quotes let result = ctx @@ -337,13 +344,13 @@ async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&result), @r###" + insta::assert_snapshot!(batches_to_string(&result), @r" +-------------+ | MY_AVG(t.i) | +-------------+ | 1.0 | +-------------+ - "###); + "); Ok(()) } @@ -369,23 +376,23 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { let result = plan_and_collect(&ctx, "SELECT dummy(i) FROM t").await?; - insta::assert_snapshot!(batches_to_string(&result), @r###" + insta::assert_snapshot!(batches_to_string(&result), @r" +------------+ | dummy(t.i) | +------------+ | 1.0 | +------------+ - "###); + "); let alias_result = plan_and_collect(&ctx, "SELECT dummy_alias(i) FROM t").await?; - insta::assert_snapshot!(batches_to_string(&alias_result), @r###" - +------------+ - | dummy(t.i) | - +------------+ - | 1.0 | - +------------+ - "###); + insta::assert_snapshot!(batches_to_string(&alias_result), @r" + +------------------+ + | dummy_alias(t.i) | + +------------------+ + | 1.0 | + +------------------+ + "); Ok(()) } @@ -446,13 +453,13 @@ async fn test_parameterized_aggregate_udf() -> Result<()> { let actual = DataFrame::new(ctx.state(), plan).collect().await?; - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +------+---+---+ | text | a | b | +------+---+---+ | foo | 1 | 2 | +------+---+---+ - "###); + "); ctx.deregister_table("t")?; Ok(()) @@ -566,6 +573,7 @@ impl TimeSum { Self { sum: 0, test_state } } + #[expect(clippy::needless_pass_by_value)] fn register(ctx: &mut SessionContext, test_state: Arc, name: &str) { let timestamp_type = DataType::Timestamp(TimeUnit::Nanosecond, None); let input_type = vec![timestamp_type.clone()]; @@ -757,11 +765,11 @@ impl Accumulator for FirstSelector { // Update the actual values for (value, time) in v.iter().zip(t.iter()) { - if let (Some(time), Some(value)) = (time, value) { - if time < self.time { - self.value = value; - self.time = time; - } + if let (Some(time), Some(value)) = (time, value) + && time < self.time + { + self.value = value; + self.time = time; } } @@ -778,7 +786,7 @@ impl Accumulator for FirstSelector { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] struct TestGroupsAccumulator { signature: Signature, result: u64, @@ -816,21 +824,6 @@ impl AggregateUDFImpl for TestGroupsAccumulator { ) -> Result> { Ok(Box::new(self.clone())) } - - fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { - if let Some(other) = other.as_any().downcast_ref::() { - self.result == other.result && self.signature == other.signature - } else { - false - } - } - - fn hash_value(&self) -> u64 { - let hasher = &mut DefaultHasher::new(); - self.signature.hash(hasher); - self.result.hash(hasher); - hasher.finish() - } } impl Accumulator for TestGroupsAccumulator { @@ -902,6 +895,32 @@ struct MetadataBasedAggregateUdf { metadata: HashMap, } +impl PartialEq for MetadataBasedAggregateUdf { + fn eq(&self, other: &Self) -> bool { + let Self { + name, + signature, + metadata, + } = self; + name == &other.name + && signature == &other.signature + && metadata == &other.metadata + } +} +impl Eq for MetadataBasedAggregateUdf {} +impl Hash for MetadataBasedAggregateUdf { + fn hash(&self, state: &mut H) { + let Self { + name, + signature, + metadata: _, // unhashable + } = self; + std::any::type_name::().hash(state); + name.hash(state); + signature.hash(state); + } +} + impl MetadataBasedAggregateUdf { fn new(metadata: HashMap) -> Self { // The name we return must be unique. Otherwise we will not call distinct @@ -940,13 +959,7 @@ impl AggregateUDFImpl for MetadataBasedAggregateUdf { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - let input_expr = acc_args - .exprs - .first() - .ok_or(exec_datafusion_err!("Expected one argument"))?; - let input_field = input_expr.return_field(acc_args.schema)?; - - let double_output = input_field + let double_output = acc_args.expr_fields[0] .metadata() .get("modify_values") .map(|v| v == "double_output") @@ -1106,22 +1119,22 @@ async fn test_metadata_based_aggregate_as_window() -> Result<()> { ))); let df = df.select(vec![ - Expr::WindowFunction(WindowFunction::new( + Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(Arc::clone(&no_output_meta_udf)), vec![col("no_metadata")], )) .alias("meta_no_in_no_out"), - Expr::WindowFunction(WindowFunction::new( + Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(no_output_meta_udf), vec![col("with_metadata")], )) .alias("meta_with_in_no_out"), - Expr::WindowFunction(WindowFunction::new( + Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(Arc::clone(&with_output_meta_udf)), vec![col("no_metadata")], )) .alias("meta_no_in_with_out"), - Expr::WindowFunction(WindowFunction::new( + Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(with_output_meta_udf), vec![col("with_metadata")], )) diff --git a/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs new file mode 100644 index 0000000000000..31af4445ace08 --- /dev/null +++ b/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs @@ -0,0 +1,171 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{Int32Array, RecordBatch, StringArray}; +use arrow::datatypes::{DataType, Field, Schema}; +use async_trait::async_trait; +use datafusion::prelude::*; +use datafusion_common::test_util::format_batches; +use datafusion_common::{Result, assert_batches_eq}; +use datafusion_expr::async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; + +fn register_table_and_udf() -> Result { + let num_rows = 3; + let batch_size = 2; + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("prompt", DataType::Utf8, false), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from((0..num_rows).collect::>())), + Arc::new(StringArray::from( + (0..num_rows) + .map(|i| format!("prompt{i}")) + .collect::>(), + )), + ], + )?; + + let ctx = SessionContext::new(); + ctx.register_batch("test_table", batch)?; + + ctx.register_udf( + AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl::new(batch_size))) + .into_scalar_udf(), + ); + + Ok(ctx) +} + +// This test checks the case where batch_size doesn't evenly divide +// the number of rows. +#[tokio::test] +async fn test_async_udf_with_non_modular_batch_size() -> Result<()> { + let ctx = register_table_and_udf()?; + + let df = ctx + .sql("SELECT id, test_async_udf(prompt) as result FROM test_table") + .await?; + + let result = df.collect().await?; + + assert_batches_eq!( + &[ + "+----+---------+", + "| id | result |", + "+----+---------+", + "| 0 | prompt0 |", + "| 1 | prompt1 |", + "| 2 | prompt2 |", + "+----+---------+" + ], + &result + ); + + Ok(()) +} + +// This test checks if metrics are printed for `AsyncFuncExec` +#[tokio::test] +async fn test_async_udf_metrics() -> Result<()> { + let ctx = register_table_and_udf()?; + + let df = ctx + .sql( + "EXPLAIN ANALYZE SELECT id, test_async_udf(prompt) as result FROM test_table", + ) + .await?; + + let result = df.collect().await?; + + let explain_analyze_str = format_batches(&result)?.to_string(); + let async_func_exec_without_metrics = + explain_analyze_str.split("\n").any(|metric_line| { + metric_line.contains("AsyncFuncExec") + && !metric_line.contains("output_rows=3") + }); + + assert!(!async_func_exec_without_metrics); + + Ok(()) +} + +#[derive(Debug, PartialEq, Eq, Hash, Clone)] +struct TestAsyncUDFImpl { + batch_size: usize, + signature: Signature, +} + +impl TestAsyncUDFImpl { + fn new(batch_size: usize) -> Self { + Self { + batch_size, + signature: Signature::exact(vec![DataType::Utf8], Volatility::Volatile), + } + } +} + +impl ScalarUDFImpl for TestAsyncUDFImpl { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "test_async_udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + panic!("Call invoke_async_with_args instead") + } +} + +#[async_trait] +impl AsyncScalarUDFImpl for TestAsyncUDFImpl { + fn ideal_batch_size(&self) -> Option { + Some(self.batch_size) + } + async fn invoke_async_with_args( + &self, + args: ScalarFunctionArgs, + ) -> Result { + let arg1 = &args.args[0]; + let results = call_external_service(arg1.clone()).await?; + Ok(results) + } +} + +/// Simulates calling an async external service +async fn call_external_service(arg1: ColumnarValue) -> Result { + Ok(arg1) +} diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index b68ef6aca0931..6e4ed69e508d3 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -56,7 +56,6 @@ //! //! The same answer can be produced by simply keeping track of the top //! N elements, reducing the total amount of required buffer memory. -//! use std::fmt::Debug; use std::hash::Hash; @@ -71,7 +70,7 @@ use arrow::{ use datafusion::execution::session_state::SessionStateBuilder; use datafusion::{ common::cast::as_int64_array, - common::{arrow_datafusion_err, internal_err, DFSchemaRef}, + common::{DFSchemaRef, arrow_datafusion_err}, error::{DataFusionError, Result}, execution::{ context::{QueryPlanner, SessionState, TaskContext}, @@ -85,17 +84,19 @@ use datafusion::{ physical_expr::EquivalenceProperties, physical_plan::{ DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, - PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, + PlanProperties, RecordBatchStream, SendableRecordBatchStream, }, physical_planner::{DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner}, prelude::{SessionConfig, SessionContext}, }; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::ScalarValue; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; +use datafusion_common::{ScalarValue, assert_eq_or_internal_err, assert_or_internal_err}; use datafusion_expr::{FetchType, InvariantLevel, Projection, SortExpr}; -use datafusion_optimizer::optimizer::ApplyOrder; use datafusion_optimizer::AnalyzerRule; +use datafusion_optimizer::optimizer::ApplyOrder; use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; use async_trait::async_trait; @@ -162,7 +163,7 @@ async fn run_and_compare_query(ctx: SessionContext, description: &str) -> Result insta::with_settings!({ description => description, }, { - insta::assert_snapshot!(actual, @r###" + insta::assert_snapshot!(actual, @r" +-------------+---------+ | customer_id | revenue | +-------------+---------+ @@ -170,7 +171,7 @@ async fn run_and_compare_query(ctx: SessionContext, description: &str) -> Result | jorge | 200 | | andy | 150 | +-------------+---------+ - "###); + "); }); } @@ -189,13 +190,13 @@ async fn run_and_compare_query_with_analyzer_rule( insta::with_settings!({ description => description, }, { - insta::assert_snapshot!(actual, @r###" + insta::assert_snapshot!(actual, @r" +------------+--------------------------+ | UInt64(42) | arrow_typeof(UInt64(42)) | +------------+--------------------------+ | 42 | UInt64 | +------------+--------------------------+ - "###); + "); }); Ok(()) @@ -213,7 +214,7 @@ async fn run_and_compare_query_with_auto_schemas( insta::with_settings!({ description => description, }, { - insta::assert_snapshot!(actual, @r###" + insta::assert_snapshot!(actual, @r" +----------+----------+ | column_1 | column_2 | +----------+----------+ @@ -221,7 +222,7 @@ async fn run_and_compare_query_with_auto_schemas( | jorge | 200 | | andy | 150 | +----------+----------+ - "###); + "); }); Ok(()) @@ -434,21 +435,21 @@ impl OptimizerRule for OptimizerMakeExtensionNodeInvalid { plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result, DataFusionError> { - if let LogicalPlan::Extension(Extension { node }) = &plan { - if let Some(prev) = node.as_any().downcast_ref::() { - return Ok(Transformed::yes(LogicalPlan::Extension(Extension { - node: Arc::new(TopKPlanNode { - k: prev.k, - input: prev.input.clone(), - expr: prev.expr.clone(), - // In a real use case, this rewriter could have change the number of inputs, etc - invariant_mock: Some(InvariantMock { - should_fail_invariant: true, - kind: InvariantLevel::Always, - }), + if let LogicalPlan::Extension(Extension { node }) = &plan + && let Some(prev) = node.as_any().downcast_ref::() + { + return Ok(Transformed::yes(LogicalPlan::Extension(Extension { + node: Arc::new(TopKPlanNode { + k: prev.k, + input: prev.input.clone(), + expr: prev.expr.clone(), + // In a real use case, this rewriter could have change the number of inputs, etc + invariant_mock: Some(InvariantMock { + should_fail_invariant: true, + kind: InvariantLevel::Always, }), - }))); - } + }), + }))); }; Ok(Transformed::no(plan)) @@ -516,23 +517,18 @@ impl OptimizerRule for TopKOptimizerRule { return Ok(Transformed::no(plan)); }; - if let LogicalPlan::Sort(Sort { - ref expr, - ref input, - .. - }) = limit.input.as_ref() + if let LogicalPlan::Sort(Sort { expr, input, .. }) = limit.input.as_ref() + && expr.len() == 1 { - if expr.len() == 1 { - // we found a sort with a single sort expr, replace with a a TopK - return Ok(Transformed::yes(LogicalPlan::Extension(Extension { - node: Arc::new(TopKPlanNode { - k: fetch, - input: input.as_ref().clone(), - expr: expr[0].clone(), - invariant_mock: self.invariant_mock.clone(), - }), - }))); - } + // we found a sort with a single sort expr, replace with a a TopK + return Ok(Transformed::yes(LogicalPlan::Extension(Extension { + node: Arc::new(TopKPlanNode { + k: fetch, + input: input.as_ref().clone(), + expr: expr[0].clone(), + invariant_mock: self.invariant_mock.clone(), + }), + }))); } Ok(Transformed::no(plan)) @@ -580,15 +576,16 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode { self.input.schema() } - fn check_invariants(&self, check: InvariantLevel, _plan: &LogicalPlan) -> Result<()> { + fn check_invariants(&self, check: InvariantLevel) -> Result<()> { if let Some(InvariantMock { should_fail_invariant, kind, }) = self.invariant_mock.clone() { - if should_fail_invariant && check == kind { - return internal_err!("node fails check, such as improper inputs"); - } + assert_or_internal_err!( + !(should_fail_invariant && check == kind), + "node fails check, such as improper inputs" + ); } Ok(()) } @@ -658,13 +655,17 @@ struct TopKExec { input: Arc, /// The maximum number of values k: usize, - cache: PlanProperties, + cache: Arc, } impl TopKExec { fn new(input: Arc, k: usize) -> Self { let cache = Self::compute_properties(input.schema()); - Self { input, k, cache } + Self { + input, + k, + cache: Arc::new(cache), + } } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. @@ -709,7 +710,7 @@ impl ExecutionPlan for TopKExec { self } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -734,9 +735,11 @@ impl ExecutionPlan for TopKExec { partition: usize, context: Arc, ) -> Result { - if 0 != partition { - return internal_err!("TopKExec invalid partition {partition}"); - } + assert_eq_or_internal_err!( + partition, + 0, + "TopKExec invalid partition {partition}" + ); Ok(Box::pin(TopKReader { input: self.input.execute(partition, context)?, @@ -746,10 +749,20 @@ impl ExecutionPlan for TopKExec { })) } - fn statistics(&self) -> Result { - // to improve the optimizability of this plan - // better statistics inference could be provided - Ok(Statistics::new_unknown(&self.schema())) + fn apply_expressions( + &self, + f: &mut dyn FnMut( + &dyn datafusion::physical_plan::PhysicalExpr, + ) -> Result, + ) -> Result { + // Visit expressions in the output ordering from equivalence properties + let mut tnr = TreeNodeRecursion::Continue; + if let Some(ordering) = self.cache.output_ordering() { + for sort_expr in ordering { + tnr = tnr.visit_sibling(|| f(sort_expr.expr.as_ref()))?; + } + } + Ok(tnr) } } @@ -912,11 +925,12 @@ impl MyAnalyzerRule { .map(|e| { e.transform(|e| { Ok(match e { - Expr::Literal(ScalarValue::Int64(i)) => { + Expr::Literal(ScalarValue::Int64(i), _) => { // transform to UInt64 - Transformed::yes(Expr::Literal(ScalarValue::UInt64( - i.map(|i| i as u64), - ))) + Transformed::yes(Expr::Literal( + ScalarValue::UInt64(i.map(|i| i as u64)), + None, + )) } _ => Transformed::no(e), }) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 25458efa4fa55..025ee9767c694 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -17,35 +17,38 @@ use std::any::Any; use std::collections::HashMap; -use std::hash::{DefaultHasher, Hash, Hasher}; +use std::hash::{Hash, Hasher}; use std::sync::Arc; -use arrow::array::{as_string_array, record_batch, Int8Array, UInt64Array}; use arrow::array::{ - builder::BooleanBuilder, cast::AsArray, Array, ArrayRef, Float32Array, Float64Array, - Int32Array, RecordBatch, StringArray, + Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, StringArray, + builder::BooleanBuilder, cast::AsArray, }; +use arrow::array::{Int8Array, UInt64Array, as_string_array, create_array, record_batch}; use arrow::compute::kernels::numeric::add; use arrow::datatypes::{DataType, Field, Schema}; use arrow_schema::extension::{Bool8, CanonicalExtensionType, ExtensionType}; -use arrow_schema::{ArrowError, FieldRef}; +use arrow_schema::{ArrowError, FieldRef, SchemaRef}; use datafusion::common::test_util::batches_to_string; use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState}; use datafusion::prelude::*; use datafusion::{execution::registry::FunctionRegistry, test_util}; use datafusion_common::cast::{as_float64_array, as_int32_array}; +use datafusion_common::metadata::FieldMetadata; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::utils::take_function_args; use datafusion_common::{ - assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, not_impl_err, - plan_err, DFSchema, DataFusionError, Result, ScalarValue, + DFSchema, DataFusionError, Result, ScalarValue, assert_batches_eq, + assert_batches_sorted_eq, assert_contains, exec_datafusion_err, exec_err, + not_impl_err, plan_err, }; -use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; use datafusion_expr::{ Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, LogicalPlanBuilder, OperateFunctionArg, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, - Signature, Volatility, + Signature, Volatility, lit_with_metadata, }; +use datafusion_expr_common::signature::TypeSignature; use datafusion_functions_nested::range::range_udf; use parking_lot::Mutex; use regex::Regex; @@ -62,13 +65,13 @@ async fn csv_query_custom_udf_with_cast() -> Result<()> { let sql = "SELECT avg(custom_sqrt(c11)) FROM aggregate_test_100"; let actual = plan_and_collect(&ctx, sql).await?; - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +------------------------------------------+ | avg(custom_sqrt(aggregate_test_100.c11)) | +------------------------------------------+ | 0.6584408483418835 | +------------------------------------------+ - "###); + "); Ok(()) } @@ -81,13 +84,13 @@ async fn csv_query_avg_sqrt() -> Result<()> { let sql = "SELECT avg(custom_sqrt(c12)) FROM aggregate_test_100"; let actual = plan_and_collect(&ctx, sql).await?; - insta::assert_snapshot!(batches_to_string(&actual), @r###" + insta::assert_snapshot!(batches_to_string(&actual), @r" +------------------------------------------+ | avg(custom_sqrt(aggregate_test_100.c12)) | +------------------------------------------+ | 0.6706002946036459 | +------------------------------------------+ - "###); + "); Ok(()) } @@ -152,7 +155,7 @@ async fn scalar_udf() -> Result<()> { let result = DataFrame::new(ctx.state(), plan).collect().await?; - insta::assert_snapshot!(batches_to_string(&result), @r###" + insta::assert_snapshot!(batches_to_string(&result), @r" +-----+-----+-----------------+ | a | b | my_add(t.a,t.b) | +-----+-----+-----------------+ @@ -161,7 +164,7 @@ async fn scalar_udf() -> Result<()> { | 10 | 12 | 22 | | 100 | 120 | 220 | +-----+-----+-----------------+ - "###); + "); let batch = &result[0]; let a = as_int32_array(batch.column(0))?; @@ -180,6 +183,7 @@ async fn scalar_udf() -> Result<()> { Ok(()) } +#[derive(PartialEq, Eq, Hash)] struct Simple0ArgsScalarUDF { name: String, signature: Signature, @@ -277,7 +281,7 @@ async fn scalar_udf_zero_params() -> Result<()> { ctx.register_udf(ScalarUDF::from(get_100_udf)); let result = plan_and_collect(&ctx, "select get_100() a from t").await?; - insta::assert_snapshot!(batches_to_string(&result), @r###" + insta::assert_snapshot!(batches_to_string(&result), @r" +-----+ | a | +-----+ @@ -286,22 +290,22 @@ async fn scalar_udf_zero_params() -> Result<()> { | 100 | | 100 | +-----+ - "###); + "); let result = plan_and_collect(&ctx, "select get_100() a").await?; - insta::assert_snapshot!(batches_to_string(&result), @r###" + insta::assert_snapshot!(batches_to_string(&result), @r" +-----+ | a | +-----+ | 100 | +-----+ - "###); + "); let result = plan_and_collect(&ctx, "select get_100() from t where a=999").await?; - insta::assert_snapshot!(batches_to_string(&result), @r###" + insta::assert_snapshot!(batches_to_string(&result), @r" ++ ++ - "###); + "); Ok(()) } @@ -328,13 +332,13 @@ async fn scalar_udf_override_built_in_scalar_function() -> Result<()> { // Make sure that the UDF is used instead of the built-in function let result = plan_and_collect(&ctx, "select abs(a) a from t").await?; - insta::assert_snapshot!(batches_to_string(&result), @r###" + insta::assert_snapshot!(batches_to_string(&result), @r" +---+ | a | +---+ | 1 | +---+ - "###); + "); Ok(()) } @@ -423,20 +427,21 @@ async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> { let err = plan_and_collect(&ctx, "SELECT MY_FUNC(i) FROM t") .await .unwrap_err(); - assert!(err - .to_string() - .contains("Error during planning: Invalid function \'my_func\'")); + assert!( + err.to_string() + .contains("Error during planning: Invalid function \'my_func\'") + ); // Can call it if you put quotes let result = plan_and_collect(&ctx, "SELECT \"MY_FUNC\"(i) FROM t").await?; - insta::assert_snapshot!(batches_to_string(&result), @r###" + insta::assert_snapshot!(batches_to_string(&result), @r" +--------------+ | MY_FUNC(t.i) | +--------------+ | 1 | +--------------+ - "###); + "); Ok(()) } @@ -467,28 +472,28 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { ctx.register_udf(udf); let result = plan_and_collect(&ctx, "SELECT dummy(i) FROM t").await?; - insta::assert_snapshot!(batches_to_string(&result), @r###" + insta::assert_snapshot!(batches_to_string(&result), @r" +------------+ | dummy(t.i) | +------------+ | 1 | +------------+ - "###); + "); let alias_result = plan_and_collect(&ctx, "SELECT dummy_alias(i) FROM t").await?; - insta::assert_snapshot!(batches_to_string(&alias_result), @r###" - +------------+ - | dummy(t.i) | - +------------+ - | 1 | - +------------+ - "###); + insta::assert_snapshot!(batches_to_string(&alias_result), @r" + +------------------+ + | dummy_alias(t.i) | + +------------------+ + | 1 | + +------------------+ + "); Ok(()) } /// Volatile UDF that should append a different value to each row -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] struct AddIndexToStringVolatileScalarUDF { name: String, signature: Signature, @@ -659,7 +664,7 @@ async fn volatile_scalar_udf_with_params() -> Result<()> { Ok(()) } -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] struct CastToI64UDF { signature: Signature, } @@ -694,7 +699,7 @@ impl ScalarUDFImpl for CastToI64UDF { fn simplify( &self, mut args: Vec, - info: &dyn SimplifyInfo, + info: &SimplifyContext, ) -> Result { // DataFusion should have ensured the function is called with just a // single argument @@ -710,10 +715,7 @@ impl ScalarUDFImpl for CastToI64UDF { arg } else { // need to use an actual cast to get the correct type - Expr::Cast(datafusion_expr::Cast { - expr: Box::new(arg), - data_type: DataType::Int64, - }) + Expr::Cast(datafusion_expr::Cast::new(Box::new(arg), DataType::Int64)) }; // return the newly written argument to DataFusion Ok(ExprSimplifyResult::Simplified(new_expr)) @@ -773,15 +775,17 @@ async fn deregister_udf() -> Result<()> { ctx.register_udf(cast2i64); assert!(ctx.udfs().contains("cast_to_i64")); + assert!(FunctionRegistry::udfs(&ctx).contains("cast_to_i64")); ctx.deregister_udf("cast_to_i64"); assert!(!ctx.udfs().contains("cast_to_i64")); + assert!(!FunctionRegistry::udfs(&ctx).contains("cast_to_i64")); Ok(()) } -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] struct TakeUDF { signature: Signature, } @@ -935,12 +939,13 @@ impl FunctionFactory for CustomFunctionFactory { // // it also defines custom [ScalarUDFImpl::simplify()] // to replace ScalarUDF expression with one instance contains. -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] struct ScalarFunctionWrapper { name: String, expr: Expr, signature: Signature, return_type: DataType, + defaults: Vec>, } impl ScalarUDFImpl for ScalarFunctionWrapper { @@ -967,21 +972,21 @@ impl ScalarUDFImpl for ScalarFunctionWrapper { fn simplify( &self, args: Vec, - _info: &dyn SimplifyInfo, + _info: &SimplifyContext, ) -> Result { - let replacement = Self::replacement(&self.expr, &args)?; + let replacement = Self::replacement(&self.expr, &args, &self.defaults)?; Ok(ExprSimplifyResult::Simplified(replacement)) } - - fn aliases(&self) -> &[String] { - &[] - } } impl ScalarFunctionWrapper { // replaces placeholders with actual arguments - fn replacement(expr: &Expr, args: &[Expr]) -> Result { + fn replacement( + expr: &Expr, + args: &[Expr], + defaults: &[Option], + ) -> Result { let result = expr.clone().transform(|e| { let r = match e { Expr::Placeholder(placeholder) => { @@ -989,11 +994,19 @@ impl ScalarFunctionWrapper { Self::parse_placeholder_identifier(&placeholder.id)?; if placeholder_position < args.len() { Transformed::yes(args[placeholder_position].clone()) - } else { + } else if placeholder_position >= defaults.len() { exec_err!( - "Function argument {} not provided, argument missing!", + "Invalid placeholder, out of range: {}", placeholder.id )? + } else { + match defaults[placeholder_position] { + Some(ref default) => Transformed::yes(default.clone()), + None => exec_err!( + "Function argument {} not provided, argument missing!", + placeholder.id + )?, + } } } _ => Transformed::no(e), @@ -1009,9 +1022,7 @@ impl ScalarFunctionWrapper { fn parse_placeholder_identifier(placeholder: &str) -> Result { if let Some(value) = placeholder.strip_prefix('$') { Ok(value.parse().map(|v: usize| v - 1).map_err(|e| { - DataFusionError::Execution(format!( - "Placeholder `{placeholder}` parsing error: {e}!" - )) + exec_datafusion_err!("Placeholder `{placeholder}` parsing error: {e}!") })?) } else { exec_err!("Placeholder should start with `$`!") @@ -1023,6 +1034,32 @@ impl TryFrom for ScalarFunctionWrapper { type Error = DataFusionError; fn try_from(definition: CreateFunction) -> std::result::Result { + let args = definition.args.unwrap_or_default(); + let defaults: Vec> = + args.iter().map(|a| a.default_expr.clone()).collect(); + let signature: Signature = match defaults.iter().position(|v| v.is_some()) { + Some(pos) => { + let mut type_signatures: Vec = vec![]; + // Generate all valid signatures + for n in pos..defaults.len() + 1 { + if n == 0 { + type_signatures.push(TypeSignature::Nullary) + } else { + type_signatures.push(TypeSignature::Exact( + args.iter().take(n).map(|a| a.data_type.clone()).collect(), + )) + } + } + Signature::one_of( + type_signatures, + definition.params.behavior.unwrap_or(Volatility::Volatile), + ) + } + None => Signature::exact( + args.iter().map(|a| a.data_type.clone()).collect(), + definition.params.behavior.unwrap_or(Volatility::Volatile), + ), + }; Ok(Self { name: definition.name, expr: definition @@ -1032,15 +1069,8 @@ impl TryFrom for ScalarFunctionWrapper { return_type: definition .return_type .expect("Return type has to be defined!"), - signature: Signature::exact( - definition - .args - .unwrap_or_default() - .into_iter() - .map(|a| a.data_type) - .collect(), - definition.params.behavior.unwrap_or(Volatility::Volatile), - ), + signature, + defaults, }) } } @@ -1063,10 +1093,11 @@ async fn create_scalar_function_from_sql_statement() -> Result<()> { // Create the `better_add` function dynamically via CREATE FUNCTION statement assert!(ctx.sql(sql).await.is_ok()); // try to `drop function` when sql options have allow ddl disabled - assert!(ctx - .sql_with_options("drop function better_add", options) - .await - .is_err()); + assert!( + ctx.sql_with_options("drop function better_add", options) + .await + .is_err() + ); let result = ctx .sql("select better_add(2.0, 2.0)") @@ -1111,6 +1142,175 @@ async fn create_scalar_function_from_sql_statement() -> Result<()> { "#; assert!(ctx.sql(bad_definition_sql).await.is_err()); + // FIXME: Definitions with invalid placeholders are allowed, fail at runtime + let bad_expression_sql = r#" + CREATE FUNCTION better_add(DOUBLE, DOUBLE) + RETURNS DOUBLE + RETURN $1 + $3 + "#; + assert!(ctx.sql(bad_expression_sql).await.is_ok()); + + let err = ctx + .sql("select better_add(2.0, 2.0)") + .await? + .collect() + .await + .expect_err("unknown placeholder"); + let expected = "Optimizer rule 'simplify_expressions' failed\ncaused by\nExecution error: Invalid placeholder, out of range: $3"; + assert!(expected.starts_with(&err.strip_backtrace())); + + Ok(()) +} + +#[tokio::test] +async fn create_scalar_function_from_sql_statement_named_arguments() -> Result<()> { + let function_factory = Arc::new(CustomFunctionFactory::default()); + let ctx = SessionContext::new().with_function_factory(function_factory.clone()); + + let sql = r#" + CREATE FUNCTION better_add(a DOUBLE, b DOUBLE) + RETURNS DOUBLE + RETURN $a + $b + "#; + + assert!(ctx.sql(sql).await.is_ok()); + + let result = ctx + .sql("select better_add(2.0, 2.0)") + .await? + .collect() + .await?; + + assert_batches_eq!( + &[ + "+-----------------------------------+", + "| better_add(Float64(2),Float64(2)) |", + "+-----------------------------------+", + "| 4.0 |", + "+-----------------------------------+", + ], + &result + ); + + // cannot mix named and positional style + let bad_expression_sql = r#" + CREATE FUNCTION bad_expression_fun(DOUBLE, b DOUBLE) + RETURNS DOUBLE + RETURN $1 + $b + "#; + let err = ctx + .sql(bad_expression_sql) + .await + .expect_err("cannot mix named and positional style"); + let expected = "Error during planning: All function arguments must use either named or positional style."; + assert!(expected.starts_with(&err.strip_backtrace())); + + Ok(()) +} + +#[tokio::test] +async fn create_scalar_function_from_sql_statement_default_arguments() -> Result<()> { + let function_factory = Arc::new(CustomFunctionFactory::default()); + let ctx = SessionContext::new().with_function_factory(function_factory.clone()); + + let sql = r#" + CREATE FUNCTION better_add(a DOUBLE = 2.0, b DOUBLE = 2.0) + RETURNS DOUBLE + RETURN $a + $b + "#; + + assert!(ctx.sql(sql).await.is_ok()); + + // Check all function arity supported + let result = ctx.sql("select better_add()").await?.collect().await?; + + assert_batches_eq!( + &[ + "+--------------+", + "| better_add() |", + "+--------------+", + "| 4.0 |", + "+--------------+", + ], + &result + ); + + let result = ctx.sql("select better_add(2.0)").await?.collect().await?; + + assert_batches_eq!( + &[ + "+------------------------+", + "| better_add(Float64(2)) |", + "+------------------------+", + "| 4.0 |", + "+------------------------+", + ], + &result + ); + + let result = ctx + .sql("select better_add(2.0, 2.0)") + .await? + .collect() + .await?; + + assert_batches_eq!( + &[ + "+-----------------------------------+", + "| better_add(Float64(2),Float64(2)) |", + "+-----------------------------------+", + "| 4.0 |", + "+-----------------------------------+", + ], + &result + ); + + assert!(ctx.sql("select better_add(2.0, 2.0, 2.0)").await.is_err()); + assert!(ctx.sql("drop function better_add").await.is_ok()); + + // works with positional style + let sql = r#" + CREATE FUNCTION better_add(DOUBLE, DOUBLE = 2.0) + RETURNS DOUBLE + RETURN $1 + $2 + "#; + assert!(ctx.sql(sql).await.is_ok()); + + assert!(ctx.sql("select better_add()").await.is_err()); + let result = ctx.sql("select better_add(2.0)").await?.collect().await?; + assert_batches_eq!( + &[ + "+------------------------+", + "| better_add(Float64(2)) |", + "+------------------------+", + "| 4.0 |", + "+------------------------+", + ], + &result + ); + + // non-default argument cannot follow default argument + let bad_expression_sql = r#" + CREATE FUNCTION bad_expression_fun(a DOUBLE = 2.0, b DOUBLE) + RETURNS DOUBLE + RETURN $a + $b + "#; + let err = ctx + .sql(bad_expression_sql) + .await + .expect_err("non-default argument cannot follow default argument"); + let expected = + "Error during planning: Non-default arguments cannot follow default arguments."; + assert!(expected.starts_with(&err.strip_backtrace())); + + let expression_sql = r#" + CREATE FUNCTION bad_expression_fun(DOUBLE, DOUBLE DEFAULT 2.0) + RETURNS DOUBLE + RETURN $1 + $2 + "#; + let result = ctx.sql(expression_sql).await; + + assert!(result.is_ok()); Ok(()) } @@ -1184,7 +1384,7 @@ async fn create_scalar_function_from_sql_statement_postgres_syntax() -> Result<( quote_style: None, span: Span::empty(), }), - data_type: DataType::Utf8, + data_type: DataType::Utf8View, default_expr: None, }]), return_type: Some(DataType::Int32), @@ -1211,6 +1411,22 @@ struct MyRegexUdf { regex: Regex, } +impl PartialEq for MyRegexUdf { + fn eq(&self, other: &Self) -> bool { + let Self { signature, regex } = self; + signature == &other.signature && regex.as_str() == other.regex.as_str() + } +} +impl Eq for MyRegexUdf {} + +impl Hash for MyRegexUdf { + fn hash(&self, state: &mut H) { + let Self { signature, regex } = self; + signature.hash(state); + regex.as_str().hash(state); + } +} + impl MyRegexUdf { fn new(pattern: &str) -> Self { Self { @@ -1262,20 +1478,6 @@ impl ScalarUDFImpl for MyRegexUdf { _ => exec_err!("regex_udf only accepts a Utf8 arguments"), } } - - fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { - if let Some(other) = other.as_any().downcast_ref::() { - self.regex.as_str() == other.regex.as_str() - } else { - false - } - } - - fn hash_value(&self) -> u64 { - let hasher = &mut DefaultHasher::new(); - self.regex.as_str().hash(hasher); - hasher.finish() - } } #[tokio::test] @@ -1373,13 +1575,25 @@ async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> Result, } +impl Hash for MetadataBasedUdf { + fn hash(&self, state: &mut H) { + let Self { + name, + signature, + metadata: _, // unhashable + } = self; + name.hash(state); + signature.hash(state); + } +} + impl MetadataBasedUdf { fn new(metadata: HashMap) -> Self { // The name we return must be unique. Otherwise we will not call distinct @@ -1426,7 +1640,7 @@ impl ScalarUDFImpl for MetadataBasedUdf { .get("modify_values") .map(|v| v == "double_output") .unwrap_or(false); - let mulitplier = if should_double { 2 } else { 1 }; + let multiplier = if should_double { 2 } else { 1 }; match &args.args[0] { ColumnarValue::Array(array) => { @@ -1435,7 +1649,7 @@ impl ScalarUDFImpl for MetadataBasedUdf { .downcast_ref::() .unwrap() .iter() - .map(|v| v.map(|x| x * mulitplier)) + .map(|v| v.map(|x| x * multiplier)) .collect(); let array_ref = Arc::new(UInt64Array::from(array_values)) as ArrayRef; Ok(ColumnarValue::Array(array_ref)) @@ -1446,15 +1660,11 @@ impl ScalarUDFImpl for MetadataBasedUdf { }; Ok(ColumnarValue::Scalar(ScalarValue::UInt64( - value.map(|v| v * mulitplier), + value.map(|v| v * multiplier), ))) } } } - - fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { - self.name == other.name() - } } #[tokio::test] @@ -1529,11 +1739,71 @@ async fn test_metadata_based_udf() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_metadata_based_udf_with_literal() -> Result<()> { + let ctx = SessionContext::new(); + let input_metadata: HashMap = + [("modify_values".to_string(), "double_output".to_string())] + .into_iter() + .collect(); + let input_metadata = FieldMetadata::from(input_metadata); + let df = ctx.sql("select 0;").await?.select(vec![ + lit(5u64).alias_with_metadata("lit_with_doubling", Some(input_metadata.clone())), + lit(5u64).alias("lit_no_doubling"), + lit_with_metadata(5u64, Some(input_metadata)) + .alias("lit_with_double_no_alias_metadata"), + ])?; + + let output_metadata: HashMap = + [("output_metatype".to_string(), "custom_value".to_string())] + .into_iter() + .collect(); + let custom_udf = ScalarUDF::from(MetadataBasedUdf::new(output_metadata.clone())); + + let plan = LogicalPlanBuilder::from(df.into_optimized_plan()?) + .project(vec![ + custom_udf + .call(vec![col("lit_with_doubling")]) + .alias("doubled_output"), + custom_udf + .call(vec![col("lit_no_doubling")]) + .alias("not_doubled_output"), + custom_udf + .call(vec![col("lit_with_double_no_alias_metadata")]) + .alias("double_without_alias_metadata"), + ])? + .build()?; + + let actual = DataFrame::new(ctx.state(), plan).collect().await?; + + let schema = Arc::new(Schema::new(vec![ + Field::new("doubled_output", DataType::UInt64, false) + .with_metadata(output_metadata.clone()), + Field::new("not_doubled_output", DataType::UInt64, false) + .with_metadata(output_metadata.clone()), + Field::new("double_without_alias_metadata", DataType::UInt64, false) + .with_metadata(output_metadata.clone()), + ])); + + let expected = RecordBatch::try_new( + schema, + vec![ + create_array!(UInt64, [10]), + create_array!(UInt64, [5]), + create_array!(UInt64, [10]), + ], + )?; + + assert_eq!(expected, actual[0]); + + Ok(()) +} + /// This UDF is to test extension handling, both on the input and output /// sides. For the input, we will handle the data differently if there is /// the canonical extension type Bool8. For the output we will add a /// user defined extension type. -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] struct ExtensionBasedUdf { name: String, signature: Signature, @@ -1566,7 +1836,7 @@ impl ScalarUDFImpl for ExtensionBasedUdf { fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result { Ok(Field::new("canonical_extension_udf", DataType::Utf8, true) - .with_extension_type(MyUserExtentionType {}) + .with_extension_type(MyUserExtensionType {}) .into()) } @@ -1612,16 +1882,12 @@ impl ScalarUDFImpl for ExtensionBasedUdf { } } } - - fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { - self.name == other.name() - } } -struct MyUserExtentionType {} +struct MyUserExtensionType {} -impl ExtensionType for MyUserExtentionType { - const NAME: &'static str = "my_user_extention_type"; +impl ExtensionType for MyUserExtensionType { + const NAME: &'static str = "my_user_Extension_type"; type Metadata = (); fn metadata(&self) -> &Self::Metadata { @@ -1693,9 +1959,9 @@ async fn test_extension_based_udf() -> Result<()> { // To test for input extensions handling, we check the strings returned let expected_schema = Schema::new(vec![ Field::new("without_bool8_extension", DataType::Utf8, true) - .with_extension_type(MyUserExtentionType {}), + .with_extension_type(MyUserExtensionType {}), Field::new("with_bool8_extension", DataType::Utf8, true) - .with_extension_type(MyUserExtentionType {}), + .with_extension_type(MyUserExtensionType {}), ]); let expected = record_batch!( @@ -1713,3 +1979,237 @@ async fn test_extension_based_udf() -> Result<()> { ctx.deregister_table("t")?; Ok(()) } + +#[tokio::test] +async fn test_config_options_work_for_scalar_func() -> Result<()> { + #[derive(Debug, PartialEq, Eq, Hash)] + struct TestScalarUDF { + signature: Signature, + } + + impl ScalarUDFImpl for TestScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "TestScalarUDF" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let tz = args.config_options.execution.time_zone.clone(); + Ok(ColumnarValue::Scalar(ScalarValue::from(tz))) + } + } + + let udf = ScalarUDF::from(TestScalarUDF { + signature: Signature::uniform(1, vec![DataType::Utf8], Volatility::Stable), + }); + + let mut config = SessionConfig::new(); + config.options_mut().execution.time_zone = Some("AEST".into()); + + let ctx = SessionContext::new_with_config(config); + + ctx.register_udf(udf.clone()); + + let df = ctx.read_empty()?; + let df = df.select(vec![udf.call(vec![lit("a")]).alias("a")])?; + let actual = df.collect().await?; + + let expected_schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); + let expected = RecordBatch::try_new( + SchemaRef::from(expected_schema), + vec![create_array!(Utf8, ["AEST"])], + )?; + + assert_eq!(expected, actual[0]); + + Ok(()) +} + +/// https://github.com/apache/datafusion/issues/17425 +#[tokio::test] +async fn test_extension_metadata_preserve_in_sql_values() -> Result<()> { + #[derive(Debug, Hash, PartialEq, Eq)] + struct MakeExtension { + signature: Signature, + } + + impl Default for MakeExtension { + fn default() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } + } + + impl ScalarUDFImpl for MakeExtension { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "make_extension" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + Ok(arg_types.to_vec()) + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + unreachable!("This shouldn't have been called") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + Ok(args.arg_fields[0] + .as_ref() + .clone() + .with_metadata(HashMap::from([( + "ARROW:extension:metadata".to_string(), + "foofy.foofy".to_string(), + )])) + .into()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + Ok(args.args[0].clone()) + } + } + + let ctx = SessionContext::new(); + ctx.register_udf(MakeExtension::default().into()); + + let batches = ctx + .sql( + " +SELECT extension FROM (VALUES + ('one', make_extension('foofy one')), + ('two', make_extension('foofy two')), + ('three', make_extension('foofy three'))) +AS t(string, extension) + ", + ) + .await? + .collect() + .await?; + + assert_eq!( + batches[0] + .schema() + .field(0) + .metadata() + .get("ARROW:extension:metadata"), + Some(&"foofy.foofy".into()) + ); + Ok(()) +} + +/// https://github.com/apache/datafusion/issues/17422 +#[tokio::test] +async fn test_extension_metadata_preserve_in_subquery() -> Result<()> { + #[derive(Debug, PartialEq, Eq, Hash)] + struct ExtensionScalarPredicate { + signature: Signature, + } + + impl Default for ExtensionScalarPredicate { + fn default() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } + } + + impl ScalarUDFImpl for ExtensionScalarPredicate { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "extension_predicate" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + Ok(arg_types.to_vec()) + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + unreachable!("This shouldn't have been called") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + for arg in args.arg_fields { + assert!(arg.metadata().contains_key("ARROW:extension:name")); + } + + Ok(Field::new("", DataType::Boolean, true).into()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + for arg in args.arg_fields { + assert!(arg.metadata().contains_key("ARROW:extension:name")); + } + + let array = + ScalarValue::Boolean(Some(true)).to_array_of_size(args.number_rows)?; + Ok(ColumnarValue::Array(array)) + } + } + + let schema = Schema::new(vec![ + Field::new("id", DataType::Int64, true), + Field::new("geometry", DataType::Utf8, true).with_metadata(HashMap::from([( + "ARROW:extension:name".to_string(), + "foofy.foofy".to_string(), + )])), + ]); + + let batch_lhs = RecordBatch::try_new( + schema.clone().into(), + vec![ + create_array!(Int64, [1, 2]), + create_array!(Utf8, [Some("item1"), Some("item2")]), + ], + )?; + + let batch_rhs = RecordBatch::try_new( + schema.clone().into(), + vec![ + create_array!(Int64, [2, 3]), + create_array!(Utf8, [Some("item2"), Some("item3")]), + ], + )?; + + let ctx = SessionContext::new(); + ctx.register_batch("l", batch_lhs)?; + ctx.register_batch("r", batch_rhs)?; + ctx.register_udf(ExtensionScalarPredicate::default().into()); + + let df = ctx + .sql( + " + SELECT L.id l_id FROM L + WHERE EXISTS (SELECT 1 FROM R WHERE extension_predicate(L.geometry, R.geometry)) + ORDER BY l_id + ", + ) + .await?; + assert!(!df.collect().await?.is_empty()); + Ok(()) +} diff --git a/datafusion/core/tests/user_defined/user_defined_table_functions.rs b/datafusion/core/tests/user_defined/user_defined_table_functions.rs index e4aff0b00705d..95694d00a6c30 100644 --- a/datafusion/core/tests/user_defined/user_defined_table_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_table_functions.rs @@ -21,17 +21,17 @@ use std::path::Path; use std::sync::Arc; use arrow::array::Int64Array; -use arrow::csv::reader::Format; use arrow::csv::ReaderBuilder; +use arrow::csv::reader::Format; use datafusion::arrow::datatypes::SchemaRef; use datafusion::arrow::record_batch::RecordBatch; use datafusion::common::test_util::batches_to_string; -use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::TableProvider; +use datafusion::datasource::memory::MemorySourceConfig; use datafusion::error::Result; use datafusion::execution::TaskContext; -use datafusion::physical_plan::{collect, ExecutionPlan}; +use datafusion::physical_plan::{ExecutionPlan, collect}; use datafusion::prelude::SessionContext; use datafusion_catalog::Session; use datafusion_catalog::TableFunctionImpl; @@ -55,7 +55,7 @@ async fn test_simple_read_csv_udtf() -> Result<()> { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&rbs), @r###" + insta::assert_snapshot!(batches_to_string(&rbs), @r" +-------------+-----------+-------------+-------------------------------------------------------------------------------------------------------------+ | n_nationkey | n_name | n_regionkey | n_comment | +-------------+-----------+-------------+-------------------------------------------------------------------------------------------------------------+ @@ -65,7 +65,7 @@ async fn test_simple_read_csv_udtf() -> Result<()> { | 4 | EGYPT | 4 | y above the carefully unusual theodolites. final dugouts are quickly across the furiously regular d | | 5 | ETHIOPIA | 0 | ven packages wake quickly. regu | +-------------+-----------+-------------+-------------------------------------------------------------------------------------------------------------+ - "###); + "); // just run, return all rows let rbs = ctx @@ -74,7 +74,7 @@ async fn test_simple_read_csv_udtf() -> Result<()> { .collect() .await?; - insta::assert_snapshot!(batches_to_string(&rbs), @r###" + insta::assert_snapshot!(batches_to_string(&rbs), @r" +-------------+-----------+-------------+--------------------------------------------------------------------------------------------------------------------+ | n_nationkey | n_name | n_regionkey | n_comment | +-------------+-----------+-------------+--------------------------------------------------------------------------------------------------------------------+ @@ -89,7 +89,7 @@ async fn test_simple_read_csv_udtf() -> Result<()> { | 9 | INDONESIA | 2 | slyly express asymptotes. regular deposits haggle slyly. carefully ironic hockey players sleep blithely. carefull | | 10 | IRAN | 4 | efully alongside of the slyly final dependencies. | +-------------+-----------+-------------+--------------------------------------------------------------------------------------------------------------------+ - "###); + "); Ok(()) } @@ -205,7 +205,7 @@ impl TableFunctionImpl for SimpleCsvTableFunc { let mut filepath = String::new(); for expr in exprs { match expr { - Expr::Literal(ScalarValue::Utf8(Some(ref path))) => { + Expr::Literal(ScalarValue::Utf8(Some(path)), _) => { filepath.clone_from(path); } expr => new_exprs.push(expr.clone()), @@ -221,6 +221,31 @@ impl TableFunctionImpl for SimpleCsvTableFunc { } } +/// Test that expressions passed to UDTFs are properly type-coerced +/// This is a regression test for https://github.com/apache/datafusion/issues/19914 +#[tokio::test] +async fn test_udtf_type_coercion() -> Result<()> { + use datafusion::datasource::MemTable; + + #[derive(Debug)] + struct NoOpTableFunc; + + impl TableFunctionImpl for NoOpTableFunc { + fn call(&self, _: &[Expr]) -> Result> { + let schema = Arc::new(arrow::datatypes::Schema::empty()); + Ok(Arc::new(MemTable::try_new(schema, vec![vec![]])?)) + } + } + + let ctx = SessionContext::new(); + ctx.register_udtf("f", Arc::new(NoOpTableFunc)); + + // This should not panic - the array elements should be coerced to Float64 + let _ = ctx.sql("SELECT * FROM f(ARRAY[0.1, 1, 2])").await?; + + Ok(()) +} + fn read_csv_batches(csv_path: impl AsRef) -> Result<(SchemaRef, Vec)> { let mut file = File::open(csv_path)?; let (schema, _) = Format::default() diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs index bcd2c3945e392..775325a337184 100644 --- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -19,8 +19,8 @@ //! user defined window functions use arrow::array::{ - record_batch, Array, ArrayRef, AsArray, Int64Array, RecordBatch, StringArray, - UInt64Array, + Array, ArrayRef, AsArray, Int64Array, RecordBatch, StringArray, UInt64Array, + record_batch, }; use arrow::datatypes::{DataType, Field, Schema}; use arrow_schema::FieldRef; @@ -28,24 +28,27 @@ use datafusion::common::test_util::batches_to_string; use datafusion::common::{Result, ScalarValue}; use datafusion::prelude::SessionContext; use datafusion_common::exec_datafusion_err; +use datafusion_expr::ptr_eq::PtrEq; use datafusion_expr::{ - PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDF, WindowUDFImpl, + LimitEffect, PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDF, + WindowUDFImpl, }; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use datafusion_functions_window_common::{ expr::ExpressionArgs, field::WindowUDFFieldArgs, }; use datafusion_physical_expr::{ - expressions::{col, lit}, PhysicalExpr, + expressions::{col, lit}, }; use std::collections::HashMap; +use std::hash::{Hash, Hasher}; use std::{ any::Any, ops::Range, sync::{ - atomic::{AtomicUsize, Ordering}, Arc, + atomic::{AtomicUsize, Ordering}, }, }; @@ -59,8 +62,7 @@ const UNBOUNDED_WINDOW_QUERY_WITH_ALIAS: &str = "SELECT x, y, val, \ from t ORDER BY x, y"; /// A query with a window function evaluated over a moving window -const BOUNDED_WINDOW_QUERY: &str = - "SELECT x, y, val, \ +const BOUNDED_WINDOW_QUERY: &str = "SELECT x, y, val, \ odd_counter(val) OVER (PARTITION BY x ORDER BY y ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) \ from t ORDER BY x, y"; @@ -72,22 +74,22 @@ async fn test_setup() { let sql = "SELECT * from t order by x, y"; let actual = execute(&ctx, sql).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" - +---+---+-----+ - | x | y | val | - +---+---+-----+ - | 1 | a | 0 | - | 1 | b | 1 | - | 1 | c | 2 | - | 2 | d | 3 | - | 2 | e | 4 | - | 2 | f | 5 | - | 2 | g | 6 | - | 2 | h | 6 | - | 2 | i | 6 | - | 2 | j | 6 | - +---+---+-----+ - "###); + insta::assert_snapshot!(batches_to_string(&actual), @r" + +---+---+-----+ + | x | y | val | + +---+---+-----+ + | 1 | a | 0 | + | 1 | b | 1 | + | 1 | c | 2 | + | 2 | d | 3 | + | 2 | e | 4 | + | 2 | f | 5 | + | 2 | g | 6 | + | 2 | h | 6 | + | 2 | i | 6 | + | 2 | j | 6 | + +---+---+-----+ + "); } /// Basic user defined window function @@ -98,22 +100,22 @@ async fn test_udwf() { let actual = execute(&ctx, UNBOUNDED_WINDOW_QUERY).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" - +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ - | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW | - +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ - | 1 | a | 0 | 1 | - | 1 | b | 1 | 1 | - | 1 | c | 2 | 1 | - | 2 | d | 3 | 2 | - | 2 | e | 4 | 2 | - | 2 | f | 5 | 2 | - | 2 | g | 6 | 2 | - | 2 | h | 6 | 2 | - | 2 | i | 6 | 2 | - | 2 | j | 6 | 2 | - +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ - "###); + insta::assert_snapshot!(batches_to_string(&actual), @r" + +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ + | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW | + +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ + | 1 | a | 0 | 1 | + | 1 | b | 1 | 1 | + | 1 | c | 2 | 1 | + | 2 | d | 3 | 2 | + | 2 | e | 4 | 2 | + | 2 | f | 5 | 2 | + | 2 | g | 6 | 2 | + | 2 | h | 6 | 2 | + | 2 | i | 6 | 2 | + | 2 | j | 6 | 2 | + +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ + "); // evaluated on two distinct batches assert_eq!(test_state.evaluate_all_called(), 2); @@ -126,10 +128,12 @@ async fn test_deregister_udwf() -> Result<()> { OddCounter::register(&mut ctx, Arc::clone(&test_state)); assert!(ctx.state().window_functions().contains_key("odd_counter")); + assert!(datafusion_execution::FunctionRegistry::udwfs(&ctx).contains("odd_counter")); ctx.deregister_udwf("odd_counter"); assert!(!ctx.state().window_functions().contains_key("odd_counter")); + assert!(!datafusion_execution::FunctionRegistry::udwfs(&ctx).contains("odd_counter")); Ok(()) } @@ -143,22 +147,22 @@ async fn test_udwf_with_alias() { .await .unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" - +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ - | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW | - +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ - | 1 | a | 0 | 1 | - | 1 | b | 1 | 1 | - | 1 | c | 2 | 1 | - | 2 | d | 3 | 2 | - | 2 | e | 4 | 2 | - | 2 | f | 5 | 2 | - | 2 | g | 6 | 2 | - | 2 | h | 6 | 2 | - | 2 | i | 6 | 2 | - | 2 | j | 6 | 2 | - +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ - "###); + insta::assert_snapshot!(batches_to_string(&actual), @r" + +---+---+-----+--------------------------+ + | x | y | val | odd_counter_alias(t.val) | + +---+---+-----+--------------------------+ + | 1 | a | 0 | 1 | + | 1 | b | 1 | 1 | + | 1 | c | 2 | 1 | + | 2 | d | 3 | 2 | + | 2 | e | 4 | 2 | + | 2 | f | 5 | 2 | + | 2 | g | 6 | 2 | + | 2 | h | 6 | 2 | + | 2 | i | 6 | 2 | + | 2 | j | 6 | 2 | + +---+---+-----+--------------------------+ + "); } /// Basic user defined window function with bounded window @@ -170,22 +174,22 @@ async fn test_udwf_bounded_window_ignores_frame() { // Since the UDWF doesn't say it needs the window frame, the frame is ignored let actual = execute(&ctx, BOUNDED_WINDOW_QUERY).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING | - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - | 1 | a | 0 | 1 | - | 1 | b | 1 | 1 | - | 1 | c | 2 | 1 | - | 2 | d | 3 | 2 | - | 2 | e | 4 | 2 | - | 2 | f | 5 | 2 | - | 2 | g | 6 | 2 | - | 2 | h | 6 | 2 | - | 2 | i | 6 | 2 | - | 2 | j | 6 | 2 | - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - "###); + insta::assert_snapshot!(batches_to_string(&actual), @r" + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING | + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + | 1 | a | 0 | 1 | + | 1 | b | 1 | 1 | + | 1 | c | 2 | 1 | + | 2 | d | 3 | 2 | + | 2 | e | 4 | 2 | + | 2 | f | 5 | 2 | + | 2 | g | 6 | 2 | + | 2 | h | 6 | 2 | + | 2 | i | 6 | 2 | + | 2 | j | 6 | 2 | + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + "); // evaluated on 2 distinct batches (when x=1 and x=2) assert_eq!(test_state.evaluate_called(), 0); @@ -200,22 +204,22 @@ async fn test_udwf_bounded_window() { let actual = execute(&ctx, BOUNDED_WINDOW_QUERY).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING | - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - | 1 | a | 0 | 1 | - | 1 | b | 1 | 1 | - | 1 | c | 2 | 1 | - | 2 | d | 3 | 1 | - | 2 | e | 4 | 2 | - | 2 | f | 5 | 1 | - | 2 | g | 6 | 1 | - | 2 | h | 6 | 0 | - | 2 | i | 6 | 0 | - | 2 | j | 6 | 0 | - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - "###); + insta::assert_snapshot!(batches_to_string(&actual), @r" + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING | + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + | 1 | a | 0 | 1 | + | 1 | b | 1 | 1 | + | 1 | c | 2 | 1 | + | 2 | d | 3 | 1 | + | 2 | e | 4 | 2 | + | 2 | f | 5 | 1 | + | 2 | g | 6 | 1 | + | 2 | h | 6 | 0 | + | 2 | i | 6 | 0 | + | 2 | j | 6 | 0 | + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + "); // Evaluate is called for each input rows assert_eq!(test_state.evaluate_called(), 10); @@ -232,22 +236,22 @@ async fn test_stateful_udwf() { let actual = execute(&ctx, UNBOUNDED_WINDOW_QUERY).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" - +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ - | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW | - +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ - | 1 | a | 0 | 0 | - | 1 | b | 1 | 1 | - | 1 | c | 2 | 1 | - | 2 | d | 3 | 1 | - | 2 | e | 4 | 1 | - | 2 | f | 5 | 2 | - | 2 | g | 6 | 2 | - | 2 | h | 6 | 2 | - | 2 | i | 6 | 2 | - | 2 | j | 6 | 2 | - +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ - "###); + insta::assert_snapshot!(batches_to_string(&actual), @r" + +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ + | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW | + +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ + | 1 | a | 0 | 0 | + | 1 | b | 1 | 1 | + | 1 | c | 2 | 1 | + | 2 | d | 3 | 1 | + | 2 | e | 4 | 1 | + | 2 | f | 5 | 2 | + | 2 | g | 6 | 2 | + | 2 | h | 6 | 2 | + | 2 | i | 6 | 2 | + | 2 | j | 6 | 2 | + +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ + "); assert_eq!(test_state.evaluate_called(), 10); assert_eq!(test_state.evaluate_all_called(), 0); @@ -263,22 +267,22 @@ async fn test_stateful_udwf_bounded_window() { let actual = execute(&ctx, BOUNDED_WINDOW_QUERY).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING | - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - | 1 | a | 0 | 1 | - | 1 | b | 1 | 1 | - | 1 | c | 2 | 1 | - | 2 | d | 3 | 1 | - | 2 | e | 4 | 2 | - | 2 | f | 5 | 1 | - | 2 | g | 6 | 1 | - | 2 | h | 6 | 0 | - | 2 | i | 6 | 0 | - | 2 | j | 6 | 0 | - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - "###); + insta::assert_snapshot!(batches_to_string(&actual), @r" + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING | + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + | 1 | a | 0 | 1 | + | 1 | b | 1 | 1 | + | 1 | c | 2 | 1 | + | 2 | d | 3 | 1 | + | 2 | e | 4 | 2 | + | 2 | f | 5 | 1 | + | 2 | g | 6 | 1 | + | 2 | h | 6 | 0 | + | 2 | i | 6 | 0 | + | 2 | j | 6 | 0 | + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + "); // Evaluate and update_state is called for each input row assert_eq!(test_state.evaluate_called(), 10); @@ -293,22 +297,22 @@ async fn test_udwf_query_include_rank() { let actual = execute(&ctx, UNBOUNDED_WINDOW_QUERY).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" - +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ - | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW | - +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ - | 1 | a | 0 | 3 | - | 1 | b | 1 | 2 | - | 1 | c | 2 | 1 | - | 2 | d | 3 | 7 | - | 2 | e | 4 | 6 | - | 2 | f | 5 | 5 | - | 2 | g | 6 | 4 | - | 2 | h | 6 | 3 | - | 2 | i | 6 | 2 | - | 2 | j | 6 | 1 | - +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ - "###); + insta::assert_snapshot!(batches_to_string(&actual), @r" + +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ + | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW | + +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ + | 1 | a | 0 | 3 | + | 1 | b | 1 | 2 | + | 1 | c | 2 | 1 | + | 2 | d | 3 | 7 | + | 2 | e | 4 | 6 | + | 2 | f | 5 | 5 | + | 2 | g | 6 | 4 | + | 2 | h | 6 | 3 | + | 2 | i | 6 | 2 | + | 2 | j | 6 | 1 | + +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ + "); assert_eq!(test_state.evaluate_called(), 0); assert_eq!(test_state.evaluate_all_called(), 0); @@ -324,22 +328,22 @@ async fn test_udwf_bounded_query_include_rank() { let actual = execute(&ctx, BOUNDED_WINDOW_QUERY).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING | - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - | 1 | a | 0 | 3 | - | 1 | b | 1 | 2 | - | 1 | c | 2 | 1 | - | 2 | d | 3 | 7 | - | 2 | e | 4 | 6 | - | 2 | f | 5 | 5 | - | 2 | g | 6 | 4 | - | 2 | h | 6 | 3 | - | 2 | i | 6 | 2 | - | 2 | j | 6 | 1 | - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - "###); + insta::assert_snapshot!(batches_to_string(&actual), @r" + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING | + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + | 1 | a | 0 | 3 | + | 1 | b | 1 | 2 | + | 1 | c | 2 | 1 | + | 2 | d | 3 | 7 | + | 2 | e | 4 | 6 | + | 2 | f | 5 | 5 | + | 2 | g | 6 | 4 | + | 2 | h | 6 | 3 | + | 2 | i | 6 | 2 | + | 2 | j | 6 | 1 | + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + "); assert_eq!(test_state.evaluate_called(), 0); assert_eq!(test_state.evaluate_all_called(), 0); @@ -357,22 +361,22 @@ async fn test_udwf_bounded_window_returns_null() { let actual = execute(&ctx, BOUNDED_WINDOW_QUERY).await.unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING | - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - | 1 | a | 0 | 1 | - | 1 | b | 1 | 1 | - | 1 | c | 2 | 1 | - | 2 | d | 3 | 1 | - | 2 | e | 4 | 2 | - | 2 | f | 5 | 1 | - | 2 | g | 6 | 1 | - | 2 | h | 6 | | - | 2 | i | 6 | | - | 2 | j | 6 | | - +---+---+-----+--------------------------------------------------------------------------------------------------------------+ - "###); + insta::assert_snapshot!(batches_to_string(&actual), @r" + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING | + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + | 1 | a | 0 | 1 | + | 1 | b | 1 | 1 | + | 1 | c | 2 | 1 | + | 2 | d | 3 | 1 | + | 2 | e | 4 | 2 | + | 2 | f | 5 | 1 | + | 2 | g | 6 | 1 | + | 2 | h | 6 | | + | 2 | i | 6 | | + | 2 | j | 6 | | + +---+---+-----+--------------------------------------------------------------------------------------------------------------+ + "); // Evaluate is called for each input rows assert_eq!(test_state.evaluate_called(), 10); @@ -522,20 +526,20 @@ impl OddCounter { } fn register(ctx: &mut SessionContext, test_state: Arc) { - #[derive(Debug, Clone)] + #[derive(Debug, Clone, PartialEq, Eq, Hash)] struct SimpleWindowUDF { signature: Signature, - test_state: Arc, + test_state: PtrEq>, aliases: Vec, } impl SimpleWindowUDF { fn new(test_state: Arc) -> Self { let signature = - Signature::exact(vec![DataType::Float64], Volatility::Immutable); + Signature::exact(vec![DataType::Int64], Volatility::Immutable); Self { signature, - test_state, + test_state: test_state.into(), aliases: vec!["odd_counter_alias".to_string()], } } @@ -568,6 +572,10 @@ impl OddCounter { fn field(&self, field_args: WindowUDFFieldArgs) -> Result { Ok(Field::new(field_args.name(), DataType::Int64, true).into()) } + + fn limit_effect(&self, _args: &[Arc]) -> LimitEffect { + LimitEffect::Unknown + } } ctx.register_udwf(WindowUDF::from(SimpleWindowUDF::new(test_state))) @@ -607,7 +615,9 @@ impl PartitionEvaluator for OddCounter { ranks_in_partition: &[Range], ) -> Result { self.test_state.inc_evaluate_all_with_rank_called(); - println!("evaluate_all_with_rank, values: {num_rows:#?}, ranks_in_partitions: {ranks_in_partition:?}"); + println!( + "evaluate_all_with_rank, values: {num_rows:#?}, ranks_in_partitions: {ranks_in_partition:?}" + ); // when evaluating with ranks, just return the inverse rank instead let array: Int64Array = ranks_in_partition .iter() @@ -643,7 +653,7 @@ fn odd_count_arr(arr: &Int64Array, num_rows: usize) -> ArrayRef { Arc::new(array) } -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] struct VariadicWindowUDF { signature: Signature, } @@ -687,6 +697,10 @@ impl WindowUDFImpl for VariadicWindowUDF { fn field(&self, _: WindowUDFFieldArgs) -> Result { unimplemented!("unnecessary for testing"); } + + fn limit_effect(&self, _args: &[Arc]) -> LimitEffect { + LimitEffect::Unknown + } } #[test] @@ -765,6 +779,31 @@ struct MetadataBasedWindowUdf { metadata: HashMap, } +impl PartialEq for MetadataBasedWindowUdf { + fn eq(&self, other: &Self) -> bool { + let Self { + name, + signature, + metadata, + } = self; + name == &other.name + && signature == &other.signature + && metadata == &other.metadata + } +} +impl Eq for MetadataBasedWindowUdf {} +impl Hash for MetadataBasedWindowUdf { + fn hash(&self, state: &mut H) { + let Self { + name, + signature, + metadata: _, // unhashable + } = self; + name.hash(state); + signature.hash(state); + } +} + impl MetadataBasedWindowUdf { fn new(metadata: HashMap) -> Self { // The name we return must be unique. Otherwise we will not call distinct @@ -815,6 +854,10 @@ impl WindowUDFImpl for MetadataBasedWindowUdf { .with_metadata(self.metadata.clone()) .into()) } + + fn limit_effect(&self, _args: &[Arc]) -> LimitEffect { + LimitEffect::Unknown + } } #[derive(Debug)] diff --git a/datafusion/datasource-arrow/Cargo.toml b/datafusion/datasource-arrow/Cargo.toml new file mode 100644 index 0000000000000..fbadc8708ca69 --- /dev/null +++ b/datafusion/datasource-arrow/Cargo.toml @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-datasource-arrow" +description = "datafusion-datasource-arrow" +readme = "README.md" +authors.workspace = true +edition.workspace = true +homepage.workspace = true +license.workspace = true +repository.workspace = true +rust-version.workspace = true +version.workspace = true + +[package.metadata.docs.rs] +all-features = true + +[dependencies] +arrow = { workspace = true } +arrow-ipc = { workspace = true } +async-trait = { workspace = true } +bytes = { workspace = true } +datafusion-common = { workspace = true, features = ["object_store"] } +datafusion-common-runtime = { workspace = true } +datafusion-datasource = { workspace = true } +datafusion-execution = { workspace = true } +datafusion-expr = { workspace = true } +datafusion-physical-expr-common = { workspace = true } +datafusion-physical-plan = { workspace = true } +datafusion-session = { workspace = true } +futures = { workspace = true } +itertools = { workspace = true } +object_store = { workspace = true } +tokio = { workspace = true } + +[dev-dependencies] +chrono = { workspace = true } + +# Note: add additional linter rules in lib.rs. +# Rust does not support workspace + new linter rules in subcrates yet +# https://github.com/rust-lang/cargo/issues/13157 +[lints] +workspace = true + +[lib] +name = "datafusion_datasource_arrow" +path = "src/mod.rs" + +[features] +compression = [ + "arrow-ipc/zstd", +] diff --git a/datafusion/datasource-arrow/LICENSE.txt b/datafusion/datasource-arrow/LICENSE.txt new file mode 100644 index 0000000000000..d74c6b599d2ae --- /dev/null +++ b/datafusion/datasource-arrow/LICENSE.txt @@ -0,0 +1,212 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +This project includes code from Apache Aurora. + +* dev/release/{release,changelog,release-candidate} are based on the scripts from + Apache Aurora + +Copyright: 2016 The Apache Software Foundation. +Home page: https://aurora.apache.org/ +License: http://www.apache.org/licenses/LICENSE-2.0 diff --git a/datafusion/datasource-arrow/NOTICE.txt b/datafusion/datasource-arrow/NOTICE.txt new file mode 100644 index 0000000000000..0bd2d52368fea --- /dev/null +++ b/datafusion/datasource-arrow/NOTICE.txt @@ -0,0 +1,5 @@ +Apache DataFusion +Copyright 2019-2026 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). diff --git a/datafusion/datasource-arrow/README.md b/datafusion/datasource-arrow/README.md new file mode 100644 index 0000000000000..9901b52105dd4 --- /dev/null +++ b/datafusion/datasource-arrow/README.md @@ -0,0 +1,34 @@ + + +# Apache DataFusion Arrow DataSource + +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. + +This crate is a submodule of DataFusion that defines a Arrow based file source. +It works with files following the [Arrow IPC format]. + +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[`datafusion`]: https://crates.io/crates/datafusion +[arrow ipc format]: https://arrow.apache.org/docs/format/Columnar.html#ipc-file-format diff --git a/datafusion/datasource-arrow/src/file_format.rs b/datafusion/datasource-arrow/src/file_format.rs new file mode 100644 index 0000000000000..f60bce3249935 --- /dev/null +++ b/datafusion/datasource-arrow/src/file_format.rs @@ -0,0 +1,782 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ArrowFormat`]: Apache Arrow [`FileFormat`] abstractions +//! +//! Works with files following the [Arrow IPC format](https://arrow.apache.org/docs/format/Columnar.html#ipc-file-format) + +use std::any::Any; +use std::collections::HashMap; +use std::fmt::{self, Debug}; +use std::io::{Seek, SeekFrom}; +use std::sync::Arc; + +use arrow::datatypes::{Schema, SchemaRef}; +use arrow::error::ArrowError; +use arrow::ipc::convert::fb_to_schema; +use arrow::ipc::reader::{FileReader, StreamReader}; +use arrow::ipc::writer::IpcWriteOptions; +use arrow::ipc::{CompressionType, root_as_message}; +use datafusion_common::error::Result; +use datafusion_common::parsers::CompressionTypeVariant; +use datafusion_common::{ + DEFAULT_ARROW_EXTENSION, DataFusionError, GetExt, Statistics, + internal_datafusion_err, not_impl_err, +}; +use datafusion_common_runtime::{JoinSet, SpawnedTask}; +use datafusion_datasource::TableSchema; +use datafusion_datasource::display::FileGroupDisplay; +use datafusion_datasource::file::FileSource; +use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; +use datafusion_datasource::sink::{DataSink, DataSinkExec}; +use datafusion_datasource::write::{ + ObjectWriterBuilder, SharedBuffer, get_writer_schema, +}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_expr::dml::InsertOp; +use datafusion_physical_expr_common::sort_expr::LexRequirement; + +use crate::source::ArrowSource; +use async_trait::async_trait; +use bytes::Bytes; +use datafusion_datasource::file_compression_type::FileCompressionType; +use datafusion_datasource::file_format::{FileFormat, FileFormatFactory}; +use datafusion_datasource::file_sink_config::{FileSink, FileSinkConfig}; +use datafusion_datasource::source::DataSourceExec; +use datafusion_datasource::write::demux::DemuxedStreamReceiver; +use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; +use datafusion_session::Session; +use futures::StreamExt; +use futures::stream::BoxStream; +use object_store::{ + GetOptions, GetRange, GetResultPayload, ObjectMeta, ObjectStore, ObjectStoreExt, + path::Path, +}; +use tokio::io::AsyncWriteExt; + +/// Initial writing buffer size. Note this is just a size hint for efficiency. It +/// will grow beyond the set value if needed. +const INITIAL_BUFFER_BYTES: usize = 1048576; + +/// If the buffered Arrow data exceeds this size, it is flushed to object store +const BUFFER_FLUSH_BYTES: usize = 1024000; + +/// Factory struct used to create [`ArrowFormat`] +#[derive(Default, Debug)] +pub struct ArrowFormatFactory; + +impl ArrowFormatFactory { + /// Creates an instance of [ArrowFormatFactory] + pub fn new() -> Self { + Self {} + } +} + +impl FileFormatFactory for ArrowFormatFactory { + fn create( + &self, + _state: &dyn Session, + _format_options: &HashMap, + ) -> Result> { + Ok(Arc::new(ArrowFormat)) + } + + fn default(&self) -> Arc { + Arc::new(ArrowFormat) + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +impl GetExt for ArrowFormatFactory { + fn get_ext(&self) -> String { + // Removes the dot, i.e. ".parquet" -> "parquet" + DEFAULT_ARROW_EXTENSION[1..].to_string() + } +} + +/// Arrow [`FileFormat`] implementation. +#[derive(Default, Debug)] +pub struct ArrowFormat; + +#[async_trait] +impl FileFormat for ArrowFormat { + fn as_any(&self) -> &dyn Any { + self + } + + fn get_ext(&self) -> String { + ArrowFormatFactory::new().get_ext() + } + + fn get_ext_with_compression( + &self, + file_compression_type: &FileCompressionType, + ) -> Result { + let ext = self.get_ext(); + match file_compression_type.get_variant() { + CompressionTypeVariant::UNCOMPRESSED => Ok(ext), + _ => Err(internal_datafusion_err!( + "Arrow FileFormat does not support compression." + )), + } + } + + fn compression_type(&self) -> Option { + None + } + + async fn infer_schema( + &self, + _state: &dyn Session, + store: &Arc, + objects: &[ObjectMeta], + ) -> Result { + let mut schemas = vec![]; + for object in objects { + let r = store.as_ref().get(&object.location).await?; + let schema = match r.payload { + #[cfg(not(target_arch = "wasm32"))] + GetResultPayload::File(mut file, _) => { + match FileReader::try_new(&mut file, None) { + Ok(reader) => reader.schema(), + Err(file_error) => { + // not in the file format, but FileReader read some bytes + // while trying to parse the file and so we need to rewind + // it to the beginning of the file + file.seek(SeekFrom::Start(0))?; + match StreamReader::try_new(&mut file, None) { + Ok(reader) => reader.schema(), + Err(stream_error) => { + return Err(internal_datafusion_err!( + "Failed to parse Arrow file as either file format or stream format. File format error: {file_error}. Stream format error: {stream_error}" + )); + } + } + } + } + } + GetResultPayload::Stream(stream) => infer_stream_schema(stream).await?, + }; + schemas.push(schema.as_ref().clone()); + } + let merged_schema = Schema::try_merge(schemas)?; + Ok(Arc::new(merged_schema)) + } + + async fn infer_stats( + &self, + _state: &dyn Session, + _store: &Arc, + table_schema: SchemaRef, + _object: &ObjectMeta, + ) -> Result { + Ok(Statistics::new_unknown(&table_schema)) + } + + async fn create_physical_plan( + &self, + state: &dyn Session, + conf: FileScanConfig, + ) -> Result> { + let object_store = state.runtime_env().object_store(&conf.object_store_url)?; + let object_location = &conf + .file_groups + .first() + .ok_or_else(|| internal_datafusion_err!("No files found in file group"))? + .files() + .first() + .ok_or_else(|| internal_datafusion_err!("No files found in file group"))? + .object_meta + .location; + + let table_schema = TableSchema::new( + Arc::clone(conf.file_schema()), + conf.table_partition_cols().clone(), + ); + + let mut source: Arc = + match is_object_in_arrow_ipc_file_format(object_store, object_location).await + { + Ok(true) => Arc::new(ArrowSource::new_file_source(table_schema)), + Ok(false) => Arc::new(ArrowSource::new_stream_file_source(table_schema)), + Err(e) => Err(e)?, + }; + + // Preserve projection from the original file source + if let Some(projection) = conf.file_source.projection() + && let Some(new_source) = source.try_pushdown_projection(projection)? + { + source = new_source; + } + + let config = FileScanConfigBuilder::from(conf) + .with_source(source) + .build(); + + Ok(DataSourceExec::from_data_source(config)) + } + + async fn create_writer_physical_plan( + &self, + input: Arc, + _state: &dyn Session, + conf: FileSinkConfig, + order_requirements: Option, + ) -> Result> { + if conf.insert_op != InsertOp::Append { + return not_impl_err!("Overwrites are not implemented yet for Arrow format"); + } + + let sink = Arc::new(ArrowFileSink::new(conf)); + + Ok(Arc::new(DataSinkExec::new(input, sink, order_requirements)) as _) + } + + fn file_source(&self, table_schema: TableSchema) -> Arc { + Arc::new(ArrowSource::new_file_source(table_schema)) + } +} + +/// Implements [`FileSink`] for Arrow IPC files +struct ArrowFileSink { + config: FileSinkConfig, +} + +impl ArrowFileSink { + fn new(config: FileSinkConfig) -> Self { + Self { config } + } +} + +#[async_trait] +impl FileSink for ArrowFileSink { + fn config(&self) -> &FileSinkConfig { + &self.config + } + + async fn spawn_writer_tasks_and_join( + &self, + context: &Arc, + demux_task: SpawnedTask>, + mut file_stream_rx: DemuxedStreamReceiver, + object_store: Arc, + ) -> Result { + let mut file_write_tasks: JoinSet> = + JoinSet::new(); + + let ipc_options = + IpcWriteOptions::try_new(64, false, arrow_ipc::MetadataVersion::V5)? + .try_with_compression(Some(CompressionType::LZ4_FRAME))?; + while let Some((path, mut rx)) = file_stream_rx.recv().await { + let shared_buffer = SharedBuffer::new(INITIAL_BUFFER_BYTES); + let mut arrow_writer = arrow_ipc::writer::FileWriter::try_new_with_options( + shared_buffer.clone(), + &get_writer_schema(&self.config), + ipc_options.clone(), + )?; + let mut object_store_writer = ObjectWriterBuilder::new( + FileCompressionType::UNCOMPRESSED, + &path, + Arc::clone(&object_store), + ) + .with_buffer_size(Some( + context + .session_config() + .options() + .execution + .objectstore_writer_buffer_size, + )) + .build()?; + file_write_tasks.spawn(async move { + let mut row_count = 0; + while let Some(batch) = rx.recv().await { + row_count += batch.num_rows(); + arrow_writer.write(&batch)?; + let mut buff_to_flush = shared_buffer.buffer.try_lock().unwrap(); + if buff_to_flush.len() > BUFFER_FLUSH_BYTES { + object_store_writer + .write_all(buff_to_flush.as_slice()) + .await?; + buff_to_flush.clear(); + } + } + arrow_writer.finish()?; + let final_buff = shared_buffer.buffer.try_lock().unwrap(); + + object_store_writer.write_all(final_buff.as_slice()).await?; + object_store_writer.shutdown().await?; + Ok(row_count) + }); + } + + let mut row_count = 0; + while let Some(result) = file_write_tasks.join_next().await { + match result { + Ok(r) => { + row_count += r?; + } + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } + } + + demux_task + .join_unwind() + .await + .map_err(|e| DataFusionError::ExecutionJoin(Box::new(e)))??; + Ok(row_count as u64) + } +} + +impl Debug for ArrowFileSink { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ArrowFileSink").finish() + } +} + +impl DisplayAs for ArrowFileSink { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "ArrowFileSink(file_groups=",)?; + FileGroupDisplay(&self.config.file_group).fmt_as(t, f)?; + write!(f, ")") + } + DisplayFormatType::TreeRender => { + writeln!(f, "format: arrow")?; + write!(f, "file={}", &self.config.original_url) + } + } + } +} + +#[async_trait] +impl DataSink for ArrowFileSink { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> &SchemaRef { + self.config.output_schema() + } + + async fn write_all( + &self, + data: SendableRecordBatchStream, + context: &Arc, + ) -> Result { + FileSink::write_all(self, data, context).await + } +} + +// Custom implementation of inferring schema. Should eventually be moved upstream to arrow-rs. +// See + +const ARROW_MAGIC: [u8; 6] = [b'A', b'R', b'R', b'O', b'W', b'1']; +const CONTINUATION_MARKER: [u8; 4] = [0xff; 4]; + +async fn infer_stream_schema( + mut stream: BoxStream<'static, object_store::Result>, +) -> Result { + // IPC streaming format. + // See https://arrow.apache.org/docs/format/Columnar.html#ipc-streaming-format + // + // + // + // ... + // + // + // ... + // + // ... + // + // ... + // + // + + // The streaming format is made up of a sequence of encapsulated messages. + // See https://arrow.apache.org/docs/format/Columnar.html#encapsulated-message-format + // + // (added in v0.15.0) + // + // + // + // + // + // The first message is the schema. + + // IPC file format is a wrapper around the streaming format with indexing information. + // See https://arrow.apache.org/docs/format/Columnar.html#ipc-file-format + // + // + // + // + //