diff --git a/discoart/create.py b/discoart/create.py index d9b451c..b73ba3e 100644 --- a/discoart/create.py +++ b/discoart/create.py @@ -80,6 +80,7 @@ def create( use_vertical_symmetry: Optional[bool] = False, visualize_cuts: Optional[bool] = False, width_height: Optional[List[int]] = [1280, 768], + disable_progress_grid: Optional[bool] = False, ) -> Optional['DocumentArray']: ... diff --git a/discoart/persist.py b/discoart/persist.py index 697a39d..ad47340 100644 --- a/discoart/persist.py +++ b/discoart/persist.py @@ -99,18 +99,23 @@ def _save_progress_thread(*args): return t -def _save_progress(da, da_gif, _nb, output_dir, fps, size_ratio): +def _save_progress(da, da_gif, _nb, output_dir, fps, size_ratio, disable_progress_grid): with threading.Lock(): try: for idx, d in enumerate(da): if d.chunks: - # only print the first image of the minibatch in progress - d.chunks.plot_image_sprites( - os.path.join(output_dir, f'{_nb}-progress-{idx}.png'), - skip_empty=True, - show_index=True, - keep_aspect_ratio=True, - ) + if disable_progress_grid: + #save only the latest image + file_name = os.path.join(output_dir, f'{_nb}-progress.png') + d.load_uri_to_image_tensor().save_image_tensor_to_file(file_name) + else: + # only print the first image of the minibatch in progress + d.chunks.plot_image_sprites( + os.path.join(output_dir, f'{_nb}-progress-{idx}.png'), + skip_empty=True, + show_index=True, + keep_aspect_ratio=True, + ) for idx, d_gif in enumerate(da_gif): if d_gif.chunks and fps > 0: d_gif.chunks.save_gif( diff --git a/discoart/resources/default.yml b/discoart/resources/default.yml index 895f573..778e79d 100644 --- a/discoart/resources/default.yml +++ b/discoart/resources/default.yml @@ -62,4 +62,5 @@ text_clip_on_cpu: False truncate_overlength_prompt: False image_output: True visualize_cuts: False -display_rate: 1 \ No newline at end of file +display_rate: 1 +disable_progress_grid: False \ No newline at end of file diff --git a/discoart/runner.py b/discoart/runner.py index 631753e..e0c6acf 100644 --- a/discoart/runner.py +++ b/discoart/runner.py @@ -453,6 +453,7 @@ def cond_fn(x, t, **kwargs): output_dir, args.gif_fps, args.gif_size_ratio, + args.disable_progress_grid, ) )