Skip to content

Commit

Permalink
split cpu eval CI by dtype (pytorch#554)
Browse files Browse the repository at this point in the history
* split cpu eval CI by dtype

* fix

* differentiate names with checks

* keep one name the same as old

* fix
  • Loading branch information
metascroy authored and malfet committed Jul 17, 2024
1 parent 05bd844 commit 6109e08
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 3 deletions.
22 changes: 20 additions & 2 deletions .ci/scripts/validate.sh
Original file line number Diff line number Diff line change
Expand Up @@ -255,10 +255,11 @@ function eval_model() {
function eval_model_sanity_check() {
local CHECKPOINT_PATH="$1"
local TARGET_DEVICE="${2:-cpu}"
local DTYPES="$3"
local MODEL_DIR="${CHECKPOINT_PATH%/*}"
local MODEL_NAME=$(basename "$CHECKPOINT_PATH" | sed 's/\.[^.]*$//')

for DTYPE in float32 bfloat16 float16; do
for DTYPE in $DTYPES; do
echo ""############### Run eval with torch.compile for dtype $DTYPE "###############"
echo ""
echo "******************************************"
Expand Down Expand Up @@ -320,7 +321,8 @@ function run_eval(){
}

function run_eval_sanity_check(){
eval_model_sanity_check "$CHECKPOINT_PATH" "$TARGET_DEVICE" || exit 1
echo "Passing DTYPES=$DTYPES"
eval_model_sanity_check "$CHECKPOINT_PATH" "$TARGET_DEVICE" "$DTYPES" || exit 1
}

CHECKPOINT_PATH="$1"
Expand Down Expand Up @@ -365,6 +367,22 @@ if [ "$#" -gt 2 ]; then
;;
"eval_sanity_check")
echo "arg:$arg"
DTYPES="bfloat16 float16 float32"
run_eval_sanity_check || exit 1
;;
"eval_sanity_check-bfloat16")
echo "arg:$arg"
DTYPES="bfloat16"
run_eval_sanity_check || exit 1
;;
"eval_sanity_check-float16")
echo "arg:$arg"
DTYPES="float16"
run_eval_sanity_check || exit 1
;;
"eval_sanity_check-float32")
echo "arg:$arg"
DTYPES="float32"
run_eval_sanity_check || exit 1
;;
*)
Expand Down
74 changes: 73 additions & 1 deletion .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,79 @@ jobs:
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
pushd ${TORCHCHAT_ROOT}
bash .ci/scripts/convert_checkpoint.sh ${REPO_NAME}
bash .ci/scripts/validate.sh "./checkpoints/${REPO_NAME}/model.pth" "cpu" "eval_sanity_check"
bash .ci/scripts/validate.sh "./checkpoints/${REPO_NAME}/model.pth" "cpu" "eval_sanity_check-bfloat16"
test-cpu-eval-sanity-check-float16:
name: test-cpu-eval-sanity-check-float16 (${{ matrix.platform }}, ${{ matrix.model_name }})
needs: gather-models-cpu
strategy:
matrix: ${{ fromJSON(needs.gather-models-cpu.outputs.models) }}
fail-fast: false
runs-on: ${{ matrix.runner }}
env:
TORCHCHAT_ROOT: ${{ github.workspace }}
REPO_NAME: ${{ matrix.repo_name }}
steps:
- name: Checkout repo
uses: actions/checkout@v3
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: '3.11'
- name: Print machine info
run: |
echo "$(uname -a)"
- name: Install dependencies
run: |
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
pip3 install -r requirements.txt
pip3 list
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
- name: Download checkpoints
run: |
bash ${TORCHCHAT_ROOT}/.ci/scripts/wget_checkpoint.sh ${{ matrix.repo_name }} "${{ matrix.resources }}"
- name: Run validation
run: |
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
pushd ${TORCHCHAT_ROOT}
bash .ci/scripts/convert_checkpoint.sh ${REPO_NAME}
bash .ci/scripts/validate.sh "./checkpoints/${REPO_NAME}/model.pth" "cpu" "eval_sanity_check-float16"
test-cpu-eval-sanity-check-float32:
name: test-cpu-eval-sanity-check-float32 (${{ matrix.platform }}, ${{ matrix.model_name }})
needs: gather-models-cpu
strategy:
matrix: ${{ fromJSON(needs.gather-models-cpu.outputs.models) }}
fail-fast: false
runs-on: ${{ matrix.runner }}
env:
TORCHCHAT_ROOT: ${{ github.workspace }}
REPO_NAME: ${{ matrix.repo_name }}
steps:
- name: Checkout repo
uses: actions/checkout@v3
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: '3.11'
- name: Print machine info
run: |
echo "$(uname -a)"
- name: Install dependencies
run: |
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
pip3 install -r requirements.txt
pip3 list
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
- name: Download checkpoints
run: |
bash ${TORCHCHAT_ROOT}/.ci/scripts/wget_checkpoint.sh ${{ matrix.repo_name }} "${{ matrix.resources }}"
- name: Run validation
run: |
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
pushd ${TORCHCHAT_ROOT}
bash .ci/scripts/convert_checkpoint.sh ${REPO_NAME}
bash .ci/scripts/validate.sh "./checkpoints/${REPO_NAME}/model.pth" "cpu" "eval_sanity_check-float32"
gather-models-gpu:
runs-on: ubuntu-22.04
Expand Down

0 comments on commit 6109e08

Please sign in to comment.