diff --git a/bolt/shuffle/sparksql/BoltShuffleWriter.cpp b/bolt/shuffle/sparksql/BoltShuffleWriter.cpp index 11116c14d..9029e979f 100644 --- a/bolt/shuffle/sparksql/BoltShuffleWriter.cpp +++ b/bolt/shuffle/sparksql/BoltShuffleWriter.cpp @@ -756,6 +756,12 @@ arrow::Status BoltShuffleWriter::buildPartition2Row(uint32_t rowNum) { rowOffset2RowId_.resize(rowNum); for (auto row = 0; row < rowNum; ++row) { auto pid = row2Partition_[row]; + if (pid >= numPartitions_) { + return arrow::Status::Invalid( + "buildPartition2Row: invalid partition id " + std::to_string(pid) + + " for row " + std::to_string(row) + + ", numPartitions=" + std::to_string(numPartitions_)); + } rowOffset2RowId_[partition2RowOffsetBase_[pid]++] = row; } @@ -927,7 +933,7 @@ arrow::Status BoltShuffleWriter::splitBoolType( for (auto& pid : partitionUsed_) { uint8_t* dstaddr = dstAddrs[pid]; if (dstaddr != nullptr) { - splitBoolTypeInternal(srcAddr, dstaddr, pid); + RETURN_NOT_OK(splitBoolTypeInternal(srcAddr, dstaddr, pid)); } } return arrow::Status::OK(); @@ -939,15 +945,16 @@ arrow::Status BoltShuffleWriter::splitBoolType( for (auto& pid : partitionUsed_) { uint8_t* dstaddr = dstAddrs[pid][0]; BOLT_DCHECK(dstaddr != nullptr); - splitBoolTypeInternal(srcAddr, dstaddr, pid); + RETURN_NOT_OK(splitBoolTypeInternal(srcAddr, dstaddr, pid)); } return arrow::Status::OK(); } -void BoltShuffleWriter::splitBoolTypeInternal( +arrow::Status BoltShuffleWriter::splitBoolTypeInternal( const uint8_t* srcAddr, uint8_t* dstaddr, uint32_t pid) { + RETURN_NOT_OK(validatePartitionRowRange(pid, "splitBoolTypeInternal")); auto r = partition2RowOffsetBase_[pid]; /*8k*/ auto size = partition2RowOffsetBase_[pid + 1]; auto dstOffset = partitionBufferBase_[pid]; @@ -957,6 +964,7 @@ void BoltShuffleWriter::splitBoolTypeInternal( for (; r < size && dstIdxByte > 0; r++, dstIdxByte--) { auto srcOffset = rowOffset2RowId_[r]; /*16k*/ + RETURN_NOT_OK(validateSourceRowId(srcOffset, "splitBoolTypeInternal")); auto src = srcAddr[srcOffset >> 3]; src = src >> (srcOffset & 7) | 0xfe; // get the bit in bit 0, other bits set to 1 @@ -969,11 +977,15 @@ void BoltShuffleWriter::splitBoolTypeInternal( } dstaddr[dstOffset >> 3] = dst; if (r == size) { - return; + return arrow::Status::OK(); } dstOffset += dstOffsetInByte; // now dst_offset is 8 aligned for (; r + 8 < size; r += 8) { + for (auto i = 0; i < 8; ++i) { + RETURN_NOT_OK(validateSourceRowId( + rowOffset2RowId_[r + i], "splitBoolTypeInternal")); + } dst = extractBitsToByteSimd(srcAddr, &rowOffset2RowId_[r]); dstaddr[dstOffset >> 3] = dst; dstOffset += 8; @@ -984,6 +996,7 @@ void BoltShuffleWriter::splitBoolTypeInternal( dstIdxByte = 0; for (; r < size; r++, dstIdxByte++) { auto srcOffset = rowOffset2RowId_[r]; /*16k*/ + RETURN_NOT_OK(validateSourceRowId(srcOffset, "splitBoolTypeInternal")); auto src = srcAddr[srcOffset >> 3]; src = src >> (srcOffset & 7) | 0xfe; // get the bit in bit 0, other bits set to 1 @@ -995,6 +1008,7 @@ void BoltShuffleWriter::splitBoolTypeInternal( dst = dst & src; // only take the useful bit. } dstaddr[dstOffset >> 3] = dst; + return arrow::Status::OK(); } template @@ -1042,6 +1056,7 @@ arrow::Status BoltShuffleWriter::splitBinaryType( const auto* srcRawNulls = src.rawNulls(); for (auto& pid : partitionUsed_) { + RETURN_NOT_OK(validatePartitionRowRange(pid, "splitBinaryType")); auto& binaryBuf = dst[pid]; // use 32bit offset @@ -1058,6 +1073,7 @@ arrow::Status BoltShuffleWriter::splitBinaryType( for (auto i = 0; i < numRows; i++) { auto rowId = rowOffset2RowId_[rowOffsetBase + i]; + RETURN_NOT_OK(validateSourceRowId(rowId, "splitBinaryType")); auto& stringView = srcRawValues[rowId]; size_t isNull = srcRawNulls && bytedance::bolt::bits::isBitNull(srcRawNulls, rowId); diff --git a/bolt/shuffle/sparksql/BoltShuffleWriter.h b/bolt/shuffle/sparksql/BoltShuffleWriter.h index 3c59e7ee9..4e6fb8495 100644 --- a/bolt/shuffle/sparksql/BoltShuffleWriter.h +++ b/bolt/shuffle/sparksql/BoltShuffleWriter.h @@ -371,9 +371,38 @@ class BoltShuffleWriter : public ShuffleWriter { const uint8_t* srcAddr, const std::vector>& dstAddrs); - void + arrow::Status splitBoolTypeInternal(const uint8_t* srcAddr, uint8_t* dstaddr, uint32_t pid); + arrow::Status validateSourceRowId(uint32_t rowId, const char* context) const { + if (rowId < rowOffset2RowId_.size()) { + return arrow::Status::OK(); + } + return arrow::Status::Invalid( + std::string(context) + ": source rowId " + std::to_string(rowId) + + " is out of bounds for input row count " + + std::to_string(rowOffset2RowId_.size())); + } + + arrow::Status validatePartitionRowRange(uint32_t pid, const char* context) + const { + if (pid + 1 >= partition2RowOffsetBase_.size()) { + return arrow::Status::Invalid( + std::string(context) + ": partition id " + std::to_string(pid) + + " is out of bounds for partition row offsets"); + } + const auto begin = partition2RowOffsetBase_[pid]; + const auto end = partition2RowOffsetBase_[pid + 1]; + if (begin <= end && end <= rowOffset2RowId_.size()) { + return arrow::Status::OK(); + } + return arrow::Status::Invalid( + std::string(context) + ": partition " + std::to_string(pid) + + " has invalid row range [" + std::to_string(begin) + ", " + + std::to_string(end) + ") for input row count " + + std::to_string(rowOffset2RowId_.size())); + } + template arrow::Status splitValidityBuffer(const bytedance::bolt::RowVector& rv); @@ -395,12 +424,14 @@ class BoltShuffleWriter : public ShuffleWriter { const uint8_t* srcAddr, const std::vector& dstAddrs) { for (auto& pid : partitionUsed_) { + RETURN_NOT_OK(validatePartitionRowRange(pid, "splitFixedType")); auto dstPidBase = (T*)(dstAddrs[pid] + partitionBufferBase_[pid] * sizeof(T)); auto pos = partition2RowOffsetBase_[pid]; auto end = partition2RowOffsetBase_[pid + 1]; for (; pos < end; ++pos) { auto rowId = rowOffset2RowId_[pos]; + RETURN_NOT_OK(validateSourceRowId(rowId, "splitFixedType")); *dstPidBase++ = reinterpret_cast(srcAddr)[rowId]; // copy } } @@ -413,12 +444,14 @@ class BoltShuffleWriter : public ShuffleWriter { const uint8_t* srcAddr, const std::vector>& dstAddrs) { for (auto& pid : partitionUsed_) { + RETURN_NOT_OK(validatePartitionRowRange(pid, "splitFixedType")); auto dstPidBase = (T*)(dstAddrs[pid].back() + partitionBufferBaseInBatches_[pid] * sizeof(T)); auto pos = partition2RowOffsetBase_[pid]; auto end = partition2RowOffsetBase_[pid + 1]; for (; pos < end; ++pos) { auto rowId = rowOffset2RowId_[pos]; + RETURN_NOT_OK(validateSourceRowId(rowId, "splitFixedType")); *dstPidBase++ = reinterpret_cast(srcAddr)[rowId]; // copy } } diff --git a/bolt/shuffle/sparksql/partitioner/HashPartitioner.cpp b/bolt/shuffle/sparksql/partitioner/HashPartitioner.cpp index 9941952d8..15efdc09b 100644 --- a/bolt/shuffle/sparksql/partitioner/HashPartitioner.cpp +++ b/bolt/shuffle/sparksql/partitioner/HashPartitioner.cpp @@ -32,6 +32,18 @@ #include "bolt/shuffle/sparksql/partitioner/HashPartitioner.h" namespace bytedance::bolt::shuffle::sparksql { +namespace { + +inline int32_t normalizePid(int32_t value, int32_t numPartitions) { + auto pid = value % numPartitions; + if (pid < 0) { + pid += numPartitions; + } + return pid; +} + +} // namespace + arrow::Status HashPartitioner::compute( const int32_t* pidArr, const int64_t numRows, @@ -41,20 +53,7 @@ arrow::Status HashPartitioner::compute( std::fill(std::begin(partition2RowCount), std::end(partition2RowCount), 0); for (auto i = 0; i < numRows; ++i) { - auto pid = pidArr[i] % numPartitions_; -#if defined(__x86_64__) - // force to generate ASM - __asm__( - "lea (%[num_partitions],%[pid],1),%[tmp]\n" - "test %[pid],%[pid]\n" - "cmovs %[tmp],%[pid]\n" - : [pid] "+r"(pid) - : [num_partitions] "r"(numPartitions_), [tmp] "r"(0)); -#else - if (pid < 0) { - pid += numPartitions_; - } -#endif + auto pid = normalizePid(pidArr[i], numPartitions_); row2partition[i] = pid; } @@ -75,20 +74,7 @@ arrow::Status HashPartitioner::precompute( } for (auto i = 0; i < numRows; ++i) { - auto pid = pidArr[i] % numPartitions_; -#if defined(__x86_64__) - // force to generate ASM - __asm__( - "lea (%[num_partitions],%[pid],1),%[tmp]\n" - "test %[pid],%[pid]\n" - "cmovs %[tmp],%[pid]\n" - : [pid] "+r"(pid) - : [num_partitions] "r"(numPartitions_), [tmp] "r"(0)); -#else - if (pid < 0) { - pid += numPartitions_; - } -#endif + auto pid = normalizePid(pidArr[i], numPartitions_); pidArr[i] = pid; partition2RowCount[pid]++; } diff --git a/bolt/shuffle/sparksql/tests/ShuffleMiscTest.cpp b/bolt/shuffle/sparksql/tests/ShuffleMiscTest.cpp index ce4eeec92..60600d784 100644 --- a/bolt/shuffle/sparksql/tests/ShuffleMiscTest.cpp +++ b/bolt/shuffle/sparksql/tests/ShuffleMiscTest.cpp @@ -14,10 +14,111 @@ * limitations under the License. */ +#include +#include "bolt/exec/tests/utils/TempDirectoryPath.h" +#include "bolt/shuffle/sparksql/BoltShuffleWriter.h" +#include "bolt/shuffle/sparksql/partitioner/HashPartitioner.h" #include "bolt/shuffle/sparksql/tests/ShuffleTestBase.h" namespace bytedance::bolt::shuffle::sparksql::test { +namespace { + +class TestBoltShuffleWriter : public BoltShuffleWriter { + public: + TestBoltShuffleWriter( + ShuffleWriterOptions options, + bytedance::bolt::memory::MemoryPool* boltPool, + arrow::MemoryPool* pool) + : BoltShuffleWriter(std::move(options), boltPool, pool) {} + + arrow::Status initialize() { + return init(); + } + + arrow::Status buildPartitionRowsWithInjectedPid( + const bytedance::bolt::RowVector& rv, + uint32_t badPid) { + auto status = initFromRowVector(rv); + if (!status.ok()) { + return status; + } + row2Partition_.assign(rv.size(), badPid); + std::fill(partition2RowCount_.begin(), partition2RowCount_.end(), 0); + partition2RowCount_[0] = rv.size(); + return buildPartition2Row(rv.size()); + } + + arrow::Status splitWithInjectedRowId( + const bytedance::bolt::RowVector& rv, + uint32_t badRowId) { + auto status = initFromRowVector(rv); + if (!status.ok()) { + return status; + } + std::fill(partitionBufferBase_.begin(), partitionBufferBase_.end(), 0); + std::fill(partitionBufferSize_.begin(), partitionBufferSize_.end(), 0); + std::fill(partition2RowCount_.begin(), partition2RowCount_.end(), 0); + partition2RowCount_[0] = 1; + partitionUsed_ = {0}; + std::fill( + partition2RowOffsetBase_.begin(), partition2RowOffsetBase_.end(), 0); + partition2RowOffsetBase_[1] = 1; + rowOffset2RowId_.assign(1, badRowId); + status = updateInputHasNull(rv); + if (!status.ok()) { + return status; + } + status = preAllocPartitionBuffers(1); + if (!status.ok()) { + return status; + } + return splitRowVector(rv); + } + + arrow::Status splitWithInjectedRowRange( + const bytedance::bolt::RowVector& rv, + uint32_t endOffset) { + auto status = initFromRowVector(rv); + if (!status.ok()) { + return status; + } + std::fill(partitionBufferBase_.begin(), partitionBufferBase_.end(), 0); + std::fill(partitionBufferSize_.begin(), partitionBufferSize_.end(), 0); + std::fill(partition2RowCount_.begin(), partition2RowCount_.end(), 0); + partition2RowCount_[0] = endOffset; + partitionUsed_ = {0}; + std::fill( + partition2RowOffsetBase_.begin(), partition2RowOffsetBase_.end(), 0); + partition2RowOffsetBase_[1] = endOffset; + rowOffset2RowId_.assign(1, 0); + status = updateInputHasNull(rv); + if (!status.ok()) { + return status; + } + status = preAllocPartitionBuffers(endOffset); + if (!status.ok()) { + return status; + } + return splitRowVector(rv); + } +}; + +ShuffleWriterOptions makeWriterOptions(const std::string& baseDir) { + ShuffleWriterOptions options; + options.partitioning = Partitioning::kHash; + options.partitionWriterOptions.partitionWriterType = + PartitionWriterType::kLocal; + options.partitionWriterOptions.numPartitions = 2; + options.partitionWriterOptions.dataFile = baseDir + "/shuffle.bin"; + options.partitionWriterOptions.configuredDirs = {baseDir}; + options.partitionWriterOptions.numSubDirs = 1; + options.bufferSize = 4; + return options; +} + +} // namespace + class ShuffleMiscTest : public ShuffleTestBase {}; // End-to-end test: RoundRobin with Adaptive mode, >=8000 partitions and >=5 @@ -55,4 +156,68 @@ TEST_F(ShuffleMiscTest, AdaptiveRoundRobinLargePartitionsMixTypes) { executeTest(param); } +TEST_F(ShuffleMiscTest, HashPartitionerNormalizesNegativeHashes) { + HashPartitioner partitioner(4); + std::vector row2Partition; + std::vector partition2RowCount(4, 0); + const std::vector hashes = {-1, -2, -3, -4, -5, 0, 1, 5}; + + ASSERT_TRUE( + partitioner + .compute( + hashes.data(), hashes.size(), row2Partition, partition2RowCount) + .ok()); + + ASSERT_EQ(row2Partition.size(), hashes.size()); + for (size_t i = 0; i < hashes.size(); ++i) { + const auto expected = ((hashes[i] % 4) + 4) % 4; + EXPECT_EQ(row2Partition[i], expected) << "hash=" << hashes[i]; + } + EXPECT_EQ( + std::accumulate( + partition2RowCount.begin(), partition2RowCount.end(), uint32_t{0}), + hashes.size()); +} + +TEST_F(ShuffleMiscTest, BuildPartition2RowRejectsInvalidPartitionId) { + auto tempDir = exec::test::TempDirectoryPath::create(); + TestBoltShuffleWriter writer( + makeWriterOptions(tempDir->path), pool(), arrow::default_memory_pool()); + ASSERT_TRUE(writer.initialize().ok()); + + auto rv = makeRowVector({makeFlatVector({1, 2, 3})}); + auto status = writer.buildPartitionRowsWithInjectedPid(*rv, 2); + + ASSERT_FALSE(status.ok()); + EXPECT_NE( + status.message().find("buildPartition2Row: invalid partition id"), + std::string::npos); +} + +TEST_F(ShuffleMiscTest, SplitRowVectorRejectsOutOfBoundsRowId) { + auto tempDir = exec::test::TempDirectoryPath::create(); + TestBoltShuffleWriter writer( + makeWriterOptions(tempDir->path), pool(), arrow::default_memory_pool()); + ASSERT_TRUE(writer.initialize().ok()); + + auto rv = makeRowVector({makeFlatVector({1})}); + auto status = writer.splitWithInjectedRowId(*rv, rv->size()); + + ASSERT_FALSE(status.ok()); + EXPECT_NE(status.message().find("source rowId"), std::string::npos); +} + +TEST_F(ShuffleMiscTest, SplitRowVectorRejectsInvalidPartitionRowRange) { + auto tempDir = exec::test::TempDirectoryPath::create(); + TestBoltShuffleWriter writer( + makeWriterOptions(tempDir->path), pool(), arrow::default_memory_pool()); + ASSERT_TRUE(writer.initialize().ok()); + + auto rv = makeRowVector({makeFlatVector({1})}); + auto status = writer.splitWithInjectedRowRange(*rv, 2); + + ASSERT_FALSE(status.ok()); + EXPECT_NE(status.message().find("invalid row range"), std::string::npos); +} + } // namespace bytedance::bolt::shuffle::sparksql::test