Skip to content

Commit

Permalink
Merge pull request #2 from antoninoLorenzo/rag_integration
Browse files Browse the repository at this point in the history
Rag Integration
  • Loading branch information
antoninoLorenzo committed Aug 4, 2024
2 parents 0c893d6 + 0d33410 commit e96b294
Show file tree
Hide file tree
Showing 8 changed files with 193 additions and 85 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
![pylint](https://img.shields.io/badge/PyLint-8.77-yellow?logo=python&logoColor=white)
![pylint](https://img.shields.io/badge/PyLint-8.46-yellow?logo=python&logoColor=white)

🚧 *Under Development* 🚧

Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
fastapi~=0.111.0
ollama~=0.2.1
ollama~=0.3.1
qdrant-client
spacy~=3.7.5
uvicorn
Expand All @@ -17,3 +17,4 @@ numpy~=1.26.4
google-generativeai
pydantic_settings
httpx
tool-parse
50 changes: 24 additions & 26 deletions src/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from src.agent.knowledge import Collection, Document, Store, Topic


def upload_knowledge(path: str, vdb: Store):
def initialize_knowledge(path: str, vdb: Store):
"""Used to initialize and keep updated the Knowledge Base.
Already existing Collections will not be overwritten.
:param path: where the JSON datasets are located.
Expand All @@ -15,29 +15,27 @@ def upload_knowledge(path: str, vdb: Store):
if not (p.is_file() and p.suffix == '.json'):
continue

if p.name in ['hack_tricks.json', 'null_byte.json']:
continue

with open(str(p), 'r', encoding='utf-8') as file:
data = json.load(file)

documents = []
topics = set()
for item in data:
topic = Topic(item['category'])
topics.add(topic)

document = Document(
name=item['title'],
content=item['content'],
topic=topic
if p.name in ['owasp.json']:
with open(str(p), 'r', encoding='utf-8') as file:
data = json.load(file)

documents = []
topics = set()
for item in data:
topic = Topic(item['category'])
topics.add(topic)

document = Document(
name=item['title'],
content=item['content'],
topic=topic
)
documents.append(document)

collection = Collection(
collection_id=i,
title=p.name,
documents=documents,
topics=list(topics)
)
documents.append(document)

collection = Collection(
collection_id=i,
title=p.name,
documents=documents,
topics=list(topics)
)
vdb.create_collection(collection)
vdb.create_collection(collection)
90 changes: 58 additions & 32 deletions src/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import json
from json import JSONDecodeError

from tool_parse import ToolRegistry, NotRegisteredError

from src.agent.knowledge import Store
from src.agent.llm import LLM, AVAILABLE_PROVIDERS
from src.agent.memory import Memory, Message, Role
Expand All @@ -16,10 +18,18 @@ class Agent:

def __init__(self, model: str,
tools_docs: str = '',
knowledge_base: Store = None,
llm_endpoint: str = 'http://localhost:11434',
provider: str = 'ollama',
provider_key: str = ''):
provider_key: str = '',
tool_registry: ToolRegistry | None = None):
"""
:param model: llm model name
:param tools_docs: documentation of penetration testing tools
:param llm_endpoint: llm endpoint
:param provider: llm provider
:param provider_key: if provider requires an api key it must be provided here
:param tool_registry: tools that can be executed automatically by the LLM use ToolRegistry
"""
# Pre-conditions
if provider not in AVAILABLE_PROVIDERS.keys():
raise RuntimeError(f'{provider} not supported.')
Expand All @@ -37,7 +47,11 @@ def __init__(self, model: str,
api_key=provider_key
)
self.mem = Memory()
self.vdb: Store | None = knowledge_base
self.tr: ToolRegistry | None = tool_registry
if tool_registry is not None and len(tool_registry) > 0:
self.tools = [tool for tool in self.tr.marshal('base')]
else:
self.tools = []

# Prompts
self._available_tools = tools_docs
Expand All @@ -48,18 +62,10 @@ def __init__(self, model: str,
self.system_plan_con = PROMPTS[model]['plan_conversion']['system']
self.user_plan_con = PROMPTS[model]['plan_conversion']['user']

def query(self, sid: int, user_in: str, rag=True):
"""Performs a query to the Large Language Model,
set `rag=True` to leverage Retrieval Augmented Generation."""
if rag:
context = self._retrieve(user_in)
prompt = self.user_plan_gen.format(
user_input=user_in,
context=context
)
else:
prompt = '\n'.join(self.user_plan_gen.split('\n')[:-3])
prompt = prompt.format(user_input=user_in)
def query(self, sid: int, user_in: str):
"""Performs a query to the Large Language Model, will use RAG
if provided with the necessary tool to perform rag search"""
prompt = self.user_plan_gen.format(user_input=user_in)

# ensure session is initialized (otherwise llm has no system prompt)
if sid not in self.mem.sessions.keys():
Expand All @@ -71,24 +77,53 @@ def query(self, sid: int, user_in: str, rag=True):
)
messages = self.mem.get_session(sid).messages_to_dict_list()

# call tools
if self.tools:
tool_response = self.llm.tool_query(
messages,
tools=self.tools
)
if tool_response['message'].get('tool_calls'):
results = self.invoke_tools(tool_response)
messages.extend(results)

# generate response
response = ''
# prompt_tokens = 0
response_tokens = 0
for chunk in self.llm.query(messages):
# if chunk['done']:
# prompt_tokens = chunk['prompt_eval_count'] if 'prompt_eval_count' in chunk else None
# response_tokens = chunk['eval_count']
yield chunk # ['message']['content']

response += chunk # ['message']['content']
yield chunk
response += chunk

# self.mem.get_session(sid).messages[-1].tokens = prompt_tokens
self.mem.store_message(
sid,
Message(Role.ASSISTANT, response, tokens=response_tokens)
)

def invoke_tools(self, tool_response):
"""Execute tools (ex. RAG) from llm response"""
results = []

call_stack = []
for tool in tool_response['message']['tool_calls']:
tool_meta = {
'name': tool['function']['name'],
'args': tool['function']['arguments']
}

if tool_meta in call_stack:
continue
try:
res = self.tr.compile(
name=tool_meta['name'],
arguments=tool_meta['args']
)
call_stack.append(tool_meta)
results.append({'role': 'tool', 'content': str(res)})
except NotRegisteredError:
pass

return results

def new_session(self, sid: int):
"""Initializes a new conversation"""
self.mem.store_message(sid, Message(Role.SYS, self.system_plan_gen))
Expand Down Expand Up @@ -178,12 +213,3 @@ def execute_plan(self, sid):

self.mem.store_plan(sid, plan)

def _retrieve(self, user_in: str):
"""Get context from Qdrant"""
if not self.vdb:
raise RuntimeError('RAG is not initialized')
context = ''
for retrieved in self.vdb.retrieve(user_in):
context += (f"{retrieved.payload['title']}:"
f"\n{retrieved.payload['text']}\n\n")
return context
9 changes: 5 additions & 4 deletions src/agent/knowledge/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class Store:
def __init__(self,
url: str = 'http://localhost:6333',
embedding_url: str = 'http://localhost:11434',
embedding_model: str = 'nomic-embed-text',
in_memory: bool = False,
router: Router = None
):
Expand All @@ -40,7 +41,7 @@ def __init__(self,
self._collections: Dict[str: Collection] = coll

self._encoder = ollama.Client(host=embedding_url).embeddings
self._embedding_model: str = 'nomic-embed-text'
self._embedding_model: str = embedding_model
self._embedding_size: int = len(
self._encoder(
self._embedding_model,
Expand Down Expand Up @@ -111,14 +112,14 @@ def upload(self, document: Document, collection_name: str):
'title': document.name,
'topic': str(document.topic),
'text': ch,
'embedding': self._encoder(self._embedding_model, ch)
'embedding': self._encoder(self._embedding_model, ch)['embedding']
} for ch in doc_chunks]
current_len = self._collections[collection_name].size

points = [
models.PointStruct(
id=current_len + i,
vector=item['embedding']['embedding'],
vector=item['embedding'],
payload={'text': item['text'], 'title': item['title'], 'topic': item['topic']}
)
for i, item in enumerate(emb_chunks)
Expand Down Expand Up @@ -154,7 +155,7 @@ def retrieve_from(self, query: str, collection_name: str, limit: int = 3):
limit=limit,
score_threshold=0.5
)
return hits
return [points.payload['text'] for points in hits]

def get_available_collections(self):
"""Makes a query to Qdrant and uses collections metadata to get
Expand Down
45 changes: 37 additions & 8 deletions src/agent/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,34 +8,37 @@

import requests.exceptions
from requests import Session
from ollama import Client
from ollama._types import ResponseError
from ollama import Client, ResponseError
from httpx import ConnectError

AVAILABLE_MODELS = {
'llama3': {
'options': {
'temperature': 0.5,
'num_ctx': 8192
}
},
'tools': True
},
'gemma:7b': {
'options': {
'temperature': 0.5,
'num_ctx': 8192
}
},
'tools': False
},
'gemma2:9b': {
'options': {
'temperature': 0.5,
'num_ctx': 8192
}
},
'tools': False
},
'mistral': {
'options': {
'temperature': 0.5,
'num_ctx': 8192
}
},
'tools': True
},
}

