Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
adding boot_disk_type and boot_disk_size_gb (#177)
Browse files Browse the repository at this point in the history
* adding boot_disk_type and boot_disk_size_gb

there params now will be standard on vertexai block, with defaults set to how they were implicitly in the code before.

* param fix and linting

* test now passes

---------

Co-authored-by: Alexander Streed <[email protected]>
  • Loading branch information
acgourley and desertaxle committed May 23, 2023
1 parent 1a586e0 commit 56bd219
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
21 changes: 18 additions & 3 deletions prefect_gcp/aiplatform.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
)
from google.cloud.aiplatform_v1.types.job_service import CancelCustomJobRequest
from google.cloud.aiplatform_v1.types.job_state import JobState
from google.cloud.aiplatform_v1.types.machine_resources import MachineSpec
from google.cloud.aiplatform_v1.types.machine_resources import DiskSpec, MachineSpec
from google.protobuf.duration_pb2 import Duration
except ModuleNotFoundError:
pass
Expand Down Expand Up @@ -141,6 +141,16 @@ class VertexAICustomTrainingJob(Infrastructure):
accelerator_count: Optional[int] = Field(
default=None, description="The number of accelerators to attach to the machine."
)
boot_disk_type: str = Field(
default="pd-ssd",
title="Boot Disk Type",
description="The type of boot disk to attach to the machine.",
)
boot_disk_size_gb: int = Field(
default=100,
title="Boot Disk Size",
description="The size of the boot disk to attach to the machine, in gigabytes.",
)
maximum_run_time: datetime.timedelta = Field(
default=datetime.timedelta(days=7), description="The maximum job running time."
)
Expand Down Expand Up @@ -223,9 +233,14 @@ def _build_job_spec(self) -> "CustomJobSpec":
accelerator_count=self.accelerator_count,
)
worker_pool_spec = WorkerPoolSpec(
container_spec=container_spec, machine_spec=machine_spec, replica_count=1
container_spec=container_spec,
machine_spec=machine_spec,
replica_count=1,
disk_spec=DiskSpec(
boot_disk_type=self.boot_disk_type,
boot_disk_size_gb=self.boot_disk_size_gb,
),
)

# look for service account
service_account = (
self.service_account or self.gcp_credentials._service_account_email
Expand Down
4 changes: 4 additions & 0 deletions tests/test_aiplatform.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def test_preview(
machine_type: "n1-standard-4"
}
replica_count: 1
disk_spec {
boot_disk_type: "pd-ssd"
boot_disk_size_gb: 100
}
}
scheduling {
}
Expand Down

0 comments on commit 56bd219

Please sign in to comment.