Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Neg Prompt Travel #495

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 97 additions & 9 deletions scripts/animatediff_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,63 @@ class AnimateDiffPromptSchedule:

def __init__(self, p: StableDiffusionProcessing, params: AnimateDiffProcess):
self.prompt_map = None
self.neg_prompt_map = None
self.original_prompt = None
self.original_neg_prompt = None
self.parse_prompt(p, params)
self.parse_neg_prompt(p, params)


def save_infotext_img(self, p: StableDiffusionProcessing):
if self.prompt_map is not None:
p.prompts = [self.original_prompt for _ in range(p.batch_size)]
if self.neg_prompt_map is not None:
p.negative_prompts = [self.original_neg_prompt for _ in range(p.batch_size)]


def save_infotext_txt(self, res: Processed):
if self.prompt_map is not None:
parts = res.info.split('\nNegative prompt: ', 1)
if len(parts) > 1:
res.info = f"{self.original_prompt}\nNegative prompt: {parts[1]}"
for i in range(len(res.infotexts)):
parts = res.infotexts[i].split('\nNegative prompt: ', 1)
if len(parts) > 1:
res.infotexts[i] = f"{self.original_prompt}\nNegative prompt: {parts[1]}"
write_params_txt(res.info)
parts = res.info.split('\nSteps: ', 1)
prompts = parts[0].split('\nNegative prompt: ', 1)

if len(parts) > 1:
if self.original_prompt is not None:
prompt = self.original_prompt
else:
if len(prompts) > 1:
prompt = prompts[0]
else:
prompt = ""
if self.original_neg_prompt is not None:
neg_prompt = self.original_neg_prompt
else:
if len(prompts) > 1:
neg_prompt = prompts[1]
else:
neg_prompt = ""
res.info = f"{prompt}\nNegative prompt: {neg_prompt}\nSteps: {parts[1]}"

for i in range(len(res.infotexts)):
parts = res.infotexts[i].split('\nSteps: ', 1)
prompts = parts[0].split('\nNegative prompt: ', 1)

if len(parts) > 1:
if self.original_prompt is not None:
prompt = self.original_prompt
else:
if len(prompts) > 1:
prompt = prompts[0]
else:
prompt = ""
if self.original_neg_prompt is not None:
neg_prompt = self.original_neg_prompt
else:
if len(prompts) > 1:
neg_prompt = prompts[1]
else:
neg_prompt = ""
res.infotexts[i] = f"{prompt}\nNegative prompt: {neg_prompt}\nSteps: {parts[1]}"

write_params_txt(res.info)


def parse_prompt(self, p: StableDiffusionProcessing, params: AnimateDiffProcess):
Expand Down Expand Up @@ -82,6 +120,56 @@ def parse_prompt(self, p: StableDiffusionProcessing, params: AnimateDiffProcess)
p.prompt = prompt_list * p.n_iter


def parse_neg_prompt(self, p: StableDiffusionProcessing, params: AnimateDiffProcess):
if type(p.negative_prompt) is not str:
logger.warn("negative prompt is not str, cannot support prompt map")
return

lines = p.negative_prompt.strip().split('\n')
data = {
'head_prompts': [],
'mapp_prompts': {},
'tail_prompts': []
}

mode = 'head'
for line in lines:
if mode == 'head':
if re.match(r'^\d+:', line):
mode = 'mapp'
else:
data['head_prompts'].append(line)

if mode == 'mapp':
match = re.match(r'^(\d+): (.+)$', line)
if match:
frame, prompt = match.groups()
assert int(frame) < params.video_length, \
f"invalid negative prompt travel frame number: {int(frame)} >= number of frames ({params.video_length})"
data['mapp_prompts'][int(frame)] = prompt
else:
mode = 'tail'

if mode == 'tail':
data['tail_prompts'].append(line)

if data['mapp_prompts']:
logger.info("You are using negative prompt travel.")
self.neg_prompt_map = {}
prompt_list = []
last_frame = 0
current_prompt = ''
for frame, prompt in data['mapp_prompts'].items():
prompt_list += [current_prompt for _ in range(last_frame, frame)]
last_frame = frame
current_prompt = f"{', '.join(data['head_prompts'])}, {prompt}, {', '.join(data['tail_prompts'])}"
self.neg_prompt_map[frame] = current_prompt
prompt_list += [current_prompt for _ in range(last_frame, p.batch_size)]
assert len(prompt_list) == p.batch_size, f"negative prompt_list length {len(prompt_list)} != batch_size {p.batch_size}"
self.original_neg_prompt = p.negative_prompt
p.negative_prompt = prompt_list * p.n_iter


def single_cond(self, center_frame, video_length: int, cond: torch.Tensor, closed_loop = False):
if closed_loop:
key_prev = list(self.prompt_map.keys())[-1]
Expand Down