Skip to content

Commit

Permalink
Merge pull request #754 from ekinsenler/cond_node_refactor
Browse files Browse the repository at this point in the history
feat: add conditional node to the smart_scraper_graph
  • Loading branch information
VinciGit00 authored Oct 16, 2024
2 parents e0fc457 + eaa83ed commit aaa011c
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 60 deletions.
38 changes: 38 additions & 0 deletions examples/extras/cond_smartscraper_usage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
Basic example of scraping pipeline using SmartScraperMultiConcatGraph with Groq
"""

import os
import json
from dotenv import load_dotenv
from scrapegraphai.graphs import SmartScraperGraph

load_dotenv()

# ************************************************
# Define the configuration for the graph
# ************************************************

graph_config = {
"llm": {
"api_key": os.getenv("GROQ_APIKEY"),
"model": "groq/gemma-7b-it",
},
"verbose": True,
"headless": True,
"reattempt": True #Setting this to True will allow the graph to reattempt the scraping process
}

# *******************************************************
# Create the SmartScraperMultiCondGraph instance and run it
# *******************************************************

multiple_search_graph = SmartScraperGraph(
prompt="Who is Marco Perini?",
source="https://perinim.github.io/",
schema=None,
config=graph_config
)

result = multiple_search_graph.run()
print(json.dumps(result, indent=4))
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ dependencies = [
"async-timeout>=4.0.3",
"transformers>=4.44.2",
"googlesearch-python>=1.2.5",
"simpleeval>=1.0.0"
"simpleeval>=1.0.0",
"async_timeout>=4.0.3"
]

license = "MIT"
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ undetected-playwright>=0.3.0
semchunk>=1.0.1
langchain-ollama>=0.1.3
simpleeval>=0.9.13
googlesearch-python>=1.2.5
googlesearch-python>=1.2.5
async_timeout>=4.0.3
7 changes: 6 additions & 1 deletion scrapegraphai/graphs/base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,10 @@ def _set_conditional_node_edges(self):
raise ValueError(f"ConditionalNode '{node.node_name}' must have exactly two outgoing edges.")
# Assign true_node_name and false_node_name
node.true_node_name = outgoing_edges[0][1].node_name
node.false_node_name = outgoing_edges[1][1].node_name
try:
node.false_node_name = outgoing_edges[1][1].node_name
except:
node.false_node_name = None

def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
"""
Expand Down Expand Up @@ -221,6 +224,8 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
node_names = {node.node_name for node in self.nodes}
if result in node_names:
current_node_name = result
elif result is None:
current_node_name = None
else:
raise ValueError(f"Conditional Node returned a node name '{result}' that does not exist in the graph")

Expand Down
138 changes: 83 additions & 55 deletions scrapegraphai/graphs/smart_scraper_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@
SmartScraperGraph Module
"""
from typing import Optional
import logging
from pydantic import BaseModel
from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph
from ..nodes import (
FetchNode,
ParseNode,
ReasoningNode,
GenerateAnswerNode
GenerateAnswerNode,
ConditionalNode
)
from ..prompts import REGEN_ADDITIONAL_INFO

class SmartScraperGraph(AbstractGraph):
"""
Expand Down Expand Up @@ -89,6 +90,28 @@ def _create_graph(self) -> BaseGraph:
}
)

cond_node = None
regen_node = None
if self.config.get("reattempt") is True:
cond_node = ConditionalNode(
input="answer",
output=["answer"],
node_name="ConditionalNode",
node_config={
"key_name": "answer",
"condition": 'not answer or answer=="NA"',
}
)
regen_node = GenerateAnswerNode(
input="user_prompt & answer",
output=["answer"],
node_config={
"llm_model": self.llm_model,
"additional_info": REGEN_ADDITIONAL_INFO,
"schema": self.schema,
}
)

if self.config.get("html_mode") is False:
parse_node = ParseNode(
input="doc",
Expand All @@ -99,6 +122,7 @@ def _create_graph(self) -> BaseGraph:
}
)

reasoning_node = None
if self.config.get("reasoning"):
reasoning_node = ReasoningNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
Expand All @@ -109,68 +133,72 @@ def _create_graph(self) -> BaseGraph:
"schema": self.schema,
}
)

