Skip to content

Commit

Permalink
test: add ut for CreatePresetInference (#151)
Browse files Browse the repository at this point in the history
Brutally checking the final command line for happy cases only.

Co-authored-by: guofei <[email protected]>
  • Loading branch information
Fei-Guo and Fei-Guo committed Nov 8, 2023
1 parent 87a4f10 commit 34c730c
Show file tree
Hide file tree
Showing 2 changed files with 212 additions and 1 deletion.
209 changes: 209 additions & 0 deletions pkg/inference/preset-inferences_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

package inference

import (
"context"
"reflect"
"strings"
"testing"

"github.com/azure/kaito/pkg/utils"
"github.com/stretchr/testify/mock"
appsv1 "k8s.io/api/apps/v1"
corev1 "k8s.io/api/core/v1"
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"

kaitov1alpha1 "github.com/azure/kaito/api/v1alpha1"
)

func TestCreatePresetInference(t *testing.T) {

testcases := map[string]struct {
nodeCount int
modelName string
callMocks func(c *utils.MockClient)
workload string
expectedCmd string
}{

"falcon-7b": {
nodeCount: 1,
modelName: "falcon-7b",
callMocks: func(c *utils.MockClient) {
c.On("Create", mock.IsType(context.Background()), mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(nil)
},
workload: "Deployment",
expectedCmd: "/bin/sh -c accelerate launch --use_deepspeed --config_file=config.yaml --num_processes=1 --num_machines=1 --machine_rank=0 --gpu_ids=all inference-api.py",
},
"falcon-7b-instruct": {
nodeCount: 1,
modelName: "falcon-7b-instruct",
callMocks: func(c *utils.MockClient) {
c.On("Create", mock.IsType(context.Background()), mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(nil)
},
workload: "Deployment",
expectedCmd: "/bin/sh -c accelerate launch --use_deepspeed --config_file=config.yaml --num_processes=1 --num_machines=1 --machine_rank=0 --gpu_ids=all inference-api.py",
},
"falcon-40b": {
nodeCount: 1,
modelName: "falcon-40b",
callMocks: func(c *utils.MockClient) {
c.On("Create", mock.IsType(context.Background()), mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(nil)
},
workload: "Deployment",
expectedCmd: "/bin/sh -c accelerate launch --use_deepspeed --num_machines=1 --machine_rank=0 --gpu_ids=all --config_file=config.yaml --num_processes=1 inference-api.py",
},
"falcon-40b-instruct": {
nodeCount: 1,
modelName: "falcon-40b-instruct",
callMocks: func(c *utils.MockClient) {
c.On("Create", mock.IsType(context.Background()), mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(nil)
},
workload: "Deployment",
expectedCmd: "/bin/sh -c accelerate launch --use_deepspeed --config_file=config.yaml --num_processes=1 --num_machines=1 --machine_rank=0 --gpu_ids=all inference-api.py",
},

"llama-7b-chat": {
nodeCount: 1,
modelName: "llama-2-7b-chat",
callMocks: func(c *utils.MockClient) {
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.Service{}), mock.Anything).Return(nil)
c.On("Create", mock.IsType(context.Background()), mock.IsType(&appsv1.StatefulSet{}), mock.Anything).Return(nil)
},
workload: "StatefulSet",
expectedCmd: "/bin/sh -c cd /workspace/llama/llama-2 && torchrun --nnodes=1 --nproc_per_node=1 --node_rank=$(echo $HOSTNAME | grep -o '[^-]*$') --master_addr=10.0.0.1 --master_port=29500 --max_restarts=3 --rdzv_id=job --rdzv_backend=c10d --rdzv_endpoint=testWorkspace-0.testWorkspace-headless.default.svc.cluster.local:29500 inference-api.py --max_seq_len=512 --max_batch_size=8",
},
"llama-13b-chat": {
nodeCount: 1,
modelName: "llama-2-13b-chat",
callMocks: func(c *utils.MockClient) {
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.Service{}), mock.Anything).Return(nil)
c.On("Create", mock.IsType(context.Background()), mock.IsType(&appsv1.StatefulSet{}), mock.Anything).Return(nil)
},
workload: "StatefulSet",
expectedCmd: "/bin/sh -c cd /workspace/llama/llama-2 && torchrun --nnodes=1 --nproc_per_node=2 --node_rank=$(echo $HOSTNAME | grep -o '[^-]*$') --master_addr=10.0.0.1 --master_port=29500 --max_restarts=3 --rdzv_id=job --rdzv_backend=c10d --rdzv_endpoint=testWorkspace-0.testWorkspace-headless.default.svc.cluster.local:29500 inference-api.py --max_seq_len=512 --max_batch_size=8",
},
"llama-70b-chat": {
nodeCount: 2,
modelName: "llama-2-70b-chat",
callMocks: func(c *utils.MockClient) {
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.Service{}), mock.Anything).Return(nil)
c.On("Create", mock.IsType(context.Background()), mock.IsType(&appsv1.StatefulSet{}), mock.Anything).Return(nil)
},
workload: "StatefulSet",
expectedCmd: "/bin/sh -c cd /workspace/llama/llama-2 && torchrun --nproc_per_node=4 --node_rank=$(echo $HOSTNAME | grep -o '[^-]*$') --master_addr=10.0.0.1 --master_port=29500 --nnodes=2 --rdzv_backend=c10d --rdzv_endpoint=testWorkspace-0.testWorkspace-headless.default.svc.cluster.local:29500 --max_restarts=3 --rdzv_id=job inference-api.py --max_seq_len=512 --max_batch_size=8",
},
"llama-7b": {
nodeCount: 1,
modelName: "llama-2-7b",
callMocks: func(c *utils.MockClient) {
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.Service{}), mock.Anything).Return(nil)
c.On("Create", mock.IsType(context.Background()), mock.IsType(&appsv1.StatefulSet{}), mock.Anything).Return(nil)
},
workload: "StatefulSet",
expectedCmd: "/bin/sh -c cd /workspace/llama/llama-2 && torchrun --nnodes=1 --nproc_per_node=1 --node_rank=$(echo $HOSTNAME | grep -o '[^-]*$') --master_addr=10.0.0.1 --master_port=29500 --max_restarts=3 --rdzv_id=job --rdzv_backend=c10d --rdzv_endpoint=testWorkspace-0.testWorkspace-headless.default.svc.cluster.local:29500 inference-api.py --max_seq_len=512 --max_batch_size=8",
},
"llama-13b": {
nodeCount: 1,
modelName: "llama-2-13b",
callMocks: func(c *utils.MockClient) {
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.Service{}), mock.Anything).Return(nil)
c.On("Create", mock.IsType(context.Background()), mock.IsType(&appsv1.StatefulSet{}), mock.Anything).Return(nil)
},
workload: "StatefulSet",
expectedCmd: "/bin/sh -c cd /workspace/llama/llama-2 && torchrun --node_rank=$(echo $HOSTNAME | grep -o '[^-]*$') --master_addr=10.0.0.1 --master_port=29500 --nnodes=1 --nproc_per_node=2 --max_restarts=3 --rdzv_id=job --rdzv_backend=c10d --rdzv_endpoint=testWorkspace-0.testWorkspace-headless.default.svc.cluster.local:29500 inference-api.py --max_batch_size=8 --max_seq_len=512",
},
"llama-70b": {
nodeCount: 2,
modelName: "llama-2-70b",
callMocks: func(c *utils.MockClient) {
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.Service{}), mock.Anything).Return(nil)
c.On("Create", mock.IsType(context.Background()), mock.IsType(&appsv1.StatefulSet{}), mock.Anything).Return(nil)
},
workload: "StatefulSet",
expectedCmd: "/bin/sh -c cd /workspace/llama/llama-2 && torchrun --nproc_per_node=4 --node_rank=$(echo $HOSTNAME | grep -o '[^-]*$') --master_addr=10.0.0.1 --master_port=29500 --nnodes=2 --rdzv_backend=c10d --rdzv_endpoint=testWorkspace-0.testWorkspace-headless.default.svc.cluster.local:29500 --max_restarts=3 --rdzv_id=job inference-api.py --max_seq_len=512 --max_batch_size=8",
},
}

