Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
undo bad format
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnLemmonMedely committed Feb 23, 2024
1 parent f7eb331 commit 494f89b
Showing 1 changed file with 16 additions and 51 deletions.
67 changes: 16 additions & 51 deletions prefect_gcp/workers/cloud_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,7 @@
if PYDANTIC_VERSION.startswith("2."):
from pydantic.v1 import Field, validator
else:
from pydantic import (
Field,
validator,
)
from pydantic import Field, validator

from prefect_gcp.cloud_run import Execution, Job
from prefect_gcp.credentials import GcpCredentials
Expand All @@ -194,7 +191,7 @@
from prefect.server.schemas.responses import DeploymentResponse


def _get_default_job_body_template() -> (Dict[str, Any]):
def _get_default_job_body_template() -> Dict[str, Any]:
"""Returns the default job body template used by the Cloud Run Job."""
return {
"apiVersion": "run.googleapis.com/v1",
Expand Down Expand Up @@ -242,7 +239,7 @@ def _get_default_job_body_template() -> (Dict[str, Any]):
}


def _get_base_job_body() -> (Dict[str, Any]):
def _get_base_job_body() -> Dict[str, Any]:
"""Returns a base job body to use for job body validation."""
return {
"apiVersion": "run.googleapis.com/v1",
Expand Down Expand Up @@ -282,8 +279,7 @@ class CloudRunWorkerJobConfiguration(BaseJobConfiguration):
"""

region: str = Field(
default="us-central1",
description="The region where the Cloud Run Job resides.",
default="us-central1", description="The region where the Cloud Run Job resides."
)
credentials: Optional[GcpCredentials] = Field(
title="GCP Credentials",
Expand Down Expand Up @@ -354,9 +350,7 @@ def _populate_envs(self):
"env"
] = envs

def _populate_name_if_not_present(
self,
):
def _populate_name_if_not_present(self):
"""Adds the flow run name to the job if one is not already provided."""
try:
if "name" not in self.job_body["metadata"]:
Expand All @@ -366,9 +360,7 @@ def _populate_name_if_not_present(
except KeyError:
raise ValueError("Unable to verify name due to invalid job body template.")

def _populate_image_if_not_present(
self,
):
def _populate_image_if_not_present(self):
"""Adds the latest prefect image to the job if one is not already provided."""
try:
if (
Expand All @@ -383,9 +375,7 @@ def _populate_image_if_not_present(
except KeyError:
raise ValueError("Unable to verify image due to invalid job body template.")

def _populate_or_format_command(
self,
):
def _populate_or_format_command(self):
"""
Ensures that the command is present in the job manifest. Populates the command
with the `prefect -m prefect.engine` if a command is not present.
Expand Down Expand Up @@ -619,10 +609,7 @@ async def run(
logger,
)
job_execution = await run_sync_in_worker_thread(
self._begin_job_execution,
configuration,
client,
logger,
self._begin_job_execution, configuration, client, logger
)

if task_status:
Expand All @@ -637,21 +624,15 @@ async def run(
)
return result

def _get_client(
self,
configuration: CloudRunWorkerJobConfiguration,
) -> Resource:
def _get_client(self, configuration: CloudRunWorkerJobConfiguration) -> Resource:
"""Get the base client needed for interacting with GCP APIs."""
# region needed for 'v1' API
api_endpoint = f"https://{configuration.region}-run.googleapis.com"
gcp_creds = configuration.credentials.get_credentials_from_service_account()
options = ClientOptions(api_endpoint=api_endpoint)

return discovery.build(
"run",
"v1",
client_options=options,
credentials=gcp_creds,
"run", "v1", client_options=options, credentials=gcp_creds
).namespaces()

def _create_job_and_wait_for_registration(
Expand All @@ -669,14 +650,12 @@ def _create_job_and_wait_for_registration(
namespace=configuration.credentials.project,
body=configuration.job_body,
)
except (googleapiclient.errors.HttpError) as exc:
except googleapiclient.errors.HttpError as exc:
self._create_job_error(exc, configuration)

try:
self._wait_for_job_creation(
client=client,
configuration=configuration,
logger=logger,
client=client, configuration=configuration, logger=logger
)
except Exception:
logger.exception(
Expand Down Expand Up @@ -781,16 +760,11 @@ def _watch_job_execution_and_get_result(
)

return CloudRunWorkerResult(
identifier=configuration.job_name,
status_code=status_code,
identifier=configuration.job_name, status_code=status_code
)

def _watch_job_execution(
self,
client,
job_execution: Execution,
timeout: int,
poll_interval: int = 5,
self, client, job_execution: Execution, timeout: int, poll_interval: int = 5
):
"""
Update job_execution status until it is no longer running or timeout is reached.
Expand Down Expand Up @@ -879,18 +853,9 @@ async def kill_infrastructure(
job_name=infrastructure_pid,
)

def _stop_job(
self,
client: Resource,
namespace: str,
job_name: str,
):
def _stop_job(self, client: Resource, namespace: str, job_name: str):
try:
Job.delete(
client=client,
namespace=namespace,
job_name=job_name,
)
Job.delete(client=client, namespace=namespace, job_name=job_name)
except Exception as exc:
if "does not exist" in str(exc):
raise InfrastructureNotFound(
Expand Down

0 comments on commit 494f89b

Please sign in to comment.