Skip to content

Commit

Permalink
RAG Evaluation: refactored metrics and evaluation function.
Browse files Browse the repository at this point in the history
  • Loading branch information
antoninoLorenzo committed Jul 16, 2024
1 parent 0839866 commit 91bca64
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 50 deletions.
Binary file modified data/rag_eval/results/plots/context_precision.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified data/rag_eval/results/plots/context_recall.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 4 additions & 0 deletions data/rag_eval/results/results.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,9 @@
{
"context_precision": 0.968,
"context_recall": 0.9359999999999999
},
{
"context_precision": 0.9819999999999999,
"context_recall": 0.9400000000000002
}
]
Binary file modified test/benchmarks/rag/__pycache__/metrics.cpython-311.pyc
Binary file not shown.
100 changes: 57 additions & 43 deletions test/benchmarks/rag/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,32 @@

from src.agent.llm import LLM
from src.agent.knowledge import Store, Collection, Document, Topic
from test.benchmarks.rag.metrics import ContextRecall, ContextPrecision, EVAL_PROMPTS
from test.benchmarks.rag.metrics import (
ContextRecall,
ContextPrecision,
ContextRelevancy,
Metric,
EVAL_PROMPTS
)


GEN_PROMPT = {
'gemma2:9b': {
'sys': textwrap.dedent("""
You are a Cybersecurity professional assistant, your job is to provide an answer to context specific questions.
You will be provided with additional Context information to provide an answer.
"""),
'usr': textwrap.dedent("""
Question: {query}
Context:
{context}
""")
'sys': """
You are a Cybersecurity professional assistant, your job is to provide an answer to context specific questions.
You will be provided with additional Context information to provide an answer.""",
'usr': """Question: {query}
Context:
{context}"""
}
}

METRICS = {
'context_precision': ContextPrecision,
'context_recall': ContextRecall,
'context_relevancy': ContextRelevancy
}


