Skip to content

Commit

Permalink
fix: filepath for inference file (#104)
Browse files Browse the repository at this point in the history
Small fix this filepath was standardized in the image and now needs to
be reflected in KAITO
  • Loading branch information
ishaansehgal99 committed Oct 25, 2023
1 parent b1a1dc1 commit e81db11
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions pkg/inference/preset-inference-types.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ var (
presetFalcon40bImage = registryName + fmt.Sprintf("/%s:latest", kaitov1alpha1.PresetFalcon40BModel)
presetFalcon40bInstructImage = registryName + fmt.Sprintf("/%s:latest", kaitov1alpha1.PresetFalcon40BInstructModel)

baseCommandPresetLlama2AChat = fmt.Sprintf("cd /workspace/llama/%s && torchrun", kaitov1alpha1.PresetLlama2AChat)
baseCommandPresetLlama2BChat = fmt.Sprintf("cd /workspace/llama/%s && torchrun", kaitov1alpha1.PresetLlama2BChat)
baseCommandPresetLlama2CChat = fmt.Sprintf("cd /workspace/llama/%s && torchrun", kaitov1alpha1.PresetLlama2CChat)
baseCommandPresetLlama = "cd /workspace/llama/llama-2 && torchrun"
// llamaTextInferenceFile = "inference-api.py" TODO: To support Text Generation Llama Models
llamaChatInferenceFile = "inference-api.py"
llamaRunParams = map[string]string{
Expand Down Expand Up @@ -105,7 +103,7 @@ var (
ModelRunParams: llamaRunParams,
InferenceFile: llamaChatInferenceFile,
DeploymentTimeout: time.Duration(10) * time.Minute,
BaseCommand: baseCommandPresetLlama2AChat,
BaseCommand: baseCommandPresetLlama,
WorldSize: 1,
DefaultVolumeMountPath: "/dev/shm",
},
Expand All @@ -119,7 +117,7 @@ var (
ModelRunParams: llamaRunParams,
InferenceFile: llamaChatInferenceFile,
DeploymentTimeout: time.Duration(20) * time.Minute,
BaseCommand: baseCommandPresetLlama2BChat,
BaseCommand: baseCommandPresetLlama,
WorldSize: 2,
DefaultVolumeMountPath: "/dev/shm",
},
Expand All @@ -133,7 +131,7 @@ var (
ModelRunParams: llamaRunParams,
InferenceFile: llamaChatInferenceFile,
DeploymentTimeout: time.Duration(30) * time.Minute,
BaseCommand: baseCommandPresetLlama2CChat,
BaseCommand: baseCommandPresetLlama,
WorldSize: 8,
DefaultVolumeMountPath: "/dev/shm",
},
Expand Down

0 comments on commit e81db11

Please sign in to comment.