Skip to content

Commit

Permalink
feat (core): support https and data url strings in image parts (#1944)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel committed Jun 13, 2024
1 parent 17e5bbb commit 0612350
Show file tree
Hide file tree
Showing 13 changed files with 159 additions and 49 deletions.
5 changes: 5 additions & 0 deletions .changeset/smart-ducks-fold.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'ai': patch
---

feat (core): support https and data url strings in image parts
29 changes: 19 additions & 10 deletions content/docs/03-ai-sdk-core/03-prompts.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,27 @@ const result = await generateText({
Instead of sending a text in the `content` property, you can send an array of parts that include text and other data types.
Currently image and text parts are supported.

For models that support multi-modal inputs, user messages can include images. An `image` can be a base64-encoded image (`string`), an `ArrayBuffer`, a `Uint8Array`,
a `Buffer`, or a `URL` object. It is possible to mix text and multiple images.
For models that support multi-modal inputs, user messages can include images. An `image` can be one of the following:

- base64-encoded image:
- `string` with base-64 encoded content
- data URL `string`, e.g. `data:image/png;base64,...`
- binary image:
- `ArrayBuffer`
- `Uint8Array`
- `Buffer`
- URL:
- http(s) URL `string`, e.g. `https://example.com/image.png`
- `URL` object, e.g. `new URL('https://example.com/image.png')`

It is possible to mix text and multiple images.

<Note type="warning">
Not all models support all types of multi-modal inputs. Check the model's
capabilities before using this feature.
</Note>

#### Example: Buffer images
#### Example: Binary image (Buffer)

```ts highlight="8-11"
const result = await generateText({
Expand All @@ -104,9 +116,7 @@ const result = await generateText({
});
```

#### Example: Base-64 encoded images

<Note>You do not need a `data:...` prefix for the base64-encoded image.</Note>
#### Example: Base-64 encoded image (string)

```ts highlight="8-11"
const result = await generateText({
Expand All @@ -126,9 +136,9 @@ const result = await generateText({
});
```

#### Example: Image URLs
#### Example: Image URL (string)

```ts highlight="8-13"
```ts highlight="8-12"
const result = await generateText({
model: yourModel,
messages: [
Expand All @@ -138,9 +148,8 @@ const result = await generateText({
{ type: 'text', text: 'Describe the image in detail.' },
{
type: 'image',
image: new URL(
image:
'https://github.com/vercel/ai/blob/main/examples/ai-core/data/comic-cat.png?raw=true',
),
},
],
},
Expand Down
2 changes: 1 addition & 1 deletion content/docs/07-reference/ai-sdk-core/01-generate-text.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ console.log(text);
name: 'image',
type: 'string | Uint8Array | Buffer | ArrayBuffer | URL',
description:
'The image content of the message part. String are base64 encoded content. URLs need to be represented with a URL object',
'The image content of the message part. String are either base64 encoded content, base64 data URLs, or http(s) URLs.',
},
],
},
Expand Down
2 changes: 1 addition & 1 deletion content/docs/07-reference/ai-sdk-core/02-stream-text.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ for await (const textPart of textStream) {
name: 'image',
type: 'string | Uint8Array | Buffer | ArrayBuffer | URL',
description:
'The image content of the message part. String are base64 encoded content. URLs need to be represented with a URL object',
'The image content of the message part. String are either base64 encoded content, base64 data URLs, or http(s) URLs.',
},
],
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ console.log(JSON.stringify(object, null, 2));
name: 'image',
type: 'string | Uint8Array | Buffer | ArrayBuffer | URL',
description:
'The image content of the message part. String are base64 encoded content. URLs need to be represented with a URL object'
'The image content of the message part. String are either base64 encoded content, base64 data URLs, or http(s) URLs.'
}
]
}
Expand Down
2 changes: 1 addition & 1 deletion content/docs/07-reference/ai-sdk-core/04-stream-object.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ for await (const partialObject of partialObjectStream) {
name: 'image',
type: 'string | Uint8Array | Buffer | ArrayBuffer | URL',
description:
'The image content of the message part. String are base64 encoded content. URLs need to be represented with a URL object'
'The image content of the message part. String are either base64 encoded content, base64 data URLs, or http(s) URLs.'
}
]
}
Expand Down
2 changes: 1 addition & 1 deletion content/docs/07-reference/ai-sdk-rsc/01-stream-ui.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ A helper function to create a streamable UI from LLM providers. This function is
name: 'image',
type: 'string | Uint8Array | Buffer | ArrayBuffer | URL',
description:
'The image content of the message part. String are base64 encoded content. URLs need to be represented with a URL object',
'The image content of the message part. String are either base64 encoded content, base64 data URLs, or http(s) URLs.',
},
],
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@ async function main() {
{ type: 'text', text: 'Describe the image in detail.' },
{
type: 'image',
image: new URL(
image:
'https://github.com/vercel/ai/blob/main/examples/ai-core/data/comic-cat.png?raw=true',
),
},
],
},
Expand Down
3 changes: 1 addition & 2 deletions examples/ai-core/src/generate-text/google-multimodal-url.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@ async function main() {
{ type: 'text', text: 'Describe the image in detail.' },
{
type: 'image',
image: new URL(
image:
'https://github.com/vercel/ai/blob/main/examples/ai-core/data/comic-cat.png?raw=true',
),
},
],
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@ async function main() {
{ type: 'text', text: 'Describe the image in detail.' },
{
type: 'image',
image: new URL(
image:
'https://github.com/vercel/ai/blob/main/examples/ai-core/data/comic-cat.png?raw=true',
),
},
],
},
Expand Down
3 changes: 1 addition & 2 deletions examples/ai-core/src/generate-text/openai-multimodal-url.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@ async function main() {
{ type: 'text', text: 'Describe the image in detail.' },
{
type: 'image',
image: new URL(
image:
'https://github.com/vercel/ai/blob/main/examples/ai-core/data/comic-cat.png?raw=true',
),
},
],
},
Expand Down
103 changes: 77 additions & 26 deletions packages/core/core/prompt/convert-to-language-model-prompt.test.ts
Original file line number Diff line number Diff line change
@@ -1,34 +1,85 @@
import { convertToLanguageModelMessage } from './convert-to-language-model-prompt';