def init_knowledge_base(data: dict[str: list[Topic]], embedding_url: str) -> Store:
"""Creates a connection to the Vector Database and
Expand Down Expand Up @@ -111,7 +120,7 @@ def gen_context_answer(question: str, llm: LLM):
return pd.DataFrame(eval_data)


def evaluate(vdb: Store, qa_paths: list, endpoint: str,
def evaluate(vdb: Store, qa_paths: list, endpoint: str, metrics: list,
generation_model: str = 'gemma2:9b',
evaluation_model: str = 'gemma2:9b'):
"""Given the Vector Database and the synthetic Q&A dataset
Expand All @@ -124,45 +133,49 @@ def evaluate(vdb: Store, qa_paths: list, endpoint: str,
- Evaluating the full contexts-question-answer-ground_truths dataset.
"""
if len(metrics) == 0:
raise ValueError('No metrics specified.')

# Setup evaluation metrics
llm = LLM(model='gemma2:9b', client_url=endpoint)
eval_metrics: dict[Metric] = {}
for metric in metrics:
if metric not in METRICS.keys():
raise ValueError(f'Invalid metric: {metric}.')

m = METRICS[metric](
EVAL_PROMPTS[evaluation_model][metric]['sys'],
EVAL_PROMPTS[evaluation_model][metric]['usr'],
llm
)
eval_metrics[metric] = m

# Evaluation Dataset
eval_dataset = generate_evaluation_dataset(
vdb=vdb,
qa_paths=qa_paths,
model=generation_model,
client_url=endpoint
)

# Setup evaluation metrics
llm = LLM(model='gemma2:9b', client_url=endpoint)
ctx_precision = ContextPrecision(
EVAL_PROMPTS[evaluation_model]['context_precision']['sys'],
EVAL_PROMPTS[evaluation_model]['context_precision']['usr'],
llm
)

ctx_recall = ContextRecall(
EVAL_PROMPTS[evaluation_model]['context_recall']['sys'],
EVAL_PROMPTS[evaluation_model]['context_recall']['usr'],
llm
)

# Run
recall = []
for i, item in tqdm(eval_dataset.iterrows(), total=len(eval_dataset), desc='Measuring Context Recall'):
ctx = '\n\n'.join(item.contexts)
ans = item.answer
recall.append(ctx_recall.compute(ans, ctx))

precision = []
for i, item in tqdm(eval_dataset.iterrows(), total=len(eval_dataset), desc='Measuring Context Recall'):
qst = item.question
ctx = '\n\n'.join(item.contexts)
ans = item.answer
precision.append(ctx_precision.compute(qst, ans, ctx))

metrics = pd.DataFrame({
'context_recall': recall,
'context_precision': precision
})
# Run Evaluation
results = {}
for metric_name, m in eval_metrics.items():
results[metric_name] = []
for i, item in tqdm(eval_dataset.iterrows(), total=len(eval_dataset), desc=f'evaluating {metric_name}'):
ctx = ''
for idx, chunk in enumerate(item.contexts):
ctx += f"[{idx}]: {chunk}\n\n"

data = {
'context': ctx,
'question': item.question,
'answer': item.answer,
'ground_truth': item.ground_truth
}
results[metric_name].append(m.compute(data))

metrics = pd.DataFrame(results)
return metrics, eval_dataset


Expand Down Expand Up @@ -226,6 +239,7 @@ def plot_eval(plot_df: pd.DataFrame, name: str):
]

metrics_df, eval_output_dataset = evaluate(
metrics=['context_precision', 'context_recall'],
vdb=knowledge_base,
qa_paths=synthetic_qa_paths,
endpoint=OLLAMA_ENDPOINT
Expand Down
57 changes: 52 additions & 5 deletions test/benchmarks/rag/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,36 @@
IMPORTANT:
- Remember to follow the "Evaluation Guidelines"
- Provide only the JSON string, do not provide any explanation."""
},
'context_relevancy': {
'sys': """Given a question and multiple chunks of context, you should analyze each chunk of context to determine its relevancy in answering the question. Use the following categorical scoring system for your classification:
- "great" (1): The context chunk is fundamental and essential to provide an answer.
- "good" (0.7): The context chunk is useful and relevant to the answer, but not fundamental.
- "inaccurate" (0.3): The context chunk is on the same topic as the answer but isn't useful in providing a response.
- "bad" (0): The context chunk is on a different topic and not relevant to the question.
Your output should contain a list of categorical scores for each context chunk, formatted as a JSON string as follows:
{{"result": ["great" | "good" | "inaccurate" | "bad", ...]}}
Evaluation Guidelines:
- Only provide the JSON string in the specified format. Do not include any additional text.
- Ensure your assessment is based on how well each context chunk aligns with the given question and supports the answer.
- If a context chunk lacks sufficient information to be relevant to the question, your response should be "bad".
- Ensure your evaluation reflects the necessity and relevancy of each context chunk in addressing the query.""",
'usr': """Question:
{question}
Contexts:
{context}
Your output should contain a list of categorical scores for each context chunk, formatted as a JSON string as follows:
{{"result": ["great" | "good" | "inaccurate" | "bad", ...]}}
IMPORTANT:
- Remember to follow the "Evaluation Guidelines"
- Provide only the JSON string, do not provide any explanation.
"""
}
}
}
Expand Down Expand Up @@ -115,7 +145,15 @@ def query(self, sys_prompt: str, usr_prompt: str) -> str:
def extract_response(response):
"""Extracts the json results from response"""
try:
# TODO: validate response response type
result = json.loads(response)['result']
if result is list:
# list of labels (ex. context relevancy)
values = []
for label in result:
values.append(METRICS_VALUES[label] if label in METRICS_VALUES else 0)
return np.mean(values)

# single label (ex. context precision)
label = json.loads(response)['result']
return METRICS_VALUES[label] if label in METRICS_VALUES else 0
except JSONDecodeError:
Expand All @@ -129,21 +167,30 @@ def extract_response(response):
class ContextRecall(Metric):
"""Assesses how much the answer is based on the context"""

def compute(self, answer: str, context: str):
def compute(self, data: dict):
"""Computes context recall given answer and context"""
return self.query(
self.system_prompt,
self.user_prompt.format(answer=answer, context=context)
self.user_prompt.format(answer=data['answer'], context=data['context'])
)


class ContextPrecision(Metric):
"""Assesses how much the context was useful in generating the answer"""

def compute(self, question: str, answer: str, context: str):
def compute(self, data: dict):
"""Uses question, answer and context"""
return self.query(
self.system_prompt,
self.user_prompt.format(question=question, answer=answer, context=context)
self.user_prompt.format(question=data['question'], answer=data['answer'], context=data['context'])
)


class ContextRelevancy(Metric):
"""Assesses how much relevant is the retrieved context to the query"""

def compute(self, data: dict):
return self.query(
self.system_prompt,
self.user_prompt.format(question=data['question'], context=data['context'])
)
1 change: 0 additions & 1 deletion test/benchmarks/rag/tmp_eval_ds.json

This file was deleted.

1 change: 0 additions & 1 deletion test/benchmarks/rag/tmp_metrics.json

This file was deleted.

0 comments on commit 91bca64

Please sign in to comment.