Skip to content

Commit

Permalink
🐛 fix: fix some ai provider known issues (lobehub#5361)
Browse files Browse the repository at this point in the history
* fix provider url

* improve fetch model list issue

* fix builtin model sort and displayName

* fix user enabled models

* fix model name

* fix model displayName name
  • Loading branch information
arvinxx authored Jan 9, 2025
1 parent 4a49bc7 commit b2775b5
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { FlexboxProps } from 'react-layout-kit';

import { isServerMode } from '@/const/version';
import { DiscoverProviderItem } from '@/types/discover';

const useStyles = createStyles(({ css }) => ({
Expand All @@ -25,13 +26,13 @@ interface ProviderConfigProps extends FlexboxProps {
identifier: string;
}

const ProviderConfig = memo<ProviderConfigProps>(({ data }) => {
const ProviderConfig = memo<ProviderConfigProps>(({ data, identifier }) => {
const { styles } = useStyles();
const { t } = useTranslation('discover');

const router = useRouter();
const openSettings = () => {
router.push('/settings/llm');
router.push(!isServerMode ? '/settings/llm' : `/settings/provider/${identifier}`);
};

const icon = <Icon icon={SquareArrowOutUpRight} size={{ fontSize: 16 }} />;
Expand Down
13 changes: 8 additions & 5 deletions src/database/repositories/aiInfra/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,16 @@ export class AiInfraRepos {
.map<EnabledAiModel & { enabled?: boolean | null }>((item) => {
const user = allModels.find((m) => m.id === item.id && m.providerId === provider.id);

const enabled = !!user ? user.enabled : item.enabled;

return {
...item,
abilities: item.abilities || {},
enabled,
abilities: !!user ? user.abilities : item.abilities || {},
config: !!user ? user.config : item.config,
contextWindowTokens: !!user ? user.contextWindowTokens : item.contextWindowTokens,
displayName: user?.displayName || item.displayName,
enabled: !!user ? user.enabled : item.enabled,
id: item.id,
providerId: provider.id,
sort: !!user ? user.sort : undefined,
type: item.type,
};
})
.filter((i) => i.enabled);
Expand Down
2 changes: 1 addition & 1 deletion src/database/server/models/__tests__/aiModel.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ describe('AiModelModel', () => {

const allModels = await aiProviderModel.query();
expect(allModels).toHaveLength(2);
expect(allModels.find((m) => m.id === 'existing-model')?.displayName).toBe('Updated Name');
expect(allModels.find((m) => m.id === 'existing-model')?.displayName).toBe('Old Name');
expect(allModels.find((m) => m.id === 'new-model')?.displayName).toBe('New Model');
});
});
Expand Down
59 changes: 14 additions & 45 deletions src/database/server/models/aiModel.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import { and, asc, desc, eq, inArray } from 'drizzle-orm/expressions';
import pMap from 'p-map';

import { LobeChatDatabase } from '@/database/type';
import {
Expand Down Expand Up @@ -131,51 +130,21 @@ export class AiModelModel {
};

batchUpdateAiModels = async (providerId: string, models: AiProviderModelListItem[]) => {
return this.db.transaction(async (trx) => {
const records = models.map(({ id, ...model }) => ({
...model,
id,
providerId,
updatedAt: new Date(),
userId: this.userId,
}));
const records = models.map(({ id, ...model }) => ({
...model,
id,
providerId,
updatedAt: new Date(),
userId: this.userId,
}));

// 第一步:尝试插入所有记录,忽略冲突
const insertedRecords = await trx
.insert(aiModels)
.values(records)
.onConflictDoNothing({
target: [aiModels.id, aiModels.userId, aiModels.providerId],
})
.returning();
// 第二步:找出需要更新的记录(即插入时发生冲突的记录)
// 找出未能插入的记录(需要更新的记录)
const insertedIds = new Set(insertedRecords.map((r) => r.id));
const recordsToUpdate = records.filter((r) => !insertedIds.has(r.id));

// 第三步:更新已存在的记录
if (recordsToUpdate.length > 0) {
await pMap(
recordsToUpdate,
async (record) => {
await trx
.update(aiModels)
.set({
...record,
updatedAt: new Date(),
})
.where(
and(
eq(aiModels.id, record.id),
eq(aiModels.userId, this.userId),
eq(aiModels.providerId, providerId),
),
);
},
{ concurrency: 10 }, // 限制并发数为 10
);
}
});
return this.db
.insert(aiModels)
.values(records)
.onConflictDoNothing({
target: [aiModels.id, aiModels.userId, aiModels.providerId],
})
.returning();
};

batchToggleAiModels = async (providerId: string, models: string[], enabled: boolean) => {
Expand Down

0 comments on commit b2775b5

Please sign in to comment.