Skip to content

Commit

Permalink
feat (providers): support custom fetch implementations (#1955)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel committed Jun 14, 2024
1 parent 1af45ce commit 7910ae8
Show file tree
Hide file tree
Showing 24 changed files with 276 additions and 5 deletions.
10 changes: 10 additions & 0 deletions .changeset/nasty-goats-hammer.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
'@ai-sdk/provider-utils': patch
'@ai-sdk/anthropic': patch
'@ai-sdk/mistral': patch
'@ai-sdk/google': patch
'@ai-sdk/openai': patch
'@ai-sdk/azure': patch
---

feat (providers): support custom fetch implementations
7 changes: 7 additions & 0 deletions content/providers/01-ai-sdk-providers/01-openai.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ You can use the following optional settings to customize the OpenAI provider ins

Custom headers to include in the requests.

- **fetch** _(input: RequestInfo, init?: RequestInit) => Promise<Response>_

Custom [fetch](https://developer.mozilla.org/en-US/docs/Web/API/fetch) implementation.
Defaults to the global `fetch` function.
You can use it as a middleware to intercept requests,
or to provide a custom fetch implementation for e.g. testing.

- **compatibility** _"strict" | "compatible"_

OpenAI compatibility mode. Should be set to `strict` when using the OpenAI API,
Expand Down
7 changes: 7 additions & 0 deletions content/providers/01-ai-sdk-providers/02-azure.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ You can use the following optional settings to customize the OpenAI provider ins
API key that is being send using the `api-key` header.
It defaults to the `AZURE_API_KEY` environment variable.

- **fetch** _(input: RequestInfo, init?: RequestInit) => Promise<Response>_

Custom [fetch](https://developer.mozilla.org/en-US/docs/Web/API/fetch) implementation.
Defaults to the global `fetch` function.
You can use it as a middleware to intercept requests,
or to provide a custom fetch implementation for e.g. testing.

## Language Models

The Azure OpenAI provider instance is a function that you can invoke to create a language model:
Expand Down
7 changes: 7 additions & 0 deletions content/providers/01-ai-sdk-providers/05-anthropic.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ You can use the following optional settings to customize the Google Generative A

Custom headers to include in the requests.

- **fetch** _(input: RequestInfo, init?: RequestInit) => Promise<Response>_

Custom [fetch](https://developer.mozilla.org/en-US/docs/Web/API/fetch) implementation.
Defaults to the global `fetch` function.
You can use it as a middleware to intercept requests,
or to provide a custom fetch implementation for e.g. testing.

## Language Models

You can create models that call the [Anthropic Messages API](https://docs.anthropic.com/claude/reference/messages_post) using the provider instance.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ You can use the following optional settings to customize the Google Generative A

Custom headers to include in the requests.

- **fetch** _(input: RequestInfo, init?: RequestInit) => Promise<Response>_

Custom [fetch](https://developer.mozilla.org/en-US/docs/Web/API/fetch) implementation.
Defaults to the global `fetch` function.
You can use it as a middleware to intercept requests,
or to provide a custom fetch implementation for e.g. testing.

## Language Models

You can create models that call the [Google Generative AI API](https://ai.google.dev/api/rest) using the provider instance.
Expand Down
7 changes: 7 additions & 0 deletions content/providers/01-ai-sdk-providers/20-mistral.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ You can use the following optional settings to customize the Mistral provider in

Custom headers to include in the requests.

- **fetch** _(input: RequestInfo, init?: RequestInit) => Promise<Response>_

Custom [fetch](https://developer.mozilla.org/en-US/docs/Web/API/fetch) implementation.
Defaults to the global `fetch` function.
You can use it as a middleware to intercept requests,
or to provide a custom fetch implementation for e.g. testing.

## Language Models

You can create models that call the [Mistral chat API](https://docs.mistral.ai/api/#operation/createChatCompletion) using provider instance.
Expand Down
27 changes: 27 additions & 0 deletions examples/ai-core/src/generate-text/anthropic-custom-fetch.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import { createAnthropic } from '@ai-sdk/anthropic';
import { generateText } from 'ai';
import dotenv from 'dotenv';

dotenv.config();

const anthropic = createAnthropic({
// example fetch wrapper that logs the URL:
fetch: async (url, options) => {
console.log(`Fetching ${url}`);
const result = await fetch(url, options);
console.log(`Fetched ${url}`);
console.log();
return result;
},
});

async function main() {
const result = await generateText({
model: anthropic('claude-3-haiku-20240307'),
prompt: 'Invent a new holiday and describe its traditions.',
});

console.log(result.text);
}

main().catch(console.error);
27 changes: 27 additions & 0 deletions examples/ai-core/src/generate-text/azure-custom-fetch.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import { createAzure } from '@ai-sdk/azure';
import { generateText } from 'ai';
import dotenv from 'dotenv';

dotenv.config();

const azure = createAzure({
// example fetch wrapper that logs the URL:
fetch: async (url, options) => {
console.log(`Fetching ${url}`);
const result = await fetch(url, options);
console.log(`Fetched ${url}`);
console.log();
return result;
},
});

async function main() {
const result = await generateText({
model: azure('v0-gpt-35-turbo'), // use your own deployment
prompt: 'Invent a new holiday and describe its traditions.',
});

console.log(result.text);
}

main().catch(console.error);
27 changes: 27 additions & 0 deletions examples/ai-core/src/generate-text/google-custom-fetch.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import { createGoogleGenerativeAI } from '@ai-sdk/google';
import { generateText } from 'ai';
import dotenv from 'dotenv';

dotenv.config();

const google = createGoogleGenerativeAI({
// example fetch wrapper that logs the URL:
fetch: async (url, options) => {
console.log(`Fetching ${url}`);
const result = await fetch(url, options);
console.log(`Fetched ${url}`);
console.log();
return result;
},
});

async function main() {
const result = await generateText({
model: google('models/gemini-pro'),
prompt: 'Invent a new holiday and describe its traditions.',
});

console.log(result.text);
}

main().catch(console.error);
27 changes: 27 additions & 0 deletions examples/ai-core/src/generate-text/mistral-custom-fetch.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import { createMistral } from '@ai-sdk/mistral';
import { generateText } from 'ai';
import dotenv from 'dotenv';

dotenv.config();

const mistral = createMistral({
// example fetch wrapper that logs the URL:
fetch: async (url, options) => {
console.log(`Fetching ${url}`);
const result = await fetch(url, options);
console.log(`Fetched ${url}`);
console.log();
return result;
},
});

async function main() {
const result = await generateText({
model: mistral('open-mistral-7b'),
prompt: 'Invent a new holiday and describe its traditions.',
});

console.log(result.text);
}

main().catch(console.error);
27 changes: 27 additions & 0 deletions examples/ai-core/src/generate-text/openai-custom-fetch.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import { createOpenAI } from '@ai-sdk/openai';
import { generateText } from 'ai';
import dotenv from 'dotenv';

dotenv.config();

const openai = createOpenAI({
// example fetch wrapper that logs the URL:
fetch: async (url, options) => {
console.log(`Fetching ${url}`);
const result = await fetch(url, options);
console.log(`Fetched ${url}`);
console.log();
return result;
},
});

async function main() {
const result = await generateText({
model: openai('gpt-3.5-turbo'),
prompt: 'Invent a new holiday and describe its traditions.',
});

console.log(result.text);
}

main().catch(console.error);
3 changes: 3 additions & 0 deletions packages/anthropic/src/anthropic-messages-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type AnthropicMessagesConfig = {
provider: string;
baseURL: string;
headers: () => Record<string, string | undefined>;
fetch?: typeof fetch;
};

export class AnthropicMessagesLanguageModel implements LanguageModelV1 {
Expand Down Expand Up @@ -163,6 +164,7 @@ export class AnthropicMessagesLanguageModel implements LanguageModelV1 {
anthropicMessagesResponseSchema,
),
abortSignal: options.abortSignal,
fetch: this.config.fetch,
});

const { messages: rawPrompt, ...rawSettings } = args;
Expand Down Expand Up @@ -222,6 +224,7 @@ export class AnthropicMessagesLanguageModel implements LanguageModelV1 {
anthropicMessagesChunkSchema,
),
abortSignal: options.abortSignal,
fetch: this.config.fetch,
});

const { messages: rawPrompt, ...rawSettings } = args;
Expand Down
7 changes: 7 additions & 0 deletions packages/anthropic/src/anthropic-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ Custom headers to include in the requests.
*/
headers?: Record<string, string>;

/**
Custom fetch implementation. You can use it as a middleware to intercept requests,
or to provide a custom fetch implementation for e.g. testing.
*/
fetch?: typeof fetch;

generateId?: () => string;
}

Expand Down Expand Up @@ -86,6 +92,7 @@ export function createAnthropic(
provider: 'anthropic.messages',
baseURL,
headers: getHeaders,
fetch: options.fetch,
});

const provider = function (
Expand Down
7 changes: 7 additions & 0 deletions packages/azure/src/azure-openai-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ Name of the Azure OpenAI resource.
API key for authenticating requests.
*/
apiKey?: string;

/**
Custom fetch implementation. You can use it as a middleware to intercept requests,
or to provide a custom fetch implementation for e.g. testing.
*/
fetch?: typeof fetch;
}

/**
Expand Down Expand Up @@ -63,6 +69,7 @@ export function createAzure(
url: ({ path, modelId }) =>
`https://${getResourceName()}.openai.azure.com/openai/deployments/${modelId}${path}?api-version=2024-05-01-preview`,
compatibility: 'compatible',
fetch: options.fetch,
});

const provider = function (
Expand Down
3 changes: 3 additions & 0 deletions packages/google/src/google-generative-ai-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type GoogleGenerativeAIConfig = {
baseURL: string;
headers: () => Record<string, string | undefined>;
generateId: () => string;
fetch?: typeof fetch;
};

export class GoogleGenerativeAILanguageModel implements LanguageModelV1 {
Expand Down Expand Up @@ -156,6 +157,7 @@ export class GoogleGenerativeAILanguageModel implements LanguageModelV1 {
failedResponseHandler: googleFailedResponseHandler,
successfulResponseHandler: createJsonResponseHandler(responseSchema),
abortSignal: options.abortSignal,
fetch: this.config.fetch,
});

const { contents: rawPrompt, ...rawSettings } = args;
Expand Down Expand Up @@ -197,6 +199,7 @@ export class GoogleGenerativeAILanguageModel implements LanguageModelV1 {
failedResponseHandler: googleFailedResponseHandler,
successfulResponseHandler: createEventSourceResponseHandler(chunkSchema),
abortSignal: options.abortSignal,
fetch: this.config.fetch,
});

const { contents: rawPrompt, ...rawSettings } = args;
Expand Down
43 changes: 38 additions & 5 deletions packages/google/src/google-provider.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import { Google } from './google-facade';
import {
generateId,
loadApiKey,
withoutTrailingSlash,
} from '@ai-sdk/provider-utils';
import { GoogleGenerativeAILanguageModel } from './google-generative-ai-language-model';
import {
GoogleGenerativeAIModelId,
Expand Down Expand Up @@ -48,6 +52,12 @@ Custom headers to include in the requests.
*/
headers?: Record<string, string>;

/**
Custom fetch implementation. You can use it as a middleware to intercept requests,
or to provide a custom fetch implementation for e.g. testing.
*/
fetch?: typeof fetch;

generateId?: () => string;
}

Expand All @@ -57,7 +67,30 @@ Create a Google Generative AI provider instance.
export function createGoogleGenerativeAI(
options: GoogleGenerativeAIProviderSettings = {},
): GoogleGenerativeAIProvider {
const google = new Google(options);
const baseURL =
withoutTrailingSlash(options.baseURL ?? options.baseUrl) ??
'https://generativelanguage.googleapis.com/v1beta';

const getHeaders = () => ({
'x-goog-api-key': loadApiKey({
apiKey: options.apiKey,
environmentVariableName: 'GOOGLE_GENERATIVE_AI_API_KEY',
description: 'Google Generative AI',
}),
...options.headers,
});

const createChatModel = (
modelId: GoogleGenerativeAIModelId,
settings: GoogleGenerativeAISettings = {},
) =>
new GoogleGenerativeAILanguageModel(modelId, settings, {
provider: 'google.generative-ai',
baseURL,
headers: getHeaders,
generateId: options.generateId ?? generateId,
fetch: options.fetch,
});

const provider = function (
modelId: GoogleGenerativeAIModelId,
Expand All @@ -69,11 +102,11 @@ export function createGoogleGenerativeAI(
);
}

return google.chat(modelId, settings);
return createChatModel(modelId, settings);
};

provider.chat = google.chat.bind(google);
provider.generativeAI = google.generativeAI.bind(google);
provider.chat = createChatModel;
provider.generativeAI = createChatModel;

return provider as GoogleGenerativeAIProvider;
}
Expand Down
Loading

0 comments on commit 7910ae8

Please sign in to comment.