-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfrontend.py
493 lines (416 loc) · 17.4 KB
/
frontend.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
# Imports
import streamlit as st
import streamlit.components.v1 as components
import requests
from bs4 import BeautifulSoup
from snowflake.snowpark.session import Session
import json
import uuid
import base64
from PIL import Image
import os
from io import BytesIO
from PIL import Image
import openai
import time
from textblob import TextBlob
import itertools
import io
import nltk
import pandas as pd
import numpy as np
import plotly.express as px
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
# Environment connections
connection_parameters = json.load(open('connection.json'))
session = Session.builder.configs(connection_parameters).create()
api_key = st.secrets["OPEN_AI_API"]
openai.api_key = api_key
amzn_api_key = "5EB311E772024A59BB861F16A6DC7B67"
api_url = 'https://4t6zjqbvra77cljo42ody4xtmi0smplq.lambda-url.us-east-1.on.aws/'
# Local variables
progress_bar = st.progress(0)
interests = []
posts = []
ad_links = []
user_exists = False
usern = ""
likes = []
st.session_state['start_idx'] = 0
st.session_state['end_idx'] = 10
st.session_state['ads'] = []
# Download image function
if 'not_uploaded' not in st.session_state:
st.session_state.not_uploaded = False
def download_image(image_url, save_folder):
response = requests.get(image_url)
image = Image.open(BytesIO(response.content)).convert('RGB')
# os.mkdir(save_folder)
file_path = os.path.join(save_folder, 'image.jpg')
image.save(file_path)
print("Image Downloaded to local")
return file_path
# Get caption -> caption string
def get_caption_snowpark(image_path):
session.file.put("image.jpg", '@dash_models',
overwrite=True, auto_compress=False)
session.add_packages(["transformers", "Pillow"])
print("Uploaded Image to Snowflake")
# Import Model & Image to Snowpark
directory = ['config.json',
'events.out.tfevents.1633443513.t1v-n-bb5dfd23-w-0.8655.0.v2',
'flax_model.msgpack',
'jaxlib-0.4.7+cuda11.cudnn86-cp39-cp39-manylinux2014_x86_64.whl',
'jaxlib-0.4.7-cp311-cp311-manylinux2014_x86_64.whl',
'preprocessor_config.json',
'pytorch_model.bin',
'special_tokens_map (1).json',
'special_tokens_map.json',
'tokenizer.json',
'tokenizer_config.json',
'vocab.json']
progress_bar.progress(40)
for filename in directory:
session.add_import('@dash_models/'+filename)
session.add_import('@dash_models/image.jpg')
print("Imported Image into Snowflake model")
print("Running model")
# Call Model with UDF
progress_bar.progress(50)
predicted_label = session.sql(
'''SELECT image_caption_generator()''').collect()
generated_caption = predicted_label[0][0]
print("Generated Caption - ", generated_caption)
progress_bar.progress(70)
return generated_caption
# Nouns extract -> nouns list
def extract_nouns(caption):
text = caption
blob = TextBlob(text)
nouns = [w for (w, pos) in blob.pos_tags if pos[0] == 'N']
print("Extracted nouns", nouns)
return nouns, caption
# Product recommendations extract -> recommends list
def get_product_recommendations(nouns, caption):
concatinated_noun_str = {'nouns': ','.join(nouns)}
emotion_recognition = caption
prompt_instructions = f"Now Give some potential market product recommendations for the following items or the emotion you recognized: '{concatinated_noun_str}'"
result_instructions = ", Result Intructions : give a single string seperated by commas, type nothing else."
product_recommendations = []
response = openai.Completion.create(
prompt=emotion_recognition + prompt_instructions + result_instructions,
engine="text-davinci-003",
max_tokens=1024,
n=1,
stop=None
)
product_recommendations = response['choices'][0]['text']
product_recommendations = product_recommendations.strip().split(',')
print("Generated Recommends - ", product_recommendations)
return product_recommendations
# Function connector -> nouns list , caption string , product_recommendations list
def generate_recommends(image_url):
save_folder = 'images'
image_path = 'image.jpg' #download_image(image_url, save_folder)
caption = get_caption_snowpark(image_path)
nouns, caption = extract_nouns(caption)
product_recommendations = get_product_recommendations(nouns, caption)
os.remove(image_path)
return nouns, caption, product_recommendations
# Function to extract Product Title
# Ad generate -> ads list on product recommends
# set up the request parameters
# Amazon Scrapping
def find_product_by_search_keyword(search_term):
params = {
'api_key': amzn_api_key,
'type': 'search',
'amazon_domain': 'amazon.com',
'search_term': search_term,
'sort_by': 'price_low_to_high',
'page': 1
}
api_result = requests.get('https://api.rainforestapi.com/request', params)
return json.dumps(api_result.json())
# Ad generate -> ads list on product recommends
def get_product_ads(products):
output_ads = []
ctr = 0
ix = 0
# Loop over product recommends
for product in products:
if(ix>3):
break
ix+=1
results = find_product_by_search_keyword(product)
res_json = json.loads(results)
expected_output = res_json["search_results"][:1]
emp_li = []
ctr+=5
progress_bar.progress(70+ctr)
print("Products links for ", product)
for li in expected_output:
try:
output_ads.append({"title": li['title'], "price": li['price']['raw'],
"product_image": li['image'], "link": li['link']})
except:
pass
print("Ads generated", output_ads)
return output_ads
# Function connector -> nouns list , caption string , ads list based on s3 image_url
def generate_product_ads_for_url(image_url):
print("Initiated request")
nouns, caption, product_recommendations = generate_recommends(image_url)
output_ads = get_product_ads(product_recommendations)
progress_bar.progress(100)
return nouns, caption, product_recommendations, output_ads
# Get User
def get_user(username):
data = {'method': 'LOGIN_USER', 'userpk': 'user#'+username}
res = requests.post(api_url, json=data)
if res.status_code == 200:
interests.append(res.json()['interests'])
return True
else:
return False
# Register User
def register_user(username):
if (get_user(username)):
st.error('Username Taken')
return False
else:
data = {'method': 'REGISTER_USER', 'userpk': 'user#'+username}
res = requests.post(api_url, json=data)
if res.status_code == 200:
st.success('User Added')
return True
# Get Posts paginated
def get_posts(start_idx, end_idx):
posts = st.session_state.get('posts')
return posts[start_idx:end_idx]
# List Posts
def load_posts():
data = {'method': 'LIST_POSTS'}
res = requests.post(api_url, json=data)
posts = []
for post in res.json():
posts.append(post)
try:
sorted_list = sorted(posts, key=lambda x: x["uploadedAt"],reverse=True)
st.session_state['posts'] = sorted_list
print("Sorting")
return sorted_list
except:
st.session_state['posts'] = posts
return posts
return posts
# Update Posts
def update_post(caption, pr, ads, postsk):
data = {
"method": "POST_RECOGNITION",
"postsk": postsk,
"caption": caption,
"ad_links": ads,
"recommends": pr
}
res = requests.post(api_url, json=data)
print("Updated Post")
# Append to ads
def gen_ads_on_likes(post):
print(post['ad_links']['N'])
st.session_state['ads'] = post['ad_links']['N']
print("Updated ads")
# Upload Image
def upload_file(image, filename, username):
data = {'method': 'UPLOAD_POST', 'image': image, 'image_filename': filename,
'postsk': filename, 'recognitions': [], 'username': username}
res = requests.post(api_url, json=data)
st.success("Image Uploaded")
progress_bar.progress(20)
post_item = res.json()
recognitions, caption, pr, ads = generate_product_ads_for_url(
post_item['image_link'])
print("Image recognitions and recommendations", recognitions, caption)
st.session_state['ads'] = ads
update_post(caption, pr, ads, filename)
load_posts()
print("Loaded posts")
# Convert base64
def get_base64(image):
# Open the image with Pillow
img = Image.open(image)
img = img.convert('RGB')
# Convert the image to a bytes object
with BytesIO() as output:
img.save(output, format='JPEG')
contents = output.getvalue()
# Encode the bytes object as a base64 string and return it
encoded = base64.b64encode(contents).decode('utf-8')
return encoded
# Main app
def main_app():
st.title('SocialLens')
st.write('Your story, in pixels')
page = st.sidebar.radio(
'Navigation', ['Sign Up', 'Log In', 'Home','Data Analysis'])
if page == 'Sign Up':
new_username = st.text_input('New Username')
if st.button('Sign Up'):
register_user(new_username)
elif page == 'Log In':
username = st.text_input('Username')
if st.button('Log In'):
if get_user(username):
st.session_state['user_exists'] = True
st.session_state['username'] = username
st.success('Logged in successfully')
st.sidebar.success(f'Logged in as {username}')
else:
st.session_state['user_exists'] = False
st.error('Invalid username')
elif page == 'Home':
start_idx = st.session_state.get('start_idx')
end_idx = st.session_state.get('end_idx')
username = st.session_state.get('username')
if st.session_state.get('user_exists'):
st.success(f'Welcome, {username}!')
# if st.button("Make a Post"):
uploaded_file = st.file_uploader(
"Choose an image file", accept_multiple_files=False, type=['jpg', 'jpeg', 'png'])
if uploaded_file is not None and not(st.session_state.get("not_uploaded")):
img_data = uploaded_file.read()
img = Image.open(io.BytesIO(img_data)).convert("RGB")
converted_img_data = io.BytesIO()
img.save(converted_img_data, format="JPEG")
file_path = os.path.join('image.jpg')
with open(file_path, "wb") as f:
f.write(converted_img_data.getvalue())
print("here")
file_name = 'img_' + str(uuid.uuid4()) + ".jpg"
print(username)
upload_file(get_base64(
uploaded_file), file_name, username)
st.progress(0)
st.session_state.not_uploaded = True
st.write('Here are some recent posts:')
load_posts()
posts_to_display = get_posts(start_idx, end_idx)
for post in posts_to_display:
st.image(post['image_link'], use_column_width=True)
st.write(f'Posted by @{post["username"]}')
# st.write(f'Likes: {post["likes"]}')
if st.button(f'❤️', key=post['sk']):
if st.session_state.get('user_exists'):
likes.append(post)
gen_ads_on_likes(post)
# st.session_state['ads'].append(post.ads)
st.success('Post liked!')
else:
st.error('Please log in to like posts.')
# Implement infinite scroll
if len(st.session_state.get('posts')) > end_idx:
if st.button('Load More'):
st.session_state['start_idx'] += 3
st.session_state['end_idx'] += 3
else:
st.warning('Scroll down to load more posts...')
# Implement ads row after scrolling past a certain number of posts
if end_idx >= 2:
st.write('Here are some ads:')
ads = st.session_state.ads
for ad in ads:
col1, col2 = st.columns([1, 3])
with col1:
st.image(ad['product_image'], use_column_width=True)
with col2:
st.write(f'Title: {ad["title"]}')
st.write(f'Price: {ad["price"]}')
ad_link = ad["link"]
st.write(f'''
<a target="_self" href={ad_link} target="_blank" rel="noopener noreferrer">
<button>
Go To Amazon
</button>
</a>
''',
unsafe_allow_html=True)
elif page == "Data Analysis":
st.title("Possible Data Insights")
# st.write("Generated some data for these graphs")
np.random.seed(3)
ads = ['Pet Ads', 'Fashion Ads', 'Food Ads',
'Travel Gear Ads', 'Tech Gadgets Ads']
ages = ['18-24', '25-34', '35-44', '45-54', '55+']
genders = ['Male', 'Female']
locations = ['USA', 'Europe', 'Asia', 'Africa', 'South America', 'Australia']
categories = ['Pet', 'Tech', 'Food', 'Travel', 'Fashion']
months = pd.date_range('2022-01-01', periods=12, freq='M')
clicks = np.random.randint(1, 100, size=(len(months), len(ads)))
age_labels = np.random.choice(ages, size=len(months))
gender_labels = np.random.choice(genders, size=len(months))
location_labels = np.random.choice(locations, size=len(months))
category_labels = np.random.choice(categories, size=len(months))
month_labels = months.strftime('%Y-%m')
df = pd.DataFrame(clicks, columns=ads)
df['Age'] = age_labels
df['Gender'] = gender_labels
df['Location'] = location_labels
df['Category'] = category_labels
df['Month'] = month_labels
st.header("Gender Analysis on single Ad Category")
st.plotly_chart(px.box(df, x='Location', y='Pet Ads',
color='Gender'), use_container_width=True)
heatmap_df = pd.DataFrame({
'Hour': np.random.randint(7, 24, size=1000),
'Clicks': np.random.randint(1, 100, size=1000),
'Location': np.random.choice(locations, size=1000),
'Age': np.random.choice(ages, size=1000),
'Gender': np.random.choice(genders, size=1000),
'Category': np.random.choice(categories, size=1000)
})
st.header("Clicks Analysis based on Gender & Location")
bubble_df = pd.DataFrame({
'Hour': np.random.randint(7, 24, size=100),
'Clicks': np.random.randint(1, 100, size=100),
'Location': np.random.choice(locations, size=100),
'Age': np.random.choice(ages, size=100),
'Gender': np.random.choice(genders, size=100),
'Category': np.random.choice(categories, size=100)
})
bubble_chart = px.scatter(bubble_df, x='Location', y='Hour',
size='Clicks', color='Category', hover_data=['Gender'])
st.plotly_chart(bubble_chart, use_container_width=True)
st.header("Clicks Analysis based on Ad Category & Time of Day")
scatter_df = pd.DataFrame({
'Clicks': np.random.randint(1, 100, size=100),
'Time': np.random.randint(8, 24, size=100),
'Age': np.random.randint(18, 65, size=100),
'Ad_Category': np.random.choice(['gadgets', 'pet', 'food'], size=100)
})
fig = px.scatter(scatter_df, x='Time', y='Clicks', color='Ad_Category',
size='Age', hover_data=['Ad_Category', 'Age'])
st.plotly_chart(fig)
# Stacked Bar Chart
st.header("Clicks Analysis with Category Competetion")
stacked_bar_df = pd.DataFrame({
'Clicks': np.random.randint(1, 100, size=100),
'Age': np.random.randint(18, 65, size=100),
'Gender': np.random.choice(['M', 'F'], size=100),
'Ad_Category': np.random.choice(['gadgets', 'pet', 'food'], size=100)
})
stacked_bar_df = stacked_bar_df.groupby(['Age', 'Gender', 'Ad_Category'])['Clicks'].sum().reset_index()
fig = px.bar(stacked_bar_df,
x='Age',
y='Clicks',
color='Ad_Category',
barmode='stack',
facet_row='Gender')
st.plotly_chart(fig)
st.header("Total Ad Category Weightage")
data = {'Ad_Category': ads,
'Weightage': [7, 43, 25, 10, 15]}
df = pd.DataFrame(data)
fig = px.pie(df, values='Weightage', names='Ad_Category')
st.plotly_chart(fig)
main_app()