Skip to content

Commit

Permalink
feat(server): add copilot prompts management api (#7082)
Browse files Browse the repository at this point in the history
  • Loading branch information
forehalo committed May 28, 2024
1 parent 1a269a4 commit 4b30fbc
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 5 deletions.
7 changes: 6 additions & 1 deletion packages/backend/server/src/plugins/copilot/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ import {
OpenAIProvider,
registerCopilotProvider,
} from './providers';
import { CopilotResolver, UserCopilotResolver } from './resolver';
import {
CopilotResolver,
PromptsManagementResolver,
UserCopilotResolver,
} from './resolver';
import { ChatSessionService } from './session';
import { CopilotStorage } from './storage';

Expand All @@ -34,6 +38,7 @@ registerCopilotProvider(OpenAIProvider);
PromptService,
CopilotProviderService,
CopilotStorage,
PromptsManagementResolver,
],
controllers: [CopilotController],
contributesTo: ServerFeature.Copilot,
Expand Down
22 changes: 21 additions & 1 deletion packages/backend/server/src/plugins/copilot/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,32 @@ export class PromptService {
* list prompt names
* @returns prompt names
*/
async list() {
async listNames() {
return this.db.aiPrompt
.findMany({ select: { name: true } })
.then(prompts => Array.from(new Set(prompts.map(p => p.name))));
}

async list() {
return this.db.aiPrompt.findMany({
select: {
name: true,
action: true,
model: true,
messages: {
select: {
role: true,
content: true,
params: true,
},
orderBy: {
idx: 'asc',
},
},
},
});
}

/**
* get prompt messages by prompt name
* @param name prompt name
Expand Down
89 changes: 89 additions & 0 deletions packages/backend/server/src/plugins/copilot/resolver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,17 @@ import {
Mutation,
ObjectType,
Parent,
Query,
registerEnumType,
ResolveField,
Resolver,
} from '@nestjs/graphql';
import { AiPromptRole } from '@prisma/client';
import { GraphQLJSON, SafeIntResolver } from 'graphql-scalars';
import GraphQLUpload from 'graphql-upload/GraphQLUpload.mjs';

import { CurrentUser } from '../../core/auth';
import { Admin } from '../../core/common';
import { UserType } from '../../core/user';
import { PermissionService } from '../../core/workspaces/permission';
import {
Expand All @@ -25,6 +28,7 @@ import {
Throttle,
TooManyRequestsException,
} from '../../fundamentals';
import { PromptService } from './prompt';
import { ChatSessionService } from './session';
import { CopilotStorage } from './storage';
import {
Expand Down Expand Up @@ -152,6 +156,40 @@ class CopilotQuotaType {
used!: number;
}

registerEnumType(AiPromptRole, {
name: 'CopilotPromptMessageRole',
});

@InputType('CopilotPromptMessageInput')
@ObjectType()
class CopilotPromptMessageType {
@Field(() => AiPromptRole)
role!: AiPromptRole;

@Field(() => String)
content!: string;

@Field(() => GraphQLJSON, { nullable: true })
params!: Record<string, string> | null;
}

registerEnumType(AvailableModels, { name: 'CopilotModels' });

@ObjectType()
class CopilotPromptType {
@Field(() => String)
name!: string;

@Field(() => AvailableModels)
model!: AvailableModels;

@Field(() => String, { nullable: true })
action!: string | null;

@Field(() => [CopilotPromptMessageType])
messages!: CopilotPromptMessageType[];
}

// ================== Resolver ==================

@ObjectType('Copilot')
Expand Down Expand Up @@ -370,3 +408,54 @@ export class UserCopilotResolver {
return { workspaceId };
}
}

@InputType()
class CreateCopilotPromptInput {
@Field(() => String)
name!: string;

@Field(() => AvailableModels)
model!: AvailableModels;

@Field(() => String, { nullable: true })
action!: string | null;

@Field(() => [CopilotPromptMessageType])
messages!: CopilotPromptMessageType[];
}

@Admin()
@Resolver(() => String)
export class PromptsManagementResolver {
constructor(private readonly promptService: PromptService) {}

@Query(() => [CopilotPromptType], {
description: 'List all copilot prompts',
})
async listCopilotPrompts() {
return this.promptService.list();
}

@Mutation(() => CopilotPromptType, {
description: 'Create a copilot prompt',
})
async createCopilotPrompt(
@Args({ type: () => CreateCopilotPromptInput, name: 'input' })
input: CreateCopilotPromptInput
) {
await this.promptService.set(input.name, input.model, input.messages);
return this.promptService.get(input.name);
}

@Mutation(() => CopilotPromptType, {
description: 'Update a copilot prompt',
})
async updateCopilotPrompt(
@Args('name') name: string,
@Args('messages', { type: () => [CopilotPromptMessageType] })
messages: CopilotPromptMessageType[]
) {
await this.promptService.update(name, messages);
return this.promptService.get(name);
}
}
54 changes: 54 additions & 0 deletions packages/backend/server/src/schema.gql
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,44 @@ type CopilotHistories {
tokens: Int!
}

enum CopilotModels {
DallE3
Gpt4Omni
Gpt4TurboPreview
Gpt4VisionPreview
Gpt35Turbo
TextEmbedding3Large
TextEmbedding3Small
TextEmbeddingAda002
TextModerationLatest
TextModerationStable
}

input CopilotPromptMessageInput {
content: String!
params: JSON
role: CopilotPromptMessageRole!
}

enum CopilotPromptMessageRole {
assistant
system
user
}

type CopilotPromptMessageType {
content: String!
params: JSON
role: CopilotPromptMessageRole!
}

type CopilotPromptType {
action: String
messages: [CopilotPromptMessageType!]!
model: CopilotModels!
name: String!
}

type CopilotQuota {
limit: SafeInt
used: SafeInt!
Expand Down Expand Up @@ -63,6 +101,13 @@ input CreateCheckoutSessionInput {
successCallbackLink: String!
}

input CreateCopilotPromptInput {
action: String
messages: [CopilotPromptMessageInput!]!
model: CopilotModels!
name: String!
}

type CredentialsRequirementType {
password: PasswordLimitsType!
}
Expand Down Expand Up @@ -206,6 +251,9 @@ type Mutation {
"""Create a chat message"""
createCopilotMessage(options: CreateChatMessageInput!): String!

"""Create a copilot prompt"""
createCopilotPrompt(input: CreateCopilotPromptInput!): CopilotPromptType!

"""Create a chat session"""
createCopilotSession(options: CreateChatSessionInput!): String!

Expand Down Expand Up @@ -238,6 +286,9 @@ type Mutation {
setBlob(blob: Upload!, workspaceId: String!): String!
setWorkspaceExperimentalFeature(enable: Boolean!, feature: FeatureType!, workspaceId: String!): Boolean!
sharePage(pageId: String!, workspaceId: String!): Boolean! @deprecated(reason: "renamed to publishPage")

"""Update a copilot prompt"""
updateCopilotPrompt(messages: [CopilotPromptMessageInput!]!, name: String!): CopilotPromptType!
updateProfile(input: UpdateUserInput!): UserType!

"""update server runtime configurable setting"""
Expand Down Expand Up @@ -296,6 +347,9 @@ type Query {

"""List blobs of workspace"""
listBlobs(workspaceId: String!): [String!]! @deprecated(reason: "use `workspace.blobs` instead")

"""List all copilot prompts"""
listCopilotPrompts: [CopilotPromptType!]!
listWorkspaceFeatures(feature: FeatureType!): [WorkspaceType!]!
prices: [SubscriptionPrice!]!

Expand Down
6 changes: 3 additions & 3 deletions packages/backend/server/tests/copilot.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,13 @@ test.beforeEach(async t => {
test('should be able to manage prompt', async t => {
const { prompt } = t.context;

t.is((await prompt.list()).length, 0, 'should have no prompt');
t.is((await prompt.listNames()).length, 0, 'should have no prompt');

await prompt.set('test', 'test', [
{ role: 'system', content: 'hello' },
{ role: 'user', content: 'hello' },
]);
t.is((await prompt.list()).length, 1, 'should have one prompt');
t.is((await prompt.listNames()).length, 1, 'should have one prompt');
t.is(
(await prompt.get('test'))!.finish({}).length,
2,
Expand All @@ -98,7 +98,7 @@ test('should be able to manage prompt', async t => {
);

await prompt.delete('test');
t.is((await prompt.list()).length, 0, 'should have no prompt');
t.is((await prompt.listNames()).length, 0, 'should have no prompt');
t.is(await prompt.get('test'), null, 'should not have the prompt');
});

Expand Down

0 comments on commit 4b30fbc

Please sign in to comment.