Skip to content

feat: Enable bfloat16 input/output tensor dtype in Python client#880

Draft
yinggeh wants to merge 3 commits intomainfrom
yinggeh/tgh-26-onnx-backend-does-not-support-bfloat16-inputs
Draft

feat: Enable bfloat16 input/output tensor dtype in Python client#880
yinggeh wants to merge 3 commits intomainfrom
yinggeh/tgh-26-onnx-backend-does-not-support-bfloat16-inputs

Conversation

@yinggeh
Copy link
Copy Markdown
Contributor

@yinggeh yinggeh commented Feb 14, 2026

Drop support for type conversions between BF16 and FLOAT32. Users should use ml_dtypes.bf16 instead. For example

np_input = np.ones(shape, dtype=ml_dtypes.bfloat16)
inputs = [
    httpclient.InferInput(
        "INPUT0", np_input.shape, "BF16"
    ).set_data_from_numpy(np_input)
]

Comment thread src/python/library/tritonclient/http/_infer_input.py Fixed
@yinggeh yinggeh added the enhancement New feature or request label Feb 14, 2026
@yinggeh yinggeh requested a review from Copilot February 14, 2026 03:36
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR enables native bfloat16 (BF16) tensor support in the Triton Python client by integrating the ml_dtypes library, replacing the previous workaround that used float32 as an intermediate representation. The change allows BF16 tensors to be serialized and deserialized directly using standard numpy operations like tobytes() and np.frombuffer().

Changes:

  • Added ml_dtypes dependency (>=0.5.4) to support native bfloat16 dtype
  • Updated dtype conversion functions to map BF16 to ml_dtypes.bfloat16 instead of np.float32
  • Removed special-case serialization/deserialization logic for BF16 tensors across HTTP and gRPC clients
  • Updated copyright years to 2026

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
src/python/library/tritonclient/utils/init.py Added ml_dtypes import; updated dtype conversion functions to use ml_dtypes.bfloat16 for BF16; removed FP32 fallback
src/python/library/tritonclient/http/_infer_result.py Removed special BF16 deserialization logic; updated copyright
src/python/library/tritonclient/http/_infer_input.py Removed special BF16 validation and serialization logic; updated copyright
src/python/library/tritonclient/grpc/_infer_result.py Removed special BF16 deserialization logic; updated copyright
src/python/library/tritonclient/grpc/_infer_input.py Removed special BF16 validation and serialization logic; updated copyright
src/python/library/requirements/requirements.txt Added ml_dtypes>=0.5.4 dependency; updated copyright

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread src/python/library/tritonclient/http/_infer_input.py Outdated
Comment thread src/python/library/tritonclient/http/_infer_input.py Fixed
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.

Comments suppressed due to low confidence (1)

src/python/library/tritonclient/http/_infer_result.py:197

  • With BF16 now handled via triton_to_np_dtype(datatype) in the generic np.frombuffer path, deserialize_bf16_tensor is no longer referenced in this method. The file still imports deserialize_bf16_tensor, which is now unused; please drop it to avoid unused imports.
                                else:
                                    np_array = np.frombuffer(
                                        self._buffer[start_index:end_index],
                                        dtype=triton_to_np_dtype(datatype),
                                    )

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +132 to +138
dtype = np_to_triton_dtype(input_tensor.dtype)
if self._datatype != dtype:
raise_error(
"got unexpected datatype {} from numpy array, expected {}".format(
dtype, self._datatype
)
)
Copy link

Copilot AI Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change makes BF16 inputs require input_tensor.dtype to map to "BF16" (i.e., ml_dtypes.bfloat16). Previously the client accepted float32 tensors for BF16 (via triton_to_np_dtype("BF16") mapping) and handled conversion/serialization. If backward compatibility is desired, consider allowing float32 for BF16 here (and converting to BF16 bytes) or explicitly documenting this breaking change in the client behavior.

Copilot uses AI. Check for mistakes.
Comment on lines 190 to 197
if self._datatype == "BYTES":
serialized_output = serialize_byte_tensor(input_tensor)
if serialized_output.size > 0:
self._raw_data = serialized_output.item()
else:
self._raw_data = b""
elif self._datatype == "BF16":
serialized_output = serialize_bf16_tensor(input_tensor)
if serialized_output.size > 0:
self._raw_data = serialized_output.item()
else:
self._raw_data = b""
else:
self._raw_data = input_tensor.tobytes()
Copy link

Copilot AI Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the BF16-specific serialization branch removed, this method now always uses input_tensor.tobytes() for non-BYTES dtypes. The file still imports serialize_bf16_tensor, which appears unused now; please remove that import to avoid unused/dead code paths.

Copilot uses AI. Check for mistakes.
Comment on lines +127 to +133
dtype = np_to_triton_dtype(input_tensor.dtype)
if self._input.datatype != dtype:
raise_error(
"got unexpected datatype {} from numpy array, expected {}".format(
dtype, self._input.datatype
)
)
Copy link

Copilot AI Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change makes BF16 inputs require an actual BF16 numpy dtype (via ml_dtypes.bfloat16 mapping). Previously, BF16 inputs could be provided as float32 (since triton_to_np_dtype("BF16") mapped to np.float32) and the client handled BF16 serialization. If existing users rely on the float32 workaround, consider accepting both float32 and bfloat16 for BF16 inputs (converting float32 to BF16 bytes) or clearly documenting the breaking behavior change.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 7 out of 7 changed files in this pull request and generated no new comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Copy Markdown
Contributor

@whoisj whoisj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yinggeh yinggeh marked this pull request as draft March 4, 2026 21:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Development

Successfully merging this pull request may close these issues.

4 participants