Skip to content

Commit

Permalink
fix boostrap error
Browse files Browse the repository at this point in the history
  • Loading branch information
CodingWithTim committed Aug 29, 2024
1 parent 9ebb442 commit 8ead2ff
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions fastchat/serve/monitor/elo_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,18 +495,27 @@ def construct_style_matrices(
return X, Y, models


def get_bootstrap_result_style_control(X, Y, models, func_compute_elo, num_round=1000):
def get_bootstrap_result_style_control(X, Y, battles, models, func_compute_elo, num_round=1000):
elos = []
coefs = []
assert X.shape[0] % 2 == 0 and X.shape[0] == Y.shape[0]
k = int(
X.shape[0] / 2
) # Since we duplicate the battles when constructing X and Y, we don't want to sample the duplicates

battles_tie_idx = (battles["winner"] == "tie") | (battles["winner"] == "tie (bothbad)")
for _ in tqdm(range(num_round), desc="bootstrap"):
indices = np.random.choice(list(range(k)), size=(k), replace=True)
_X = np.concatenate([X[indices], X[indices]])
_Y = np.concatenate([Y[indices], Y[indices]])

index2tie = np.zeros(k, dtype=bool)
index2tie[battles_tie_idx] = True

nontie_indices = indices[~index2tie[indices]]
tie_indices = np.concatenate([indices[index2tie[indices]], indices[index2tie[indices]]+k])

_X = np.concatenate([X[nontie_indices], X[nontie_indices], X[tie_indices]])
_Y = np.concatenate([Y[nontie_indices], Y[nontie_indices], Y[tie_indices]])

assert _X.shape == X.shape and _Y.shape == Y.shape

states = ~_X[:, : len(models)].any(axis=0)
Expand Down Expand Up @@ -585,7 +594,7 @@ def report_elo_analysis_results(
if style_control:
X, Y, models = construct_style_matrices(battles)
bootstrap_df, boostrap_coef = get_bootstrap_result_style_control(
X, Y, models, fit_mle_elo, num_round=num_bootstrap
X, Y, battles, models, fit_mle_elo, num_round=num_bootstrap
)
elo_rating_final, coef_final = fit_mle_elo(X, Y, models)
else:
Expand Down

0 comments on commit 8ead2ff

Please sign in to comment.