Skip to content

Commit

Permalink
chore: Refactor to implement model plugin - part 5 (#191)
Browse files Browse the repository at this point in the history
This PR updates the crd validation UT to use the test model.

Co-authored-by: guofei <[email protected]>
  • Loading branch information
Fei-Guo and Fei-Guo committed Dec 15, 2023
1 parent c1151fd commit a0f963e
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 115 deletions.
2 changes: 1 addition & 1 deletion api/v1alpha1/workspace_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (r *ResourceSpec) validateCreate(inference InferenceSpec) (errs *apis.Field
}
skuPerGPUMemory := skuConfig.GPUMem / skuConfig.GPUCount
if int64(skuPerGPUMemory) < modelPerGPUMemory.ScaledValue(resource.Giga) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Insufficient GPU memory: Instance type %s provides %d per GPU, but preset %s requires at least %d per GPU", instanceType, skuPerGPUMemory, presetName, modelPerGPUMemory.ScaledValue(resource.Giga)), "instanceType"))
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Insufficient per GPU memory: Instance type %s provides %d per GPU, but preset %s requires at least %d per GPU", instanceType, skuPerGPUMemory, presetName, modelPerGPUMemory.ScaledValue(resource.Giga)), "instanceType"))
}
if int64(totalGPUMem) < modelTotalGPUMemory.ScaledValue(resource.Giga) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Insufficient total GPU memory: Instance type %s has a total of %d, but preset %s requires at least %d", instanceType, totalGPUMem, presetName, modelTotalGPUMemory.ScaledValue(resource.Giga)), "instanceType"))
Expand Down
214 changes: 100 additions & 114 deletions api/v1alpha1/workspace_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,84 @@

package v1alpha1

