diff --git "a/1_\360\237\217\240_Home.py" "b/1_\360\237\217\240_Home.py" index 795f271..9a920b0 100644 --- "a/1_\360\237\217\240_Home.py" +++ "b/1_\360\237\217\240_Home.py" @@ -1,7 +1,7 @@ import streamlit as st from streamlit_pills import pills -from agent_utils import ( +from core.agent_builder import ( load_meta_agent_and_tools, load_agent_ids_from_directory, ) diff --git a/README.md b/README.md index c83bbd0..6b22ee3 100644 --- a/README.md +++ b/README.md @@ -83,7 +83,7 @@ It will be able to pick the right RAG tools (either top-k vector search or optio ### Builder Agent -By default the builder agent uses OpenAI. This is defined in the `builder_config.py` file. +By default the builder agent uses OpenAI. This is defined in the `core/builder_config.py` file. You can customize this to whatever LLM you want (an example is provided for Anthropic). diff --git a/agent_utils.py b/agent_utils.py deleted file mode 100644 index c6ff234..0000000 --- a/agent_utils.py +++ /dev/null @@ -1,786 +0,0 @@ -from llama_index.llms import OpenAI, ChatMessage, Anthropic, Replicate -from llama_index.llms.base import LLM -from llama_index.llms.utils import resolve_llm -from pydantic import BaseModel, Field -import os -from llama_index.agent import OpenAIAgent, ReActAgent -from llama_index.agent.react.prompts import REACT_CHAT_SYSTEM_HEADER -from llama_index import ( - VectorStoreIndex, - SummaryIndex, - ServiceContext, - StorageContext, - Document, - load_index_from_storage, -) -from llama_index.prompts import ChatPromptTemplate -from typing import List, cast, Optional -from llama_index import SimpleDirectoryReader -from llama_index.embeddings.utils import resolve_embed_model -from llama_index.tools import QueryEngineTool, ToolMetadata, FunctionTool -from llama_index.agent.types import BaseAgent -from llama_index.chat_engine.types import BaseChatEngine -from llama_index.agent.react.formatter import ReActChatFormatter -from llama_index.llms.openai_utils import is_function_calling_model -from llama_index.chat_engine import CondensePlusContextChatEngine -from builder_config import BUILDER_LLM -from typing import Dict, Tuple, Any, Callable -import streamlit as st -from pathlib import Path -import json -import uuid -from constants import AGENT_CACHE_DIR -import shutil - -from llama_index.callbacks import CallbackManager -from callback_manager import StreamlitFunctionsCallbackHandler - - -def _resolve_llm(llm_str: str) -> LLM: - """Resolve LLM.""" - # TODO: make this less hardcoded with if-else statements - # see if there's a prefix - # - if there isn't, assume it's an OpenAI model - # - if there is, resolve it - tokens = llm_str.split(":") - if len(tokens) == 1: - os.environ["OPENAI_API_KEY"] = st.secrets.openai_key - llm: LLM = OpenAI(model=llm_str) - elif tokens[0] == "local": - llm = resolve_llm(llm_str) - elif tokens[0] == "openai": - os.environ["OPENAI_API_KEY"] = st.secrets.openai_key - llm = OpenAI(model=tokens[1]) - elif tokens[0] == "anthropic": - os.environ["ANTHROPIC_API_KEY"] = st.secrets.anthropic_key - llm = Anthropic(model=tokens[1]) - elif tokens[0] == "replicate": - os.environ["REPLICATE_API_KEY"] = st.secrets.replicate_key - llm = Replicate(model=tokens[1]) - else: - raise ValueError(f"LLM {llm_str} not recognized.") - return llm - - -#################### -#### META TOOLS #### -#################### - - -# System prompt tool -GEN_SYS_PROMPT_STR = """\ -Task information is given below. - -Given the task, please generate a system prompt for an OpenAI-powered bot \ -to solve this task: -{task} \ - -Make sure the system prompt obeys the following requirements: -- Tells the bot to ALWAYS use tools given to solve the task. \ -NEVER give an answer without using a tool. -- Does not reference a specific data source. \ -The data source is implicit in any queries to the bot, \ -and telling the bot to analyze a specific data source might confuse it given a \ -user query. - -""" - -gen_sys_prompt_messages = [ - ChatMessage( - role="system", - content="You are helping to build a system prompt for another bot.", - ), - ChatMessage(role="user", content=GEN_SYS_PROMPT_STR), -] - -GEN_SYS_PROMPT_TMPL = ChatPromptTemplate(gen_sys_prompt_messages) - - -class RAGParams(BaseModel): - """RAG parameters. - - Parameters used to configure a RAG pipeline. - - """ - - include_summarization: bool = Field( - default=False, - description=( - "Whether to include summarization in the RAG pipeline. (only for GPT-4)" - ), - ) - top_k: int = Field( - default=2, description="Number of documents to retrieve from vector store." - ) - chunk_size: int = Field(default=1024, description="Chunk size for vector store.") - embed_model: str = Field( - default="default", description="Embedding model to use (default is OpenAI)" - ) - llm: str = Field( - default="gpt-4-1106-preview", description="LLM to use for summarization." - ) - - -def load_data( - file_names: Optional[List[str]] = None, urls: Optional[List[str]] = None -) -> List[Document]: - """Load data.""" - file_names = file_names or [] - urls = urls or [] - if not file_names and not urls: - raise ValueError("Must specify either file_names or urls.") - elif file_names and urls: - raise ValueError("Must specify only one of file_names or urls.") - elif file_names: - reader = SimpleDirectoryReader(input_files=file_names) - docs = reader.load_data() - elif urls: - from llama_hub.web.simple_web.base import SimpleWebPageReader - - # use simple web page reader from llamahub - loader = SimpleWebPageReader() - docs = loader.load_data(urls=urls) - else: - raise ValueError("Must specify either file_names or urls.") - - return docs - - -def load_agent( - tools: List, - llm: LLM, - system_prompt: str, - extra_kwargs: Optional[Dict] = None, - **kwargs: Any, -) -> BaseChatEngine: - """Load agent.""" - extra_kwargs = extra_kwargs or {} - if isinstance(llm, OpenAI) and is_function_calling_model(llm.model): - # TODO: use default msg handler - # TODO: separate this from agent_utils.py... - def _msg_handler(msg: str) -> None: - """Message handler.""" - st.info(msg) - st.session_state.agent_messages.append( - {"role": "assistant", "content": msg, "msg_type": "info"} - ) - - # add streamlit callbacks (to inject events) - handler = StreamlitFunctionsCallbackHandler(_msg_handler) - callback_manager = CallbackManager([handler]) - # get OpenAI Agent - agent: BaseChatEngine = OpenAIAgent.from_tools( - tools=tools, - llm=llm, - system_prompt=system_prompt, - **kwargs, - callback_manager=callback_manager, - ) - else: - if "vector_index" not in extra_kwargs: - raise ValueError( - "Must pass in vector index for CondensePlusContextChatEngine." - ) - vector_index = cast(VectorStoreIndex, extra_kwargs["vector_index"]) - rag_params = cast(RAGParams, extra_kwargs["rag_params"]) - # use condense + context chat engine - agent = CondensePlusContextChatEngine.from_defaults( - vector_index.as_retriever(similarity_top_k=rag_params.top_k), - ) - - return agent - - -def load_meta_agent( - tools: List, - llm: LLM, - system_prompt: str, - extra_kwargs: Optional[Dict] = None, - **kwargs: Any, -) -> BaseAgent: - """Load meta agent. - - TODO: consolidate with load_agent. - - The meta-agent *has* to perform tool-use. - - """ - extra_kwargs = extra_kwargs or {} - if isinstance(llm, OpenAI) and is_function_calling_model(llm.model): - # get OpenAI Agent - - agent: BaseAgent = OpenAIAgent.from_tools( - tools=tools, - llm=llm, - system_prompt=system_prompt, - **kwargs, - ) - else: - agent = ReActAgent.from_tools( - tools=tools, - llm=llm, - react_chat_formatter=ReActChatFormatter( - system_header=system_prompt + "\n" + REACT_CHAT_SYSTEM_HEADER, - ), - **kwargs, - ) - - return agent - - -def construct_agent( - system_prompt: str, - rag_params: RAGParams, - docs: List[Document], - vector_index: Optional[VectorStoreIndex] = None, - additional_tools: Optional[List] = None, -) -> Tuple[BaseChatEngine, Dict]: - """Construct agent from docs / parameters / indices.""" - extra_info = {} - additional_tools = additional_tools or [] - - # first resolve llm and embedding model - embed_model = resolve_embed_model(rag_params.embed_model) - # llm = resolve_llm(rag_params.llm) - # TODO: use OpenAI for now - # llm = OpenAI(model=rag_params.llm) - llm = _resolve_llm(rag_params.llm) - - # first let's index the data with the right parameters - service_context = ServiceContext.from_defaults( - chunk_size=rag_params.chunk_size, - llm=llm, - embed_model=embed_model, - ) - - if vector_index is None: - vector_index = VectorStoreIndex.from_documents( - docs, service_context=service_context - ) - else: - pass - - extra_info["vector_index"] = vector_index - - vector_query_engine = vector_index.as_query_engine( - similarity_top_k=rag_params.top_k - ) - all_tools = [] - vector_tool = QueryEngineTool( - query_engine=vector_query_engine, - metadata=ToolMetadata( - name="vector_tool", - description=("Use this tool to answer any user question over any data."), - ), - ) - all_tools.append(vector_tool) - if rag_params.include_summarization: - summary_index = SummaryIndex.from_documents( - docs, service_context=service_context - ) - summary_query_engine = summary_index.as_query_engine() - summary_tool = QueryEngineTool( - query_engine=summary_query_engine, - metadata=ToolMetadata( - name="summary_tool", - description=( - "Use this tool for any user questions that ask " - "for a summarization of content" - ), - ), - ) - all_tools.append(summary_tool) - - # then we add tools - all_tools.extend(additional_tools) - - # build agent - if system_prompt is None: - return "System prompt not set yet. Please set system prompt first." - - agent = load_agent( - all_tools, - llm=llm, - system_prompt=system_prompt, - verbose=True, - extra_kwargs={"vector_index": vector_index, "rag_params": rag_params}, - ) - return agent, extra_info - - -def get_web_agent_tool() -> QueryEngineTool: - """Get web agent tool. - - Wrap with our load and search tool spec. - - """ - from llama_hub.tools.metaphor.base import MetaphorToolSpec - - # TODO: set metaphor API key - metaphor_tool = MetaphorToolSpec( - api_key=st.secrets.metaphor_key, - ) - metaphor_tool_list = metaphor_tool.to_tool_list() - - # TODO: LoadAndSearch doesn't work yet - # The search_and_retrieve_documents tool is the third in the tool list, - # as seen above - # wrapped_retrieve = LoadAndSearchToolSpec.from_defaults( - # metaphor_tool_list[2], - # ) - - # NOTE: requires openai right now - # We don't give the Agent our unwrapped retrieve document tools - # instead passing the wrapped tools - web_agent = OpenAIAgent.from_tools( - # [*wrapped_retrieve.to_tool_list(), metaphor_tool_list[4]], - metaphor_tool_list, - llm=BUILDER_LLM, - verbose=True, - ) - - # return agent as a tool - # TODO: tune description - web_agent_tool = QueryEngineTool.from_defaults( - web_agent, - name="web_agent", - description=""" - This agent can answer questions by searching the web. \ -Use this tool if the answer is ONLY likely to be found by searching \ -the internet, especially for queries about recent events. - """, - ) - - return web_agent_tool - - -def get_tool_objects(tool_names: List[str]) -> List: - """Get tool objects from tool names.""" - # construct additional tools - tool_objs = [] - for tool_name in tool_names: - if tool_name == "web_search": - # build web agent - tool_objs.append(get_web_agent_tool()) - else: - raise ValueError(f"Tool {tool_name} not recognized.") - - return tool_objs - - -class ParamCache(BaseModel): - """Cache for RAG agent builder. - - Created a wrapper class around a dict in case we wanted to more explicitly - type different items in the cache. - - """ - - # arbitrary types - class Config: - arbitrary_types_allowed = True - - # system prompt - system_prompt: Optional[str] = Field( - default=None, description="System prompt for RAG agent." - ) - # data - file_names: List[str] = Field( - default_factory=list, description="File names as data source (if specified)" - ) - urls: List[str] = Field( - default_factory=list, description="URLs as data source (if specified)" - ) - docs: List = Field(default_factory=list, description="Documents for RAG agent.") - # tools - tools: List = Field( - default_factory=list, description="Additional tools for RAG agent (e.g. web)" - ) - # RAG params - rag_params: RAGParams = Field( - default_factory=RAGParams, description="RAG parameters for RAG agent." - ) - - # agent params - vector_index: Optional[VectorStoreIndex] = Field( - default=None, description="Vector index for RAG agent." - ) - agent_id: str = Field( - default_factory=lambda: f"Agent_{str(uuid.uuid4())}", - description="Agent ID for RAG agent.", - ) - agent: Optional[BaseChatEngine] = Field(default=None, description="RAG agent.") - - def save_to_disk(self, save_dir: str) -> None: - """Save cache to disk.""" - # NOTE: more complex than just calling dict() because we want to - # only store serializable fields and be space-efficient - - dict_to_serialize = { - "system_prompt": self.system_prompt, - "file_names": self.file_names, - "urls": self.urls, - # TODO: figure out tools - "tools": self.tools, - "rag_params": self.rag_params.dict(), - "agent_id": self.agent_id, - } - # store the vector store within the agent - if self.vector_index is None: - raise ValueError("Must specify vector index in order to save.") - self.vector_index.storage_context.persist(Path(save_dir) / "storage") - - # if save_path directories don't exist, create it - if not Path(save_dir).exists(): - Path(save_dir).mkdir(parents=True) - with open(Path(save_dir) / "cache.json", "w") as f: - json.dump(dict_to_serialize, f) - - @classmethod - def load_from_disk( - cls, - save_dir: str, - ) -> "ParamCache": - """Load cache from disk.""" - storage_context = StorageContext.from_defaults( - persist_dir=str(Path(save_dir) / "storage") - ) - vector_index = cast(VectorStoreIndex, load_index_from_storage(storage_context)) - - with open(Path(save_dir) / "cache.json", "r") as f: - cache_dict = json.load(f) - - # replace rag params with RAGParams object - cache_dict["rag_params"] = RAGParams(**cache_dict["rag_params"]) - - # add in the missing fields - # load docs - cache_dict["docs"] = load_data( - file_names=cache_dict["file_names"], urls=cache_dict["urls"] - ) - # load agent from index - additional_tools = get_tool_objects(cache_dict["tools"]) - agent, _ = construct_agent( - cache_dict["system_prompt"], - cache_dict["rag_params"], - cache_dict["docs"], - vector_index=vector_index, - additional_tools=additional_tools, - # TODO: figure out tools - ) - cache_dict["vector_index"] = vector_index - cache_dict["agent"] = agent - - return cls(**cache_dict) - - -def add_agent_id_to_directory(dir: str, agent_id: str) -> None: - """Save agent id to directory.""" - full_path = Path(dir) / "agent_ids.json" - if not full_path.exists(): - with open(full_path, "w") as f: - json.dump({"agent_ids": [agent_id]}, f) - else: - with open(full_path, "r") as f: - agent_ids = json.load(f)["agent_ids"] - if agent_id in agent_ids: - raise ValueError(f"Agent id {agent_id} already exists.") - agent_ids_set = set(agent_ids) - agent_ids_set.add(agent_id) - with open(full_path, "w") as f: - json.dump({"agent_ids": list(agent_ids_set)}, f) - - -def load_agent_ids_from_directory(dir: str) -> List[str]: - """Load agent ids file.""" - full_path = Path(dir) / "agent_ids.json" - if not full_path.exists(): - return [] - with open(full_path, "r") as f: - agent_ids = json.load(f)["agent_ids"] - - return agent_ids - - -def load_cache_from_directory( - dir: str, - agent_id: str, -) -> ParamCache: - """Load cache from directory.""" - full_path = Path(dir) / f"{agent_id}" - if not full_path.exists(): - raise ValueError(f"Cache for agent {agent_id} does not exist.") - cache = ParamCache.load_from_disk(str(full_path)) - return cache - - -def remove_agent_from_directory( - dir: str, - agent_id: str, -) -> None: - """Remove agent from directory.""" - - # modify / resave agent_ids - agent_ids = load_agent_ids_from_directory(dir) - new_agent_ids = [id for id in agent_ids if id != agent_id] - full_path = Path(dir) / "agent_ids.json" - with open(full_path, "w") as f: - json.dump({"agent_ids": new_agent_ids}, f) - - # remove agent cache - full_path = Path(dir) / f"{agent_id}" - if full_path.exists(): - # recursive delete - shutil.rmtree(full_path) - - -class RAGAgentBuilder: - """RAG Agent builder. - - Contains a set of functions to construct a RAG agent, including: - - setting system prompts - - loading data - - adding web search - - setting parameters (e.g. top-k) - - Must pass in a cache. This cache will be modified as the agent is built. - - """ - - def __init__( - self, cache: Optional[ParamCache] = None, cache_dir: Optional[str] = None - ) -> None: - """Init params.""" - self._cache = cache or ParamCache() - self._cache_dir = cache_dir or AGENT_CACHE_DIR - - @property - def cache(self) -> ParamCache: - """Cache.""" - return self._cache - - def create_system_prompt(self, task: str) -> str: - """Create system prompt for another agent given an input task.""" - llm = BUILDER_LLM - fmt_messages = GEN_SYS_PROMPT_TMPL.format_messages(task=task) - response = llm.chat(fmt_messages) - self._cache.system_prompt = response.message.content - - return f"System prompt created: {response.message.content}" - - def load_data( - self, file_names: Optional[List[str]] = None, urls: Optional[List[str]] = None - ) -> str: - """Load data for a given task. - - Only ONE of file_names or urls should be specified. - - Args: - file_names (Optional[List[str]]): List of file names to load. - Defaults to None. - urls (Optional[List[str]]): List of urls to load. - Defaults to None. - - """ - file_names = file_names or [] - urls = urls or [] - docs = load_data(file_names=file_names, urls=urls) - self._cache.docs = docs - self._cache.file_names = file_names - self._cache.urls = urls - return "Data loaded successfully." - - def add_web_tool(self) -> str: - """Add a web tool to enable agent to solve a task.""" - # TODO: make this not hardcoded to a web tool - # Set up Metaphor tool - if "web_search" in self._cache.tools: - return "Web tool already added." - else: - self._cache.tools.append("web_search") - return "Web tool added successfully." - - def get_rag_params(self) -> Dict: - """Get parameters used to configure the RAG pipeline. - - Should be called before `set_rag_params` so that the agent is aware of the - schema. - - """ - rag_params = self._cache.rag_params - return rag_params.dict() - - def set_rag_params(self, **rag_params: Dict) -> str: - """Set RAG parameters. - - These parameters will then be used to actually initialize the agent. - Should call `get_rag_params` first to get the schema of the input dictionary. - - Args: - **rag_params (Dict): dictionary of RAG parameters. - - """ - new_dict = self._cache.rag_params.dict() - new_dict.update(rag_params) - rag_params_obj = RAGParams(**new_dict) - self._cache.rag_params = rag_params_obj - return "RAG parameters set successfully." - - def create_agent(self, agent_id: Optional[str] = None) -> str: - """Create an agent. - - There are no parameters for this function because all the - functions should have already been called to set up the agent. - - """ - if self._cache.system_prompt is None: - raise ValueError("Must set system prompt before creating agent.") - - # construct additional tools - additional_tools = get_tool_objects(self.cache.tools) - agent, extra_info = construct_agent( - cast(str, self._cache.system_prompt), - cast(RAGParams, self._cache.rag_params), - self._cache.docs, - additional_tools=additional_tools, - ) - - # if agent_id not specified, randomly generate one - agent_id = agent_id or self._cache.agent_id or f"Agent_{str(uuid.uuid4())}" - self._cache.vector_index = extra_info["vector_index"] - self._cache.agent_id = agent_id - self._cache.agent = agent - - # save the cache to disk - agent_cache_path = f"{self._cache_dir}/{agent_id}" - self._cache.save_to_disk(agent_cache_path) - # save to agent ids - add_agent_id_to_directory(str(self._cache_dir), agent_id) - - return "Agent created successfully." - - def update_agent( - self, - agent_id: str, - system_prompt: Optional[str] = None, - include_summarization: Optional[bool] = None, - top_k: Optional[int] = None, - chunk_size: Optional[int] = None, - embed_model: Optional[str] = None, - llm: Optional[str] = None, - additional_tools: Optional[List] = None, - ) -> None: - """Update agent. - - Delete old agent by ID and create a new one. - Optionally update the system prompt and RAG parameters. - - NOTE: Currently is manually called, not meant for agent use. - - """ - # remove saved agent from directory, since we'll be re-saving - remove_agent_from_directory(str(AGENT_CACHE_DIR), self.cache.agent_id) - - # set agent id - self.cache.agent_id = agent_id - - # set system prompt - if system_prompt is not None: - self.cache.system_prompt = system_prompt - # get agent_builder - # We call set_rag_params and create_agent, which will - # update the cache - # TODO: decouple functions from tool functions exposed to the agent - rag_params_dict: Dict[str, Any] = {} - if include_summarization is not None: - rag_params_dict["include_summarization"] = include_summarization - if top_k is not None: - rag_params_dict["top_k"] = top_k - if chunk_size is not None: - rag_params_dict["chunk_size"] = chunk_size - if embed_model is not None: - rag_params_dict["embed_model"] = embed_model - if llm is not None: - rag_params_dict["llm"] = llm - - self.set_rag_params(**rag_params_dict) - - # update tools - if additional_tools is not None: - self.cache.tools = additional_tools - - # this will update the agent in the cache - self.create_agent() - - -#################### -#### META Agent #### -#################### - -RAG_BUILDER_SYS_STR = """\ -You are helping to construct an agent given a user-specified task. -You should generally use the tools in this rough order to build the agent. - -1) Create system prompt tool: to create the system prompt for the agent. -2) Load in user-specified data (based on file paths they specify). -3) Decide whether or not to add additional tools. -4) Set parameters for the RAG pipeline. -5) Build the agent - -This will be a back and forth conversation with the user. You should -continue asking users if there's anything else they want to do until -they say they're done. To help guide them on the process, -you can give suggestions on parameters they can set based on the tools they -have available (e.g. "Do you want to set the number of documents to retrieve?") - -""" - - -### DEFINE Agent #### -# NOTE: here we define a function that is dependent on the LLM, -# please make sure to update the LLM above if you change the function below - - -def _get_builder_agent_tools(agent_builder: RAGAgentBuilder) -> List[FunctionTool]: - """Get list of builder agent tools to pass to the builder agent.""" - # see if metaphor api key is set, otherwise don't add web tool - # TODO: refactor this later - - if "metaphor_key" in st.secrets: - fns: List[Callable] = [ - agent_builder.create_system_prompt, - agent_builder.load_data, - agent_builder.add_web_tool, - agent_builder.get_rag_params, - agent_builder.set_rag_params, - agent_builder.create_agent, - ] - else: - fns = [ - agent_builder.create_system_prompt, - agent_builder.load_data, - agent_builder.get_rag_params, - agent_builder.set_rag_params, - agent_builder.create_agent, - ] - - fn_tools: List[FunctionTool] = [FunctionTool.from_defaults(fn=fn) for fn in fns] - return fn_tools - - -# define agent -# @st.cache_resource -def load_meta_agent_and_tools( - cache: Optional[ParamCache] = None, -) -> Tuple[BaseAgent, RAGAgentBuilder]: - - # think of this as tools for the agent to use - agent_builder = RAGAgentBuilder(cache) - - fn_tools = _get_builder_agent_tools(agent_builder) - - builder_agent = load_meta_agent( - fn_tools, llm=BUILDER_LLM, system_prompt=RAG_BUILDER_SYS_STR, verbose=True - ) - - return builder_agent, agent_builder diff --git a/core/__init__.py b/core/__init__.py new file mode 100644 index 0000000..3a5547f --- /dev/null +++ b/core/__init__.py @@ -0,0 +1 @@ +"""Init file.""" \ No newline at end of file diff --git a/core/agent_builder.py b/core/agent_builder.py new file mode 100644 index 0000000..df0da44 --- /dev/null +++ b/core/agent_builder.py @@ -0,0 +1,373 @@ +"""Agent builder.""" + +from llama_index.llms import OpenAI, ChatMessage, Anthropic, Replicate +from llama_index.llms.base import LLM +from llama_index.llms.utils import resolve_llm +from pydantic import BaseModel, Field +from llama_index.prompts import ChatPromptTemplate +from typing import List, cast, Optional +from llama_index import SimpleDirectoryReader +from llama_index.embeddings.utils import resolve_embed_model +from llama_index.tools import QueryEngineTool, ToolMetadata, FunctionTool +from llama_index.agent.types import BaseAgent +from llama_index.chat_engine.types import BaseChatEngine +from llama_index.agent.react.formatter import ReActChatFormatter +from llama_index.llms.openai_utils import is_function_calling_model +from llama_index.chat_engine import CondensePlusContextChatEngine +from core.builder_config import BUILDER_LLM +from typing import Dict, Tuple, Any, Callable +import streamlit as st +from pathlib import Path +import json +import uuid +from core.constants import AGENT_CACHE_DIR +import shutil + +from llama_index.callbacks import CallbackManager +from callback_manager import StreamlitFunctionsCallbackHandler +from core.param_cache import ParamCache, RAGParams +from core.utils import ( + load_agent, load_data, get_tool_objects, construct_agent, load_meta_agent +) + + +def add_agent_id_to_directory(dir: str, agent_id: str) -> None: + """Save agent id to directory.""" + full_path = Path(dir) / "agent_ids.json" + if not full_path.exists(): + with open(full_path, "w") as f: + json.dump({"agent_ids": [agent_id]}, f) + else: + with open(full_path, "r") as f: + agent_ids = json.load(f)["agent_ids"] + if agent_id in agent_ids: + raise ValueError(f"Agent id {agent_id} already exists.") + agent_ids_set = set(agent_ids) + agent_ids_set.add(agent_id) + with open(full_path, "w") as f: + json.dump({"agent_ids": list(agent_ids_set)}, f) + + +def load_agent_ids_from_directory(dir: str) -> List[str]: + """Load agent ids file.""" + full_path = Path(dir) / "agent_ids.json" + if not full_path.exists(): + return [] + with open(full_path, "r") as f: + agent_ids = json.load(f)["agent_ids"] + + return agent_ids + + +def load_cache_from_directory( + dir: str, + agent_id: str, +) -> ParamCache: + """Load cache from directory.""" + full_path = Path(dir) / f"{agent_id}" + if not full_path.exists(): + raise ValueError(f"Cache for agent {agent_id} does not exist.") + cache = ParamCache.load_from_disk(str(full_path)) + return cache + + +def remove_agent_from_directory( + dir: str, + agent_id: str, +) -> None: + """Remove agent from directory.""" + + # modify / resave agent_ids + agent_ids = load_agent_ids_from_directory(dir) + new_agent_ids = [id for id in agent_ids if id != agent_id] + full_path = Path(dir) / "agent_ids.json" + with open(full_path, "w") as f: + json.dump({"agent_ids": new_agent_ids}, f) + + # remove agent cache + full_path = Path(dir) / f"{agent_id}" + if full_path.exists(): + # recursive delete + shutil.rmtree(full_path) + + + +# System prompt tool +GEN_SYS_PROMPT_STR = """\ +Task information is given below. + +Given the task, please generate a system prompt for an OpenAI-powered bot \ +to solve this task: +{task} \ + +Make sure the system prompt obeys the following requirements: +- Tells the bot to ALWAYS use tools given to solve the task. \ +NEVER give an answer without using a tool. +- Does not reference a specific data source. \ +The data source is implicit in any queries to the bot, \ +and telling the bot to analyze a specific data source might confuse it given a \ +user query. + +""" + +gen_sys_prompt_messages = [ + ChatMessage( + role="system", + content="You are helping to build a system prompt for another bot.", + ), + ChatMessage(role="user", content=GEN_SYS_PROMPT_STR), +] + +GEN_SYS_PROMPT_TMPL = ChatPromptTemplate(gen_sys_prompt_messages) + + +class RAGAgentBuilder: + """RAG Agent builder. + + Contains a set of functions to construct a RAG agent, including: + - setting system prompts + - loading data + - adding web search + - setting parameters (e.g. top-k) + + Must pass in a cache. This cache will be modified as the agent is built. + + """ + + def __init__( + self, cache: Optional[ParamCache] = None, cache_dir: Optional[str] = None + ) -> None: + """Init params.""" + self._cache = cache or ParamCache() + self._cache_dir = cache_dir or AGENT_CACHE_DIR + + @property + def cache(self) -> ParamCache: + """Cache.""" + return self._cache + + def create_system_prompt(self, task: str) -> str: + """Create system prompt for another agent given an input task.""" + llm = BUILDER_LLM + fmt_messages = GEN_SYS_PROMPT_TMPL.format_messages(task=task) + response = llm.chat(fmt_messages) + self._cache.system_prompt = response.message.content + + return f"System prompt created: {response.message.content}" + + def load_data( + self, file_names: Optional[List[str]] = None, urls: Optional[List[str]] = None + ) -> str: + """Load data for a given task. + + Only ONE of file_names or urls should be specified. + + Args: + file_names (Optional[List[str]]): List of file names to load. + Defaults to None. + urls (Optional[List[str]]): List of urls to load. + Defaults to None. + + """ + file_names = file_names or [] + urls = urls or [] + docs = load_data(file_names=file_names, urls=urls) + self._cache.docs = docs + self._cache.file_names = file_names + self._cache.urls = urls + return "Data loaded successfully." + + def add_web_tool(self) -> str: + """Add a web tool to enable agent to solve a task.""" + # TODO: make this not hardcoded to a web tool + # Set up Metaphor tool + if "web_search" in self._cache.tools: + return "Web tool already added." + else: + self._cache.tools.append("web_search") + return "Web tool added successfully." + + def get_rag_params(self) -> Dict: + """Get parameters used to configure the RAG pipeline. + + Should be called before `set_rag_params` so that the agent is aware of the + schema. + + """ + rag_params = self._cache.rag_params + return rag_params.dict() + + def set_rag_params(self, **rag_params: Dict) -> str: + """Set RAG parameters. + + These parameters will then be used to actually initialize the agent. + Should call `get_rag_params` first to get the schema of the input dictionary. + + Args: + **rag_params (Dict): dictionary of RAG parameters. + + """ + new_dict = self._cache.rag_params.dict() + new_dict.update(rag_params) + rag_params_obj = RAGParams(**new_dict) + self._cache.rag_params = rag_params_obj + return "RAG parameters set successfully." + + def create_agent(self, agent_id: Optional[str] = None) -> str: + """Create an agent. + + There are no parameters for this function because all the + functions should have already been called to set up the agent. + + """ + if self._cache.system_prompt is None: + raise ValueError("Must set system prompt before creating agent.") + + # construct additional tools + additional_tools = get_tool_objects(self.cache.tools) + agent, extra_info = construct_agent( + cast(str, self._cache.system_prompt), + cast(RAGParams, self._cache.rag_params), + self._cache.docs, + additional_tools=additional_tools, + ) + + # if agent_id not specified, randomly generate one + agent_id = agent_id or self._cache.agent_id or f"Agent_{str(uuid.uuid4())}" + self._cache.vector_index = extra_info["vector_index"] + self._cache.agent_id = agent_id + self._cache.agent = agent + + # save the cache to disk + agent_cache_path = f"{self._cache_dir}/{agent_id}" + self._cache.save_to_disk(agent_cache_path) + # save to agent ids + add_agent_id_to_directory(str(self._cache_dir), agent_id) + + return "Agent created successfully." + + def update_agent( + self, + agent_id: str, + system_prompt: Optional[str] = None, + include_summarization: Optional[bool] = None, + top_k: Optional[int] = None, + chunk_size: Optional[int] = None, + embed_model: Optional[str] = None, + llm: Optional[str] = None, + additional_tools: Optional[List] = None, + ) -> None: + """Update agent. + + Delete old agent by ID and create a new one. + Optionally update the system prompt and RAG parameters. + + NOTE: Currently is manually called, not meant for agent use. + + """ + # remove saved agent from directory, since we'll be re-saving + remove_agent_from_directory(str(AGENT_CACHE_DIR), self.cache.agent_id) + + # set agent id + self.cache.agent_id = agent_id + + # set system prompt + if system_prompt is not None: + self.cache.system_prompt = system_prompt + # get agent_builder + # We call set_rag_params and create_agent, which will + # update the cache + # TODO: decouple functions from tool functions exposed to the agent + rag_params_dict: Dict[str, Any] = {} + if include_summarization is not None: + rag_params_dict["include_summarization"] = include_summarization + if top_k is not None: + rag_params_dict["top_k"] = top_k + if chunk_size is not None: + rag_params_dict["chunk_size"] = chunk_size + if embed_model is not None: + rag_params_dict["embed_model"] = embed_model + if llm is not None: + rag_params_dict["llm"] = llm + + self.set_rag_params(**rag_params_dict) + + # update tools + if additional_tools is not None: + self.cache.tools = additional_tools + + # this will update the agent in the cache + self.create_agent() + + +#################### +#### META Agent #### +#################### + +RAG_BUILDER_SYS_STR = """\ +You are helping to construct an agent given a user-specified task. +You should generally use the tools in this rough order to build the agent. + +1) Create system prompt tool: to create the system prompt for the agent. +2) Load in user-specified data (based on file paths they specify). +3) Decide whether or not to add additional tools. +4) Set parameters for the RAG pipeline. +5) Build the agent + +This will be a back and forth conversation with the user. You should +continue asking users if there's anything else they want to do until +they say they're done. To help guide them on the process, +you can give suggestions on parameters they can set based on the tools they +have available (e.g. "Do you want to set the number of documents to retrieve?") + +""" + + +### DEFINE Agent #### +# NOTE: here we define a function that is dependent on the LLM, +# please make sure to update the LLM above if you change the function below + + +def _get_builder_agent_tools(agent_builder: RAGAgentBuilder) -> List[FunctionTool]: + """Get list of builder agent tools to pass to the builder agent.""" + # see if metaphor api key is set, otherwise don't add web tool + # TODO: refactor this later + + if "metaphor_key" in st.secrets: + fns: List[Callable] = [ + agent_builder.create_system_prompt, + agent_builder.load_data, + agent_builder.add_web_tool, + agent_builder.get_rag_params, + agent_builder.set_rag_params, + agent_builder.create_agent, + ] + else: + fns = [ + agent_builder.create_system_prompt, + agent_builder.load_data, + agent_builder.get_rag_params, + agent_builder.set_rag_params, + agent_builder.create_agent, + ] + + fn_tools: List[FunctionTool] = [FunctionTool.from_defaults(fn=fn) for fn in fns] + return fn_tools + + +# define agent +# @st.cache_resource +def load_meta_agent_and_tools( + cache: Optional[ParamCache] = None, +) -> Tuple[BaseAgent, RAGAgentBuilder]: + + # think of this as tools for the agent to use + agent_builder = RAGAgentBuilder(cache) + + fn_tools = _get_builder_agent_tools(agent_builder) + + builder_agent = load_meta_agent( + fn_tools, llm=BUILDER_LLM, system_prompt=RAG_BUILDER_SYS_STR, verbose=True + ) + + return builder_agent, agent_builder \ No newline at end of file diff --git a/builder_config.py b/core/builder_config.py similarity index 100% rename from builder_config.py rename to core/builder_config.py diff --git a/callback_manager.py b/core/callback_manager.py similarity index 100% rename from callback_manager.py rename to core/callback_manager.py diff --git a/constants.py b/core/constants.py similarity index 100% rename from constants.py rename to core/constants.py diff --git a/core/param_cache.py b/core/param_cache.py new file mode 100644 index 0000000..88936ca --- /dev/null +++ b/core/param_cache.py @@ -0,0 +1,175 @@ +"""Param cache.""" + +from llama_index.llms import OpenAI, ChatMessage, Anthropic, Replicate +from llama_index.llms.base import LLM +from llama_index.llms.utils import resolve_llm +from pydantic import BaseModel, Field +import os +from llama_index.agent import OpenAIAgent, ReActAgent +from llama_index.agent.react.prompts import REACT_CHAT_SYSTEM_HEADER +from llama_index import ( + VectorStoreIndex, + SummaryIndex, + ServiceContext, + StorageContext, + Document, + load_index_from_storage, +) +from llama_index.prompts import ChatPromptTemplate +from typing import List, cast, Optional +from llama_index import SimpleDirectoryReader +from llama_index.embeddings.utils import resolve_embed_model +from llama_index.tools import QueryEngineTool, ToolMetadata, FunctionTool +from llama_index.agent.types import BaseAgent +from llama_index.chat_engine.types import BaseChatEngine +from llama_index.agent.react.formatter import ReActChatFormatter +from llama_index.llms.openai_utils import is_function_calling_model +from llama_index.chat_engine import CondensePlusContextChatEngine +from core.builder_config import BUILDER_LLM +from typing import Dict, Tuple, Any, Callable +import streamlit as st +from pathlib import Path +import json +import uuid +from core.constants import AGENT_CACHE_DIR +import shutil +from core.utils import ( + load_data, + get_tool_objects, + construct_agent + +) + +from llama_index.callbacks import CallbackManager +from callback_manager import StreamlitFunctionsCallbackHandler + + +class RAGParams(BaseModel): + """RAG parameters. + + Parameters used to configure a RAG pipeline. + + """ + + include_summarization: bool = Field( + default=False, + description=( + "Whether to include summarization in the RAG pipeline. (only for GPT-4)" + ), + ) + top_k: int = Field( + default=2, description="Number of documents to retrieve from vector store." + ) + chunk_size: int = Field(default=1024, description="Chunk size for vector store.") + embed_model: str = Field( + default="default", description="Embedding model to use (default is OpenAI)" + ) + llm: str = Field( + default="gpt-4-1106-preview", description="LLM to use for summarization." + ) + + +class ParamCache(BaseModel): + """Cache for RAG agent builder. + + Created a wrapper class around a dict in case we wanted to more explicitly + type different items in the cache. + + """ + + # arbitrary types + class Config: + arbitrary_types_allowed = True + + # system prompt + system_prompt: Optional[str] = Field( + default=None, description="System prompt for RAG agent." + ) + # data + file_names: List[str] = Field( + default_factory=list, description="File names as data source (if specified)" + ) + urls: List[str] = Field( + default_factory=list, description="URLs as data source (if specified)" + ) + docs: List = Field(default_factory=list, description="Documents for RAG agent.") + # tools + tools: List = Field( + default_factory=list, description="Additional tools for RAG agent (e.g. web)" + ) + # RAG params + rag_params: RAGParams = Field( + default_factory=RAGParams, description="RAG parameters for RAG agent." + ) + + # agent params + vector_index: Optional[VectorStoreIndex] = Field( + default=None, description="Vector index for RAG agent." + ) + agent_id: str = Field( + default_factory=lambda: f"Agent_{str(uuid.uuid4())}", + description="Agent ID for RAG agent.", + ) + agent: Optional[BaseChatEngine] = Field(default=None, description="RAG agent.") + + def save_to_disk(self, save_dir: str) -> None: + """Save cache to disk.""" + # NOTE: more complex than just calling dict() because we want to + # only store serializable fields and be space-efficient + + dict_to_serialize = { + "system_prompt": self.system_prompt, + "file_names": self.file_names, + "urls": self.urls, + # TODO: figure out tools + "tools": self.tools, + "rag_params": self.rag_params.dict(), + "agent_id": self.agent_id, + } + # store the vector store within the agent + if self.vector_index is None: + raise ValueError("Must specify vector index in order to save.") + self.vector_index.storage_context.persist(Path(save_dir) / "storage") + + # if save_path directories don't exist, create it + if not Path(save_dir).exists(): + Path(save_dir).mkdir(parents=True) + with open(Path(save_dir) / "cache.json", "w") as f: + json.dump(dict_to_serialize, f) + + @classmethod + def load_from_disk( + cls, + save_dir: str, + ) -> "ParamCache": + """Load cache from disk.""" + storage_context = StorageContext.from_defaults( + persist_dir=str(Path(save_dir) / "storage") + ) + vector_index = cast(VectorStoreIndex, load_index_from_storage(storage_context)) + + with open(Path(save_dir) / "cache.json", "r") as f: + cache_dict = json.load(f) + + # replace rag params with RAGParams object + cache_dict["rag_params"] = RAGParams(**cache_dict["rag_params"]) + + # add in the missing fields + # load docs + cache_dict["docs"] = load_data( + file_names=cache_dict["file_names"], urls=cache_dict["urls"] + ) + # load agent from index + additional_tools = get_tool_objects(cache_dict["tools"]) + agent, _ = construct_agent( + cache_dict["system_prompt"], + cache_dict["rag_params"], + cache_dict["docs"], + vector_index=vector_index, + additional_tools=additional_tools, + # TODO: figure out tools + ) + cache_dict["vector_index"] = vector_index + cache_dict["agent"] = agent + + return cls(**cache_dict) diff --git a/core/utils.py b/core/utils.py new file mode 100644 index 0000000..f367936 --- /dev/null +++ b/core/utils.py @@ -0,0 +1,369 @@ +"""Utils.""" + +from llama_index.llms import OpenAI, ChatMessage, Anthropic, Replicate +from llama_index.llms.base import LLM +from llama_index.llms.utils import resolve_llm +from pydantic import BaseModel, Field +import os +from llama_index.agent import OpenAIAgent, ReActAgent +from llama_index.agent.react.prompts import REACT_CHAT_SYSTEM_HEADER +from llama_index import ( + VectorStoreIndex, + SummaryIndex, + ServiceContext, + StorageContext, + Document, + load_index_from_storage, +) +from llama_index.prompts import ChatPromptTemplate +from typing import List, cast, Optional +from llama_index import SimpleDirectoryReader +from llama_index.embeddings.utils import resolve_embed_model +from llama_index.tools import QueryEngineTool, ToolMetadata, FunctionTool +from llama_index.agent.types import BaseAgent +from llama_index.chat_engine.types import BaseChatEngine +from llama_index.agent.react.formatter import ReActChatFormatter +from llama_index.llms.openai_utils import is_function_calling_model +from llama_index.chat_engine import CondensePlusContextChatEngine +from core.builder_config import BUILDER_LLM +from typing import Dict, Tuple, Any, Callable +import streamlit as st +from pathlib import Path +import json +import uuid +from core.constants import AGENT_CACHE_DIR +import shutil + +from llama_index.callbacks import CallbackManager +from callback_manager import StreamlitFunctionsCallbackHandler + +def _resolve_llm(llm_str: str) -> LLM: + """Resolve LLM.""" + # TODO: make this less hardcoded with if-else statements + # see if there's a prefix + # - if there isn't, assume it's an OpenAI model + # - if there is, resolve it + tokens = llm_str.split(":") + if len(tokens) == 1: + os.environ["OPENAI_API_KEY"] = st.secrets.openai_key + llm: LLM = OpenAI(model=llm_str) + elif tokens[0] == "local": + llm = resolve_llm(llm_str) + elif tokens[0] == "openai": + os.environ["OPENAI_API_KEY"] = st.secrets.openai_key + llm = OpenAI(model=tokens[1]) + elif tokens[0] == "anthropic": + os.environ["ANTHROPIC_API_KEY"] = st.secrets.anthropic_key + llm = Anthropic(model=tokens[1]) + elif tokens[0] == "replicate": + os.environ["REPLICATE_API_KEY"] = st.secrets.replicate_key + llm = Replicate(model=tokens[1]) + else: + raise ValueError(f"LLM {llm_str} not recognized.") + return llm + +#################### +#### META TOOLS #### +#################### + + +# System prompt tool +GEN_SYS_PROMPT_STR = """\ +Task information is given below. + +Given the task, please generate a system prompt for an OpenAI-powered bot \ +to solve this task: +{task} \ + +Make sure the system prompt obeys the following requirements: +- Tells the bot to ALWAYS use tools given to solve the task. \ +NEVER give an answer without using a tool. +- Does not reference a specific data source. \ +The data source is implicit in any queries to the bot, \ +and telling the bot to analyze a specific data source might confuse it given a \ +user query. + +""" + +gen_sys_prompt_messages = [ + ChatMessage( + role="system", + content="You are helping to build a system prompt for another bot.", + ), + ChatMessage(role="user", content=GEN_SYS_PROMPT_STR), +] + +GEN_SYS_PROMPT_TMPL = ChatPromptTemplate(gen_sys_prompt_messages) + + +class RAGParams(BaseModel): + """RAG parameters. + + Parameters used to configure a RAG pipeline. + + """ + + include_summarization: bool = Field( + default=False, + description=( + "Whether to include summarization in the RAG pipeline. (only for GPT-4)" + ), + ) + top_k: int = Field( + default=2, description="Number of documents to retrieve from vector store." + ) + chunk_size: int = Field(default=1024, description="Chunk size for vector store.") + embed_model: str = Field( + default="default", description="Embedding model to use (default is OpenAI)" + ) + llm: str = Field( + default="gpt-4-1106-preview", description="LLM to use for summarization." + ) + + +def load_data( + file_names: Optional[List[str]] = None, urls: Optional[List[str]] = None +) -> List[Document]: + """Load data.""" + file_names = file_names or [] + urls = urls or [] + if not file_names and not urls: + raise ValueError("Must specify either file_names or urls.") + elif file_names and urls: + raise ValueError("Must specify only one of file_names or urls.") + elif file_names: + reader = SimpleDirectoryReader(input_files=file_names) + docs = reader.load_data() + elif urls: + from llama_hub.web.simple_web.base import SimpleWebPageReader + + # use simple web page reader from llamahub + loader = SimpleWebPageReader() + docs = loader.load_data(urls=urls) + else: + raise ValueError("Must specify either file_names or urls.") + + return docs + + +def load_agent( + tools: List, + llm: LLM, + system_prompt: str, + extra_kwargs: Optional[Dict] = None, + **kwargs: Any, +) -> BaseChatEngine: + """Load agent.""" + extra_kwargs = extra_kwargs or {} + if isinstance(llm, OpenAI) and is_function_calling_model(llm.model): + # TODO: use default msg handler + # TODO: separate this from agent_utils.py... + def _msg_handler(msg: str) -> None: + """Message handler.""" + st.info(msg) + st.session_state.agent_messages.append( + {"role": "assistant", "content": msg, "msg_type": "info"} + ) + + # add streamlit callbacks (to inject events) + handler = StreamlitFunctionsCallbackHandler(_msg_handler) + callback_manager = CallbackManager([handler]) + # get OpenAI Agent + agent: BaseChatEngine = OpenAIAgent.from_tools( + tools=tools, + llm=llm, + system_prompt=system_prompt, + **kwargs, + callback_manager=callback_manager, + ) + else: + if "vector_index" not in extra_kwargs: + raise ValueError( + "Must pass in vector index for CondensePlusContextChatEngine." + ) + vector_index = cast(VectorStoreIndex, extra_kwargs["vector_index"]) + rag_params = cast(RAGParams, extra_kwargs["rag_params"]) + # use condense + context chat engine + agent = CondensePlusContextChatEngine.from_defaults( + vector_index.as_retriever(similarity_top_k=rag_params.top_k), + ) + + return agent + + +def load_meta_agent( + tools: List, + llm: LLM, + system_prompt: str, + extra_kwargs: Optional[Dict] = None, + **kwargs: Any, +) -> BaseAgent: + """Load meta agent. + + TODO: consolidate with load_agent. + + The meta-agent *has* to perform tool-use. + + """ + extra_kwargs = extra_kwargs or {} + if isinstance(llm, OpenAI) and is_function_calling_model(llm.model): + # get OpenAI Agent + + agent: BaseAgent = OpenAIAgent.from_tools( + tools=tools, + llm=llm, + system_prompt=system_prompt, + **kwargs, + ) + else: + agent = ReActAgent.from_tools( + tools=tools, + llm=llm, + react_chat_formatter=ReActChatFormatter( + system_header=system_prompt + "\n" + REACT_CHAT_SYSTEM_HEADER, + ), + **kwargs, + ) + + return agent + + +def construct_agent( + system_prompt: str, + rag_params: RAGParams, + docs: List[Document], + vector_index: Optional[VectorStoreIndex] = None, + additional_tools: Optional[List] = None, +) -> Tuple[BaseChatEngine, Dict]: + """Construct agent from docs / parameters / indices.""" + extra_info = {} + additional_tools = additional_tools or [] + + # first resolve llm and embedding model + embed_model = resolve_embed_model(rag_params.embed_model) + # llm = resolve_llm(rag_params.llm) + # TODO: use OpenAI for now + # llm = OpenAI(model=rag_params.llm) + llm = _resolve_llm(rag_params.llm) + + # first let's index the data with the right parameters + service_context = ServiceContext.from_defaults( + chunk_size=rag_params.chunk_size, + llm=llm, + embed_model=embed_model, + ) + + if vector_index is None: + vector_index = VectorStoreIndex.from_documents( + docs, service_context=service_context + ) + else: + pass + + extra_info["vector_index"] = vector_index + + vector_query_engine = vector_index.as_query_engine( + similarity_top_k=rag_params.top_k + ) + all_tools = [] + vector_tool = QueryEngineTool( + query_engine=vector_query_engine, + metadata=ToolMetadata( + name="vector_tool", + description=("Use this tool to answer any user question over any data."), + ), + ) + all_tools.append(vector_tool) + if rag_params.include_summarization: + summary_index = SummaryIndex.from_documents( + docs, service_context=service_context + ) + summary_query_engine = summary_index.as_query_engine() + summary_tool = QueryEngineTool( + query_engine=summary_query_engine, + metadata=ToolMetadata( + name="summary_tool", + description=( + "Use this tool for any user questions that ask " + "for a summarization of content" + ), + ), + ) + all_tools.append(summary_tool) + + # then we add tools + all_tools.extend(additional_tools) + + # build agent + if system_prompt is None: + return "System prompt not set yet. Please set system prompt first." + + agent = load_agent( + all_tools, + llm=llm, + system_prompt=system_prompt, + verbose=True, + extra_kwargs={"vector_index": vector_index, "rag_params": rag_params}, + ) + return agent, extra_info + + +def get_web_agent_tool() -> QueryEngineTool: + """Get web agent tool. + + Wrap with our load and search tool spec. + + """ + from llama_hub.tools.metaphor.base import MetaphorToolSpec + + # TODO: set metaphor API key + metaphor_tool = MetaphorToolSpec( + api_key=st.secrets.metaphor_key, + ) + metaphor_tool_list = metaphor_tool.to_tool_list() + + # TODO: LoadAndSearch doesn't work yet + # The search_and_retrieve_documents tool is the third in the tool list, + # as seen above + # wrapped_retrieve = LoadAndSearchToolSpec.from_defaults( + # metaphor_tool_list[2], + # ) + + # NOTE: requires openai right now + # We don't give the Agent our unwrapped retrieve document tools + # instead passing the wrapped tools + web_agent = OpenAIAgent.from_tools( + # [*wrapped_retrieve.to_tool_list(), metaphor_tool_list[4]], + metaphor_tool_list, + llm=BUILDER_LLM, + verbose=True, + ) + + # return agent as a tool + # TODO: tune description + web_agent_tool = QueryEngineTool.from_defaults( + web_agent, + name="web_agent", + description=""" + This agent can answer questions by searching the web. \ +Use this tool if the answer is ONLY likely to be found by searching \ +the internet, especially for queries about recent events. + """, + ) + + return web_agent_tool + + +def get_tool_objects(tool_names: List[str]) -> List: + """Get tool objects from tool names.""" + # construct additional tools + tool_objs = [] + for tool_name in tool_names: + if tool_name == "web_search": + # build web agent + tool_objs.append(get_web_agent_tool()) + else: + raise ValueError(f"Tool {tool_name} not recognized.") + + return tool_objs + diff --git "a/pages/2_\342\232\231\357\270\217_RAG_Config.py" "b/pages/2_\342\232\231\357\270\217_RAG_Config.py" index 58d992e..b520855 100644 --- "a/pages/2_\342\232\231\357\270\217_RAG_Config.py" +++ "b/pages/2_\342\232\231\357\270\217_RAG_Config.py" @@ -2,10 +2,12 @@ import streamlit as st from typing import cast, Optional -from agent_utils import ( +from core.param_cache import ( RAGParams, - RAGAgentBuilder, ParamCache, +) +from core.agent_builder import ( + RAGAgentBuilder, remove_agent_from_directory, ) from st_utils import update_selected_agent_with_id diff --git "a/pages/3_\360\237\244\226_Generated_RAG_Agent.py" "b/pages/3_\360\237\244\226_Generated_RAG_Agent.py" index 4f897b0..cbff512 100644 --- "a/pages/3_\360\237\244\226_Generated_RAG_Agent.py" +++ "b/pages/3_\360\237\244\226_Generated_RAG_Agent.py" @@ -1,7 +1,8 @@ """Streamlit page showing builder config.""" import streamlit as st -from typing import cast, Optional -from agent_utils import RAGAgentBuilder, ParamCache +from typing import cast, Optional +from core.agent_builder import RAGAgentBuilder +from core.param_cache import ParamCache from st_utils import add_sidebar diff --git a/st_utils.py b/st_utils.py index e7f4cd7..c59835b 100644 --- a/st_utils.py +++ b/st_utils.py @@ -1,9 +1,9 @@ """Streamlit utils.""" -from agent_utils import ( +from core.agent_builder import ( load_agent_ids_from_directory, load_cache_from_directory, ) -from constants import ( +from core.constants import ( AGENT_CACHE_DIR, ) from typing import Optional