Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 fix: fix provider known issues #5361

Merged
merged 6 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { FlexboxProps } from 'react-layout-kit';

import { isServerMode } from '@/const/version';
import { DiscoverProviderItem } from '@/types/discover';

const useStyles = createStyles(({ css }) => ({
Expand All @@ -25,13 +26,13 @@ interface ProviderConfigProps extends FlexboxProps {
identifier: string;
}

const ProviderConfig = memo<ProviderConfigProps>(({ data }) => {
const ProviderConfig = memo<ProviderConfigProps>(({ data, identifier }) => {
const { styles } = useStyles();
const { t } = useTranslation('discover');

const router = useRouter();
const openSettings = () => {
router.push('/settings/llm');
router.push(!isServerMode ? '/settings/llm' : `/settings/provider/${identifier}`);
};

const icon = <Icon icon={SquareArrowOutUpRight} size={{ fontSize: 16 }} />;
Expand Down
13 changes: 8 additions & 5 deletions src/database/repositories/aiInfra/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,16 @@ export class AiInfraRepos {
.map<EnabledAiModel & { enabled?: boolean | null }>((item) => {
const user = allModels.find((m) => m.id === item.id && m.providerId === provider.id);

const enabled = !!user ? user.enabled : item.enabled;

return {
...item,
abilities: item.abilities || {},
enabled,
abilities: !!user ? user.abilities : item.abilities || {},
config: !!user ? user.config : item.config,
contextWindowTokens: !!user ? user.contextWindowTokens : item.contextWindowTokens,
displayName: user?.displayName || item.displayName,
enabled: !!user ? user.enabled : item.enabled,
id: item.id,
providerId: provider.id,
sort: !!user ? user.sort : undefined,
type: item.type,
};
})
.filter((i) => i.enabled);
Expand Down
2 changes: 1 addition & 1 deletion src/database/server/models/__tests__/aiModel.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ describe('AiModelModel', () => {

const allModels = await aiProviderModel.query();
expect(allModels).toHaveLength(2);
expect(allModels.find((m) => m.id === 'existing-model')?.displayName).toBe('Updated Name');
expect(allModels.find((m) => m.id === 'existing-model')?.displayName).toBe('Old Name');
expect(allModels.find((m) => m.id === 'new-model')?.displayName).toBe('New Model');
});
});
Expand Down
59 changes: 14 additions & 45 deletions src/database/server/models/aiModel.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import { and, asc, desc, eq, inArray } from 'drizzle-orm/expressions';
import pMap from 'p-map';

import { LobeChatDatabase } from '@/database/type';
import {
Expand Down Expand Up @@ -131,51 +130,21 @@ export class AiModelModel {
};

batchUpdateAiModels = async (providerId: string, models: AiProviderModelListItem[]) => {
return this.db.transaction(async (trx) => {
const records = models.map(({ id, ...model }) => ({
...model,
id,
providerId,
updatedAt: new Date(),
userId: this.userId,
}));
const records = models.map(({ id, ...model }) => ({
...model,
id,
providerId,
updatedAt: new Date(),
userId: this.userId,
}));

// 第一步:尝试插入所有记录,忽略冲突
const insertedRecords = await trx
.insert(aiModels)
.values(records)
.onConflictDoNothing({
target: [aiModels.id, aiModels.userId, aiModels.providerId],
})
.returning();
// 第二步:找出需要更新的记录(即插入时发生冲突的记录)
// 找出未能插入的记录(需要更新的记录)
const insertedIds = new Set(insertedRecords.map((r) => r.id));
const recordsToUpdate = records.filter((r) => !insertedIds.has(r.id));

// 第三步:更新已存在的记录
if (recordsToUpdate.length > 0) {
await pMap(
recordsToUpdate,
async (record) => {
await trx
.update(aiModels)
.set({
...record,
updatedAt: new Date(),
})
.where(
and(
eq(aiModels.id, record.id),
eq(aiModels.userId, this.userId),
eq(aiModels.providerId, providerId),
),
);
},
{ concurrency: 10 }, // 限制并发数为 10
);
}
});
return this.db
.insert(aiModels)
.values(records)
.onConflictDoNothing({
target: [aiModels.id, aiModels.userId, aiModels.providerId],
})
.returning();
};

batchToggleAiModels = async (providerId: string, models: string[], enabled: boolean) => {
Expand Down
Loading