diff --git a/test/benchmarks/rag/metrics.py b/test/benchmarks/rag/metrics.py index 68467c0..46d8061 100644 --- a/test/benchmarks/rag/metrics.py +++ b/test/benchmarks/rag/metrics.py @@ -99,6 +99,18 @@ def compute(self, *args, **kwargs) -> float: """Needs to be implemented to evaluate a metric""" pass + def query(self, sys_prompt: str, usr_prompt: str) -> str: + """""" + messages = [ + {'role': 'system', 'content': sys_prompt}, + {'role': 'user', 'content': usr_prompt} + ] + response = self.llm.query(messages) + result = '' + for chunk in response: + result += chunk + return self.extract_response(result) + @staticmethod def extract_response(response): """Extracts the json results from response""" @@ -119,16 +131,10 @@ class ContextRecall(Metric): def compute(self, answer: str, context: str): """Computes context recall given answer and context""" - messages = [ - {'role': 'system', 'content': self.system_prompt}, - {'role': 'user', 'content': self.user_prompt.format(answer=answer, context=context)} - ] - - response = self.llm.query(messages) - result = '' - for chunk in response: - result += chunk - return self.extract_response(result) + return self.query( + self.system_prompt, + self.user_prompt.format(answer=answer, context=context) + ) class ContextPrecision(Metric): @@ -136,14 +142,8 @@ class ContextPrecision(Metric): def compute(self, question: str, answer: str, context: str): """Uses question, answer and context""" - messages = [ - {'role': 'system', 'content': self.system_prompt}, - {'role': 'user', 'content': self.user_prompt.format(question=question, answer=answer, context=context)} - ] - - response = self.llm.query(messages) - result = '' - for chunk in response: - result += chunk - return self.extract_response(result) + return self.query( + self.system_prompt, + self.user_prompt.format(question=question, answer=answer, context=context) + )