Skip to content

Commit

Permalink
Dataset size on CLI (#345)
Browse files Browse the repository at this point in the history
* init

* lint

* fix

* lint

* lint

* rm print stmt

* Update dataset_utils.py
  • Loading branch information
hitenvidhani authored Sep 14, 2023
1 parent 779c7b9 commit c39b68a
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from prompt2model.dataset_retriever.base import DatasetInfo, DatasetRetriever
from prompt2model.prompt_parser import PromptSpec
from prompt2model.utils import encode_text, retrieve_objects
from prompt2model.utils.dataset_utils import get_dataset_size

datasets.utils.logging.disable_progress_bar()
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -115,10 +116,12 @@ def choose_dataset_by_cli(self, top_datasets: list[DatasetInfo]) -> str | None:
"""
self._print_divider()
print("Here are the datasets I've retrieved for you:")
print("#\tName\tDescription")
print("#\tName\tSize[MB]\tDescription")
for i, d in enumerate(top_datasets):
description_no_spaces = d.description.replace("\n", " ")
print(f"{i+1}):\t{d.name}\t{description_no_spaces}")
description_no_space = d.description.replace("\n", " ")
print(
f"{i+1}):\t{d.name}\t{get_dataset_size(d.name)}\t{description_no_space}"
)

self._print_divider()
print(
Expand Down
33 changes: 33 additions & 0 deletions prompt2model/utils/dataset_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Util functions for datasets."""

import requests

from prompt2model.utils.logging_utils import get_formatted_logger

logger = get_formatted_logger("dataset_utils")


def query(API_URL):
"""Returns a response json for a URL."""
try:
response = requests.get(API_URL)
if response.status_code == 200:
return response.json()
else:
logger.error(f"Error occurred in fetching size: {response.status_code}")
except requests.exceptions.RequestException as e:
logger.error("Error occurred in making the request: " + str(e))

return {}


def get_dataset_size(dataset_name):
"""Fetches dataset size for a dataset in MB from hugging face API."""
API_URL = f"https://datasets-server.huggingface.co/size?dataset={dataset_name}"
data = query(API_URL)
size_dict = data.get("size", {})
return (
"NA"
if size_dict is {}
else "{:.2f}".format(size_dict["dataset"]["num_bytes_memory"] / 1024 / 1024)
)
64 changes: 64 additions & 0 deletions prompt2model/utils/dataset_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Testing dataset utility functions."""
from unittest.mock import patch

from prompt2model.utils import dataset_utils


@patch("prompt2model.utils.dataset_utils.query")
def test_get_dataset_size(mock_request):
"""Test function for get_dataset_size."""
mock_request.return_value = {
"size": {
"dataset": {
"dataset": "rotten_tomatoes",
"num_bytes_original_files": 487770,
"num_bytes_parquet_files": 881052,
"num_bytes_memory": 1345449,
"num_rows": 10662,
},
"configs": [
{
"dataset": "rotten_tomatoes",
"config": "default",
"num_bytes_original_files": 487770,
"num_bytes_parquet_files": 881052,
"num_bytes_memory": 1345449,
"num_rows": 10662,
"num_columns": 2,
}
],
"splits": [
{
"dataset": "rotten_tomatoes",
"config": "default",
"split": "train",
"num_bytes_parquet_files": 698845,
"num_bytes_memory": 1074806,
"num_rows": 8530,
"num_columns": 2,
},
{
"dataset": "rotten_tomatoes",
"config": "default",
"split": "validation",
"num_bytes_parquet_files": 90001,
"num_bytes_memory": 134675,
"num_rows": 1066,
"num_columns": 2,
},
{
"dataset": "rotten_tomatoes",
"config": "default",
"split": "test",
"num_bytes_parquet_files": 92206,
"num_bytes_memory": 135968,
"num_rows": 1066,
"num_columns": 2,
},
],
},
"pending": [],
"failed": [],
"partial": False,
}
assert dataset_utils.get_dataset_size("rotten_tomatoes") == "1.28"

0 comments on commit c39b68a

Please sign in to comment.