-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathvisualizer.py
114 lines (94 loc) · 3.91 KB
/
visualizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import base64
import io
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import plotly.express as px
import json
import openai
import os
import re
from datetime import datetime
matplotlib.use('Agg')
def convert_ndarray(obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, dict):
return {k: convert_ndarray(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [convert_ndarray(i) for i in obj]
elif isinstance(obj, datetime):
return obj.isoformat()
else:
return obj
def create_plot(python_code):
local_scope = {}
exec(python_code, {}, local_scope)
if 'plt' in python_code:
img_buffer = io.BytesIO()
plt.savefig(img_buffer, format="png")
plt.close()
img_buffer.seek(0)
img_str = base64.b64encode(img_buffer.read()).decode("utf-8")
return {"type": "image", "data": f"data:image/png;base64,{img_str}"}
elif 'px' in python_code: # Check if it's plotly
fig = local_scope.get('fig', None) # Retrieve the figure object if defined
if fig:
return {"type": "plotly", "data": convert_ndarray(fig.to_dict()['data']), "layout": convert_ndarray(fig.to_dict()['layout'])}
return {"type": "text", "data": python_code}
def process_response(response_text):
if "```python" in response_text:
code = response_text.replace("```python", "").replace("```", "").strip()
code = code.replace("plt.show()", "").replace("fig.show()", "")
try:
return create_plot(code)
except Exception as e:
return {"type": "error", "data": f"Error during execution: {e}"}
else:
return {"type": "text", "data": response_text}
valid_visualizers = [
"Line Plot", "Bar Chart", "Histogram", "Scatter Plot",
"Box Plot", "Heatmap", "Pie Chart", "Area Chart",
"Violin Plot", "Pair Plot", "Timeline Plot", "Word Cloud", "Venn Diagram", "Flow Chart"
]
valid_packages = ['matplotlib', 'plotly.express', 'pandas', 'numpy', 'scikit-learn', 'seaborn', 'statsmodels']
openai.api_key = os.getenv("OPENAI_API_KEY")
client = openai.OpenAI()
message = (
f"You are an assistant that can only do two things."
f"Summarize the following text, fixing any spelling mistakes and making it a complete sentence. "
f"Do not mention that this is a prompt or request. Use only the information provided: content and say it as a description of vizualisation that will be generated. Do not mention about the python code"
f"Second. You must return Python code that properly generates one of these types of graphs: "
f"{', '.join(valid_visualizers)} and you can only use these packages for creating the graph {', '.join(valid_packages)}. You must choose the most appropriate visualization to create given the following prompt."
f"If possible create interactive plotly graphs, but we want a mix of plotly and other types of graphs."
)
def get_visualizations_json(data):
results = []
prompt = data[0]
prompt = ' '.join(prompt)
completion = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": message},
{
"role": "user",
"content": prompt
}
]
)
# Execute and get code from response
response = completion.choices[0].message.content
code_text = re.search(r"```python(.*?)```", response, re.DOTALL)
# Extract summary and code
if code_text:
summary = response[:code_text.start()].strip()
code = code_text.group(0).strip()
else:
summary = response
code = ""
visualization_result = process_response(code)
if visualization_result["type"] == "error":
raise Exception(visualization_result["data"])
visualization_result["summary"] = summary
results.append(visualization_result)
return results