Skip to content

Commit

Permalink
Add prediction field to ModelError (#326)
Browse files Browse the repository at this point in the history
This PR extends #325 to add the prediction object itself to
`ModelError`, as opposed to just its ID. This makes it convenient to
introspect logs and other information to determine how to handle the
failure.

```python
import replicate
from replicate.exceptions import ModelError

try:
  output = replicate.run("stability-ai/stable-diffusion-3", { "prompt": "..." })
except ModelError as e
  if "(some known issue)" in e.logs:
    pass

  print("Failed prediction: " + e.prediction.id)
```

---------

Signed-off-by: Rohan Mehta <[email protected]>
Signed-off-by: Mattt Zmuda <[email protected]>
Co-authored-by: Rohan Mehta <[email protected]>
  • Loading branch information
mattt and rohan-mehta committed Jul 18, 2024
1 parent ecfedfb commit 71c124d
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 6 deletions.
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,24 @@ or a handle to a file on your local device.
"an astronaut riding a horse"
```
`replicate.run` raises `ModelError` if the prediction fails.
You can access the exception's `prediction` property
to get more information about the failure.

```python
import replicate
from replicate.exceptions import ModelError

try:
output = replicate.run("stability-ai/stable-diffusion-3", { "prompt": "An astronaut riding a rainbow unicorn" })
except ModelError as e
if "(some known issue)" in e.logs:
pass

print("Failed prediction: " + e.prediction.id)
```


## Run a model and stream its output

Replicate’s API supports server-sent event streams (SSEs) for language models.
Expand Down
11 changes: 10 additions & 1 deletion replicate/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Optional
from typing import TYPE_CHECKING, Optional

import httpx

if TYPE_CHECKING:
from replicate.prediction import Prediction


class ReplicateException(Exception):
"""A base class for all Replicate exceptions."""
Expand All @@ -10,6 +13,12 @@ class ReplicateException(Exception):
class ModelError(ReplicateException):
"""An error from user's code in a model."""

prediction: "Prediction"

def __init__(self, prediction: "Prediction") -> None:
self.prediction = prediction
super().__init__(prediction.error)


class ReplicateError(ReplicateException):
"""
Expand Down
4 changes: 2 additions & 2 deletions replicate/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def output_iterator(self) -> Iterator[Any]:
self.reload()

if self.status == "failed":
raise ModelError(self.error)
raise ModelError(self)

output = self.output or []
new_output = output[len(previous_output) :]
Expand All @@ -272,7 +272,7 @@ async def async_output_iterator(self) -> AsyncIterator[Any]:
await self.async_reload()

if self.status == "failed":
raise ModelError(self.error)
raise ModelError(self)

output = self.output or []
new_output = output[len(previous_output) :]
Expand Down
4 changes: 2 additions & 2 deletions replicate/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def run(
prediction.wait()

if prediction.status == "failed":
raise ModelError(prediction.error)
raise ModelError(prediction)

return prediction.output

Expand Down Expand Up @@ -97,7 +97,7 @@ async def async_run(
await prediction.async_wait()

if prediction.status == "failed":
raise ModelError(prediction.error)
raise ModelError(prediction)

return prediction.output

Expand Down
71 changes: 70 additions & 1 deletion tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import replicate
from replicate.client import Client
from replicate.exceptions import ReplicateError
from replicate.exceptions import ModelError, ReplicateError


@pytest.mark.vcr("run.yaml")
Expand Down Expand Up @@ -184,3 +184,72 @@ def prediction_with_status(status: str) -> dict:
)

assert output == "Hello, world!"


@pytest.mark.asyncio
async def test_run_with_model_error(mock_replicate_api_token):
def prediction_with_status(status: str) -> dict:
return {
"id": "p1",
"model": "test/example",
"version": "v1",
"urls": {
"get": "https://api.replicate.com/v1/predictions/p1",
"cancel": "https://api.replicate.com/v1/predictions/p1/cancel",
},
"created_at": "2023-10-05T12:00:00.000000Z",
"source": "api",
"status": status,
"input": {"text": "world"},
"output": None,
"error": "OOM" if status == "failed" else None,
"logs": "",
}

router = respx.Router(base_url="https://api.replicate.com/v1")
router.route(method="POST", path="/predictions").mock(
return_value=httpx.Response(
201,
json=prediction_with_status("processing"),
)
)
router.route(method="GET", path="/predictions/p1").mock(
return_value=httpx.Response(
200,
json=prediction_with_status("failed"),
)
)
router.route(
method="GET",
path="/models/test/example/versions/v1",
).mock(
return_value=httpx.Response(
201,
json={
"id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1",
"created_at": "2024-07-18T00:35:56.210272Z",
"cog_version": "0.9.10",
"openapi_schema": {
"openapi": "3.0.2",
},
},
)
)
router.route(host="api.replicate.com").pass_through()

client = Client(
api_token="test-token", transport=httpx.MockTransport(router.handler)
)
client.poll_interval = 0.001

with pytest.raises(ModelError) as excinfo:
client.run(
"test/example:v1",
input={
"text": "Hello, world!",
},
)

assert str(excinfo.value) == "OOM"
assert excinfo.value.prediction.error == "OOM"
assert excinfo.value.prediction.status == "failed"

0 comments on commit 71c124d

Please sign in to comment.