diff --git a/examples/textgrad_examples/evals/textgrad_solution_optimization.py b/examples/textgrad_examples/evals/textgrad_solution_optimization.py
index 92c3a9e..abd2cfc 100644
--- a/examples/textgrad_examples/evals/textgrad_solution_optimization.py
+++ b/examples/textgrad_examples/evals/textgrad_solution_optimization.py
@@ -1,7 +1,7 @@
# This script applies Trace to optimize the workflow in TextGrad's solution_optimization.py.
from opto import trace
-from opto.optimizers import OptoPrime, TextGrad
+from opto.optimizers import OptoPrime, TextGrad, OptoPrimeMulti
import re
import json
@@ -162,6 +162,12 @@ def run_trace_test_time_training(sample):
if args.algo == "textgrad":
# This runs Trace's TextGrad optimizer
optimizer = TextGrad([instance_var], max_tokens=16383)
+ elif args.algo == 'opto_multi':
+ # This runs Trace's OptoPrimeMulti optimizer
+ optimizer = OptoPrimeMulti([instance_var],
+ prompt_symbols={'variables': '#Parameters'},
+ num_responses=3,
+ max_tokens=16383)
else: # This runs Trace's OptoPrime optimizer
optimizer = OptoPrime([instance_var],
prompt_symbols={'variables': '#Parameters'},
@@ -215,7 +221,7 @@ def backfill(regret, maxlen):
args = config()
-assert args.algo in ["textgrad", "trace", "ttextgrad"], "ttextgrad is Trace's implementation textgrad"
+assert args.algo in ["textgrad", "trace", "ttextgrad", 'opto_multi'], "ttextgrad is original implementation textgrad"
llm_engine = tg.get_engine(engine_name=args.engine)
tg.set_backward_engine(llm_engine, override=True)
@@ -228,7 +234,7 @@ def backfill(regret, maxlen):
with concurrent.futures.ThreadPoolExecutor(max_workers=args.num_threads) as executor:
futures = []
for i, sample in enumerate(test_set):
- if args.algo in ["trace", 'textgrad']:
+ if args.algo in ["trace", 'textgrad', 'opto_multi']:
future = executor.submit(run_trace_test_time_training, sample)
else:
future = executor.submit(run_test_time_training, sample)
diff --git a/examples/textgrad_examples/notebooks/textgrad_test_time_loss_for_code_OptoPrimeMulti.ipynb b/examples/textgrad_examples/notebooks/textgrad_test_time_loss_for_code_OptoPrimeMulti.ipynb
new file mode 100644
index 0000000..a5881d0
--- /dev/null
+++ b/examples/textgrad_examples/notebooks/textgrad_test_time_loss_for_code_OptoPrimeMulti.ipynb
@@ -0,0 +1,600 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "-WGqLq5vb7Jm",
+ "outputId": "f18bba80-41ae-4473-f426-4fb8f5746082"
+ },
+ "outputs": [],
+ "source": [
+ "%pip install textgrad\n",
+ "%pip install git+https://github.com/microsoft/Trace.git\n",
+ "%pip install dask[dataframe]\n",
+ "%pip install autogen"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "7MTXRbDhcHAP"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import openai\n",
+ "\n",
+ "os.environ['OPENAI_API_KEY'] = \"\"\n",
+ "\n",
+ "OAI_CONFIG_LIST = [ { \"model\": \"gpt-4o-mini\", \"api_key\": os.environ['OPENAI_API_KEY'],}]\n",
+ "\n",
+ "import json; config_file_path = \"/content/config_list.json\"; json.dump(OAI_CONFIG_LIST, open(config_file_path, \"w\")); os.environ['OAI_CONFIG_LIST'] = config_file_path"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "JftiVF1eb0rH"
+ },
+ "outputs": [],
+ "source": [
+ "from opto import trace\n",
+ "from opto.optimizers import OptoPrime, OptoPrimeMulti\n",
+ "\n",
+ "import random\n",
+ "import time"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "id": "cO6nE2LPb0rH"
+ },
+ "outputs": [],
+ "source": [
+ "# We'll use below utilities to run a python function.\n",
+ "from IPython.core.interactiveshell import InteractiveShell\n",
+ "\n",
+ "def run_function_in_interpreter(func_code):\n",
+ " # raise Exception(\"This function will run the code returned by GPT-4o. Remove this if you'd like to run the code!\")\n",
+ " interpreter = InteractiveShell.instance()\n",
+ "\n",
+ " interpreter.run_cell(func_code, store_history=False, silent=True)\n",
+ "\n",
+ " func_name = func_code.split(\"def \")[1].split(\"(\")[0].strip()\n",
+ " func = interpreter.user_ns[func_name]\n",
+ "\n",
+ " return func\n",
+ "\n",
+ "\n",
+ "\n",
+ "def test_longest_increasing_subsequence(fn):\n",
+ " nums = [10, 22, 9, 33, 21, 50, 41, 60]\n",
+ " assert fn(nums) == 5\n",
+ "\n",
+ " nums = [7, 2, 1, 3, 8, 4, 9, 6, 5]\n",
+ " assert fn(nums) == 4\n",
+ "\n",
+ " nums = [5, 4, 3, 2, 1]\n",
+ " assert fn(nums) == 1\n",
+ "\n",
+ " nums = [1, 2, 3, 4, 5]\n",
+ " assert fn(nums) == 5\n",
+ "\n",
+ " nums = [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5]\n",
+ " assert fn(nums) == 4\n",
+ "\n",
+ " nums = [10, 9, 2, 5, 3, 7, 101, 18]\n",
+ " assert fn(nums) == 4\n",
+ "\n",
+ " nums = [0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15]\n",
+ " assert fn(nums) == 6\n",
+ "\n",
+ " nums = [7, 7, 7, 7, 7, 7, 7]\n",
+ " assert fn(nums) == 1\n",
+ "\n",
+ " nums = [20, 25, 47, 35, 56, 68, 98, 101, 212, 301, 415, 500]\n",
+ " assert fn(nums) == 11\n",
+ "\n",
+ " nums = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0]\n",
+ " assert fn(nums) == 1\n",
+ "\n",
+ " print(\"All test cases passed!\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "id": "OYkVBYtkb0rH"
+ },
+ "outputs": [],
+ "source": [
+ "problem_text = \"\"\"Longest Increasing Subsequence (LIS)\n",
+ "\n",
+ "Problem Statement:\n",
+ "Given a sequence of integers, find the length of the longest subsequence that is strictly increasing. A subsequence is a sequence that can be derived from another sequence by deleting some or no elements without changing the order of the remaining elements.\n",
+ "\n",
+ "Input:\n",
+ "The input consists of a list of integers representing the sequence.\n",
+ "\n",
+ "Output:\n",
+ "The output should be an integer representing the length of the longest increasing subsequence.\"\"\"\n",
+ "\n",
+ "initial_solution = \"\"\"\n",
+ "def longest_increasing_subsequence(nums):\n",
+ " n = len(nums)\n",
+ " dp = [1] * n\n",
+ "\n",
+ " for i in range(1, n):\n",
+ " for j in range(i):\n",
+ " if nums[i] > nums[j]:\n",
+ " dp[i] = max(dp[i], dp[j] + 1)\n",
+ "\n",
+ " max_length = max(dp)\n",
+ " lis = []\n",
+ "\n",
+ " for i in range(n - 1, -1, -1):\n",
+ " if dp[i] == max_length:\n",
+ " lis.append(nums[i])\n",
+ " max_length -= 1\n",
+ "\n",
+ " return len(lis[::-1])\n",
+ "\"\"\"\n",
+ "\n",
+ "# Generate a random test case\n",
+ "def generate_random_test_case(size, min_value, max_value):\n",
+ " return [random.randint(min_value, max_value) for _ in range(size)]\n",
+ "\n",
+ "# Test the function with a random test case\n",
+ "size = 10000 # Adjust the size as needed\n",
+ "min_value = 1\n",
+ "max_value = 1000\n",
+ "\n",
+ "nums = generate_random_test_case(size, min_value, max_value)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "fqTOVqftb0rI",
+ "outputId": "a7506eb0-7cc0-47c7-facc-76f93541117e"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Test Case Size: 10000\n",
+ "Longest Increasing Subsequence Length: 176\n",
+ "Runtime: 12.51021 seconds\n",
+ "All test cases passed!\n"
+ ]
+ }
+ ],
+ "source": [
+ "longest_increasing_subsequence = run_function_in_interpreter(initial_solution)\n",
+ "\n",
+ "start_time = time.time()\n",
+ "lis = longest_increasing_subsequence(nums)\n",
+ "end_time = time.time()\n",
+ "\n",
+ "print(f\"Test Case Size: {size}\")\n",
+ "print(f\"Longest Increasing Subsequence Length: {lis}\")\n",
+ "print(f\"Runtime: {end_time - start_time:.5f} seconds\")\n",
+ "\n",
+ "# Test for all test cases\n",
+ "test_longest_increasing_subsequence(longest_increasing_subsequence)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Tj_xi5Jib0rI"
+ },
+ "source": [
+ "# Trace code"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 645
+ },
+ "id": "wSgBz-Amb0rI",
+ "outputId": "299e048c-d0db-4b9c-c557-f72fd962d4a2"
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "PYDEV DEBUGGER WARNING:\n",
+ "sys.settrace() should not be used when the debugger is being used.\n",
+ "This may cause the debugger to stop working correctly.\n",
+ "If this is needed, please check: \n",
+ "http://pydev.blogspot.com/2007/06/why-cant-pydev-debugger-work-with.html\n",
+ "to see how to restore the debug tracing back correctly.\n",
+ "Call Location:\n",
+ " File \"/usr/local/lib/python3.10/dist-packages/opto/trace/bundle.py\", line 359, in sync_call_fun\n",
+ " sys.settrace(oldtracer)\n",
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "\n",
+ "code = trace.node(initial_solution, trainable=True)\n",
+ "opt = OptoPrimeMulti([code])\n",
+ "\n",
+ "feedback = \"Think about the problem and the code snippet. Does the code solve the problem? What is the runtime complexity? Improve the runtime complexity of the code.\"\n",
+ "format_string = \"Problem: {problem_text}\\nCurrent Code: {solution}\"\n",
+ "\n",
+ "from opto.trace import operators as ops\n",
+ "problem = ops.format(format_string, problem_text=problem_text, solution=code)\n",
+ "opt.zero_feedback()\n",
+ "\n",
+ "# Let's visualize our computation graph.\n",
+ "problem.backward(feedback, visualize=True)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "BFgB5Ngfb0rJ",
+ "outputId": "6f5c2f0f-693a-4954-9898-acd59addaf08"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Temperatures for responses: [1.3, 0.9750000000000001, 0.65, 0.32499999999999996, 0.0]\n",
+ "LLM responses:\n",
+ " ['{\\n\"reasoning\": \"The #Instruction asks for changes in #Variables based on #Feedback about improving the output. The #Feedback indicates that while the current code solves the problem of finding the length of the longest increasing subsequence, the runtime complexity of the algorithm is O(n^2), which can be optimized to O(n log n) using binary search. Therefore, the suggested improvement involves changing `str0` to implement a more efficient version of the longest increasing subsequence algorithm. A common approach is to utilize a list to track the smallest tail for all subsequences of a given length.\",\\n\"answer\": \"\",\\n\"suggestion\": {\\n \"str0\": \"def longest_increasing_subsequence(nums):\\\\n from bisect import bisect_left\\\\n if not nums:\\\\n return 0\\\\n lis = []\\\\n for x in nums:\\\\n i = bisect_left(lis, x)\\\\n if i == len(lis):\\\\n lis.append(x)\\\\n else:\\\\n lis[i] = x\\\\n return len(lis)\"\\n}\\n}']\n",
+ "LLM responses:\n",
+ " ['{\\n\"reasoning\": \"The instruction asks to improve the output based on the feedback. The feedback indicates that while the provided code for finding the longest increasing subsequence (LIS) works, its runtime complexity can be enhanced. The current solution has a time complexity of O(n^2), which can be improved to O(n log n) using a more efficient algorithm such as binary search with a dynamic array. This involves using a list to keep track of the smallest tail of all increasing subsequences of different lengths and applying binary search to maintain this list when inserting new elements. Therefore, I suggest changing the implementation of the function `str0` to include this optimized approach. The expected result is a new version of the LIS function that operates with better efficiency.\",\\n\"answer\": \"The updated code for the longest increasing subsequence implementation should look like this:\\\\ndef longest_increasing_subsequence(nums):\\\\n from bisect import bisect_left\\\\n lis = []\\\\n for num in nums:\\\\n pos = bisect_left(lis, num)\\\\n if pos == len(lis):\\\\n lis.append(num)\\\\n else:\\\\n lis[pos] = num\\\\n return len(lis)\",\\n\"suggestion\": {\\n \"str0\": \"def longest_increasing_subsequence(nums):\\\\n from bisect import bisect_left\\\\n lis = []\\\\n for num in nums:\\\\n pos = bisect_left(lis, num)\\\\n if pos == len(lis):\\\\n lis.append(num)\\\\n else:\\\\n lis[pos] = num\\\\n return len(lis)\"\\n}}\\n']\n",
+ "LLM responses:\n",
+ " ['{\\n\"reasoning\": \"The instruction asks to improve the output based on the feedback provided. The feedback suggests that while the current code does solve the problem of finding the length of the longest increasing subsequence, its runtime complexity can be improved. The current code has a time complexity of O(n^2) due to the nested loops used for comparing elements. A more efficient approach would be to use a binary search algorithm that can reduce the complexity to O(n log n). Therefore, I need to suggest a change in the implementation of the function \\'longest_increasing_subsequence\\' to improve its runtime complexity.\",\\n\"answer\": \"The improved solution should use a dynamic programming approach with binary search.\",\\n\"suggestion\": {\\n \"str0\": \"def longest_increasing_subsequence(nums):\\\\n from bisect import bisect_left\\\\n subsequence = []\\\\n for num in nums:\\\\n pos = bisect_left(subsequence, num)\\\\n if pos == len(subsequence):\\\\n subsequence.append(num)\\\\n else:\\\\n subsequence[pos] = num\\\\n return len(subsequence)\"\\n}\\n}']\n",
+ "LLM responses:\n",
+ " ['{\\n\"reasoning\": \"The instruction asks to improve the output based on the feedback, which suggests that while the current code solves the problem of finding the length of the longest increasing subsequence (LIS), it could be optimized for better runtime complexity. The current implementation has a time complexity of O(n^2) due to the nested loops. To improve this, we can implement a more efficient algorithm that uses binary search, reducing the time complexity to O(n log n). The suggested changes involve modifying the `str0` variable to include a new implementation of the LIS algorithm that utilizes this optimized approach. This change is expected to enhance the performance of the code without altering its functionality.\", \\n\"answer\": \"The current implementation of the longest increasing subsequence function can be improved for better runtime complexity.\", \\n\"suggestion\": {\\n \"str0\": \"def longest_increasing_subsequence(nums):\\\\n from bisect import bisect_left\\\\n lis = []\\\\n for num in nums:\\\\n pos = bisect_left(lis, num)\\\\n if pos == len(lis):\\\\n lis.append(num)\\\\n else:\\\\n lis[pos] = num\\\\n return len(lis)\"\\n}}\\n']\n",
+ "LLM responses:\n",
+ " ['{\\n\"reasoning\": \"The instruction asks to improve the output based on the feedback provided. The feedback suggests that while the code solves the problem of finding the longest increasing subsequence (LIS), it has a runtime complexity of O(n^2) due to the nested loops. To improve the runtime complexity, we can implement a more efficient algorithm that uses binary search, which can reduce the complexity to O(n log n). Therefore, I will suggest a new implementation for the function \\'longest_increasing_subsequence\\' that utilizes binary search to achieve this improved performance.\",\\n\"answer\": \"The current code does solve the problem of finding the longest increasing subsequence, but it can be optimized for better performance.\",\\n\"suggestion\": {\\n \"str0\": \"def longest_increasing_subsequence(nums):\\\\n from bisect import bisect_left\\\\n lis = []\\\\n for num in nums:\\\\n pos = bisect_left(lis, num)\\\\n if pos == len(lis):\\\\n lis.append(num)\\\\n else:\\\\n lis[pos] = num\\\\n return len(lis)\"\\n}\\n}']\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Let's update the code\n",
+ "opt.step(verbose='output')\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "h9Peebbkc5Mv",
+ "outputId": "31a2bc3c-6835-444b-b42e-aadfcb61e6be"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{: 'def longest_increasing_subsequence(nums):\\n from bisect import bisect_left\\n if not nums:\\n return 0\\n lis = []\\n for x in nums:\\n i = bisect_left(lis, x)\\n if i == len(lis):\\n lis.append(x)\\n else:\\n lis[i] = x\\n return len(lis)'}\n",
+ "\n",
+ "{: 'def longest_increasing_subsequence(nums):\\n from bisect import bisect_left\\n lis = []\\n for num in nums:\\n pos = bisect_left(lis, num)\\n if pos == len(lis):\\n lis.append(num)\\n else:\\n lis[pos] = num\\n return len(lis)'}\n",
+ "\n",
+ "{: 'def longest_increasing_subsequence(nums):\\n from bisect import bisect_left\\n subsequence = []\\n for num in nums:\\n pos = bisect_left(subsequence, num)\\n if pos == len(subsequence):\\n subsequence.append(num)\\n else:\\n subsequence[pos] = num\\n return len(subsequence)'}\n",
+ "\n",
+ "{: 'def longest_increasing_subsequence(nums):\\n from bisect import bisect_left\\n lis = []\\n for num in nums:\\n pos = bisect_left(lis, num)\\n if pos == len(lis):\\n lis.append(num)\\n else:\\n lis[pos] = num\\n return len(lis)'}\n",
+ "\n",
+ "{: 'def longest_increasing_subsequence(nums):\\n from bisect import bisect_left\\n lis = []\\n for num in nums:\\n pos = bisect_left(lis, num)\\n if pos == len(lis):\\n lis.append(num)\\n else:\\n lis[pos] = num\\n return len(lis)'}\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "for c in opt.candidates:\n",
+ " print(f\"{c}\\n\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "RC6AQahriWFc",
+ "outputId": "282e7f14-cf29-4f73-93a1-49de05ae837d"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Candidate 1:\n",
+ " Longest Increasing Subsequence Length: 176\n",
+ " Runtime: 0.00793 seconds\n",
+ " Code: <<>>\n",
+ "\n",
+ "Candidate 2:\n",
+ " Longest Increasing Subsequence Length: 176\n",
+ " Runtime: 0.01296 seconds\n",
+ " Code: <<>>\n",
+ "\n",
+ "Candidate 3:\n",
+ " Longest Increasing Subsequence Length: 176\n",
+ " Runtime: 0.01116 seconds\n",
+ " Code: <<>>\n",
+ "\n",
+ "Candidate 4:\n",
+ " Longest Increasing Subsequence Length: 176\n",
+ " Runtime: 0.01040 seconds\n",
+ " Code: <<>>\n",
+ "\n",
+ "Candidate 5:\n",
+ " Longest Increasing Subsequence Length: 176\n",
+ " Runtime: 0.01898 seconds\n",
+ " Code: <<>>\n",
+ "\n",
+ "Execution Summary:\n",
+ "Candidate 1: Result = 176, Runtime = 0.00793 seconds\n",
+ "Candidate 2: Result = 176, Runtime = 0.01296 seconds\n",
+ "Candidate 3: Result = 176, Runtime = 0.01116 seconds\n",
+ "Candidate 4: Result = 176, Runtime = 0.01040 seconds\n",
+ "Candidate 5: Result = 176, Runtime = 0.01898 seconds\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Test all candidates and log execution times\n",
+ "execution_results = []\n",
+ "\n",
+ "for i, candidate in enumerate(opt.candidates):\n",
+ " if not candidate: # Skip invalid candidates\n",
+ " print(f\"Candidate {i+1}: Skipped (Invalid)\")\n",
+ " continue\n",
+ "\n",
+ " # Extract the function code from the dictionary\n",
+ " func_code = list(candidate.values())[0] # Assumes there's only one key-value pair in the dictionary\n",
+ " if not func_code:\n",
+ " print(f\"Candidate {i+1}: No code found\")\n",
+ " continue\n",
+ "\n",
+ " # Compile and run the function\n",
+ " func = run_function_in_interpreter(func_code) # Extract and run candidate function\n",
+ " try:\n",
+ " start_time = time.time()\n",
+ " result = func(nums) # Test the function\n",
+ " end_time = time.time()\n",
+ "\n",
+ " runtime = end_time - start_time\n",
+ " execution_results.append({\n",
+ " \"candidate\": i + 1,\n",
+ " \"result\": result,\n",
+ " \"runtime\": runtime\n",
+ " })\n",
+ "\n",
+ " func_code_nonl = func_code.replace('\\n',' ')\n",
+ " print(f\"Candidate {i+1}:\")\n",
+ " print(f\" Longest Increasing Subsequence Length: {result}\")\n",
+ " print(f\" Runtime: {runtime:.5f} seconds\")\n",
+ " print(f\" Code: <<<{func_code_nonl}>>>\\n\")\n",
+ "\n",
+ " except Exception as e:\n",
+ " print(f\"Candidate {i+1}: Failed with error: {e}\\n\")\n",
+ "\n",
+ "# Display a summary of all candidate results\n",
+ "print(\"Execution Summary:\")\n",
+ "for res in execution_results:\n",
+ " print(f\"Candidate {res['candidate']}: Result = {res['result']}, Runtime = {res['runtime']:.5f} seconds\")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "J73bk2Ieb0rJ",
+ "outputId": "6b7d059c-03bc-4bae-f764-aaad8513693e"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Longest Increasing Subsequence Length: 176\n",
+ "Runtime: 0.00555 seconds\n",
+ "All test cases passed!\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Hopefully, we should get much better runtime!\n",
+ "longest_increasing_subsequence = run_function_in_interpreter(code.data)\n",
+ "\n",
+ "start_time = time.time()\n",
+ "lis = longest_increasing_subsequence(nums)\n",
+ "end_time = time.time()\n",
+ "\n",
+ "print(f\"Longest Increasing Subsequence Length: {lis}\")\n",
+ "print(f\"Runtime: {end_time - start_time:.5f} seconds\")\n",
+ "\n",
+ "test_longest_increasing_subsequence(longest_increasing_subsequence)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "oKL459B3b0rJ"
+ },
+ "source": [
+ "At this point, OptoPrime in Trace solves the problem. There's no need to further iterate."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "6wgWIOZ9b0rJ",
+ "outputId": "e8c91b2e-7dee-4254-911b-ad91b035a3d7"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "def longest_increasing_subsequence(nums):\n",
+ " from bisect import bisect_left\n",
+ " lis = []\n",
+ " for num in nums:\n",
+ " pos = bisect_left(lis, num)\n",
+ " if pos == len(lis):\n",
+ " lis.append(num)\n",
+ " else:\n",
+ " lis[pos] = num\n",
+ " return len(lis)\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(code.data)"
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "trace-3.9",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.19"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/opto/optimizers/__init__.py b/opto/optimizers/__init__.py
index 362eff5..74a0bdd 100644
--- a/opto/optimizers/__init__.py
+++ b/opto/optimizers/__init__.py
@@ -1,5 +1,6 @@
from opto.optimizers.optoprime import OptoPrime
+from opto.optimizers.optoprimemulti import OptoPrimeMulti
from opto.optimizers.opro import OPRO
from opto.optimizers.textgrad import TextGrad
-__all__ = ["OPRO", "OptoPrime", "TextGrad"]
\ No newline at end of file
+__all__ = ["OPRO", "OptoPrime", "OptoPrimeMulti", "TextGrad"]
\ No newline at end of file
diff --git a/opto/optimizers/optoprimemulti.py b/opto/optimizers/optoprimemulti.py
new file mode 100644
index 0000000..86461d0
--- /dev/null
+++ b/opto/optimizers/optoprimemulti.py
@@ -0,0 +1,255 @@
+from typing import Any, List, Dict, Union, Tuple, Optional
+import json
+from textwrap import dedent
+
+from opto.trace.propagators import GraphPropagator
+from opto.optimizers.optoprime import OptoPrime
+
+
+class OptoPrimeMulti(OptoPrime):
+ def __init__(self, *args,
+ num_responses: int = 5,
+ temperature_range: Optional[List[float]] = None,
+ selector: Optional[callable] = None,
+ **kwargs):
+ super().__init__(*args, **kwargs)
+ if temperature_range is None:
+ self.temperature_range = [1.3, 0.]
+ self.candidates = [] # Store all candidate solutions
+ self.selected_candidate = None # Store the selected candidate solution
+ self.num_responses = num_responses
+ self.selector = selector
+ self.use_synthesis = False
+
+ def call_llm(
+ self, system_prompt: str, user_prompt: str, verbose: Union[bool, str] = False,
+ max_tokens: int = 4096, num_responses: int = 1, temperature: float = 0.
+ ) -> List[str]:
+ """Call the LLM with a prompt and return multiple responses."""
+ if verbose not in (False, "output"):
+ print("Prompt\n", system_prompt + user_prompt)
+
+ messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
+
+ try:
+ response = self.llm.create(
+ messages=messages,
+ response_format={"type": "json_object"},
+ max_tokens=max_tokens,
+ n=num_responses,
+ temperature=temperature,
+ )
+ except Exception as e:
+ if verbose:
+ print(f"ERROR {e}")
+ # Default to returning an empty response list if an error occurs # Error handling improvement
+ return []
+
+ responses = [choice.message.content for choice in response.choices]
+
+ if verbose:
+ print("LLM responses:\n", responses)
+ return responses
+
+ def generate_candidates(
+ self, summary, system_prompt: str, user_prompt: str, verbose: Union[bool, str] = False,
+ mask=None, max_tokens: int = None, num_responses: Optional[int] = None, temperature_range: Optional[List[float]] = None, generation_technique: str = "temperature_variation"
+ ) -> List[str]:
+ """
+ Generate multiple candidates using configurable techniques.
+ Args:
+ summary: The summarized problem instance.
+ system_prompt (str): The system-level prompt.
+ user_prompt (str): The user-level prompt.
+ verbose (bool): Whether to print debug information.
+ mask: Mask for the problem instance.
+ max_tokens (int, optional): Maximum token limit for the LLM responses.
+ num_responses (int): Number of responses to request.
+ temperature_range (List[float]): [max_temperature, min_temperature].
+ generation_technique (str): Technique for generating candidates. Options:
+ - "temperature_variation": Use temperature range for diversity (default).
+ - "self_refinement": Iteratively refine candidates using self-feedback.
+ - "iterative_alternatives": Find new alternative optimal solutions given previous candidates.
+ Returns:
+ List[str]: List of LLM responses as strings.
+ """
+ num_responses = num_responses if num_responses is not None else self.num_responses # Allow overriding num_responses
+ temperature_range = temperature_range if temperature_range is not None else self.temperature_range
+ max_tokens = max_tokens or self.max_tokens # Allow overriding max_tokens
+
+ candidates = []
+
+ # Temperature Variation (Original Logic)
+ if generation_technique == "temperature_variation":
+ self.use_synthesis = True # Enable synthesis for the final selection
+ max_temp, min_temp = max(temperature_range), min(temperature_range)
+ temperatures = [
+ max_temp - i * (max_temp - min_temp) / max(1, num_responses - 1)
+ for i in range(num_responses)
+ ]
+
+ if verbose:
+ print(f"Temperatures for responses: {temperatures}")
+
+ candidates = [
+ self.call_llm(
+ system_prompt=system_prompt,
+ user_prompt=user_prompt,
+ verbose=verbose,
+ max_tokens=max_tokens,
+ num_responses=1,
+ temperature=temp
+ )[0] # Extract the single response
+ for temp in temperatures
+ ]
+
+ # Self-Refinement
+ elif generation_technique == "self_refinement":
+ for _ in range(num_responses):
+ if not candidates: # First candidate, no refinement needed
+ current_prompt = system_prompt
+ else: # Refine the last candidate
+ current_prompt = f"{system_prompt}\nRefine the following solution: {candidates[-1]}"
+
+ candidate = self.call_llm(
+ system_prompt=current_prompt,
+ user_prompt=user_prompt,
+ verbose=verbose,
+ max_tokens=max_tokens,
+ num_responses=1,
+ temperature=0. # Deterministic output
+ )[0]
+ candidates.append(candidate)
+
+ # Iterative Alternatives
+ elif generation_technique == "iterative_alternatives":
+ self.use_synthesis = True # Enable synthesis for the final selection
+ for i in range(num_responses):
+ if not candidates: # First candidate, no alternatives yet
+ current_prompt = system_prompt
+ else: # Generate a new alternative based on previous candidates
+ previous_solutions = "\n".join(
+ f"SOLUTION {idx + 1}: <<<{candidate}>>>"
+ for idx, candidate in enumerate(candidates)
+ )
+ current_prompt = (
+ f"{system_prompt}\nGiven the following solutions, propose a new alternative optimal solution:\n"
+ f"{previous_solutions}\n{user_prompt}"
+ )
+
+ candidate = self.call_llm(
+ system_prompt=current_prompt,
+ user_prompt=user_prompt,
+ verbose=verbose,
+ max_tokens=max_tokens,
+ num_responses=1,
+ temperature=0. # Deterministic output
+ )[0]
+ candidates.append(candidate)
+
+ else:
+ raise ValueError(f"Invalid generation_technique: {generation_technique}. "
+ "Supported options: 'temperature_variation', 'self_refinement', "
+ "'iterative_alternatives'.")
+
+ # Log the generated candidates
+ if self.log is not None:
+ self.log.append({"system_prompt": system_prompt, "user_prompt": user_prompt, "response": candidates})
+ self.summary_log.append({'problem_instance': self.problem_instance(summary), 'summary': summary})
+
+ return candidates
+
+ def select_candidate(self, candidates: List[Dict], use_synthesis: bool = False) -> Dict:
+ """
+ Select the best response based on the responses.
+ Args:
+ candidates (List[Dict]): List of candidate responses as dictionaries.
+ use_synthesis (bool): If True, synthesize an optimal solution from all candidates.
+ Returns:
+ Dict: The selected candidate or an empty dictionary if no candidates exist.
+ """
+ if not candidates:
+ return {}
+
+ # Default behavior: return the last candidate
+ if not use_synthesis:
+ return candidates[-1]
+
+ # Synthesize an optimal solution from all candidates
+ candidate_texts = [f"SOLUTION {i + 1}: <<<{json.dumps(candidate, indent=2)}>>>" for i, candidate in enumerate(candidates)]
+ synthesis_prompt = (
+ "Given the following solutions and the initial question, provide an optimal solution by combining the best elements of each. Follow the same output structure as the candidates.\n\n"
+ "Candidates:\n" + "\n".join(candidate_texts) + "\n\n"
+ "Optimal Solution:\n"
+ )
+
+ # Call the LLM to synthesize the optimal solution
+ synthesized_response = self.call_llm(
+ system_prompt="You are an expert optimizer. Synthesize the best solution from the given candidates.",
+ user_prompt=synthesis_prompt,
+ verbose=False,
+ #max_tokens=??,
+ num_responses=1,
+ temperature=0.3 # Low temperature for deterministic synthesis
+ )
+
+ if synthesized_response:
+ try:
+ return json.loads(synthesized_response[0])
+ except json.JSONDecodeError:
+ # Fallback to the last candidate if synthesis fails
+ return candidates[-1]
+ else:
+ # Fallback to the last candidate if synthesis fails
+ return candidates[-1]
+
+ def _step(
+ self, verbose=False, mask=None, num_responses: Optional[int] = None, temperature_range: Optional[List[float]] = None,
+ selector: callable = None, *args, **kwargs
+ ) -> Dict: # Added type annotation for return value
+ """
+ Perform a single optimization step, storing responses in self.responses and allowing selection.
+ Args:
+ verbose (bool): Whether to print debug information.
+ mask (list, optional): Mask for the problem instance.
+ num_responses (int): Number of responses to request from the LLM.
+ temperature (float): Sampling temperature for the LLM.
+ selector (callable, optional): Function to select the best response.
+ Returns:
+ Dict: The update dictionary based on the selected response.
+ """
+ num_responses = num_responses if num_responses is not None else self.num_responses # Allow overriding num_responses
+ temperature_range = temperature_range if temperature_range is not None else self.temperature_range
+ selector = selector if selector is not None else self.selector
+
+ assert isinstance(self.propagator, GraphPropagator)
+ summary = self.summarize()
+ system_prompt, user_prompt = self.construct_prompt(summary, mask=mask)
+
+ system_prompt = self.replace_symbols(system_prompt, self.prompt_symbols)
+ user_prompt = self.replace_symbols(user_prompt, self.prompt_symbols)
+
+ # Generate candidates
+ responses = self.generate_candidates(
+ summary, system_prompt, user_prompt, verbose=verbose, mask=mask,
+ num_responses=num_responses, temperature_range=temperature_range
+ )
+
+ self.candidates = [] # Clear previous responses
+ for response in responses:
+ if "TERMINATE" in response:
+ self.candidates.append({})
+ continue
+
+ suggestion = self.extract_llm_suggestion(response)
+ update_dict = self.construct_update_dict(suggestion)
+
+ self.candidates.append(update_dict)
+
+ # Select the response using the selector or the default select_candidate method
+ if selector and callable(selector): # Ensure the selector is callable
+ self.selected_candidate = selector(self.candidates)
+ else:
+ self.selected_candidate = self.select_candidate(candidates=self.candidates, use_synthesis=self.use_synthesis)
+
+ return self.selected_candidate