Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference with peft and base model are the same #515

Closed
fozziethebeat opened this issue May 29, 2023 · 1 comment
Closed

Inference with peft and base model are the same #515

fozziethebeat opened this issue May 29, 2023 · 1 comment

Comments

@fozziethebeat
Copy link

I trained a small pythia-160m model using a small variant of the example int8 training script in trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/clm_finetune_peft_imdb.py.

When doing inference on this LoRa model in a separate script, the results from both the adapter and base model remain the same. I've tried calling both get_base_model and disable_adapter which both seem like they should do the right thing after scanning the code.

A sample demonstration script:

import torch
import os

from dataclasses import dataclass, field
from datasets import load_dataset
from peft import PeftModel, PeftConfig
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
    HfArgumentParser,
    AutoModelForCausalLM,
    AutoTokenizer,
)


@dataclass
class ModelArguments:
    final_model: str = (
        field(
            default="peft_done",
        ),
    )


parser = HfArgumentParser(ModelArguments)

# Fun fact, even if you have only one argument class to parse, you still need
# to decompose a tuple.
(model_args,) = parser.parse_args_into_dataclasses()

peft_model_id = model_args.final_model
config = PeftConfig.from_pretrained(peft_model_id)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)

model = AutoModelForCausalLM.from_pretrained(
    config.base_model_name_or_path,
    return_dict=True,
    load_in_8bit=True,
    device_map="auto",
)


def generate(model, prompt):
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    if inputs["input_ids"].shape[1] >= (2048 - 128):
        return "Too Long"

    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=50)
        input_ids = inputs["input_ids"]
        generated_tokens = outputs[:, input_ids.shape[1] :]
        return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]


print(generate(model, "I really enjoyed the "))
model = PeftModel.from_pretrained(model, peft_model_id)
print(generate(model, "I really enjoyed the "))
print(generate(model.get_base_model(), "I really enjoyed the "))

This results in the following output:

the book. I have been reading it for a while now, and I am
very excited to read it again.

I am very happy with the book. I am very excited to read it again.
I am very happy with
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
icing on the cake. I was so impressed with the way the icing was applied. I was so impressed with the way the icing was applied. I was so impressed with the way the icing was applied. I was so impressed with the
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
icing on the cake. I was so impressed with the way the icing was applied. I was so impressed with the way the icing was applied. I was so impressed with the way the icing was applied. I was so impressed with the

From what I can tell, the inference from the base model before the adapter is loaded works is unique, but inference on both the adapter and base model after PeftModel.from_pretrained is exactly the same.

How does one disable an adapter or get inference results from a base model while an adapter is active?

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant