From f0ad3a7252f0f6be4c2de36206e1dc13367d540e Mon Sep 17 00:00:00 2001 From: andthattoo Date: Tue, 14 May 2024 17:34:11 +0300 Subject: [PATCH 01/18] function calling abilities, multiple tools and base traits for various function calling standards. --- Cargo.lock | 592 +++++++++++++++++- Cargo.toml | 8 +- src/error.rs | 16 + src/functions/mod.rs | 53 ++ src/functions/pipelines/mod.rs | 2 + src/functions/pipelines/nous_hermes/mod.rs | 2 + .../pipelines/nous_hermes/parsers.rs | 0 .../pipelines/nous_hermes/prompts.rs | 0 src/functions/pipelines/openai/mod.rs | 92 +++ src/functions/pipelines/openai/parsers.rs | 23 + src/functions/pipelines/openai/prompts.rs | 29 + src/functions/pipelines/openai/request.rs | 66 ++ src/functions/tools/mod.rs | 59 ++ src/functions/tools/scraper.rs | 69 ++ src/functions/tools/search_ddg.rs | 109 ++++ src/functions/tools/weather.rs | 33 + src/lib.rs | 2 + 17 files changed, 1145 insertions(+), 10 deletions(-) create mode 100644 src/functions/mod.rs create mode 100644 src/functions/pipelines/mod.rs create mode 100644 src/functions/pipelines/nous_hermes/mod.rs create mode 100644 src/functions/pipelines/nous_hermes/parsers.rs create mode 100644 src/functions/pipelines/nous_hermes/prompts.rs create mode 100644 src/functions/pipelines/openai/mod.rs create mode 100644 src/functions/pipelines/openai/parsers.rs create mode 100644 src/functions/pipelines/openai/prompts.rs create mode 100644 src/functions/pipelines/openai/request.rs create mode 100644 src/functions/tools/mod.rs create mode 100644 src/functions/tools/scraper.rs create mode 100644 src/functions/tools/search_ddg.rs create mode 100644 src/functions/tools/weather.rs diff --git a/Cargo.lock b/Cargo.lock index 9f61ef8..49ff6cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,51 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "ahash" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +dependencies = [ + "cfg-if", + "getrandom", + "once_cell", + "version_check", + "zerocopy", +] + +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + +[[package]] +name = "async-trait" +version = "0.1.80" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.59", +] + +[[package]] +name = "auto_enums" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1899bfcfd9340ceea3533ea157360ba8fa864354eccbceab58e1006ecab35393" +dependencies = [ + "derive_utils", + "proc-macro2", + "quote", + "syn 2.0.59", +] + [[package]] name = "autocfg" version = "1.1.0" @@ -62,6 +107,12 @@ version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec" +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + [[package]] name = "bytes" version = "1.5.0" @@ -99,6 +150,78 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" +[[package]] +name = "cssparser" +version = "0.31.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b3df4f93e5fbbe73ec01ec8d3f68bba73107993a5b1e7519273c32db9b0d5be" +dependencies = [ + "cssparser-macros", + "dtoa-short", + "itoa", + "phf 0.11.2", + "smallvec", +] + +[[package]] +name = "cssparser-macros" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13b588ba4ac1a99f7f2964d24b3d896ddc6bf847ee3855dbd4366f058cfcd331" +dependencies = [ + "quote", + "syn 2.0.59", +] + +[[package]] +name = "derive_more" +version = "0.99.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fb810d30a7c1953f91334de7244731fc3f3c10d7fe163338a35b9f640960321" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "derive_utils" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61bb5a1014ce6dfc2a378578509abe775a5aa06bff584a547555d9efdb81b926" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.59", +] + +[[package]] +name = "dtoa" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcbb2bf8e87535c23f7a8a321e364ce21462d0ff10cb6407820e8e96dfff6653" + +[[package]] +name = "dtoa-short" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbaceec3c6e4211c79e7b1800fb9680527106beb2f9c51904a3210c03a448c74" +dependencies = [ + "dtoa", +] + +[[package]] +name = "ego-tree" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a68a4904193147e0a8dec3314640e6db742afd5f6e634f428a6af230d9b3591" + +[[package]] +name = "either" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2" + [[package]] name = "errno" version = "0.3.5" @@ -145,6 +268,16 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futf" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df420e2e84819663797d1ec6544b13c5be84629e7bb00dc960d6917db2987843" +dependencies = [ + "mac", + "new_debug_unreachable", +] + [[package]] name = "futures-channel" version = "0.3.29" @@ -174,7 +307,7 @@ checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.59", ] [[package]] @@ -206,6 +339,24 @@ dependencies = [ "slab", ] +[[package]] +name = "fxhash" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c31b6d751ae2c7f11320402d34e41349dd1016f8d5d45e48c4312bc8625af50c" +dependencies = [ + "byteorder", +] + +[[package]] +name = "getopts" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14dbbfd5c71d70241ecf9e6f13737f7b5ce823821063188d7e46c41d371eebd5" +dependencies = [ + "unicode-width", +] + [[package]] name = "getrandom" version = "0.2.11" @@ -223,12 +374,32 @@ version = "0.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0" +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + [[package]] name = "hermit-abi" version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" +[[package]] +name = "html5ever" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bea68cab48b8459f17cf1c944c67ddc572d272d9f2b274140f223ecb1da4a3b7" +dependencies = [ + "log", + "mac", + "markup5ever", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "http" version = "1.1.0" @@ -357,6 +528,15 @@ version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.9" @@ -406,6 +586,26 @@ version = "0.4.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" +[[package]] +name = "mac" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c41e0c4fef86961ac6d6f8a82609f55f31b05e4fce149ac5710e439df7619ba4" + +[[package]] +name = "markup5ever" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2629bb1404f3d34c2e921f21fd34ba00b206124c81f65c50b43b6aaefeb016" +dependencies = [ + "log", + "phf 0.10.1", + "phf_codegen", + "string_cache", + "string_cache_codegen", + "tendril", +] + [[package]] name = "memchr" version = "2.6.4" @@ -456,6 +656,12 @@ dependencies = [ "tempfile", ] +[[package]] +name = "new_debug_unreachable" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" + [[package]] name = "num_cpus" version = "1.16.0" @@ -479,11 +685,15 @@ dependencies = [ name = "ollama-rs" version = "0.1.9" dependencies = [ + "async-trait", "base64", + "log", "ollama-rs", "reqwest", + "scraper", "serde", "serde_json", + "text-splitter", "tokio", "tokio-stream", "url", @@ -491,9 +701,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.18.0" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "openssl" @@ -518,7 +728,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.59", ] [[package]] @@ -568,6 +778,86 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "phf" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fabbf1ead8a5bcbc20f5f8b939ee3f5b0f6f281b6ad3468b84656b658b455259" +dependencies = [ + "phf_shared 0.10.0", +] + +[[package]] +name = "phf" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc" +dependencies = [ + "phf_macros", + "phf_shared 0.11.2", +] + +[[package]] +name = "phf_codegen" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fb1c3a8bc4dd4e5cfce29b44ffc14bedd2ee294559a294e2a4d4c9e9a6a13cd" +dependencies = [ + "phf_generator 0.10.0", + "phf_shared 0.10.0", +] + +[[package]] +name = "phf_generator" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d5285893bb5eb82e6aaf5d59ee909a06a16737a8970984dd7746ba9283498d6" +dependencies = [ + "phf_shared 0.10.0", + "rand", +] + +[[package]] +name = "phf_generator" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48e4cc64c2ad9ebe670cb8fd69dd50ae301650392e81c05f9bfcb2d5bdbc24b0" +dependencies = [ + "phf_shared 0.11.2", + "rand", +] + +[[package]] +name = "phf_macros" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3444646e286606587e49f3bcf1679b8cef1dc2c5ecc29ddacaffc305180d464b" +dependencies = [ + "phf_generator 0.11.2", + "phf_shared 0.11.2", + "proc-macro2", + "quote", + "syn 2.0.59", +] + +[[package]] +name = "phf_shared" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6796ad771acdc0123d2a88dc428b5e38ef24456743ddb1744ed628f9815c096" +dependencies = [ + "siphasher", +] + +[[package]] +name = "phf_shared" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90fcb95eef784c2ac79119d1dd819e162b5da872ce6f3c3abe1e8ca1c082f72b" +dependencies = [ + "siphasher", +] + [[package]] name = "pin-project" version = "1.1.5" @@ -585,7 +875,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.59", ] [[package]] @@ -606,6 +896,18 @@ version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + +[[package]] +name = "precomputed-hash" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" + [[package]] name = "proc-macro2" version = "1.0.81" @@ -624,6 +926,36 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + [[package]] name = "redox_syscall" version = "0.4.1" @@ -633,6 +965,35 @@ dependencies = [ "bitflags 1.3.2", ] +[[package]] +name = "regex" +version = "1.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" + [[package]] name = "reqwest" version = "0.12.4" @@ -753,6 +1114,12 @@ dependencies = [ "untrusted", ] +[[package]] +name = "rustversion" +version = "1.0.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "092474d1a01ea8278f69e6a358998405fae5b8b963ddaeb2b0b04a128bf1dfb0" + [[package]] name = "ryu" version = "1.0.15" @@ -774,6 +1141,22 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "scraper" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b80b33679ff7a0ea53d37f3b39de77ea0c75b12c5805ac43ec0c33b3051af1b" +dependencies = [ + "ahash", + "cssparser", + "ego-tree", + "getopts", + "html5ever", + "once_cell", + "selectors", + "tendril", +] + [[package]] name = "security-framework" version = "2.9.2" @@ -797,6 +1180,25 @@ dependencies = [ "libc", ] +[[package]] +name = "selectors" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4eb30575f3638fc8f6815f448d50cb1a2e255b0897985c8c59f4d37b72a07b06" +dependencies = [ + "bitflags 2.4.1", + "cssparser", + "derive_more", + "fxhash", + "log", + "new_debug_unreachable", + "phf 0.10.1", + "phf_codegen", + "precomputed-hash", + "servo_arc", + "smallvec", +] + [[package]] name = "serde" version = "1.0.198" @@ -814,7 +1216,7 @@ checksum = "e88edab869b01783ba905e7d0153f9fc1a6505a96e4ad3018011eedb838566d9" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.59", ] [[package]] @@ -840,6 +1242,15 @@ dependencies = [ "serde", ] +[[package]] +name = "servo_arc" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d036d71a959e00c77a63538b90a6c2390969f9772b096ea837205c6bd0491a44" +dependencies = [ + "stable_deref_trait", +] + [[package]] name = "signal-hook-registry" version = "1.4.1" @@ -849,6 +1260,12 @@ dependencies = [ "libc", ] +[[package]] +name = "siphasher" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" + [[package]] name = "slab" version = "0.4.9" @@ -880,12 +1297,77 @@ version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +[[package]] +name = "stable_deref_trait" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" + +[[package]] +name = "string_cache" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f91138e76242f575eb1d3b38b4f1362f10d3a43f47d182a5b359af488a02293b" +dependencies = [ + "new_debug_unreachable", + "once_cell", + "parking_lot", + "phf_shared 0.10.0", + "precomputed-hash", + "serde", +] + +[[package]] +name = "string_cache_codegen" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bb30289b722be4ff74a408c3cc27edeaad656e06cb1fe8fa9231fa59c728988" +dependencies = [ + "phf_generator 0.10.0", + "phf_shared 0.10.0", + "proc-macro2", + "quote", +] + +[[package]] +name = "strum" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d8cec3501a5194c432b2b7976db6b7d10ec95c253208b45f83f7136aa985e29" +dependencies = [ + "strum_macros", +] + +[[package]] +name = "strum_macros" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6cf59daf282c0a494ba14fd21610a0325f9f90ec9d1231dea26bcb1d696c946" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.59", +] + [[package]] name = "subtle" version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "syn" version = "2.0.59" @@ -916,6 +1398,54 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "tendril" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d24a120c5fc464a3458240ee02c299ebcb9d67b5249c8848b09d639dca8d7bb0" +dependencies = [ + "futf", + "mac", + "utf-8", +] + +[[package]] +name = "text-splitter" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3634ff66852bfbf7e8e987735bac08168daa5d42dc39a0df7d05fc83eaa3fe4" +dependencies = [ + "ahash", + "auto_enums", + "either", + "itertools", + "once_cell", + "regex", + "strum", + "thiserror", + "unicode-segmentation", +] + +[[package]] +name = "thiserror" +version = "1.0.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "579e9083ca58dd9dcf91a9923bb9054071b9ebbd800b342194c9feb0ee89fc18" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2470041c06ec3ac1ab38d0356a6119054dedaea53e12fbefc0de730a1c08524" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.59", +] + [[package]] name = "tinyvec" version = "1.6.0" @@ -958,7 +1488,7 @@ checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.59", ] [[package]] @@ -1082,6 +1612,18 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "unicode-segmentation" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202" + +[[package]] +name = "unicode-width" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68f5e5f3158ecfd4b8ff6fe086db7c8467a2dfdac97fe420f2b7c4aa97af66d6" + [[package]] name = "untrusted" version = "0.9.0" @@ -1099,12 +1641,24 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "vcpkg" version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" +[[package]] +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" + [[package]] name = "want" version = "0.3.1" @@ -1141,7 +1695,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn", + "syn 2.0.59", "wasm-bindgen-shared", ] @@ -1175,7 +1729,7 @@ checksum = "c5353b8dab669f5e10f5bd76df26a9360c748f054f862ff5f3f8aae0c7fb3907" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.59", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -1294,6 +1848,26 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "zerocopy" +version = "0.7.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae87e3fcd617500e5d106f0380cf7b77f3c6092aae37191433159dda23cfb087" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.59", +] + [[package]] name = "zeroize" version = "1.7.0" diff --git a/Cargo.toml b/Cargo.toml index 880a015..ae318ca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,20 +10,26 @@ readme = "README.md" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -reqwest = { version = "0.12.4", default-features = false } +reqwest = { version = "0.12.4", default-features = false, features=["json"] } serde = { version = "1", features = ["derive"] } serde_json = "1" tokio = { version = "1", features = ["full"], optional = true } tokio-stream = { version = "0.1.15", optional = true } url = "2" +log = "0.4" +scraper = {version = "0.19.0" , optional = true } +text-splitter = {version = "0.13.1", optional = true } +async-trait = { version = "0.1.73", optional = true } [features] default = ["reqwest/default-tls"] stream = ["tokio-stream", "reqwest/stream", "tokio"] rustls = ["reqwest/rustls-tls"] chat-history = [] +function-calling = ["scraper", "text-splitter", "async-trait"] [dev-dependencies] tokio = { version = "1", features = ["full"] } ollama-rs = { path = ".", features = ["stream", "chat-history"] } base64 = "0.22.0" + diff --git a/src/error.rs b/src/error.rs index 66f4e8c..436bc9d 100644 --- a/src/error.rs +++ b/src/error.rs @@ -34,3 +34,19 @@ impl From for OllamaError { Self { message } } } + +impl From> for OllamaError { + fn from(error: Box) -> Self { + Self { + message: error.to_string(), + } + } +} + +impl From for OllamaError { + fn from(error: serde_json::Error) -> Self { + Self { + message: error.to_string(), + } + } +} diff --git a/src/functions/mod.rs b/src/functions/mod.rs new file mode 100644 index 0000000..d40df22 --- /dev/null +++ b/src/functions/mod.rs @@ -0,0 +1,53 @@ +pub mod tools; +pub mod pipelines; + +pub use tools::WeatherTool; +pub use tools::Scraper; +pub use tools::DDGSearcher; + +use async_trait::async_trait; +use serde_json::{Value, json}; +use std::error::Error; +use crate::generation::chat::ChatMessage; + + +pub trait FunctionCallBase: Send + Sync { + fn name(&self) -> String; +} + +#[async_trait] +pub trait FunctionCall: FunctionCallBase { + async fn call(&self, params: Value) -> Result>; +} + +pub struct DefaultFunctionCall {} + +impl FunctionCallBase for DefaultFunctionCall { + fn name(&self) -> String { + "default_function".to_string() + } +} + + +pub fn convert_to_ollama_tool(tool: &dyn crate::generation::functions::tools::Tool) -> Value { + let schema = tool.parameters(); + json!({ + "name": tool.name(), + "properties": schema["properties"], + "required": schema["required"] + }) +} + + +pub fn parse_response(message: &ChatMessage) -> Result { + let content = &message.content; + let value: Value = serde_json::from_str(content).map_err(|e| e.to_string())?; + + if let Some(function_call) = value.get("function_call") { + if let Some(arguments) = function_call.get("arguments") { + return Ok(arguments.to_string()); + } + return Err("`arguments` missing from `function_call`".to_string()); + } + Err("`function_call` missing from `content`".to_string()) +} diff --git a/src/functions/pipelines/mod.rs b/src/functions/pipelines/mod.rs new file mode 100644 index 0000000..727e688 --- /dev/null +++ b/src/functions/pipelines/mod.rs @@ -0,0 +1,2 @@ +pub mod openai; +pub mod nous_hermes; \ No newline at end of file diff --git a/src/functions/pipelines/nous_hermes/mod.rs b/src/functions/pipelines/nous_hermes/mod.rs new file mode 100644 index 0000000..6bf9d61 --- /dev/null +++ b/src/functions/pipelines/nous_hermes/mod.rs @@ -0,0 +1,2 @@ +pub mod prompts; +pub mod parsers; \ No newline at end of file diff --git a/src/functions/pipelines/nous_hermes/parsers.rs b/src/functions/pipelines/nous_hermes/parsers.rs new file mode 100644 index 0000000..e69de29 diff --git a/src/functions/pipelines/nous_hermes/prompts.rs b/src/functions/pipelines/nous_hermes/prompts.rs new file mode 100644 index 0000000..e69de29 diff --git a/src/functions/pipelines/openai/mod.rs b/src/functions/pipelines/openai/mod.rs new file mode 100644 index 0000000..0ef6f78 --- /dev/null +++ b/src/functions/pipelines/openai/mod.rs @@ -0,0 +1,92 @@ +pub mod prompts; +pub mod parsers; +pub mod request; + +pub use prompts::{DEFAULT_SYSTEM_TEMPLATE ,DEFAULT_RESPONSE_FUNCTION}; +pub use request::FunctionCallRequest; +pub use parsers::{generate_system_message, parse_response}; + +use std::sync::Arc; +use async_trait::async_trait; +use serde_json::{json, Value}; +use std::error::Error; +use crate::generation::functions::{FunctionCall, FunctionCallBase}; +use crate::generation::chat::{ChatMessage, ChatMessageResponse}; +use crate::generation::chat::request::{ChatMessageRequest}; +use crate::generation::functions::tools::Tool; +use crate::error::OllamaError; + + +pub struct OpenAIFunctionCall { + pub name: String, +} + +impl OpenAIFunctionCall { + pub fn new(name: &str) -> Self { + OpenAIFunctionCall { + name: name.to_string(), + } + } +} + +impl FunctionCallBase for OpenAIFunctionCall { + fn name(&self) -> String { + "openai".to_string() + } +} + +#[async_trait] +impl FunctionCall for OpenAIFunctionCall { + async fn call(&self, params: Value) -> Result> { + // Simulate a function call by returning a simple JSON value + Ok(json!({ "result": format!("Function {} called with params: {}", self.name, params) })) + } +} + + +impl crate::Ollama { + pub async fn function_call_with_history( + &self, + request: ChatMessageRequest, + tool: Arc, + ) -> Result { + let function_call = OpenAIFunctionCall::new(&tool.name()); + let params = tool.parameters(); + let result = function_call.call(params).await?; + Ok(ChatMessageResponse { + model: request.model_name, + created_at: "".to_string(), + message: Some(ChatMessage::assistant(result.to_string())), + done: true, + final_data: None, + }) + } + + pub async fn function_call( + &self, + request: ChatMessageRequest, + ) -> crate::error::Result { + let mut request = request; + request.stream = false; + + let url = format!("{}api/chat", self.url_str()); + let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?; + let res = self + .reqwest_client + .post(url) + .body(serialized) + .send() + .await + .map_err(|e| e.to_string())?; + + if !res.status().is_success() { + return Err(res.text().await.unwrap_or_else(|e| e.to_string()).into()); + } + + let bytes = res.bytes().await.map_err(|e| e.to_string())?; + let res = + serde_json::from_slice::(&bytes).map_err(|e| e.to_string())?; + + Ok(res) + } +} \ No newline at end of file diff --git a/src/functions/pipelines/openai/parsers.rs b/src/functions/pipelines/openai/parsers.rs new file mode 100644 index 0000000..980ed95 --- /dev/null +++ b/src/functions/pipelines/openai/parsers.rs @@ -0,0 +1,23 @@ +use crate::generation::chat::ChatMessage; +use serde_json::Value; +use crate::generation::functions::pipelines::openai::DEFAULT_SYSTEM_TEMPLATE; +use crate::generation::functions::tools::Tool; + +pub fn parse_response(message: &ChatMessage) -> Result { + let content = &message.content; + let value: Value = serde_json::from_str(content).map_err(|e| e.to_string())?; + + if let Some(function_call) = value.get("function_call") { + Ok(function_call.clone()) + } else { + Ok(value) + } +} + +pub fn generate_system_message(tools: &[&dyn Tool]) -> ChatMessage { + let tools_info: Vec = tools.iter().map(|tool| tool.parameters()).collect(); + let tools_json = serde_json::to_string(&tools_info).unwrap(); + let system_message_content = DEFAULT_SYSTEM_TEMPLATE.replace("{tools}", &tools_json); + ChatMessage::system(system_message_content) +} + diff --git a/src/functions/pipelines/openai/prompts.rs b/src/functions/pipelines/openai/prompts.rs new file mode 100644 index 0000000..94e0749 --- /dev/null +++ b/src/functions/pipelines/openai/prompts.rs @@ -0,0 +1,29 @@ +pub const DEFAULT_SYSTEM_TEMPLATE: &str = r#" +You have access to the following tools: + +{tools} + +You must always select one of the above tools and respond with only a JSON object matching the following schema: + +{ + "tool": , + "tool_input": +} +"#; + +pub const DEFAULT_RESPONSE_FUNCTION: &str = r#" +{ + "name": "__conversational_response", + "description": "Respond conversationally if no other tools should be called for a given query.", + "parameters": { + "type": "object", + "properties": { + "response": { + "type": "string", + "description": "Conversational response to the user." + } + }, + "required": ["response"] + } +} +"#; diff --git a/src/functions/pipelines/openai/request.rs b/src/functions/pipelines/openai/request.rs new file mode 100644 index 0000000..301b6ae --- /dev/null +++ b/src/functions/pipelines/openai/request.rs @@ -0,0 +1,66 @@ +use serde_json::Value; +use std::sync::Arc; +use crate::generation::chat::{ChatMessage, ChatMessageResponse}; +use crate::generation::chat::request::{ChatMessageRequest}; +use crate::generation::functions::pipelines::openai::DEFAULT_SYSTEM_TEMPLATE; +use crate::generation::functions::tools::Tool; +use crate::Ollama; +use crate::error::OllamaError; + +#[derive(Clone)] +pub struct FunctionCallRequest { + model_name: String, + tools: Vec>, +} + +impl FunctionCallRequest { + pub fn new(model_name: &str, tools: Vec>) -> Self { + FunctionCallRequest { + model_name: model_name.to_string(), + tools, + } + } + + pub async fn send(&self, ollama: &mut Ollama, input: &str) -> Result { + let system_message = self.get_system_message(); + ollama.send_chat_messages_with_history( + ChatMessageRequest::new(self.model_name.clone(), vec![system_message.clone()]), + "default".to_string(), + ).await?; + + let user_message = ChatMessage::user(input.to_string()); + + let result = ollama + .send_chat_messages_with_history( + ChatMessageRequest::new(self.model_name.clone(), vec![user_message]), + "default".to_string(), + ).await?; + + let response_content = result.message.clone().unwrap().content; + let response_value: Value = match serde_json::from_str(&response_content) { + Ok(value) => value, + Err(e) => return Err(OllamaError::from(e.to_string())), + }; + + if let Some(function_call) = response_value.get("function_call") { + if let Some(tool_name) = function_call.get("tool").and_then(Value::as_str) { + if let Some(tool) = self.tools.iter().find(|t| t.name() == tool_name) { + let result = ollama.function_call_with_history( + ChatMessageRequest::new(self.model_name.clone(), vec![ChatMessage::user(tool_name.to_string())]), + tool.clone(), + ).await?; + return Ok(result); + } + } + } + + Ok(result) + } + + pub fn get_system_message(&self) -> ChatMessage { + let tools_info: Vec = self.tools.iter().map(|tool| tool.parameters()).collect(); + let tools_json = serde_json::to_string(&tools_info).unwrap(); + let system_message_content = DEFAULT_SYSTEM_TEMPLATE.replace("{tools}", &tools_json); + ChatMessage::system(system_message_content) + } +} diff --git a/src/functions/tools/mod.rs b/src/functions/tools/mod.rs new file mode 100644 index 0000000..d824fc9 --- /dev/null +++ b/src/functions/tools/mod.rs @@ -0,0 +1,59 @@ +pub mod search_ddg; +pub mod weather; +pub mod scraper; + +pub use self::weather::WeatherTool; +pub use self::scraper::Scraper; +pub use self::search_ddg::DDGSearcher; + +use async_trait::async_trait; +use serde_json::{json, Value}; +use std::error::Error; +use std::string::String; + +#[async_trait] +pub trait Tool: Send + Sync { + /// Returns the name of the tool. + fn name(&self) -> String; + + /// Provides a description of what the tool does and when to use it. + fn description(&self) -> String; + + /// This are the parameters for OpenAI-like function call. + fn parameters(&self) -> Value { + json!({ + "type": "object", + "properties": { + "input": { + "type": "string", + "description": self.description() + } + }, + "required": ["input"] + }) + } + + /// Processes an input string and executes the tool's functionality, returning a `Result`. + async fn call(&self, input: &str) -> Result> { + let input = self.parse_input(input).await; + self.run(input).await + } + + /// Executes the core functionality of the tool. + async fn run(&self, input: Value) -> Result>; + + /// Parses the input string. + async fn parse_input(&self, input: &str) -> Value { + log::info!("Using default implementation: {}", input); + match serde_json::from_str::(input) { + Ok(input) => { + if input["input"].is_string() { + Value::String(input["input"].as_str().unwrap().to_string()) + } else { + Value::String(input.to_string()) + } + } + Err(_) => Value::String(input.to_string()), + } + } +} \ No newline at end of file diff --git a/src/functions/tools/scraper.rs b/src/functions/tools/scraper.rs new file mode 100644 index 0000000..2e8b4c0 --- /dev/null +++ b/src/functions/tools/scraper.rs @@ -0,0 +1,69 @@ +use reqwest::Client; +use scraper::{Html, Selector}; +use std::env; +use text_splitter::TextSplitter; + +use std::error::Error; +use serde_json::{Value, json}; +use crate::generation::functions::tools::Tool; +use async_trait::async_trait; + +pub struct Scraper {} + + +#[async_trait] +impl Tool for Scraper { + fn name(&self) -> String { + "Website Scraper".to_string() + } + + fn description(&self) -> String { + "Scrapes text content from websites and splits it into manageable chunks.".to_string() + } + + fn parameters(&self) -> Value { + json!({ + "type": "object", + "properties": { + "website": { + "type": "string", + "description": "The URL of the website to scrape" + } + }, + "required": ["website"] + }) + } + + async fn run(&self, input: Value) -> Result> { + let website = input["website"].as_str().ok_or("Website URL is required")?; + let browserless_token = env::var("BROWSERLESS_TOKEN").expect("BROWSERLESS_TOKEN must be set"); + let url = format!("http://0.0.0.0:3000/content?token={}", browserless_token); + let payload = json!({ + "url": website + }); + let client = Client::new(); + let response = client + .post(&url) + .header("cache-control", "no-cache") + .header("content-type", "application/json") + .json(&payload) + .send() + .await?; + + let response_text = response.text().await?; + let document = Html::parse_document(&response_text); + let selector = Selector::parse("p, h1, h2, h3, h4, h5, h6").unwrap(); + let elements: Vec = document + .select(&selector) + .map(|el| el.text().collect::()) + .collect(); + let body = elements.join(" "); + + let splitter = TextSplitter::new(1000); + let chunks = splitter.chunks(&body); + let sentences: Vec = chunks.map(|s| s.to_string()).collect(); + let sentences = sentences.join("\n \n"); + Ok(sentences) + } +} + diff --git a/src/functions/tools/search_ddg.rs b/src/functions/tools/search_ddg.rs new file mode 100644 index 0000000..9407300 --- /dev/null +++ b/src/functions/tools/search_ddg.rs @@ -0,0 +1,109 @@ +use reqwest; + +use url::Url; + +use scraper::{Html, Selector}; +use std::error::Error; + +use crate::generation::functions::tools::Tool; +use serde::{Deserialize, Serialize}; +use serde_json::{Value, json}; +use async_trait::async_trait; + + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SearchResult { + title: String, + link: String, + snippet: String, +} + +impl SearchResult { + fn extract_domain(url: &str) -> Option { + Url::parse(url).ok()?.domain().map(|d| d.to_string()) + } +} + +pub struct DDGSearcher { + pub client: reqwest::Client, + pub base_url: String, +} + +impl DDGSearcher { + pub fn new() -> Self { + DDGSearcher { + client: reqwest::Client::new(), + base_url: "https://duckduckgo.com".to_string(), + } + } + + pub async fn search(&self, query: &str) -> Result, Box> { + let url = format!("{}/html/?q={}", self.base_url, query); + let resp = self.client.get(&url).send().await?; + let body = resp.text().await?; + let document = Html::parse_document(&body); + + let result_selector = Selector::parse(".web-result").unwrap(); + let result_title_selector = Selector::parse(".result__a").unwrap(); + let result_url_selector = Selector::parse(".result__url").unwrap(); + let result_snippet_selector = Selector::parse(".result__snippet").unwrap(); + + let results = document.select(&result_selector).map(|result| { + + let title = result.select(&result_title_selector).next().unwrap().text().collect::>().join(""); + let link = result.select(&result_url_selector).next().unwrap().text().collect::>().join("").trim().to_string(); + let snippet = result.select(&result_snippet_selector).next().unwrap().text().collect::>().join(""); + + SearchResult { + title, + link, + //url: String::from(url.value().attr("href").unwrap()), + snippet, + } + }).collect::>(); + + Ok(results) + } +} + +#[async_trait] +impl Tool for DDGSearcher { + fn name(&self) -> String { + "DDG Searcher".to_string() + } + + fn description(&self) -> String { + "Searches the web using DuckDuckGo's HTML interface.".to_string() + } + + fn parameters(&self) -> Value { + json!({ + "description": "This tool lets you search the web using DuckDuckGo. The input should be a search query.", + "type": "object", + "properties": { + "query": { + "description": "The search query to send to DuckDuckGo", + "type": "string" + } + }, + "required": ["query"] + }) + } + + async fn call(&self, input: &str) -> Result> { + let input_value = self.parse_input(input).await; + self.run(input_value).await + } + + async fn run(&self, input: Value) -> Result> { + let query = input.as_str().ok_or("Input should be a string")?; + let results = self.search(query).await?; + let results_json = serde_json::to_string(&results)?; + Ok(results_json) + } + + async fn parse_input(&self, input: &str) -> Value { + // Use default implementation provided in the Tool trait + Tool::parse_input(self, input).await + } +} \ No newline at end of file diff --git a/src/functions/tools/weather.rs b/src/functions/tools/weather.rs new file mode 100644 index 0000000..2671677 --- /dev/null +++ b/src/functions/tools/weather.rs @@ -0,0 +1,33 @@ +use async_trait::async_trait; +use serde_json::{json, Value}; +use std::error::Error; +use crate::generation::functions::tools::Tool; + +pub struct WeatherTool; + +#[async_trait] +impl Tool for WeatherTool { + fn name(&self) -> String { + "WeatherTool".to_string() + } + + fn description(&self) -> String { + "Get the current weather in a given location.".to_string() + } + + async fn run(&self, input: Value) -> Result> { + let location = input.as_str().ok_or("Input should be a string")?; + let unit = "fahrenheit"; // Default unit + let result = if location.to_lowercase().contains("tokyo") { + json!({"location": "Tokyo", "temperature": "10", "unit": unit}) + } else if location.to_lowercase().contains("san francisco") { + json!({"location": "San Francisco", "temperature": "72", "unit": unit}) + } else if location.to_lowercase().contains("paris") { + json!({"location": "Paris", "temperature": "22", "unit": unit}) + } else { + json!({"location": location, "temperature": "unknown"}) + }; + + Ok(result.to_string()) + } +} diff --git a/src/lib.rs b/src/lib.rs index 18ec517..9365de4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,8 @@ pub mod generation; #[cfg(feature = "chat-history")] pub mod history; pub mod models; +#[cfg(feature = "function-calling")] +pub mod functions; use url::Url; From 51ec2fd97189f049c652c13326fefd871f5d9275 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Fri, 17 May 2024 13:19:55 +0300 Subject: [PATCH 02/18] implemented the ollama.function() standard for function calling implemented base traits removed unnecessary code & import --- Cargo.toml | 10 +- src/functions/mod.rs | 88 +++++++-------- src/functions/pipelines/mod.rs | 15 ++- src/functions/pipelines/openai/mod.rs | 90 +-------------- src/functions/pipelines/openai/parsers.rs | 23 ---- src/functions/pipelines/openai/prompts.rs | 3 + src/functions/pipelines/openai/request.rs | 127 +++++++++++++--------- src/functions/requests.rs | 39 +++++++ src/functions/tools/search_ddg.rs | 15 +-- src/generation.rs | 1 + 10 files changed, 186 insertions(+), 225 deletions(-) delete mode 100644 src/functions/pipelines/openai/parsers.rs create mode 100644 src/functions/requests.rs diff --git a/Cargo.toml b/Cargo.toml index ae318ca..078c31c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,18 +15,18 @@ serde = { version = "1", features = ["derive"] } serde_json = "1" tokio = { version = "1", features = ["full"], optional = true } tokio-stream = { version = "0.1.15", optional = true } +async-trait = { version = "0.1.73" } # Remove optional = true url = "2" -log = "0.4" -scraper = {version = "0.19.0" , optional = true } -text-splitter = {version = "0.13.1", optional = true } -async-trait = { version = "0.1.73", optional = true } +log = "0.4" # Add this line +scraper = {version = "0.19.0" , optional = true } # Add scraper dependency +text-splitter = {version = "0.13.1", optional = true } # Add text_splitter dependency [features] default = ["reqwest/default-tls"] stream = ["tokio-stream", "reqwest/stream", "tokio"] rustls = ["reqwest/rustls-tls"] chat-history = [] -function-calling = ["scraper", "text-splitter", "async-trait"] +function-calling = ["scraper", "text-splitter", "chat-history"] [dev-dependencies] tokio = { version = "1", features = ["full"] } diff --git a/src/functions/mod.rs b/src/functions/mod.rs index d40df22..0f2fcbc 100644 --- a/src/functions/mod.rs +++ b/src/functions/mod.rs @@ -1,53 +1,53 @@ pub mod tools; pub mod pipelines; +pub mod request; pub use tools::WeatherTool; pub use tools::Scraper; pub use tools::DDGSearcher; - -use async_trait::async_trait; -use serde_json::{Value, json}; -use std::error::Error; -use crate::generation::chat::ChatMessage; - - -pub trait FunctionCallBase: Send + Sync { - fn name(&self) -> String; -} - -#[async_trait] -pub trait FunctionCall: FunctionCallBase { - async fn call(&self, params: Value) -> Result>; -} - -pub struct DefaultFunctionCall {} - -impl FunctionCallBase for DefaultFunctionCall { - fn name(&self) -> String { - "default_function".to_string() +pub use crate::generation::functions::request::FunctionCallRequest; +pub use crate::generation::functions::pipelines::openai::request::OpenAIFunctionCall; + +use crate::generation::chat::ChatMessageResponse; +use crate::generation::chat::request::{ChatMessageRequest}; +use crate::generation::functions::tools::Tool; +use crate::error::OllamaError; +use std::sync::Arc; +use crate::generation::functions::pipelines::RequestParserBase; + + +/*impl Ollama { + pub fn new_with_function_calling() -> Self{ + let m: Ollama = Ollama::new_default_with_history(30); + m.add_assistant_response("default".to_string(), "openai".to_string()); + return m } -} - - -pub fn convert_to_ollama_tool(tool: &dyn crate::generation::functions::tools::Tool) -> Value { - let schema = tool.parameters(); - json!({ - "name": tool.name(), - "properties": schema["properties"], - "required": schema["required"] - }) -} - - -pub fn parse_response(message: &ChatMessage) -> Result { - let content = &message.content; - let value: Value = serde_json::from_str(content).map_err(|e| e.to_string())?; - - if let Some(function_call) = value.get("function_call") { - if let Some(arguments) = function_call.get("arguments") { - return Ok(arguments.to_string()); - } - return Err("`arguments` missing from `function_call`".to_string()); +}*/ + + +#[cfg(feature = "function-calling")] +impl crate::Ollama { + + pub async fn send_function_call_with_history( + &mut self, + request: FunctionCallRequest, + parser: Arc, + ) -> Result { + //let system_message = request.chat.messages.first().unwrap().clone(); + let system_prompt = parser.get_system_message(&request.tools).await; //TODO: Check if system prompt is added + self.send_chat_messages_with_history( + ChatMessageRequest::new(request.chat.model_name.clone(), vec![system_prompt.clone()]), + "default".to_string(), + ).await?; + + let result = self + .send_chat_messages_with_history( + ChatMessageRequest::new(request.chat.model_name.clone(), request.chat.messages), + "default".to_string(), + ).await?; + + let response_content: String = result.message.clone().unwrap().content; + let result = parser.parse(&response_content, request.chat.model_name.clone(), request.tools).await?; + return Ok(result); } - Err("`function_call` missing from `content`".to_string()) } diff --git a/src/functions/pipelines/mod.rs b/src/functions/pipelines/mod.rs index 727e688..ab6129e 100644 --- a/src/functions/pipelines/mod.rs +++ b/src/functions/pipelines/mod.rs @@ -1,2 +1,15 @@ +use crate::generation::chat::{ChatMessage, ChatMessageResponse}; +use crate::error::OllamaError; +use crate::generation::functions::tools::Tool; +use async_trait::async_trait; +use std::sync::Arc; + pub mod openai; -pub mod nous_hermes; \ No newline at end of file +pub mod nous_hermes; + + +#[async_trait] +pub trait RequestParserBase { + async fn parse(&self, input: &str, model_name:String, tools: Vec>) -> Result; + async fn get_system_message(&self, tools: &[Arc]) -> ChatMessage; +} diff --git a/src/functions/pipelines/openai/mod.rs b/src/functions/pipelines/openai/mod.rs index 0ef6f78..bc9e7a4 100644 --- a/src/functions/pipelines/openai/mod.rs +++ b/src/functions/pipelines/openai/mod.rs @@ -1,92 +1,4 @@ pub mod prompts; -pub mod parsers; pub mod request; -pub use prompts::{DEFAULT_SYSTEM_TEMPLATE ,DEFAULT_RESPONSE_FUNCTION}; -pub use request::FunctionCallRequest; -pub use parsers::{generate_system_message, parse_response}; - -use std::sync::Arc; -use async_trait::async_trait; -use serde_json::{json, Value}; -use std::error::Error; -use crate::generation::functions::{FunctionCall, FunctionCallBase}; -use crate::generation::chat::{ChatMessage, ChatMessageResponse}; -use crate::generation::chat::request::{ChatMessageRequest}; -use crate::generation::functions::tools::Tool; -use crate::error::OllamaError; - - -pub struct OpenAIFunctionCall { - pub name: String, -} - -impl OpenAIFunctionCall { - pub fn new(name: &str) -> Self { - OpenAIFunctionCall { - name: name.to_string(), - } - } -} - -impl FunctionCallBase for OpenAIFunctionCall { - fn name(&self) -> String { - "openai".to_string() - } -} - -#[async_trait] -impl FunctionCall for OpenAIFunctionCall { - async fn call(&self, params: Value) -> Result> { - // Simulate a function call by returning a simple JSON value - Ok(json!({ "result": format!("Function {} called with params: {}", self.name, params) })) - } -} - - -impl crate::Ollama { - pub async fn function_call_with_history( - &self, - request: ChatMessageRequest, - tool: Arc, - ) -> Result { - let function_call = OpenAIFunctionCall::new(&tool.name()); - let params = tool.parameters(); - let result = function_call.call(params).await?; - Ok(ChatMessageResponse { - model: request.model_name, - created_at: "".to_string(), - message: Some(ChatMessage::assistant(result.to_string())), - done: true, - final_data: None, - }) - } - - pub async fn function_call( - &self, - request: ChatMessageRequest, - ) -> crate::error::Result { - let mut request = request; - request.stream = false; - - let url = format!("{}api/chat", self.url_str()); - let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?; - let res = self - .reqwest_client - .post(url) - .body(serialized) - .send() - .await - .map_err(|e| e.to_string())?; - - if !res.status().is_success() { - return Err(res.text().await.unwrap_or_else(|e| e.to_string()).into()); - } - - let bytes = res.bytes().await.map_err(|e| e.to_string())?; - let res = - serde_json::from_slice::(&bytes).map_err(|e| e.to_string())?; - - Ok(res) - } -} \ No newline at end of file +pub use prompts::{DEFAULT_SYSTEM_TEMPLATE ,DEFAULT_RESPONSE_FUNCTION}; \ No newline at end of file diff --git a/src/functions/pipelines/openai/parsers.rs b/src/functions/pipelines/openai/parsers.rs deleted file mode 100644 index 980ed95..0000000 --- a/src/functions/pipelines/openai/parsers.rs +++ /dev/null @@ -1,23 +0,0 @@ -use crate::generation::chat::ChatMessage; -use serde_json::Value; -use crate::generation::functions::pipelines::openai::DEFAULT_SYSTEM_TEMPLATE; -use crate::generation::functions::tools::Tool; - -pub fn parse_response(message: &ChatMessage) -> Result { - let content = &message.content; - let value: Value = serde_json::from_str(content).map_err(|e| e.to_string())?; - - if let Some(function_call) = value.get("function_call") { - Ok(function_call.clone()) - } else { - Ok(value) - } -} - -pub fn generate_system_message(tools: &[&dyn Tool]) -> ChatMessage { - let tools_info: Vec = tools.iter().map(|tool| tool.parameters()).collect(); - let tools_json = serde_json::to_string(&tools_info).unwrap(); - let system_message_content = DEFAULT_SYSTEM_TEMPLATE.replace("{tools}", &tools_json); - ChatMessage::system(system_message_content) -} - diff --git a/src/functions/pipelines/openai/prompts.rs b/src/functions/pipelines/openai/prompts.rs index 94e0749..e858ca9 100644 --- a/src/functions/pipelines/openai/prompts.rs +++ b/src/functions/pipelines/openai/prompts.rs @@ -1,8 +1,11 @@ pub const DEFAULT_SYSTEM_TEMPLATE: &str = r#" +You are a function calling AI agent with self-recursion. +You can call only one function at a time and analyse data you get from function response. You have access to the following tools: {tools} +Don't make assumptions about what values to plug into function arguments. You must always select one of the above tools and respond with only a JSON object matching the following schema: { diff --git a/src/functions/pipelines/openai/request.rs b/src/functions/pipelines/openai/request.rs index 301b6ae..d306d94 100644 --- a/src/functions/pipelines/openai/request.rs +++ b/src/functions/pipelines/openai/request.rs @@ -1,66 +1,91 @@ -use serde_json::Value; -use std::sync::Arc; -use crate::generation::chat::{ChatMessage, ChatMessageResponse}; -use crate::generation::chat::request::{ChatMessageRequest}; -use crate::generation::functions::pipelines::openai::DEFAULT_SYSTEM_TEMPLATE; -use crate::generation::functions::tools::Tool; -use crate::Ollama; -use crate::error::OllamaError; +use serde::Serialize; -#[derive(Clone)] -pub struct FunctionCallRequest { - model_name: String, - tools: Vec>, +use crate::generation::{ + images::Image, + options::GenerationOptions, + parameters::{FormatType, KeepAlive}, +}; + +use super::GenerationContext; + +/// A generation request to Ollama. +#[derive(Debug, Clone, Serialize)] +pub struct GenerationRequest { + #[serde(rename = "model")] + pub model_name: String, + pub prompt: String, + pub images: Vec, + pub options: Option, + pub system: Option, + pub template: Option, + pub context: Option, + pub format: Option, + pub keep_alive: Option, + pub(crate) stream: bool, } -impl FunctionCallRequest { - pub fn new(model_name: &str, tools: Vec>) -> Self { - FunctionCallRequest { - model_name: model_name.to_string(), - tools, +impl GenerationRequest { + pub fn new(model_name: String, prompt: String) -> Self { + Self { + model_name, + prompt, + images: Vec::new(), + options: None, + system: None, + template: None, + context: None, + format: None, + keep_alive: None, + // Stream value will be overwritten by Ollama::generate_stream() and Ollama::generate() methods + stream: false, } } - pub async fn send(&self, ollama: &mut Ollama, input: &str) -> Result { - let system_message = self.get_system_message(); - ollama.send_chat_messages_with_history( - ChatMessageRequest::new(self.model_name.clone(), vec![system_message.clone()]), - "default".to_string(), - ).await?; + /// A list of images to be used with the prompt + pub fn images(mut self, images: Vec) -> Self { + self.images = images; + self + } + + /// Add an image to be used with the prompt + pub fn add_image(mut self, image: Image) -> Self { + self.images.push(image); + self + } - let user_message = ChatMessage::user(input.to_string()); + /// Additional model parameters listed in the documentation for the Modelfile + pub fn options(mut self, options: GenerationOptions) -> Self { + self.options = Some(options); + self + } - let result = ollama - .send_chat_messages_with_history( - ChatMessageRequest::new(self.model_name.clone(), vec![user_message]), - "default".to_string(), - ).await?; + /// System prompt to (overrides what is defined in the Modelfile) + pub fn system(mut self, system: String) -> Self { + self.system = Some(system); + self + } - let response_content = result.message.clone().unwrap().content; - let response_value: Value = match serde_json::from_str(&response_content) { - Ok(value) => value, - Err(e) => return Err(OllamaError::from(e.to_string())), - }; + /// The full prompt or prompt template (overrides what is defined in the Modelfile) + pub fn template(mut self, template: String) -> Self { + self.template = Some(template); + self + } - if let Some(function_call) = response_value.get("function_call") { - if let Some(tool_name) = function_call.get("tool").and_then(Value::as_str) { - if let Some(tool) = self.tools.iter().find(|t| t.name() == tool_name) { - let result = ollama.function_call_with_history( - ChatMessageRequest::new(self.model_name.clone(), vec![ChatMessage::user(tool_name.to_string())]), - tool.clone(), - ).await?; - return Ok(result); - } - } - } + /// The context parameter returned from a previous request to /generate, this can be used to keep a short conversational memory + pub fn context(mut self, context: GenerationContext) -> Self { + self.context = Some(context); + self + } - Ok(result) + // The format to return a response in. Currently the only accepted value is `json` + pub fn format(mut self, format: FormatType) -> Self { + self.format = Some(format); + self } - pub fn get_system_message(&self) -> ChatMessage { - let tools_info: Vec = self.tools.iter().map(|tool| tool.parameters()).collect(); - let tools_json = serde_json::to_string(&tools_info).unwrap(); - let system_message_content = DEFAULT_SYSTEM_TEMPLATE.replace("{tools}", &tools_json); - ChatMessage::system(system_message_content) + /// Used to control how long a model stays loaded in memory, by default models are unloaded after 5 minutes of inactivity + pub fn keep_alive(mut self, keep_alive: KeepAlive) -> Self { + self.keep_alive = Some(keep_alive); + self } } diff --git a/src/functions/requests.rs b/src/functions/requests.rs new file mode 100644 index 0000000..11fd439 --- /dev/null +++ b/src/functions/requests.rs @@ -0,0 +1,39 @@ +use crate::generation::chat::request::ChatMessageRequest; +use std::sync::Arc; +use crate::generation::chat::ChatMessage; +use crate::generation::{options::GenerationOptions, parameters::FormatType}; +use crate::generation::functions::Tool; + +#[derive(Clone)] +pub struct FunctionCallRequest { + pub chat: ChatMessageRequest, + pub tools: Vec> +} + +impl FunctionCallRequest { + pub fn new(model_name: String, tools: Vec>, messages: Vec) -> Self { + let chat = ChatMessageRequest::new(model_name, messages); + Self { + chat, + tools + } + } + + /// Additional model parameters listed in the documentation for the Modelfile + pub fn options(mut self, options: GenerationOptions) -> Self { + self.chat.options = Some(options); + self + } + + /// The full prompt or prompt template (overrides what is defined in the Modelfile) + pub fn template(mut self, template: String) -> Self { + self.chat.template = Some(template); + self + } + + // The format to return a response in. Currently the only accepted value is `json` + pub fn format(mut self, format: FormatType) -> Self { + self.chat.format = Some(format); + self + } +} \ No newline at end of file diff --git a/src/functions/tools/search_ddg.rs b/src/functions/tools/search_ddg.rs index 9407300..67ddfae 100644 --- a/src/functions/tools/search_ddg.rs +++ b/src/functions/tools/search_ddg.rs @@ -1,7 +1,5 @@ use reqwest; -use url::Url; - use scraper::{Html, Selector}; use std::error::Error; @@ -18,12 +16,6 @@ pub struct SearchResult { snippet: String, } -impl SearchResult { - fn extract_domain(url: &str) -> Option { - Url::parse(url).ok()?.domain().map(|d| d.to_string()) - } -} - pub struct DDGSearcher { pub client: reqwest::Client, pub base_url: String, @@ -82,8 +74,8 @@ impl Tool for DDGSearcher { "type": "object", "properties": { "query": { - "description": "The search query to send to DuckDuckGo", - "type": "string" + "type": "string", + "description": "The search query to send to DuckDuckGo" } }, "required": ["query"] @@ -96,14 +88,13 @@ impl Tool for DDGSearcher { } async fn run(&self, input: Value) -> Result> { - let query = input.as_str().ok_or("Input should be a string")?; + let query = input["query"].as_str().unwrap(); let results = self.search(query).await?; let results_json = serde_json::to_string(&results)?; Ok(results_json) } async fn parse_input(&self, input: &str) -> Value { - // Use default implementation provided in the Tool trait Tool::parse_input(self, input).await } } \ No newline at end of file diff --git a/src/generation.rs b/src/generation.rs index e739bcd..b47ec20 100644 --- a/src/generation.rs +++ b/src/generation.rs @@ -4,3 +4,4 @@ pub mod embeddings; pub mod images; pub mod options; pub mod parameters; +pub mod functions; From cfc4457d711ec51ced7aa228c2d93dd2c7d57c48 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Fri, 17 May 2024 14:46:22 +0300 Subject: [PATCH 03/18] OpenAIFunctionCall imported Added NousHermes prompts.rs --- src/functions/mod.rs | 54 +++++-- .../pipelines/nous_hermes/prompts.rs | 40 +++++ src/functions/pipelines/openai/request.rs | 148 +++++++++--------- src/functions/{requests.rs => request.rs} | 0 src/functions/tools/mod.rs | 2 - src/lib.rs | 2 - 6 files changed, 147 insertions(+), 99 deletions(-) rename src/functions/{requests.rs => request.rs} (100%) diff --git a/src/functions/mod.rs b/src/functions/mod.rs index 0f2fcbc..df5e669 100644 --- a/src/functions/mod.rs +++ b/src/functions/mod.rs @@ -2,13 +2,12 @@ pub mod tools; pub mod pipelines; pub mod request; -pub use tools::WeatherTool; pub use tools::Scraper; pub use tools::DDGSearcher; pub use crate::generation::functions::request::FunctionCallRequest; pub use crate::generation::functions::pipelines::openai::request::OpenAIFunctionCall; -use crate::generation::chat::ChatMessageResponse; +use crate::generation::chat::{ChatMessage, ChatMessageResponse}; use crate::generation::chat::request::{ChatMessageRequest}; use crate::generation::functions::tools::Tool; use crate::error::OllamaError; @@ -16,29 +15,28 @@ use std::sync::Arc; use crate::generation::functions::pipelines::RequestParserBase; -/*impl Ollama { - pub fn new_with_function_calling() -> Self{ - let m: Ollama = Ollama::new_default_with_history(30); - m.add_assistant_response("default".to_string(), "openai".to_string()); - return m - } -}*/ - - #[cfg(feature = "function-calling")] impl crate::Ollama { + pub async fn check_system_message(&self, messages: &Vec, system_prompt: &str) -> bool { + let system_message = messages.first().unwrap().clone(); + return system_message.content == system_prompt + } + + #[cfg(feature = "chat-history")] pub async fn send_function_call_with_history( &mut self, request: FunctionCallRequest, parser: Arc, ) -> Result { - //let system_message = request.chat.messages.first().unwrap().clone(); - let system_prompt = parser.get_system_message(&request.tools).await; //TODO: Check if system prompt is added - self.send_chat_messages_with_history( - ChatMessageRequest::new(request.chat.model_name.clone(), vec![system_prompt.clone()]), - "default".to_string(), - ).await?; + + let system_prompt = parser.get_system_message(&request.tools).await; + if request.chat.messages.len() == 0{ // If there are no messages in the chat, add a system prompt + self.send_chat_messages_with_history( + ChatMessageRequest::new(request.chat.model_name.clone(), vec![system_prompt.clone()]), + "default".to_string(), + ).await?; + } let result = self .send_chat_messages_with_history( @@ -50,4 +48,26 @@ impl crate::Ollama { let result = parser.parse(&response_content, request.chat.model_name.clone(), request.tools).await?; return Ok(result); } + + + pub async fn send_function_call( + &self, + request: FunctionCallRequest, + parser: Arc, + ) -> crate::error::Result { + let mut request = request; + + request.chat.stream = false; + let system_prompt = parser.get_system_message(&request.tools).await; + let model_name = request.chat.model_name.clone(); + + //Make sure the first message in chat is the system prompt + if !self.check_system_message(&request.chat.messages, &system_prompt.content).await { + request.chat.messages.insert(0, system_prompt); + } + let result = self.send_chat_messages(request.chat).await?; + let response_content: String = result.message.clone().unwrap().content; + let result = parser.parse(&response_content, model_name, request.tools).await?; + return Ok(result); + } } diff --git a/src/functions/pipelines/nous_hermes/prompts.rs b/src/functions/pipelines/nous_hermes/prompts.rs index e69de29..f237a7b 100644 --- a/src/functions/pipelines/nous_hermes/prompts.rs +++ b/src/functions/pipelines/nous_hermes/prompts.rs @@ -0,0 +1,40 @@ +pub const DEFAULT_SYSTEM_TEMPLATE: &str = r#" +Role: | + You are a function calling AI agent with self-recursion. + You can call only one function at a time and analyse data you get from function response. + You are provided with function signatures within XML tags. + The current date is: {date}. +Objective: | + You may use agentic frameworks for reasoning and planning to help with user query. + Please call a function and wait for function results to be provided to you in the next iteration. + Don't make assumptions about what values to plug into function arguments. + Once you have called a function, results will be fed back to you within XML tags. + Don't make assumptions about tool results if XML tags are not present since function hasn't been executed yet. + Analyze the data once you get the results and call another function. + At each iteration please continue adding the your analysis to previous summary. + Your final response should directly answer the user query with an anlysis or summary of the results of function calls. +Tools: | + Here are the available tools: + {tools} + If the provided function signatures doesn't have the function you must call, you may write executable python code in markdown syntax and call code_interpreter() function as follows: + + {{"arguments": {{"code_markdown": , "name": "code_interpreter"}}}} + + Make sure that the json object above with code markdown block is parseable with json.loads() and the XML block with XML ElementTree. +Examples: | + Here are some example usage of functions: + {examples} +Schema: | + Use the following pydantic model json schema for each tool call you will make: + {schema} +Instructions: | + At the very first turn you don't have so you shouldn't not make up the results. + Please keep a running summary with analysis of previous function results and summaries from previous iterations. + Do not stop calling functions until the task has been accomplished or you've reached max iteration of 10. + Calling multiple functions at once can overload the system and increase cost so call one function at a time please. + If you plan to continue with analysis, always call another function. + For each function call return a valid json object (using doulbe quotes) with function name and arguments within XML tags as follows: + + {{"arguments": , "name": }} + +"#; \ No newline at end of file diff --git a/src/functions/pipelines/openai/request.rs b/src/functions/pipelines/openai/request.rs index d306d94..6b35615 100644 --- a/src/functions/pipelines/openai/request.rs +++ b/src/functions/pipelines/openai/request.rs @@ -1,91 +1,83 @@ -use serde::Serialize; +use serde_json::Value; +use std::sync::Arc; +use crate::generation::chat::{ChatMessage, ChatMessageResponse}; +use crate::generation::functions::pipelines::openai::DEFAULT_SYSTEM_TEMPLATE; +use crate::generation::functions::tools::Tool; +use crate::error::OllamaError; +use crate::generation::functions::pipelines::RequestParserBase; +use serde_json::json; +use serde::{Deserialize, Serialize}; +use async_trait::async_trait; -use crate::generation::{ - images::Image, - options::GenerationOptions, - parameters::{FormatType, KeepAlive}, -}; - -use super::GenerationContext; - -/// A generation request to Ollama. -#[derive(Debug, Clone, Serialize)] -pub struct GenerationRequest { - #[serde(rename = "model")] - pub model_name: String, - pub prompt: String, - pub images: Vec, - pub options: Option, - pub system: Option, - pub template: Option, - pub context: Option, - pub format: Option, - pub keep_alive: Option, - pub(crate) stream: bool, +pub fn convert_to_ollama_tool(tool: &Arc) -> Value { + let schema = tool.parameters(); + json!({ + "name": tool.name(), + "properties": schema["properties"], + "required": schema["required"] + }) } -impl GenerationRequest { - pub fn new(model_name: String, prompt: String) -> Self { - Self { - model_name, - prompt, - images: Vec::new(), - options: None, - system: None, - template: None, - context: None, - format: None, - keep_alive: None, - // Stream value will be overwritten by Ollama::generate_stream() and Ollama::generate() methods - stream: false, - } - } - - /// A list of images to be used with the prompt - pub fn images(mut self, images: Vec) -> Self { - self.images = images; - self - } - - /// Add an image to be used with the prompt - pub fn add_image(mut self, image: Image) -> Self { - self.images.push(image); - self - } +#[derive(Debug, Deserialize, Serialize)] +pub struct OpenAIFunctionCallSignature { + pub tool: String, //name of the tool + pub tool_input: Value, +} - /// Additional model parameters listed in the documentation for the Modelfile - pub fn options(mut self, options: GenerationOptions) -> Self { - self.options = Some(options); - self - } +pub struct OpenAIFunctionCall {} - /// System prompt to (overrides what is defined in the Modelfile) - pub fn system(mut self, system: String) -> Self { - self.system = Some(system); - self - } +impl OpenAIFunctionCall { - /// The full prompt or prompt template (overrides what is defined in the Modelfile) - pub fn template(mut self, template: String) -> Self { - self.template = Some(template); - self - } + pub async fn function_call_with_history( + &self, + model_name: String, + tool_params: Value, + tool: Arc, + ) -> Result { - /// The context parameter returned from a previous request to /generate, this can be used to keep a short conversational memory - pub fn context(mut self, context: GenerationContext) -> Self { - self.context = Some(context); - self + let result = tool.run(tool_params).await; + return match result { + Ok(result) => { + Ok(ChatMessageResponse { + model: model_name.clone(), + created_at: "".to_string(), + message: Some(ChatMessage::assistant(result.to_string())), + done: true, + final_data: None, + }) + }, + Err(e) => Err(OllamaError::from(e)) + }; } +} - // The format to return a response in. Currently the only accepted value is `json` - pub fn format(mut self, format: FormatType) -> Self { - self.format = Some(format); - self +#[async_trait] +impl RequestParserBase for OpenAIFunctionCall { + async fn parse(&self, input: &str, model_name: String, tools: Vec>) -> Result { + let response_value: Result = serde_json::from_str(input); + match response_value { + Ok(response) => { + if let Some(tool) = tools.iter().find(|t| t.name() == response.tool) { + let tool_params = response.tool_input; + let result = self.function_call_with_history(model_name.clone(), + tool_params.clone(), + tool.clone(), + ).await?; + return Ok(result); + } else { + return Err(OllamaError::from("Tool not found".to_string())); + } + }, + Err(e) => { + return Err(OllamaError::from(e)); + } + } } - /// Used to control how long a model stays loaded in memory, by default models are unloaded after 5 minutes of inactivity - pub fn keep_alive(mut self, keep_alive: KeepAlive) -> Self { - self.keep_alive = Some(keep_alive); - self + async fn get_system_message(&self, tools: &[Arc]) -> ChatMessage { // Corrected here to use a slice + let tools_info: Vec = tools.iter().map(|tool| convert_to_ollama_tool(tool)).collect(); + let tools_json = serde_json::to_string(&tools_info).unwrap(); + let system_message_content = DEFAULT_SYSTEM_TEMPLATE.replace("{tools}", &tools_json); + ChatMessage::system(system_message_content) } } diff --git a/src/functions/requests.rs b/src/functions/request.rs similarity index 100% rename from src/functions/requests.rs rename to src/functions/request.rs diff --git a/src/functions/tools/mod.rs b/src/functions/tools/mod.rs index d824fc9..6b29b62 100644 --- a/src/functions/tools/mod.rs +++ b/src/functions/tools/mod.rs @@ -1,8 +1,6 @@ pub mod search_ddg; -pub mod weather; pub mod scraper; -pub use self::weather::WeatherTool; pub use self::scraper::Scraper; pub use self::search_ddg::DDGSearcher; diff --git a/src/lib.rs b/src/lib.rs index 9365de4..18ec517 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,8 +3,6 @@ pub mod generation; #[cfg(feature = "chat-history")] pub mod history; pub mod models; -#[cfg(feature = "function-calling")] -pub mod functions; use url::Url; From 3b3144a681f64c1f8b354e1bce7e34a97262e4ee Mon Sep 17 00:00:00 2001 From: andthattoo Date: Fri, 17 May 2024 18:43:50 +0300 Subject: [PATCH 04/18] added nous function calling --- src/functions/pipelines/nous_hermes/mod.rs | 25 ++++- .../pipelines/nous_hermes/parsers.rs | 0 .../pipelines/nous_hermes/request.rs | 95 +++++++++++++++++++ 3 files changed, 119 insertions(+), 1 deletion(-) delete mode 100644 src/functions/pipelines/nous_hermes/parsers.rs create mode 100644 src/functions/pipelines/nous_hermes/request.rs diff --git a/src/functions/pipelines/nous_hermes/mod.rs b/src/functions/pipelines/nous_hermes/mod.rs index 6bf9d61..de21b03 100644 --- a/src/functions/pipelines/nous_hermes/mod.rs +++ b/src/functions/pipelines/nous_hermes/mod.rs @@ -1,2 +1,25 @@ pub mod prompts; -pub mod parsers; \ No newline at end of file +pub mod request; + +pub use prompts::DEFAULT_SYSTEM_TEMPLATE; + +use serde_json::{json, Value}; +use std::collections::HashMap; +use serde_json::Map; +use std::sync::Arc; +use crate::generation::functions::Tool; + +pub fn convert_to_openai_tool(tool: Arc) -> HashMap { + let mut function = HashMap::new(); + function.insert("name".to_string(), Value::String(tool.name())); + function.insert("description".to_string(), Value::String(tool.description())); + function.insert("parameters".to_string(), tool.parameters()); + + let mut result = HashMap::new(); + result.insert("type".to_string(), Value::String("function".to_string())); + + let mapp: Map = function.into_iter().collect(); + result.insert("function".to_string(), Value::Object(mapp)); + + result +} diff --git a/src/functions/pipelines/nous_hermes/parsers.rs b/src/functions/pipelines/nous_hermes/parsers.rs deleted file mode 100644 index e69de29..0000000 diff --git a/src/functions/pipelines/nous_hermes/request.rs b/src/functions/pipelines/nous_hermes/request.rs new file mode 100644 index 0000000..3e8500b --- /dev/null +++ b/src/functions/pipelines/nous_hermes/request.rs @@ -0,0 +1,95 @@ +use serde_json::{json, Map, Value}; +use std::collections::HashMap; +use std::sync::Arc; +use async_trait::async_trait; +use crate::generation::chat::{ChatMessage, ChatMessageResponse}; +use crate::generation::functions::pipelines::nous_hermes::DEFAULT_SYSTEM_TEMPLATE; +use crate::generation::functions::tools::Tool; +use crate::error::OllamaError; +use crate::generation::functions::pipelines::RequestParserBase; +use serde::{Deserialize, Serialize}; + +pub fn convert_to_openai_tool(tool: Arc) -> Value { + let mut function = HashMap::new(); + function.insert("name".to_string(), Value::String(tool.name())); + function.insert("description".to_string(), Value::String(tool.description())); + function.insert("parameters".to_string(), tool.parameters()); + + let mut result = HashMap::new(); + result.insert("type".to_string(), Value::String("function".to_string())); + + let mapp: Map = function.into_iter().collect(); + result.insert("function".to_string(), Value::Object(mapp)); + + json!(result) +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct NousFunctionCallSignature { + pub tool: String, // name of the tool + pub tool_input: Value, +} + +pub struct NousFunctionCall {} + +impl NousFunctionCall { + pub async fn function_call_with_history( + &self, + model_name: String, + tool_params: Value, + tool: Arc, + ) -> Result { + let result = tool.run(tool_params).await; + match result { + Ok(result) => Ok(ChatMessageResponse { + model: model_name.clone(), + created_at: "".to_string(), + message: Some(ChatMessage::assistant(result)), + done: true, + final_data: None, + }), + Err(e) => Err(OllamaError::from(e)), + } + } + + pub fn format_tool_response(&self, function_response: &str) -> String { + format!("\n{}\n\n", function_response) + } +} + +#[async_trait] +impl RequestParserBase for NousFunctionCall { + async fn parse( + &self, + input: &str, + model_name: String, + tools: Vec>, + ) -> Result { + let response_value: Result = serde_json::from_str(input); + match response_value { + Ok(response) => { + if let Some(tool) = tools.iter().find(|t| t.name() == response.tool) { + let tool_params = response.tool_input; + let result = self + .function_call_with_history( + model_name.clone(), + tool_params.clone(), + tool.clone(), + ) + .await?; + return Ok(result); + } else { + return Err(OllamaError::from("Tool not found".to_string())); + } + } + Err(e) => return Err(OllamaError::from(e)), + } + } + + async fn get_system_message(&self, tools: &[Arc]) -> ChatMessage { + let tools_info: Vec = tools.iter().map(|tool| convert_to_openai_tool(tool.clone())).collect(); + let tools_json = serde_json::to_string(&tools_info).unwrap(); + let system_message_content = DEFAULT_SYSTEM_TEMPLATE.replace("{tools}", &tools_json); + ChatMessage::system(system_message_content) + } +} From 42be20580120211bcc00d13166f5de1f5f895b59 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Fri, 17 May 2024 19:16:01 +0300 Subject: [PATCH 05/18] fixed regex issue --- src/functions/pipelines/nous_hermes/mod.rs | 23 +-------- .../pipelines/nous_hermes/request.rs | 51 ++++++++++++------- 2 files changed, 33 insertions(+), 41 deletions(-) diff --git a/src/functions/pipelines/nous_hermes/mod.rs b/src/functions/pipelines/nous_hermes/mod.rs index de21b03..df6fa06 100644 --- a/src/functions/pipelines/nous_hermes/mod.rs +++ b/src/functions/pipelines/nous_hermes/mod.rs @@ -1,25 +1,4 @@ pub mod prompts; pub mod request; -pub use prompts::DEFAULT_SYSTEM_TEMPLATE; - -use serde_json::{json, Value}; -use std::collections::HashMap; -use serde_json::Map; -use std::sync::Arc; -use crate::generation::functions::Tool; - -pub fn convert_to_openai_tool(tool: Arc) -> HashMap { - let mut function = HashMap::new(); - function.insert("name".to_string(), Value::String(tool.name())); - function.insert("description".to_string(), Value::String(tool.description())); - function.insert("parameters".to_string(), tool.parameters()); - - let mut result = HashMap::new(); - result.insert("type".to_string(), Value::String("function".to_string())); - - let mapp: Map = function.into_iter().collect(); - result.insert("function".to_string(), Value::Object(mapp)); - - result -} +pub use prompts::DEFAULT_SYSTEM_TEMPLATE; \ No newline at end of file diff --git a/src/functions/pipelines/nous_hermes/request.rs b/src/functions/pipelines/nous_hermes/request.rs index 3e8500b..4e03d8b 100644 --- a/src/functions/pipelines/nous_hermes/request.rs +++ b/src/functions/pipelines/nous_hermes/request.rs @@ -8,6 +8,7 @@ use crate::generation::functions::tools::Tool; use crate::error::OllamaError; use crate::generation::functions::pipelines::RequestParserBase; use serde::{Deserialize, Serialize}; +use regex::Regex; pub fn convert_to_openai_tool(tool: Arc) -> Value { let mut function = HashMap::new(); @@ -26,8 +27,8 @@ pub fn convert_to_openai_tool(tool: Arc) -> Value { #[derive(Debug, Deserialize, Serialize)] pub struct NousFunctionCallSignature { - pub tool: String, // name of the tool - pub tool_input: Value, + pub name: String, + pub arguments: Value, } pub struct NousFunctionCall {} @@ -53,7 +54,12 @@ impl NousFunctionCall { } pub fn format_tool_response(&self, function_response: &str) -> String { - format!("\n{}\n\n", function_response) + format!("\n{}\n\n", function_response) + } + + pub fn extract_tool_response(&self, content: &str) -> Option { + let re = Regex::new(r"(?s)(.*?)").unwrap(); + re.captures(content).and_then(|cap| cap.get(1).map(|m| m.as_str().to_string())) } } @@ -65,24 +71,31 @@ impl RequestParserBase for NousFunctionCall { model_name: String, tools: Vec>, ) -> Result { - let response_value: Result = serde_json::from_str(input); - match response_value { - Ok(response) => { - if let Some(tool) = tools.iter().find(|t| t.name() == response.tool) { - let tool_params = response.tool_input; - let result = self - .function_call_with_history( - model_name.clone(), - tool_params.clone(), - tool.clone(), - ) - .await?; - return Ok(result); - } else { - return Err(OllamaError::from("Tool not found".to_string())); + //Extract between and + let tool_response = self.extract_tool_response(input); + match tool_response { + Some(tool_response_str) => { + let response_value: Result = serde_json::from_str(&tool_response_str); + match response_value { + Ok(response) => { + if let Some(tool) = tools.iter().find(|t| t.name() == response.name) { + let tool_params = response.arguments; + let result = self + .function_call_with_history( + model_name.clone(), + tool_params.clone(), + tool.clone(), + ) + .await?; + return Ok(result); + } else { + return Err(OllamaError::from("Tool not found".to_string())); + } + } + Err(e) => return Err(OllamaError::from(e)), } } - Err(e) => return Err(OllamaError::from(e)), + None => return Err(OllamaError::from("Tool response not found".to_string())), } } From 3afff90c71dbefabcacfef7d573d47e4f911401b Mon Sep 17 00:00:00 2001 From: andthattoo Date: Fri, 17 May 2024 19:17:35 +0300 Subject: [PATCH 06/18] function call added --- tests/function_call.rs | 60 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 tests/function_call.rs diff --git a/tests/function_call.rs b/tests/function_call.rs new file mode 100644 index 0000000..fe0f709 --- /dev/null +++ b/tests/function_call.rs @@ -0,0 +1,60 @@ +use ollama_rs::{ + generation::functions::tools::{Scraper, DDGSearcher}, + generation::functions::{FunctionCallRequest, OpenAIFunctionCall}, + generation::chat::ChatMessage, + Ollama, +}; +use tokio::io::{stdout, AsyncWriteExt}; +use std::sync::Arc; +use log::info; +use env_logger; +use std::env; +use ollama_rs::generation::functions::pipelines::nous_hermes::request::NousFunctionCall; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let mut ollama = Ollama::new_default_with_history(30); + let scraper_tool = Arc::new(Scraper {}); + let ddg_search_tool = Arc::new(DDGSearcher::new()); + //adrienbrault/nous-hermes2pro:Q8_0 "openhermes:latest" + let mut stdout = stdout(); + + env::set_var("RUST_LOG", "info"); + env_logger::init(); + + loop { + stdout.write_all(b"\n> ").await?; + stdout.flush().await?; + + let mut input = String::new(); + std::io::stdin().read_line(&mut input)?; + + let input = input.trim_end(); + if input.eq_ignore_ascii_case("exit") { + break; + } + + let user_message = ChatMessage::user(input.to_string()); + + + let parser = Arc::new(NousFunctionCall {}); + let result = ollama.send_function_call( + FunctionCallRequest::new( + "adrienbrault/nous-hermes2pro:Q8_0".to_string(), + vec![scraper_tool.clone(), ddg_search_tool.clone()], + vec![user_message.clone()] + ), + parser.clone()).await?; + + if let Some(message) = result.message { + stdout.write_all(message.content.as_bytes()).await?; + } + + stdout.flush().await?; + } + + // Display whole history of messages + dbg!(&ollama.get_messages_history("default".to_string())); + + Ok(()) +} From d0f109f2c8937f6da75d272ee40b7ef8a09932a5 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Sat, 18 May 2024 01:07:44 +0300 Subject: [PATCH 07/18] successfully builds for cargo build --features "function-calling" --- Cargo.lock | 1 + Cargo.toml | 3 +- src/functions/tools/weather.rs | 33 ------------------- src/generation.rs | 1 + src/{ => generation}/functions/mod.rs | 0 .../functions/pipelines/mod.rs | 0 .../functions/pipelines/nous_hermes/mod.rs | 0 .../pipelines/nous_hermes/prompts.rs | 0 .../pipelines/nous_hermes/request.rs | 1 + .../functions/pipelines/openai/mod.rs | 0 .../functions/pipelines/openai/prompts.rs | 0 .../functions/pipelines/openai/request.rs | 0 src/{ => generation}/functions/request.rs | 0 src/{ => generation}/functions/tools/mod.rs | 0 .../functions/tools/scraper.rs | 0 .../functions/tools/search_ddg.rs | 0 tests/function_call.rs | 2 +- 17 files changed, 6 insertions(+), 35 deletions(-) delete mode 100644 src/functions/tools/weather.rs rename src/{ => generation}/functions/mod.rs (100%) rename src/{ => generation}/functions/pipelines/mod.rs (100%) rename src/{ => generation}/functions/pipelines/nous_hermes/mod.rs (100%) rename src/{ => generation}/functions/pipelines/nous_hermes/prompts.rs (100%) rename src/{ => generation}/functions/pipelines/nous_hermes/request.rs (99%) rename src/{ => generation}/functions/pipelines/openai/mod.rs (100%) rename src/{ => generation}/functions/pipelines/openai/prompts.rs (100%) rename src/{ => generation}/functions/pipelines/openai/request.rs (100%) rename src/{ => generation}/functions/request.rs (100%) rename src/{ => generation}/functions/tools/mod.rs (100%) rename src/{ => generation}/functions/tools/scraper.rs (100%) rename src/{ => generation}/functions/tools/search_ddg.rs (100%) diff --git a/Cargo.lock b/Cargo.lock index 49ff6cf..dbc31fd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -689,6 +689,7 @@ dependencies = [ "base64", "log", "ollama-rs", + "regex", "reqwest", "scraper", "serde", diff --git a/Cargo.toml b/Cargo.toml index 078c31c..03d930e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,13 +20,14 @@ url = "2" log = "0.4" # Add this line scraper = {version = "0.19.0" , optional = true } # Add scraper dependency text-splitter = {version = "0.13.1", optional = true } # Add text_splitter dependency +regex = {version = "1.9.3", optional = true } # Add regex dependency [features] default = ["reqwest/default-tls"] stream = ["tokio-stream", "reqwest/stream", "tokio"] rustls = ["reqwest/rustls-tls"] chat-history = [] -function-calling = ["scraper", "text-splitter", "chat-history"] +function-calling = ["scraper", "text-splitter", "regex", "chat-history"] [dev-dependencies] tokio = { version = "1", features = ["full"] } diff --git a/src/functions/tools/weather.rs b/src/functions/tools/weather.rs deleted file mode 100644 index 2671677..0000000 --- a/src/functions/tools/weather.rs +++ /dev/null @@ -1,33 +0,0 @@ -use async_trait::async_trait; -use serde_json::{json, Value}; -use std::error::Error; -use crate::generation::functions::tools::Tool; - -pub struct WeatherTool; - -#[async_trait] -impl Tool for WeatherTool { - fn name(&self) -> String { - "WeatherTool".to_string() - } - - fn description(&self) -> String { - "Get the current weather in a given location.".to_string() - } - - async fn run(&self, input: Value) -> Result> { - let location = input.as_str().ok_or("Input should be a string")?; - let unit = "fahrenheit"; // Default unit - let result = if location.to_lowercase().contains("tokyo") { - json!({"location": "Tokyo", "temperature": "10", "unit": unit}) - } else if location.to_lowercase().contains("san francisco") { - json!({"location": "San Francisco", "temperature": "72", "unit": unit}) - } else if location.to_lowercase().contains("paris") { - json!({"location": "Paris", "temperature": "22", "unit": unit}) - } else { - json!({"location": location, "temperature": "unknown"}) - }; - - Ok(result.to_string()) - } -} diff --git a/src/generation.rs b/src/generation.rs index b47ec20..9b8f682 100644 --- a/src/generation.rs +++ b/src/generation.rs @@ -4,4 +4,5 @@ pub mod embeddings; pub mod images; pub mod options; pub mod parameters; +#[cfg(feature = "function-calling")] pub mod functions; diff --git a/src/functions/mod.rs b/src/generation/functions/mod.rs similarity index 100% rename from src/functions/mod.rs rename to src/generation/functions/mod.rs diff --git a/src/functions/pipelines/mod.rs b/src/generation/functions/pipelines/mod.rs similarity index 100% rename from src/functions/pipelines/mod.rs rename to src/generation/functions/pipelines/mod.rs diff --git a/src/functions/pipelines/nous_hermes/mod.rs b/src/generation/functions/pipelines/nous_hermes/mod.rs similarity index 100% rename from src/functions/pipelines/nous_hermes/mod.rs rename to src/generation/functions/pipelines/nous_hermes/mod.rs diff --git a/src/functions/pipelines/nous_hermes/prompts.rs b/src/generation/functions/pipelines/nous_hermes/prompts.rs similarity index 100% rename from src/functions/pipelines/nous_hermes/prompts.rs rename to src/generation/functions/pipelines/nous_hermes/prompts.rs diff --git a/src/functions/pipelines/nous_hermes/request.rs b/src/generation/functions/pipelines/nous_hermes/request.rs similarity index 99% rename from src/functions/pipelines/nous_hermes/request.rs rename to src/generation/functions/pipelines/nous_hermes/request.rs index 4e03d8b..e1fb84c 100644 --- a/src/functions/pipelines/nous_hermes/request.rs +++ b/src/generation/functions/pipelines/nous_hermes/request.rs @@ -10,6 +10,7 @@ use crate::generation::functions::pipelines::RequestParserBase; use serde::{Deserialize, Serialize}; use regex::Regex; + pub fn convert_to_openai_tool(tool: Arc) -> Value { let mut function = HashMap::new(); function.insert("name".to_string(), Value::String(tool.name())); diff --git a/src/functions/pipelines/openai/mod.rs b/src/generation/functions/pipelines/openai/mod.rs similarity index 100% rename from src/functions/pipelines/openai/mod.rs rename to src/generation/functions/pipelines/openai/mod.rs diff --git a/src/functions/pipelines/openai/prompts.rs b/src/generation/functions/pipelines/openai/prompts.rs similarity index 100% rename from src/functions/pipelines/openai/prompts.rs rename to src/generation/functions/pipelines/openai/prompts.rs diff --git a/src/functions/pipelines/openai/request.rs b/src/generation/functions/pipelines/openai/request.rs similarity index 100% rename from src/functions/pipelines/openai/request.rs rename to src/generation/functions/pipelines/openai/request.rs diff --git a/src/functions/request.rs b/src/generation/functions/request.rs similarity index 100% rename from src/functions/request.rs rename to src/generation/functions/request.rs diff --git a/src/functions/tools/mod.rs b/src/generation/functions/tools/mod.rs similarity index 100% rename from src/functions/tools/mod.rs rename to src/generation/functions/tools/mod.rs diff --git a/src/functions/tools/scraper.rs b/src/generation/functions/tools/scraper.rs similarity index 100% rename from src/functions/tools/scraper.rs rename to src/generation/functions/tools/scraper.rs diff --git a/src/functions/tools/search_ddg.rs b/src/generation/functions/tools/search_ddg.rs similarity index 100% rename from src/functions/tools/search_ddg.rs rename to src/generation/functions/tools/search_ddg.rs diff --git a/tests/function_call.rs b/tests/function_call.rs index fe0f709..253fbcb 100644 --- a/tests/function_call.rs +++ b/tests/function_call.rs @@ -11,7 +11,7 @@ use env_logger; use std::env; use ollama_rs::generation::functions::pipelines::nous_hermes::request::NousFunctionCall; -#[tokio::main] +#[tokio::test] async fn main() -> Result<(), Box> { let mut ollama = Ollama::new_default_with_history(30); let scraper_tool = Arc::new(Scraper {}); From 4851eb3e0d35813a4a84ae54305abf261d46be6f Mon Sep 17 00:00:00 2001 From: andthattoo Date: Sat, 18 May 2024 01:16:07 +0300 Subject: [PATCH 08/18] removed env-logger, add it with verbose option --- src/generation/functions/mod.rs | 4 +++- src/generation/functions/pipelines/nous_hermes/request.rs | 1 - tests/function_call.rs | 6 ------ 3 files changed, 3 insertions(+), 8 deletions(-) diff --git a/src/generation/functions/mod.rs b/src/generation/functions/mod.rs index df5e669..b9c4ace 100644 --- a/src/generation/functions/mod.rs +++ b/src/generation/functions/mod.rs @@ -14,7 +14,6 @@ use crate::error::OllamaError; use std::sync::Arc; use crate::generation::functions::pipelines::RequestParserBase; - #[cfg(feature = "function-calling")] impl crate::Ollama { @@ -44,7 +43,9 @@ impl crate::Ollama { "default".to_string(), ).await?; + let response_content: String = result.message.clone().unwrap().content; + let result = parser.parse(&response_content, request.chat.model_name.clone(), request.tools).await?; return Ok(result); } @@ -67,6 +68,7 @@ impl crate::Ollama { } let result = self.send_chat_messages(request.chat).await?; let response_content: String = result.message.clone().unwrap().content; + let result = parser.parse(&response_content, model_name, request.tools).await?; return Ok(result); } diff --git a/src/generation/functions/pipelines/nous_hermes/request.rs b/src/generation/functions/pipelines/nous_hermes/request.rs index e1fb84c..4e03d8b 100644 --- a/src/generation/functions/pipelines/nous_hermes/request.rs +++ b/src/generation/functions/pipelines/nous_hermes/request.rs @@ -10,7 +10,6 @@ use crate::generation::functions::pipelines::RequestParserBase; use serde::{Deserialize, Serialize}; use regex::Regex; - pub fn convert_to_openai_tool(tool: Arc) -> Value { let mut function = HashMap::new(); function.insert("name".to_string(), Value::String(tool.name())); diff --git a/tests/function_call.rs b/tests/function_call.rs index 253fbcb..792f952 100644 --- a/tests/function_call.rs +++ b/tests/function_call.rs @@ -6,9 +6,6 @@ use ollama_rs::{ }; use tokio::io::{stdout, AsyncWriteExt}; use std::sync::Arc; -use log::info; -use env_logger; -use std::env; use ollama_rs::generation::functions::pipelines::nous_hermes::request::NousFunctionCall; #[tokio::test] @@ -19,9 +16,6 @@ async fn main() -> Result<(), Box> { //adrienbrault/nous-hermes2pro:Q8_0 "openhermes:latest" let mut stdout = stdout(); - env::set_var("RUST_LOG", "info"); - env_logger::init(); - loop { stdout.write_all(b"\n> ").await?; stdout.flush().await?; From eb4b6f9e219ad6ee18e8c00288fa06da9acdc6aa Mon Sep 17 00:00:00 2001 From: andthattoo Date: Sat, 18 May 2024 01:17:59 +0300 Subject: [PATCH 09/18] cargo fmt commit --- src/generation.rs | 4 +- src/generation/functions/mod.rs | 57 +++++++++++------ src/generation/functions/pipelines/mod.rs | 12 ++-- .../functions/pipelines/nous_hermes/mod.rs | 2 +- .../pipelines/nous_hermes/prompts.rs | 2 +- .../pipelines/nous_hermes/request.rs | 25 +++++--- .../functions/pipelines/openai/mod.rs | 2 +- .../functions/pipelines/openai/request.rs | 63 +++++++++++-------- src/generation/functions/request.rs | 13 ++-- src/generation/functions/tools/mod.rs | 4 +- src/generation/functions/tools/scraper.rs | 13 ++-- src/generation/functions/tools/search_ddg.rs | 54 +++++++++++----- tests/function_call.rs | 26 ++++---- 13 files changed, 165 insertions(+), 112 deletions(-) diff --git a/src/generation.rs b/src/generation.rs index 9b8f682..a29c4ae 100644 --- a/src/generation.rs +++ b/src/generation.rs @@ -1,8 +1,8 @@ pub mod chat; pub mod completion; pub mod embeddings; +#[cfg(feature = "function-calling")] +pub mod functions; pub mod images; pub mod options; pub mod parameters; -#[cfg(feature = "function-calling")] -pub mod functions; diff --git a/src/generation/functions/mod.rs b/src/generation/functions/mod.rs index b9c4ace..525c838 100644 --- a/src/generation/functions/mod.rs +++ b/src/generation/functions/mod.rs @@ -1,25 +1,28 @@ -pub mod tools; pub mod pipelines; pub mod request; +pub mod tools; -pub use tools::Scraper; -pub use tools::DDGSearcher; -pub use crate::generation::functions::request::FunctionCallRequest; pub use crate::generation::functions::pipelines::openai::request::OpenAIFunctionCall; +pub use crate::generation::functions::request::FunctionCallRequest; +pub use tools::DDGSearcher; +pub use tools::Scraper; +use crate::error::OllamaError; +use crate::generation::chat::request::ChatMessageRequest; use crate::generation::chat::{ChatMessage, ChatMessageResponse}; -use crate::generation::chat::request::{ChatMessageRequest}; +use crate::generation::functions::pipelines::RequestParserBase; use crate::generation::functions::tools::Tool; -use crate::error::OllamaError; use std::sync::Arc; -use crate::generation::functions::pipelines::RequestParserBase; #[cfg(feature = "function-calling")] impl crate::Ollama { - - pub async fn check_system_message(&self, messages: &Vec, system_prompt: &str) -> bool { + pub async fn check_system_message( + &self, + messages: &Vec, + system_prompt: &str, + ) -> bool { let system_message = messages.first().unwrap().clone(); - return system_message.content == system_prompt + return system_message.content == system_prompt; } #[cfg(feature = "chat-history")] @@ -28,29 +31,38 @@ impl crate::Ollama { request: FunctionCallRequest, parser: Arc, ) -> Result { - let system_prompt = parser.get_system_message(&request.tools).await; - if request.chat.messages.len() == 0{ // If there are no messages in the chat, add a system prompt + if request.chat.messages.len() == 0 { + // If there are no messages in the chat, add a system prompt self.send_chat_messages_with_history( - ChatMessageRequest::new(request.chat.model_name.clone(), vec![system_prompt.clone()]), + ChatMessageRequest::new( + request.chat.model_name.clone(), + vec![system_prompt.clone()], + ), "default".to_string(), - ).await?; + ) + .await?; } let result = self .send_chat_messages_with_history( ChatMessageRequest::new(request.chat.model_name.clone(), request.chat.messages), "default".to_string(), - ).await?; - + ) + .await?; let response_content: String = result.message.clone().unwrap().content; - let result = parser.parse(&response_content, request.chat.model_name.clone(), request.tools).await?; + let result = parser + .parse( + &response_content, + request.chat.model_name.clone(), + request.tools, + ) + .await?; return Ok(result); } - pub async fn send_function_call( &self, request: FunctionCallRequest, @@ -63,13 +75,18 @@ impl crate::Ollama { let model_name = request.chat.model_name.clone(); //Make sure the first message in chat is the system prompt - if !self.check_system_message(&request.chat.messages, &system_prompt.content).await { + if !self + .check_system_message(&request.chat.messages, &system_prompt.content) + .await + { request.chat.messages.insert(0, system_prompt); } let result = self.send_chat_messages(request.chat).await?; let response_content: String = result.message.clone().unwrap().content; - let result = parser.parse(&response_content, model_name, request.tools).await?; + let result = parser + .parse(&response_content, model_name, request.tools) + .await?; return Ok(result); } } diff --git a/src/generation/functions/pipelines/mod.rs b/src/generation/functions/pipelines/mod.rs index ab6129e..ee758b2 100644 --- a/src/generation/functions/pipelines/mod.rs +++ b/src/generation/functions/pipelines/mod.rs @@ -1,15 +1,19 @@ -use crate::generation::chat::{ChatMessage, ChatMessageResponse}; use crate::error::OllamaError; +use crate::generation::chat::{ChatMessage, ChatMessageResponse}; use crate::generation::functions::tools::Tool; use async_trait::async_trait; use std::sync::Arc; -pub mod openai; pub mod nous_hermes; - +pub mod openai; #[async_trait] pub trait RequestParserBase { - async fn parse(&self, input: &str, model_name:String, tools: Vec>) -> Result; + async fn parse( + &self, + input: &str, + model_name: String, + tools: Vec>, + ) -> Result; async fn get_system_message(&self, tools: &[Arc]) -> ChatMessage; } diff --git a/src/generation/functions/pipelines/nous_hermes/mod.rs b/src/generation/functions/pipelines/nous_hermes/mod.rs index df6fa06..23c05af 100644 --- a/src/generation/functions/pipelines/nous_hermes/mod.rs +++ b/src/generation/functions/pipelines/nous_hermes/mod.rs @@ -1,4 +1,4 @@ pub mod prompts; pub mod request; -pub use prompts::DEFAULT_SYSTEM_TEMPLATE; \ No newline at end of file +pub use prompts::DEFAULT_SYSTEM_TEMPLATE; diff --git a/src/generation/functions/pipelines/nous_hermes/prompts.rs b/src/generation/functions/pipelines/nous_hermes/prompts.rs index f237a7b..f1242ff 100644 --- a/src/generation/functions/pipelines/nous_hermes/prompts.rs +++ b/src/generation/functions/pipelines/nous_hermes/prompts.rs @@ -37,4 +37,4 @@ Instructions: | {{"arguments": , "name": }} -"#; \ No newline at end of file +"#; diff --git a/src/generation/functions/pipelines/nous_hermes/request.rs b/src/generation/functions/pipelines/nous_hermes/request.rs index 4e03d8b..e49df6d 100644 --- a/src/generation/functions/pipelines/nous_hermes/request.rs +++ b/src/generation/functions/pipelines/nous_hermes/request.rs @@ -1,14 +1,14 @@ -use serde_json::{json, Map, Value}; -use std::collections::HashMap; -use std::sync::Arc; -use async_trait::async_trait; +use crate::error::OllamaError; use crate::generation::chat::{ChatMessage, ChatMessageResponse}; use crate::generation::functions::pipelines::nous_hermes::DEFAULT_SYSTEM_TEMPLATE; -use crate::generation::functions::tools::Tool; -use crate::error::OllamaError; use crate::generation::functions::pipelines::RequestParserBase; -use serde::{Deserialize, Serialize}; +use crate::generation::functions::tools::Tool; +use async_trait::async_trait; use regex::Regex; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Map, Value}; +use std::collections::HashMap; +use std::sync::Arc; pub fn convert_to_openai_tool(tool: Arc) -> Value { let mut function = HashMap::new(); @@ -59,7 +59,8 @@ impl NousFunctionCall { pub fn extract_tool_response(&self, content: &str) -> Option { let re = Regex::new(r"(?s)(.*?)").unwrap(); - re.captures(content).and_then(|cap| cap.get(1).map(|m| m.as_str().to_string())) + re.captures(content) + .and_then(|cap| cap.get(1).map(|m| m.as_str().to_string())) } } @@ -75,7 +76,8 @@ impl RequestParserBase for NousFunctionCall { let tool_response = self.extract_tool_response(input); match tool_response { Some(tool_response_str) => { - let response_value: Result = serde_json::from_str(&tool_response_str); + let response_value: Result = + serde_json::from_str(&tool_response_str); match response_value { Ok(response) => { if let Some(tool) = tools.iter().find(|t| t.name() == response.name) { @@ -100,7 +102,10 @@ impl RequestParserBase for NousFunctionCall { } async fn get_system_message(&self, tools: &[Arc]) -> ChatMessage { - let tools_info: Vec = tools.iter().map(|tool| convert_to_openai_tool(tool.clone())).collect(); + let tools_info: Vec = tools + .iter() + .map(|tool| convert_to_openai_tool(tool.clone())) + .collect(); let tools_json = serde_json::to_string(&tools_info).unwrap(); let system_message_content = DEFAULT_SYSTEM_TEMPLATE.replace("{tools}", &tools_json); ChatMessage::system(system_message_content) diff --git a/src/generation/functions/pipelines/openai/mod.rs b/src/generation/functions/pipelines/openai/mod.rs index bc9e7a4..1c37516 100644 --- a/src/generation/functions/pipelines/openai/mod.rs +++ b/src/generation/functions/pipelines/openai/mod.rs @@ -1,4 +1,4 @@ pub mod prompts; pub mod request; -pub use prompts::{DEFAULT_SYSTEM_TEMPLATE ,DEFAULT_RESPONSE_FUNCTION}; \ No newline at end of file +pub use prompts::{DEFAULT_RESPONSE_FUNCTION, DEFAULT_SYSTEM_TEMPLATE}; diff --git a/src/generation/functions/pipelines/openai/request.rs b/src/generation/functions/pipelines/openai/request.rs index 6b35615..bb3bbe3 100644 --- a/src/generation/functions/pipelines/openai/request.rs +++ b/src/generation/functions/pipelines/openai/request.rs @@ -1,13 +1,13 @@ -use serde_json::Value; -use std::sync::Arc; +use crate::error::OllamaError; use crate::generation::chat::{ChatMessage, ChatMessageResponse}; use crate::generation::functions::pipelines::openai::DEFAULT_SYSTEM_TEMPLATE; -use crate::generation::functions::tools::Tool; -use crate::error::OllamaError; use crate::generation::functions::pipelines::RequestParserBase; -use serde_json::json; -use serde::{Deserialize, Serialize}; +use crate::generation::functions::tools::Tool; use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use serde_json::Value; +use std::sync::Arc; pub fn convert_to_ollama_tool(tool: &Arc) -> Value { let schema = tool.parameters(); @@ -27,55 +27,64 @@ pub struct OpenAIFunctionCallSignature { pub struct OpenAIFunctionCall {} impl OpenAIFunctionCall { - pub async fn function_call_with_history( &self, model_name: String, tool_params: Value, tool: Arc, ) -> Result { - let result = tool.run(tool_params).await; return match result { - Ok(result) => { - Ok(ChatMessageResponse { - model: model_name.clone(), - created_at: "".to_string(), - message: Some(ChatMessage::assistant(result.to_string())), - done: true, - final_data: None, - }) - }, - Err(e) => Err(OllamaError::from(e)) + Ok(result) => Ok(ChatMessageResponse { + model: model_name.clone(), + created_at: "".to_string(), + message: Some(ChatMessage::assistant(result.to_string())), + done: true, + final_data: None, + }), + Err(e) => Err(OllamaError::from(e)), }; } } #[async_trait] impl RequestParserBase for OpenAIFunctionCall { - async fn parse(&self, input: &str, model_name: String, tools: Vec>) -> Result { - let response_value: Result = serde_json::from_str(input); + async fn parse( + &self, + input: &str, + model_name: String, + tools: Vec>, + ) -> Result { + let response_value: Result = + serde_json::from_str(input); match response_value { Ok(response) => { if let Some(tool) = tools.iter().find(|t| t.name() == response.tool) { let tool_params = response.tool_input; - let result = self.function_call_with_history(model_name.clone(), - tool_params.clone(), - tool.clone(), - ).await?; + let result = self + .function_call_with_history( + model_name.clone(), + tool_params.clone(), + tool.clone(), + ) + .await?; return Ok(result); } else { return Err(OllamaError::from("Tool not found".to_string())); } - }, + } Err(e) => { return Err(OllamaError::from(e)); } } } - async fn get_system_message(&self, tools: &[Arc]) -> ChatMessage { // Corrected here to use a slice - let tools_info: Vec = tools.iter().map(|tool| convert_to_ollama_tool(tool)).collect(); + async fn get_system_message(&self, tools: &[Arc]) -> ChatMessage { + // Corrected here to use a slice + let tools_info: Vec = tools + .iter() + .map(|tool| convert_to_ollama_tool(tool)) + .collect(); let tools_json = serde_json::to_string(&tools_info).unwrap(); let system_message_content = DEFAULT_SYSTEM_TEMPLATE.replace("{tools}", &tools_json); ChatMessage::system(system_message_content) diff --git a/src/generation/functions/request.rs b/src/generation/functions/request.rs index 11fd439..44da42e 100644 --- a/src/generation/functions/request.rs +++ b/src/generation/functions/request.rs @@ -1,22 +1,19 @@ use crate::generation::chat::request::ChatMessageRequest; -use std::sync::Arc; use crate::generation::chat::ChatMessage; -use crate::generation::{options::GenerationOptions, parameters::FormatType}; use crate::generation::functions::Tool; +use crate::generation::{options::GenerationOptions, parameters::FormatType}; +use std::sync::Arc; #[derive(Clone)] pub struct FunctionCallRequest { pub chat: ChatMessageRequest, - pub tools: Vec> + pub tools: Vec>, } impl FunctionCallRequest { pub fn new(model_name: String, tools: Vec>, messages: Vec) -> Self { let chat = ChatMessageRequest::new(model_name, messages); - Self { - chat, - tools - } + Self { chat, tools } } /// Additional model parameters listed in the documentation for the Modelfile @@ -36,4 +33,4 @@ impl FunctionCallRequest { self.chat.format = Some(format); self } -} \ No newline at end of file +} diff --git a/src/generation/functions/tools/mod.rs b/src/generation/functions/tools/mod.rs index 6b29b62..2f54814 100644 --- a/src/generation/functions/tools/mod.rs +++ b/src/generation/functions/tools/mod.rs @@ -1,5 +1,5 @@ -pub mod search_ddg; pub mod scraper; +pub mod search_ddg; pub use self::scraper::Scraper; pub use self::search_ddg::DDGSearcher; @@ -54,4 +54,4 @@ pub trait Tool: Send + Sync { Err(_) => Value::String(input.to_string()), } } -} \ No newline at end of file +} diff --git a/src/generation/functions/tools/scraper.rs b/src/generation/functions/tools/scraper.rs index 2e8b4c0..6e15eec 100644 --- a/src/generation/functions/tools/scraper.rs +++ b/src/generation/functions/tools/scraper.rs @@ -3,14 +3,13 @@ use scraper::{Html, Selector}; use std::env; use text_splitter::TextSplitter; -use std::error::Error; -use serde_json::{Value, json}; use crate::generation::functions::tools::Tool; use async_trait::async_trait; +use serde_json::{json, Value}; +use std::error::Error; pub struct Scraper {} - #[async_trait] impl Tool for Scraper { fn name(&self) -> String { @@ -36,11 +35,12 @@ impl Tool for Scraper { async fn run(&self, input: Value) -> Result> { let website = input["website"].as_str().ok_or("Website URL is required")?; - let browserless_token = env::var("BROWSERLESS_TOKEN").expect("BROWSERLESS_TOKEN must be set"); + let browserless_token = + env::var("BROWSERLESS_TOKEN").expect("BROWSERLESS_TOKEN must be set"); let url = format!("http://0.0.0.0:3000/content?token={}", browserless_token); let payload = json!({ - "url": website - }); + "url": website + }); let client = Client::new(); let response = client .post(&url) @@ -66,4 +66,3 @@ impl Tool for Scraper { Ok(sentences) } } - diff --git a/src/generation/functions/tools/search_ddg.rs b/src/generation/functions/tools/search_ddg.rs index 67ddfae..ee56c77 100644 --- a/src/generation/functions/tools/search_ddg.rs +++ b/src/generation/functions/tools/search_ddg.rs @@ -4,10 +4,9 @@ use scraper::{Html, Selector}; use std::error::Error; use crate::generation::functions::tools::Tool; -use serde::{Deserialize, Serialize}; -use serde_json::{Value, json}; use async_trait::async_trait; - +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SearchResult { @@ -40,19 +39,40 @@ impl DDGSearcher { let result_url_selector = Selector::parse(".result__url").unwrap(); let result_snippet_selector = Selector::parse(".result__snippet").unwrap(); - let results = document.select(&result_selector).map(|result| { - - let title = result.select(&result_title_selector).next().unwrap().text().collect::>().join(""); - let link = result.select(&result_url_selector).next().unwrap().text().collect::>().join("").trim().to_string(); - let snippet = result.select(&result_snippet_selector).next().unwrap().text().collect::>().join(""); - - SearchResult { - title, - link, - //url: String::from(url.value().attr("href").unwrap()), - snippet, - } - }).collect::>(); + let results = document + .select(&result_selector) + .map(|result| { + let title = result + .select(&result_title_selector) + .next() + .unwrap() + .text() + .collect::>() + .join(""); + let link = result + .select(&result_url_selector) + .next() + .unwrap() + .text() + .collect::>() + .join("") + .trim() + .to_string(); + let snippet = result + .select(&result_snippet_selector) + .next() + .unwrap() + .text() + .collect::>() + .join(""); + + SearchResult { + title, + link, + snippet, + } + }) + .collect::>(); Ok(results) } @@ -97,4 +117,4 @@ impl Tool for DDGSearcher { async fn parse_input(&self, input: &str) -> Value { Tool::parse_input(self, input).await } -} \ No newline at end of file +} diff --git a/tests/function_call.rs b/tests/function_call.rs index 792f952..1a67298 100644 --- a/tests/function_call.rs +++ b/tests/function_call.rs @@ -1,12 +1,12 @@ +use ollama_rs::generation::functions::pipelines::nous_hermes::request::NousFunctionCall; use ollama_rs::{ - generation::functions::tools::{Scraper, DDGSearcher}, - generation::functions::{FunctionCallRequest, OpenAIFunctionCall}, generation::chat::ChatMessage, + generation::functions::tools::{DDGSearcher, Scraper}, + generation::functions::{FunctionCallRequest, OpenAIFunctionCall}, Ollama, }; -use tokio::io::{stdout, AsyncWriteExt}; use std::sync::Arc; -use ollama_rs::generation::functions::pipelines::nous_hermes::request::NousFunctionCall; +use tokio::io::{stdout, AsyncWriteExt}; #[tokio::test] async fn main() -> Result<(), Box> { @@ -30,15 +30,17 @@ async fn main() -> Result<(), Box> { let user_message = ChatMessage::user(input.to_string()); - let parser = Arc::new(NousFunctionCall {}); - let result = ollama.send_function_call( - FunctionCallRequest::new( - "adrienbrault/nous-hermes2pro:Q8_0".to_string(), - vec![scraper_tool.clone(), ddg_search_tool.clone()], - vec![user_message.clone()] - ), - parser.clone()).await?; + let result = ollama + .send_function_call( + FunctionCallRequest::new( + "adrienbrault/nous-hermes2pro:Q8_0".to_string(), + vec![scraper_tool.clone(), ddg_search_tool.clone()], + vec![user_message.clone()], + ), + parser.clone(), + ) + .await?; if let Some(message) = result.message { stdout.write_all(message.content.as_bytes()).await?; From 846cb513869903a5a78891ee52e6ea45bd30a6c3 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Sat, 18 May 2024 13:39:28 +0300 Subject: [PATCH 10/18] added fallback and error handling function calling structs now return ``ChatMessageResponse`if it's a tool error. That being: - Error parsing tool call - No such tool - Validation error for call format (parameter errors etc.) --- src/generation/chat/mod.rs | 13 ++- src/generation/functions/mod.rs | 79 +++++++++---------- src/generation/functions/pipelines/mod.rs | 13 ++- .../pipelines/nous_hermes/request.rs | 71 ++++++++++------- .../functions/pipelines/openai/request.rs | 79 ++++++++++--------- src/generation/functions/request.rs | 13 +-- src/generation/functions/tools/mod.rs | 4 +- src/generation/functions/tools/scraper.rs | 35 +++----- src/generation/functions/tools/search_ddg.rs | 54 ++++--------- tests/function_call.rs | 55 ++++--------- 10 files changed, 186 insertions(+), 230 deletions(-) diff --git a/src/generation/chat/mod.rs b/src/generation/chat/mod.rs index 32fb337..7a0c3a2 100644 --- a/src/generation/chat/mod.rs +++ b/src/generation/chat/mod.rs @@ -1,11 +1,8 @@ use serde::{Deserialize, Serialize}; use crate::Ollama; - pub mod request; - use request::ChatMessageRequest; - use super::images::Image; #[cfg(feature = "chat-history")] @@ -14,7 +11,7 @@ use crate::history::MessagesHistory; #[cfg(feature = "stream")] /// A stream of `ChatMessageResponse` objects pub type ChatMessageResponseStream = - std::pin::Pin> + Send>>; +std::pin::Pin> + Send>>; impl Ollama { #[cfg(feature = "stream")] @@ -210,6 +207,11 @@ impl ChatMessage { Self::new(MessageRole::System, content) } + #[cfg(feature = "function-calling")] + pub fn tool(content: String) -> Self { + Self::new(MessageRole::Tool, content) + } + pub fn with_images(mut self, images: Vec) -> Self { self.images = Some(images); self @@ -233,4 +235,7 @@ pub enum MessageRole { Assistant, #[serde(rename = "system")] System, + #[cfg(feature = "function-calling")] + #[serde(rename = "tool")] + Tool, } diff --git a/src/generation/functions/mod.rs b/src/generation/functions/mod.rs index 525c838..8899c47 100644 --- a/src/generation/functions/mod.rs +++ b/src/generation/functions/mod.rs @@ -1,28 +1,26 @@ +pub mod tools; pub mod pipelines; pub mod request; -pub mod tools; -pub use crate::generation::functions::pipelines::openai::request::OpenAIFunctionCall; -pub use crate::generation::functions::request::FunctionCallRequest; -pub use tools::DDGSearcher; pub use tools::Scraper; +pub use tools::DDGSearcher; +pub use crate::generation::functions::request::FunctionCallRequest; +pub use crate::generation::functions::pipelines::openai::request::OpenAIFunctionCall; +pub use crate::generation::functions::pipelines::nous_hermes::request::NousFunctionCall; -use crate::error::OllamaError; -use crate::generation::chat::request::ChatMessageRequest; use crate::generation::chat::{ChatMessage, ChatMessageResponse}; -use crate::generation::functions::pipelines::RequestParserBase; +use crate::generation::chat::request::{ChatMessageRequest}; use crate::generation::functions::tools::Tool; +use crate::error::OllamaError; use std::sync::Arc; +use crate::generation::functions::pipelines::RequestParserBase; #[cfg(feature = "function-calling")] impl crate::Ollama { - pub async fn check_system_message( - &self, - messages: &Vec, - system_prompt: &str, - ) -> bool { + + pub async fn check_system_message(&self, messages: &Vec, system_prompt: &str) -> bool { let system_message = messages.first().unwrap().clone(); - return system_message.content == system_prompt; + return system_message.content == system_prompt } #[cfg(feature = "chat-history")] @@ -31,43 +29,41 @@ impl crate::Ollama { request: FunctionCallRequest, parser: Arc, ) -> Result { + let system_prompt = parser.get_system_message(&request.tools).await; - if request.chat.messages.len() == 0 { - // If there are no messages in the chat, add a system prompt + if request.chat.messages.len() == 0{ // If there are no messages in the chat, add a system prompt self.send_chat_messages_with_history( - ChatMessageRequest::new( - request.chat.model_name.clone(), - vec![system_prompt.clone()], - ), + ChatMessageRequest::new(request.chat.model_name.clone(), vec![system_prompt.clone()]), "default".to_string(), - ) - .await?; + ).await?; } let result = self .send_chat_messages_with_history( ChatMessageRequest::new(request.chat.model_name.clone(), request.chat.messages), "default".to_string(), - ) - .await?; + ).await?; + let response_content: String = result.message.clone().unwrap().content; - let result = parser - .parse( - &response_content, - request.chat.model_name.clone(), - request.tools, - ) - .await?; - return Ok(result); + let result = parser.parse(&response_content, request.chat.model_name.clone(), request.tools).await; + match result { + Ok(r) => { + return Ok(r); + }, + Err(e) => { + return Ok(e); + } + } } + pub async fn send_function_call( &self, request: FunctionCallRequest, parser: Arc, - ) -> crate::error::Result { + ) -> Result { let mut request = request; request.chat.stream = false; @@ -75,18 +71,21 @@ impl crate::Ollama { let model_name = request.chat.model_name.clone(); //Make sure the first message in chat is the system prompt - if !self - .check_system_message(&request.chat.messages, &system_prompt.content) - .await - { + if !self.check_system_message(&request.chat.messages, &system_prompt.content).await { request.chat.messages.insert(0, system_prompt); } + let result = self.send_chat_messages(request.chat).await?; let response_content: String = result.message.clone().unwrap().content; - let result = parser - .parse(&response_content, model_name, request.tools) - .await?; - return Ok(result); + let result = parser.parse(&response_content, model_name, request.tools).await; + match result { + Ok(r) => { + return Ok(r); + }, + Err(e) => { + return Ok(e); + } + } } } diff --git a/src/generation/functions/pipelines/mod.rs b/src/generation/functions/pipelines/mod.rs index ee758b2..325af72 100644 --- a/src/generation/functions/pipelines/mod.rs +++ b/src/generation/functions/pipelines/mod.rs @@ -1,19 +1,16 @@ -use crate::error::OllamaError; use crate::generation::chat::{ChatMessage, ChatMessageResponse}; +use crate::error::OllamaError; use crate::generation::functions::tools::Tool; use async_trait::async_trait; use std::sync::Arc; -pub mod nous_hermes; pub mod openai; +pub mod nous_hermes; + #[async_trait] pub trait RequestParserBase { - async fn parse( - &self, - input: &str, - model_name: String, - tools: Vec>, - ) -> Result; + async fn parse(&self, input: &str, model_name:String, tools: Vec>) -> Result; async fn get_system_message(&self, tools: &[Arc]) -> ChatMessage; + fn error_handler(&self, error: OllamaError) -> ChatMessageResponse; } diff --git a/src/generation/functions/pipelines/nous_hermes/request.rs b/src/generation/functions/pipelines/nous_hermes/request.rs index e49df6d..09c0242 100644 --- a/src/generation/functions/pipelines/nous_hermes/request.rs +++ b/src/generation/functions/pipelines/nous_hermes/request.rs @@ -1,14 +1,14 @@ -use crate::error::OllamaError; +use serde_json::{json, Map, Value}; +use std::collections::HashMap; +use std::sync::Arc; +use async_trait::async_trait; use crate::generation::chat::{ChatMessage, ChatMessageResponse}; use crate::generation::functions::pipelines::nous_hermes::DEFAULT_SYSTEM_TEMPLATE; -use crate::generation::functions::pipelines::RequestParserBase; use crate::generation::functions::tools::Tool; -use async_trait::async_trait; -use regex::Regex; +use crate::error::OllamaError; +use crate::generation::functions::pipelines::RequestParserBase; use serde::{Deserialize, Serialize}; -use serde_json::{json, Map, Value}; -use std::collections::HashMap; -use std::sync::Arc; +use regex::Regex; pub fn convert_to_openai_tool(tool: Arc) -> Value { let mut function = HashMap::new(); @@ -39,28 +39,28 @@ impl NousFunctionCall { model_name: String, tool_params: Value, tool: Arc, - ) -> Result { + ) -> Result { let result = tool.run(tool_params).await; match result { - Ok(result) => Ok(ChatMessageResponse { - model: model_name.clone(), - created_at: "".to_string(), - message: Some(ChatMessage::assistant(result)), - done: true, - final_data: None, - }), - Err(e) => Err(OllamaError::from(e)), + Ok(result) => + Ok(ChatMessageResponse { + model: model_name.clone(), + created_at: "".to_string(), + message: Some(ChatMessage::tool(self.format_tool_response(&result))), + done: true, + final_data: None, + }), + Err(e) => Err(self.error_handler(OllamaError::from(e))), } } pub fn format_tool_response(&self, function_response: &str) -> String { - format!("\n{}\n\n", function_response) + format!("\n{}\n\n", function_response) } pub fn extract_tool_response(&self, content: &str) -> Option { let re = Regex::new(r"(?s)(.*?)").unwrap(); - re.captures(content) - .and_then(|cap| cap.get(1).map(|m| m.as_str().to_string())) + re.captures(content).and_then(|cap| cap.get(1).map(|m| m.as_str().to_string())) } } @@ -71,13 +71,12 @@ impl RequestParserBase for NousFunctionCall { input: &str, model_name: String, tools: Vec>, - ) -> Result { + ) -> Result { //Extract between and let tool_response = self.extract_tool_response(input); match tool_response { Some(tool_response_str) => { - let response_value: Result = - serde_json::from_str(&tool_response_str); + let response_value: Result = serde_json::from_str(&tool_response_str); match response_value { Ok(response) => { if let Some(tool) = tools.iter().find(|t| t.name() == response.name) { @@ -88,26 +87,38 @@ impl RequestParserBase for NousFunctionCall { tool_params.clone(), tool.clone(), ) - .await?; + .await?; //Error is also returned as String for LLM feedback return Ok(result); } else { - return Err(OllamaError::from("Tool not found".to_string())); + return Err(self.error_handler(OllamaError::from("Tool not found".to_string()))); } } - Err(e) => return Err(OllamaError::from(e)), + Err(e) => return Err(self.error_handler(OllamaError::from(e))), } } - None => return Err(OllamaError::from("Tool response not found".to_string())), + None => return Err(self.error_handler(OllamaError::from("Tool call not found".to_string()))), } } async fn get_system_message(&self, tools: &[Arc]) -> ChatMessage { - let tools_info: Vec = tools - .iter() - .map(|tool| convert_to_openai_tool(tool.clone())) - .collect(); + let tools_info: Vec = tools.iter().map(|tool| convert_to_openai_tool(tool.clone())).collect(); let tools_json = serde_json::to_string(&tools_info).unwrap(); let system_message_content = DEFAULT_SYSTEM_TEMPLATE.replace("{tools}", &tools_json); ChatMessage::system(system_message_content) } + + fn error_handler(&self, error: OllamaError) -> ChatMessageResponse { + let error_message = format!( + "\nThere was an error parsing function calls\n Here's the error stack trace: {}\nPlease call the function again with correct syntax", + error.to_string() + ); + + ChatMessageResponse { + model: "".to_string(), + created_at: "".to_string(), + message: Some(ChatMessage::tool(error_message)), + done: true, + final_data: None, + } + } } diff --git a/src/generation/functions/pipelines/openai/request.rs b/src/generation/functions/pipelines/openai/request.rs index bb3bbe3..dc82df6 100644 --- a/src/generation/functions/pipelines/openai/request.rs +++ b/src/generation/functions/pipelines/openai/request.rs @@ -1,13 +1,13 @@ -use crate::error::OllamaError; +use serde_json::Value; +use std::sync::Arc; use crate::generation::chat::{ChatMessage, ChatMessageResponse}; use crate::generation::functions::pipelines::openai::DEFAULT_SYSTEM_TEMPLATE; -use crate::generation::functions::pipelines::RequestParserBase; use crate::generation::functions::tools::Tool; -use async_trait::async_trait; -use serde::{Deserialize, Serialize}; +use crate::error::OllamaError; +use crate::generation::functions::pipelines::RequestParserBase; use serde_json::json; -use serde_json::Value; -use std::sync::Arc; +use serde::{Deserialize, Serialize}; +use async_trait::async_trait; pub fn convert_to_ollama_tool(tool: &Arc) -> Value { let schema = tool.parameters(); @@ -27,66 +27,67 @@ pub struct OpenAIFunctionCallSignature { pub struct OpenAIFunctionCall {} impl OpenAIFunctionCall { + pub async fn function_call_with_history( &self, model_name: String, tool_params: Value, tool: Arc, - ) -> Result { + ) -> Result { + let result = tool.run(tool_params).await; return match result { - Ok(result) => Ok(ChatMessageResponse { - model: model_name.clone(), - created_at: "".to_string(), - message: Some(ChatMessage::assistant(result.to_string())), - done: true, - final_data: None, - }), - Err(e) => Err(OllamaError::from(e)), + Ok(result) => { + Ok(ChatMessageResponse { + model: model_name.clone(), + created_at: "".to_string(), + message: Some(ChatMessage::assistant(result.to_string())), + done: true, + final_data: None, + }) + }, + Err(e) => Err(self.error_handler(OllamaError::from(e))), }; } } #[async_trait] impl RequestParserBase for OpenAIFunctionCall { - async fn parse( - &self, - input: &str, - model_name: String, - tools: Vec>, - ) -> Result { - let response_value: Result = - serde_json::from_str(input); + async fn parse(&self, input: &str, model_name: String, tools: Vec>) -> Result { + let response_value: Result = serde_json::from_str(input); match response_value { Ok(response) => { if let Some(tool) = tools.iter().find(|t| t.name() == response.tool) { let tool_params = response.tool_input; - let result = self - .function_call_with_history( - model_name.clone(), - tool_params.clone(), - tool.clone(), - ) - .await?; + let result = self.function_call_with_history(model_name.clone(), + tool_params.clone(), + tool.clone(), + ).await?; return Ok(result); } else { - return Err(OllamaError::from("Tool not found".to_string())); + return Err(self.error_handler(OllamaError::from("Tool not found".to_string()))); } - } + }, Err(e) => { - return Err(OllamaError::from(e)); + return Err(self.error_handler(OllamaError::from(e))); } } } - async fn get_system_message(&self, tools: &[Arc]) -> ChatMessage { - // Corrected here to use a slice - let tools_info: Vec = tools - .iter() - .map(|tool| convert_to_ollama_tool(tool)) - .collect(); + async fn get_system_message(&self, tools: &[Arc]) -> ChatMessage { // Corrected here to use a slice + let tools_info: Vec = tools.iter().map(|tool| convert_to_ollama_tool(tool)).collect(); let tools_json = serde_json::to_string(&tools_info).unwrap(); let system_message_content = DEFAULT_SYSTEM_TEMPLATE.replace("{tools}", &tools_json); ChatMessage::system(system_message_content) } + + fn error_handler(&self, error: OllamaError) -> ChatMessageResponse { + ChatMessageResponse { + model: "".to_string(), + created_at: "".to_string(), + message: Some(ChatMessage::tool(error.to_string())), + done: true, + final_data: None, + } + } } diff --git a/src/generation/functions/request.rs b/src/generation/functions/request.rs index 44da42e..11fd439 100644 --- a/src/generation/functions/request.rs +++ b/src/generation/functions/request.rs @@ -1,19 +1,22 @@ use crate::generation::chat::request::ChatMessageRequest; +use std::sync::Arc; use crate::generation::chat::ChatMessage; -use crate::generation::functions::Tool; use crate::generation::{options::GenerationOptions, parameters::FormatType}; -use std::sync::Arc; +use crate::generation::functions::Tool; #[derive(Clone)] pub struct FunctionCallRequest { pub chat: ChatMessageRequest, - pub tools: Vec>, + pub tools: Vec> } impl FunctionCallRequest { pub fn new(model_name: String, tools: Vec>, messages: Vec) -> Self { let chat = ChatMessageRequest::new(model_name, messages); - Self { chat, tools } + Self { + chat, + tools + } } /// Additional model parameters listed in the documentation for the Modelfile @@ -33,4 +36,4 @@ impl FunctionCallRequest { self.chat.format = Some(format); self } -} +} \ No newline at end of file diff --git a/src/generation/functions/tools/mod.rs b/src/generation/functions/tools/mod.rs index 2f54814..6b29b62 100644 --- a/src/generation/functions/tools/mod.rs +++ b/src/generation/functions/tools/mod.rs @@ -1,5 +1,5 @@ -pub mod scraper; pub mod search_ddg; +pub mod scraper; pub use self::scraper::Scraper; pub use self::search_ddg::DDGSearcher; @@ -54,4 +54,4 @@ pub trait Tool: Send + Sync { Err(_) => Value::String(input.to_string()), } } -} +} \ No newline at end of file diff --git a/src/generation/functions/tools/scraper.rs b/src/generation/functions/tools/scraper.rs index 6e15eec..2e8c61c 100644 --- a/src/generation/functions/tools/scraper.rs +++ b/src/generation/functions/tools/scraper.rs @@ -1,12 +1,9 @@ use reqwest::Client; use scraper::{Html, Selector}; -use std::env; -use text_splitter::TextSplitter; - +use std::error::Error; +use serde_json::{Value, json}; use crate::generation::functions::tools::Tool; use async_trait::async_trait; -use serde_json::{json, Value}; -use std::error::Error; pub struct Scraper {} @@ -35,34 +32,20 @@ impl Tool for Scraper { async fn run(&self, input: Value) -> Result> { let website = input["website"].as_str().ok_or("Website URL is required")?; - let browserless_token = - env::var("BROWSERLESS_TOKEN").expect("BROWSERLESS_TOKEN must be set"); - let url = format!("http://0.0.0.0:3000/content?token={}", browserless_token); - let payload = json!({ - "url": website - }); let client = Client::new(); - let response = client - .post(&url) - .header("cache-control", "no-cache") - .header("content-type", "application/json") - .json(&payload) - .send() - .await?; + let response = client.get(website).send().await?.text().await?; - let response_text = response.text().await?; - let document = Html::parse_document(&response_text); + let document = Html::parse_document(&response); let selector = Selector::parse("p, h1, h2, h3, h4, h5, h6").unwrap(); let elements: Vec = document .select(&selector) - .map(|el| el.text().collect::()) + .map(|el| el.text().collect::>().join(" ")) .collect(); let body = elements.join(" "); - let splitter = TextSplitter::new(1000); - let chunks = splitter.chunks(&body); - let sentences: Vec = chunks.map(|s| s.to_string()).collect(); - let sentences = sentences.join("\n \n"); - Ok(sentences) + let sentences: Vec = body.split(". ").map(|s| s.to_string()).collect(); + let formatted_content = sentences.join("\n\n"); + + Ok(formatted_content) } } diff --git a/src/generation/functions/tools/search_ddg.rs b/src/generation/functions/tools/search_ddg.rs index ee56c77..67ddfae 100644 --- a/src/generation/functions/tools/search_ddg.rs +++ b/src/generation/functions/tools/search_ddg.rs @@ -4,9 +4,10 @@ use scraper::{Html, Selector}; use std::error::Error; use crate::generation::functions::tools::Tool; -use async_trait::async_trait; use serde::{Deserialize, Serialize}; -use serde_json::{json, Value}; +use serde_json::{Value, json}; +use async_trait::async_trait; + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SearchResult { @@ -39,40 +40,19 @@ impl DDGSearcher { let result_url_selector = Selector::parse(".result__url").unwrap(); let result_snippet_selector = Selector::parse(".result__snippet").unwrap(); - let results = document - .select(&result_selector) - .map(|result| { - let title = result - .select(&result_title_selector) - .next() - .unwrap() - .text() - .collect::>() - .join(""); - let link = result - .select(&result_url_selector) - .next() - .unwrap() - .text() - .collect::>() - .join("") - .trim() - .to_string(); - let snippet = result - .select(&result_snippet_selector) - .next() - .unwrap() - .text() - .collect::>() - .join(""); - - SearchResult { - title, - link, - snippet, - } - }) - .collect::>(); + let results = document.select(&result_selector).map(|result| { + + let title = result.select(&result_title_selector).next().unwrap().text().collect::>().join(""); + let link = result.select(&result_url_selector).next().unwrap().text().collect::>().join("").trim().to_string(); + let snippet = result.select(&result_snippet_selector).next().unwrap().text().collect::>().join(""); + + SearchResult { + title, + link, + //url: String::from(url.value().attr("href").unwrap()), + snippet, + } + }).collect::>(); Ok(results) } @@ -117,4 +97,4 @@ impl Tool for DDGSearcher { async fn parse_input(&self, input: &str) -> Value { Tool::parse_input(self, input).await } -} +} \ No newline at end of file diff --git a/tests/function_call.rs b/tests/function_call.rs index 1a67298..f029dea 100644 --- a/tests/function_call.rs +++ b/tests/function_call.rs @@ -1,56 +1,33 @@ -use ollama_rs::generation::functions::pipelines::nous_hermes::request::NousFunctionCall; use ollama_rs::{ generation::chat::ChatMessage, generation::functions::tools::{DDGSearcher, Scraper}, - generation::functions::{FunctionCallRequest, OpenAIFunctionCall}, + generation::functions::{FunctionCallRequest, NousFunctionCall}, Ollama, }; use std::sync::Arc; use tokio::io::{stdout, AsyncWriteExt}; #[tokio::test] -async fn main() -> Result<(), Box> { +async fn test_send_function_call() { let mut ollama = Ollama::new_default_with_history(30); let scraper_tool = Arc::new(Scraper {}); let ddg_search_tool = Arc::new(DDGSearcher::new()); - //adrienbrault/nous-hermes2pro:Q8_0 "openhermes:latest" - let mut stdout = stdout(); - loop { - stdout.write_all(b"\n> ").await?; - stdout.flush().await?; + let query = "".to_string(); + let user_message = ChatMessage::user(query.to_string()); - let mut input = String::new(); - std::io::stdin().read_line(&mut input)?; + let parser = Arc::new(NousFunctionCall {}); + let result = ollama + .send_function_call( + FunctionCallRequest::new( + "adrienbrault/nous-hermes2pro:Q8_0".to_string(), + vec![scraper_tool.clone(), ddg_search_tool.clone()], + vec![user_message.clone()], + ), + parser.clone(), + ) + .await?; - let input = input.trim_end(); - if input.eq_ignore_ascii_case("exit") { - break; - } + dbg!(&result); - let user_message = ChatMessage::user(input.to_string()); - - let parser = Arc::new(NousFunctionCall {}); - let result = ollama - .send_function_call( - FunctionCallRequest::new( - "adrienbrault/nous-hermes2pro:Q8_0".to_string(), - vec![scraper_tool.clone(), ddg_search_tool.clone()], - vec![user_message.clone()], - ), - parser.clone(), - ) - .await?; - - if let Some(message) = result.message { - stdout.write_all(message.content.as_bytes()).await?; - } - - stdout.flush().await?; - } - - // Display whole history of messages - dbg!(&ollama.get_messages_history("default".to_string())); - - Ok(()) } From 9b942ad57ad6b172764a57e39632277414da44e1 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Sat, 18 May 2024 14:04:07 +0300 Subject: [PATCH 11/18] removed-? --- tests/function_call.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/function_call.rs b/tests/function_call.rs index f029dea..c7e66c3 100644 --- a/tests/function_call.rs +++ b/tests/function_call.rs @@ -26,7 +26,7 @@ async fn test_send_function_call() { ), parser.clone(), ) - .await?; + .await.unwrap(); dbg!(&result); From 16fec9681a98466220c4ca35f107d7fcb32577d8 Mon Sep 17 00:00:00 2001 From: erhant Date: Sat, 18 May 2024 15:00:31 +0300 Subject: [PATCH 12/18] added docs, cleaned comments, slight formatting --- Cargo.toml | 13 ++-- README.md | 18 ++++++ src/generation/functions/mod.rs | 63 ++++++++++++------- src/generation/functions/pipelines/mod.rs | 12 ++-- .../pipelines/nous_hermes/request.rs | 52 +++++++++------ .../functions/pipelines/openai/request.rs | 61 ++++++++++-------- src/generation/functions/request.rs | 13 ++-- src/generation/functions/tools/mod.rs | 6 +- src/generation/functions/tools/scraper.rs | 12 +++- src/generation/functions/tools/search_ddg.rs | 55 +++++++++++----- tests/function_call.rs | 34 +++++----- 11 files changed, 214 insertions(+), 125 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 03d930e..cca9f52 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,17 +10,17 @@ readme = "README.md" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -reqwest = { version = "0.12.4", default-features = false, features=["json"] } +reqwest = { version = "0.12.4", default-features = false, features = ["json"] } serde = { version = "1", features = ["derive"] } serde_json = "1" tokio = { version = "1", features = ["full"], optional = true } tokio-stream = { version = "0.1.15", optional = true } -async-trait = { version = "0.1.73" } # Remove optional = true +async-trait = { version = "0.1.73" } url = "2" -log = "0.4" # Add this line -scraper = {version = "0.19.0" , optional = true } # Add scraper dependency -text-splitter = {version = "0.13.1", optional = true } # Add text_splitter dependency -regex = {version = "1.9.3", optional = true } # Add regex dependency +log = "0.4" +scraper = { version = "0.19.0", optional = true } +text-splitter = { version = "0.13.1", optional = true } +regex = { version = "1.9.3", optional = true } [features] default = ["reqwest/default-tls"] @@ -33,4 +33,3 @@ function-calling = ["scraper", "text-splitter", "regex", "chat-history"] tokio = { version = "1", features = ["full"] } ollama-rs = { path = ".", features = ["stream", "chat-history"] } base64 = "0.22.0" - diff --git a/README.md b/README.md index c768403..f55e3c0 100644 --- a/README.md +++ b/README.md @@ -147,3 +147,21 @@ let res = ollama.generate_embeddings("llama2:latest".to_string(), prompt, None). ``` _Returns a `GenerateEmbeddingsResponse` struct containing the embeddings (a vector of floats)._ + +### Make a function call + +```rust +let tools = vec![Arc::new(Scraper::new())]; +let parser = Arc::new(NousFunctionCall::new()); +let message = ChatMessage::user("What is the current oil price?".to_string()); +let res = ollama.send_function_call( + FunctionCallRequest::new( + "adrienbrault/nous-hermes2pro:Q8_0".to_string(), + tools, + vec![message], + ), + parser, + ).await.unwrap(); +``` + +_Uses the given tools (such as searching the web) to find an answer, returns a `ChatMessageResponse` with the answer to the question._ diff --git a/src/generation/functions/mod.rs b/src/generation/functions/mod.rs index 8899c47..b4392a8 100644 --- a/src/generation/functions/mod.rs +++ b/src/generation/functions/mod.rs @@ -1,26 +1,29 @@ -pub mod tools; pub mod pipelines; pub mod request; +pub mod tools; -pub use tools::Scraper; -pub use tools::DDGSearcher; -pub use crate::generation::functions::request::FunctionCallRequest; -pub use crate::generation::functions::pipelines::openai::request::OpenAIFunctionCall; pub use crate::generation::functions::pipelines::nous_hermes::request::NousFunctionCall; +pub use crate::generation::functions::pipelines::openai::request::OpenAIFunctionCall; +pub use crate::generation::functions::request::FunctionCallRequest; +pub use tools::DDGSearcher; +pub use tools::Scraper; +use crate::error::OllamaError; +use crate::generation::chat::request::ChatMessageRequest; use crate::generation::chat::{ChatMessage, ChatMessageResponse}; -use crate::generation::chat::request::{ChatMessageRequest}; +use crate::generation::functions::pipelines::RequestParserBase; use crate::generation::functions::tools::Tool; -use crate::error::OllamaError; use std::sync::Arc; -use crate::generation::functions::pipelines::RequestParserBase; #[cfg(feature = "function-calling")] impl crate::Ollama { - - pub async fn check_system_message(&self, messages: &Vec, system_prompt: &str) -> bool { + pub async fn check_system_message( + &self, + messages: &Vec, + system_prompt: &str, + ) -> bool { let system_message = messages.first().unwrap().clone(); - return system_message.content == system_prompt + return system_message.content == system_prompt; } #[cfg(feature = "chat-history")] @@ -29,36 +32,45 @@ impl crate::Ollama { request: FunctionCallRequest, parser: Arc, ) -> Result { - let system_prompt = parser.get_system_message(&request.tools).await; - if request.chat.messages.len() == 0{ // If there are no messages in the chat, add a system prompt + if request.chat.messages.len() == 0 { + // If there are no messages in the chat, add a system prompt self.send_chat_messages_with_history( - ChatMessageRequest::new(request.chat.model_name.clone(), vec![system_prompt.clone()]), + ChatMessageRequest::new( + request.chat.model_name.clone(), + vec![system_prompt.clone()], + ), "default".to_string(), - ).await?; + ) + .await?; } let result = self .send_chat_messages_with_history( ChatMessageRequest::new(request.chat.model_name.clone(), request.chat.messages), "default".to_string(), - ).await?; - + ) + .await?; let response_content: String = result.message.clone().unwrap().content; - let result = parser.parse(&response_content, request.chat.model_name.clone(), request.tools).await; + let result = parser + .parse( + &response_content, + request.chat.model_name.clone(), + request.tools, + ) + .await; match result { Ok(r) => { return Ok(r); - }, + } Err(e) => { return Ok(e); } } } - pub async fn send_function_call( &self, request: FunctionCallRequest, @@ -71,18 +83,23 @@ impl crate::Ollama { let model_name = request.chat.model_name.clone(); //Make sure the first message in chat is the system prompt - if !self.check_system_message(&request.chat.messages, &system_prompt.content).await { + if !self + .check_system_message(&request.chat.messages, &system_prompt.content) + .await + { request.chat.messages.insert(0, system_prompt); } let result = self.send_chat_messages(request.chat).await?; let response_content: String = result.message.clone().unwrap().content; - let result = parser.parse(&response_content, model_name, request.tools).await; + let result = parser + .parse(&response_content, model_name, request.tools) + .await; match result { Ok(r) => { return Ok(r); - }, + } Err(e) => { return Ok(e); } diff --git a/src/generation/functions/pipelines/mod.rs b/src/generation/functions/pipelines/mod.rs index 325af72..d90e8fa 100644 --- a/src/generation/functions/pipelines/mod.rs +++ b/src/generation/functions/pipelines/mod.rs @@ -1,16 +1,20 @@ -use crate::generation::chat::{ChatMessage, ChatMessageResponse}; use crate::error::OllamaError; +use crate::generation::chat::{ChatMessage, ChatMessageResponse}; use crate::generation::functions::tools::Tool; use async_trait::async_trait; use std::sync::Arc; -pub mod openai; pub mod nous_hermes; - +pub mod openai; #[async_trait] pub trait RequestParserBase { - async fn parse(&self, input: &str, model_name:String, tools: Vec>) -> Result; + async fn parse( + &self, + input: &str, + model_name: String, + tools: Vec>, + ) -> Result; async fn get_system_message(&self, tools: &[Arc]) -> ChatMessage; fn error_handler(&self, error: OllamaError) -> ChatMessageResponse; } diff --git a/src/generation/functions/pipelines/nous_hermes/request.rs b/src/generation/functions/pipelines/nous_hermes/request.rs index 09c0242..64e4a88 100644 --- a/src/generation/functions/pipelines/nous_hermes/request.rs +++ b/src/generation/functions/pipelines/nous_hermes/request.rs @@ -1,14 +1,14 @@ -use serde_json::{json, Map, Value}; -use std::collections::HashMap; -use std::sync::Arc; -use async_trait::async_trait; +use crate::error::OllamaError; use crate::generation::chat::{ChatMessage, ChatMessageResponse}; use crate::generation::functions::pipelines::nous_hermes::DEFAULT_SYSTEM_TEMPLATE; -use crate::generation::functions::tools::Tool; -use crate::error::OllamaError; use crate::generation::functions::pipelines::RequestParserBase; -use serde::{Deserialize, Serialize}; +use crate::generation::functions::tools::Tool; +use async_trait::async_trait; use regex::Regex; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Map, Value}; +use std::collections::HashMap; +use std::sync::Arc; pub fn convert_to_openai_tool(tool: Arc) -> Value { let mut function = HashMap::new(); @@ -34,6 +34,10 @@ pub struct NousFunctionCallSignature { pub struct NousFunctionCall {} impl NousFunctionCall { + pub fn new() -> Self { + Self {} + } + pub async fn function_call_with_history( &self, model_name: String, @@ -42,14 +46,13 @@ impl NousFunctionCall { ) -> Result { let result = tool.run(tool_params).await; match result { - Ok(result) => - Ok(ChatMessageResponse { - model: model_name.clone(), - created_at: "".to_string(), - message: Some(ChatMessage::tool(self.format_tool_response(&result))), - done: true, - final_data: None, - }), + Ok(result) => Ok(ChatMessageResponse { + model: model_name.clone(), + created_at: "".to_string(), + message: Some(ChatMessage::tool(self.format_tool_response(&result))), + done: true, + final_data: None, + }), Err(e) => Err(self.error_handler(OllamaError::from(e))), } } @@ -60,7 +63,8 @@ impl NousFunctionCall { pub fn extract_tool_response(&self, content: &str) -> Option { let re = Regex::new(r"(?s)(.*?)").unwrap(); - re.captures(content).and_then(|cap| cap.get(1).map(|m| m.as_str().to_string())) + re.captures(content) + .and_then(|cap| cap.get(1).map(|m| m.as_str().to_string())) } } @@ -76,7 +80,8 @@ impl RequestParserBase for NousFunctionCall { let tool_response = self.extract_tool_response(input); match tool_response { Some(tool_response_str) => { - let response_value: Result = serde_json::from_str(&tool_response_str); + let response_value: Result = + serde_json::from_str(&tool_response_str); match response_value { Ok(response) => { if let Some(tool) = tools.iter().find(|t| t.name() == response.name) { @@ -90,18 +95,25 @@ impl RequestParserBase for NousFunctionCall { .await?; //Error is also returned as String for LLM feedback return Ok(result); } else { - return Err(self.error_handler(OllamaError::from("Tool not found".to_string()))); + return Err( + self.error_handler(OllamaError::from("Tool not found".to_string())) + ); } } Err(e) => return Err(self.error_handler(OllamaError::from(e))), } } - None => return Err(self.error_handler(OllamaError::from("Tool call not found".to_string()))), + None => { + return Err(self.error_handler(OllamaError::from("Tool call not found".to_string()))) + } } } async fn get_system_message(&self, tools: &[Arc]) -> ChatMessage { - let tools_info: Vec = tools.iter().map(|tool| convert_to_openai_tool(tool.clone())).collect(); + let tools_info: Vec = tools + .iter() + .map(|tool| convert_to_openai_tool(tool.clone())) + .collect(); let tools_json = serde_json::to_string(&tools_info).unwrap(); let system_message_content = DEFAULT_SYSTEM_TEMPLATE.replace("{tools}", &tools_json); ChatMessage::system(system_message_content) diff --git a/src/generation/functions/pipelines/openai/request.rs b/src/generation/functions/pipelines/openai/request.rs index dc82df6..cd025b4 100644 --- a/src/generation/functions/pipelines/openai/request.rs +++ b/src/generation/functions/pipelines/openai/request.rs @@ -1,13 +1,13 @@ -use serde_json::Value; -use std::sync::Arc; +use crate::error::OllamaError; use crate::generation::chat::{ChatMessage, ChatMessageResponse}; use crate::generation::functions::pipelines::openai::DEFAULT_SYSTEM_TEMPLATE; -use crate::generation::functions::tools::Tool; -use crate::error::OllamaError; use crate::generation::functions::pipelines::RequestParserBase; -use serde_json::json; -use serde::{Deserialize, Serialize}; +use crate::generation::functions::tools::Tool; use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use serde_json::Value; +use std::sync::Arc; pub fn convert_to_ollama_tool(tool: &Arc) -> Value { let schema = tool.parameters(); @@ -27,25 +27,21 @@ pub struct OpenAIFunctionCallSignature { pub struct OpenAIFunctionCall {} impl OpenAIFunctionCall { - pub async fn function_call_with_history( &self, model_name: String, tool_params: Value, tool: Arc, ) -> Result { - let result = tool.run(tool_params).await; return match result { - Ok(result) => { - Ok(ChatMessageResponse { - model: model_name.clone(), - created_at: "".to_string(), - message: Some(ChatMessage::assistant(result.to_string())), - done: true, - final_data: None, - }) - }, + Ok(result) => Ok(ChatMessageResponse { + model: model_name.clone(), + created_at: "".to_string(), + message: Some(ChatMessage::assistant(result.to_string())), + done: true, + final_data: None, + }), Err(e) => Err(self.error_handler(OllamaError::from(e))), }; } @@ -53,29 +49,42 @@ impl OpenAIFunctionCall { #[async_trait] impl RequestParserBase for OpenAIFunctionCall { - async fn parse(&self, input: &str, model_name: String, tools: Vec>) -> Result { - let response_value: Result = serde_json::from_str(input); + async fn parse( + &self, + input: &str, + model_name: String, + tools: Vec>, + ) -> Result { + let response_value: Result = + serde_json::from_str(input); match response_value { Ok(response) => { if let Some(tool) = tools.iter().find(|t| t.name() == response.tool) { let tool_params = response.tool_input; - let result = self.function_call_with_history(model_name.clone(), - tool_params.clone(), - tool.clone(), - ).await?; + let result = self + .function_call_with_history( + model_name.clone(), + tool_params.clone(), + tool.clone(), + ) + .await?; return Ok(result); } else { return Err(self.error_handler(OllamaError::from("Tool not found".to_string()))); } - }, + } Err(e) => { return Err(self.error_handler(OllamaError::from(e))); } } } - async fn get_system_message(&self, tools: &[Arc]) -> ChatMessage { // Corrected here to use a slice - let tools_info: Vec = tools.iter().map(|tool| convert_to_ollama_tool(tool)).collect(); + async fn get_system_message(&self, tools: &[Arc]) -> ChatMessage { + // Corrected here to use a slice + let tools_info: Vec = tools + .iter() + .map(|tool| convert_to_ollama_tool(tool)) + .collect(); let tools_json = serde_json::to_string(&tools_info).unwrap(); let system_message_content = DEFAULT_SYSTEM_TEMPLATE.replace("{tools}", &tools_json); ChatMessage::system(system_message_content) diff --git a/src/generation/functions/request.rs b/src/generation/functions/request.rs index 11fd439..44da42e 100644 --- a/src/generation/functions/request.rs +++ b/src/generation/functions/request.rs @@ -1,22 +1,19 @@ use crate::generation::chat::request::ChatMessageRequest; -use std::sync::Arc; use crate::generation::chat::ChatMessage; -use crate::generation::{options::GenerationOptions, parameters::FormatType}; use crate::generation::functions::Tool; +use crate::generation::{options::GenerationOptions, parameters::FormatType}; +use std::sync::Arc; #[derive(Clone)] pub struct FunctionCallRequest { pub chat: ChatMessageRequest, - pub tools: Vec> + pub tools: Vec>, } impl FunctionCallRequest { pub fn new(model_name: String, tools: Vec>, messages: Vec) -> Self { let chat = ChatMessageRequest::new(model_name, messages); - Self { - chat, - tools - } + Self { chat, tools } } /// Additional model parameters listed in the documentation for the Modelfile @@ -36,4 +33,4 @@ impl FunctionCallRequest { self.chat.format = Some(format); self } -} \ No newline at end of file +} diff --git a/src/generation/functions/tools/mod.rs b/src/generation/functions/tools/mod.rs index 6b29b62..f49643f 100644 --- a/src/generation/functions/tools/mod.rs +++ b/src/generation/functions/tools/mod.rs @@ -1,5 +1,5 @@ -pub mod search_ddg; pub mod scraper; +pub mod search_ddg; pub use self::scraper::Scraper; pub use self::search_ddg::DDGSearcher; @@ -17,7 +17,7 @@ pub trait Tool: Send + Sync { /// Provides a description of what the tool does and when to use it. fn description(&self) -> String; - /// This are the parameters for OpenAI-like function call. + /// Returns the parameters for OpenAI-like function call. fn parameters(&self) -> Value { json!({ "type": "object", @@ -54,4 +54,4 @@ pub trait Tool: Send + Sync { Err(_) => Value::String(input.to_string()), } } -} \ No newline at end of file +} diff --git a/src/generation/functions/tools/scraper.rs b/src/generation/functions/tools/scraper.rs index 2e8c61c..1c7f5ac 100644 --- a/src/generation/functions/tools/scraper.rs +++ b/src/generation/functions/tools/scraper.rs @@ -1,12 +1,18 @@ +use crate::generation::functions::tools::Tool; +use async_trait::async_trait; use reqwest::Client; use scraper::{Html, Selector}; +use serde_json::{json, Value}; use std::error::Error; -use serde_json::{Value, json}; -use crate::generation::functions::tools::Tool; -use async_trait::async_trait; pub struct Scraper {} +impl Scraper { + pub fn new() -> Self { + Self {} + } +} + #[async_trait] impl Tool for Scraper { fn name(&self) -> String { diff --git a/src/generation/functions/tools/search_ddg.rs b/src/generation/functions/tools/search_ddg.rs index 67ddfae..c6f17be 100644 --- a/src/generation/functions/tools/search_ddg.rs +++ b/src/generation/functions/tools/search_ddg.rs @@ -4,10 +4,9 @@ use scraper::{Html, Selector}; use std::error::Error; use crate::generation::functions::tools::Tool; -use serde::{Deserialize, Serialize}; -use serde_json::{Value, json}; use async_trait::async_trait; - +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SearchResult { @@ -40,19 +39,41 @@ impl DDGSearcher { let result_url_selector = Selector::parse(".result__url").unwrap(); let result_snippet_selector = Selector::parse(".result__snippet").unwrap(); - let results = document.select(&result_selector).map(|result| { - - let title = result.select(&result_title_selector).next().unwrap().text().collect::>().join(""); - let link = result.select(&result_url_selector).next().unwrap().text().collect::>().join("").trim().to_string(); - let snippet = result.select(&result_snippet_selector).next().unwrap().text().collect::>().join(""); - - SearchResult { - title, - link, - //url: String::from(url.value().attr("href").unwrap()), - snippet, - } - }).collect::>(); + let results = document + .select(&result_selector) + .map(|result| { + let title = result + .select(&result_title_selector) + .next() + .unwrap() + .text() + .collect::>() + .join(""); + let link = result + .select(&result_url_selector) + .next() + .unwrap() + .text() + .collect::>() + .join("") + .trim() + .to_string(); + let snippet = result + .select(&result_snippet_selector) + .next() + .unwrap() + .text() + .collect::>() + .join(""); + + SearchResult { + title, + link, + //url: String::from(url.value().attr("href").unwrap()), + snippet, + } + }) + .collect::>(); Ok(results) } @@ -97,4 +118,4 @@ impl Tool for DDGSearcher { async fn parse_input(&self, input: &str) -> Value { Tool::parse_input(self, input).await } -} \ No newline at end of file +} diff --git a/tests/function_call.rs b/tests/function_call.rs index c7e66c3..ff0d836 100644 --- a/tests/function_call.rs +++ b/tests/function_call.rs @@ -1,3 +1,5 @@ +// #![cfg(feature = "function-calling")] + use ollama_rs::{ generation::chat::ChatMessage, generation::functions::tools::{DDGSearcher, Scraper}, @@ -5,29 +7,33 @@ use ollama_rs::{ Ollama, }; use std::sync::Arc; -use tokio::io::{stdout, AsyncWriteExt}; #[tokio::test] async fn test_send_function_call() { - let mut ollama = Ollama::new_default_with_history(30); - let scraper_tool = Arc::new(Scraper {}); - let ddg_search_tool = Arc::new(DDGSearcher::new()); + /// Model to be used, make sure it is tailored towards "function calling", such as: + /// - openhermes:latest + /// - adrienbrault/nous-hermes2pro:Q8_0 + const MODEL: &str = "adrienbrault/nous-hermes2pro:Q8_0"; - let query = "".to_string(); - let user_message = ChatMessage::user(query.to_string()); + const PROMPT: &str = ""; + let user_message = ChatMessage::user(PROMPT.to_string()); - let parser = Arc::new(NousFunctionCall {}); + let scraper_tool = Arc::new(Scraper::new()); + let ddg_search_tool = Arc::new(DDGSearcher::new()); + let parser = Arc::new(NousFunctionCall::new()); + + let ollama = Ollama::new_default_with_history(30); let result = ollama .send_function_call( FunctionCallRequest::new( - "adrienbrault/nous-hermes2pro:Q8_0".to_string(), - vec![scraper_tool.clone(), ddg_search_tool.clone()], - vec![user_message.clone()], + MODEL.to_string(), + vec![scraper_tool, ddg_search_tool], + vec![user_message], ), - parser.clone(), + parser, ) - .await.unwrap(); - - dbg!(&result); + .await + .unwrap(); + assert!(result.done); } From 3df10009c143c981e430a611dc7b538b2ace5062 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Mon, 20 May 2024 03:09:41 +0300 Subject: [PATCH 13/18] - Added stock price tool - RequestParserBase trait now has `format_query` and `format_response` for greater flexibility - Updated the prompt of NousHermes - Tool errors are now ChatMessageResponses: They are not Ollama errors and might be passed to LLMs as messages for further processing. Error handler classes added. - Removed ChatMessage::Tool, Ollama doesn't accept it yet, maybe added in future. - send_function_call_with_history now uses Ollama's set_system_response - Fixed a few bugs, added JSON parsing guardails --- src/generation/chat/mod.rs | 12 +-- src/generation/functions/mod.rs | 76 +++++++-------- src/generation/functions/pipelines/mod.rs | 6 ++ .../pipelines/nous_hermes/prompts.rs | 37 +++++++- .../pipelines/nous_hermes/request.rs | 44 ++++++--- .../functions/pipelines/openai/request.rs | 2 +- src/generation/functions/tools/finance.rs | 92 +++++++++++++++++++ src/generation/functions/tools/mod.rs | 2 + tests/function_call.rs | 63 ++++++++++++- 9 files changed, 266 insertions(+), 68 deletions(-) create mode 100644 src/generation/functions/tools/finance.rs diff --git a/src/generation/chat/mod.rs b/src/generation/chat/mod.rs index 7a0c3a2..6f9e437 100644 --- a/src/generation/chat/mod.rs +++ b/src/generation/chat/mod.rs @@ -2,8 +2,8 @@ use serde::{Deserialize, Serialize}; use crate::Ollama; pub mod request; -use request::ChatMessageRequest; use super::images::Image; +use request::ChatMessageRequest; #[cfg(feature = "chat-history")] use crate::history::MessagesHistory; @@ -11,7 +11,7 @@ use crate::history::MessagesHistory; #[cfg(feature = "stream")] /// A stream of `ChatMessageResponse` objects pub type ChatMessageResponseStream = -std::pin::Pin> + Send>>; + std::pin::Pin> + Send>>; impl Ollama { #[cfg(feature = "stream")] @@ -207,11 +207,6 @@ impl ChatMessage { Self::new(MessageRole::System, content) } - #[cfg(feature = "function-calling")] - pub fn tool(content: String) -> Self { - Self::new(MessageRole::Tool, content) - } - pub fn with_images(mut self, images: Vec) -> Self { self.images = Some(images); self @@ -235,7 +230,4 @@ pub enum MessageRole { Assistant, #[serde(rename = "system")] System, - #[cfg(feature = "function-calling")] - #[serde(rename = "tool")] - Tool, } diff --git a/src/generation/functions/mod.rs b/src/generation/functions/mod.rs index b4392a8..3999de8 100644 --- a/src/generation/functions/mod.rs +++ b/src/generation/functions/mod.rs @@ -17,58 +17,68 @@ use std::sync::Arc; #[cfg(feature = "function-calling")] impl crate::Ollama { - pub async fn check_system_message( - &self, - messages: &Vec, - system_prompt: &str, - ) -> bool { + fn has_system_prompt(&self, messages: &Vec, system_prompt: &str) -> bool { let system_message = messages.first().unwrap().clone(); return system_message.content == system_prompt; } + fn has_system_prompt_history(&mut self) -> bool { + return self.get_messages_history("default".to_string()).is_some(); + } + #[cfg(feature = "chat-history")] pub async fn send_function_call_with_history( &mut self, request: FunctionCallRequest, parser: Arc, ) -> Result { - let system_prompt = parser.get_system_message(&request.tools).await; - if request.chat.messages.len() == 0 { - // If there are no messages in the chat, add a system prompt - self.send_chat_messages_with_history( - ChatMessageRequest::new( - request.chat.model_name.clone(), - vec![system_prompt.clone()], - ), - "default".to_string(), - ) - .await?; + let mut request = request; + + if !self.has_system_prompt_history() { + let system_prompt = parser.get_system_message(&request.tools).await; + self.set_system_response("default".to_string(), system_prompt.content); + + //format input + let formatted_query = ChatMessage::user( + parser.format_query(&request.chat.messages.first().unwrap().content), + ); + //replace with formatted_query with previous chat_message + request.chat.messages.remove(0); + request.chat.messages.insert(0, formatted_query); } - let result = self + let tool_call_result = self .send_chat_messages_with_history( ChatMessageRequest::new(request.chat.model_name.clone(), request.chat.messages), "default".to_string(), ) .await?; - let response_content: String = result.message.clone().unwrap().content; - + let tool_call_content: String = tool_call_result.message.clone().unwrap().content; let result = parser .parse( - &response_content, + &tool_call_content, request.chat.model_name.clone(), request.tools, ) .await; - match result { + + return match result { Ok(r) => { - return Ok(r); + self.add_assistant_response( + "default".to_string(), + r.message.clone().unwrap().content, + ); + Ok(r) } Err(e) => { - return Ok(e); + self.add_assistant_response( + "default".to_string(), + e.message.clone().unwrap().content, + ); + Ok(e) } - } + }; } pub async fn send_function_call( @@ -83,26 +93,18 @@ impl crate::Ollama { let model_name = request.chat.model_name.clone(); //Make sure the first message in chat is the system prompt - if !self - .check_system_message(&request.chat.messages, &system_prompt.content) - .await - { + if !self.has_system_prompt(&request.chat.messages, &system_prompt.content) { request.chat.messages.insert(0, system_prompt); } - let result = self.send_chat_messages(request.chat).await?; let response_content: String = result.message.clone().unwrap().content; let result = parser .parse(&response_content, model_name, request.tools) .await; - match result { - Ok(r) => { - return Ok(r); - } - Err(e) => { - return Ok(e); - } - } + return match result { + Ok(r) => Ok(r), + Err(e) => Ok(e), + }; } } diff --git a/src/generation/functions/pipelines/mod.rs b/src/generation/functions/pipelines/mod.rs index d90e8fa..057e347 100644 --- a/src/generation/functions/pipelines/mod.rs +++ b/src/generation/functions/pipelines/mod.rs @@ -15,6 +15,12 @@ pub trait RequestParserBase { model_name: String, tools: Vec>, ) -> Result; + fn format_query(&self, input: &str) -> String { + input.to_string() + } + fn format_response(&self, response: &str) -> String { + response.to_string() + } async fn get_system_message(&self, tools: &[Arc]) -> ChatMessage; fn error_handler(&self, error: OllamaError) -> ChatMessageResponse; } diff --git a/src/generation/functions/pipelines/nous_hermes/prompts.rs b/src/generation/functions/pipelines/nous_hermes/prompts.rs index f1242ff..fc0080b 100644 --- a/src/generation/functions/pipelines/nous_hermes/prompts.rs +++ b/src/generation/functions/pipelines/nous_hermes/prompts.rs @@ -16,17 +16,44 @@ Objective: | Tools: | Here are the available tools: {tools} - If the provided function signatures doesn't have the function you must call, you may write executable python code in markdown syntax and call code_interpreter() function as follows: + If the provided function signatures doesn't have the function you must call, you may write executable rust code in markdown syntax and call code_interpreter() function as follows: - {{"arguments": {{"code_markdown": , "name": "code_interpreter"}}}} + {"arguments": {"code_markdown": , "name": "code_interpreter"}} Make sure that the json object above with code markdown block is parseable with json.loads() and the XML block with XML ElementTree. Examples: | Here are some example usage of functions: - {examples} + [ + { + "example": "```\nSYSTEM: You are a helpful assistant who has access to functions. Use them if required\n[\n {\n \"name\": \"calculate_distance\",\n \"description\": \"Calculate the distance between two locations\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"origin\": {\n \"type\": \"string\",\n \"description\": \"The starting location\"\n },\n \"destination\": {\n \"type\": \"string\",\n \"description\": \"The destination location\"\n },\n \"mode\": {\n \"type\": \"string\",\n \"description\": \"The mode of transportation\"\n }\n },\n \"required\": [\n \"origin\",\n \"destination\",\n \"mode\"\n ]\n }\n },\n {\n \"name\": \"generate_password\",\n \"description\": \"Generate a random password\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"length\": {\n \"type\": \"integer\",\n \"description\": \"The length of the password\"\n }\n },\n \"required\": [\n \"length\"\n ]\n }\n }\n]\n\n\nUSER: Hi, I need to know the distance from New York to Los Angeles by car.\nASSISTANT:\n\n{\"arguments\": {\"origin\": \"New York\",\n \"destination\": \"Los Angeles\", \"mode\": \"car\"}, \"name\": \"calculate_distance\"}\n\n```\n" + }, + { + "example": "```\nSYSTEM: You are a helpful assistant with access to functions. Use them if required\n[\n {\n \"name\": \"calculate_distance\",\n \"description\": \"Calculate the distance between two locations\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"origin\": {\n \"type\": \"string\",\n \"description\": \"The starting location\"\n },\n \"destination\": {\n \"type\": \"string\",\n \"description\": \"The destination location\"\n },\n \"mode\": {\n \"type\": \"string\",\n \"description\": \"The mode of transportation\"\n }\n },\n \"required\": [\n \"origin\",\n \"destination\",\n \"mode\"\n ]\n }\n },\n {\n \"name\": \"generate_password\",\n \"description\": \"Generate a random password\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"length\": {\n \"type\": \"integer\",\n \"description\": \"The length of the password\"\n }\n },\n \"required\": [\n \"length\"\n ]\n }\n }\n]\n\n\nUSER: Can you help me generate a random password with a length of 8 characters?\nASSISTANT:\n\n{\"arguments\": {\"length\": 8}, \"name\": \"generate_password\"}\n\n```" + } +] Schema: | Use the following pydantic model json schema for each tool call you will make: - {schema} + { + "name": "tool name", + "description": "tool description", + "parameters": { + "type": "object", + "properties": { + "parameter1": { + "type": "string", + "description": "parameter description" + }, + "parameter2": { + "type": "string", + "description": "parameter description" + } + }, + "required": [ + "parameter1", + "parameter2" + ] + } + } Instructions: | At the very first turn you don't have so you shouldn't not make up the results. Please keep a running summary with analysis of previous function results and summaries from previous iterations. @@ -35,6 +62,6 @@ Instructions: | If you plan to continue with analysis, always call another function. For each function call return a valid json object (using doulbe quotes) with function name and arguments within XML tags as follows: - {{"arguments": , "name": }} + {"arguments": , "name": } "#; diff --git a/src/generation/functions/pipelines/nous_hermes/request.rs b/src/generation/functions/pipelines/nous_hermes/request.rs index 64e4a88..cda684b 100644 --- a/src/generation/functions/pipelines/nous_hermes/request.rs +++ b/src/generation/functions/pipelines/nous_hermes/request.rs @@ -49,7 +49,7 @@ impl NousFunctionCall { Ok(result) => Ok(ChatMessageResponse { model: model_name.clone(), created_at: "".to_string(), - message: Some(ChatMessage::tool(self.format_tool_response(&result))), + message: Some(ChatMessage::assistant(self.format_tool_response(&result))), done: true, final_data: None, }), @@ -61,10 +61,19 @@ impl NousFunctionCall { format!("\n{}\n\n", function_response) } - pub fn extract_tool_response(&self, content: &str) -> Option { + pub fn extract_tool_call(&self, content: &str) -> Option { let re = Regex::new(r"(?s)(.*?)").unwrap(); - re.captures(content) - .and_then(|cap| cap.get(1).map(|m| m.as_str().to_string())) + if let Some(captures) = re.captures(content) { + if let Some(matched) = captures.get(1) { + let result = matched + .as_str() + .replace("\n", "") + .replace("{{", "{") + .replace("}}", "}"); + return Some(result); + } + } + None } } @@ -76,8 +85,8 @@ impl RequestParserBase for NousFunctionCall { model_name: String, tools: Vec>, ) -> Result { - //Extract between and - let tool_response = self.extract_tool_response(input); + //Extract between and + let tool_response = self.extract_tool_call(input); match tool_response { Some(tool_response_str) => { let response_value: Result = @@ -95,20 +104,33 @@ impl RequestParserBase for NousFunctionCall { .await?; //Error is also returned as String for LLM feedback return Ok(result); } else { - return Err( - self.error_handler(OllamaError::from("Tool not found".to_string())) - ); + return Err(self.error_handler(OllamaError::from( + "Tool name not found".to_string(), + ))); } } Err(e) => return Err(self.error_handler(OllamaError::from(e))), } } None => { - return Err(self.error_handler(OllamaError::from("Tool call not found".to_string()))) + return Err(self.error_handler(OllamaError::from( + "Error while extracting tags.".to_string(), + ))) } } } + fn format_query(&self, input: &str) -> String { + format!( + "{}\nThis is the first turn and you don't have to analyze yet", + input + ) + } + + fn format_response(&self, response: &str) -> String { + format!("Agent iteration to assist with user query: {}", response) + } + async fn get_system_message(&self, tools: &[Arc]) -> ChatMessage { let tools_info: Vec = tools .iter() @@ -128,7 +150,7 @@ impl RequestParserBase for NousFunctionCall { ChatMessageResponse { model: "".to_string(), created_at: "".to_string(), - message: Some(ChatMessage::tool(error_message)), + message: Some(ChatMessage::assistant(error_message)), done: true, final_data: None, } diff --git a/src/generation/functions/pipelines/openai/request.rs b/src/generation/functions/pipelines/openai/request.rs index cd025b4..43cdc18 100644 --- a/src/generation/functions/pipelines/openai/request.rs +++ b/src/generation/functions/pipelines/openai/request.rs @@ -94,7 +94,7 @@ impl RequestParserBase for OpenAIFunctionCall { ChatMessageResponse { model: "".to_string(), created_at: "".to_string(), - message: Some(ChatMessage::tool(error.to_string())), + message: Some(ChatMessage::assistant(error.to_string())), done: true, final_data: None, } diff --git a/src/generation/functions/tools/finance.rs b/src/generation/functions/tools/finance.rs new file mode 100644 index 0000000..e5447e7 --- /dev/null +++ b/src/generation/functions/tools/finance.rs @@ -0,0 +1,92 @@ +use crate::generation::functions::tools::Tool; +use async_trait::async_trait; +use reqwest::Client; +use scraper::{Html, Selector}; +use serde_json::{json, Value}; +use std::collections::HashMap; +use std::error::Error; + +pub struct StockScraper { + base_url: String, + language: String, +} + +impl StockScraper { + pub fn new() -> Self { + StockScraper { + base_url: "https://www.google.com/finance".to_string(), + language: "en".to_string(), + } + } + + // Changed to an async function + pub async fn scrape( + &self, + exchange: &str, + ticker: &str, + ) -> Result, Box> { + let target_url = format!( + "{}/quote/{}:{}?hl={}", + self.base_url, ticker, exchange, self.language + ); + let client = Client::new(); + let response = client.get(&target_url).send().await?; // Make the request asynchronously + let content = response.text().await?; // Asynchronously get the text of the response + let document = Html::parse_document(&content); + + let items_selector = Selector::parse("div.gyFHrc").unwrap(); + let desc_selector = Selector::parse("div.mfs7Fc").unwrap(); + let value_selector = Selector::parse("div.P6K39c").unwrap(); + + let mut stock_description = HashMap::new(); + + for item in document.select(&items_selector) { + if let Some(item_description) = item.select(&desc_selector).next() { + if let Some(item_value) = item.select(&value_selector).next() { + stock_description.insert( + item_description.text().collect::>().join(""), + item_value.text().collect::>().join(""), + ); + } + } + } + + Ok(stock_description) + } +} + +#[async_trait] +impl Tool for StockScraper { + fn name(&self) -> String { + "Stock Scraper".to_string() + } + + fn description(&self) -> String { + "Scrapes stock information from Google Finance.".to_string() + } + + fn parameters(&self) -> Value { + json!({ + "type": "object", + "properties": { + "exchange": { + "type": "string", + "description": "The stock exchange market identifier code (MIC)" + }, + "ticker": { + "type": "string", + "description": "The ticker symbol of the stock" + } + }, + "required": ["exchange", "ticker"] + }) + } + + async fn run(&self, input: Value) -> Result> { + let exchange = input["exchange"].as_str().ok_or("Exchange is required")?; + let ticker = input["ticker"].as_str().ok_or("Ticker is required")?; + + let result = self.scrape(exchange, ticker).await?; + Ok(serde_json::to_string(&result)?) + } +} diff --git a/src/generation/functions/tools/mod.rs b/src/generation/functions/tools/mod.rs index f49643f..b5f4223 100644 --- a/src/generation/functions/tools/mod.rs +++ b/src/generation/functions/tools/mod.rs @@ -1,6 +1,8 @@ +pub mod finance; pub mod scraper; pub mod search_ddg; +pub use self::finance::StockScraper; pub use self::scraper::Scraper; pub use self::search_ddg::DDGSearcher; diff --git a/tests/function_call.rs b/tests/function_call.rs index ff0d836..6a7e377 100644 --- a/tests/function_call.rs +++ b/tests/function_call.rs @@ -2,7 +2,7 @@ use ollama_rs::{ generation::chat::ChatMessage, - generation::functions::tools::{DDGSearcher, Scraper}, + generation::functions::tools::{DDGSearcher, Scraper, StockScraper}, generation::functions::{FunctionCallRequest, NousFunctionCall}, Ollama, }; @@ -11,11 +11,11 @@ use std::sync::Arc; #[tokio::test] async fn test_send_function_call() { /// Model to be used, make sure it is tailored towards "function calling", such as: - /// - openhermes:latest - /// - adrienbrault/nous-hermes2pro:Q8_0 + /// - OpenAIFunctionCall: not model specific, degraded performance + /// - NousFunctionCall: adrienbrault/nous-hermes2pro:Q8_0 const MODEL: &str = "adrienbrault/nous-hermes2pro:Q8_0"; - const PROMPT: &str = ""; + const PROMPT: &str = "Aside from the Apple Remote, what other device can control the program Apple Remote was originally designed to interact with?"; let user_message = ChatMessage::user(PROMPT.to_string()); let scraper_tool = Arc::new(Scraper::new()); @@ -37,3 +37,58 @@ async fn test_send_function_call() { assert!(result.done); } + +#[tokio::test] +async fn test_send_function_call_with_history() { + /// Model to be used, make sure it is tailored towards "function calling", such as: + /// - OpenAIFunctionCall: not model specific, degraded performance + /// - NousFunctionCall: adrienbrault/nous-hermes2pro:Q8_0 + const MODEL: &str = "adrienbrault/nous-hermes2pro:Q8_0"; + + const PROMPT: &str = "Aside from the Apple Remote, what other device can control the program Apple Remote was originally designed to interact with?"; + let user_message = ChatMessage::user(PROMPT.to_string()); + + let scraper_tool = Arc::new(Scraper::new()); + let ddg_search_tool = Arc::new(DDGSearcher::new()); + let parser = Arc::new(NousFunctionCall::new()); + + let mut ollama = Ollama::new_default_with_history(30); + let result = ollama + .send_function_call_with_history( + FunctionCallRequest::new( + MODEL.to_string(), + vec![scraper_tool, ddg_search_tool], + vec![user_message], + ), + parser, + ) + .await + .unwrap(); + + assert!(result.done); +} + +#[tokio::test] +async fn test_send_function_call_finance() { + /// Model to be used, make sure it is tailored towards "function calling", such as: + /// - OpenAIFunctionCall: not model specific, degraded performance + /// - NousFunctionCall: adrienbrault/nous-hermes2pro:Q8_0 + const MODEL: &str = "adrienbrault/nous-hermes2pro:Q8_0"; + + const PROMPT: &str = "What are the current risk factors to $APPL?"; + let user_message = ChatMessage::user(PROMPT.to_string()); + + let stock_scraper = Arc::new(StockScraper::new()); + let parser = Arc::new(NousFunctionCall::new()); + + let ollama = Ollama::new_default_with_history(30); + let result = ollama + .send_function_call( + FunctionCallRequest::new(MODEL.to_string(), vec![stock_scraper], vec![user_message]), + parser, + ) + .await + .unwrap(); + + assert!(result.done); +} From c6c649c3a9333e59dde857da9e56566c91be6c72 Mon Sep 17 00:00:00 2001 From: andthattoo Date: Tue, 21 May 2024 00:32:46 +0300 Subject: [PATCH 14/18] - fixed clippy errors, --- src/generation/functions/mod.rs | 12 ++++++------ .../functions/pipelines/nous_hermes/request.rs | 10 ++++++++-- src/generation/functions/pipelines/openai/request.rs | 6 +++--- src/generation/functions/tools/finance.rs | 6 ++++++ src/generation/functions/tools/scraper.rs | 7 +++++++ src/generation/functions/tools/search_ddg.rs | 6 ++++++ 6 files changed, 36 insertions(+), 11 deletions(-) diff --git a/src/generation/functions/mod.rs b/src/generation/functions/mod.rs index 3999de8..d48c714 100644 --- a/src/generation/functions/mod.rs +++ b/src/generation/functions/mod.rs @@ -17,9 +17,9 @@ use std::sync::Arc; #[cfg(feature = "function-calling")] impl crate::Ollama { - fn has_system_prompt(&self, messages: &Vec, system_prompt: &str) -> bool { + fn has_system_prompt(&self, messages: &[ChatMessage], system_prompt: &str) -> bool { let system_message = messages.first().unwrap().clone(); - return system_message.content == system_prompt; + system_message.content == system_prompt } fn has_system_prompt_history(&mut self) -> bool { @@ -63,7 +63,7 @@ impl crate::Ollama { ) .await; - return match result { + match result { Ok(r) => { self.add_assistant_response( "default".to_string(), @@ -78,7 +78,7 @@ impl crate::Ollama { ); Ok(e) } - }; + } } pub async fn send_function_call( @@ -102,9 +102,9 @@ impl crate::Ollama { let result = parser .parse(&response_content, model_name, request.tools) .await; - return match result { + match result { Ok(r) => Ok(r), Err(e) => Ok(e), - }; + } } } diff --git a/src/generation/functions/pipelines/nous_hermes/request.rs b/src/generation/functions/pipelines/nous_hermes/request.rs index cda684b..764ab8b 100644 --- a/src/generation/functions/pipelines/nous_hermes/request.rs +++ b/src/generation/functions/pipelines/nous_hermes/request.rs @@ -33,6 +33,12 @@ pub struct NousFunctionCallSignature { pub struct NousFunctionCall {} +impl Default for NousFunctionCall { + fn default() -> Self { + Self::new() + } +} + impl NousFunctionCall { pub fn new() -> Self { Self {} @@ -67,7 +73,7 @@ impl NousFunctionCall { if let Some(matched) = captures.get(1) { let result = matched .as_str() - .replace("\n", "") + .replace('\n', "") .replace("{{", "{") .replace("}}", "}"); return Some(result); @@ -144,7 +150,7 @@ impl RequestParserBase for NousFunctionCall { fn error_handler(&self, error: OllamaError) -> ChatMessageResponse { let error_message = format!( "\nThere was an error parsing function calls\n Here's the error stack trace: {}\nPlease call the function again with correct syntax", - error.to_string() + error ); ChatMessageResponse { diff --git a/src/generation/functions/pipelines/openai/request.rs b/src/generation/functions/pipelines/openai/request.rs index 43cdc18..e435110 100644 --- a/src/generation/functions/pipelines/openai/request.rs +++ b/src/generation/functions/pipelines/openai/request.rs @@ -34,7 +34,7 @@ impl OpenAIFunctionCall { tool: Arc, ) -> Result { let result = tool.run(tool_params).await; - return match result { + match result { Ok(result) => Ok(ChatMessageResponse { model: model_name.clone(), created_at: "".to_string(), @@ -43,7 +43,7 @@ impl OpenAIFunctionCall { final_data: None, }), Err(e) => Err(self.error_handler(OllamaError::from(e))), - }; + } } } @@ -83,7 +83,7 @@ impl RequestParserBase for OpenAIFunctionCall { // Corrected here to use a slice let tools_info: Vec = tools .iter() - .map(|tool| convert_to_ollama_tool(tool)) + .map(convert_to_ollama_tool) .collect(); let tools_json = serde_json::to_string(&tools_info).unwrap(); let system_message_content = DEFAULT_SYSTEM_TEMPLATE.replace("{tools}", &tools_json); diff --git a/src/generation/functions/tools/finance.rs b/src/generation/functions/tools/finance.rs index e5447e7..9161af4 100644 --- a/src/generation/functions/tools/finance.rs +++ b/src/generation/functions/tools/finance.rs @@ -11,6 +11,12 @@ pub struct StockScraper { language: String, } +impl Default for StockScraper { + fn default() -> Self { + Self::new() + } +} + impl StockScraper { pub fn new() -> Self { StockScraper { diff --git a/src/generation/functions/tools/scraper.rs b/src/generation/functions/tools/scraper.rs index 1c7f5ac..c0e29c9 100644 --- a/src/generation/functions/tools/scraper.rs +++ b/src/generation/functions/tools/scraper.rs @@ -7,12 +7,19 @@ use std::error::Error; pub struct Scraper {} +impl Default for Scraper { + fn default() -> Self { + Self::new() + } +} + impl Scraper { pub fn new() -> Self { Self {} } } + #[async_trait] impl Tool for Scraper { fn name(&self) -> String { diff --git a/src/generation/functions/tools/search_ddg.rs b/src/generation/functions/tools/search_ddg.rs index c6f17be..6793bb2 100644 --- a/src/generation/functions/tools/search_ddg.rs +++ b/src/generation/functions/tools/search_ddg.rs @@ -20,6 +20,12 @@ pub struct DDGSearcher { pub base_url: String, } +impl Default for DDGSearcher { + fn default() -> Self { + Self::new() + } +} + impl DDGSearcher { pub fn new() -> Self { DDGSearcher { From ea7201df028cc24828fd2abbefb6bbd6753d04eb Mon Sep 17 00:00:00 2001 From: andthattoo Date: Tue, 21 May 2024 22:13:57 +0300 Subject: [PATCH 15/18] cargo fmt --all --- src/generation/functions/pipelines/openai/request.rs | 6 +----- src/generation/functions/tools/scraper.rs | 1 - 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/src/generation/functions/pipelines/openai/request.rs b/src/generation/functions/pipelines/openai/request.rs index e435110..1918266 100644 --- a/src/generation/functions/pipelines/openai/request.rs +++ b/src/generation/functions/pipelines/openai/request.rs @@ -80,11 +80,7 @@ impl RequestParserBase for OpenAIFunctionCall { } async fn get_system_message(&self, tools: &[Arc]) -> ChatMessage { - // Corrected here to use a slice - let tools_info: Vec = tools - .iter() - .map(convert_to_ollama_tool) - .collect(); + let tools_info: Vec = tools.iter().map(convert_to_ollama_tool).collect(); let tools_json = serde_json::to_string(&tools_info).unwrap(); let system_message_content = DEFAULT_SYSTEM_TEMPLATE.replace("{tools}", &tools_json); ChatMessage::system(system_message_content) diff --git a/src/generation/functions/tools/scraper.rs b/src/generation/functions/tools/scraper.rs index c0e29c9..2398e72 100644 --- a/src/generation/functions/tools/scraper.rs +++ b/src/generation/functions/tools/scraper.rs @@ -19,7 +19,6 @@ impl Scraper { } } - #[async_trait] impl Tool for Scraper { fn name(&self) -> String { From d6a2304b0db6af0bdb43d9f385fdc66e1d16597a Mon Sep 17 00:00:00 2001 From: andthattoo Date: Wed, 22 May 2024 14:35:07 +0300 Subject: [PATCH 16/18] StockScraper added to pub use --- src/generation/functions/mod.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/generation/functions/mod.rs b/src/generation/functions/mod.rs index d48c714..7188d2b 100644 --- a/src/generation/functions/mod.rs +++ b/src/generation/functions/mod.rs @@ -7,6 +7,7 @@ pub use crate::generation::functions::pipelines::openai::request::OpenAIFunction pub use crate::generation::functions::request::FunctionCallRequest; pub use tools::DDGSearcher; pub use tools::Scraper; +pub use tools::StockScraper; use crate::error::OllamaError; use crate::generation::chat::request::ChatMessageRequest; From 9ad7b37e8e2f04decd50536ef7539fe99407013d Mon Sep 17 00:00:00 2001 From: andthattoo Date: Wed, 22 May 2024 18:45:39 +0300 Subject: [PATCH 17/18] id is now a param for `send_function_call_with_history` --- src/generation/functions/mod.rs | 15 +++++---------- tests/function_call.rs | 5 +++-- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/src/generation/functions/mod.rs b/src/generation/functions/mod.rs index 7188d2b..34d1a32 100644 --- a/src/generation/functions/mod.rs +++ b/src/generation/functions/mod.rs @@ -32,12 +32,13 @@ impl crate::Ollama { &mut self, request: FunctionCallRequest, parser: Arc, + id: String, ) -> Result { let mut request = request; if !self.has_system_prompt_history() { let system_prompt = parser.get_system_message(&request.tools).await; - self.set_system_response("default".to_string(), system_prompt.content); + self.set_system_response(id.clone(), system_prompt.content); //format input let formatted_query = ChatMessage::user( @@ -51,7 +52,7 @@ impl crate::Ollama { let tool_call_result = self .send_chat_messages_with_history( ChatMessageRequest::new(request.chat.model_name.clone(), request.chat.messages), - "default".to_string(), + id.clone(), ) .await?; @@ -66,17 +67,11 @@ impl crate::Ollama { match result { Ok(r) => { - self.add_assistant_response( - "default".to_string(), - r.message.clone().unwrap().content, - ); + self.add_assistant_response(id.clone(), r.message.clone().unwrap().content); Ok(r) } Err(e) => { - self.add_assistant_response( - "default".to_string(), - e.message.clone().unwrap().content, - ); + self.add_assistant_response(id.clone(), e.message.clone().unwrap().content); Ok(e) } } diff --git a/tests/function_call.rs b/tests/function_call.rs index 6a7e377..8b8e7f7 100644 --- a/tests/function_call.rs +++ b/tests/function_call.rs @@ -22,7 +22,7 @@ async fn test_send_function_call() { let ddg_search_tool = Arc::new(DDGSearcher::new()); let parser = Arc::new(NousFunctionCall::new()); - let ollama = Ollama::new_default_with_history(30); + let ollama = Ollama::default(); let result = ollama .send_function_call( FunctionCallRequest::new( @@ -61,6 +61,7 @@ async fn test_send_function_call_with_history() { vec![user_message], ), parser, + "default".to_string(), ) .await .unwrap(); @@ -81,7 +82,7 @@ async fn test_send_function_call_finance() { let stock_scraper = Arc::new(StockScraper::new()); let parser = Arc::new(NousFunctionCall::new()); - let ollama = Ollama::new_default_with_history(30); + let ollama = Ollama::default(); let result = ollama .send_function_call( FunctionCallRequest::new(MODEL.to_string(), vec![stock_scraper], vec![user_message]), From bb7cdff9bd2987703e7a263df1012ab76418fc1d Mon Sep 17 00:00:00 2001 From: andthattoo Date: Tue, 28 May 2024 01:58:14 +0300 Subject: [PATCH 18/18] OpenAIFunctionCall attributes changed to "name" and "arguments" instead of "tool" and "tool_inputs". Testing showed it worked better with different models. --- .../pipelines/nous_hermes/request.rs | 7 +-- .../functions/pipelines/openai/prompts.rs | 4 +- .../functions/pipelines/openai/request.rs | 45 ++++++++++++------- 3 files changed, 34 insertions(+), 22 deletions(-) diff --git a/src/generation/functions/pipelines/nous_hermes/request.rs b/src/generation/functions/pipelines/nous_hermes/request.rs index 764ab8b..f1c30a8 100644 --- a/src/generation/functions/pipelines/nous_hermes/request.rs +++ b/src/generation/functions/pipelines/nous_hermes/request.rs @@ -10,7 +10,7 @@ use serde_json::{json, Map, Value}; use std::collections::HashMap; use std::sync::Arc; -pub fn convert_to_openai_tool(tool: Arc) -> Value { +pub fn convert_to_openai_tool(tool: &Arc) -> Value { let mut function = HashMap::new(); function.insert("name".to_string(), Value::String(tool.name())); function.insert("description".to_string(), Value::String(tool.description())); @@ -138,10 +138,7 @@ impl RequestParserBase for NousFunctionCall { } async fn get_system_message(&self, tools: &[Arc]) -> ChatMessage { - let tools_info: Vec = tools - .iter() - .map(|tool| convert_to_openai_tool(tool.clone())) - .collect(); + let tools_info: Vec = tools.iter().map(convert_to_openai_tool).collect(); let tools_json = serde_json::to_string(&tools_info).unwrap(); let system_message_content = DEFAULT_SYSTEM_TEMPLATE.replace("{tools}", &tools_json); ChatMessage::system(system_message_content) diff --git a/src/generation/functions/pipelines/openai/prompts.rs b/src/generation/functions/pipelines/openai/prompts.rs index e858ca9..d497aa0 100644 --- a/src/generation/functions/pipelines/openai/prompts.rs +++ b/src/generation/functions/pipelines/openai/prompts.rs @@ -9,8 +9,8 @@ Don't make assumptions about what values to plug into function arguments. You must always select one of the above tools and respond with only a JSON object matching the following schema: { - "tool": , - "tool_input": + "name": , + "arguments": } "#; diff --git a/src/generation/functions/pipelines/openai/request.rs b/src/generation/functions/pipelines/openai/request.rs index 1918266..4a42cd3 100644 --- a/src/generation/functions/pipelines/openai/request.rs +++ b/src/generation/functions/pipelines/openai/request.rs @@ -5,23 +5,29 @@ use crate::generation::functions::pipelines::RequestParserBase; use crate::generation::functions::tools::Tool; use async_trait::async_trait; use serde::{Deserialize, Serialize}; -use serde_json::json; -use serde_json::Value; +use serde_json::{json, Map, Value}; +use std::collections::HashMap; use std::sync::Arc; -pub fn convert_to_ollama_tool(tool: &Arc) -> Value { - let schema = tool.parameters(); - json!({ - "name": tool.name(), - "properties": schema["properties"], - "required": schema["required"] - }) +pub fn convert_to_openai_tool(tool: &Arc) -> Value { + let mut function = HashMap::new(); + function.insert("name".to_string(), Value::String(tool.name())); + function.insert("description".to_string(), Value::String(tool.description())); + function.insert("parameters".to_string(), tool.parameters()); + + let mut result = HashMap::new(); + result.insert("type".to_string(), Value::String("function".to_string())); + + let mapp: Map = function.into_iter().collect(); + result.insert("function".to_string(), Value::Object(mapp)); + + json!(result) } #[derive(Debug, Deserialize, Serialize)] pub struct OpenAIFunctionCallSignature { - pub tool: String, //name of the tool - pub tool_input: Value, + pub name: String, //name of the tool + pub arguments: Value, } pub struct OpenAIFunctionCall {} @@ -45,6 +51,15 @@ impl OpenAIFunctionCall { Err(e) => Err(self.error_handler(OllamaError::from(e))), } } + + fn clean_tool_call(&self, json_str: &str) -> String { + json_str + .trim() + .trim_start_matches("```json") + .trim_end_matches("```") + .trim() + .to_string() + } } #[async_trait] @@ -56,11 +71,11 @@ impl RequestParserBase for OpenAIFunctionCall { tools: Vec>, ) -> Result { let response_value: Result = - serde_json::from_str(input); + serde_json::from_str(&self.clean_tool_call(input)); match response_value { Ok(response) => { - if let Some(tool) = tools.iter().find(|t| t.name() == response.tool) { - let tool_params = response.tool_input; + if let Some(tool) = tools.iter().find(|t| t.name() == response.name) { + let tool_params = response.arguments; let result = self .function_call_with_history( model_name.clone(), @@ -80,7 +95,7 @@ impl RequestParserBase for OpenAIFunctionCall { } async fn get_system_message(&self, tools: &[Arc]) -> ChatMessage { - let tools_info: Vec = tools.iter().map(convert_to_ollama_tool).collect(); + let tools_info: Vec = tools.iter().map(convert_to_openai_tool).collect(); let tools_json = serde_json::to_string(&tools_info).unwrap(); let system_message_content = DEFAULT_SYSTEM_TEMPLATE.replace("{tools}", &tools_json); ChatMessage::system(system_message_content)