Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions bolt/shuffle/sparksql/BoltShuffleWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,12 @@ arrow::Status BoltShuffleWriter::buildPartition2Row(uint32_t rowNum) {
rowOffset2RowId_.resize(rowNum);
for (auto row = 0; row < rowNum; ++row) {
auto pid = row2Partition_[row];
if (pid >= numPartitions_) {
return arrow::Status::Invalid(
"buildPartition2Row: invalid partition id " + std::to_string(pid) +
" for row " + std::to_string(row) +
", numPartitions=" + std::to_string(numPartitions_));
}
rowOffset2RowId_[partition2RowOffsetBase_[pid]++] = row;
}

Expand Down Expand Up @@ -927,7 +933,7 @@ arrow::Status BoltShuffleWriter::splitBoolType(
for (auto& pid : partitionUsed_) {
uint8_t* dstaddr = dstAddrs[pid];
if (dstaddr != nullptr) {
splitBoolTypeInternal(srcAddr, dstaddr, pid);
RETURN_NOT_OK(splitBoolTypeInternal(srcAddr, dstaddr, pid));
}
}
return arrow::Status::OK();
Expand All @@ -939,15 +945,16 @@ arrow::Status BoltShuffleWriter::splitBoolType(
for (auto& pid : partitionUsed_) {
uint8_t* dstaddr = dstAddrs[pid][0];
BOLT_DCHECK(dstaddr != nullptr);
splitBoolTypeInternal(srcAddr, dstaddr, pid);
RETURN_NOT_OK(splitBoolTypeInternal(srcAddr, dstaddr, pid));
}
return arrow::Status::OK();
}

void BoltShuffleWriter::splitBoolTypeInternal(
arrow::Status BoltShuffleWriter::splitBoolTypeInternal(
const uint8_t* srcAddr,
uint8_t* dstaddr,
uint32_t pid) {
RETURN_NOT_OK(validatePartitionRowRange(pid, "splitBoolTypeInternal"));
auto r = partition2RowOffsetBase_[pid]; /*8k*/
auto size = partition2RowOffsetBase_[pid + 1];
auto dstOffset = partitionBufferBase_[pid];
Expand All @@ -957,6 +964,7 @@ void BoltShuffleWriter::splitBoolTypeInternal(

for (; r < size && dstIdxByte > 0; r++, dstIdxByte--) {
auto srcOffset = rowOffset2RowId_[r]; /*16k*/
RETURN_NOT_OK(validateSourceRowId(srcOffset, "splitBoolTypeInternal"));
auto src = srcAddr[srcOffset >> 3];
src = src >> (srcOffset & 7) |
0xfe; // get the bit in bit 0, other bits set to 1
Expand All @@ -969,11 +977,15 @@ void BoltShuffleWriter::splitBoolTypeInternal(
}
dstaddr[dstOffset >> 3] = dst;
if (r == size) {
return;
return arrow::Status::OK();
}
dstOffset += dstOffsetInByte;
// now dst_offset is 8 aligned
for (; r + 8 < size; r += 8) {
for (auto i = 0; i < 8; ++i) {
RETURN_NOT_OK(validateSourceRowId(
rowOffset2RowId_[r + i], "splitBoolTypeInternal"));
}
dst = extractBitsToByteSimd(srcAddr, &rowOffset2RowId_[r]);
dstaddr[dstOffset >> 3] = dst;
dstOffset += 8;
Expand All @@ -984,6 +996,7 @@ void BoltShuffleWriter::splitBoolTypeInternal(
dstIdxByte = 0;
for (; r < size; r++, dstIdxByte++) {
auto srcOffset = rowOffset2RowId_[r]; /*16k*/
RETURN_NOT_OK(validateSourceRowId(srcOffset, "splitBoolTypeInternal"));
auto src = srcAddr[srcOffset >> 3];
src = src >> (srcOffset & 7) |
0xfe; // get the bit in bit 0, other bits set to 1
Expand All @@ -995,6 +1008,7 @@ void BoltShuffleWriter::splitBoolTypeInternal(
dst = dst & src; // only take the useful bit.
}
dstaddr[dstOffset >> 3] = dst;
return arrow::Status::OK();
}

template <bool needAlloc>
Expand Down Expand Up @@ -1042,6 +1056,7 @@ arrow::Status BoltShuffleWriter::splitBinaryType(
const auto* srcRawNulls = src.rawNulls();

for (auto& pid : partitionUsed_) {
RETURN_NOT_OK(validatePartitionRowRange(pid, "splitBinaryType"));
auto& binaryBuf = dst[pid];

// use 32bit offset
Expand All @@ -1058,6 +1073,7 @@ arrow::Status BoltShuffleWriter::splitBinaryType(

for (auto i = 0; i < numRows; i++) {
auto rowId = rowOffset2RowId_[rowOffsetBase + i];
RETURN_NOT_OK(validateSourceRowId(rowId, "splitBinaryType"));
auto& stringView = srcRawValues[rowId];
size_t isNull =
srcRawNulls && bytedance::bolt::bits::isBitNull(srcRawNulls, rowId);
Expand Down
35 changes: 34 additions & 1 deletion bolt/shuffle/sparksql/BoltShuffleWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -371,9 +371,38 @@ class BoltShuffleWriter : public ShuffleWriter {
const uint8_t* srcAddr,
const std::vector<std::vector<uint8_t*>>& dstAddrs);

void
arrow::Status
splitBoolTypeInternal(const uint8_t* srcAddr, uint8_t* dstaddr, uint32_t pid);

arrow::Status validateSourceRowId(uint32_t rowId, const char* context) const {
if (rowId < rowOffset2RowId_.size()) {
return arrow::Status::OK();
}
return arrow::Status::Invalid(
std::string(context) + ": source rowId " + std::to_string(rowId) +
" is out of bounds for input row count " +
std::to_string(rowOffset2RowId_.size()));
}

arrow::Status validatePartitionRowRange(uint32_t pid, const char* context)
const {
if (pid + 1 >= partition2RowOffsetBase_.size()) {
return arrow::Status::Invalid(
std::string(context) + ": partition id " + std::to_string(pid) +
" is out of bounds for partition row offsets");
}
const auto begin = partition2RowOffsetBase_[pid];
const auto end = partition2RowOffsetBase_[pid + 1];
if (begin <= end && end <= rowOffset2RowId_.size()) {
return arrow::Status::OK();
}
return arrow::Status::Invalid(
std::string(context) + ": partition " + std::to_string(pid) +
" has invalid row range [" + std::to_string(begin) + ", " +
std::to_string(end) + ") for input row count " +
std::to_string(rowOffset2RowId_.size()));
}

template <bool needAlloc>
arrow::Status splitValidityBuffer(const bytedance::bolt::RowVector& rv);

Expand All @@ -395,12 +424,14 @@ class BoltShuffleWriter : public ShuffleWriter {
const uint8_t* srcAddr,
const std::vector<uint8_t*>& dstAddrs) {
for (auto& pid : partitionUsed_) {
RETURN_NOT_OK(validatePartitionRowRange(pid, "splitFixedType"));
auto dstPidBase =
(T*)(dstAddrs[pid] + partitionBufferBase_[pid] * sizeof(T));
auto pos = partition2RowOffsetBase_[pid];
auto end = partition2RowOffsetBase_[pid + 1];
for (; pos < end; ++pos) {
auto rowId = rowOffset2RowId_[pos];
RETURN_NOT_OK(validateSourceRowId(rowId, "splitFixedType"));
*dstPidBase++ = reinterpret_cast<const T*>(srcAddr)[rowId]; // copy
}
}
Expand All @@ -413,12 +444,14 @@ class BoltShuffleWriter : public ShuffleWriter {
const uint8_t* srcAddr,
const std::vector<std::vector<uint8_t*>>& dstAddrs) {
for (auto& pid : partitionUsed_) {
RETURN_NOT_OK(validatePartitionRowRange(pid, "splitFixedType"));
auto dstPidBase =
(T*)(dstAddrs[pid].back() + partitionBufferBaseInBatches_[pid] * sizeof(T));
auto pos = partition2RowOffsetBase_[pid];
auto end = partition2RowOffsetBase_[pid + 1];
for (; pos < end; ++pos) {
auto rowId = rowOffset2RowId_[pos];
RETURN_NOT_OK(validateSourceRowId(rowId, "splitFixedType"));
*dstPidBase++ = reinterpret_cast<const T*>(srcAddr)[rowId]; // copy
}
}
Expand Down
42 changes: 14 additions & 28 deletions bolt/shuffle/sparksql/partitioner/HashPartitioner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@
#include "bolt/shuffle/sparksql/partitioner/HashPartitioner.h"
namespace bytedance::bolt::shuffle::sparksql {

namespace {

inline int32_t normalizePid(int32_t value, int32_t numPartitions) {
auto pid = value % numPartitions;
if (pid < 0) {
pid += numPartitions;
}
return pid;
}

} // namespace

arrow::Status HashPartitioner::compute(
const int32_t* pidArr,
const int64_t numRows,
Expand All @@ -41,20 +53,7 @@ arrow::Status HashPartitioner::compute(
std::fill(std::begin(partition2RowCount), std::end(partition2RowCount), 0);

for (auto i = 0; i < numRows; ++i) {
auto pid = pidArr[i] % numPartitions_;
#if defined(__x86_64__)
// force to generate ASM
__asm__(
"lea (%[num_partitions],%[pid],1),%[tmp]\n"
"test %[pid],%[pid]\n"
"cmovs %[tmp],%[pid]\n"
: [pid] "+r"(pid)
: [num_partitions] "r"(numPartitions_), [tmp] "r"(0));
#else
if (pid < 0) {
pid += numPartitions_;
}
#endif
auto pid = normalizePid(pidArr[i], numPartitions_);
row2partition[i] = pid;
}

Expand All @@ -75,20 +74,7 @@ arrow::Status HashPartitioner::precompute(
}

for (auto i = 0; i < numRows; ++i) {
auto pid = pidArr[i] % numPartitions_;
#if defined(__x86_64__)
// force to generate ASM
__asm__(
"lea (%[num_partitions],%[pid],1),%[tmp]\n"
"test %[pid],%[pid]\n"
"cmovs %[tmp],%[pid]\n"
: [pid] "+r"(pid)
: [num_partitions] "r"(numPartitions_), [tmp] "r"(0));
#else
if (pid < 0) {
pid += numPartitions_;
}
#endif
auto pid = normalizePid(pidArr[i], numPartitions_);
pidArr[i] = pid;
partition2RowCount[pid]++;
}
Expand Down
Loading