Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Install jaxlib with poetry, cleanup #3

Merged
merged 5 commits into from
Jun 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@ A (soon-to-be) collection of tools for generating [dalle-mini](https://github.co
Install the dependencies, then try out the CLI. Try `python generate.py --help` for more.

```sh
# Install poetry
curl -sSL https://install.python-poetry.org | python3 -
# If you installed poetry 1.1.x before, uninstall first
curl -sSL https://install.python-poetry.org | python3 - --uninstall

# Install poetry 1.2.x preview
curl -sSL https://install.python-poetry.org | python3 - --preview

# Create virtual env for this project, install requirements
poetry install
Expand Down
66 changes: 47 additions & 19 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 6 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[build-system]
requires = ["poetry>=1.1"]
requires = ["poetry-core"]
build-backend = "poetry.masonry.api"

[tool.poetry]
Expand All @@ -14,15 +14,17 @@ fire = "^0.4.0"
Flask = "^2.1.2"
ipywidgets = "^7.7.1"
pySqsListener = "^0.8.10"
python = ">=3.10.4,<3.11"
python = ">=3.10,<3.11"
python-slugify = "^6.1.2"
tokenizers = "~=0.11.6"
vqgan-jax = {git = "https://github.com/patil-suraj/vqgan-jax.git", rev = "main"}
slack-sdk = "^3.17.2"
slack-bolt = "^1.14.0"
tqdm = "^4.64.0"
jax = "^0.3.13"
jaxlib = "^0.3.10"
jaxlib = [
{url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.10+cuda11.cudnn82-cp310-none-manylinux2014_x86_64.whl", markers = "platform_machine == 'linux'"},
{url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-0.3.10-cp310-none-macosx_11_0_arm64.whl", markers="platform_machine == 'arm64'"}
]

[tool.poetry.dev-dependencies]
black = "~22.3.0"
Expand Down
6 changes: 5 additions & 1 deletion server.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@ def output(path):
return redirect(f"/output/{path}")

return render_template(
"template.html", prompt=prompt, imgs=imgs, expected_img_count=expectedimgs, show_links=True
"template.html",
prompt=prompt,
imgs=imgs,
expected_img_count=expectedimgs,
show_links=True,
)

else:
Expand Down
11 changes: 8 additions & 3 deletions sitegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,11 @@ def get_dir_details(path):
with open(f"{path}/prompt.txt", "r") as rp:
prompt = rp.read()

imgs = [ os.path.basename(x) for x in glob.glob(f"{path}/[!f]*.png") ] #a small hack to ignore "final.png"
return ( prompt, imgs )
imgs = [
os.path.basename(x) for x in glob.glob(f"{path}/[!f]*.png")
] # a small hack to ignore "final.png"
return (prompt, imgs)


def generate_index(path, show_links=False):
tl = jinja2.FileSystemLoader(searchpath="./templates")
Expand All @@ -53,7 +56,9 @@ def generate_index(path, show_links=False):
if imgs is None or len(imgs) == 0:
return None

return template.render(prompt=prompt, imgs=imgs, expected_img_count=len(imgs), show_links=show_links)
return template.render(
prompt=prompt, imgs=imgs, expected_img_count=len(imgs), show_links=show_links
)


if __name__ == "__main__":
Expand Down
104 changes: 61 additions & 43 deletions slackbot.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,76 @@
#!/usr/bin/env python

import os
import time
from tqdm import tqdm
from request import send as send_queue_request

from slack_bolt import App
from slack_bolt.adapter.socket_mode import SocketModeHandler
from tqdm import tqdm

from request import send as send_queue_request

# import time


SLACK_BOT_TOKEN = os.environ['SLACK_BOT_TOKEN']
SLACK_APP_TOKEN = os.environ['SLACK_APP_TOKEN']
SLACK_BOT_TOKEN = os.environ["SLACK_BOT_TOKEN"]
SLACK_APP_TOKEN = os.environ["SLACK_APP_TOKEN"]

app = App(token=SLACK_BOT_TOKEN)


@app.event("app_mention")
def mention_handler(body, say, logger):
event = body["event"]
thread_ts = event.get("thread_ts", None) or event["ts"]
prompt = event["text"].replace(event["text"].split(' ')[0].strip(), "").strip() # i feel dirty but its late
print(f"Generating {prompt=}")
# start = time.time()
rundir = send_queue_request(prompt)
say(text=f'On it!', thread_ts=thread_ts)

max_t = 8000000 # my 2080Ti can generate from SQS to final image in: 1672800 ticks
for i in tqdm(range(max_t)):
if i % 100000 == 0:
print(f"Checking {rundir} {i}/{max_t}")

if os.path.exists(f'output/{rundir}/final.png'):
img = f'https://dalle-mini-tools.xeb.ai/output/{rundir}/final.png'
print(f"Found {img}")
say(img)
# say(img, thread_ts=thread_ts)
# end = time.time()
# duration = (start - end)
# # if duration >= 86400:
# # days = int(duration / 86400)
# elapsed = time.strftime("%H hours, %M minutes, %S seconds", time.gmtime(duration))
# say(f'Took me {elapsed}.', thread_ts=thread_ts)
return

say(text=f'...I gave up.', thread_ts=thread_ts)
def mention_handler_app_mention(body, say, logger):
event = body["event"]
thread_ts = event.get("thread_ts", None) or event["ts"]
prompt = (
event["text"].replace(event["text"].split(" ")[0].strip(), "").strip()
) # i feel dirty but its late
print(f"Generating {prompt=}")
# start = time.time()
rundir = send_queue_request(prompt)
say(text="On it!", thread_ts=thread_ts)

max_t = 8000000 # my 2080Ti can generate from SQS to final image in: 1672800 ticks
for i in tqdm(range(max_t)):
if i % 100000 == 0:
print(f"Checking {rundir} {i}/{max_t}")

if os.path.exists(f"output/{rundir}/final.png"):
img = f"https://dalle-mini-tools.xeb.ai/output/{rundir}/final.png"
print(f"Found {img}")
say(img)

# say(img, thread_ts=thread_ts)
# end = time.time()
# duration = start - end
# # if duration >= 86400:
# # days = int(duration / 86400)
# elapsed = time.strftime(
# "%H hours, %M minutes, %S seconds", time.gmtime(duration)
# )
# say(f"Took me {elapsed}.", thread_ts=thread_ts)
return

say(text="...I gave up.", thread_ts=thread_ts)


@app.event("message")
def mention_handler(body, say):
event = body["event"]
thread_ts = event.get("thread_ts", None) or event["ts"]
if "text" not in event:
return
def mention_handler_message(body, say):
event = body["event"]
thread_ts = event.get("thread_ts", None) or event["ts"]
if "text" not in event:
return

message = event["text"].strip()
if "generation station" in message.lower():
say(
text=(
"What was that? Did you want to generate an image? Just mention me"
" (@ImageGen) and tell me what you want."
),
thread_ts=thread_ts,
)

message = event["text"].strip()
if "generation station" in message.lower():
say(text='What was that? Did you want to generate an image? Just mention me (@ImageGen) and tell me what you want.', thread_ts=thread_ts)

if __name__ == "__main__":
handler = SocketModeHandler(app, SLACK_APP_TOKEN)
handler.start()
handler = SocketModeHandler(app, SLACK_APP_TOKEN)
handler.start()
22 changes: 11 additions & 11 deletions worker.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#!/usr/bin/env python

import os
import fire
import subprocess
from generate import Generator

import fire
from sqs_listener import SqsListener

from generate import Generator
Expand All @@ -13,20 +13,19 @@ class ImgGenListener(SqsListener):
def init_model(self, output_dir, clip_scores, postprocess):
self.generator = Generator(output_dir, clip_scores)
self.postprocess = postprocess
print(f"Initialized model")
print("Initialized model")

def postprocessing(self, run_name):
if not self.postprocess:
print(f"Postprocessing not enabled, skipping...")
print("Postprocessing not enabled, skipping...")

if os.path.exists("postprocess.sh"):
cmds = [ "./postprocess.sh", run_name ]
cmds = ["./postprocess.sh", run_name]
p = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
out, err = p.communicate()
if p.returncode != 0:
print("-"**5)
print("-" ** 5)
print(f"Exception\n{err=}\n\n{out=}")


def handle_message(self, body, attr, msg_attr):
print(f"Processing {body=} {attr=} {msg_attr=}")
Expand All @@ -36,12 +35,13 @@ def handle_message(self, body, attr, msg_attr):
self.postprocessing(run_name)
print(f"Processed! {body=}")


def main(
output_dir="output",
clip_scores=False,
output_dir="output",
clip_scores=False,
postprocess=True,
queue_name="dalle-mini-tools",
error_queue="dalle-mini-tools_errors",
queue_name="dalle-mini-tools",
error_queue="dalle-mini-tools_errors",
region_name="us-east-1",
interval=1,
):
Expand Down