diff --git a/agentrun/__init__.py b/agentrun/__init__.py index 316a2e5..24961ac 100644 --- a/agentrun/__init__.py +++ b/agentrun/__init__.py @@ -109,6 +109,10 @@ SandboxClient, Template, ) +# Tool +from agentrun.tool import Tool as ToolResource +from agentrun.tool import ToolClient as ToolResourceClient +from agentrun.tool import ToolControlAPI as ToolResourceControlAPI # ToolSet from agentrun.toolset import ToolSet, ToolSetClient from agentrun.utils.config import Config @@ -247,6 +251,10 @@ "AioSandbox", "CustomSandbox", "Template", + ######## Tool ######## + "ToolResource", + "ToolResourceClient", + "ToolResourceControlAPI", ######## ToolSet ######## "ToolSetClient", "ToolSet", diff --git a/agentrun/integration/agentscope/__init__.py b/agentrun/integration/agentscope/__init__.py index d9e108f..91ba476 100644 --- a/agentrun/integration/agentscope/__init__.py +++ b/agentrun/integration/agentscope/__init__.py @@ -3,11 +3,20 @@ 提供 AgentRun 模型与沙箱工具的 AgentScope 适配入口。 / 提供 AgentRun 模型with沙箱工具的 AgentScope 适配入口。 """ -from .builtin import knowledgebase_toolset, model, sandbox_toolset, toolset +from .builtin import ( + knowledgebase_toolset, + model, + sandbox_toolset, + skill_tools, + tool_resource, + toolset, +) __all__ = [ "model", "toolset", "sandbox_toolset", "knowledgebase_toolset", + "tool_resource", + "skill_tools", ] diff --git a/agentrun/integration/agentscope/builtin.py b/agentrun/integration/agentscope/builtin.py index 1a94e7f..17f7aff 100644 --- a/agentrun/integration/agentscope/builtin.py +++ b/agentrun/integration/agentscope/builtin.py @@ -14,10 +14,13 @@ from agentrun.integration.builtin import model as _model from agentrun.integration.builtin import ModelArgs from agentrun.integration.builtin import sandbox_toolset as _sandbox_toolset +from agentrun.integration.builtin import skill_tools as _skill_tools +from agentrun.integration.builtin import tool_resource as _tool_resource from agentrun.integration.builtin import toolset as _toolset from agentrun.integration.utils.tool import Tool from agentrun.model import ModelProxy, ModelService from agentrun.sandbox import TemplateType +from agentrun.tool.tool import Tool as ToolResourceType from agentrun.toolset import ToolSet from agentrun.utils.config import Config @@ -50,6 +53,24 @@ def toolset( ) +def tool_resource( + name: Union[str, ToolResourceType], + *, + prefix: Optional[str] = None, + modify_tool_name: Optional[Callable[[Tool], Tool]] = None, + filter_tools_by_name: Optional[Callable[[str], bool]] = None, + config: Optional[Config] = None, +) -> List[Any]: + """将 ToolResource 封装为 AgentScope 工具列表。 / AgentScope Built-in ToolResource Integration""" + + ts = _tool_resource(input=name, config=config) + return ts.to_agentscope( + prefix=prefix, + modify_tool_name=modify_tool_name, + filter_tools_by_name=filter_tools_by_name, + ) + + def sandbox_toolset( template_name: str, *, @@ -86,3 +107,18 @@ def knowledgebase_toolset( modify_tool_name=modify_tool_name, filter_tools_by_name=filter_tools_by_name, ) + + +def skill_tools( + name: Optional[Union[str, List[str]]] = None, + *, + skills_dir: str = ".skills", + prefix: Optional[str] = None, + config: Optional[Config] = None, +) -> List[Any]: + """将 Skill 封装为 AgentScope 工具列表。 / AgentScope Built-in Skill Integration""" + + ts = _skill_tools(name=name, skills_dir=skills_dir, config=config) + return ts.to_agentscope( + prefix=prefix, + ) diff --git a/agentrun/integration/builtin/__init__.py b/agentrun/integration/builtin/__init__.py index 49f4258..8b7f084 100644 --- a/agentrun/integration/builtin/__init__.py +++ b/agentrun/integration/builtin/__init__.py @@ -7,12 +7,16 @@ from .knowledgebase import knowledgebase_toolset from .model import model, ModelArgs from .sandbox import sandbox_toolset +from .skill import skill_tools +from .tool_resource import tool_resource from .toolset import toolset __all__ = [ "model", "ModelArgs", "toolset", + "tool_resource", "sandbox_toolset", "knowledgebase_toolset", + "skill_tools", ] diff --git a/agentrun/integration/builtin/skill.py b/agentrun/integration/builtin/skill.py new file mode 100644 index 0000000..d0f057c --- /dev/null +++ b/agentrun/integration/builtin/skill.py @@ -0,0 +1,16 @@ +"""内置 Skill 集成函数 / Built-in Skill Integration Functions + +提供快速创建 Skill 工具集对象的便捷函数。 +Provides convenient functions for quickly creating Skill toolset objects. +""" + +from typing import List, Optional, Union + +from agentrun.integration.utils.skill_loader import skill_tools as _skill_tools +from agentrun.integration.utils.tool import CommonToolSet +from agentrun.utils.config import Config + +# Re-export for convenience +skill_tools = _skill_tools + +__all__ = ["skill_tools"] diff --git a/agentrun/integration/builtin/tool_resource.py b/agentrun/integration/builtin/tool_resource.py new file mode 100644 index 0000000..18d01b8 --- /dev/null +++ b/agentrun/integration/builtin/tool_resource.py @@ -0,0 +1,48 @@ +"""内置 ToolResource 集成函数 / Built-in ToolResource Integration Functions + +提供快速创建通用工具集对象的便捷函数(基于新版 Tool 模块)。 +Provides convenient functions for quickly creating common toolset objects (based on new Tool module). +""" + +from typing import Optional, Union + +from agentrun.integration.utils.tool import CommonToolSet +from agentrun.tool.client import ToolClient +from agentrun.tool.tool import Tool as ToolResourceType +from agentrun.utils.config import Config + + +def tool_resource( + input: Union[str, ToolResourceType], config: Optional[Config] = None +) -> CommonToolSet: + """将 ToolResource 封装为通用工具集 / Wrap ToolResource as CommonToolSet + + 支持从工具名称或 ToolResource 实例创建通用工具集。 + Supports creating CommonToolSet from tool name or ToolResource instance. + + Args: + input: 工具名称或 ToolResource 实例 / Tool name or ToolResource instance + config: 配置对象 / Configuration object + + Returns: + CommonToolSet: 通用工具集实例 / CommonToolSet instance + + Examples: + >>> # 从工具名称创建 / Create from tool name + >>> ts = tool_resource("my-tool") + >>> + >>> # 从 ToolResource 实例创建 / Create from ToolResource instance + >>> tool = ToolClient().get(name="my-tool") + >>> ts = tool_resource(tool) + >>> + >>> # 转换为 LangChain 工具 / Convert to LangChain tools + >>> lc_tools = ts.to_langchain() + """ + + resource = ( + input + if isinstance(input, ToolResourceType) + else ToolClient().get(name=input, config=config) + ) + + return CommonToolSet.from_agentrun_tool(resource, config=config) diff --git a/agentrun/integration/crewai/__init__.py b/agentrun/integration/crewai/__init__.py index 46ab61d..f2e581e 100644 --- a/agentrun/integration/crewai/__init__.py +++ b/agentrun/integration/crewai/__init__.py @@ -4,10 +4,18 @@ CrewAI 与 LangChain 兼容,因此直接复用 LangChain 的转换逻辑。 / CrewAI with LangChain 兼容,因此直接复用 LangChain 的转换逻辑。 """ -from .builtin import knowledgebase_toolset, model, sandbox_toolset +from .builtin import ( + knowledgebase_toolset, + model, + sandbox_toolset, + skill_tools, + tool_resource, +) __all__ = [ "model", "sandbox_toolset", "knowledgebase_toolset", + "tool_resource", + "skill_tools", ] diff --git a/agentrun/integration/crewai/builtin.py b/agentrun/integration/crewai/builtin.py index beda5a7..1c8aadb 100644 --- a/agentrun/integration/crewai/builtin.py +++ b/agentrun/integration/crewai/builtin.py @@ -14,10 +14,13 @@ from agentrun.integration.builtin import model as _model from agentrun.integration.builtin import ModelArgs from agentrun.integration.builtin import sandbox_toolset as _sandbox_toolset +from agentrun.integration.builtin import skill_tools as _skill_tools +from agentrun.integration.builtin import tool_resource as _tool_resource from agentrun.integration.builtin import toolset as _toolset from agentrun.integration.utils.tool import Tool from agentrun.model import ModelProxy, ModelService from agentrun.sandbox import TemplateType +from agentrun.tool.tool import Tool as ToolResourceType from agentrun.toolset import ToolSet from agentrun.utils.config import Config @@ -50,6 +53,24 @@ def toolset( ) +def tool_resource( + name: Union[str, ToolResourceType], + *, + prefix: Optional[str] = None, + modify_tool_name: Optional[Callable[[Tool], Tool]] = None, + filter_tools_by_name: Optional[Callable[[str], bool]] = None, + config: Optional[Config] = None, +) -> List[Any]: + """将 ToolResource 封装为 CrewAI 工具列表。 / CrewAI Built-in ToolResource Integration""" + + ts = _tool_resource(input=name, config=config) + return ts.to_crewai( + prefix=prefix, + modify_tool_name=modify_tool_name, + filter_tools_by_name=filter_tools_by_name, + ) + + def sandbox_toolset( template_name: str, *, @@ -86,3 +107,18 @@ def knowledgebase_toolset( modify_tool_name=modify_tool_name, filter_tools_by_name=filter_tools_by_name, ) + + +def skill_tools( + name: Optional[Union[str, List[str]]] = None, + *, + skills_dir: str = ".skills", + prefix: Optional[str] = None, + config: Optional[Config] = None, +) -> List[Any]: + """将 Skill 封装为 CrewAI 工具列表。 / CrewAI Built-in Skill Integration""" + + ts = _skill_tools(name=name, skills_dir=skills_dir, config=config) + return ts.to_crewai( + prefix=prefix, + ) diff --git a/agentrun/integration/google_adk/__init__.py b/agentrun/integration/google_adk/__init__.py index 372f64d..fad29ba 100644 --- a/agentrun/integration/google_adk/__init__.py +++ b/agentrun/integration/google_adk/__init__.py @@ -3,11 +3,20 @@ 提供与 Google Agent Development Kit 的模型与沙箱工具集成。 / 提供with Google Agent Development Kit 的模型with沙箱工具集成。 """ -from .builtin import knowledgebase_toolset, model, sandbox_toolset, toolset +from .builtin import ( + knowledgebase_toolset, + model, + sandbox_toolset, + skill_tools, + tool_resource, + toolset, +) __all__ = [ "model", "toolset", "sandbox_toolset", "knowledgebase_toolset", + "tool_resource", + "skill_tools", ] diff --git a/agentrun/integration/google_adk/builtin.py b/agentrun/integration/google_adk/builtin.py index e655f8f..9622565 100644 --- a/agentrun/integration/google_adk/builtin.py +++ b/agentrun/integration/google_adk/builtin.py @@ -14,10 +14,13 @@ from agentrun.integration.builtin import model as _model from agentrun.integration.builtin import ModelArgs from agentrun.integration.builtin import sandbox_toolset as _sandbox_toolset +from agentrun.integration.builtin import skill_tools as _skill_tools +from agentrun.integration.builtin import tool_resource as _tool_resource from agentrun.integration.builtin import toolset as _toolset from agentrun.integration.utils.tool import Tool from agentrun.model import ModelProxy, ModelService from agentrun.sandbox import TemplateType +from agentrun.tool.tool import Tool as ToolResourceType from agentrun.toolset import ToolSet from agentrun.utils.config import Config @@ -50,6 +53,24 @@ def toolset( ) +def tool_resource( + name: Union[str, ToolResourceType], + *, + prefix: Optional[str] = None, + modify_tool_name: Optional[Callable[[Tool], Tool]] = None, + filter_tools_by_name: Optional[Callable[[str], bool]] = None, + config: Optional[Config] = None, +) -> List[Any]: + """将 ToolResource 封装为 Google ADK 工具列表。 / Google ADK Built-in ToolResource Integration""" + + ts = _tool_resource(input=name, config=config) + return ts.to_google_adk( + prefix=prefix, + modify_tool_name=modify_tool_name, + filter_tools_by_name=filter_tools_by_name, + ) + + def sandbox_toolset( template_name: str, *, @@ -86,3 +107,18 @@ def knowledgebase_toolset( modify_tool_name=modify_tool_name, filter_tools_by_name=filter_tools_by_name, ) + + +def skill_tools( + name: Optional[Union[str, List[str]]] = None, + *, + skills_dir: str = ".skills", + prefix: Optional[str] = None, + config: Optional[Config] = None, +) -> List[Any]: + """将 Skill 封装为 Google ADK 工具列表。 / Google ADK Built-in Skill Integration""" + + ts = _skill_tools(name=name, skills_dir=skills_dir, config=config) + return ts.to_google_adk( + prefix=prefix, + ) diff --git a/agentrun/integration/langchain/__init__.py b/agentrun/integration/langchain/__init__.py index 3e48086..9cad7e6 100644 --- a/agentrun/integration/langchain/__init__.py +++ b/agentrun/integration/langchain/__init__.py @@ -20,7 +20,14 @@ AgentRunConverter, ) # 向后兼容 -from .builtin import knowledgebase_toolset, model, sandbox_toolset, toolset +from .builtin import ( + knowledgebase_toolset, + model, + sandbox_toolset, + skill_tools, + tool_resource, + toolset, +) __all__ = [ "AgentRunConverter", @@ -28,4 +35,6 @@ "toolset", "sandbox_toolset", "knowledgebase_toolset", + "tool_resource", + "skill_tools", ] diff --git a/agentrun/integration/langchain/builtin.py b/agentrun/integration/langchain/builtin.py index c18e479..9c6b9ab 100644 --- a/agentrun/integration/langchain/builtin.py +++ b/agentrun/integration/langchain/builtin.py @@ -14,10 +14,13 @@ from agentrun.integration.builtin import model as _model from agentrun.integration.builtin import ModelArgs from agentrun.integration.builtin import sandbox_toolset as _sandbox_toolset +from agentrun.integration.builtin import skill_tools as _skill_tools +from agentrun.integration.builtin import tool_resource as _tool_resource from agentrun.integration.builtin import toolset as _toolset from agentrun.integration.utils.tool import Tool from agentrun.model import ModelProxy, ModelService from agentrun.sandbox import TemplateType +from agentrun.tool.tool import Tool as ToolResourceType from agentrun.toolset import ToolSet from agentrun.utils.config import Config @@ -50,6 +53,24 @@ def toolset( ) +def tool_resource( + name: Union[str, ToolResourceType], + *, + prefix: Optional[str] = None, + modify_tool_name: Optional[Callable[[Tool], Tool]] = None, + filter_tools_by_name: Optional[Callable[[str], bool]] = None, + config: Optional[Config] = None, +) -> List[Any]: + """将 ToolResource 封装为 LangChain ``StructuredTool`` 列表。 / LangChain Built-in ToolResource Integration""" + + ts = _tool_resource(input=name, config=config) + return ts.to_langchain( + prefix=prefix, + modify_tool_name=modify_tool_name, + filter_tools_by_name=filter_tools_by_name, + ) + + def sandbox_toolset( template_name: str, *, @@ -92,3 +113,22 @@ def knowledgebase_toolset( modify_tool_name=modify_tool_name, filter_tools_by_name=filter_tools_by_name, ) + + +def skill_tools( + name: Optional[Union[str, List[str]]] = None, + *, + skills_dir: str = ".skills", + prefix: Optional[str] = None, + modify_tool_name: Optional[Callable[[Tool], Tool]] = None, + filter_tools_by_name: Optional[Callable[[str], bool]] = None, + config: Optional[Config] = None, +) -> List[Any]: + """将 Skill 封装为 LangChain ``StructuredTool`` 列表。 / LangChain Built-in Skill Integration""" + + ts = _skill_tools(name=name, skills_dir=skills_dir, config=config) + return ts.to_langchain( + prefix=prefix, + modify_tool_name=modify_tool_name, + filter_tools_by_name=filter_tools_by_name, + ) diff --git a/agentrun/integration/langgraph/__init__.py b/agentrun/integration/langgraph/__init__.py index 71fa409..141e6c6 100644 --- a/agentrun/integration/langgraph/__init__.py +++ b/agentrun/integration/langgraph/__init__.py @@ -15,8 +15,8 @@ >>> from agentrun.integration.langgraph import AgentRunConverter >>> >>> async for event in agent.astream_events(input_data, version="v2"): - ... for item in AgentRunConverter.to_agui_events(event): - ... yield item + ... for item in AgentRunConverter.to_agui_events(event): + ... yield item 支持多种调用方式: - agent.astream_events(input, version="v2") - 支持 token by token @@ -25,7 +25,14 @@ """ from .agent_converter import AgentRunConverter -from .builtin import knowledgebase_toolset, model, sandbox_toolset, toolset +from .builtin import ( + knowledgebase_toolset, + model, + sandbox_toolset, + skill_tools, + tool_resource, + toolset, +) __all__ = [ "AgentRunConverter", @@ -33,4 +40,6 @@ "toolset", "sandbox_toolset", "knowledgebase_toolset", + "tool_resource", + "skill_tools", ] diff --git a/agentrun/integration/langgraph/builtin.py b/agentrun/integration/langgraph/builtin.py index a9efaae..5b06979 100644 --- a/agentrun/integration/langgraph/builtin.py +++ b/agentrun/integration/langgraph/builtin.py @@ -14,10 +14,13 @@ from agentrun.integration.builtin import model as _model from agentrun.integration.builtin import ModelArgs from agentrun.integration.builtin import sandbox_toolset as _sandbox_toolset +from agentrun.integration.builtin import skill_tools as _skill_tools +from agentrun.integration.builtin import tool_resource as _tool_resource from agentrun.integration.builtin import toolset as _toolset from agentrun.integration.utils.tool import Tool from agentrun.model import ModelProxy, ModelService from agentrun.sandbox import TemplateType +from agentrun.tool.tool import Tool as ToolResourceType from agentrun.toolset import ToolSet from agentrun.utils.config import Config @@ -50,6 +53,24 @@ def toolset( ) +def tool_resource( + name: Union[str, ToolResourceType], + *, + prefix: Optional[str] = None, + modify_tool_name: Optional[Callable[[Tool], Tool]] = None, + filter_tools_by_name: Optional[Callable[[str], bool]] = None, + config: Optional[Config] = None, +) -> List[Any]: + """将 ToolResource 封装为 LangGraph 工具列表。 / LangGraph Built-in ToolResource Integration""" + + ts = _tool_resource(input=name, config=config) + return ts.to_langgraph( + prefix=prefix, + modify_tool_name=modify_tool_name, + filter_tools_by_name=filter_tools_by_name, + ) + + def sandbox_toolset( template_name: str, *, @@ -86,3 +107,22 @@ def knowledgebase_toolset( modify_tool_name=modify_tool_name, filter_tools_by_name=filter_tools_by_name, ) + + +def skill_tools( + name: Optional[Union[str, List[str]]] = None, + *, + skills_dir: str = ".skills", + prefix: Optional[str] = None, + modify_tool_name: Optional[Callable[[Tool], Tool]] = None, + filter_tools_by_name: Optional[Callable[[str], bool]] = None, + config: Optional[Config] = None, +) -> List[Any]: + """将 Skill 封装为 LangGraph 工具列表。 / LangGraph Built-in Skill Integration""" + + ts = _skill_tools(name=name, skills_dir=skills_dir, config=config) + return ts.to_langgraph( + prefix=prefix, + modify_tool_name=modify_tool_name, + filter_tools_by_name=filter_tools_by_name, + ) diff --git a/agentrun/integration/pydantic_ai/__init__.py b/agentrun/integration/pydantic_ai/__init__.py index 5a04376..0af972e 100644 --- a/agentrun/integration/pydantic_ai/__init__.py +++ b/agentrun/integration/pydantic_ai/__init__.py @@ -3,11 +3,20 @@ 提供 AgentRun 模型与沙箱工具的 PydanticAI 适配入口。 / 提供 AgentRun 模型with沙箱工具的 PydanticAI 适配入口。 """ -from .builtin import knowledgebase_toolset, model, sandbox_toolset, toolset +from .builtin import ( + knowledgebase_toolset, + model, + sandbox_toolset, + skill_tools, + tool_resource, + toolset, +) __all__ = [ "model", "toolset", "sandbox_toolset", "knowledgebase_toolset", + "tool_resource", + "skill_tools", ] diff --git a/agentrun/integration/pydantic_ai/builtin.py b/agentrun/integration/pydantic_ai/builtin.py index a5e5b05..eb235f9 100644 --- a/agentrun/integration/pydantic_ai/builtin.py +++ b/agentrun/integration/pydantic_ai/builtin.py @@ -14,10 +14,13 @@ from agentrun.integration.builtin import model as _model from agentrun.integration.builtin import ModelArgs from agentrun.integration.builtin import sandbox_toolset as _sandbox_toolset +from agentrun.integration.builtin import skill_tools as _skill_tools +from agentrun.integration.builtin import tool_resource as _tool_resource from agentrun.integration.builtin import toolset as _toolset from agentrun.integration.utils.tool import Tool from agentrun.model import ModelProxy, ModelService from agentrun.sandbox import TemplateType +from agentrun.tool.tool import Tool as ToolResourceType from agentrun.toolset import ToolSet from agentrun.utils.config import Config @@ -50,6 +53,24 @@ def toolset( ) +def tool_resource( + name: Union[str, ToolResourceType], + *, + prefix: Optional[str] = None, + modify_tool_name: Optional[Callable[[Tool], Tool]] = None, + filter_tools_by_name: Optional[Callable[[str], bool]] = None, + config: Optional[Config] = None, +) -> List[Any]: + """将 ToolResource 封装为 PydanticAI 工具列表。 / PydanticAI Built-in ToolResource Integration""" + + ts = _tool_resource(input=name, config=config) + return ts.to_pydantic_ai( + prefix=prefix, + modify_tool_name=modify_tool_name, + filter_tools_by_name=filter_tools_by_name, + ) + + def sandbox_toolset( template_name: str, *, @@ -86,3 +107,18 @@ def knowledgebase_toolset( modify_tool_name=modify_tool_name, filter_tools_by_name=filter_tools_by_name, ) + + +def skill_tools( + name: Optional[Union[str, List[str]]] = None, + *, + skills_dir: str = ".skills", + prefix: Optional[str] = None, + config: Optional[Config] = None, +) -> List[Any]: + """将 Skill 封装为 PydanticAI 工具列表。 / PydanticAI Built-in Skill Integration""" + + ts = _skill_tools(name=name, skills_dir=skills_dir, config=config) + return ts.to_pydantic_ai( + prefix=prefix, + ) diff --git a/agentrun/integration/utils/skill_loader.py b/agentrun/integration/utils/skill_loader.py new file mode 100644 index 0000000..ca178ef --- /dev/null +++ b/agentrun/integration/utils/skill_loader.py @@ -0,0 +1,459 @@ +"""Skill 加载器模块 / Skill Loader Module + +提供从本地 .skills 目录加载 Skill 包的能力,并构造 load_skills 工具供 Agent 运行时调用。 +Provides the ability to load Skill packages from a local .skills directory +and construct a load_skills tool for Agent runtime invocation. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +import json +import os +import re +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union + +from agentrun.integration.utils.tool import CommonToolSet, Tool, ToolParameter +from agentrun.utils.log import logger + +if TYPE_CHECKING: + from agentrun.tool.tool import Tool as ToolResource + from agentrun.utils.config import Config + + +@dataclass +class SkillInfo: + """Skill 摘要信息 / Skill summary information + + Attributes: + name: skill 名称 / skill name + description: skill 描述 / skill description + version: skill 版本 / skill version + path: 本地目录路径 / local directory path + """ + + name: str + description: str = "" + version: str = "" + path: str = "" + + +@dataclass +class SkillDetail(SkillInfo): + """Skill 详细信息 / Skill detail information + + Attributes: + instruction: SKILL.md 全文内容 / full content of SKILL.md + files: 目录下的文件/文件夹列表 / list of files/folders in the directory + """ + + instruction: str = "" + files: List[str] = field(default_factory=list) + + +def _parse_frontmatter(content: str) -> Dict[str, str]: + """解析 SKILL.md 的 YAML frontmatter / Parse YAML frontmatter from SKILL.md + + 使用简单的正则解析,避免引入 PyYAML 依赖。 + Uses simple regex parsing to avoid introducing PyYAML dependency. + + Args: + content: SKILL.md 文件内容 / SKILL.md file content + + Returns: + 解析出的 key-value 字典 / parsed key-value dictionary + """ + match = re.match(r"^---\s*\n(.*?)\n---", content, re.DOTALL) + if not match: + return {} + result: Dict[str, str] = {} + for line in match.group(1).split("\n"): + line = line.strip() + if not line or ":" not in line: + continue + key, _, value = line.partition(":") + key = key.strip() + value = value.strip().strip('"').strip("'") + if key: + result[key] = value + return result + + +class SkillLoader: + """Skill 加载器 / Skill Loader + + 负责扫描本地 .skills 目录、解析 skill 元信息、读取 skill 指令内容, + 并构造 load_skills 工具供 Agent 运行时调用。 + + Responsible for scanning the local .skills directory, parsing skill metadata, + reading skill instruction content, and constructing the load_skills tool + for Agent runtime invocation. + + Args: + skills_dir: 本地 skill 目录路径 / local skill directory path + remote_skill_names: 需要从远程下载的 skill 名称列表 / list of remote skill names to download + config: 配置对象 / configuration object + """ + + def __init__( + self, + skills_dir: str = ".skills", + remote_skill_names: Optional[List[str]] = None, + config: Optional["Config"] = None, + ): + self._skills_dir = skills_dir + self._remote_skill_names = remote_skill_names or [] + self._config = config + self._skills_cache: Optional[List[SkillInfo]] = None + + def _ensure_skills_available(self) -> None: + """确保远程 skill 已下载到本地 / Ensure remote skills are downloaded locally + + 对每个 remote_skill_name,检查本地是否已存在对应目录, + 不存在则通过 ToolClient 下载。 + + For each remote_skill_name, check if the local directory exists, + download via ToolClient if not. + """ + if not self._remote_skill_names: + return + + from agentrun.tool.client import ToolClient + + for skill_name in self._remote_skill_names: + skill_path = os.path.join(self._skills_dir, skill_name) + if os.path.isdir(skill_path): + logger.debug( + f"Skill '{skill_name}' already exists at {skill_path}, " + "skipping download" + ) + continue + logger.info( + f"Downloading remote skill '{skill_name}' to {self._skills_dir}" + ) + tool_resource = ToolClient().get( + name=skill_name, config=self._config + ) + tool_resource.download_skill( + target_dir=self._skills_dir, config=self._config + ) + + def _parse_skill_metadata(self, skill_dir: str) -> SkillInfo: + """解析 skill 元信息 / Parse skill metadata + + 按以下优先级获取 skill 的 name 和 description: + 1. SKILL.md 的 YAML frontmatter + 2. package.json + 3. 目录名作为 name,description 为空字符串 + + Priority for getting skill name and description: + 1. SKILL.md YAML frontmatter + 2. package.json + 3. Directory name as name, empty string as description + + Args: + skill_dir: skill 目录的完整路径 / full path to skill directory + + Returns: + SkillInfo 实例 / SkillInfo instance + """ + dir_name = os.path.basename(skill_dir) + name = dir_name + description = "" + version = "" + + skill_md_path = os.path.join(skill_dir, "SKILL.md") + if os.path.isfile(skill_md_path): + try: + with open(skill_md_path, "r", encoding="utf-8") as file_handle: + content = file_handle.read() + frontmatter = _parse_frontmatter(content) + if frontmatter.get("name"): + name = frontmatter["name"] + if frontmatter.get("description"): + description = frontmatter["description"] + if frontmatter.get("version"): + version = frontmatter["version"] + if name != dir_name or description or version: + return SkillInfo( + name=name, + description=description, + version=version, + path=skill_dir, + ) + except (OSError, UnicodeDecodeError) as error: + logger.warning( + f"Failed to read SKILL.md in {skill_dir}: {error}" + ) + + package_json_path = os.path.join(skill_dir, "package.json") + if os.path.isfile(package_json_path): + try: + with open( + package_json_path, "r", encoding="utf-8" + ) as file_handle: + package_data = json.load(file_handle) + if package_data.get("name"): + name = package_data["name"] + if package_data.get("description"): + description = package_data["description"] + if package_data.get("version"): + version = package_data["version"] + except (OSError, json.JSONDecodeError, UnicodeDecodeError) as error: + logger.warning( + f"Failed to read package.json in {skill_dir}: {error}" + ) + + return SkillInfo( + name=name, description=description, version=version, path=skill_dir + ) + + def scan_skills(self) -> List[SkillInfo]: + """扫描 .skills/ 目录,返回所有 skill 的摘要信息 / Scan .skills/ directory and return all skill summaries + + Returns: + SkillInfo 列表 / list of SkillInfo + """ + if self._skills_cache is not None: + return self._skills_cache + + self._ensure_skills_available() + + if not os.path.isdir(self._skills_dir): + self._skills_cache = [] + return self._skills_cache + + skills: List[SkillInfo] = [] + try: + entries = sorted(os.listdir(self._skills_dir)) + except OSError as error: + logger.warning( + f"Failed to list skills directory {self._skills_dir}: {error}" + ) + self._skills_cache = [] + return self._skills_cache + + for entry in entries: + entry_path = os.path.join(self._skills_dir, entry) + if os.path.isdir(entry_path) and not entry.startswith("."): + skill_info = self._parse_skill_metadata(entry_path) + skills.append(skill_info) + + self._skills_cache = skills + return self._skills_cache + + def load_skill(self, name: str) -> Optional[SkillDetail]: + """加载指定 skill 的详细信息 / Load detailed information for a specific skill + + Args: + name: skill 名称 / skill name + + Returns: + SkillDetail 实例,如果 skill 不存在则返回 None / + SkillDetail instance, or None if skill does not exist + """ + skills = self.scan_skills() + target_skill: Optional[SkillInfo] = None + for skill in skills: + if skill.name == name: + target_skill = skill + break + + if target_skill is None: + return None + + instruction = "" + skill_md_path = os.path.join(target_skill.path, "SKILL.md") + if os.path.isfile(skill_md_path): + try: + with open(skill_md_path, "r", encoding="utf-8") as file_handle: + instruction = file_handle.read() + except (OSError, UnicodeDecodeError) as error: + logger.warning( + f"Failed to read SKILL.md for skill '{name}': {error}" + ) + + files: List[str] = [] + try: + for entry in sorted(os.listdir(target_skill.path)): + if not entry.startswith("."): + entry_path = os.path.join(target_skill.path, entry) + if os.path.isdir(entry_path): + files.append(entry + "/") + else: + files.append(entry) + except OSError as error: + logger.warning(f"Failed to list files for skill '{name}': {error}") + + return SkillDetail( + name=target_skill.name, + description=target_skill.description, + version=target_skill.version, + path=target_skill.path, + instruction=instruction, + files=files, + ) + + def _build_tool_description(self, skills: List[SkillInfo]) -> str: + """构建 load_skills 工具的 description / Build the description for the load_skills tool + + 将所有可用 skill 的名称和描述写入工具描述中。 + Writes all available skill names and descriptions into the tool description. + + Args: + skills: skill 摘要列表 / list of skill summaries + + Returns: + 工具描述字符串 / tool description string + """ + if not skills: + return ( + "Load skill instructions for the agent. " + "No skills available in the configured directory." + ) + + skill_lines = [] + for skill in skills: + desc_part = f": {skill.description}" if skill.description else "" + skill_lines.append(f"- {skill.name}{desc_part}") + + skills_list = "\n".join(skill_lines) + return ( + "Load skill instructions for the agent. " + "Call without arguments to list all skills, " + "or with a skill name to get detailed instructions.\n\n" + f"Available skills:\n{skills_list}" + ) + + def _load_skills_func(self, name: Optional[str] = None) -> str: + """load_skills 工具的执行函数 / Execution function for the load_skills tool + + Args: + name: skill 名称(可选)/ skill name (optional) + + Returns: + JSON 字符串 / JSON string + """ + if name is None or name == "": + skills = self.scan_skills() + result: Dict[str, Any] = { + "skills": [ + {"name": skill.name, "description": skill.description} + for skill in skills + ] + } + return json.dumps(result, ensure_ascii=False) + + detail = self.load_skill(name) + if detail is None: + available = [skill.name for skill in self.scan_skills()] + available_str = ", ".join(available) if available else "none" + error_result: Dict[str, str] = { + "error": ( + f"Skill '{name}' not found. " + f"Available skills: {available_str}" + ) + } + return json.dumps(error_result, ensure_ascii=False) + + detail_result: Dict[str, Any] = { + "name": detail.name, + "description": detail.description, + "instruction": detail.instruction, + "files": detail.files, + } + return json.dumps(detail_result, ensure_ascii=False) + + def to_common_toolset(self) -> CommonToolSet: + """构造包含 load_skills 工具的 CommonToolSet / Construct CommonToolSet with load_skills tool + + Returns: + CommonToolSet 实例 / CommonToolSet instance + """ + skills = self.scan_skills() + description = self._build_tool_description(skills) + + load_skills_tool = Tool( + name="load_skills", + description=description, + parameters=[ + ToolParameter( + name="name", + param_type="string", + description=( + "The name of the skill to load. " + "If omitted, returns a list of all available skills." + ), + required=False, + ), + ], + func=self._load_skills_func, + ) + + return CommonToolSet(tools_list=[load_skills_tool]) + + +def skill_tools( + name: Optional[Union[str, List[str], "ToolResource"]] = None, + *, + skills_dir: str = ".skills", + config: Optional["Config"] = None, +) -> CommonToolSet: + """将 Skill 封装为通用工具集 / Wrap Skills as CommonToolSet + + 支持从工具名称、名称列表或 ToolResource 实例创建通用工具集。 + Supports creating CommonToolSet from tool name, name list, or ToolResource instance. + + Args: + name: 远程 skill 名称、名称列表或 ToolResource 实例(可选)/ + Remote skill name, name list, or ToolResource instance (optional). + 如果提供,会先下载到 skills_dir 再加载 / + If provided, downloads to skills_dir before loading. + 如果不提供,仅从 skills_dir 加载本地已有的 skill / + If not provided, only loads local skills from skills_dir. + skills_dir: 本地 skill 目录,默认 ".skills" / Local skill directory, default ".skills" + config: 配置对象 / Configuration object + + Returns: + CommonToolSet: 包含 load_skills 工具的通用工具集 / + CommonToolSet containing the load_skills tool + + Examples: + >>> # 仅加载本地 skill / Load local skills only + >>> ts = skill_tools(skills_dir=".skills") + >>> + >>> # 下载远程 skill 后加载 / Download remote skill then load + >>> ts = skill_tools("my-remote-skill") + >>> + >>> # 下载多个远程 skill / Download multiple remote skills + >>> ts = skill_tools(["skill-a", "skill-b"]) + >>> + >>> # 转换为 LangChain 工具 / Convert to LangChain tools + >>> lc_tools = ts.to_langchain() + """ + remote_names: List[str] = [] + + if name is not None: + if isinstance(name, str): + remote_names = [name] + elif isinstance(name, list): + remote_names = name + else: + # ToolResource instance — extract its name and download + tool_resource_instance = name + resource_name = getattr( + tool_resource_instance, "name", None + ) or getattr(tool_resource_instance, "tool_name", None) + if resource_name: + skill_path = os.path.join(skills_dir, resource_name) + if not os.path.isdir(skill_path): + tool_resource_instance.download_skill( + target_dir=skills_dir, config=config + ) + + loader = SkillLoader( + skills_dir=skills_dir, + remote_skill_names=remote_names, + config=config, + ) + return loader.to_common_toolset() diff --git a/agentrun/integration/utils/tool.py b/agentrun/integration/utils/tool.py index bde72a5..c479846 100644 --- a/agentrun/integration/utils/tool.py +++ b/agentrun/integration/utils/tool.py @@ -47,6 +47,7 @@ ) if TYPE_CHECKING: + from agentrun.tool.tool import Tool as ToolResource from agentrun.toolset import ToolSet from agentrun.utils.log import logger @@ -859,6 +860,57 @@ def from_agentrun_toolset( return CommonToolSet(integration_tools) + @classmethod + def from_agentrun_tool( + cls, + tool_resource: "ToolResource", + config: Optional[Any] = None, + refresh: bool = False, + ) -> "CommonToolSet": + """从 AgentRun ToolResource 创建通用工具集 / Create CommonToolSet from AgentRun ToolResource + + Args: + tool_resource: agentrun.tool.tool.Tool (ToolResource) 实例 / ToolResource instance + config: 额外的请求配置,调用工具时会自动合并 / Extra request config, merged automatically when calling tools + refresh: 是否先刷新最新信息 / Whether to refresh latest info first + + Returns: + 通用 ToolSet 实例,可直接调用 .to_openai_function()、.to_langchain() 等 + CommonToolSet instance, can directly call .to_openai_function(), .to_langchain(), etc. + + Example: + >>> from agentrun import ToolResource, ToolResourceClient + >>> from agentrun.integration.utils.tool import CommonToolSet + >>> + >>> client = ToolResourceClient() + >>> tool = client.get(name="my-tool") + >>> common_toolset = CommonToolSet.from_agentrun_tool(tool) + >>> + >>> openai_tools = common_toolset.to_openai_function() + >>> langchain_tools = common_toolset.to_langchain() + """ + + if refresh: + tool_resource = tool_resource.get(config=config) + + tools_meta = tool_resource.list_tools(config=config) or [] + integration_tools: List[Tool] = [] + seen_names: set = set() + + for meta in tools_meta: + tool = _build_tool_from_meta(tool_resource, meta, config) + if tool: + if tool.name in seen_names: + logger.warning( + f"Duplicate tool name '{tool.name}' detected, " + "second occurrence will be skipped" + ) + continue + seen_names.add(tool.name) + integration_tools.append(tool) + + return CommonToolSet(integration_tools) + def to_openai_function( self, prefix: Optional[str] = None, diff --git a/agentrun/tool/__client_async_template.py b/agentrun/tool/__client_async_template.py new file mode 100644 index 0000000..2f504a9 --- /dev/null +++ b/agentrun/tool/__client_async_template.py @@ -0,0 +1,53 @@ +"""Tool 客户端 / Tool Client + +此模块提供工具的客户端 API。 +This module provides the client API for tools. +""" + +from typing import Any, Dict, List, Optional + +from agentrun.tool.api.control import ToolControlAPI +from agentrun.utils.config import Config +from agentrun.utils.exception import HTTPError + +from .tool import Tool + + +class ToolClient: + """Tool 客户端 / Tool Client + + 提供工具的获取功能。 + Provides get function for tools. + """ + + def __init__(self, config: Optional[Config] = None): + """初始化客户端 / Initialize client + + Args: + config: 配置对象,可选 / Configuration object, optional + """ + self.__control_api = ToolControlAPI(config) + + async def get_async( + self, + name: str, + config: Optional[Config] = None, + ) -> "Tool": + """异步获取工具 / Get tool asynchronously + + Args: + name: 工具名称 / Tool name + config: 配置对象,可选 / Configuration object, optional + + Returns: + Tool: 工具资源对象 / Tool resource object + """ + try: + result = await self.__control_api.get_tool_async( + name=name, + config=config, + ) + except HTTPError as e: + raise e.to_resource_error("Tool", name) from e + + return Tool.from_inner_object(result) diff --git a/agentrun/tool/__init__.py b/agentrun/tool/__init__.py new file mode 100644 index 0000000..cce0e04 --- /dev/null +++ b/agentrun/tool/__init__.py @@ -0,0 +1,41 @@ +"""Tool 模块 / Tool Module + +此模块提供工具管理功能。 +This module provides tool management functionality. +""" + +from .api.control import ToolControlAPI +from .api.mcp import ToolMCPSession +from .api.openapi import ToolOpenAPIClient +from .client import ToolClient +from .model import ( + McpConfig, + ToolCodeConfiguration, + ToolContainerConfiguration, + ToolInfo, + ToolLogConfiguration, + ToolNASConfig, + ToolNetworkConfiguration, + ToolOSSMountConfig, + ToolSchema, + ToolType, +) +from .tool import Tool + +__all__ = [ + "ToolControlAPI", + "ToolMCPSession", + "ToolOpenAPIClient", + "ToolClient", + "Tool", + "ToolType", + "McpConfig", + "ToolCodeConfiguration", + "ToolContainerConfiguration", + "ToolInfo", + "ToolLogConfiguration", + "ToolNASConfig", + "ToolNetworkConfiguration", + "ToolOSSMountConfig", + "ToolSchema", +] diff --git a/agentrun/tool/__tool_async_template.py b/agentrun/tool/__tool_async_template.py new file mode 100644 index 0000000..74dd049 --- /dev/null +++ b/agentrun/tool/__tool_async_template.py @@ -0,0 +1,396 @@ +"""Tool 资源类 / Tool Resource Class + +提供工具资源的面向对象封装和完整生命周期管理。 +Provides object-oriented wrapper and complete lifecycle management for tool resources. +""" + +import io +import os +import shutil +from typing import Any, Dict, List, Optional +import zipfile + +import httpx +import pydash + +from agentrun.utils.config import Config +from agentrun.utils.log import logger +from agentrun.utils.model import BaseModel + +from .model import ( + McpConfig, + ToolCodeConfiguration, + ToolContainerConfiguration, + ToolInfo, + ToolLogConfiguration, + ToolNetworkConfiguration, + ToolOSSMountConfig, + ToolSchema, + ToolType, +) + + +class Tool(BaseModel): + """工具资源 / Tool Resource + + 提供工具的查询、调用等功能。 + Provides query, invocation and other functionality for tools. + + Attributes: + code_configuration: 代码包配置 / Code configuration + container_configuration: 容器配置 / Container configuration + created_time: 创建时间 / Creation time + data_endpoint: 数据链路端点 / Data endpoint + description: 描述 / Description + environment_variables: 环境变量 / Environment variables + gpu: GPU 配置 / GPU configuration + internet_access: 是否允许公网访问 / Whether internet access is allowed + last_modified_time: 最后修改时间 / Last modified time + log_configuration: 日志配置 / Log configuration + mcp_config: MCP 配置 / MCP configuration + memory: 内存大小(MB) / Memory size in MB + name: 工具名称 / Tool name + network_config: 网络配置 / Network configuration + oss_mount_config: OSS 挂载配置 / OSS mount configuration + protocol_spec: 协议规格(OpenAPI JSON) / Protocol spec (OpenAPI JSON) + protocol_type: 协议类型 / Protocol type + status: 状态 / Status + timeout: 超时时间(秒) / Timeout in seconds + tool_id: 工具 ID / Tool ID + tool_name: 工具名称 / Tool name + tool_type: 工具类型(MCP/FUNCTIONCALL) / Tool type + version_id: 版本 ID / Version ID + """ + + code_configuration: Optional[ToolCodeConfiguration] = None + """代码包配置 / Code configuration""" + + container_configuration: Optional[ToolContainerConfiguration] = None + """容器配置 / Container configuration""" + + created_time: Optional[str] = None + """创建时间 / Creation time""" + + data_endpoint: Optional[str] = None + """数据链路端点 / Data endpoint""" + + description: Optional[str] = None + """描述 / Description""" + + environment_variables: Optional[Dict[str, str]] = None + """环境变量 / Environment variables""" + + gpu: Optional[str] = None + """GPU 配置 / GPU configuration""" + + internet_access: Optional[bool] = None + """是否允许公网访问 / Whether internet access is allowed""" + + last_modified_time: Optional[str] = None + """最后修改时间 / Last modified time""" + + log_configuration: Optional[ToolLogConfiguration] = None + """日志配置 / Log configuration""" + + mcp_config: Optional[McpConfig] = None + """MCP 配置 / MCP configuration""" + + memory: Optional[int] = None + """内存大小(MB) / Memory size in MB""" + + name: Optional[str] = None + """工具名称 / Tool name""" + + network_config: Optional[ToolNetworkConfiguration] = None + """网络配置 / Network configuration""" + + oss_mount_config: Optional[ToolOSSMountConfig] = None + """OSS 挂载配置 / OSS mount configuration""" + + protocol_spec: Optional[str] = None + """协议规格(OpenAPI JSON 字符串) / Protocol spec (OpenAPI JSON string)""" + + protocol_type: Optional[str] = None + """协议类型 / Protocol type""" + + status: Optional[str] = None + """状态 / Status""" + + timeout: Optional[int] = None + """超时时间(秒) / Timeout in seconds""" + + tool_id: Optional[str] = None + """工具 ID / Tool ID""" + + tool_name: Optional[str] = None + """工具名称 / Tool name""" + + tool_type: Optional[str] = None + """工具类型(MCP/FUNCTIONCALL) / Tool type (MCP/FUNCTIONCALL)""" + + version_id: Optional[str] = None + """版本 ID / Version ID""" + + @classmethod + def __get_client(cls, config: Optional[Config] = None): + from .client import ToolClient + + return ToolClient(config) + + @classmethod + async def get_by_name_async( + cls, name: str, config: Optional[Config] = None + ) -> "Tool": + """异步通过名称获取工具 / Get tool by name asynchronously""" + cli = cls.__get_client(config) + return await cli.get_async(name=name) + + async def get_async(self, config: Optional[Config] = None) -> "Tool": + """异步刷新工具信息 / Refresh tool info asynchronously""" + effective_name = self.tool_name or self.name + if effective_name is None: + raise ValueError("Tool name is required to get the Tool.") + + result = await self.get_by_name_async( + name=effective_name, config=config + ) + return self.update_self(result) + + def _get_functioncall_server_url( + self, config: Optional[Config] = None + ) -> Optional[str]: + """获取 FunctionCall 工具的 fallback server URL / Get fallback server URL for FunctionCall tools + + 当 OpenAPI spec 中没有 servers 字段时,使用 data_endpoint 构造 URL。 + Constructs URL from data_endpoint when servers is not present in OpenAPI spec. + """ + effective_name = self.tool_name or self.name + data_endpoint = self.data_endpoint + if not data_endpoint: + cfg = Config.with_configs(config) + data_endpoint = cfg._data_endpoint + if not data_endpoint or not effective_name: + return None + return f"{data_endpoint}/tools/{effective_name}" + + def _get_tool_type(self) -> Optional[ToolType]: + """获取工具类型 / Get tool type""" + raw_type = self.tool_type + if raw_type: + try: + return ToolType(raw_type) + except ValueError: + return None + return None + + def _get_mcp_endpoint( + self, config: Optional[Config] = None + ) -> Optional[str]: + """获取 MCP 数据链路 URL / Get MCP data endpoint URL + + 根据 session_affinity 决定使用 /mcp 还是 /sse 路径。 + 如果 self.data_endpoint 为空,则从 Config 中获取。 + Determines /mcp or /sse path based on session_affinity. + Falls back to Config if self.data_endpoint is not set. + """ + effective_name = self.tool_name or self.name + data_endpoint = self.data_endpoint + if not data_endpoint: + cfg = Config.with_configs(config) + data_endpoint = cfg._data_endpoint + if not data_endpoint or not effective_name: + return None + + session_affinity = pydash.get( + self, "mcp_config.session_affinity", "MCP_SSE" + ) + + if session_affinity == "MCP_STREAMABLE": + return f"{data_endpoint}/tools/{effective_name}/mcp" + return f"{data_endpoint}/tools/{effective_name}/sse" + + async def list_tools_async( + self, config: Optional[Config] = None + ) -> List[ToolInfo]: + """异步获取子工具列表 / Get sub-tool list asynchronously + + 对于 MCP 类型,通过 MCP 协议获取工具列表。 + 对于 FUNCTIONCALL 类型,解析 protocol_spec 获取工具列表。 + For MCP type, gets tool list via MCP protocol. + For FUNCTIONCALL type, parses protocol_spec to get tool list. + + Returns: + List[ToolInfo]: 子工具信息列表 / List of sub-tool information + """ + tool_type = self._get_tool_type() + + if tool_type == ToolType.MCP: + from .api.mcp import ToolMCPSession + + mcp_endpoint = self._get_mcp_endpoint(config) + if not mcp_endpoint: + logger.warning( + "MCP endpoint not available for tool %s", self.name + ) + return [] + + session_affinity = pydash.get( + self, "mcp_config.session_affinity", "MCP_SSE" + ) + + cfg = Config.with_configs(config) + session = ToolMCPSession( + endpoint=mcp_endpoint, + session_affinity=session_affinity, + headers=cfg.get_headers(), + ) + return await session.list_tools_async() + + elif tool_type == ToolType.FUNCTIONCALL: + from .api.openapi import ToolOpenAPIClient + + openapi_client = ToolOpenAPIClient( + protocol_spec=self.protocol_spec, + fallback_server_url=self._get_functioncall_server_url(config), + ) + return await openapi_client.list_tools_async() + + return [] + + async def call_tool_async( + self, + name: str, + arguments: Optional[Dict[str, Any]] = None, + config: Optional[Config] = None, + ) -> Any: + """异步调用子工具 / Call sub-tool asynchronously + + Args: + name: 子工具名称 / Sub-tool name + arguments: 调用参数 / Call arguments + config: 配置对象,可选 / Configuration object, optional + + Returns: + Any: 工具执行结果 / Tool execution result + """ + tool_type = self._get_tool_type() + + logger.debug("invoke tool %s with arguments %s", name, arguments) + + if tool_type == ToolType.MCP: + from .api.mcp import ToolMCPSession + + mcp_endpoint = self._get_mcp_endpoint(config) + if not mcp_endpoint: + raise ValueError( + f"MCP endpoint not available for tool {self.name}" + ) + + session_affinity = pydash.get( + self, "mcp_config.session_affinity", "MCP_SSE" + ) + + cfg = Config.with_configs(config) + session = ToolMCPSession( + endpoint=mcp_endpoint, + session_affinity=session_affinity, + headers=cfg.get_headers(), + ) + result = await session.call_tool_async(name, arguments) + logger.debug("invoke tool %s got result %s", name, result) + return result + + elif tool_type == ToolType.FUNCTIONCALL: + from .api.openapi import ToolOpenAPIClient + + cfg = Config.with_configs(config) + openapi_client = ToolOpenAPIClient( + protocol_spec=self.protocol_spec, + headers=cfg.get_headers(), + fallback_server_url=self._get_functioncall_server_url(config), + ) + result = await openapi_client.call_tool_async(name, arguments) + logger.debug("invoke tool %s got result %s", name, result) + return result + + raise ValueError(f"Unsupported tool type: {self.tool_type}") + + def _get_skill_download_url( + self, config: Optional[Config] = None + ) -> Optional[str]: + """获取 Skill 工具的下载 URL / Get download URL for Skill tools + + 根据 data_endpoint 和 tool_name 构造下载地址。 + Constructs download URL from data_endpoint and tool_name. + + Returns: + Optional[str]: 下载 URL / Download URL + """ + effective_name = self.tool_name or self.name + data_endpoint = self.data_endpoint + if not data_endpoint: + cfg = Config.with_configs(config) + data_endpoint = cfg._data_endpoint + if not data_endpoint or not effective_name: + return None + return f"{data_endpoint}/tools/{effective_name}/download" + + async def download_skill_async( + self, + target_dir: str = ".skills", + config: Optional[Config] = None, + ) -> str: + """异步下载 Skill 包并解压到本地目录 / Download skill package and extract to local directory asynchronously + + 从数据链路下载 skill 的 zip 包,并解压到 {target_dir}/{tool_name}/ 目录下。 + Downloads skill zip package from data endpoint and extracts to {target_dir}/{tool_name}/ directory. + + Args: + target_dir: 目标根目录,默认为 ".skills" / Target root directory, defaults to ".skills" + config: 配置对象,可选 / Configuration object, optional + + Returns: + str: 解压后的 skill 目录路径 / Extracted skill directory path + + Raises: + ValueError: 工具类型不是 SKILL 或缺少必要信息 / Tool type is not SKILL or missing required info + httpx.HTTPStatusError: 下载失败 / Download failed + """ + tool_type = self._get_tool_type() + if tool_type != ToolType.SKILL: + raise ValueError( + "download_skill is only available for SKILL type tools," + f" got {self.tool_type}" + ) + + download_url = self._get_skill_download_url(config) + if not download_url: + raise ValueError( + "Cannot construct download URL: data_endpoint or tool_name" + " is missing" + ) + + effective_name = self.tool_name or self.name + skill_dir = os.path.join(target_dir, effective_name or "unknown_skill") + + logger.debug("downloading skill from %s to %s", download_url, skill_dir) + + cfg = Config.with_configs(config) + headers = cfg.get_headers() + + async with httpx.AsyncClient( + timeout=300, follow_redirects=True + ) as http_client: + response = await http_client.get(download_url, headers=headers) + response.raise_for_status() + + if os.path.exists(skill_dir): + shutil.rmtree(skill_dir) + os.makedirs(skill_dir, exist_ok=True) + + zip_buffer = io.BytesIO(response.content) + with zipfile.ZipFile(zip_buffer, "r") as zip_file: + zip_file.extractall(skill_dir) + + logger.info("skill downloaded and extracted to %s", skill_dir) + return skill_dir diff --git a/agentrun/tool/api/__init__.py b/agentrun/tool/api/__init__.py new file mode 100644 index 0000000..fd1b7de --- /dev/null +++ b/agentrun/tool/api/__init__.py @@ -0,0 +1 @@ +"""Tool API 模块 / Tool API Module""" diff --git a/agentrun/tool/api/control.py b/agentrun/tool/api/control.py new file mode 100644 index 0000000..b2630ba --- /dev/null +++ b/agentrun/tool/api/control.py @@ -0,0 +1,128 @@ +"""Tool 管控链路 API / Tool Control API + +通过底层 agentrun20250910 SDK 与平台交互,获取 Tool 资源。 +Interacts with the platform via the agentrun20250910 SDK to get Tool resources. +""" + +from typing import Dict, Optional + +from alibabacloud_agentrun20250910.models import Tool as InnerTool +from alibabacloud_tea_openapi.exceptions._client import ClientException +from alibabacloud_tea_openapi.exceptions._server import ServerException +from darabonba.runtime import RuntimeOptions +import pydash + +from agentrun.utils.config import Config +from agentrun.utils.control_api import ControlAPI +from agentrun.utils.exception import ClientError, ServerError +from agentrun.utils.log import logger + + +class ToolControlAPI(ControlAPI): + """Tool 管控链路 API / Tool Control API""" + + def __init__(self, config: Optional[Config] = None): + """初始化 API 客户端 / Initialize API client + + Args: + config: 全局配置对象 / Global configuration object + """ + super().__init__(config) + + def get_tool( + self, + name: str, + headers: Optional[Dict[str, str]] = None, + config: Optional[Config] = None, + ) -> InnerTool: + """获取工具 / Get tool + + Args: + name: Tool 名称 / Tool name + headers: 请求头 / Request headers + config: 配置 / Configuration + + Returns: + InnerTool: 底层 SDK 的 Tool 对象 / Inner SDK Tool object + + Raises: + ClientError: 客户端错误 / Client error + ServerError: 服务器错误 / Server error + """ + try: + client = self._get_client(config) + response = client.get_tool_with_options( + name, + headers=headers or {}, + runtime=RuntimeOptions(), + ) + + logger.debug( + "request api get_tool, request Request ID:" + f" {response.headers['x-acs-request-id'] if response.headers else ''}\n" + f" request: {[name]}\n response: {response.body.data}" + ) + + return response.body.data + except ClientException as e: + raise ClientError( + e.status_code, + pydash.get(e, "data.message", pydash.get(e, "message", "")), + request_id=e.request_id, + request=[name], + ) from e + except ServerException as e: + raise ServerError( + e.status_code, + pydash.get(e, "data.message", pydash.get(e, "message", "")), + request_id=e.request_id, + ) from e + + async def get_tool_async( + self, + name: str, + headers: Optional[Dict[str, str]] = None, + config: Optional[Config] = None, + ) -> InnerTool: + """异步获取工具 / Get tool asynchronously + + Args: + name: Tool 名称 / Tool name + headers: 请求头 / Request headers + config: 配置 / Configuration + + Returns: + InnerTool: 底层 SDK 的 Tool 对象 / Inner SDK Tool object + + Raises: + ClientError: 客户端错误 / Client error + ServerError: 服务器错误 / Server error + """ + try: + client = self._get_client(config) + response = await client.get_tool_with_options_async( + name, + headers=headers or {}, + runtime=RuntimeOptions(), + ) + + logger.debug( + "request api get_tool, request Request ID:" + f" {response.headers['x-acs-request-id'] if response.headers else ''}\n" + f" request: {[name]}\n response: {response.body.data}" + ) + + return response.body.data + except ClientException as e: + raise ClientError( + e.status_code, + pydash.get(e, "data.message", pydash.get(e, "message", "")), + request_id=e.request_id, + request=[name], + ) from e + except ServerException as e: + raise ServerError( + e.status_code, + pydash.get(e, "data.message", pydash.get(e, "message", "")), + request_id=e.request_id, + ) from e diff --git a/agentrun/tool/api/mcp.py b/agentrun/tool/api/mcp.py new file mode 100644 index 0000000..a0cef61 --- /dev/null +++ b/agentrun/tool/api/mcp.py @@ -0,0 +1,167 @@ +"""Tool MCP 数据链路 / Tool MCP Data API + +通过 MCP 协议与 Tool 的数据链路交互,支持 SSE 和 Streamable HTTP 两种传输方式。 +Interacts with Tool data endpoints via MCP protocol, supporting SSE and Streamable HTTP transports. +""" + +import asyncio +from typing import Any, Dict, List, Optional + +from agentrun.tool.model import ToolInfo, ToolSchema +from agentrun.utils.log import logger + + +class ToolMCPSession: + """Tool MCP 会话管理 / Tool MCP Session Manager + + 独立实现的 MCP 会话管理,支持 SSE 和 Streamable HTTP 两种传输方式。 + Independent MCP session manager supporting SSE and Streamable HTTP transports. + """ + + def __init__( + self, + endpoint: str, + session_affinity: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + ): + """初始化 MCP 会话 / Initialize MCP session + + Args: + endpoint: MCP 数据链路 URL / MCP data endpoint URL + session_affinity: 会话亲和性策略 / Session affinity strategy + headers: 请求头 / Request headers + """ + self.endpoint = endpoint + self.session_affinity = session_affinity + self.headers = headers or {} + + @property + def is_streamable(self) -> bool: + """是否使用 Streamable HTTP 传输 / Whether to use Streamable HTTP transport""" + return self.session_affinity == "MCP_STREAMABLE" + + async def list_tools_async(self) -> List[ToolInfo]: + """异步获取工具列表 / Get tool list asynchronously + + Returns: + List[ToolInfo]: 工具信息列表 / List of tool information + """ + try: + from mcp import ClientSession + + if self.is_streamable: + from mcp.client.streamable_http import streamablehttp_client + + async with streamablehttp_client( + self.endpoint, headers=self.headers + ) as (read_stream, write_stream, _): + async with ClientSession( + read_stream, write_stream + ) as session: + await session.initialize() + result = await session.list_tools() + return [ + ToolInfo.from_mcp_tool(tool) + for tool in result.tools + ] + else: + from mcp.client.sse import sse_client + + async with sse_client(self.endpoint, headers=self.headers) as ( + read_stream, + write_stream, + ): + async with ClientSession( + read_stream, write_stream + ) as session: + await session.initialize() + result = await session.list_tools() + return [ + ToolInfo.from_mcp_tool(tool) + for tool in result.tools + ] + except ImportError: + logger.warning( + "mcp package is not installed. Install it with: pip install mcp" + ) + return [] + + def list_tools(self) -> List[ToolInfo]: + """同步获取工具列表 / Get tool list synchronously + + Returns: + List[ToolInfo]: 工具信息列表 / List of tool information + """ + return asyncio.get_event_loop().run_until_complete( + self.list_tools_async() + ) + + async def call_tool_async( + self, + name: str, + arguments: Optional[Dict[str, Any]] = None, + ) -> Any: + """异步调用工具 / Call tool asynchronously + + Args: + name: 子工具名称 / Sub-tool name + arguments: 调用参数 / Call arguments + + Returns: + Any: 工具执行结果 / Tool execution result + """ + try: + from mcp import ClientSession + + if self.is_streamable: + from mcp.client.streamable_http import streamablehttp_client + + async with streamablehttp_client( + self.endpoint, headers=self.headers + ) as (read_stream, write_stream, _): + async with ClientSession( + read_stream, write_stream + ) as session: + await session.initialize() + result = await session.call_tool( + name, arguments=arguments or {} + ) + return result + else: + from mcp.client.sse import sse_client + + async with sse_client(self.endpoint, headers=self.headers) as ( + read_stream, + write_stream, + ): + async with ClientSession( + read_stream, write_stream + ) as session: + await session.initialize() + result = await session.call_tool( + name, arguments=arguments or {} + ) + return result + except ImportError: + raise ImportError( + "mcp package is required for MCP tool calls. " + "Install it with: pip install mcp" + ) + + def call_tool( + self, + name: str, + arguments: Optional[Dict[str, Any]] = None, + ) -> Any: + """同步调用工具 / Call tool synchronously + + Args: + name: 子工具名称 / Sub-tool name + arguments: 调用参数 / Call arguments + + Returns: + Any: 工具执行结果 / Tool execution result + """ + return asyncio.get_event_loop().run_until_complete( + self.call_tool_async(name, arguments) + ) diff --git a/agentrun/tool/api/openapi.py b/agentrun/tool/api/openapi.py new file mode 100644 index 0000000..5873c7f --- /dev/null +++ b/agentrun/tool/api/openapi.py @@ -0,0 +1,337 @@ +"""Tool OpenAPI 数据链路 / Tool OpenAPI Data API + +解析 FunctionCall 类型 Tool 的 protocol_spec(OpenAPI JSON), +提取 operations 转换为 ToolInfo 列表,并通过 Server URL 发起 HTTP 调用。 +Parses protocol_spec (OpenAPI JSON) for FunctionCall type Tools, +extracts operations as ToolInfo list, and makes HTTP calls via Server URL. +""" + +import json +from typing import Any, Dict, List, Optional + +import httpx + +from agentrun.tool.model import ToolInfo, ToolSchema +from agentrun.utils.log import logger + + +class ToolOpenAPIClient: + """FunctionCall 类型 Tool 的 OpenAPI 客户端 / OpenAPI Client for FunctionCall Tools + + 解析 protocol_spec 中的 OpenAPI Schema,提供 list_tools 和 call_tool 能力。 + Parses OpenAPI Schema from protocol_spec, provides list_tools and call_tool capabilities. + """ + + def __init__( + self, + protocol_spec: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + fallback_server_url: Optional[str] = None, + ): + """初始化 OpenAPI 客户端 / Initialize OpenAPI client + + Args: + protocol_spec: OpenAPI JSON 字符串 / OpenAPI JSON string + headers: 请求头 / Request headers + fallback_server_url: 当 OpenAPI spec 中没有 servers 时的备用 URL / + Fallback URL when servers is not present in OpenAPI spec + """ + self.headers = headers or {} + self._fallback_server_url = fallback_server_url + self._spec: Optional[Dict[str, Any]] = None + self._operations: Optional[List[Dict[str, Any]]] = None + + if protocol_spec: + try: + self._spec = json.loads(protocol_spec) + except (json.JSONDecodeError, TypeError): + logger.warning("Failed to parse protocol_spec as JSON") + + @property + def server_url(self) -> Optional[str]: + """获取 OpenAPI Schema 中的 Server URL / Get Server URL from OpenAPI Schema + + 优先从 spec.servers 获取,如果不存在则使用 fallback_server_url。 + Prefers spec.servers, falls back to fallback_server_url if not present. + """ + if self._spec: + servers = self._spec.get("servers", []) + if servers and isinstance(servers, list): + url = servers[0].get("url") + if url: + return url + return self._fallback_server_url + + def _resolve_ref(self, ref: str) -> Dict[str, Any]: + """解析 $ref 引用 / Resolve $ref reference + + 支持 JSON Pointer 格式的本地引用,如 #/components/schemas/WeatherRequest。 + Supports local JSON Pointer references like #/components/schemas/WeatherRequest. + + Args: + ref: $ref 字符串 / $ref string + + Returns: + 解析后的 schema 字典 / Resolved schema dict + """ + if not self._spec or not ref.startswith("#/"): + return {} + + parts = ref[2:].split("/") + current: Any = self._spec + for part in parts: + if isinstance(current, dict): + current = current.get(part, {}) + else: + return {} + return current if isinstance(current, dict) else {} + + def _resolve_schema( + self, schema: Optional[Dict[str, Any]] + ) -> Optional[Dict[str, Any]]: + """递归解析 schema 中的所有 $ref 引用 / Recursively resolve all $ref in schema + + Args: + schema: 可能包含 $ref 的 schema / Schema that may contain $ref + + Returns: + 解析后的完整 schema / Fully resolved schema + """ + if not schema or not isinstance(schema, dict): + return schema + + if "$ref" in schema: + resolved = self._resolve_ref(schema["$ref"]) + return self._resolve_schema(resolved) + + result = {} + for key, value in schema.items(): + if key == "properties" and isinstance(value, dict): + result[key] = { + prop_name: self._resolve_schema(prop_schema) or prop_schema + for prop_name, prop_schema in value.items() + } + elif key in ("items", "additionalProperties") and isinstance( + value, dict + ): + result[key] = self._resolve_schema(value) or value + elif key in ("anyOf", "oneOf", "allOf") and isinstance(value, list): + result[key] = [ + self._resolve_schema(item) or item for item in value + ] + else: + result[key] = value + + return result + + def _parse_operations(self) -> List[Dict[str, Any]]: + """解析 OpenAPI Schema 中的所有 operations / Parse all operations from OpenAPI Schema""" + if self._operations is not None: + return self._operations + + self._operations = [] + if not self._spec: + return self._operations + + paths = self._spec.get("paths", {}) + for path, path_item in paths.items(): + if not isinstance(path_item, dict): + continue + for method in ("get", "post", "put", "delete", "patch"): + operation = path_item.get(method) + if not operation or not isinstance(operation, dict): + continue + + operation_id = operation.get("operationId", f"{method}_{path}") + summary = operation.get("summary", "") + description = operation.get("description", "") + + request_body_schema = None + request_body = operation.get("requestBody", {}) + if isinstance(request_body, dict): + content = request_body.get("content", {}) + json_content = content.get("application/json", {}) + raw_schema = json_content.get("schema") + request_body_schema = self._resolve_schema(raw_schema) + + parameters_schema = None + parameters = operation.get("parameters", []) + if parameters and isinstance(parameters, list): + props = {} + required_params = [] + for param in parameters: + if not isinstance(param, dict): + continue + param_name = param.get("name", "") + param_schema = param.get("schema", {"type": "string"}) + param_schema["description"] = param.get( + "description", "" + ) + props[param_name] = param_schema + if param.get("required"): + required_params.append(param_name) + if props: + parameters_schema = { + "type": "object", + "properties": props, + } + if required_params: + parameters_schema["required"] = required_params + + input_schema = request_body_schema or parameters_schema + + self._operations.append({ + "operation_id": operation_id, + "summary": summary, + "description": description, + "method": method.upper(), + "path": path, + "input_schema": input_schema, + }) + + return self._operations + + def list_tools(self) -> List[ToolInfo]: + """获取工具列表 / Get tool list + + Returns: + List[ToolInfo]: 工具信息列表 / List of tool information + """ + operations = self._parse_operations() + tools = [] + for operation in operations: + parameters = None + if operation.get("input_schema"): + parameters = ToolSchema.from_any_openapi_schema( + operation["input_schema"] + ) + + tool_description = operation["summary"] or operation["description"] + tools.append( + ToolInfo( + name=operation["operation_id"], + description=tool_description, + parameters=parameters + or ToolSchema(type="object", properties={}), + ) + ) + return tools + + async def list_tools_async(self) -> List[ToolInfo]: + """异步获取工具列表 / Get tool list asynchronously""" + return self.list_tools() + + def call_tool( + self, + name: str, + arguments: Optional[Dict[str, Any]] = None, + ) -> Any: + """调用工具 / Call tool + + Args: + name: operationId / Operation ID + arguments: 调用参数 / Call arguments + + Returns: + Any: 调用结果 / Call result + + Raises: + ValueError: operation 不存在 / Operation not found + """ + operations = self._parse_operations() + target_operation = None + for operation in operations: + if operation["operation_id"] == name: + target_operation = operation + break + + if not target_operation: + raise ValueError( + f"Operation '{name}' not found in OpenAPI spec. Available" + f" operations: {[op['operation_id'] for op in operations]}" + ) + + base_url = self.server_url + if not base_url: + raise ValueError("No server URL found in OpenAPI spec") + + url = f"{base_url.rstrip('/')}{target_operation['path']}" + method = target_operation["method"] + + logger.debug( + f"Calling FunctionCall tool: {method} {url} with args={arguments}" + ) + + with httpx.Client(headers=self.headers, timeout=30.0) as client: + if method in ("POST", "PUT", "PATCH"): + response = client.request(method, url, json=arguments or {}) + else: + response = client.request(method, url, params=arguments or {}) + + response.raise_for_status() + + content_type = response.headers.get("content-type", "") + if "application/json" in content_type: + return response.json() + return response.text + + async def call_tool_async( + self, + name: str, + arguments: Optional[Dict[str, Any]] = None, + ) -> Any: + """异步调用工具 / Call tool asynchronously + + Args: + name: operationId / Operation ID + arguments: 调用参数 / Call arguments + + Returns: + Any: 调用结果 / Call result + + Raises: + ValueError: operation 不存在 / Operation not found + """ + operations = self._parse_operations() + target_operation = None + for operation in operations: + if operation["operation_id"] == name: + target_operation = operation + break + + if not target_operation: + raise ValueError( + f"Operation '{name}' not found in OpenAPI spec. Available" + f" operations: {[op['operation_id'] for op in operations]}" + ) + + base_url = self.server_url + if not base_url: + raise ValueError("No server URL found in OpenAPI spec") + + url = f"{base_url.rstrip('/')}{target_operation['path']}" + method = target_operation["method"] + + logger.debug( + f"Calling FunctionCall tool async: {method} {url} with" + f" args={arguments}" + ) + + async with httpx.AsyncClient( + headers=self.headers, timeout=30.0 + ) as client: + if method in ("POST", "PUT", "PATCH"): + response = await client.request( + method, url, json=arguments or {} + ) + else: + response = await client.request( + method, url, params=arguments or {} + ) + + response.raise_for_status() + + content_type = response.headers.get("content-type", "") + if "application/json" in content_type: + return response.json() + return response.text diff --git a/agentrun/tool/client.py b/agentrun/tool/client.py new file mode 100644 index 0000000..048de3b --- /dev/null +++ b/agentrun/tool/client.py @@ -0,0 +1,87 @@ +""" +This file is auto generated by the code generation script. +Do not modify this file manually. +Use the `make codegen` command to regenerate. + +当前文件为自动生成的控制 API 客户端代码。请勿手动修改此文件。 +使用 `make codegen` 命令重新生成。 + +source: agentrun/tool/__client_async_template.py + +Tool 客户端 / Tool Client + +此模块提供工具的客户端 API。 +This module provides the client API for tools. +""" + +from typing import Any, Dict, List, Optional + +from agentrun.tool.api.control import ToolControlAPI +from agentrun.utils.config import Config +from agentrun.utils.exception import HTTPError + +from .tool import Tool + + +class ToolClient: + """Tool 客户端 / Tool Client + + 提供工具的获取功能。 + Provides get function for tools. + """ + + def __init__(self, config: Optional[Config] = None): + """初始化客户端 / Initialize client + + Args: + config: 配置对象,可选 / Configuration object, optional + """ + self.__control_api = ToolControlAPI(config) + + async def get_async( + self, + name: str, + config: Optional[Config] = None, + ) -> "Tool": + """异步获取工具 / Get tool asynchronously + + Args: + name: 工具名称 / Tool name + config: 配置对象,可选 / Configuration object, optional + + Returns: + Tool: 工具资源对象 / Tool resource object + """ + try: + result = await self.__control_api.get_tool_async( + name=name, + config=config, + ) + except HTTPError as e: + raise e.to_resource_error("Tool", name) from e + + return Tool.from_inner_object(result) + + def get( + self, + name: str, + config: Optional[Config] = None, + ) -> "Tool": + """同步获取工具 / Get tool synchronously + + Args: + name: 工具名称 / Tool name + config: 配置对象,可选 / Configuration object, optional + + Returns: + Tool: 工具资源对象 / Tool resource object + """ + try: + result = self.__control_api.get_tool( + name=name, + config=config, + ) + except HTTPError as e: + raise e.to_resource_error("Tool", name) from e + + return Tool.from_inner_object(result) diff --git a/agentrun/tool/model.py b/agentrun/tool/model.py new file mode 100644 index 0000000..2d8cc81 --- /dev/null +++ b/agentrun/tool/model.py @@ -0,0 +1,408 @@ +"""Tool 模型定义 / Tool Model Definitions + +定义工具相关的数据模型和枚举。 +Defines data models and enumerations related to tools. +""" + +from enum import Enum +from typing import Any, Dict, List, Optional + +from agentrun.utils.model import BaseModel + + +class ToolType(str, Enum): + """工具类型 / Tool Type""" + + MCP = "MCP" + """MCP 协议工具 / MCP Protocol Tool""" + FUNCTIONCALL = "FUNCTIONCALL" + """函数调用工具 / Function Call Tool""" + SKILL = "SKILL" + """技能工具 / Skill Tool""" + + +class McpConfig(BaseModel): + """MCP 工具配置 / MCP Tool Configuration + + 包含 MCP 工具的会话亲和性、代理配置等信息。 + Contains session affinity, proxy configuration, etc. for MCP tools. + """ + + session_affinity: Optional[str] = None + """会话亲和性策略 / Session affinity strategy + NONE: 无亲和性 / No affinity + MCP_SSE: 基于 SSE 的会话亲和性 / SSE-based session affinity + MCP_STREAMABLE: 基于流式 HTTP 的会话亲和性 / Streamable HTTP-based session affinity + """ + + session_affinity_config: Optional[str] = None + """会话亲和性的详细配置,JSON 格式字符串 / Session affinity config, JSON string""" + + proxy_enabled: Optional[bool] = None + """是否启用 MCP 代理 / Whether MCP proxy is enabled""" + + bound_configuration: Optional[Dict[str, Any]] = None + """工具的绑定配置 / Tool binding configuration""" + + mcp_proxy_configuration: Optional[Dict[str, Any]] = None + """MCP 代理的详细配置 / MCP proxy detailed configuration""" + + +class ToolCodeConfiguration(BaseModel): + """代码包配置 / Code Configuration + + 代码包类型工具的配置信息。 + Configuration for code-package type tools. + """ + + code_checksum: Optional[str] = None + """代码校验和 / Code checksum""" + + code_size: Optional[int] = None + """代码大小(字节)/ Code size in bytes""" + + command: Optional[List[str]] = None + """启动命令 / Startup command""" + + language: Optional[str] = None + """编程语言 / Programming language""" + + oss_bucket_name: Optional[str] = None + """OSS 存储桶名称 / OSS bucket name""" + + oss_object_name: Optional[str] = None + """OSS 对象名称 / OSS object name""" + + +class ToolContainerConfiguration(BaseModel): + """容器配置 / Container Configuration + + 容器类型工具的配置信息。 + Configuration for container type tools. + """ + + args: Optional[List[str]] = None + """容器启动参数 / Container startup arguments""" + + command: Optional[List[str]] = None + """容器启动命令 / Container startup command""" + + image: Optional[str] = None + """容器镜像地址 / Container image URL""" + + port: Optional[int] = None + """容器端口 / Container port""" + + +class ToolLogConfiguration(BaseModel): + """日志配置 / Log Configuration + + 工具的日志配置信息。 + Log configuration for tools. + """ + + log_store: Optional[str] = None + """SLS 日志库 / SLS log store""" + + project: Optional[str] = None + """SLS 项目 / SLS project""" + + +class ToolNASConfig(BaseModel): + """NAS 文件存储配置 / NAS Configuration + + 工具访问 NAS 文件系统的配置。 + Configuration for tool access to NAS file system. + """ + + group_id: Optional[int] = None + """组 ID / Group ID""" + + mount_points: Optional[List[Dict[str, Any]]] = None + """挂载点列表 / Mount points list""" + + user_id: Optional[int] = None + """用户 ID / User ID""" + + +class ToolNetworkConfiguration(BaseModel): + """网络配置 / Network Configuration + + 工具的网络配置信息。 + Network configuration for tools. + """ + + security_group_id: Optional[str] = None + """安全组 ID / Security group ID""" + + vpc_id: Optional[str] = None + """VPC ID""" + + vswitch_ids: Optional[List[str]] = None + """交换机 ID 列表 / VSwitch IDs""" + + +class ToolOSSMountConfig(BaseModel): + """OSS 挂载配置 / OSS Mount Configuration + + 工具访问 OSS 存储的挂载配置。 + Configuration for tool access to OSS storage. + """ + + mount_points: Optional[List[Dict[str, Any]]] = None + """挂载点列表 / Mount points list""" + + +class ToolSchema(BaseModel): + """JSON Schema 兼容的工具参数描述 / JSON Schema Compatible Tool Parameter Description + + 支持完整的 JSON Schema 字段,能够描述复杂的嵌套数据结构。 + Supports full JSON Schema fields for describing complex nested data structures. + """ + + type: Optional[str] = None + """数据类型 / Data type""" + + description: Optional[str] = None + """描述信息 / Description""" + + title: Optional[str] = None + """标题 / Title""" + + properties: Optional[Dict[str, "ToolSchema"]] = None + """对象属性 / Object properties""" + + required: Optional[List[str]] = None + """必填字段 / Required fields""" + + additional_properties: Optional[bool] = None + """是否允许额外属性 / Whether additional properties are allowed""" + + items: Optional["ToolSchema"] = None + """数组元素类型 / Array item type""" + + min_items: Optional[int] = None + """数组最小长度 / Minimum array length""" + + max_items: Optional[int] = None + """数组最大长度 / Maximum array length""" + + pattern: Optional[str] = None + """字符串正则模式 / String regex pattern""" + + min_length: Optional[int] = None + """字符串最小长度 / Minimum string length""" + + max_length: Optional[int] = None + """字符串最大长度 / Maximum string length""" + + format: Optional[str] = None + """字符串格式 / String format (date, date-time, email, uri, etc.)""" + + enum: Optional[List[Any]] = None + """枚举值 / Enum values""" + + minimum: Optional[float] = None + """数值最小值 / Minimum numeric value""" + + maximum: Optional[float] = None + """数值最大值 / Maximum numeric value""" + + exclusive_minimum: Optional[float] = None + """数值排他最小值 / Exclusive minimum numeric value""" + + exclusive_maximum: Optional[float] = None + """数值排他最大值 / Exclusive maximum numeric value""" + + any_of: Optional[List["ToolSchema"]] = None + """任一匹配 / Any of""" + + one_of: Optional[List["ToolSchema"]] = None + """唯一匹配 / One of""" + + all_of: Optional[List["ToolSchema"]] = None + """全部匹配 / All of""" + + default: Optional[Any] = None + """默认值 / Default value""" + + @classmethod + def from_any_openapi_schema(cls, schema: Any) -> "ToolSchema": + """从任意 OpenAPI/JSON Schema 创建 ToolSchema / Create ToolSchema from any OpenAPI/JSON Schema + + 递归解析所有嵌套结构,保留完整的 schema 信息。 + Recursively parses all nested structures, preserving complete schema information. + """ + if not schema or not isinstance(schema, dict): + return cls(type="string") + + from pydash import get as pydash_get + + properties_raw = pydash_get(schema, "properties", {}) + properties = ( + { + key: cls.from_any_openapi_schema(value) + for key, value in properties_raw.items() + } + if properties_raw + else None + ) + + items_raw = pydash_get(schema, "items") + items = cls.from_any_openapi_schema(items_raw) if items_raw else None + + any_of_raw = pydash_get(schema, "anyOf") + any_of = ( + [cls.from_any_openapi_schema(s) for s in any_of_raw] + if any_of_raw + else None + ) + + one_of_raw = pydash_get(schema, "oneOf") + one_of = ( + [cls.from_any_openapi_schema(s) for s in one_of_raw] + if one_of_raw + else None + ) + + all_of_raw = pydash_get(schema, "allOf") + all_of = ( + [cls.from_any_openapi_schema(s) for s in all_of_raw] + if all_of_raw + else None + ) + + return cls( + type=pydash_get(schema, "type"), + description=pydash_get(schema, "description"), + title=pydash_get(schema, "title"), + properties=properties, + required=pydash_get(schema, "required"), + additional_properties=pydash_get(schema, "additionalProperties"), + items=items, + min_items=pydash_get(schema, "minItems"), + max_items=pydash_get(schema, "maxItems"), + pattern=pydash_get(schema, "pattern"), + min_length=pydash_get(schema, "minLength"), + max_length=pydash_get(schema, "maxLength"), + format=pydash_get(schema, "format"), + enum=pydash_get(schema, "enum"), + minimum=pydash_get(schema, "minimum"), + maximum=pydash_get(schema, "maximum"), + exclusive_minimum=pydash_get(schema, "exclusiveMinimum"), + exclusive_maximum=pydash_get(schema, "exclusiveMaximum"), + any_of=any_of, + one_of=one_of, + all_of=all_of, + default=pydash_get(schema, "default"), + ) + + def to_json_schema(self) -> Dict[str, Any]: + """转换为标准 JSON Schema 格式 / Convert to standard JSON Schema format""" + result: Dict[str, Any] = {} + + if self.type: + result["type"] = self.type + if self.description: + result["description"] = self.description + if self.title: + result["title"] = self.title + + if self.properties: + result["properties"] = { + k: v.to_json_schema() for k, v in self.properties.items() + } + if self.required: + result["required"] = self.required + if self.additional_properties is not None: + result["additionalProperties"] = self.additional_properties + + if self.items: + result["items"] = self.items.to_json_schema() + if self.min_items is not None: + result["minItems"] = self.min_items + if self.max_items is not None: + result["maxItems"] = self.max_items + + if self.pattern: + result["pattern"] = self.pattern + if self.min_length is not None: + result["minLength"] = self.min_length + if self.max_length is not None: + result["maxLength"] = self.max_length + if self.format: + result["format"] = self.format + if self.enum: + result["enum"] = self.enum + + if self.minimum is not None: + result["minimum"] = self.minimum + if self.maximum is not None: + result["maximum"] = self.maximum + if self.exclusive_minimum is not None: + result["exclusiveMinimum"] = self.exclusive_minimum + if self.exclusive_maximum is not None: + result["exclusiveMaximum"] = self.exclusive_maximum + + if self.any_of: + result["anyOf"] = [s.to_json_schema() for s in self.any_of] + if self.one_of: + result["oneOf"] = [s.to_json_schema() for s in self.one_of] + if self.all_of: + result["allOf"] = [s.to_json_schema() for s in self.all_of] + + if self.default is not None: + result["default"] = self.default + + return result + + +class ToolInfo(BaseModel): + """工具信息 / Tool Information + + 描述单个工具的名称、描述和参数 schema。 + Describes a single tool's name, description, and parameter schema. + """ + + name: Optional[str] = None + """工具名称 / Tool name""" + + description: Optional[str] = None + """工具描述 / Tool description""" + + parameters: Optional[ToolSchema] = None + """工具参数 schema / Tool parameter schema""" + + @classmethod + def from_mcp_tool(cls, tool: Any) -> "ToolInfo": + """从 MCP tool 创建 ToolInfo / Create ToolInfo from MCP tool""" + if hasattr(tool, "name"): + tool_name = tool.name + tool_description = getattr(tool, "description", None) + input_schema = getattr(tool, "inputSchema", None) or getattr( + tool, "input_schema", None + ) + elif isinstance(tool, dict): + tool_name = tool.get("name") + tool_description = tool.get("description") + input_schema = tool.get("inputSchema") or tool.get("input_schema") + else: + raise ValueError(f"Unsupported MCP tool format: {type(tool)}") + + if not tool_name: + raise ValueError("MCP tool must have a name") + + parameters = None + if input_schema: + if isinstance(input_schema, dict): + parameters = ToolSchema.from_any_openapi_schema(input_schema) + elif hasattr(input_schema, "model_dump"): + parameters = ToolSchema.from_any_openapi_schema( + input_schema.model_dump() + ) + + return cls( + name=tool_name, + description=tool_description, + parameters=parameters or ToolSchema(type="object", properties={}), + ) diff --git a/agentrun/tool/tool.py b/agentrun/tool/tool.py new file mode 100644 index 0000000..b9eb565 --- /dev/null +++ b/agentrun/tool/tool.py @@ -0,0 +1,583 @@ +""" +This file is auto generated by the code generation script. +Do not modify this file manually. +Use the `make codegen` command to regenerate. + +当前文件为自动生成的控制 API 客户端代码。请勿手动修改此文件。 +使用 `make codegen` 命令重新生成。 + +source: agentrun/tool/__tool_async_template.py + +Tool 资源类 / Tool Resource Class + +提供工具资源的面向对象封装和完整生命周期管理。 +Provides object-oriented wrapper and complete lifecycle management for tool resources. +""" + +import io +import os +import shutil +from typing import Any, Dict, List, Optional +import zipfile + +import httpx +import pydash + +from agentrun.utils.config import Config +from agentrun.utils.log import logger +from agentrun.utils.model import BaseModel + +from .model import ( + McpConfig, + ToolCodeConfiguration, + ToolContainerConfiguration, + ToolInfo, + ToolLogConfiguration, + ToolNetworkConfiguration, + ToolOSSMountConfig, + ToolSchema, + ToolType, +) + + +class Tool(BaseModel): + """工具资源 / Tool Resource + + 提供工具的查询、调用等功能。 + Provides query, invocation and other functionality for tools. + + Attributes: + code_configuration: 代码包配置 / Code configuration + container_configuration: 容器配置 / Container configuration + created_time: 创建时间 / Creation time + data_endpoint: 数据链路端点 / Data endpoint + description: 描述 / Description + environment_variables: 环境变量 / Environment variables + gpu: GPU 配置 / GPU configuration + internet_access: 是否允许公网访问 / Whether internet access is allowed + last_modified_time: 最后修改时间 / Last modified time + log_configuration: 日志配置 / Log configuration + mcp_config: MCP 配置 / MCP configuration + memory: 内存大小(MB) / Memory size in MB + name: 工具名称 / Tool name + network_config: 网络配置 / Network configuration + oss_mount_config: OSS 挂载配置 / OSS mount configuration + protocol_spec: 协议规格(OpenAPI JSON) / Protocol spec (OpenAPI JSON) + protocol_type: 协议类型 / Protocol type + status: 状态 / Status + timeout: 超时时间(秒) / Timeout in seconds + tool_id: 工具 ID / Tool ID + tool_name: 工具名称 / Tool name + tool_type: 工具类型(MCP/FUNCTIONCALL) / Tool type + version_id: 版本 ID / Version ID + """ + + code_configuration: Optional[ToolCodeConfiguration] = None + """代码包配置 / Code configuration""" + + container_configuration: Optional[ToolContainerConfiguration] = None + """容器配置 / Container configuration""" + + created_time: Optional[str] = None + """创建时间 / Creation time""" + + data_endpoint: Optional[str] = None + """数据链路端点 / Data endpoint""" + + description: Optional[str] = None + """描述 / Description""" + + environment_variables: Optional[Dict[str, str]] = None + """环境变量 / Environment variables""" + + gpu: Optional[str] = None + """GPU 配置 / GPU configuration""" + + internet_access: Optional[bool] = None + """是否允许公网访问 / Whether internet access is allowed""" + + last_modified_time: Optional[str] = None + """最后修改时间 / Last modified time""" + + log_configuration: Optional[ToolLogConfiguration] = None + """日志配置 / Log configuration""" + + mcp_config: Optional[McpConfig] = None + """MCP 配置 / MCP configuration""" + + memory: Optional[int] = None + """内存大小(MB) / Memory size in MB""" + + name: Optional[str] = None + """工具名称 / Tool name""" + + network_config: Optional[ToolNetworkConfiguration] = None + """网络配置 / Network configuration""" + + oss_mount_config: Optional[ToolOSSMountConfig] = None + """OSS 挂载配置 / OSS mount configuration""" + + protocol_spec: Optional[str] = None + """协议规格(OpenAPI JSON 字符串) / Protocol spec (OpenAPI JSON string)""" + + protocol_type: Optional[str] = None + """协议类型 / Protocol type""" + + status: Optional[str] = None + """状态 / Status""" + + timeout: Optional[int] = None + """超时时间(秒) / Timeout in seconds""" + + tool_id: Optional[str] = None + """工具 ID / Tool ID""" + + tool_name: Optional[str] = None + """工具名称 / Tool name""" + + tool_type: Optional[str] = None + """工具类型(MCP/FUNCTIONCALL) / Tool type (MCP/FUNCTIONCALL)""" + + version_id: Optional[str] = None + """版本 ID / Version ID""" + + @classmethod + def __get_client(cls, config: Optional[Config] = None): + from .client import ToolClient + + return ToolClient(config) + + @classmethod + async def get_by_name_async( + cls, name: str, config: Optional[Config] = None + ) -> "Tool": + """异步通过名称获取工具 / Get tool by name asynchronously""" + cli = cls.__get_client(config) + return await cli.get_async(name=name) + + @classmethod + def get_by_name(cls, name: str, config: Optional[Config] = None) -> "Tool": + """同步通过名称获取工具 / Get tool by name synchronously""" + cli = cls.__get_client(config) + return cli.get(name=name) + + async def get_async(self, config: Optional[Config] = None) -> "Tool": + """异步刷新工具信息 / Refresh tool info asynchronously""" + effective_name = self.tool_name or self.name + if effective_name is None: + raise ValueError("Tool name is required to get the Tool.") + + result = await self.get_by_name_async( + name=effective_name, config=config + ) + return self.update_self(result) + + def get(self, config: Optional[Config] = None) -> "Tool": + """同步刷新工具信息 / Refresh tool info synchronously""" + effective_name = self.tool_name or self.name + if effective_name is None: + raise ValueError("Tool name is required to get the Tool.") + + result = self.get_by_name(name=effective_name, config=config) + return self.update_self(result) + + def _get_functioncall_server_url( + self, config: Optional[Config] = None + ) -> Optional[str]: + """获取 FunctionCall 工具的 fallback server URL / Get fallback server URL for FunctionCall tools + + 当 OpenAPI spec 中没有 servers 字段时,使用 data_endpoint 构造 URL。 + Constructs URL from data_endpoint when servers is not present in OpenAPI spec. + """ + effective_name = self.tool_name or self.name + data_endpoint = self.data_endpoint + if not data_endpoint: + cfg = Config.with_configs(config) + data_endpoint = cfg._data_endpoint + if not data_endpoint or not effective_name: + return None + return f"{data_endpoint}/tools/{effective_name}" + + def _get_tool_type(self) -> Optional[ToolType]: + """获取工具类型 / Get tool type""" + raw_type = self.tool_type + if raw_type: + try: + return ToolType(raw_type) + except ValueError: + return None + return None + + def _get_mcp_endpoint( + self, config: Optional[Config] = None + ) -> Optional[str]: + """获取 MCP 数据链路 URL / Get MCP data endpoint URL + + 根据 session_affinity 决定使用 /mcp 还是 /sse 路径。 + 如果 self.data_endpoint 为空,则从 Config 中获取。 + Determines /mcp or /sse path based on session_affinity. + Falls back to Config if self.data_endpoint is not set. + """ + effective_name = self.tool_name or self.name + data_endpoint = self.data_endpoint + if not data_endpoint: + cfg = Config.with_configs(config) + data_endpoint = cfg._data_endpoint + if not data_endpoint or not effective_name: + return None + + session_affinity = pydash.get( + self, "mcp_config.session_affinity", "MCP_SSE" + ) + + if session_affinity == "MCP_STREAMABLE": + return f"{data_endpoint}/tools/{effective_name}/mcp" + return f"{data_endpoint}/tools/{effective_name}/sse" + + async def list_tools_async( + self, config: Optional[Config] = None + ) -> List[ToolInfo]: + """异步获取子工具列表 / Get sub-tool list asynchronously + + 对于 MCP 类型,通过 MCP 协议获取工具列表。 + 对于 FUNCTIONCALL 类型,解析 protocol_spec 获取工具列表。 + For MCP type, gets tool list via MCP protocol. + For FUNCTIONCALL type, parses protocol_spec to get tool list. + + Returns: + List[ToolInfo]: 子工具信息列表 / List of sub-tool information + """ + tool_type = self._get_tool_type() + + if tool_type == ToolType.MCP: + from .api.mcp import ToolMCPSession + + mcp_endpoint = self._get_mcp_endpoint(config) + if not mcp_endpoint: + logger.warning( + "MCP endpoint not available for tool %s", self.name + ) + return [] + + session_affinity = pydash.get( + self, "mcp_config.session_affinity", "MCP_SSE" + ) + + cfg = Config.with_configs(config) + session = ToolMCPSession( + endpoint=mcp_endpoint, + session_affinity=session_affinity, + headers=cfg.get_headers(), + ) + return await session.list_tools_async() + + elif tool_type == ToolType.FUNCTIONCALL: + from .api.openapi import ToolOpenAPIClient + + openapi_client = ToolOpenAPIClient( + protocol_spec=self.protocol_spec, + fallback_server_url=self._get_functioncall_server_url(config), + ) + return await openapi_client.list_tools_async() + + return [] + + def list_tools(self, config: Optional[Config] = None) -> List[ToolInfo]: + """同步获取子工具列表 / Get sub-tool list synchronously + + 对于 MCP 类型,通过 MCP 协议获取工具列表。 + 对于 FUNCTIONCALL 类型,解析 protocol_spec 获取工具列表。 + For MCP type, gets tool list via MCP protocol. + For FUNCTIONCALL type, parses protocol_spec to get tool list. + + Returns: + List[ToolInfo]: 子工具信息列表 / List of sub-tool information + """ + tool_type = self._get_tool_type() + + if tool_type == ToolType.MCP: + from .api.mcp import ToolMCPSession + + mcp_endpoint = self._get_mcp_endpoint(config) + if not mcp_endpoint: + logger.warning( + "MCP endpoint not available for tool %s", self.name + ) + return [] + + session_affinity = pydash.get( + self, "mcp_config.session_affinity", "MCP_SSE" + ) + + cfg = Config.with_configs(config) + session = ToolMCPSession( + endpoint=mcp_endpoint, + session_affinity=session_affinity, + headers=cfg.get_headers(), + ) + return session.list_tools() + + elif tool_type == ToolType.FUNCTIONCALL: + from .api.openapi import ToolOpenAPIClient + + openapi_client = ToolOpenAPIClient( + protocol_spec=self.protocol_spec, + fallback_server_url=self._get_functioncall_server_url(config), + ) + return openapi_client.list_tools() + + return [] + + async def call_tool_async( + self, + name: str, + arguments: Optional[Dict[str, Any]] = None, + config: Optional[Config] = None, + ) -> Any: + """异步调用子工具 / Call sub-tool asynchronously + + Args: + name: 子工具名称 / Sub-tool name + arguments: 调用参数 / Call arguments + config: 配置对象,可选 / Configuration object, optional + + Returns: + Any: 工具执行结果 / Tool execution result + """ + tool_type = self._get_tool_type() + + logger.debug("invoke tool %s with arguments %s", name, arguments) + + if tool_type == ToolType.MCP: + from .api.mcp import ToolMCPSession + + mcp_endpoint = self._get_mcp_endpoint(config) + if not mcp_endpoint: + raise ValueError( + f"MCP endpoint not available for tool {self.name}" + ) + + session_affinity = pydash.get( + self, "mcp_config.session_affinity", "MCP_SSE" + ) + + cfg = Config.with_configs(config) + session = ToolMCPSession( + endpoint=mcp_endpoint, + session_affinity=session_affinity, + headers=cfg.get_headers(), + ) + result = await session.call_tool_async(name, arguments) + logger.debug("invoke tool %s got result %s", name, result) + return result + + elif tool_type == ToolType.FUNCTIONCALL: + from .api.openapi import ToolOpenAPIClient + + cfg = Config.with_configs(config) + openapi_client = ToolOpenAPIClient( + protocol_spec=self.protocol_spec, + headers=cfg.get_headers(), + fallback_server_url=self._get_functioncall_server_url(config), + ) + result = await openapi_client.call_tool_async(name, arguments) + logger.debug("invoke tool %s got result %s", name, result) + return result + + raise ValueError(f"Unsupported tool type: {self.tool_type}") + + def call_tool( + self, + name: str, + arguments: Optional[Dict[str, Any]] = None, + config: Optional[Config] = None, + ) -> Any: + """同步调用子工具 / Call sub-tool synchronously + + Args: + name: 子工具名称 / Sub-tool name + arguments: 调用参数 / Call arguments + config: 配置对象,可选 / Configuration object, optional + + Returns: + Any: 工具执行结果 / Tool execution result + """ + tool_type = self._get_tool_type() + + logger.debug("invoke tool %s with arguments %s", name, arguments) + + if tool_type == ToolType.MCP: + from .api.mcp import ToolMCPSession + + mcp_endpoint = self._get_mcp_endpoint(config) + if not mcp_endpoint: + raise ValueError( + f"MCP endpoint not available for tool {self.name}" + ) + + session_affinity = pydash.get( + self, "mcp_config.session_affinity", "MCP_SSE" + ) + + cfg = Config.with_configs(config) + session = ToolMCPSession( + endpoint=mcp_endpoint, + session_affinity=session_affinity, + headers=cfg.get_headers(), + ) + result = session.call_tool(name, arguments) + logger.debug("invoke tool %s got result %s", name, result) + return result + + elif tool_type == ToolType.FUNCTIONCALL: + from .api.openapi import ToolOpenAPIClient + + cfg = Config.with_configs(config) + openapi_client = ToolOpenAPIClient( + protocol_spec=self.protocol_spec, + headers=cfg.get_headers(), + fallback_server_url=self._get_functioncall_server_url(config), + ) + result = openapi_client.call_tool(name, arguments) + logger.debug("invoke tool %s got result %s", name, result) + return result + + raise ValueError(f"Unsupported tool type: {self.tool_type}") + + def _get_skill_download_url( + self, config: Optional[Config] = None + ) -> Optional[str]: + """获取 Skill 工具的下载 URL / Get download URL for Skill tools + + 根据 data_endpoint 和 tool_name 构造下载地址。 + Constructs download URL from data_endpoint and tool_name. + + Returns: + Optional[str]: 下载 URL / Download URL + """ + effective_name = self.tool_name or self.name + data_endpoint = self.data_endpoint + if not data_endpoint: + cfg = Config.with_configs(config) + data_endpoint = cfg._data_endpoint + if not data_endpoint or not effective_name: + return None + return f"{data_endpoint}/tools/{effective_name}/download" + + async def download_skill_async( + self, + target_dir: str = ".skills", + config: Optional[Config] = None, + ) -> str: + """异步下载 Skill 包并解压到本地目录 / Download skill package and extract to local directory asynchronously + + 从数据链路下载 skill 的 zip 包,并解压到 {target_dir}/{tool_name}/ 目录下。 + Downloads skill zip package from data endpoint and extracts to {target_dir}/{tool_name}/ directory. + + Args: + target_dir: 目标根目录,默认为 ".skills" / Target root directory, defaults to ".skills" + config: 配置对象,可选 / Configuration object, optional + + Returns: + str: 解压后的 skill 目录路径 / Extracted skill directory path + + Raises: + ValueError: 工具类型不是 SKILL 或缺少必要信息 / Tool type is not SKILL or missing required info + httpx.HTTPStatusError: 下载失败 / Download failed + """ + tool_type = self._get_tool_type() + if tool_type != ToolType.SKILL: + raise ValueError( + "download_skill is only available for SKILL type tools," + f" got {self.tool_type}" + ) + + download_url = self._get_skill_download_url(config) + if not download_url: + raise ValueError( + "Cannot construct download URL: data_endpoint or tool_name" + " is missing" + ) + + effective_name = self.tool_name or self.name + skill_dir = os.path.join(target_dir, effective_name or "unknown_skill") + + logger.debug("downloading skill from %s to %s", download_url, skill_dir) + + cfg = Config.with_configs(config) + headers = cfg.get_headers() + + async with httpx.AsyncClient( + timeout=300, follow_redirects=True + ) as http_client: + response = await http_client.get(download_url, headers=headers) + response.raise_for_status() + + if os.path.exists(skill_dir): + shutil.rmtree(skill_dir) + os.makedirs(skill_dir, exist_ok=True) + + zip_buffer = io.BytesIO(response.content) + with zipfile.ZipFile(zip_buffer, "r") as zip_file: + zip_file.extractall(skill_dir) + + logger.info("skill downloaded and extracted to %s", skill_dir) + return skill_dir + + def download_skill( + self, + target_dir: str = ".skills", + config: Optional[Config] = None, + ) -> str: + """同步下载 Skill 包并解压到本地目录 / Download skill package and extract to local directory synchronously + + 从数据链路下载 skill 的 zip 包,并解压到 {target_dir}/{tool_name}/ 目录下。 + Downloads skill zip package from data endpoint and extracts to {target_dir}/{tool_name}/ directory. + + Args: + target_dir: 目标根目录,默认为 ".skills" / Target root directory, defaults to ".skills" + config: 配置对象,可选 / Configuration object, optional + + Returns: + str: 解压后的 skill 目录路径 / Extracted skill directory path + + Raises: + ValueError: 工具类型不是 SKILL 或缺少必要信息 / Tool type is not SKILL or missing required info + httpx.HTTPStatusError: 下载失败 / Download failed + """ + tool_type = self._get_tool_type() + if tool_type != ToolType.SKILL: + raise ValueError( + "download_skill is only available for SKILL type tools," + f" got {self.tool_type}" + ) + + download_url = self._get_skill_download_url(config) + if not download_url: + raise ValueError( + "Cannot construct download URL: data_endpoint or tool_name" + " is missing" + ) + + effective_name = self.tool_name or self.name + skill_dir = os.path.join(target_dir, effective_name or "unknown_skill") + + logger.debug("downloading skill from %s to %s", download_url, skill_dir) + + cfg = Config.with_configs(config) + headers = cfg.get_headers() + + with httpx.Client(timeout=300, follow_redirects=True) as http_client: + response = http_client.get(download_url, headers=headers) + response.raise_for_status() + + if os.path.exists(skill_dir): + shutil.rmtree(skill_dir) + os.makedirs(skill_dir, exist_ok=True) + + zip_buffer = io.BytesIO(response.content) + with zipfile.ZipFile(zip_buffer, "r") as zip_file: + zip_file.extractall(skill_dir) + + logger.info("skill downloaded and extracted to %s", skill_dir) + return skill_dir diff --git a/pyproject.toml b/pyproject.toml index b3950bb..2c5c075 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ dependencies = [ "litellm>=1.79.3", "alibabacloud-devs20230714>=2.4.1", "pydash>=8.0.5", - "alibabacloud-agentrun20250910>=5.3.1", + "alibabacloud-agentrun20250910>=5.6.0", "alibabacloud_tea_openapi>=0.4.2", "alibabacloud_bailian20231229>=2.6.2", "agentrun-mem0ai>=0.0.10", diff --git a/tests/unittests/integration/test_skill_loader.py b/tests/unittests/integration/test_skill_loader.py new file mode 100644 index 0000000..830bba1 --- /dev/null +++ b/tests/unittests/integration/test_skill_loader.py @@ -0,0 +1,911 @@ +"""SkillLoader 单元测试 / SkillLoader Unit Tests + +测试 Skill 加载器的核心功能: +- _parse_frontmatter() 函数 +- SkillLoader 类(scan_skills / load_skill / to_common_toolset) +- skill_tools() 入口函数 +- builtin/skill.py 导出 +- 各框架 builtin 中的 skill_tools() 函数 +""" + +import json +import os +import sys +from typing import Any, Dict, List, Optional +from unittest.mock import MagicMock, patch + +import pytest + +import agentrun.integration.builtin.skill as _builtin_skill_mod +from agentrun.integration.utils.skill_loader import ( + _parse_frontmatter, + skill_tools, + SkillDetail, + SkillInfo, + SkillLoader, +) +from agentrun.integration.utils.tool import CommonToolSet + +# ============================================================================= +# Helper: 创建临时 skill 目录结构 +# ============================================================================= + + +def _create_skill_dir( + base_dir: str, + skill_name: str, + *, + skill_md_content: Optional[str] = None, + package_json: Optional[Dict[str, Any]] = None, + extra_files: Optional[Dict[str, str]] = None, +) -> str: + """在 base_dir 下创建一个 skill 子目录,写入可选的 SKILL.md / package.json / 其他文件。""" + skill_path = os.path.join(base_dir, skill_name) + os.makedirs(skill_path, exist_ok=True) + + if skill_md_content is not None: + with open( + os.path.join(skill_path, "SKILL.md"), "w", encoding="utf-8" + ) as fh: + fh.write(skill_md_content) + + if package_json is not None: + with open( + os.path.join(skill_path, "package.json"), "w", encoding="utf-8" + ) as fh: + json.dump(package_json, fh) + + if extra_files: + for filename, content in extra_files.items(): + file_path = os.path.join(skill_path, filename) + sub_dir = os.path.dirname(file_path) + if sub_dir and not os.path.isdir(sub_dir): + os.makedirs(sub_dir, exist_ok=True) + with open(file_path, "w", encoding="utf-8") as fh: + fh.write(content) + + return skill_path + + +# ============================================================================= +# 1. _parse_frontmatter 测试 +# ============================================================================= + + +class TestParseFrontmatter: + """测试 YAML frontmatter 解析函数""" + + def test_valid_frontmatter(self) -> None: + content = ( + "---\nname: my-skill\ndescription: A test skill\nversion:" + " 1.0.0\n---\n# Body" + ) + result = _parse_frontmatter(content) + assert result["name"] == "my-skill" + assert result["description"] == "A test skill" + assert result["version"] == "1.0.0" + + def test_no_frontmatter(self) -> None: + content = "# Just a markdown file\nNo frontmatter here." + result = _parse_frontmatter(content) + assert result == {} + + def test_empty_string(self) -> None: + result = _parse_frontmatter("") + assert result == {} + + def test_quoted_values(self) -> None: + content = ( + "---\nname: \"quoted-name\"\ndescription: 'single-quoted'\n---\n" + ) + result = _parse_frontmatter(content) + assert result["name"] == "quoted-name" + assert result["description"] == "single-quoted" + + def test_empty_value(self) -> None: + content = "---\nname: my-skill\ndescription:\n---\n" + result = _parse_frontmatter(content) + assert result["name"] == "my-skill" + assert result["description"] == "" + + def test_colon_in_value(self) -> None: + content = ( + "---\nname: my-skill\ndescription: A skill: does things\n---\n" + ) + result = _parse_frontmatter(content) + assert result["description"] == "A skill: does things" + + def test_blank_lines_in_frontmatter(self) -> None: + content = "---\nname: my-skill\n\ndescription: test\n---\n" + result = _parse_frontmatter(content) + assert result["name"] == "my-skill" + assert result["description"] == "test" + + def test_no_closing_delimiter(self) -> None: + content = "---\nname: my-skill\ndescription: test\n" + result = _parse_frontmatter(content) + assert result == {} + + +# ============================================================================= +# 2. SkillLoader.scan_skills 测试 +# ============================================================================= + + +class TestScanSkills: + """测试 SkillLoader.scan_skills()""" + + def test_empty_directory(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + loader = SkillLoader(skills_dir=skills_dir) + result = loader.scan_skills() + assert result == [] + + def test_nonexistent_directory(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "nonexistent") + loader = SkillLoader(skills_dir=skills_dir) + result = loader.scan_skills() + assert result == [] + + def test_skill_with_frontmatter(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir( + skills_dir, + "my-skill", + skill_md_content=( + "---\nname: custom-name\ndescription: A great skill\nversion:" + " 2.0\n---\n# Skill" + ), + ) + loader = SkillLoader(skills_dir=skills_dir) + result = loader.scan_skills() + assert len(result) == 1 + assert result[0].name == "custom-name" + assert result[0].description == "A great skill" + assert result[0].version == "2.0" + + def test_skill_with_package_json_only(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir( + skills_dir, + "pkg-skill", + package_json={ + "name": "pkg-name", + "description": "From package.json", + "version": "3.0", + }, + ) + loader = SkillLoader(skills_dir=skills_dir) + result = loader.scan_skills() + assert len(result) == 1 + assert result[0].name == "pkg-name" + assert result[0].description == "From package.json" + assert result[0].version == "3.0" + + def test_skill_with_no_metadata(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir(skills_dir, "bare-skill") + loader = SkillLoader(skills_dir=skills_dir) + result = loader.scan_skills() + assert len(result) == 1 + assert result[0].name == "bare-skill" + assert result[0].description == "" + + def test_frontmatter_takes_priority_over_package_json( + self, tmp_path: Any + ) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir( + skills_dir, + "priority-skill", + skill_md_content=( + "---\nname: from-frontmatter\ndescription: FM desc\n---\n" + ), + package_json={"name": "from-pkg", "description": "PKG desc"}, + ) + loader = SkillLoader(skills_dir=skills_dir) + result = loader.scan_skills() + assert len(result) == 1 + assert result[0].name == "from-frontmatter" + assert result[0].description == "FM desc" + + def test_multiple_skills_sorted(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir(skills_dir, "beta-skill") + _create_skill_dir(skills_dir, "alpha-skill") + _create_skill_dir(skills_dir, "gamma-skill") + loader = SkillLoader(skills_dir=skills_dir) + result = loader.scan_skills() + assert len(result) == 3 + assert [s.name for s in result] == [ + "alpha-skill", + "beta-skill", + "gamma-skill", + ] + + def test_hidden_directories_skipped(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir(skills_dir, ".hidden-skill") + _create_skill_dir(skills_dir, "visible-skill") + loader = SkillLoader(skills_dir=skills_dir) + result = loader.scan_skills() + assert len(result) == 1 + assert result[0].name == "visible-skill" + + def test_files_in_root_are_skipped(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + with open(os.path.join(skills_dir, "not-a-skill.txt"), "w") as fh: + fh.write("just a file") + _create_skill_dir(skills_dir, "real-skill") + loader = SkillLoader(skills_dir=skills_dir) + result = loader.scan_skills() + assert len(result) == 1 + assert result[0].name == "real-skill" + + def test_cache_is_used(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir(skills_dir, "cached-skill") + loader = SkillLoader(skills_dir=skills_dir) + first_result = loader.scan_skills() + # Add another skill after first scan + _create_skill_dir(skills_dir, "new-skill") + second_result = loader.scan_skills() + # Should return cached result + assert first_result is second_result + assert len(second_result) == 1 + + def test_malformed_package_json(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + skill_path = os.path.join(skills_dir, "bad-pkg") + os.makedirs(skill_path) + with open(os.path.join(skill_path, "package.json"), "w") as fh: + fh.write("{invalid json") + loader = SkillLoader(skills_dir=skills_dir) + result = loader.scan_skills() + assert len(result) == 1 + assert result[0].name == "bad-pkg" + + def test_skill_md_without_frontmatter_falls_to_package_json( + self, tmp_path: Any + ) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir( + skills_dir, + "fallback-skill", + skill_md_content="# No frontmatter here\nJust content.", + package_json={"name": "pkg-fallback", "description": "From pkg"}, + ) + loader = SkillLoader(skills_dir=skills_dir) + result = loader.scan_skills() + assert len(result) == 1 + assert result[0].name == "pkg-fallback" + assert result[0].description == "From pkg" + + +# ============================================================================= +# 3. SkillLoader.load_skill 测试 +# ============================================================================= + + +class TestLoadSkill: + """测试 SkillLoader.load_skill()""" + + def test_load_existing_skill(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + md_content = ( + "---\nname: test-skill\ndescription: Test\n---\n# Instructions\nDo" + " stuff." + ) + _create_skill_dir( + skills_dir, + "test-skill", + skill_md_content=md_content, + extra_files={"helper.py": "print('hello')"}, + ) + loader = SkillLoader(skills_dir=skills_dir) + detail = loader.load_skill("test-skill") + assert detail is not None + assert detail.name == "test-skill" + assert detail.description == "Test" + assert detail.instruction == md_content + assert "SKILL.md" in detail.files + assert "helper.py" in detail.files + + def test_load_nonexistent_skill(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + loader = SkillLoader(skills_dir=skills_dir) + detail = loader.load_skill("nonexistent") + assert detail is None + + def test_load_skill_with_subdirectory(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + skill_path = _create_skill_dir( + skills_dir, + "dir-skill", + skill_md_content="---\nname: dir-skill\n---\n", + ) + sub_dir = os.path.join(skill_path, "scripts") + os.makedirs(sub_dir) + with open(os.path.join(sub_dir, "run.sh"), "w") as fh: + fh.write("#!/bin/bash") + loader = SkillLoader(skills_dir=skills_dir) + detail = loader.load_skill("dir-skill") + assert detail is not None + assert "scripts/" in detail.files + + def test_load_skill_hidden_files_excluded(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir( + skills_dir, + "hidden-files-skill", + skill_md_content="---\nname: hidden-files-skill\n---\n", + extra_files={".hidden": "secret", "visible.txt": "public"}, + ) + loader = SkillLoader(skills_dir=skills_dir) + detail = loader.load_skill("hidden-files-skill") + assert detail is not None + assert ".hidden" not in detail.files + assert "visible.txt" in detail.files + + def test_load_skill_without_skill_md(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir( + skills_dir, + "no-md-skill", + extra_files={"readme.txt": "hello"}, + ) + loader = SkillLoader(skills_dir=skills_dir) + detail = loader.load_skill("no-md-skill") + assert detail is not None + assert detail.instruction == "" + assert "readme.txt" in detail.files + + +# ============================================================================= +# 4. SkillLoader._build_tool_description 测试 +# ============================================================================= + + +class TestBuildToolDescription: + """测试 load_skills 工具描述的构建""" + + def test_no_skills(self) -> None: + loader = SkillLoader(skills_dir="/nonexistent") + desc = loader._build_tool_description([]) + assert "No skills available" in desc + + def test_with_skills(self) -> None: + loader = SkillLoader(skills_dir="/nonexistent") + skills = [ + SkillInfo(name="alpha", description="Alpha skill"), + SkillInfo(name="beta", description=""), + ] + desc = loader._build_tool_description(skills) + assert "alpha: Alpha skill" in desc + assert "- beta" in desc + assert "Available skills:" in desc + + +# ============================================================================= +# 5. SkillLoader._load_skills_func 测试 +# ============================================================================= + + +class TestLoadSkillsFunc: + """测试 load_skills 工具的执行函数""" + + def test_list_all_skills(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir( + skills_dir, + "skill-a", + skill_md_content="---\nname: skill-a\ndescription: Skill A\n---\n", + ) + _create_skill_dir( + skills_dir, + "skill-b", + skill_md_content="---\nname: skill-b\ndescription: Skill B\n---\n", + ) + loader = SkillLoader(skills_dir=skills_dir) + result_json = loader._load_skills_func(name=None) + result = json.loads(result_json) + assert "skills" in result + assert len(result["skills"]) == 2 + names = [s["name"] for s in result["skills"]] + assert "skill-a" in names + assert "skill-b" in names + + def test_list_with_empty_string(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir(skills_dir, "only-skill") + loader = SkillLoader(skills_dir=skills_dir) + result_json = loader._load_skills_func(name="") + result = json.loads(result_json) + assert "skills" in result + + def test_load_specific_skill(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + md_content = ( + "---\nname: target\ndescription: Target skill\n---\n# Instructions" + ) + _create_skill_dir( + skills_dir, + "target", + skill_md_content=md_content, + ) + loader = SkillLoader(skills_dir=skills_dir) + result_json = loader._load_skills_func(name="target") + result = json.loads(result_json) + assert result["name"] == "target" + assert result["description"] == "Target skill" + assert "instruction" in result + assert "files" in result + + def test_load_nonexistent_skill_returns_error(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir(skills_dir, "existing") + loader = SkillLoader(skills_dir=skills_dir) + result_json = loader._load_skills_func(name="missing") + result = json.loads(result_json) + assert "error" in result + assert "missing" in result["error"] + assert "existing" in result["error"] + + +# ============================================================================= +# 6. SkillLoader.to_common_toolset 测试 +# ============================================================================= + + +class TestToCommonToolset: + """测试 to_common_toolset() 返回的 CommonToolSet""" + + def test_returns_common_toolset(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir( + skills_dir, + "test-skill", + skill_md_content="---\nname: test-skill\ndescription: Test\n---\n", + ) + loader = SkillLoader(skills_dir=skills_dir) + toolset = loader.to_common_toolset() + assert isinstance(toolset, CommonToolSet) + + def test_toolset_has_load_skills_tool(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir(skills_dir, "test-skill") + loader = SkillLoader(skills_dir=skills_dir) + toolset = loader.to_common_toolset() + tools_list = toolset.tools() + assert len(tools_list) == 1 + assert tools_list[0].name == "load_skills" + + def test_tool_description_contains_skill_names(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir( + skills_dir, + "alpha", + skill_md_content="---\nname: alpha\ndescription: Alpha desc\n---\n", + ) + loader = SkillLoader(skills_dir=skills_dir) + toolset = loader.to_common_toolset() + tool = toolset.tools()[0] + assert "alpha" in tool.description + assert "Alpha desc" in tool.description + + def test_tool_has_name_parameter(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir(skills_dir, "test-skill") + loader = SkillLoader(skills_dir=skills_dir) + toolset = loader.to_common_toolset() + tool = toolset.tools()[0] + # CanonicalTool.parameters is a JSON schema dict + assert "properties" in tool.parameters + assert "name" in tool.parameters["properties"] + name_prop = tool.parameters["properties"]["name"] + assert name_prop["type"] == "string" + + def test_tool_func_is_callable(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir( + skills_dir, + "callable-skill", + skill_md_content="---\nname: callable-skill\n---\n", + ) + loader = SkillLoader(skills_dir=skills_dir) + toolset = loader.to_common_toolset() + tool = toolset.tools()[0] + result_json = tool.func() + result = json.loads(result_json) + assert "skills" in result + + def test_empty_skills_dir_still_returns_toolset( + self, tmp_path: Any + ) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + loader = SkillLoader(skills_dir=skills_dir) + toolset = loader.to_common_toolset() + assert isinstance(toolset, CommonToolSet) + tools_list = toolset.tools() + assert len(tools_list) == 1 + assert "No skills available" in tools_list[0].description + + +# ============================================================================= +# 7. skill_tools() 入口函数测试 +# ============================================================================= + + +class TestSkillToolsFunction: + """测试 skill_tools() 入口函数""" + + def test_local_only(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir( + skills_dir, + "local-skill", + skill_md_content=( + "---\nname: local-skill\ndescription: Local\n---\n" + ), + ) + toolset = skill_tools(skills_dir=skills_dir) + assert isinstance(toolset, CommonToolSet) + assert len(toolset.tools()) == 1 + + def test_with_string_name_triggers_remote_download( + self, tmp_path: Any + ) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + + mock_tool_client = MagicMock() + mock_tool_resource = MagicMock() + mock_tool_client.return_value.get.return_value = mock_tool_resource + + with patch( + "agentrun.integration.utils.skill_loader.SkillLoader._ensure_skills_available" + ) as mock_ensure: + toolset = skill_tools(name="remote-skill", skills_dir=skills_dir) + assert isinstance(toolset, CommonToolSet) + + def test_with_list_of_names(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + + with patch( + "agentrun.integration.utils.skill_loader.SkillLoader._ensure_skills_available" + ): + toolset = skill_tools( + name=["skill-a", "skill-b"], skills_dir=skills_dir + ) + assert isinstance(toolset, CommonToolSet) + + def test_with_tool_resource_instance(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + + mock_resource = MagicMock() + mock_resource.name = "resource-skill" + + toolset = skill_tools(name=mock_resource, skills_dir=skills_dir) + assert isinstance(toolset, CommonToolSet) + mock_resource.download_skill.assert_called_once() + + def test_with_tool_resource_already_downloaded(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + # Pre-create the skill directory so download is skipped + _create_skill_dir(skills_dir, "existing-resource") + + mock_resource = MagicMock() + mock_resource.name = "existing-resource" + + toolset = skill_tools(name=mock_resource, skills_dir=skills_dir) + assert isinstance(toolset, CommonToolSet) + mock_resource.download_skill.assert_not_called() + + def test_none_name_loads_local_only(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir(skills_dir, "local-only") + toolset = skill_tools(name=None, skills_dir=skills_dir) + assert isinstance(toolset, CommonToolSet) + + +# ============================================================================= +# 8. _ensure_skills_available 测试 +# ============================================================================= + + +class TestEnsureSkillsAvailable: + """测试远程 skill 下载逻辑""" + + def test_no_remote_names_does_nothing(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + loader = SkillLoader(skills_dir=skills_dir, remote_skill_names=[]) + # Should not raise + loader._ensure_skills_available() + + def test_existing_skill_skips_download(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + _create_skill_dir(skills_dir, "already-here") + + loader = SkillLoader( + skills_dir=skills_dir, remote_skill_names=["already-here"] + ) + with patch("agentrun.tool.client.ToolClient") as mock_client: + loader._ensure_skills_available() + mock_client.assert_not_called() + + def test_missing_skill_triggers_download(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + + mock_tool_resource = MagicMock() + mock_client_instance = MagicMock() + mock_client_instance.get.return_value = mock_tool_resource + + loader = SkillLoader( + skills_dir=skills_dir, remote_skill_names=["new-skill"] + ) + with patch( + "agentrun.tool.client.ToolClient", + return_value=mock_client_instance, + ): + loader._ensure_skills_available() + mock_client_instance.get.assert_called_once_with( + name="new-skill", config=None + ) + mock_tool_resource.download_skill.assert_called_once_with( + target_dir=skills_dir, config=None + ) + + +# ============================================================================= +# 9. builtin/skill.py 导出测试 +# ============================================================================= + + +class TestBuiltinSkillExport: + """测试 builtin/skill.py 的导出""" + + def test_skill_tools_is_exported(self) -> None: + assert hasattr(_builtin_skill_mod, "skill_tools") + assert callable(_builtin_skill_mod.skill_tools) + + def test_skill_tools_in_all(self) -> None: + assert "skill_tools" in _builtin_skill_mod.__all__ + + def test_import_from_builtin_init(self) -> None: + from agentrun.integration.builtin import skill_tools as imported_fn + + assert callable(imported_fn) + + +# ============================================================================= +# 10. 各框架 builtin skill_tools 测试 +# ============================================================================= + + +class TestFrameworkBuiltinSkillTools: + """测试各框架 builtin 中的 skill_tools() 函数""" + + def _run_framework_test(self, framework_module_path: str) -> None: + """通用框架测试:mock builtin skill_tools 返回 CommonToolSet, + 验证框架 skill_tools 调用了正确的转换方法。""" + mock_toolset = MagicMock(spec=CommonToolSet) + mock_toolset.to_langchain.return_value = [MagicMock()] + mock_toolset.to_google_adk.return_value = [MagicMock()] + mock_toolset.to_crewai.return_value = [MagicMock()] + mock_toolset.to_langgraph.return_value = [MagicMock()] + mock_toolset.to_pydantic_ai.return_value = [MagicMock()] + mock_toolset.to_agentscope.return_value = [MagicMock()] + + with patch( + f"{framework_module_path}._skill_tools", + return_value=mock_toolset, + ): + module = sys.modules.get(framework_module_path) + if module is None: + import importlib + + module = importlib.import_module(framework_module_path) + result = module.skill_tools(skills_dir=".test-skills") + assert isinstance(result, list) + assert len(result) == 1 + + def test_langchain_skill_tools(self) -> None: + self._run_framework_test("agentrun.integration.langchain.builtin") + + def test_google_adk_skill_tools(self) -> None: + self._run_framework_test("agentrun.integration.google_adk.builtin") + + def test_crewai_skill_tools(self) -> None: + self._run_framework_test("agentrun.integration.crewai.builtin") + + def test_langgraph_skill_tools(self) -> None: + self._run_framework_test("agentrun.integration.langgraph.builtin") + + def test_pydantic_ai_skill_tools(self) -> None: + self._run_framework_test("agentrun.integration.pydantic_ai.builtin") + + def test_agentscope_skill_tools(self) -> None: + self._run_framework_test("agentrun.integration.agentscope.builtin") + + def test_framework_import_from_init(self) -> None: + """验证各框架 __init__.py 正确导出 skill_tools""" + from agentrun.integration.agentscope import skill_tools as as_fn + from agentrun.integration.crewai import skill_tools as crew_fn + from agentrun.integration.google_adk import skill_tools as adk_fn + from agentrun.integration.langchain import skill_tools as lc_fn + from agentrun.integration.langgraph import skill_tools as lg_fn + from agentrun.integration.pydantic_ai import skill_tools as pai_fn + + assert callable(lc_fn) + assert callable(adk_fn) + assert callable(crew_fn) + assert callable(lg_fn) + assert callable(pai_fn) + assert callable(as_fn) + + +# ============================================================================= +# 11. SkillInfo / SkillDetail 数据类测试 +# ============================================================================= + + +class TestDataClasses: + """测试 SkillInfo 和 SkillDetail 数据类""" + + def test_skill_info_defaults(self) -> None: + info = SkillInfo(name="test") + assert info.name == "test" + assert info.description == "" + assert info.version == "" + assert info.path == "" + + def test_skill_info_with_all_fields(self) -> None: + info = SkillInfo( + name="full", description="desc", version="1.0", path="/path" + ) + assert info.name == "full" + assert info.description == "desc" + assert info.version == "1.0" + assert info.path == "/path" + + def test_skill_detail_defaults(self) -> None: + detail = SkillDetail(name="test") + assert detail.name == "test" + assert detail.instruction == "" + assert detail.files == [] + + def test_skill_detail_inherits_skill_info(self) -> None: + detail = SkillDetail( + name="full", + description="desc", + version="1.0", + path="/path", + instruction="# Do stuff", + files=["a.py", "b.py"], + ) + assert isinstance(detail, SkillInfo) + assert detail.instruction == "# Do stuff" + assert detail.files == ["a.py", "b.py"] + + +# ============================================================================= +# 12. 端到端集成测试 +# ============================================================================= + + +class TestEndToEnd: + """端到端测试:从创建 skill 目录到调用 load_skills 工具""" + + def test_full_workflow(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + + md_content = ( + "---\n" + "name: e2e-skill\n" + "description: End-to-end test skill\n" + "version: 1.0.0\n" + "---\n" + "\n" + "# E2E Skill\n" + "\n" + "Follow these instructions to use the skill.\n" + ) + _create_skill_dir( + skills_dir, + "e2e-skill", + skill_md_content=md_content, + extra_files={"scripts/run.sh": "#!/bin/bash\necho hello"}, + ) + + toolset = skill_tools(skills_dir=skills_dir) + assert isinstance(toolset, CommonToolSet) + tools_list = toolset.tools() + assert len(tools_list) == 1 + + tool = tools_list[0] + assert tool.name == "load_skills" + assert "e2e-skill" in tool.description + + # List all skills + list_result = json.loads(tool.func()) + assert len(list_result["skills"]) == 1 + assert list_result["skills"][0]["name"] == "e2e-skill" + + # Load specific skill + detail_result = json.loads(tool.func(name="e2e-skill")) + assert detail_result["name"] == "e2e-skill" + assert detail_result["description"] == "End-to-end test skill" + assert "Follow these instructions" in detail_result["instruction"] + assert "SKILL.md" in detail_result["files"] + assert "scripts/" in detail_result["files"] + + # Load nonexistent skill + error_result = json.loads(tool.func(name="nonexistent")) + assert "error" in error_result + assert "e2e-skill" in error_result["error"] + + def test_multiple_skills_workflow(self, tmp_path: Any) -> None: + skills_dir = str(tmp_path / "skills") + os.makedirs(skills_dir) + + _create_skill_dir( + skills_dir, + "skill-alpha", + skill_md_content=( + "---\nname: skill-alpha\ndescription: Alpha\n---\n# Alpha" + ), + ) + _create_skill_dir( + skills_dir, + "skill-beta", + package_json={"name": "skill-beta", "description": "Beta"}, + ) + + toolset = skill_tools(skills_dir=skills_dir) + tool = toolset.tools()[0] + + list_result = json.loads(tool.func()) + assert len(list_result["skills"]) == 2 + + alpha = json.loads(tool.func(name="skill-alpha")) + assert alpha["name"] == "skill-alpha" + assert "# Alpha" in alpha["instruction"] + + beta = json.loads(tool.func(name="skill-beta")) + assert beta["name"] == "skill-beta" + assert beta["instruction"] == "" diff --git a/tests/unittests/tool/__init__.py b/tests/unittests/tool/__init__.py new file mode 100644 index 0000000..745294f --- /dev/null +++ b/tests/unittests/tool/__init__.py @@ -0,0 +1,5 @@ +"""Tool 模块单元测试 / Tool Module Unit Tests + +测试 tool 模块中数据模型、API 客户端和资源类的相关功能。 +Tests data models, API clients, and resource classes in the tool module. +""" diff --git a/tests/unittests/tool/test_mcp.py b/tests/unittests/tool/test_mcp.py new file mode 100644 index 0000000..83dc3fa --- /dev/null +++ b/tests/unittests/tool/test_mcp.py @@ -0,0 +1,308 @@ +"""Tool MCP 会话单元测试 / Tool MCP Session Unit Tests + +测试 ToolMCPSession 的 MCP 协议交互功能。 +Tests MCP protocol interaction functionality of ToolMCPSession. +""" + +import sys +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import pytest + +from agentrun.tool.api.mcp import ToolMCPSession +from agentrun.tool.model import ToolInfo + + +class TestToolMCPSessionInit: + """测试 ToolMCPSession 初始化 / Test ToolMCPSession initialization""" + + def test_init_with_defaults(self): + """测试使用默认参数初始化""" + session = ToolMCPSession(endpoint="http://example.com/mcp") + assert session.endpoint == "http://example.com/mcp" + assert session.session_affinity is None + assert session.headers == {} + + def test_init_with_all_parameters(self): + """测试使用所有参数初始化""" + headers = {"Authorization": "Bearer token"} + session = ToolMCPSession( + endpoint="http://example.com/mcp", + session_affinity="MCP_STREAMABLE", + headers=headers, + ) + assert session.endpoint == "http://example.com/mcp" + assert session.session_affinity == "MCP_STREAMABLE" + assert session.headers == headers + + +class TestToolMCPSessionIsStreamable: + """测试 is_streamable 属性""" + + def test_is_streamable_returns_true_for_mcp_streamable(self): + """测试 MCP_STREAMABLE 返回 True""" + session = ToolMCPSession( + endpoint="http://example.com/mcp", + session_affinity="MCP_STREAMABLE", + ) + assert session.is_streamable is True + + def test_is_streamable_returns_false_for_other_values(self): + """测试其他值返回 False""" + for value in [None, "MCP_SSE", "OTHER", ""]: + session = ToolMCPSession( + endpoint="http://example.com/mcp", + session_affinity=value, + ) + assert session.is_streamable is False + + +def _make_mock_mcp_tool(name: str, description: str) -> MagicMock: + """创建 mock MCP tool 对象""" + tool = MagicMock() + tool.name = name + tool.description = description + tool.inputSchema = {"type": "object", "properties": {}} + return tool + + +def _setup_mock_mcp_modules( + mock_session: AsyncMock, +) -> dict: + """设置 mock mcp 模块,返回需要注入到 sys.modules 的字典""" + mock_client_session_cls = MagicMock() + mock_session_ctx = AsyncMock() + mock_session_ctx.__aenter__.return_value = mock_session + mock_session_ctx.__aexit__.return_value = None + mock_client_session_cls.return_value = mock_session_ctx + + # mock streamablehttp_client + mock_streamable_fn = MagicMock() + mock_streamable_ctx = AsyncMock() + mock_streamable_ctx.__aenter__.return_value = ( + AsyncMock(), + AsyncMock(), + MagicMock(), + ) + mock_streamable_ctx.__aexit__.return_value = None + mock_streamable_fn.return_value = mock_streamable_ctx + + # mock sse_client + mock_sse_fn = MagicMock() + mock_sse_ctx = AsyncMock() + mock_sse_ctx.__aenter__.return_value = (AsyncMock(), AsyncMock()) + mock_sse_ctx.__aexit__.return_value = None + mock_sse_fn.return_value = mock_sse_ctx + + mock_mcp = MagicMock() + mock_mcp.ClientSession = mock_client_session_cls + + mock_mcp_client_streamable = MagicMock() + mock_mcp_client_streamable.streamablehttp_client = mock_streamable_fn + + mock_mcp_client_sse = MagicMock() + mock_mcp_client_sse.sse_client = mock_sse_fn + + return { + "mcp": mock_mcp, + "mcp.client": MagicMock(), + "mcp.client.streamable_http": mock_mcp_client_streamable, + "mcp.client.sse": mock_mcp_client_sse, + } + + +class TestToolMCPSessionListToolsAsync: + """测试 list_tools_async 方法""" + + @pytest.mark.asyncio + async def test_list_tools_async_streamable_mode(self): + """测试 Streamable 模式下获取工具列表""" + mock_tool = _make_mock_mcp_tool("tool1", "Test tool 1") + mock_result = MagicMock() + mock_result.tools = [mock_tool] + + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + mock_session.list_tools = AsyncMock(return_value=mock_result) + + mock_modules = _setup_mock_mcp_modules(mock_session) + + with patch.dict(sys.modules, mock_modules): + session = ToolMCPSession( + endpoint="http://example.com/mcp", + session_affinity="MCP_STREAMABLE", + headers={"Authorization": "Bearer token"}, + ) + tools = await session.list_tools_async() + + assert len(tools) == 1 + assert isinstance(tools[0], ToolInfo) + assert tools[0].name == "tool1" + + @pytest.mark.asyncio + async def test_list_tools_async_sse_mode(self): + """测试 SSE 模式下获取工具列表""" + mock_tool = _make_mock_mcp_tool("tool1", "Test tool 1") + mock_result = MagicMock() + mock_result.tools = [mock_tool] + + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + mock_session.list_tools = AsyncMock(return_value=mock_result) + + mock_modules = _setup_mock_mcp_modules(mock_session) + + with patch.dict(sys.modules, mock_modules): + session = ToolMCPSession( + endpoint="http://example.com/mcp", + session_affinity="MCP_SSE", + ) + tools = await session.list_tools_async() + + assert len(tools) == 1 + assert isinstance(tools[0], ToolInfo) + + @pytest.mark.asyncio + async def test_list_tools_async_import_error(self): + """测试 mcp 未安装时返回空列表""" + saved_modules = {} + modules_to_remove = [ + k for k in sys.modules if k == "mcp" or k.startswith("mcp.") + ] + for key in modules_to_remove: + saved_modules[key] = sys.modules.pop(key) + + original_import = __builtins__.__import__ if hasattr(__builtins__, "__import__") else __import__ # type: ignore + + def mock_import(name, *args, **kwargs): + if name == "mcp" or name.startswith("mcp."): + raise ImportError(f"No module named '{name}'") + return original_import(name, *args, **kwargs) + + with patch("builtins.__import__", side_effect=mock_import): + session = ToolMCPSession(endpoint="http://example.com/mcp") + tools = await session.list_tools_async() + + sys.modules.update(saved_modules) + assert tools == [] + + +class TestToolMCPSessionListTools: + """测试 list_tools 同步方法""" + + def test_list_tools_synchronous(self): + """测试同步获取工具列表""" + expected_tools = [ToolInfo(name="tool1", description="Test")] + + with patch.object( + ToolMCPSession, + "list_tools_async", + new_callable=AsyncMock, + return_value=expected_tools, + ): + with patch("asyncio.get_event_loop") as mock_get_loop: + mock_loop = MagicMock() + mock_loop.run_until_complete.return_value = expected_tools + mock_get_loop.return_value = mock_loop + + session = ToolMCPSession(endpoint="http://example.com/mcp") + tools = session.list_tools() + + assert tools == expected_tools + mock_loop.run_until_complete.assert_called_once() + + +class TestToolMCPSessionCallToolAsync: + """测试 call_tool_async 方法""" + + @pytest.mark.asyncio + async def test_call_tool_async_streamable_mode(self): + """测试 Streamable 模式下调用工具""" + mock_call_result = MagicMock() + mock_call_result.content = [{"type": "text", "text": "result"}] + + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + mock_session.call_tool = AsyncMock(return_value=mock_call_result) + + mock_modules = _setup_mock_mcp_modules(mock_session) + + with patch.dict(sys.modules, mock_modules): + session = ToolMCPSession( + endpoint="http://example.com/mcp", + session_affinity="MCP_STREAMABLE", + ) + result = await session.call_tool_async( + "test_tool", {"param": "value"} + ) + + assert result == mock_call_result + + @pytest.mark.asyncio + async def test_call_tool_async_sse_mode(self): + """测试 SSE 模式下调用工具""" + mock_call_result = MagicMock() + + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + mock_session.call_tool = AsyncMock(return_value=mock_call_result) + + mock_modules = _setup_mock_mcp_modules(mock_session) + + with patch.dict(sys.modules, mock_modules): + session = ToolMCPSession( + endpoint="http://example.com/mcp", + session_affinity="MCP_SSE", + ) + result = await session.call_tool_async("test_tool", {"key": "val"}) + + assert result == mock_call_result + + @pytest.mark.asyncio + async def test_call_tool_async_import_error(self): + """测试 mcp 未安装时抛出 ImportError""" + saved_modules = {} + modules_to_remove = [ + k for k in sys.modules if k == "mcp" or k.startswith("mcp.") + ] + for key in modules_to_remove: + saved_modules[key] = sys.modules.pop(key) + + original_import = __builtins__.__import__ if hasattr(__builtins__, "__import__") else __import__ # type: ignore + + def mock_import(name, *args, **kwargs): + if name == "mcp" or name.startswith("mcp."): + raise ImportError(f"No module named '{name}'") + return original_import(name, *args, **kwargs) + + with patch("builtins.__import__", side_effect=mock_import): + session = ToolMCPSession(endpoint="http://example.com/mcp") + with pytest.raises(ImportError): + await session.call_tool_async("test_tool") + + sys.modules.update(saved_modules) + + +class TestToolMCPSessionCallTool: + """测试 call_tool 同步方法""" + + def test_call_tool_synchronous(self): + """测试同步调用工具""" + expected_result = {"result": "success"} + + with patch.object( + ToolMCPSession, + "call_tool_async", + new_callable=AsyncMock, + return_value=expected_result, + ): + with patch("asyncio.get_event_loop") as mock_get_loop: + mock_loop = MagicMock() + mock_loop.run_until_complete.return_value = expected_result + mock_get_loop.return_value = mock_loop + + session = ToolMCPSession(endpoint="http://example.com/mcp") + result = session.call_tool("test_tool", {"param": "value"}) + + assert result == expected_result + mock_loop.run_until_complete.assert_called_once() diff --git a/tests/unittests/tool/test_model.py b/tests/unittests/tool/test_model.py new file mode 100644 index 0000000..b195d0e --- /dev/null +++ b/tests/unittests/tool/test_model.py @@ -0,0 +1,661 @@ +"""Tool 模型单元测试 / Tool Model Unit Tests + +测试 tool 模块中数据模型和工具 schema 的相关功能。 +Tests data models and tool schema functionality in the tool module. +""" + +import pytest + +from agentrun.tool.model import ( + McpConfig, + ToolCodeConfiguration, + ToolContainerConfiguration, + ToolInfo, + ToolLogConfiguration, + ToolNASConfig, + ToolNetworkConfiguration, + ToolOSSMountConfig, + ToolSchema, + ToolType, +) + + +class TestToolType: + """测试 ToolType 枚举""" + + def test_mcp_type(self): + """测试 MCP 类型""" + assert ToolType.MCP == "MCP" + assert ToolType.MCP.value == "MCP" + + def test_functioncall_type(self): + """测试 FUNCTIONCALL 类型""" + assert ToolType.FUNCTIONCALL == "FUNCTIONCALL" + assert ToolType.FUNCTIONCALL.value == "FUNCTIONCALL" + + +class TestMcpConfig: + """测试 McpConfig 模型""" + + def test_default_values(self): + """测试默认值""" + config = McpConfig() + assert config.session_affinity is None + assert config.session_affinity_config is None + assert config.proxy_enabled is None + assert config.bound_configuration is None + assert config.mcp_proxy_configuration is None + + def test_with_values(self): + """测试带值创建""" + config = McpConfig( + session_affinity="MCP_SSE", + session_affinity_config='{"key": "value"}', + proxy_enabled=True, + bound_configuration={"key": "value"}, + mcp_proxy_configuration={"proxy": "config"}, + ) + assert config.session_affinity == "MCP_SSE" + assert config.session_affinity_config == '{"key": "value"}' + assert config.proxy_enabled is True + assert config.bound_configuration == {"key": "value"} + assert config.mcp_proxy_configuration == {"proxy": "config"} + + +class TestToolCodeConfiguration: + """测试 ToolCodeConfiguration 模型""" + + def test_default_values(self): + """测试默认值""" + config = ToolCodeConfiguration() + assert config.code_checksum is None + assert config.code_size is None + assert config.command is None + assert config.language is None + assert config.oss_bucket_name is None + assert config.oss_object_name is None + + def test_with_values(self): + """测试带值创建""" + config = ToolCodeConfiguration( + code_checksum="abc123", + code_size=1024, + command=["python", "app.py"], + language="python3.10", + oss_bucket_name="my-bucket", + oss_object_name="code.zip", + ) + assert config.code_checksum == "abc123" + assert config.code_size == 1024 + assert config.command == ["python", "app.py"] + assert config.language == "python3.10" + assert config.oss_bucket_name == "my-bucket" + assert config.oss_object_name == "code.zip" + + +class TestToolContainerConfiguration: + """测试 ToolContainerConfiguration 模型""" + + def test_default_values(self): + """测试默认值""" + config = ToolContainerConfiguration() + assert config.args is None + assert config.command is None + assert config.image is None + assert config.port is None + + def test_with_values(self): + """测试带值创建""" + config = ToolContainerConfiguration( + args=["--arg1", "value1"], + command=["python", "app.py"], + image="registry.example.com/tool:latest", + port=8080, + ) + assert config.args == ["--arg1", "value1"] + assert config.command == ["python", "app.py"] + assert config.image == "registry.example.com/tool:latest" + assert config.port == 8080 + + +class TestToolLogConfiguration: + """测试 ToolLogConfiguration 模型""" + + def test_default_values(self): + """测试默认值""" + config = ToolLogConfiguration() + assert config.log_store is None + assert config.project is None + + def test_with_values(self): + """测试带值创建""" + config = ToolLogConfiguration( + log_store="my-log-store", + project="my-project", + ) + assert config.log_store == "my-log-store" + assert config.project == "my-project" + + +class TestToolNASConfig: + """测试 ToolNASConfig 模型""" + + def test_default_values(self): + """测试默认值""" + config = ToolNASConfig() + assert config.group_id is None + assert config.mount_points is None + assert config.user_id is None + + def test_with_values(self): + """测试带值创建""" + config = ToolNASConfig( + group_id=1001, + mount_points=[{"path": "/mnt/nas", "nas_id": "nas-123"}], + user_id=1000, + ) + assert config.group_id == 1001 + assert config.mount_points == [ + {"path": "/mnt/nas", "nas_id": "nas-123"} + ] + assert config.user_id == 1000 + + +class TestToolNetworkConfiguration: + """测试 ToolNetworkConfiguration 模型""" + + def test_default_values(self): + """测试默认值""" + config = ToolNetworkConfiguration() + assert config.security_group_id is None + assert config.vpc_id is None + assert config.vswitch_ids is None + + def test_with_values(self): + """测试带值创建""" + config = ToolNetworkConfiguration( + security_group_id="sg-123", + vpc_id="vpc-456", + vswitch_ids=["vsw-789", "vsw-012"], + ) + assert config.security_group_id == "sg-123" + assert config.vpc_id == "vpc-456" + assert config.vswitch_ids == ["vsw-789", "vsw-012"] + + +class TestToolOSSMountConfig: + """测试 ToolOSSMountConfig 模型""" + + def test_default_values(self): + """测试默认值""" + config = ToolOSSMountConfig() + assert config.mount_points is None + + def test_with_values(self): + """测试带值创建""" + config = ToolOSSMountConfig( + mount_points=[{ + "bucket": "my-bucket", + "endpoint": "oss-cn-hangzhou.aliyuncs.com", + }] + ) + assert config.mount_points == [ + {"bucket": "my-bucket", "endpoint": "oss-cn-hangzhou.aliyuncs.com"} + ] + + +class TestToolSchema: + """测试 ToolSchema 模型""" + + def test_default_values(self): + """测试默认值""" + schema = ToolSchema() + assert schema.type is None + assert schema.description is None + assert schema.properties is None + assert schema.required is None + assert schema.items is None + assert schema.any_of is None + assert schema.one_of is None + assert schema.all_of is None + + def test_from_any_openapi_schema_simple(self): + """测试从简单 OpenAPI Schema 创建""" + openapi_schema = { + "type": "object", + "description": "A simple object", + "properties": { + "name": {"type": "string", "description": "Name field"}, + "age": {"type": "integer", "description": "Age field"}, + }, + "required": ["name"], + } + schema = ToolSchema.from_any_openapi_schema(openapi_schema) + assert schema.type == "object" + assert schema.description == "A simple object" + assert schema.properties is not None + assert "name" in schema.properties + assert "age" in schema.properties + assert schema.required == ["name"] + + def test_from_any_openapi_schema_nested(self): + """测试从嵌套 OpenAPI Schema 创建""" + openapi_schema = { + "type": "object", + "properties": { + "user": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "email": {"type": "string"}, + }, + } + }, + } + schema = ToolSchema.from_any_openapi_schema(openapi_schema) + assert schema.type == "object" + assert schema.properties is not None + assert "user" in schema.properties + assert schema.properties["user"].type == "object" + assert schema.properties["user"].properties is not None + assert "name" in schema.properties["user"].properties + + def test_from_any_openapi_schema_empty(self): + """测试从空 Schema 创建""" + schema = ToolSchema.from_any_openapi_schema(None) + assert schema.type == "string" + + # Empty dict creates a ToolSchema with None type (pydash_get returns None for missing keys) + schema = ToolSchema.from_any_openapi_schema({}) + # Actually returns "string" due to the check at the beginning of the method + assert schema.type == "string" + + def test_from_any_openapi_schema_non_dict(self): + """测试从非 dict 输入创建""" + schema = ToolSchema.from_any_openapi_schema("invalid") + assert schema.type == "string" + + schema = ToolSchema.from_any_openapi_schema(123) + assert schema.type == "string" + + def test_from_any_openapi_schema_array(self): + """测试从数组 Schema 创建""" + openapi_schema = { + "type": "array", + "items": {"type": "string"}, + "minItems": 1, + "maxItems": 10, + } + schema = ToolSchema.from_any_openapi_schema(openapi_schema) + assert schema.type == "array" + assert schema.items is not None + assert schema.items.type == "string" + assert schema.min_items == 1 + assert schema.max_items == 10 + + def test_from_any_openapi_schema_anyof(self): + """测试 anyOf 支持""" + openapi_schema = { + "anyOf": [ + {"type": "string"}, + {"type": "integer"}, + ] + } + schema = ToolSchema.from_any_openapi_schema(openapi_schema) + assert schema.any_of is not None + assert len(schema.any_of) == 2 + assert schema.any_of[0].type == "string" + assert schema.any_of[1].type == "integer" + + def test_from_any_openapi_schema_oneof(self): + """测试 oneOf 支持""" + openapi_schema = { + "oneOf": [ + {"type": "string"}, + {"type": "boolean"}, + ] + } + schema = ToolSchema.from_any_openapi_schema(openapi_schema) + assert schema.one_of is not None + assert len(schema.one_of) == 2 + assert schema.one_of[0].type == "string" + assert schema.one_of[1].type == "boolean" + + def test_from_any_openapi_schema_allof(self): + """测试 allOf 支持""" + openapi_schema = { + "allOf": [ + {"type": "object", "properties": {"name": {"type": "string"}}}, + {"type": "object", "properties": {"age": {"type": "integer"}}}, + ] + } + schema = ToolSchema.from_any_openapi_schema(openapi_schema) + assert schema.all_of is not None + assert len(schema.all_of) == 2 + assert schema.all_of[0].type == "object" + assert schema.all_of[1].type == "object" + + def test_to_json_schema_simple(self): + """测试转换为 JSON Schema - 简单情况""" + schema = ToolSchema( + type="string", + description="A string field", + min_length=1, + max_length=100, + ) + json_schema = schema.to_json_schema() + assert json_schema["type"] == "string" + assert json_schema["description"] == "A string field" + assert json_schema["minLength"] == 1 + assert json_schema["maxLength"] == 100 + + def test_to_json_schema_nested(self): + """测试转换为 JSON Schema - 嵌套情况""" + schema = ToolSchema( + type="object", + properties={ + "user": ToolSchema( + type="object", + properties={ + "name": ToolSchema(type="string"), + }, + ) + }, + ) + json_schema = schema.to_json_schema() + assert json_schema["type"] == "object" + assert "properties" in json_schema + assert "user" in json_schema["properties"] + assert json_schema["properties"]["user"]["type"] == "object" + assert "properties" in json_schema["properties"]["user"] + + def test_to_json_schema_roundtrip(self): + """测试完整往返转换""" + openapi_schema = { + "type": "object", + "description": "Test schema", + "properties": { + "name": { + "type": "string", + "description": "Name", + "minLength": 1, + }, + "age": { + "type": "integer", + "minimum": 0, + "maximum": 150, + }, + }, + "required": ["name"], + } + schema = ToolSchema.from_any_openapi_schema(openapi_schema) + json_schema = schema.to_json_schema() + assert json_schema["type"] == "object" + assert json_schema["description"] == "Test schema" + assert "properties" in json_schema + assert "name" in json_schema["properties"] + assert "age" in json_schema["properties"] + assert json_schema["required"] == ["name"] + + def test_to_json_schema_with_anyof(self): + """测试转换包含 anyOf 的 schema""" + schema = ToolSchema( + any_of=[ + ToolSchema(type="string"), + ToolSchema(type="integer"), + ] + ) + json_schema = schema.to_json_schema() + assert "anyOf" in json_schema + assert len(json_schema["anyOf"]) == 2 + assert json_schema["anyOf"][0]["type"] == "string" + assert json_schema["anyOf"][1]["type"] == "integer" + + def test_recursive_properties(self): + """测试递归嵌套 properties""" + schema = ToolSchema( + type="object", + properties={ + "level1": ToolSchema( + type="object", + properties={ + "level2": ToolSchema( + type="object", + properties={ + "level3": ToolSchema(type="string"), + }, + ), + }, + ), + }, + ) + json_schema = schema.to_json_schema() + assert json_schema["type"] == "object" + assert json_schema["properties"]["level1"]["type"] == "object" + assert ( + json_schema["properties"]["level1"]["properties"]["level2"]["type"] + == "object" + ) + assert ( + json_schema["properties"]["level1"]["properties"]["level2"][ + "properties" + ]["level3"]["type"] + == "string" + ) + + def test_to_json_schema_with_string_constraints(self): + """测试 pattern, min_length, max_length, format""" + schema = ToolSchema( + type="string", + pattern="^[a-zA-Z]+$", + min_length=1, + max_length=100, + format="email", + ) + json_schema = schema.to_json_schema() + assert json_schema["type"] == "string" + assert json_schema["pattern"] == "^[a-zA-Z]+$" + assert json_schema["minLength"] == 1 + assert json_schema["maxLength"] == 100 + assert json_schema["format"] == "email" + + def test_to_json_schema_with_number_constraints(self): + """测试 minimum, maximum, exclusive_minimum, exclusive_maximum""" + schema = ToolSchema( + type="number", + minimum=0, + maximum=100, + exclusive_minimum=0, + exclusive_maximum=100, + ) + json_schema = schema.to_json_schema() + assert json_schema["type"] == "number" + assert json_schema["minimum"] == 0 + assert json_schema["maximum"] == 100 + assert json_schema["exclusiveMinimum"] == 0 + assert json_schema["exclusiveMaximum"] == 100 + + def test_to_json_schema_with_enum(self): + """测试 enum 字段""" + schema = ToolSchema( + type="string", + enum=["red", "green", "blue"], + ) + json_schema = schema.to_json_schema() + assert json_schema["type"] == "string" + assert json_schema["enum"] == ["red", "green", "blue"] + + def test_to_json_schema_with_additional_properties(self): + """测试 additionalProperties""" + schema = ToolSchema( + type="object", + additional_properties=True, + ) + json_schema = schema.to_json_schema() + assert json_schema["type"] == "object" + assert json_schema["additionalProperties"] is True + + def test_to_json_schema_with_default(self): + """测试 default 字段""" + schema = ToolSchema( + type="string", + default="default_value", + ) + json_schema = schema.to_json_schema() + assert json_schema["type"] == "string" + assert json_schema["default"] == "default_value" + + def test_to_json_schema_with_title(self): + """测试 title 字段""" + schema = ToolSchema( + type="string", + title="String Field", + ) + json_schema = schema.to_json_schema() + assert json_schema["type"] == "string" + assert json_schema["title"] == "String Field" + + def test_to_json_schema_with_one_of(self): + """测试 oneOf 序列化""" + schema = ToolSchema( + one_of=[ + ToolSchema(type="string"), + ToolSchema(type="integer"), + ], + ) + json_schema = schema.to_json_schema() + assert "oneOf" in json_schema + assert len(json_schema["oneOf"]) == 2 + assert json_schema["oneOf"][0]["type"] == "string" + assert json_schema["oneOf"][1]["type"] == "integer" + + def test_to_json_schema_with_all_of(self): + """测试 allOf 序列化""" + schema = ToolSchema( + all_of=[ + ToolSchema( + type="object", + properties={"name": ToolSchema(type="string")}, + ), + ToolSchema( + type="object", + properties={"age": ToolSchema(type="integer")}, + ), + ], + ) + json_schema = schema.to_json_schema() + assert "allOf" in json_schema + assert len(json_schema["allOf"]) == 2 + assert json_schema["allOf"][0]["type"] == "object" + assert json_schema["allOf"][1]["type"] == "object" + + +class TestToolInfo: + """测试 ToolInfo 模型""" + + def test_default_values(self): + """测试默认值""" + info = ToolInfo() + assert info.name is None + assert info.description is None + assert info.parameters is None + + def test_with_values(self): + """测试带值创建""" + info = ToolInfo( + name="test_tool", + description="A test tool", + parameters=ToolSchema( + type="object", + properties={ + "input": ToolSchema(type="string"), + }, + ), + ) + assert info.name == "test_tool" + assert info.description == "A test tool" + assert info.parameters is not None + assert info.parameters.type == "object" + + def test_from_mcp_tool_with_object(self): + """测试从 MCP 工具对象创建""" + mcp_tool = { + "name": "mcp_tool", + "description": "An MCP tool", + "inputSchema": { + "type": "object", + "properties": { + "param": {"type": "string"}, + }, + }, + } + info = ToolInfo.from_mcp_tool(mcp_tool) + assert info.name == "mcp_tool" + assert info.description == "An MCP tool" + assert info.parameters is not None + assert info.parameters.type == "object" + + def test_from_mcp_tool_with_dict(self): + """测试从 MCP 工具字典创建""" + mcp_tool = { + "name": "dict_tool", + "description": "Dict tool", + "inputSchema": { + "type": "string", + }, + } + info = ToolInfo.from_mcp_tool(mcp_tool) + assert info.name == "dict_tool" + assert info.description == "Dict tool" + assert info.parameters is not None + assert info.parameters.type == "string" + + def test_from_mcp_tool_without_name(self): + """测试从没有 name 的 MCP 工具创建""" + mcp_tool = { + "description": "Tool without name", + "inputSchema": {"type": "string"}, + } + with pytest.raises(ValueError, match="name"): + ToolInfo.from_mcp_tool(mcp_tool) + + def test_from_mcp_tool_with_empty_schema(self): + """测试从空 schema 的 MCP 工具创建""" + mcp_tool = { + "name": "empty_schema_tool", + "description": "Tool with empty schema", + } + info = ToolInfo.from_mcp_tool(mcp_tool) + assert info.name == "empty_schema_tool" + assert info.description == "Tool with empty schema" + assert info.parameters is not None + assert info.parameters.type == "object" + + def test_from_mcp_tool_with_model_dump(self): + """测试 from_mcp_tool 当 input_schema 有 model_dump 方法时""" + + class MockInputSchema: + + def model_dump(self): + return { + "type": "object", + "properties": { + "param1": {"type": "string"}, + "param2": {"type": "integer"}, + }, + "required": ["param1"], + } + + mcp_tool = { + "name": "tool_with_model_dump", + "description": "Tool with model_dump input schema", + "inputSchema": MockInputSchema(), + } + info = ToolInfo.from_mcp_tool(mcp_tool) + assert info.name == "tool_with_model_dump" + assert info.description == "Tool with model_dump input schema" + assert info.parameters is not None + assert info.parameters.type == "object" + assert "param1" in info.parameters.properties + assert "param2" in info.parameters.properties + assert info.parameters.required == ["param1"] diff --git a/tests/unittests/tool/test_openapi.py b/tests/unittests/tool/test_openapi.py new file mode 100644 index 0000000..995c40d --- /dev/null +++ b/tests/unittests/tool/test_openapi.py @@ -0,0 +1,694 @@ +"""Tool OpenAPI 客户端单元测试 / Tool OpenAPI Client Unit Tests + +测试 ToolOpenAPIClient 的 OpenAPI Schema 解析和 HTTP 调用功能。 +Tests OpenAPI Schema parsing and HTTP call functionality of ToolOpenAPIClient. +""" + +import json +from unittest.mock import Mock, patch + +import httpx +import pytest + +from agentrun.tool.api.openapi import ToolOpenAPIClient +from agentrun.tool.model import ToolInfo, ToolSchema + + +class TestToolOpenAPIClient: + """测试 ToolOpenAPIClient""" + + @pytest.fixture + def sample_openapi_spec(self): + """示例 OpenAPI Spec""" + return json.dumps({ + "openapi": "3.0.0", + "info": { + "title": "Test API", + "version": "1.0.0", + }, + "servers": [ + {"url": "https://api.example.com/v1"}, + ], + "paths": { + "/users": { + "get": { + "operationId": "listUsers", + "summary": "List all users", + "description": "Get a list of users", + "parameters": [ + { + "name": "limit", + "in": "query", + "required": False, + "schema": {"type": "integer"}, + }, + ], + }, + "post": { + "operationId": "createUser", + "summary": "Create a user", + "description": "Create a new user", + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "email": {"type": "string"}, + }, + "required": ["name"], + }, + } + } + }, + }, + }, + "/users/{id}": { + "get": { + "operationId": "getUser", + "summary": "Get user by ID", + }, + "put": { + "operationId": "updateUser", + "summary": "Update user", + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "name": {"type": "string"}, + }, + }, + } + } + }, + }, + "delete": { + "operationId": "deleteUser", + "summary": "Delete user", + }, + }, + }, + }) + + @pytest.fixture + def sample_openapi_spec_no_servers(self): + """没有 servers 的 OpenAPI Spec""" + return json.dumps({ + "openapi": "3.0.0", + "info": { + "title": "Test API", + "version": "1.0.0", + }, + "paths": {}, + }) + + def test_init_with_valid_json(self, sample_openapi_spec): + """测试使用有效 JSON 初始化""" + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + assert client._spec is not None + assert client._spec["openapi"] == "3.0.0" + + def test_init_with_invalid_json(self): + """测试使用无效 JSON 初始化""" + client = ToolOpenAPIClient(protocol_spec="invalid json") + assert client._spec is None + + def test_init_with_none(self): + """测试使用 None 初始化""" + client = ToolOpenAPIClient(protocol_spec=None) + assert client._spec is None + + def test_init_with_headers(self, sample_openapi_spec): + """测试带 headers 初始化""" + headers = {"Authorization": "Bearer token"} + client = ToolOpenAPIClient( + protocol_spec=sample_openapi_spec, + headers=headers, + ) + assert client.headers == headers + + def test_server_url(self, sample_openapi_spec): + """测试获取 server URL""" + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + assert client.server_url == "https://api.example.com/v1" + + def test_server_url_no_servers(self, sample_openapi_spec_no_servers): + """测试没有 servers 时的 server URL""" + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec_no_servers) + assert client.server_url is None + + def test_server_url_no_spec(self): + """测试没有 spec 时的 server URL""" + client = ToolOpenAPIClient(protocol_spec=None) + assert client.server_url is None + + def test_parse_operations_get_method(self, sample_openapi_spec): + """测试解析 GET 方法""" + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + operations = client._parse_operations() + + get_operation = next( + (op for op in operations if op["operation_id"] == "listUsers"), + None, + ) + assert get_operation is not None + assert get_operation["method"] == "GET" + assert get_operation["path"] == "/users" + assert get_operation["summary"] == "List all users" + assert get_operation["input_schema"] is not None + assert "properties" in get_operation["input_schema"] + + def test_parse_operations_post_method_with_request_body( + self, sample_openapi_spec + ): + """测试解析 POST 方法(带 requestBody)""" + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + operations = client._parse_operations() + + post_operation = next( + (op for op in operations if op["operation_id"] == "createUser"), + None, + ) + assert post_operation is not None + assert post_operation["method"] == "POST" + assert post_operation["path"] == "/users" + assert post_operation["input_schema"] is not None + assert post_operation["input_schema"]["type"] == "object" + + def test_parse_operations_multiple_methods(self, sample_openapi_spec): + """测试解析多个 HTTP 方法""" + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + operations = client._parse_operations() + + operation_ids = [op["operation_id"] for op in operations] + assert "listUsers" in operation_ids + assert "createUser" in operation_ids + assert "getUser" in operation_ids + assert "updateUser" in operation_ids + assert "deleteUser" in operation_ids + + def test_parse_operations_parameters(self, sample_openapi_spec): + """测试解析 parameters""" + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + operations = client._parse_operations() + + list_users_op = next( + (op for op in operations if op["operation_id"] == "listUsers"), + None, + ) + assert list_users_op is not None + assert list_users_op["input_schema"] is not None + assert "properties" in list_users_op["input_schema"] + assert "limit" in list_users_op["input_schema"]["properties"] + + def test_parse_operations_no_spec(self): + """测试没有 spec 时的解析""" + client = ToolOpenAPIClient(protocol_spec=None) + operations = client._parse_operations() + assert operations == [] + + def test_list_tools(self, sample_openapi_spec): + """测试获取工具列表""" + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + tools = client.list_tools() + + assert len(tools) > 0 + assert all(isinstance(tool, ToolInfo) for tool in tools) + + # 检查特定工具 + list_users_tool = next( + (t for t in tools if t.name == "listUsers"), + None, + ) + assert list_users_tool is not None + assert list_users_tool.description == "List all users" + assert list_users_tool.parameters is not None + + def test_list_tools_empty_spec(self): + """测试空 spec 时的工具列表""" + client = ToolOpenAPIClient(protocol_spec='{"paths": {}}') + tools = client.list_tools() + assert tools == [] + + @patch("agentrun.tool.api.openapi.httpx.Client") + def test_call_tool_post_method( + self, mock_client_class, sample_openapi_spec + ): + """测试调用 POST 方法""" + # Mock httpx response + mock_response = Mock() + mock_response.json.return_value = {"id": 123, "name": "Test User"} + mock_response.headers = {"content-type": "application/json"} + mock_response.raise_for_status = Mock() + + mock_client_instance = Mock() + mock_client_instance.request.return_value = mock_response + mock_client_instance.__enter__ = Mock(return_value=mock_client_instance) + mock_client_instance.__exit__ = Mock(return_value=False) + mock_client_class.return_value = mock_client_instance + + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + result = client.call_tool( + "createUser", {"name": "Test User", "email": "test@example.com"} + ) + + assert result == {"id": 123, "name": "Test User"} + mock_client_instance.request.assert_called_once() + call_args = mock_client_instance.request.call_args + assert call_args[0][0] == "POST" + assert "https://api.example.com/v1/users" in call_args[0][1] + + @patch("agentrun.tool.api.openapi.httpx.Client") + def test_call_tool_get_method(self, mock_client_class, sample_openapi_spec): + """测试调用 GET 方法""" + # Mock httpx response + mock_response = Mock() + mock_response.json.return_value = {"id": 123, "name": "Test User"} + mock_response.headers = {"content-type": "application/json"} + mock_response.raise_for_status = Mock() + + mock_client_instance = Mock() + mock_client_instance.request.return_value = mock_response + mock_client_instance.__enter__ = Mock(return_value=mock_client_instance) + mock_client_instance.__exit__ = Mock(return_value=False) + mock_client_class.return_value = mock_client_instance + + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + result = client.call_tool("listUsers", {"limit": 10}) + + assert result == {"id": 123, "name": "Test User"} + mock_client_instance.request.assert_called_once() + call_args = mock_client_instance.request.call_args + assert call_args[0][0] == "GET" + assert "limit" in call_args[1]["params"] + + @patch("agentrun.tool.api.openapi.httpx.Client") + def test_call_tool_text_response( + self, mock_client_class, sample_openapi_spec + ): + """测试调用工具返回文本响应""" + # Mock httpx response + mock_response = Mock() + mock_response.text = "Plain text response" + mock_response.headers = {"content-type": "text/plain"} + mock_response.raise_for_status = Mock() + + mock_client_instance = Mock() + mock_client_instance.request.return_value = mock_response + mock_client_instance.__enter__ = Mock(return_value=mock_client_instance) + mock_client_instance.__exit__ = Mock(return_value=False) + mock_client_class.return_value = mock_client_instance + + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + result = client.call_tool("listUsers", {}) + + assert result == "Plain text response" + + def test_call_tool_operation_not_found(self, sample_openapi_spec): + """测试调用不存在的 operation""" + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + with pytest.raises( + ValueError, match="Operation 'nonExistent' not found" + ): + client.call_tool("nonExistent", {}) + + def test_call_tool_no_server_url(self): + """测试没有 server URL 时调用工具""" + spec_without_server = json.dumps({ + "openapi": "3.0.0", + "paths": { + "/test": { + "get": { + "operationId": "testOp", + }, + }, + }, + }) + client = ToolOpenAPIClient(protocol_spec=spec_without_server) + with pytest.raises(ValueError, match="No server URL found"): + client.call_tool("testOp", {}) + + @patch("httpx.AsyncClient") + async def test_call_tool_async_post_method( + self, mock_async_client_class, sample_openapi_spec + ): + """测试异步调用 POST 方法""" + # Mock async httpx response + mock_response = Mock() + mock_response.json = Mock(return_value={"id": 123}) + mock_response.headers = {"content-type": "application/json"} + mock_response.raise_for_status = Mock() + + # Create a proper async context manager mock + mock_client_instance = AsyncMock() + mock_client_instance.request = AsyncMock(return_value=mock_response) + mock_client_instance.__aenter__ = AsyncMock( + return_value=mock_client_instance + ) + mock_client_instance.__aexit__ = AsyncMock() + mock_async_client_class.return_value = mock_client_instance + + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + result = await client.call_tool_async("createUser", {"name": "Test"}) + + assert result == {"id": 123} + + @patch("agentrun.tool.api.openapi.httpx.AsyncClient") + async def test_call_tool_async_operation_not_found( + self, mock_async_client_class, sample_openapi_spec + ): + """测试异步调用不存在的 operation""" + mock_client_instance = AsyncMock() + mock_async_client_class.return_value = mock_client_instance + + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + with pytest.raises( + ValueError, match="Operation 'nonExistent' not found" + ): + await client.call_tool_async("nonExistent", {}) + + @patch("agentrun.tool.api.openapi.httpx.AsyncClient") + async def test_call_tool_async_no_server_url(self, mock_async_client_class): + """测试异步调用没有 server URL""" + spec_without_server = json.dumps({ + "openapi": "3.0.0", + "paths": { + "/test": { + "get": { + "operationId": "testOp", + }, + }, + }, + }) + mock_client_instance = AsyncMock() + mock_async_client_class.return_value = mock_client_instance + + client = ToolOpenAPIClient(protocol_spec=spec_without_server) + with pytest.raises(ValueError, match="No server URL found"): + await client.call_tool_async("testOp", {}) + + async def test_list_tools_async(self, sample_openapi_spec): + """测试异步获取工具列表""" + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + tools = await client.list_tools_async() + + assert len(tools) > 0 + assert all(isinstance(tool, ToolInfo) for tool in tools) + + def test_resolve_ref(self): + """测试 _resolve_ref 解析 $ref 引用""" + spec = json.dumps({ + "openapi": "3.0.0", + "components": { + "schemas": { + "User": { + "type": "object", + "properties": { + "name": {"type": "string"}, + }, + } + } + }, + }) + client = ToolOpenAPIClient(protocol_spec=spec) + ref = "#/components/schemas/User" + resolved = client._resolve_ref(ref) + assert resolved is not None + assert resolved["type"] == "object" + assert "name" in resolved["properties"] + + def test_resolve_ref_invalid(self): + """测试 _resolve_ref 无效引用""" + spec = json.dumps({ + "openapi": "3.0.0", + "components": {"schemas": {"User": {"type": "object"}}}, + }) + client = ToolOpenAPIClient(protocol_spec=spec) + ref = "#/components/schemas/NonExistent" + resolved = client._resolve_ref(ref) + assert resolved == {} + + def test_resolve_ref_no_spec(self): + """测试 _resolve_ref 没有 spec""" + client = ToolOpenAPIClient(protocol_spec=None) + ref = "#/components/schemas/User" + resolved = client._resolve_ref(ref) + assert resolved == {} + + def test_resolve_schema_with_ref(self): + """测试 _resolve_schema 递归解析 $ref""" + spec = json.dumps({ + "openapi": "3.0.0", + "components": { + "schemas": { + "User": { + "type": "object", + "properties": {"name": {"type": "string"}}, + } + } + }, + }) + client = ToolOpenAPIClient(protocol_spec=spec) + schema = {"$ref": "#/components/schemas/User"} + resolved = client._resolve_schema(schema) + assert resolved is not None + assert resolved["type"] == "object" + + def test_resolve_schema_none(self): + """测试 _resolve_schema 传入 None""" + spec = json.dumps({"openapi": "3.0.0"}) + client = ToolOpenAPIClient(protocol_spec=spec) + resolved = client._resolve_schema(None) + assert resolved is None + + def test_resolve_schema_with_items(self): + """测试 _resolve_schema 解析 items 中的 $ref""" + spec = json.dumps({ + "openapi": "3.0.0", + "components": {"schemas": {"Item": {"type": "string"}}}, + }) + client = ToolOpenAPIClient(protocol_spec=spec) + schema = { + "type": "array", + "items": {"$ref": "#/components/schemas/Item"}, + } + resolved = client._resolve_schema(schema) + assert resolved is not None + assert resolved["type"] == "array" + assert resolved["items"]["type"] == "string" + + def test_resolve_schema_with_anyof(self): + """测试 _resolve_schema 解析 anyOf 中的 $ref""" + spec = json.dumps({ + "openapi": "3.0.0", + "components": { + "schemas": { + "StringType": {"type": "string"}, + "NumberType": {"type": "number"}, + } + }, + }) + client = ToolOpenAPIClient(protocol_spec=spec) + schema = { + "anyOf": [ + {"$ref": "#/components/schemas/StringType"}, + {"$ref": "#/components/schemas/NumberType"}, + ] + } + resolved = client._resolve_schema(schema) + assert resolved is not None + assert "anyOf" in resolved + assert len(resolved["anyOf"]) == 2 + assert resolved["anyOf"][0]["type"] == "string" + assert resolved["anyOf"][1]["type"] == "number" + + def test_server_url_fallback(self): + """测试 server_url 使用 fallback_server_url""" + spec = json.dumps( + {"openapi": "3.0.0", "info": {"title": "Test API"}, "paths": {}} + ) + client = ToolOpenAPIClient( + protocol_spec=spec, + fallback_server_url="https://fallback.example.com", + ) + assert client.server_url == "https://fallback.example.com" + + def test_server_url_empty_servers_list(self): + """测试 servers 为空列表时使用 fallback""" + spec = json.dumps({ + "openapi": "3.0.0", + "info": {"title": "Test API"}, + "servers": [], + "paths": {}, + }) + client = ToolOpenAPIClient( + protocol_spec=spec, + fallback_server_url="https://fallback.example.com", + ) + assert client.server_url == "https://fallback.example.com" + + @patch("agentrun.tool.api.openapi.httpx.Client") + def test_call_tool_put_method(self, mock_client_class, sample_openapi_spec): + """测试 PUT 方法调用(走 POST/PUT/PATCH 分支)""" + mock_response = Mock() + mock_response.json.return_value = {"success": True} + mock_response.headers = {"content-type": "application/json"} + mock_response.raise_for_status = Mock() + + mock_client_instance = Mock() + mock_client_instance.request.return_value = mock_response + mock_client_instance.__enter__ = Mock(return_value=mock_client_instance) + mock_client_instance.__exit__ = Mock(return_value=False) + mock_client_class.return_value = mock_client_instance + + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + result = client.call_tool("updateUser", {"name": "Updated Name"}) + + assert result == {"success": True} + mock_client_instance.request.assert_called_once() + call_args = mock_client_instance.request.call_args + assert call_args[0][0] == "PUT" + + @patch("agentrun.tool.api.openapi.httpx.Client") + def test_call_tool_delete_method( + self, mock_client_class, sample_openapi_spec + ): + """测试 DELETE 方法调用(走 GET/DELETE 分支)""" + mock_response = Mock() + mock_response.json.return_value = {"success": True} + mock_response.headers = {"content-type": "application/json"} + mock_response.raise_for_status = Mock() + + mock_client_instance = Mock() + mock_client_instance.request.return_value = mock_response + mock_client_instance.__enter__ = Mock(return_value=mock_client_instance) + mock_client_instance.__exit__ = Mock(return_value=False) + mock_client_class.return_value = mock_client_instance + + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + result = client.call_tool("deleteUser", {}) + + assert result == {"success": True} + mock_client_instance.request.assert_called_once() + call_args = mock_client_instance.request.call_args + assert call_args[0][0] == "DELETE" + + @patch("agentrun.tool.api.openapi.httpx.AsyncClient") + async def test_call_tool_async_get_method( + self, mock_async_client_class, sample_openapi_spec + ): + """测试异步 GET 方法调用""" + mock_response = Mock() + mock_response.json.return_value = {"id": 123} + mock_response.headers = {"content-type": "application/json"} + mock_response.raise_for_status = Mock() + + mock_client_instance = Mock() + mock_client_instance.request = AsyncMock(return_value=mock_response) + mock_client_instance.__aenter__ = AsyncMock( + return_value=mock_client_instance + ) + mock_client_instance.__aexit__ = AsyncMock(return_value=False) + mock_async_client_class.return_value = mock_client_instance + + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + result = await client.call_tool_async("getUser", {"id": "123"}) + + assert result == {"id": 123} + mock_client_instance.request.assert_called_once() + call_args = mock_client_instance.request.call_args + assert call_args[0][0] == "GET" + + @patch("agentrun.tool.api.openapi.httpx.AsyncClient") + async def test_call_tool_async_text_response( + self, mock_async_client_class, sample_openapi_spec + ): + """测试异步调用返回 text 响应""" + mock_response = Mock() + mock_response.text = "plain text response" + mock_response.headers = {"content-type": "text/plain"} + mock_response.json.side_effect = ValueError("No JSON") + mock_response.raise_for_status = Mock() + + mock_client_instance = Mock() + mock_client_instance.request = AsyncMock(return_value=mock_response) + mock_client_instance.__aenter__ = AsyncMock( + return_value=mock_client_instance + ) + mock_client_instance.__aexit__ = AsyncMock(return_value=False) + mock_async_client_class.return_value = mock_client_instance + + client = ToolOpenAPIClient(protocol_spec=sample_openapi_spec) + result = await client.call_tool_async("listUsers", {"limit": 10}) + + assert result == "plain text response" + + def test_parse_operations_no_operation_id(self): + """测试没有 operationId 时使用默认值""" + spec = json.dumps({ + "openapi": "3.0.0", + "info": {"title": "Test API"}, + "servers": [{"url": "https://api.example.com"}], + "paths": { + "/test": {"get": {"summary": "Test without operationId"}} + }, + }) + client = ToolOpenAPIClient(protocol_spec=spec) + operations = client._parse_operations() + + assert len(operations) == 1 + assert operations[0]["operation_id"] is not None + assert operations[0]["method"] == "GET" + + def test_parse_operations_invalid_path_item(self): + """测试无效的 path_item(非 dict)""" + spec = json.dumps({ + "openapi": "3.0.0", + "info": {"title": "Test API"}, + "servers": [{"url": "https://api.example.com"}], + "paths": {"/test": "invalid"}, + }) + client = ToolOpenAPIClient(protocol_spec=spec) + operations = client._parse_operations() + + assert operations == [] + + def test_parse_operations_required_parameters(self): + """测试 required 参数的解析""" + spec = json.dumps({ + "openapi": "3.0.0", + "info": {"title": "Test API"}, + "servers": [{"url": "https://api.example.com"}], + "paths": { + "/users/{id}": { + "get": { + "operationId": "getUserById", + "parameters": [{ + "name": "id", + "in": "path", + "required": True, + "schema": {"type": "string"}, + }], + } + } + }, + }) + client = ToolOpenAPIClient(protocol_spec=spec) + operations = client._parse_operations() + + assert len(operations) == 1 + op = operations[0] + assert op["operation_id"] == "getUserById" + assert op["input_schema"] is not None + assert "id" in op["input_schema"]["properties"] + assert "id" in op["input_schema"]["required"] + + +class AsyncMock(Mock): + """Async mock helper""" + + async def __call__(self, *args, **kwargs): + return super().__call__(*args, **kwargs) diff --git a/tests/unittests/tool/test_tool.py b/tests/unittests/tool/test_tool.py new file mode 100644 index 0000000..e3e2bf2 --- /dev/null +++ b/tests/unittests/tool/test_tool.py @@ -0,0 +1,1018 @@ +"""Tool 资源类和客户端单元测试 / Tool Resource Class and Client Unit Tests + +测试 Tool 资源类和 ToolClient 的功能。 +Tests functionality of Tool resource class and ToolClient. +""" + +from unittest.mock import AsyncMock, Mock, patch + +import pytest + +from agentrun.tool.client import ToolClient +from agentrun.tool.model import ( + McpConfig, + ToolCodeConfiguration, + ToolContainerConfiguration, + ToolInfo, + ToolLogConfiguration, + ToolNetworkConfiguration, + ToolOSSMountConfig, + ToolSchema, + ToolType, +) +from agentrun.tool.tool import Tool + + +class TestTool: + """测试 Tool 资源类""" + + def test_tool_attributes_default(self): + """测试 Tool 默认属性""" + tool = Tool() + assert tool.tool_id is None + assert tool.name is None + assert tool.tool_name is None + assert tool.description is None + assert tool.tool_type is None + assert tool.status is None + assert tool.code_configuration is None + assert tool.container_configuration is None + assert tool.mcp_config is None + assert tool.log_configuration is None + assert tool.network_config is None + assert tool.oss_mount_config is None + assert tool.data_endpoint is None + assert tool.protocol_spec is None + assert tool.protocol_type is None + assert tool.memory is None + assert tool.gpu is None + assert tool.timeout is None + assert tool.internet_access is None + assert tool.environment_variables is None + assert tool.created_time is None + assert tool.last_modified_time is None + assert tool.version_id is None + + def test_tool_attributes_with_values(self): + """测试 Tool 带值创建""" + tool = Tool( + tool_id="tool-123", + name="my-tool", + tool_name="my-tool", + description="A test tool", + tool_type="MCP", + status="READY", + data_endpoint="https://example.com/data", + memory=1024, + gpu="T4", + timeout=60, + internet_access=True, + environment_variables={"KEY": "value"}, + ) + assert tool.tool_id == "tool-123" + assert tool.name == "my-tool" + assert tool.tool_name == "my-tool" + assert tool.description == "A test tool" + assert tool.tool_type == "MCP" + assert tool.status == "READY" + assert tool.data_endpoint == "https://example.com/data" + assert tool.memory == 1024 + assert tool.gpu == "T4" + assert tool.timeout == 60 + assert tool.internet_access is True + assert tool.environment_variables == {"KEY": "value"} + + def test_get_tool_type_mcp(self): + """测试获取 MCP 工具类型""" + tool = Tool(tool_type="MCP") + assert tool._get_tool_type() == ToolType.MCP + + def test_get_tool_type_functioncall(self): + """测试获取 FUNCTIONCALL 工具类型""" + tool = Tool(tool_type="FUNCTIONCALL") + assert tool._get_tool_type() == ToolType.FUNCTIONCALL + + def test_get_tool_type_invalid(self): + """测试获取无效工具类型""" + tool = Tool(tool_type="INVALID") + assert tool._get_tool_type() is None + + def test_get_tool_type_none(self): + """测试获取 None 工具类型""" + tool = Tool() + assert tool._get_tool_type() is None + + def test_get_mcp_endpoint_sse(self): + """测试获取 MCP SSE endpoint""" + tool = Tool( + tool_name="my-tool", + data_endpoint="https://example.com", + mcp_config=McpConfig(session_affinity="MCP_SSE"), + ) + endpoint = tool._get_mcp_endpoint() + assert endpoint == "https://example.com/tools/my-tool/sse" + + def test_get_mcp_endpoint_streamable(self): + """测试获取 MCP Streamable endpoint""" + tool = Tool( + tool_name="my-tool", + data_endpoint="https://example.com", + mcp_config=McpConfig(session_affinity="MCP_STREAMABLE"), + ) + endpoint = tool._get_mcp_endpoint() + assert endpoint == "https://example.com/tools/my-tool/mcp" + + def test_get_mcp_endpoint_default(self): + """测试获取 MCP endpoint(默认 SSE)""" + tool = Tool( + tool_name="my-tool", + data_endpoint="https://example.com", + ) + endpoint = tool._get_mcp_endpoint() + assert endpoint == "https://example.com/tools/my-tool/sse" + + def test_get_mcp_endpoint_no_name(self): + """测试没有 name 时获取 MCP endpoint""" + tool = Tool( + data_endpoint="https://example.com", + ) + endpoint = tool._get_mcp_endpoint() + assert endpoint is None + + def test_get_mcp_endpoint_no_data_endpoint(self): + """测试没有 data_endpoint 时获取 MCP endpoint""" + tool = Tool( + tool_name="my-tool", + ) + endpoint = tool._get_mcp_endpoint() + assert endpoint is None + + def test_from_inner_object(self): + """测试从内部对象创建 Tool""" + inner_tool = Mock() + inner_tool.tool_id = "tool-123" + inner_tool.name = "my-tool" + inner_tool.description = "Test tool" + inner_tool.tool_type = "MCP" + inner_tool.status = "READY" + inner_tool.data_endpoint = "https://example.com/data" + inner_tool.memory = 1024 + inner_tool.gpu = "T4" + inner_tool.timeout = 60 + inner_tool.internet_access = True + inner_tool.environment_variables = {"KEY": "value"} + inner_tool.created_time = "2024-01-01T00:00:00Z" + inner_tool.last_modified_time = "2024-01-02T00:00:00Z" + inner_tool.version_id = "version-123" + inner_tool.protocol_spec = '{"openapi": "3.0.0"}' + inner_tool.protocol_type = "openapi" + + # Mock configurations + inner_tool.code_configuration = None + inner_tool.container_configuration = None + inner_tool.mcp_config = None + inner_tool.log_configuration = None + inner_tool.network_config = None + inner_tool.oss_mount_config = None + + # Mock to_map method + inner_tool.to_map = Mock( + return_value={ + "toolId": "tool-123", + "name": "my-tool", + "description": "Test tool", + "toolType": "MCP", + "status": "READY", + "dataEndpoint": "https://example.com/data", + "memory": 1024, + "gpu": "T4", + "timeout": 60, + "internetAccess": True, + "environmentVariables": {"KEY": "value"}, + "createdTime": "2024-01-01T00:00:00Z", + "lastModifiedTime": "2024-01-02T00:00:00Z", + "versionId": "version-123", + "protocolSpec": '{"openapi": "3.0.0"}', + "protocolType": "openapi", + } + ) + + tool = Tool.from_inner_object(inner_tool) + + assert tool.tool_id == "tool-123" + assert tool.name == "my-tool" + assert tool.description == "Test tool" + assert tool.tool_type == "MCP" + assert tool.status == "READY" + assert tool.data_endpoint == "https://example.com/data" + assert tool.memory == 1024 + assert tool.gpu == "T4" + assert tool.timeout == 60 + assert tool.internet_access is True + assert tool.environment_variables == {"KEY": "value"} + assert tool.created_time == "2024-01-01T00:00:00Z" + assert tool.last_modified_time == "2024-01-02T00:00:00Z" + assert tool.version_id == "version-123" + assert tool.protocol_spec == '{"openapi": "3.0.0"}' + assert tool.protocol_type == "openapi" + + @patch("agentrun.tool.api.mcp.ToolMCPSession") + @patch("agentrun.utils.config.Config") + def test_list_tools_mcp(self, mock_config_class, mock_mcp_session_class): + """测试获取 MCP 工具列表""" + mock_session = Mock() + mock_session.list_tools.return_value = [ + ToolInfo(name="tool1", description="Tool 1"), + ToolInfo(name="tool2", description="Tool 2"), + ] + mock_mcp_session_class.return_value = mock_session + + mock_config = Mock() + mock_config.get_headers.return_value = {} + mock_config_class.with_configs.return_value = mock_config + + tool = Tool( + tool_name="my-tool", + tool_type="MCP", + data_endpoint="https://example.com", + mcp_config=McpConfig(session_affinity="MCP_SSE"), + ) + + tools = tool.list_tools() + + assert len(tools) == 2 + assert tools[0].name == "tool1" + assert tools[1].name == "tool2" + + @patch("agentrun.tool.api.openapi.ToolOpenAPIClient") + def test_list_tools_functioncall(self, mock_openapi_client_class): + """测试获取 FUNCTIONCALL 工具列表""" + mock_client = Mock() + mock_client.list_tools.return_value = [ + ToolInfo(name="tool1", description="Tool 1"), + ToolInfo(name="tool2", description="Tool 2"), + ] + mock_openapi_client_class.return_value = mock_client + + tool = Tool( + tool_type="FUNCTIONCALL", + protocol_spec='{"openapi": "3.0.0"}', + ) + + tools = tool.list_tools() + + assert len(tools) == 2 + assert tools[0].name == "tool1" + assert tools[1].name == "tool2" + + def test_list_tools_no_type(self): + """测试没有工具类型时获取工具列表""" + tool = Tool() + tools = tool.list_tools() + assert tools == [] + + @patch("agentrun.tool.api.mcp.ToolMCPSession") + @patch("agentrun.utils.config.Config") + def test_call_tool_mcp(self, mock_config_class, mock_mcp_session_class): + """测试调用 MCP 工具""" + mock_session = Mock() + mock_session.call_tool.return_value = {"result": "success"} + mock_mcp_session_class.return_value = mock_session + + mock_config = Mock() + mock_config.get_headers.return_value = {} + mock_config_class.with_configs.return_value = mock_config + + tool = Tool( + tool_name="my-tool", + tool_type="MCP", + data_endpoint="https://example.com", + mcp_config=McpConfig(session_affinity="MCP_SSE"), + ) + + result = tool.call_tool("tool1", {"param": "value"}) + + assert result == {"result": "success"} + + @patch("agentrun.tool.api.openapi.ToolOpenAPIClient") + @patch("agentrun.utils.config.Config") + def test_call_tool_functioncall( + self, mock_config_class, mock_openapi_client_class + ): + """测试调用 FUNCTIONCALL 工具""" + mock_client = Mock() + mock_client.call_tool.return_value = {"result": "success"} + mock_openapi_client_class.return_value = mock_client + + mock_config = Mock() + mock_config.get_headers.return_value = {} + mock_config_class.with_configs.return_value = mock_config + + tool = Tool( + tool_type="FUNCTIONCALL", + protocol_spec='{"openapi": "3.0.0"}', + ) + + result = tool.call_tool("tool1", {"param": "value"}) + + assert result == {"result": "success"} + + def test_call_tool_unsupported_type(self): + """测试调用不支持的类型工具""" + tool = Tool(tool_type="UNSUPPORTED") + with pytest.raises(ValueError, match="Unsupported tool type"): + tool.call_tool("tool1", {}) + + @patch("agentrun.tool.api.mcp.ToolMCPSession") + @patch("agentrun.utils.config.Config") + async def test_list_tools_async_mcp( + self, mock_config_class, mock_mcp_session_class + ): + """测试异步获取 MCP 工具列表""" + mock_session = Mock() + mock_session.list_tools_async = AsyncMock( + return_value=[ + ToolInfo(name="tool1", description="Tool 1"), + ] + ) + mock_mcp_session_class.return_value = mock_session + + mock_config = Mock() + mock_config.get_headers.return_value = {} + mock_config_class.with_configs.return_value = mock_config + + tool = Tool( + tool_name="my-tool", + tool_type="MCP", + data_endpoint="https://example.com", + mcp_config=McpConfig(session_affinity="MCP_SSE"), + ) + + tools = await tool.list_tools_async() + + assert len(tools) == 1 + assert tools[0].name == "tool1" + + @patch("agentrun.tool.api.mcp.ToolMCPSession") + @patch("agentrun.utils.config.Config") + async def test_call_tool_async_mcp( + self, mock_config_class, mock_mcp_session_class + ): + """测试异步调用 MCP 工具""" + mock_session = Mock() + mock_session.call_tool_async = AsyncMock( + return_value={"result": "success"} + ) + mock_mcp_session_class.return_value = mock_session + + mock_config = Mock() + mock_config.get_headers.return_value = {} + mock_config_class.with_configs.return_value = mock_config + + tool = Tool( + tool_name="my-tool", + tool_type="MCP", + data_endpoint="https://example.com", + mcp_config=McpConfig(session_affinity="MCP_SSE"), + ) + + result = await tool.call_tool_async("tool1", {"param": "value"}) + + assert result == {"result": "success"} + + # ==================== SKILL 相关测试 ==================== + + def test_get_tool_type_skill(self): + """测试获取 SKILL 工具类型""" + tool = Tool(tool_type="SKILL") + assert tool._get_tool_type() == ToolType.SKILL + + def test_get_skill_download_url_with_data_endpoint(self): + """测试使用 data_endpoint 构造 skill 下载 URL""" + tool = Tool( + tool_name="my-skill", + data_endpoint="https://example.com", + ) + url = tool._get_skill_download_url() + assert url == "https://example.com/tools/my-skill/download" + + def test_get_skill_download_url_uses_name_fallback(self): + """测试 tool_name 为空时使用 name 作为 fallback""" + tool = Tool( + name="fallback-skill", + data_endpoint="https://example.com", + ) + url = tool._get_skill_download_url() + assert url == "https://example.com/tools/fallback-skill/download" + + def test_get_skill_download_url_tool_name_takes_priority(self): + """测试 tool_name 优先于 name""" + tool = Tool( + tool_name="primary-skill", + name="fallback-skill", + data_endpoint="https://example.com", + ) + url = tool._get_skill_download_url() + assert url == "https://example.com/tools/primary-skill/download" + + @patch("agentrun.tool.tool.Config") + def test_get_skill_download_url_config_fallback(self, mock_config_class): + """测试 data_endpoint 为空时从 Config 获取""" + mock_config = Mock() + mock_config._data_endpoint = "https://config-endpoint.com" + mock_config_class.with_configs.return_value = mock_config + + tool = Tool(tool_name="my-skill") + url = tool._get_skill_download_url() + assert url == "https://config-endpoint.com/tools/my-skill/download" + + def test_get_skill_download_url_no_name(self): + """测试没有 name 时返回 None""" + tool = Tool(data_endpoint="https://example.com") + url = tool._get_skill_download_url() + assert url is None + + @patch("agentrun.tool.tool.Config") + def test_get_skill_download_url_no_endpoint(self, mock_config_class): + """测试没有 data_endpoint 且 Config 也没有时返回 None""" + mock_config = Mock() + mock_config._data_endpoint = None + mock_config_class.with_configs.return_value = mock_config + + tool = Tool(tool_name="my-skill") + url = tool._get_skill_download_url() + assert url is None + + @patch("httpx.AsyncClient") + @patch("agentrun.utils.config.Config") + async def test_download_skill_async_success( + self, mock_config_class, mock_async_client_class + ): + """测试成功下载并解压 skill 包""" + import io + import os + import shutil + import tempfile + import zipfile + + mock_config = Mock() + mock_config.get_headers.return_value = {"Authorization": "Bearer token"} + mock_config_class.with_configs.return_value = mock_config + + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "w") as zf: + zf.writestr("SKILL.md", "# Test Skill") + zf.writestr("main.py", "print('hello')") + zip_content = zip_buffer.getvalue() + + mock_response = Mock() + mock_response.content = zip_content + mock_response.raise_for_status = Mock() + + mock_client_instance = AsyncMock() + mock_client_instance.get.return_value = mock_response + mock_client_instance.__aenter__ = AsyncMock( + return_value=mock_client_instance + ) + mock_client_instance.__aexit__ = AsyncMock(return_value=False) + mock_async_client_class.return_value = mock_client_instance + + tool = Tool( + tool_name="test-skill", + tool_type="SKILL", + data_endpoint="https://example.com", + ) + + tmp_dir = tempfile.mkdtemp() + try: + result = await tool.download_skill_async(target_dir=tmp_dir) + + expected_dir = os.path.join(tmp_dir, "test-skill") + assert result == expected_dir + assert os.path.exists(expected_dir) + assert os.path.isfile(os.path.join(expected_dir, "SKILL.md")) + assert os.path.isfile(os.path.join(expected_dir, "main.py")) + + with open(os.path.join(expected_dir, "SKILL.md")) as f: + assert f.read() == "# Test Skill" + finally: + shutil.rmtree(tmp_dir) + + @patch("httpx.AsyncClient") + @patch("agentrun.utils.config.Config") + async def test_download_skill_async_overwrites_existing( + self, mock_config_class, mock_async_client_class + ): + """测试下载 skill 时覆盖已存在的目录""" + import io + import os + import shutil + import tempfile + import zipfile + + mock_config = Mock() + mock_config.get_headers.return_value = {} + mock_config_class.with_configs.return_value = mock_config + + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "w") as zf: + zf.writestr("new_file.txt", "new content") + zip_content = zip_buffer.getvalue() + + mock_response = Mock() + mock_response.content = zip_content + mock_response.raise_for_status = Mock() + + mock_client_instance = AsyncMock() + mock_client_instance.get.return_value = mock_response + mock_client_instance.__aenter__ = AsyncMock( + return_value=mock_client_instance + ) + mock_client_instance.__aexit__ = AsyncMock(return_value=False) + mock_async_client_class.return_value = mock_client_instance + + tool = Tool( + tool_name="test-skill", + tool_type="SKILL", + data_endpoint="https://example.com", + ) + + tmp_dir = tempfile.mkdtemp() + try: + existing_dir = os.path.join(tmp_dir, "test-skill") + os.makedirs(existing_dir) + with open(os.path.join(existing_dir, "old_file.txt"), "w") as f: + f.write("old content") + + result = await tool.download_skill_async(target_dir=tmp_dir) + + assert os.path.isfile(os.path.join(result, "new_file.txt")) + assert not os.path.exists(os.path.join(result, "old_file.txt")) + finally: + shutil.rmtree(tmp_dir) + + async def test_download_skill_async_wrong_type(self): + """测试非 SKILL 类型调用 download_skill_async 抛出 ValueError""" + tool = Tool(tool_type="MCP", tool_name="my-tool") + + with pytest.raises(ValueError, match="only available for SKILL"): + await tool.download_skill_async() + + async def test_download_skill_async_no_url(self): + """测试无法构造下载 URL 时抛出 ValueError""" + tool = Tool(tool_type="SKILL") + + with pytest.raises(ValueError, match="Cannot construct download URL"): + await tool.download_skill_async() + + @patch("httpx.AsyncClient") + @patch("agentrun.utils.config.Config") + async def test_download_skill_async_http_error( + self, mock_config_class, mock_async_client_class + ): + """测试下载失败时抛出 HTTPStatusError""" + import httpx + + mock_config = Mock() + mock_config.get_headers.return_value = {} + mock_config_class.with_configs.return_value = mock_config + + mock_response = Mock() + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "Not Found", + request=Mock(), + response=Mock(status_code=404), + ) + + mock_client_instance = AsyncMock() + mock_client_instance.get.return_value = mock_response + mock_client_instance.__aenter__ = AsyncMock( + return_value=mock_client_instance + ) + mock_client_instance.__aexit__ = AsyncMock(return_value=False) + mock_async_client_class.return_value = mock_client_instance + + tool = Tool( + tool_name="test-skill", + tool_type="SKILL", + data_endpoint="https://example.com", + ) + + with pytest.raises(httpx.HTTPStatusError): + await tool.download_skill_async() + + @patch("agentrun.tool.tool.httpx.Client") + @patch("agentrun.tool.tool.Config") + def test_download_skill_sync_success( + self, mock_config_class, mock_client_class + ): + """测试同步版本 download_skill 成功""" + import io + import os + import shutil + import tempfile + import zipfile + + mock_config = Mock() + mock_config.get_headers.return_value = {} + mock_config_class.with_configs.return_value = mock_config + + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "w") as zf: + zf.writestr("skill.py", "print('skill')") + zip_content = zip_buffer.getvalue() + + mock_response = Mock() + mock_response.content = zip_content + mock_response.raise_for_status = Mock() + + mock_client_instance = Mock() + mock_client_instance.get.return_value = mock_response + mock_client_instance.__enter__ = Mock(return_value=mock_client_instance) + mock_client_instance.__exit__ = Mock(return_value=False) + mock_client_class.return_value = mock_client_instance + + tool = Tool( + tool_name="sync-skill", + tool_type="SKILL", + data_endpoint="https://example.com", + ) + + tmp_dir = tempfile.mkdtemp() + try: + result = tool.download_skill(target_dir=tmp_dir) + + expected_dir = os.path.join(tmp_dir, "sync-skill") + assert result == expected_dir + assert os.path.isfile(os.path.join(expected_dir, "skill.py")) + finally: + shutil.rmtree(tmp_dir) + + def test_download_skill_sync_wrong_type(self): + """测试同步版本非 SKILL 类型抛出 ValueError""" + tool = Tool(tool_type="FUNCTIONCALL", tool_name="my-tool") + + with pytest.raises(ValueError, match="only available for SKILL"): + tool.download_skill() + + +class TestToolClient: + """测试 ToolClient""" + + def test_client_init(self): + """测试客户端初始化""" + client = ToolClient() + assert client is not None + + @patch("agentrun.tool.client.ToolControlAPI") + def test_get(self, mock_control_api_class): + """测试获取工具""" + # Mock inner tool + inner_tool = Mock() + inner_tool.tool_id = "tool-123" + inner_tool.name = "my-tool" + inner_tool.description = "Test tool" + inner_tool.tool_type = "MCP" + inner_tool.status = "READY" + inner_tool.data_endpoint = "https://example.com/data" + inner_tool.memory = 1024 + inner_tool.gpu = None + inner_tool.timeout = 60 + inner_tool.internet_access = True + inner_tool.environment_variables = None + inner_tool.created_time = None + inner_tool.last_modified_time = None + inner_tool.version_id = None + inner_tool.protocol_spec = None + inner_tool.protocol_type = None + inner_tool.code_configuration = None + inner_tool.container_configuration = None + inner_tool.mcp_config = None + inner_tool.log_configuration = None + inner_tool.network_config = None + inner_tool.oss_mount_config = None + + # Mock to_map method + inner_tool.to_map = Mock( + return_value={ + "toolId": "tool-123", + "name": "my-tool", + "description": "Test tool", + "toolType": "MCP", + "status": "READY", + "dataEndpoint": "https://example.com/data", + "memory": 1024, + "timeout": 60, + "internetAccess": True, + } + ) + + mock_api = Mock() + mock_api.get_tool.return_value = inner_tool + mock_control_api_class.return_value = mock_api + + client = ToolClient() + tool = client.get(name="my-tool") + + assert tool.tool_id == "tool-123" + assert tool.name == "my-tool" + assert tool.tool_type == "MCP" + mock_api.get_tool.assert_called_once_with(name="my-tool", config=None) + + @patch("agentrun.tool.client.ToolControlAPI") + async def test_get_async(self, mock_control_api_class): + """测试异步获取工具""" + # Mock inner tool + inner_tool = Mock() + inner_tool.tool_id = "tool-123" + inner_tool.name = "my-tool" + inner_tool.description = "Test tool" + inner_tool.tool_type = "MCP" + inner_tool.status = "READY" + inner_tool.data_endpoint = "https://example.com/data" + inner_tool.memory = 1024 + inner_tool.gpu = None + inner_tool.timeout = 60 + inner_tool.internet_access = True + inner_tool.environment_variables = None + inner_tool.created_time = None + inner_tool.last_modified_time = None + inner_tool.version_id = None + inner_tool.protocol_spec = None + inner_tool.protocol_type = None + inner_tool.code_configuration = None + inner_tool.container_configuration = None + inner_tool.mcp_config = None + inner_tool.log_configuration = None + inner_tool.network_config = None + inner_tool.oss_mount_config = None + + # Mock to_map method + inner_tool.to_map = Mock( + return_value={ + "toolId": "tool-123", + "name": "my-tool", + "description": "Test tool", + "toolType": "MCP", + "status": "READY", + "dataEndpoint": "https://example.com/data", + "memory": 1024, + "timeout": 60, + "internetAccess": True, + } + ) + + mock_api = Mock() + mock_api.get_tool_async = AsyncMock(return_value=inner_tool) + mock_control_api_class.return_value = mock_api + + client = ToolClient() + tool = await client.get_async(name="my-tool") + + assert tool.tool_id == "tool-123" + assert tool.name == "my-tool" + assert tool.tool_type == "MCP" + mock_api.get_tool_async.assert_called_once_with( + name="my-tool", config=None + ) + + @patch("agentrun.tool.client.ToolControlAPI") + def test_get_http_error(self, mock_control_api_class): + """测试 get() 遇到 HTTPError 时的异常转换""" + from agentrun.utils.exception import HTTPError + + mock_resource_error = Exception("Resource not found") + mock_resource_error.message = "Resource not found" # type: ignore + mock_resource_error.error_code = "ResourceNotFound" # type: ignore + + mock_http_error = HTTPError.__new__(HTTPError) + mock_http_error.to_resource_error = Mock(return_value=mock_resource_error) # type: ignore + + mock_api = Mock() + mock_api.get_tool.side_effect = mock_http_error + mock_control_api_class.return_value = mock_api + + client = ToolClient() + + with pytest.raises(Exception) as exc_info: + client.get(name="my-tool") + assert exc_info.value.message == "Resource not found" # type: ignore + + @patch("agentrun.tool.client.ToolControlAPI") + async def test_get_async_http_error(self, mock_control_api_class): + """测试 get_async() 遇到 HTTPError 时的异常转换""" + from agentrun.utils.exception import HTTPError + + mock_resource_error = Exception("Resource not found") + mock_resource_error.message = "Resource not found" # type: ignore + mock_resource_error.error_code = "ResourceNotFound" # type: ignore + + mock_http_error = HTTPError.__new__(HTTPError) + mock_http_error.to_resource_error = Mock(return_value=mock_resource_error) # type: ignore + + mock_api = Mock() + mock_api.get_tool_async = AsyncMock(side_effect=mock_http_error) + mock_control_api_class.return_value = mock_api + + client = ToolClient() + + with pytest.raises(Exception) as exc_info: + await client.get_async(name="my-tool") + assert exc_info.value.message == "Resource not found" # type: ignore + + @patch("agentrun.tool.tool.Tool._Tool__get_client") + def test_get_by_name(self, mock_get_client): + """测试类方法 get_by_name""" + mock_client = Mock() + mock_tool = Tool(tool_id="tool-123", name="my-tool", tool_type="MCP") + mock_client.get.return_value = mock_tool + mock_get_client.return_value = mock_client + + tool = Tool.get_by_name("my-tool") + + assert tool.tool_id == "tool-123" + assert tool.name == "my-tool" + mock_client.get.assert_called_once_with(name="my-tool") + + @patch("agentrun.tool.tool.Tool._Tool__get_client") + async def test_get_by_name_async(self, mock_get_client): + """测试类方法 get_by_name_async""" + mock_client = Mock() + mock_tool = Tool(tool_id="tool-123", name="my-tool", tool_type="MCP") + mock_client.get_async = AsyncMock(return_value=mock_tool) + mock_get_client.return_value = mock_client + + tool = await Tool.get_by_name_async("my-tool") + + assert tool.tool_id == "tool-123" + assert tool.name == "my-tool" + mock_client.get_async.assert_called_once_with(name="my-tool") + + @patch("agentrun.tool.tool.Tool.get_by_name") + def test_get_sync(self, mock_get_by_name): + """测试实例方法 get()""" + mock_tool = Tool(tool_id="tool-123", name="my-tool", tool_type="MCP") + mock_get_by_name.return_value = mock_tool + + tool = Tool(tool_name="my-tool") + result = tool.get() + + assert result.tool_id == "tool-123" + mock_get_by_name.assert_called_once_with(name="my-tool", config=None) + + def test_get_sync_no_name(self): + """测试 get() 没有 name 时抛出 ValueError""" + tool = Tool() + + with pytest.raises(ValueError, match="Tool name is required"): + tool.get() + + @patch("agentrun.tool.tool.Tool.get_by_name_async") + async def test_get_async_method(self, mock_get_by_name_async): + """测试实例方法 get_async()""" + mock_tool = Tool(tool_id="tool-123", name="my-tool", tool_type="MCP") + mock_get_by_name_async.return_value = mock_tool + + tool = Tool(tool_name="my-tool") + result = await tool.get_async() + + assert result.tool_id == "tool-123" + mock_get_by_name_async.assert_called_once_with( + name="my-tool", config=None + ) + + def test_get_async_no_name(self): + """测试 get_async() 没有 name 时抛出 ValueError""" + tool = Tool() + + with pytest.raises(ValueError, match="Tool name is required"): + import asyncio + + asyncio.run(tool.get_async()) + + def test_get_functioncall_server_url(self): + """测试 _get_functioncall_server_url 有 data_endpoint""" + tool = Tool( + tool_name="my-tool", data_endpoint="https://example.com/data" + ) + url = tool._get_functioncall_server_url() + + assert url == "https://example.com/data/tools/my-tool" + + def test_get_functioncall_server_url_no_endpoint(self): + """测试 _get_functioncall_server_url 没有 data_endpoint 和 name 时返回 None""" + tool = Tool() + url = tool._get_functioncall_server_url() + + assert url is None + + @patch("agentrun.utils.config.Config") + async def test_list_tools_async_mcp_no_endpoint(self, mock_config_class): + """测试 MCP 类型但没有 endpoint 时返回空列表""" + tool = Tool(tool_name="my-tool", tool_type="MCP") + + tools = await tool.list_tools_async() + + assert tools == [] + + @patch("agentrun.tool.api.openapi.ToolOpenAPIClient") + async def test_list_tools_async_functioncall( + self, mock_openapi_client_class + ): + """测试 FUNCTIONCALL 类型的 list_tools_async""" + mock_client = Mock() + mock_client.list_tools_async = AsyncMock( + return_value=[ + ToolInfo(name="tool1", description="Tool 1"), + ToolInfo(name="tool2", description="Tool 2"), + ] + ) + mock_openapi_client_class.return_value = mock_client + + tool = Tool( + tool_type="FUNCTIONCALL", + protocol_spec='{"openapi": "3.0.0"}', + ) + + tools = await tool.list_tools_async() + + assert len(tools) == 2 + assert tools[0].name == "tool1" + assert tools[1].name == "tool2" + + async def test_list_tools_async_no_type(self): + """测试没有类型时 list_tools_async 返回空列表""" + tool = Tool() + tools = await tool.list_tools_async() + assert tools == [] + + @patch("agentrun.tool.api.openapi.ToolOpenAPIClient") + @patch("agentrun.utils.config.Config") + async def test_call_tool_async_functioncall( + self, mock_config_class, mock_openapi_client_class + ): + """测试 FUNCTIONCALL 类型的 call_tool_async""" + mock_client = Mock() + mock_client.call_tool_async = AsyncMock( + return_value={"result": "success"} + ) + mock_openapi_client_class.return_value = mock_client + + mock_config = Mock() + mock_config.get_headers.return_value = {} + mock_config_class.with_configs.return_value = mock_config + + tool = Tool( + tool_type="FUNCTIONCALL", + protocol_spec='{"openapi": "3.0.0"}', + ) + + result = await tool.call_tool_async("tool1", {"param": "value"}) + + assert result == {"result": "success"} + + async def test_call_tool_async_mcp_no_endpoint(self): + """测试 MCP 类型但没有 endpoint 时 call_tool_async 抛出 ValueError""" + tool = Tool(tool_name="my-tool", tool_type="MCP") + + with pytest.raises(ValueError, match="MCP endpoint not available"): + await tool.call_tool_async("tool1", {"param": "value"}) + + @patch("agentrun.tool.api.openapi.ToolOpenAPIClient") + @patch("agentrun.utils.config.Config") + def test_call_tool_functioncall( + self, mock_config_class, mock_openapi_client_class + ): + """测试 FUNCTIONCALL 类型的 call_tool(同步)""" + mock_client = Mock() + mock_client.call_tool.return_value = {"result": "success"} + mock_openapi_client_class.return_value = mock_client + + mock_config = Mock() + mock_config.get_headers.return_value = {} + mock_config_class.with_configs.return_value = mock_config + + tool = Tool( + tool_type="FUNCTIONCALL", + protocol_spec='{"openapi": "3.0.0"}', + ) + + result = tool.call_tool("tool1", {"param": "value"}) + + assert result == {"result": "success"} + + def test_call_tool_mcp_no_endpoint(self): + """测试 MCP 类型但没有 endpoint 时 call_tool 抛出 ValueError""" + tool = Tool(tool_name="my-tool", tool_type="MCP") + + with pytest.raises(ValueError, match="MCP endpoint not available"): + tool.call_tool("tool1", {"param": "value"}) + + @patch("agentrun.utils.config.Config") + def test_list_tools_mcp_no_endpoint(self, mock_config_class): + """测试 MCP 类型但没有 endpoint 时 list_tools 返回空列表""" + tool = Tool(tool_name="my-tool", tool_type="MCP") + + tools = tool.list_tools() + + assert tools == []