Skip to content

Commit

Permalink
feat: Automate the adapters manifests (#463)
Browse files Browse the repository at this point in the history
**Reason for Change**:
<!-- What does this PR improve or fix in Kaito? Why is it needed? -->

**Requirements**

- [ ] added unit tests and e2e tests (if applicable).

**Issue Fixed**:
<!-- If this PR fixes GitHub issue 4321, add "Fixes #4321" to the next
line. -->

**Notes for Reviewers**:

---------

Signed-off-by: Bangqi Zhu <[email protected]>
Co-authored-by: Bangqi Zhu <[email protected]>
  • Loading branch information
bangqipropel and Bangqi Zhu committed Jun 12, 2024
1 parent 5f03f40 commit 4db58ce
Show file tree
Hide file tree
Showing 10 changed files with 385 additions and 12 deletions.
73 changes: 72 additions & 1 deletion api/v1alpha1/workspace_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"reflect"
"regexp"
"sort"
"strconv"
"strings"

"github.com/azure/kaito/pkg/utils"
Expand All @@ -27,6 +28,7 @@ const (

DefaultLoraConfigMapTemplate = "lora-params-template"
DefaultQloraConfigMapTemplate = "qlora-params-template"
MaxAdaptersNumber = 10
)

func (w *Workspace) SupportedVerbs() []admissionregistrationv1.OperationType {
Expand Down Expand Up @@ -58,7 +60,6 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) {
w.Resource.validateUpdate(&old.Resource).ViaField("resource"),
)
if w.Inference != nil {
// TODO: Add Adapter Spec Validation - Including DataSource Validation for Adapter
errs = errs.Also(w.Inference.validateUpdate(old.Inference).ViaField("inference"))
}
if w.Tuning != nil {
Expand Down Expand Up @@ -89,6 +90,44 @@ func (w *Workspace) validateUpdate(old *Workspace) (errs *apis.FieldError) {
return errs
}

func ValidateDNSSubdomain(name string) bool {
var dnsSubDomainRegexp = regexp.MustCompile(`^(?i:[a-z0-9]([-a-z0-9]*[a-z0-9])?)$`)
if len(name) < 1 || len(name) > 253 {
return false
}
return dnsSubDomainRegexp.MatchString(name)
}

func (r *AdapterSpec) validateCreateorUpdate() (errs *apis.FieldError) {
if r.Source == nil {
errs = errs.Also(apis.ErrMissingField("Source"))
} else {
errs = errs.Also(r.Source.validateCreate().ViaField("Adapters"))

if r.Source.Name == "" {
errs = errs.Also(apis.ErrMissingField("Name of Adapter field must be specified"))
} else if !ValidateDNSSubdomain(r.Source.Name) {
errs = errs.Also(apis.ErrMissingField("Name of Adapter must be a valid DNS subdomain value"))
}
if r.Source.Image == "" {
errs = errs.Also(apis.ErrMissingField("Image of Adapter field must be specified"))
}
if r.Strength == nil {
var defaultStrength = "1.0"
r.Strength = &defaultStrength
}
strength, err := strconv.ParseFloat(*r.Strength, 64)
if err != nil {
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Invalid strength value for Adapter '%s': %v", r.Source.Name, err), "adapter"))
}
if strength < 0 || strength > 1.0 {
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Strength value for Adapter '%s' must be between 0 and 1", r.Source.Name), "adapter"))
}

}
return errs
}

func (r *TuningSpec) validateCreate(ctx context.Context, workspaceNamespace string) (errs *apis.FieldError) {
methodLowerCase := strings.ToLower(string(r.Method))
if methodLowerCase != string(TuningMethodLora) && methodLowerCase != string(TuningMethodQLora) {
Expand Down Expand Up @@ -346,6 +385,16 @@ func (i *InferenceSpec) validateCreate() (errs *apis.FieldError) {
}
// Note: we don't enforce private access mode to have image secrets, in case anonymous pulling is enabled
}
if len(i.Adapters) > MaxAdaptersNumber {
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Number of Adapters exceeds the maximum limit, maximum of %s allowed", strconv.Itoa(MaxAdaptersNumber))))
}

// check if adapter names are duplicate
if len(i.Adapters) > 0 {
nameMap := make(map[string]bool)
errs = errs.Also(validateDuplicateName(i.Adapters, nameMap))
}

return errs
}

Expand All @@ -358,5 +407,27 @@ func (i *InferenceSpec) validateUpdate(old *InferenceSpec) (errs *apis.FieldErro
errs = errs.Also(apis.ErrGeneric("field cannot be unset/set if it was set/unset", "template"))
}

// check if adapter names are duplicate
for _, adapter := range i.Adapters {
errs = errs.Also(adapter.validateCreateorUpdate())
}

// check if adapter names are duplicate

if len(i.Adapters) > 0 {
nameMap := make(map[string]bool)
errs = errs.Also(validateDuplicateName(i.Adapters, nameMap))
}
return errs
}

func validateDuplicateName(adapters []AdapterSpec, nameMap map[string]bool) (errs *apis.FieldError) {
for _, adapter := range adapters {
if _, ok := nameMap[adapter.Source.Name]; ok {
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Duplicate adapter source name found: %s", adapter.Source.Name)))
} else {
nameMap[adapter.Source.Name] = true
}
}
return errs
}
142 changes: 141 additions & 1 deletion api/v1alpha1/workspace_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package v1alpha1

import (
"context"
"fmt"
"os"
"reflect"
"sort"
Expand All @@ -24,6 +25,10 @@ import (

const DefaultReleaseNamespace = "kaito-workspace"

var ValidStrength string = "0.5"
var InvalidStrength1 string = "invalid"
var InvalidStrength2 string = "1.5"

var gpuCountRequirement string
var totalGPUMemoryRequirement string
var perGPUMemoryRequirement string
Expand Down Expand Up @@ -474,6 +479,56 @@ func TestInferenceSpecValidateCreate(t *testing.T) {
errContent: "This preset only supports private AccessMode, AccessMode must be private to continue",
expectErrs: true,
},
{
name: "Adapeters more than 10",
inferenceSpec: func() *InferenceSpec {
spec := &InferenceSpec{
Preset: &PresetSpec{
PresetMeta: PresetMeta{
Name: ModelName("test-validation"),
AccessMode: ModelImageAccessModePublic,
},
},
}
for i := 1; i <= 11; i++ {
spec.Adapters = append(spec.Adapters, AdapterSpec{
Source: &DataSource{
Name: fmt.Sprintf("Adapter-%d", i),
Image: fmt.Sprintf("fake.kaito.com/kaito-image:0.0.%d", i),
},
Strength: &ValidStrength,
})
}
return spec
}(),
errContent: "Number of Adapters exceeds the maximum limit, maximum of 10 allowed",
expectErrs: true,
},
{
name: "Adapeters names are duplicated",
inferenceSpec: func() *InferenceSpec {
spec := &InferenceSpec{
Preset: &PresetSpec{
PresetMeta: PresetMeta{
Name: ModelName("test-validation"),
AccessMode: ModelImageAccessModePublic,
},
},
}
for i := 1; i <= 2; i++ {
spec.Adapters = append(spec.Adapters, AdapterSpec{
Source: &DataSource{
Name: "Adapter",
Image: fmt.Sprintf("fake.kaito.com/kaito-image:0.0.%d", i),
},
Strength: &ValidStrength,
})
}
return spec
}(),
errContent: "",
expectErrs: true,
},
{
name: "Valid Preset",
inferenceSpec: &InferenceSpec{
Expand All @@ -484,7 +539,7 @@ func TestInferenceSpecValidateCreate(t *testing.T) {
},
},
},
errContent: "",
errContent: "Duplicate adapter source name found:",
expectErrs: false,
},
}
Expand Down Expand Up @@ -520,6 +575,91 @@ func TestInferenceSpecValidateCreate(t *testing.T) {
}
}

