Skip to content

Commit

Permalink
Support more chain types, implement vector search powered by huggingf…
Browse files Browse the repository at this point in the history
…ace inference api (#34)

* Implement ChainFactory for chain singletons

* Add in-memory vector search powered by huggingface inference api

* Add todo items for unlimited context search
  • Loading branch information
logancyang authored May 25, 2023
1 parent 24defc6 commit ab2b5e5
Show file tree
Hide file tree
Showing 10 changed files with 242 additions and 36 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# 🔍 Copilot for Obsidian
![GitHub release (latest SemVer)](https://img.shields.io/github/v/release/logancyang/obsidian-copilot?style=for-the-badge&sort=semver) ![Obsidian plugin](https://img.shields.io/endpoint?url=https%3A%2F%2Fscambier.xyz%2Fobsidian-endpoints%2Fcopilot.json&style=for-the-badge)
![GitHub release (latest SemVer)](https://img.shields.io/github/v/release/logancyang/obsidian-copilot?style=for-the-badge&sort=semver) ![Obsidian Downloads](https://img.shields.io/badge/dynamic/json?logo=obsidian&color=%23483699&label=downloads&query=%24%5B%22copilot%22%5D.downloads&url=https%3A%2F%2Fraw.githubusercontent.com%2Fobsidianmd%2Fobsidian-releases%2Fmaster%2Fcommunity-plugin-stats.json&style=for-the-badge)


Copilot for Obsidian is a ChatGPT interface right inside Obsidian. It has a minimalistic design and is straightforward to use.
Expand Down
9 changes: 9 additions & 0 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"typescript": "4.7.4"
},
"dependencies": {
"@huggingface/inference": "^1.8.0",
"@tabler/icons-react": "^2.14.0",
"axios": "^1.3.4",
"esbuild-plugin-svg": "^0.1.0",
Expand Down
143 changes: 110 additions & 33 deletions src/aiState.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
import { AI_SENDER, DEFAULT_SYSTEM_PROMPT, USER_SENDER } from '@/constants';
import ChainFactory, {
CONVERSATIONAL_RETRIEVAL_QA_CHAIN,
LLM_CHAIN,
} from '@/chainFactory';
import {
AI_SENDER,
DEFAULT_SYSTEM_PROMPT,
USER_SENDER
} from '@/constants';
import { ChatMessage } from '@/sharedState';
import { ConversationChain } from "langchain/chains";
import {
BaseChain,
ConversationChain,
ConversationalRetrievalQAChain
} from "langchain/chains";
import { ChatOpenAI } from 'langchain/chat_models/openai';
import { HuggingFaceInferenceEmbeddings } from "langchain/embeddings/hf";
import { BufferWindowMemory } from "langchain/memory";
import {
ChatPromptTemplate,
Expand All @@ -10,20 +23,30 @@ import {
SystemMessagePromptTemplate,
} from "langchain/prompts";
import { AIChatMessage, HumanChatMessage, SystemChatMessage } from 'langchain/schema';
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter";
import { MemoryVectorStore } from "langchain/vectorstores/memory";
import { useState } from 'react';

export interface LangChainParams {
key: string,
huggingfaceApiKey: string,
model: string,
temperature: number,
maxTokens: number,
systemMessage: string,
chatContextTurns: number,
}

interface SetChainOptions {
prompt?: ChatPromptTemplate;
noteContent?: string;
}

class AIState {
static chatOpenAI: ChatOpenAI;
static chain: ConversationChain;
static chain: BaseChain;
static retrievalChain: ConversationalRetrievalQAChain;
static useChain: string;
memory: BufferWindowMemory;
langChainParams: LangChainParams;

Expand All @@ -35,22 +58,23 @@ class AIState {
returnMessages: true,
});

this.createNewChain();
this.createNewChain(LLM_CHAIN);
}

clearChatMemory(): void {
console.log('clearing chat memory');
this.memory.clear();
this.createNewChain();
this.createNewChain(LLM_CHAIN);
AIState.useChain = LLM_CHAIN;
}

setModel(newModel: string): void {
console.log('setting model to', newModel);
this.langChainParams.model = newModel;
this.createNewChain();
this.createNewChain(LLM_CHAIN);
}

createNewChain(): void {
createNewChain(chainType: string): void {
const {
key, model, temperature, maxTokens, systemMessage,
} = this.langChainParams;
Expand All @@ -69,17 +93,47 @@ class AIState {
streaming: true,
});

this.setChain(chainType, {prompt: chatPrompt});
}

async setChain(
chainType: string,
options: SetChainOptions = {},
): Promise<void> {
// TODO: Use this once https://github.com/hwchase17/langchainjs/issues/1327 is resolved
AIState.chain = new ConversationChain({
llm: AIState.chatOpenAI,
memory: this.memory,
prompt: chatPrompt,
});
if (chainType === LLM_CHAIN && options.prompt) {
AIState.chain = ChainFactory.getLLMChain({
llm: AIState.chatOpenAI,
memory: this.memory,
prompt: options.prompt,
}) as ConversationChain;
AIState.useChain = LLM_CHAIN;
console.log('Set chain:', LLM_CHAIN);
} else if (chainType === CONVERSATIONAL_RETRIEVAL_QA_CHAIN && options.noteContent) {
const textSplitter = new RecursiveCharacterTextSplitter({ chunkSize: 1000 });
const docs = await textSplitter.createDocuments([options.noteContent]);
console.log('docs:', docs);
const vectorStore = await MemoryVectorStore.fromDocuments(
docs,
new HuggingFaceInferenceEmbeddings({
apiKey: this.langChainParams.huggingfaceApiKey,
}),
);
/* Create or retrieve the chain */
AIState.retrievalChain = ChainFactory.getRetrievalChain({
llm: AIState.chatOpenAI,
retriever: vectorStore.asRetriever(),
});
// Issue where conversational retrieval chain gives rephrased question
// when streaming: https://github.com/hwchase17/langchainjs/issues/754#issuecomment-1540257078
// Temp workaround triggers CORS issue 'refused to set header user-agent'
// Wait for official fix.
AIState.useChain = CONVERSATIONAL_RETRIEVAL_QA_CHAIN;
console.log('Set chain:', CONVERSATIONAL_RETRIEVAL_QA_CHAIN);
}
}

async countTokens(inputStr: string): Promise<number> {
// TODO: This is currently falling back to an approximation. Follow up with LangchainJS:
// https://github.com/hwchase17/langchainjs/issues/985
return AIState.chatOpenAI.getNumTokens(inputStr);
}

Expand Down Expand Up @@ -135,32 +189,55 @@ class AIState {

async runChain(
userMessage: string,
chatContext: ChatMessage[],
abortController: AbortController,
updateCurrentAiMessage: (message: string) => void,
addMessage: (message: ChatMessage) => void,
debug = false,
) {
if (debug) {
console.log('Chat memory:', this.memory);
}
let fullAIResponse = '';
// TODO: chain.call stop signal gives error:
// "input values have 2 keys, you must specify an input key or pass only 1 key as input".
// Follow up with LangchainJS: https://github.com/hwchase17/langchainjs/issues/1327
await AIState.chain.call(
{
input: userMessage,
// signal: abortController.signal,
},
[
{
handleLLMNewToken: (token) => {
fullAIResponse += token;
updateCurrentAiMessage(fullAIResponse);
}
switch(AIState.useChain) {
case LLM_CHAIN:
if (debug) {
console.log('Chat memory:', this.memory);
}
]
);
// TODO: chain.call stop signal gives error:
// "input values have 2 keys, you must specify an input key or pass only 1 key as input".
// Follow up with LangchainJS: https://github.com/hwchase17/langchainjs/issues/1327
await AIState.chain.call(
{
input: userMessage,
// signal: abortController.signal,
},
[
{
handleLLMNewToken: (token) => {
fullAIResponse += token;
updateCurrentAiMessage(fullAIResponse);
}
}
]
);
break;
case CONVERSATIONAL_RETRIEVAL_QA_CHAIN:
await AIState.retrievalChain.call(
{
question: userMessage,
chat_history: chatContext,
},
[
{
handleLLMNewToken: (token) => {
fullAIResponse += token;
updateCurrentAiMessage(fullAIResponse);
}
}
]
);
break;
default:
console.error('Chain type not supported:', AIState.useChain);
}

addMessage({
message: fullAIResponse,
Expand Down
58 changes: 58 additions & 0 deletions src/chainFactory.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import { BaseLanguageModel } from "langchain/base_language";
import {
BaseChain,
ConversationChain,
ConversationalRetrievalQAChain,
LLMChainInput,
} from "langchain/chains";
import { BaseRetriever } from "langchain/schema";


export interface ConversationalRetrievalChainParams {
llm: BaseLanguageModel;
retriever: BaseRetriever;
options?: {
questionGeneratorTemplate?: string;
qaTemplate?: string;
returnSourceDocuments?: boolean;
}
}

// Add new chain types here
export const LLM_CHAIN = 'llm_chain';
export const CONVERSATIONAL_RETRIEVAL_QA_CHAIN = 'conversational_retrieval_chain';
export const SUPPORTED_CHAIN_TYPES = new Set([
LLM_CHAIN,
CONVERSATIONAL_RETRIEVAL_QA_CHAIN,
]);

class ChainFactory {
private static instances: Map<string, BaseChain> = new Map();

public static getLLMChain(args: LLMChainInput): BaseChain {
let instance = ChainFactory.instances.get(LLM_CHAIN);
if (!instance) {
instance = new ConversationChain(args as LLMChainInput);
console.log('New chain created: ', instance._chainType());
ChainFactory.instances.set(LLM_CHAIN, instance);
}
return instance;
}

public static getRetrievalChain(
args: ConversationalRetrievalChainParams
): ConversationalRetrievalQAChain {
let instance = ChainFactory.instances.get(CONVERSATIONAL_RETRIEVAL_QA_CHAIN);
if (!instance) {
const argsRetrieval = args as ConversationalRetrievalChainParams;
instance = ConversationalRetrievalQAChain.fromLLM(
argsRetrieval.llm, argsRetrieval.retriever, argsRetrieval.options
);
console.log('New chain created: ', instance._chainType());
ChainFactory.instances.set(CONVERSATIONAL_RETRIEVAL_QA_CHAIN, instance);
}
return instance as ConversationalRetrievalQAChain;
}
}

export default ChainFactory;
23 changes: 21 additions & 2 deletions src/components/Chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ import {
simplifyPrompt,
summarizePrompt,
tocPrompt,
useNoteAsContextPrompt
useNoteAsContextPrompt,
} from '@/utils';
import { EventEmitter } from 'events';
import { TFile } from 'obsidian';
import { Notice, TFile } from 'obsidian';
import React, {
useContext,
useEffect,
Expand Down Expand Up @@ -124,12 +124,31 @@ const Chat: React.FC<ChatProps> = ({

const file = app.workspace.getActiveFile();
if (!file) {
new Notice('No active note found.');
console.error('No active note found.');
return;
}
const noteContent = await getFileContent(file);
const noteName = getFileName(file);

/* TODO: Make a switch for unlimited context search, on and off. When turned on, this
message is shown in both notice and console: Unlimited Context Enabled!
*/
// const activeNoteOnMessage: ChatMessage = {
// sender: AI_SENDER,
// message: `OK please ask me questions about [[${noteName}]]`,
// isVisible: true,
// };
// addMessage(activeNoteOnMessage);

// if (noteContent) {
// aiState.setChain(CONVERSATIONAL_RETRIEVAL_QA_CHAIN, { noteContent });
// } else {
// new Notice('No note content found.');
// console.error('No note content found.');
// return;
// }

// Set the context based on the noteContent
const prompt = useNoteAsContextPrompt(noteName, noteContent);

Expand Down
1 change: 1 addition & 0 deletions src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ export const AI_SENDER = 'ai';
export const DEFAULT_SYSTEM_PROMPT = 'You are Obsidian Copilot, a helpful assistant that integrates AI to Obsidian note-taking.';
export const DEFAULT_SETTINGS: CopilotSettings = {
openAiApiKey: '',
huggingfaceApiKey: '',
defaultModel: 'gpt-3.5-turbo',
temperature: '0.7',
maxTokens: '1000',
Expand Down
10 changes: 10 additions & 0 deletions src/langchainStream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,17 @@ export const getAIResponse = async (
abortController,
updateCurrentAiMessage,
addMessage,
debug,
);

// await aiState.runChain(
// userMessage.message,
// chatContext,
// abortController,
// updateCurrentAiMessage,
// addMessage,
// debug,
// );
} catch (error) {
const errorData = error?.response?.data?.error || error;
const errorCode = errorData?.code || error;
Expand Down
Loading

0 comments on commit ab2b5e5

Please sign in to comment.