diff --git a/Cargo.lock b/Cargo.lock index 9f61ef8..dbc31fd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,51 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "ahash" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +dependencies = [ + "cfg-if", + "getrandom", + "once_cell", + "version_check", + "zerocopy", +] + +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + +[[package]] +name = "async-trait" +version = "0.1.80" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.59", +] + +[[package]] +name = "auto_enums" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1899bfcfd9340ceea3533ea157360ba8fa864354eccbceab58e1006ecab35393" +dependencies = [ + "derive_utils", + "proc-macro2", + "quote", + "syn 2.0.59", +] + [[package]] name = "autocfg" version = "1.1.0" @@ -62,6 +107,12 @@ version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec" +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + [[package]] name = "bytes" version = "1.5.0" @@ -99,6 +150,78 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" +[[package]] +name = "cssparser" +version = "0.31.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b3df4f93e5fbbe73ec01ec8d3f68bba73107993a5b1e7519273c32db9b0d5be" +dependencies = [ + "cssparser-macros", + "dtoa-short", + "itoa", + "phf 0.11.2", + "smallvec", +] + +[[package]] +name = "cssparser-macros" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13b588ba4ac1a99f7f2964d24b3d896ddc6bf847ee3855dbd4366f058cfcd331" +dependencies = [ + "quote", + "syn 2.0.59", +] + +[[package]] +name = "derive_more" +version = "0.99.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fb810d30a7c1953f91334de7244731fc3f3c10d7fe163338a35b9f640960321" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "derive_utils" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61bb5a1014ce6dfc2a378578509abe775a5aa06bff584a547555d9efdb81b926" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.59", +] + +[[package]] +name = "dtoa" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcbb2bf8e87535c23f7a8a321e364ce21462d0ff10cb6407820e8e96dfff6653" + +[[package]] +name = "dtoa-short" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbaceec3c6e4211c79e7b1800fb9680527106beb2f9c51904a3210c03a448c74" +dependencies = [ + "dtoa", +] + +[[package]] +name = "ego-tree" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a68a4904193147e0a8dec3314640e6db742afd5f6e634f428a6af230d9b3591" + +[[package]] +name = "either" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2" + [[package]] name = "errno" version = "0.3.5" @@ -145,6 +268,16 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futf" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df420e2e84819663797d1ec6544b13c5be84629e7bb00dc960d6917db2987843" +dependencies = [ + "mac", + "new_debug_unreachable", +] + [[package]] name = "futures-channel" version = "0.3.29" @@ -174,7 +307,7 @@ checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.59", ] [[package]] @@ -206,6 +339,24 @@ dependencies = [ "slab", ] +[[package]] +name = "fxhash" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c31b6d751ae2c7f11320402d34e41349dd1016f8d5d45e48c4312bc8625af50c" +dependencies = [ + "byteorder", +] + +[[package]] +name = "getopts" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14dbbfd5c71d70241ecf9e6f13737f7b5ce823821063188d7e46c41d371eebd5" +dependencies = [ + "unicode-width", +] + [[package]] name = "getrandom" version = "0.2.11" @@ -223,12 +374,32 @@ version = "0.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0" +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + [[package]] name = "hermit-abi" version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" +[[package]] +name = "html5ever" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bea68cab48b8459f17cf1c944c67ddc572d272d9f2b274140f223ecb1da4a3b7" +dependencies = [ + "log", + "mac", + "markup5ever", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "http" version = "1.1.0" @@ -357,6 +528,15 @@ version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.9" @@ -406,6 +586,26 @@ version = "0.4.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" +[[package]] +name = "mac" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c41e0c4fef86961ac6d6f8a82609f55f31b05e4fce149ac5710e439df7619ba4" + +[[package]] +name = "markup5ever" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2629bb1404f3d34c2e921f21fd34ba00b206124c81f65c50b43b6aaefeb016" +dependencies = [ + "log", + "phf 0.10.1", + "phf_codegen", + "string_cache", + "string_cache_codegen", + "tendril", +] + [[package]] name = "memchr" version = "2.6.4" @@ -456,6 +656,12 @@ dependencies = [ "tempfile", ] +[[package]] +name = "new_debug_unreachable" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" + [[package]] name = "num_cpus" version = "1.16.0" @@ -479,11 +685,16 @@ dependencies = [ name = "ollama-rs" version = "0.1.9" dependencies = [ + "async-trait", "base64", + "log", "ollama-rs", + "regex", "reqwest", + "scraper", "serde", "serde_json", + "text-splitter", "tokio", "tokio-stream", "url", @@ -491,9 +702,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.18.0" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "openssl" @@ -518,7 +729,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.59", ] [[package]] @@ -568,6 +779,86 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "phf" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fabbf1ead8a5bcbc20f5f8b939ee3f5b0f6f281b6ad3468b84656b658b455259" +dependencies = [ + "phf_shared 0.10.0", +] + +[[package]] +name = "phf" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc" +dependencies = [ + "phf_macros", + "phf_shared 0.11.2", +] + +[[package]] +name = "phf_codegen" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fb1c3a8bc4dd4e5cfce29b44ffc14bedd2ee294559a294e2a4d4c9e9a6a13cd" +dependencies = [ + "phf_generator 0.10.0", + "phf_shared 0.10.0", +] + +[[package]] +name = "phf_generator" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d5285893bb5eb82e6aaf5d59ee909a06a16737a8970984dd7746ba9283498d6" +dependencies = [ + "phf_shared 0.10.0", + "rand", +] + +[[package]] +name = "phf_generator" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48e4cc64c2ad9ebe670cb8fd69dd50ae301650392e81c05f9bfcb2d5bdbc24b0" +dependencies = [ + "phf_shared 0.11.2", + "rand", +] + +[[package]] +name = "phf_macros" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3444646e286606587e49f3bcf1679b8cef1dc2c5ecc29ddacaffc305180d464b" +dependencies = [ + "phf_generator 0.11.2", + "phf_shared 0.11.2", + "proc-macro2", + "quote", + "syn 2.0.59", +] + +[[package]] +name = "phf_shared" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6796ad771acdc0123d2a88dc428b5e38ef24456743ddb1744ed628f9815c096" +dependencies = [ + "siphasher", +] + +[[package]] +name = "phf_shared" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90fcb95eef784c2ac79119d1dd819e162b5da872ce6f3c3abe1e8ca1c082f72b" +dependencies = [ + "siphasher", +] + [[package]] name = "pin-project" version = "1.1.5" @@ -585,7 +876,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.59", ] [[package]] @@ -606,6 +897,18 @@ version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + +[[package]] +name = "precomputed-hash" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" + [[package]] name = "proc-macro2" version = "1.0.81" @@ -624,6 +927,36 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + [[package]] name = "redox_syscall" version = "0.4.1" @@ -633,6 +966,35 @@ dependencies = [ "bitflags 1.3.2", ] +[[package]] +name = "regex" +version = "1.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" + [[package]] name = "reqwest" version = "0.12.4" @@ -753,6 +1115,12 @@ dependencies = [ "untrusted", ] +[[package]] +name = "rustversion" +version = "1.0.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "092474d1a01ea8278f69e6a358998405fae5b8b963ddaeb2b0b04a128bf1dfb0" + [[package]] name = "ryu" version = "1.0.15" @@ -774,6 +1142,22 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "scraper" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b80b33679ff7a0ea53d37f3b39de77ea0c75b12c5805ac43ec0c33b3051af1b" +dependencies = [ + "ahash", + "cssparser", + "ego-tree", + "getopts", + "html5ever", + "once_cell", + "selectors", + "tendril", +] + [[package]] name = "security-framework" version = "2.9.2" @@ -797,6 +1181,25 @@ dependencies = [ "libc", ] +[[package]] +name = "selectors" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4eb30575f3638fc8f6815f448d50cb1a2e255b0897985c8c59f4d37b72a07b06" +dependencies = [ + "bitflags 2.4.1", + "cssparser", + "derive_more", + "fxhash", + "log", + "new_debug_unreachable", + "phf 0.10.1", + "phf_codegen", + "precomputed-hash", + "servo_arc", + "smallvec", +] + [[package]] name = "serde" version = "1.0.198" @@ -814,7 +1217,7 @@ checksum = "e88edab869b01783ba905e7d0153f9fc1a6505a96e4ad3018011eedb838566d9" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.59", ] [[package]] @@ -840,6 +1243,15 @@ dependencies = [ "serde", ] +[[package]] +name = "servo_arc" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d036d71a959e00c77a63538b90a6c2390969f9772b096ea837205c6bd0491a44" +dependencies = [ + "stable_deref_trait", +] + [[package]] name = "signal-hook-registry" version = "1.4.1" @@ -849,6 +1261,12 @@ dependencies = [ "libc", ] +[[package]] +name = "siphasher" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" + [[package]] name = "slab" version = "0.4.9" @@ -880,12 +1298,77 @@ version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +[[package]] +name = "stable_deref_trait" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" + +[[package]] +name = "string_cache" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f91138e76242f575eb1d3b38b4f1362f10d3a43f47d182a5b359af488a02293b" +dependencies = [ + "new_debug_unreachable", + "once_cell", + "parking_lot", + "phf_shared 0.10.0", + "precomputed-hash", + "serde", +] + +[[package]] +name = "string_cache_codegen" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bb30289b722be4ff74a408c3cc27edeaad656e06cb1fe8fa9231fa59c728988" +dependencies = [ + "phf_generator 0.10.0", + "phf_shared 0.10.0", + "proc-macro2", + "quote", +] + +[[package]] +name = "strum" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d8cec3501a5194c432b2b7976db6b7d10ec95c253208b45f83f7136aa985e29" +dependencies = [ + "strum_macros", +] + +[[package]] +name = "strum_macros" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6cf59daf282c0a494ba14fd21610a0325f9f90ec9d1231dea26bcb1d696c946" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.59", +] + [[package]] name = "subtle" version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "syn" version = "2.0.59" @@ -916,6 +1399,54 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "tendril" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d24a120c5fc464a3458240ee02c299ebcb9d67b5249c8848b09d639dca8d7bb0" +dependencies = [ + "futf", + "mac", + "utf-8", +] + +[[package]] +name = "text-splitter" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3634ff66852bfbf7e8e987735bac08168daa5d42dc39a0df7d05fc83eaa3fe4" +dependencies = [ + "ahash", + "auto_enums", + "either", + "itertools", + "once_cell", + "regex", + "strum", + "thiserror", + "unicode-segmentation", +] + +[[package]] +name = "thiserror" +version = "1.0.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "579e9083ca58dd9dcf91a9923bb9054071b9ebbd800b342194c9feb0ee89fc18" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2470041c06ec3ac1ab38d0356a6119054dedaea53e12fbefc0de730a1c08524" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.59", +] + [[package]] name = "tinyvec" version = "1.6.0" @@ -958,7 +1489,7 @@ checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.59", ] [[package]] @@ -1082,6 +1613,18 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "unicode-segmentation" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202" + +[[package]] +name = "unicode-width" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68f5e5f3158ecfd4b8ff6fe086db7c8467a2dfdac97fe420f2b7c4aa97af66d6" + [[package]] name = "untrusted" version = "0.9.0" @@ -1099,12 +1642,24 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "vcpkg" version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" +[[package]] +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" + [[package]] name = "want" version = "0.3.1" @@ -1141,7 +1696,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn", + "syn 2.0.59", "wasm-bindgen-shared", ] @@ -1175,7 +1730,7 @@ checksum = "c5353b8dab669f5e10f5bd76df26a9360c748f054f862ff5f3f8aae0c7fb3907" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.59", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -1294,6 +1849,26 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "zerocopy" +version = "0.7.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae87e3fcd617500e5d106f0380cf7b77f3c6092aae37191433159dda23cfb087" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.59", +] + [[package]] name = "zeroize" version = "1.7.0" diff --git a/Cargo.toml b/Cargo.toml index 880a015..cca9f52 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,18 +10,24 @@ readme = "README.md" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -reqwest = { version = "0.12.4", default-features = false } +reqwest = { version = "0.12.4", default-features = false, features = ["json"] } serde = { version = "1", features = ["derive"] } serde_json = "1" tokio = { version = "1", features = ["full"], optional = true } tokio-stream = { version = "0.1.15", optional = true } +async-trait = { version = "0.1.73" } url = "2" +log = "0.4" +scraper = { version = "0.19.0", optional = true } +text-splitter = { version = "0.13.1", optional = true } +regex = { version = "1.9.3", optional = true } [features] default = ["reqwest/default-tls"] stream = ["tokio-stream", "reqwest/stream", "tokio"] rustls = ["reqwest/rustls-tls"] chat-history = [] +function-calling = ["scraper", "text-splitter", "regex", "chat-history"] [dev-dependencies] tokio = { version = "1", features = ["full"] } diff --git a/README.md b/README.md index c768403..f55e3c0 100644 --- a/README.md +++ b/README.md @@ -147,3 +147,21 @@ let res = ollama.generate_embeddings("llama2:latest".to_string(), prompt, None). ``` _Returns a `GenerateEmbeddingsResponse` struct containing the embeddings (a vector of floats)._ + +### Make a function call + +```rust +let tools = vec![Arc::new(Scraper::new())]; +let parser = Arc::new(NousFunctionCall::new()); +let message = ChatMessage::user("What is the current oil price?".to_string()); +let res = ollama.send_function_call( + FunctionCallRequest::new( + "adrienbrault/nous-hermes2pro:Q8_0".to_string(), + tools, + vec![message], + ), + parser, + ).await.unwrap(); +``` + +_Uses the given tools (such as searching the web) to find an answer, returns a `ChatMessageResponse` with the answer to the question._ diff --git a/src/error.rs b/src/error.rs index 66f4e8c..436bc9d 100644 --- a/src/error.rs +++ b/src/error.rs @@ -34,3 +34,19 @@ impl From for OllamaError { Self { message } } } + +impl From> for OllamaError { + fn from(error: Box) -> Self { + Self { + message: error.to_string(), + } + } +} + +impl From for OllamaError { + fn from(error: serde_json::Error) -> Self { + Self { + message: error.to_string(), + } + } +} diff --git a/src/generation.rs b/src/generation.rs index e739bcd..a29c4ae 100644 --- a/src/generation.rs +++ b/src/generation.rs @@ -1,6 +1,8 @@ pub mod chat; pub mod completion; pub mod embeddings; +#[cfg(feature = "function-calling")] +pub mod functions; pub mod images; pub mod options; pub mod parameters; diff --git a/src/generation/chat/mod.rs b/src/generation/chat/mod.rs index 32fb337..6f9e437 100644 --- a/src/generation/chat/mod.rs +++ b/src/generation/chat/mod.rs @@ -1,12 +1,9 @@ use serde::{Deserialize, Serialize}; use crate::Ollama; - pub mod request; - -use request::ChatMessageRequest; - use super::images::Image; +use request::ChatMessageRequest; #[cfg(feature = "chat-history")] use crate::history::MessagesHistory; diff --git a/src/generation/functions/mod.rs b/src/generation/functions/mod.rs new file mode 100644 index 0000000..34d1a32 --- /dev/null +++ b/src/generation/functions/mod.rs @@ -0,0 +1,106 @@ +pub mod pipelines; +pub mod request; +pub mod tools; + +pub use crate::generation::functions::pipelines::nous_hermes::request::NousFunctionCall; +pub use crate::generation::functions::pipelines::openai::request::OpenAIFunctionCall; +pub use crate::generation::functions::request::FunctionCallRequest; +pub use tools::DDGSearcher; +pub use tools::Scraper; +pub use tools::StockScraper; + +use crate::error::OllamaError; +use crate::generation::chat::request::ChatMessageRequest; +use crate::generation::chat::{ChatMessage, ChatMessageResponse}; +use crate::generation::functions::pipelines::RequestParserBase; +use crate::generation::functions::tools::Tool; +use std::sync::Arc; + +#[cfg(feature = "function-calling")] +impl crate::Ollama { + fn has_system_prompt(&self, messages: &[ChatMessage], system_prompt: &str) -> bool { + let system_message = messages.first().unwrap().clone(); + system_message.content == system_prompt + } + + fn has_system_prompt_history(&mut self) -> bool { + return self.get_messages_history("default".to_string()).is_some(); + } + + #[cfg(feature = "chat-history")] + pub async fn send_function_call_with_history( + &mut self, + request: FunctionCallRequest, + parser: Arc, + id: String, + ) -> Result { + let mut request = request; + + if !self.has_system_prompt_history() { + let system_prompt = parser.get_system_message(&request.tools).await; + self.set_system_response(id.clone(), system_prompt.content); + + //format input + let formatted_query = ChatMessage::user( + parser.format_query(&request.chat.messages.first().unwrap().content), + ); + //replace with formatted_query with previous chat_message + request.chat.messages.remove(0); + request.chat.messages.insert(0, formatted_query); + } + + let tool_call_result = self + .send_chat_messages_with_history( + ChatMessageRequest::new(request.chat.model_name.clone(), request.chat.messages), + id.clone(), + ) + .await?; + + let tool_call_content: String = tool_call_result.message.clone().unwrap().content; + let result = parser + .parse( + &tool_call_content, + request.chat.model_name.clone(), + request.tools, + ) + .await; + + match result { + Ok(r) => { + self.add_assistant_response(id.clone(), r.message.clone().unwrap().content); + Ok(r) + } + Err(e) => { + self.add_assistant_response(id.clone(), e.message.clone().unwrap().content); + Ok(e) + } + } + } + + pub async fn send_function_call( + &self, + request: FunctionCallRequest, + parser: Arc, + ) -> Result { + let mut request = request; + + request.chat.stream = false; + let system_prompt = parser.get_system_message(&request.tools).await; + let model_name = request.chat.model_name.clone(); + + //Make sure the first message in chat is the system prompt + if !self.has_system_prompt(&request.chat.messages, &system_prompt.content) { + request.chat.messages.insert(0, system_prompt); + } + let result = self.send_chat_messages(request.chat).await?; + let response_content: String = result.message.clone().unwrap().content; + + let result = parser + .parse(&response_content, model_name, request.tools) + .await; + match result { + Ok(r) => Ok(r), + Err(e) => Ok(e), + } + } +} diff --git a/src/generation/functions/pipelines/mod.rs b/src/generation/functions/pipelines/mod.rs new file mode 100644 index 0000000..057e347 --- /dev/null +++ b/src/generation/functions/pipelines/mod.rs @@ -0,0 +1,26 @@ +use crate::error::OllamaError; +use crate::generation::chat::{ChatMessage, ChatMessageResponse}; +use crate::generation::functions::tools::Tool; +use async_trait::async_trait; +use std::sync::Arc; + +pub mod nous_hermes; +pub mod openai; + +#[async_trait] +pub trait RequestParserBase { + async fn parse( + &self, + input: &str, + model_name: String, + tools: Vec>, + ) -> Result; + fn format_query(&self, input: &str) -> String { + input.to_string() + } + fn format_response(&self, response: &str) -> String { + response.to_string() + } + async fn get_system_message(&self, tools: &[Arc]) -> ChatMessage; + fn error_handler(&self, error: OllamaError) -> ChatMessageResponse; +} diff --git a/src/generation/functions/pipelines/nous_hermes/mod.rs b/src/generation/functions/pipelines/nous_hermes/mod.rs new file mode 100644 index 0000000..23c05af --- /dev/null +++ b/src/generation/functions/pipelines/nous_hermes/mod.rs @@ -0,0 +1,4 @@ +pub mod prompts; +pub mod request; + +pub use prompts::DEFAULT_SYSTEM_TEMPLATE; diff --git a/src/generation/functions/pipelines/nous_hermes/prompts.rs b/src/generation/functions/pipelines/nous_hermes/prompts.rs new file mode 100644 index 0000000..fc0080b --- /dev/null +++ b/src/generation/functions/pipelines/nous_hermes/prompts.rs @@ -0,0 +1,67 @@ +pub const DEFAULT_SYSTEM_TEMPLATE: &str = r#" +Role: | + You are a function calling AI agent with self-recursion. + You can call only one function at a time and analyse data you get from function response. + You are provided with function signatures within XML tags. + The current date is: {date}. +Objective: | + You may use agentic frameworks for reasoning and planning to help with user query. + Please call a function and wait for function results to be provided to you in the next iteration. + Don't make assumptions about what values to plug into function arguments. + Once you have called a function, results will be fed back to you within XML tags. + Don't make assumptions about tool results if XML tags are not present since function hasn't been executed yet. + Analyze the data once you get the results and call another function. + At each iteration please continue adding the your analysis to previous summary. + Your final response should directly answer the user query with an anlysis or summary of the results of function calls. +Tools: | + Here are the available tools: + {tools} + If the provided function signatures doesn't have the function you must call, you may write executable rust code in markdown syntax and call code_interpreter() function as follows: + + {"arguments": {"code_markdown": , "name": "code_interpreter"}} + + Make sure that the json object above with code markdown block is parseable with json.loads() and the XML block with XML ElementTree. +Examples: | + Here are some example usage of functions: + [ + { + "example": "```\nSYSTEM: You are a helpful assistant who has access to functions. Use them if required\n[\n {\n \"name\": \"calculate_distance\",\n \"description\": \"Calculate the distance between two locations\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"origin\": {\n \"type\": \"string\",\n \"description\": \"The starting location\"\n },\n \"destination\": {\n \"type\": \"string\",\n \"description\": \"The destination location\"\n },\n \"mode\": {\n \"type\": \"string\",\n \"description\": \"The mode of transportation\"\n }\n },\n \"required\": [\n \"origin\",\n \"destination\",\n \"mode\"\n ]\n }\n },\n {\n \"name\": \"generate_password\",\n \"description\": \"Generate a random password\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"length\": {\n \"type\": \"integer\",\n \"description\": \"The length of the password\"\n }\n },\n \"required\": [\n \"length\"\n ]\n }\n }\n]\n\n\nUSER: Hi, I need to know the distance from New York to Los Angeles by car.\nASSISTANT:\n\n{\"arguments\": {\"origin\": \"New York\",\n \"destination\": \"Los Angeles\", \"mode\": \"car\"}, \"name\": \"calculate_distance\"}\n\n```\n" + }, + { + "example": "```\nSYSTEM: You are a helpful assistant with access to functions. Use them if required\n[\n {\n \"name\": \"calculate_distance\",\n \"description\": \"Calculate the distance between two locations\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"origin\": {\n \"type\": \"string\",\n \"description\": \"The starting location\"\n },\n \"destination\": {\n \"type\": \"string\",\n \"description\": \"The destination location\"\n },\n \"mode\": {\n \"type\": \"string\",\n \"description\": \"The mode of transportation\"\n }\n },\n \"required\": [\n \"origin\",\n \"destination\",\n \"mode\"\n ]\n }\n },\n {\n \"name\": \"generate_password\",\n \"description\": \"Generate a random password\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"length\": {\n \"type\": \"integer\",\n \"description\": \"The length of the password\"\n }\n },\n \"required\": [\n \"length\"\n ]\n }\n }\n]\n\n\nUSER: Can you help me generate a random password with a length of 8 characters?\nASSISTANT:\n\n{\"arguments\": {\"length\": 8}, \"name\": \"generate_password\"}\n\n```" + } +] +Schema: | + Use the following pydantic model json schema for each tool call you will make: + { + "name": "tool name", + "description": "tool description", + "parameters": { + "type": "object", + "properties": { + "parameter1": { + "type": "string", + "description": "parameter description" + }, + "parameter2": { + "type": "string", + "description": "parameter description" + } + }, + "required": [ + "parameter1", + "parameter2" + ] + } + } +Instructions: | + At the very first turn you don't have so you shouldn't not make up the results. + Please keep a running summary with analysis of previous function results and summaries from previous iterations. + Do not stop calling functions until the task has been accomplished or you've reached max iteration of 10. + Calling multiple functions at once can overload the system and increase cost so call one function at a time please. + If you plan to continue with analysis, always call another function. + For each function call return a valid json object (using doulbe quotes) with function name and arguments within XML tags as follows: + + {"arguments": , "name": } + +"#; diff --git a/src/generation/functions/pipelines/nous_hermes/request.rs b/src/generation/functions/pipelines/nous_hermes/request.rs new file mode 100644 index 0000000..f1c30a8 --- /dev/null +++ b/src/generation/functions/pipelines/nous_hermes/request.rs @@ -0,0 +1,161 @@ +use crate::error::OllamaError; +use crate::generation::chat::{ChatMessage, ChatMessageResponse}; +use crate::generation::functions::pipelines::nous_hermes::DEFAULT_SYSTEM_TEMPLATE; +use crate::generation::functions::pipelines::RequestParserBase; +use crate::generation::functions::tools::Tool; +use async_trait::async_trait; +use regex::Regex; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Map, Value}; +use std::collections::HashMap; +use std::sync::Arc; + +pub fn convert_to_openai_tool(tool: &Arc) -> Value { + let mut function = HashMap::new(); + function.insert("name".to_string(), Value::String(tool.name())); + function.insert("description".to_string(), Value::String(tool.description())); + function.insert("parameters".to_string(), tool.parameters()); + + let mut result = HashMap::new(); + result.insert("type".to_string(), Value::String("function".to_string())); + + let mapp: Map = function.into_iter().collect(); + result.insert("function".to_string(), Value::Object(mapp)); + + json!(result) +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct NousFunctionCallSignature { + pub name: String, + pub arguments: Value, +} + +pub struct NousFunctionCall {} + +impl Default for NousFunctionCall { + fn default() -> Self { + Self::new() + } +} + +impl NousFunctionCall { + pub fn new() -> Self { + Self {} + } + + pub async fn function_call_with_history( + &self, + model_name: String, + tool_params: Value, + tool: Arc, + ) -> Result { + let result = tool.run(tool_params).await; + match result { + Ok(result) => Ok(ChatMessageResponse { + model: model_name.clone(), + created_at: "".to_string(), + message: Some(ChatMessage::assistant(self.format_tool_response(&result))), + done: true, + final_data: None, + }), + Err(e) => Err(self.error_handler(OllamaError::from(e))), + } + } + + pub fn format_tool_response(&self, function_response: &str) -> String { + format!("\n{}\n\n", function_response) + } + + pub fn extract_tool_call(&self, content: &str) -> Option { + let re = Regex::new(r"(?s)(.*?)").unwrap(); + if let Some(captures) = re.captures(content) { + if let Some(matched) = captures.get(1) { + let result = matched + .as_str() + .replace('\n', "") + .replace("{{", "{") + .replace("}}", "}"); + return Some(result); + } + } + None + } +} + +#[async_trait] +impl RequestParserBase for NousFunctionCall { + async fn parse( + &self, + input: &str, + model_name: String, + tools: Vec>, + ) -> Result { + //Extract between and + let tool_response = self.extract_tool_call(input); + match tool_response { + Some(tool_response_str) => { + let response_value: Result = + serde_json::from_str(&tool_response_str); + match response_value { + Ok(response) => { + if let Some(tool) = tools.iter().find(|t| t.name() == response.name) { + let tool_params = response.arguments; + let result = self + .function_call_with_history( + model_name.clone(), + tool_params.clone(), + tool.clone(), + ) + .await?; //Error is also returned as String for LLM feedback + return Ok(result); + } else { + return Err(self.error_handler(OllamaError::from( + "Tool name not found".to_string(), + ))); + } + } + Err(e) => return Err(self.error_handler(OllamaError::from(e))), + } + } + None => { + return Err(self.error_handler(OllamaError::from( + "Error while extracting tags.".to_string(), + ))) + } + } + } + + fn format_query(&self, input: &str) -> String { + format!( + "{}\nThis is the first turn and you don't have to analyze yet", + input + ) + } + + fn format_response(&self, response: &str) -> String { + format!("Agent iteration to assist with user query: {}", response) + } + + async fn get_system_message(&self, tools: &[Arc]) -> ChatMessage { + let tools_info: Vec = tools.iter().map(convert_to_openai_tool).collect(); + let tools_json = serde_json::to_string(&tools_info).unwrap(); + let system_message_content = DEFAULT_SYSTEM_TEMPLATE.replace("{tools}", &tools_json); + ChatMessage::system(system_message_content) + } + + fn error_handler(&self, error: OllamaError) -> ChatMessageResponse { + let error_message = format!( + "\nThere was an error parsing function calls\n Here's the error stack trace: {}\nPlease call the function again with correct syntax", + error + ); + + ChatMessageResponse { + model: "".to_string(), + created_at: "".to_string(), + message: Some(ChatMessage::assistant(error_message)), + done: true, + final_data: None, + } + } +} diff --git a/src/generation/functions/pipelines/openai/mod.rs b/src/generation/functions/pipelines/openai/mod.rs new file mode 100644 index 0000000..1c37516 --- /dev/null +++ b/src/generation/functions/pipelines/openai/mod.rs @@ -0,0 +1,4 @@ +pub mod prompts; +pub mod request; + +pub use prompts::{DEFAULT_RESPONSE_FUNCTION, DEFAULT_SYSTEM_TEMPLATE}; diff --git a/src/generation/functions/pipelines/openai/prompts.rs b/src/generation/functions/pipelines/openai/prompts.rs new file mode 100644 index 0000000..d497aa0 --- /dev/null +++ b/src/generation/functions/pipelines/openai/prompts.rs @@ -0,0 +1,32 @@ +pub const DEFAULT_SYSTEM_TEMPLATE: &str = r#" +You are a function calling AI agent with self-recursion. +You can call only one function at a time and analyse data you get from function response. +You have access to the following tools: + +{tools} + +Don't make assumptions about what values to plug into function arguments. +You must always select one of the above tools and respond with only a JSON object matching the following schema: + +{ + "name": , + "arguments": +} +"#; + +pub const DEFAULT_RESPONSE_FUNCTION: &str = r#" +{ + "name": "__conversational_response", + "description": "Respond conversationally if no other tools should be called for a given query.", + "parameters": { + "type": "object", + "properties": { + "response": { + "type": "string", + "description": "Conversational response to the user." + } + }, + "required": ["response"] + } +} +"#; diff --git a/src/generation/functions/pipelines/openai/request.rs b/src/generation/functions/pipelines/openai/request.rs new file mode 100644 index 0000000..4a42cd3 --- /dev/null +++ b/src/generation/functions/pipelines/openai/request.rs @@ -0,0 +1,113 @@ +use crate::error::OllamaError; +use crate::generation::chat::{ChatMessage, ChatMessageResponse}; +use crate::generation::functions::pipelines::openai::DEFAULT_SYSTEM_TEMPLATE; +use crate::generation::functions::pipelines::RequestParserBase; +use crate::generation::functions::tools::Tool; +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Map, Value}; +use std::collections::HashMap; +use std::sync::Arc; + +pub fn convert_to_openai_tool(tool: &Arc) -> Value { + let mut function = HashMap::new(); + function.insert("name".to_string(), Value::String(tool.name())); + function.insert("description".to_string(), Value::String(tool.description())); + function.insert("parameters".to_string(), tool.parameters()); + + let mut result = HashMap::new(); + result.insert("type".to_string(), Value::String("function".to_string())); + + let mapp: Map = function.into_iter().collect(); + result.insert("function".to_string(), Value::Object(mapp)); + + json!(result) +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct OpenAIFunctionCallSignature { + pub name: String, //name of the tool + pub arguments: Value, +} + +pub struct OpenAIFunctionCall {} + +impl OpenAIFunctionCall { + pub async fn function_call_with_history( + &self, + model_name: String, + tool_params: Value, + tool: Arc, + ) -> Result { + let result = tool.run(tool_params).await; + match result { + Ok(result) => Ok(ChatMessageResponse { + model: model_name.clone(), + created_at: "".to_string(), + message: Some(ChatMessage::assistant(result.to_string())), + done: true, + final_data: None, + }), + Err(e) => Err(self.error_handler(OllamaError::from(e))), + } + } + + fn clean_tool_call(&self, json_str: &str) -> String { + json_str + .trim() + .trim_start_matches("```json") + .trim_end_matches("```") + .trim() + .to_string() + } +} + +#[async_trait] +impl RequestParserBase for OpenAIFunctionCall { + async fn parse( + &self, + input: &str, + model_name: String, + tools: Vec>, + ) -> Result { + let response_value: Result = + serde_json::from_str(&self.clean_tool_call(input)); + match response_value { + Ok(response) => { + if let Some(tool) = tools.iter().find(|t| t.name() == response.name) { + let tool_params = response.arguments; + let result = self + .function_call_with_history( + model_name.clone(), + tool_params.clone(), + tool.clone(), + ) + .await?; + return Ok(result); + } else { + return Err(self.error_handler(OllamaError::from("Tool not found".to_string()))); + } + } + Err(e) => { + return Err(self.error_handler(OllamaError::from(e))); + } + } + } + + async fn get_system_message(&self, tools: &[Arc]) -> ChatMessage { + let tools_info: Vec = tools.iter().map(convert_to_openai_tool).collect(); + let tools_json = serde_json::to_string(&tools_info).unwrap(); + let system_message_content = DEFAULT_SYSTEM_TEMPLATE.replace("{tools}", &tools_json); + ChatMessage::system(system_message_content) + } + + fn error_handler(&self, error: OllamaError) -> ChatMessageResponse { + ChatMessageResponse { + model: "".to_string(), + created_at: "".to_string(), + message: Some(ChatMessage::assistant(error.to_string())), + done: true, + final_data: None, + } + } +} diff --git a/src/generation/functions/request.rs b/src/generation/functions/request.rs new file mode 100644 index 0000000..44da42e --- /dev/null +++ b/src/generation/functions/request.rs @@ -0,0 +1,36 @@ +use crate::generation::chat::request::ChatMessageRequest; +use crate::generation::chat::ChatMessage; +use crate::generation::functions::Tool; +use crate::generation::{options::GenerationOptions, parameters::FormatType}; +use std::sync::Arc; + +#[derive(Clone)] +pub struct FunctionCallRequest { + pub chat: ChatMessageRequest, + pub tools: Vec>, +} + +impl FunctionCallRequest { + pub fn new(model_name: String, tools: Vec>, messages: Vec) -> Self { + let chat = ChatMessageRequest::new(model_name, messages); + Self { chat, tools } + } + + /// Additional model parameters listed in the documentation for the Modelfile + pub fn options(mut self, options: GenerationOptions) -> Self { + self.chat.options = Some(options); + self + } + + /// The full prompt or prompt template (overrides what is defined in the Modelfile) + pub fn template(mut self, template: String) -> Self { + self.chat.template = Some(template); + self + } + + // The format to return a response in. Currently the only accepted value is `json` + pub fn format(mut self, format: FormatType) -> Self { + self.chat.format = Some(format); + self + } +} diff --git a/src/generation/functions/tools/finance.rs b/src/generation/functions/tools/finance.rs new file mode 100644 index 0000000..9161af4 --- /dev/null +++ b/src/generation/functions/tools/finance.rs @@ -0,0 +1,98 @@ +use crate::generation::functions::tools::Tool; +use async_trait::async_trait; +use reqwest::Client; +use scraper::{Html, Selector}; +use serde_json::{json, Value}; +use std::collections::HashMap; +use std::error::Error; + +pub struct StockScraper { + base_url: String, + language: String, +} + +impl Default for StockScraper { + fn default() -> Self { + Self::new() + } +} + +impl StockScraper { + pub fn new() -> Self { + StockScraper { + base_url: "https://www.google.com/finance".to_string(), + language: "en".to_string(), + } + } + + // Changed to an async function + pub async fn scrape( + &self, + exchange: &str, + ticker: &str, + ) -> Result, Box> { + let target_url = format!( + "{}/quote/{}:{}?hl={}", + self.base_url, ticker, exchange, self.language + ); + let client = Client::new(); + let response = client.get(&target_url).send().await?; // Make the request asynchronously + let content = response.text().await?; // Asynchronously get the text of the response + let document = Html::parse_document(&content); + + let items_selector = Selector::parse("div.gyFHrc").unwrap(); + let desc_selector = Selector::parse("div.mfs7Fc").unwrap(); + let value_selector = Selector::parse("div.P6K39c").unwrap(); + + let mut stock_description = HashMap::new(); + + for item in document.select(&items_selector) { + if let Some(item_description) = item.select(&desc_selector).next() { + if let Some(item_value) = item.select(&value_selector).next() { + stock_description.insert( + item_description.text().collect::>().join(""), + item_value.text().collect::>().join(""), + ); + } + } + } + + Ok(stock_description) + } +} + +#[async_trait] +impl Tool for StockScraper { + fn name(&self) -> String { + "Stock Scraper".to_string() + } + + fn description(&self) -> String { + "Scrapes stock information from Google Finance.".to_string() + } + + fn parameters(&self) -> Value { + json!({ + "type": "object", + "properties": { + "exchange": { + "type": "string", + "description": "The stock exchange market identifier code (MIC)" + }, + "ticker": { + "type": "string", + "description": "The ticker symbol of the stock" + } + }, + "required": ["exchange", "ticker"] + }) + } + + async fn run(&self, input: Value) -> Result> { + let exchange = input["exchange"].as_str().ok_or("Exchange is required")?; + let ticker = input["ticker"].as_str().ok_or("Ticker is required")?; + + let result = self.scrape(exchange, ticker).await?; + Ok(serde_json::to_string(&result)?) + } +} diff --git a/src/generation/functions/tools/mod.rs b/src/generation/functions/tools/mod.rs new file mode 100644 index 0000000..b5f4223 --- /dev/null +++ b/src/generation/functions/tools/mod.rs @@ -0,0 +1,59 @@ +pub mod finance; +pub mod scraper; +pub mod search_ddg; + +pub use self::finance::StockScraper; +pub use self::scraper::Scraper; +pub use self::search_ddg::DDGSearcher; + +use async_trait::async_trait; +use serde_json::{json, Value}; +use std::error::Error; +use std::string::String; + +#[async_trait] +pub trait Tool: Send + Sync { + /// Returns the name of the tool. + fn name(&self) -> String; + + /// Provides a description of what the tool does and when to use it. + fn description(&self) -> String; + + /// Returns the parameters for OpenAI-like function call. + fn parameters(&self) -> Value { + json!({ + "type": "object", + "properties": { + "input": { + "type": "string", + "description": self.description() + } + }, + "required": ["input"] + }) + } + + /// Processes an input string and executes the tool's functionality, returning a `Result`. + async fn call(&self, input: &str) -> Result> { + let input = self.parse_input(input).await; + self.run(input).await + } + + /// Executes the core functionality of the tool. + async fn run(&self, input: Value) -> Result>; + + /// Parses the input string. + async fn parse_input(&self, input: &str) -> Value { + log::info!("Using default implementation: {}", input); + match serde_json::from_str::(input) { + Ok(input) => { + if input["input"].is_string() { + Value::String(input["input"].as_str().unwrap().to_string()) + } else { + Value::String(input.to_string()) + } + } + Err(_) => Value::String(input.to_string()), + } + } +} diff --git a/src/generation/functions/tools/scraper.rs b/src/generation/functions/tools/scraper.rs new file mode 100644 index 0000000..2398e72 --- /dev/null +++ b/src/generation/functions/tools/scraper.rs @@ -0,0 +1,63 @@ +use crate::generation::functions::tools::Tool; +use async_trait::async_trait; +use reqwest::Client; +use scraper::{Html, Selector}; +use serde_json::{json, Value}; +use std::error::Error; + +pub struct Scraper {} + +impl Default for Scraper { + fn default() -> Self { + Self::new() + } +} + +impl Scraper { + pub fn new() -> Self { + Self {} + } +} + +#[async_trait] +impl Tool for Scraper { + fn name(&self) -> String { + "Website Scraper".to_string() + } + + fn description(&self) -> String { + "Scrapes text content from websites and splits it into manageable chunks.".to_string() + } + + fn parameters(&self) -> Value { + json!({ + "type": "object", + "properties": { + "website": { + "type": "string", + "description": "The URL of the website to scrape" + } + }, + "required": ["website"] + }) + } + + async fn run(&self, input: Value) -> Result> { + let website = input["website"].as_str().ok_or("Website URL is required")?; + let client = Client::new(); + let response = client.get(website).send().await?.text().await?; + + let document = Html::parse_document(&response); + let selector = Selector::parse("p, h1, h2, h3, h4, h5, h6").unwrap(); + let elements: Vec = document + .select(&selector) + .map(|el| el.text().collect::>().join(" ")) + .collect(); + let body = elements.join(" "); + + let sentences: Vec = body.split(". ").map(|s| s.to_string()).collect(); + let formatted_content = sentences.join("\n\n"); + + Ok(formatted_content) + } +} diff --git a/src/generation/functions/tools/search_ddg.rs b/src/generation/functions/tools/search_ddg.rs new file mode 100644 index 0000000..6793bb2 --- /dev/null +++ b/src/generation/functions/tools/search_ddg.rs @@ -0,0 +1,127 @@ +use reqwest; + +use scraper::{Html, Selector}; +use std::error::Error; + +use crate::generation::functions::tools::Tool; +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SearchResult { + title: String, + link: String, + snippet: String, +} + +pub struct DDGSearcher { + pub client: reqwest::Client, + pub base_url: String, +} + +impl Default for DDGSearcher { + fn default() -> Self { + Self::new() + } +} + +impl DDGSearcher { + pub fn new() -> Self { + DDGSearcher { + client: reqwest::Client::new(), + base_url: "https://duckduckgo.com".to_string(), + } + } + + pub async fn search(&self, query: &str) -> Result, Box> { + let url = format!("{}/html/?q={}", self.base_url, query); + let resp = self.client.get(&url).send().await?; + let body = resp.text().await?; + let document = Html::parse_document(&body); + + let result_selector = Selector::parse(".web-result").unwrap(); + let result_title_selector = Selector::parse(".result__a").unwrap(); + let result_url_selector = Selector::parse(".result__url").unwrap(); + let result_snippet_selector = Selector::parse(".result__snippet").unwrap(); + + let results = document + .select(&result_selector) + .map(|result| { + let title = result + .select(&result_title_selector) + .next() + .unwrap() + .text() + .collect::>() + .join(""); + let link = result + .select(&result_url_selector) + .next() + .unwrap() + .text() + .collect::>() + .join("") + .trim() + .to_string(); + let snippet = result + .select(&result_snippet_selector) + .next() + .unwrap() + .text() + .collect::>() + .join(""); + + SearchResult { + title, + link, + //url: String::from(url.value().attr("href").unwrap()), + snippet, + } + }) + .collect::>(); + + Ok(results) + } +} + +#[async_trait] +impl Tool for DDGSearcher { + fn name(&self) -> String { + "DDG Searcher".to_string() + } + + fn description(&self) -> String { + "Searches the web using DuckDuckGo's HTML interface.".to_string() + } + + fn parameters(&self) -> Value { + json!({ + "description": "This tool lets you search the web using DuckDuckGo. The input should be a search query.", + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query to send to DuckDuckGo" + } + }, + "required": ["query"] + }) + } + + async fn call(&self, input: &str) -> Result> { + let input_value = self.parse_input(input).await; + self.run(input_value).await + } + + async fn run(&self, input: Value) -> Result> { + let query = input["query"].as_str().unwrap(); + let results = self.search(query).await?; + let results_json = serde_json::to_string(&results)?; + Ok(results_json) + } + + async fn parse_input(&self, input: &str) -> Value { + Tool::parse_input(self, input).await + } +} diff --git a/tests/function_call.rs b/tests/function_call.rs new file mode 100644 index 0000000..8b8e7f7 --- /dev/null +++ b/tests/function_call.rs @@ -0,0 +1,95 @@ +// #![cfg(feature = "function-calling")] + +use ollama_rs::{ + generation::chat::ChatMessage, + generation::functions::tools::{DDGSearcher, Scraper, StockScraper}, + generation::functions::{FunctionCallRequest, NousFunctionCall}, + Ollama, +}; +use std::sync::Arc; + +#[tokio::test] +async fn test_send_function_call() { + /// Model to be used, make sure it is tailored towards "function calling", such as: + /// - OpenAIFunctionCall: not model specific, degraded performance + /// - NousFunctionCall: adrienbrault/nous-hermes2pro:Q8_0 + const MODEL: &str = "adrienbrault/nous-hermes2pro:Q8_0"; + + const PROMPT: &str = "Aside from the Apple Remote, what other device can control the program Apple Remote was originally designed to interact with?"; + let user_message = ChatMessage::user(PROMPT.to_string()); + + let scraper_tool = Arc::new(Scraper::new()); + let ddg_search_tool = Arc::new(DDGSearcher::new()); + let parser = Arc::new(NousFunctionCall::new()); + + let ollama = Ollama::default(); + let result = ollama + .send_function_call( + FunctionCallRequest::new( + MODEL.to_string(), + vec![scraper_tool, ddg_search_tool], + vec![user_message], + ), + parser, + ) + .await + .unwrap(); + + assert!(result.done); +} + +#[tokio::test] +async fn test_send_function_call_with_history() { + /// Model to be used, make sure it is tailored towards "function calling", such as: + /// - OpenAIFunctionCall: not model specific, degraded performance + /// - NousFunctionCall: adrienbrault/nous-hermes2pro:Q8_0 + const MODEL: &str = "adrienbrault/nous-hermes2pro:Q8_0"; + + const PROMPT: &str = "Aside from the Apple Remote, what other device can control the program Apple Remote was originally designed to interact with?"; + let user_message = ChatMessage::user(PROMPT.to_string()); + + let scraper_tool = Arc::new(Scraper::new()); + let ddg_search_tool = Arc::new(DDGSearcher::new()); + let parser = Arc::new(NousFunctionCall::new()); + + let mut ollama = Ollama::new_default_with_history(30); + let result = ollama + .send_function_call_with_history( + FunctionCallRequest::new( + MODEL.to_string(), + vec![scraper_tool, ddg_search_tool], + vec![user_message], + ), + parser, + "default".to_string(), + ) + .await + .unwrap(); + + assert!(result.done); +} + +#[tokio::test] +async fn test_send_function_call_finance() { + /// Model to be used, make sure it is tailored towards "function calling", such as: + /// - OpenAIFunctionCall: not model specific, degraded performance + /// - NousFunctionCall: adrienbrault/nous-hermes2pro:Q8_0 + const MODEL: &str = "adrienbrault/nous-hermes2pro:Q8_0"; + + const PROMPT: &str = "What are the current risk factors to $APPL?"; + let user_message = ChatMessage::user(PROMPT.to_string()); + + let stock_scraper = Arc::new(StockScraper::new()); + let parser = Arc::new(NousFunctionCall::new()); + + let ollama = Ollama::default(); + let result = ollama + .send_function_call( + FunctionCallRequest::new(MODEL.to_string(), vec![stock_scraper], vec![user_message]), + parser, + ) + .await + .unwrap(); + + assert!(result.done); +}