Skip to content

Commit

Permalink
Implemented tool_query
Browse files Browse the repository at this point in the history
  • Loading branch information
antoninoLorenzo committed Aug 4, 2024
1 parent 525dd19 commit 93083c7
Showing 1 changed file with 37 additions and 8 deletions.
45 changes: 37 additions & 8 deletions src/agent/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,34 +8,37 @@

import requests.exceptions
from requests import Session
from ollama import Client
from ollama._types import ResponseError
from ollama import Client, ResponseError
from httpx import ConnectError

AVAILABLE_MODELS = {
'llama3': {
'options': {
'temperature': 0.5,
'num_ctx': 8192
}
},
'tools': True
},
'gemma:7b': {
'options': {
'temperature': 0.5,
'num_ctx': 8192
}
},
'tools': False
},
'gemma2:9b': {
'options': {
'temperature': 0.5,
'num_ctx': 8192
}
},
'tools': False
},
'mistral': {
'options': {
'temperature': 0.5,
'num_ctx': 8192
}
},
'tools': True
},
}

Expand All @@ -51,6 +54,10 @@ class Provider(ABC):
def query(self, messages: list):
"""Implement to makes query to the LLM provider"""

@abstractmethod
def tool_query(self, messages: list, tools: list | None = None):
"""Implement for LLM tool calling"""


class ProviderError(Exception):
"""Just a wrapper to Exception for error handling
Expand All @@ -67,7 +74,7 @@ def __post_init__(self):
raise ValueError(f'Model {self.model} is not available')
self.client = Client(self.client_url)

def query(self, messages: list):
def query(self, messages: list, stream=True, tools: list | None = None):
"""Generator that returns response chunks."""
try:
stream = self.client.chat(
Expand All @@ -81,6 +88,21 @@ def query(self, messages: list):
except ResponseError as err:
raise ProviderError(err)

def tool_query(self, messages: list, tools: list | None = None):
""""""
if not AVAILABLE_MODELS[self.model]['tools']:
raise NotImplementedError(f'Model {self.model} do not implement tool calling')

if not tools:
# TODO: should add validation for tools
raise ValueError('Empty tool list')

return self.client.chat(
model=self.model,
messages=messages,
tools=tools
)


@dataclass
class OpenRouter(Provider):
Expand All @@ -94,7 +116,7 @@ def __post_init__(self):
'mistral': 'mistralai/mistral-7b-instruct:free'
}

def query(self, messages: list):
def query(self, messages: list, stream=True, tools: list | None = None):
"""Generator that returns response chunks."""
response = self.session.post(
url=self.client_url,
Expand All @@ -120,6 +142,9 @@ def query(self, messages: list):

return output

def tool_query(self, messages: list, tools: list | None = None):
raise NotImplementedError("Tool Calling not available for OpenRouter")


@dataclass
class LLM:
Expand All @@ -141,3 +166,7 @@ def query(self, messages: list):
"""Generator that returns response chunks."""
for chunk in self.provider.query(messages):
yield chunk

def tool_query(self, messages: list, tools: list | None = None):
""""""
return self.provider.tool_query(messages, tools)

0 comments on commit 93083c7

Please sign in to comment.