Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright 2019-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -59,6 +59,13 @@ if(TRITON_ENABLE_GRPC)
set(TRITON_COMMON_ENABLE_GRPC ON)
endif() # TRITON_ENABLE_GRPC

# protobuf
#
if(TRITON_ENABLE_HTTP OR TRITON_ENABLE_METRICS OR TRITON_ENABLE_SAGEMAKER OR
TRITON_ENABLE_VERTEX_AI)
set(TRITON_COMMON_ENABLE_PROTOBUF ON)
endif()

FetchContent_MakeAvailable(repo-common repo-core repo-backend)

# CUDA
Expand Down Expand Up @@ -406,6 +413,17 @@ if(${TRITON_ENABLE_HTTP}
re2::re2
)

# model_config.h (GetElementCount, etc.) needs Protobuf + generated protos.
if(TARGET triton-common-model-config)
target_link_libraries(
http-endpoint-library
PUBLIC
triton-common-model-config
proto-library
protobuf::libprotobuf
)
endif()

target_include_directories(
http-endpoint-library
PRIVATE $<TARGET_PROPERTY:libevhtp::evhtp,INTERFACE_INCLUDE_DIRECTORIES>
Expand Down
31 changes: 1 addition & 30 deletions src/common.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2020-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2020-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -105,35 +105,6 @@ ShapeToString(const std::vector<int64_t>& shape)
return ShapeToString(shape.data(), shape.size());
}

int64_t
GetElementCount(const std::vector<int64_t>& dims)
{
bool first = true;
int64_t cnt = 0;
for (auto dim : dims) {
if (dim == WILDCARD_DIM) {
return -1;
} else if (dim < 0) { // invalid dim
return -2;
} else if (dim == 0) {
return 0;
}

if (first) {
cnt = dim;
first = false;
} else {
// Check for overflow before multiplication
if (cnt > (INT64_MAX / dim)) {
return -3;
}
cnt *= dim;
}
}

return cnt;
}

