Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions src/python/examples/simple_grpc_aio_infer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@
import argparse
import asyncio
import sys
from typing import List, Tuple

import numpy as np
import numpy.typing as npt
import tritonclient.grpc.aio as grpcclient


def get_triton_client(FLAGS):
def get_triton_client(FLAGS: argparse.Namespace) -> grpcclient.InferenceServerClient:
try:
triton_client = grpcclient.InferenceServerClient(
url=FLAGS.url,
Expand All @@ -50,7 +52,12 @@ def get_triton_client(FLAGS):
return triton_client


def get_inputs_and_outputs():
def get_inputs_and_outputs() -> Tuple[
List[grpcclient.InferInput],
List[grpcclient.InferRequestedOutput],
npt.NDArray[np.int32],
npt.NDArray[np.int32],
]:
# Infer
inputs = []
outputs = []
Expand All @@ -73,7 +80,11 @@ def get_inputs_and_outputs():
return inputs, outputs, input0_data, input1_data


def get_output(input0_data, input1_data, results):
def get_output(
input0_data: npt.NDArray[np.int32],
input1_data: npt.NDArray[np.int32],
results: grpcclient.InferResult,
) -> None:
# Get the output arrays from the results
output0_data = results.as_numpy("OUTPUT0")
output1_data = results.as_numpy("OUTPUT1")
Expand Down Expand Up @@ -102,7 +113,7 @@ def get_output(input0_data, input1_data, results):
sys.exit(1)


async def main(FLAGS):
async def main(FLAGS: argparse.Namespace) -> None:
# Initialize
triton_client = get_triton_client(FLAGS)
model_name = "simple"
Expand Down