Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement model switching for custom prompts #1269

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/LLMProviders/chatModelManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -390,4 +390,10 @@ export default class ChatModelManager {
const settings = getSettings();
return settings.activeModels.find((model) => model.name === modelName);
}

getCurrentModel(): CustomModel | null {
const currentModelKey = getModelKey();
if (!currentModelKey) return null;
return this.findModelByName(currentModelKey.split("|")[0]) || null;
}
}
10 changes: 8 additions & 2 deletions src/commands/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ export function registerBuiltInCommands(plugin: CopilotPlugin) {
(plugin as any).removeCommand(id);
});

const promptProcessor = CustomPromptProcessor.getInstance(plugin.app.vault);
const promptProcessor = CustomPromptProcessor.getInstance(plugin.app);

addEditorCommand(plugin, COMMAND_IDS.FIX_GRAMMAR, (editor) => {
processInlineEditCommand(plugin, editor, COMMAND_IDS.FIX_GRAMMAR);
Expand Down Expand Up @@ -249,7 +249,12 @@ export function registerBuiltInCommands(plugin: CopilotPlugin) {
new Notice(`No prompt found with the title "${promptTitle}".`);
return;
}
plugin.processCustomPrompt(COMMAND_IDS.APPLY_CUSTOM_PROMPT, prompt.content);
plugin.processCustomPrompt(
COMMAND_IDS.APPLY_CUSTOM_PROMPT,
prompt.content,
prompt.model,
prompt.isTemporaryModel
);
} catch (err) {
console.error(err);
new Notice("An error occurred.");
Expand All @@ -260,6 +265,7 @@ export function registerBuiltInCommands(plugin: CopilotPlugin) {
addCommand(plugin, COMMAND_IDS.APPLY_ADHOC_PROMPT, async () => {
const modal = new AdhocPromptModal(plugin.app, async (adhocPrompt: string) => {
try {
// For ad-hoc prompts, we don't support model switching
plugin.processCustomPrompt(COMMAND_IDS.APPLY_ADHOC_PROMPT, adhocPrompt);
} catch (err) {
console.error(err);
Expand Down
4 changes: 2 additions & 2 deletions src/components/Chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ const Chat: React.FC<ChatProps> = ({
setLoadingMessage(LOADING_MESSAGES.DEFAULT);

// First, process the original user message for custom prompts
const customPromptProcessor = CustomPromptProcessor.getInstance(app.vault);
const customPromptProcessor = CustomPromptProcessor.getInstance(app);
let processedUserMessage = await customPromptProcessor.processCustomPrompt(
inputMessage || "",
"",
Expand Down Expand Up @@ -482,7 +482,7 @@ ${chatContent}`;
};
};

const customPromptProcessor = CustomPromptProcessor.getInstance(app.vault);
const customPromptProcessor = CustomPromptProcessor.getInstance(app);
// eslint-disable-next-line react-hooks/exhaustive-deps
useEffect(
createEffect(COMMAND_IDS.APPLY_CUSTOM_PROMPT, async (selectedText, customPrompt) => {
Expand Down
23 changes: 22 additions & 1 deletion src/components/chat-components/ChatInput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import React, {
} from "react";
import { useDropzone } from "react-dropzone";
import ContextControl from "./ContextControl";
import ChatModelManager from "@/LLMProviders/chatModelManager";

interface ChatInputProps {
inputMessage: string;
Expand Down Expand Up @@ -199,14 +200,34 @@ const ChatInput = forwardRef<{ focus: () => void }, ChatInputProps>(
};

const showCustomPromptModal = async () => {
const customPromptProcessor = CustomPromptProcessor.getInstance(app.vault);
const customPromptProcessor = CustomPromptProcessor.getInstance(app);
const prompts = await customPromptProcessor.getAllPrompts();
const promptTitles = prompts.map((prompt) => prompt.title);

new ListPromptModal(app, promptTitles, async (promptTitle: string) => {
const selectedPrompt = prompts.find((prompt) => prompt.title === promptTitle);
if (selectedPrompt) {
customPromptProcessor.recordPromptUsage(selectedPrompt.title);

// If the prompt specifies a model, try to switch to it
if (selectedPrompt.model) {
try {
const modelInstance = ChatModelManager.getInstance().findModelByName(
selectedPrompt.model
);
if (modelInstance) {
await ChatModelManager.getInstance().setChatModel(modelInstance);
} else {
new Notice(`Model "${selectedPrompt.model}" not found. Using current model.`);
}
} catch (error) {
console.error("Error switching model:", error);
new Notice(
`Failed to switch to model "${selectedPrompt.model}". Using current model.`
);
}
}

setInputMessage(selectedPrompt.content);
}
}).open();
Expand Down
20 changes: 14 additions & 6 deletions src/customPromptProcessor.test.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import { CustomPrompt, CustomPromptProcessor } from "@/customPromptProcessor";
import { extractNoteFiles, getFileContent, getNotesFromPath } from "@/utils";
import { Notice, TFile, Vault } from "obsidian";
import { Notice, TFile, Vault, App } from "obsidian";

// Mock Obsidian
jest.mock("obsidian", () => ({
Notice: jest.fn(),
TFile: jest.fn(),
Vault: jest.fn(),
App: jest.fn(),
}));

// Mock the utility functions
Expand All @@ -21,22 +22,29 @@ jest.mock("@/utils", () => ({

describe("CustomPromptProcessor", () => {
let processor: CustomPromptProcessor;
let mockVault: Vault;
let mockApp: App;
let mockActiveNote: TFile;

beforeEach(() => {
// Reset mocks before each test
jest.clearAllMocks();

// Create mock objects
mockVault = {} as Vault;
mockApp = {
vault: {} as Vault,
metadataCache: {
getFileCache: jest.fn().mockReturnValue({
frontmatter: {},
}),
},
} as unknown as App;
mockActiveNote = {
path: "path/to/active/note.md",
basename: "Active Note",
} as TFile;

// Create an instance of CustomPromptProcessor with mocked dependencies
processor = CustomPromptProcessor.getInstance(mockVault);
processor = CustomPromptProcessor.getInstance(mockApp);
});

it("should add 1 context and selectedText", async () => {
Expand Down Expand Up @@ -106,7 +114,7 @@ describe("CustomPromptProcessor", () => {

expect(result).toContain("This is the active note: {activenote}");
expect(result).toContain("Content of the active note");
expect(getFileContent).toHaveBeenCalledWith(mockActiveNote, mockVault);
expect(getFileContent).toHaveBeenCalledWith(mockActiveNote, mockApp.vault);
});

it("should handle {activeNote} when no active note is provided", async () => {
Expand Down Expand Up @@ -299,7 +307,7 @@ describe("CustomPromptProcessor", () => {

expect(result).toContain("Summarize this: {selectedText}");
expect(result).toContain("selectedText (entire active note):\n\n Content of the active note");
expect(getFileContent).toHaveBeenCalledWith(mockActiveNote, mockVault);
expect(getFileContent).toHaveBeenCalledWith(mockActiveNote, mockApp.vault);
});

it("should not duplicate active note content when both {} and {activeNote} are present", async () => {
Expand Down
53 changes: 45 additions & 8 deletions src/customPromptProcessor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,60 @@ import {
getNotesFromTags,
processVariableNameForNotePath,
} from "@/utils";
import { normalizePath, Notice, TFile, Vault } from "obsidian";
import { normalizePath, Notice, TFile, Vault, App } from "obsidian";

export interface CustomPrompt {
title: string;
content: string;
model?: string;
isTemporaryModel?: boolean;
}

export class CustomPromptProcessor {
private static instance: CustomPromptProcessor;
private static instance: CustomPromptProcessor | null = null;
private vault: Vault;
private usageStrategy: TimestampUsageStrategy;
private app: App;

private constructor(private vault: Vault) {
private constructor(app: App) {
this.app = app;
this.vault = app.vault;
this.usageStrategy = new TimestampUsageStrategy();
}

get customPromptsFolder(): string {
return getSettings().customPromptsFolder;
}

static getInstance(vault: Vault): CustomPromptProcessor {
static getInstance(app: App): CustomPromptProcessor {
if (!CustomPromptProcessor.instance) {
CustomPromptProcessor.instance = new CustomPromptProcessor(vault);
CustomPromptProcessor.instance = new CustomPromptProcessor(app);
}
return CustomPromptProcessor.instance;
}

private parseFrontMatter(
file: TFile,
content: string
): { frontmatter: { model?: string; isTemporaryModel?: boolean }; content: string } {
// Get the cached frontmatter from Obsidian's metadata cache
const cache = this.app.metadataCache.getFileCache(file);
const frontmatter = (cache?.frontmatter as Record<string, unknown>) || {};

// Handle both model and temp-model fields
const tempModel = frontmatter["temp-model"] as string | undefined;
const model = tempModel || (frontmatter.model as string | undefined);
const isTemporaryModel = "temp-model" in frontmatter;

// Get the content without frontmatter
const contentWithoutFrontmatter = content.replace(/^---\n[\s\S]*?\n---\n/, "").trim();

return {
frontmatter: { model, isTemporaryModel },
content: contentWithoutFrontmatter,
};
}

recordPromptUsage(title: string) {
this.usageStrategy.recordUsage(title);
}
Expand All @@ -47,10 +75,13 @@ export class CustomPromptProcessor {

const prompts: CustomPrompt[] = [];
for (const file of files) {
const content = await this.vault.read(file);
const rawContent = await this.vault.read(file);
const { frontmatter, content } = this.parseFrontMatter(file, rawContent);
prompts.push({
title: file.basename,
content: content,
model: frontmatter.model,
isTemporaryModel: frontmatter.isTemporaryModel,
});
}

Expand All @@ -64,8 +95,14 @@ export class CustomPromptProcessor {
const filePath = `${this.customPromptsFolder}/${title}.md`;
const file = this.vault.getAbstractFileByPath(filePath);
if (file instanceof TFile) {
const content = await this.vault.read(file);
return { title, content };
const rawContent = await this.vault.read(file);
const { frontmatter, content } = this.parseFrontMatter(file, rawContent);
return {
title,
content,
model: frontmatter.model,
isTemporaryModel: frontmatter.isTemporaryModel,
};
}
return null;
}
Expand Down
48 changes: 46 additions & 2 deletions src/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import {
WorkspaceLeaf,
} from "obsidian";
import { IntentAnalyzer } from "./LLMProviders/intentAnalyzer";
import ChatModelManager from "@/LLMProviders/chatModelManager";

export default class CopilotPlugin extends Plugin {
// A chat history that stores the messages sent and received
Expand All @@ -43,6 +44,7 @@ export default class CopilotPlugin extends Plugin {
vectorStoreManager: VectorStoreManager;
fileParserManager: FileParserManager;
settingsUnsubscriber?: () => void;
private chatModelManager: ChatModelManager;

async onload(): Promise<void> {
await this.loadSettings();
Expand Down Expand Up @@ -75,6 +77,8 @@ export default class CopilotPlugin extends Plugin {

this.initActiveLeafChangeHandler();

this.chatModelManager = ChatModelManager.getInstance();

this.addRibbonIcon("message-square", "Open Copilot Chat", (evt: MouseEvent) => {
this.activateView();
});
Expand Down Expand Up @@ -202,9 +206,49 @@ export default class CopilotPlugin extends Plugin {
} as Partial<Editor> as Editor;
}

processCustomPrompt(eventType: string, customPrompt: string) {
async processCustomPrompt(
eventType: string,
customPrompt: string,
model?: string,
isTemporary?: boolean
) {
const editor = this.getCurrentEditorOrDummy();
this.processText(editor, eventType, customPrompt, false);

// If a model is specified, switch to it
let originalModel: CustomModel | null = null;
if (model) {
try {
// Only store the original model if this is a temporary switch
if (isTemporary) {
originalModel = this.chatModelManager.getCurrentModel();
}
const targetModel = this.chatModelManager.findModelByName(model);
if (targetModel) {
await this.chatModelManager.setChatModel(targetModel);
} else {
new Notice(`Model "${model}" not found. Using current model.`);
}
} catch (error) {
console.error("Error switching model:", error);
new Notice(`Failed to switch to model "${model}". Using current model.`);
}
}

try {
await this.processText(editor, eventType, customPrompt, false);
} finally {
// Restore the original model only if this was a temporary switch
if (originalModel && isTemporary) {
try {
await this.chatModelManager.setChatModel(originalModel);
} catch (error) {
console.error("Error restoring original model:", error);
new Notice(
"Failed to restore original model. You may need to manually reset your model in settings."
);
}
}
}
}

toggleView() {
Expand Down