Skip to content

Commit

Permalink
add ai module
Browse files Browse the repository at this point in the history
  • Loading branch information
yujonglee committed Jun 22, 2024
1 parent ab75f74 commit 2a65408
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 8 deletions.
7 changes: 7 additions & 0 deletions core/config/runtime.exs
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,10 @@ if config_env() == :prod do
end

config :nostrum, :token, System.get_env("DISCORD_BOT_TOKEN")

if config_env() != :test do
config :canary, :openai_api_base, System.fetch_env!("OPENAI_API_BASE")
config :canary, :openai_api_key, System.fetch_env!("OPENAI_API_KEY")
config :canary, :chat_completion_model, System.fetch_env!("CHAT_COMPLETION_MODEL")
config :canary, :text_embedding_model, System.fetch_env!("TEXT_EMBEDDING_MODEL")
end
10 changes: 3 additions & 7 deletions core/lib/canary.ex
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
defmodule Canary do
@moduledoc """
Canary keeps the contexts that define your domain
and business logic.
Contexts are also responsible for managing your data, regardless
if it comes from the database, an external API or others.
"""
def rest_client(opts \\ []) do
Req.new(opts)
end
end
124 changes: 124 additions & 0 deletions core/lib/canary/ai.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
defmodule Canary.AI do
use Retry

defp client() do
proxy_url = Application.fetch_env!(:canary, :openai_api_base)
proxy_key = Application.fetch_env!(:canary, :openai_api_key)

Canary.rest_client(
base_url: proxy_url,
headers: [{"Authorization", "Bearer #{proxy_key}"}]
)
end

def embedding(request) do
resp =
retry with: exponential_backoff() |> randomize |> cap(1_000) |> expiry(4_000) do
client()
|> Req.post(
url: "/v1/embeddings",
json: request
)
end

case resp do
{:ok, %{status: 200, body: %{"data" => data}}} ->
{:ok, data |> Enum.map(& &1["embedding"])}

{:ok, data} ->
{:error, data}

{:error, error} ->
{:error, error}
end
end

def chat(request, opts \\ []) do
opts = Keyword.merge([callback: fn data -> IO.inspect(data) end], opts)
into = if request[:stream], do: get_handler(opts[:callback]), else: nil

request =
request
|> Map.update!(:messages, &trim/1)
|> Map.update(:tools, nil, &trim/1)
|> Map.reject(fn {_k, v} -> is_nil(v) end)

resp =
retry with: exponential_backoff() |> randomize |> cap(1_000) |> expiry(4_000) do
client()
|> Req.post(
url: "/v1/chat/completions",
json: request,
into: into
)
end

case resp do
{:ok, %{body: %{"choices" => [%{"finish_reason" => "tool_calls", "message" => message}]}}} ->
{:ok, parse_tool_calls(message)}

{:ok, %{body: %{"choices" => [%{"delta" => delta}]}}} ->
{:ok, delta["content"]}

{:ok, %{body: %{"choices" => [%{"message" => message}]}}} ->
tool_calls = parse_tool_calls(message)

if tool_calls != [] do
{:ok, tool_calls}
else
{:ok, message["content"]}
end

{:ok, %{body: body}} ->
{:ok, body}

{:error, error} ->
{:error, error}
end
end

defp parse_tool_calls(message) do
message
|> get_in([Access.key("tool_calls", [])])
|> Enum.map(fn %{"function" => f} ->
%{
name: f["name"],
args: Jason.decode!(f["arguments"])
}
end)
end

defp get_handler(callback) do
fn {:data, data}, acc ->
Enum.each(parse(data), callback)
{:cont, acc}
end
end

defp parse(chunk) do
chunk
|> String.split("data: ")
|> Enum.map(&String.trim/1)
|> Enum.map(&decode/1)
|> Enum.reject(&is_nil/1)
end

defp decode(""), do: nil
defp decode("[DONE]"), do: nil

defp decode(data) do
case Jason.decode(data) do
{:ok, r} -> r
_ -> nil
end
end

def trim(data) when is_list(data), do: Enum.map(data, &trim/1)

def trim(data) when is_map(data) do
data |> Map.new(fn {k, v} -> {k, trim(v)} end)
end

