Skip to content

Commit

Permalink
CLI Interface: integrated rich
Browse files Browse the repository at this point in the history
  • Loading branch information
antoninoLorenzo committed Jul 1, 2024
1 parent d4b35e6 commit 8a0dd7e
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 43 deletions.
131 changes: 89 additions & 42 deletions ai-ops-cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import sys

import requests
from requests.exceptions import Timeout, ConnectionError
from rich.console import Console
from rich.tree import Tree
from rich.prompt import Prompt, InvalidResponse


class AgentClient:
Expand All @@ -14,9 +18,10 @@ class AgentClient:
def __init__(self, api_url: str = 'http://127.0.0.1:8000'):
self.api_url = api_url
self.client = requests.Session()
self.console = Console()
self.current_session = {'sid': 0, 'name': 'Undefined'}
self.commands = {
'help': AgentClient.help,
'help': self.help,
'chat': self.chat,
'new': self.new_session,
'save': self.save_session,
Expand All @@ -25,26 +30,44 @@ def __init__(self, api_url: str = 'http://127.0.0.1:8000'):
'list': self.list_sessions,
'load': self.load_session,
'exec': self.execute_plan,
'plans': self.list_plans
'plans': self.list_plans,
'bye': ''
}

# check API availability
try:
self.client.get(self.api_url, timeout=20)
except ConnectionError:
self.console.print('backend not available', style='red')
sys.exit(-1)

def run(self):
"""Runs the main loop of the client"""
while True:
user_input = input("> ")
user_input = user_input.strip()
try:
user_input = Prompt.ask(
'> ',
console=self.console,
choices=list(self.commands.keys()),
default='help',
show_choices=False,
show_default=False
)
except InvalidResponse:
self.console.print('Invalid command', style='bold red')
self.commands['help']()

if user_input == 'bye':
break

if user_input not in self.commands.keys():
print("Invalid Input")
continue

self.commands[user_input]()

def new_session(self):
"""Creates a new session and opens the related chat"""
session_name = input('Session Name: ')
session_name = Prompt.ask(
'Session Name',
console=self.console
)

response = self.client.get(
f'{self.api_url}/session/new',
Expand All @@ -61,14 +84,17 @@ def save_session(self):
f'{self.api_url}/session/{self.current_session["sid"]}/save'
)
if response.status_code != 200:
print(f'[!] Failed: {response.status_code}')
self.console.print(f'[!] Failed: {response.status_code}')
else:
print(f'[+] Saved')
self.console.print(f'[+] Saved')
self.chat()

def rename_session(self):
"""Renames the current session"""
session_name = input('New Name: ')
session_name = Prompt(
'New Name',
console=self.console
)

response = self.client.get(
f'{self.api_url}/session/{self.current_session["sid"]}/rename',
Expand All @@ -85,7 +111,7 @@ def delete_session(self):
)
response.raise_for_status()
body = response.json()
print(f'[{"+" if body["success"] else "-"}] {body["message"]}')
self.console.print(f'[{"+" if body["success"] else "-"}] {body["message"]}')

def list_sessions(self):
"""List all sessions"""
Expand All @@ -95,17 +121,21 @@ def list_sessions(self):
response.raise_for_status()
body = response.json()
if len(body) == 0:
print('[+] No sessions found')
self.console.print('[+] No sessions found')
else:
print(f'[+] Available Sessions: ')
tree = Tree("[+] Available Sessions:")
for session in body:
print(f'| - ({session["sid"]}) {session["name"]}')
tree.add(f'({session["sid"]}) {session["name"]}')
self.console.print(tree)

def load_session(self):
"""Opens an existing session"""
session_id = input('Enter session id: ')
session_id = Prompt.ask(
'Enter session ID',
console=self.console
)
if not session_id.isdigit():
print('[-] Not a number')
self.console.print('[-] Not a number', style='bold red')
self.load_session()