# Define the graph variation configurations
# (html_mode, reasoning, reattempt)
graph_variation_config = {
(False, True, False): {
"nodes": [fetch_node, parse_node, reasoning_node, generate_answer_node],
"edges": [(fetch_node, parse_node), (parse_node, reasoning_node), (reasoning_node, generate_answer_node)]
},
(True, True, False): {
"nodes": [fetch_node, reasoning_node, generate_answer_node],
"edges": [(fetch_node, reasoning_node), (reasoning_node, generate_answer_node)]
},
(True, False, False): {
"nodes": [fetch_node, generate_answer_node],
"edges": [(fetch_node, generate_answer_node)]
},
(False, False, False): {
"nodes": [fetch_node, parse_node, generate_answer_node],
"edges": [(fetch_node, parse_node), (parse_node, generate_answer_node)]
},
(False, True, True): {
"nodes": [fetch_node, parse_node, reasoning_node, generate_answer_node, cond_node, regen_node],
"edges": [(fetch_node, parse_node), (parse_node, reasoning_node), (reasoning_node, generate_answer_node),
(generate_answer_node, cond_node), (cond_node, regen_node), (cond_node, None)]
},
(True, True, True): {
"nodes": [fetch_node, reasoning_node, generate_answer_node, cond_node, regen_node],
"edges": [(fetch_node, reasoning_node), (reasoning_node, generate_answer_node),
(generate_answer_node, cond_node), (cond_node, regen_node), (cond_node, None)]
},
(True, False, True): {
"nodes": [fetch_node, generate_answer_node, cond_node, regen_node],
"edges": [(fetch_node, generate_answer_node), (generate_answer_node, cond_node),
(cond_node, regen_node), (cond_node, None)]
},
(False, False, True): {
"nodes": [fetch_node, parse_node, generate_answer_node, cond_node, regen_node],
"edges": [(fetch_node, parse_node), (parse_node, generate_answer_node),
(generate_answer_node, cond_node), (cond_node, regen_node), (cond_node, None)]
}
}

if self.config.get("html_mode") is False and self.config.get("reasoning") is True:

return BaseGraph(
nodes=[
fetch_node,
parse_node,
reasoning_node,
generate_answer_node,
],
edges=[
(fetch_node, parse_node),
(parse_node, reasoning_node),
(reasoning_node, generate_answer_node)
],
entry_point=fetch_node,
graph_name=self.__class__.__name__
)

elif self.config.get("html_mode") is True and self.config.get("reasoning") is True:
# Get the current conditions
html_mode = self.config.get("html_mode", False)
reasoning = self.config.get("reasoning", False)
reattempt = self.config.get("reattempt", False)

return BaseGraph(
nodes=[
fetch_node,
reasoning_node,
generate_answer_node,
],
edges=[
(fetch_node, reasoning_node),
(reasoning_node, generate_answer_node)
],
entry_point=fetch_node,
graph_name=self.__class__.__name__
)
# Retrieve the appropriate graph configuration
config = graph_variation_config.get((html_mode, reasoning, reattempt))

elif self.config.get("html_mode") is True and self.config.get("reasoning") is False:
if config:
return BaseGraph(
nodes=[
fetch_node,
generate_answer_node,
],
edges=[
(fetch_node, generate_answer_node)
],
nodes=config["nodes"],
edges=config["edges"],
entry_point=fetch_node,
graph_name=self.__class__.__name__
)

# Default return if no conditions match
return BaseGraph(
nodes=[
fetch_node,
parse_node,
generate_answer_node,
],
edges=[
(fetch_node, parse_node),
(parse_node, generate_answer_node)
],
entry_point=fetch_node,
graph_name=self.__class__.__name__
)

nodes=[fetch_node, parse_node, generate_answer_node],
edges=[(fetch_node, parse_node), (parse_node, generate_answer_node)],
entry_point=fetch_node,
graph_name=self.__class__.__name__
)

def run(self) -> str:
"""
Executes the scraping process and returns the answer to the prompt.
Expand Down
2 changes: 1 addition & 1 deletion scrapegraphai/nodes/conditional_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def execute(self, state: dict) -> dict:
str: The name of the next node to execute based on the presence of the key.
"""

if self.true_node_name is None or self.false_node_name is None:
if self.true_node_name is None:
raise ValueError("ConditionalNode's next nodes are not set properly.")

if self.condition:
Expand Down
2 changes: 1 addition & 1 deletion scrapegraphai/prompts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .generate_answer_node_prompts import (TEMPLATE_CHUNKS,
TEMPLATE_NO_CHUNKS,
TEMPLATE_MERGE, TEMPLATE_CHUNKS_MD,
TEMPLATE_NO_CHUNKS_MD, TEMPLATE_MERGE_MD)
TEMPLATE_NO_CHUNKS_MD, TEMPLATE_MERGE_MD, REGEN_ADDITIONAL_INFO)
from .generate_answer_node_csv_prompts import (TEMPLATE_CHUKS_CSV,
TEMPLATE_NO_CHUKS_CSV,
TEMPLATE_MERGE_CSV)
Expand Down
4 changes: 4 additions & 0 deletions scrapegraphai/prompts/generate_answer_node_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,7 @@
USER QUESTION: {question}\n
WEBSITE CONTENT: {context}\n
"""

REGEN_ADDITIONAL_INFO = """
You are a scraper and you have just failed to scrape the requested information from a website. \n
I want you to try again and provide the missing informations. \n"""

0 comments on commit aaa011c

Please sign in to comment.