Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions bolt/functions/flinksql/ElementAt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ class FlinkElementAtFunction : public SubscriptImpl<
/* isElementAt */ true,
/* nullOnNonConstantInvalidIndex */ true> {
public:
explicit FlinkElementAtFunction(bool allowCaching)
: SubscriptImpl(allowCaching) {}
FlinkElementAtFunction(bool allowCaching, bool constantIndexInvalidThrows)
: SubscriptImpl(allowCaching, constantIndexInvalidThrows) {}
};

} // namespace
Expand All @@ -60,13 +60,15 @@ void registerFlinkElementAtFunction(const std::string& name) {
[](const std::string&,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const bolt::core::QueryConfig& config) {
static const auto kSubscriptStateLess =
std::make_shared<FlinkElementAtFunction>(false);
const bool constantIndexInvalidThrows =
inputArgs.size() > 1 && inputArgs[1].constantValue != nullptr;
if (inputArgs[0].type->isArray()) {
return kSubscriptStateLess;
return std::make_shared<FlinkElementAtFunction>(
false, constantIndexInvalidThrows);
} else {
return std::make_shared<FlinkElementAtFunction>(
config.isExpressionEvaluationCacheEnabled());
config.isExpressionEvaluationCacheEnabled(),
constantIndexInvalidThrows);
}
});
}
Expand Down
36 changes: 22 additions & 14 deletions bolt/functions/flinksql/tests/ElementAtTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ class ElementAtTest : public FlinkFunctionBaseTest {};
/// 4. Index out of bounds (constant or non-constant) β†’ returns NULL.
/// 5. Valid index β†’ returns the corresponding element.

// ── Constant index: valid accesses ──────────────────────────────────────────

