Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add accelerator_count for VertexAICustomTrainingJob (#174)
Browse files Browse the repository at this point in the history
* add accelerator count

* update changelog

* added test
  • Loading branch information
jeremy-thomas-roc committed May 2, 2023
1 parent a4bfa72 commit 461aef7
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased

### Added
- Added optional `accelerator_count` property for `VertexAICustomTrainingJob`.

### Changed

Expand Down
7 changes: 6 additions & 1 deletion prefect_gcp/aiplatform.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ class VertexAICustomTrainingJob(Infrastructure):
accelerator_type: Optional[str] = Field(
default=None, description="The type of accelerator to attach to the machine."
)
accelerator_count: Optional[int] = Field(
default=None, description="The number of accelerators to attach to the machine."
)
maximum_run_time: datetime.timedelta = Field(
default=datetime.timedelta(days=7), description="The maximum job running time."
)
Expand Down Expand Up @@ -215,7 +218,9 @@ def _build_job_spec(self) -> "CustomJobSpec":
image_uri=self.image, command=self.command, args=[], env=env_list
)
machine_spec = MachineSpec(
machine_type=self.machine_type, accelerator_type=self.accelerator_type
machine_type=self.machine_type,
accelerator_type=self.accelerator_type,
accelerator_count=self.accelerator_count,
)
worker_pool_spec = WorkerPoolSpec(
container_spec=container_spec, machine_spec=machine_spec, replica_count=1
Expand Down
15 changes: 15 additions & 0 deletions tests/test_aiplatform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from unittest.mock import MagicMock, patch

import pytest
from google.cloud.aiplatform_v1.types.accelerator_type import AcceleratorType
from google.cloud.aiplatform_v1.types.job_state import JobState
from prefect.exceptions import InfrastructureNotFound

Expand Down Expand Up @@ -140,3 +141,17 @@ def test_run_error(self, vertex_ai_custom_training_job: VertexAICustomTrainingJo
)
with pytest.raises(RuntimeError, match="my error msg"):
vertex_ai_custom_training_job.run()

def test_machine_spec(
self, vertex_ai_custom_training_job: VertexAICustomTrainingJob
):
vertex_ai_custom_training_job.accelerator_count = 1
vertex_ai_custom_training_job.accelerator_type = "NVIDIA_TESLA_T4"

job_spec = vertex_ai_custom_training_job._build_job_spec()

assert job_spec.worker_pool_specs[0].machine_spec.accelerator_count == 1
assert (
job_spec.worker_pool_specs[0].machine_spec.accelerator_type
== AcceleratorType.NVIDIA_TESLA_T4
)

0 comments on commit 461aef7

Please sign in to comment.