Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion model_analyzer/config/input/config_command_profile.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import argparse
Expand Down Expand Up @@ -317,6 +317,32 @@ def _fill_config(self):
)
)

def _preprocess_triton_http_headers(value):
"""Parse JSON string from CLI into dict"""
if isinstance(value, str):
import json

try:
return json.loads(value)
except json.JSONDecodeError as e:
raise TritonModelAnalyzerException(
f"Failed to parse triton_http_headers as JSON: {e}"
)
return value

self._add_config(
ConfigField(
"triton_http_headers",
flags=["--triton-http-headers"],
field_type=ConfigObject(
schema={"*": ConfigPrimitive(str)},
preprocess=_preprocess_triton_http_headers,
),
default_value={},
description="HTTP headers to send to Triton Server (key-value pairs)",
)
)

self._add_repository_configs()
self._add_client_configs()
self._add_profile_models_configs()
Expand Down
6 changes: 6 additions & 0 deletions model_analyzer/config/input/config_object.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

# Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
Expand Down Expand Up @@ -78,6 +80,10 @@ def set_value(self, value):
1 on success, and 0 on failure
"""

# Apply preprocessing if defined (e.g., to parse JSON strings)
if self._preprocess:
value = self._preprocess(value)

new_value = {}
schema = self._schema

Expand Down
6 changes: 5 additions & 1 deletion model_analyzer/entrypoint.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

# Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
Expand Down Expand Up @@ -50,7 +52,9 @@ def get_client_handle(config):
if config.client_protocol == "http":
http_ssl_options = get_http_ssl_options(config)
client = TritonClientFactory.create_http_client(
server_url=config.triton_http_endpoint, ssl_options=http_ssl_options
server_url=config.triton_http_endpoint,
ssl_options=http_ssl_options,
headers=config.triton_http_headers,
)
elif config.client_protocol == "grpc":
grpc_ssl_options = get_grpc_ssl_options(config)
Expand Down
18 changes: 12 additions & 6 deletions model_analyzer/triton/client/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

# Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
Expand Down Expand Up @@ -57,7 +59,7 @@ def wait_for_server_ready(
retries = num_retries
while retries > 0:
try:
if self._client.is_server_ready():
if self._client.is_server_ready(headers=self._headers):
time.sleep(sleep_time)
return
else:
Expand Down Expand Up @@ -105,7 +107,9 @@ def load_model(self, model_name, variant_name="", config_str=None):
variant_name = variant_name if variant_name else model_name

try:
self._client.load_model(model_name, config=config_str)
self._client.load_model(
model_name, config=config_str, headers=self._headers
)
logger.debug(f"Model {variant_name} loaded")
return None
except Exception as e:
Expand Down Expand Up @@ -139,7 +143,7 @@ def unload_model(self, model_name):
"""

try:
self._client.unload_model(model_name)
self._client.unload_model(model_name, headers=self._headers)
logger.debug(f"Model {model_name} unloaded")
return None
except Exception as e:
Expand Down Expand Up @@ -175,7 +179,7 @@ def wait_for_model_ready(self, model_name, num_retries, sleep_time=1):
error = None
while retries > 0:
try:
if self._client.is_model_ready(model_name):
if self._client.is_model_ready(model_name, headers=self._headers):
return None
else:
time.sleep(sleep_time)
Expand Down Expand Up @@ -207,14 +211,16 @@ def get_model_config(self, model_name, num_retries):
"""

self.wait_for_model_ready(model_name, num_retries)
model_config_dict = self._client.get_model_config(model_name)
model_config_dict = self._client.get_model_config(
model_name, headers=self._headers
)
return model_config_dict

def is_server_ready(self):
"""
Returns true if the server is ready. Else False
"""
return self._client.is_server_ready()
return self._client.is_server_ready(headers=self._headers)

def _check_for_triton_log_errors(self, log_file):
if not log_file or log_file == DEVNULL:
Expand Down
10 changes: 8 additions & 2 deletions model_analyzer/triton/client/client_factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

# Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
Expand Down Expand Up @@ -41,17 +43,21 @@ def create_grpc_client(server_url, ssl_options={}):
return TritonGRPCClient(server_url=server_url, ssl_options=ssl_options)

@staticmethod
def create_http_client(server_url, ssl_options={}):
def create_http_client(server_url, ssl_options={}, headers={}):
"""
Parameters
----------
server_url : str
The url for Triton server's HTTP endpoint
ssl_options : dict
Dictionary of SSL options for HTTP python client
headers : dict
Dictionary of HTTP headers to send to Triton Server