def trim(data) when is_binary(data), do: String.trim(data)
def trim(data), do: data
end
4 changes: 3 additions & 1 deletion core/mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ defmodule Canary.MixProject do
{:ash, "~> 3.0"},
{:ash_authentication, "~> 4.0"},
{:ash_authentication_phoenix, "~> 2.0"},
{:ash_postgres, "~> 2.0"}
{:ash_postgres, "~> 2.0"},
{:req, "~> 0.5.0"},
{:retry, "~> 0.18"}
]
end

Expand Down
1 change: 1 addition & 0 deletions core/mix.lock
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
"postgrex": {:hex, :postgrex, "0.18.0", "f34664101eaca11ff24481ed4c378492fed2ff416cd9b06c399e90f321867d7e", [:mix], [{:db_connection, "~> 2.1", [hex: :db_connection, repo: "hexpm", optional: false]}, {:decimal, "~> 1.5 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: true]}, {:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: true]}], "hexpm", "a042989ba1bc1cca7383ebb9e461398e3f89f868c92ce6671feb7ef132a252d1"},
"reactor": {:hex, :reactor, "0.8.4", "344d02ba4a0010763851f4e4aa0ff190ebe7e392e3c27c6cd143dde077b986e7", [:mix], [{:libgraph, "~> 0.16", [hex: :libgraph, repo: "hexpm", optional: false]}, {:spark, "~> 2.0", [hex: :spark, repo: "hexpm", optional: false]}, {:splode, "~> 0.2", [hex: :splode, repo: "hexpm", optional: false]}, {:telemetry, "~> 1.2", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "49c1fd3c786603cec8140ce941c41c7ea72cc4411860ccdee9876c4ca2204f81"},
"req": {:hex, :req, "0.5.0", "6d8a77c25cfc03e06a439fb12ffb51beade53e3fe0e2c5e362899a18b50298b3", [:mix], [{:brotli, "~> 0.3.1", [hex: :brotli, repo: "hexpm", optional: true]}, {:ezstd, "~> 1.0", [hex: :ezstd, repo: "hexpm", optional: true]}, {:finch, "~> 0.17", [hex: :finch, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:mime, "~> 1.6 or ~> 2.0", [hex: :mime, repo: "hexpm", optional: false]}, {:nimble_csv, "~> 1.0", [hex: :nimble_csv, repo: "hexpm", optional: true]}, {:plug, "~> 1.0", [hex: :plug, repo: "hexpm", optional: true]}], "hexpm", "dda04878c1396eebbfdec6db6f3d4ca609e5c8846b7ee88cc56eb9891406f7a3"},
"retry": {:hex, :retry, "0.18.0", "dc58ebe22c95aa00bc2459f9e0c5400e6005541cf8539925af0aa027dc860543", [:mix], [], "hexpm", "9483959cc7bf69c9e576d9dfb2b678b71c045d3e6f39ab7c9aa1489df4492d73"},
"rewrite": {:hex, :rewrite, "0.10.5", "6afadeae0b9d843b27ac6225e88e165884875e0aed333ef4ad3bf36f9c101bed", [:mix], [{:glob_ex, "~> 0.1", [hex: :glob_ex, repo: "hexpm", optional: false]}, {:sourceror, "~> 1.0", [hex: :sourceror, repo: "hexpm", optional: false]}], "hexpm", "51cc347a4269ad3a1e7a2c4122dbac9198302b082f5615964358b4635ebf3d4f"},
"rustler": {:hex, :rustler, "0.32.1", "f4cf5a39f9e85d182c0a3f75fa15b5d0add6542ab0bf9ceac6b4023109ebd3fc", [:mix], [{:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:toml, "~> 0.6", [hex: :toml, repo: "hexpm", optional: false]}], "hexpm", "b96be75526784f86f6587f051bc8d6f4eaff23d6e0f88dbcfe4d5871f52946f7"},
"salsa20": {:hex, :salsa20, "1.0.4", "404cbea1fa8e68a41bcc834c0a2571ac175580fec01cc38cc70c0fb9ffc87e9b", [:mix], [], "hexpm", "745ddcd8cfa563ddb0fd61e7ce48d5146279a2cf7834e1da8441b369fdc58ac6"},
Expand Down

0 comments on commit 2a65408

Please sign in to comment.