Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
fix vpc access for cloud2 v2 worker
Browse files Browse the repository at this point in the history
  • Loading branch information
jakekaplan committed Apr 18, 2024
1 parent 47d54ea commit 69e4758
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 15 deletions.
22 changes: 14 additions & 8 deletions prefect_gcp/workers/cloud_run_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def _get_default_job_body_template() -> Dict[str, Any]:
"serviceAccount": "{{ service_account_name }}",
"maxRetries": "{{ max_retries }}",
"timeout": "{{ timeout }}",
"vpcAccess": "{{ vpc_connector_name }}",
"vpcAccess":
{"connector": "{{ vpc_connector_name }}"},
"containers": [
{
"env": [],
Expand Down Expand Up @@ -184,7 +185,7 @@ def prepare_for_flow_run(
self._format_args_if_present()
self._populate_image_if_not_present()
self._populate_timeout()
self._populate_vpc_if_present()
self._remove_vpc_access_if_unset()

def _populate_timeout(self):
"""
Expand Down Expand Up @@ -235,14 +236,19 @@ def _format_args_if_present(self):
"args"
] = shlex.split(args)

def _populate_vpc_if_present(self):
def _remove_vpc_access_if_unset(self):
"""
Populates the job body with the VPC connector if present.
Removes vpcAccess if unset.
"""
if self.job_body["template"]["template"].get("vpcAccess") is not None:
self.job_body["template"]["template"]["vpcAccess"] = {
"connector": self.job_body["template"]["template"]["vpcAccess"],
}
vpc_access = self.job_body["template"]["template"].get("vpcAccess")

if not vpc_access:
return

# if connector is the only key and it's not set, we'll remove it.
# otherwise we'll pass whatever the user has provided.
if len(vpc_access) == 1 and vpc_access.get("connector") is None:
self.job_body["template"]["template"]["vpcAccess"] = None

# noinspection PyMethodParameters
@validator("job_body")
Expand Down
52 changes: 45 additions & 7 deletions tests/test_cloud_run_worker_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ def job_body():
"template": {
"maxRetries": None,
"timeout": None,
"vpcAccess": "projects/my_project/locations/us-central1/connectors/my-connector", # noqa: E501
"vpcAccess": {
"connector": None,
},
"containers": [
{
"env": [],
Expand Down Expand Up @@ -122,12 +124,48 @@ def test_format_args_if_present(self, cloud_run_worker_v2_job_config):
"containers"
][0]["args"] == ["-m", "prefect.engine"]

def test_populate_vpc_if_present(self, cloud_run_worker_v2_job_config):
cloud_run_worker_v2_job_config._populate_vpc_if_present()

def test_remove_vpc_access_if_unset(self, cloud_run_worker_v2_job_config):
assert cloud_run_worker_v2_job_config.job_body["template"]["template"][
"vpcAccess"
] == {"connector": None}

cloud_run_worker_v2_job_config._remove_vpc_access_if_unset()

assert (
cloud_run_worker_v2_job_config.job_body["template"]["template"][
"vpcAccess"
]["connector"]
== "projects/my_project/locations/us-central1/connectors/my-connector"
cloud_run_worker_v2_job_config.job_body["template"]["template"]["vpcAccess"]
is None
)

def test_vpc_access_left_alone_if_connector_set(
self, cloud_run_worker_v2_job_config
):
cloud_run_worker_v2_job_config.job_body["template"]["template"]["vpcAccess"][
"connector"
] = "projects/my_project/locations/us-central1/connectors/my-connector"

cloud_run_worker_v2_job_config._remove_vpc_access_if_unset()

assert cloud_run_worker_v2_job_config.job_body["template"]["template"][
"vpcAccess"
] == {
"connector": "projects/my_project/locations/us-central1/connectors/my-connector"
}

def test_vpc_access_left_alone_if_network_config_set(
self, cloud_run_worker_v2_job_config
):
cloud_run_worker_v2_job_config.job_body["template"]["template"]["vpcAccess"][
"networkInterfaces"
] = [{"network": "projects/my_project/global/networks/my-network"}]

cloud_run_worker_v2_job_config._remove_vpc_access_if_unset()

assert cloud_run_worker_v2_job_config.job_body["template"]["template"][
"vpcAccess"
] == {
"connector": None,
"networkInterfaces": [
{"network": "projects/my_project/global/networks/my-network"}
],
}

0 comments on commit 69e4758

Please sign in to comment.