From d2e4ec6d13bb66bcc85d4e3f84170e111c398a2e Mon Sep 17 00:00:00 2001 From: mrcodigofuente Date: Tue, 9 Aug 2022 19:42:43 -0700 Subject: [PATCH 1/2] Add option disable_progress_grid to user arguments --- discoart/create.py | 1 + discoart/persist.py | 22 ++++++++++++++-------- discoart/resources/default.yml | 3 ++- discoart/runner.py | 1 + 4 files changed, 18 insertions(+), 9 deletions(-) diff --git a/discoart/create.py b/discoart/create.py index d9b451c..b73ba3e 100644 --- a/discoart/create.py +++ b/discoart/create.py @@ -80,6 +80,7 @@ def create( use_vertical_symmetry: Optional[bool] = False, visualize_cuts: Optional[bool] = False, width_height: Optional[List[int]] = [1280, 768], + disable_progress_grid: Optional[bool] = False, ) -> Optional['DocumentArray']: ... diff --git a/discoart/persist.py b/discoart/persist.py index 697a39d..fdd4a67 100644 --- a/discoart/persist.py +++ b/discoart/persist.py @@ -1,5 +1,6 @@ import os import threading +import base64 from threading import Thread import torchvision.transforms.functional as TF @@ -99,18 +100,23 @@ def _save_progress_thread(*args): return t -def _save_progress(da, da_gif, _nb, output_dir, fps, size_ratio): +def _save_progress(da, da_gif, _nb, output_dir, fps, size_ratio, disable_progress_grid): with threading.Lock(): try: for idx, d in enumerate(da): if d.chunks: - # only print the first image of the minibatch in progress - d.chunks.plot_image_sprites( - os.path.join(output_dir, f'{_nb}-progress-{idx}.png'), - skip_empty=True, - show_index=True, - keep_aspect_ratio=True, - ) + if disable_progress_grid: + #save only the latest image + file_name = os.path.join(output_dir, f'{_nb}-progress.png') + d.load_uri_to_image_tensor().save_image_tensor_to_file(file_name) + else: + # only print the first image of the minibatch in progress + d.chunks.plot_image_sprites( + os.path.join(output_dir, f'{_nb}-progress-{idx}.png'), + skip_empty=True, + show_index=True, + keep_aspect_ratio=True, + ) for idx, d_gif in enumerate(da_gif): if d_gif.chunks and fps > 0: d_gif.chunks.save_gif( diff --git a/discoart/resources/default.yml b/discoart/resources/default.yml index 895f573..778e79d 100644 --- a/discoart/resources/default.yml +++ b/discoart/resources/default.yml @@ -62,4 +62,5 @@ text_clip_on_cpu: False truncate_overlength_prompt: False image_output: True visualize_cuts: False -display_rate: 1 \ No newline at end of file +display_rate: 1 +disable_progress_grid: False \ No newline at end of file diff --git a/discoart/runner.py b/discoart/runner.py index 631753e..e0c6acf 100644 --- a/discoart/runner.py +++ b/discoart/runner.py @@ -453,6 +453,7 @@ def cond_fn(x, t, **kwargs): output_dir, args.gif_fps, args.gif_size_ratio, + args.disable_progress_grid, ) ) From e7ddd74338e573c1200405da99229f79ac981fe4 Mon Sep 17 00:00:00 2001 From: mrcodigofuente Date: Tue, 9 Aug 2022 20:07:03 -0700 Subject: [PATCH 2/2] Fix: remove import base64 from persist.py --- discoart/persist.py | 1 - 1 file changed, 1 deletion(-) diff --git a/discoart/persist.py b/discoart/persist.py index fdd4a67..ad47340 100644 --- a/discoart/persist.py +++ b/discoart/persist.py @@ -1,6 +1,5 @@ import os import threading -import base64 from threading import Thread import torchvision.transforms.functional as TF