diff --git a/examples/batch_chat.py b/examples/batch_chat.py index ff3a066..146280a 100644 --- a/examples/batch_chat.py +++ b/examples/batch_chat.py @@ -24,7 +24,7 @@ img_paths = ['examples/image1.webp', 'examples/image1.webp'] questions = ['Please describe this image in detail.', - 'What is the text in this images?'] + 'What is the text in this images? Please describe it in detail.'] assert len(img_paths) == len(questions) @@ -66,7 +66,8 @@ batch_inputs, batch_masks, batch_atten_masks = [], [], [] for inputs, im_mask in zip(inputs_list, masks_list): if im_mask.shape[1] < max_len: - pad_embeds = torch.cat([pad_embed]*(max_len - im_mask.shape[1])) + pad_length = max_len - im_mask.shape[1] + pad_embeds = pad_embed.repeat(1, pad_length, 1) pad_masks = torch.tensor([0]*(max_len - im_mask.shape[1])).unsqueeze(0).cuda() inputs = torch.cat([pad_embeds, inputs['inputs_embeds']], dim=1) atten_masks = torch.cat([pad_masks, torch.ones_like(im_mask)], dim=1)