From 19a2ce7d2c549997e3d98ca66c689bfa7d8d32de Mon Sep 17 00:00:00 2001 From: Walter Korman Date: Tue, 7 Jan 2025 02:13:33 -0800 Subject: [PATCH] feat (ai/core): enhance generateImage, add initial Fireworks image support. (#4266) Co-authored-by: Lars Grammel --- .changeset/cuddly-kiwis-guess.md | 7 + .changeset/perfect-lobsters-guess.md | 5 + .changeset/poor-pets-obey.md | 6 + .../03-ai-sdk-core/35-image-generation.mdx | 77 ++++++- .../01-ai-sdk-core/10-generate-image.mdx | 13 ++ .../01-ai-sdk-providers/01-openai.mdx | 7 +- .../01-ai-sdk-providers/11-google-vertex.mdx | 19 +- .../01-ai-sdk-providers/26-fireworks.mdx | 52 ++++- .../ai-core/src/e2e/feature-test-suite.ts | 28 ++- examples/ai-core/src/e2e/fireworks.test.ts | 1 + .../ai-core/src/generate-image/fireworks.ts | 26 +++ .../src/generate-image/google-vertex.ts | 4 +- examples/ai-core/src/generate-image/openai.ts | 4 - .../generate-image/generate-image.test.ts | 41 +++- .../ai/core/generate-image/generate-image.ts | 47 +++- packages/ai/core/test/mock-image-model-v1.ts | 2 +- .../src/fireworks-image-model.test.ts | 209 ++++++++++++++++++ .../fireworks/src/fireworks-image-model.ts | 144 ++++++++++++ packages/fireworks/src/fireworks-provider.ts | 51 +++-- .../src/google-vertex-image-model.test.ts | 90 +++++++- .../src/google-vertex-image-model.ts | 32 ++- .../openai/src/openai-image-model.test.ts | 14 +- packages/openai/src/openai-image-model.ts | 34 ++- .../src/test/binary-test-server.ts | 67 ++++++ packages/provider-utils/src/test/index.ts | 1 + .../errors/unsupported-functionality-error.ts | 14 +- .../src/image-model/v1/image-model-v1.ts | 31 ++- 27 files changed, 935 insertions(+), 91 deletions(-) create mode 100644 .changeset/cuddly-kiwis-guess.md create mode 100644 .changeset/perfect-lobsters-guess.md create mode 100644 .changeset/poor-pets-obey.md create mode 100644 examples/ai-core/src/generate-image/fireworks.ts create mode 100644 packages/fireworks/src/fireworks-image-model.test.ts create mode 100644 packages/fireworks/src/fireworks-image-model.ts create mode 100644 packages/provider-utils/src/test/binary-test-server.ts diff --git a/.changeset/cuddly-kiwis-guess.md b/.changeset/cuddly-kiwis-guess.md new file mode 100644 index 000000000000..ecf8bb170af1 --- /dev/null +++ b/.changeset/cuddly-kiwis-guess.md @@ -0,0 +1,7 @@ +--- +'@ai-sdk/google-vertex': patch +'@ai-sdk/openai': patch +'ai': patch +--- + +feat (ai/core): add aspectRatio and seed options to generateImage diff --git a/.changeset/perfect-lobsters-guess.md b/.changeset/perfect-lobsters-guess.md new file mode 100644 index 000000000000..2009b380e69c --- /dev/null +++ b/.changeset/perfect-lobsters-guess.md @@ -0,0 +1,5 @@ +--- +'@ai-sdk/provider': patch +--- + +feat (provider): add message option to UnsupportedFunctionalityError diff --git a/.changeset/poor-pets-obey.md b/.changeset/poor-pets-obey.md new file mode 100644 index 000000000000..cd1d577fba9f --- /dev/null +++ b/.changeset/poor-pets-obey.md @@ -0,0 +1,6 @@ +--- +'@ai-sdk/fireworks': patch +'@ai-sdk/provider-utils': patch +--- + +feat (provider/fireworks): Add image model support. diff --git a/content/docs/03-ai-sdk-core/35-image-generation.mdx b/content/docs/03-ai-sdk-core/35-image-generation.mdx index 771857c2c2d0..f6a157f3f97e 100644 --- a/content/docs/03-ai-sdk-core/35-image-generation.mdx +++ b/content/docs/03-ai-sdk-core/35-image-generation.mdx @@ -17,7 +17,6 @@ import { openai } from '@ai-sdk/openai'; const { image } = await generateImage({ model: openai.image('dall-e-3'), prompt: 'Santa Claus driving a Cadillac', - size: '1024x1024', }); ``` @@ -28,11 +27,50 @@ const base64 = image.base64; // base64 image data const uint8Array = image.uint8Array; // Uint8Array image data ``` +### Size and Aspect Ratio + +Depending on the model, you can either specify the size or the aspect ratio. + +##### Size + +The size is specified as a string in the format `{width}x{height}`. +Models only support a few sizes, and the supported sizes are different for each model and provider. + +```tsx highlight={"7"} +import { experimental_generateImage as generateImage } from 'ai'; +import { openai } from '@ai-sdk/openai'; + +const { image } = await generateImage({ + model: openai.image('dall-e-3'), + prompt: 'Santa Claus driving a Cadillac', + size: '1024x1024', +}); +``` + +##### Aspect Ratio + +The aspect ratio is specified as a string in the format `{width}:{height}`. +Models only support a few aspect ratios, and the supported aspect ratios are different for each model and provider. + +```tsx highlight={"7"} +import { experimental_generateImage as generateImage } from 'ai'; +import { vertex } from '@ai-sdk/google-vertex'; + +const { image } = await generateImage({ + model: vertex.image('imagen-3.0-generate-001'), + prompt: 'Santa Claus driving a Cadillac', + aspectRatio: '16:9', +}); +``` + ### Generating Multiple Images `generateImage` also supports generating multiple images at once for models that support it: -```tsx highlight={"4"} +```tsx highlight={"7"} +import { experimental_generateImage as generateImage } from 'ai'; +import { openai } from '@ai-sdk/openai'; + const { images } = await generateImage({ model: openai.image('dall-e-2'), prompt: 'Santa Claus driving a Cadillac', @@ -40,6 +78,22 @@ const { images } = await generateImage({ }); ``` +### Providing a Seed + +You can provide a seed to the `generateImage` function to control the output of the image generation process. +If supported by the model, the same seed will always produce the same image. + +```tsx highlight={"7"} +import { experimental_generateImage as generateImage } from 'ai'; +import { openai } from '@ai-sdk/openai'; + +const { image } = await generateImage({ + model: openai.image('dall-e-3'), + prompt: 'Santa Claus driving a Cadillac', + seed: 1234567890, +}); +``` + ### Provider-specific Settings Image models often have provider- or even model-specific settings. @@ -47,7 +101,10 @@ You can pass such settings to the `generateImage` function using the `providerOptions` parameter. The options for the provider (`openai` in the example below) become request body properties. -```tsx highlight={"5-7"} +```tsx highlight={"9"} +import { experimental_generateImage as generateImage } from 'ai'; +import { openai } from '@ai-sdk/openai'; + const { image } = await generateImage({ model: openai.image('dall-e-3'), prompt: 'Santa Claus driving a Cadillac', @@ -93,9 +150,11 @@ const { image } = await generateImage({ ## Image Models -| Provider | Model | Supported Sizes | -| ----------------------------------------------------------------------- | ------------------------------ | ------------------------------------------------------------------------------------------------------------- | -| [Google Vertex](/providers/ai-sdk-providers/google-vertex#image-models) | `imagen-3.0-generate-001` | See [aspect ratios](https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images#aspect-ratio) | -| [Google Vertex](/providers/ai-sdk-providers/google-vertex#image-models) | `imagen-3.0-fast-generate-001` | See [aspect ratios](https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images#aspect-ratio) | -| [OpenAI](/providers/ai-sdk-providers/openai#image-models) | `dall-e-3` | 1024x1024, 1792x1024, 1024x1792 | -| [OpenAI](/providers/ai-sdk-providers/openai#image-models) | `dall-e-2` | 256x256, 512x512, 1024x1024 | +| Provider | Model | Sizes | Aspect Ratios | +| ----------------------------------------------------------------------- | ---------------------------------------------- | ------------------------------- | ----------------------------------------------- | +| [Google Vertex](/providers/ai-sdk-providers/google-vertex#image-models) | `imagen-3.0-generate-001` | Use aspect ratio | 1:1, 3:4, 4:3, 9:16, 16:9 | +| [Google Vertex](/providers/ai-sdk-providers/google-vertex#image-models) | `imagen-3.0-fast-generate-001` | Use aspect ratio | 1:1, 3:4, 4:3, 9:16, 16:9 | +| [OpenAI](/providers/ai-sdk-providers/openai#image-models) | `dall-e-3` | 1024x1024, 1792x1024, 1024x1792 | use size | +| [OpenAI](/providers/ai-sdk-providers/openai#image-models) | `dall-e-2` | 256x256, 512x512, 1024x1024 | use size | +| [Fireworks](/providers/ai-sdk-providers/fireworks#image-models) | `accounts/fireworks/models/flux-1-dev-fp8` | Use aspect ratio | 1:1, 2:3, 3:2, 4:5, 5:4, 16:9, 9:16, 9:21, 21:9 | +| [Fireworks](/providers/ai-sdk-providers/fireworks#image-models) | `accounts/fireworks/models/flux-1-schnell-fp8` | Use aspect ratio | 1:1, 2:3, 3:2, 4:5, 5:4, 16:9, 9:16, 9:21, 21:9 | diff --git a/content/docs/07-reference/01-ai-sdk-core/10-generate-image.mdx b/content/docs/07-reference/01-ai-sdk-core/10-generate-image.mdx index 52fe2bc03a90..7f2d25ae7926 100644 --- a/content/docs/07-reference/01-ai-sdk-core/10-generate-image.mdx +++ b/content/docs/07-reference/01-ai-sdk-core/10-generate-image.mdx @@ -61,6 +61,19 @@ console.log(images); description: 'Size of the images to generate. Format: `{width}x{height}`.', }, + { + name: 'aspectRatio', + type: 'string', + isOptional: true, + description: + 'Aspect ratio of the images to generate. Format: `{width}:{height}`.', + }, + { + name: 'seed', + type: 'number', + isOptional: true, + description: 'Seed for the image generation.', + }, { name: 'providerOptions', type: 'Record>', diff --git a/content/providers/01-ai-sdk-providers/01-openai.mdx b/content/providers/01-ai-sdk-providers/01-openai.mdx index 4b1637521396..ca2d5a8459d2 100644 --- a/content/providers/01-ai-sdk-providers/01-openai.mdx +++ b/content/providers/01-ai-sdk-providers/01-openai.mdx @@ -618,9 +618,14 @@ using the `.image()` factory method. const model = openai.image('dall-e-3'); ``` + + Dall-E models do not support the `aspectRatio` parameter. Use the `size` + parameter instead. + + ### Model Capabilities -| Model | Supported Sizes | +| Model | Sizes | | ---------- | ------------------------------- | | `dall-e-3` | 1024x1024, 1792x1024, 1024x1792 | | `dall-e-2` | 256x256, 512x512, 1024x1024 | diff --git a/content/providers/01-ai-sdk-providers/11-google-vertex.mdx b/content/providers/01-ai-sdk-providers/11-google-vertex.mdx index b7344cc493c1..22144033d5a8 100644 --- a/content/providers/01-ai-sdk-providers/11-google-vertex.mdx +++ b/content/providers/01-ai-sdk-providers/11-google-vertex.mdx @@ -564,8 +564,6 @@ The following optional settings are available for Google Vertex AI embedding mod You can create [Imagen](https://cloud.google.com/vertex-ai/generative-ai/docs/image/overview) models that call the [Imagen on Vertex AI API](https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images) using the `.image()` factory method. For more on image generation with the AI SDK see [generateImage()](/docs/reference/ai-sdk-core/generate-image). -Note that Imagen does not support an explicit size parameter. Instead, size is driven by the [aspect ratio](https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images#aspect-ratio) of the input image. - ```ts import { vertex } from '@ai-sdk/google-vertex'; import { experimental_generateImage as generateImage } from 'ai'; @@ -573,18 +571,21 @@ import { experimental_generateImage as generateImage } from 'ai'; const { image } = await generateImage({ model: vertex.image('imagen-3.0-generate-001'), prompt: 'A futuristic cityscape at sunset', - providerOptions: { - vertex: { aspectRatio: '16:9' }, - }, + aspectRatio: '16:9', }); ``` + + Imagen models do not support the `size` parameter. Use the `aspectRatio` + parameter instead. + + #### Model Capabilities -| Model | Supported Sizes | -| ------------------------------ | ------------------------------------------------------------------------------------------------------------- | -| `imagen-3.0-generate-001` | See [aspect ratios](https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images#aspect-ratio) | -| `imagen-3.0-fast-generate-001` | See [aspect ratios](https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images#aspect-ratio) | +| Model | Aspect Ratios | +| ------------------------------ | ------------------------- | +| `imagen-3.0-generate-001` | 1:1, 3:4, 4:3, 9:16, 16:9 | +| `imagen-3.0-fast-generate-001` | 1:1, 3:4, 4:3, 9:16, 16:9 | ## Google Vertex Anthropic Provider Usage diff --git a/content/providers/01-ai-sdk-providers/26-fireworks.mdx b/content/providers/01-ai-sdk-providers/26-fireworks.mdx index 3a8aa783d66f..16cc989f73b1 100644 --- a/content/providers/01-ai-sdk-providers/26-fireworks.mdx +++ b/content/providers/01-ai-sdk-providers/26-fireworks.mdx @@ -87,7 +87,7 @@ const { text } = await generateText({ Fireworks language models can also be used in the `streamText` and `streamUI` functions (see [AI SDK Core](/docs/ai-sdk-core) and [AI SDK RSC](/docs/ai-sdk-rsc)). -## Completion Models +### Completion Models You can create models that call the Fireworks completions API using the `.completion()` factory method: @@ -95,17 +95,7 @@ You can create models that call the Fireworks completions API using the `.comple const model = fireworks.completion('accounts/fireworks/models/firefunction-v1'); ``` -## Embedding Models - -You can create models that call the Fireworks embeddings API using the `.textEmbeddingModel()` factory method: - -```ts -const model = fireworks.textEmbeddingModel( - 'accounts/fireworks/models/nomic-embed-text-v1', -); -``` - -## Model Capabilities +### Model Capabilities | Model | Image Input | Object Generation | Tool Usage | Tool Streaming | | ---------------------------------------------------------- | ------------------- | ------------------- | ------------------- | ------------------- | @@ -124,3 +114,41 @@ const model = fireworks.textEmbeddingModel( The table above lists popular models. Please see the [Fireworks models page](https://fireworks.ai/models) for a full list of available models. + +## Embedding Models + +You can create models that call the Fireworks embeddings API using the `.textEmbeddingModel()` factory method: + +```ts +const model = fireworks.textEmbeddingModel( + 'accounts/fireworks/models/nomic-embed-text-v1', +); +``` + +## Image Models + +You can create Fireworks image models using the `.image()` factory method. +For more on image generation with the AI SDK see [generateImage()](/docs/reference/ai-sdk-core/generate-image). + +```ts +import { fireworks } from '@ai-sdk/fireworks'; +import { experimental_generateImage as generateImage } from 'ai'; + +const { image } = await generateImage({ + model: fireworks.image('accounts/fireworks/models/flux-1-dev-fp8'), + prompt: 'A futuristic cityscape at sunset', + aspectRatio: '16:9', +}); +``` + + + Fireworks models do not support the `size` parameter. Use the `aspectRatio` + parameter instead. + + +### Model Capabilities + +| Model | Aspect Ratios | +| ---------------------------------------------- | ----------------------------------------------- | +| `accounts/fireworks/models/flux-1-dev-fp8` | 1:1, 2:3, 3:2, 4:5, 5:4, 16:9, 9:16, 9:21, 21:9 | +| `accounts/fireworks/models/flux-1-schnell-fp8` | 1:1, 2:3, 3:2, 4:5, 5:4, 16:9, 9:16, 9:21, 21:9 | diff --git a/examples/ai-core/src/e2e/feature-test-suite.ts b/examples/ai-core/src/e2e/feature-test-suite.ts index 13820446de62..34701b772334 100644 --- a/examples/ai-core/src/e2e/feature-test-suite.ts +++ b/examples/ai-core/src/e2e/feature-test-suite.ts @@ -1,5 +1,6 @@ import { z } from 'zod'; import { + experimental_generateImage as generateImage, generateText, generateObject, streamText, @@ -10,12 +11,17 @@ import { } from 'ai'; import fs from 'fs'; import { describe, expect, it, vi } from 'vitest'; -import type { EmbeddingModelV1, LanguageModelV1 } from '@ai-sdk/provider'; +import type { + EmbeddingModelV1, + ImageModelV1, + LanguageModelV1, +} from '@ai-sdk/provider'; export interface ModelVariants { invalidModel?: LanguageModelV1; languageModels?: LanguageModelV1[]; embeddingModels?: EmbeddingModelV1[]; + imageModels?: ImageModelV1[]; } export interface TestSuiteOptions { @@ -369,5 +375,25 @@ export function createFeatureTestSuite({ ); } }); + + describe.each(createModelObjects(models.imageModels))( + 'Image Model: $modelId', + ({ model }) => { + it('should generate an image', async () => { + const result = await generateImage({ + model, + prompt: 'A cute cartoon cat', + }); + + // Verify we got a base64 string back + expect(result.image.base64).toBeTruthy(); + expect(typeof result.image.base64).toBe('string'); + + // Check the decoded length is reasonable (at least 10KB) + const decoded = Buffer.from(result.image.base64, 'base64'); + expect(decoded.length).toBeGreaterThan(10 * 1024); + }); + }, + ); }; } diff --git a/examples/ai-core/src/e2e/fireworks.test.ts b/examples/ai-core/src/e2e/fireworks.test.ts index af100d555956..880b062b5c47 100644 --- a/examples/ai-core/src/e2e/fireworks.test.ts +++ b/examples/ai-core/src/e2e/fireworks.test.ts @@ -21,6 +21,7 @@ createFeatureTestSuite({ embeddingModels: [ provider.textEmbeddingModel('nomic-ai/nomic-embed-text-v1.5'), ], + imageModels: [provider.image('accounts/fireworks/models/flux-1-dev-fp8')], }, timeout: 10000, customAssertions: { diff --git a/examples/ai-core/src/generate-image/fireworks.ts b/examples/ai-core/src/generate-image/fireworks.ts new file mode 100644 index 000000000000..109bd15426b0 --- /dev/null +++ b/examples/ai-core/src/generate-image/fireworks.ts @@ -0,0 +1,26 @@ +import 'dotenv/config'; +import { fireworks } from '@ai-sdk/fireworks'; +import { experimental_generateImage as generateImage } from 'ai'; +import fs from 'fs'; + +async function main() { + const { image } = await generateImage({ + model: fireworks.image('accounts/fireworks/models/flux-1-dev-fp8'), + prompt: 'A burrito launched through a tunnel', + aspectRatio: '4:3', + seed: 0, // 0 is random seed for this model + providerOptions: { + fireworks: { + // https://fireworks.ai/models/fireworks/flux-1-dev-fp8/playground + guidance_scale: 10, + num_inference_steps: 10, + }, + }, + }); + + const filename = `image-${Date.now()}.png`; + fs.writeFileSync(filename, image.uint8Array); + console.log(`Image saved to ${filename}`); +} + +main().catch(console.error); diff --git a/examples/ai-core/src/generate-image/google-vertex.ts b/examples/ai-core/src/generate-image/google-vertex.ts index 699ead66c91b..faa8eb458a28 100644 --- a/examples/ai-core/src/generate-image/google-vertex.ts +++ b/examples/ai-core/src/generate-image/google-vertex.ts @@ -7,9 +7,11 @@ async function main() { const { image } = await generateImage({ model: vertex.image('imagen-3.0-generate-001'), prompt: 'A burrito launched through a tunnel', + aspectRatio: '1:1', providerOptions: { vertex: { - aspectRatio: '16:9', + // https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/imagen-api#parameter_list + addWatermark: false, }, }, }); diff --git a/examples/ai-core/src/generate-image/openai.ts b/examples/ai-core/src/generate-image/openai.ts index 98a420e0c918..d035a33586e2 100644 --- a/examples/ai-core/src/generate-image/openai.ts +++ b/examples/ai-core/src/generate-image/openai.ts @@ -7,10 +7,6 @@ async function main() { const { image } = await generateImage({ model: openai.image('dall-e-3'), prompt: 'Santa Claus driving a Cadillac', - size: '1024x1024', - providerOptions: { - openai: { style: 'vivid', quality: 'hd' }, - }, }); const filename = `image-${Date.now()}.png`; diff --git a/packages/ai/core/generate-image/generate-image.test.ts b/packages/ai/core/generate-image/generate-image.test.ts index 51902348a410..7b815970bea9 100644 --- a/packages/ai/core/generate-image/generate-image.test.ts +++ b/packages/ai/core/generate-image/generate-image.test.ts @@ -1,7 +1,10 @@ import { ImageModelV1 } from '@ai-sdk/provider'; import { MockImageModelV1 } from '../test/mock-image-model-v1'; import { generateImage } from './generate-image'; -import { convertBase64ToUint8Array } from '@ai-sdk/provider-utils'; +import { + convertBase64ToUint8Array, + convertUint8ArrayToBase64, +} from '@ai-sdk/provider-utils'; const prompt = 'sunny day at the beach'; @@ -21,6 +24,8 @@ describe('generateImage', () => { }), prompt, size: '1024x1024', + aspectRatio: '16:9', + seed: 12345, providerOptions: { openai: { style: 'vivid' } }, headers: { 'custom-request-header': 'request-header-value' }, abortSignal, @@ -30,12 +35,46 @@ describe('generateImage', () => { n: 1, prompt, size: '1024x1024', + aspectRatio: '16:9', + seed: 12345, providerOptions: { openai: { style: 'vivid' } }, headers: { 'custom-request-header': 'request-header-value' }, abortSignal, }); }); + it('should handle base64 strings', async () => { + const base64String = 'SGVsbG8gV29ybGQ='; + const result = await generateImage({ + model: new MockImageModelV1({ + doGenerate: async () => ({ images: [base64String] }), + }), + prompt, + }); + expect(result.images).toStrictEqual([ + { + base64: base64String, + uint8Array: convertBase64ToUint8Array(base64String), + }, + ]); + }); + + it('should handle Uint8Arrays', async () => { + const uint8Array = new Uint8Array([72, 101, 108, 108, 111]); + const result = await generateImage({ + model: new MockImageModelV1({ + doGenerate: async () => ({ images: [uint8Array] }), + }), + prompt, + }); + expect(result.images).toStrictEqual([ + { + base64: convertUint8ArrayToBase64(uint8Array), + uint8Array: uint8Array, + }, + ]); + }); + it('should return generated images', async () => { const base64Images = [ 'SGVsbG8gV29ybGQ=', // "Hello World" in base64 diff --git a/packages/ai/core/generate-image/generate-image.ts b/packages/ai/core/generate-image/generate-image.ts index c1e11c5ff6da..ac8b1f086cd4 100644 --- a/packages/ai/core/generate-image/generate-image.ts +++ b/packages/ai/core/generate-image/generate-image.ts @@ -1,5 +1,8 @@ import { ImageModelV1, JSONValue } from '@ai-sdk/provider'; -import { convertBase64ToUint8Array } from '@ai-sdk/provider-utils'; +import { + convertBase64ToUint8Array, + convertUint8ArrayToBase64, +} from '@ai-sdk/provider-utils'; import { prepareRetries } from '../prompt/prepare-retries'; import { GeneratedImage, GenerateImageResult } from './generate-image-result'; @@ -10,6 +13,8 @@ Generates images using an image model. @param prompt - The prompt that should be used to generate the image. @param n - Number of images to generate. Default: 1. @param size - Size of the images to generate. Must have the format `{width}x{height}`. +@param aspectRatio - Aspect ratio of the images to generate. Must have the format `{width}:{height}`. +@param seed - Seed for the image generation. @param providerOptions - Additional provider-specific options that are passed through to the provider as body parameters. @param maxRetries - Maximum number of retries. Set to 0 to disable retries. Default: 2. @@ -23,6 +28,8 @@ export async function generateImage({ prompt, n, size, + aspectRatio, + seed, providerOptions, maxRetries: maxRetriesArg, abortSignal, @@ -44,10 +51,20 @@ Number of images to generate. n?: number; /** -Size of the images to generate. Must have the format `{width}x{height}`. +Size of the images to generate. Must have the format `{width}x{height}`. If not provided, the default size will be used. */ size?: `${number}x${number}`; + /** +Aspect ratio of the images to generate. Must have the format `{width}:{height}`. If not provided, the default aspect ratio will be used. + */ + aspectRatio?: `${number}:${number}`; + + /** +Seed for the image generation. If not provided, the default seed will be used. + */ + seed?: number; + /** Additional provider-specific options that are passed through to the provider as body parameters. @@ -91,23 +108,33 @@ Only applicable for HTTP-based providers. abortSignal, headers, size, + aspectRatio, + seed, providerOptions: providerOptions ?? {}, }), ); - return new DefaultGenerateImageResult({ base64Images: images }); + return new DefaultGenerateImageResult({ images }); } class DefaultGenerateImageResult implements GenerateImageResult { readonly images: Array; - constructor(options: { base64Images: Array }) { - this.images = options.base64Images.map(base64 => ({ - base64, - get uint8Array() { - return convertBase64ToUint8Array(this.base64); - }, - })); + constructor(options: { images: Array | Array }) { + this.images = options.images.map(image => { + const isUint8Array = image instanceof Uint8Array; + return { + // lazy conversion to base64 inside get to avoid unnecessary conversion overhead: + get base64() { + return isUint8Array ? convertUint8ArrayToBase64(image) : image; + }, + + // lazy conversion to uint8array inside get to avoid unnecessary conversion overhead: + get uint8Array() { + return isUint8Array ? image : convertBase64ToUint8Array(image); + }, + }; + }); } get image() { diff --git a/packages/ai/core/test/mock-image-model-v1.ts b/packages/ai/core/test/mock-image-model-v1.ts index 74997ad1f488..706b835925a7 100644 --- a/packages/ai/core/test/mock-image-model-v1.ts +++ b/packages/ai/core/test/mock-image-model-v1.ts @@ -3,9 +3,9 @@ import { notImplemented } from './not-implemented'; export class MockImageModelV1 implements ImageModelV1 { readonly specificationVersion = 'v1'; - readonly provider: ImageModelV1['provider']; readonly modelId: ImageModelV1['modelId']; + readonly maxImagesPerCall = 1; doGenerate: ImageModelV1['doGenerate']; diff --git a/packages/fireworks/src/fireworks-image-model.test.ts b/packages/fireworks/src/fireworks-image-model.test.ts new file mode 100644 index 000000000000..119bff09b6c2 --- /dev/null +++ b/packages/fireworks/src/fireworks-image-model.test.ts @@ -0,0 +1,209 @@ +import { APICallError } from '@ai-sdk/provider'; +import { BinaryTestServer } from '@ai-sdk/provider-utils/test'; +import { FireworksImageModel } from './fireworks-image-model'; +import { describe, it, expect } from 'vitest'; +import { UnsupportedFunctionalityError } from '@ai-sdk/provider'; + +const prompt = 'A cute baby sea otter'; + +const model = new FireworksImageModel( + 'accounts/fireworks/models/flux-1-dev-fp8', + { + provider: 'fireworks', + baseURL: 'https://api.example.com', + headers: () => ({ 'api-key': 'test-key' }), + }, +); + +describe('FireworksImageModel', () => { + describe('doGenerate', () => { + const server = new BinaryTestServer( + 'https://api.example.com/workflows/accounts/fireworks/models/flux-1-dev-fp8/text_to_image', + ); + + server.setupTestEnvironment(); + + function prepareBinaryResponse() { + const mockImageBuffer = Buffer.from('mock-image-data'); + server.responseBody = mockImageBuffer; + } + + it('should pass the correct parameters including aspect ratio and seed', async () => { + prepareBinaryResponse(); + + await model.doGenerate({ + prompt, + n: 1, + size: undefined, + aspectRatio: '16:9', + seed: 42, + providerOptions: { fireworks: { additional_param: 'value' } }, + }); + + expect(await server.getRequestBodyJson()).toStrictEqual({ + prompt, + aspect_ratio: '16:9', + seed: 42, + additional_param: 'value', + }); + }); + + it('should pass headers', async () => { + prepareBinaryResponse(); + + const modelWithHeaders = new FireworksImageModel( + 'accounts/fireworks/models/flux-1-dev-fp8', + { + provider: 'fireworks', + baseURL: 'https://api.example.com', + headers: () => ({ + 'Custom-Provider-Header': 'provider-header-value', + }), + }, + ); + + await modelWithHeaders.doGenerate({ + prompt, + n: 1, + size: undefined, + aspectRatio: undefined, + seed: undefined, + providerOptions: {}, + headers: { + 'Custom-Request-Header': 'request-header-value', + }, + }); + + const requestHeaders = await server.getRequestHeaders(); + + expect(requestHeaders).toStrictEqual({ + 'content-type': 'application/json', + 'custom-provider-header': 'provider-header-value', + 'custom-request-header': 'request-header-value', + }); + }); + + it('should return binary image data', async () => { + const mockImageBuffer = Buffer.from('mock-image-data'); + server.responseBody = mockImageBuffer; + + const result = await model.doGenerate({ + prompt, + n: 1, + size: undefined, + aspectRatio: undefined, + seed: undefined, + providerOptions: {}, + }); + + expect(result.images).toHaveLength(1); + expect(result.images[0]).toBeInstanceOf(Uint8Array); + expect(Buffer.from(result.images[0])).toEqual(mockImageBuffer); + }); + + it('should handle empty response body', async () => { + server.responseBody = null; + + await expect( + model.doGenerate({ + prompt, + n: 1, + size: undefined, + aspectRatio: undefined, + seed: undefined, + providerOptions: {}, + }), + ).rejects.toThrow(APICallError); + + await expect( + model.doGenerate({ + prompt, + n: 1, + size: undefined, + aspectRatio: undefined, + seed: undefined, + providerOptions: {}, + }), + ).rejects.toMatchObject({ + message: 'Response body is empty', + statusCode: 200, + url: 'https://api.example.com/workflows/accounts/fireworks/models/flux-1-dev-fp8/text_to_image', + requestBodyValues: { + prompt: 'A cute baby sea otter', + }, + }); + }); + + it('should handle API errors', async () => { + server.responseStatus = 400; + server.responseBody = Buffer.from('Bad Request'); + + await expect( + model.doGenerate({ + prompt, + n: 1, + size: undefined, + aspectRatio: undefined, + seed: undefined, + providerOptions: {}, + }), + ).rejects.toThrow(APICallError); + + await expect( + model.doGenerate({ + prompt, + n: 1, + size: undefined, + aspectRatio: undefined, + seed: undefined, + providerOptions: {}, + }), + ).rejects.toMatchObject({ + message: 'Bad Request', + statusCode: 400, + url: 'https://api.example.com/workflows/accounts/fireworks/models/flux-1-dev-fp8/text_to_image', + requestBodyValues: { + prompt: 'A cute baby sea otter', + }, + responseBody: 'Bad Request', + }); + }); + + it('should throw error when requesting more than one image', async () => { + await expect( + model.doGenerate({ + prompt, + n: 2, + size: undefined, + aspectRatio: undefined, + seed: undefined, + providerOptions: {}, + }), + ).rejects.toThrowError( + new UnsupportedFunctionalityError({ + functionality: 'generate multiple images', + message: `This model does not support generating more than 1 images at a time.`, + }), + ); + }); + + it('should throw error when specifying image size', async () => { + await expect( + model.doGenerate({ + prompt, + n: 1, + size: '512x512', + aspectRatio: undefined, + seed: undefined, + providerOptions: {}, + }), + ).rejects.toThrowError( + new UnsupportedFunctionalityError({ + functionality: 'image size', + message: + 'This model does not support the `size` option. Use `aspectRatio` instead.', + }), + ); + }); + }); +}); diff --git a/packages/fireworks/src/fireworks-image-model.ts b/packages/fireworks/src/fireworks-image-model.ts new file mode 100644 index 000000000000..8a880cb3dd39 --- /dev/null +++ b/packages/fireworks/src/fireworks-image-model.ts @@ -0,0 +1,144 @@ +import { + APICallError, + ImageModelV1, + UnsupportedFunctionalityError, +} from '@ai-sdk/provider'; +import { + combineHeaders, + extractResponseHeaders, + FetchFunction, + postJsonToApi, + ResponseHandler, +} from '@ai-sdk/provider-utils'; + +// https://fireworks.ai/models?type=image +export type FireworksImageModelId = + | 'accounts/fireworks/models/flux-1-dev-fp8' + | 'accounts/fireworks/models/flux-1-schnell-fp8' + | (string & {}); + +interface FireworksImageModelConfig { + provider: string; + baseURL: string; + headers: () => Record; + fetch?: FetchFunction; +} + +const createBinaryResponseHandler = + (): ResponseHandler => + async ({ response, url, requestBodyValues }) => { + const responseHeaders = extractResponseHeaders(response); + + if (!response.body) { + throw new APICallError({ + message: 'Response body is empty', + url, + requestBodyValues, + statusCode: response.status, + responseHeaders, + responseBody: undefined, + }); + } + + try { + const buffer = await response.arrayBuffer(); + return { + responseHeaders, + value: buffer, + }; + } catch (error) { + throw new APICallError({ + message: 'Failed to read response as array buffer', + url, + requestBodyValues, + statusCode: response.status, + responseHeaders, + responseBody: undefined, + cause: error, + }); + } + }; + +const statusCodeErrorResponseHandler: ResponseHandler = async ({ + response, + url, + requestBodyValues, +}) => { + const responseHeaders = extractResponseHeaders(response); + const responseBody = await response.text(); + + return { + responseHeaders, + value: new APICallError({ + message: response.statusText, + url, + requestBodyValues: requestBodyValues as Record, + statusCode: response.status, + responseHeaders, + responseBody, + }), + }; +}; + +export class FireworksImageModel implements ImageModelV1 { + readonly specificationVersion = 'v1'; + + get provider(): string { + return this.config.provider; + } + + readonly maxImagesPerCall = 1; + + constructor( + readonly modelId: FireworksImageModelId, + private config: FireworksImageModelConfig, + ) {} + + async doGenerate({ + prompt, + n, + size, + aspectRatio, + seed, + providerOptions, + headers, + abortSignal, + }: Parameters[0]): Promise< + Awaited> + > { + if (size != null) { + throw new UnsupportedFunctionalityError({ + functionality: 'image size', + message: + 'This model does not support the `size` option. Use `aspectRatio` instead.', + }); + } + + if (n > this.maxImagesPerCall) { + throw new UnsupportedFunctionalityError({ + functionality: `generate more than ${this.maxImagesPerCall} images`, + message: `This model does not support generating more than ${this.maxImagesPerCall} images at a time.`, + }); + } + + const url = `${this.config.baseURL}/workflows/${this.modelId}/text_to_image`; + const body = { + prompt, + aspect_ratio: aspectRatio, + seed, + ...(providerOptions.fireworks ?? {}), + }; + + const { value: response } = await postJsonToApi({ + url, + headers: combineHeaders(this.config.headers(), headers), + body, + failedResponseHandler: statusCodeErrorResponseHandler, + successfulResponseHandler: createBinaryResponseHandler(), + abortSignal, + fetch: this.config.fetch, + }); + + return { images: [new Uint8Array(response)] }; + } +} diff --git a/packages/fireworks/src/fireworks-provider.ts b/packages/fireworks/src/fireworks-provider.ts index d28ddca8922b..92d1501a3405 100644 --- a/packages/fireworks/src/fireworks-provider.ts +++ b/packages/fireworks/src/fireworks-provider.ts @@ -4,7 +4,11 @@ import { OpenAICompatibleEmbeddingModel, ProviderErrorStructure, } from '@ai-sdk/openai-compatible'; -import { EmbeddingModelV1, LanguageModelV1 } from '@ai-sdk/provider'; +import { + EmbeddingModelV1, + ImageModelV1, + LanguageModelV1, +} from '@ai-sdk/provider'; import { FetchFunction, loadApiKey, @@ -23,6 +27,10 @@ import { FireworksEmbeddingModelId, FireworksEmbeddingSettings, } from './fireworks-embedding-settings'; +import { + FireworksImageModel, + FireworksImageModelId, +} from './fireworks-image-model'; export type FireworksErrorData = z.infer; @@ -87,14 +95,19 @@ Creates a text embedding model for text generation. modelId: FireworksEmbeddingModelId, settings?: FireworksEmbeddingSettings, ): EmbeddingModelV1; + + /** +Creates a model for image generation. +*/ + image(modelId: FireworksImageModelId): ImageModelV1; } +const defaultBaseURL = 'https://api.fireworks.ai/inference/v1'; + export function createFireworks( options: FireworksProviderSettings = {}, ): FireworksProvider { - const baseURL = withoutTrailingSlash( - options.baseURL ?? 'https://api.fireworks.ai/inference/v1', - ); + const baseURL = withoutTrailingSlash(options.baseURL ?? defaultBaseURL); const getHeaders = () => ({ Authorization: `Bearer ${loadApiKey({ apiKey: options.apiKey, @@ -105,11 +118,10 @@ export function createFireworks( }); interface CommonModelConfig { - provider: `fireworks.${string}`; + provider: string; url: ({ path }: { path: string }) => string; headers: () => Record; fetch?: FetchFunction; - errorStructure?: ProviderErrorStructure; } const getCommonModelConfig = (modelType: string): CommonModelConfig => ({ @@ -117,7 +129,6 @@ export function createFireworks( url: ({ path }) => `${baseURL}${path}`, headers: getHeaders, fetch: options.fetch, - errorStructure: fireworksErrorStructure, }); const createChatModel = ( @@ -126,6 +137,7 @@ export function createFireworks( ) => { return new OpenAICompatibleChatLanguageModel(modelId, settings, { ...getCommonModelConfig('chat'), + errorStructure: fireworksErrorStructure, defaultObjectGenerationMode: 'json', }); }; @@ -134,21 +146,25 @@ export function createFireworks( modelId: FireworksCompletionModelId, settings: FireworksCompletionSettings = {}, ) => - new OpenAICompatibleCompletionLanguageModel( - modelId, - settings, - getCommonModelConfig('completion'), - ); + new OpenAICompatibleCompletionLanguageModel(modelId, settings, { + ...getCommonModelConfig('completion'), + errorStructure: fireworksErrorStructure, + }); const createTextEmbeddingModel = ( modelId: FireworksEmbeddingModelId, settings: FireworksEmbeddingSettings = {}, ) => - new OpenAICompatibleEmbeddingModel( - modelId, - settings, - getCommonModelConfig('embedding'), - ); + new OpenAICompatibleEmbeddingModel(modelId, settings, { + ...getCommonModelConfig('embedding'), + errorStructure: fireworksErrorStructure, + }); + + const createImageModel = (modelId: FireworksImageModelId) => + new FireworksImageModel(modelId, { + ...getCommonModelConfig('image'), + baseURL: baseURL ?? defaultBaseURL, + }); const provider = ( modelId: FireworksChatModelId, @@ -158,6 +174,7 @@ export function createFireworks( provider.completionModel = createCompletionModel; provider.chatModel = createChatModel; provider.textEmbeddingModel = createTextEmbeddingModel; + provider.image = createImageModel; return provider as FireworksProvider; } diff --git a/packages/google-vertex/src/google-vertex-image-model.test.ts b/packages/google-vertex/src/google-vertex-image-model.test.ts index e45ded4f6fd8..f28295cfd7ac 100644 --- a/packages/google-vertex/src/google-vertex-image-model.test.ts +++ b/packages/google-vertex/src/google-vertex-image-model.test.ts @@ -1,6 +1,7 @@ import { JsonTestServer } from '@ai-sdk/provider-utils/test'; +import { describe, expect, it } from 'vitest'; import { GoogleVertexImageModel } from './google-vertex-image-model'; -import { describe, it, expect, vi } from 'vitest'; +import { UnsupportedFunctionalityError } from '@ai-sdk/provider'; const prompt = 'A cute baby sea otter'; @@ -34,6 +35,8 @@ describe('GoogleVertexImageModel', () => { prompt, n: 2, size: undefined, + aspectRatio: undefined, + seed: undefined, providerOptions: { vertex: { aspectRatio: '1:1' } }, }); @@ -64,6 +67,8 @@ describe('GoogleVertexImageModel', () => { prompt, n: 2, size: undefined, + aspectRatio: undefined, + seed: undefined, providerOptions: {}, headers: { 'Custom-Request-Header': 'request-header-value', @@ -86,6 +91,8 @@ describe('GoogleVertexImageModel', () => { prompt, n: 2, size: undefined, + aspectRatio: undefined, + seed: undefined, providerOptions: {}, }); @@ -103,9 +110,17 @@ describe('GoogleVertexImageModel', () => { prompt: 'test prompt', n: 1, size: '1024x1024', + aspectRatio: undefined, + seed: undefined, providerOptions: {}, }), - ).rejects.toThrow(/Google Vertex does not support the `size` option./); + ).rejects.toThrow( + new UnsupportedFunctionalityError({ + functionality: 'image size', + message: + 'This model does not support the `size` option. Use `aspectRatio` instead.', + }), + ); }); it('sends aspect ratio in the request', async () => { @@ -115,6 +130,8 @@ describe('GoogleVertexImageModel', () => { prompt: 'test prompt', n: 1, size: undefined, + aspectRatio: undefined, + seed: undefined, providerOptions: { vertex: { aspectRatio: '16:9', @@ -130,5 +147,74 @@ describe('GoogleVertexImageModel', () => { }, }); }); + + it('should pass aspect ratio directly when specified', async () => { + prepareJsonResponse(); + + await model.doGenerate({ + prompt: 'test prompt', + n: 1, + size: undefined, + aspectRatio: '16:9', + seed: undefined, + providerOptions: {}, + }); + + expect(await server.getRequestBodyJson()).toStrictEqual({ + instances: [{ prompt: 'test prompt' }], + parameters: { + sampleCount: 1, + aspectRatio: '16:9', + }, + }); + }); + + it('should pass seed directly when specified', async () => { + prepareJsonResponse(); + + await model.doGenerate({ + prompt: 'test prompt', + n: 1, + size: undefined, + aspectRatio: undefined, + seed: 42, + providerOptions: {}, + }); + + expect(await server.getRequestBodyJson()).toStrictEqual({ + instances: [{ prompt: 'test prompt' }], + parameters: { + sampleCount: 1, + seed: 42, + }, + }); + }); + + it('should combine aspectRatio, seed and provider options', async () => { + prepareJsonResponse(); + + await model.doGenerate({ + prompt: 'test prompt', + n: 1, + size: undefined, + aspectRatio: '1:1', + seed: 42, + providerOptions: { + vertex: { + temperature: 0.8, + }, + }, + }); + + expect(await server.getRequestBodyJson()).toStrictEqual({ + instances: [{ prompt: 'test prompt' }], + parameters: { + sampleCount: 1, + aspectRatio: '1:1', + seed: 42, + temperature: 0.8, + }, + }); + }); }); }); diff --git a/packages/google-vertex/src/google-vertex-image-model.ts b/packages/google-vertex/src/google-vertex-image-model.ts index f93bdf6f1065..507d799b82f4 100644 --- a/packages/google-vertex/src/google-vertex-image-model.ts +++ b/packages/google-vertex/src/google-vertex-image-model.ts @@ -1,9 +1,9 @@ -import { ImageModelV1, JSONValue } from '@ai-sdk/provider'; +import { ImageModelV1, UnsupportedFunctionalityError } from '@ai-sdk/provider'; import { Resolvable, - postJsonToApi, combineHeaders, createJsonResponseHandler, + postJsonToApi, resolve, } from '@ai-sdk/provider-utils'; import { z } from 'zod'; @@ -25,6 +25,9 @@ interface GoogleVertexImageModelConfig { export class GoogleVertexImageModel implements ImageModelV1 { readonly specificationVersion = 'v1'; + // https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/imagen-api#parameter_list + readonly maxImagesPerCall = 4; + get provider(): string { return this.config.provider; } @@ -38,24 +41,35 @@ export class GoogleVertexImageModel implements ImageModelV1 { prompt, n, size, + aspectRatio, + seed, providerOptions, headers, abortSignal, }: Parameters[0]): Promise< Awaited> > { - if (size) { - throw new Error( - 'Google Vertex does not support the `size` option. Use ' + - '`providerOptions.vertex.aspectRatio` instead. See ' + - 'https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images#aspect-ratio', - ); + if (size != null) { + throw new UnsupportedFunctionalityError({ + functionality: 'image size', + message: + 'This model does not support the `size` option. Use `aspectRatio` instead.', + }); + } + + if (n > this.maxImagesPerCall) { + throw new UnsupportedFunctionalityError({ + functionality: `generate more than ${this.maxImagesPerCall} images`, + message: `This model does not support generating more than ${this.maxImagesPerCall} images at a time.`, + }); } const body = { instances: [{ prompt }], parameters: { sampleCount: n, + ...(aspectRatio != null ? { aspectRatio } : {}), + ...(seed != null ? { seed } : {}), ...(providerOptions.vertex ?? {}), }, }; @@ -68,7 +82,7 @@ export class GoogleVertexImageModel implements ImageModelV1 { successfulResponseHandler: createJsonResponseHandler( vertexImageResponseSchema, ), - abortSignal: abortSignal, + abortSignal, fetch: this.config.fetch, }); diff --git a/packages/openai/src/openai-image-model.test.ts b/packages/openai/src/openai-image-model.test.ts index 2f3cc6210028..3e70e4f67a72 100644 --- a/packages/openai/src/openai-image-model.test.ts +++ b/packages/openai/src/openai-image-model.test.ts @@ -34,15 +34,17 @@ describe('doGenerate', () => { await model.doGenerate({ prompt, - n: 2, + n: 1, size: '1024x1024', + aspectRatio: undefined, + seed: undefined, providerOptions: { openai: { style: 'vivid' } }, }); expect(await server.getRequestBodyJson()).toStrictEqual({ model: 'dall-e-3', prompt, - n: 2, + n: 1, size: '1024x1024', style: 'vivid', response_format: 'b64_json', @@ -63,8 +65,10 @@ describe('doGenerate', () => { await provider.image('dall-e-3').doGenerate({ prompt, - n: 2, + n: 1, size: '1024x1024', + aspectRatio: undefined, + seed: undefined, providerOptions: { openai: { style: 'vivid' } }, headers: { 'Custom-Request-Header': 'request-header-value', @@ -88,8 +92,10 @@ describe('doGenerate', () => { const result = await model.doGenerate({ prompt, - n: 2, + n: 1, size: undefined, + aspectRatio: undefined, + seed: undefined, providerOptions: {}, }); diff --git a/packages/openai/src/openai-image-model.ts b/packages/openai/src/openai-image-model.ts index b9e9b7e9d79b..409ce120d8ca 100644 --- a/packages/openai/src/openai-image-model.ts +++ b/packages/openai/src/openai-image-model.ts @@ -1,4 +1,4 @@ -import { ImageModelV1 } from '@ai-sdk/provider'; +import { ImageModelV1, UnsupportedFunctionalityError } from '@ai-sdk/provider'; import { combineHeaders, createJsonResponseHandler, @@ -10,12 +10,22 @@ import { openaiFailedResponseHandler } from './openai-error'; export type OpenAIImageModelId = 'dall-e-3' | 'dall-e-2' | (string & {}); +// https://platform.openai.com/docs/guides/images +const modelMaxImagesPerCall: Record = { + 'dall-e-3': 1, + 'dall-e-2': 10, +}; + export class OpenAIImageModel implements ImageModelV1 { readonly specificationVersion = 'v1'; readonly modelId: OpenAIImageModelId; private readonly config: OpenAIConfig; + get maxImagesPerCall(): number { + return modelMaxImagesPerCall[this.modelId] ?? 1; + } + get provider(): string { return this.config.provider; } @@ -29,12 +39,34 @@ export class OpenAIImageModel implements ImageModelV1 { prompt, n, size, + aspectRatio, + seed, providerOptions, headers, abortSignal, }: Parameters[0]): Promise< Awaited> > { + if (aspectRatio != null) { + throw new UnsupportedFunctionalityError({ + functionality: 'image aspect ratio', + message: + 'This model does not support aspect ratio. Use `size` instead.', + }); + } + + if (seed != null) { + throw new UnsupportedFunctionalityError({ + functionality: 'image seed', + }); + } + + if (n > this.maxImagesPerCall) { + throw new UnsupportedFunctionalityError({ + functionality: `generate more than ${this.maxImagesPerCall} images`, + }); + } + const { value: response } = await postJsonToApi({ url: this.config.url({ path: '/images/generations', diff --git a/packages/provider-utils/src/test/binary-test-server.ts b/packages/provider-utils/src/test/binary-test-server.ts new file mode 100644 index 000000000000..ead68dd94609 --- /dev/null +++ b/packages/provider-utils/src/test/binary-test-server.ts @@ -0,0 +1,67 @@ +import { HttpResponse, http } from 'msw'; +import { SetupServer, setupServer } from 'msw/node'; + +export class BinaryTestServer { + readonly server: SetupServer; + + responseBody: Buffer | null = null; + responseHeaders: Record = {}; + responseStatus = 200; + + request: Request | undefined; + + constructor(url: string) { + this.server = setupServer( + http.post(url, ({ request }) => { + this.request = request; + + if (this.responseBody === null) { + return new HttpResponse(null, { status: this.responseStatus }); + } + + return new HttpResponse(this.responseBody, { + status: this.responseStatus, + headers: this.responseHeaders, + }); + }), + ); + } + + async getRequestBodyJson() { + expect(this.request).toBeDefined(); + return JSON.parse(await this.request!.text()); + } + + async getRequestHeaders() { + expect(this.request).toBeDefined(); + const requestHeaders = this.request!.headers; + + // convert headers to object for easier comparison + const headersObject: Record = {}; + requestHeaders.forEach((value, key) => { + headersObject[key] = value; + }); + + return headersObject; + } + + async getRequestUrlSearchParams() { + expect(this.request).toBeDefined(); + return new URL(this.request!.url).searchParams; + } + + async getRequestUrl() { + expect(this.request).toBeDefined(); + return new URL(this.request!.url).toString(); + } + + setupTestEnvironment() { + beforeAll(() => this.server.listen()); + beforeEach(() => { + this.responseBody = null; + this.request = undefined; + }); + afterEach(() => this.server.resetHandlers()); + afterAll(() => this.server.close()); + } +} diff --git a/packages/provider-utils/src/test/index.ts b/packages/provider-utils/src/test/index.ts index d88f31089f86..2af39f2f44e9 100644 --- a/packages/provider-utils/src/test/index.ts +++ b/packages/provider-utils/src/test/index.ts @@ -1,3 +1,4 @@ +export * from './binary-test-server'; export * from './convert-array-to-async-iterable'; export * from './convert-array-to-readable-stream'; export * from './convert-async-iterable-to-array'; diff --git a/packages/provider/src/errors/unsupported-functionality-error.ts b/packages/provider/src/errors/unsupported-functionality-error.ts index 3efb22e9354c..679240d593f6 100644 --- a/packages/provider/src/errors/unsupported-functionality-error.ts +++ b/packages/provider/src/errors/unsupported-functionality-error.ts @@ -9,12 +9,14 @@ export class UnsupportedFunctionalityError extends AISDKError { readonly functionality: string; - constructor({ functionality }: { functionality: string }) { - super({ - name, - message: `'${functionality}' functionality not supported.`, - }); - + constructor({ + functionality, + message = `'${functionality}' functionality not supported.`, + }: { + functionality: string; + message?: string; + }) { + super({ name, message }); this.functionality = functionality; } diff --git a/packages/provider/src/image-model/v1/image-model-v1.ts b/packages/provider/src/image-model/v1/image-model-v1.ts index 4cac218766e7..138f916f2efd 100644 --- a/packages/provider/src/image-model/v1/image-model-v1.ts +++ b/packages/provider/src/image-model/v1/image-model-v1.ts @@ -23,6 +23,12 @@ Provider-specific model ID for logging purposes. */ readonly modelId: string; + /** +Limit of how many images can be generated in a single API call. +If undefined, we will max generate one image per call. + */ + readonly maxImagesPerCall: number | undefined; + /** Generates an array of images. */ @@ -38,10 +44,25 @@ Number of images to generate. n: number; /** -Size of the images to generate. Must have the format `{width}x{height}`. +Size of the images to generate. +Must have the format `{width}x{height}`. +`undefined` will use the provider's default size. */ size: `${number}x${number}` | undefined; + /** +Aspect ratio of the images to generate. +Must have the format `{width}:{height}`. +`undefined` will use the provider's default aspect ratio. + */ + aspectRatio: `${number}:${number}` | undefined; + + /** +Seed for the image generation. +`undefined` will use the provider's default seed. + */ + seed: number | undefined; + /** Additional provider-specific options that are passed through to the provider as body parameters. @@ -70,8 +91,12 @@ Abort signal for cancelling the operation. headers?: Record; }): PromiseLike<{ /** -Generated images as base64 encoded strings. +Generated images as base64 encoded strings or binary data. +The images should be returned without any unnecessary conversion. +If the API returns base64 encoded strings, the images should be returned +as base64 encoded strings. If the API returns binary data, the images should +be returned as binary data. */ - images: Array; + images: Array | Array; }>; };