Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Support generator handlers #320

Merged
merged 2 commits into from
Mar 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions dude/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import asyncio
import collections
import inspect
import itertools
import logging
import time
from abc import ABC, abstractmethod
from pathlib import Path
from types import GeneratorType
from typing import (
Any,
AsyncIterable,
Expand Down Expand Up @@ -178,7 +180,7 @@ def wrapper(func: Callable) -> Union[Callable, Coroutine]:
sel = Selector(selector=selector, css=css, xpath=xpath, text=text, regex=regex)
assert sel, "Any of selector, css, xpath, text and regex params should be present."

if asyncio.iscoroutinefunction(func):
if asyncio.iscoroutinefunction(func) or inspect.isasyncgenfunction(func):
self.has_async = True

rule = Rule(
Expand Down Expand Up @@ -536,18 +538,29 @@ def extract_all(self, page_number: int, **kwargs: Any) -> Iterable[ScrapedData]:
for page_url, group_index, group_id, element_index, element, handler in collected_elements:
data = handler(element)

if isinstance(data, GeneratorType):
for index, d in enumerate(data):
yield ScrapedData(
page_number=page_number,
page_url=page_url,
group_id=group_id,
group_index=group_index,
element_index=index,
data=d,
)
continue

if not data:
continue

scraped_data = ScrapedData(
yield ScrapedData(
page_number=page_number,
page_url=page_url,
group_id=group_id,
group_index=group_index,
element_index=element_index,
data=data,
)
yield scraped_data

async def extract_all_async(self, page_number: int, **kwargs: Any) -> AsyncIterable[ScrapedData]:
"""
Expand All @@ -557,20 +570,33 @@ async def extract_all_async(self, page_number: int, **kwargs: Any) -> AsyncItera
collected_elements = [element async for element in self.collect_elements_async(**kwargs)]

for page_url, group_index, group_id, element_index, element, handler in collected_elements:
if inspect.isasyncgenfunction(handler):
index = 0
async for data in handler(element):
yield ScrapedData(
page_number=page_number,
page_url=page_url,
group_id=group_id,
group_index=group_index,
element_index=index,
data=data,
)
index += 1
continue

data = await handler(element)

if not data:
continue

scraped_data = ScrapedData(
yield ScrapedData(
page_number=page_number,
page_url=page_url,
group_id=group_id,
group_index=group_index,
element_index=element_index,
data=data,
)
yield scraped_data

def get_scraping_rules(self, url: str) -> Iterable[Rule]:
return filter(rule_filter(url), self.rules)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pydude"
version = "0.23.0"
version = "0.24.0"
repository = "https://github.com/roniemartinez/dude"
description = "dude uncomplicated data extraction"
authors = ["Ronie Martinez <[email protected]>"]
Expand Down
34 changes: 34 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,40 @@ def expected_data(base_url: str) -> List[Dict]:
]


@pytest.fixture()
def expected_generator_data(base_url: str) -> List[Dict]:
is_integer = IsInteger()
return [
{
"_page_number": 1,
"_page_url": base_url,
"_group_id": is_integer,
"_group_index": 0,
"_element_index": 0,
"url": "url-1.html",
"title": "Title 1",
},
{
"_page_number": 1,
"_page_url": base_url,
"_group_id": is_integer,
"_group_index": 0,
"_element_index": 1,
"url": "url-2.html",
"title": "Title 2",
},
{
"_page_number": 1,
"_page_url": base_url,
"_group_id": is_integer,
"_group_index": 0,
"_element_index": 2,
"url": "url-3.html",
"title": "Title 3",
},
]


@pytest.fixture()
def expected_browser_data(file_url: str) -> List[Dict]:
is_integer = IsInteger()
Expand Down
60 changes: 59 additions & 1 deletion tests/test_parsel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List
from typing import AsyncIterator, Dict, Iterator, List
from unittest import mock
from urllib.parse import urljoin

Expand Down Expand Up @@ -88,6 +88,30 @@ def url(selector: parsel.Selector) -> Dict:
return {"url": selector.get()}


@pytest.fixture()
def parsel_generator(scraper_application: Scraper) -> None:
@scraper_application.select(css="body")
def generator(selector: parsel.Selector) -> Iterator[Dict]:
group: parsel.Selector
for group in selector.css(".custom-group"):
yield {
"title": group.css(".title::text").get(),
"url": group.css(".url::attr(href)").get(),
}


@pytest.fixture()
def async_parsel_generator(scraper_application: Scraper) -> None:
@scraper_application.select(css="body")
async def generator(selector: parsel.Selector) -> AsyncIterator[Dict]:
group: parsel.Selector
for group in selector.css(".custom-group"):
yield {
"title": group.css(".title::text").get(),
"url": group.css(".url::attr(href)").get(),
}


def test_full_flow_parsel(
scraper_application: Scraper,
parsel_css: None,
Expand Down Expand Up @@ -224,3 +248,37 @@ def test_full_flow_parsel_regex(
scraper_application.run(urls=[base_url], pages=2, format="custom", parser="parsel")

mock_database.save.assert_called_with(expected_data)


def test_full_flow_parsel_generator(
scraper_application: Scraper,
parsel_generator: None,
expected_generator_data: List[Dict],
base_url: str,
scraper_save: None,
mock_database: mock.MagicMock,
mock_httpx: Router,
) -> None:
assert scraper_application.has_async is False
assert len(scraper_application.rules) == 1

scraper_application.run(urls=[base_url], pages=2, format="custom", parser="parsel")

mock_database.save.assert_called_with(expected_generator_data)


def test_full_flow_parsel_async_generator(
scraper_application: Scraper,
async_parsel_generator: None,
expected_generator_data: List[Dict],
base_url: str,
scraper_save: None,
mock_database: mock.MagicMock,
mock_httpx: Router,
) -> None:
assert len(scraper_application.rules) == 1
assert scraper_application.has_async is True

scraper_application.run(urls=[base_url], pages=2, format="custom", parser="parsel")

mock_database.save.assert_called_with(expected_generator_data)