TEST_F(ElementAtTest, constantValidIndex) {
auto arrayVector = makeArrayVector<int64_t>({{10, 20, 30}});
EXPECT_EQ(
Expand All @@ -49,8 +47,6 @@ TEST_F(ElementAtTest, constantValidIndex) {
30);
}

// ── Constant index: index == 0 β†’ throws ─────────────────────────────────────

TEST_F(ElementAtTest, constantZeroIndexThrows) {
auto arrayVector = makeArrayVector<int64_t>({{10, 20, 30}});
BOLT_ASSERT_THROW(
Expand All @@ -59,7 +55,17 @@ TEST_F(ElementAtTest, constantZeroIndexThrows) {
"SQL array indices start at 1");
}

// ── Constant index: negative index β†’ throws ──────────────────────────────────
TEST_F(ElementAtTest, constantZeroIndexThrowsForAllRows) {
auto arrayVector = makeArrayVector<int64_t>({
{10, 20, 30},
{40, 50, 60},
{70, 80, 90},
});
BOLT_ASSERT_THROW(
evaluate<SimpleVector<int64_t>>(
"element_at(C0, 0)", makeRowVector({arrayVector})),
"SQL array indices start at 1");
}

TEST_F(ElementAtTest, constantNegativeIndexThrows) {
auto arrayVector = makeArrayVector<int64_t>({{10, 20, 30}});
Expand All @@ -69,17 +75,13 @@ TEST_F(ElementAtTest, constantNegativeIndexThrows) {
"SQL array indices start at 1");
}

// ── Constant index: out-of-bounds β†’ returns NULL ─────────────────────────────

TEST_F(ElementAtTest, constantOutOfBoundsReturnsNull) {
auto arrayVector = makeArrayVector<int64_t>({{10, 20, 30}});
auto result = evaluate<SimpleVector<int64_t>>(
"element_at(C0, 99)", makeRowVector({arrayVector}));
EXPECT_TRUE(result->isNullAt(0));
}

// ── Non-constant index: zero β†’ returns NULL ──────────────────────────────────

TEST_F(ElementAtTest, nonConstantZeroIndexReturnsNull) {
// Build a flat array column and an index column where index = 0 at runtime.
auto arrayVector = makeArrayVector<int64_t>({{10, 20, 30}, {40, 50}});
Expand All @@ -90,7 +92,17 @@ TEST_F(ElementAtTest, nonConstantZeroIndexReturnsNull) {
EXPECT_TRUE(result->isNullAt(1));
}

// ── Non-constant index: negative β†’ returns NULL ──────────────────────────────
TEST_F(ElementAtTest, runtimeConstantZeroIndexReturnsNull) {
auto arrayVector = makeArrayVector<int64_t>({{10, 20, 30}, {40, 50}});
auto zeroIndexBase = makeFlatVector<int64_t>({0});
auto indexVector = BaseVector::wrapInConstant(2, 0, zeroIndexBase);

auto result = evaluate<SimpleVector<int64_t>>(
"element_at(C0, C1)", makeRowVector({arrayVector, indexVector}));

EXPECT_TRUE(result->isNullAt(0));
EXPECT_TRUE(result->isNullAt(1));
}

TEST_F(ElementAtTest, nonConstantNegativeIndexReturnsNull) {
auto arrayVector = makeArrayVector<int64_t>({{10, 20, 30}, {40, 50}});
Expand All @@ -101,8 +113,6 @@ TEST_F(ElementAtTest, nonConstantNegativeIndexReturnsNull) {
EXPECT_TRUE(result->isNullAt(1));
}

// ── Non-constant index: out-of-bounds β†’ returns NULL ─────────────────────────

TEST_F(ElementAtTest, nonConstantOutOfBoundsReturnsNull) {
auto arrayVector = makeArrayVector<int64_t>({{10, 20, 30}});
auto indexVector = makeFlatVector<int64_t>({99});
Expand All @@ -111,8 +121,6 @@ TEST_F(ElementAtTest, nonConstantOutOfBoundsReturnsNull) {
EXPECT_TRUE(result->isNullAt(0));
}

// ── Non-constant index: mixed valid / invalid rows ───────────────────────────

TEST_F(ElementAtTest, nonConstantMixedRows) {
auto arrayVector = makeArrayVector<int64_t>({{10, 20, 30}, {40, 50}, {60}});
auto indexVector = makeFlatVector<int64_t>({2, 0, -1});
Expand Down
34 changes: 28 additions & 6 deletions bolt/functions/lib/SubscriptUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,11 @@ template <
bool nullOnNonConstantInvalidIndex = false>
class SubscriptImpl : public exec::Subscript {
public:
explicit SubscriptImpl(bool allowCaching)
: mapSubscript_(MapSubscript(allowCaching)) {}
explicit SubscriptImpl(
bool allowCaching,
bool constantIndexInvalidThrows = false)
: mapSubscript_(MapSubscript(allowCaching)),
constantIndexInvalidThrows_(constantIndexInvalidThrows) {}

void apply(
const SelectivityVector& rows,
Expand Down Expand Up @@ -313,19 +316,32 @@ class SubscriptImpl : public exec::Subscript {
isZeroSubscriptError,
zeroBasedArrayIndex);
if (isZeroSubscriptError) {
context.setErrors(rows, zeroSubscriptError());
allFailed = true;
if constexpr (nullOnNonConstantInvalidIndex) {
if (constantIndexInvalidThrows_) {
std::rethrow_exception(zeroSubscriptError());
}
return BaseVector::createNullConstant(
baseArray->elements()->type(), rows.end(), context.pool());
} else {
context.setErrors(rows, zeroSubscriptError());
allFailed = true;
}
}

if (!allFailed) {
rows.applyToSelected([&](auto row) {
const auto elementIndex = getIndex(
adjustedIndex, row, rawSizes, rawOffsets, arrayIndices, context);
rawIndices[row] = elementIndex;
if (elementIndex == -1) {
nullsBuilder.setNull(row);
rawIndices[row] = 0;
} else {
rawIndices[row] = elementIndex;
}
});
} else {
return BaseVector::createNullConstant(
baseArray->elements()->type(), rows.end(), context.pool());
}
} else {
rows.applyToSelected([&](auto row) {
Expand All @@ -345,9 +361,11 @@ class SubscriptImpl : public exec::Subscript {
}
const auto elementIndex = getIndex(
adjustedIndex, row, rawSizes, rawOffsets, arrayIndices, context);
rawIndices[row] = elementIndex;
if (elementIndex == -1) {
nullsBuilder.setNull(row);
rawIndices[row] = 0;
} else {
rawIndices[row] = elementIndex;
}
});
}
Expand Down Expand Up @@ -429,6 +447,9 @@ class SubscriptImpl : public exec::Subscript {
index += arraySize;
}
} else {
if constexpr (nullOnNonConstantInvalidIndex) {
return -1;
}
context.setBoltExceptionError(row, negativeSubscriptError());
return -1;
}
Expand All @@ -452,6 +473,7 @@ class SubscriptImpl : public exec::Subscript {

private:
MapSubscript mapSubscript_;
const bool constantIndexInvalidThrows_;
};

} // namespace bytedance::bolt::functions