Skip to content

Commit

Permalink
feat (ai/core): add caching to generated images (#4297)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel authored Jan 7, 2025
1 parent 19a2ce7 commit 8b422ea
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 75 deletions.
5 changes: 5 additions & 0 deletions .changeset/twelve-crabs-fry.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'ai': patch
---

feat (ai/core): add caching to generated images
132 changes: 71 additions & 61 deletions packages/ai/core/generate-image/generate-image.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,76 +43,86 @@ describe('generateImage', () => {
});
});

it('should handle base64 strings', async () => {
const base64String = 'SGVsbG8gV29ybGQ=';
const result = await generateImage({
model: new MockImageModelV1({
doGenerate: async () => ({ images: [base64String] }),
}),
prompt,
});
expect(result.images).toStrictEqual([
{
base64: base64String,
uint8Array: convertBase64ToUint8Array(base64String),
},
]);
});
describe('base64 image data', () => {
it('should return generated images', async () => {
const base64Images = [
'SGVsbG8gV29ybGQ=', // "Hello World" in base64
'VGVzdGluZw==', // "Testing" in base64
];

it('should handle Uint8Arrays', async () => {
const uint8Array = new Uint8Array([72, 101, 108, 108, 111]);
const result = await generateImage({
model: new MockImageModelV1({
doGenerate: async () => ({ images: [uint8Array] }),
}),
prompt,
const result = await generateImage({
model: new MockImageModelV1({
doGenerate: async () => ({ images: base64Images }),
}),
prompt,
});

expect(
result.images.map(image => ({
base64: image.base64,
uint8Array: image.uint8Array,
})),
).toStrictEqual([
{
base64: base64Images[0],
uint8Array: convertBase64ToUint8Array(base64Images[0]),
},
{
base64: base64Images[1],
uint8Array: convertBase64ToUint8Array(base64Images[1]),
},
]);
});
expect(result.images).toStrictEqual([
{
base64: convertUint8ArrayToBase64(uint8Array),
uint8Array: uint8Array,
},
]);
});

it('should return generated images', async () => {
const base64Images = [
'SGVsbG8gV29ybGQ=', // "Hello World" in base64
'VGVzdGluZw==', // "Testing" in base64
];
it('should return the first image', async () => {
const base64Image = 'SGVsbG8gV29ybGQ='; // "Hello World" in base64

const result = await generateImage({
model: new MockImageModelV1({
doGenerate: async () => ({ images: base64Images }),
}),
prompt,
});
const result = await generateImage({
model: new MockImageModelV1({
doGenerate: async () => ({ images: [base64Image, 'base64-image-2'] }),
}),
prompt,
});

expect(result.images).toStrictEqual([
{
base64: base64Images[0],
uint8Array: convertBase64ToUint8Array(base64Images[0]),
},
{
base64: base64Images[1],
uint8Array: convertBase64ToUint8Array(base64Images[1]),
},
]);
expect({
base64: result.image.base64,
uint8Array: result.image.uint8Array,
}).toStrictEqual({
base64: base64Image,
uint8Array: convertBase64ToUint8Array(base64Image),
});
});
});

it('should return the first image', async () => {
const base64Image = 'SGVsbG8gV29ybGQ='; // "Hello World" in base64
describe('uint8array image data', () => {
it('should return generated images', async () => {
const uint8ArrayImages = [
convertBase64ToUint8Array('SGVsbG8gV29ybGQ='), // "Hello World" in base64
convertBase64ToUint8Array('VGVzdGluZw=='), // "Testing" in base64
];

const result = await generateImage({
model: new MockImageModelV1({
doGenerate: async () => ({ images: [base64Image, 'base64-image-2'] }),
}),
prompt,
});
const result = await generateImage({
model: new MockImageModelV1({
doGenerate: async () => ({ images: uint8ArrayImages }),
}),
prompt,
});

expect(result.image).toStrictEqual({
base64: base64Image,
uint8Array: convertBase64ToUint8Array(base64Image),
expect(
result.images.map(image => ({
base64: image.base64,
uint8Array: image.uint8Array,
})),
).toStrictEqual([
{
base64: convertUint8ArrayToBase64(uint8ArrayImages[0]),
uint8Array: uint8ArrayImages[0],
},
{
base64: convertUint8ArrayToBase64(uint8ArrayImages[1]),
uint8Array: uint8ArrayImages[1],
},
]);
});
});
});
45 changes: 31 additions & 14 deletions packages/ai/core/generate-image/generate-image.ts
Original file line number Diff line number Diff line change
Expand Up @@ -121,23 +121,40 @@ class DefaultGenerateImageResult implements GenerateImageResult {
readonly images: Array<GeneratedImage>;

constructor(options: { images: Array<string> | Array<Uint8Array> }) {
this.images = options.images.map(image => {
const isUint8Array = image instanceof Uint8Array;
return {
// lazy conversion to base64 inside get to avoid unnecessary conversion overhead:
get base64() {
return isUint8Array ? convertUint8ArrayToBase64(image) : image;
},

// lazy conversion to uint8array inside get to avoid unnecessary conversion overhead:
get uint8Array() {
return isUint8Array ? image : convertBase64ToUint8Array(image);
},
};
});
this.images = options.images.map(
image => new DefaultGeneratedImage({ imageData: image }),
);
}

get image() {
return this.images[0];
}
}

class DefaultGeneratedImage implements GeneratedImage {
private base64Data: string | undefined;
private uint8ArrayData: Uint8Array | undefined;

constructor({ imageData }: { imageData: string | Uint8Array }) {
const isUint8Array = imageData instanceof Uint8Array;

this.base64Data = isUint8Array ? undefined : imageData;
this.uint8ArrayData = isUint8Array ? imageData : undefined;
}

// lazy conversion with caching to avoid unnecessary conversion overhead:
get base64() {
if (this.base64Data == null) {
this.base64Data = convertUint8ArrayToBase64(this.uint8ArrayData!);
}
return this.base64Data;
}

// lazy conversion with caching to avoid unnecessary conversion overhead:
get uint8Array() {
if (this.uint8ArrayData == null) {
this.uint8ArrayData = convertBase64ToUint8Array(this.base64Data!);
}
return this.uint8ArrayData;
}
}

0 comments on commit 8b422ea

Please sign in to comment.