Skip to content

Commit

Permalink
RAG Integration
Browse files Browse the repository at this point in the history
  • Loading branch information
antoninoLorenzo committed Aug 4, 2024
1 parent 440a259 commit ff90c92
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 30 deletions.
50 changes: 24 additions & 26 deletions src/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from src.agent.knowledge import Collection, Document, Store, Topic


def upload_knowledge(path: str, vdb: Store):
def initialize_knowledge(path: str, vdb: Store):
"""Used to initialize and keep updated the Knowledge Base.
Already existing Collections will not be overwritten.
:param path: where the JSON datasets are located.
Expand All @@ -15,29 +15,27 @@ def upload_knowledge(path: str, vdb: Store):
if not (p.is_file() and p.suffix == '.json'):
continue

if p.name in ['hack_tricks.json', 'null_byte.json']:
continue

with open(str(p), 'r', encoding='utf-8') as file:
data = json.load(file)

documents = []
topics = set()
for item in data:
topic = Topic(item['category'])
topics.add(topic)

document = Document(
name=item['title'],
content=item['content'],
topic=topic
if p.name in ['owasp.json']:
with open(str(p), 'r', encoding='utf-8') as file:
data = json.load(file)

documents = []
topics = set()
for item in data:
topic = Topic(item['category'])
topics.add(topic)

document = Document(
name=item['title'],
content=item['content'],
topic=topic
)
documents.append(document)

collection = Collection(
collection_id=i,
title=p.name,
documents=documents,
topics=list(topics)
)
documents.append(document)

collection = Collection(
collection_id=i,
title=p.name,
documents=documents,
topics=list(topics)
)
vdb.create_collection(collection)
vdb.create_collection(collection)
58 changes: 54 additions & 4 deletions src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,20 @@
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic_settings import BaseSettings
from tool_parse import ToolRegistry, NotRegisteredError

# from src import upload_knowledge
from src import initialize_knowledge
from src.agent import Agent
# from src.agent.knowledge import Store
from src.agent.knowledge import Store
from src.agent.plan import TaskStatus
from src.agent.tools import TOOLS
from src.agent.llm import ProviderError

load_dotenv()
TR = ToolRegistry()


# --- Get AI-OPS Settings
class AgentSettings(BaseSettings):
"""Setup for AI Agent"""
MODEL: str = os.environ.get('MODEL', 'gemma:7b')
Expand All @@ -47,6 +50,16 @@ class AgentSettings(BaseSettings):
PROVIDER_KEY: str = os.environ.get('PROVIDER_KEY', '')


class RAGSettings(BaseSettings):
"""Settings for Qdrant vector database"""
RAG_URL: str = os.environ.get('RAG_URL', 'http://localhost:6333')
IN_MEMORY: bool = os.environ.get('IN_MEMORY', True)
EMBEDDING_MODEL: str = os.environ.get('EMBEDDING_MODEL', 'nomic-embed-text')
# There the assumption that embedding url is the same of llm provider
EMBEDDING_URL: str = os.environ.get('ENDPOINT', 'http://localhost:11434')
DOCS_BASE_PATH: str = os.environ.get('DOCS_BASE_PATH', './data/json/')


class APISettings(BaseSettings):
"""Setup for API"""
ORIGINS: list = [
Expand All @@ -57,15 +70,49 @@ class APISettings(BaseSettings):

agent_settings = AgentSettings()
api_settings = APISettings()
rag_settings = RAGSettings()


# --- Initialize RAG
store = Store(
url=rag_settings.RAG_URL,
embedding_url=rag_settings.EMBEDDING_URL,
embedding_model=rag_settings.EMBEDDING_MODEL,
in_memory=rag_settings.IN_MEMORY
)

initialize_knowledge(rag_settings.DOCS_BASE_PATH, store)
available = ''
for name, coll in store.collections.items():
topics = ", ".join([topic.name for topic in coll.topics])
available += f"- '{name}': {topics}\n"


@TR.register(
description=f"""Search documents in a Retrieval Augmented Generation Vector Database.
Available collections are:
{available}
"""
)
def search_rag(rag_query: str, collection: str) -> str:
"""
:param rag_query: what should be searched
:param collection: the collection name
"""
return '\n\n'.join(store.retrieve_from(rag_query, collection))


# --- Initialize Agent
agent = Agent(
model=agent_settings.MODEL,
llm_endpoint=agent_settings.ENDPOINT,
tools_docs='\n'.join([tool.get_documentation() for tool in TOOLS]),
provider=agent_settings.PROVIDER,
provider_key=agent_settings.PROVIDER_KEY
provider_key=agent_settings.PROVIDER_KEY,
tool_registry=TR
)

# --- Initialize API
app = FastAPI()
app.add_middleware(
CORSMiddleware,
Expand Down Expand Up @@ -173,7 +220,7 @@ def delete_session(sid: int):

def query_generator(sid: int, q: str):
try:
stream = agent.query(sid, q, rag=False)
stream = agent.query(sid, q)
for chunk in stream:
yield chunk
except ProviderError as err:
Expand Down Expand Up @@ -271,3 +318,6 @@ def create_collection(title: str, base_path: str, topics: list):
...
]
"""
# TODO:
# when a new collection is uploaded the search_rag tool
# should be re-registered and the agent should be updated

0 comments on commit ff90c92

Please sign in to comment.