Expand All @@ -51,6 +54,10 @@ class Provider(ABC):
def query(self, messages: list):
"""Implement to makes query to the LLM provider"""

@abstractmethod
def tool_query(self, messages: list, tools: list | None = None):
"""Implement for LLM tool calling"""


class ProviderError(Exception):
"""Just a wrapper to Exception for error handling
Expand All @@ -67,7 +74,7 @@ def __post_init__(self):
raise ValueError(f'Model {self.model} is not available')
self.client = Client(self.client_url)

def query(self, messages: list):
def query(self, messages: list, stream=True, tools: list | None = None):
"""Generator that returns response chunks."""
try:
stream = self.client.chat(
Expand All @@ -81,6 +88,21 @@ def query(self, messages: list):
except ResponseError as err:
raise ProviderError(err)

def tool_query(self, messages: list, tools: list | None = None):
""""""
if not AVAILABLE_MODELS[self.model]['tools']:
raise NotImplementedError(f'Model {self.model} do not implement tool calling')

if not tools:
# TODO: should add validation for tools
raise ValueError('Empty tool list')

return self.client.chat(
model=self.model,
messages=messages,
tools=tools
)


@dataclass
class OpenRouter(Provider):
Expand All @@ -94,7 +116,7 @@ def __post_init__(self):
'mistral': 'mistralai/mistral-7b-instruct:free'
}

def query(self, messages: list):
def query(self, messages: list, stream=True, tools: list | None = None):
"""Generator that returns response chunks."""
response = self.session.post(
url=self.client_url,
Expand All @@ -120,6 +142,9 @@ def query(self, messages: list):

return output

def tool_query(self, messages: list, tools: list | None = None):
raise NotImplementedError("Tool Calling not available for OpenRouter")


@dataclass
class LLM:
Expand All @@ -141,3 +166,7 @@ def query(self, messages: list):
"""Generator that returns response chunks."""
for chunk in self.provider.query(messages):
yield chunk

def tool_query(self, messages: list, tools: list | None = None):
""""""
return self.provider.tool_query(messages, tools)
Loading

0 comments on commit e96b294

Please sign in to comment.