func TestAdapterSpecValidateCreateorUpdate(t *testing.T) {
RegisterValidationTestModels()
tests := []struct {
name string
adapterSpec *AdapterSpec
errContent string // Content expected error to include, if any
expectErrs bool
}{
{
name: "Missing Source",
adapterSpec: &AdapterSpec{
Strength: &ValidStrength,
},
errContent: "Source",
expectErrs: true,
},
{
name: "Missing Source Name",
adapterSpec: &AdapterSpec{
Source: &DataSource{
Image: "fake.kaito.com/kaito-image:0.0.1",
},
Strength: &ValidStrength,
},
errContent: "Name of Adapter field must be specified",
expectErrs: true,
},
{
name: "Invalid Strength, not a number",
adapterSpec: &AdapterSpec{
Source: &DataSource{
Name: "Adapter-1",
Image: "fake.kaito.com/kaito-image:0.0.1",
},
Strength: &InvalidStrength1,
},
errContent: "Invalid strength value for Adapter 'Adapter-1'",
expectErrs: true,
},
{
name: "Invalid Strength, larger than 1",
adapterSpec: &AdapterSpec{
Source: &DataSource{
Name: "Adapter-1",
Image: "fake.kaito.com/kaito-image:0.0.1",
},
Strength: &InvalidStrength2,
},
errContent: "Strength value for Adapter 'Adapter-1' must be between 0 and 1",
expectErrs: true,
},
{
name: "Valid Adapter",
adapterSpec: &AdapterSpec{
Source: &DataSource{
Name: "Adapter-1",
Image: "fake.kaito.com/kaito-image:0.0.1",
},
Strength: &ValidStrength,
},
errContent: "",
expectErrs: false,
},
}

// Run the tests
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
errs := tc.adapterSpec.validateCreateorUpdate()
hasErrs := errs != nil
if hasErrs != tc.expectErrs {
t.Errorf("validateUpdate() errors = %v, expectErrs %v", errs, tc.expectErrs)
}

// If there is an error and errContent is not empty, check that the error contains the expected content.
if hasErrs && tc.errContent != "" {
errMsg := errs.Error()
if !strings.Contains(errMsg, tc.errContent) {
t.Errorf("validateUpdate() error message = %v, expected to contain = %v", errMsg, tc.errContent)
}
}
})
}
}

