Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a bunch of checks for errors and are more careful about inputs #4

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies = [
"PyYAML>=6.0.1",
"slack_bolt>=1.18.1",
"solace_ai_connector>=0.1.3",
"prettytable>=3.10.0",
]

[project.urls]
Expand Down
34 changes: 24 additions & 10 deletions src/solace_ai_connector_slack/components/slack_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,10 @@ def run(self):
self.stop_event.wait()

def handle_channel_event(self, event):
# Just return if it is a message deleted event or message changed event
if event.get("subtype") in ["message_deleted", "message_changed"]:
return

# For now, just do the normal handling
channel_name = self.get_channel_name(event.get("channel"))
event["channel_name"] = channel_name
Expand Down Expand Up @@ -267,8 +271,14 @@ def handle_event(self, event):
except Exception as e:
log.error("Error getting team domain: %s", e)

user_email = self.get_user_email(event["user"])
(text, mention_emails) = self.process_text_for_mentions(event["text"])
# Determine the thread_ts to put in the message
if event.get("channel_type") == "im" and event.get("subtype") == "app_mention":
thread_ts = event.get("ts")
else:
thread_ts = None

user_email = self.get_user_email(event.get("user"))
(text, mention_emails) = self.process_text_for_mentions(event.get("text", ""))
payload = {
"text": text,
"files": files,
Expand All @@ -278,11 +288,11 @@ def handle_event(self, event):
"mentions": mention_emails,
"type": event.get("type"),
"client_msg_id": event.get("client_msg_id"),
"ts": event.get("thread_ts"),
"channel": event.get("channel"),
"channel_name": event.get("channel_name", ""),
"subtype": event.get("subtype"),
"event_ts": event.get("event_ts"),
"ts": event.get("ts"),
"thread_ts": thread_ts,
"channel_type": event.get("channel_type"),
"user_id": event.get("user"),
}
Expand All @@ -291,10 +301,10 @@ def handle_event(self, event):
"team_id": event.get("team"),
"type": event.get("type"),
"client_msg_id": event.get("client_msg_id"),
"ts": event.get("thread_ts"),
"channel": event.get("channel"),
"subtype": event.get("subtype"),
"event_ts": event.get("event_ts"),
"ts": event.get("ts"),
"thread_ts": thread_ts,
"channel_type": event.get("channel_type"),
"user_id": event.get("user"),
"input_type": "slack",
Expand All @@ -304,7 +314,7 @@ def handle_event(self, event):
ack_msg_ts = self.app.client.chat_postMessage(
channel=event["channel"],
text=self.acknowledgement_message,
thread_ts=event.get("thread_ts"),
thread_ts=thread_ts,
).get("ts")
user_properties["ack_msg_ts"] = ack_msg_ts

Expand All @@ -319,8 +329,12 @@ def download_file_as_base64_string(self, file_url):
return base64_string

def get_user_email(self, user_id):
response = self.app.client.users_info(user=user_id)
return response["user"]["profile"].get("email", user_id)
try:
response = self.app.client.users_info(user=user_id)
return response["user"]["profile"].get("email", user_id)
except Exception as e:
log.error("Error getting user email: %s", e)
return user_id

def process_text_for_mentions(self, text):
mention_emails = []
Expand Down Expand Up @@ -423,7 +437,6 @@ def handle_new_channel_join(self, event):
def register_handlers(self):
@self.app.event("message")
def handle_chat_message(event):
print("Got message event: ", event, event.get("channel_type"))
if event.get("channel_type") == "im":
self.handle_event(event)
elif event.get("channel_type") == "channel":
Expand All @@ -435,6 +448,7 @@ def handle_chat_message(event):
def handle_app_mention(event):
print("Got app_mention event: ", event)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this print? If so, can't we use leveled log instead?

event["channel_type"] = "im"
event["subtype"] = "app_mention"
event["channel_name"] = self.get_channel_name(event.get("channel"))
self.handle_event(event)

Expand Down
32 changes: 30 additions & 2 deletions src/solace_ai_connector_slack/components/slack_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import re
from datetime import datetime

from prettytable import PrettyTable

from solace_ai_connector.common.log import log
from .slack_base import SlackBase
Expand Down Expand Up @@ -126,7 +127,7 @@ def invoke(self, message, data):
last_streamed_chunk = content.get("last_streamed_chunk")
uuid = content.get("uuid")
channel = message_info.get("channel")
thread_ts = message_info.get("ts")
thread_ts = message_info.get("thread_ts")
ack_msg_ts = message_info.get("ack_msg_ts")

if not channel:
Expand All @@ -152,7 +153,7 @@ def send_message(self, message):
messages = message.get_data("previous:text")
stream = message.get_data("previous:stream")
files = message.get_data("previous:files") or []
thread_ts = message.get_data("previous:ts")
thread_ts = message.get_data("previous:thread_ts")
ack_msg_ts = message.get_data("previous:ack_msg_ts")
first_streamed_chunk = message.get_data("previous:first_streamed_chunk")
last_streamed_chunk = message.get_data("previous:last_streamed_chunk")
Expand Down Expand Up @@ -245,6 +246,12 @@ def fix_markdown(self, message):
message = re.sub(r"```[a-z]+\n", "```", message)
# Fix bold
message = re.sub(r"\*\*(.*?)\*\*", r"*\1*", message)
# Fix headings - make them bold instead
message = re.sub(r"^(#+) (.*)$", r"*\2*", message, flags=re.MULTILINE)

# Reformat a table to be Slack compatible
message = self.convert_markdown_tables(message)

return message

def get_streaming_state(self, uuid):
Expand Down Expand Up @@ -272,3 +279,24 @@ def age_out_streaming_state(self, age=60):
for uuid, state in list(self.streaming_state.items()):
if (now - state["create_time"]).total_seconds() > age:
del self.streaming_state[uuid]

def convert_markdown_tables(self, message):
def markdown_to_fixed_width(match):
table_str = match.group(0)
rows = [
line.strip().split("|")
for line in table_str.split("\n")
if line.strip()
]
headers = [cell.strip() for cell in rows[0] if cell.strip()]

pt = PrettyTable()
pt.field_names = headers

for row in rows[2:]:
pt.add_row([cell.strip() for cell in row if cell.strip()])

return f"\n```\n{pt.get_string()}\n```\n"

pattern = r"\|.*\|[\n\r]+\|[-:| ]+\|[\n\r]+((?:\|.*\|[\n\r]+)+)"
return re.sub(pattern, markdown_to_fixed_width, message)