Skip to content

Commit

Permalink
cr
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryjliu committed Dec 2, 2023
1 parent bf930dc commit 3f81f7e
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 7 deletions.
8 changes: 4 additions & 4 deletions core/agent_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from llama_index.llms.openai_utils import is_function_calling_model
from llama_index.chat_engine import CondensePlusContextChatEngine
from core.builder_config import BUILDER_LLM
from typing import Dict, Tuple, Any, Callable
from typing import Dict, Tuple, Any, Callable, Union
import streamlit as st
from pathlib import Path
import json
Expand All @@ -41,7 +41,7 @@ class AgentCacheRegistry:
"""

def __init__(self, dir: str) -> None:
def __init__(self, dir: Union[str, Path]) -> None:
"""Init params."""
self._dir = dir

Expand Down Expand Up @@ -92,7 +92,7 @@ def delete_agent_cache(self, agent_id: str) -> None:
# modify / resave agent_ids
agent_ids = self.get_agent_ids()
new_agent_ids = [id for id in agent_ids if id != agent_id]
full_path = Path(self.self._dir) / "agent_ids.json"
full_path = Path(self._dir) / "agent_ids.json"
with open(full_path, "w") as f:
json.dump({"agent_ids": new_agent_ids}, f)

Expand Down Expand Up @@ -152,7 +152,7 @@ def __init__(
) -> None:
"""Init params."""
self._cache = cache or ParamCache()
self._agent_registry = agent_registry or AgentCacheRegistry(AGENT_CACHE_DIR)
self._agent_registry = agent_registry or AgentCacheRegistry(str(AGENT_CACHE_DIR))

@property
def cache(self) -> ParamCache:
Expand Down
4 changes: 1 addition & 3 deletions st_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Streamlit utils."""
from core.agent_builder import (
load_agent_ids_from_directory,
load_cache_from_directory,
load_meta_agent_and_tools,
AgentCacheRegistry,
RAGAgentBuilder,
Expand Down Expand Up @@ -85,7 +83,7 @@ def get_current_state() -> CurrentSessionState:
"""
# get agent registry
agent_registry = AgentCacheRegistry(AGENT_CACHE_DIR)
agent_registry = AgentCacheRegistry(str(AGENT_CACHE_DIR))
if "agent_registry" not in st.session_state.keys():
st.session_state.agent_registry = agent_registry

Expand Down

0 comments on commit 3f81f7e

Please sign in to comment.