Skip to content

Commit

Permalink
fix: ollama tokenizer limited to 1024 tokens + ollama structured outp…
Browse files Browse the repository at this point in the history
…ut + fix browser backend
  • Loading branch information
PeriniM committed Jan 12, 2025
1 parent 1a01912 commit ad693b2
Show file tree
Hide file tree
Showing 12 changed files with 55 additions and 69 deletions.
4 changes: 2 additions & 2 deletions examples/local_models/smart_scraper_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"temperature": 0,
"format": "json", # Ollama needs the format to be specified explicitly
# "base_url": "http://localhost:11434", # set ollama URL arbitrarily
"model_tokens": 1024,
"model_tokens": 4096,
},
"verbose": True,
"headless": False,
Expand All @@ -25,7 +25,7 @@
# Create the SmartScraperGraph instance and run it
# ************************************************
smart_scraper_graph = SmartScraperGraph(
prompt="Find some information about what does the company do, the name and a contact email.",
prompt="Find some information about what does the company do and the list of founders.",
source="https://scrapegraphai.com/",
config=graph_config,
)
Expand Down
29 changes: 18 additions & 11 deletions examples/local_models/smart_scraper_schema_ollama.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
"""
"""
Basic example of scraping pipeline using SmartScraper with schema
"""

import json
from typing import List

from pydantic import BaseModel, Field

from scrapegraphai.graphs import SmartScraperGraph
from scrapegraphai.utils import prettify_exec_info


# ************************************************
# Define the configuration for the graph
# ************************************************
class Project(BaseModel):
title: str = Field(description="The title of the project")
description: str = Field(description="The description of the project")


class Projects(BaseModel):
projects: List[Project]
projects: list[Project]


graph_config = {
"llm": {
"model": "ollama/llama3.1",
"temperature": 0,
"format": "json", # Ollama needs the format to be specified explicitly
# "base_url": "http://localhost:11434", # set ollama URL arbitrarily
},
"llm": {"model": "ollama/llama3.2", "temperature": 0, "model_tokens": 4096},
"verbose": True,
"headless": False
"headless": False,
}

# ************************************************
Expand All @@ -36,8 +36,15 @@ class Projects(BaseModel):
prompt="List me all the projects with their description",
source="https://perinim.github.io/projects/",
schema=Projects,
config=graph_config
config=graph_config,
)

result = smart_scraper_graph.run()
print(json.dumps(result, indent=4))

# ************************************************
# Get graph execution info
# ************************************************

graph_exec_info = smart_scraper_graph.get_execution_info()
print(prettify_exec_info(graph_exec_info))
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ dependencies = [
"googlesearch-python>=1.2.5",
"async-timeout>=4.0.3",
"simpleeval>=1.0.0",
"jsonschema>=4.23.0",
"transformers>=4.46.3",
"jsonschema>=4.23.0"
]

readme = "README.md"
Expand Down
4 changes: 2 additions & 2 deletions scrapegraphai/docloaders/chromium.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,15 @@ def __init__(

dynamic_import(backend, message)

self.backend = backend
self.browser_config = kwargs
self.headless = headless
self.proxy = parse_or_search_proxy(proxy) if proxy else None
self.urls = urls
self.load_state = load_state
self.requires_js_support = requires_js_support
self.storage_state = storage_state
self.browser_name = browser_name
self.backend = kwargs.get("backend", backend)
self.browser_name = kwargs.get("browser_name", browser_name)
self.retry_limit = kwargs.get("retry_limit", retry_limit)
self.timeout = kwargs.get("timeout", timeout)

Expand Down
5 changes: 3 additions & 2 deletions scrapegraphai/graphs/abstract_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,9 @@ def _create_llm(self, llm_config: dict) -> object:
]
except KeyError:
print(
f"""Model {llm_params['model_provider']}/{llm_params['model']} not found,
using default token size (8192)"""
f"""Max input tokens for model {llm_params['model_provider']}/{llm_params['model']} not found,
please specify the model_tokens parameter in the llm section of the graph configuration.
Using default token size: 8192"""
)
self.model_token = 8192
else:
Expand Down
10 changes: 6 additions & 4 deletions scrapegraphai/nodes/generate_answer_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from langchain_community.chat_models import ChatOllama
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from langchain_openai import ChatOpenAI
from requests.exceptions import Timeout
from tqdm import tqdm

Expand Down Expand Up @@ -59,7 +59,10 @@ def __init__(
self.llm_model = node_config["llm_model"]

if isinstance(node_config["llm_model"], ChatOllama):
self.llm_model.format = "json"
if node_config.get("schema", None) is None:
self.llm_model.format = "json"
else:
self.llm_model.format = self.node_config["schema"].model_json_schema()

self.verbose = node_config.get("verbose", False)
self.force = node_config.get("force", False)
Expand Down Expand Up @@ -123,8 +126,7 @@ def execute(self, state: dict) -> dict:
format_instructions = ""

if (
isinstance(self.llm_model, (ChatOpenAI, AzureChatOpenAI))
and not self.script_creator
not self.script_creator
or self.force
and not self.script_creator
or self.is_md_scraper
Expand Down
13 changes: 10 additions & 3 deletions scrapegraphai/nodes/generate_answer_node_k_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

from langchain.prompts import PromptTemplate
from langchain_aws import ChatBedrock
from langchain_community.chat_models import ChatOllama
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
from langchain_mistralai import ChatMistralAI
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from langchain_openai import ChatOpenAI
from tqdm import tqdm

from ..prompts import (
Expand Down Expand Up @@ -55,6 +56,13 @@ def __init__(
super().__init__(node_name, "node", input, output, 2, node_config)

self.llm_model = node_config["llm_model"]

if isinstance(node_config["llm_model"], ChatOllama):
if node_config.get("schema", None) is None:
self.llm_model.format = "json"
else:
self.llm_model.format = self.node_config["schema"].model_json_schema()

self.embedder_model = node_config.get("embedder_model", None)
self.verbose = node_config.get("verbose", False)
self.force = node_config.get("force", False)
Expand Down Expand Up @@ -92,8 +100,7 @@ def execute(self, state: dict) -> dict:
format_instructions = ""

if (
isinstance(self.llm_model, (ChatOpenAI, AzureChatOpenAI))
and not self.script_creator
not self.script_creator
or self.force
and not self.script_creator
or self.is_md_scraper
Expand Down
4 changes: 1 addition & 3 deletions scrapegraphai/nodes/parse_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def execute(self, state: dict) -> dict:
chunks = split_text_into_chunks(
text=docs_transformed.page_content,
chunk_size=self.chunk_size - 250,
model=self.llm_model,
)
else:
docs_transformed = docs_transformed[0]
Expand All @@ -115,11 +114,10 @@ def execute(self, state: dict) -> dict:
chunks = split_text_into_chunks(
text=docs_transformed.page_content,
chunk_size=chunk_size,
model=self.llm_model,
)
else:
chunks = split_text_into_chunks(
text=docs_transformed, chunk_size=chunk_size, model=self.llm_model
text=docs_transformed, chunk_size=chunk_size
)

state.update({self.output[0]: chunks})
Expand Down
14 changes: 5 additions & 9 deletions scrapegraphai/utils/split_text_into_chunks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,10 @@

from typing import List

from langchain_core.language_models.chat_models import BaseChatModel

from .tokenizer import num_tokens_calculus


def split_text_into_chunks(
text: str, chunk_size: int, model: BaseChatModel, use_semchunk=True
) -> List[str]:
def split_text_into_chunks(text: str, chunk_size: int, use_semchunk=True) -> List[str]:
"""
Splits the text into chunks based on the number of tokens.
Expand All @@ -27,17 +23,17 @@ def split_text_into_chunks(
from semchunk import chunk

def count_tokens(text):
return num_tokens_calculus(text, model)
return num_tokens_calculus(text)

chunk_size = min(chunk_size - 500, int(chunk_size * 0.9))
chunk_size = min(chunk_size, int(chunk_size * 0.9))

chunks = chunk(
text=text, chunk_size=chunk_size, token_counter=count_tokens, memoize=False
)
return chunks

else:
tokens = num_tokens_calculus(text, model)
tokens = num_tokens_calculus(text)

if tokens <= chunk_size:
return [text]
Expand All @@ -48,7 +44,7 @@ def count_tokens(text):

words = text.split()
for word in words:
word_tokens = num_tokens_calculus(word, model)
word_tokens = num_tokens_calculus(word)
if current_length + word_tokens > chunk_size:
chunks.append(" ".join(current_chunk))
current_chunk = [word]
Expand Down
28 changes: 4 additions & 24 deletions scrapegraphai/utils/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,15 @@
Module for counting tokens and splitting text into chunks
"""

from langchain_core.language_models.chat_models import BaseChatModel
from langchain_mistralai import ChatMistralAI
from langchain_ollama import ChatOllama
from langchain_openai import ChatOpenAI
from .tokenizers.tokenizer_openai import num_tokens_openai


def num_tokens_calculus(string: str, llm_model: BaseChatModel) -> int:
def num_tokens_calculus(string: str) -> int:
"""
Returns the number of tokens in a text string.
"""
if isinstance(llm_model, ChatOpenAI):
from .tokenizers.tokenizer_openai import num_tokens_openai

num_tokens_fn = num_tokens_openai
num_tokens_fn = num_tokens_openai

elif isinstance(llm_model, ChatMistralAI):
from .tokenizers.tokenizer_mistral import num_tokens_mistral

num_tokens_fn = num_tokens_mistral

elif isinstance(llm_model, ChatOllama):
from .tokenizers.tokenizer_ollama import num_tokens_ollama

num_tokens_fn = num_tokens_ollama

else:
from .tokenizers.tokenizer_openai import num_tokens_openai

num_tokens_fn = num_tokens_openai

num_tokens = num_tokens_fn(string, llm_model)
num_tokens = num_tokens_fn(string)
return num_tokens
6 changes: 2 additions & 4 deletions scrapegraphai/utils/tokenizers/tokenizer_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,17 @@
"""

import tiktoken
from langchain_core.language_models.chat_models import BaseChatModel

from ..logging import get_logger


def num_tokens_openai(text: str, llm_model: BaseChatModel) -> int:
def num_tokens_openai(text: str) -> int:
"""
Estimate the number of tokens in a given text using OpenAI's tokenization method,
adjusted for different OpenAI models.
Args:
text (str): The text to be tokenized and counted.
llm_model (BaseChatModel): The specific OpenAI model to adjust tokenization.
Returns:
int: The number of tokens in the text.
Expand All @@ -25,7 +23,7 @@ def num_tokens_openai(text: str, llm_model: BaseChatModel) -> int:

logger.debug(f"Counting tokens for text of {len(text)} characters")

encoding = tiktoken.encoding_for_model("gpt-4")
encoding = tiktoken.encoding_for_model("gpt-4o")

num_tokens = len(encoding.encode(text))
return num_tokens
4 changes: 1 addition & 3 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit ad693b2

Please sign in to comment.