Skip to content

Commit

Permalink
chore: Refactor to implement model plugin - part 3 (#189)
Browse files Browse the repository at this point in the history
This change activates all model plugins. 

- Remove all hard coded model name in workspace_controller.go by using
the plugin modelregister.
- Fix related UTs 
- Change the model interface for better abstraction
- Change to use Deployment for llama2-7b models.

---------

Co-authored-by: guofei <[email protected]>
  • Loading branch information
Fei-Guo and Fei-Guo committed Dec 13, 2023
1 parent 2d0c593 commit e0f3a35
Show file tree
Hide file tree
Showing 15 changed files with 162 additions and 256 deletions.
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

# Image URL to use all building/pushing image targets
REGISTRY ?= mcr.microsoft.com/aks/kaito
REGISTRY ?= YOUR_REGISTRY
IMG_NAME ?= workspace
VERSION ?= v0.1.0
IMG_TAG ?= $(subst v,,$(VERSION))
Expand Down Expand Up @@ -132,7 +132,7 @@ az-patch-install-helm: ## Update Azure client env vars and settings in helm valu

.PHONY: build
build: manifests generate fmt vet ## Build manager binary.
go build -o bin/manager cmd/main.go
go build -o bin/manager cmd/*.go

.PHONY: run
run: manifests generate fmt vet ## Run a controller from your host.
Expand Down Expand Up @@ -260,4 +260,4 @@ release-manifest:

.PHONY: clean
clean:
@rm -rf $(BIN_DIR)
@rm -rf $(BIN_DIR)
5 changes: 3 additions & 2 deletions docker/kaito/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ RUN \
go mod download

# Copy the go source
COPY cmd/main.go cmd/main.go
COPY cmd/ cmd/
COPY api/ api/
COPY pkg/ pkg/
COPY presets/ presets/

# Build
# the GOARCH has not a default value to allow the binary be built according to the host where the command
Expand All @@ -27,7 +28,7 @@ COPY pkg/ pkg/
# by leaving it empty we can ensure that the container and binary shipped on it will have the same platform.
RUN --mount=type=cache,target=${GOCACHE} \
--mount=type=cache,id=kaito-controller,sharing=locked,target=/go/pkg/mod \
CGO_ENABLED=0 GOOS=${TARGETOS:-linux} GOARCH=${TARGETARCH} GO111MODULE=on go build -a -o manager cmd/main.go
CGO_ENABLED=0 GOOS=${TARGETOS:-linux} GOARCH=${TARGETARCH} GO111MODULE=on go build -a -o manager cmd/*.go

# Use distroless as minimal base image to package the manager binary
# Refer to https://github.com/GoogleContainerTools/distroless for more details
Expand Down
106 changes: 34 additions & 72 deletions pkg/controllers/workspace_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/azure/kaito/pkg/machine"
"github.com/azure/kaito/pkg/resources"
"github.com/azure/kaito/pkg/utils"
"github.com/azure/kaito/pkg/utils/plugin"
"github.com/go-logr/logr"
"github.com/samber/lo"
corev1 "k8s.io/api/core/v1"
Expand Down Expand Up @@ -92,12 +93,8 @@ func (c *WorkspaceReconciler) addOrUpdateWorkspace(ctx context.Context, wObj *ka
}
return reconcile.Result{}, err
}
useHeadlessService := false
if wObj.Inference.Preset != nil && strings.Contains(string(wObj.Inference.Preset.Name), "llama") {
useHeadlessService = true
}

if err := c.ensureService(ctx, wObj, useHeadlessService); err != nil {
if err := c.ensureService(ctx, wObj); err != nil {
if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeReady, metav1.ConditionFalse,
"workspaceFailed", err.Error()); updateErr != nil {
klog.ErrorS(updateErr, "failed to update workspace status", "workspace", klog.KObj(wObj))
Expand All @@ -106,7 +103,7 @@ func (c *WorkspaceReconciler) addOrUpdateWorkspace(ctx context.Context, wObj *ka
return reconcile.Result{}, err
}

if err = c.applyInference(ctx, wObj, useHeadlessService); err != nil {
if err = c.applyInference(ctx, wObj); err != nil {
if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeReady, metav1.ConditionFalse,
"workspaceFailed", err.Error()); updateErr != nil {
klog.ErrorS(updateErr, "failed to update workspace status", "workspace", klog.KObj(wObj))
Expand Down Expand Up @@ -301,15 +298,8 @@ func (c *WorkspaceReconciler) validateNodeInstanceType(ctx context.Context, wObj
func (c *WorkspaceReconciler) createAndValidateNode(ctx context.Context, wObj *kaitov1alpha1.Workspace) (*corev1.Node, error) {
var machineOSDiskSize string
if wObj.Inference.Preset != nil && wObj.Inference.Preset.Name != "" {
presetName := wObj.Inference.Preset.Name
if _, exists := inference.Llama2PresetInferences[presetName]; exists {
machineOSDiskSize = inference.Llama2PresetInferences[presetName].DiskStorageRequirement
} else if _, exists := inference.FalconPresetInferences[presetName]; exists {
machineOSDiskSize = inference.FalconPresetInferences[presetName].DiskStorageRequirement
} else {
err := fmt.Errorf("preset model %s is not supported", presetName)
return nil, err
}
presetName := string(wObj.Inference.Preset.Name)
machineOSDiskSize = plugin.KaitoModelRegister.MustGet(presetName).GetInferenceParameters().DiskStorageRequirement
}
if machineOSDiskSize == "" {
machineOSDiskSize = "0" // The default OS size is used
Expand Down Expand Up @@ -382,7 +372,7 @@ func (c *WorkspaceReconciler) ensureNodePlugins(ctx context.Context, wObj *kaito
}
}

func (c *WorkspaceReconciler) ensureService(ctx context.Context, wObj *kaitov1alpha1.Workspace, useHeadlessService bool) error {
func (c *WorkspaceReconciler) ensureService(ctx context.Context, wObj *kaitov1alpha1.Workspace) error {
serviceType := corev1.ServiceTypeClusterIP
wAnnotation := wObj.GetAnnotations()

Expand All @@ -402,69 +392,41 @@ func (c *WorkspaceReconciler) ensureService(ctx context.Context, wObj *kaitov1al
} else {
return nil
}
var isStatefulSet bool

if wObj.Inference.Preset != nil {
isStatefulSet = strings.Contains(string(wObj.Inference.Preset.Name), "llama")
}
serviceObj := resources.GenerateServiceManifest(ctx, wObj, serviceType, isStatefulSet)
err = resources.CreateResource(ctx, serviceObj, c.Client)
if err != nil {
return err
}
if useHeadlessService {
headlessService := resources.GenerateHeadlessServiceManifest(ctx, wObj)
err = resources.CreateResource(ctx, headlessService, c.Client)
presetName := string(wObj.Inference.Preset.Name)
model := plugin.KaitoModelRegister.MustGet(presetName)
serviceObj := resources.GenerateServiceManifest(ctx, wObj, serviceType, model.SupportDistributedInference())
err = resources.CreateResource(ctx, serviceObj, c.Client)
if err != nil {
return err
}
if model.SupportDistributedInference() {
headlessService := resources.GenerateHeadlessServiceManifest(ctx, wObj)
err = resources.CreateResource(ctx, headlessService, c.Client)
if err != nil {
return err
}
}
}
return nil
}

func (c *WorkspaceReconciler) getInferenceObjFromPreset(ctx context.Context, wObj *kaitov1alpha1.Workspace) (inference.PresetInferenceParam, error) {
presetName := wObj.Inference.Preset.Name
var inferenceObj inference.PresetInferenceParam
switch presetName {
case kaitov1alpha1.PresetLlama2AModel:
inferenceObj = inference.Llama2PresetInferences[kaitov1alpha1.PresetLlama2AModel]
case kaitov1alpha1.PresetLlama2BModel:
inferenceObj = inference.Llama2PresetInferences[kaitov1alpha1.PresetLlama2BModel]
case kaitov1alpha1.PresetLlama2CModel:
inferenceObj = inference.Llama2PresetInferences[kaitov1alpha1.PresetLlama2CModel]
case kaitov1alpha1.PresetLlama2AChat:
inferenceObj = inference.Llama2PresetInferences[kaitov1alpha1.PresetLlama2AChat]
case kaitov1alpha1.PresetLlama2BChat:
inferenceObj = inference.Llama2PresetInferences[kaitov1alpha1.PresetLlama2BChat]
case kaitov1alpha1.PresetLlama2CChat:
inferenceObj = inference.Llama2PresetInferences[kaitov1alpha1.PresetLlama2CChat]
case kaitov1alpha1.PresetFalcon7BModel:
inferenceObj = inference.FalconPresetInferences[kaitov1alpha1.PresetFalcon7BModel]
case kaitov1alpha1.PresetFalcon7BInstructModel:
inferenceObj = inference.FalconPresetInferences[kaitov1alpha1.PresetFalcon7BInstructModel]
case kaitov1alpha1.PresetFalcon40BModel:
inferenceObj = inference.FalconPresetInferences[kaitov1alpha1.PresetFalcon40BModel]
case kaitov1alpha1.PresetFalcon40BInstructModel:
inferenceObj = inference.FalconPresetInferences[kaitov1alpha1.PresetFalcon40BInstructModel]
default:
err := fmt.Errorf("preset model %s is not supported", presetName)
return inference.PresetInferenceParam{}, err
}

inferenceObj.AccessMode = string(wObj.Inference.Preset.PresetMeta.AccessMode)
if inferenceObj.AccessMode == "private" && wObj.Inference.Preset.PresetOptions.Image != "" {
inferenceObj.Image = wObj.Inference.Preset.PresetOptions.Image
func (c *WorkspaceReconciler) updateInferenceParamFromWorkspace(ctx context.Context, wObj *kaitov1alpha1.Workspace, inferenceParam *inference.PresetInferenceParam) {
inferenceParam.AccessMode = string(wObj.Inference.Preset.PresetMeta.AccessMode)
if inferenceParam.AccessMode == "private" && wObj.Inference.Preset.PresetOptions.Image != "" {
inferenceParam.Image = wObj.Inference.Preset.PresetOptions.Image

imagePullSecretRefs := []corev1.LocalObjectReference{}
for _, secretName := range wObj.Inference.Preset.PresetOptions.ImagePullSecrets {
imagePullSecretRefs = append(imagePullSecretRefs, corev1.LocalObjectReference{Name: secretName})
}
inferenceObj.ImagePullSecrets = imagePullSecretRefs
inferenceParam.ImagePullSecrets = imagePullSecretRefs
}
return inferenceObj, nil
}

// applyInference applies inference spec.
func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1alpha1.Workspace, useHeadlessService bool) error {
func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1alpha1.Workspace) error {
var err error
func() {
if wObj.Inference.Template != nil {
Expand All @@ -478,16 +440,16 @@ func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1a
return
}
} else if wObj.Inference.Preset != nil {
presetName := string(wObj.Inference.Preset.Name)
model := plugin.KaitoModelRegister.MustGet(presetName)

inferenceParam := model.GetInferenceParameters()

c.updateInferenceParamFromWorkspace(ctx, wObj, inferenceParam)
// TODO: we only do create if it does not exist for preset model. Need to document it.
var inferenceObj inference.PresetInferenceParam
inferenceObj, err = c.getInferenceObjFromPreset(ctx, wObj)
if err != nil {
klog.ErrorS(err, "unable to retrieve inference object from preset", "workspace", klog.KObj(wObj))
return
}

var existingObj client.Object
if inferenceObj.ModelName == "LLaMa2" {
if model.SupportDistributedInference() {
existingObj = &appsv1.StatefulSet{}
} else {
existingObj = &appsv1.Deployment{}
Expand All @@ -496,17 +458,17 @@ func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1a

if err = resources.GetResource(ctx, wObj.Name, wObj.Namespace, c.Client, existingObj); err == nil {
klog.InfoS("An inference workload already exists for workspace", "workspace", klog.KObj(wObj))
if err = resources.CheckResourceStatus(existingObj, c.Client, inferenceObj.DeploymentTimeout); err != nil {
if err = resources.CheckResourceStatus(existingObj, c.Client, inferenceParam.DeploymentTimeout); err != nil {
return
}
} else if apierrors.IsNotFound(err) {
var workloadObj client.Object
// Need to create a new workload
workloadObj, err = inference.CreatePresetInference(ctx, wObj, inferenceObj, useHeadlessService, c.Client)
workloadObj, err = inference.CreatePresetInference(ctx, wObj, inferenceParam, model.SupportDistributedInference(), c.Client)
if err != nil {
return
}
if err = resources.CheckResourceStatus(workloadObj, c.Client, inferenceObj.DeploymentTimeout); err != nil {
if err = resources.CheckResourceStatus(workloadObj, c.Client, inferenceParam.DeploymentTimeout); err != nil {
return
}
}
Expand Down
Loading

0 comments on commit e0f3a35

Please sign in to comment.