func TestInferenceSpecValidateUpdate(t *testing.T) {
tests := []struct {
name string
Expand Down
1 change: 1 addition & 0 deletions examples/inference/kaito_workspace_falcon_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ resource:
inference:
preset:
name: "falcon-7b"

18 changes: 18 additions & 0 deletions examples/inference/kaito_workspace_falcon_7b_with_adapters.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
apiVersion: kaito.sh/v1alpha1
kind: Workspace
metadata:
name: workspace-falcon-7b
resource:
instanceType: "Standard_NC12s_v3"
labelSelector:
matchLabels:
apps: falcon-7b
inference:
preset:
name: "falcon-7b"
adapters:
- source:
name: "falcon-7b-adapter"
image: "<YOUR_IMAGE>"
strength: "0.2"

21 changes: 15 additions & 6 deletions pkg/inference/preset-inferences.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ package inference
import (
"context"
"fmt"
"github.com/azure/kaito/pkg/utils"
"os"
"strconv"

"github.com/azure/kaito/pkg/utils"

kaitov1alpha1 "github.com/azure/kaito/api/v1alpha1"
"github.com/azure/kaito/pkg/model"
"github.com/azure/kaito/pkg/resources"
Expand Down Expand Up @@ -56,13 +57,14 @@ var (
tolerations = []corev1.Toleration{
{
Effect: corev1.TaintEffectNoSchedule,
Operator: corev1.TolerationOpEqual,
Key: resources.GPUString,
Operator: corev1.TolerationOpExists,
Key: resources.CapacityNvidiaGPU,
},
{
Effect: corev1.TaintEffectNoSchedule,
Value: resources.GPUString,
Key: "sku",
Effect: corev1.TaintEffectNoSchedule,
Value: resources.GPUString,
Key: "sku",
Operator: corev1.TolerationOpEqual,
},
}
)
Expand Down Expand Up @@ -127,6 +129,13 @@ func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Work
if shmVolumeMount.Name != "" {
volumeMounts = append(volumeMounts, shmVolumeMount)
}

if len(workspaceObj.Inference.Adapters) > 0 {
adapterVolume, adapterVolumeMount := utils.ConfigAdapterVolume()
volumes = append(volumes, adapterVolume)
volumeMounts = append(volumeMounts, adapterVolumeMount)
}

commands, resourceReq := prepareInferenceParameters(ctx, inferenceObj)
image, imagePullSecrets := GetInferenceImageInfo(ctx, workspaceObj, inferenceObj)

Expand Down
Loading

0 comments on commit 4db58ce

Please sign in to comment.