From 9ba9b06f303b0dc2a1aea94aa0f4915f1f951735 Mon Sep 17 00:00:00 2001 From: "jake@prefect.io" Date: Mon, 22 Apr 2024 12:17:02 -0400 Subject: [PATCH] fix function, add more tests --- prefect_gcp/workers/cloud_run_v2.py | 15 ++++++++++----- tests/test_cloud_run_worker_v2.py | 23 +++++++++++++++++++---- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/prefect_gcp/workers/cloud_run_v2.py b/prefect_gcp/workers/cloud_run_v2.py index eb97afa..01355f8 100644 --- a/prefect_gcp/workers/cloud_run_v2.py +++ b/prefect_gcp/workers/cloud_run_v2.py @@ -239,14 +239,19 @@ def _remove_vpc_access_if_unset(self): """ Removes vpcAccess if unset. """ - vpc_access = self.job_body["template"]["template"].get("vpcAccess") - if not vpc_access: + if "vpcAccess" not in self.job_body["template"]["template"]: return - # if connector is the only key and it's not set, we'll remove it. - # otherwise we'll pass whatever the user has provided. - if len(vpc_access) == 1 and vpc_access.get("connector") is None: + vpc_access = self.job_body["template"]["template"]["vpcAccess"] + + # if vpcAccess is unset or connector is unset, remove the entire vpcAccess block + # otherwise leave the user provided value. + if not vpc_access or ( + len(vpc_access) == 1 + and "connector" in vpc_access + and vpc_access["connector"] is None + ): self.job_body["template"]["template"].pop("vpcAccess") # noinspection PyMethodParameters diff --git a/tests/test_cloud_run_worker_v2.py b/tests/test_cloud_run_worker_v2.py index ab3309c..75eb4ba 100644 --- a/tests/test_cloud_run_worker_v2.py +++ b/tests/test_cloud_run_worker_v2.py @@ -124,10 +124,25 @@ def test_format_args_if_present(self, cloud_run_worker_v2_job_config): "containers" ][0]["args"] == ["-m", "prefect.engine"] - def test_remove_vpc_access_if_unset(self, cloud_run_worker_v2_job_config): - assert cloud_run_worker_v2_job_config.job_body["template"]["template"][ + @pytest.mark.parametrize("vpc_access", [{"connector": None}, {}, None]) + def test_remove_vpc_access_if_connector_unset( + self, cloud_run_worker_v2_job_config, vpc_access + ): + cloud_run_worker_v2_job_config.job_body["template"]["template"][ + "vpcAccess" + ] = vpc_access + + cloud_run_worker_v2_job_config._remove_vpc_access_if_unset() + + assert ( "vpcAccess" - ] == {"connector": None} + not in cloud_run_worker_v2_job_config.job_body["template"]["template"] + ) + + def test_remove_vpc_access_originally_not_present( + self, cloud_run_worker_v2_job_config + ): + cloud_run_worker_v2_job_config.job_body["template"]["template"].pop("vpcAccess") cloud_run_worker_v2_job_config._remove_vpc_access_if_unset() @@ -148,7 +163,7 @@ def test_vpc_access_left_alone_if_connector_set( assert cloud_run_worker_v2_job_config.job_body["template"]["template"][ "vpcAccess" ] == { - "connector": "projects/my_project/locations/us-central1/connectors/my-connector" # noqa: E501 + "connector": "projects/my_project/locations/us-central1/connectors/my-connector" # noqa E501 } def test_vpc_access_left_alone_if_network_config_set(