-
Notifications
You must be signed in to change notification settings - Fork 192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add "strict" interface to replicate.run()
#288
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -2,6 +2,7 @@ const ApiError = require("./lib/error"); | |||||
const ModelVersionIdentifier = require("./lib/identifier"); | ||||||
const { createReadableStream } = require("./lib/stream"); | ||||||
const { | ||||||
LazyFile, | ||||||
withAutomaticRetries, | ||||||
validateWebhook, | ||||||
parseProgressFromLogs, | ||||||
|
@@ -50,11 +51,8 @@ class Replicate { | |||||
* @param {"default" | "upload" | "data-uri"} [options.fileEncodingStrategy] - Determines the file encoding strategy to use | ||||||
*/ | ||||||
constructor(options = {}) { | ||||||
this.auth = | ||||||
options.auth || | ||||||
(typeof process !== "undefined" ? process.env.REPLICATE_API_TOKEN : null); | ||||||
this.userAgent = | ||||||
options.userAgent || `replicate-javascript/${packageJSON.version}`; | ||||||
this.auth = options.auth || (typeof process !== "undefined" ? process.env.REPLICATE_API_TOKEN : null); | ||||||
this.userAgent = options.userAgent || `replicate-javascript/${packageJSON.version}`; | ||||||
this.baseUrl = options.baseUrl || "https://api.replicate.com/v1"; | ||||||
this.fetch = options.fetch || globalThis.fetch; | ||||||
this.fileEncodingStrategy = options.fileEncodingStrategy ?? "default"; | ||||||
|
@@ -134,13 +132,14 @@ class Replicate { | |||||
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output | ||||||
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) | ||||||
* @param {AbortSignal} [options.signal] - AbortSignal to cancel the prediction | ||||||
* @param {boolean} [options.strict] - Boolean to indicate that return type should conform to output schema | ||||||
* @param {Function} [progress] - Callback function that receives the prediction object as it's updated. The function is called when the prediction is created, each time its updated while polling for completion, and when it's completed. | ||||||
* @throws {Error} If the reference is invalid | ||||||
* @throws {Error} If the prediction failed | ||||||
* @returns {Promise<object>} - Resolves with the output of running the model | ||||||
*/ | ||||||
async run(ref, options, progress) { | ||||||
const { wait, signal, ...data } = options; | ||||||
const { wait, signal, strict, ...data } = options; | ||||||
|
||||||
const identifier = ModelVersionIdentifier.parse(ref); | ||||||
|
||||||
|
@@ -154,6 +153,7 @@ class Replicate { | |||||
prediction = await this.predictions.create({ | ||||||
...data, | ||||||
model: `${identifier.owner}/${identifier.name}`, | ||||||
stream: true, | ||||||
}); | ||||||
} else { | ||||||
throw new Error("Invalid model version identifier"); | ||||||
|
@@ -164,23 +164,30 @@ class Replicate { | |||||
progress(prediction); | ||||||
} | ||||||
|
||||||
prediction = await this.wait( | ||||||
prediction, | ||||||
wait || {}, | ||||||
async (updatedPrediction) => { | ||||||
// Call progress callback with the updated prediction object | ||||||
if (progress) { | ||||||
progress(updatedPrediction); | ||||||
} | ||||||
if (strict && !identifier.version) { | ||||||
// Language models only support streaming at the moment. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Clarify potential ambiguity in use of "only" here
Suggested change
|
||||||
const stream = createReadableStream({ | ||||||
url: prediction.urls.stream, | ||||||
fetch: this.fetch, | ||||||
...(signal ? { options: { signal } } : {}), | ||||||
}); | ||||||
|
||||||
// We handle the cancel later in the function. | ||||||
if (signal && signal.aborted) { | ||||||
return true; // stop polling | ||||||
} | ||||||
return streamAsyncIterator(stream); | ||||||
} | ||||||
|
||||||
prediction = await this.wait(prediction, wait || {}, async (updatedPrediction) => { | ||||||
// Call progress callback with the updated prediction object | ||||||
if (progress) { | ||||||
progress(updatedPrediction); | ||||||
} | ||||||
|
||||||
return false; // continue polling | ||||||
// We handle the cancel later in the function. | ||||||
if (signal && signal.aborted) { | ||||||
return true; // stop polling | ||||||
} | ||||||
); | ||||||
|
||||||
return false; // continue polling | ||||||
}); | ||||||
|
||||||
if (signal && signal.aborted) { | ||||||
prediction = await this.predictions.cancel(prediction.id); | ||||||
|
@@ -195,7 +202,22 @@ class Replicate { | |||||
throw new Error(`Prediction failed: ${prediction.error}`); | ||||||
} | ||||||
|
||||||
return prediction.output; | ||||||
if (!strict) { | ||||||
return prediction.output; | ||||||
} | ||||||
|
||||||
const response = await this.models.versions.get(identifier.owner, identifier.name, identifier.version); | ||||||
const { openapi_schema: schema } = response; | ||||||
Comment on lines
+209
to
+210
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think fetching the schema in a separate request is a deal-breaker here. That's something the Python client has been doing from the start, and it's added a lot of unnecessary overhead. It's not as magical, but could we instead implement coerced There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I'm not sure what you mean here? |
||||||
|
||||||
try { | ||||||
return coerceOutput(schema.components.schemas.Output, prediction.output); | ||||||
} catch (err) { | ||||||
if (err instanceof CoercionError) { | ||||||
console.warn(err.message); | ||||||
return prediction.output; | ||||||
} | ||||||
throw err; | ||||||
} | ||||||
} | ||||||
|
||||||
/** | ||||||
|
@@ -217,10 +239,7 @@ class Replicate { | |||||
if (route instanceof URL) { | ||||||
url = route; | ||||||
} else { | ||||||
url = new URL( | ||||||
route.startsWith("/") ? route.slice(1) : route, | ||||||
baseUrl.endsWith("/") ? baseUrl : `${baseUrl}/` | ||||||
); | ||||||
url = new URL(route.startsWith("/") ? route.slice(1) : route, baseUrl.endsWith("/") ? baseUrl : `${baseUrl}/`); | ||||||
} | ||||||
|
||||||
const { method = "GET", params = {}, data } = options; | ||||||
|
@@ -275,7 +294,7 @@ class Replicate { | |||||
throw new ApiError( | ||||||
`Request to ${url} failed with status ${response.status} ${response.statusText}: ${responseText}.`, | ||||||
request, | ||||||
response | ||||||
response, | ||||||
); | ||||||
} | ||||||
|
||||||
|
@@ -344,8 +363,7 @@ class Replicate { | |||||
const response = await endpoint(); | ||||||
yield response.results; | ||||||
if (response.next) { | ||||||
const nextPage = () => | ||||||
this.request(response.next, { method: "GET" }).then((r) => r.json()); | ||||||
const nextPage = () => this.request(response.next, { method: "GET" }).then((r) => r.json()); | ||||||
yield* this.paginate(nextPage); | ||||||
} | ||||||
} | ||||||
|
@@ -372,11 +390,7 @@ class Replicate { | |||||
throw new Error("Invalid prediction"); | ||||||
} | ||||||
|
||||||
if ( | ||||||
prediction.status === "succeeded" || | ||||||
prediction.status === "failed" || | ||||||
prediction.status === "canceled" | ||||||
) { | ||||||
if (prediction.status === "succeeded" || prediction.status === "failed" || prediction.status === "canceled") { | ||||||
return prediction; | ||||||
} | ||||||
|
||||||
|
@@ -413,3 +427,71 @@ class Replicate { | |||||
module.exports = Replicate; | ||||||
module.exports.validateWebhook = validateWebhook; | ||||||
module.exports.parseProgressFromLogs = parseProgressFromLogs; | ||||||
|
||||||
// TODO: Extend to contain more information about fields/schema/outputs | ||||||
class CoercionError extends Error {} | ||||||
|
||||||
function coerceOutput(schema, output) { | ||||||
if (schema.type === "array") { | ||||||
if (!Array.isArray(output)) { | ||||||
throw new CoercionError("output is not array type"); | ||||||
} | ||||||
|
||||||
// TODO: Add helper to return iterable with a `display()` function | ||||||
// that returns a string rather than an array taking into account | ||||||
// the `x-cog-array-display` property. | ||||||
if (schema["x-cog-array-type"] === "iterator") { | ||||||
return (async function* () { | ||||||
for (const url of output) { | ||||||
yield coerceOutput(schema["items"], url); | ||||||
} | ||||||
})(); | ||||||
} | ||||||
return output.map((entry) => coerceOutput(schema["items"], entry)); | ||||||
} | ||||||
|
||||||
if (schema.type === "object") { | ||||||
if (typeof output !== "object" && object !== null) { | ||||||
throw new CoercionError("output is not object type"); | ||||||
} | ||||||
|
||||||
const mapped = {}; | ||||||
for (const [property, subschema] of Object.entries(schema.properties)) { | ||||||
if (output[property]) { | ||||||
mapped[property] = coerceOutput(subschema, output[property]); | ||||||
} else if (subschema.required && subschema.required.includes(property)) { | ||||||
throw new CoercionError(`output is missing required property: ${property}`); | ||||||
} | ||||||
} | ||||||
return mapped; | ||||||
} | ||||||
|
||||||
if (schema.type === "string") { | ||||||
if (typeof output !== "string") { | ||||||
throw new CoercionError("output is not string type"); | ||||||
} | ||||||
|
||||||
if (schema.format === "uri") { | ||||||
try { | ||||||
return new LazyFile(new URL(output)); | ||||||
} catch (error) { | ||||||
throw new CoercionError("output is not a valid uri format"); | ||||||
} | ||||||
} | ||||||
|
||||||
// TODO: Handle dates | ||||||
} | ||||||
|
||||||
if (schema.type === "integer" || schema.type === "number") { | ||||||
if (typeof output !== "number") { | ||||||
throw new CoercionError(`output is not ${schema.type} type`); | ||||||
} | ||||||
} | ||||||
|
||||||
if (schema.type === "boolean") { | ||||||
if (typeof output !== "boolean") { | ||||||
throw new CoercionError("output is not boolean type"); | ||||||
} | ||||||
} | ||||||
return output; | ||||||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right now, this will cause all models that don't support streaming to respond with
{"detail":"Streaming not supported for the output type of the requested version.","status":422}
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We've just fixed that, consumers no longer need to specify the
stream: true
flag. All streamable models will get a stream URL 🙌