Returns
-------
TritonHTTPClient
"""
return TritonHTTPClient(server_url=server_url, ssl_options=ssl_options)
return TritonHTTPClient(
server_url=server_url, ssl_options=ssl_options, headers=headers
)
4 changes: 4 additions & 0 deletions model_analyzer/triton/client/grpc_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

# Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
Expand Down Expand Up @@ -69,6 +71,8 @@ def __init__(self, server_url, ssl_options={}):
channel_args=channel_args,
)

self._headers = {}

def get_model_config(self, model_name, num_retries):
"""
Model name to get the config for.
Expand Down
12 changes: 9 additions & 3 deletions model_analyzer/triton/client/http_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

# Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
Expand Down Expand Up @@ -28,14 +30,16 @@ class TritonHTTPClient(TritonClient):
for HTTP
"""

def __init__(self, server_url, ssl_options={}):
def __init__(self, server_url, ssl_options={}, headers={}):
"""
Parameters
----------
server_url : str
The url for Triton server's HTTP endpoint
ssl_options : dict
Dictionary of SSL options for HTTP python client
headers : dict
Dictionary of HTTP headers to send to Triton Server
"""

ssl = False
Expand Down Expand Up @@ -86,6 +90,8 @@ def __init__(self, server_url, ssl_options={}):
ssl_context_factory = None
insecure = False

self._headers = headers

self._client = httpclient.InferenceServerClient(
url=server_url,
ssl=ssl,
Expand All @@ -98,10 +104,10 @@ def get_model_repository_index(self):
"""
Returns the JSON dict holding the model repository index.
"""
return self._client.get_model_repository_index()
return self._client.get_model_repository_index(headers=self._headers)

def is_model_ready(self, model_name: str) -> bool:
"""
Returns true if the model is loaded on the server
"""
return self._client.is_model_ready(model_name)
return self._client.is_model_ready(model_name, headers=self._headers)
74 changes: 73 additions & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import copy
Expand Down Expand Up @@ -235,6 +235,15 @@ def get_test_options():
"local",
"SHOULD_FAIL",
),
OptionStruct(
"dict",
"profile",
"--triton-http-headers",
None,
'{"Authorization": "Bearer token123"}',
"{}",
None,
),
OptionStruct(
"string",
"profile",
Expand Down Expand Up @@ -574,6 +583,8 @@ def test_all_options(self):
self._test_string_option(option)
elif option.type in ["intlist", "stringlist"]:
self._test_list_option(option)
elif option.type == "dict":
self._test_dict_option(option)
elif option.type in ["noop"]:
pass
else:
Expand Down Expand Up @@ -765,6 +776,62 @@ def _test_list_option(self, option_struct):
expected_default_value_converted,
)

def _test_dict_option(self, option_struct):
"""Test dictionary/JSON options that are parsed from CLI strings"""
long_option = option_struct.long_flag
short_option = option_struct.short_flag
expected_value = option_struct.expected_value
expected_default_value = option_struct.expected_default_value
expected_failing_value = option_struct.expected_failing_value
extra_commands = option_struct.extra_commands

# Convert JSON strings to dicts for comparison
expected_value_converted = self._convert_json_string_to_dict(expected_value)
if expected_default_value is not None:
expected_default_value_converted = self._convert_json_string_to_dict(
expected_default_value
)
else:
expected_default_value_converted = None

long_option_with_underscores = self._convert_flag_to_use_underscores(
long_option
)

# Test long flag
self._test_long_flag(
long_option,
option_struct.cli_subcommand,
expected_value,
long_option_with_underscores,
expected_value_converted,
extra_commands,
)

# Test short flag (if it exists)
self._test_short_flag(
short_option,
option_struct.cli_subcommand,
expected_value,
long_option_with_underscores,
expected_value_converted,
)

# Test default value
if expected_default_value is not None:
self._test_expected_default_value(
option_struct.cli_subcommand,
long_option_with_underscores,
expected_default_value_converted,
)

# Test invalid JSON causes failure
if expected_failing_value is not None:
cli = option_struct.cli_subcommand()
cli.args.extend([long_option, expected_failing_value])
with self.assertRaises(TritonModelAnalyzerException):
_, config = cli.parse()

# Helper methods

def _test_long_flag(
Expand Down Expand Up @@ -827,6 +894,11 @@ def _convert_string_to_string_list(self, list_values):
return ret_val[0]
return ret_val

def _convert_json_string_to_dict(self, json_string):
import json

return json.loads(json_string)


if __name__ == "__main__":
unittest.main()
Loading