bool
Contains(const std::vector<std::string>& vec, const std::string& str)
{
Expand Down
15 changes: 1 addition & 14 deletions src/common.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2019-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -51,10 +51,6 @@ constexpr char kTritonSharedMemoryRegionPrefix[] =

constexpr int MAX_GRPC_MESSAGE_SIZE = INT32_MAX;

/// The value for a dimension in a shape that indicates that that
/// dimension can take on any size.
constexpr int WILDCARD_DIM = -1;

// Maximum allowed depth for JSON parsing
constexpr int32_t HTTP_MAX_JSON_NESTING_DEPTH = 100;

Expand Down Expand Up @@ -162,15 +158,6 @@ TRITONSERVER_Error* GetModelVersionFromString(
std::string GetEnvironmentVariableOrDefault(
const std::string& variable_name, const std::string& default_value);

/// Get the number of elements in a shape.
///
/// \param dims The shape.
/// \return The number of elements, -1 if the number of elements
/// cannot be determined because the shape contains one or more
/// wildcard dimensions, -2 if the shape contains an invalid dim,
/// or -3 if the number is too large to represent as an int64_t.
int64_t GetElementCount(const std::vector<int64_t>& dims);

/// Convert shape to string representation.
///
/// \param shape The shape as a vector.
Expand Down
18 changes: 13 additions & 5 deletions src/http_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#define TRITONJSON_STATUSRETURN(M) \
return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, (M).c_str())
#define TRITONJSON_STATUSSUCCESS nullptr
#include "triton/common/model_config.h"
#include "triton/common/triton_json.h"

namespace triton { namespace server {
Expand Down Expand Up @@ -2614,23 +2615,21 @@ HTTPAPIServer::ParseJsonTritonIO(
memory_type_id));
}
} else {
const int64_t element_cnt = GetElementCount(shape_vec);
const int64_t element_cnt = triton::common::GetElementCount(shape_vec);

if (element_cnt == 0) {
RETURN_IF_ERR(TRITONSERVER_InferenceRequestAppendInputData(
irequest, input_name, nullptr, 0 /* byte_size */,
TRITONSERVER_MEMORY_CPU, 0 /* memory_type_id */));
} else if (element_cnt == -2) {
// -2 indicates invalid dimension
} else if (element_cnt == triton::common::INVALID_SIZE) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
std::string(
"invalid shape for input '" + std::string(input_name) +
"': shape " + ShapeToString(shape_vec) +
" contains one or more invalid dimensions")
.c_str());
} else if (element_cnt == -3) {
// -3 indicates integer overflow
} else if (element_cnt == triton::common::OVERFLOW_SIZE) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
std::string(
Expand All @@ -2639,6 +2638,15 @@ HTTPAPIServer::ParseJsonTritonIO(
" causes total element count to exceed maximum size of " +
std::to_string(INT64_MAX))
.c_str());
} else if (element_cnt == triton::common::WILDCARD_SIZE) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
std::string(
"invalid shape for input '" + std::string(input_name) +
"': shape " + ShapeToString(shape_vec) +
" contains one or more variable-size dimensions (-1); cannot "
"determine element count for JSON input")
.c_str());
} else {
// JSON... presence of "data" already validated but still
// checking here. Flow in this endpoint needs to be
Expand Down
131 changes: 119 additions & 12 deletions src/test/tensor_size_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,25 +173,35 @@ assert_get_byte_size_success(
{
int64_t size;
TRITONSERVER_Error* err;
inference::DataType core_dtype = tcore::TritonToDataType(dtype);

// Backend (old API)
// Backend (public old API)
ASSERT_EQ(expected_size, tb::GetByteSize(dtype, shape));

// Backend (new API)
// Backend (public new API)
err = tb::GetByteSize(dtype, shape, &size);
ASSERT_EQ(err, nullptr);
ASSERT_EQ(expected_size, size);

// Common
inference::DataType core_dtype = tcore::TritonToDataType(dtype);
// Common (public API)
ASSERT_EQ(tc::GetByteSize(core_dtype, shape), expected_size);

// Core
// Core (internal helper)
if (test_core) {
size = 0;
auto status = tcore::GetByteSize(core_dtype, shape, kTensorName, &size);
ASSERT_TRUE(status.IsOk()) << status.Message();
ASSERT_EQ(size, expected_size);
// Special case: rejects wildcard / non-fixed size with INVALID_ARG
if (expected_size == tc::WILDCARD_SIZE) {
ASSERT_FALSE(status.IsOk());
ASSERT_EQ(status.StatusCode(), triton::core::Status::Code::INVALID_ARG);
ASSERT_TRUE(
std::string(status.Message())
.find("contains one or more variable-size dimensions") !=
std::string::npos);
} else {
ASSERT_TRUE(status.IsOk()) << status.Message();
ASSERT_EQ(size, expected_size);
}
}
}

Expand Down Expand Up @@ -274,10 +284,6 @@ TEST_F(GetElementCountTest, GetElementCountWildcard)
// Test 3: multiple -1 dims
shape = {8, -1, -1};
assert_get_element_count_success(shape, expected_cnt);

// Test 4: -1 dim before overflow
shape = {-1, 1LL << 32, 1LL << 31};
assert_get_element_count_success(shape, expected_cnt);
}

TEST_F(GetElementCountTest, GetElementCountZero)
Expand Down Expand Up @@ -323,6 +329,12 @@ TEST_F(GetElementCountTest, GetElementCountInvalidDim)
error_msg = std::string("shape") + tb::ShapeToString(shape) +
" contains an invalid dim.";
assert_get_element_count_error(shape, ErrorCode::kInvalidDim, error_msg);

// Test 4: invalid dim after a wildcard
shape = {-1, -2};
error_msg = std::string("shape") + tb::ShapeToString(shape) +
" contains an invalid dim.";
assert_get_element_count_error(shape, ErrorCode::kInvalidDim, error_msg);
}

TEST_F(GetElementCountTest, GetElementCountOverflow)
Expand All @@ -339,11 +351,56 @@ TEST_F(GetElementCountTest, GetElementCountOverflow)
shape = {1LL << 32, 1LL << 31};
error_msg = "unexpected integer overflow while calculating element count.";
assert_get_element_count_error(shape, ErrorCode::kOverflow, error_msg);
}

