From 951bed038edabccef35bf13d49bc40627900ab63 Mon Sep 17 00:00:00 2001 From: sugary199 <906940958@qq.com> Date: Wed, 10 Apr 2024 11:01:32 +0800 Subject: [PATCH 1/2] Fix tensor dimension mismatch in padding operation --- examples/batch_chat.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/batch_chat.py b/examples/batch_chat.py index ff3a066..f99754e 100644 --- a/examples/batch_chat.py +++ b/examples/batch_chat.py @@ -66,7 +66,8 @@ batch_inputs, batch_masks, batch_atten_masks = [], [], [] for inputs, im_mask in zip(inputs_list, masks_list): if im_mask.shape[1] < max_len: - pad_embeds = torch.cat([pad_embed]*(max_len - im_mask.shape[1])) + pad_length = max_len - im_mask.shape[1] + pad_embeds = pad_embed.repeat(1, pad_length, 1) pad_masks = torch.tensor([0]*(max_len - im_mask.shape[1])).unsqueeze(0).cuda() inputs = torch.cat([pad_embeds, inputs['inputs_embeds']], dim=1) atten_masks = torch.cat([pad_masks, torch.ones_like(im_mask)], dim=1) From 60ed6e25f4b82ce6aeffc51b6be552bce695899b Mon Sep 17 00:00:00 2001 From: sugary199 <906940958@qq.com> Date: Wed, 10 Apr 2024 11:05:49 +0800 Subject: [PATCH 2/2] Increase the difference in token numbers between the two examples --- examples/batch_chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/batch_chat.py b/examples/batch_chat.py index f99754e..146280a 100644 --- a/examples/batch_chat.py +++ b/examples/batch_chat.py @@ -24,7 +24,7 @@ img_paths = ['examples/image1.webp', 'examples/image1.webp'] questions = ['Please describe this image in detail.', - 'What is the text in this images?'] + 'What is the text in this images? Please describe it in detail.'] assert len(img_paths) == len(questions)