-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[DEV] Change compile timing of stan models
Now stan models will be compiled when packages are installed. Along with this manner, tests ways are also changed.
- Loading branch information
1 parent
7aa3707
commit f9452f6
Showing
13 changed files
with
470 additions
and
497 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -128,4 +128,5 @@ Network Trash Folder | |
Temporary Items | ||
.apdisk | ||
|
||
|
||
# Compiled models of stan | ||
sphere/stan_models | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,8 @@ | ||
language: python | ||
python: | ||
- "3.4" | ||
- "3.5" | ||
- "3.6" | ||
install: | ||
- pip install . | ||
- pip install nose | ||
- pip install --upgrade pip | ||
- pip install -U -r requirements.txt | ||
script: | ||
- nosetests -v | ||
- python setup.py develop test |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
matplotlib==2.2.2 | ||
numpy==1.14.3 | ||
pandas==0.22.0 | ||
pystan==2.17.1.0 | ||
scipy==1.1.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,29 +3,118 @@ | |
# Author: Shinya Suzuki | ||
# Created: 2017-03-18 | ||
|
||
from pathlib import Path | ||
import pickle | ||
import sys | ||
from pkg_resources import ( | ||
normalize_path, | ||
working_set, | ||
add_activation_listener, | ||
require, | ||
) | ||
try: | ||
from setuptools import setup, find_packages | ||
from setuptools.command.build_py import build_py | ||
from setuptools.command.develop import develop | ||
from setuptools.command.test import test as test_command | ||
except ImportError: | ||
raise ImportError("Please install setuptools.") | ||
|
||
SETUP_DIR = Path(__file__).parent.resolve() | ||
MODELS_DIR = SETUP_DIR / "stan" | ||
MODELS_TARGET_DIR = Path("sphere/stan_models") | ||
|
||
|
||
def take_package_name(name): | ||
if name.startswith("-e"): | ||
return name[name.find("=")+1:name.rfind("-")] | ||
else: | ||
return name.strip() | ||
|
||
|
||
def load_requires_from_file(filepath): | ||
with open(filepath) as fp: | ||
return [take_package_name(pkg_name) for pkg_name in fp.readlines()] | ||
|
||
|
||
def compile_stan_models(target_dir, models_dir=MODELS_DIR): | ||
from pystan import StanModel | ||
model_path_list = MODELS_DIR.glob("*.stan") | ||
for model_path in model_path_list: | ||
model_type = model_path.stem | ||
target_name = model_type + ".pkl" | ||
target_path = target_dir / target_name | ||
with open(model_path) as f: | ||
model_code = f.read() | ||
model = StanModel(model_code=model_code) | ||
with open(target_path, "wb") as f: | ||
pickle.dump(model, f, protocol=pickle.HIGHEST_PROTOCOL) | ||
|
||
|
||
class BuildPyCommand(build_py): | ||
"""Custom build command to pre-compile Stan models.""" | ||
def run(self): | ||
if not self.dry_run: | ||
build_lib_path = Path(self.build_lib).resolve() | ||
target_dir = build_lib_path / MODELS_TARGET_DIR | ||
self.mkpath(str(target_dir)) | ||
compile_stan_models(target_dir) | ||
build_py.run(self) | ||
|
||
|
||
class DevelopCommand(develop): | ||
"""Custom develop command to pre-compile Stan models in-place.""" | ||
|
||
def run(self): | ||
if not self.dry_run: | ||
build_lib_path = Path(self.setup_path).resolve() | ||
target_dir = build_lib_path / MODELS_TARGET_DIR | ||
self.mkpath(str(target_dir)) | ||
compile_stan_models(target_dir) | ||
develop.run(self) | ||
|
||
|
||
class TestCommand(test_command): | ||
"""We must run tests on the build directory, not source.""" | ||
|
||
def with_project_on_sys_path(self, func): | ||
# Ensure metadata is up-to-date | ||
self.reinitialize_command('build_py', inplace=0) | ||
self.run_command('build_py') | ||
bpy_cmd = self.get_finalized_command("build_py") | ||
build_path = normalize_path(bpy_cmd.build_lib) | ||
|
||
# Build extensions | ||
self.reinitialize_command('egg_info', egg_base=build_path) | ||
self.run_command('egg_info') | ||
|
||
self.reinitialize_command('build_ext', inplace=0) | ||
self.run_command('build_ext') | ||
|
||
ei_cmd = self.get_finalized_command("egg_info") | ||
|
||
old_path = sys.path[:] | ||
old_modules = sys.modules.copy() | ||
|
||
try: | ||
sys.path.insert(0, normalize_path(ei_cmd.egg_base)) | ||
working_set.__init__() | ||
add_activation_listener(lambda dist: dist.activate()) | ||
require('%s==%s' % (ei_cmd.egg_name, ei_cmd.egg_version)) | ||
func() | ||
finally: | ||
sys.path[:] = old_path | ||
sys.modules.clear() | ||
sys.modules.update(old_modules) | ||
working_set.__init__() | ||
|
||
|
||
setup( | ||
name="sphere", | ||
version='1.0.0', | ||
packages=find_packages(exclude=[ | ||
"tests" | ||
]), | ||
|
||
install_requires=[ | ||
"numpy>=1.13.1", | ||
"pandas>=0.20.3", | ||
"matplotlib>=2.0.2", | ||
"pystan>=2.16.0.0", | ||
"scipy>=1.0.0" | ||
], | ||
extras_require={ | ||
"test": ["nose"] | ||
}, | ||
test_suite="nosetests", | ||
packages=find_packages(), | ||
|
||
install_requires=load_requires_from_file("requirements.txt"), | ||
|
||
author="Shinya SUZUKI", | ||
author_email="[email protected]", | ||
|
@@ -53,5 +142,11 @@ | |
"sphere_filter= sphere.sphere_filter:main_wrapper", | ||
"sphere_symtest= sphere.sphere_symtest:main_wrapper" | ||
] | ||
} | ||
}, | ||
cmdclass={ | ||
"build_py": BuildPyCommand, | ||
"develop": DevelopCommand, | ||
"test": TestCommand | ||
}, | ||
test_suite="tests" | ||
) |
Oops, something went wrong.