/*
import (
"reflect"
"sort"
"strings"
"testing"

"github.com/azure/kaito/pkg/model"
"github.com/azure/kaito/pkg/utils/plugin"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

var gpuCountRequirement string
var totalGPUMemoryRequirement string
var perGPUMemoryRequirement string

type testModel struct{}

func (*testModel) GetInferenceParameters() *model.PresetInferenceParam {
return &model.PresetInferenceParam{
GPUCountRequirement: gpuCountRequirement,
TotalGPUMemoryRequirement: totalGPUMemoryRequirement,
PerGPUMemoryRequirement: perGPUMemoryRequirement,
}
}
func (*testModel) SupportDistributedInference() bool {
return false
}

func RegisterValidationTestModel() {
var test testModel
plugin.KaitoModelRegister.Register(&plugin.Registration{
Name: "test-validation",
Instance: &test,
})
}

func pointerToInt(i int) *int {
return &i
}

func TestResourceSpecValidateCreate(t *testing.T) {
RegisterValidationTestModel()
tests := []struct {
name string
resourceSpec *ResourceSpec
inferenceSpec *InferenceSpec
errContent string // Content expect error to include, if any
expectErrs bool
name string
resourceSpec *ResourceSpec
modelGPUCount string
modelPerGPUMemory string
modelTotalGPUMemory string
preset bool
errContent string // Content expect error to include, if any
expectErrs bool
}{
{
name: "Valid resource",
resourceSpec: &ResourceSpec{
InstanceType: "Standard_ND96asr_v4",
Count: pointerToInt(1),
},
modelGPUCount: "8",
modelPerGPUMemory: "19Gi",
modelTotalGPUMemory: "152Gi",
preset: true,
errContent: "",
expectErrs: false,
},
{
name: "Insufficient total GPU memory",
resourceSpec: &ResourceSpec{
InstanceType: "Standard_NC6",
Count: pointerToInt(1),
},
inferenceSpec: &InferenceSpec{
Preset: &PresetSpec{
PresetMeta: PresetMeta{
Name: ModelName("falcon-7b"),
},
},
},
errContent: "Insufficient total GPU memory",
expectErrs: true,
modelGPUCount: "1",
modelPerGPUMemory: "0",
modelTotalGPUMemory: "14Gi",
preset: true,
errContent: "Insufficient total GPU memory",
expectErrs: true,
},

{
Expand All @@ -49,15 +89,25 @@ func TestResourceSpecValidateCreate(t *testing.T) {
InstanceType: "Standard_NC24ads_A100_v4",
Count: pointerToInt(1),
},
inferenceSpec: &InferenceSpec{
Preset: &PresetSpec{
PresetMeta: PresetMeta{
Name: ModelName("llama-2-13b-chat"),
},
},
},
errContent: "Insufficient number of GPUs",
expectErrs: true,
modelGPUCount: "2",
modelPerGPUMemory: "15Gi",
modelTotalGPUMemory: "30Gi",
preset: true,
errContent: "Insufficient number of GPUs",
expectErrs: true,
},
{
name: "Insufficient per GPU memory",
resourceSpec: &ResourceSpec{
InstanceType: "Standard_NC6",
Count: pointerToInt(2),
},
modelGPUCount: "1",
modelPerGPUMemory: "15Gi",
modelTotalGPUMemory: "15Gi",
preset: true,
errContent: "Insufficient per GPU memory",
expectErrs: true,
},

{
Expand All @@ -66,13 +116,6 @@ func TestResourceSpecValidateCreate(t *testing.T) {
InstanceType: "Standard_invalid_sku",
Count: pointerToInt(1),
},
inferenceSpec: &InferenceSpec{
Preset: &PresetSpec{
PresetMeta: PresetMeta{
Name: ModelName("llama-2-70b"),
},
},
},
errContent: "Unsupported instance",
expectErrs: true,
},
Expand All @@ -82,59 +125,16 @@ func TestResourceSpecValidateCreate(t *testing.T) {
InstanceType: "Standard_NV12s_v3",
Count: pointerToInt(1),
},
inferenceSpec: &InferenceSpec{
Template: &v1.PodTemplateSpec{}, // Assuming a non-nil TemplateSpec implies it's set
},
preset: false,
errContent: "",
expectErrs: false,
},
{
name: "Invalid Preset",
resourceSpec: &ResourceSpec{
InstanceType: "Standard_NV12s_v3",
Count: pointerToInt(1),
},
inferenceSpec: &InferenceSpec{
Preset: &PresetSpec{
PresetMeta: PresetMeta{
Name: ModelName("invalid-preset"),
},
},
},
errContent: "Unsupported preset",
expectErrs: true,
},
{
name: "Invalid SKU",
resourceSpec: &ResourceSpec{
InstanceType: "Standard_invalid_sku",
Count: pointerToInt(1),
},
inferenceSpec: &InferenceSpec{
Preset: &PresetSpec{
PresetMeta: PresetMeta{
Name: ModelName("llama-2-70b"),
},
},
},
errContent: "Unsupported instance",
expectErrs: true,
},
{
name: "N-Prefix SKU",
resourceSpec: &ResourceSpec{
InstanceType: "Standard_Nsku",
Count: pointerToInt(1),
},
inferenceSpec: &InferenceSpec{
Preset: &PresetSpec{
PresetMeta: PresetMeta{
Name: ModelName("llama-2-7b"),
},
},
},
errContent: "",
expectErrs: false,
},
Expand All @@ -145,21 +145,34 @@ func TestResourceSpecValidateCreate(t *testing.T) {
InstanceType: "Standard_Dsku",
Count: pointerToInt(1),
},
inferenceSpec: &InferenceSpec{
Preset: &PresetSpec{
PresetMeta: PresetMeta{
Name: ModelName("llama-2-7b"),
},
},
},
errContent: "",
expectErrs: false,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
errs := tc.resourceSpec.validateCreate(*tc.inferenceSpec)
var spec InferenceSpec

if tc.preset {
spec = InferenceSpec{
Preset: &PresetSpec{
PresetMeta: PresetMeta{
Name: ModelName("test-validation"),
},
},
}
} else {
spec = InferenceSpec{
Template: &v1.PodTemplateSpec{}, // Assuming a non-nil TemplateSpec implies it's set
}
}

gpuCountRequirement = tc.modelGPUCount
totalGPUMemoryRequirement = tc.modelTotalGPUMemory
perGPUMemoryRequirement = tc.modelPerGPUMemory

errs := tc.resourceSpec.validateCreate(spec)
hasErrs := errs != nil
if hasErrs != tc.expectErrs {
t.Errorf("validateCreate() errors = %v, expectErrs %v", errs, tc.expectErrs)
Expand Down Expand Up @@ -256,6 +269,7 @@ func TestResourceSpecValidateUpdate(t *testing.T) {
}

func TestInferenceSpecValidateCreate(t *testing.T) {
RegisterValidationTestModel()
tests := []struct {
name string
inferenceSpec *InferenceSpec
Expand Down Expand Up @@ -293,7 +307,7 @@ func TestInferenceSpecValidateCreate(t *testing.T) {
inferenceSpec: &InferenceSpec{
Preset: &PresetSpec{
PresetMeta: PresetMeta{
Name: ModelName("falcon-7b"),
Name: ModelName("test-validation"),
},
},
Template: &v1.PodTemplateSpec{}, // Assuming a non-nil TemplateSpec implies it's set
Expand All @@ -306,7 +320,7 @@ func TestInferenceSpecValidateCreate(t *testing.T) {
inferenceSpec: &InferenceSpec{
Preset: &PresetSpec{
PresetMeta: PresetMeta{
Name: ModelName("llama-2-7b"),
Name: ModelName("test-validation"),
AccessMode: "private",
},
PresetOptions: PresetOptions{},
Expand All @@ -320,7 +334,7 @@ func TestInferenceSpecValidateCreate(t *testing.T) {
inferenceSpec: &InferenceSpec{
Preset: &PresetSpec{
PresetMeta: PresetMeta{
Name: ModelName("falcon-7b"),
Name: ModelName("test-validation"),
AccessMode: "public",
},
},
Expand Down Expand Up @@ -475,31 +489,3 @@ func TestGetSupportedSKUs(t *testing.T) {
})
}
}
func TestIsValidPreset(t *testing.T) {
tests := []struct {
name string
preset string
expectValid bool
}{
{
name: "valid preset",
preset: "falcon-7b",
expectValid: true,
},
{
name: "invalid preset",
preset: "nonexistent-preset",
expectValid: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if valid := isValidPreset(test.preset); valid != test.expectValid {
t.Errorf("isValidPreset(%s) = %v, want %v", test.preset, valid, test.expectValid)
}
})
}
}
*/

0 comments on commit a0f963e

Please sign in to comment.