diff --git a/models/common.py b/models/common.py index ad35f908d865..91b600c7c55e 100644 --- a/models/common.py +++ b/models/common.py @@ -1,5 +1,6 @@ # This file contains modules common to various models +import base64 import math from pathlib import Path @@ -8,6 +9,7 @@ import torch import torch.nn as nn from PIL import Image +from io import BytesIO from utils.datasets import letterbox from utils.general import non_max_suppression, make_divisible, scale_coords, xyxy2xywh @@ -245,7 +247,7 @@ def __init__(self, imgs, pred, files, names=None): self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized self.n = len(self.pred) - def display(self, pprint=False, show=False, save=False, render=False, save_dir=''): + def display(self, pprint=False, show=False, save=False, render=False, save_dir='', base_64=False): colors = color_list() for i, (img, pred) in enumerate(zip(self.imgs, self.pred)): str = f'image {i + 1}/{len(self.pred)}: {img.shape[0]}x{img.shape[1]} ' @@ -253,7 +255,7 @@ def display(self, pprint=False, show=False, save=False, render=False, save_dir=' for c in pred[:, -1].unique(): n = (pred[:, -1] == c).sum() # detections per class str += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string - if show or save or render: + if show or save or render or base_64: for *box, conf, cls in pred: # xyxy, confidence, class label = f'{self.names[int(cls)]} {conf:.2f}' plot_one_box(box, img, label=label, color=colors[int(cls) % 10]) @@ -268,6 +270,12 @@ def display(self, pprint=False, show=False, save=False, render=False, save_dir=' print(f"{'Saving' * (i == 0)} {f},", end='' if i < self.n - 1 else ' done.\n') if render: self.imgs[i] = np.asarray(img) + if base_64: + buffered = BytesIO() + img.save(buffered, format="JPEG") + img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') + return img_base64 + def print(self): self.display(pprint=True) # print results @@ -283,6 +291,9 @@ def render(self): self.display(render=True) # render results return self.imgs + def tobase64(self): + return self.display(base_64=True) + def __len__(self): return self.n diff --git a/utils/datasets.py b/utils/datasets.py index d6ab16518034..0790a3e120ac 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -1033,19 +1033,52 @@ def extract_boxes(path='../coco128/'): # from utils.datasets import *; extract_ assert cv2.imwrite(str(f), im[b[1]:b[3], b[0]:b[2]]), f'box failure in {f}' -def autosplit(path='../coco128', weights=(0.9, 0.1, 0.0)): # from utils.datasets import *; autosplit('../coco128') +def autosplit(path='../coco128', weights=(0.9, 0.1, 0.0), annotated_only=False, bg_imgs_path=None, bg_imgs_ratio=0.05): # from utils.datasets import *; autosplit('../coco128') + """ Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files # Arguments - path: Path to images directory - weights: Train, val, test weights (list) + path: Path to images directory + weights: Train, val, test weights (list) + annotated_only: Only use images with an annotated txt file associated to create the dataset + bg_imgs_path: Path to background images to introduce inside the dataset (to reduce FPs) + bg_imgs_ratio: Ratio of background images to add according to the number of images in each split """ + path = Path(path) # images dir - files = list(path.rglob('*.*')) + + # make sure we only work with images files + files = sum([list(path.rglob(f"*.{img_ext}")) for img_ext in img_formats], []) n = len(files) # number of files + indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split + txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files [(path / x).unlink() for x in txt if (path / x).exists()] # remove existing + + if annotated_only: + print("Only annotated images with a .txt file associated will be used to create the dataset") + + bg_dataset_sizes = [0, 0, 0] # calculate num of bg images needed for each split for i, img in tqdm(zip(indices, files), total=n): - if img.suffix[1:] in img_formats: + # in case we want to use only annotated files + if not annotated_only or (annotated_only and (img.parent / (img.stem + ".txt")).exists()): with open(path / txt[i], 'a') as f: f.write(str(img) + '\n') # add image to txt file + bg_dataset_sizes[i] += bg_imgs_ratio + + # automatically fill our dataset with background images + if bg_imgs_path: + assert 0 < bg_imgs_ratio <= 0.1, "We recommend a background images ratio between 0% (0) and 10% (0.1)" + bg_path = Path(bg_imgs_path) + bg_files = sum([list(bg_path.rglob(f"*.{img_ext}")) for img_ext in img_formats], []) + + print("Filling the dataset with some background images...") + + # keep only the number of bg images needed + bg_files = random.sample(bg_files, round(sum(bg_dataset_sizes))) + bg_indices = random.choices([0, 1, 2], weights=weights, k=len(bg_files)) + + for i, img in tqdm(zip(bg_indices, bg_files), total=len(bg_files)): + with open(path / txt[i], 'a') as f: + f.write(str(img) + '\n') # add image to txt file +