diff --git a/prefect_gcp/workers/cloud_run_v2.py b/prefect_gcp/workers/cloud_run_v2.py index 956beb4..6a5a548 100644 --- a/prefect_gcp/workers/cloud_run_v2.py +++ b/prefect_gcp/workers/cloud_run_v2.py @@ -54,7 +54,8 @@ def _get_default_job_body_template() -> Dict[str, Any]: "serviceAccount": "{{ service_account_name }}", "maxRetries": "{{ max_retries }}", "timeout": "{{ timeout }}", - "vpcAccess": "{{ vpc_connector_name }}", + "vpcAccess": + {"connector": "{{ vpc_connector_name }}"}, "containers": [ { "env": [], @@ -184,7 +185,7 @@ def prepare_for_flow_run( self._format_args_if_present() self._populate_image_if_not_present() self._populate_timeout() - self._populate_vpc_if_present() + self._remove_vpc_access_if_unset() def _populate_timeout(self): """ @@ -235,14 +236,19 @@ def _format_args_if_present(self): "args" ] = shlex.split(args) - def _populate_vpc_if_present(self): + def _remove_vpc_access_if_unset(self): """ - Populates the job body with the VPC connector if present. + Removes vpcAccess if unset. """ - if self.job_body["template"]["template"].get("vpcAccess") is not None: - self.job_body["template"]["template"]["vpcAccess"] = { - "connector": self.job_body["template"]["template"]["vpcAccess"], - } + vpc_access = self.job_body["template"]["template"].get("vpcAccess") + + if not vpc_access: + return + + # if connector is the only key and it's not set, we'll remove it. + # otherwise we'll pass whatever the user has provided. + if len(vpc_access) == 1 and vpc_access.get("connector") is None: + self.job_body["template"]["template"]["vpcAccess"] = None # noinspection PyMethodParameters @validator("job_body") diff --git a/tests/test_cloud_run_worker_v2.py b/tests/test_cloud_run_worker_v2.py index 5422430..b45c705 100644 --- a/tests/test_cloud_run_worker_v2.py +++ b/tests/test_cloud_run_worker_v2.py @@ -15,7 +15,9 @@ def job_body(): "template": { "maxRetries": None, "timeout": None, - "vpcAccess": "projects/my_project/locations/us-central1/connectors/my-connector", # noqa: E501 + "vpcAccess": { + "connector": None, + }, "containers": [ { "env": [], @@ -122,12 +124,48 @@ def test_format_args_if_present(self, cloud_run_worker_v2_job_config): "containers" ][0]["args"] == ["-m", "prefect.engine"] - def test_populate_vpc_if_present(self, cloud_run_worker_v2_job_config): - cloud_run_worker_v2_job_config._populate_vpc_if_present() + + def test_remove_vpc_access_if_unset(self, cloud_run_worker_v2_job_config): + assert cloud_run_worker_v2_job_config.job_body["template"]["template"][ + "vpcAccess" + ] == {"connector": None} + + cloud_run_worker_v2_job_config._remove_vpc_access_if_unset() assert ( - cloud_run_worker_v2_job_config.job_body["template"]["template"][ - "vpcAccess" - ]["connector"] - == "projects/my_project/locations/us-central1/connectors/my-connector" + cloud_run_worker_v2_job_config.job_body["template"]["template"]["vpcAccess"] + is None ) + + def test_vpc_access_left_alone_if_connector_set( + self, cloud_run_worker_v2_job_config + ): + cloud_run_worker_v2_job_config.job_body["template"]["template"]["vpcAccess"][ + "connector" + ] = "projects/my_project/locations/us-central1/connectors/my-connector" + + cloud_run_worker_v2_job_config._remove_vpc_access_if_unset() + + assert cloud_run_worker_v2_job_config.job_body["template"]["template"][ + "vpcAccess" + ] == { + "connector": "projects/my_project/locations/us-central1/connectors/my-connector" + } + + def test_vpc_access_left_alone_if_network_config_set( + self, cloud_run_worker_v2_job_config + ): + cloud_run_worker_v2_job_config.job_body["template"]["template"]["vpcAccess"][ + "networkInterfaces" + ] = [{"network": "projects/my_project/global/networks/my-network"}] + + cloud_run_worker_v2_job_config._remove_vpc_access_if_unset() + + assert cloud_run_worker_v2_job_config.job_body["template"]["template"][ + "vpcAccess" + ] == { + "connector": None, + "networkInterfaces": [ + {"network": "projects/my_project/global/networks/my-network"} + ], + } \ No newline at end of file