Skip to content

Commit

Permalink
Fixed Ollama class after test
Browse files Browse the repository at this point in the history
  • Loading branch information
antoninoLorenzo committed Aug 11, 2024
1 parent 4a96697 commit 39d6309
Showing 1 changed file with 41 additions and 12 deletions.
53 changes: 41 additions & 12 deletions src/agent/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from ollama import Client, ResponseError
from httpx import ConnectError

from src.agent.memory.base import Role


AVAILABLE_MODELS = {
'llama3': {
'options': {
Expand Down Expand Up @@ -72,10 +75,36 @@ class Ollama(Provider):
def __post_init__(self):
if self.model not in AVAILABLE_MODELS.keys():
raise ValueError(f'Model {self.model} is not available')
self.client = Client(self.client_url)
try:
self.client = Client(host=self.client_url)
except Exception as err:
raise RuntimeError(f'Something went wrong: {err}')

def query(self, messages: list):
"""Generator that returns response chunks."""
# check types
message_types_dict = [isinstance(msg, dict) for msg in messages]
if not isinstance(messages, list) or \
len(messages) == 0 or \
False in message_types_dict:
raise TypeError(f'messages must be a list of dictionaries, found: \n {messages}')

# check format
valid_roles = [str(role) for role in [Role.SYS, Role.USER, Role.ASSISTANT]]
err_message = f'messages must follow the format {{"role": "{valid_roles}", "content": "..."}}'

# check format - keys
message_keys = [list(msg.keys()) for msg in messages]
valid_keys = ['role' in keys and 'content' and len(keys) == 2 in keys for keys in message_keys]
if False in valid_keys:
raise ValueError(err_message)

# check format = values
message_roles = [msg['role'] in valid_roles for msg in messages]
message_content = [len(msg['content']) != 0 for msg in messages]
if False in message_roles or False in message_content:
raise ValueError(err_message)

try:
stream = self.client.chat(
model=self.model,
Expand Down Expand Up @@ -119,17 +148,17 @@ def __post_init__(self):
def query(self, messages: list):
"""Generator that returns response chunks."""
response = self.session.post(
url=self.client_url,
headers={
"Authorization": f"Bearer {self.api_key}",
"HTTP-Referer": 'https://github.com/antoninoLorenzo/AI-OPS',
"X-Title": 'AI-OPS',
},
data=json.dumps({
"model": self.models[self.model],
"messages": messages,
# 'stream': True how the fuck works
})
url=self.client_url,
headers={
"Authorization": f"Bearer {self.api_key}",
"HTTP-Referer": 'https://github.com/antoninoLorenzo/AI-OPS',
"X-Title": 'AI-OPS',
},
data=json.dumps({
"model": self.models[self.model],
"messages": messages,
# 'stream': True how the fuck works
})
)

try:
Expand Down

0 comments on commit 39d6309

Please sign in to comment.