Skip to content

Commit

Permalink
Upgrade ipa support (#879)
Browse files Browse the repository at this point in the history
## Accelerating ComfyUI_IPAdapter_plus with OneDiff
### Environment
Please Refer to the Readme in the Respective Repositories for
Installation Instructions.

- ComfyUI:
  - github: https://github.com/comfyanonymous/ComfyUI
  - commit: 2d4164271634476627aae31fbec251ca748a0ae0 
  - Date:   Wed May 15 02:40:06 2024 -0400
 
  
- ComfyUI_IPAdapter_plus:
  - github: https://github.com/cubiq/ComfyUI_IPAdapter_plus
  - commit 20125bf9394b1bc98ef3228277a31a3a52c72fc2 
  - Date:   Wed May 8 16:10:20 2024 +0200

- ComfyUI_InstantID:
  - github: https://github.com/cubiq/ComfyUI_InstantID
  - commit d8c70a0cd8ce0d4d62e78653674320c9c3084ec1 
  - Date:   Wed May 8 16:55:55 2024 +0200

- OneDiff:
  - github: https://github.com/siliconflow/onediff 

    ```shell
    # install onediff
    git clone https://github.com/siliconflow/onediff.git
    cd onediff && git checkout dev_support_ipadapter_v2
    pip install -e .
    
    # install onediff_comfy_nodes
    ln -s $(pwd)/onediff_comfy_nodes path/to/ComfyUI/custom_nodes/
    # or
    # cp -r onediff_comfy_nodes path/to/ComfyUI/custom_nodes/
    ```
  • Loading branch information
ccssu committed May 17, 2024
1 parent 0819aa4 commit d8a6a90
Show file tree
Hide file tree
Showing 16 changed files with 368 additions and 512 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from ..utils.booster_utils import is_using_oneflow_backend
from ._config import comfyui_instantid_hijacker, comfyui_instantid_pt
from ..hijack_ipadapter_plus.set_model_patch_replace import set_model_patch_replace
from ..hijack_ipadapter_plus.set_model_patch_replace import set_model_patch_replace_v2

set_model_patch_replace_fn_pt = comfyui_instantid_pt.InstantID._set_model_patch_replace


Expand All @@ -9,5 +10,5 @@ def cond_func(org_fn, model, *args, **kwargs):


comfyui_instantid_hijacker.register(
set_model_patch_replace_fn_pt, set_model_patch_replace, cond_func
set_model_patch_replace_fn_pt, set_model_patch_replace_v2, cond_func
)
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""hijack ComfyUI/custom_nodes/ComfyUI_IPAdapter_plus/IPAdapterPlus.py"""
from ..utils.booster_utils import is_using_oneflow_backend
from ._config import ipadapter_plus_hijacker, ipadapter_plus_pt
from .set_model_patch_replace import set_model_patch_replace
from .set_model_patch_replace import set_model_patch_replace_v2

set_model_patch_replace_fn_pt = ipadapter_plus_pt.IPAdapterPlus.set_model_patch_replace


Expand All @@ -10,5 +11,5 @@ def cond_func(org_fn, model, *args, **kwargs):


ipadapter_plus_hijacker.register(
set_model_patch_replace_fn_pt, set_model_patch_replace, cond_func
set_model_patch_replace_fn_pt, set_model_patch_replace_v2, cond_func
)
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,26 @@ Please Refer to the Readme in the Respective Repositories for Installation Instr

- ComfyUI:
- github: https://github.com/comfyanonymous/ComfyUI
- commit: 4bd7d55b9028d79829a645edfe8259f7b7a049c0
- Date: Thu Apr 11 22:43:05 2024 -0400
- commit: 2d4164271634476627aae31fbec251ca748a0ae0
- Date: Wed May 15 02:40:06 2024 -0400

- ComfyUI_IPAdapter_plus:
- github: https://github.com/cubiq/ComfyUI_IPAdapter_plus
- commit 417d806e7a2153c98613e86407c1941b2b348e88
- Date: Wed Apr 10 13:28:41 2024 +0200

- ComfyUI_InstantID$ git log
- commit e9cc7597b2a7cd441065418a975a2de4aa2450df
- Date: Tue Apr 9 14:05:15 2024 +0200
- commit 20125bf9394b1bc98ef3228277a31a3a52c72fc2
- Date: Wed May 8 16:10:20 2024 +0200

- ComfyUI_InstantID$
- commit d8c70a0cd8ce0d4d62e78653674320c9c3084ec1
- Date: Wed May 8 16:55:55 2024 +0200

- ComfyUI-AnimateDiff-Evolved$ git log
- commit f9e0343f4c4606ee6365a9af4a7e16118f1c45e1
- Date: Sat Apr 6 17:32:15 2024 -0500

- OneDiff:
- github: https://github.com/siliconflow/onediff


### Quick Start

> Recommend running the official example of ComfyUI_IPAdapter_plus now, and then trying OneDiff acceleration.
Expand Down Expand Up @@ -56,7 +57,6 @@ As follows:
| ------------------ | --------- |
| Dynamic Shape | Yes |
| Dynamic Batch Size | No |
| Vae Speed Up | Yes |

## Contact

Expand Down
Original file line number Diff line number Diff line change
@@ -1,59 +1,91 @@
from comfy.ldm.modules.attention import attention_pytorch
from register_comfy.CrossAttentionPatch import \
CrossAttentionPatch as CrossAttentionPatch_PT
from register_comfy.CrossAttentionPatch import Attn2Replace, ipadapter_attention

from onediff.infer_compiler.transform import torch2oflow
from ..utils.booster_utils import clear_deployable_module_cache_and_unbind
from ..patch_management import PatchType, create_patch_executor


def set_model_patch_replace(org_fn, model, patch_kwargs, key):
# from onediff.infer_compiler.utils.cost_util import cost_time
# @cost_time(debug=True, message="set_model_patch_replace_v2")
def set_model_patch_replace_v2(org_fn, model, patch_kwargs, key):
diff_model = model.model.diffusion_model
cache_patch_executor = create_patch_executor(PatchType.CachedCrossAttentionPatch)
masks_patch_executor = create_patch_executor(PatchType.CrossAttentionForwardMasksPatch)
unet_extra_options_patch_executor = create_patch_executor(
PatchType.UNetExtraInputOptions
)
cache_dict = cache_patch_executor.get_patch(diff_model)
cache_key = create_patch_executor(PatchType.UiNodeWithIndexPatch).get_patch(model)
to = model.model_options["transformer_options"]

ui_cache_key = create_patch_executor(PatchType.UiNodeWithIndexPatch).get_patch(
model
)
unet_extra_options = unet_extra_options_patch_executor.get_patch(diff_model)

if "attn2" not in unet_extra_options:
unet_extra_options["attn2"] = {}

to = model.model_options["transformer_options"].copy()
if "patches_replace" not in to:
to["patches_replace"] = {}
else:
to["patches_replace"] = to["patches_replace"].copy()

if "attn2" not in to["patches_replace"]:
to["patches_replace"]["attn2"] = {}

masks_dict = masks_patch_executor.get_patch(diff_model)
else:
to["patches_replace"]["attn2"] = to["patches_replace"]["attn2"].copy()

if key in cache_dict:
patch: CrossAttentionPatch_PT = cache_dict[key]
if patch.retrieve_from_cache(cache_key) is not None:
if patch.update(cache_key, torch2oflow(patch_kwargs)):
patch.update_mask(cache_key, masks_dict, patch_kwargs["mask"])
return
def split_patch_kwargs(patch_kwargs):
split1dict = {}
split2dict = {}
for k, v in patch_kwargs.items():
if k in ["cond", "uncond", "mask", "weight"]:
split1dict[k] = v
else:
clear_deployable_module_cache_and_unbind(model)
split2dict[k] = v

return split1dict, split2dict

new_patch_kwargs, patch_kwargs = split_patch_kwargs(patch_kwargs)
# update patch_kwargs
if key in cache_dict:
try:
attn2_m = cache_dict[key]
index = attn2_m.cache_map.get(ui_cache_key, None)
if index is not None:
unet_extra_options["attn2"][attn2_m.forward_patch_key][
index
] = new_patch_kwargs
return
except Exception as e:
clear_deployable_module_cache_and_unbind(model)

if key not in to["patches_replace"]["attn2"]:
if key not in cache_dict:
patch_pt = CrossAttentionPatch_PT(**patch_kwargs)
patch_pt.optimized_attention = attention_pytorch
patch_of = torch2oflow(patch_pt)
patch_of.bind_model(patch_pt)

patch = patch_of
cache_dict[key] = patch
patch.set_cache(cache_key, len(patch.weights) - 1)
patch.append_mask(masks_dict, patch_kwargs["mask"])

patch: CrossAttentionPatch_PT = cache_dict[key]
to["patches_replace"]["attn2"][key] = patch
attn2_m_pt = Attn2Replace(ipadapter_attention, **patch_kwargs)
attn2_m_of = torch2oflow(attn2_m_pt, bypass_check=True)

cache_dict[key] = attn2_m_of
attn2_m: Attn2Replace = attn2_m_of
index = len(attn2_m.callback) - 1
attn2_m.cache_map[ui_cache_key] = index
unet_extra_options["attn2"][attn2_m.forward_patch_key] = [new_patch_kwargs]

# QuantizedInputPatch
attn2_m._bind_model = attn2_m_pt
else:
attn2_m = cache_dict[key]

to["patches_replace"]["attn2"][key] = attn2_m
model.model_options["transformer_options"] = to
else:
patch = to["patches_replace"]["attn2"][key]
patch.set_new_condition(**torch2oflow(patch_kwargs))
patch.set_cache(cache_key, len(patch.weights) - 1)
patch.append_mask(masks_dict, patch_kwargs["mask"])

if patch.get_bind_model() is not None:
bind_model: CrossAttentionPatch_PT = patch.get_bind_model()
bind_model.set_new_condition(**patch_kwargs)


create_patch_executor(PatchType.QuantizedInputPatch).set_patch()
attn2_m: Attn2Replace = to["patches_replace"]["attn2"][key]
attn2_m.add(attn2_m.callback[0], **torch2oflow(patch_kwargs))
unet_extra_options["attn2"][attn2_m.forward_patch_key].append(
new_patch_kwargs
) # update last patch
attn2_m.cache_map[ui_cache_key] = len(attn2_m.callback) - 1

if attn2_m.get_bind_model() is not None:
bind_model: Attn2Replace = attn2_m.get_bind_model()
bind_model.add(bind_model.callback[0], **patch_kwargs)

if not create_patch_executor(PatchType.QuantizedInputPatch).check_patch():
create_patch_executor(PatchType.QuantizedInputPatch).set_patch()
4 changes: 2 additions & 2 deletions onediff_comfy_nodes/modules/oneflow/hijack_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ def calc_cond_batch_of(orig_func, model, conds, x_in, timestep, model_options):

if create_patch_executor(PatchType.CachedCrossAttentionPatch).check_patch(diff_model):
transformer_options["sigmas"] = timestep[0].item()
masks_patch_executor = create_patch_executor(PatchType.CrossAttentionForwardMasksPatch)
transformer_options["_masks"] = masks_patch_executor.get_patch(diff_model)
patch_executor = create_patch_executor(PatchType.UNetExtraInputOptions)
transformer_options["_attn2"] = patch_executor.get_patch(diff_model)["attn2"]
else:
transformer_options["sigmas"] = timestep

Expand Down
Loading

0 comments on commit d8a6a90

Please sign in to comment.