diff --git a/examples/agent_tool_search.py b/examples/agent_tool_search.py new file mode 100644 index 0000000..bca48f8 --- /dev/null +++ b/examples/agent_tool_search.py @@ -0,0 +1,185 @@ +"""Search and execute example: LLM-driven tool discovery and execution. + +There are two ways to give tools to an LLM: + +1. ``toolset.openai()`` — fetches ALL tools and converts them to OpenAI format. + Token cost scales with the number of tools in your catalog. + +2. ``toolset.openai(mode="search_and_execute")`` — returns just 2 tools + (tool_search + tool_execute). The LLM discovers and runs tools on-demand, + keeping token usage constant regardless of catalog size. + +This example demonstrates approach 2 with two patterns: +- Raw client (Gemini): manual agent loop with ``toolset.execute()`` +- LangChain: framework handles tool execution automatically + +Prerequisites: + - STACKONE_API_KEY environment variable + - STACKONE_ACCOUNT_ID environment variable + - GOOGLE_API_KEY environment variable (for Gemini/LangChain) + +Run with: + uv run python examples/agent_tool_search.py +""" + +from __future__ import annotations + +import json +import os + +try: + from dotenv import load_dotenv + + load_dotenv() +except ModuleNotFoundError: + pass + +from stackone_ai import StackOneToolSet + + +def example_gemini() -> None: + """Raw client: Gemini via OpenAI-compatible API. + + Shows: init toolset -> get OpenAI tools -> manual agent loop with toolset.execute(). + """ + print("=" * 60) + print("Example 1: Raw client (Gemini) — manual execution") + print("=" * 60) + print() + + try: + from openai import OpenAI + except ImportError: + print("Skipped: pip install openai") + print() + return + + google_key = os.getenv("GOOGLE_API_KEY") + if not google_key: + print("Skipped: Set GOOGLE_API_KEY to run this example.") + print() + return + + # 1. Init toolset + account_id = os.getenv("STACKONE_ACCOUNT_ID") + toolset = StackOneToolSet( + account_id=account_id, + search={"method": "semantic", "top_k": 3}, + execute={"account_ids": [account_id]} if account_id else None, + ) + + # 2. Get tools in OpenAI format + openai_tools = toolset.openai(mode="search_and_execute") + + # 3. Create Gemini client (OpenAI-compatible) and run agent loop + client = OpenAI( + api_key=google_key, + base_url="https://generativelanguage.googleapis.com/v1beta/openai/", + ) + messages: list[dict] = [ + {"role": "user", "content": "List my upcoming Calendly events for the next week."}, + ] + + for _step in range(10): + response = client.chat.completions.create( + model="gemini-3-pro-preview", + messages=messages, + tools=openai_tools, + tool_choice="auto", + ) + + choice = response.choices[0] + + # 4. If no tool calls, print final answer and stop + if not choice.message.tool_calls: + print(f"Answer: {choice.message.content}") + break + + # 5. Execute tool calls manually and feed results back + messages.append(choice.message.model_dump(exclude_none=True)) + for tool_call in choice.message.tool_calls: + print(f" -> {tool_call.function.name}({tool_call.function.arguments})") + result = toolset.execute(tool_call.function.name, tool_call.function.arguments) + messages.append( + { + "role": "tool", + "tool_call_id": tool_call.id, + "content": json.dumps(result), + } + ) + + print() + + +def example_langchain() -> None: + """Framework: LangChain with auto-execution. + + Shows: init toolset -> get LangChain tools -> bind to model -> framework executes tools. + No toolset.execute() needed — the framework calls _run() on tools automatically. + """ + print("=" * 60) + print("Example 2: LangChain — framework handles execution") + print("=" * 60) + print() + + try: + from langchain_core.messages import AIMessage, HumanMessage, ToolMessage + from langchain_google_genai import ChatGoogleGenerativeAI + except ImportError: + print("Skipped: pip install langchain-google-genai") + print() + return + + if not os.getenv("GOOGLE_API_KEY"): + print("Skipped: Set GOOGLE_API_KEY to run this example.") + print() + return + + # 1. Init toolset + account_id = os.getenv("STACKONE_ACCOUNT_ID") + toolset = StackOneToolSet( + account_id=account_id, + search={"method": "semantic", "top_k": 3}, + execute={"account_ids": [account_id]} if account_id else None, + ) + + # 2. Get tools in LangChain format and bind to model + langchain_tools = toolset.langchain(mode="search_and_execute") + tools_by_name = {tool.name: tool for tool in langchain_tools} + model = ChatGoogleGenerativeAI(model="gemini-3-pro-preview").bind_tools(langchain_tools) + + # 3. Run agent loop + messages = [HumanMessage(content="List my upcoming Calendly events for the next week.")] + + for _step in range(10): + response: AIMessage = model.invoke(messages) + + # 4. If no tool calls, print final answer and stop + if not response.tool_calls: + print(f"Answer: {response.content}") + break + + # 5. Framework-compatible execution — invoke LangChain tools directly + messages.append(response) + for tool_call in response.tool_calls: + print(f" -> {tool_call['name']}({json.dumps(tool_call['args'])})") + tool = tools_by_name[tool_call["name"]] + result = tool.invoke(tool_call["args"]) + messages.append(ToolMessage(content=json.dumps(result), tool_call_id=tool_call["id"])) + + print() + + +def main() -> None: + """Run all examples.""" + api_key = os.getenv("STACKONE_API_KEY") + if not api_key: + print("Set STACKONE_API_KEY to run these examples.") + return + + example_gemini() + example_langchain() + + +if __name__ == "__main__": + main() diff --git a/stackone_ai/__init__.py b/stackone_ai/__init__.py index f8fd6fb..b5ba7fd 100644 --- a/stackone_ai/__init__.py +++ b/stackone_ai/__init__.py @@ -7,12 +7,13 @@ SemanticSearchResponse, SemanticSearchResult, ) -from stackone_ai.toolset import SearchConfig, SearchMode, SearchTool, StackOneToolSet +from stackone_ai.toolset import ExecuteToolsConfig, SearchConfig, SearchMode, SearchTool, StackOneToolSet __all__ = [ "StackOneToolSet", "StackOneTool", "Tools", + "ExecuteToolsConfig", "SearchConfig", "SearchMode", "SearchTool", diff --git a/stackone_ai/models.py b/stackone_ai/models.py index aabc802..f38511d 100644 --- a/stackone_ai/models.py +++ b/stackone_ai/models.py @@ -422,6 +422,10 @@ def to_langchain(self) -> BaseTool: python_type = int elif type_str == "boolean": python_type = bool + elif type_str == "object": + python_type = dict + elif type_str == "array": + python_type = list field = Field(description=details.get("description", "")) else: diff --git a/stackone_ai/toolset.py b/stackone_ai/toolset.py index 998dbc0..815a06b 100644 --- a/stackone_ai/toolset.py +++ b/stackone_ai/toolset.py @@ -8,15 +8,19 @@ import logging import os import threading -from collections.abc import Coroutine +from collections.abc import Coroutine, Sequence from dataclasses import dataclass from importlib import metadata from typing import Any, Literal, TypedDict, TypeVar +from pydantic import BaseModel, Field, PrivateAttr, ValidationError, field_validator + from stackone_ai.constants import DEFAULT_BASE_URL from stackone_ai.models import ( ExecuteConfig, + JsonDict, ParameterLocation, + StackOneAPIError, StackOneTool, ToolParameters, Tools, @@ -52,6 +56,20 @@ class SearchConfig(TypedDict, total=False): """Minimum similarity score threshold 0-1.""" +class ExecuteToolsConfig(TypedDict, total=False): + """Execution configuration for the StackOneToolSet constructor. + + Controls default account scoping for tool execution. + + When set to ``None`` (default), no account scoping is applied. + When provided, ``account_ids`` flow through to ``openai(mode="search_and_execute")`` + and ``fetch_tools()`` as defaults. + """ + + account_ids: list[str] + """Account IDs to scope tool discovery and execution.""" + + _SEARCH_DEFAULT: SearchConfig = {"method": "auto"} try: @@ -68,6 +86,223 @@ class SearchConfig(TypedDict, total=False): _USER_AGENT = f"stackone-ai-python/{_SDK_VERSION}" +# --- Internal tool_search + tool_execute --- + + +class _SearchInput(BaseModel): + """Input validation for tool_search.""" + + query: str = Field(..., min_length=1) + connector: str | None = None + top_k: int | None = Field(default=None, ge=1, le=50) + + @field_validator("query") + @classmethod + def validate_query(cls, v: str) -> str: + trimmed = v.strip() + if not trimmed: + raise ValueError("query must be a non-empty string") + return trimmed + + +class _SearchTool(StackOneTool): + """LLM-callable tool that searches for available StackOne tools.""" + + _toolset: Any = PrivateAttr(default=None) + + def execute( + self, arguments: str | JsonDict | None = None, *, options: JsonDict | None = None + ) -> JsonDict: + try: + if isinstance(arguments, str): + raw_params = json.loads(arguments) + else: + raw_params = arguments or {} + + parsed = _SearchInput(**raw_params) + + search_config = self._toolset._search_config or {} + results = self._toolset.search_tools( + parsed.query, + connector=parsed.connector or search_config.get("connector"), + top_k=parsed.top_k or search_config.get("top_k") or 5, + min_similarity=search_config.get("min_similarity"), + search=search_config.get("method"), + account_ids=self._toolset._account_ids, + ) + + return { + "tools": [ + { + "name": t.name, + "description": t.description, + "parameters": t.parameters.properties, + } + for t in results + ], + "total": len(results), + "query": parsed.query, + } + except (json.JSONDecodeError, ValidationError) as exc: + return {"error": f"Invalid input: {exc}", "query": raw_params if "raw_params" in dir() else None} + + +class _ExecuteInput(BaseModel): + """Input validation for tool_execute.""" + + tool_name: str = Field(..., min_length=1) + parameters: dict[str, Any] = Field(default_factory=dict) + + @field_validator("tool_name") + @classmethod + def validate_tool_name(cls, v: str) -> str: + trimmed = v.strip() + if not trimmed: + raise ValueError("tool_name must be a non-empty string") + return trimmed + + +class _ExecuteTool(StackOneTool): + """LLM-callable tool that executes a StackOne tool by name.""" + + _toolset: Any = PrivateAttr(default=None) + _cached_tools: Any = PrivateAttr(default=None) + + def execute( + self, arguments: str | JsonDict | None = None, *, options: JsonDict | None = None + ) -> JsonDict: + tool_name = "unknown" + try: + if isinstance(arguments, str): + raw_params = json.loads(arguments) + else: + raw_params = arguments or {} + + parsed = _ExecuteInput(**raw_params) + tool_name = parsed.tool_name + + if self._cached_tools is None: + self._cached_tools = self._toolset.fetch_tools(account_ids=self._toolset._account_ids) + + target = self._cached_tools.get_tool(parsed.tool_name) + + if target is None: + return { + "error": ( + f'Tool "{parsed.tool_name}" not found. Use tool_search to find available tools.' + ), + } + + return target.execute(parsed.parameters, options=options) + except StackOneAPIError as exc: + return { + "error": str(exc), + "status_code": exc.status_code, + "response_body": exc.response_body, + "tool_name": tool_name, + } + except (json.JSONDecodeError, ValidationError) as exc: + return {"error": f"Invalid input: {exc}", "tool_name": tool_name} + + +def _create_search_tool(api_key: str) -> _SearchTool: + name = "tool_search" + description = ( + "Search for available tools by describing what you need. " + "Returns matching tool names, descriptions, and parameter schemas. " + "Use the returned parameter schemas to know exactly what to pass " + "when calling tool_execute." + ) + parameters = ToolParameters( + type="object", + properties={ + "query": { + "type": "string", + "description": ( + "Natural language description of what you need " + '(e.g. "create an employee", "list time off requests")' + ), + }, + "connector": { + "type": "string", + "description": 'Optional connector filter (e.g. "bamboohr")', + "nullable": True, + }, + "top_k": { + "type": "integer", + "description": "Max results to return (1-50, default 5)", + "minimum": 1, + "maximum": 50, + "nullable": True, + }, + }, + ) + execute_config = ExecuteConfig( + name=name, + method="POST", + url="local://meta/search", + parameter_locations={ + "query": ParameterLocation.BODY, + "connector": ParameterLocation.BODY, + "top_k": ParameterLocation.BODY, + }, + ) + + tool = _SearchTool.__new__(_SearchTool) + StackOneTool.__init__( + tool, + description=description, + parameters=parameters, + _execute_config=execute_config, + _api_key=api_key, + ) + return tool + + +def _create_execute_tool(api_key: str) -> _ExecuteTool: + name = "tool_execute" + description = ( + "Execute a tool by name with the given parameters. " + "Use tool_search first to find available tools. " + "The parameters field must match the parameter schema returned " + "by tool_search. Pass parameters as a nested object matching " + "the schema structure." + ) + parameters = ToolParameters( + type="object", + properties={ + "tool_name": { + "type": "string", + "description": "Exact tool name from tool_search results", + }, + "parameters": { + "type": "object", + "description": "Parameters for the tool, matching the schema from tool_search.", + "nullable": True, + }, + }, + ) + execute_config = ExecuteConfig( + name=name, + method="POST", + url="local://meta/execute", + parameter_locations={ + "tool_name": ParameterLocation.BODY, + "parameters": ParameterLocation.BODY, + }, + ) + + tool = _ExecuteTool.__new__(_ExecuteTool) + StackOneTool.__init__( + tool, + description=description, + parameters=parameters, + _execute_config=execute_config, + _api_key=api_key, + ) + return tool + + T = TypeVar("T") @@ -318,7 +553,8 @@ def __init__( api_key: str | None = None, account_id: str | None = None, base_url: str | None = None, - search: SearchConfig | None = _SEARCH_DEFAULT, + search: SearchConfig | None = None, + execute: ExecuteToolsConfig | None = None, ) -> None: """Initialize StackOne tools with authentication @@ -327,10 +563,14 @@ def __init__( account_id: Optional account ID base_url: Optional base URL override for API requests search: Search configuration. Controls default search behavior. - Omit or pass ``{}`` for defaults (method="auto"). - Pass ``None`` to disable search. + Pass ``None`` (default) to disable search — ``toolset.openai()`` + will return all regular tools. + Pass ``{}`` or ``{"method": "auto"}`` to enable search with defaults. Pass ``{"method": "semantic", "top_k": 5}`` for custom defaults. Per-call options always override these defaults. + execute: Execution configuration. Controls default account scoping + for tool execution. Pass ``{"account_ids": ["acc-1"]}`` to scope + meta tools to specific accounts. Raises: ToolsetConfigError: If no API key is provided or found in environment @@ -347,6 +587,8 @@ def __init__( self._account_ids: list[str] = [] self._semantic_client: SemanticSearchClient | None = None self._search_config: SearchConfig | None = search + self._execute_config: ExecuteToolsConfig | None = execute + self._tools_cache: Tools | None = None def set_accounts(self, account_ids: list[str]) -> StackOneToolSet: """Set account IDs for filtering tools @@ -393,6 +635,120 @@ def get_search_tool(self, *, search: SearchMode | None = None) -> SearchTool: return SearchTool(self, config=config) + def _build_tools(self, account_ids: list[str] | None = None) -> Tools: + """Build tool_search + tool_execute tools scoped to this toolset.""" + if self._search_config is None: + raise ToolsetConfigError( + "Search is disabled. Initialize StackOneToolSet with a search config to enable." + ) + + if account_ids: + self._account_ids = account_ids + + search_tool = _create_search_tool(self.api_key) + search_tool._toolset = self + + execute_tool = _create_execute_tool(self.api_key) + execute_tool._toolset = self + + return Tools([search_tool, execute_tool]) + + def openai( + self, + *, + mode: Literal["search_and_execute"] | None = None, + account_ids: list[str] | None = None, + ) -> list[dict[str, Any]]: + """Get tools in OpenAI function calling format. + + Args: + mode: Tool mode. + ``None`` (default): fetch all tools and convert to OpenAI format. + ``"search_and_execute"``: return two meta tools (tool_search + tool_execute) + that let the LLM discover and execute tools on-demand. + account_ids: Account IDs to scope tools. Overrides the ``execute`` + config from the constructor. + + Returns: + List of tool definitions in OpenAI function format. + + Examples:: + + # All tools + toolset = StackOneToolSet() + tools = toolset.openai() + + # Meta tools for agent-driven discovery + toolset = StackOneToolSet() + tools = toolset.openai(mode="search_and_execute") + """ + effective_account_ids = account_ids or ( + self._execute_config.get("account_ids") if self._execute_config else None + ) + + if mode == "search_and_execute": + return self._build_tools(account_ids=effective_account_ids).to_openai() + + return self.fetch_tools(account_ids=effective_account_ids).to_openai() + + def langchain( + self, + *, + mode: Literal["search_and_execute"] | None = None, + account_ids: list[str] | None = None, + ) -> Sequence[Any]: + """Get tools in LangChain format. + + Args: + mode: Tool mode. + ``None`` (default): fetch all tools and convert to LangChain format. + ``"search_and_execute"``: return two tools (tool_search + tool_execute) + that let the LLM discover and execute tools on-demand. + The framework handles tool execution automatically. + account_ids: Account IDs to scope tools. Overrides the ``execute`` + config from the constructor. + + Returns: + List of LangChain tool objects. + """ + effective_account_ids = account_ids or ( + self._execute_config.get("account_ids") if self._execute_config else None + ) + + if mode == "search_and_execute": + return self._build_tools(account_ids=effective_account_ids).to_langchain() + + return self.fetch_tools(account_ids=effective_account_ids).to_langchain() + + def execute( + self, + tool_name: str, + arguments: str | dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Execute a tool by name. + + Use with ``openai(mode="search_and_execute")`` in manual agent loops — + pass the tool name and arguments from the LLM's tool call directly. + + Tools are cached after the first call. + + Args: + tool_name: The tool name from the LLM's tool call + (e.g. ``"tool_search"`` or ``"tool_execute"``). + arguments: The arguments from the LLM's tool call, + as a JSON string or dict. + + Returns: + Tool execution result as a dict. + """ + if self._tools_cache is None: + self._tools_cache = self._build_tools() + + tool = self._tools_cache.get_tool(tool_name) + if tool is None: + return {"error": f'Tool "{tool_name}" not found.'} + return tool.execute(arguments) + @property def semantic_client(self) -> SemanticSearchClient: """Lazy initialization of semantic search client. diff --git a/tests/test_agent_tools.py b/tests/test_agent_tools.py new file mode 100644 index 0000000..642d6b7 --- /dev/null +++ b/tests/test_agent_tools.py @@ -0,0 +1,536 @@ +"""Tests for tool_search + tool_execute (agent tool discovery).""" + +from __future__ import annotations + +import json +from unittest.mock import MagicMock, patch + +from stackone_ai.models import ( + ExecuteConfig, + StackOneAPIError, + StackOneTool, + ToolParameters, + Tools, +) +from stackone_ai.toolset import ( + StackOneToolSet, + _create_execute_tool, + _create_search_tool, + _ExecuteTool, + _SearchTool, +) + + +def _make_mock_tool(name: str = "test_tool", description: str = "A test tool") -> StackOneTool: + return StackOneTool( + description=description, + parameters=ToolParameters( + type="object", + properties={ + "id": {"type": "string", "description": "The ID"}, + "count": {"type": "integer", "description": "A count"}, + }, + ), + _execute_config=ExecuteConfig( + name=name, + method="GET", + url="http://localhost/test/{id}", + ), + _api_key="test-key", + ) + + +def _make_tools(toolset: MagicMock) -> Tools: + """Build tool_search + tool_execute using the private helpers, wiring in a mock toolset.""" + search_tool = _create_search_tool(toolset.api_key) + search_tool._toolset = toolset + + execute_tool = _create_execute_tool(toolset.api_key) + execute_tool._toolset = toolset + + return Tools([search_tool, execute_tool]) + + +def _make_mock_toolset(tools: list[StackOneTool] | None = None) -> MagicMock: + toolset = MagicMock() + toolset.api_key = "test-key" + toolset._search_config = {"method": "auto"} + toolset._account_ids = [] + + mock_tools = Tools(tools or [_make_mock_tool()]) + toolset.search_tools.return_value = mock_tools + toolset.fetch_tools.return_value = mock_tools + return toolset + + +class TestBuildMetaTools: + def test_returns_tools_collection(self): + toolset = _make_mock_toolset() + result = _make_tools(toolset) + + assert isinstance(result, Tools) + assert len(result) == 2 + + def test_tool_names(self): + toolset = _make_mock_toolset() + result = _make_tools(toolset) + + names = [t.name for t in result] + assert "tool_search" in names + assert "tool_execute" in names + + def test_search_tool_type(self): + toolset = _make_mock_toolset() + result = _make_tools(toolset) + search = result.get_tool("tool_search") + assert isinstance(search, _SearchTool) + + def test_execute_tool_type(self): + toolset = _make_mock_toolset() + result = _make_tools(toolset) + execute = result.get_tool("tool_execute") + assert isinstance(execute, _ExecuteTool) + + def test_private_attrs_excluded_from_serialization(self): + toolset = _make_mock_toolset() + result = _make_tools(toolset) + search = result.get_tool("tool_search") + + dumped = search.model_dump() + assert "_toolset" not in dumped + + +class TestToolSearch: + def test_delegates_to_search_tools(self): + toolset = _make_mock_toolset() + built = _make_tools(toolset) + search = built.get_tool("tool_search") + + search.execute({"query": "find employees"}) + + toolset.search_tools.assert_called_once() + call_args = toolset.search_tools.call_args + assert call_args[0][0] == "find employees" + + def test_returns_tool_names_descriptions_and_schemas(self): + mock_tool = _make_mock_tool(name="bamboohr_list_employees", description="List employees") + toolset = _make_mock_toolset([mock_tool]) + built = _make_tools(toolset) + search = built.get_tool("tool_search") + + result = search.execute({"query": "list employees"}) + + assert result["total"] == 1 + tool_info = result["tools"][0] + assert tool_info["name"] == "bamboohr_list_employees" + assert tool_info["description"] == "List employees" + assert "parameters" in tool_info + assert "id" in tool_info["parameters"] + + def test_reads_config_from_toolset(self): + toolset = _make_mock_toolset() + toolset._search_config = {"method": "semantic", "top_k": 3, "min_similarity": 0.5} + built = _make_tools(toolset) + search = built.get_tool("tool_search") + + search.execute({"query": "employees"}) + + call_kwargs = toolset.search_tools.call_args[1] + assert call_kwargs["search"] == "semantic" + assert call_kwargs["top_k"] == 3 + assert call_kwargs["min_similarity"] == 0.5 + + def test_reads_account_ids_from_toolset(self): + toolset = _make_mock_toolset() + toolset._account_ids = ["acc-1", "acc-2"] + built = _make_tools(toolset) + search = built.get_tool("tool_search") + + search.execute({"query": "employees"}) + + call_kwargs = toolset.search_tools.call_args[1] + assert call_kwargs["account_ids"] == ["acc-1", "acc-2"] + + def test_string_arguments(self): + toolset = _make_mock_toolset() + built = _make_tools(toolset) + search = built.get_tool("tool_search") + + result = search.execute(json.dumps({"query": "employees"})) + + assert "tools" in result + toolset.search_tools.assert_called_once() + + def test_validation_error_returns_error_dict(self): + toolset = _make_mock_toolset() + built = _make_tools(toolset) + search = built.get_tool("tool_search") + + result = search.execute({"query": ""}) + + assert "error" in result + toolset.search_tools.assert_not_called() + + def test_invalid_json_returns_error_dict(self): + toolset = _make_mock_toolset() + built = _make_tools(toolset) + search = built.get_tool("tool_search") + + result = search.execute("not valid json") + + assert "error" in result + + def test_missing_query_returns_error_dict(self): + toolset = _make_mock_toolset() + built = _make_tools(toolset) + search = built.get_tool("tool_search") + + result = search.execute({}) + + assert "error" in result + + +class TestToolExecute: + def test_delegates_to_fetch_and_execute(self): + toolset = MagicMock() + toolset.api_key = "test-key" + toolset._account_ids = [] + + mock_tool = MagicMock() + mock_tool.name = "test_tool" + mock_tools = MagicMock() + mock_tools.get_tool.return_value = mock_tool + mock_tool.execute.return_value = {"result": "ok"} + toolset.fetch_tools.return_value = mock_tools + + built = _make_tools(toolset) + execute = built.get_tool("tool_execute") + + result = execute.execute({"tool_name": "test_tool", "parameters": {"id": "123"}}) + + mock_tool.execute.assert_called_once() + assert result == {"result": "ok"} + + def test_tool_not_found_returns_error(self): + toolset = MagicMock() + toolset.api_key = "test-key" + toolset._account_ids = [] + mock_tools = MagicMock() + mock_tools.get_tool.return_value = None + toolset.fetch_tools.return_value = mock_tools + + built = _make_tools(toolset) + execute = built.get_tool("tool_execute") + + result = execute.execute({"tool_name": "nonexistent_tool"}) + + assert "error" in result + assert "not found" in result["error"] + + def test_api_error_returned_as_dict(self): + toolset = MagicMock() + toolset.api_key = "test-key" + toolset._account_ids = [] + + mock_tool = MagicMock() + mock_tool.name = "test_tool" + mock_tool.execute.side_effect = StackOneAPIError( + message="Bad Request", status_code=400, response_body="invalid" + ) + mock_tools = MagicMock() + mock_tools.get_tool.return_value = mock_tool + toolset.fetch_tools.return_value = mock_tools + + built = _make_tools(toolset) + execute = built.get_tool("tool_execute") + + result = execute.execute({"tool_name": "test_tool", "parameters": {}}) + + assert "error" in result + assert result["status_code"] == 400 + assert result["tool_name"] == "test_tool" + + def test_validation_error_returns_error_dict(self): + toolset = _make_mock_toolset() + built = _make_tools(toolset) + execute = built.get_tool("tool_execute") + + result = execute.execute({"tool_name": ""}) + + assert "error" in result + + def test_invalid_json_returns_error_dict(self): + toolset = _make_mock_toolset() + built = _make_tools(toolset) + execute = built.get_tool("tool_execute") + + result = execute.execute("not valid json") + + assert "error" in result + + def test_caches_fetched_tools(self): + toolset = MagicMock() + toolset.api_key = "test-key" + toolset._account_ids = [] + + mock_tool = MagicMock() + mock_tool.name = "test_tool" + mock_tool.execute.return_value = {"ok": True} + mock_tools = MagicMock() + mock_tools.get_tool.return_value = mock_tool + toolset.fetch_tools.return_value = mock_tools + + built = _make_tools(toolset) + execute = built.get_tool("tool_execute") + + execute.execute({"tool_name": "test_tool"}) + execute.execute({"tool_name": "test_tool"}) + + toolset.fetch_tools.assert_called_once() + + def test_passes_account_ids_from_toolset(self): + toolset = MagicMock() + toolset.api_key = "test-key" + toolset._account_ids = ["acc-1"] + + mock_tool = MagicMock() + mock_tool.name = "test_tool" + mock_tool.execute.return_value = {"ok": True} + mock_tools = MagicMock() + mock_tools.get_tool.return_value = mock_tool + toolset.fetch_tools.return_value = mock_tools + + built = _make_tools(toolset) + execute = built.get_tool("tool_execute") + + execute.execute({"tool_name": "test_tool"}) + + toolset.fetch_tools.assert_called_once_with(account_ids=["acc-1"]) + + def test_string_arguments(self): + toolset = MagicMock() + toolset.api_key = "test-key" + toolset._account_ids = [] + + mock_tool = MagicMock() + mock_tool.name = "test_tool" + mock_tool.execute.return_value = {"ok": True} + mock_tools = MagicMock() + mock_tools.get_tool.return_value = mock_tool + toolset.fetch_tools.return_value = mock_tools + + built = _make_tools(toolset) + execute = built.get_tool("tool_execute") + + result = execute.execute(json.dumps({"tool_name": "test_tool", "parameters": {}})) + + assert result == {"ok": True} + + +class TestLangChainConversion: + def test_tools_convert_to_langchain(self): + toolset = _make_mock_toolset() + built = _make_tools(toolset) + + langchain_tools = built.to_langchain() + + assert len(langchain_tools) == 2 + names = [t.name for t in langchain_tools] + assert "tool_search" in names + assert "tool_execute" in names + + def test_execute_tool_parameters_field_is_dict_type(self): + """The 'parameters' field of tool_execute should map to dict, not str.""" + toolset = _make_mock_toolset() + built = _make_tools(toolset) + execute_tool = built.get_tool("tool_execute") + + langchain_tool = execute_tool.to_langchain() + annotations = langchain_tool.args_schema.__annotations__ + + assert annotations["parameters"] is dict + + +class TestOpenAIConversion: + def test_tools_convert_to_openai(self): + toolset = _make_mock_toolset() + built = _make_tools(toolset) + + openai_tools = built.to_openai() + + assert len(openai_tools) == 2 + names = [t["function"]["name"] for t in openai_tools] + assert "tool_search" in names + assert "tool_execute" in names + + def test_nullable_fields_not_required(self): + toolset = _make_mock_toolset() + built = _make_tools(toolset) + + openai_tools = built.to_openai() + search_fn = next(t for t in openai_tools if t["function"]["name"] == "tool_search") + required = search_fn["function"]["parameters"].get("required", []) + + assert "query" in required + assert "connector" not in required + assert "top_k" not in required + + +class TestToolSetOpenAIMethod: + """Tests for StackOneToolSet.openai() convenience method.""" + + def test_openai_default_fetches_all_tools(self): + toolset = StackOneToolSet(api_key="test-key") + mock_tools = Tools([_make_mock_tool()]) + + with patch.object(toolset, "fetch_tools", return_value=mock_tools) as mock_fetch: + result = toolset.openai() + + mock_fetch.assert_called_once_with(account_ids=None) + assert len(result) == 1 + assert result[0]["function"]["name"] == "test_tool" + + def test_openai_search_and_execute_returns_tools(self): + toolset = StackOneToolSet(api_key="test-key") + mock_built = Tools([_make_mock_tool(name="tool_search"), _make_mock_tool(name="tool_execute")]) + + with patch.object(toolset, "_build_tools", return_value=mock_built) as mock_build: + result = toolset.openai(mode="search_and_execute") + + mock_build.assert_called_once_with(account_ids=None) + assert len(result) == 2 + names = [t["function"]["name"] for t in result] + assert "tool_search" in names + assert "tool_execute" in names + + def test_openai_passes_account_ids(self): + toolset = StackOneToolSet(api_key="test-key") + mock_tools = Tools([_make_mock_tool()]) + + with patch.object(toolset, "fetch_tools", return_value=mock_tools) as mock_fetch: + toolset.openai(account_ids=["acc-1"]) + + mock_fetch.assert_called_once_with(account_ids=["acc-1"]) + + def test_openai_uses_execute_config_account_ids(self): + toolset = StackOneToolSet(api_key="test-key", execute={"account_ids": ["acc-from-config"]}) + mock_tools = Tools([_make_mock_tool()]) + + with patch.object(toolset, "fetch_tools", return_value=mock_tools) as mock_fetch: + toolset.openai() + + mock_fetch.assert_called_once_with(account_ids=["acc-from-config"]) + + def test_openai_account_ids_overrides_execute_config(self): + toolset = StackOneToolSet(api_key="test-key", execute={"account_ids": ["from-config"]}) + mock_tools = Tools([_make_mock_tool()]) + + with patch.object(toolset, "fetch_tools", return_value=mock_tools) as mock_fetch: + toolset.openai(account_ids=["from-call"]) + + mock_fetch.assert_called_once_with(account_ids=["from-call"]) + + def test_openai_search_and_execute_with_execute_config(self): + toolset = StackOneToolSet(api_key="test-key", execute={"account_ids": ["acc-1"]}) + mock_built = Tools([_make_mock_tool(name="tool_search"), _make_mock_tool(name="tool_execute")]) + + with patch.object(toolset, "_build_tools", return_value=mock_built) as mock_build: + toolset.openai(mode="search_and_execute") + + mock_build.assert_called_once_with(account_ids=["acc-1"]) + + +class TestToolSetExecuteMethod: + """Tests for StackOneToolSet.execute() convenience method.""" + + def test_execute_delegates_to_tool(self): + toolset = StackOneToolSet(api_key="test-key", search={"method": "auto"}) + mock_tool = MagicMock() + mock_tool.execute.return_value = {"result": "ok"} + mock_built = MagicMock() + mock_built.get_tool.return_value = mock_tool + + with patch.object(toolset, "_build_tools", return_value=mock_built): + result = toolset.execute("tool_search", {"query": "employees"}) + + assert result == {"result": "ok"} + mock_built.get_tool.assert_called_once_with("tool_search") + mock_tool.execute.assert_called_once_with({"query": "employees"}) + + def test_execute_caches_tools(self): + toolset = StackOneToolSet(api_key="test-key", search={"method": "auto"}) + mock_tool = MagicMock() + mock_tool.execute.return_value = {"ok": True} + mock_built = MagicMock() + mock_built.get_tool.return_value = mock_tool + + with patch.object(toolset, "_build_tools", return_value=mock_built) as mock_build: + toolset.execute("tool_search", {"query": "a"}) + toolset.execute("tool_execute", {"tool_name": "b"}) + + mock_build.assert_called_once() + + def test_execute_returns_error_for_unknown_tool(self): + toolset = StackOneToolSet(api_key="test-key", search={"method": "auto"}) + mock_built = MagicMock() + mock_built.get_tool.return_value = None + + with patch.object(toolset, "_build_tools", return_value=mock_built): + result = toolset.execute("nonexistent", {}) + + assert "error" in result + + def test_execute_accepts_string_arguments(self): + toolset = StackOneToolSet(api_key="test-key", search={"method": "auto"}) + mock_tool = MagicMock() + mock_tool.execute.return_value = {"ok": True} + mock_built = MagicMock() + mock_built.get_tool.return_value = mock_tool + + with patch.object(toolset, "_build_tools", return_value=mock_built): + result = toolset.execute("tool_search", '{"query": "test"}') + + assert result == {"ok": True} + mock_tool.execute.assert_called_once_with('{"query": "test"}') + + +class TestToolSetLangChainMethod: + """Tests for StackOneToolSet.langchain() convenience method.""" + + def test_langchain_default_fetches_all_tools(self): + toolset = StackOneToolSet(api_key="test-key") + mock_tools = Tools([_make_mock_tool()]) + + with patch.object(toolset, "fetch_tools", return_value=mock_tools) as mock_fetch: + result = toolset.langchain() + + mock_fetch.assert_called_once_with(account_ids=None) + assert len(result) == 1 + + def test_langchain_search_and_execute_returns_tools(self): + toolset = StackOneToolSet(api_key="test-key") + mock_tools = Tools([_make_mock_tool(name="tool_search"), _make_mock_tool(name="tool_execute")]) + + with patch.object(toolset, "_build_tools", return_value=mock_tools) as mock_build: + result = toolset.langchain(mode="search_and_execute") + + mock_build.assert_called_once_with(account_ids=None) + assert len(result) == 2 + + def test_langchain_passes_account_ids(self): + toolset = StackOneToolSet(api_key="test-key") + mock_tools = Tools([_make_mock_tool()]) + + with patch.object(toolset, "fetch_tools", return_value=mock_tools) as mock_fetch: + toolset.langchain(account_ids=["acc-1"]) + + mock_fetch.assert_called_once_with(account_ids=["acc-1"]) + + def test_langchain_uses_execute_config_account_ids(self): + toolset = StackOneToolSet(api_key="test-key", execute={"account_ids": ["acc-from-config"]}) + mock_tools = Tools([_make_mock_tool()]) + + with patch.object(toolset, "fetch_tools", return_value=mock_tools) as mock_fetch: + toolset.langchain() + + mock_fetch.assert_called_once_with(account_ids=["acc-from-config"]) diff --git a/tests/test_semantic_search.py b/tests/test_semantic_search.py index 13bef94..8f397fc 100644 --- a/tests/test_semantic_search.py +++ b/tests/test_semantic_search.py @@ -304,7 +304,7 @@ def test_toolset_search_tools( ), ] - toolset = StackOneToolSet(api_key="test-key") + toolset = StackOneToolSet(api_key="test-key", search={"method": "auto"}) tools = toolset.search_tools("create employee", top_k=5) # Should only return tools for available connectors (bamboohr, hibob) @@ -352,7 +352,7 @@ def test_toolset_search_tools_fallback( ), ] - toolset = StackOneToolSet(api_key="test-key") + toolset = StackOneToolSet(api_key="test-key", search={"method": "auto"}) tools = toolset.search_tools("create employee", top_k=5, search="auto") # Should return results from the local BM25+TF-IDF fallback @@ -394,7 +394,7 @@ def test_toolset_search_tools_fallback_respects_connector( ), ] - toolset = StackOneToolSet(api_key="test-key") + toolset = StackOneToolSet(api_key="test-key", search={"method": "auto"}) tools = toolset.search_tools("create employee", connector="bamboohr", search="auto") assert len(tools) > 0 @@ -423,7 +423,7 @@ def test_toolset_search_tools_fallback_disabled( ), ] - toolset = StackOneToolSet(api_key="test-key") + toolset = StackOneToolSet(api_key="test-key", search={"method": "auto"}) with pytest.raises(SemanticSearchError): toolset.search_tools("create employee", search="semantic") @@ -458,7 +458,7 @@ def test_toolset_search_action_names( query="create employee", ) - toolset = StackOneToolSet(api_key="test-key") + toolset = StackOneToolSet(api_key="test-key", search={"method": "auto"}) results = toolset.search_action_names("create employee", min_similarity=0.5) # min_similarity is passed to server; mock returns both results @@ -511,7 +511,7 @@ def test_local_mode_skips_semantic_api( ), ] - toolset = StackOneToolSet(api_key="test-key") + toolset = StackOneToolSet(api_key="test-key", search={"method": "auto"}) tools = toolset.search_tools("create employee", top_k=5, search="local") assert len(tools) > 0 @@ -537,7 +537,7 @@ def test_semantic_mode_raises_on_failure( ), ] - toolset = StackOneToolSet(api_key="test-key") + toolset = StackOneToolSet(api_key="test-key", search={"method": "auto"}) with pytest.raises(SemanticSearchError): toolset.search_tools("create employee", search="semantic") @@ -561,7 +561,7 @@ def test_auto_mode_falls_back_to_local( ), ] - toolset = StackOneToolSet(api_key="test-key") + toolset = StackOneToolSet(api_key="test-key", search={"method": "auto"}) tools = toolset.search_tools("create employee", top_k=5, search="auto") assert len(tools) > 0 @@ -585,7 +585,7 @@ def test_search_tool_passes_search_mode( ), ] - toolset = StackOneToolSet(api_key="test-key") + toolset = StackOneToolSet(api_key="test-key", search={"method": "auto"}) search_tool = toolset.get_search_tool(search="local") tools = search_tool("list employees", top_k=5) @@ -752,7 +752,7 @@ def _search_side_effect( ), ] - toolset = StackOneToolSet(api_key="test-key") + toolset = StackOneToolSet(api_key="test-key", search={"method": "auto"}) results = toolset.search_action_names( "create employee", account_ids=["acc-123"], @@ -776,7 +776,7 @@ def test_search_action_names_returns_empty_on_failure(self, mock_search: MagicMo mock_search.side_effect = SemanticSearchError("API unavailable") - toolset = StackOneToolSet(api_key="test-key") + toolset = StackOneToolSet(api_key="test-key", search={"method": "auto"}) results = toolset.search_action_names("create employee") assert results == [] @@ -808,7 +808,7 @@ def test_searches_all_connectors_in_parallel(self, mock_fetch: MagicMock, mock_s ), ] - toolset = StackOneToolSet(api_key="test-key") + toolset = StackOneToolSet(api_key="test-key", search={"method": "auto"}) toolset.search_action_names( "test", account_ids=["acc-123"], @@ -855,7 +855,7 @@ def test_respects_top_k_after_filtering(self, mock_fetch: MagicMock, mock_search ), ] - toolset = StackOneToolSet(api_key="test-key") + toolset = StackOneToolSet(api_key="test-key", search={"method": "auto"}) results = toolset.search_action_names( "test", account_ids=["acc-123"], @@ -967,7 +967,7 @@ def test_search_tools_deduplicates_versions(self, mock_fetch: MagicMock, mock_se ), ] - toolset = StackOneToolSet(api_key="test-key") + toolset = StackOneToolSet(api_key="test-key", search={"method": "auto"}) tools = toolset.search_tools("list employees", top_k=5) # Should deduplicate: both breathehr versions -> breathehr_list_employees @@ -1002,7 +1002,7 @@ def test_search_action_names_normalizes_versions(self, mock_search: MagicMock) - query="list employees", ) - toolset = StackOneToolSet(api_key="test-key") + toolset = StackOneToolSet(api_key="test-key", search={"method": "auto"}) results = toolset.search_action_names("list employees", top_k=5) # Both results are returned with normalized names (no dedup in global path)