for k, tc := range testcases {
t.Run(k, func(t *testing.T) {
mockClient := utils.NewClient()
tc.callMocks(mockClient)

workspace := utils.MockWorkspace
workspace.Resource.Count = &tc.nodeCount

useHeadlessSvc := false
var inferenceObj PresetInferenceParam
if strings.HasPrefix(tc.modelName, "llama") {
inferenceObj = Llama2PresetInferences[kaitov1alpha1.ModelName(tc.modelName)]
useHeadlessSvc = true
} else {
inferenceObj = FalconPresetInferences[kaitov1alpha1.ModelName(tc.modelName)]
}

svc := &corev1.Service{
ObjectMeta: v1.ObjectMeta{
Name: workspace.Name,
Namespace: workspace.Namespace,
},
Spec: corev1.ServiceSpec{
ClusterIP: "10.0.0.1",
},
}
mockClient.CreateOrUpdateObjectInMap(svc)

createdObject, _ := CreatePresetInference(context.TODO(), workspace, inferenceObj, useHeadlessSvc, mockClient)
createdWorkload := ""
switch createdObject.(type) {
case *appsv1.Deployment:
createdWorkload = "Deployment"
case *appsv1.StatefulSet:
createdWorkload = "StatefulSet"
}
if tc.workload != createdWorkload {
t.Errorf("%s: returned worklaod type is wrong", k)
}

var workloadCmd string
if tc.workload == "Deployment" {
workloadCmd = strings.Join((createdObject.(*appsv1.Deployment)).Spec.Template.Spec.Containers[0].Command, " ")

} else {
workloadCmd = strings.Join((createdObject.(*appsv1.StatefulSet)).Spec.Template.Spec.Containers[0].Command, " ")
}

mainCmd := strings.Split(workloadCmd, "--")[0]
params := toParameterMap(strings.Split(workloadCmd, "--")[1:])

expectedMaincmd := strings.Split(tc.expectedCmd, "--")[0]
expectedParams := toParameterMap(strings.Split(workloadCmd, "--")[1:])

if mainCmd != expectedMaincmd {
t.Errorf("%s main cmdline is not expected, got %s, expect %s ", k, workloadCmd, tc.expectedCmd)
}

if !reflect.DeepEqual(params, expectedParams) {
t.Errorf("%s parameters are not expected, got %s, expect %s ", k, params, expectedParams)
}
})
}
}

func toParameterMap(in []string) map[string]string {
ret := make(map[string]string)
for _, each := range in {
r := strings.Split(each, "=")
k := r[0]
var v string
if len(r) == 1 {
v = ""
} else {
v = r[1]
}
ret[k] = v
}
return ret
}
4 changes: 3 additions & 1 deletion pkg/utils/mockClient.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ func (m *MockClient) GetObjectFromMap(obj k8sClient.Object, key types.Namespaced
// k8s Client interface
func (m *MockClient) Get(ctx context.Context, key types.NamespacedName, obj k8sClient.Object, opts ...k8sClient.GetOption) error {
//make any necessary changes to the object
m.UpdateCb(key)
if m.UpdateCb != nil {
m.UpdateCb(key)
}

m.GetObjectFromMap(obj, key)

Expand Down

0 comments on commit 34c730c

Please sign in to comment.