describe('convertToLanguageModelMessage', () => {
describe('assistant message', () => {
it('should ignore empty text parts', async () => {
const result = convertToLanguageModelMessage({
role: 'assistant',
content: [
{
type: 'text',
text: '',
},
{
type: 'tool-call',
toolName: 'toolName',
toolCallId: 'toolCallId',
args: {},
},
],
describe('user message', () => {
describe('image parts', () => {
it('should convert image string https url to URL object', async () => {
const result = convertToLanguageModelMessage({
role: 'user',
content: [
{
type: 'image',
image: 'https://example.com/image.jpg',
},
],
});

expect(result).toEqual({
role: 'user',
content: [
{
type: 'image',
image: new URL('https://example.com/image.jpg'),
},
],
});
});

it('should convert image string data url to base64 content', async () => {
const result = convertToLanguageModelMessage({
role: 'user',
content: [
{
type: 'image',
image: '',
},
],
});

expect(result).toEqual({
role: 'user',
content: [
{
type: 'image',
image: new Uint8Array([116, 101, 115, 116]),
mimeType: 'image/jpg',
},
],
});
});
});
});

describe('assistant message', () => {
describe('text parts', () => {
it('should ignore empty text parts', async () => {
const result = convertToLanguageModelMessage({
role: 'assistant',
content: [
{
type: 'text',
text: '',
},
{
type: 'tool-call',
toolName: 'toolName',
toolCallId: 'toolCallId',
args: {},
},
],
});

expect(result).toEqual({
role: 'assistant',
content: [
{
type: 'tool-call',
args: {},
toolCallId: 'toolCallId',
toolName: 'toolName',
},
],
expect(result).toEqual({
role: 'assistant',
content: [
{
type: 'tool-call',
args: {},
toolCallId: 'toolCallId',
toolName: 'toolName',
},
],
});
});
});
});
Expand Down
49 changes: 49 additions & 0 deletions packages/core/core/prompt/convert-to-language-model-prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { detectImageMimeType } from '../util/detect-image-mimetype';
import { convertDataContentToUint8Array } from './data-content';
import { ValidatedPrompt } from './get-validated-prompt';
import { InvalidMessageRoleError } from './invalid-message-role-error';
import { getErrorMessage } from '@ai-sdk/provider-utils';

export function convertToLanguageModelPrompt(
prompt: ValidatedPrompt,
Expand Down Expand Up @@ -80,6 +81,54 @@ export function convertToLanguageModelMessage(
};
}

// try to convert string image parts to urls
if (typeof part.image === 'string') {
try {
const url = new URL(part.image);

switch (url.protocol) {
case 'http:':
case 'https:': {
return {
type: 'image',
image: url,
mimeType: part.mimeType,
};
}
case 'data:': {
try {
const [header, base64Content] = part.image.split(',');
const mimeType = header.split(';')[0].split(':')[1];

if (mimeType == null || base64Content == null) {
throw new Error('Invalid data URL format');
}

return {
type: 'image',
image:
convertDataContentToUint8Array(base64Content),
mimeType,
};
} catch (error) {
throw new Error(
`Error processing data URL: ${getErrorMessage(
message,
)}`,
);
}
}
default: {
throw new Error(
`Unsupported URL protocol: ${url.protocol}`,
);
}
}
} catch (_ignored) {
// not a URL
}
}

const imageUint8 = convertDataContentToUint8Array(part.image);

return {
Expand Down

0 comments on commit 0612350

Please sign in to comment.