Skip to content

Commit

Permalink
feat: Fine Tune (Part 10) - Updating fine tuning py (#371)
Browse files Browse the repository at this point in the history
**Reason for Change**:
Updating the fine tuning program. Include fine tuning and inference in one image.
  • Loading branch information
ishaansehgal99 committed May 28, 2024
1 parent e5344e7 commit 498f92b
Show file tree
Hide file tree
Showing 20 changed files with 648 additions and 254 deletions.
58 changes: 53 additions & 5 deletions api/v1alpha1/params_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@ package v1alpha1
import (
"context"
"fmt"
"reflect"

"github.com/azure/kaito/pkg/k8sclient"
"github.com/azure/kaito/pkg/utils"
"gopkg.in/yaml.v2"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/runtime"
"knative.dev/pkg/apis"
"path/filepath"
"reflect"
"sigs.k8s.io/controller-runtime/pkg/client"
"strings"
)

type Config struct {
Expand Down Expand Up @@ -94,15 +95,59 @@ func (t *TrainingConfig) UnmarshalYAML(unmarshal func(interface{}) error) error
return nil
}

func validateMethodViaConfigMap(cm *corev1.ConfigMap, methodLowerCase string) *apis.FieldError {
func UnmarshalTrainingConfig(cm *corev1.ConfigMap) (*Config, *apis.FieldError) {
trainingConfigYAML, ok := cm.Data["training_config.yaml"]
if !ok {
return apis.ErrGeneric(fmt.Sprintf("ConfigMap '%s' does not contain 'training_config.yaml' in namespace '%s'", cm.Name, cm.Namespace), "config")
return nil, apis.ErrGeneric(fmt.Sprintf("ConfigMap '%s' does not contain 'training_config.yaml' in namespace '%s'", cm.Name, cm.Namespace), "config")
}

var config Config
if err := yaml.Unmarshal([]byte(trainingConfigYAML), &config); err != nil {
return apis.ErrGeneric(fmt.Sprintf("Failed to parse 'training_config.yaml' in ConfigMap '%s' in namespace '%s': %v", cm.Name, cm.Namespace, err), "config")
return nil, apis.ErrGeneric(fmt.Sprintf("Failed to parse 'training_config.yaml' in ConfigMap '%s' in namespace '%s': %v", cm.Name, cm.Namespace, err), "config")
}
return &config, nil
}

func validateTrainingArgsViaConfigMap(cm *corev1.ConfigMap) *apis.FieldError {
config, err := UnmarshalTrainingConfig(cm)
if err != nil {
return err
}

trainingArgs := config.TrainingConfig.TrainingArguments
if trainingArgs != nil {
trainingArgsRaw, trainingArgsExists := trainingArgs["TrainingArguments"]
if trainingArgsExists {
// If specified, ensure output dir is of type string
outputDirValue, found, err := utils.SearchRawExtension(trainingArgsRaw, "output_dir")
if err != nil {
return apis.ErrGeneric(fmt.Sprintf("Failed to parse 'output_dir' in ConfigMap '%s' in namespace '%s': %v", cm.Name, cm.Namespace, err), "output_dir")
}
if found {
userSpecifiedDir, ok := outputDirValue.(string)
if !ok {
return apis.ErrInvalidValue(fmt.Sprintf("output_dir is not a string in ConfigMap '%s' in namespace '%s'", cm.Name, cm.Namespace), "output_dir")
}

// Ensure the user-specified directory is under baseDir
baseDir := "/mnt"
cleanPath := filepath.Clean(filepath.Join(baseDir, userSpecifiedDir))
if cleanPath == baseDir || !strings.HasPrefix(cleanPath, baseDir) {
return apis.ErrInvalidValue(fmt.Sprintf("Invalid output_dir specified: '%s', must be a directory", userSpecifiedDir), "output_dir")
}
}

// TODO: Here we perform the tuning GPU Memory Checks!
fmt.Println(trainingArgsRaw)
}
}
return nil
}

func validateMethodViaConfigMap(cm *corev1.ConfigMap, methodLowerCase string) *apis.FieldError {
config, err := UnmarshalTrainingConfig(cm)
if err != nil {
return err
}

// Validate QuantizationConfig if it exists
Expand Down Expand Up @@ -225,6 +270,9 @@ func (r *TuningSpec) validateConfigMap(ctx context.Context, namespace string, me
if err := validateMethodViaConfigMap(&cm, methodLowerCase); err != nil {
errs = errs.Also(err)
}
if err := validateTrainingArgsViaConfigMap(&cm); err != nil {
errs = errs.Also(err)
}
}
return errs
}
6 changes: 3 additions & 3 deletions api/v1alpha1/workspace_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,12 +336,12 @@ func (i *InferenceSpec) validateCreate() (errs *apis.FieldError) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported inference preset name %s", presetName), "presetName"))
}
// Validate private preset has private image specified
if plugin.KaitoModelRegister.MustGet(string(i.Preset.Name)).GetInferenceParameters().ImageAccessMode == "private" &&
i.Preset.PresetMeta.AccessMode != "private" {
if plugin.KaitoModelRegister.MustGet(string(i.Preset.Name)).GetInferenceParameters().ImageAccessMode == string(ModelImageAccessModePrivate) &&
i.Preset.PresetMeta.AccessMode != ModelImageAccessModePrivate {
errs = errs.Also(apis.ErrGeneric("This preset only supports private AccessMode, AccessMode must be private to continue"))
}
// Additional validations for Preset
if i.Preset.PresetMeta.AccessMode == "private" && i.Preset.PresetOptions.Image == "" {
if i.Preset.PresetMeta.AccessMode == ModelImageAccessModePrivate && i.Preset.PresetOptions.Image == "" {
errs = errs.Also(apis.ErrGeneric("When AccessMode is private, an image must be provided in PresetOptions"))
}
// Note: we don't enforce private access mode to have image secrets, in case anonymous pulling is enabled
Expand Down
12 changes: 6 additions & 6 deletions api/v1alpha1/workspace_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,15 @@ type testModelPrivate struct{}

func (*testModelPrivate) GetInferenceParameters() *model.PresetParam {
return &model.PresetParam{
ImageAccessMode: "private",
ImageAccessMode: string(ModelImageAccessModePrivate),
GPUCountRequirement: gpuCountRequirement,
TotalGPUMemoryRequirement: totalGPUMemoryRequirement,
PerGPUMemoryRequirement: perGPUMemoryRequirement,
}
}
func (*testModelPrivate) GetTuningParameters() *model.PresetParam {
return &model.PresetParam{
ImageAccessMode: "private",
ImageAccessMode: string(ModelImageAccessModePrivate),
GPUCountRequirement: gpuCountRequirement,
TotalGPUMemoryRequirement: totalGPUMemoryRequirement,
PerGPUMemoryRequirement: perGPUMemoryRequirement,
Expand Down Expand Up @@ -121,7 +121,7 @@ func defaultConfigMapManifest() *v1.ConfigMap {
bias: "none"
TrainingArguments:
output_dir: "."
output_dir: "output"
num_train_epochs: 4
auto_find_batch_size: true
ddp_find_unused_parameters: false
Expand Down Expand Up @@ -168,7 +168,7 @@ func qloraConfigMapManifest() *v1.ConfigMap {
bias: "none"
TrainingArguments:
output_dir: "."
output_dir: "output"
num_train_epochs: 4
auto_find_batch_size: true
ddp_find_unused_parameters: false
Expand Down Expand Up @@ -461,7 +461,7 @@ func TestInferenceSpecValidateCreate(t *testing.T) {
Preset: &PresetSpec{
PresetMeta: PresetMeta{
Name: ModelName("test-validation"),
AccessMode: "private",
AccessMode: ModelImageAccessModePrivate,
},
PresetOptions: PresetOptions{},
},
Expand All @@ -488,7 +488,7 @@ func TestInferenceSpecValidateCreate(t *testing.T) {
Preset: &PresetSpec{
PresetMeta: PresetMeta{
Name: ModelName("test-validation"),
AccessMode: "public",
AccessMode: ModelImageAccessModePublic,
},
},
},
Expand Down
9 changes: 4 additions & 5 deletions charts/kaito/workspace/templates/lora-params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@ data:
load_in_4bit: false
LoraConfig:
r: 16
lora_alpha: 32
r: 8
lora_alpha: 8
target_modules: "query_key_value"
lora_dropout: 0.05
bias: "none"
lora_dropout: 0.0
TrainingArguments:
output_dir: "."
output_dir: "/mnt/results"
num_train_epochs: 4
auto_find_batch_size: true
ddp_find_unused_parameters: false
Expand Down
9 changes: 4 additions & 5 deletions charts/kaito/workspace/templates/qlora-params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,13 @@ data:
bnb_4bit_use_double_quant: true
LoraConfig:
r: 16
lora_alpha: 32
r: 8
lora_alpha: 8
target_modules: "query_key_value"
lora_dropout: 0.05
bias: "none"
lora_dropout: 0.0
TrainingArguments:
output_dir: "."
output_dir: "/mnt/results"
num_train_epochs: 4
auto_find_batch_size: true
ddp_find_unused_parameters: false
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,20 @@ RUN echo $VERSION > /workspace/tfs/version.txt
# First, copy just the preset files and install dependencies
# This is done before copying the code to utilize Docker's layer caching and
# avoid reinstalling dependencies unless the requirements file changes.
COPY kaito/presets/inference/${MODEL_TYPE}/requirements.txt /workspace/tfs/requirements.txt
RUN pip install --no-cache-dir -r requirements.txt
# Inference
COPY kaito/presets/inference/${MODEL_TYPE}/requirements.txt /workspace/tfs/inference-requirements.txt
RUN pip install --no-cache-dir -r inference-requirements.txt

COPY kaito/presets/inference/${MODEL_TYPE}/inference_api.py /workspace/tfs/inference_api.py

# Fine Tuning
COPY kaito/presets/tuning/${MODEL_TYPE}/requirements.txt /workspace/tfs/tuning-requirements.txt
RUN pip install --no-cache-dir -r tuning-requirements.txt

COPY kaito/presets/tuning/${MODEL_TYPE}/cli.py /workspace/tfs/cli.py
COPY kaito/presets/tuning/${MODEL_TYPE}/fine_tuning.py /workspace/tfs/fine_tuning.py
COPY kaito/presets/tuning/${MODEL_TYPE}/parser.py /workspace/tfs/parser.py
COPY kaito/presets/tuning/${MODEL_TYPE}/dataset.py /workspace/tfs/dataset.py

# Copy the entire model weights to the weights directory
COPY ${WEIGHTS_PATH} /workspace/tfs/weights
23 changes: 0 additions & 23 deletions docker/presets/tuning/Dockerfile

This file was deleted.

2 changes: 1 addition & 1 deletion pkg/inference/preset-inferences.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func updateTorchParamsForDistributedInference(ctx context.Context, kubeClient cl

func GetInferenceImageInfo(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, presetObj *model.PresetParam) (string, []corev1.LocalObjectReference) {
imagePullSecretRefs := []corev1.LocalObjectReference{}
if presetObj.ImageAccessMode == "private" {
if presetObj.ImageAccessMode == string(kaitov1alpha1.ModelImageAccessModePrivate) {
imageName := workspaceObj.Inference.Preset.PresetOptions.Image
for _, secretName := range workspaceObj.Inference.Preset.PresetOptions.ImagePullSecrets {
imagePullSecretRefs = append(imagePullSecretRefs, corev1.LocalObjectReference{Name: secretName})
Expand Down
Loading

0 comments on commit 498f92b

Please sign in to comment.