From ad693b2bb201b4d9280139e70a2930358e779366 Mon Sep 17 00:00:00 2001 From: PeriniM Date: Sun, 12 Jan 2025 12:45:49 +0100 Subject: [PATCH] fix: ollama tokenizer limited to 1024 tokens + ollama structured output + fix browser backend --- examples/local_models/smart_scraper_ollama.py | 4 +-- .../smart_scraper_schema_ollama.py | 29 ++++++++++++------- pyproject.toml | 3 +- scrapegraphai/docloaders/chromium.py | 4 +-- scrapegraphai/graphs/abstract_graph.py | 5 ++-- scrapegraphai/nodes/generate_answer_node.py | 10 ++++--- .../nodes/generate_answer_node_k_level.py | 13 +++++++-- scrapegraphai/nodes/parse_node.py | 4 +-- scrapegraphai/utils/split_text_into_chunks.py | 14 ++++----- scrapegraphai/utils/tokenizer.py | 28 +++--------------- .../utils/tokenizers/tokenizer_openai.py | 6 ++-- uv.lock | 4 +-- 12 files changed, 55 insertions(+), 69 deletions(-) diff --git a/examples/local_models/smart_scraper_ollama.py b/examples/local_models/smart_scraper_ollama.py index 61294eaf..b08dceb9 100644 --- a/examples/local_models/smart_scraper_ollama.py +++ b/examples/local_models/smart_scraper_ollama.py @@ -15,7 +15,7 @@ "temperature": 0, "format": "json", # Ollama needs the format to be specified explicitly # "base_url": "http://localhost:11434", # set ollama URL arbitrarily - "model_tokens": 1024, + "model_tokens": 4096, }, "verbose": True, "headless": False, @@ -25,7 +25,7 @@ # Create the SmartScraperGraph instance and run it # ************************************************ smart_scraper_graph = SmartScraperGraph( - prompt="Find some information about what does the company do, the name and a contact email.", + prompt="Find some information about what does the company do and the list of founders.", source="https://scrapegraphai.com/", config=graph_config, ) diff --git a/examples/local_models/smart_scraper_schema_ollama.py b/examples/local_models/smart_scraper_schema_ollama.py index 5a5b3cea..ae3ec849 100644 --- a/examples/local_models/smart_scraper_schema_ollama.py +++ b/examples/local_models/smart_scraper_schema_ollama.py @@ -1,12 +1,15 @@ -""" +""" Basic example of scraping pipeline using SmartScraper with schema """ + import json -from typing import List + from pydantic import BaseModel, Field + from scrapegraphai.graphs import SmartScraperGraph from scrapegraphai.utils import prettify_exec_info + # ************************************************ # Define the configuration for the graph # ************************************************ @@ -14,18 +17,15 @@ class Project(BaseModel): title: str = Field(description="The title of the project") description: str = Field(description="The description of the project") + class Projects(BaseModel): - projects: List[Project] + projects: list[Project] + graph_config = { - "llm": { - "model": "ollama/llama3.1", - "temperature": 0, - "format": "json", # Ollama needs the format to be specified explicitly - # "base_url": "http://localhost:11434", # set ollama URL arbitrarily - }, + "llm": {"model": "ollama/llama3.2", "temperature": 0, "model_tokens": 4096}, "verbose": True, - "headless": False + "headless": False, } # ************************************************ @@ -36,8 +36,15 @@ class Projects(BaseModel): prompt="List me all the projects with their description", source="https://perinim.github.io/projects/", schema=Projects, - config=graph_config + config=graph_config, ) result = smart_scraper_graph.run() print(json.dumps(result, indent=4)) + +# ************************************************ +# Get graph execution info +# ************************************************ + +graph_exec_info = smart_scraper_graph.get_execution_info() +print(prettify_exec_info(graph_exec_info)) diff --git a/pyproject.toml b/pyproject.toml index 79684ba4..5ed6568c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,8 +30,7 @@ dependencies = [ "googlesearch-python>=1.2.5", "async-timeout>=4.0.3", "simpleeval>=1.0.0", - "jsonschema>=4.23.0", - "transformers>=4.46.3", + "jsonschema>=4.23.0" ] readme = "README.md" diff --git a/scrapegraphai/docloaders/chromium.py b/scrapegraphai/docloaders/chromium.py index 2c4f142d..1d252d0d 100644 --- a/scrapegraphai/docloaders/chromium.py +++ b/scrapegraphai/docloaders/chromium.py @@ -61,7 +61,6 @@ def __init__( dynamic_import(backend, message) - self.backend = backend self.browser_config = kwargs self.headless = headless self.proxy = parse_or_search_proxy(proxy) if proxy else None @@ -69,7 +68,8 @@ def __init__( self.load_state = load_state self.requires_js_support = requires_js_support self.storage_state = storage_state - self.browser_name = browser_name + self.backend = kwargs.get("backend", backend) + self.browser_name = kwargs.get("browser_name", browser_name) self.retry_limit = kwargs.get("retry_limit", retry_limit) self.timeout = kwargs.get("timeout", timeout) diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py index 812aaf80..a56c9954 100644 --- a/scrapegraphai/graphs/abstract_graph.py +++ b/scrapegraphai/graphs/abstract_graph.py @@ -203,8 +203,9 @@ def _create_llm(self, llm_config: dict) -> object: ] except KeyError: print( - f"""Model {llm_params['model_provider']}/{llm_params['model']} not found, - using default token size (8192)""" + f"""Max input tokens for model {llm_params['model_provider']}/{llm_params['model']} not found, + please specify the model_tokens parameter in the llm section of the graph configuration. + Using default token size: 8192""" ) self.model_token = 8192 else: diff --git a/scrapegraphai/nodes/generate_answer_node.py b/scrapegraphai/nodes/generate_answer_node.py index 688300cf..58ec7772 100644 --- a/scrapegraphai/nodes/generate_answer_node.py +++ b/scrapegraphai/nodes/generate_answer_node.py @@ -10,7 +10,7 @@ from langchain_community.chat_models import ChatOllama from langchain_core.output_parsers import JsonOutputParser from langchain_core.runnables import RunnableParallel -from langchain_openai import AzureChatOpenAI, ChatOpenAI +from langchain_openai import ChatOpenAI from requests.exceptions import Timeout from tqdm import tqdm @@ -59,7 +59,10 @@ def __init__( self.llm_model = node_config["llm_model"] if isinstance(node_config["llm_model"], ChatOllama): - self.llm_model.format = "json" + if node_config.get("schema", None) is None: + self.llm_model.format = "json" + else: + self.llm_model.format = self.node_config["schema"].model_json_schema() self.verbose = node_config.get("verbose", False) self.force = node_config.get("force", False) @@ -123,8 +126,7 @@ def execute(self, state: dict) -> dict: format_instructions = "" if ( - isinstance(self.llm_model, (ChatOpenAI, AzureChatOpenAI)) - and not self.script_creator + not self.script_creator or self.force and not self.script_creator or self.is_md_scraper diff --git a/scrapegraphai/nodes/generate_answer_node_k_level.py b/scrapegraphai/nodes/generate_answer_node_k_level.py index ffea4c37..daef9d02 100644 --- a/scrapegraphai/nodes/generate_answer_node_k_level.py +++ b/scrapegraphai/nodes/generate_answer_node_k_level.py @@ -6,10 +6,11 @@ from langchain.prompts import PromptTemplate from langchain_aws import ChatBedrock +from langchain_community.chat_models import ChatOllama from langchain_core.output_parsers import JsonOutputParser from langchain_core.runnables import RunnableParallel from langchain_mistralai import ChatMistralAI -from langchain_openai import AzureChatOpenAI, ChatOpenAI +from langchain_openai import ChatOpenAI from tqdm import tqdm from ..prompts import ( @@ -55,6 +56,13 @@ def __init__( super().__init__(node_name, "node", input, output, 2, node_config) self.llm_model = node_config["llm_model"] + + if isinstance(node_config["llm_model"], ChatOllama): + if node_config.get("schema", None) is None: + self.llm_model.format = "json" + else: + self.llm_model.format = self.node_config["schema"].model_json_schema() + self.embedder_model = node_config.get("embedder_model", None) self.verbose = node_config.get("verbose", False) self.force = node_config.get("force", False) @@ -92,8 +100,7 @@ def execute(self, state: dict) -> dict: format_instructions = "" if ( - isinstance(self.llm_model, (ChatOpenAI, AzureChatOpenAI)) - and not self.script_creator + not self.script_creator or self.force and not self.script_creator or self.is_md_scraper diff --git a/scrapegraphai/nodes/parse_node.py b/scrapegraphai/nodes/parse_node.py index c73dbb40..fbc9ba31 100644 --- a/scrapegraphai/nodes/parse_node.py +++ b/scrapegraphai/nodes/parse_node.py @@ -96,7 +96,6 @@ def execute(self, state: dict) -> dict: chunks = split_text_into_chunks( text=docs_transformed.page_content, chunk_size=self.chunk_size - 250, - model=self.llm_model, ) else: docs_transformed = docs_transformed[0] @@ -115,11 +114,10 @@ def execute(self, state: dict) -> dict: chunks = split_text_into_chunks( text=docs_transformed.page_content, chunk_size=chunk_size, - model=self.llm_model, ) else: chunks = split_text_into_chunks( - text=docs_transformed, chunk_size=chunk_size, model=self.llm_model + text=docs_transformed, chunk_size=chunk_size ) state.update({self.output[0]: chunks}) diff --git a/scrapegraphai/utils/split_text_into_chunks.py b/scrapegraphai/utils/split_text_into_chunks.py index a470152c..36f05bc8 100644 --- a/scrapegraphai/utils/split_text_into_chunks.py +++ b/scrapegraphai/utils/split_text_into_chunks.py @@ -4,14 +4,10 @@ from typing import List -from langchain_core.language_models.chat_models import BaseChatModel - from .tokenizer import num_tokens_calculus -def split_text_into_chunks( - text: str, chunk_size: int, model: BaseChatModel, use_semchunk=True -) -> List[str]: +def split_text_into_chunks(text: str, chunk_size: int, use_semchunk=True) -> List[str]: """ Splits the text into chunks based on the number of tokens. @@ -27,9 +23,9 @@ def split_text_into_chunks( from semchunk import chunk def count_tokens(text): - return num_tokens_calculus(text, model) + return num_tokens_calculus(text) - chunk_size = min(chunk_size - 500, int(chunk_size * 0.9)) + chunk_size = min(chunk_size, int(chunk_size * 0.9)) chunks = chunk( text=text, chunk_size=chunk_size, token_counter=count_tokens, memoize=False @@ -37,7 +33,7 @@ def count_tokens(text): return chunks else: - tokens = num_tokens_calculus(text, model) + tokens = num_tokens_calculus(text) if tokens <= chunk_size: return [text] @@ -48,7 +44,7 @@ def count_tokens(text): words = text.split() for word in words: - word_tokens = num_tokens_calculus(word, model) + word_tokens = num_tokens_calculus(word) if current_length + word_tokens > chunk_size: chunks.append(" ".join(current_chunk)) current_chunk = [word] diff --git a/scrapegraphai/utils/tokenizer.py b/scrapegraphai/utils/tokenizer.py index 1d72c0d5..b7847bb2 100644 --- a/scrapegraphai/utils/tokenizer.py +++ b/scrapegraphai/utils/tokenizer.py @@ -2,35 +2,15 @@ Module for counting tokens and splitting text into chunks """ -from langchain_core.language_models.chat_models import BaseChatModel -from langchain_mistralai import ChatMistralAI -from langchain_ollama import ChatOllama -from langchain_openai import ChatOpenAI +from .tokenizers.tokenizer_openai import num_tokens_openai -def num_tokens_calculus(string: str, llm_model: BaseChatModel) -> int: +def num_tokens_calculus(string: str) -> int: """ Returns the number of tokens in a text string. """ - if isinstance(llm_model, ChatOpenAI): - from .tokenizers.tokenizer_openai import num_tokens_openai - num_tokens_fn = num_tokens_openai + num_tokens_fn = num_tokens_openai - elif isinstance(llm_model, ChatMistralAI): - from .tokenizers.tokenizer_mistral import num_tokens_mistral - - num_tokens_fn = num_tokens_mistral - - elif isinstance(llm_model, ChatOllama): - from .tokenizers.tokenizer_ollama import num_tokens_ollama - - num_tokens_fn = num_tokens_ollama - - else: - from .tokenizers.tokenizer_openai import num_tokens_openai - - num_tokens_fn = num_tokens_openai - - num_tokens = num_tokens_fn(string, llm_model) + num_tokens = num_tokens_fn(string) return num_tokens diff --git a/scrapegraphai/utils/tokenizers/tokenizer_openai.py b/scrapegraphai/utils/tokenizers/tokenizer_openai.py index 603e93c8..0c3b2c2e 100644 --- a/scrapegraphai/utils/tokenizers/tokenizer_openai.py +++ b/scrapegraphai/utils/tokenizers/tokenizer_openai.py @@ -3,19 +3,17 @@ """ import tiktoken -from langchain_core.language_models.chat_models import BaseChatModel from ..logging import get_logger -def num_tokens_openai(text: str, llm_model: BaseChatModel) -> int: +def num_tokens_openai(text: str) -> int: """ Estimate the number of tokens in a given text using OpenAI's tokenization method, adjusted for different OpenAI models. Args: text (str): The text to be tokenized and counted. - llm_model (BaseChatModel): The specific OpenAI model to adjust tokenization. Returns: int: The number of tokens in the text. @@ -25,7 +23,7 @@ def num_tokens_openai(text: str, llm_model: BaseChatModel) -> int: logger.debug(f"Counting tokens for text of {len(text)} characters") - encoding = tiktoken.encoding_for_model("gpt-4") + encoding = tiktoken.encoding_for_model("gpt-4o") num_tokens = len(encoding.encode(text)) return num_tokens diff --git a/uv.lock b/uv.lock index ff6f13f2..ef0623ed 100644 --- a/uv.lock +++ b/uv.lock @@ -3429,7 +3429,7 @@ wheels = [ [[package]] name = "scrapegraphai" -version = "1.35.0b2" +version = "1.35.0" source = { editable = "." } dependencies = [ { name = "async-timeout", version = "4.0.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.12'" }, @@ -3452,7 +3452,6 @@ dependencies = [ { name = "simpleeval" }, { name = "tiktoken" }, { name = "tqdm" }, - { name = "transformers" }, { name = "undetected-playwright" }, ] @@ -3516,7 +3515,6 @@ requires-dist = [ { name = "surya-ocr", marker = "extra == 'ocr'", specifier = ">=0.5.0" }, { name = "tiktoken", specifier = ">=0.7" }, { name = "tqdm", specifier = ">=4.66.4" }, - { name = "transformers", specifier = ">=4.46.3" }, { name = "undetected-playwright", specifier = ">=0.3.0" }, ]