response = self.client.get(
Expand All @@ -115,21 +145,31 @@ def load_session(self):
response.raise_for_status()

body = response.json()
self.current_session = {'sid': body['sid'], 'name': body['name']}
print(f'+ {self.current_session["name"]} ({self.current_session["sid"]})')
sid = body['sid']
name = body['name']
self.current_session = {'sid': sid, 'name': name}
self.console.print(f'({sid}) [bold blue]{name}[/]')

for msg in body['messages']:
print(f'{msg["role"]}: {msg["content"]}\n')
self.console.print(f'[bold white]{msg["role"]}[/]: {msg["content"]}\n')
self.chat(print_name=False)

def chat(self, print_name=True):
"""Opens a chat with the Agent"""
sid = self.current_session["sid"]
query_url = f'{self.api_url}/session/{sid}/query'

if print_name:
print(f'+ {self.current_session["name"]} ({self.current_session["sid"]})')
query_url = f'{self.api_url}/session/{self.current_session["sid"]}/query'
name = self.current_session["name"]
self.console.print(f'({sid}) [bold blue]{name}[/]')

while True:
q = input('user: ')
q = Prompt.ask(
'[bold white]User[/]',
console=self.console,
default='-1',
show_default=False
)
if q == '-1':
break

Expand All @@ -139,7 +179,7 @@ def chat(self, print_name=True):
headers=None,
stream=True) as resp:
resp.raise_for_status()
print('assistant: ')
self.console.print('[bold white]Assistant[/]: ', end='')
for chunk in resp.iter_content():
if chunk:
print(chunk.decode(), end='', flush=True)
Expand All @@ -152,10 +192,10 @@ def execute_plan(self):
headers=None,
stream=True
) as resp:
print(f'[+] Tasks\n')
self.console.print(f'[+] Tasks\n')
for task_str in resp.iter_content():
if task_str:
print(task_str.decode(), end='')
self.console.print(task_str.decode(), end='')

def list_plans(self):
"""Retrieve the plans in the current session and
Expand All @@ -173,23 +213,30 @@ def list_plans(self):
if len(task['output']) > 0:
tasks += f'\n{task["output"]}\n'

print(f'[+] Plan {i}\n\n'
self.console.print(f'[+] Plan {i}\n\n'
f'{tasks}')

@staticmethod
def help():
"""Print help message"""
print(f'help : show available commands.\n'
f'chat : open chat with the agent.\n'
f'-1 : exit chat\n'
f'new : create a new session.\n'
f'save : saves the current session.\n'
f'delete : deletes the current session from persistent sessions.\n'
f'list : show the saved sessions.\n'
f'load : opens a session.\n'
f'exec : execute the last plan generated by the agent.\n'
f'plans : lists all plans in the current session.\n'
f'bye : exit the program')
def help(self):
"""Print help message"""
# Basic Commands
self.console.print("[bold white]Basic Commands[/]")
self.console.print("- [bold blue]help[/] : Show available commands.")
self.console.print("- [bold blue]bye[/] : Exit the program")

# Agent Related
self.console.print("\n[bold white]Agent Related[/]")
self.console.print("- [bold blue]chat[/] : Open chat with the agent.")
self.console.print("- [bold blue]-1[/] : Exit chat")
self.console.print("- [bold blue]exec[/] : Execute the last plan generated by the agent.")
self.console.print("- [bold blue]plans[/] : Lists all plans in the current session.")

# Session Related
self.console.print("\n[bold white]Session Related[/]")
self.console.print("- [bold blue]new[/] : Create a new session.")
self.console.print("- [bold blue]save[/] : Save the current session.")
self.console.print("- [bold blue]delete[/] : Delete the current session from persistent sessions.")
self.console.print("- [bold blue]list[/] : Show the saved sessions.")
self.console.print("- [bold blue]load[/] : Opens a session.")


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ requests
matplotlib
seaborn
yake
rich
2 changes: 1 addition & 1 deletion src/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def query(self, sid: int, user_in: str, rag=True, stream=True):
else:
prompt = '\n'.join(self.user_plan_gen.split('\n')[:-3])
prompt = prompt.format(user_input=user_in)

self.mem.store_message(
sid,
Message(Role.USER, prompt)
Expand Down
5 changes: 5 additions & 0 deletions src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ class APISettings(BaseSettings):
)


@app.get('/')
def ping():
return ''


# --- SESSION RELATED
@app.get('/session/list')
def list_sessions():
Expand Down

0 comments on commit 8a0dd7e

Please sign in to comment.