feat: Enable bfloat16 input/output tensor dtype in Python client#880
feat: Enable bfloat16 input/output tensor dtype in Python client#880
Conversation
There was a problem hiding this comment.
Pull request overview
This PR enables native bfloat16 (BF16) tensor support in the Triton Python client by integrating the ml_dtypes library, replacing the previous workaround that used float32 as an intermediate representation. The change allows BF16 tensors to be serialized and deserialized directly using standard numpy operations like tobytes() and np.frombuffer().
Changes:
- Added ml_dtypes dependency (>=0.5.4) to support native bfloat16 dtype
- Updated dtype conversion functions to map BF16 to ml_dtypes.bfloat16 instead of np.float32
- Removed special-case serialization/deserialization logic for BF16 tensors across HTTP and gRPC clients
- Updated copyright years to 2026
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| src/python/library/tritonclient/utils/init.py | Added ml_dtypes import; updated dtype conversion functions to use ml_dtypes.bfloat16 for BF16; removed FP32 fallback |
| src/python/library/tritonclient/http/_infer_result.py | Removed special BF16 deserialization logic; updated copyright |
| src/python/library/tritonclient/http/_infer_input.py | Removed special BF16 validation and serialization logic; updated copyright |
| src/python/library/tritonclient/grpc/_infer_result.py | Removed special BF16 deserialization logic; updated copyright |
| src/python/library/tritonclient/grpc/_infer_input.py | Removed special BF16 validation and serialization logic; updated copyright |
| src/python/library/requirements/requirements.txt | Added ml_dtypes>=0.5.4 dependency; updated copyright |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.
Comments suppressed due to low confidence (1)
src/python/library/tritonclient/http/_infer_result.py:197
- With BF16 now handled via
triton_to_np_dtype(datatype)in the genericnp.frombufferpath,deserialize_bf16_tensoris no longer referenced in this method. The file still importsdeserialize_bf16_tensor, which is now unused; please drop it to avoid unused imports.
else:
np_array = np.frombuffer(
self._buffer[start_index:end_index],
dtype=triton_to_np_dtype(datatype),
)
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| dtype = np_to_triton_dtype(input_tensor.dtype) | ||
| if self._datatype != dtype: | ||
| raise_error( | ||
| "got unexpected datatype {} from numpy array, expected {}".format( | ||
| dtype, self._datatype | ||
| ) | ||
| ) |
There was a problem hiding this comment.
This change makes BF16 inputs require input_tensor.dtype to map to "BF16" (i.e., ml_dtypes.bfloat16). Previously the client accepted float32 tensors for BF16 (via triton_to_np_dtype("BF16") mapping) and handled conversion/serialization. If backward compatibility is desired, consider allowing float32 for BF16 here (and converting to BF16 bytes) or explicitly documenting this breaking change in the client behavior.
| if self._datatype == "BYTES": | ||
| serialized_output = serialize_byte_tensor(input_tensor) | ||
| if serialized_output.size > 0: | ||
| self._raw_data = serialized_output.item() | ||
| else: | ||
| self._raw_data = b"" | ||
| elif self._datatype == "BF16": | ||
| serialized_output = serialize_bf16_tensor(input_tensor) | ||
| if serialized_output.size > 0: | ||
| self._raw_data = serialized_output.item() | ||
| else: | ||
| self._raw_data = b"" | ||
| else: | ||
| self._raw_data = input_tensor.tobytes() |
There was a problem hiding this comment.
With the BF16-specific serialization branch removed, this method now always uses input_tensor.tobytes() for non-BYTES dtypes. The file still imports serialize_bf16_tensor, which appears unused now; please remove that import to avoid unused/dead code paths.
| dtype = np_to_triton_dtype(input_tensor.dtype) | ||
| if self._input.datatype != dtype: | ||
| raise_error( | ||
| "got unexpected datatype {} from numpy array, expected {}".format( | ||
| dtype, self._input.datatype | ||
| ) | ||
| ) |
There was a problem hiding this comment.
This change makes BF16 inputs require an actual BF16 numpy dtype (via ml_dtypes.bfloat16 mapping). Previously, BF16 inputs could be provided as float32 (since triton_to_np_dtype("BF16") mapped to np.float32) and the client handled BF16 serialization. If existing users rely on the float32 workaround, consider accepting both float32 and bfloat16 for BF16 inputs (converting float32 to BF16 bytes) or clearly documenting the breaking behavior change.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 7 out of 7 changed files in this pull request and generated no new comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Drop support for type conversions between BF16 and FLOAT32. Users should use
ml_dtypes.bf16instead. For example