Skip to content

Commit

Permalink
fix: Refactor Naming Conventions, Update Dependencies, Enhance Exampl…
Browse files Browse the repository at this point in the history
…es, and Add Volume Validation Check (#470)

**Reason for Change**:
This PR includes a bunch of renaming and chore changes. It includes: 

1. Package renaming to avoid hyphens in package names for phi-2 and
phi-3.
2. Updating pip dependencies required for phi-3
3. Use huggingface naming for plugins with hyphens
4. Add phi3 to examples folder
5. Readme: Update phi-3 example requests 
6. Change ConfigMapTemplate back to Config and update the comments
7. Fix a device error encountered by inference API UTs by explicitly
specifying device
8. Raise not implemented error in validation for using volume
  • Loading branch information
ishaansehgal99 committed Jun 18, 2024
1 parent e592eb3 commit f2b3504
Show file tree
Hide file tree
Showing 22 changed files with 156 additions and 137 deletions.
11 changes: 6 additions & 5 deletions api/v1alpha1/params_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,18 @@ package v1alpha1
import (
"context"
"fmt"
"path/filepath"
"reflect"
"strings"

"github.com/azure/kaito/pkg/k8sclient"
"github.com/azure/kaito/pkg/utils"
"gopkg.in/yaml.v2"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/runtime"
"knative.dev/pkg/apis"
"path/filepath"
"reflect"
"sigs.k8s.io/controller-runtime/pkg/client"
"strings"
)

type Config struct {
Expand Down Expand Up @@ -257,9 +258,9 @@ func (r *TuningSpec) validateConfigMap(ctx context.Context, namespace string, me
err := k8sclient.Client.Get(ctx, client.ObjectKey{Name: configMapName, Namespace: namespace}, &cm)
if err != nil {
if errors.IsNotFound(err) {
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("ConfigMap '%s' specified in 'config' not found in namespace '%s'", r.ConfigTemplate, namespace), "config"))
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("ConfigMap '%s' specified in 'config' not found in namespace '%s'", r.Config, namespace), "config"))
} else {
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Failed to get ConfigMap '%s' in namespace '%s': %v", r.ConfigTemplate, namespace, err), "config"))
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Failed to get ConfigMap '%s' in namespace '%s': %v", r.Config, namespace, err), "config"))
}
} else {
if err := validateConfigMapSchema(&cm); err != nil {
Expand Down
9 changes: 4 additions & 5 deletions api/v1alpha1/workspace_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,11 @@ type TuningSpec struct {
// Method specifies the Parameter-Efficient Fine-Tuning(PEFT) method, such as lora, qlora, used for the tuning.
// +optional
Method TuningMethod `json:"method,omitempty"`
// ConfigTemplate specifies the name of the configmap that contains the basic tuning arguments.
// A separate configmap will be generated based on the ConfigTemplate and the preset model name, and used by
// the tuning Job. If specified, the congfigmap needs to be in the same namespace of the workspace custom resource.
// If not specified, a default ConfigTemplate is used based on the specified tuning method.
// Config specifies the name of a custom ConfigMap that contains tuning arguments.
// If specified, the ConfigMap must be in the same namespace as the Workspace custom resource.
// If not specified, a default Config is used based on the specified tuning method.
// +optional
ConfigTemplate string `json:"configTemplate,omitempty"`
Config string `json:"config,omitempty"`
// Input describes the input used by the tuning method.
Input *DataSource `json:"input"`
// Output specified where to store the tuning output.
Expand Down
15 changes: 12 additions & 3 deletions api/v1alpha1/workspace_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func (r *TuningSpec) validateCreate(ctx context.Context, workspaceNamespace stri
if methodLowerCase != string(TuningMethodLora) && methodLowerCase != string(TuningMethodQLora) {
errs = errs.Also(apis.ErrInvalidValue(r.Method, "Method"))
}
if r.ConfigTemplate == "" {
if r.Config == "" {
klog.InfoS("Tuning config not specified. Using default based on method.")
releaseNamespace, err := utils.GetReleaseNamespace()
if err != nil {
Expand All @@ -149,7 +149,7 @@ func (r *TuningSpec) validateCreate(ctx context.Context, workspaceNamespace stri
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Failed to evaluate validateConfigMap: %v", err), "Config"))
}
} else {
if err := r.validateConfigMap(ctx, workspaceNamespace, methodLowerCase, r.ConfigTemplate); err != nil {
if err := r.validateConfigMap(ctx, workspaceNamespace, methodLowerCase, r.Config); err != nil {
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Failed to evaluate validateConfigMap: %v", err), "Config"))
}
}
Expand Down Expand Up @@ -200,6 +200,7 @@ func (r *DataSource) validateCreate() (errs *apis.FieldError) {
sourcesSpecified++
}
if r.Volume != nil {
errs = errs.Also(apis.ErrInvalidValue("Volume support is not implemented yet", "Volume"))
sourcesSpecified++
}
// Regex checks for a / and a colon followed by a tag
Expand All @@ -223,6 +224,9 @@ func (r *DataSource) validateUpdate(old *DataSource, isTuning bool) (errs *apis.
if isTuning && !reflect.DeepEqual(old.Name, r.Name) {
errs = errs.Also(apis.ErrInvalidValue("During tuning Name field cannot be changed once set", "Name"))
}
if r.Volume != nil {
errs = errs.Also(apis.ErrInvalidValue("Volume support is not implemented yet", "Volume"))
}
oldURLs := make([]string, len(old.URLs))
copy(oldURLs, old.URLs)
sort.Strings(oldURLs)
Expand Down Expand Up @@ -255,7 +259,9 @@ func (r *DataSource) validateUpdate(old *DataSource, isTuning bool) (errs *apis.

func (r *DataDestination) validateCreate() (errs *apis.FieldError) {
destinationsSpecified := 0
// TODO: Implement Volumes
if r.Volume != nil {
errs = errs.Also(apis.ErrInvalidValue("Volume support is not implemented yet", "Volume"))
destinationsSpecified++
}
if r.Image != "" {
Expand All @@ -279,7 +285,10 @@ func (r *DataDestination) validateCreate() (errs *apis.FieldError) {
}

func (r *DataDestination) validateUpdate(old *DataDestination) (errs *apis.FieldError) {
// TODO: Check if the Volume is changed.
// TODO: Implement Volumes
if r.Volume != nil {
errs = errs.Also(apis.ErrInvalidValue("Volume support is not implemented yet", "Volume"))
}
if old.Image != r.Image {
errs = errs.Also(apis.ErrInvalidValue("Image field cannot be changed once set", "Image"))
}
Expand Down
81 changes: 40 additions & 41 deletions api/v1alpha1/workspace_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -893,8 +893,8 @@ func TestTuningSpecValidateCreate(t *testing.T) {
{
name: "All fields valid",
tuningSpec: &TuningSpec{
Input: &DataSource{Name: "valid-input", Volume: &v1.VolumeSource{}},
Output: &DataDestination{Volume: &v1.VolumeSource{}},
Input: &DataSource{Name: "valid-input", Image: "AZURE_ACR.azurecr.io/test:0.0.0"},
Output: &DataDestination{Image: "AZURE_ACR.azurecr.io/test:0.0.0", ImagePushSecret: "secret"},
Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}},
Method: TuningMethodLora,
},
Expand All @@ -904,8 +904,8 @@ func TestTuningSpecValidateCreate(t *testing.T) {
{
name: "Verify QLoRA Config",
tuningSpec: &TuningSpec{
Input: &DataSource{Name: "valid-input", Volume: &v1.VolumeSource{}},
Output: &DataDestination{Volume: &v1.VolumeSource{}},
Input: &DataSource{Name: "valid-input", Image: "AZURE_ACR.azurecr.io/test:0.0.0"},
Output: &DataDestination{Image: "AZURE_ACR.azurecr.io/test:0.0.0", ImagePushSecret: "secret"},
Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}},
Method: TuningMethodQLora,
},
Expand All @@ -915,7 +915,7 @@ func TestTuningSpecValidateCreate(t *testing.T) {
{
name: "Missing Input",
tuningSpec: &TuningSpec{
Output: &DataDestination{Volume: &v1.VolumeSource{}},
Output: &DataDestination{Image: "AZURE_ACR.azurecr.io/test:0.0.0", ImagePushSecret: ""},
Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}},
Method: TuningMethodLora,
},
Expand All @@ -936,7 +936,7 @@ func TestTuningSpecValidateCreate(t *testing.T) {
name: "Missing Preset",
tuningSpec: &TuningSpec{
Input: &DataSource{Name: "valid-input"},
Output: &DataDestination{Volume: &v1.VolumeSource{}},
Output: &DataDestination{Image: "AZURE_ACR.azurecr.io/test:0.0.0", ImagePushSecret: ""},
Method: TuningMethodLora,
},
wantErr: true,
Expand All @@ -946,7 +946,7 @@ func TestTuningSpecValidateCreate(t *testing.T) {
name: "Invalid Preset",
tuningSpec: &TuningSpec{
Input: &DataSource{Name: "valid-input"},
Output: &DataDestination{Volume: &v1.VolumeSource{}},
Output: &DataDestination{Image: "AZURE_ACR.azurecr.io/test:0.0.0", ImagePushSecret: ""},
Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("invalid-preset")}},
Method: TuningMethodLora,
},
Expand All @@ -957,7 +957,7 @@ func TestTuningSpecValidateCreate(t *testing.T) {
name: "Invalid Method",
tuningSpec: &TuningSpec{
Input: &DataSource{Name: "valid-input"},
Output: &DataDestination{Volume: &v1.VolumeSource{}},
Output: &DataDestination{Image: "AZURE_ACR.azurecr.io/test:0.0.0", ImagePushSecret: ""},
Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}},
Method: "invalid-method",
},
Expand Down Expand Up @@ -999,13 +999,13 @@ func TestTuningSpecValidateUpdate(t *testing.T) {
name: "No changes",
oldTuning: &TuningSpec{
Input: &DataSource{Name: "input1"},
Output: &DataDestination{Volume: &v1.VolumeSource{}},
Output: &DataDestination{Image: "AZURE_ACR.azurecr.io/test:0.0.0"},
Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}},
Method: TuningMethodLora,
},
newTuning: &TuningSpec{
Input: &DataSource{Name: "input1"},
Output: &DataDestination{Volume: &v1.VolumeSource{}},
Output: &DataDestination{Image: "AZURE_ACR.azurecr.io/test:0.0.0"},
Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}},
Method: TuningMethodLora,
},
Expand Down Expand Up @@ -1072,7 +1072,7 @@ func TestDataSourceValidateCreate(t *testing.T) {
{
name: "Volume specified only",
dataSource: &DataSource{
Volume: &v1.VolumeSource{},
Image: "AZURE_ACR.azurecr.io/test:0.0.0",
},
wantErr: false,
},
Expand Down Expand Up @@ -1105,20 +1105,19 @@ func TestDataSourceValidateCreate(t *testing.T) {
wantErr: true,
errField: "Exactly one of URLs, Volume, or Image must be specified",
},
{
name: "URLs and Volume specified",
dataSource: &DataSource{
URLs: []string{"http://example.com/data"},
Volume: &v1.VolumeSource{},
},
wantErr: true,
errField: "Exactly one of URLs, Volume, or Image must be specified",
},
// {
// name: "URLs and Volume specified",
// dataSource: &DataSource{
// URLs: []string{"http://example.com/data"},
// Volume: &v1.VolumeSource{},
// },
// wantErr: true,
// errField: "Exactly one of URLs, Volume, or Image must be specified",
// },
{
name: "All fields specified",
dataSource: &DataSource{
URLs: []string{"http://example.com/data"},
Volume: &v1.VolumeSource{},
Image: "aimodels.azurecr.io/data-image:latest",
},
wantErr: true,
Expand Down Expand Up @@ -1154,13 +1153,13 @@ func TestDataSourceValidateUpdate(t *testing.T) {
name: "No changes",
oldSource: &DataSource{
URLs: []string{"http://example.com/data1", "http://example.com/data2"},
Volume: &v1.VolumeSource{},
// Volume: &v1.VolumeSource{},
Image: "data-image:latest",
ImagePullSecrets: []string{"secret1", "secret2"},
},
newSource: &DataSource{
URLs: []string{"http://example.com/data2", "http://example.com/data1"}, // Note the different order, should not matter
Volume: &v1.VolumeSource{},
// Volume: &v1.VolumeSource{},
Image: "data-image:latest",
ImagePullSecrets: []string{"secret2", "secret1"}, // Note the different order, should not matter
},
Expand Down Expand Up @@ -1245,13 +1244,13 @@ func TestDataDestinationValidateCreate(t *testing.T) {
wantErr: true,
errField: "At least one of Volume or Image must be specified",
},
{
name: "Volume specified only",
dataDestination: &DataDestination{
Volume: &v1.VolumeSource{},
},
wantErr: false,
},
// {
// name: "Volume specified only",
// dataDestination: &DataDestination{
// Volume: &v1.VolumeSource{},
// },
// wantErr: false,
// },
{
name: "Image specified only",
dataDestination: &DataDestination{
Expand All @@ -1276,15 +1275,15 @@ func TestDataDestinationValidateCreate(t *testing.T) {
},
wantErr: true,
},
{
name: "Both fields specified",
dataDestination: &DataDestination{
Volume: &v1.VolumeSource{},
Image: "aimodels.azurecr.io/data-image:latest",
ImagePushSecret: "imagePushSecret",
},
wantErr: false,
},
// {
// name: "Both fields specified",
// dataDestination: &DataDestination{
// Volume: &v1.VolumeSource{},
// Image: "aimodels.azurecr.io/data-image:latest",
// ImagePushSecret: "imagePushSecret",
// },
// wantErr: false,
// },
}

for _, tt := range tests {
Expand Down Expand Up @@ -1314,12 +1313,12 @@ func TestDataDestinationValidateUpdate(t *testing.T) {
{
name: "No changes",
oldDest: &DataDestination{
Volume: &v1.VolumeSource{},
// Volume: &v1.VolumeSource{},
Image: "old-image:latest",
ImagePushSecret: "old-secret",
},
newDest: &DataDestination{
Volume: &v1.VolumeSource{},
// Volume: &v1.VolumeSource{},
Image: "old-image:latest",
ImagePushSecret: "old-secret",
},
Expand Down
9 changes: 4 additions & 5 deletions charts/kaito/workspace/crds/kaito.sh_workspaces.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -299,12 +299,11 @@ spec:
type: object
tuning:
properties:
configTemplate:
config:
description: |-
ConfigTemplate specifies the name of the configmap that contains the basic tuning arguments.
A separate configmap will be generated based on the ConfigTemplate and the preset model name, and used by
the tuning Job. If specified, the congfigmap needs to be in the same namespace of the workspace custom resource.
If not specified, a default ConfigTemplate is used based on the specified tuning method.
Config specifies the name of a custom ConfigMap that contains tuning arguments.
If specified, the ConfigMap must be in the same namespace as the Workspace custom resource.
If not specified, a default Config is used based on the specified tuning method.
type: string
input:
description: Input describes the input used by the tuning method.
Expand Down
4 changes: 2 additions & 2 deletions cmd/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ import (
_ "github.com/azure/kaito/presets/models/llama2"
_ "github.com/azure/kaito/presets/models/llama2chat"
_ "github.com/azure/kaito/presets/models/mistral"
_ "github.com/azure/kaito/presets/models/phi-2"
_ "github.com/azure/kaito/presets/models/phi-3"
_ "github.com/azure/kaito/presets/models/phi2"
_ "github.com/azure/kaito/presets/models/phi3"
)
9 changes: 4 additions & 5 deletions config/crd/bases/kaito.sh_workspaces.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -299,12 +299,11 @@ spec:
type: object
tuning:
properties:
configTemplate:
config:
description: |-
ConfigTemplate specifies the name of the configmap that contains the basic tuning arguments.
A separate configmap will be generated based on the ConfigTemplate and the preset model name, and used by
the tuning Job. If specified, the congfigmap needs to be in the same namespace of the workspace custom resource.
If not specified, a default ConfigTemplate is used based on the specified tuning method.
Config specifies the name of a custom ConfigMap that contains tuning arguments.
If specified, the ConfigMap must be in the same namespace as the Workspace custom resource.
If not specified, a default Config is used based on the specified tuning method.
type: string
input:
description: Input describes the input used by the tuning method.
Expand Down
20 changes: 0 additions & 20 deletions examples/fine-tuning/kaito_workspace_tuning_falcon_7b.yaml

This file was deleted.

19 changes: 19 additions & 0 deletions examples/fine-tuning/kaito_workspace_tuning_phi_3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
apiVersion: kaito.sh/v1alpha1
kind: Workspace
metadata:
name: workspace-tuning-phi-3
resource:
instanceType: "Standard_NC6s_v3"
labelSelector:
matchLabels:
app: tuning-phi-3
tuning:
preset:
name: phi-3-mini-128k-instruct
method: qlora
input:
urls:
- "https://huggingface.co/datasets/philschmid/dolly-15k-oai-style/resolve/main/data/train-00000-of-00001-54e3756291ca09c6.parquet?download=true"
output:
image: "ACR_REPO_HERE.azurecr.io/ADAPTER_HERE:0.0.1" # Tuning Output ACR Path
imagePushSecret: ACR_REGISTRY_SECRET_HERE
13 changes: 13 additions & 0 deletions examples/inference/kaito_workspace_phi_3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
apiVersion: kaito.sh/v1alpha1
kind: Workspace
metadata:
name: workspace-phi-3-mini
resource:
instanceType: "Standard_NC6s_v3"
labelSelector:
matchLabels:
apps: phi-3
inference:
preset:
name: phi-3-mini-4k-instruct
# Note: This configuration also works with the phi-3-mini-128k-instruct preset
Loading

0 comments on commit f2b3504

Please sign in to comment.