Skip to content

Commit

Permalink
Implemented tool usage
Browse files Browse the repository at this point in the history
  • Loading branch information
antoninoLorenzo committed Aug 4, 2024
1 parent 93083c7 commit 440a259
Showing 1 changed file with 61 additions and 22 deletions.
83 changes: 61 additions & 22 deletions src/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import json
from json import JSONDecodeError

from tool_parse import ToolRegistry, NotRegisteredError

from src.agent.knowledge import Store
from src.agent.llm import LLM, AVAILABLE_PROVIDERS
from src.agent.memory import Memory, Message, Role
Expand All @@ -19,7 +21,16 @@ def __init__(self, model: str,
knowledge_base: Store = None,
llm_endpoint: str = 'http://localhost:11434',
provider: str = 'ollama',
provider_key: str = ''):
provider_key: str = '',
tool_registry: ToolRegistry | None = None):
"""
:param model: llm model name
:param tools_docs: documentation of penetration testing tools
:param knowledge_base:
:param llm_endpoint: llm endpoint
:param provider: llm provider
:param provider_key: if provider requires an api key it must be provided here
"""
# Pre-conditions
if provider not in AVAILABLE_PROVIDERS.keys():
raise RuntimeError(f'{provider} not supported.')
Expand All @@ -37,7 +48,13 @@ def __init__(self, model: str,
api_key=provider_key
)
self.mem = Memory()
self.vdb: Store | None = knowledge_base
# TODO: remove Store use tool calling
# self.vdb: Store | None = knowledge_base
self.tr: ToolRegistry | None = tool_registry
if tool_registry is not None:
self.tools = [tool for tool in self.tr.marshal('base')]
else:
self.tools = []

# Prompts
self._available_tools = tools_docs
Expand All @@ -48,18 +65,10 @@ def __init__(self, model: str,
self.system_plan_con = PROMPTS[model]['plan_conversion']['system']
self.user_plan_con = PROMPTS[model]['plan_conversion']['user']

def query(self, sid: int, user_in: str, rag=True):
"""Performs a query to the Large Language Model,
set `rag=True` to leverage Retrieval Augmented Generation."""
if rag:
context = self._retrieve(user_in)
prompt = self.user_plan_gen.format(
user_input=user_in,
context=context
)
else:
prompt = '\n'.join(self.user_plan_gen.split('\n')[:-3])
prompt = prompt.format(user_input=user_in)
def query(self, sid: int, user_in: str):
"""Performs a query to the Large Language Model, will use RAG
if provided with the necessary tool to perform rag search"""
prompt = self.user_plan_gen.format(user_input=user_in)

# ensure session is initialized (otherwise llm has no system prompt)
if sid not in self.mem.sessions.keys():
Expand All @@ -71,19 +80,23 @@ def query(self, sid: int, user_in: str, rag=True):
)
messages = self.mem.get_session(sid).messages_to_dict_list()

# call tools
if self.tools:
tool_response = self.llm.tool_query(
messages,
tools=self.tools
)
if tool_response['message'].get('tool_calls'):
results = self.invoke_tools(tool_response)
messages.extend(results)

# generate response
response = ''
# prompt_tokens = 0
response_tokens = 0
for chunk in self.llm.query(messages):
# if chunk['done']:
# prompt_tokens = chunk['prompt_eval_count'] if 'prompt_eval_count' in chunk else None
# response_tokens = chunk['eval_count']
yield chunk # ['message']['content']

response += chunk # ['message']['content']
yield chunk
response += chunk

# self.mem.get_session(sid).messages[-1].tokens = prompt_tokens
self.mem.store_message(
sid,
Message(Role.ASSISTANT, response, tokens=response_tokens)
Expand Down Expand Up @@ -113,6 +126,32 @@ def rename_session(self, sid: int, session_name: str):
"""Rename the specified session"""
self.mem.rename_session(sid, session_name)

def invoke_tools(self, tool_response):
"""Execute tools (ex. RAG) from llm response"""
results = []

call_stack = []
for tool in tool_response['message']['tool_calls']:
tool_meta = {
'name': tool['function']['name'],
'args': tool['function']['arguments']
}

if tool_meta in call_stack:
continue
print(tool_meta)
try:
res = self.tr.compile(
name=tool_meta['name'],
arguments=tool_meta['args']
)
call_stack.append(tool_meta)
results.append({'role': 'tool', 'content': str(res)})
except NotRegisteredError:
pass

return results

def extract_plan(self, plan_nl):
"""Converts a structured LLM response in a Plan object"""
prompt = self.user_plan_con.format(query=plan_nl)
Expand Down

0 comments on commit 440a259

Please sign in to comment.