-
Notifications
You must be signed in to change notification settings - Fork 183
Expand file tree
/
Copy pathopenai_agents_tools.py
More file actions
97 lines (79 loc) · 3.03 KB
/
openai_agents_tools.py
File metadata and controls
97 lines (79 loc) · 3.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import asyncio
import logging
import os
import random
from datetime import datetime
from agents import Agent, OpenAIResponsesModel, Runner, function_tool, set_tracing_disabled
from azure.identity.aio import DefaultAzureCredential, get_bearer_token_provider
from dotenv import load_dotenv
from openai import AsyncOpenAI
from rich.logging import RichHandler
# Setup logging with rich
logging.basicConfig(level=logging.WARNING, format="%(message)s", datefmt="[%X]", handlers=[RichHandler()])
logger = logging.getLogger("weekend_planner")
# Disable tracing since we're not connected to a supported tracing provider
set_tracing_disabled(disabled=True)
# Setup the OpenAI client to use either Azure OpenAI
load_dotenv(override=True)
API_HOST = os.getenv("API_HOST", "azure")
async_credential = None
if API_HOST == "azure":
async_credential = DefaultAzureCredential()
token_provider = get_bearer_token_provider(async_credential, "https://cognitiveservices.azure.com/.default")
client = AsyncOpenAI(
base_url=os.environ["AZURE_OPENAI_ENDPOINT"] + "/openai/v1",
api_key=token_provider,
)
MODEL_NAME = os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"]
elif API_HOST == "ollama":
client = AsyncOpenAI(base_url=os.environ.get("OLLAMA_ENDPOINT", "http://localhost:11434/v1"), api_key="none")
MODEL_NAME = os.environ["OLLAMA_MODEL"]
else:
client = AsyncOpenAI(api_key=os.environ["OPENAI_API_KEY"])
MODEL_NAME = os.environ.get("OPENAI_MODEL", "gpt-4o")
@function_tool
def get_weather(city: str) -> str:
logger.info(f"Getting weather for {city}")
if random.random() < 0.05:
return {
"city": city,
"temperature": 72,
"description": "Sunny",
}
else:
return {
"city": city,
"temperature": 60,
"description": "Rainy",
}
@function_tool
def get_activities(city: str, date: str) -> list:
logger.info(f"Getting activities for {city} on {date}")
return [
{"name": "Hiking", "location": city},
{"name": "Beach", "location": city},
{"name": "Museum", "location": city},
]
@function_tool
def get_current_date() -> str:
"""Gets the current date and returns as a string in format YYYY-MM-DD."""
logger.info("Getting current date")
return datetime.now().strftime("%Y-%m-%d")
agent = Agent(
name="Weekend Planner",
instructions=(
"You help users plan their weekends and choose the best activities for the given weather."
"If an activity would be unpleasant in the weather, don't suggest it."
"Include the date of the weekend in your response."
),
tools=[get_weather, get_activities, get_current_date],
model=OpenAIResponsesModel(model=MODEL_NAME, openai_client=client),
)
async def main():
result = await Runner.run(agent, input="hii what can I do this weekend in Seattle?")
print(result.final_output)
if async_credential:
await async_credential.close()
if __name__ == "__main__":
logger.setLevel(logging.INFO)
asyncio.run(main())