Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "strict" interface to replicate.run() #288

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -145,17 +145,18 @@ declare module "replicate" {
fetch: (input: Request | string, init?: RequestInit) => Promise<Response>;
fileEncodingStrategy: FileEncodingStrategy;

run(
run<T>(
identifier: `${string}/${string}` | `${string}/${string}:${string}`,
options: {
input: object;
wait?: { interval?: number };
strict?: boolean;
webhook?: string;
webhook_events_filter?: WebhookEventType[];
signal?: AbortSignal;
},
progress?: (prediction: Prediction) => void
): Promise<object>;
): Promise<T>;

stream(
identifier: `${string}/${string}` | `${string}/${string}:${string}`,
Expand Down
148 changes: 115 additions & 33 deletions index.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ const ApiError = require("./lib/error");
const ModelVersionIdentifier = require("./lib/identifier");
const { createReadableStream } = require("./lib/stream");
const {
LazyFile,
withAutomaticRetries,
validateWebhook,
parseProgressFromLogs,
Expand Down Expand Up @@ -50,11 +51,8 @@ class Replicate {
* @param {"default" | "upload" | "data-uri"} [options.fileEncodingStrategy] - Determines the file encoding strategy to use
*/
constructor(options = {}) {
this.auth =
options.auth ||
(typeof process !== "undefined" ? process.env.REPLICATE_API_TOKEN : null);
this.userAgent =
options.userAgent || `replicate-javascript/${packageJSON.version}`;
this.auth = options.auth || (typeof process !== "undefined" ? process.env.REPLICATE_API_TOKEN : null);
this.userAgent = options.userAgent || `replicate-javascript/${packageJSON.version}`;
this.baseUrl = options.baseUrl || "https://api.replicate.com/v1";
this.fetch = options.fetch || globalThis.fetch;
this.fileEncodingStrategy = options.fileEncodingStrategy ?? "default";
Expand Down Expand Up @@ -134,13 +132,14 @@ class Replicate {
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
* @param {AbortSignal} [options.signal] - AbortSignal to cancel the prediction
* @param {boolean} [options.strict] - Boolean to indicate that return type should conform to output schema
* @param {Function} [progress] - Callback function that receives the prediction object as it's updated. The function is called when the prediction is created, each time its updated while polling for completion, and when it's completed.
* @throws {Error} If the reference is invalid
* @throws {Error} If the prediction failed
* @returns {Promise<object>} - Resolves with the output of running the model
*/
async run(ref, options, progress) {
const { wait, signal, ...data } = options;
const { wait, signal, strict, ...data } = options;

const identifier = ModelVersionIdentifier.parse(ref);

Expand All @@ -154,6 +153,7 @@ class Replicate {
prediction = await this.predictions.create({
...data,
model: `${identifier.owner}/${identifier.name}`,
stream: true,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now, this will cause all models that don't support streaming to respond with {"detail":"Streaming not supported for the output type of the requested version.","status":422}.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've just fixed that, consumers no longer need to specify the stream: true flag. All streamable models will get a stream URL 🙌

});
} else {
throw new Error("Invalid model version identifier");
Expand All @@ -164,23 +164,30 @@ class Replicate {
progress(prediction);
}

prediction = await this.wait(
prediction,
wait || {},
async (updatedPrediction) => {
// Call progress callback with the updated prediction object
if (progress) {
progress(updatedPrediction);
}
if (strict && !identifier.version) {
// Language models only support streaming at the moment.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clarify potential ambiguity in use of "only" here

Suggested change
// Language models only support streaming at the moment.
// Currently, only language models support streaming

const stream = createReadableStream({
url: prediction.urls.stream,
fetch: this.fetch,
...(signal ? { options: { signal } } : {}),
});

// We handle the cancel later in the function.
if (signal && signal.aborted) {
return true; // stop polling
}
return streamAsyncIterator(stream);
}

prediction = await this.wait(prediction, wait || {}, async (updatedPrediction) => {
// Call progress callback with the updated prediction object
if (progress) {
progress(updatedPrediction);
}

return false; // continue polling
// We handle the cancel later in the function.
if (signal && signal.aborted) {
return true; // stop polling
}
);

return false; // continue polling
});

if (signal && signal.aborted) {
prediction = await this.predictions.cancel(prediction.id);
Expand All @@ -195,7 +202,22 @@ class Replicate {
throw new Error(`Prediction failed: ${prediction.error}`);
}

return prediction.output;
if (!strict) {
return prediction.output;
}

const response = await this.models.versions.get(identifier.owner, identifier.name, identifier.version);
const { openapi_schema: schema } = response;
Comment on lines +209 to +210
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think fetching the schema in a separate request is a deal-breaker here. That's something the Python client has been doing from the start, and it's added a lot of unnecessary overhead.

It's not as magical, but could we instead implement coerced run on a model / version? That way, you'd avoid duplicate requests and have a clearer expectations around how/when schemas are validated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not as magical, but could we instead implement coerced run on a model / version? That way, you'd avoid duplicate requests and have a clearer expectations around how/when schemas are validated.

I'm not sure what you mean here?


try {
return coerceOutput(schema.components.schemas.Output, prediction.output);
} catch (err) {
if (err instanceof CoercionError) {
console.warn(err.message);
return prediction.output;
}
throw err;
}
}

/**
Expand All @@ -217,10 +239,7 @@ class Replicate {
if (route instanceof URL) {
url = route;
} else {
url = new URL(
route.startsWith("/") ? route.slice(1) : route,
baseUrl.endsWith("/") ? baseUrl : `${baseUrl}/`
);
url = new URL(route.startsWith("/") ? route.slice(1) : route, baseUrl.endsWith("/") ? baseUrl : `${baseUrl}/`);
}

const { method = "GET", params = {}, data } = options;
Expand Down Expand Up @@ -275,7 +294,7 @@ class Replicate {
throw new ApiError(
`Request to ${url} failed with status ${response.status} ${response.statusText}: ${responseText}.`,
request,
response
response,
);
}

Expand Down Expand Up @@ -344,8 +363,7 @@ class Replicate {
const response = await endpoint();
yield response.results;
if (response.next) {
const nextPage = () =>
this.request(response.next, { method: "GET" }).then((r) => r.json());
const nextPage = () => this.request(response.next, { method: "GET" }).then((r) => r.json());
yield* this.paginate(nextPage);
}
}
Expand All @@ -372,11 +390,7 @@ class Replicate {
throw new Error("Invalid prediction");
}

if (
prediction.status === "succeeded" ||
prediction.status === "failed" ||
prediction.status === "canceled"
) {
if (prediction.status === "succeeded" || prediction.status === "failed" || prediction.status === "canceled") {
return prediction;
}

Expand Down Expand Up @@ -413,3 +427,71 @@ class Replicate {
module.exports = Replicate;
module.exports.validateWebhook = validateWebhook;
module.exports.parseProgressFromLogs = parseProgressFromLogs;

// TODO: Extend to contain more information about fields/schema/outputs
class CoercionError extends Error {}

function coerceOutput(schema, output) {
if (schema.type === "array") {
if (!Array.isArray(output)) {
throw new CoercionError("output is not array type");
}

// TODO: Add helper to return iterable with a `display()` function
// that returns a string rather than an array taking into account
// the `x-cog-array-display` property.
if (schema["x-cog-array-type"] === "iterator") {
return (async function* () {
for (const url of output) {
yield coerceOutput(schema["items"], url);
}
})();
}
return output.map((entry) => coerceOutput(schema["items"], entry));
}

if (schema.type === "object") {
if (typeof output !== "object" && object !== null) {
throw new CoercionError("output is not object type");
}

const mapped = {};
for (const [property, subschema] of Object.entries(schema.properties)) {
if (output[property]) {
mapped[property] = coerceOutput(subschema, output[property]);
} else if (subschema.required && subschema.required.includes(property)) {
throw new CoercionError(`output is missing required property: ${property}`);
}
}
return mapped;
}

if (schema.type === "string") {
if (typeof output !== "string") {
throw new CoercionError("output is not string type");
}

if (schema.format === "uri") {
try {
return new LazyFile(new URL(output));
} catch (error) {
throw new CoercionError("output is not a valid uri format");
}
}

// TODO: Handle dates
}

if (schema.type === "integer" || schema.type === "number") {
if (typeof output !== "number") {
throw new CoercionError(`output is not ${schema.type} type`);
}
}

if (schema.type === "boolean") {
if (typeof output !== "boolean") {
throw new CoercionError("output is not boolean type");
}
}
return output;
}
Loading
Loading