Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Mermaid.js #1054

Merged
merged 9 commits into from
Jan 15, 2025
15 changes: 11 additions & 4 deletions src/interface/obsidian/src/chat_view.ts
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,8 @@ export class KhojChatView extends KhojPaneView {
inferredQueries?: string[],
conversationId?: string,
images?: string[],
excalidrawDiagram?: string
excalidrawDiagram?: string,
mermaidjsDiagram?: string
) {
if (!message) return;

Expand All @@ -496,8 +497,9 @@ export class KhojChatView extends KhojPaneView {
intentType?.includes("text-to-image") ||
intentType === "excalidraw" ||
(images && images.length > 0) ||
mermaidjsDiagram ||
excalidrawDiagram) {
let imageMarkdown = this.generateImageMarkdown(message, intentType ?? "", inferredQueries, conversationId, images, excalidrawDiagram);
let imageMarkdown = this.generateImageMarkdown(message, intentType ?? "", inferredQueries, conversationId, images, excalidrawDiagram, mermaidjsDiagram);
chatMessageEl = this.renderMessage(chatEl, imageMarkdown, sender, dt);
} else {
chatMessageEl = this.renderMessage(chatEl, message, sender, dt);
Expand All @@ -517,7 +519,7 @@ export class KhojChatView extends KhojPaneView {
chatMessageBodyEl.appendChild(this.createReferenceSection(references));
}

generateImageMarkdown(message: string, intentType: string, inferredQueries?: string[], conversationId?: string, images?: string[], excalidrawDiagram?: string): string {
generateImageMarkdown(message: string, intentType: string, inferredQueries?: string[], conversationId?: string, images?: string[], excalidrawDiagram?: string, mermaidjsDiagram?: string): string {
let imageMarkdown = "";
if (intentType === "text-to-image") {
imageMarkdown = `![](data:image/png;base64,${message})`;
Expand All @@ -529,6 +531,8 @@ export class KhojChatView extends KhojPaneView {
const domain = this.setting.khojUrl.endsWith("/") ? this.setting.khojUrl : `${this.setting.khojUrl}/`;
const redirectMessage = `Hey, I'm not ready to show you diagrams yet here. But you can view it in ${domain}chat?conversationId=${conversationId}`;
imageMarkdown = redirectMessage;
} else if (mermaidjsDiagram) {
imageMarkdown = "```mermaid\n" + mermaidjsDiagram + "\n```";
} else if (images && images.length > 0) {
for (let image of images) {
if (image.startsWith("https://")) {
Expand Down Expand Up @@ -908,6 +912,7 @@ export class KhojChatView extends KhojPaneView {
chatBodyEl.dataset.conversationId ?? "",
chatLog.images,
chatLog.excalidrawDiagram,
chatLog.mermaidjsDiagram,
);
// push the user messages to the chat history
if (chatLog.by === "you") {
Expand Down Expand Up @@ -1012,7 +1017,7 @@ export class KhojChatView extends KhojPaneView {
}

handleJsonResponse(jsonData: any): void {
if (jsonData.image || jsonData.detail || jsonData.images || jsonData.excalidrawDiagram) {
if (jsonData.image || jsonData.detail || jsonData.images || jsonData.mermaidjsDiagram) {
this.chatMessageState.rawResponse = this.handleImageResponse(jsonData, this.chatMessageState.rawResponse);
} else if (jsonData.response) {
this.chatMessageState.rawResponse = jsonData.response;
Expand Down Expand Up @@ -1407,6 +1412,8 @@ export class KhojChatView extends KhojPaneView {
const domain = this.setting.khojUrl.endsWith("/") ? this.setting.khojUrl : `${this.setting.khojUrl}/`;
const redirectMessage = `Hey, I'm not ready to show you diagrams yet here. But you can view it in ${domain}`;
rawResponse += redirectMessage;
} else if (imageJson.mermaidjsDiagram) {
rawResponse += imageJson.mermaidjsDiagram;
}

// If response has detail field, response is an error message.
Expand Down
2 changes: 1 addition & 1 deletion src/interface/obsidian/styles.css
Original file line number Diff line number Diff line change
Expand Up @@ -858,4 +858,4 @@ img.copy-icon {
100% {
transform: rotate(360deg);
}
}
}
6 changes: 3 additions & 3 deletions src/interface/web/app/common/chatFunctions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ export interface MessageMetadata {

export interface GeneratedAssetsData {
images: string[];
excalidrawDiagram: string;
mermaidjsDiagram: string;
files: AttachedFileText[];
}

Expand Down Expand Up @@ -114,8 +114,8 @@ export function processMessageChunk(
currentMessage.generatedImages = generatedAssets.images;
}

if (generatedAssets.excalidrawDiagram) {
currentMessage.generatedExcalidrawDiagram = generatedAssets.excalidrawDiagram;
if (generatedAssets.mermaidjsDiagram) {
currentMessage.generatedMermaidjsDiagram = generatedAssets.mermaidjsDiagram;
}

if (generatedAssets.files) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
conversationId: props.conversationId,
images: message.generatedImages,
queryFiles: message.generatedFiles,
excalidrawDiagram: message.generatedExcalidrawDiagram,
mermaidjsDiagram: message.generatedMermaidjsDiagram,
turnId: messageTurnId,
}}
conversationId={props.conversationId}
Expand Down
11 changes: 11 additions & 0 deletions src/interface/web/app/components/chatMessage/chatMessage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ import { DialogTitle } from "@radix-ui/react-dialog";
import { convertBytesToText } from "@/app/common/utils";
import { ScrollArea } from "@/components/ui/scroll-area";
import { getIconFromFilename } from "@/app/common/iconUtils";
import Mermaid from "../mermaid/mermaid";

const md = new markdownIt({
html: true,
Expand Down Expand Up @@ -164,6 +165,7 @@ export interface SingleChatMessage {
turnId?: string;
queryFiles?: AttachedFileText[];
excalidrawDiagram?: string;
mermaidjsDiagram?: string;
}

export interface StreamMessage {
Expand All @@ -182,9 +184,11 @@ export interface StreamMessage {
turnId?: string;
queryFiles?: AttachedFileText[];
excalidrawDiagram?: string;
mermaidjsDiagram?: string;
generatedFiles?: AttachedFileText[];
generatedImages?: string[];
generatedExcalidrawDiagram?: string;
generatedMermaidjsDiagram?: string;
}

export interface ChatHistoryData {
Expand Down Expand Up @@ -271,6 +275,7 @@ interface ChatMessageProps {
turnId?: string;
generatedImage?: string;
excalidrawDiagram?: string;
mermaidjsDiagram?: string;
generatedFiles?: AttachedFileText[];
}

Expand Down Expand Up @@ -358,6 +363,7 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
const [isPlaying, setIsPlaying] = useState<boolean>(false);
const [interrupted, setInterrupted] = useState<boolean>(false);
const [excalidrawData, setExcalidrawData] = useState<string>("");
const [mermaidjsData, setMermaidjsData] = useState<string>("");

const interruptedRef = useRef<boolean>(false);
const messageRef = useRef<HTMLDivElement>(null);
Expand Down Expand Up @@ -401,6 +407,10 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
setExcalidrawData(props.chatMessage.excalidrawDiagram);
}

if (props.chatMessage.mermaidjsDiagram) {
setMermaidjsData(props.chatMessage.mermaidjsDiagram);
}

// Replace LaTeX delimiters with placeholders
message = message
.replace(/\\\(/g, "LEFTPAREN")
Expand Down Expand Up @@ -718,6 +728,7 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
dangerouslySetInnerHTML={{ __html: markdownRendered }}
/>
{excalidrawData && <ExcalidrawComponent data={excalidrawData} />}
{mermaidjsData && <Mermaid chart={mermaidjsData} />}
</div>
<div className={styles.teaserReferencesContainer}>
<TeaserReferencesSection
Expand Down
47 changes: 35 additions & 12 deletions src/interface/web/app/components/loading/loading.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import { SidebarInset, SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar";
import { CircleNotch } from "@phosphor-icons/react";
import { AppSidebar } from "../appSidebar/appSidebar";
import { Separator } from "@/components/ui/separator";
import { useIsMobileWidth } from "@/app/common/utils";
import { KhojLogoType } from "../logo/khojLogo";

interface LoadingProps {
className?: string;
Expand All @@ -7,21 +12,39 @@ interface LoadingProps {
}

export default function Loading(props: LoadingProps) {
const isMobileWidth = useIsMobileWidth();

return (
// NOTE: We can display usage tips here for casual learning moments.
<div
className={
props.className ||
"bg-background opacity-50 flex items-center justify-center h-screen"
}
>
<div>
{props.message || "Loading"}{" "}
<span>
<CircleNotch className="inline animate-spin h-5 w-5" />
</span>
<SidebarProvider>
<AppSidebar conversationId={""} />
<SidebarInset>
<header className="flex h-16 shrink-0 items-center gap-2 border-b px-4">
<SidebarTrigger className="-ml-1" />
<Separator orientation="vertical" className="mr-2 h-4" />
{isMobileWidth ? (
<a className="p-0 no-underline" href="/">
<KhojLogoType className="h-auto w-16" />
</a>
) : (
<h2 className="text-lg">Ask Anything</h2>
)}
</header>
</SidebarInset>
<div
className={
props.className ||
"bg-background opacity-50 flex items-center justify-center h-full w-full fixed top-0 left-0 z-50"
}
>
<div>
{props.message || "Loading"}{" "}
<span>
<CircleNotch className="inline animate-spin h-5 w-5" />
</span>
</div>
</div>
</div>
</SidebarProvider>
);
}

Expand Down
173 changes: 173 additions & 0 deletions src/interface/web/app/components/mermaid/mermaid.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import React, { useEffect, useState, useRef } from "react";
import mermaid from "mermaid";
import { Download, Info } from "@phosphor-icons/react";
import { Button } from "@/components/ui/button";

interface MermaidProps {
chart: string;
}

const Mermaid: React.FC<MermaidProps> = ({ chart }) => {
const [mermaidError, setMermaidError] = useState<string | null>(null);
const [mermaidId] = useState(`mermaid-chart-${Math.random().toString(12).substring(7)}`);
const elementRef = useRef<HTMLDivElement>(null);

useEffect(() => {
mermaid.initialize({
startOnLoad: false,
});

mermaid.parseError = (error) => {
console.error("Mermaid errors:", error);
// Extract error message from error object
// Parse error message safely
let errorMessage;
try {
errorMessage = typeof error === "string" ? JSON.parse(error) : error;
} catch (e) {
errorMessage = error?.toString() || "Unknown error";
}

console.log("Mermaid error message:", errorMessage);

if (errorMessage.str !== "element is null") {
setMermaidError(
"Something went wrong while rendering the diagram. Please try again later or downvote the message if the issue persists.",
);
} else {
setMermaidError(null);
}
};

mermaid.contentLoaded();
}, []);

const handleExport = async () => {
if (!elementRef.current) return;

try {
// Get SVG element
const svgElement = elementRef.current.querySelector("svg");
if (!svgElement) throw new Error("No SVG found");

// Get SVG viewBox dimensions
const viewBox = svgElement.getAttribute("viewBox")?.split(" ").map(Number) || [
0, 0, 0, 0,
];
const [, , viewBoxWidth, viewBoxHeight] = viewBox;

// Create canvas with viewBox dimensions
const canvas = document.createElement("canvas");
const scale = 2; // For better resolution
canvas.width = viewBoxWidth * scale;
canvas.height = viewBoxHeight * scale;
const ctx = canvas.getContext("2d");
if (!ctx) throw new Error("Failed to get canvas context");

// Convert SVG to data URL
const svgData = new XMLSerializer().serializeToString(svgElement);
const svgBlob = new Blob([svgData], { type: "image/svg+xml;charset=utf-8" });
const svgUrl = URL.createObjectURL(svgBlob);

// Create and load image
const img = new Image();
img.src = svgUrl;

await new Promise((resolve, reject) => {
img.onload = () => {
// Scale context for better resolution
ctx.scale(scale, scale);
ctx.drawImage(img, 0, 0, viewBoxWidth, viewBoxHeight);

canvas.toBlob((blob) => {
if (!blob) {
reject(new Error("Failed to create blob"));
return;
}

const url = URL.createObjectURL(blob);
const a = document.createElement("a");
a.href = url;
a.download = `mermaid-diagram-${Date.now()}.png`;
a.click();

// Cleanup
URL.revokeObjectURL(url);
URL.revokeObjectURL(svgUrl);
resolve(true);
}, "image/png");
};

img.onerror = () => reject(new Error("Failed to load SVG"));
});
} catch (error) {
console.error("Error exporting diagram:", error);
setMermaidError("Failed to export diagram");
}
};

useEffect(() => {
if (elementRef.current) {
elementRef.current.removeAttribute("data-processed");

mermaid
.run({
nodes: [elementRef.current],
})
.then(() => {
setMermaidError(null);
})
.catch((error) => {
let errorMessage;
try {
errorMessage = typeof error === "string" ? JSON.parse(error) : error;
} catch (e) {
errorMessage = error?.toString() || "Unknown error";
}

console.log("Mermaid error message:", errorMessage);

if (errorMessage.str !== "element is null") {
setMermaidError(
"Something went wrong while rendering the diagram. Please try again later or downvote the message if the issue persists.",
);
} else {
setMermaidError(null);
}
});
}
}, [chart]);

return (
<div>
{mermaidError ? (
<div className="flex items-center gap-2 bg-red-100 border border-red-500 rounded-md p-3 mt-3 text-red-900 text-sm">
<Info className="w-12 h-12" />
<span>Error rendering diagram: {mermaidError}</span>
</div>
) : (
<div
id={mermaidId}
ref={elementRef}
className="mermaid"
style={{
width: "auto",
height: "auto",
boxSizing: "border-box",
overflow: "auto",
}}
>
{chart}
</div>
)}
{!mermaidError && (
<Button onClick={handleExport} variant={"secondary"} className="mt-3">
<Download className="w-5 h-5" />
Export as PNG
</Button>
)}
</div>
);
};

export default Mermaid;
Loading
Loading