diff --git a/README.md b/README.md index bb4facb..7dc9df1 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,7 @@ pip install . ## 命令行使用示例 ```shell -export CSG_TOKEN=3b77c98077b415ca381ded189b86d5df226e3776 +export CSG_TOKEN=your_access_token # 模型下载 csghub-cli download wanghh2000/myprivate1 @@ -100,19 +100,19 @@ csghub-cli upload wanghh2000/myds1 abc/4.txt abc/5.txt -t dataset ```python from pycsghub.snapshot_download import snapshot_download -token = "3b77c98077b415ca381ded189b86d5df226e3776" +token = "your_access_token" endpoint = "https://hub.opencsg.com" repo_type = "model" repo_id = 'OpenCSG/csg-wukong-1B' cache_dir = '/Users/hhwang/temp/' -result = snapshot_download(repo_id, cache_dir=cache_dir, endpoint=endpoint, token=token, repo_type=repotype) +result = snapshot_download(repo_id, repo_type=repo_type, cache_dir=cache_dir, endpoint=endpoint, token=token,) ``` ### 数据集下载 ```python from pycsghub.snapshot_download import snapshot_download -token = "3b77c98077b415ca381ded189b86d5df226e3776" +token = "your_access_token" endpoint = "https://hub.opencsg.com" repo_id = 'AIWizards/tmmluplus' @@ -127,7 +127,7 @@ result = snapshot_download(repo_id, repo_type=repo_type, cache_dir=cache_dir, en ```python from pycsghub.file_download import file_download -token = "3b77c98077b415ca381ded189b86d5df226e3776" +token = "your_access_token" endpoint = "https://hub.opencsg.com" repo_type = "model" @@ -140,7 +140,7 @@ result = file_download(repo_id, file_name='README.md', cache_dir=cache_dir, endp ```python from pycsghub.file_download import http_get -token = "3b77c98077b415ca381ded189b86d5df226e3776" +token = "your_access_token" url = "https://hub.opencsg.com/api/v1/models/OpenCSG/csg-wukong-1B/resolve/tokenizer.model" local_dir = '/home/test/' @@ -155,7 +155,7 @@ http_get(url=url, token=token, local_dir=local_dir, file_name=file_name, headers ```python from pycsghub.file_upload import http_upload_file -token = "3b77c98077b415ca381ded189b86d5df226e3776" +token = "your_access_token" endpoint = "https://hub.opencsg.com" repo_type = "model" @@ -168,7 +168,7 @@ result = http_upload_file(repo_id, endpoint=endpoint, token=token, repo_type='mo ```python from pycsghub.file_upload import http_upload_file -token = "3b77c98077b415ca381ded189b86d5df226e3776" +token = "your_access_token" endpoint = "https://hub.opencsg.com" repo_type = "model" diff --git a/README_EN.md b/README_EN.md index 9066ee4..9f1f944 100644 --- a/README_EN.md +++ b/README_EN.md @@ -46,7 +46,7 @@ After installation, you can begin using the SDK to connect to your CSGHub server import os from pycsghub.repo_reader import AutoModelForCausalLM, AutoTokenizer -os.environ['CSG_TOKEN'] = '3b77c98077b415ca381ded189b86d5df226e3776' +os.environ['CSG_TOKEN'] = 'your_access_token' mid = 'OpenCSG/csg-wukong-1B' model = AutoModelForCausalLM.from_pretrained(mid) @@ -77,7 +77,7 @@ pip install . ## Use cases of command line ```shell -export CSG_TOKEN=3b77c98077b415ca381ded189b86d5df226e3776 +export CSG_TOKEN=your_access_token # download model csghub-cli download wanghh2000/myprivate1 @@ -102,7 +102,7 @@ For more detailed instructions, including API documentation and usage examples, ```python from pycsghub.snapshot_download import snapshot_download -token = "3b77c98077b415ca381ded189b86d5df226e3776" +token = "your_access_token" endpoint = "https://hub.opencsg.com" repo_id = 'OpenCSG/csg-wukong-1B' @@ -127,7 +127,7 @@ Use `http_get` function to download single file ```python from pycsghub.file_download import http_get -token = "3b77c98077b415ca381ded189b86d5df226e3776" +token = "your_access_token" url = "https://hub.opencsg.com/api/v1/models/OpenCSG/csg-wukong-1B/resolve/tokenizer.model" local_dir = '/home/test/' @@ -141,7 +141,7 @@ use `file_download` function to download single file from a repository ```python from pycsghub.file_download import file_download -token = "3b77c98077b415ca381ded189b86d5df226e3776" +token = "your_access_token" endpoint = "https://hub.opencsg.com" repo_id = 'OpenCSG/csg-wukong-1B' @@ -154,7 +154,7 @@ result = file_download(repo_id, file_name='README.md', cache_dir=cache_dir, endp ```python from pycsghub.file_upload import http_upload_file -token = "3b77c98077b415ca381ded189b86d5df226e3776" +token = "your_access_token" endpoint = "https://hub.opencsg.com" repo_type = "model" @@ -167,7 +167,7 @@ result = http_upload_file(repo_id, endpoint=endpoint, token=token, repo_type='mo ```python from pycsghub.file_upload import http_upload_file -token = "3b77c98077b415ca381ded189b86d5df226e3776" +token = "your_access_token" endpoint = "https://hub.opencsg.com" repo_type = "model" diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000..bdd541c --- /dev/null +++ b/examples/README.md @@ -0,0 +1,25 @@ +# Examples + +我们提供了大量示例脚本,用于通过 CSGHub SDK 与 CSGHub 服务器进行交互。 + +虽然我们努力展示尽可能多的用例,但预计它们不会在您的特定问题上开箱即用,并且您需要更改几行代码以适应您的需求。为了帮助您,大多数示例完全公开了数据的预处理,允许您根据需要进行调整和编辑。 + +## Important note + +**Important** + +为了确保您能够成功运行最新版本的示例脚本,您需要**从源代码安装库**。为此,请在新虚拟环境中执行以下步骤: + +```shell +git clone https://github.com/OpenCSGs/csghub-sdk.git +cd csghub-sdk +pip install . +``` + +运行示例脚本前,请先设置必要的环境变量如下。 + +```shell +export HF_ENDPOINT="https://hub.opencsg.com" +``` + +你可以根据自己的需求调整脚本。 diff --git a/examples/README_EN.md b/examples/README_EN.md new file mode 100644 index 0000000..d2e9bd3 --- /dev/null +++ b/examples/README_EN.md @@ -0,0 +1,25 @@ +# Examples + +We host a wide range of example scripts for use CSGHub SDK to interact with the CSGHub server. + +While we strive to present as many use cases as possible. It is expected that they won't work out-of-the-box on your specific problem and that you will be required to change a few lines of code to adapt them to your needs. To help you with that, most of the examples fully expose the preprocessing of the data, allowing you to tweak and edit them as required. + +## Important note + +**Important** + +To make sure you can successfully run the latest versions of the example scripts, you have to **install the library from source**. To do this, execute the following steps in a new virtual environment: + +```shell +git clone https://github.com/OpenCSGs/csghub-sdk.git +cd csghub-sdk +pip install . +``` + +Before running the example script, please set the necessary environment variables as follows. + +```shell +export HF_ENDPOINT="https://hub.opencsg.com" +``` + +You can also adapt the script to your own needs. diff --git a/examples/download_dataset.py b/examples/download_dataset.py new file mode 100644 index 0000000..2788c1a --- /dev/null +++ b/examples/download_dataset.py @@ -0,0 +1,9 @@ +from pycsghub.snapshot_download import snapshot_download +# token = "your access token" +token = None + +endpoint = "https://hub.opencsg.com" +repo_id = 'OpenDataLab/CodeExp' +repo_type = "dataset" +cache_dir = '/Users/hhwang/temp/' +result = snapshot_download(repo_id, repo_type=repo_type, cache_dir=cache_dir, endpoint=endpoint, token=token) diff --git a/examples/download_file.py b/examples/download_file.py new file mode 100644 index 0000000..7c28088 --- /dev/null +++ b/examples/download_file.py @@ -0,0 +1,9 @@ +from pycsghub.file_download import file_download +# token = "your access token" +token = None + +endpoint = "https://hub.opencsg.com" +repo_type = "model" +repo_id = 'OpenCSG/csg-wukong-1B' +cache_dir = '/Users/hhwang/temp/' +result = file_download(repo_id, file_name='README.md', cache_dir=cache_dir, endpoint=endpoint, token=token, repo_type=repo_type) diff --git a/examples/download_model.py b/examples/download_model.py new file mode 100644 index 0000000..96d9e3a --- /dev/null +++ b/examples/download_model.py @@ -0,0 +1,9 @@ +from pycsghub.snapshot_download import snapshot_download +# token = "your access token" +token = None + +endpoint = "https://hub.opencsg.com" +repo_type = "model" +repo_id = 'OpenCSG/csg-wukong-1B' +cache_dir = '/Users/hhwang/temp/' +result = snapshot_download(repo_id, repo_type=repo_type, cache_dir=cache_dir, endpoint=endpoint, token=token) diff --git a/examples/load_dataset.py b/examples/load_dataset.py new file mode 100644 index 0000000..487ff32 --- /dev/null +++ b/examples/load_dataset.py @@ -0,0 +1,11 @@ +# from datasets.load import load_dataset +from pycsghub.repo_reader import load_dataset + +dsPath = "wanghh2000/glue" +dsName = "mrpc" + +# access_token = "your_access_token" +access_token = None + +raw_datasets = load_dataset(path=dsPath, name=dsName, token=access_token) +print('raw_datasets', raw_datasets) diff --git a/examples/run_finetune_bert.py b/examples/run_finetune_bert.py new file mode 100644 index 0000000..7e7b7f5 --- /dev/null +++ b/examples/run_finetune_bert.py @@ -0,0 +1,65 @@ +from typing import Any +import pandas as pd + +from transformers import DataCollatorWithPadding +from transformers import TrainingArguments +from transformers import Trainer + +from pycsghub.repo_reader import load_dataset +from pycsghub.repo_reader import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig + +model_id_or_path = "wanghh2000/bert-base-uncased" +tokenizer = AutoTokenizer.from_pretrained(model_id_or_path, trust_remote_code=True) +model = AutoModelForSequenceClassification.from_pretrained(model_id_or_path) + +dsPath = "wanghh2000/glue" +dsName = "mrpc" +# access_token = "your_access_token" +access_token = None +raw_datasets = load_dataset(dsPath, dsName, token=access_token) + +def get_data_proprocess() -> Any: + def preprocess_function(examples: pd.DataFrame): + ret = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=100) + ret = {**examples, **ret} + return pd.DataFrame.from_dict(ret) + return preprocess_function + +train_dataset = raw_datasets["train"].select(range(20)).map(get_data_proprocess(), batched=True) +eval_dataset = raw_datasets["validation"].select(range(20)).map(get_data_proprocess(), batched=True) + +def data_collator() -> Any: + data_collator = DataCollatorWithPadding(tokenizer=tokenizer) + return data_collator + +outputDir = "/Users/hhwang/temp/ff" +args = TrainingArguments( + outputDir, + evaluation_strategy="steps", + save_strategy="steps", + logging_strategy="steps", + logging_steps = 2, + save_steps = 10, + eval_steps = 2, + learning_rate=2e-5, + per_device_train_batch_size=4, + per_device_eval_batch_size=4, + num_train_epochs=2, + weight_decay=0.01, + push_to_hub=False, + disable_tqdm=False, # declutter the output a little + use_cpu=True, # you need to explicitly set no_cuda if you want CPUs + remove_unused_columns=True, +) + +trainer = Trainer( + model, + args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, +) + +trainResult = trainer.train() +trainer.save_model() +print(f"save model to {outputDir}") diff --git a/examples/run_wukong_inference.py b/examples/run_wukong_inference.py new file mode 100644 index 0000000..48cc3b8 --- /dev/null +++ b/examples/run_wukong_inference.py @@ -0,0 +1,10 @@ +import os +from pycsghub.repo_reader import AutoModelForCausalLM, AutoTokenizer + +mid = 'OpenCSG/csg-wukong-1B' +model = AutoModelForCausalLM.from_pretrained(mid) +tokenizer = AutoTokenizer.from_pretrained(mid) + +inputs = tokenizer.encode("Write a short story", return_tensors="pt") +outputs = model.generate(inputs) +print('result: ',tokenizer.batch_decode(outputs)) diff --git a/examples/upload_file.py b/examples/upload_file.py new file mode 100644 index 0000000..434f304 --- /dev/null +++ b/examples/upload_file.py @@ -0,0 +1,8 @@ +from pycsghub.file_upload import http_upload_file + +token = "your_access_token" + +endpoint = "https://hub.opencsg.com" +repo_type = "model" +repo_id = 'wanghh2000/myprivate1' +result = http_upload_file(repo_id, endpoint=endpoint, token=token, repo_type='model', file_path='README.md') diff --git a/pycsghub/file_download.py b/pycsghub/file_download.py index ce2ea58..4ceea39 100644 --- a/pycsghub/file_download.py +++ b/pycsghub/file_download.py @@ -79,7 +79,7 @@ def file_download( " online, set 'local_files_only' to False.") return cache.get_root_location() else: - download_endpoint = endpoint if endpoint is not None else get_endpoint() + download_endpoint = get_endpoint(endpoint=endpoint) # make headers # todo need to add cookies? repo_info = utils.get_repo_info(repo_id=repo_id, @@ -205,10 +205,9 @@ def http_get(*, if __name__ == '__main__': - token = "f3a7b9c1d6e5f8e2a1b5d4f9e6a2b8d7c3a4e2b1d9f6e7a8d2c5a7b4c1e3f5b8a1d4f9" + \ - "b7d6e2f8a5d3b1e7f9c6a8b2d1e4f7d5b6e9f2a4b3c8e1d7f995hd82hf" + token = "your_access_token" - url = "https://hub-stg.opencsg.com/api/v1/models/wayne0019/lwfmodel/resolve/lfsfile.bin" + url = "https://hub.opencsg.com/api/v1/models/wayne0019/lwfmodel/resolve/lfsfile.bin" local_dir = '/home/test/' file_name = 'test.txt' headers = None diff --git a/pycsghub/repo_reader/__init__.py b/pycsghub/repo_reader/__init__.py index 5131aa5..52c7ac5 100644 --- a/pycsghub/repo_reader/__init__.py +++ b/pycsghub/repo_reader/__init__.py @@ -1 +1,2 @@ -from .model.huggingface.model_auto import * \ No newline at end of file +from .model.huggingface.model_auto import * +from .dataset.huggingface.load import * \ No newline at end of file diff --git a/pycsghub/repo_reader/dataset/huggingface/load.py b/pycsghub/repo_reader/dataset/huggingface/load.py new file mode 100644 index 0000000..13b254c --- /dev/null +++ b/pycsghub/repo_reader/dataset/huggingface/load.py @@ -0,0 +1,69 @@ +from typing import Dict, Mapping, Optional, Sequence, Union +import datasets +from datasets.splits import Split +from datasets.features import Features +from datasets.download.download_config import DownloadConfig +from datasets.download.download_manager import DownloadMode +from datasets.utils.info_utils import VerificationMode +from datasets.utils.version import Version +from datasets.iterable_dataset import IterableDataset +from datasets.dataset_dict import DatasetDict, IterableDatasetDict +from datasets.arrow_dataset import Dataset +from pycsghub.snapshot_download import snapshot_download +from pycsghub.utils import get_token_to_send +from pycsghub.constants import REPO_TYPE_DATASET + +def load_dataset( + path: str, + name: Optional[str] = None, + data_dir: Optional[str] = None, + data_files: Optional[Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]] = None, + split: Optional[Union[str, Split]] = None, + cache_dir: Optional[str] = None, + features: Optional[Features] = None, + download_config: Optional[DownloadConfig] = None, + download_mode: Optional[Union[DownloadMode, str]] = None, + verification_mode: Optional[Union[VerificationMode, str]] = None, + ignore_verifications="deprecated", + keep_in_memory: Optional[bool] = None, + save_infos: bool = False, + revision: Optional[Union[str, Version]] = None, + token: Optional[Union[bool, str]] = None, + use_auth_token="deprecated", + task="deprecated", + streaming: bool = False, + num_proc: Optional[int] = None, + storage_options: Optional[Dict] = None, + trust_remote_code: bool = None, + **config_kwargs, +) -> Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset]: + if token is None: + try: + token = get_token_to_send(None) + except Exception: + pass + localPath = snapshot_download(path, repo_type=REPO_TYPE_DATASET, cache_dir=cache_dir, token=token) + return datasets.load.load_dataset( + path=localPath, + name=name, + data_dir=data_dir, + data_files=data_files, + split=split, + cache_dir=cache_dir, + features=features, + download_config=download_config, + download_mode=download_mode, + verification_mode=verification_mode, + ignore_verifications=ignore_verifications, + keep_in_memory=keep_in_memory, + save_infos=save_infos, + revision=revision, + token=token, + use_auth_token=use_auth_token, + task=task, + streaming=streaming, + num_proc=num_proc, + storage_options=storage_options, + trust_remote_code=trust_remote_code, + **config_kwargs + ) \ No newline at end of file diff --git a/pycsghub/snapshot_download.py b/pycsghub/snapshot_download.py index 54b676b..adebf9b 100644 --- a/pycsghub/snapshot_download.py +++ b/pycsghub/snapshot_download.py @@ -55,7 +55,7 @@ def snapshot_download( " online, set 'local_files_only' to False.") return cache.get_root_location() else: - download_endpoint = endpoint if endpoint is not None else get_endpoint() + download_endpoint = get_endpoint(endpoint=endpoint) # make headers # todo need to add cookies? repo_info = utils.get_repo_info(repo_id, @@ -74,15 +74,12 @@ def snapshot_download( ) ) - with tempfile.TemporaryDirectory( - dir=temporary_cache_dir) as temp_cache_dir: + with tempfile.TemporaryDirectory(dir=temporary_cache_dir) as temp_cache_dir: for repo_file in repo_files: repo_file_info = pack_repo_file_info(repo_file, revision) if cache.exists(repo_file_info): file_name = os.path.basename(repo_file_info['Path']) - print( - f'File {file_name} already in cache, skip downloading!' - ) + print(f"File {file_name} already in cache '{cache.get_root_location()}', skip downloading!") continue # get download url @@ -103,7 +100,8 @@ def snapshot_download( # todo using hash to check file integrity temp_file = os.path.join(temp_cache_dir, repo_file) - cache.put_file(repo_file_info, temp_file) - + savedFile = cache.put_file(repo_file_info, temp_file) + print(f"Saved file to '{savedFile}'") + cache.save_model_version(revision_info={'Revision': revision}) return os.path.join(cache.get_root_location()) diff --git a/pycsghub/test/snapshot_download_test.py b/pycsghub/test/snapshot_download_test.py index 76411e5..51971a6 100644 --- a/pycsghub/test/snapshot_download_test.py +++ b/pycsghub/test/snapshot_download_test.py @@ -9,7 +9,7 @@ def test_something(self): self.assertEqual(True, False) # add assertion here def test_snapshot_download(self): - token = ("4e5b97a59c1f8a954954971bf1cdbf3ce61a35sd5") + token = ("your_access_token") endpoint = "https://hub.opencsg.com" repo_id = 'OpenCSG/csg-wukong-1B' cache_dir = '/home/test4' @@ -20,9 +20,8 @@ def test_snapshot_download(self): print(result) def test_singlefile_download(self): - token = ("f3a7b9c1d6e5f8e2a1b5d4f9e6a2b8d7c3a4e2b1d9f6e7a8d2c5a7b4c1e3f5b8a1d4f" - "9b7d6e2f8a5d3b1e7f9c6a8b2d1e4f7d5b6e9f2a4b3c8e1d7f995hd82hf") - endpoint = "https://hub-stg.opencsg.com" + token = ("your_access_token") + endpoint = "https://hub.opencsg.com" repo_id = 'wayne0019/lwfmodel' cache_dir = '/home/test6' result = file_download(repo_id, @@ -33,7 +32,7 @@ def test_singlefile_download(self): print(result) def test_singlefile_download_not_exist(self): - token = ("4e5b97a59c1f8a954954971bf1cdbf3ce61a35d5") + token = ("your_access_token") endpoint = "https://hub.opencsg.com" repo_id = 'OpenCSG/csg-wukong-1B' cache_dir = '/home/test5' @@ -47,7 +46,7 @@ def test_singlefile_download_not_exist(self): self.assertEqual(str(e), "file wolegequ.hehe not in repo wayne0019/lwfmodel") def test_snapshot_download(self): - token = ("4e5b97a59c1f8a954954971bf1cdbf3ce61a35sd5") + token = ("your_access_token") endpoint = "https://hub.opencsg.com" repo_id = 'AIWizards/tmmluplus' cache_dir = '~/Downloads/' diff --git a/pycsghub/test/utils_test.py b/pycsghub/test/utils_test.py index aea873f..d4ca215 100644 --- a/pycsghub/test/utils_test.py +++ b/pycsghub/test/utils_test.py @@ -2,9 +2,8 @@ from pycsghub.utils import model_info class MyTestCase(unittest.TestCase): - token = "f3a7b9c1d6e5f8e2a1b5d4f9e6a2b8d7c3a4e2b1d9f6e7a8d2c5a7b4c1e3f5b8a1d4f9" + \ - "b7d6e2f8a5d3b1e7f9c6a8b2d1e4f7d5b6e9f2a4b3c8e1d7f995hd82hf" - endpoint = "https://hub-stg.opencsg.com" + token = "your_access_token" + endpoint = "https://hub.opencsg.com" repo_id = 'wayne0019/lwfmodel' def test_something(self): self.assertEqual(True, False) # add assertion here diff --git a/pycsghub/utils.py b/pycsghub/utils.py index bdc6ebb..3cad08f 100644 --- a/pycsghub/utils.py +++ b/pycsghub/utils.py @@ -76,8 +76,7 @@ def get_cache_dir(model_id: Optional[str] = None, repo_type: Optional[str] = Non sub_dir = 'hub' if repo_type == "dataset": sub_dir = 'dataset' - base_path = os.getenv('CSGHUB_CACHE', - os.path.join(default_cache_dir, sub_dir)) + base_path = os.getenv('CSGHUB_CACHE', os.path.join(default_cache_dir, sub_dir)) return base_path if model_id is None else os.path.join( base_path, model_id + '/') @@ -205,18 +204,11 @@ def dataset_info( """ headers = build_csg_headers(token=token) - path = ( - f"{endpoint}/hf/api/datasets/{repo_id}" - if revision is None - else (f"{endpoint}/hf/api/datasets/{repo_id}/revision/{quote(revision, safe='')}") - ) + path = get_repo_meta_path(repo_type=REPO_TYPE_DATASET, repo_id=repo_id, revision=revision, endpoint=endpoint) params = {} if files_metadata: params["blobs"] = True - r = requests.get(path, - headers=headers, - timeout=timeout, - params=params) + r = requests.get(path, headers=headers, timeout=timeout, params=params) r.raise_for_status() data = r.json() return DatasetInfo(**data) @@ -270,18 +262,11 @@ def space_info( """ headers = build_csg_headers(token=token) - path = ( - f"{endpoint}/hf/api/spaces/{repo_id}" - if revision is None - else (f"{endpoint}/hf/api/spaces/{repo_id}/revision/{quote(revision, safe='')}") - ) + path = get_repo_meta_path(repo_type=REPO_TYPE_SPACE, repo_id=repo_id, revision=revision, endpoint=endpoint) params = {} if files_metadata: params["blobs"] = True - r = requests.get(path, - headers=headers, - timeout=timeout, - params=params) + r = requests.get(path, headers=headers, timeout=timeout, params=params) r.raise_for_status() data = r.json() return SpaceInfo(**data) @@ -339,29 +324,27 @@ def model_info( """ headers = build_csg_headers(token=token) - path = ( - f"{endpoint}/hf/api/models/{repo_id}/revision/main" - if revision is None - else f"{endpoint}/hf/api/models/{repo_id}/revision/{quote(revision, safe='')}" - ) + path = get_repo_meta_path(repo_type=REPO_TYPE_MODEL, repo_id=repo_id, revision=revision, endpoint=endpoint) params = {} if securityStatus: params["securityStatus"] = True if files_metadata: params["blobs"] = True - r = requests.get(path, - headers=headers, - timeout=timeout, - params=params) + r = requests.get(path, headers=headers, timeout=timeout, params=params) r.raise_for_status() data = r.json() return ModelInfo(**data) - -def get_endpoint(): - csghub_domain = os.getenv('CSGHUB_DOMAIN', DEFAULT_CSGHUB_DOMAIN) - return csghub_domain - +def get_repo_meta_path(repo_type: str, repo_id: str, revision: Optional[str] = None, endpoint: Optional[str] = None) -> str: + if repo_type == REPO_TYPE_MODEL or repo_type == REPO_TYPE_DATASET or repo_type == REPO_TYPE_SPACE: + path = ( + f"{endpoint}/hf/api/{repo_type}s/{repo_id}/revision/main" + if revision is None + else f"{endpoint}/hf/api/{repo_type}s/{repo_id}/revision/{quote(revision, safe='')}" + ) + else: + raise ValueError("repo_type must be one of 'model', 'dataset' or 'space'") + return path def get_file_download_url(model_id: str, file_path: str, @@ -378,18 +361,31 @@ def get_file_download_url(model_id: str, Returns: str: The file url. """ - file_path = urllib.parse.quote_plus(file_path) - revision = urllib.parse.quote_plus(revision) + file_path = urllib.parse.quote(file_path) + revision = urllib.parse.quote(revision) download_url_template = '{endpoint}/hf/{model_id}/resolve/{revision}/{file_path}' if repo_type == REPO_TYPE_DATASET: download_url_template = '{endpoint}/hf/datasets/{model_id}/resolve/{revision}/{file_path}' return download_url_template.format( - endpoint=endpoint if endpoint is not None else get_endpoint(), + endpoint=endpoint, model_id=model_id, revision=revision, file_path=file_path, ) +def get_endpoint(endpoint: Optional[str] = None): + """Format endpoint to remove trailing slash and add a leading slash if not present. + Args: + endpoint (str): The endpoint url. + + Returns: + str: The formatted endpoint url. + """ + csghub_domain = os.getenv('CSGHUB_DOMAIN', DEFAULT_CSGHUB_DOMAIN) + corrent_endpoint = endpoint if endpoint is not None else csghub_domain + if corrent_endpoint.endswith('/'): + corrent_endpoint = corrent_endpoint[:-1] + return corrent_endpoint def file_integrity_validation(file_path, expected_sha256) -> None: @@ -409,7 +405,6 @@ def file_integrity_validation(file_path, msg = 'File %s integrity check failed, the download may be incomplete, please try again.' % file_path raise FileIntegrityError(msg) - def compute_hash(file_path) -> str: BUFFER_SIZE = 1024 * 64 # 64k buffer size sha256_hash = hashlib.sha256() diff --git a/setup.py b/setup.py index dfcc4c7..752c03a 100644 --- a/setup.py +++ b/setup.py @@ -40,12 +40,13 @@ "Sphinx==7.3.7", "thread==2.0.3", "tornado==6.4", - "tqdm==4.66.2", + "tqdm==4.66.3", "torch", "transformers==4.40.1", "trove_classifiers==2024.5.22", "truststore==0.9.1", "urllib3_secure_extra==0.1.0", + "datasets==2.20.0" ], python_requires=">=3.10", ) \ No newline at end of file