Skip to content

Commit

Permalink
Fixed plot
Browse files Browse the repository at this point in the history
  • Loading branch information
antoninoLorenzo committed Jul 16, 2024
1 parent dc6b941 commit e5ce191
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 26 deletions.
Binary file removed data/rag_eval/results/plots/context_precision.png
Binary file not shown.
Binary file removed data/rag_eval/results/plots/context_recall.png
Binary file not shown.
Binary file removed data/rag_eval/results/plots/context_relevancy.png
Binary file not shown.
Binary file added data/rag_eval/results/plots/plot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 0 additions & 5 deletions data/rag_eval/results/results.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,6 @@
"context_recall": 0,
"context_relevancy": 0
},
{
"context_precision": 0.9819999999999999,
"context_recall": 0.9400000000000002,
"context_relevancy": 0
},
{
"context_precision": 0.9819999999999999,
"context_recall": 0.9400000000000002,
Expand Down
57 changes: 36 additions & 21 deletions test/benchmarks/rag/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,10 @@ def evaluate(vdb: Store, qa_paths: list, endpoint: str, metrics: list,
return metrics, eval_dataset


def update_evaluation_plots(results_df: pd.DataFrame, metrics: list, modified=True):
def update_evaluation_plots(results_df: pd.DataFrame, metrics: list,
modified=True,
rows: int = 1,
cols: int = 3):
if len(metrics) == 0:
raise ValueError('No metrics specified.')

Expand All @@ -206,8 +209,9 @@ def update_evaluation_plots(results_df: pd.DataFrame, metrics: list, modified=Tr

# Add new results
res: pd.Series = results_df.mean()
new_results = {metric_name: res[metric_name] if metric_name in res else content[len(content) - 1][metric_name]
for metric_name in content[0].keys()}
new_results = {
metric_name: res[metric_name] if metric_name in res else content[len(content) - 1][metric_name]
for metric_name in content[0].keys()}
content.append(new_results)

fp.seek(0)
Expand All @@ -216,33 +220,44 @@ def update_evaluation_plots(results_df: pd.DataFrame, metrics: list, modified=Tr
else:
history = results_df

def plot_eval(plot_df: pd.DataFrame, name: str):
# Ensure the grid has enough space for all metrics
total_metrics = len(metrics)
if rows * cols < total_metrics:
raise ValueError(f'Grid size ({rows}x{cols}) is too small for {total_metrics} metrics.')

# Create a single plot with subplots for each metric
fig, axes = plt.subplots(rows, cols, figsize=(cols * 5, rows * 5))
axes = axes.flatten() # Flatten in case of a single row or column

def plot_eval(ax, plot_df: pd.DataFrame, name: str):
"""Create a plot for an evaluation metric, the columns should be named 'x' and 'y'"""
sns.lineplot(data=plot_df, x='x', y='y', zorder=0)
plt.scatter(
sns.lineplot(data=plot_df, x='x', y='y', ax=ax, zorder=0)
ax.scatter(
plot_df.iloc[1:]['x'],
plot_df.iloc[1:]['y'],
color='#000000',
s=15,
zorder=1
)

plt.ylim(0, 1)
plt.xticks(range(0, len(plot_df)))

plt.title(f'RAG Evaluation: {name}')
plt.ylabel(name)
plt.xlabel('')
return plt
ax.set_ylim(0, 1)
ax.set_xticks(range(0, len(plot_df)))
ax.set_title(f'RAG Evaluation: {name}')
ax.set_ylabel(name)
ax.set_xlabel('')

# Output the updated evaluation plots
plots = {}
for col in history.columns:
for i, col in enumerate(history.columns):
values = history[col].to_list()
plots[col] = [{'x': i, 'y': val} for i, val in enumerate(values)]
metric_plot_df = pd.DataFrame(plots[col])
plot = plot_eval(metric_plot_df, col)
plot.savefig(f'../../../data/rag_eval/results/plots/{col}.png')
metric_plot_df = pd.DataFrame([{'x': i, 'y': val} for i, val in enumerate(values)])
plot_eval(axes[i], metric_plot_df, col)

# Hide any unused subplots
for j in range(i + 1, len(axes)):
fig.delaxes(axes[j])

plt.tight_layout()
plt.savefig(f'../../../data/rag_eval/results/plots/plot.png')
plt.close()


def main(plot_only=False):
Expand Down Expand Up @@ -279,4 +294,4 @@ def main(plot_only=False):


if __name__ == '__main__':
main(plot_only=False)
main(plot_only=True)

0 comments on commit e5ce191

Please sign in to comment.