diff --git a/bolt/functions/flinksql/ElementAt.cpp b/bolt/functions/flinksql/ElementAt.cpp index d4fcfa0c7..26f3a1f24 100644 --- a/bolt/functions/flinksql/ElementAt.cpp +++ b/bolt/functions/flinksql/ElementAt.cpp @@ -47,8 +47,8 @@ class FlinkElementAtFunction : public SubscriptImpl< /* isElementAt */ true, /* nullOnNonConstantInvalidIndex */ true> { public: - explicit FlinkElementAtFunction(bool allowCaching) - : SubscriptImpl(allowCaching) {} + FlinkElementAtFunction(bool allowCaching, bool constantIndexInvalidThrows) + : SubscriptImpl(allowCaching, constantIndexInvalidThrows) {} }; } // namespace @@ -60,13 +60,15 @@ void registerFlinkElementAtFunction(const std::string& name) { [](const std::string&, const std::vector& inputArgs, const bolt::core::QueryConfig& config) { - static const auto kSubscriptStateLess = - std::make_shared(false); + const bool constantIndexInvalidThrows = + inputArgs.size() > 1 && inputArgs[1].constantValue != nullptr; if (inputArgs[0].type->isArray()) { - return kSubscriptStateLess; + return std::make_shared( + false, constantIndexInvalidThrows); } else { return std::make_shared( - config.isExpressionEvaluationCacheEnabled()); + config.isExpressionEvaluationCacheEnabled(), + constantIndexInvalidThrows); } }); } diff --git a/bolt/functions/flinksql/tests/ElementAtTest.cpp b/bolt/functions/flinksql/tests/ElementAtTest.cpp index 8a69110a3..fc79829d5 100644 --- a/bolt/functions/flinksql/tests/ElementAtTest.cpp +++ b/bolt/functions/flinksql/tests/ElementAtTest.cpp @@ -33,8 +33,6 @@ class ElementAtTest : public FlinkFunctionBaseTest {}; /// 4. Index out of bounds (constant or non-constant) → returns NULL. /// 5. Valid index → returns the corresponding element. -// ── Constant index: valid accesses ────────────────────────────────────────── - TEST_F(ElementAtTest, constantValidIndex) { auto arrayVector = makeArrayVector({{10, 20, 30}}); EXPECT_EQ( @@ -49,8 +47,6 @@ TEST_F(ElementAtTest, constantValidIndex) { 30); } -// ── Constant index: index == 0 → throws ───────────────────────────────────── - TEST_F(ElementAtTest, constantZeroIndexThrows) { auto arrayVector = makeArrayVector({{10, 20, 30}}); BOLT_ASSERT_THROW( @@ -59,7 +55,17 @@ TEST_F(ElementAtTest, constantZeroIndexThrows) { "SQL array indices start at 1"); } -// ── Constant index: negative index → throws ────────────────────────────────── +TEST_F(ElementAtTest, constantZeroIndexThrowsForAllRows) { + auto arrayVector = makeArrayVector({ + {10, 20, 30}, + {40, 50, 60}, + {70, 80, 90}, + }); + BOLT_ASSERT_THROW( + evaluate>( + "element_at(C0, 0)", makeRowVector({arrayVector})), + "SQL array indices start at 1"); +} TEST_F(ElementAtTest, constantNegativeIndexThrows) { auto arrayVector = makeArrayVector({{10, 20, 30}}); @@ -69,8 +75,6 @@ TEST_F(ElementAtTest, constantNegativeIndexThrows) { "SQL array indices start at 1"); } -// ── Constant index: out-of-bounds → returns NULL ───────────────────────────── - TEST_F(ElementAtTest, constantOutOfBoundsReturnsNull) { auto arrayVector = makeArrayVector({{10, 20, 30}}); auto result = evaluate>( @@ -78,8 +82,6 @@ TEST_F(ElementAtTest, constantOutOfBoundsReturnsNull) { EXPECT_TRUE(result->isNullAt(0)); } -// ── Non-constant index: zero → returns NULL ────────────────────────────────── - TEST_F(ElementAtTest, nonConstantZeroIndexReturnsNull) { // Build a flat array column and an index column where index = 0 at runtime. auto arrayVector = makeArrayVector({{10, 20, 30}, {40, 50}}); @@ -90,7 +92,17 @@ TEST_F(ElementAtTest, nonConstantZeroIndexReturnsNull) { EXPECT_TRUE(result->isNullAt(1)); } -// ── Non-constant index: negative → returns NULL ────────────────────────────── +TEST_F(ElementAtTest, runtimeConstantZeroIndexReturnsNull) { + auto arrayVector = makeArrayVector({{10, 20, 30}, {40, 50}}); + auto zeroIndexBase = makeFlatVector({0}); + auto indexVector = BaseVector::wrapInConstant(2, 0, zeroIndexBase); + + auto result = evaluate>( + "element_at(C0, C1)", makeRowVector({arrayVector, indexVector})); + + EXPECT_TRUE(result->isNullAt(0)); + EXPECT_TRUE(result->isNullAt(1)); +} TEST_F(ElementAtTest, nonConstantNegativeIndexReturnsNull) { auto arrayVector = makeArrayVector({{10, 20, 30}, {40, 50}}); @@ -101,8 +113,6 @@ TEST_F(ElementAtTest, nonConstantNegativeIndexReturnsNull) { EXPECT_TRUE(result->isNullAt(1)); } -// ── Non-constant index: out-of-bounds → returns NULL ───────────────────────── - TEST_F(ElementAtTest, nonConstantOutOfBoundsReturnsNull) { auto arrayVector = makeArrayVector({{10, 20, 30}}); auto indexVector = makeFlatVector({99}); @@ -111,8 +121,6 @@ TEST_F(ElementAtTest, nonConstantOutOfBoundsReturnsNull) { EXPECT_TRUE(result->isNullAt(0)); } -// ── Non-constant index: mixed valid / invalid rows ─────────────────────────── - TEST_F(ElementAtTest, nonConstantMixedRows) { auto arrayVector = makeArrayVector({{10, 20, 30}, {40, 50}, {60}}); auto indexVector = makeFlatVector({2, 0, -1}); diff --git a/bolt/functions/lib/SubscriptUtil.h b/bolt/functions/lib/SubscriptUtil.h index fa313adbc..4f2c0063e 100644 --- a/bolt/functions/lib/SubscriptUtil.h +++ b/bolt/functions/lib/SubscriptUtil.h @@ -199,8 +199,11 @@ template < bool nullOnNonConstantInvalidIndex = false> class SubscriptImpl : public exec::Subscript { public: - explicit SubscriptImpl(bool allowCaching) - : mapSubscript_(MapSubscript(allowCaching)) {} + explicit SubscriptImpl( + bool allowCaching, + bool constantIndexInvalidThrows = false) + : mapSubscript_(MapSubscript(allowCaching)), + constantIndexInvalidThrows_(constantIndexInvalidThrows) {} void apply( const SelectivityVector& rows, @@ -313,19 +316,32 @@ class SubscriptImpl : public exec::Subscript { isZeroSubscriptError, zeroBasedArrayIndex); if (isZeroSubscriptError) { - context.setErrors(rows, zeroSubscriptError()); - allFailed = true; + if constexpr (nullOnNonConstantInvalidIndex) { + if (constantIndexInvalidThrows_) { + std::rethrow_exception(zeroSubscriptError()); + } + return BaseVector::createNullConstant( + baseArray->elements()->type(), rows.end(), context.pool()); + } else { + context.setErrors(rows, zeroSubscriptError()); + allFailed = true; + } } if (!allFailed) { rows.applyToSelected([&](auto row) { const auto elementIndex = getIndex( adjustedIndex, row, rawSizes, rawOffsets, arrayIndices, context); - rawIndices[row] = elementIndex; if (elementIndex == -1) { nullsBuilder.setNull(row); + rawIndices[row] = 0; + } else { + rawIndices[row] = elementIndex; } }); + } else { + return BaseVector::createNullConstant( + baseArray->elements()->type(), rows.end(), context.pool()); } } else { rows.applyToSelected([&](auto row) { @@ -345,9 +361,11 @@ class SubscriptImpl : public exec::Subscript { } const auto elementIndex = getIndex( adjustedIndex, row, rawSizes, rawOffsets, arrayIndices, context); - rawIndices[row] = elementIndex; if (elementIndex == -1) { nullsBuilder.setNull(row); + rawIndices[row] = 0; + } else { + rawIndices[row] = elementIndex; } }); } @@ -429,6 +447,9 @@ class SubscriptImpl : public exec::Subscript { index += arraySize; } } else { + if constexpr (nullOnNonConstantInvalidIndex) { + return -1; + } context.setBoltExceptionError(row, negativeSubscriptError()); return -1; } @@ -452,6 +473,7 @@ class SubscriptImpl : public exec::Subscript { private: MapSubscript mapSubscript_; + const bool constantIndexInvalidThrows_; }; } // namespace bytedance::bolt::functions