Skip to content

Commit

Permalink
Refactored message validation
Browse files Browse the repository at this point in the history
  • Loading branch information
antoninoLorenzo committed Aug 12, 2024
1 parent 9fa7151 commit 3003b8f
Showing 1 changed file with 42 additions and 24 deletions.
66 changes: 42 additions & 24 deletions src/agent/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field

import httpx
import requests.exceptions
from requests import Session
from ollama import Client, ResponseError
Expand Down Expand Up @@ -61,6 +62,32 @@ def query(self, messages: list):
def tool_query(self, messages: list, tools: list | None = None):
"""Implement for LLM tool calling"""

@staticmethod
def verify_messages_format(messages: list[dict]):
"""Format validation for messages."""
# check types
message_types_dict = [isinstance(msg, dict) for msg in messages]
if not isinstance(messages, list) or \
len(messages) == 0 or \
False in message_types_dict:
raise TypeError(f'messages must be a list of dictionaries, found: \n {messages}')

# check format
valid_roles = [str(role) for role in [Role.SYS, Role.USER, Role.ASSISTANT]]
err_message = f'messages must follow the format {{"role": "{valid_roles}", "content": "..."}}'

# check format - keys
message_keys = [list(msg.keys()) for msg in messages]
valid_keys = ['role' in keys and 'content' in keys and len(keys) == 2 for keys in message_keys]
if False in valid_keys:
raise ValueError(err_message + f'\nMessage Keys: {message_keys}')

# check format = values
message_roles = [msg['role'] in valid_roles for msg in messages]
message_content = [len(msg['content']) != 0 for msg in messages]
if False in message_roles or False in message_content:
raise ValueError(err_message)


class ProviderError(Exception):
"""Just a wrapper to Exception for error handling
Expand All @@ -82,28 +109,10 @@ def __post_init__(self):

def query(self, messages: list):
"""Generator that returns response chunks."""
# check types
message_types_dict = [isinstance(msg, dict) for msg in messages]
if not isinstance(messages, list) or \
len(messages) == 0 or \
False in message_types_dict:
raise TypeError(f'messages must be a list of dictionaries, found: \n {messages}')

# check format
valid_roles = [str(role) for role in [Role.SYS, Role.USER, Role.ASSISTANT]]
err_message = f'messages must follow the format {{"role": "{valid_roles}", "content": "..."}}'

# check format - keys
message_keys = [list(msg.keys()) for msg in messages]
valid_keys = ['role' in keys and 'content' and len(keys) == 2 in keys for keys in message_keys]
if False in valid_keys:
raise ValueError(err_message)

# check format = values
message_roles = [msg['role'] in valid_roles for msg in messages]
message_content = [len(msg['content']) != 0 for msg in messages]
if False in message_roles or False in message_content:
raise ValueError(err_message)
try:
self.verify_messages_format(messages)
except (TypeError, ValueError) as input_err:
raise input_err from input_err

try:
stream = self.client.chat(
Expand All @@ -114,11 +123,20 @@ def query(self, messages: list):
)
for chunk in stream:
yield chunk['message']['content']
except ResponseError as err:
except (ResponseError, httpx.ConnectError) as err:
raise ProviderError(err)

def tool_query(self, messages: list, tools: list | None = None):
""""""
"""Implements LLM tool calling.
:param messages:
The current conversation provided as a list of messages in the
format [{"role": "assistant/user/system", "content": "..."}, ...]
:param tools:
A list tools in the format specified by `ollama-python`, the
conversion is managed by `ToolRegistry` from `tool-parse` library.
:return
Ollama response with "message" : {"tool_calls": ...} or None.
"""
if not AVAILABLE_MODELS[self.model]['tools']:
raise NotImplementedError(f'Model {self.model} do not implement tool calling')

Expand Down

0 comments on commit 3003b8f

Please sign in to comment.