Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fund data as an example #292

Merged
merged 13 commits into from
Mar 19, 2021
48 changes: 24 additions & 24 deletions scripts/data_collector/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
Parameters
----------
save_dir: str
stock save dir
instrument save dir
max_workers: int
workers, default 4
max_collector_count: int
Expand Down Expand Up @@ -77,11 +77,11 @@ def __init__(
self.start_datetime = self.normalize_start_datetime(start)
self.end_datetime = self.normalize_end_datetime(end)

self.stock_list = sorted(set(self.get_stock_list()))
self.instrument_list = sorted(set(self.get_instrument_list()))

if limit_nums is not None:
try:
self.stock_list = self.stock_list[: int(limit_nums)]
self.instrument_list = self.instrument_list[: int(limit_nums)]
except Exception as e:
logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored")

Expand All @@ -108,8 +108,8 @@ def min_numbers_trading(self):
raise NotImplementedError("rewrite min_numbers_trading")

@abc.abstractmethod
def get_stock_list(self):
raise NotImplementedError("rewrite get_stock_list")
def get_instrument_list(self):
raise NotImplementedError("rewrite get_instrument_list")

@abc.abstractmethod
def normalize_symbol(self, symbol: str):
Expand Down Expand Up @@ -158,27 +158,27 @@ def _simple_collector(self, symbol: str):
return _result

def save_instrument(self, symbol, df: pd.DataFrame):
"""save stock data to file
"""save instrument data to file

Parameters
----------
symbol: str
stock code
instrument code
df : pd.DataFrame
df.columns must contain "symbol" and "datetime"
"""
if df.empty:
if df is None or df.empty:
logger.warning(f"{symbol} is empty")
return

symbol = self.normalize_symbol(symbol)
symbol = code_to_fname(symbol)
stock_path = self.save_dir.joinpath(f"{symbol}.csv")
instrument_path = self.save_dir.joinpath(f"{symbol}.csv")
df["symbol"] = symbol
if stock_path.exists():
_old_df = pd.read_csv(stock_path)
if instrument_path.exists():
_old_df = pd.read_csv(instrument_path)
df = _old_df.append(df, sort=False)
df.to_csv(stock_path, index=False)
df.to_csv(instrument_path, index=False)

def cache_small_data(self, symbol, df):
if len(df) <= self.min_numbers_trading:
Expand All @@ -191,38 +191,38 @@ def cache_small_data(self, symbol, df):
self.mini_symbol_map.pop(symbol)
return self.NORMAL_FLAG

def _collector(self, stock_list):
def _collector(self, instrument_list):

error_symbol = []
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
with tqdm(total=len(stock_list)) as p_bar:
for _symbol, _result in zip(stock_list, executor.map(self._simple_collector, stock_list)):
with tqdm(total=len(instrument_list)) as p_bar:
for _symbol, _result in zip(instrument_list, executor.map(self._simple_collector, instrument_list)):
if _result != self.NORMAL_FLAG:
error_symbol.append(_symbol)
p_bar.update()
print(error_symbol)
logger.info(f"error symbol nums: {len(error_symbol)}")
logger.info(f"current get symbol nums: {len(stock_list)}")
logger.info(f"current get symbol nums: {len(instrument_list)}")
error_symbol.extend(self.mini_symbol_map.keys())
return sorted(set(error_symbol))

def collector_data(self):
"""collector data"""
logger.info("start collector data......")
stock_list = self.stock_list
instrument_list = self.instrument_list
for i in range(self.max_collector_count):
if not stock_list:
if not instrument_list:
break
logger.info(f"getting data: {i+1}")
stock_list = self._collector(stock_list)
instrument_list = self._collector(instrument_list)
logger.info(f"{i+1} finish.")
for _symbol, _df_list in self.mini_symbol_map.items():
self.save_instrument(
_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"])
)
if self.mini_symbol_map:
logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self.mini_symbol_map.keys())}")
logger.info(f"total {len(self.stock_list)}, error: {len(set(stock_list))}")
logger.warning(f"less than {self.min_numbers_trading} instrument list: {list(self.mini_symbol_map.keys())}")
logger.info(f"total {len(self.instrument_list)}, error: {len(set(instrument_list))}")


class BaseNormalize(abc.ABC):
Expand Down Expand Up @@ -386,9 +386,9 @@ def download_data(
Examples
---------
# get daily data
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
$ python collector.py download_data --source_dir ~/.qlib/instrument_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
# get 1m data
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
$ python collector.py download_data --source_dir ~/.qlib/instrument_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
"""

_class = getattr(self._cur_module, self.collector_class_name) # type: Type[BaseCollector]
Expand Down Expand Up @@ -416,7 +416,7 @@ def normalize_data(self, date_field_name: str = "date", symbol_field_name: str =

Examples
---------
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --interval 1d
$ python collector.py normalize_data --source_dir ~/.qlib/instrument_data/source --normalize_dir ~/.qlib/instrument_data/normalize --region CN --interval 1d
"""
_class = getattr(self._cur_module, self.normalize_class_name)
yc = Normalize(
Expand Down
51 changes: 51 additions & 0 deletions scripts/data_collector/fund/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Collect Fund Data

> *Please pay **ATTENTION** that the data is collected from [天天基金网](https://fund.eastmoney.com/) and the data might not be perfect. We recommend users to prepare their own data if they have high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*

## Requirements

```bash
pip install -r requirements.txt
```

## Collector Data


### CN Data

#### 1d from East Money

```bash

# download from eastmoney.com
python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d

# normalize
python collector.py normalize_data --source_dir ~/.qlib/fund_data/source/cn_1d --normalize_dir ~/.qlib/fund_data/source/cn_1d_nor --region CN --interval 1d --date_field_name FSRQ

# dump data
cd qlib/scripts
python dump_bin.py dump_all --csv_path ~/.qlib/fund_data/source/cn_1d_nor --qlib_dir ~/.qlib/qlib_data/cn_fund_data --freq day --date_field_name FSRQ --include_fields DWJZ,LJJZ

```

### using data

```python
import qlib
from qlib.data import D

qlib.init(provider_uri="~/.qlib/qlib_data/cn_fund_data")
df = D.features(D.instruments(market="all"), ["$DWJZ", "$LJJZ"], freq="day")
```


### Help
```bash
pythono collector.py collector_data --help
```

## Parameters

- interval: 1d
- region: CN
Loading