Skip to content

Commit

Permalink
fix: Standardize Dataset Input (#460)
Browse files Browse the repository at this point in the history
  • Loading branch information
ishaansehgal99 committed Jun 4, 2024
1 parent ef93dcb commit faa1fc7
Show file tree
Hide file tree
Showing 9 changed files with 12 additions and 161 deletions.
2 changes: 0 additions & 2 deletions api/v1alpha1/params_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ type Config struct {

type TrainingConfig struct {
ModelConfig map[string]runtime.RawExtension `yaml:"ModelConfig"`
TokenizerParams map[string]runtime.RawExtension `yaml:"TokenizerParams"`
QuantizationConfig map[string]runtime.RawExtension `yaml:"QuantizationConfig"`
LoraConfig map[string]runtime.RawExtension `yaml:"LoraConfig"`
TrainingArguments map[string]runtime.RawExtension `yaml:"TrainingArguments"`
Expand Down Expand Up @@ -76,7 +75,6 @@ func (t *TrainingConfig) UnmarshalYAML(unmarshal func(interface{}) error) error
target *map[string]runtime.RawExtension
}{
{"ModelConfig", &t.ModelConfig},
{"TokenizerParams", &t.TokenizerParams},
{"QuantizationConfig", &t.QuantizationConfig},
{"LoraConfig", &t.LoraConfig},
{"TrainingArguments", &t.TrainingArguments},
Expand Down
8 changes: 0 additions & 8 deletions api/v1alpha1/workspace_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,6 @@ func defaultConfigMapManifest() *v1.ConfigMap {
local_files_only: true
device_map: "auto"
TokenizerParams:
padding: true
truncation: true
QuantizationConfig:
load_in_4bit: false
Expand Down Expand Up @@ -150,10 +146,6 @@ func qloraConfigMapManifest() *v1.ConfigMap {
local_files_only: true
device_map: "auto"
TokenizerParams:
padding: true
truncation: true
QuantizationConfig:
load_in_4bit: true
bnb_4bit_quant_type: "nf4"
Expand Down
7 changes: 0 additions & 7 deletions api/v1alpha1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 3 additions & 20 deletions charts/kaito/workspace/templates/lora-params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@ data:
local_files_only: true
device_map: "auto"
TokenizerParams: # Configurable Parameters: https://huggingface.co/docs/transformers/v4.40.2/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__
padding: true # Default to true, generally recommended to pad to the longest sequence in the batch
truncation: true # Default to true to prevent errors from input sequences longer than max length
QuantizationConfig: # Configurable Parameters: https://huggingface.co/docs/transformers/v4.40.2/en/main_classes/quantization#transformers.BitsAndBytesConfig
load_in_4bit: false
Expand All @@ -36,20 +32,7 @@ data:
DatasetConfig: # Configurable Parameters: https://github.com/Azure/kaito/blob/main/presets/tuning/text-generation/cli.py#L44
shuffle_dataset: true
train_test_split: 1 # Default to using all data for fine-tuning due to strong pre-trained baseline and typically limited fine-tuning data.
# context_column: <Optional> For additional context or prompts, used in instruction fine-tuning.
# response_column: <Defaults to "text"> Main text column, required for general and instruction fine-tuning.
# messages_column: <Optional> For structured conversational data, used in chat fine-tuning.
# Expected Dataset format:
# {"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "Paris, as if everyone doesn't know that already."}]}
# e.g. https://huggingface.co/datasets/philschmid/dolly-15k-oai-style
# Column usage examples:
# 1. General Fine-Tuning:
# - Required Field: response_column
# - Example: response_column: "text"
# - Example Dataset: https://huggingface.co/datasets/stanfordnlp/imdb
# 2. Instruction Fine-Tuning:
# - Required Fields: context_column, response_column
# - Example: context_column: "question", response_column: "response"
# - Example Dataset: https://huggingface.co/datasets/Open-Orca/OpenOrca
# 3. Chat Fine-Tuning:
# - Required Field: messages_column
# - Example: messages_column: "messages"
# - Example Dataset: https://huggingface.co/datasets/philschmid/dolly-15k-oai-style
24 changes: 3 additions & 21 deletions charts/kaito/workspace/templates/qlora-params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@ data:
local_files_only: true
device_map: "auto"
TokenizerParams: # Configurable Parameters: https://huggingface.co/docs/transformers/v4.40.2/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__
padding: true # Default to true, generally recommended to pad to the longest sequence in the batch
truncation: true # Default to true to prevent errors from input sequences longer than max length
QuantizationConfig: # Configurable Parameters: https://huggingface.co/docs/transformers/v4.40.2/en/main_classes/quantization#transformers.BitsAndBytesConfig
load_in_4bit: true
bnb_4bit_quant_type: "nf4"
Expand All @@ -39,20 +35,6 @@ data:
DatasetConfig: # Configurable Parameters: https://github.com/Azure/kaito/blob/main/presets/tuning/text-generation/cli.py#L44
shuffle_dataset: true
train_test_split: 1 # Default to using all data for fine-tuning due to strong pre-trained baseline and typically limited fine-tuning data.
# context_column: <Optional> For additional context or prompts, used in instruction fine-tuning.
# response_column: <Defaults to "text"> Main text column, required for general and instruction fine-tuning.
# messages_column: <Optional> For structured conversational data, used in chat fine-tuning.
# Column usage examples:
# 1. General Fine-Tuning:
# - Required Field: response_column
# - Example: response_column: "text"
# - Example Dataset: https://huggingface.co/datasets/stanfordnlp/imdb
# 2. Instruction Fine-Tuning:
# - Required Fields: context_column, response_column
# - Example: context_column: "question", response_column: "response"
# - Example Dataset: https://huggingface.co/datasets/Open-Orca/OpenOrca
# 3. Chat Fine-Tuning:
# - Required Field: messages_column
# - Example: messages_column: "messages"
# - Example Dataset: https://huggingface.co/datasets/philschmid/dolly-15k-oai-style
# Expected Dataset format:
# {"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "Paris, as if everyone doesn't know that already."}]}
# e.g. https://huggingface.co/datasets/philschmid/dolly-15k-oai-style
21 changes: 0 additions & 21 deletions presets/tuning/text-generation/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,27 +55,6 @@ class DatasetConfig:
messages_column: Optional[str] = field(default=None, metadata={"help": "Column containing structured conversational data in JSON format with roles and content, used for chatbot training."})
train_test_split: float = field(default=0.8, metadata={"help": "Split between test and training data (e.g. 0.8 means 80/20% train/test split)"})

@dataclass
class TokenizerParams:
"""
Tokenizer params
"""
add_special_tokens: bool = field(default=True, metadata={"help": ""})
padding: bool = field(default=False, metadata={"help": ""})
truncation: bool = field(default=None, metadata={"help": ""})
max_length: Optional[int] = field(default=None, metadata={"help": ""})
stride: int = field(default=0, metadata={"help": ""})
is_split_into_words: bool = field(default=False, metadata={"help": ""})
pad_to_multiple_of: Optional[int] = field(default=None, metadata={"help": ""})
return_tensors: Optional[str] = field(default=None, metadata={"help": ""})
return_token_type_ids: Optional[bool] = field(default=None, metadata={"help": ""})
return_attention_mask: Optional[bool] = field(default=None, metadata={"help": ""})
return_overflowing_tokens: bool = field(default=False, metadata={"help": ""})
return_special_tokens_mask: bool = field(default=False, metadata={"help": ""})
return_offsets_mapping: bool = field(default=False, metadata={"help": ""})
return_length: bool = field(default=False, metadata={"help": ""})
verbose: bool = field(default=True, metadata={"help": ""})

@dataclass
class ModelConfig:
"""
Expand Down
77 changes: 4 additions & 73 deletions presets/tuning/text-generation/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@
SUPPORTED_EXTENSIONS = {'csv', 'json', 'parquet', 'arrow', 'webdataset'}

class DatasetManager:
def __init__(self, config, tokenizer, tokenizer_params):
def __init__(self, config):
self.config = config
self.tokenizer_params = tokenizer_params
self.tokenizer = tokenizer
self.dataset = None
self.dataset_text_field = None # Set this field if dataset consists of singular text column

Expand All @@ -31,6 +29,9 @@ def select_and_rename_columns(self, columns_to_select, rename_map=None):
self.dataset = self.dataset.rename_column(old_name, new_name)

def load_data(self):
# OAI Compliant: https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset
# https://github.com/huggingface/trl/blob/main/trl/extras/dataset_formatting.py
# https://huggingface.co/docs/trl/en/sft_trainer#dataset-format-support
if self.config.dataset_path:
dataset_path = os.path.join("/mnt", self.config.dataset_path.strip("/"))
else:
Expand Down Expand Up @@ -85,73 +86,3 @@ def split_dataset(self):
def get_dataset(self):
self.check_dataset_loaded()
return self.dataset

def format_and_preprocess(self):
# OAI Compliant: https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset
# https://github.com/huggingface/trl/blob/main/trl/extras/dataset_formatting.py
# https://huggingface.co/docs/trl/en/sft_trainer#dataset-format-support
if self.config.messages_column:
self.format_conversational()
elif self.config.context_column and self.config.response_column:
self.format_instruct()
elif self.config.response_column:
self.dataset = self.dataset.map(
lambda example: self.tokenizer(example[self.config.response_column], **self.tokenizer_params),
batched=True
)
self.format_text()
self.dataset_text_field = self.config.response_column

def format_text(self):
self.check_dataset_loaded()
self.check_column_exists(self.config.response_column)
self.select_and_rename_columns([self.config.response_column])

def format_instruct(self):
"""Ensure dataset is formatted for instruct fine tuning"""
self.check_dataset_loaded()
required_columns = [self.config.context_column, self.config.response_column]
for column in required_columns:
self.check_column_exists(column)

# Select and rename columns
rename_map = {}
if self.config.context_column != "prompt":
rename_map[self.config.context_column] = "prompt"
if self.config.response_column != "completion":
rename_map[self.config.response_column] = "completion"
self.select_and_rename_columns(required_columns, rename_map)

def format_conversational(self):
"""Ensure some basic formatting of dataset for conversational fine tuning"""
self.check_dataset_loaded()
# Check if the specified column exists in the dataset
self.check_column_exists(self.config.messages_column)

# Select and rename columns
rename_map = {}
if self.config.messages_column != "messages":
rename_map[self.config.messages_column] = "messages"
self.select_and_rename_columns([self.config.messages_column], rename_map)

# Consider supporting in future
# https://github.com/huggingface/trl/pull/444
# def format_instruction_based_fn(self, examples):
# output_text = []
# for i in range(len(examples[self.config.context_column])):
# instruction = examples[self.config.instruction_column][i]
# context = examples[self.config.context_column][i]
# response = examples[self.config.response_column][i]
# text = f'''Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

# ### Instruction:
# {instruction}

# ### Input:
# {context}

# ### Response:
# {response}
# '''
# output_text.append(text)
# return output_text
7 changes: 1 addition & 6 deletions presets/tuning/text-generation/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
parsed_configs = parse_configs(CONFIG_YAML)

model_config = parsed_configs.get('ModelConfig')
tk_params = parsed_configs.get('TokenizerParams')
bnb_config = parsed_configs.get('QuantizationConfig')
ext_lora_config = parsed_configs.get('LoraConfig')
ta_args = parsed_configs.get('TrainingArguments')
Expand All @@ -41,9 +40,6 @@
bnb_config = BitsAndBytesConfig(**bnb_config_args)
enable_qlora = bnb_config.is_quantizable()

# Load Tokenizer Params
tk_params = asdict(tk_params)

# Load the Pre-Trained Tokenizer
tokenizer_args = {key: value for key, value in model_args.items() if key != "torch_dtype"}
tokenizer = AutoTokenizer.from_pretrained(**tokenizer_args)
Expand Down Expand Up @@ -82,7 +78,7 @@
model.config.use_cache = False
model.print_trainable_parameters()

dm = DatasetManager(ds_config, tokenizer, tk_params)
dm = DatasetManager(ds_config)
# Load the dataset
dm.load_data()
if not dm.get_dataset():
Expand All @@ -93,7 +89,6 @@
if ds_config.shuffle_dataset:
dm.shuffle_dataset()

dm.format_and_preprocess()
train_dataset, eval_dataset = dm.split_dataset()

# checkpoint_callback = CheckpointCallback()
Expand Down
4 changes: 1 addition & 3 deletions presets/tuning/text-generation/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@
from dataclasses import asdict, fields

import yaml
from cli import (DatasetConfig, ExtDataCollator, ExtLoraConfig, ModelConfig,
QuantizationConfig, TokenizerParams)
from cli import (DatasetConfig, ExtDataCollator, ExtLoraConfig, ModelConfig, QuantizationConfig)
from transformers import HfArgumentParser, TrainingArguments

# Mapping from config section names to data classes
CONFIG_CLASS_MAP = {
'ModelConfig': ModelConfig,
'TokenizerParams': TokenizerParams,
'QuantizationConfig': QuantizationConfig,
'LoraConfig': ExtLoraConfig,
'TrainingArguments': TrainingArguments,
Expand Down

0 comments on commit faa1fc7

Please sign in to comment.