Skip to content

Commit

Permalink
Bart (apache#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
heheda12345 authored Oct 25, 2023
2 parents cc4ef4d + 97214e2 commit aac8598
Show file tree
Hide file tree
Showing 20 changed files with 2,154 additions and 87 deletions.
11 changes: 10 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,20 @@ jobs:
source ~/venv/frontend-env/bin/activate
pip install --force-reinstall .
cd scripts && BUILD_DIR=~/frontend ./compile_longobj.sh && cd ..
- name: Test with pytest
- name: install dependency
run: |
source /opt/spack/share/spack/setup-env.sh
spack load [email protected] /jb4mlxg
spack load [email protected]%gcc@=11.3.0
source ~/venv/frontend-env/bin/activate
pip install --upgrade -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html
if ! pip show transformers &> /dev/null; then
pip install transformers==v4.29.1
fi
- name: Test with pytest
run: |
source /opt/spack/share/spack/setup-env.sh
spack load [email protected] /jb4mlxg
spack load [email protected]%gcc@=11.3.0
source ~/venv/frontend-env/bin/activate
srun --exclusive ./scripts/pytest_with_preload.sh -vs test
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
.vscode
build
__pycache__
*.so
*.so
test/simple.py
4 changes: 4 additions & 0 deletions frontend/c_api.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,8 @@ def parse_rangeiterobject(obj: Any) -> Tuple[int, int, int, int]:


def make_rangeiterobject(start: int, stop: int, step: int) -> Any:
pass


def get_from_freevars(frame: FrameType, idx: int) -> Any:
pass
24 changes: 23 additions & 1 deletion frontend/csrc/frame_evaluation.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#define PY_SSIZE_T_CLEAN
#include "csrc.h"
#include <Python.h>
#include <cellobject.h>
#include <frameobject.h>
#include <map>
#include <object.h>
Expand Down Expand Up @@ -466,6 +467,24 @@ static PyObject *is_bound_method(PyObject *self, PyObject *args) {
}
}

static PyObject *get_from_freevars(PyObject *self, PyObject *args) {
PyObject *frame;
int index;
if (!PyArg_ParseTuple(args, "Oi", &frame, &index)) {
PRINT_PYERR;
PyErr_SetString(PyExc_TypeError,
"invalid parameter in get_from_freevars");
return NULL;
}
PyFrameObject *f = (PyFrameObject *)frame;
PyObject *value = f->f_localsplus[index + f->f_code->co_nlocals];
if (value == NULL) {
value = null_object;
}
Py_INCREF(value);
return value;
}

static PyMethodDef _methods[] = {
{"set_eval_frame", set_eval_frame, METH_VARARGS, NULL},
{"set_skip_files", set_skip_files, METH_VARARGS, NULL},
Expand All @@ -491,6 +510,7 @@ static PyMethodDef _methods[] = {
METH_VARARGS, NULL},
{"get_code_map", get_code_map, METH_VARARGS, NULL},
{"is_bound_method", is_bound_method, METH_VARARGS, NULL},
{"get_from_freevars", get_from_freevars, METH_VARARGS, NULL},
{"parse_rangeiterobject", frontend_csrc::parse_rangeiterobject,
METH_VARARGS, NULL},
{"make_rangeiterobject", frontend_csrc::make_rangeiterobject, METH_VARARGS,
Expand All @@ -512,5 +532,7 @@ PyMODINIT_FUNC PyInit_c_api(void) {
CHECK(result == 0);
Py_INCREF(Py_None);
set_eval_frame_callback(Py_None);
return PyModule_Create(&_module);

PyObject *m = PyModule_Create(&_module);
return m;
}
Loading

0 comments on commit aac8598

Please sign in to comment.