Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade ipa support #879

Merged
merged 8 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from ..utils.booster_utils import is_using_oneflow_backend
from ._config import comfyui_instantid_hijacker, comfyui_instantid_pt
from ..hijack_ipadapter_plus.set_model_patch_replace import set_model_patch_replace
from ..hijack_ipadapter_plus.set_model_patch_replace import set_model_patch_replace_v2

set_model_patch_replace_fn_pt = comfyui_instantid_pt.InstantID._set_model_patch_replace


Expand All @@ -9,5 +10,5 @@ def cond_func(org_fn, model, *args, **kwargs):


comfyui_instantid_hijacker.register(
set_model_patch_replace_fn_pt, set_model_patch_replace, cond_func
set_model_patch_replace_fn_pt, set_model_patch_replace_v2, cond_func
)
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""hijack ComfyUI/custom_nodes/ComfyUI_IPAdapter_plus/IPAdapterPlus.py"""
from ..utils.booster_utils import is_using_oneflow_backend
from ._config import ipadapter_plus_hijacker, ipadapter_plus_pt
from .set_model_patch_replace import set_model_patch_replace
from .set_model_patch_replace import set_model_patch_replace_v2

set_model_patch_replace_fn_pt = ipadapter_plus_pt.IPAdapterPlus.set_model_patch_replace


Expand All @@ -10,5 +11,5 @@ def cond_func(org_fn, model, *args, **kwargs):


ipadapter_plus_hijacker.register(
set_model_patch_replace_fn_pt, set_model_patch_replace, cond_func
set_model_patch_replace_fn_pt, set_model_patch_replace_v2, cond_func
)
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,26 @@ Please Refer to the Readme in the Respective Repositories for Installation Instr

- ComfyUI:
- github: https://github.com/comfyanonymous/ComfyUI
- commit: 4bd7d55b9028d79829a645edfe8259f7b7a049c0
- Date: Thu Apr 11 22:43:05 2024 -0400
- commit: 2d4164271634476627aae31fbec251ca748a0ae0
- Date: Wed May 15 02:40:06 2024 -0400

- ComfyUI_IPAdapter_plus:
- github: https://github.com/cubiq/ComfyUI_IPAdapter_plus
- commit 417d806e7a2153c98613e86407c1941b2b348e88
- Date: Wed Apr 10 13:28:41 2024 +0200

- ComfyUI_InstantID$ git log
- commit e9cc7597b2a7cd441065418a975a2de4aa2450df
- Date: Tue Apr 9 14:05:15 2024 +0200
- commit 20125bf9394b1bc98ef3228277a31a3a52c72fc2
- Date: Wed May 8 16:10:20 2024 +0200

- ComfyUI_InstantID$
- commit d8c70a0cd8ce0d4d62e78653674320c9c3084ec1
- Date: Wed May 8 16:55:55 2024 +0200

- ComfyUI-AnimateDiff-Evolved$ git log
- commit f9e0343f4c4606ee6365a9af4a7e16118f1c45e1
- Date: Sat Apr 6 17:32:15 2024 -0500

- OneDiff:
- github: https://github.com/siliconflow/onediff


### Quick Start

> Recommend running the official example of ComfyUI_IPAdapter_plus now, and then trying OneDiff acceleration.
Expand Down Expand Up @@ -56,7 +57,6 @@ As follows:
| ------------------ | --------- |
| Dynamic Shape | Yes |
| Dynamic Batch Size | No |
| Vae Speed Up | Yes |

## Contact

Expand Down
Original file line number Diff line number Diff line change
@@ -1,59 +1,91 @@
from comfy.ldm.modules.attention import attention_pytorch
from register_comfy.CrossAttentionPatch import \
CrossAttentionPatch as CrossAttentionPatch_PT
from register_comfy.CrossAttentionPatch import Attn2Replace, ipadapter_attention

from onediff.infer_compiler.transform import torch2oflow
from ..utils.booster_utils import clear_deployable_module_cache_and_unbind
from ..patch_management import PatchType, create_patch_executor


def set_model_patch_replace(org_fn, model, patch_kwargs, key):
# from onediff.infer_compiler.utils.cost_util import cost_time
# @cost_time(debug=True, message="set_model_patch_replace_v2")
def set_model_patch_replace_v2(org_fn, model, patch_kwargs, key):
diff_model = model.model.diffusion_model
cache_patch_executor = create_patch_executor(PatchType.CachedCrossAttentionPatch)
masks_patch_executor = create_patch_executor(PatchType.CrossAttentionForwardMasksPatch)
unet_extra_options_patch_executor = create_patch_executor(
PatchType.UNetExtraInputOptions
)
cache_dict = cache_patch_executor.get_patch(diff_model)
cache_key = create_patch_executor(PatchType.UiNodeWithIndexPatch).get_patch(model)
to = model.model_options["transformer_options"]

ui_cache_key = create_patch_executor(PatchType.UiNodeWithIndexPatch).get_patch(
model
)
unet_extra_options = unet_extra_options_patch_executor.get_patch(diff_model)

if "attn2" not in unet_extra_options:
unet_extra_options["attn2"] = {}

to = model.model_options["transformer_options"].copy()
if "patches_replace" not in to:
to["patches_replace"] = {}
else:
to["patches_replace"] = to["patches_replace"].copy()

if "attn2" not in to["patches_replace"]:
to["patches_replace"]["attn2"] = {}

masks_dict = masks_patch_executor.get_patch(diff_model)
else:
to["patches_replace"]["attn2"] = to["patches_replace"]["attn2"].copy()

if key in cache_dict:
patch: CrossAttentionPatch_PT = cache_dict[key]
if patch.retrieve_from_cache(cache_key) is not None:
if patch.update(cache_key, torch2oflow(patch_kwargs)):
patch.update_mask(cache_key, masks_dict, patch_kwargs["mask"])
return
def split_patch_kwargs(patch_kwargs):
split1dict = {}
split2dict = {}
for k, v in patch_kwargs.items():
if k in ["cond", "uncond", "mask", "weight"]:
split1dict[k] = v
else:
clear_deployable_module_cache_and_unbind(model)
split2dict[k] = v

return split1dict, split2dict

new_patch_kwargs, patch_kwargs = split_patch_kwargs(patch_kwargs)
# update patch_kwargs
if key in cache_dict:
try:
attn2_m = cache_dict[key]
index = attn2_m.cache_map.get(ui_cache_key, None)
if index is not None:
unet_extra_options["attn2"][attn2_m.forward_patch_key][
index
] = new_patch_kwargs
return
except Exception as e:
clear_deployable_module_cache_and_unbind(model)

if key not in to["patches_replace"]["attn2"]:
if key not in cache_dict:
patch_pt = CrossAttentionPatch_PT(**patch_kwargs)
patch_pt.optimized_attention = attention_pytorch
patch_of = torch2oflow(patch_pt)
patch_of.bind_model(patch_pt)

patch = patch_of
cache_dict[key] = patch
patch.set_cache(cache_key, len(patch.weights) - 1)
patch.append_mask(masks_dict, patch_kwargs["mask"])

patch: CrossAttentionPatch_PT = cache_dict[key]
to["patches_replace"]["attn2"][key] = patch
attn2_m_pt = Attn2Replace(ipadapter_attention, **patch_kwargs)
attn2_m_of = torch2oflow(attn2_m_pt, bypass_check=True)

cache_dict[key] = attn2_m_of
attn2_m: Attn2Replace = attn2_m_of
index = len(attn2_m.callback) - 1
attn2_m.cache_map[ui_cache_key] = index
unet_extra_options["attn2"][attn2_m.forward_patch_key] = [new_patch_kwargs]

# QuantizedInputPatch
attn2_m._bind_model = attn2_m_pt
else:
attn2_m = cache_dict[key]

to["patches_replace"]["attn2"][key] = attn2_m
model.model_options["transformer_options"] = to
else:
patch = to["patches_replace"]["attn2"][key]
patch.set_new_condition(**torch2oflow(patch_kwargs))
patch.set_cache(cache_key, len(patch.weights) - 1)
patch.append_mask(masks_dict, patch_kwargs["mask"])

if patch.get_bind_model() is not None:
bind_model: CrossAttentionPatch_PT = patch.get_bind_model()
bind_model.set_new_condition(**patch_kwargs)


create_patch_executor(PatchType.QuantizedInputPatch).set_patch()
attn2_m: Attn2Replace = to["patches_replace"]["attn2"][key]
attn2_m.add(attn2_m.callback[0], **torch2oflow(patch_kwargs))
unet_extra_options["attn2"][attn2_m.forward_patch_key].append(
new_patch_kwargs
) # update last patch
attn2_m.cache_map[ui_cache_key] = len(attn2_m.callback) - 1

if attn2_m.get_bind_model() is not None:
bind_model: Attn2Replace = attn2_m.get_bind_model()
bind_model.add(bind_model.callback[0], **patch_kwargs)

if not create_patch_executor(PatchType.QuantizedInputPatch).check_patch():
create_patch_executor(PatchType.QuantizedInputPatch).set_patch()
4 changes: 2 additions & 2 deletions onediff_comfy_nodes/modules/oneflow/hijack_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ def calc_cond_batch_of(orig_func, model, conds, x_in, timestep, model_options):

if create_patch_executor(PatchType.CachedCrossAttentionPatch).check_patch(diff_model):
transformer_options["sigmas"] = timestep[0].item()
masks_patch_executor = create_patch_executor(PatchType.CrossAttentionForwardMasksPatch)
transformer_options["_masks"] = masks_patch_executor.get_patch(diff_model)
patch_executor = create_patch_executor(PatchType.UNetExtraInputOptions)
transformer_options["_attn2"] = patch_executor.get_patch(diff_model)["attn2"]
else:
transformer_options["sigmas"] = timestep

Expand Down
Loading