Skip to content

Commit

Permalink
Update rag benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
antoninoLorenzo committed Jul 16, 2024
1 parent d0a4324 commit c6c2da6
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 30 deletions.
Binary file not shown.
44 changes: 27 additions & 17 deletions test/benchmarks/rag/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@
}


def init_knowledge_base(data: dict[str: list[Topic]]) -> Store:
def init_knowledge_base(data: dict[str: list[Topic]], embedding_url: str) -> Store:
"""Creates a connection to the Vector Database and
uploads the data used to generate the synthetic dataset.
:param data: {path to a JSON file : topic list}
:param embedding_url: llm endpoint
"""
store = Store()
store = Store(in_memory=True, embedding_url=embedding_url)
i = 0
for p, topics in data.items():
path = Path(p)
Expand Down Expand Up @@ -71,7 +72,8 @@ def init_knowledge_base(data: dict[str: list[Topic]]) -> Store:
return store


def generate_evaluation_dataset(vdb: Store, qa_paths: list, model: str = 'gemma2:9b'):
def generate_evaluation_dataset(vdb: Store, qa_paths: list, client_url: str,
model: str = 'gemma2:9b'):
"""Uses the RAG pipeline to generate an evaluation dataset composed of
questions and ground truths from Q&A dataset and context + answers from
the RAG pipeline."""
Expand All @@ -81,20 +83,22 @@ def generate_evaluation_dataset(vdb: Store, qa_paths: list, model: str = 'gemma2
)

def gen_context_answer(question: str, llm: LLM):
points = vdb.retrieve(question, 'owasp')
points = vdb.retrieve_from(question, 'owasp')
context_list = [f'{p.payload["title"]}: {p.payload["text"]}' for p in points]
context = '\n'.join(context_list)
answer = llm.query(
response = llm.query(
messages=[
{'role': 'system', 'content': GEN_PROMPT[model]['sys']},
{'role': 'user', 'content': GEN_PROMPT[model]['usr'].format(query=question, context=context)}
],
stream=False
)['message']['content']
)
answer = ''
for chunk in response:
answer += chunk

return context_list, answer

generator = LLM(model=model)
generator = LLM(model=model, client_url=client_url)
eval_data = []
for i, items in tqdm(qa.iterrows(), total=len(qa), desc='Retrieving context and generating answers.'):
ctx, ans = gen_context_answer(items.question, generator)
Expand All @@ -120,13 +124,15 @@ def evaluate(vdb: Store, qa_paths: list, endpoint: str,
- Evaluating the full contexts-question-answer-ground_truths dataset.
"""
eval_dataset = generate_evaluation_dataset(vdb, qa_paths, generation_model)
eval_dataset = generate_evaluation_dataset(
vdb=vdb,
qa_paths=qa_paths,
model=generation_model,
client_url=endpoint
)

# Setup evaluation metrics
llm = LLM(
model='gemma2:9b',
client_url=endpoint,
)
llm = LLM(model='gemma2:9b', client_url=endpoint)
ctx_recall = ContextRecall(
EVAL_PROMPTS[evaluation_model]['context_recall']['sys'],
EVAL_PROMPTS[evaluation_model]['context_recall']['usr'],
Expand Down Expand Up @@ -210,17 +216,21 @@ def plot_eval(plot_df: pd.DataFrame, name: str):
raise RuntimeError('Missing environment variable "ENDPOINT"')

knowledge_base: Store = init_knowledge_base({
'../../../data/json/owasp.json': [Topic.WebPenetrationTesting]
})
'../../../data/json/owasp.json': ['Web Pentesting'],
}, embedding_url=OLLAMA_ENDPOINT)

synthetic_qa_paths = [
'../../../data/rag_eval/owasp_100.json',
# '../../../data/rag_eval/owasp_100-200.json'
'../../../data/rag_eval/owasp_50.json',
]

eval_results_df = evaluate(
vdb=knowledge_base,
qa_paths=synthetic_qa_paths,
endpoint=OLLAMA_ENDPOINT
)
print(eval_results_df.head())
eval_results_df.to_json('./tmp.json')

# eval_results_df = pd.read_json('./tmp.json')

update_evaluation_plots(eval_results_df)
30 changes: 17 additions & 13 deletions test/benchmarks/rag/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@

from src.agent.llm import LLM, Ollama


# TODO:
# rating could be done categorically instead of numerically
# ex. great = 1; good = 0.7; inaccurate = 0.3; bad = 0
EVAL_PROMPTS = {
'mistral:7b': {
'gemma2:9b': {
'context_recall': {
'sys': textwrap.dedent("""
Given a context, and an answer, analyze each sentence in the answer and classify if the sentence can be attributed to the given context or not. Use only "Yes" (1) or "No" (0) as a binary classification.
Expand Down Expand Up @@ -89,20 +91,16 @@ def compute(self, *args, **kwargs) -> float:

@staticmethod
def extract_response(response):
"""Extracts the json results from a HuggingFace Inference Endpoint response"""
print(response)
eval_json = response['message']['content']
# TODO: check
# [0]['generated_text'].split('\n')[-1]

"""Extracts the json results from response"""
try:
return np.mean(json.loads(eval_json)['result'])
# TODO: validate response response type
return np.mean(json.loads(response)['result'])
except JSONDecodeError:
match = re.search(JSON_PATTERN, eval_json)
match = re.search(JSON_PATTERN, response)
if match:
return np.mean(json.loads(match.group())['result'])
else:
return eval_json
return response


class ContextRecall(Metric):
Expand All @@ -115,7 +113,10 @@ def compute(self, answer: str, context: str):
{'role': 'user', 'content': self.user_prompt.format(answer=answer, context=context)}
]

result = self.llm.query(messages)
response = self.llm.query(messages)
result = ''
for chunk in response:
result += chunk
return self.extract_response(result)


Expand All @@ -129,6 +130,9 @@ def compute(self, question: str, answer: str, context: str):
{'role': 'user', 'content': self.user_prompt.format(question=question, answer=answer, context=context)}
]

result = self.llm.query(messages)
response = self.llm.query(messages)
result = ''
for chunk in response:
result += chunk
return self.extract_response(result)

0 comments on commit c6c2da6

Please sign in to comment.