TEST_F(GetElementCountTest, GetElementCountMixed)
{
std::vector<int64_t> shape;
std::string error_msg;

// Test 1: -1 dim before overflow
shape = {-1, 1LL << 32, 1LL << 31};
error_msg = "unexpected integer overflow while calculating element count.";
assert_get_element_count_error(shape, ErrorCode::kOverflow, error_msg);

// Test 2: -1 dim before overflow 2
shape = {1LL << 32, -1, 1LL << 31};
error_msg = "unexpected integer overflow while calculating element count.";
assert_get_element_count_error(shape, ErrorCode::kOverflow, error_msg);

// Test 3: overflows before -1 dim
shape = {1LL << 32, 1LL << 31, -1};
error_msg = "unexpected integer overflow while calculating element count.";
assert_get_element_count_error(shape, ErrorCode::kOverflow, error_msg);

// Test 4: -1 dim before invalid dim
shape = {-1, -2};
error_msg = std::string("shape") + tb::ShapeToString(shape) +
" contains an invalid dim.";
assert_get_element_count_error(shape, ErrorCode::kInvalidDim, error_msg);

// Test 5: invalid dim before -1 dim
shape = {-2, -1};
error_msg = std::string("shape") + tb::ShapeToString(shape) +
" contains an invalid dim.";
assert_get_element_count_error(shape, ErrorCode::kInvalidDim, error_msg);

// Test 6: invalid dim before overflow dim
shape = {-2, 1LL << 32, 1LL << 31};
error_msg = std::string("shape") + tb::ShapeToString(shape) +
" contains an invalid dim.";
assert_get_element_count_error(shape, ErrorCode::kInvalidDim, error_msg);

// Test 7: invalid dim before overflow dim 2
shape = {1LL << 32, -2, 1LL << 31};
error_msg = std::string("shape") + tb::ShapeToString(shape) +
" contains an invalid dim.";
assert_get_element_count_error(shape, ErrorCode::kInvalidDim, error_msg);

// Test 8: overflow dim before invalid dim
shape = {1LL << 32, 1LL << 31, -2};
error_msg = "unexpected integer overflow while calculating element count.";
assert_get_element_count_error(shape, ErrorCode::kOverflow, error_msg);
}

class GetByteSizeTest : public ::testing::Test {
Expand Down Expand Up @@ -400,7 +457,6 @@ TEST_F(GetByteSizeTest, GetByteSizeWildcard)
ASSERT_TRUE(status.IsOk()) << status.Message();
ASSERT_EQ(size, sizeof(int32_t) * 8 * 8);


// Test 3: invalid shape and element count overflows
dtype = TRITONSERVER_TYPE_INVALID;
shape = {1LL << 40, 1LL << 40};
Expand Down Expand Up @@ -481,6 +537,57 @@ TEST_F(GetByteSizeTest, GetByteSizeOverflow)
assert_get_byte_size_error(dtype, shape, ErrorCode::kOverflow, error_msg);
}

TEST_F(GetByteSizeTest, GetByteSizeMixed)
{
TRITONSERVER_DataType dtype = TRITONSERVER_TYPE_INT32;
std::vector<int64_t> shape;
std::string error_msg;

// Test 1: wildcard dim before overflow
shape = {-1, 1LL << 32, 1LL << 31};
error_msg = "unexpected integer overflow while calculating byte size.";
assert_get_byte_size_error(dtype, shape, ErrorCode::kOverflow, error_msg);

// Test 2: wildcard dim before overflow 2
shape = {1LL << 32, -1, 1LL << 31};
error_msg = "unexpected integer overflow while calculating byte size.";
assert_get_byte_size_error(dtype, shape, ErrorCode::kOverflow, error_msg);

// Test 3: overflows before wildcard dim
shape = {1LL << 32, 1LL << 31, -1};
error_msg = "unexpected integer overflow while calculating byte size.";
assert_get_byte_size_error(dtype, shape, ErrorCode::kOverflow, error_msg);

// Test 4: wildcard dim before invalid dim
shape = {-1, -2};
error_msg = std::string("shape") + tb::ShapeToString(shape) +
" contains an invalid dim.";
assert_get_byte_size_error(dtype, shape, ErrorCode::kInvalidDim, error_msg);

// Test 5: invalid dim before wildcard dim
shape = {-2, -1};
error_msg = std::string("shape") + tb::ShapeToString(shape) +
" contains an invalid dim.";
assert_get_byte_size_error(dtype, shape, ErrorCode::kInvalidDim, error_msg);

// Test 6: invalid dim before overflow
shape = {-2, 1LL << 32, 1LL << 31};
error_msg = std::string("shape") + tb::ShapeToString(shape) +
" contains an invalid dim.";
assert_get_byte_size_error(dtype, shape, ErrorCode::kInvalidDim, error_msg);

// Test 7: invalid dim before overflow 2
shape = {1LL << 32, -2, 1LL << 31};
error_msg = std::string("shape") + tb::ShapeToString(shape) +
" contains an invalid dim.";
assert_get_byte_size_error(dtype, shape, ErrorCode::kInvalidDim, error_msg);

// Test 8: overflow before invalid dim
shape = {1LL << 32, 1LL << 31, -2};
error_msg = "unexpected integer overflow while calculating byte size.";
assert_get_byte_size_error(dtype, shape, ErrorCode::kOverflow, error_msg);
}

} // namespace

int
Expand Down
Loading